bidi hacked!

This commit is contained in:
Gregory Schier
2024-02-02 01:10:54 -08:00
parent c83d904cf0
commit 50866abda4
7 changed files with 219 additions and 30 deletions

1
src-tauri/Cargo.lock generated
View File

@@ -6247,6 +6247,7 @@ dependencies = [
"tauri-plugin-log",
"tauri-plugin-window-state",
"tokio",
"tokio-stream",
"uuid",
"window-shadows",
]

View File

@@ -61,6 +61,7 @@ datetime = "0.5.2"
window-shadows = "0.2.2"
reqwest_cookie_store = "0.6.0"
grpc = { path = "./grpc" }
tokio-stream = "0.1.14"
[features]
# by default Tauri runs in production mode

View File

@@ -3,12 +3,15 @@ use std::collections::HashMap;
use hyper::client::HttpConnector;
use hyper::Client;
use hyper_rustls::HttpsConnector;
use prost_reflect::{DescriptorPool, DynamicMessage};
use prost_reflect::DescriptorPool;
pub use prost_reflect::DynamicMessage;
use serde_json::Deserializer;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::{Stream, StreamExt};
use tonic::body::BoxBody;
use tonic::transport::Uri;
use tonic::{IntoRequest, Streaming};
use tonic::{IntoRequest, IntoStreamingRequest, Streaming};
use crate::codec::DynamicCodec;
use crate::proto::{fill_pool, method_desc_to_path};
@@ -52,6 +55,39 @@ impl GrpcConnection {
Ok(response_json)
}
pub async fn bidi_streaming(
&self,
service: &str,
method: &str,
stream: ReceiverStream<String>,
) -> Result<Streaming<DynamicMessage>> {
let service = self.pool.get_service_by_name(service).unwrap();
let method = &service.methods().find(|m| m.name() == method).unwrap();
let mut client = tonic::client::Grpc::with_origin(self.conn.clone(), self.uri.clone());
let method2 = method.clone();
let req = stream
.map(move |s| {
let mut deserializer = Deserializer::from_str(&s);
let req_message = DynamicMessage::deserialize(method2.input(), &mut deserializer)
.map_err(|e| e.to_string())
.unwrap();
deserializer.end().unwrap();
req_message
})
.into_streaming_request();
let path = method_desc_to_path(method);
let codec = DynamicCodec::new(method.clone());
client.ready().await.unwrap();
Ok(client
.streaming(req, path, codec)
.await
.map_err(|s| s.to_string())?
.into_inner())
}
pub async fn server_streaming(
&self,
service: &str,
@@ -120,6 +156,20 @@ impl GrpcManager {
.await
}
pub async fn bidi_streaming(
&mut self,
id: &str,
uri: Uri,
service: &str,
method: &str,
stream: ReceiverStream<String>,
) -> Result<Streaming<DynamicMessage>> {
println!("Bidi streaming {}", id);
self.connect(id, uri)
.await
.bidi_streaming(service, method, stream)
.await
}
pub async fn connect(&mut self, id: &str, uri: Uri) -> GrpcConnection {
let (pool, conn) = fill_pool(&uri).await;
let connection = GrpcConnection { pool, conn, uri };

View File

@@ -130,6 +130,104 @@ async fn cmd_grpc_client_streaming(
Ok(())
}
#[tauri::command]
async fn cmd_grpc_bidi_streaming(
endpoint: &str,
service: &str,
method: &str,
app_handle: AppHandle<Wry>,
grpc_handle: State<'_, Mutex<GrpcManager>>,
) -> Result<String, String> {
let (in_msg_tx, mut in_msg_rx) = tauri::async_runtime::channel::<String>(16);
let maybe_in_msg_tx = Mutex::new(Some(in_msg_tx.clone()));
let (cancelled_tx, mut cancelled_rx) = tokio::sync::watch::channel(false);
let uri = safe_uri(endpoint).map_err(|e| e.to_string())?;
let conn_id = generate_id(Some("grpc"));
let in_msg_stream = tokio_stream::wrappers::ReceiverStream::new(in_msg_rx);
let mut stream = grpc_handle
.lock()
.await
.bidi_streaming(&conn_id, uri, service, method, in_msg_stream)
.await
.unwrap();
#[derive(serde::Deserialize)]
enum GrpcMessage {
Message(String),
Commit,
Cancel,
}
let cb = {
let cancelled_rx = cancelled_rx.clone();
move |ev: tauri::Event| {
if *cancelled_rx.borrow() {
// Stream is cancelled
return;
}
match serde_json::from_str::<GrpcMessage>(ev.payload().unwrap()) {
Ok(GrpcMessage::Message(msg)) => {
println!("Received message: {}", msg);
in_msg_tx.try_send(msg).unwrap();
}
Ok(GrpcMessage::Commit) => {
println!("Received commit");
// TODO: Commit client streaming stream
}
Ok(GrpcMessage::Cancel) => {
println!("Received cancel");
cancelled_tx.send_replace(true);
}
Err(e) => {
error!("Failed to parse gRPC message: {:?}", e);
}
}
}
};
let event_handler = app_handle.listen_global("grpc_message_in", cb);
let app_handle2 = app_handle.clone();
let grpc_listen = async move {
loop {
match stream.next().await {
Some(Ok(item)) => {
let item = serde_json::to_string_pretty(&item).unwrap();
app_handle2
.emit_all("grpc_message", item)
.expect("Failed to emit");
}
Some(Err(e)) => {
error!("gRPC stream error: {:?}", e);
// TODO: Handle error
}
None => {
info!("gRPC stream closed by sender");
break;
}
}
}
};
tauri::async_runtime::spawn(async move {
tokio::select! {
_ = grpc_listen => {
debug!("gRPC listen finished");
},
_ = cancelled_rx.changed() => {
debug!("gRPC connection cancelled");
},
}
app_handle.unlisten(event_handler);
});
Ok(conn_id)
}
#[tauri::command]
async fn cmd_grpc_server_streaming(
endpoint: &str,
@@ -1119,6 +1217,7 @@ fn main() {
cmd_grpc_call_unary,
cmd_grpc_client_streaming,
cmd_grpc_server_streaming,
cmd_grpc_bidi_streaming,
cmd_grpc_reflect,
cmd_import_data,
cmd_list_cookie_jars,

View File

@@ -2,7 +2,7 @@ import useResizeObserver from '@react-hook/resize-observer';
import classNames from 'classnames';
import { format } from 'date-fns';
import type { CSSProperties, FormEvent } from 'react';
import React, { useRef, useCallback, useEffect, useMemo, useState } from 'react';
import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { useAlert } from '../hooks/useAlert';
import type { GrpcMessage } from '../hooks/useGrpc';
import { useGrpc } from '../hooks/useGrpc';
@@ -15,7 +15,6 @@ import { Icon } from './core/Icon';
import { IconButton } from './core/IconButton';
import { JsonAttributeTree } from './core/JsonAttributeTree';
import { RadioDropdown } from './core/RadioDropdown';
import { Select } from './core/Select';
import { Separator } from './core/Separator';
import { SplitLayout } from './core/SplitLayout';
import { HStack, VStack } from './core/Stacks';
@@ -55,6 +54,10 @@ export function GrpcConnectionLayout({ style }: Props) {
return s.methods.find((m) => m.name === method.value);
}, [grpc.schema, method.value, service.value]);
const handleCancel = useCallback(() => {
grpc.cancel.mutateAsync().catch(console.error);
}, [grpc.cancel]);
const handleConnect = useCallback(
async (e: FormEvent) => {
e.preventDefault();
@@ -92,6 +95,7 @@ export function GrpcConnectionLayout({ style }: Props) {
[
activeMethod,
alert,
grpc.bidiStreaming,
grpc.serverStreaming,
grpc.unary,
message.value,
@@ -195,9 +199,11 @@ export function GrpcConnectionLayout({ style }: Props) {
size="sm"
title="to-do"
hotkeyAction="request.send"
onClick={handleConnect}
onClick={grpc.isStreaming ? handleCancel : handleConnect}
icon={
!activeMethod?.clientStreaming && activeMethod?.serverStreaming
grpc.isStreaming
? 'x'
: !activeMethod?.clientStreaming && activeMethod?.serverStreaming
? 'arrowDownToDot'
: activeMethod?.clientStreaming && !activeMethod?.serverStreaming
? 'arrowUpFromDot'
@@ -206,26 +212,29 @@ export function GrpcConnectionLayout({ style }: Props) {
: 'sendHorizontal'
}
/>
<IconButton
className="border border-highlight"
size="sm"
title="to-do"
onClick={async () => {
await grpc.cancel.mutateAsync();
}}
icon="trash"
/>
{activeMethod?.clientStreaming && (
<IconButton
className="border border-highlight"
size="sm"
title="to-do"
hotkeyAction="request.send"
onClick={() => grpc.send.mutateAsync({ message: message.value ?? '' })}
icon="sendHorizontal"
/>
)}
</HStack>
</div>
<GrpcEditor
forceUpdateKey={[service, method].join('::')}
url={url.value ?? ''}
defaultValue={message.value}
onChange={message.set}
service={service.value ?? null}
method={method.value ?? null}
className="bg-gray-50"
/>
{!service.isLoading && !method.isLoading && (
<GrpcEditor
forceUpdateKey={[service, method].join('::')}
url={url.value ?? ''}
defaultValue={message.value}
onChange={message.set}
service={service.value ?? null}
method={method.value ?? null}
className="bg-gray-50"
/>
)}
</VStack>
)}
rightSlot={() =>
@@ -259,8 +268,20 @@ export function GrpcConnectionLayout({ style }: Props) {
)}
>
<Icon
className={m.isServer ? 'text-blue-600' : 'text-green-600'}
icon={m.isServer ? 'arrowBigDownDash' : 'arrowBigUpDash'}
className={
m.type === 'server'
? 'text-blue-600'
: m.type === 'client'
? 'text-green-600'
: 'text-gray-600'
}
icon={
m.type === 'server'
? 'arrowBigDownDash'
: m.type === 'client'
? 'arrowBigUpDash'
: 'info'
}
/>
<div className="w-full truncate text-gray-800 text-xs">{m.message}</div>
<div className="text-gray-600 text-2xs" title={m.time.toISOString()}>

View File

@@ -47,6 +47,7 @@ const icons = {
arrowUp: lucide.ArrowUpIcon,
arrowBigDownDash: lucide.ArrowBigDownDashIcon,
arrowBigUpDash: lucide.ArrowBigUpDashIcon,
info: lucide.InfoIcon,
x: lucide.XIcon,
empty: (props: HTMLAttributes<HTMLSpanElement>) => <span {...props} />,

View File

@@ -1,7 +1,9 @@
import { useMutation, useQuery } from '@tanstack/react-query';
import { invoke } from '@tauri-apps/api';
import { message } from '@tauri-apps/api/dialog';
import { emit } from '@tauri-apps/api/event';
import { useState } from 'react';
import { send } from 'vite';
import { useListenToTauriEvent } from './useListenToTauriEvent';
interface ReflectResponseService {
@@ -12,7 +14,7 @@ interface ReflectResponseService {
export interface GrpcMessage {
message: string;
time: Date;
isServer: boolean;
type: 'server' | 'client' | 'info';
}
export function useGrpc(url: string | null) {
@@ -23,7 +25,7 @@ export function useGrpc(url: string | null) {
(event) => {
setMessages((prev) => [
...prev,
{ message: event.payload, time: new Date(), isServer: true },
{ message: event.payload, time: new Date(), type: 'server' },
]);
},
[],
@@ -50,7 +52,7 @@ export function useGrpc(url: string | null) {
mutationFn: async ({ service, method, message }) => {
if (url === null) throw new Error('No URL provided');
setMessages([
{ isServer: false, message: JSON.stringify(JSON.parse(message)), time: new Date() },
{ type: 'client', message: JSON.stringify(JSON.parse(message)), time: new Date() },
]);
const id: string = await invoke('cmd_grpc_server_streaming', {
endpoint: url,
@@ -70,22 +72,34 @@ export function useGrpc(url: string | null) {
mutationKey: ['grpc_bidi_streaming', url],
mutationFn: async ({ service, method, message }) => {
if (url === null) throw new Error('No URL provided');
setMessages([]);
const id: string = await invoke('cmd_grpc_bidi_streaming', {
endpoint: url,
service,
method,
message,
});
setMessages([{ type: 'info', message: `Started connection ${id}`, time: new Date() }]);
setActiveConnectionId(id);
},
});
const send = useMutation({
mutationKey: ['grpc_send', url],
mutationFn: async ({ message }: { message: string }) => {
await emit('grpc_message_in', { Message: message });
setMessages((m) => [...m, { type: 'client', message, time: new Date() }]);
},
});
const cancel = useMutation({
mutationKey: ['grpc_cancel', url],
mutationFn: async () => {
await emit('grpc_message_in', 'Cancel');
setActiveConnectionId(null);
setMessages((m) => [
...m,
{ type: 'info', message: 'Cancelled by client', time: new Date() },
]);
},
});
@@ -104,5 +118,7 @@ export function useGrpc(url: string | null) {
schema: reflect.data,
cancel,
messages,
isStreaming: activeConnectionId !== null,
send,
};
}