From e27ed9beccf7f3ab447e9a03a70e900ce38e0b31 Mon Sep 17 00:00:00 2001 From: Gregory Schier Date: Fri, 2 Feb 2024 01:10:54 -0800 Subject: [PATCH] bidi hacked! --- src-tauri/Cargo.lock | 1 + src-tauri/Cargo.toml | 1 + src-tauri/grpc/src/manager.rs | 54 ++++++++++- src-tauri/src/main.rs | 99 +++++++++++++++++++++ src-web/components/GrpcConnectionLayout.tsx | 69 +++++++++----- src-web/components/core/Icon.tsx | 1 + src-web/hooks/useGrpc.ts | 24 ++++- 7 files changed, 219 insertions(+), 30 deletions(-) diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index 4c6f07e1..7d92cec2 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -6247,6 +6247,7 @@ dependencies = [ "tauri-plugin-log", "tauri-plugin-window-state", "tokio", + "tokio-stream", "uuid", "window-shadows", ] diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index da9dfecf..45c74430 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -61,6 +61,7 @@ datetime = "0.5.2" window-shadows = "0.2.2" reqwest_cookie_store = "0.6.0" grpc = { path = "./grpc" } +tokio-stream = "0.1.14" [features] # by default Tauri runs in production mode diff --git a/src-tauri/grpc/src/manager.rs b/src-tauri/grpc/src/manager.rs index 3debec82..611dc450 100644 --- a/src-tauri/grpc/src/manager.rs +++ b/src-tauri/grpc/src/manager.rs @@ -3,12 +3,15 @@ use std::collections::HashMap; use hyper::client::HttpConnector; use hyper::Client; use hyper_rustls::HttpsConnector; -use prost_reflect::{DescriptorPool, DynamicMessage}; +use prost_reflect::DescriptorPool; +pub use prost_reflect::DynamicMessage; use serde_json::Deserializer; use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tokio_stream::{Stream, StreamExt}; use tonic::body::BoxBody; use tonic::transport::Uri; -use tonic::{IntoRequest, Streaming}; +use tonic::{IntoRequest, IntoStreamingRequest, Streaming}; use crate::codec::DynamicCodec; use crate::proto::{fill_pool, method_desc_to_path}; @@ -52,6 +55,39 @@ impl GrpcConnection { Ok(response_json) } + + pub async fn bidi_streaming( + &self, + service: &str, + method: &str, + stream: ReceiverStream, + ) -> Result> { + let service = self.pool.get_service_by_name(service).unwrap(); + let method = &service.methods().find(|m| m.name() == method).unwrap(); + + let mut client = tonic::client::Grpc::with_origin(self.conn.clone(), self.uri.clone()); + + let method2 = method.clone(); + let req = stream + .map(move |s| { + let mut deserializer = Deserializer::from_str(&s); + let req_message = DynamicMessage::deserialize(method2.input(), &mut deserializer) + .map_err(|e| e.to_string()) + .unwrap(); + deserializer.end().unwrap(); + req_message + }) + .into_streaming_request(); + let path = method_desc_to_path(method); + let codec = DynamicCodec::new(method.clone()); + client.ready().await.unwrap(); + Ok(client + .streaming(req, path, codec) + .await + .map_err(|s| s.to_string())? + .into_inner()) + } + pub async fn server_streaming( &self, service: &str, @@ -120,6 +156,20 @@ impl GrpcManager { .await } + pub async fn bidi_streaming( + &mut self, + id: &str, + uri: Uri, + service: &str, + method: &str, + stream: ReceiverStream, + ) -> Result> { + println!("Bidi streaming {}", id); + self.connect(id, uri) + .await + .bidi_streaming(service, method, stream) + .await + } pub async fn connect(&mut self, id: &str, uri: Uri) -> GrpcConnection { let (pool, conn) = fill_pool(&uri).await; let connection = GrpcConnection { pool, conn, uri }; diff --git a/src-tauri/src/main.rs b/src-tauri/src/main.rs index 651b02fd..966e0e51 100644 --- a/src-tauri/src/main.rs +++ b/src-tauri/src/main.rs @@ -130,6 +130,104 @@ async fn cmd_grpc_client_streaming( Ok(()) } +#[tauri::command] +async fn cmd_grpc_bidi_streaming( + endpoint: &str, + service: &str, + method: &str, + app_handle: AppHandle, + grpc_handle: State<'_, Mutex>, +) -> Result { + let (in_msg_tx, mut in_msg_rx) = tauri::async_runtime::channel::(16); + let maybe_in_msg_tx = Mutex::new(Some(in_msg_tx.clone())); + let (cancelled_tx, mut cancelled_rx) = tokio::sync::watch::channel(false); + + let uri = safe_uri(endpoint).map_err(|e| e.to_string())?; + let conn_id = generate_id(Some("grpc")); + + let in_msg_stream = tokio_stream::wrappers::ReceiverStream::new(in_msg_rx); + + let mut stream = grpc_handle + .lock() + .await + .bidi_streaming(&conn_id, uri, service, method, in_msg_stream) + .await + .unwrap(); + + #[derive(serde::Deserialize)] + enum GrpcMessage { + Message(String), + Commit, + Cancel, + } + + let cb = { + let cancelled_rx = cancelled_rx.clone(); + + move |ev: tauri::Event| { + if *cancelled_rx.borrow() { + // Stream is cancelled + return; + } + + match serde_json::from_str::(ev.payload().unwrap()) { + Ok(GrpcMessage::Message(msg)) => { + println!("Received message: {}", msg); + in_msg_tx.try_send(msg).unwrap(); + } + Ok(GrpcMessage::Commit) => { + println!("Received commit"); + // TODO: Commit client streaming stream + } + Ok(GrpcMessage::Cancel) => { + println!("Received cancel"); + cancelled_tx.send_replace(true); + } + Err(e) => { + error!("Failed to parse gRPC message: {:?}", e); + } + } + } + }; + let event_handler = app_handle.listen_global("grpc_message_in", cb); + + let app_handle2 = app_handle.clone(); + let grpc_listen = async move { + loop { + match stream.next().await { + Some(Ok(item)) => { + let item = serde_json::to_string_pretty(&item).unwrap(); + app_handle2 + .emit_all("grpc_message", item) + .expect("Failed to emit"); + } + Some(Err(e)) => { + error!("gRPC stream error: {:?}", e); + // TODO: Handle error + } + None => { + info!("gRPC stream closed by sender"); + break; + } + } + } + }; + + tauri::async_runtime::spawn(async move { + tokio::select! { + _ = grpc_listen => { + debug!("gRPC listen finished"); + }, + _ = cancelled_rx.changed() => { + debug!("gRPC connection cancelled"); + }, + } + app_handle.unlisten(event_handler); + }); + + Ok(conn_id) +} + #[tauri::command] async fn cmd_grpc_server_streaming( endpoint: &str, @@ -1119,6 +1217,7 @@ fn main() { cmd_grpc_call_unary, cmd_grpc_client_streaming, cmd_grpc_server_streaming, + cmd_grpc_bidi_streaming, cmd_grpc_reflect, cmd_import_data, cmd_list_cookie_jars, diff --git a/src-web/components/GrpcConnectionLayout.tsx b/src-web/components/GrpcConnectionLayout.tsx index fa09cc1a..ce9111a2 100644 --- a/src-web/components/GrpcConnectionLayout.tsx +++ b/src-web/components/GrpcConnectionLayout.tsx @@ -2,7 +2,7 @@ import useResizeObserver from '@react-hook/resize-observer'; import classNames from 'classnames'; import { format } from 'date-fns'; import type { CSSProperties, FormEvent } from 'react'; -import React, { useRef, useCallback, useEffect, useMemo, useState } from 'react'; +import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react'; import { useAlert } from '../hooks/useAlert'; import type { GrpcMessage } from '../hooks/useGrpc'; import { useGrpc } from '../hooks/useGrpc'; @@ -15,7 +15,6 @@ import { Icon } from './core/Icon'; import { IconButton } from './core/IconButton'; import { JsonAttributeTree } from './core/JsonAttributeTree'; import { RadioDropdown } from './core/RadioDropdown'; -import { Select } from './core/Select'; import { Separator } from './core/Separator'; import { SplitLayout } from './core/SplitLayout'; import { HStack, VStack } from './core/Stacks'; @@ -55,6 +54,10 @@ export function GrpcConnectionLayout({ style }: Props) { return s.methods.find((m) => m.name === method.value); }, [grpc.schema, method.value, service.value]); + const handleCancel = useCallback(() => { + grpc.cancel.mutateAsync().catch(console.error); + }, [grpc.cancel]); + const handleConnect = useCallback( async (e: FormEvent) => { e.preventDefault(); @@ -92,6 +95,7 @@ export function GrpcConnectionLayout({ style }: Props) { [ activeMethod, alert, + grpc.bidiStreaming, grpc.serverStreaming, grpc.unary, message.value, @@ -195,9 +199,11 @@ export function GrpcConnectionLayout({ style }: Props) { size="sm" title="to-do" hotkeyAction="request.send" - onClick={handleConnect} + onClick={grpc.isStreaming ? handleCancel : handleConnect} icon={ - !activeMethod?.clientStreaming && activeMethod?.serverStreaming + grpc.isStreaming + ? 'x' + : !activeMethod?.clientStreaming && activeMethod?.serverStreaming ? 'arrowDownToDot' : activeMethod?.clientStreaming && !activeMethod?.serverStreaming ? 'arrowUpFromDot' @@ -206,26 +212,29 @@ export function GrpcConnectionLayout({ style }: Props) { : 'sendHorizontal' } /> - { - await grpc.cancel.mutateAsync(); - }} - icon="trash" - /> + {activeMethod?.clientStreaming && ( + grpc.send.mutateAsync({ message: message.value ?? '' })} + icon="sendHorizontal" + /> + )} - + {!service.isLoading && !method.isLoading && ( + + )} )} rightSlot={() => @@ -259,8 +268,20 @@ export function GrpcConnectionLayout({ style }: Props) { )} >
{m.message}
diff --git a/src-web/components/core/Icon.tsx b/src-web/components/core/Icon.tsx index f5bda426..68b380d0 100644 --- a/src-web/components/core/Icon.tsx +++ b/src-web/components/core/Icon.tsx @@ -47,6 +47,7 @@ const icons = { arrowUp: lucide.ArrowUpIcon, arrowBigDownDash: lucide.ArrowBigDownDashIcon, arrowBigUpDash: lucide.ArrowBigUpDashIcon, + info: lucide.InfoIcon, x: lucide.XIcon, empty: (props: HTMLAttributes) => , diff --git a/src-web/hooks/useGrpc.ts b/src-web/hooks/useGrpc.ts index 5e61e281..2099333a 100644 --- a/src-web/hooks/useGrpc.ts +++ b/src-web/hooks/useGrpc.ts @@ -1,7 +1,9 @@ import { useMutation, useQuery } from '@tanstack/react-query'; import { invoke } from '@tauri-apps/api'; +import { message } from '@tauri-apps/api/dialog'; import { emit } from '@tauri-apps/api/event'; import { useState } from 'react'; +import { send } from 'vite'; import { useListenToTauriEvent } from './useListenToTauriEvent'; interface ReflectResponseService { @@ -12,7 +14,7 @@ interface ReflectResponseService { export interface GrpcMessage { message: string; time: Date; - isServer: boolean; + type: 'server' | 'client' | 'info'; } export function useGrpc(url: string | null) { @@ -23,7 +25,7 @@ export function useGrpc(url: string | null) { (event) => { setMessages((prev) => [ ...prev, - { message: event.payload, time: new Date(), isServer: true }, + { message: event.payload, time: new Date(), type: 'server' }, ]); }, [], @@ -50,7 +52,7 @@ export function useGrpc(url: string | null) { mutationFn: async ({ service, method, message }) => { if (url === null) throw new Error('No URL provided'); setMessages([ - { isServer: false, message: JSON.stringify(JSON.parse(message)), time: new Date() }, + { type: 'client', message: JSON.stringify(JSON.parse(message)), time: new Date() }, ]); const id: string = await invoke('cmd_grpc_server_streaming', { endpoint: url, @@ -70,22 +72,34 @@ export function useGrpc(url: string | null) { mutationKey: ['grpc_bidi_streaming', url], mutationFn: async ({ service, method, message }) => { if (url === null) throw new Error('No URL provided'); - setMessages([]); const id: string = await invoke('cmd_grpc_bidi_streaming', { endpoint: url, service, method, message, }); + setMessages([{ type: 'info', message: `Started connection ${id}`, time: new Date() }]); setActiveConnectionId(id); }, }); + const send = useMutation({ + mutationKey: ['grpc_send', url], + mutationFn: async ({ message }: { message: string }) => { + await emit('grpc_message_in', { Message: message }); + setMessages((m) => [...m, { type: 'client', message, time: new Date() }]); + }, + }); + const cancel = useMutation({ mutationKey: ['grpc_cancel', url], mutationFn: async () => { await emit('grpc_message_in', 'Cancel'); setActiveConnectionId(null); + setMessages((m) => [ + ...m, + { type: 'info', message: 'Cancelled by client', time: new Date() }, + ]); }, }); @@ -104,5 +118,7 @@ export function useGrpc(url: string | null) { schema: reflect.data, cancel, messages, + isStreaming: activeConnectionId !== null, + send, }; }