diff --git a/src-tauri/grpc/src/lib.rs b/src-tauri/grpc/src/lib.rs index a452b1c0..cc318486 100644 --- a/src-tauri/grpc/src/lib.rs +++ b/src-tauri/grpc/src/lib.rs @@ -1,12 +1,10 @@ -use prost_reflect::{DynamicMessage, SerializeOptions}; +use prost_reflect::SerializeOptions; use serde::{Deserialize, Serialize}; -use serde_json::Deserializer; use tokio_stream::Stream; use tonic::transport::Uri; -use tonic::{IntoRequest, Response, Streaming}; +use tonic::IntoRequest; -use crate::codec::DynamicCodec; -use crate::proto::{fill_pool, method_desc_to_path}; +use crate::proto::fill_pool; mod codec; mod json_schema; @@ -33,57 +31,6 @@ pub struct MethodDefinition { pub server_streaming: bool, } -struct ClientStream {} - -impl Stream for ClientStream { - type Item = DynamicMessage; - - fn poll_next( - self: std::pin::Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - println!("poll_next"); - todo!() - } -} - -pub async fn client_streaming( - uri: &Uri, - service: &str, - method: &str, - message_json: &str, -) -> Result { - let (pool, conn) = fill_pool(uri).await; - - let service = pool.get_service_by_name(service).unwrap(); - let method = &service.methods().find(|m| m.name() == method).unwrap(); - let input_message = method.input(); - - let mut deserializer = Deserializer::from_str(message_json); - let req_message = - DynamicMessage::deserialize(input_message, &mut deserializer).map_err(|e| e.to_string())?; - deserializer.end().unwrap(); - - let mut client = tonic::client::Grpc::with_origin(conn, uri.clone()); - - println!( - "\n---------- SENDING -----------------\n{}", - serde_json::to_string_pretty(&req_message).expect("json") - ); - - let req = tonic::Request::new(ClientStream {}); - - let path = method_desc_to_path(method); - let codec = DynamicCodec::new(method.clone()); - client.ready().await.unwrap(); - - let resp = client.client_streaming(req, path, codec).await.unwrap(); - let response_json = serde_json::to_string_pretty(&resp.into_inner()).expect("json to string"); - println!("\n---------- RECEIVING ---------------\n{}", response_json,); - - Ok(response_json) -} - pub async fn callable(uri: &Uri) -> Vec { let (pool, _) = fill_pool(uri).await; diff --git a/src-tauri/grpc/src/manager.rs b/src-tauri/grpc/src/manager.rs index 4e91543f..14fd610e 100644 --- a/src-tauri/grpc/src/manager.rs +++ b/src-tauri/grpc/src/manager.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; +use hyper::client::connect::Connect; use hyper::client::HttpConnector; use hyper::Client; use hyper_rustls::HttpsConnector; @@ -8,7 +9,7 @@ pub use prost_reflect::DynamicMessage; use serde_json::Deserializer; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; -use tokio_stream::{Stream, StreamExt}; +use tokio_stream::StreamExt; use tonic::body::BoxBody; use tonic::transport::Uri; use tonic::{IntoRequest, IntoStreamingRequest, Streaming}; @@ -88,6 +89,40 @@ impl GrpcConnection { .into_inner()) } + pub async fn client_streaming( + &self, + service: &str, + method: &str, + stream: ReceiverStream, + ) -> Result { + let service = self.pool.get_service_by_name(service).unwrap(); + let method = &service.methods().find(|m| m.name() == method).unwrap(); + let mut client = tonic::client::Grpc::with_origin(self.conn.clone(), self.uri.clone()); + + let req = { + let method = method.clone(); + stream + .map(move |s| { + let mut deserializer = Deserializer::from_str(&s); + let req_message = + DynamicMessage::deserialize(method.input(), &mut deserializer) + .map_err(|e| e.to_string()) + .unwrap(); + deserializer.end().unwrap(); + req_message + }) + .into_streaming_request() + }; + let path = method_desc_to_path(method); + let codec = DynamicCodec::new(method.clone()); + client.ready().await.unwrap(); + Ok(client + .client_streaming(req, path, codec) + .await + .map_err(|s| s.to_string())? + .into_inner()) + } + pub async fn server_streaming( &self, service: &str, @@ -156,6 +191,20 @@ impl GrpcManager { .await } + pub async fn client_streaming( + &mut self, + id: &str, + uri: Uri, + service: &str, + method: &str, + stream: ReceiverStream, + ) -> Result { + self.connect(id, uri) + .await + .client_streaming(service, method, stream) + .await + } + pub async fn streaming( &mut self, id: &str, @@ -169,6 +218,7 @@ impl GrpcManager { .streaming(service, method, stream) .await } + pub async fn connect(&mut self, id: &str, uri: Uri) -> GrpcConnection { let (pool, conn) = fill_pool(&uri).await; let connection = GrpcConnection { pool, conn, uri }; diff --git a/src-tauri/src/main.rs b/src-tauri/src/main.rs index 253cf592..96f69812 100644 --- a/src-tauri/src/main.rs +++ b/src-tauri/src/main.rs @@ -177,19 +177,217 @@ async fn cmd_grpc_call_unary( #[tauri::command] async fn cmd_grpc_client_streaming( - endpoint: &str, - service: &str, - method: &str, - message: &str, - // app_handle: AppHandle, - // db_state: State<'_, Mutex>>, -) -> Result<(), String> { - let service = service.to_string(); - let method = method.to_string(); - let message = message.to_string(); - let uri = safe_uri(endpoint).map_err(|e| e.to_string())?; - tokio::spawn(async move { grpc::client_streaming(&uri, &service, &method, &message).await }); - Ok(()) + request_id: &str, + grpc_handle: State<'_, Mutex>, + app_handle: AppHandle, + db_state: State<'_, Mutex>>, +) -> Result { + println!("CLIENT STREAMING"); + let db = &*db_state.lock().await; + let req = get_grpc_request(db, request_id) + .await + .map_err(|e| e.to_string())?; + let conn = { + let req = req.clone(); + upsert_grpc_connection( + db, + &GrpcConnection { + workspace_id: req.workspace_id, + request_id: req.id, + ..Default::default() + }, + ) + .await + .map_err(|e| e.to_string())? + }; + println!("CREATED CONN: {}", conn.clone().id); + emit_side_effect(app_handle.clone(), "created_model", conn.clone()); + + { + let conn = conn.clone(); + let req = req.clone(); + let db = db.clone(); + emit_side_effect( + app_handle.clone(), + "created_model", + upsert_grpc_message( + &db, + &GrpcMessage { + message: "Initiating connection".to_string(), + workspace_id: req.workspace_id, + request_id: req.id, + connection_id: conn.id, + is_info: true, + ..Default::default() + }, + ) + .await + .expect("Failed to upsert message"), + ); + }; + + let (in_msg_tx, in_msg_rx) = tauri::async_runtime::channel::(16); + let maybe_in_msg_tx = std::sync::Mutex::new(Some(in_msg_tx.clone())); + let (cancelled_tx, mut cancelled_rx) = tokio::sync::watch::channel(false); + + let uri = safe_uri(&req.url).map_err(|e| e.to_string())?; + + let in_msg_stream = tokio_stream::wrappers::ReceiverStream::new(in_msg_rx); + + let (service, method) = { + let req = req.clone(); + match (req.service, req.method) { + (Some(service), Some(method)) => (service, method), + _ => return Err("Service and method are required".to_string()), + } + }; + + #[derive(serde::Deserialize)] + enum IncomingMsg { + Message(String), + Commit, + Cancel, + } + + let cb = { + let cancelled_rx = cancelled_rx.clone(); + let app_handle = app_handle.clone(); + let conn = conn.clone(); + let req = req.clone(); + + move |ev: tauri::Event| { + if *cancelled_rx.borrow() { + // Stream is cancelled + return; + } + + let mut maybe_in_msg_tx = maybe_in_msg_tx + .lock() + .expect("previous holder not to panic"); + let in_msg_tx = if let Some(in_msg_tx) = maybe_in_msg_tx.as_ref() { + in_msg_tx + } else { + // This would mean that the stream is already committed because + // we have already dropped the sending half + return; + }; + + match serde_json::from_str::(ev.payload().unwrap()) { + Ok(IncomingMsg::Message(msg)) => { + in_msg_tx.try_send(msg.clone()).unwrap(); + let app_handle = app_handle.clone(); + let req = req.clone(); + let conn = conn.clone(); + tauri::async_runtime::spawn(async move { + let db_state = app_handle.state::>>(); + let db = &*db_state.lock().await; + emit_side_effect( + app_handle.clone(), + "created_model", + upsert_grpc_message( + &db, + &GrpcMessage { + message: msg, + workspace_id: req.workspace_id, + request_id: req.id, + connection_id: conn.id, + ..Default::default() + }, + ) + .await + .expect("Failed to upsert message"), + ); + }); + } + Ok(IncomingMsg::Commit) => { + println!("Received commit"); + maybe_in_msg_tx.take(); + // in_msg_stream.close(); + } + Ok(IncomingMsg::Cancel) => { + println!("Received cancel"); + cancelled_tx.send_replace(true); + } + Err(e) => { + error!("Failed to parse gRPC message: {:?}", e); + } + } + } + }; + let event_handler = + app_handle.listen_global(format!("grpc_client_msg_{}", conn.id).as_str(), cb); + + let grpc_listen = { + let app_handle = app_handle.clone(); + let conn = conn.clone(); + let req = req.clone(); + async move { + let grpc_handle = app_handle.state::>(); + let db_state = app_handle.state::>>(); + println!("STARTING CLIENT STREAM"); + let msg = grpc_handle + .lock() + .await + .client_streaming(&conn.id, uri, &service, &method, in_msg_stream) + .await + .unwrap(); + let db = &*db_state.lock().await; + emit_side_effect( + app_handle.clone(), + "created_model", + upsert_grpc_message( + db, + &GrpcMessage { + message: msg.to_string(), + workspace_id: req.workspace_id, + request_id: req.id, + connection_id: conn.id, + is_server: true, + ..Default::default() + }, + ) + .await + .expect("Failed to upsert message"), + ); + } + }; + + println!("ENDED CLIENT STREAM"); + { + let conn = conn.clone(); + tauri::async_runtime::spawn(async move { + tokio::select! { + _ = grpc_listen => { + debug!("gRPC listen finished"); + }, + _ = cancelled_rx.changed() => { + debug!("gRPC connection cancelled"); + let db_state = app_handle.state::>>(); + let db = &*db_state.lock().await; + emit_side_effect( + app_handle.clone(), + "created_model", + upsert_grpc_message( + &db, + &GrpcMessage { + message: "Connection cancelled".to_string(), + workspace_id: req.workspace_id, + request_id: req.id, + connection_id: conn.id, + is_info: true, + ..Default::default() + }, + ) + .await + .expect("Failed to upsert message"), + ); + }, + } + app_handle.unlisten(event_handler); + }); + }; + + Ok(conn) } #[tauri::command] @@ -371,17 +569,39 @@ async fn cmd_grpc_streaming( } }; - tauri::async_runtime::spawn(async move { - tokio::select! { - _ = grpc_listen => { - debug!("gRPC listen finished"); - }, - _ = cancelled_rx.changed() => { - debug!("gRPC connection cancelled"); - }, - } - app_handle.unlisten(event_handler); - }); + { + let conn = conn.clone(); + tauri::async_runtime::spawn(async move { + tokio::select! { + _ = grpc_listen => { + debug!("gRPC listen finished"); + }, + _ = cancelled_rx.changed() => { + debug!("gRPC connection cancelled"); + let db_state = app_handle.state::>>(); + let db = &*db_state.lock().await; + emit_side_effect( + app_handle.clone(), + "created_model", + upsert_grpc_message( + &db, + &GrpcMessage { + message: "Connection cancelled".to_string(), + workspace_id: req.workspace_id, + request_id: req.id, + connection_id: conn.id, + is_info: true, + ..Default::default() + }, + ) + .await + .expect("Failed to upsert message"), + ); + }, + } + app_handle.unlisten(event_handler); + }); + }; Ok(conn.id) } diff --git a/src-web/components/GrpcConnectionLayout.tsx b/src-web/components/GrpcConnectionLayout.tsx index 199a1672..5d9f6664 100644 --- a/src-web/components/GrpcConnectionLayout.tsx +++ b/src-web/components/GrpcConnectionLayout.tsx @@ -48,6 +48,10 @@ export function GrpcConnectionLayout({ style }: Props) { grpc.cancel.mutateAsync().catch(console.error); }, [grpc.cancel]); + const handleCommit = useCallback(() => { + grpc.commit.mutateAsync().catch(console.error); + }, [grpc.commit]); + const handleConnect = useCallback( async (e: FormEvent) => { e.preventDefault(); @@ -62,13 +66,23 @@ export function GrpcConnectionLayout({ style }: Props) { } if (activeMethod.clientStreaming && activeMethod.serverStreaming) { await grpc.streaming.mutateAsync(activeRequest.id); - } else if (activeMethod.serverStreaming && !activeMethod.clientStreaming) { + } else if (!activeMethod.clientStreaming && activeMethod.serverStreaming) { await grpc.serverStreaming.mutateAsync(activeRequest.id); + } else if (activeMethod.clientStreaming && !activeMethod.serverStreaming) { + await grpc.clientStreaming.mutateAsync(activeRequest.id); } else { await grpc.unary.mutateAsync(activeRequest.id); } }, - [activeMethod, activeRequest, alert, grpc.streaming, grpc.serverStreaming, grpc.unary], + [ + activeMethod, + activeRequest, + alert, + grpc.streaming, + grpc.serverStreaming, + grpc.clientStreaming, + grpc.unary, + ], ); useEffect(() => { @@ -212,6 +226,17 @@ export function GrpcConnectionLayout({ style }: Props) { icon="sendHorizontal" /> )} + {activeMethod?.clientStreaming && + !activeMethod.serverStreaming && + grpc.isStreaming && ( + + )}
- + {activeMessage.isInfo ? ( + {activeMessage.message} + ) : ( + + )}
) diff --git a/src-web/hooks/useGrpc.ts b/src-web/hooks/useGrpc.ts index 22cb4937..32132ec2 100644 --- a/src-web/hooks/useGrpc.ts +++ b/src-web/hooks/useGrpc.ts @@ -34,6 +34,17 @@ export function useGrpc(url: string | null, requestId: string | null) { }, }); + const clientStreaming = useMutation({ + mutationKey: ['grpc_client_streaming', url], + mutationFn: async (requestId) => { + if (url === null) throw new Error('No URL provided'); + await messages.set([]); + const c = (await invoke('cmd_grpc_client_streaming', { requestId })) as GrpcConnection; + console.log('GOT CONNECTION', c); + setActiveConnectionId(c.id); + }, + }); + const serverStreaming = useMutation({ mutationKey: ['grpc_server_streaming', url], mutationFn: async (requestId) => { @@ -77,6 +88,14 @@ export function useGrpc(url: string | null, requestId: string | null) { }, }); + const commit = useMutation({ + mutationKey: ['grpc_commit', url], + mutationFn: async () => { + setActiveConnectionId(null); + await emit(`grpc_client_msg_${activeConnectionId}`, 'Commit'); + }, + }); + const reflect = useQuery({ queryKey: ['grpc_reflect', url ?? ''], queryFn: async () => { @@ -87,10 +106,12 @@ export function useGrpc(url: string | null, requestId: string | null) { return { unary, + clientStreaming, serverStreaming, streaming, services: reflect.data, cancel, + commit, isStreaming: activeConnectionId !== null, send, };