From 67000af7f919221860122ef02c0546aba02841e4 Mon Sep 17 00:00:00 2001 From: Gregory Schier Date: Thu, 1 Feb 2024 20:29:32 -0800 Subject: [PATCH] Better connection management --- src-tauri/Cargo.lock | 41 +++++++++++++++++ src-tauri/grpc/Cargo.toml | 2 + src-tauri/grpc/src/lib.rs | 41 +---------------- src-tauri/grpc/src/manager.rs | 29 ++++-------- src-tauri/grpc/src/proto.rs | 51 +++++++++++++++------ src-tauri/src/main.rs | 45 ++++++++++++++++-- src-web/components/GrpcConnectionLayout.tsx | 8 +++- src-web/hooks/useGrpc.ts | 37 +++++++++++++-- 8 files changed, 173 insertions(+), 81 deletions(-) diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index c798866a..68019756 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -1679,6 +1679,8 @@ name = "grpc" version = "0.1.0" dependencies = [ "anyhow", + "hyper", + "hyper-rustls", "log", "once_cell", "prost", @@ -1949,6 +1951,22 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-rustls" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" +dependencies = [ + "futures-util", + "http", + "hyper", + "log", + "rustls", + "rustls-native-certs", + "tokio", + "tokio-rustls", +] + [[package]] name = "hyper-timeout" version = "0.4.1" @@ -3795,11 +3813,24 @@ version = "0.21.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "446e14c5cda4f3f30fe71863c34ec70f5ac79d6087097ad0bb433e1be5edf04c" dependencies = [ + "log", "ring", "rustls-webpki", "sct", ] +[[package]] +name = "rustls-native-certs" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" +dependencies = [ + "openssl-probe", + "rustls-pemfile", + "schannel", + "security-framework", +] + [[package]] name = "rustls-pemfile" version = "1.0.4" @@ -5081,6 +5112,16 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" +dependencies = [ + "rustls", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.14" diff --git a/src-tauri/grpc/Cargo.toml b/src-tauri/grpc/Cargo.toml index 88137351..640896ed 100644 --- a/src-tauri/grpc/Cargo.toml +++ b/src-tauri/grpc/Cargo.toml @@ -16,3 +16,5 @@ prost-reflect = { version = "0.12.0", features = ["serde", "derive"] } log = "0.4.20" once_cell = { version = "1.19.0", features = [] } anyhow = "1.0.79" +hyper = { version = "0.14"} +hyper-rustls = { version = "0.24.0", features = ["http2"] } diff --git a/src-tauri/grpc/src/lib.rs b/src-tauri/grpc/src/lib.rs index fa746cff..a452b1c0 100644 --- a/src-tauri/grpc/src/lib.rs +++ b/src-tauri/grpc/src/lib.rs @@ -10,8 +10,8 @@ use crate::proto::{fill_pool, method_desc_to_path}; mod codec; mod json_schema; -mod proto; pub mod manager; +mod proto; pub fn serialize_options() -> SerializeOptions { SerializeOptions::new().skip_default_fields(false) @@ -64,7 +64,7 @@ pub async fn client_streaming( DynamicMessage::deserialize(input_message, &mut deserializer).map_err(|e| e.to_string())?; deserializer.end().unwrap(); - let mut client = tonic::client::Grpc::new(conn); + let mut client = tonic::client::Grpc::with_origin(conn, uri.clone()); println!( "\n---------- SENDING -----------------\n{}", @@ -84,43 +84,6 @@ pub async fn client_streaming( Ok(response_json) } -pub async fn server_streaming( - uri: &Uri, - service: &str, - method: &str, - message_json: &str, -) -> Result>, String> { - let (pool, conn) = fill_pool(uri).await; - - let service = pool.get_service_by_name(service).unwrap(); - let method = &service.methods().find(|m| m.name() == method).unwrap(); - let input_message = method.input(); - - let mut deserializer = Deserializer::from_str(message_json); - let req_message = - DynamicMessage::deserialize(input_message, &mut deserializer).map_err(|e| e.to_string())?; - deserializer.end().unwrap(); - - let mut client = tonic::client::Grpc::new(conn); - - println!( - "\n---------- SENDING -----------------\n{}", - serde_json::to_string_pretty(&req_message).expect("json") - ); - - let req = req_message.into_request(); - let path = method_desc_to_path(method); - let codec = DynamicCodec::new(method.clone()); - client.ready().await.unwrap(); - - let resp = client.server_streaming(req, path, codec).await.unwrap(); - // let response_json = serde_json::to_string_pretty(&resp.into_inner()).expect("json to string"); - // println!("\n---------- RECEIVING ---------------\n{}", response_json,); - - // Ok(response_json) - Ok(resp) -} - pub async fn callable(uri: &Uri) -> Vec { let (pool, _) = fill_pool(uri).await; diff --git a/src-tauri/grpc/src/manager.rs b/src-tauri/grpc/src/manager.rs index 43cc7edd..a3ec03e0 100644 --- a/src-tauri/grpc/src/manager.rs +++ b/src-tauri/grpc/src/manager.rs @@ -1,10 +1,13 @@ use std::collections::HashMap; +use hyper::client::HttpConnector; +use hyper::Client; +use hyper_rustls::HttpsConnector; use prost_reflect::{DescriptorPool, DynamicMessage}; use serde_json::Deserializer; use tokio::sync::mpsc; -use tokio_stream::StreamExt; -use tonic::transport::{Channel, Uri}; +use tonic::body::BoxBody; +use tonic::transport::Uri; use tonic::{IntoRequest, Streaming}; use crate::codec::DynamicCodec; @@ -15,7 +18,8 @@ type Result = std::result::Result; #[derive(Clone)] pub struct GrpcConnection { pool: DescriptorPool, - conn: Channel, + conn: Client, BoxBody>, + pub uri: Uri, } impl GrpcConnection { @@ -29,7 +33,7 @@ impl GrpcConnection { .map_err(|e| e.to_string())?; deserializer.end().unwrap(); - let mut client = tonic::client::Grpc::new(self.conn.clone()); + let mut client = tonic::client::Grpc::with_origin(self.conn.clone(), self.uri.clone()); println!( "\n---------- SENDING -----------------\n{}", @@ -63,7 +67,7 @@ impl GrpcConnection { .map_err(|e| e.to_string())?; deserializer.end().unwrap(); - let mut client = tonic::client::Grpc::new(self.conn.clone()); + let mut client = tonic::client::Grpc::with_origin(self.conn.clone(), self.uri.clone()); println!( "\n---------- SENDING -----------------\n{}", @@ -115,24 +119,11 @@ impl GrpcManager { .await .server_streaming(service, method, message) .await - - // while let Some(item) = stream.next().await { - // match item { - // Ok(item) => { - // let item = serde_json::to_string_pretty(&item).unwrap(); - // println!("Sending message {}", item); - // self.send.send(item).await.unwrap() - // } - // Err(e) => println!("\terror: {}", e), - // } - // } - - // Ok(()) } pub async fn connect(&mut self, id: &str, uri: Uri) -> GrpcConnection { let (pool, conn) = fill_pool(&uri).await; - let connection = GrpcConnection { pool, conn }; + let connection = GrpcConnection { pool, conn, uri }; self.connections.insert(id.to_string(), connection.clone()); connection } diff --git a/src-tauri/grpc/src/proto.rs b/src-tauri/grpc/src/proto.rs index dbd59a55..e59dfbd9 100644 --- a/src-tauri/grpc/src/proto.rs +++ b/src-tauri/grpc/src/proto.rs @@ -1,27 +1,48 @@ use std::ops::Deref; use std::str::FromStr; + use anyhow::anyhow; +use hyper::client::HttpConnector; +use hyper::Client; +use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder}; use prost::Message; use prost_reflect::{DescriptorPool, MethodDescriptor}; use prost_types::FileDescriptorProto; use tokio_stream::StreamExt; +use tonic::body::BoxBody; use tonic::codegen::http::uri::PathAndQuery; +use tonic::transport::Uri; use tonic::Request; -use tonic::transport::{Channel, Uri}; use tonic_reflection::pb::server_reflection_client::ServerReflectionClient; use tonic_reflection::pb::server_reflection_request::MessageRequest; use tonic_reflection::pb::server_reflection_response::MessageResponse; use tonic_reflection::pb::ServerReflectionRequest; -pub async fn fill_pool(uri: &Uri) -> (DescriptorPool, Channel) { +pub async fn fill_pool( + uri: &Uri, +) -> ( + DescriptorPool, + Client, BoxBody>, +) { let mut pool = DescriptorPool::new(); - let conn = tonic::transport::Endpoint::new(uri.clone()) - .unwrap() - .connect() - .await - .unwrap(); + let connector = HttpsConnectorBuilder::new().with_native_roots(); + let connector = connector.https_or_http().enable_http2().wrap_connector({ + let mut http_connector = HttpConnector::new(); + http_connector.enforce_http(false); + http_connector + }); + let transport = Client::builder() + .pool_max_idle_per_host(0) + .http2_only(true) + .build(connector); - let mut client = ServerReflectionClient::new(conn.clone()); + println!( + "URI uri={} host={:?} authority={:?}", + uri, + uri.host(), + uri.authority() + ); + let mut client = ServerReflectionClient::with_origin(transport.clone(), uri.clone()); let services = list_services(&mut client).await; for service in services { @@ -31,10 +52,12 @@ pub async fn fill_pool(uri: &Uri) -> (DescriptorPool, Channel) { file_descriptor_set_from_service_name(&service, &mut pool, &mut client).await; } - (pool, conn) + (pool, transport) } -async fn list_services(reflect_client: &mut ServerReflectionClient) -> Vec { +async fn list_services( + reflect_client: &mut ServerReflectionClient, BoxBody>>, +) -> Vec { let response = send_reflection_request(reflect_client, MessageRequest::ListServices("".into())).await; @@ -53,13 +76,13 @@ async fn list_services(reflect_client: &mut ServerReflectionClient) -> async fn file_descriptor_set_from_service_name( service_name: &str, pool: &mut DescriptorPool, - client: &mut ServerReflectionClient, + client: &mut ServerReflectionClient, BoxBody>>, ) { let response = send_reflection_request( client, MessageRequest::FileContainingSymbol(service_name.into()), ) - .await; + .await; let file_descriptor_response = match response { MessageResponse::FileDescriptorResponse(resp) => resp, @@ -82,7 +105,7 @@ async fn file_descriptor_set_from_service_name( async fn file_descriptor_set_by_filename( filename: &str, pool: &mut DescriptorPool, - client: &mut ServerReflectionClient, + client: &mut ServerReflectionClient, BoxBody>>, ) { // We already fetched this file if let Some(_) = pool.get_file_by_name(filename) { @@ -104,7 +127,7 @@ async fn file_descriptor_set_by_filename( } async fn send_reflection_request( - client: &mut ServerReflectionClient, + client: &mut ServerReflectionClient, BoxBody>>, message: MessageRequest, ) -> MessageResponse { let reflection_request = ServerReflectionRequest { diff --git a/src-tauri/src/main.rs b/src-tauri/src/main.rs index ec8895cc..ba06b93a 100644 --- a/src-tauri/src/main.rs +++ b/src-tauri/src/main.rs @@ -17,7 +17,7 @@ use std::str::FromStr; use ::http::uri::InvalidUri; use ::http::Uri; use fern::colors::ColoredLevelConfig; -use futures::StreamExt; +use futures::{Stream, StreamExt}; use log::{debug, error, info, warn}; use rand::random; use serde::Serialize; @@ -95,6 +95,14 @@ async fn cmd_grpc_reflect(endpoint: &str) -> Result, Stri Ok(grpc::callable(&uri).await) } +async fn cmd_grpc_cancel( + id: &str, + grpc_handle: State<'_, Mutex>, +) -> Result<(), String> { + // grpc_handle.lock().await.cancel(id).await.unwrap() + Ok(()) +} + #[tauri::command] async fn cmd_grpc_call_unary( endpoint: &str, @@ -149,17 +157,44 @@ async fn cmd_grpc_server_streaming( .await .unwrap(); - while let Some(item) = stream.next().await { - match item { - Ok(item) => { + loop { + match stream.message().await { + Ok(Some(item)) => { let item = serde_json::to_string_pretty(&item).unwrap(); + println!("GOT MESSAGE {:?}", item); println!("Sending message {}", item); emit_side_effect(&app_handle, "grpc_message", item); } - Err(e) => println!("\terror: {}", e), + Ok(None) => { + // sleep for a bit + println!("NO MESSAGE YET"); + sleep(std::time::Duration::from_millis(100)).await; + } + Err(e) => { + return Err(e.to_string()); + } } } + // while let Some(item) = stream.message() { + // println!("GOT MESSAGE"); + // // if grpc_handle.lock().await.is_cancelled(&conn_id) { + // // break; + // // } + // + // match item { + // Ok(item) => { + // let item = serde_json::to_string_pretty(&item).unwrap(); + // println!("Sending message {}", item); + // emit_side_effect(&app_handle, "grpc_message", item); + // } + // Err(e) => println!("\terror: {}", e), + // } + // // let foo = stream.trailers().await.unwrap(); + // break; + // } + + println!("DONE"); Ok(conn_id) } diff --git a/src-web/components/GrpcConnectionLayout.tsx b/src-web/components/GrpcConnectionLayout.tsx index 2c559c5c..9685ae97 100644 --- a/src-web/components/GrpcConnectionLayout.tsx +++ b/src-web/components/GrpcConnectionLayout.tsx @@ -67,7 +67,13 @@ export function GrpcConnectionLayout({ style }: Props) { body: 'Service or method not selected', }); } - if (activeMethod.serverStreaming && !activeMethod.clientStreaming) { + if (activeMethod.clientStreaming && activeMethod.serverStreaming) { + await grpc.bidiStreaming.mutateAsync({ + service: service.value ?? 'n/a', + method: method.value ?? 'n/a', + message: message.value ?? '', + }); + } else if (activeMethod.serverStreaming && !activeMethod.clientStreaming) { await grpc.serverStreaming.mutateAsync({ service: service.value ?? 'n/a', method: method.value ?? 'n/a', diff --git a/src-web/hooks/useGrpc.ts b/src-web/hooks/useGrpc.ts index 8d001cee..30127eb1 100644 --- a/src-web/hooks/useGrpc.ts +++ b/src-web/hooks/useGrpc.ts @@ -16,6 +16,7 @@ export interface GrpcMessage { export function useGrpc(url: string | null) { const [messages, setMessages] = useState([]); + const [activeConnectionId, setActiveConnectionId] = useState(null); useListenToTauriEvent( 'grpc_message', (event) => { @@ -40,7 +41,7 @@ export function useGrpc(url: string | null) { }); const serverStreaming = useMutation< - string, + void, string, { service: string; method: string; message: string } >({ @@ -50,12 +51,40 @@ export function useGrpc(url: string | null) { setMessages([ { isServer: false, message: JSON.stringify(JSON.parse(message)), time: new Date() }, ]); - return (await invoke('cmd_grpc_server_streaming', { + const id: string = await invoke('cmd_grpc_server_streaming', { endpoint: url, service, method, message, - })) as string; + }); + setActiveConnectionId(id); + }, + }); + + const bidiStreaming = useMutation< + void, + string, + { service: string; method: string; message: string } + >({ + mutationKey: ['grpc_bidi_streaming', url], + mutationFn: async ({ service, method, message }) => { + if (url === null) throw new Error('No URL provided'); + setMessages([]); + const id: string = await invoke('cmd_grpc_bidi_streaming', { + endpoint: url, + service, + method, + message, + }); + setActiveConnectionId(id); + }, + }); + + const cancel = useMutation({ + mutationKey: ['grpc_cancel', url], + mutationFn: async () => { + await invoke('cmd_grpc_cancel', { id: activeConnectionId }); + setActiveConnectionId(null); }, }); @@ -71,7 +100,9 @@ export function useGrpc(url: string | null) { return { unary, serverStreaming, + bidiStreaming, schema: reflect.data, + cancel, messages, }; }