Websocket Support (#159)

This commit is contained in:
Gregory Schier
2025-01-31 09:00:11 -08:00
committed by GitHub
parent d411713502
commit c8be8082c5
122 changed files with 5090 additions and 616 deletions

View File

@@ -0,0 +1,330 @@
use crate::error::Error::GenericError;
use crate::error::Result;
use crate::manager::WebsocketManager;
use crate::render::render_request;
use chrono::Utc;
use log::info;
use std::str::FromStr;
use tauri::http::{HeaderMap, HeaderName};
use tauri::{AppHandle, Manager, Runtime, State, WebviewWindow};
use tokio::sync::{mpsc, Mutex};
use tokio_tungstenite::tungstenite::http::HeaderValue;
use tokio_tungstenite::tungstenite::Message;
use yaak_models::models::{
HttpResponseHeader, WebsocketConnection, WebsocketConnectionState, WebsocketEvent,
WebsocketEventType, WebsocketRequest,
};
use yaak_models::queries;
use yaak_models::queries::{
get_base_environment, get_cookie_jar, get_environment, get_websocket_connection,
get_websocket_request, upsert_websocket_connection, upsert_websocket_event, UpdateSource,
};
use yaak_plugins::events::{
CallHttpAuthenticationRequest, HttpHeader, RenderPurpose, WindowContext,
};
use yaak_plugins::manager::PluginManager;
use yaak_plugins::template_callback::PluginTemplateCallback;
#[tauri::command]
pub(crate) async fn upsert_request<R: Runtime>(
request: WebsocketRequest,
w: WebviewWindow<R>,
) -> Result<WebsocketRequest> {
Ok(queries::upsert_websocket_request(&w, request, &UpdateSource::Window).await?)
}
#[tauri::command]
pub(crate) async fn delete_request<R: Runtime>(
request_id: &str,
w: WebviewWindow<R>,
) -> Result<WebsocketRequest> {
Ok(queries::delete_websocket_request(&w, request_id, &UpdateSource::Window).await?)
}
#[tauri::command]
pub(crate) async fn delete_connection<R: Runtime>(
connection_id: &str,
w: WebviewWindow<R>,
) -> Result<WebsocketConnection> {
Ok(queries::delete_websocket_connection(&w, connection_id, &UpdateSource::Window).await?)
}
#[tauri::command]
pub(crate) async fn delete_connections<R: Runtime>(
request_id: &str,
w: WebviewWindow<R>,
) -> Result<()> {
Ok(queries::delete_all_websocket_connections(&w, request_id, &UpdateSource::Window).await?)
}
#[tauri::command]
pub(crate) async fn list_events<R: Runtime>(
connection_id: &str,
app_handle: AppHandle<R>,
) -> Result<Vec<WebsocketEvent>> {
Ok(queries::list_websocket_events(&app_handle, connection_id).await?)
}
#[tauri::command]
pub(crate) async fn list_requests<R: Runtime>(
workspace_id: &str,
app_handle: AppHandle<R>,
) -> Result<Vec<WebsocketRequest>> {
Ok(queries::list_websocket_requests(&app_handle, workspace_id).await?)
}
#[tauri::command]
pub(crate) async fn list_connections<R: Runtime>(
workspace_id: &str,
app_handle: AppHandle<R>,
) -> Result<Vec<WebsocketConnection>> {
Ok(queries::list_websocket_connections_for_workspace(&app_handle, workspace_id).await?)
}
#[tauri::command]
pub(crate) async fn send<R: Runtime>(
connection_id: &str,
environment_id: Option<&str>,
window: WebviewWindow<R>,
ws_manager: State<'_, Mutex<WebsocketManager>>,
) -> Result<WebsocketConnection> {
let connection = get_websocket_connection(&window, connection_id).await?;
let unrendered_request = get_websocket_request(&window, &connection.request_id)
.await?
.ok_or(GenericError("WebSocket Request not found".to_string()))?;
let environment = match environment_id {
Some(id) => Some(get_environment(&window, id).await?),
None => None,
};
let base_environment = get_base_environment(&window, &unrendered_request.workspace_id).await?;
let request = render_request(
&unrendered_request,
&base_environment,
environment.as_ref(),
&PluginTemplateCallback::new(
window.app_handle(),
&WindowContext::from_window(&window),
RenderPurpose::Send,
),
)
.await;
let mut ws_manager = ws_manager.lock().await;
ws_manager.send(&connection.id, Message::Text(request.message.clone().into())).await?;
upsert_websocket_event(
&window,
WebsocketEvent {
connection_id: connection.id.clone(),
request_id: request.id.clone(),
workspace_id: connection.workspace_id.clone(),
is_server: false,
message_type: WebsocketEventType::Text,
message: request.message.into(),
..Default::default()
},
&UpdateSource::Window,
)
.await
.unwrap();
Ok(connection)
}
#[tauri::command]
pub(crate) async fn close<R: Runtime>(
connection_id: &str,
window: WebviewWindow<R>,
ws_manager: State<'_, Mutex<WebsocketManager>>,
) -> Result<WebsocketConnection> {
let connection = get_websocket_connection(&window, connection_id).await?;
let request = get_websocket_request(&window, &connection.request_id)
.await?
.ok_or(GenericError("WebSocket Request not found".to_string()))?;
let mut ws_manager = ws_manager.lock().await;
ws_manager.send(&connection.id, Message::Close(None)).await?;
upsert_websocket_event(
&window,
WebsocketEvent {
connection_id: connection.id.clone(),
request_id: request.id.clone(),
workspace_id: request.workspace_id.clone(),
is_server: false,
message_type: WebsocketEventType::Close,
..Default::default()
},
&UpdateSource::Window,
)
.await
.unwrap();
let connection = upsert_websocket_connection(
&window,
&WebsocketConnection {
state: WebsocketConnectionState::Closed,
elapsed: Utc::now()
.naive_utc()
.signed_duration_since(connection.created_at)
.num_milliseconds() as i32,
..connection.clone()
},
&UpdateSource::Window,
)
.await?;
Ok(connection)
}
#[tauri::command]
pub(crate) async fn connect<R: Runtime>(
request_id: &str,
environment_id: Option<&str>,
cookie_jar_id: Option<&str>,
window: WebviewWindow<R>,
plugin_manager: State<'_, PluginManager>,
ws_manager: State<'_, Mutex<WebsocketManager>>,
) -> Result<WebsocketConnection> {
let unrendered_request = get_websocket_request(&window, request_id)
.await?
.ok_or(GenericError("Failed to find GRPC request".to_string()))?;
let environment = match environment_id {
Some(id) => Some(get_environment(&window, id).await?),
None => None,
};
let base_environment = get_base_environment(&window, &unrendered_request.workspace_id).await?;
let request = render_request(
&unrendered_request,
&base_environment,
environment.as_ref(),
&PluginTemplateCallback::new(
window.app_handle(),
&WindowContext::from_window(&window),
RenderPurpose::Send,
),
)
.await;
let mut headers = HeaderMap::new();
if let Some(auth_name) = request.authentication_type.clone() {
let auth = request.authentication.clone();
let plugin_req = CallHttpAuthenticationRequest {
context_id: format!("{:x}", md5::compute(request_id.to_string())),
values: serde_json::from_value(serde_json::to_value(&auth).unwrap()).unwrap(),
method: "POST".to_string(),
url: request.url.clone(),
headers: request
.headers
.clone()
.into_iter()
.map(|h| HttpHeader {
name: h.name,
value: h.value,
})
.collect(),
};
let plugin_result =
plugin_manager.call_http_authentication(&window, &auth_name, plugin_req).await?;
for header in plugin_result.set_headers {
headers.insert(
HeaderName::from_str(&header.name).unwrap(),
HeaderValue::from_str(&header.value).unwrap(),
);
}
}
// TODO: Handle cookies
let _cookie_jar = match cookie_jar_id {
Some(id) => Some(get_cookie_jar(&window, id).await?),
None => None,
};
let connection = upsert_websocket_connection(
&window,
&WebsocketConnection {
workspace_id: request.workspace_id.clone(),
request_id: request_id.to_string(),
..Default::default()
},
&UpdateSource::Window,
)
.await?;
let (receive_tx, mut receive_rx) = mpsc::channel::<Message>(128);
let mut ws_manager = ws_manager.lock().await;
{
let connection_id = connection.id.clone();
let request_id = request.id.to_string();
let workspace_id = request.workspace_id.clone();
let window = window.clone();
tokio::spawn(async move {
while let Some(message) = receive_rx.recv().await {
upsert_websocket_event(
&window,
WebsocketEvent {
connection_id: connection_id.clone(),
request_id: request_id.clone(),
workspace_id: workspace_id.clone(),
is_server: true,
message_type: match message {
Message::Text(_) => WebsocketEventType::Text,
Message::Binary(_) => WebsocketEventType::Binary,
Message::Ping(_) => WebsocketEventType::Ping,
Message::Pong(_) => WebsocketEventType::Pong,
Message::Close(_) => WebsocketEventType::Close,
// Raw frame will never happen during a read
Message::Frame(_) => WebsocketEventType::Frame,
},
message: message.into_data().into(),
..Default::default()
},
&UpdateSource::Window,
)
.await
.unwrap();
}
info!("Websocket connection closed");
});
}
let response = match ws_manager.connect(&connection.id, &request.url, headers, receive_tx).await
{
Ok(r) => r,
Err(e) => {
return Ok(upsert_websocket_connection(
&window,
&WebsocketConnection {
error: Some(format!("{e:?}")),
..connection
},
&UpdateSource::Window,
)
.await?);
}
};
let response_headers = response
.headers()
.into_iter()
.map(|(name, value)| HttpResponseHeader {
name: name.to_string(),
value: value.to_str().unwrap().to_string(),
})
.collect::<Vec<HttpResponseHeader>>();
let connection = upsert_websocket_connection(
&window,
&WebsocketConnection {
state: WebsocketConnectionState::Connected,
headers: response_headers,
status: response.status().as_u16() as i32,
url: request.url.clone(),
..connection
},
&UpdateSource::Window,
)
.await?;
Ok(connection)
}

View File

@@ -0,0 +1,80 @@
use log::info;
use rustls::crypto::ring;
use rustls::ClientConfig;
use rustls_platform_verifier::BuilderVerifierExt;
use std::sync::Arc;
use tauri::http::HeaderMap;
use tokio::net::TcpStream;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::handshake::client::Response;
use tokio_tungstenite::tungstenite::http::HeaderValue;
use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
use tokio_tungstenite::{
connect_async_tls_with_config, Connector, MaybeTlsStream, WebSocketStream,
};
pub(crate) async fn ws_connect(
url: &str,
headers: HeaderMap<HeaderValue>,
) -> crate::error::Result<(WebSocketStream<MaybeTlsStream<TcpStream>>, Response)> {
info!("Connecting to WS {url}");
let arc_crypto_provider = Arc::new(ring::default_provider());
let config = ClientConfig::builder_with_provider(arc_crypto_provider)
.with_safe_default_protocol_versions()
.unwrap()
.with_platform_verifier()
.with_no_client_auth();
let mut req = url.into_client_request()?;
let req_headers = req.headers_mut();
for (name, value) in headers {
if let Some(name) = name {
req_headers.insert(name, value);
}
}
let (stream, response) = connect_async_tls_with_config(
req,
Some(WebSocketConfig::default()),
false,
Some(Connector::Rustls(Arc::new(config))),
)
.await?;
Ok((stream, response))
}
#[cfg(test)]
mod tests {
use crate::connect::ws_connect;
use crate::error::Result;
use futures_util::{SinkExt, StreamExt};
use std::time::Duration;
use tokio::time::timeout;
use tokio_tungstenite::tungstenite::Message;
#[tokio::test]
async fn test_connection() -> Result<()> {
let (stream, response) = ws_connect("wss://echo.websocket.org/", Default::default()).await?;
assert_eq!(response.status(), 101);
let (mut write, mut read) = stream.split();
let task = tokio::spawn(async move {
while let Some(Ok(message)) = read.next().await {
if message.is_text() && message.to_text().unwrap() == "Hello" {
return message;
}
}
panic!("Didn't receive text message");
});
write.send(Message::Text("Hello".into())).await?;
let task = timeout(Duration::from_secs(3), task);
let message = task.await.unwrap().unwrap();
assert_eq!(message.into_text().unwrap(), "Hello");
Ok(())
}
}

View File

@@ -0,0 +1,29 @@
use serde::{Serialize, Serializer};
use thiserror::Error;
use tokio_tungstenite::tungstenite;
#[derive(Error, Debug)]
pub enum Error {
#[error("WebSocket error: {0}")]
WebSocketErr(#[from] tungstenite::Error),
#[error("Model error: {0}")]
ModelError(#[from] yaak_models::error::Error),
#[error("Plugin error: {0}")]
PluginError(#[from] yaak_plugins::error::Error),
#[error("WebSocket error: {0}")]
GenericError(String),
}
impl Serialize for Error {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.to_string().as_ref())
}
}
pub type Result<T> = std::result::Result<T, Error>;

View File

@@ -0,0 +1,37 @@
mod cmd;
mod connect;
mod error;
mod manager;
mod render;
use crate::cmd::{
close, connect, delete_connection, delete_connections, delete_request, list_connections,
list_events, list_requests, send, upsert_request,
};
use crate::manager::WebsocketManager;
use tauri::plugin::{Builder, TauriPlugin};
use tauri::{generate_handler, Manager, Runtime};
use tokio::sync::Mutex;
pub fn init<R: Runtime>() -> TauriPlugin<R> {
Builder::new("yaak-ws")
.invoke_handler(generate_handler![
close,
connect,
delete_connection,
delete_connections,
delete_request,
list_connections,
list_events,
list_requests,
send,
upsert_request,
])
.setup(|app, _api| {
let manager = WebsocketManager::new();
app.manage(Mutex::new(manager));
Ok(())
})
.build()
}

View File

@@ -0,0 +1,62 @@
use crate::connect::ws_connect;
use crate::error::Result;
use futures_util::stream::SplitSink;
use futures_util::{SinkExt, StreamExt};
use log::debug;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio::sync::{mpsc, Mutex};
use tokio_tungstenite::tungstenite::handshake::client::Response;
use tokio_tungstenite::tungstenite::http::{HeaderMap, HeaderValue};
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
#[derive(Clone)]
pub struct WebsocketManager {
connections:
Arc<Mutex<HashMap<String, SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>>,
}
impl WebsocketManager {
pub fn new() -> Self {
WebsocketManager {
connections: Default::default(),
}
}
pub async fn connect(
&mut self,
id: &str,
url: &str,
headers: HeaderMap<HeaderValue>,
receive_tx: mpsc::Sender<Message>,
) -> Result<Response> {
let (stream, response) = ws_connect(url, headers).await?;
let (write, mut read) = stream.split();
self.connections.lock().await.insert(id.to_string(), write);
let tx = receive_tx.clone();
tauri::async_runtime::spawn(async move {
while let Some(Ok(message)) = read.next().await {
debug!("Received websocket message {message:?}");
if message.is_close() {
return;
}
tx.send(message).await.unwrap();
}
});
Ok(response)
}
pub async fn send(&mut self, id: &str, msg: Message) -> Result<()> {
debug!("Send websocket message {msg:?}");
let mut connections = self.connections.lock().await;
let connection = match connections.get_mut(id) {
None => return Ok(()),
Some(c) => c,
};
connection.send(msg).await?;
Ok(())
}
}

View File

@@ -0,0 +1,40 @@
use std::collections::BTreeMap;
use yaak_models::models::{Environment, HttpRequestHeader, WebsocketRequest};
use yaak_models::render::make_vars_hashmap;
use yaak_templates::{parse_and_render, render_json_value_raw, TemplateCallback};
pub async fn render_request<T: TemplateCallback>(
r: &WebsocketRequest,
base_environment: &Environment,
environment: Option<&Environment>,
cb: &T,
) -> WebsocketRequest {
let vars = &make_vars_hashmap(base_environment, environment);
let mut headers = Vec::new();
for p in r.headers.clone() {
headers.push(HttpRequestHeader {
enabled: p.enabled,
name: parse_and_render(&p.name, vars, cb).await,
value: parse_and_render(&p.value, vars, cb).await,
id: p.id,
})
}
let mut authentication = BTreeMap::new();
for (k, v) in r.authentication.clone() {
authentication.insert(k, render_json_value_raw(v, vars, cb).await);
}
let url = parse_and_render(r.url.as_str(), vars, cb).await;
let message = parse_and_render(&r.message.clone(), vars, cb).await;
WebsocketRequest {
url,
headers,
authentication,
message,
..r.to_owned()
}
}