fix ws connection state (#175)

Co-authored-by: Gregory Schier <gschier1990@gmail.com>
This commit is contained in:
Hao Xiang
2025-03-09 00:03:16 +08:00
committed by GitHub
parent f4d0371060
commit cdce2ac53a
17 changed files with 260 additions and 184 deletions

View File

@@ -36,20 +36,21 @@ use yaak_models::models::{
ModelType, Plugin, Settings, WebsocketRequest, Workspace, WorkspaceMeta,
};
use yaak_models::queries::{
batch_upsert, cancel_pending_grpc_connections, cancel_pending_responses,
create_default_http_response, delete_all_grpc_connections,
delete_all_grpc_connections_for_workspace, delete_all_http_responses_for_request,
delete_all_http_responses_for_workspace, delete_all_websocket_connections_for_workspace,
delete_cookie_jar, delete_environment, delete_folder, delete_grpc_connection,
delete_grpc_request, delete_http_request, delete_http_response, delete_plugin,
delete_workspace, duplicate_folder, duplicate_grpc_request, duplicate_http_request,
ensure_base_environment, generate_model_id, get_base_environment, get_cookie_jar,
get_environment, get_folder, get_grpc_connection, get_grpc_request, get_http_request,
get_http_response, get_key_value_raw, get_or_create_settings, get_or_create_workspace_meta,
get_plugin, get_workspace, get_workspace_export_resources, list_cookie_jars, list_environments,
list_folders, list_grpc_connections_for_workspace, list_grpc_events, list_grpc_requests,
list_http_requests, list_http_responses_for_workspace, list_key_values_raw, list_plugins,
list_workspaces, set_key_value_raw, update_response_if_id, update_settings, upsert_cookie_jar,
batch_upsert, cancel_pending_grpc_connections, cancel_pending_http_responses,
cancel_pending_websocket_connections, create_default_http_response,
delete_all_grpc_connections, delete_all_grpc_connections_for_workspace,
delete_all_http_responses_for_request, delete_all_http_responses_for_workspace,
delete_all_websocket_connections_for_workspace, delete_cookie_jar, delete_environment,
delete_folder, delete_grpc_connection, delete_grpc_request, delete_http_request,
delete_http_response, delete_plugin, delete_workspace, duplicate_folder,
duplicate_grpc_request, duplicate_http_request, ensure_base_environment, generate_model_id,
get_base_environment, get_cookie_jar, get_environment, get_folder, get_grpc_connection,
get_grpc_request, get_http_request, get_http_response, get_key_value_raw,
get_or_create_settings, get_or_create_workspace_meta, get_plugin, get_workspace,
get_workspace_export_resources, list_cookie_jars, list_environments, list_folders,
list_grpc_connections_for_workspace, list_grpc_events, list_grpc_requests, list_http_requests,
list_http_responses_for_workspace, list_key_values_raw, list_plugins, list_workspaces,
set_key_value_raw, update_response_if_id, update_settings, upsert_cookie_jar,
upsert_environment, upsert_folder, upsert_grpc_connection, upsert_grpc_event,
upsert_grpc_request, upsert_http_request, upsert_plugin, upsert_workspace,
upsert_workspace_meta, BatchUpsertResult, UpdateSource,
@@ -367,7 +368,8 @@ async fn cmd_grpc_go<R: Runtime>(
RenderPurpose::Send,
),
)
.await.expect("Failed to render template")
.await
.expect("Failed to render template")
})
});
let d_msg: DynamicMessage = match deserialize_message(msg.as_str(), method_desc)
@@ -1921,8 +1923,9 @@ pub fn run() {
// Cancel pending requests
let h = app_handle.clone();
tauri::async_runtime::block_on(async move {
let _ = cancel_pending_responses(&h).await;
let _ = cancel_pending_http_responses(&h).await;
let _ = cancel_pending_grpc_connections(&h).await;
let _ = cancel_pending_websocket_connections(&h).await;
});
}
RunEvent::WindowEvent {

View File

@@ -64,11 +64,11 @@ export type UpdateSource = "sync" | "window" | "plugin" | "background" | "import
export type WebsocketConnection = { model: "websocket_connection", id: string, createdAt: string, updatedAt: string, workspaceId: string, requestId: string, elapsed: number, error: string | null, headers: Array<HttpResponseHeader>, state: WebsocketConnectionState, status: number, url: string, };
export type WebsocketConnectionState = "initialized" | "connected" | "closed";
export type WebsocketConnectionState = "initialized" | "connected" | "closing" | "closed";
export type WebsocketEvent = { model: "websocket_event", id: string, createdAt: string, updatedAt: string, workspaceId: string, requestId: string, connectionId: string, isServer: boolean, message: Array<number>, messageType: WebsocketEventType, };
export type WebsocketEventType = "binary" | "close" | "frame" | "ping" | "pong" | "text";
export type WebsocketEventType = "binary" | "close" | "frame" | "open" | "ping" | "pong" | "text";
export type WebsocketMessageType = "text" | "binary";

View File

@@ -549,6 +549,7 @@ impl<'s> TryFrom<&Row<'s>> for HttpRequest {
pub enum WebsocketConnectionState {
Initialized,
Connected,
Closing,
Closed,
}
@@ -714,6 +715,7 @@ pub enum WebsocketEventType {
Binary,
Close,
Frame,
Open,
Ping,
Pong,
Text,

View File

@@ -6,9 +6,9 @@ use crate::models::{
GrpcRequestIden, HttpRequest, HttpRequestIden, HttpResponse, HttpResponseHeader,
HttpResponseIden, HttpResponseState, KeyValue, KeyValueIden, ModelType, Plugin, PluginIden,
PluginKeyValue, PluginKeyValueIden, Settings, SettingsIden, SyncState, SyncStateIden,
WebsocketConnection, WebsocketConnectionIden, WebsocketEvent, WebsocketEventIden,
WebsocketRequest, WebsocketRequestIden, Workspace, WorkspaceIden, WorkspaceMeta,
WorkspaceMetaIden,
WebsocketConnection, WebsocketConnectionIden, WebsocketConnectionState, WebsocketEvent,
WebsocketEventIden, WebsocketRequest, WebsocketRequestIden, Workspace, WorkspaceIden,
WorkspaceMeta, WorkspaceMetaIden,
};
use crate::plugin::SqliteConnection;
use chrono::{NaiveDateTime, Utc};
@@ -2143,6 +2143,21 @@ pub async fn create_http_response<R: Runtime>(
Ok(m)
}
pub async fn cancel_pending_websocket_connections<R: Runtime>(mgr: &impl Manager<R>) -> Result<()> {
let dbm = &*mgr.state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
let closed = serde_json::to_value(&WebsocketConnectionState::Closed)?;
let (sql, params) = Query::update()
.table(WebsocketConnectionIden::Table)
.values([(WebsocketConnectionIden::State, closed.as_str().into())])
.cond_where(Expr::col(WebsocketConnectionIden::State).ne(closed.as_str()))
.build_rusqlite(SqliteQueryBuilder);
let mut stmt = db.prepare(sql.as_str())?;
stmt.execute(&*params.as_params())?;
Ok(())
}
pub async fn cancel_pending_grpc_connections(app: &AppHandle) -> Result<()> {
let dbm = &*app.app_handle().state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
@@ -2158,7 +2173,7 @@ pub async fn cancel_pending_grpc_connections(app: &AppHandle) -> Result<()> {
Ok(())
}
pub async fn cancel_pending_responses(app: &AppHandle) -> Result<()> {
pub async fn cancel_pending_http_responses(app: &AppHandle) -> Result<()> {
let dbm = &*app.app_handle().state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();

View File

@@ -75,7 +75,8 @@ impl PluginManager {
// Handle when client plugin runtime disconnects
tauri::async_runtime::spawn(async move {
while let Some(_) = client_disconnect_rx.recv().await {
info!("Plugin runtime client disconnected! TODO: Handle this case");
// Happens when the app is closed
info!("Plugin runtime client disconnected");
}
});

View File

@@ -2,7 +2,6 @@ use crate::error::Error::GenericError;
use crate::error::Result;
use crate::manager::WebsocketManager;
use crate::render::render_request;
use chrono::Utc;
use log::{info, warn};
use std::str::FromStr;
use tauri::http::{HeaderMap, HeaderName};
@@ -147,42 +146,21 @@ pub(crate) async fn close<R: Runtime>(
ws_manager: State<'_, Mutex<WebsocketManager>>,
) -> Result<WebsocketConnection> {
let connection = get_websocket_connection(&window, connection_id).await?;
let request = get_websocket_request(&window, &connection.request_id)
.await?
.ok_or(GenericError("WebSocket Request not found".to_string()))?;
let mut ws_manager = ws_manager.lock().await;
if let Err(e) = ws_manager.send(&connection.id, Message::Close(None)).await {
warn!("Failed to close WebSocket connection: {e:?}");
};
upsert_websocket_event(
let connection = upsert_websocket_connection(
&window,
WebsocketEvent {
connection_id: connection.id.clone(),
request_id: request.id.clone(),
workspace_id: request.workspace_id.clone(),
is_server: false,
message_type: WebsocketEventType::Close,
..Default::default()
&WebsocketConnection {
state: WebsocketConnectionState::Closing,
..connection
},
&UpdateSource::Window,
)
.await
.unwrap();
let connection = upsert_websocket_connection(
&window,
&WebsocketConnection {
state: WebsocketConnectionState::Closed,
elapsed: Utc::now()
.naive_utc()
.signed_duration_since(connection.created_at)
.num_milliseconds() as i32,
..connection.clone()
},
&UpdateSource::Window,
)
.await?;
let mut ws_manager = ws_manager.lock().await;
if let Err(e) = ws_manager.close(&connection.id).await {
warn!("Failed to close WebSocket connection: {e:?}");
};
Ok(connection)
}
@@ -264,42 +242,6 @@ pub(crate) async fn connect<R: Runtime>(
let (receive_tx, mut receive_rx) = mpsc::channel::<Message>(128);
let mut ws_manager = ws_manager.lock().await;
{
let connection_id = connection.id.clone();
let request_id = request.id.to_string();
let workspace_id = request.workspace_id.clone();
let window = window.clone();
tokio::spawn(async move {
while let Some(message) = receive_rx.recv().await {
upsert_websocket_event(
&window,
WebsocketEvent {
connection_id: connection_id.clone(),
request_id: request_id.clone(),
workspace_id: workspace_id.clone(),
is_server: true,
message_type: match message {
Message::Text(_) => WebsocketEventType::Text,
Message::Binary(_) => WebsocketEventType::Binary,
Message::Ping(_) => WebsocketEventType::Ping,
Message::Pong(_) => WebsocketEventType::Pong,
Message::Close(_) => WebsocketEventType::Close,
// Raw frame will never happen during a read
Message::Frame(_) => WebsocketEventType::Frame,
},
message: message.into_data().into(),
..Default::default()
},
&UpdateSource::Window,
)
.await
.unwrap();
}
info!("Websocket connection closed");
});
}
let (url, url_parameters) = apply_path_placeholders(&request.url, request.url_parameters);
// Add URL parameters to URL
@@ -331,6 +273,21 @@ pub(crate) async fn connect<R: Runtime>(
}
};
upsert_websocket_event(
&window,
WebsocketEvent {
connection_id: connection.id.clone(),
request_id: request.id.clone(),
workspace_id: connection.workspace_id.clone(),
is_server: false,
message_type: WebsocketEventType::Open,
..Default::default()
},
&UpdateSource::Window,
)
.await
.unwrap();
let response_headers = response
.headers()
.into_iter()
@@ -353,5 +310,74 @@ pub(crate) async fn connect<R: Runtime>(
)
.await?;
{
let connection_id = connection.id.clone();
let request_id = request.id.to_string();
let workspace_id = request.workspace_id.clone();
let window = window.clone();
let connection = connection.clone();
let mut has_written_close = false;
tokio::spawn(async move {
while let Some(message) = receive_rx.recv().await {
if let Message::Close(_) = message {
has_written_close = true;
}
upsert_websocket_event(
&window,
WebsocketEvent {
connection_id: connection_id.clone(),
request_id: request_id.clone(),
workspace_id: workspace_id.clone(),
is_server: true,
message_type: match message {
Message::Text(_) => WebsocketEventType::Text,
Message::Binary(_) => WebsocketEventType::Binary,
Message::Ping(_) => WebsocketEventType::Ping,
Message::Pong(_) => WebsocketEventType::Pong,
Message::Close(_) => WebsocketEventType::Close,
// Raw frame will never happen during a read
Message::Frame(_) => WebsocketEventType::Frame,
},
message: message.into_data().into(),
..Default::default()
},
&UpdateSource::Window,
)
.await
.unwrap();
}
info!("Websocket connection closed");
if !has_written_close {
upsert_websocket_event(
&window,
WebsocketEvent {
connection_id: connection_id.clone(),
request_id: request_id.clone(),
workspace_id: workspace_id.clone(),
is_server: true,
message_type: WebsocketEventType::Close,
..Default::default()
},
&UpdateSource::Window,
)
.await
.unwrap();
}
upsert_websocket_connection(
&window,
&WebsocketConnection {
workspace_id: request.workspace_id.clone(),
request_id: request_id.to_string(),
state: WebsocketConnectionState::Closed,
..connection
},
&UpdateSource::Window,
)
.await
.unwrap();
});
}
Ok(connection)
}

View File

@@ -41,40 +41,4 @@ pub(crate) async fn ws_connect(
)
.await?;
Ok((stream, response))
}
#[cfg(test)]
mod tests {
use crate::connect::ws_connect;
use crate::error::Result;
use futures_util::{SinkExt, StreamExt};
use std::time::Duration;
use tokio::time::timeout;
use tokio_tungstenite::tungstenite::Message;
#[tokio::test]
async fn test_connection() -> Result<()> {
let (stream, response) = ws_connect("wss://echo.websocket.org/", Default::default()).await?;
assert_eq!(response.status(), 101);
let (mut write, mut read) = stream.split();
let task = tokio::spawn(async move {
while let Some(Ok(message)) = read.next().await {
if message.is_text() && message.to_text().unwrap() == "Hello" {
return message;
}
}
panic!("Didn't receive text message");
});
write.send(Message::Text("Hello".into())).await?;
let task = timeout(Duration::from_secs(3), task);
let message = task.await.unwrap().unwrap();
assert_eq!(message.into_text().unwrap(), "Hello");
Ok(())
}
}
}

View File

@@ -5,7 +5,7 @@ mod manager;
mod render;
use crate::commands::{
connect, close, delete_connection, delete_connections, delete_request, duplicate_request,
close, connect, delete_connection, delete_connections, delete_request, duplicate_request,
list_connections, list_events, list_requests, send, upsert_request,
};
use crate::manager::WebsocketManager;
@@ -31,7 +31,6 @@ pub fn init<R: Runtime>() -> TauriPlugin<R> {
.setup(|app, _api| {
let manager = WebsocketManager::new();
app.manage(Mutex::new(manager));
Ok(())
})
.build()

View File

@@ -2,7 +2,7 @@ use crate::connect::ws_connect;
use crate::error::Result;
use futures_util::stream::SplitSink;
use futures_util::{SinkExt, StreamExt};
use log::debug;
use log::{debug, warn};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::net::TcpStream;
@@ -32,19 +32,27 @@ impl WebsocketManager {
headers: HeaderMap<HeaderValue>,
receive_tx: mpsc::Sender<Message>,
) -> Result<Response> {
let connections = self.connections.clone();
let connection_id = id.to_string();
let tx = receive_tx.clone();
let (stream, response) = ws_connect(url, headers).await?;
let (write, mut read) = stream.split();
self.connections.lock().await.insert(id.to_string(), write);
let tx = receive_tx.clone();
connections.lock().await.insert(id.to_string(), write);
tauri::async_runtime::spawn(async move {
while let Some(Ok(message)) = read.next().await {
debug!("Received websocket message {message:?}");
if message.is_close() {
return;
while let Some(msg) = read.next().await {
match msg {
Err(e) => {
warn!("Broken websocket connection: {}", e);
break;
}
Ok(message) => tx.send(message).await.unwrap(),
}
tx.send(message).await.unwrap();
}
debug!("Connection {} closed", connection_id);
connections.lock().await.remove(&connection_id);
});
Ok(response)
}
@@ -59,4 +67,15 @@ impl WebsocketManager {
connection.send(msg).await?;
Ok(())
}
pub async fn close(&mut self, id: &str) -> Result<()> {
debug!("Closing websocket");
let mut connections = self.connections.lock().await;
let connection = match connections.get_mut(id) {
None => return Ok(()),
Some(c) => c,
};
connection.close().await?;
Ok(())
}
}