mirror of
https://github.com/mountain-loop/yaak.git
synced 2026-04-25 10:08:29 +02:00
Fixes for websocket closing
This commit is contained in:
@@ -283,10 +283,26 @@ pub(crate) async fn connect<R: Runtime>(
|
|||||||
let (receive_tx, mut receive_rx) = mpsc::channel::<Message>(128);
|
let (receive_tx, mut receive_rx) = mpsc::channel::<Message>(128);
|
||||||
let mut ws_manager = ws_manager.lock().await;
|
let mut ws_manager = ws_manager.lock().await;
|
||||||
|
|
||||||
let (url, url_parameters) = apply_path_placeholders(&request.url, request.url_parameters);
|
let (mut url, url_parameters) = apply_path_placeholders(&request.url, request.url_parameters);
|
||||||
|
if !url.starts_with("ws://") && !url.starts_with("wss://") {
|
||||||
|
url.insert_str(0, "ws://");
|
||||||
|
}
|
||||||
|
|
||||||
// Add URL parameters to URL
|
// Add URL parameters to URL
|
||||||
let mut url = Url::parse(&url).unwrap();
|
let mut url = match Url::parse(&url) {
|
||||||
|
Ok(url) => url,
|
||||||
|
Err(e) => {
|
||||||
|
return Ok(app_handle.db().upsert_websocket_connection(
|
||||||
|
&WebsocketConnection {
|
||||||
|
error: Some(format!("Failed to parse URL {}", e.to_string())),
|
||||||
|
state: WebsocketConnectionState::Closed,
|
||||||
|
..connection
|
||||||
|
},
|
||||||
|
&UpdateSource::from_window(&window),
|
||||||
|
)?);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
{
|
{
|
||||||
let valid_query_pairs = url_parameters
|
let valid_query_pairs = url_parameters
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
|||||||
@@ -2,26 +2,29 @@ use crate::connect::ws_connect;
|
|||||||
use crate::error::Result;
|
use crate::error::Result;
|
||||||
use futures_util::stream::SplitSink;
|
use futures_util::stream::SplitSink;
|
||||||
use futures_util::{SinkExt, StreamExt};
|
use futures_util::{SinkExt, StreamExt};
|
||||||
use log::{debug, warn};
|
use log::{debug, info, warn};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio::sync::{mpsc, Mutex};
|
use tokio::sync::{Mutex, mpsc};
|
||||||
|
use tokio_tungstenite::tungstenite::Message;
|
||||||
use tokio_tungstenite::tungstenite::handshake::client::Response;
|
use tokio_tungstenite::tungstenite::handshake::client::Response;
|
||||||
use tokio_tungstenite::tungstenite::http::{HeaderMap, HeaderValue};
|
use tokio_tungstenite::tungstenite::http::{HeaderMap, HeaderValue};
|
||||||
use tokio_tungstenite::tungstenite::Message;
|
|
||||||
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
|
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct WebsocketManager {
|
pub struct WebsocketManager {
|
||||||
connections:
|
connections:
|
||||||
Arc<Mutex<HashMap<String, SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>>,
|
Arc<Mutex<HashMap<String, SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>>,
|
||||||
|
read_tasks: Arc<Mutex<HashMap<String, tokio::task::JoinHandle<()>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WebsocketManager {
|
impl WebsocketManager {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
WebsocketManager {
|
WebsocketManager {
|
||||||
connections: Default::default(),
|
connections: Default::default(),
|
||||||
|
read_tasks: Default::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -33,28 +36,35 @@ impl WebsocketManager {
|
|||||||
receive_tx: mpsc::Sender<Message>,
|
receive_tx: mpsc::Sender<Message>,
|
||||||
validate_certificates: bool,
|
validate_certificates: bool,
|
||||||
) -> Result<Response> {
|
) -> Result<Response> {
|
||||||
let connections = self.connections.clone();
|
|
||||||
let connection_id = id.to_string();
|
|
||||||
let tx = receive_tx.clone();
|
let tx = receive_tx.clone();
|
||||||
|
|
||||||
let (stream, response) = ws_connect(url, headers, validate_certificates).await?;
|
let (stream, response) = ws_connect(url, headers, validate_certificates).await?;
|
||||||
let (write, mut read) = stream.split();
|
let (write, mut read) = stream.split();
|
||||||
|
|
||||||
connections.lock().await.insert(id.to_string(), write);
|
self.connections.lock().await.insert(id.to_string(), write);
|
||||||
|
|
||||||
tauri::async_runtime::spawn(async move {
|
let handle = {
|
||||||
while let Some(msg) = read.next().await {
|
let connection_id = id.to_string();
|
||||||
match msg {
|
let connections = self.connections.clone();
|
||||||
Err(e) => {
|
let read_tasks = self.read_tasks.clone();
|
||||||
warn!("Broken websocket connection: {}", e);
|
tokio::task::spawn(async move {
|
||||||
break;
|
while let Some(msg) = read.next().await {
|
||||||
|
match msg {
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Broken websocket connection: {}", e);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Ok(message) => tx.send(message).await.unwrap(),
|
||||||
}
|
}
|
||||||
Ok(message) => tx.send(message).await.unwrap(),
|
|
||||||
}
|
}
|
||||||
}
|
debug!("Connection {} closed", connection_id);
|
||||||
debug!("Connection {} closed", connection_id);
|
connections.lock().await.remove(&connection_id);
|
||||||
connections.lock().await.remove(&connection_id);
|
read_tasks.lock().await.remove(&connection_id);
|
||||||
});
|
})
|
||||||
|
};
|
||||||
|
|
||||||
|
self.read_tasks.lock().await.insert(id.to_string(), handle);
|
||||||
|
|
||||||
Ok(response)
|
Ok(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,13 +80,21 @@ impl WebsocketManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn close(&mut self, id: &str) -> Result<()> {
|
pub async fn close(&mut self, id: &str) -> Result<()> {
|
||||||
debug!("Closing websocket");
|
info!("Closing websocket");
|
||||||
let mut connections = self.connections.lock().await;
|
if let Some(mut connection) = self.connections.lock().await.remove(id) {
|
||||||
let connection = match connections.get_mut(id) {
|
// Wait a maximum of 1 second for the connection to close
|
||||||
None => return Ok(()),
|
if let Err(e) = connection.close().await {
|
||||||
Some(c) => c,
|
warn!("Failed to close websocket connection {e:?}");
|
||||||
};
|
};
|
||||||
connection.close().await?;
|
}
|
||||||
|
|
||||||
|
// Wait at short time for the server to close the connection, then stop
|
||||||
|
// reading.
|
||||||
|
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||||
|
if let Some(handle) = self.read_tasks.lock().await.remove(id) {
|
||||||
|
handle.abort();
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user