Websockets for plugin runtime communication (#156)

This commit is contained in:
Gregory Schier
2025-01-20 10:55:53 -08:00
committed by GitHub
parent 095aaa5e92
commit b698a56549
54 changed files with 841 additions and 1185 deletions

View File

@@ -1,4 +1,4 @@
use crate::server::plugin_runtime::EventStreamEvent;
use crate::events::InternalEvent;
use thiserror::Error;
use tokio::io;
use tokio::sync::mpsc::error::SendError;
@@ -14,18 +14,15 @@ pub enum Error {
#[error("Tauri shell error: {0}")]
TauriShellErr(#[from] tauri_plugin_shell::Error),
#[error("Grpc transport error: {0}")]
GrpcTransportErr(#[from] tonic::transport::Error),
#[error("Grpc send error: {0}")]
GrpcSendErr(#[from] SendError<tonic::Result<EventStreamEvent>>),
GrpcSendErr(#[from] SendError<InternalEvent>),
#[error("JSON error: {0}")]
JsonErr(#[from] serde_json::Error),
#[error("Plugin not found: {0}")]
PluginNotFoundErr(String),
#[error("Auth plugin not found: {0}")]
AuthPluginNotFound(String),

View File

@@ -1,4 +1,5 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use tauri::{Runtime, WebviewWindow};
use ts_rs::TS;
@@ -11,9 +12,24 @@ use yaak_models::models::{Environment, Folder, GrpcRequest, HttpRequest, HttpRes
pub struct InternalEvent {
pub id: String,
pub plugin_ref_id: String,
pub plugin_name: String,
pub reply_id: Option<String>,
pub payload: InternalEventPayload,
pub window_context: WindowContext,
pub payload: InternalEventPayload,
}
/// Special type used to deserialize everything but the payload. This is so we can
/// catch any plugin-related type errors, since payload is sent by the plugin author
/// and all other fields are sent by Yaak first-party code.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct InternalEventRawPayload {
pub id: String,
pub plugin_ref_id: String,
pub plugin_name: String,
pub reply_id: Option<String>,
pub window_context: WindowContext,
pub payload: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize, TS)]
@@ -95,6 +111,8 @@ pub enum InternalEventPayload {
/// Returned when a plugin doesn't get run, just so the server
/// has something to listen for
EmptyResponse(EmptyPayload),
ErrorResponse(ErrorResponse),
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, TS)]
@@ -102,6 +120,13 @@ pub enum InternalEventPayload {
#[ts(export, type = "{}", export_to = "events.ts")]
pub struct EmptyPayload {}
#[derive(Debug, Clone, Default, Serialize, Deserialize, TS)]
#[serde(default)]
#[ts(export, export_to = "events.ts")]
pub struct ErrorResponse {
pub error: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, TS)]
#[serde(default, rename_all = "camelCase")]
#[ts(export, export_to = "events.ts")]
@@ -116,7 +141,6 @@ pub struct BootRequest {
pub struct BootResponse {
pub name: String,
pub version: String,
pub capabilities: Vec<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, TS)]

View File

@@ -9,7 +9,7 @@ pub mod events;
pub mod manager;
mod nodejs;
pub mod plugin_handle;
mod server;
mod server_ws;
mod util;
pub fn init<R: Runtime>() -> TauriPlugin<R> {

View File

@@ -12,8 +12,7 @@ use crate::events::{
};
use crate::nodejs::start_nodejs_plugin_runtime;
use crate::plugin_handle::PluginHandle;
use crate::server::plugin_runtime::plugin_runtime_server::PluginRuntimeServer;
use crate::server::PluginRuntimeServerImpl;
use crate::server_ws::PluginRuntimeServerWebsocket;
use log::{info, warn};
use std::collections::HashMap;
use std::env;
@@ -25,8 +24,6 @@ use tauri::{AppHandle, Manager, Runtime, WebviewWindow};
use tokio::fs::read_dir;
use tokio::net::TcpListener;
use tokio::sync::{mpsc, Mutex};
use tonic::codegen::tokio_stream;
use tonic::transport::Server;
use yaak_models::queries::{generate_id, list_plugins};
#[derive(Clone)]
@@ -34,7 +31,7 @@ pub struct PluginManager {
subscribers: Arc<Mutex<HashMap<String, mpsc::Sender<InternalEvent>>>>,
plugins: Arc<Mutex<Vec<PluginHandle>>>,
kill_tx: tokio::sync::watch::Sender<bool>,
grpc_service: Arc<PluginRuntimeServerImpl>,
ws_service: Arc<PluginRuntimeServerWebsocket>,
}
#[derive(Clone)]
@@ -50,13 +47,13 @@ impl PluginManager {
let (client_disconnect_tx, mut client_disconnect_rx) = mpsc::channel(128);
let (client_connect_tx, mut client_connect_rx) = tokio::sync::watch::channel(false);
let grpc_service =
PluginRuntimeServerImpl::new(events_tx, client_disconnect_tx, client_connect_tx);
let ws_service =
PluginRuntimeServerWebsocket::new(events_tx, client_disconnect_tx, client_connect_tx);
let plugin_manager = PluginManager {
plugins: Arc::new(Mutex::new(Vec::new())),
subscribers: Arc::new(Mutex::new(HashMap::new())),
grpc_service: Arc::new(grpc_service.clone()),
ws_service: Arc::new(ws_service.clone()),
kill_tx: kill_server_tx,
};
@@ -79,14 +76,9 @@ impl PluginManager {
}
});
info!("Starting plugin server");
let svc = PluginRuntimeServer::new(grpc_service.to_owned())
.max_encoding_message_size(usize::MAX)
.max_decoding_message_size(usize::MAX);
let listen_addr = match option_env!("PORT") {
None => "localhost:0".to_string(),
let listen_addr = match option_env!("YAAK_PLUGIN_SERVER_PORT") {
Some(port) => format!("localhost:{port}"),
None => "localhost:0".to_string(),
};
let listener = tauri::async_runtime::block_on(async move {
TcpListener::bind(listen_addr).await.expect("Failed to bind TCP listener")
@@ -114,14 +106,9 @@ impl PluginManager {
};
// 1. Spawn server in the background
info!("Starting gRPC plugin server on {addr}");
info!("Starting plugin server on {addr}");
tauri::async_runtime::spawn(async move {
Server::builder()
.timeout(Duration::from_secs(10))
.add_service(svc)
.serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener))
.await
.expect("grpc plugin runtime server failed to start");
ws_service.listen(listener).await;
});
// 2. Start Node.js runtime and initialize plugins
@@ -203,7 +190,7 @@ impl PluginManager {
watch: bool,
) -> Result<()> {
info!("Adding plugin by dir {dir}");
let maybe_tx = self.grpc_service.app_to_plugin_events_tx.lock().await;
let maybe_tx = self.ws_service.app_to_plugin_events_tx.lock().await;
let tx = match &*maybe_tx {
None => return Err(ClientNotInitializedErr),
Some(tx) => tx,
@@ -357,21 +344,23 @@ impl PluginManager {
.collect::<Vec<InternalEvent>>();
// 2. Spawn thread to subscribe to incoming events and check reply ids
let send_events_fut = {
let sub_events_fut = {
let events_to_send = events_to_send.clone();
tokio::spawn(async move {
let mut found_events = Vec::new();
while let Some(event) = rx.recv().await {
if events_to_send
let matched_sent_event = events_to_send
.iter()
.find(|e| Some(e.id.to_owned()) == event.reply_id)
.is_some()
{
.is_some();
if matched_sent_event {
found_events.push(event.clone());
};
if found_events.len() == events_to_send.len() {
let found_them_all = found_events.len() == events_to_send.len();
if found_them_all{
break;
}
}
@@ -390,7 +379,7 @@ impl PluginManager {
}
// 4. Join on the spawned thread
let events = send_events_fut.await.expect("Thread didn't succeed");
let events = sub_events_fut.await.expect("Thread didn't succeed");
// 5. Unsubscribe
self.unsubscribe(rx_id.as_str()).await;
@@ -502,7 +491,7 @@ impl PluginManager {
// Clone for mutability
let mut req = req.clone();
// Fill in default values
// Fill in default values
for arg in authentication.config.clone() {
let base = match arg {
FormInput::Text(a) => a.base,

View File

@@ -32,8 +32,8 @@ pub async fn start_nodejs_plugin_runtime<R: Runtime>(
let cmd = app
.shell()
.sidecar("yaaknode")?
.env("PORT", addr.port().to_string())
.args(&[plugin_runtime_main]);
.env("YAAK_PLUGIN_RUNTIME_PORT", addr.port().to_string())
.args(&[&plugin_runtime_main]);
let (mut child_rx, child) = cmd.spawn()?;
info!("Spawned plugin runtime");

View File

@@ -1,8 +1,8 @@
use crate::error::Result;
use crate::events::{BootResponse, InternalEvent, InternalEventPayload, WindowContext};
use crate::server::plugin_runtime::EventStreamEvent;
use crate::util::gen_id;
use log::info;
use std::path::Path;
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
@@ -10,12 +10,12 @@ use tokio::sync::{mpsc, Mutex};
pub struct PluginHandle {
pub ref_id: String,
pub dir: String,
pub(crate) to_plugin_tx: Arc<Mutex<mpsc::Sender<tonic::Result<EventStreamEvent>>>>,
pub(crate) to_plugin_tx: Arc<Mutex<mpsc::Sender<InternalEvent>>>,
pub(crate) boot_resp: Arc<Mutex<BootResponse>>,
}
impl PluginHandle {
pub fn new(dir: &str, tx: mpsc::Sender<tonic::Result<EventStreamEvent>>) -> Self {
pub fn new(dir: &str, tx: mpsc::Sender<InternalEvent>) -> Self {
let ref_id = gen_id();
PluginHandle {
@@ -46,9 +46,11 @@ impl PluginHandle {
payload: &InternalEventPayload,
reply_id: Option<String>,
) -> InternalEvent {
let dir = Path::new(&self.dir);
InternalEvent {
id: gen_id(),
plugin_ref_id: self.ref_id.clone(),
plugin_name: dir.file_name().unwrap().to_str().unwrap().to_string(),
reply_id,
payload: payload.clone(),
window_context,
@@ -63,13 +65,7 @@ impl PluginHandle {
}
pub(crate) async fn send(&self, event: &InternalEvent) -> Result<()> {
self.to_plugin_tx
.lock()
.await
.send(Ok(EventStreamEvent {
event: serde_json::to_string(event)?,
}))
.await?;
self.to_plugin_tx.lock().await.send(event.to_owned()).await?;
Ok(())
}

View File

@@ -1,99 +0,0 @@
use log::warn;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
use tonic::codegen::tokio_stream::wrappers::ReceiverStream;
use tonic::codegen::tokio_stream::{Stream, StreamExt};
use tonic::{Request, Response, Status, Streaming};
use crate::events::InternalEvent;
use crate::server::plugin_runtime::plugin_runtime_server::PluginRuntime;
use plugin_runtime::EventStreamEvent;
pub mod plugin_runtime {
tonic::include_proto!("yaak.plugins.runtime");
}
type ResponseStream = Pin<Box<dyn Stream<Item = Result<EventStreamEvent, Status>> + Send>>;
#[derive(Clone)]
pub(crate) struct PluginRuntimeServerImpl {
pub(crate) app_to_plugin_events_tx:
Arc<Mutex<Option<mpsc::Sender<tonic::Result<EventStreamEvent>>>>>,
client_disconnect_tx: mpsc::Sender<bool>,
client_connect_tx: tokio::sync::watch::Sender<bool>,
plugin_to_app_events_tx: mpsc::Sender<InternalEvent>,
}
impl PluginRuntimeServerImpl {
pub fn new(
events_tx: mpsc::Sender<InternalEvent>,
disconnect_tx: mpsc::Sender<bool>,
connect_tx: tokio::sync::watch::Sender<bool>,
) -> Self {
PluginRuntimeServerImpl {
app_to_plugin_events_tx: Arc::new(Mutex::new(None)),
client_disconnect_tx: disconnect_tx,
client_connect_tx: connect_tx,
plugin_to_app_events_tx: events_tx,
}
}
}
#[tonic::async_trait]
impl PluginRuntime for PluginRuntimeServerImpl {
type EventStreamStream = ResponseStream;
async fn event_stream(
&self,
req: Request<Streaming<EventStreamEvent>>,
) -> tonic::Result<Response<Self::EventStreamStream>> {
let mut in_stream = req.into_inner();
let (to_plugin_tx, to_plugin_rx) = mpsc::channel::<tonic::Result<EventStreamEvent>>(128);
let mut app_to_plugin_events_tx = self.app_to_plugin_events_tx.lock().await;
*app_to_plugin_events_tx = Some(to_plugin_tx);
let plugin_to_app_events_tx = self.plugin_to_app_events_tx.clone();
let client_disconnect_tx = self.client_disconnect_tx.clone();
self.client_connect_tx.send(true).expect("Failed to send client ready event");
tokio::spawn(async move {
while let Some(result) = in_stream.next().await {
// Received event from plugin runtime
match result {
Ok(v) => {
let event: InternalEvent = match serde_json::from_str(v.event.as_str()) {
Ok(pe) => pe,
Err(e) => {
warn!("Failed to deserialize event {e:?} -> {}", v.event);
continue;
}
};
// Send event to subscribers
// Emit event to the channel for server to handle
if let Err(e) = plugin_to_app_events_tx.try_send(event.clone()) {
warn!("Failed to send to channel. Receiver probably isn't listening: {:?}", e);
}
}
Err(err) => {
// TODO: Better error handling
warn!("gRPC server error {err}");
break;
}
};
}
if let Err(e) = client_disconnect_tx.send(true).await {
warn!("Failed to send killed event {:?}", e);
}
});
// Write the same data that was received
let out_stream = ReceiverStream::new(to_plugin_rx);
Ok(Response::new(Box::pin(out_stream) as Self::EventStreamStream))
}
}

View File

@@ -0,0 +1,132 @@
use crate::events::{ErrorResponse, InternalEvent, InternalEventPayload, InternalEventRawPayload};
use futures_util::{SinkExt, StreamExt};
use log::{error, info, warn};
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{mpsc, Mutex};
use tokio_tungstenite::accept_async_with_config;
use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
use tokio_tungstenite::tungstenite::Message;
#[derive(Clone)]
pub(crate) struct PluginRuntimeServerWebsocket {
pub(crate) app_to_plugin_events_tx: Arc<Mutex<Option<mpsc::Sender<InternalEvent>>>>,
client_disconnect_tx: mpsc::Sender<bool>,
client_connect_tx: tokio::sync::watch::Sender<bool>,
plugin_to_app_events_tx: mpsc::Sender<InternalEvent>,
}
impl PluginRuntimeServerWebsocket {
pub fn new(
events_tx: mpsc::Sender<InternalEvent>,
disconnect_tx: mpsc::Sender<bool>,
connect_tx: tokio::sync::watch::Sender<bool>,
) -> Self {
PluginRuntimeServerWebsocket {
app_to_plugin_events_tx: Arc::new(Mutex::new(None)),
client_disconnect_tx: disconnect_tx,
client_connect_tx: connect_tx,
plugin_to_app_events_tx: events_tx,
}
}
pub async fn listen(&self, listener: TcpListener) {
while let Ok((stream, _)) = listener.accept().await {
self.accept_connection(stream).await;
}
}
async fn accept_connection(&self, stream: TcpStream) {
let (to_plugin_tx, mut to_plugin_rx) = mpsc::channel::<InternalEvent>(128);
let mut app_to_plugin_events_tx = self.app_to_plugin_events_tx.lock().await;
*app_to_plugin_events_tx = Some(to_plugin_tx);
let plugin_to_app_events_tx = self.plugin_to_app_events_tx.clone();
let client_disconnect_tx = self.client_disconnect_tx.clone();
let client_connect_tx = self.client_connect_tx.clone();
let addr = stream.peer_addr().expect("connected streams should have a peer address");
let conf = WebSocketConfig::default();
let ws_stream = accept_async_with_config(stream, Some(conf))
.await
.expect("Error during the websocket handshake occurred");
let (mut ws_sender, mut ws_receiver) = ws_stream.split();
tauri::async_runtime::spawn(async move {
client_connect_tx.send(true).expect("Failed to send client ready event");
info!("New plugin runtime websocket connection: {}", addr);
loop {
tokio::select! {
msg = ws_receiver.next() => {
let msg = match msg {
Some(Ok(msg)) => msg,
Some(Err(e)) => {
warn!("Websocket error {e:?}");
continue;
}
None => break,
};
// Skip non-text messages
if !msg.is_text() {
return;
}
let event = match serde_json::from_str::<InternalEventRawPayload>(&msg.into_text().unwrap()) {
Ok(e) => e,
Err(e) => {
error!("Failed to decode plugin event {e:?}");
continue;
}
};
// Parse everything but the payload so we can catch errors on that, specifically
let payload = serde_json::from_value::<InternalEventPayload>(event.payload)
.unwrap_or_else(|e| {
InternalEventPayload::ErrorResponse(ErrorResponse {
error: format!("Plugin error from {}: {e:?}", event.plugin_name),
})
});
let event = InternalEvent{
id: event.id,
payload,
plugin_ref_id: event.plugin_ref_id,
plugin_name: event.plugin_name,
window_context: event.window_context,
reply_id: event.reply_id,
};
// Send event to subscribers
// Emit event to the channel for server to handle
if let Err(e) = plugin_to_app_events_tx.try_send(event) {
warn!("Failed to send to channel. Receiver probably isn't listening: {:?}", e);
}
}
event_for_plugin = to_plugin_rx.recv() => {
match event_for_plugin {
None => {
error!("Plugin runtime client WS channel closed");
return;
},
Some(event) => {
let event_bytes = serde_json::to_string(&event).unwrap();
let msg = Message::text(event_bytes);
ws_sender.send(msg).await.unwrap();
}
}
}
}
}
if let Err(e) = client_disconnect_tx.send(true).await {
warn!("Failed to send killed event {:?}", e);
}
});
}
}