Plugin runtime v2 (#62)

This commit is contained in:
Gregory Schier
2024-08-08 21:30:59 -07:00
committed by GitHub
parent f967820f12
commit 063e6cf00c
64 changed files with 1539 additions and 705 deletions

3
src-tauri/Cargo.lock generated
View File

@@ -7632,9 +7632,12 @@ dependencies = [
"serde_json",
"tauri",
"tauri-plugin-shell",
"thiserror",
"tokio",
"tonic 0.12.1",
"tonic-build",
"ts-rs",
"yaak_models",
]
[[package]]

View File

@@ -29,7 +29,6 @@ openssl-sys = { version = "0.9", features = ["vendored"] } # For Ubuntu installa
grpc = { path = "./grpc" }
templates = { path = "./templates" }
yaak_plugin_runtime = { path = "yaak_plugin_runtime" }
yaak_models = { path = "yaak_models" }
anyhow = "1.0.86"
base64 = "0.22.0"
chrono = { version = "0.4.31", features = ["serde"] }
@@ -56,10 +55,12 @@ tauri-plugin-updater = "2.0.0-rc.0"
tauri-plugin-window-state = "2.0.0-rc.0"
tokio = { version = "1.36.0", features = ["sync"] }
tokio-stream = "0.1.15"
yaak_models = {workspace = true}
uuid = "1.7.0"
thiserror = "1.0.61"
mime_guess = "2.0.5"
[workspace.dependencies]
yaak_models = { path = "yaak_models" }
tauri = { version = "2.0.0-rc.0", features = ["devtools", "protocol-asset"] }
tauri-plugin-shell = "2.0.0-rc.0"

View File

@@ -35,9 +35,9 @@ use ::grpc::{deserialize_message, serialize_message, Code, ServiceDefinition};
use yaak_plugin_runtime::manager::PluginManager;
use crate::analytics::{AnalyticsAction, AnalyticsResource};
use crate::export_resources::{get_workspace_export_resources, WorkspaceExportResources};
use crate::grpc::metadata_to_map;
use crate::http_request::send_http_request;
use crate::export_resources::{get_workspace_export_resources, ImportResult, WorkspaceExportResources};
use crate::notifications::YaakNotifier;
use crate::render::{render_request, variables_from_environment};
use crate::updates::{UpdateMode, YaakUpdater};
@@ -61,9 +61,9 @@ use yaak_models::queries::{
};
mod analytics;
mod export_resources;
mod grpc;
mod http_request;
mod export_resources;
mod notifications;
mod render;
#[cfg(target_os = "macos")]
@@ -102,13 +102,13 @@ struct AppMetaData {
async fn cmd_metadata(app_handle: AppHandle) -> Result<AppMetaData, ()> {
let app_data_dir = app_handle.path().app_data_dir().unwrap();
let app_log_dir = app_handle.path().app_log_dir().unwrap();
return Ok(AppMetaData {
Ok(AppMetaData {
is_dev: is_dev(),
version: app_handle.package_info().version.to_string(),
name: app_handle.package_info().name.to_string(),
app_data_dir: app_data_dir.to_string_lossy().to_string(),
app_log_dir: app_log_dir.to_string_lossy().to_string(),
});
})
}
#[tauri::command]
@@ -720,7 +720,8 @@ async fn cmd_filter_response(
response_id: &str,
plugin_manager: State<'_, Mutex<PluginManager>>,
filter: &str,
) -> Result<String, String> {
) -> Result<Vec<Value>, String> {
println!("FILTERING? {filter}");
let response = get_http_response(&w, response_id)
.await
.expect("Failed to get response");
@@ -743,9 +744,10 @@ async fn cmd_filter_response(
plugin_manager
.lock()
.await
.run_response_filter(filter, &body, &content_type)
.run_filter(filter, &body, &content_type)
.await
.map(|r| r.data)
.map(|r| r.items)
.map_err(|e| e.to_string())
}
#[tauri::command]
@@ -753,29 +755,16 @@ async fn cmd_import_data(
w: WebviewWindow,
plugin_manager: State<'_, Mutex<PluginManager>>,
file_path: &str,
_workspace_id: &str,
) -> Result<WorkspaceExportResources, String> {
let file =
read_to_string(file_path).unwrap_or_else(|_| panic!("Unable to read file {}", file_path));
let file_contents = file.as_str();
let import_response = plugin_manager
let (import_result, plugin_name) = plugin_manager
.lock()
.await
.run_import(file_contents)
.await?;
let import_result: ImportResult =
serde_json::from_str(import_response.data.as_str()).map_err(|e| e.to_string())?;
// TODO: Track the plugin that ran, maybe return the run info in the plugin response?
let plugin_name = import_response.info.unwrap_or_default().plugin;
info!("Imported data using {}", plugin_name);
analytics::track_event(
&w.app_handle(),
AnalyticsResource::App,
AnalyticsAction::Import,
Some(json!({ "plugin": plugin_name })),
)
.await;
.await
.map_err(|e| e.to_string())?;
let mut imported_resources = WorkspaceExportResources::default();
let mut id_map: HashMap<String, String> = HashMap::new();
@@ -806,7 +795,9 @@ async fn cmd_import_data(
}
}
for mut v in import_result.resources.workspaces {
let resources = import_result.resources;
for mut v in resources.workspaces {
v.id = maybe_gen_id(v.id.as_str(), ModelType::TypeWorkspace, &mut id_map);
let x = upsert_workspace(&w, v).await.map_err(|e| e.to_string())?;
imported_resources.workspaces.push(x.clone());
@@ -816,7 +807,7 @@ async fn cmd_import_data(
imported_resources.workspaces.len()
);
for mut v in import_result.resources.environments {
for mut v in resources.environments {
v.id = maybe_gen_id(v.id.as_str(), ModelType::TypeEnvironment, &mut id_map);
v.workspace_id = maybe_gen_id(
v.workspace_id.as_str(),
@@ -831,7 +822,7 @@ async fn cmd_import_data(
imported_resources.environments.len()
);
for mut v in import_result.resources.folders {
for mut v in resources.folders {
v.id = maybe_gen_id(v.id.as_str(), ModelType::TypeFolder, &mut id_map);
v.workspace_id = maybe_gen_id(
v.workspace_id.as_str(),
@@ -844,7 +835,7 @@ async fn cmd_import_data(
}
info!("Imported {} folders", imported_resources.folders.len());
for mut v in import_result.resources.http_requests {
for mut v in resources.http_requests {
v.id = maybe_gen_id(v.id.as_str(), ModelType::TypeHttpRequest, &mut id_map);
v.workspace_id = maybe_gen_id(
v.workspace_id.as_str(),
@@ -862,7 +853,7 @@ async fn cmd_import_data(
imported_resources.http_requests.len()
);
for mut v in import_result.resources.grpc_requests {
for mut v in resources.grpc_requests {
v.id = maybe_gen_id(v.id.as_str(), ModelType::TypeGrpcRequest, &mut id_map);
v.workspace_id = maybe_gen_id(
v.workspace_id.as_str(),
@@ -880,6 +871,14 @@ async fn cmd_import_data(
imported_resources.grpc_requests.len()
);
analytics::track_event(
&w.app_handle(),
AnalyticsResource::App,
AnalyticsAction::Import,
Some(json!({ "plugin": plugin_name })),
)
.await;
Ok(imported_resources)
}
@@ -901,14 +900,14 @@ async fn cmd_request_to_curl(
.await
.map_err(|e| e.to_string())?;
let rendered = render_request(&request, &workspace, environment.as_ref());
let request_json = serde_json::to_string(&rendered).map_err(|e| e.to_string())?;
let import_response = plugin_manager
.lock()
.await
.run_export_curl(request_json.as_str())
.await?;
Ok(import_response.data)
.run_export_curl(&rendered)
.await
.map_err(|e| e.to_string())?;
Ok(import_response.content)
}
#[tauri::command]
@@ -916,10 +915,23 @@ async fn cmd_curl_to_request(
command: &str,
plugin_manager: State<'_, Mutex<PluginManager>>,
workspace_id: &str,
w: WebviewWindow,
) -> Result<HttpRequest, String> {
let import_response = plugin_manager.lock().await.run_import(command).await?;
let import_result: ImportResult =
serde_json::from_str(import_response.data.as_str()).map_err(|e| e.to_string())?;
let (import_result, plugin_name) = plugin_manager
.lock()
.await
.run_import(command)
.await
.map_err(|e| e.to_string())?;
analytics::track_event(
&w.app_handle(),
AnalyticsResource::App,
AnalyticsAction::Import,
Some(json!({ "plugin": plugin_name })),
)
.await;
import_result
.resources
.http_requests
@@ -946,6 +958,7 @@ async fn cmd_export_data(
.write(true)
.open(export_path)
.expect("Unable to create file");
serde_json::to_writer_pretty(&f, &export_data)
.map_err(|e| e.to_string())
.expect("Failed to write");
@@ -1590,6 +1603,7 @@ pub fn run() {
.level_for("cookie_store", log::LevelFilter::Info)
.level_for("h2", log::LevelFilter::Info)
.level_for("hyper", log::LevelFilter::Info)
.level_for("hyper_util", log::LevelFilter::Info)
.level_for("hyper_rustls", log::LevelFilter::Info)
.level_for("reqwest", log::LevelFilter::Info)
.level_for("sqlx", log::LevelFilter::Warn)
@@ -1615,8 +1629,8 @@ pub fn run() {
.plugin(tauri_plugin_dialog::init())
.plugin(tauri_plugin_os::init())
.plugin(tauri_plugin_fs::init())
.plugin(yaak_models::Builder::default().build())
.plugin(yaak_plugin_runtime::init());
.plugin(yaak_models::plugin::Builder::default().build())
.plugin(yaak_plugin_runtime::plugin::init());
#[cfg(target_os = "macos")]
{

View File

@@ -0,0 +1,13 @@
use thiserror::Error;
#[derive(Error, Debug)]
pub enum Error {
#[error("SQL error")]
SqlError(#[from] rusqlite::Error),
#[error("JSON error")]
JsonError(#[from] serde_json::Error),
#[error("unknown error")]
Unknown,
}
pub type Result<T> = std::result::Result<T, Error>;

View File

@@ -1,57 +1,5 @@
use std::env::current_dir;
use std::fs::create_dir_all;
use r2d2;
use r2d2_sqlite;
use log::info;
use r2d2::Pool;
use r2d2_sqlite::SqliteConnectionManager;
use serde::Deserialize;
use tauri::async_runtime::Mutex;
use tauri::plugin::TauriPlugin;
use tauri::{is_dev, plugin, Manager, Runtime};
pub mod models;
pub mod queries;
mod error;
pub struct SqliteConnection(Mutex<Pool<SqliteConnectionManager>>);
#[derive(Default, Deserialize)]
pub struct PluginConfig {
// Nothing yet (will be configurable in tauri.conf.json
}
/// Tauri SQL plugin builder.
#[derive(Default)]
pub struct Builder {
// Nothing Yet
}
impl Builder {
pub fn new() -> Self {
Self::default()
}
pub fn build<R: Runtime>(&self) -> TauriPlugin<R, Option<PluginConfig>> {
plugin::Builder::<R, Option<PluginConfig>>::new("yaak_models")
.setup(|app, _api| {
let app_path = match is_dev() {
true => current_dir().unwrap(),
false => app.path().app_data_dir().unwrap(),
};
create_dir_all(app_path.clone()).expect("Problem creating App directory!");
let db_file_path = app_path.join("db.sqlite");
info!("Opening SQLite DB at {db_file_path:?}");
let manager = SqliteConnectionManager::file(db_file_path);
let pool = Pool::new(manager).unwrap();
app.manage(SqliteConnection(Mutex::new(pool)));
Ok(())
})
.build()
}
}
pub mod plugin;

View File

@@ -8,7 +8,7 @@ use ts_rs::TS;
#[derive(Debug, Clone, Serialize, Deserialize, Default, TS)]
#[serde(default, rename_all = "camelCase")]
#[ts(export, export_to = "../../../src-web/lib/gen/")]
#[ts(export)]
pub struct Settings {
pub id: String,
#[ts(type = "\"settings\"")]
@@ -72,7 +72,7 @@ impl<'s> TryFrom<&Row<'s>> for Settings {
#[derive(Debug, Clone, Serialize, Deserialize, Default, TS)]
#[serde(default, rename_all = "camelCase")]
#[ts(export, export_to = "../../../plugin-runtime-types/src/gen/")]
#[ts(export)]
pub struct Workspace {
pub id: String,
#[ts(type = "\"workspace\"")]
@@ -140,7 +140,7 @@ impl Workspace {
}
#[derive(Debug, Clone, Serialize, Deserialize, TS)]
#[ts(export, export_to = "../../../src-web/lib/gen/")]
#[ts(export)]
enum CookieDomain {
HostOnly(String),
Suffix(String),
@@ -149,14 +149,14 @@ enum CookieDomain {
}
#[derive(Debug, Clone, Serialize, Deserialize, TS)]
#[ts(export, export_to = "../../../src-web/lib/gen/")]
#[ts(export)]
enum CookieExpires {
AtUtc(String),
SessionEnd,
}
#[derive(Debug, Clone, Serialize, Deserialize, TS)]
#[ts(export, export_to = "../../../src-web/lib/gen/")]
#[ts(export)]
pub struct Cookie {
raw_cookie: String,
domain: CookieDomain,
@@ -166,7 +166,7 @@ pub struct Cookie {
#[derive(Debug, Clone, Serialize, Deserialize, Default, TS)]
#[serde(default, rename_all = "camelCase")]
#[ts(export, export_to = "../../../src-web/lib/gen/")]
#[ts(export)]
pub struct CookieJar {
pub id: String,
#[ts(type = "\"cookie_jar\"")]
@@ -210,7 +210,7 @@ impl<'s> TryFrom<&Row<'s>> for CookieJar {
#[derive(Debug, Clone, Serialize, Deserialize, Default, TS)]
#[serde(default, rename_all = "camelCase")]
#[ts(export, export_to = "../../../plugin-runtime-types/src/gen/")]
#[ts(export)]
pub struct Environment {
pub id: String,
pub workspace_id: String,
@@ -254,7 +254,7 @@ impl<'s> TryFrom<&Row<'s>> for Environment {
#[derive(Debug, Clone, Serialize, Deserialize, Default, TS)]
#[serde(default, rename_all = "camelCase")]
#[ts(export, export_to = "../../../plugin-runtime-types/src/gen/")]
#[ts(export)]
pub struct EnvironmentVariable {
#[serde(default = "default_true")]
#[ts(optional, as = "Option<bool>")]
@@ -265,7 +265,7 @@ pub struct EnvironmentVariable {
#[derive(Debug, Clone, Serialize, Deserialize, Default, TS)]
#[serde(default, rename_all = "camelCase")]
#[ts(export, export_to = "../../../plugin-runtime-types/src/gen/")]
#[ts(export)]
pub struct Folder {
pub created_at: NaiveDateTime,
pub updated_at: NaiveDateTime,
@@ -311,7 +311,7 @@ impl<'s> TryFrom<&Row<'s>> for Folder {
#[derive(Debug, Clone, Serialize, Deserialize, Default, TS)]
#[serde(default, rename_all = "camelCase")]
#[ts(export, export_to = "../../../plugin-runtime-types/src/gen/")]
#[ts(export)]
pub struct HttpRequestHeader {
#[serde(default = "default_true")]
#[ts(optional, as = "Option<bool>")]
@@ -322,7 +322,7 @@ pub struct HttpRequestHeader {
#[derive(Debug, Clone, Serialize, Deserialize, Default, TS)]
#[serde(default, rename_all = "camelCase")]
#[ts(export, export_to = "../../../plugin-runtime-types/src/gen/")]
#[ts(export)]
pub struct HttpUrlParameter {
#[serde(default = "default_true")]
#[ts(optional, as = "Option<bool>")]
@@ -333,7 +333,7 @@ pub struct HttpUrlParameter {
#[derive(Debug, Clone, Serialize, Deserialize, Default, TS)]
#[serde(default, rename_all = "camelCase")]
#[ts(export, export_to = "../../../plugin-runtime-types/src/gen/")]
#[ts(export)]
pub struct HttpRequest {
pub created_at: NaiveDateTime,
pub updated_at: NaiveDateTime,
@@ -348,7 +348,7 @@ pub struct HttpRequest {
pub url_parameters: Vec<HttpUrlParameter>,
#[serde(default = "default_http_request_method")]
pub method: String,
#[ts(type = "Record<string, any>")]
#[ts(type = "Record<string, any>")]
pub body: HashMap<String, Value>,
pub body_type: Option<String>,
#[ts(type = "Record<string, any>")]
@@ -410,7 +410,7 @@ impl<'s> TryFrom<&Row<'s>> for HttpRequest {
#[derive(Debug, Clone, Serialize, Deserialize, Default, TS)]
#[serde(default, rename_all = "camelCase")]
#[ts(export, export_to = "../../../plugin-runtime-types/src/gen/")]
#[ts(export)]
pub struct HttpResponseHeader {
pub name: String,
pub value: String,
@@ -418,7 +418,7 @@ pub struct HttpResponseHeader {
#[derive(Debug, Clone, Serialize, Deserialize, Default, TS)]
#[serde(default, rename_all = "camelCase")]
#[ts(export, export_to = "../../../plugin-runtime-types/src/gen/")]
#[ts(export)]
pub struct HttpResponse {
pub id: String,
#[ts(type = "\"http_response\"")]
@@ -501,7 +501,7 @@ impl HttpResponse {
#[derive(Debug, Clone, Serialize, Deserialize, Default, TS)]
#[serde(default, rename_all = "camelCase")]
#[ts(export, export_to = "../../../plugin-runtime-types/src/gen/")]
#[ts(export)]
pub struct GrpcMetadataEntry {
#[serde(default = "default_true")]
#[ts(optional, as = "Option<bool>")]
@@ -512,7 +512,7 @@ pub struct GrpcMetadataEntry {
#[derive(Debug, Clone, Serialize, Deserialize, Default, TS)]
#[serde(default, rename_all = "camelCase")]
#[ts(export, export_to = "../../../plugin-runtime-types/src/gen/")]
#[ts(export)]
pub struct GrpcRequest {
pub id: String,
#[ts(type = "\"grpc_request\"")]
@@ -582,7 +582,7 @@ impl<'s> TryFrom<&Row<'s>> for GrpcRequest {
#[derive(Debug, Clone, Serialize, Deserialize, Default, TS)]
#[serde(default, rename_all = "camelCase")]
#[ts(export, export_to = "../../../plugin-runtime-types/src/gen/")]
#[ts(export)]
pub struct GrpcConnection {
pub id: String,
#[ts(type = "\"grpc_connection\"")]
@@ -644,7 +644,7 @@ impl<'s> TryFrom<&Row<'s>> for GrpcConnection {
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, TS)]
#[serde(rename_all = "snake_case")]
#[ts(export, export_to = "../../../plugin-runtime-types/src/gen/")]
#[ts(export)]
pub enum GrpcEventType {
Info,
Error,
@@ -662,7 +662,7 @@ impl Default for GrpcEventType {
#[derive(Debug, Clone, Serialize, Deserialize, Default, TS)]
#[serde(default, rename_all = "camelCase")]
#[ts(export, export_to = "../../../plugin-runtime-types/src/gen/")]
#[ts(export)]
pub struct GrpcEvent {
pub id: String,
#[ts(type = "\"grpc_event\"")]
@@ -720,7 +720,7 @@ impl<'s> TryFrom<&Row<'s>> for GrpcEvent {
#[derive(Debug, Clone, Serialize, Deserialize, Default, TS)]
#[serde(default, rename_all = "camelCase")]
#[ts(export, export_to = "../../../plugin-runtime-types/src/gen/")]
#[ts(export)]
pub struct KeyValue {
#[ts(type = "\"key_value\"")]
pub model: String,

View File

@@ -0,0 +1,56 @@
use log::info;
use r2d2::Pool;
use r2d2_sqlite::SqliteConnectionManager;
use serde::Deserialize;
use std::env::current_dir;
use std::fs::create_dir_all;
use std::time::Duration;
use tauri::async_runtime::Mutex;
use tauri::plugin::TauriPlugin;
use tauri::{is_dev, plugin, Manager, Runtime};
pub struct SqliteConnection(pub Mutex<Pool<SqliteConnectionManager>>);
#[derive(Default, Deserialize)]
pub struct PluginConfig {
// Nothing yet (will be configurable in tauri.conf.json
}
/// Tauri SQL plugin builder.
#[derive(Default)]
pub struct Builder {
// Nothing Yet
}
impl Builder {
pub fn new() -> Self {
Self::default()
}
pub fn build<R: Runtime>(&self) -> TauriPlugin<R, Option<PluginConfig>> {
plugin::Builder::<R, Option<PluginConfig>>::new("yaak_models")
.setup(|app, _api| {
let app_path = match is_dev() {
true => current_dir().unwrap(),
false => app.path().app_data_dir().unwrap(),
};
create_dir_all(app_path.clone()).expect("Problem creating App directory!");
let db_file_path = app_path.join("db.sqlite");
info!("Opening SQLite DB at {db_file_path:?}");
let manager = SqliteConnectionManager::file(db_file_path);
let pool = Pool::builder()
.max_size(1000) // Up from 10 (just in case)
.connection_timeout(Duration::from_secs(10))
.build(manager)
.unwrap();
app.manage(SqliteConnection(Mutex::new(pool)));
Ok(())
})
.build()
}
}

View File

@@ -1,6 +1,14 @@
use std::fs;
use log::error;
use crate::error::Result;
use crate::models::{
CookieJar, CookieJarIden, Environment, EnvironmentIden, Folder, FolderIden, GrpcConnection,
GrpcConnectionIden, GrpcEvent, GrpcEventIden, GrpcRequest, GrpcRequestIden, HttpRequest,
HttpRequestIden, HttpResponse, HttpResponseHeader, HttpResponseIden, KeyValue, KeyValueIden,
ModelType, Settings, SettingsIden, Workspace, WorkspaceIden,
};
use crate::plugin::SqliteConnection;
use log::{debug, error};
use rand::distributions::{Alphanumeric, DistString};
use sea_query::ColumnRef::Asterisk;
use sea_query::Keyword::CurrentTimestamp;
@@ -8,25 +16,6 @@ use sea_query::{Cond, Expr, OnConflict, Order, Query, SqliteQueryBuilder};
use sea_query_rusqlite::RusqliteBinder;
use serde::Serialize;
use tauri::{AppHandle, Emitter, Manager, WebviewWindow, Wry};
use thiserror::Error;
use crate::models::{
CookieJar, CookieJarIden, Environment, EnvironmentIden, Folder, FolderIden, GrpcConnection,
GrpcConnectionIden, GrpcEvent, GrpcEventIden, GrpcRequest, GrpcRequestIden, HttpRequest,
HttpRequestIden, HttpResponse, HttpResponseHeader, HttpResponseIden, KeyValue, KeyValueIden,
ModelType, Settings, SettingsIden, Workspace, WorkspaceIden,
};
use crate::SqliteConnection;
#[derive(Error, Debug)]
pub enum DBError {
#[error("SQL error")]
SqlError(#[from] rusqlite::Error),
#[error("JSON error")]
JsonError(#[from] serde_json::Error),
#[error("unknown error")]
Unknown,
}
pub async fn set_key_value_string(
mgr: &impl Manager<Wry>,
@@ -96,9 +85,10 @@ pub async fn set_key_value_raw(
key: &str,
value: &str,
) -> (KeyValue, bool) {
let existing = get_key_value_raw(mgr, namespace, key).await;
let dbm = &*mgr.state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
let existing = get_key_value_raw(mgr, namespace, key).await;
let (sql, params) = Query::insert()
.into_table(KeyValueIden::Table)
.columns([
@@ -153,7 +143,7 @@ pub async fn get_key_value_raw(
.ok()
}
pub async fn list_workspaces(mgr: &impl Manager<Wry>) -> Result<Vec<Workspace>, DBError> {
pub async fn list_workspaces(mgr: &impl Manager<Wry>) -> Result<Vec<Workspace>> {
let dbm = &*mgr.state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
let (sql, params) = Query::select()
@@ -165,7 +155,7 @@ pub async fn list_workspaces(mgr: &impl Manager<Wry>) -> Result<Vec<Workspace>,
Ok(items.map(|v| v.unwrap()).collect())
}
pub async fn get_workspace(mgr: &impl Manager<Wry>, id: &str) -> Result<Workspace, DBError> {
pub async fn get_workspace(mgr: &impl Manager<Wry>, id: &str) -> Result<Workspace> {
let dbm = &*mgr.state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
let (sql, params) = Query::select()
@@ -177,10 +167,7 @@ pub async fn get_workspace(mgr: &impl Manager<Wry>, id: &str) -> Result<Workspac
Ok(stmt.query_row(&*params.as_params(), |row| row.try_into())?)
}
pub async fn upsert_workspace(
window: &WebviewWindow,
workspace: Workspace,
) -> Result<Workspace, DBError> {
pub async fn upsert_workspace(window: &WebviewWindow, workspace: Workspace) -> Result<Workspace> {
let id = match workspace.id.as_str() {
"" => generate_model_id(ModelType::TypeWorkspace),
_ => workspace.id.to_string(),
@@ -235,10 +222,11 @@ pub async fn upsert_workspace(
Ok(emit_upserted_model(window, m))
}
pub async fn delete_workspace(window: &WebviewWindow, id: &str) -> Result<Workspace, DBError> {
pub async fn delete_workspace(window: &WebviewWindow, id: &str) -> Result<Workspace> {
let workspace = get_workspace(window, id).await?;
let dbm = &*window.app_handle().state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
let workspace = get_workspace(window, id).await?;
let (sql, params) = Query::delete()
.from_table(WorkspaceIden::Table)
@@ -253,7 +241,7 @@ pub async fn delete_workspace(window: &WebviewWindow, id: &str) -> Result<Worksp
emit_deleted_model(window, workspace)
}
pub async fn get_cookie_jar(mgr: &impl Manager<Wry>, id: &str) -> Result<CookieJar, DBError> {
pub async fn get_cookie_jar(mgr: &impl Manager<Wry>, id: &str) -> Result<CookieJar> {
let dbm = &*mgr.state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
@@ -269,7 +257,7 @@ pub async fn get_cookie_jar(mgr: &impl Manager<Wry>, id: &str) -> Result<CookieJ
pub async fn list_cookie_jars(
mgr: &impl Manager<Wry>,
workspace_id: &str,
) -> Result<Vec<CookieJar>, DBError> {
) -> Result<Vec<CookieJar>> {
let dbm = &*mgr.state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
let (sql, params) = Query::select()
@@ -282,7 +270,7 @@ pub async fn list_cookie_jars(
Ok(items.map(|v| v.unwrap()).collect())
}
pub async fn delete_cookie_jar(window: &WebviewWindow, id: &str) -> Result<CookieJar, DBError> {
pub async fn delete_cookie_jar(window: &WebviewWindow, id: &str) -> Result<CookieJar> {
let cookie_jar = get_cookie_jar(window, id).await?;
let dbm = &*window.app_handle().state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
@@ -296,16 +284,13 @@ pub async fn delete_cookie_jar(window: &WebviewWindow, id: &str) -> Result<Cooki
emit_deleted_model(window, cookie_jar)
}
pub async fn duplicate_grpc_request(
window: &WebviewWindow,
id: &str,
) -> Result<GrpcRequest, DBError> {
pub async fn duplicate_grpc_request(window: &WebviewWindow, id: &str) -> Result<GrpcRequest> {
let mut request = get_grpc_request(window, id).await?.clone();
request.id = "".to_string();
upsert_grpc_request(window, &request).await
}
pub async fn delete_grpc_request(window: &WebviewWindow, id: &str) -> Result<GrpcRequest, DBError> {
pub async fn delete_grpc_request(window: &WebviewWindow, id: &str) -> Result<GrpcRequest> {
let req = get_grpc_request(window, id).await?;
let dbm = &*window.app_handle().state::<SqliteConnection>();
@@ -322,15 +307,15 @@ pub async fn delete_grpc_request(window: &WebviewWindow, id: &str) -> Result<Grp
pub async fn upsert_grpc_request(
window: &WebviewWindow,
request: &GrpcRequest,
) -> Result<GrpcRequest, DBError> {
let dbm = &*window.app_handle().state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
) -> Result<GrpcRequest> {
let id = match request.id.as_str() {
"" => generate_model_id(ModelType::TypeGrpcRequest),
_ => request.id.to_string(),
};
let trimmed_name = request.name.trim();
let dbm = &*window.app_handle().state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
let (sql, params) = Query::insert()
.into_table(GrpcRequestIden::Table)
.columns([
@@ -396,7 +381,7 @@ pub async fn upsert_grpc_request(
Ok(emit_upserted_model(window, m))
}
pub async fn get_grpc_request(mgr: &impl Manager<Wry>, id: &str) -> Result<GrpcRequest, DBError> {
pub async fn get_grpc_request(mgr: &impl Manager<Wry>, id: &str) -> Result<GrpcRequest> {
let dbm = &*mgr.state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
@@ -412,7 +397,7 @@ pub async fn get_grpc_request(mgr: &impl Manager<Wry>, id: &str) -> Result<GrpcR
pub async fn list_grpc_requests(
mgr: &impl Manager<Wry>,
workspace_id: &str,
) -> Result<Vec<GrpcRequest>, DBError> {
) -> Result<Vec<GrpcRequest>> {
let dbm = &*mgr.state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
let (sql, params) = Query::select()
@@ -428,13 +413,13 @@ pub async fn list_grpc_requests(
pub async fn upsert_grpc_connection(
window: &WebviewWindow,
connection: &GrpcConnection,
) -> Result<GrpcConnection, DBError> {
let dbm = &*window.app_handle().state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
) -> Result<GrpcConnection> {
let id = match connection.id.as_str() {
"" => generate_model_id(ModelType::TypeGrpcConnection),
_ => connection.id.to_string(),
};
let dbm = &*window.app_handle().state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
let (sql, params) = Query::insert()
.into_table(GrpcConnectionIden::Table)
.columns([
@@ -487,10 +472,7 @@ pub async fn upsert_grpc_connection(
Ok(emit_upserted_model(window, m))
}
pub async fn get_grpc_connection(
mgr: &impl Manager<Wry>,
id: &str,
) -> Result<GrpcConnection, DBError> {
pub async fn get_grpc_connection(mgr: &impl Manager<Wry>, id: &str) -> Result<GrpcConnection> {
let dbm = &*mgr.state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
let (sql, params) = Query::select()
@@ -505,7 +487,7 @@ pub async fn get_grpc_connection(
pub async fn list_grpc_connections(
mgr: &impl Manager<Wry>,
request_id: &str,
) -> Result<Vec<GrpcConnection>, DBError> {
) -> Result<Vec<GrpcConnection>> {
let dbm = &*mgr.state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
@@ -520,10 +502,7 @@ pub async fn list_grpc_connections(
Ok(items.map(|v| v.unwrap()).collect())
}
pub async fn delete_grpc_connection(
window: &WebviewWindow,
id: &str,
) -> Result<GrpcConnection, DBError> {
pub async fn delete_grpc_connection(window: &WebviewWindow, id: &str) -> Result<GrpcConnection> {
let resp = get_grpc_connection(window, id).await?;
let dbm = &*window.app_handle().state::<SqliteConnection>();
@@ -538,27 +517,21 @@ pub async fn delete_grpc_connection(
emit_deleted_model(window, resp)
}
pub async fn delete_all_grpc_connections(
window: &WebviewWindow,
request_id: &str,
) -> Result<(), DBError> {
pub async fn delete_all_grpc_connections(window: &WebviewWindow, request_id: &str) -> Result<()> {
for r in list_grpc_connections(window, request_id).await? {
delete_grpc_connection(window, &r.id).await?;
}
Ok(())
}
pub async fn upsert_grpc_event(
window: &WebviewWindow,
event: &GrpcEvent,
) -> Result<GrpcEvent, DBError> {
let dbm = &*window.app_handle().state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
pub async fn upsert_grpc_event(window: &WebviewWindow, event: &GrpcEvent) -> Result<GrpcEvent> {
let id = match event.id.as_str() {
"" => generate_model_id(ModelType::TypeGrpcEvent),
_ => event.id.to_string(),
};
let dbm = &*window.app_handle().state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
let (sql, params) = Query::insert()
.into_table(GrpcEventIden::Table)
.columns([
@@ -607,7 +580,7 @@ pub async fn upsert_grpc_event(
Ok(emit_upserted_model(window, m))
}
pub async fn get_grpc_event(mgr: &impl Manager<Wry>, id: &str) -> Result<GrpcEvent, DBError> {
pub async fn get_grpc_event(mgr: &impl Manager<Wry>, id: &str) -> Result<GrpcEvent> {
let dbm = &*mgr.state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
let (sql, params) = Query::select()
@@ -622,7 +595,7 @@ pub async fn get_grpc_event(mgr: &impl Manager<Wry>, id: &str) -> Result<GrpcEve
pub async fn list_grpc_events(
mgr: &impl Manager<Wry>,
connection_id: &str,
) -> Result<Vec<GrpcEvent>, DBError> {
) -> Result<Vec<GrpcEvent>> {
let dbm = &*mgr.state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
@@ -640,7 +613,7 @@ pub async fn list_grpc_events(
pub async fn upsert_cookie_jar(
window: &WebviewWindow,
cookie_jar: &CookieJar,
) -> Result<CookieJar, DBError> {
) -> Result<CookieJar> {
let id = match cookie_jar.id.as_str() {
"" => generate_model_id(ModelType::TypeCookieJar),
_ => cookie_jar.id.to_string(),
@@ -688,7 +661,7 @@ pub async fn upsert_cookie_jar(
pub async fn list_environments(
mgr: &impl Manager<Wry>,
workspace_id: &str,
) -> Result<Vec<Environment>, DBError> {
) -> Result<Vec<Environment>> {
let dbm = &*mgr.state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
@@ -703,11 +676,12 @@ pub async fn list_environments(
Ok(items.map(|v| v.unwrap()).collect())
}
pub async fn delete_environment(window: &WebviewWindow, id: &str) -> Result<Environment, DBError> {
let dbm = &*window.app_handle().state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
pub async fn delete_environment(window: &WebviewWindow, id: &str) -> Result<Environment> {
let env = get_environment(window, id).await?;
let dbm = &*window.app_handle().state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
let (sql, params) = Query::delete()
.from_table(EnvironmentIden::Table)
.cond_where(Expr::col(EnvironmentIden::Id).eq(id))
@@ -717,7 +691,7 @@ pub async fn delete_environment(window: &WebviewWindow, id: &str) -> Result<Envi
emit_deleted_model(window, env)
}
async fn get_settings(mgr: &impl Manager<Wry>) -> Result<Settings, DBError> {
async fn get_settings(mgr: &impl Manager<Wry>) -> Result<Settings> {
let dbm = &*mgr.state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
@@ -752,10 +726,7 @@ pub async fn get_or_create_settings(mgr: &impl Manager<Wry>) -> Settings {
.expect("Failed to insert Settings")
}
pub async fn update_settings(
window: &WebviewWindow,
settings: Settings,
) -> Result<Settings, DBError> {
pub async fn update_settings(window: &WebviewWindow, settings: Settings) -> Result<Settings> {
let dbm = &*window.app_handle().state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
@@ -810,7 +781,7 @@ pub async fn update_settings(
pub async fn upsert_environment(
window: &WebviewWindow,
environment: Environment,
) -> Result<Environment, DBError> {
) -> Result<Environment> {
let id = match environment.id.as_str() {
"" => generate_model_id(ModelType::TypeEnvironment),
_ => environment.id.to_string(),
@@ -857,7 +828,7 @@ pub async fn upsert_environment(
Ok(emit_upserted_model(window, m))
}
pub async fn get_environment(mgr: &impl Manager<Wry>, id: &str) -> Result<Environment, DBError> {
pub async fn get_environment(mgr: &impl Manager<Wry>, id: &str) -> Result<Environment> {
let dbm = &*mgr.state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
@@ -870,7 +841,7 @@ pub async fn get_environment(mgr: &impl Manager<Wry>, id: &str) -> Result<Enviro
Ok(stmt.query_row(&*params.as_params(), |row| row.try_into())?)
}
pub async fn get_folder(mgr: &impl Manager<Wry>, id: &str) -> Result<Folder, DBError> {
pub async fn get_folder(mgr: &impl Manager<Wry>, id: &str) -> Result<Folder> {
let dbm = &*mgr.state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
@@ -883,10 +854,7 @@ pub async fn get_folder(mgr: &impl Manager<Wry>, id: &str) -> Result<Folder, DBE
Ok(stmt.query_row(&*params.as_params(), |row| row.try_into())?)
}
pub async fn list_folders(
mgr: &impl Manager<Wry>,
workspace_id: &str,
) -> Result<Vec<Folder>, DBError> {
pub async fn list_folders(mgr: &impl Manager<Wry>, workspace_id: &str) -> Result<Vec<Folder>> {
let dbm = &*mgr.state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
@@ -901,8 +869,9 @@ pub async fn list_folders(
Ok(items.map(|v| v.unwrap()).collect())
}
pub async fn delete_folder(window: &WebviewWindow, id: &str) -> Result<Folder, DBError> {
pub async fn delete_folder(window: &WebviewWindow, id: &str) -> Result<Folder> {
let folder = get_folder(window, id).await?;
let dbm = &*window.app_handle().state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
@@ -915,7 +884,7 @@ pub async fn delete_folder(window: &WebviewWindow, id: &str) -> Result<Folder, D
emit_deleted_model(window, folder)
}
pub async fn upsert_folder(window: &WebviewWindow, r: Folder) -> Result<Folder, DBError> {
pub async fn upsert_folder(window: &WebviewWindow, r: Folder) -> Result<Folder> {
let id = match r.id.as_str() {
"" => generate_model_id(ModelType::TypeFolder),
_ => r.id.to_string(),
@@ -941,6 +910,7 @@ pub async fn upsert_folder(window: &WebviewWindow, r: Folder) -> Result<Folder,
CurrentTimestamp.into(),
CurrentTimestamp.into(),
r.workspace_id.as_str().into(),
r.folder_id.as_ref().map(|s| s.as_str()).into(),
trimmed_name.into(),
r.sort_priority.into(),
])
@@ -962,19 +932,13 @@ pub async fn upsert_folder(window: &WebviewWindow, r: Folder) -> Result<Folder,
Ok(emit_upserted_model(window, m))
}
pub async fn duplicate_http_request(
window: &WebviewWindow,
id: &str,
) -> Result<HttpRequest, DBError> {
pub async fn duplicate_http_request(window: &WebviewWindow, id: &str) -> Result<HttpRequest> {
let mut request = get_http_request(window, id).await?.clone();
request.id = "".to_string();
upsert_http_request(window, request).await
}
pub async fn upsert_http_request(
window: &WebviewWindow,
r: HttpRequest,
) -> Result<HttpRequest, DBError> {
pub async fn upsert_http_request(window: &WebviewWindow, r: HttpRequest) -> Result<HttpRequest> {
let id = match r.id.as_str() {
"" => generate_model_id(ModelType::TypeHttpRequest),
_ => r.id.to_string(),
@@ -1050,7 +1014,7 @@ pub async fn upsert_http_request(
pub async fn list_http_requests(
mgr: &impl Manager<Wry>,
workspace_id: &str,
) -> Result<Vec<HttpRequest>, DBError> {
) -> Result<Vec<HttpRequest>> {
let dbm = &*mgr.state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
let (sql, params) = Query::select()
@@ -1064,7 +1028,7 @@ pub async fn list_http_requests(
Ok(items.map(|v| v.unwrap()).collect())
}
pub async fn get_http_request(mgr: &impl Manager<Wry>, id: &str) -> Result<HttpRequest, DBError> {
pub async fn get_http_request(mgr: &impl Manager<Wry>, id: &str) -> Result<HttpRequest> {
let dbm = &*mgr.state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
@@ -1077,7 +1041,7 @@ pub async fn get_http_request(mgr: &impl Manager<Wry>, id: &str) -> Result<HttpR
Ok(stmt.query_row(&*params.as_params(), |row| row.try_into())?)
}
pub async fn delete_http_request(window: &WebviewWindow, id: &str) -> Result<HttpRequest, DBError> {
pub async fn delete_http_request(window: &WebviewWindow, id: &str) -> Result<HttpRequest> {
let req = get_http_request(window, id).await?;
// DB deletes will cascade but this will delete the files
@@ -1108,7 +1072,7 @@ pub async fn create_http_response(
headers: Vec<HttpResponseHeader>,
version: Option<&str>,
remote_addr: Option<&str>,
) -> Result<HttpResponse, DBError> {
) -> Result<HttpResponse> {
let req = get_http_request(window, request_id).await?;
let id = generate_model_id(ModelType::TypeHttpResponse);
let dbm = &*window.app_handle().state::<SqliteConnection>();
@@ -1158,7 +1122,7 @@ pub async fn create_http_response(
Ok(emit_upserted_model(window, m))
}
pub async fn cancel_pending_grpc_connections(app: &AppHandle) -> Result<(), DBError> {
pub async fn cancel_pending_grpc_connections(app: &AppHandle) -> Result<()> {
let dbm = &*app.app_handle().state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
@@ -1172,7 +1136,7 @@ pub async fn cancel_pending_grpc_connections(app: &AppHandle) -> Result<(), DBEr
Ok(())
}
pub async fn cancel_pending_responses(app: &AppHandle) -> Result<(), DBError> {
pub async fn cancel_pending_responses(app: &AppHandle) -> Result<()> {
let dbm = &*app.app_handle().state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
@@ -1192,7 +1156,7 @@ pub async fn cancel_pending_responses(app: &AppHandle) -> Result<(), DBError> {
pub async fn update_response_if_id(
window: &WebviewWindow,
response: &HttpResponse,
) -> Result<HttpResponse, DBError> {
) -> Result<HttpResponse> {
if response.id.is_empty() {
Ok(response.clone())
} else {
@@ -1203,7 +1167,7 @@ pub async fn update_response_if_id(
pub async fn update_response(
window: &WebviewWindow,
response: &HttpResponse,
) -> Result<HttpResponse, DBError> {
) -> Result<HttpResponse> {
let dbm = &*window.app_handle().state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
@@ -1254,7 +1218,7 @@ pub async fn update_response(
Ok(emit_upserted_model(window, m))
}
pub async fn get_http_response(mgr: &impl Manager<Wry>, id: &str) -> Result<HttpResponse, DBError> {
pub async fn get_http_response(mgr: &impl Manager<Wry>, id: &str) -> Result<HttpResponse> {
let dbm = &*mgr.state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
let (sql, params) = Query::select()
@@ -1266,10 +1230,7 @@ pub async fn get_http_response(mgr: &impl Manager<Wry>, id: &str) -> Result<Http
Ok(stmt.query_row(&*params.as_params(), |row| row.try_into())?)
}
pub async fn delete_http_response(
window: &WebviewWindow,
id: &str,
) -> Result<HttpResponse, DBError> {
pub async fn delete_http_response(window: &WebviewWindow, id: &str) -> Result<HttpResponse> {
let resp = get_http_response(window, id).await?;
// Delete the body file if it exists
@@ -1290,10 +1251,7 @@ pub async fn delete_http_response(
emit_deleted_model(window, resp)
}
pub async fn delete_all_http_responses(
window: &WebviewWindow,
request_id: &str,
) -> Result<(), DBError> {
pub async fn delete_all_http_responses(window: &WebviewWindow, request_id: &str) -> Result<()> {
for r in list_responses(window, request_id, None).await? {
delete_http_response(window, &r.id).await?;
}
@@ -1304,7 +1262,7 @@ pub async fn list_responses(
mgr: &impl Manager<Wry>,
request_id: &str,
limit: Option<i64>,
) -> Result<Vec<HttpResponse>, DBError> {
) -> Result<Vec<HttpResponse>> {
let limit_unwrapped = limit.unwrap_or_else(|| i64::MAX);
let dbm = mgr.state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
@@ -1323,7 +1281,7 @@ pub async fn list_responses(
pub async fn list_responses_by_workspace_id(
mgr: &impl Manager<Wry>,
workspace_id: &str,
) -> Result<Vec<HttpResponse>, DBError> {
) -> Result<Vec<HttpResponse>> {
let dbm = &*mgr.state::<SqliteConnection>();
let db = dbm.0.lock().await.get().unwrap();
let (sql, params) = Query::select()
@@ -1337,6 +1295,12 @@ pub async fn list_responses_by_workspace_id(
Ok(items.map(|v| v.unwrap()).collect())
}
pub async fn debug_pool(mgr: &impl Manager<Wry>) {
let dbm = &*mgr.state::<SqliteConnection>();
let db = dbm.0.lock().await;
debug!("Debug database state: {:?}", db.state());
}
pub fn generate_model_id(model: ModelType) -> String {
let id = generate_id();
format!("{}_{}", model.id_prefix(), id)
@@ -1363,7 +1327,7 @@ fn emit_upserted_model<M: Serialize + Clone>(window: &WebviewWindow, model: M) -
model
}
fn emit_deleted_model<M: Serialize + Clone, E>(window: &WebviewWindow, model: M) -> Result<M, E> {
fn emit_deleted_model<M: Serialize + Clone>(window: &WebviewWindow, model: M) -> Result<M> {
let payload = ModelPayload {
model: model.clone(),
window_label: window.label().to_string(),

View File

@@ -17,6 +17,9 @@ tauri = { workspace = true }
tauri-plugin-shell = { workspace = true }
tokio = { version = "1.0", features = ["macros", "rt-multi-thread", "process"] }
tonic = "0.12.1"
ts-rs = "9.0.1"
thiserror = "1.0.63"
yaak_models = {workspace = true}
[build-dependencies]
tonic-build = "0.12.1"

View File

@@ -1,4 +1,9 @@
fn main() -> Result<(), Box<dyn std::error::Error>> {
// Tell ts-rs where to generate types to
println!("cargo:rustc-env=TS_RS_EXPORT_DIR=../../plugin-runtime-types/src/gen");
// Compile protobuf types
tonic_build::compile_protos("../../proto/plugins/runtime.proto")?;
Ok(())
}

View File

@@ -0,0 +1,40 @@
use thiserror::Error;
use tokio::io;
use tokio::sync::mpsc::error::SendError;
use crate::server::plugin_runtime::EventStreamEvent;
#[derive(Error, Debug)]
pub enum Error {
#[error("IO error")]
IoErr(#[from] io::Error),
#[error("Tauri error")]
TauriErr(#[from] tauri::Error),
#[error("Tauri shell error")]
TauriShellErr(#[from] tauri_plugin_shell::Error),
#[error("Grpc transport error")]
GrpcTransportErr(#[from] tonic::transport::Error),
#[error("Grpc send error")]
GrpcSendErr(#[from] SendError<tonic::Result<EventStreamEvent>>),
#[error("JSON error")]
JsonErr(#[from] serde_json::Error),
#[error("Plugin not found error")]
PluginNotFoundErr(String),
#[error("unknown error")]
MissingCallbackIdErr(String),
#[error("Missing callback ID error")]
MissingCallbackErr(String),
#[error("No plugins found")]
NoPluginsErr(String),
#[error("Plugin error")]
PluginErr(String),
#[error("Unknown error")]
UnknownErr(String),
}
impl Into<String> for Error {
fn into(self) -> String {
todo!()
}
}
pub type Result<T> = std::result::Result<T, Error>;

View File

@@ -0,0 +1,132 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
use ts_rs::TS;
use yaak_models::models::{Environment, Folder, GrpcRequest, HttpRequest, HttpResponse, Workspace};
#[derive(Debug, Clone, Serialize, Deserialize, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub struct InternalEvent {
pub id: String,
pub plugin_ref_id: String,
pub reply_id: Option<String>,
pub payload: InternalEventPayload,
}
#[derive(Debug, Clone, Serialize, Deserialize, TS)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
#[ts(export)]
pub enum InternalEventPayload {
BootRequest(BootRequest),
BootResponse(BootResponse),
ImportRequest(ImportRequest),
ImportResponse(ImportResponse),
FilterRequest(FilterRequest),
FilterResponse(FilterResponse),
ExportHttpRequestRequest(ExportHttpRequestRequest),
ExportHttpRequestResponse(ExportHttpRequestResponse),
/// Returned when a plugin doesn't get run, just so the server
/// has something to listen for
EmptyResponse(EmptyResponse),
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, TS)]
#[serde(default)]
#[ts(export, type = "{}")]
pub struct EmptyResponse {}
#[derive(Debug, Clone, Default, Serialize, Deserialize, TS)]
#[serde(default, rename_all = "camelCase")]
#[ts(export)]
pub struct BootRequest {
pub dir: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, TS)]
#[serde(default, rename_all = "camelCase")]
#[ts(export)]
pub struct BootResponse {
pub name: String,
pub version: String,
pub capabilities: Vec<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, TS)]
#[serde(default, rename_all = "camelCase")]
#[ts(export)]
pub struct ImportRequest {
pub content: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, TS)]
#[serde(default, rename_all = "camelCase")]
#[ts(export)]
pub struct ImportResponse {
pub resources: ImportResources,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, TS)]
#[serde(default, rename_all = "camelCase")]
#[ts(export)]
pub struct FilterRequest {
pub content: String,
pub filter: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, TS)]
#[serde(default, rename_all = "camelCase")]
#[ts(export)]
pub struct FilterResponse {
pub items: Vec<Value>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, TS)]
#[serde(default, rename_all = "camelCase")]
#[ts(export)]
pub struct ExportHttpRequestRequest {
pub http_request: HttpRequest,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, TS)]
#[serde(default, rename_all = "camelCase")]
#[ts(export)]
pub struct ExportHttpRequestResponse {
pub content: String,
}
// TODO: Migrate plugins to return this type
// #[derive(Debug, Clone, Serialize, Deserialize, TS)]
// #[serde(rename_all = "camelCase", untagged)]
// #[ts(export)]
// pub enum ExportableModel {
// Workspace(Workspace),
// Environment(Environment),
// Folder(Folder),
// HttpRequest(HttpRequest),
// GrpcRequest(GrpcRequest),
// }
#[derive(Debug, Clone, Default, Serialize, Deserialize, TS)]
#[serde(default, rename_all = "camelCase")]
#[ts(export)]
pub struct ImportResources {
pub workspaces: Vec<Workspace>,
pub environments: Vec<Environment>,
pub folders: Vec<Folder>,
pub http_requests: Vec<HttpRequest>,
pub grpc_requests: Vec<GrpcRequest>,
}
#[derive(Debug, Clone, Serialize, Deserialize, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub enum Model {
Workspace(Workspace),
Environment(Environment),
Folder(Folder),
HttpRequest(HttpRequest),
HttpResponse(HttpResponse),
GrpcRequest(GrpcRequest),
}

View File

@@ -1,41 +1,6 @@
extern crate core;
use crate::manager::PluginManager;
use log::info;
use std::process::exit;
use tauri::plugin::{Builder, TauriPlugin};
use tauri::{Manager, RunEvent, Runtime, State};
use tokio::sync::Mutex;
pub mod error;
mod events;
pub mod manager;
mod nodejs;
pub mod plugin_runtime {
tonic::include_proto!("yaak.plugins.runtime");
}
pub fn init<R: Runtime>() -> TauriPlugin<R> {
Builder::new("yaak_plugin_runtime")
.setup(|app, _| {
tauri::async_runtime::block_on(async move {
let manager = PluginManager::new(&app).await;
let manager_state = Mutex::new(manager);
app.manage(manager_state);
Ok(())
})
})
.on_event(|app, e| match e {
// TODO: Also exit when app is force-quit (eg. cmd+r in IntelliJ runner)
RunEvent::ExitRequested { api, .. } => {
api.prevent_exit();
tauri::async_runtime::block_on(async move {
info!("Exiting plugin runtime due to app exit");
let manager: State<Mutex<PluginManager>> = app.state();
manager.lock().await.cleanup().await;
exit(0);
});
}
_ => {}
})
.build()
}
pub mod plugin;
mod server;

View File

@@ -0,0 +1,26 @@
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// let dir = env::var("YAAK_PLUGINS_DIR").expect("YAAK_PLUGINS_DIR not set");
//
// let plugin_dirs: Vec<String> = match read_dir(dir) {
// Ok(result) => {
// let mut dirs: Vec<String> = vec![];
// for entry_result in result {
// match entry_result {
// Ok(entry) => {
// if entry.path().is_dir() {
// dirs.push(entry.path().to_string_lossy().to_string())
// }
// }
// Err(_) => {
// continue;
// }
// }
// };
// dirs
// }
// Err(_) => vec![],
// };
// start_server(plugin_dirs).await.unwrap();
Ok(())
}

View File

@@ -1,34 +1,38 @@
use log::{debug, info};
use std::time::Duration;
use tauri::{AppHandle, Manager, Runtime};
use tokio::sync::watch::Sender;
use tonic::transport::Channel;
use crate::nodejs::node_start;
use crate::plugin_runtime::plugin_runtime_client::PluginRuntimeClient;
use crate::plugin_runtime::{
HookExportRequest, HookImportRequest, HookResponse, HookResponseFilterRequest,
use crate::error::Result;
use crate::events::{
ExportHttpRequestRequest, ExportHttpRequestResponse, FilterRequest, FilterResponse,
ImportRequest, ImportResponse, InternalEventPayload,
};
use crate::error::Error::PluginErr;
use crate::nodejs::start_nodejs_plugin_runtime;
use crate::plugin::start_server;
use crate::server::PluginRuntimeGrpcServer;
use std::time::Duration;
use tauri::{AppHandle, Runtime};
use tokio::sync::watch::Sender;
use yaak_models::models::HttpRequest;
pub struct PluginManager {
client: PluginRuntimeClient<Channel>,
kill_tx: Sender<bool>,
server: PluginRuntimeGrpcServer,
}
impl PluginManager {
pub async fn new<R: Runtime>(app_handle: &AppHandle<R>) -> PluginManager {
let temp_dir = app_handle.path().temp_dir().unwrap();
pub async fn new<R: Runtime>(
app_handle: &AppHandle<R>,
plugin_dirs: Vec<String>,
) -> PluginManager {
let (server, addr) = start_server(plugin_dirs)
.await
.expect("Failed to start plugin runtime server");
let (kill_tx, kill_rx) = tokio::sync::watch::channel(false);
let start_resp = node_start(app_handle, &temp_dir, &kill_rx).await;
info!("Connecting to gRPC client at {}", start_resp.addr);
start_nodejs_plugin_runtime(app_handle, addr, &kill_rx)
.await
.expect("Failed to start plugin runtime");
let client = match PluginRuntimeClient::connect(start_resp.addr.clone()).await {
Ok(v) => v,
Err(err) => panic!("{}", err.to_string()),
};
PluginManager { client, kill_tx }
PluginManager { kill_tx, server }
}
pub async fn cleanup(&mut self) {
@@ -38,49 +42,81 @@ impl PluginManager {
tokio::time::sleep(Duration::from_millis(500)).await;
}
pub async fn run_import(&mut self, data: &str) -> Result<HookResponse, String> {
let response = self
.client
.hook_import(tonic::Request::new(HookImportRequest {
data: data.to_string(),
pub async fn run_import(&mut self, content: &str) -> Result<(ImportResponse, String)> {
let reply_events = self
.server
.send_and_wait(&InternalEventPayload::ImportRequest(ImportRequest {
content: content.to_string(),
}))
.await
.map_err(|e| e.message().to_string())?;
.await?;
Ok(response.into_inner())
// TODO: Don't just return the first valid response
for event in reply_events {
match event.payload {
InternalEventPayload::ImportResponse(resp) => {
let ref_id = event.plugin_ref_id.as_str();
let plugin = self.server.plugin_by_ref_id(ref_id).await?;
let plugin_name = plugin.name().await;
return Ok((resp, plugin_name));
}
_ => {}
}
}
Err(PluginErr("No import responses found".to_string()))
}
pub async fn run_export_curl(&mut self, request: &str) -> Result<HookResponse, String> {
let response = self
.client
.hook_export(tonic::Request::new(HookExportRequest {
request: request.to_string(),
}))
.await
.map_err(|e| e.message().to_string())?;
pub async fn run_export_curl(
&mut self,
request: &HttpRequest,
) -> Result<ExportHttpRequestResponse> {
let event = self
.server
.send_to_plugin_and_wait(
"exporter-curl",
&InternalEventPayload::ExportHttpRequestRequest(ExportHttpRequestRequest {
http_request: request.to_owned(),
}),
)
.await?;
Ok(response.into_inner())
match event.payload {
InternalEventPayload::ExportHttpRequestResponse(resp) => Ok(resp),
InternalEventPayload::EmptyResponse(_) => {
Err(PluginErr("Export returned empty".to_string()))
}
e => Err(PluginErr(format!("Export returned invalid event {:?}", e))),
}
}
pub async fn run_response_filter(
pub async fn run_filter(
&mut self,
filter: &str,
body: &str,
content: &str,
content_type: &str,
) -> Result<HookResponse, String> {
debug!("Running plugin filter");
let response = self
.client
.hook_response_filter(tonic::Request::new(HookResponseFilterRequest {
filter: filter.to_string(),
body: body.to_string(),
content_type: content_type.to_string(),
}))
.await
.map_err(|e| e.message().to_string())?;
) -> Result<FilterResponse> {
let plugin_name = match content_type {
"application/json" => "filter-jsonpath",
_ => "filter-xpath",
};
let result = response.into_inner();
debug!("Ran plugin response filter {}", result.data);
Ok(result)
let event = self
.server
.send_to_plugin_and_wait(
plugin_name,
&InternalEventPayload::FilterRequest(FilterRequest {
filter: filter.to_string(),
content: content.to_string(),
}),
)
.await?;
match event.payload {
InternalEventPayload::FilterResponse(resp) => Ok(resp),
InternalEventPayload::EmptyResponse(_) => {
Err(PluginErr("Filter returned empty".to_string()))
}
e => Err(PluginErr(format!("Export returned invalid event {:?}", e))),
}
}
}

View File

@@ -1,14 +1,12 @@
use std::path::PathBuf;
use std::time::Duration;
use std::net::SocketAddr;
use crate::error::Result;
use log::info;
use rand::distributions::{Alphanumeric, DistString};
use serde;
use serde::Deserialize;
use tauri::path::BaseDirectory;
use tauri::{AppHandle, Manager, Runtime};
use tauri_plugin_shell::process::CommandEvent;
use tauri_plugin_shell::ShellExt;
use tokio::fs;
use tokio::sync::watch::Receiver;
#[derive(Deserialize, Default)]
@@ -17,57 +15,48 @@ struct PortFile {
port: i32,
}
pub struct StartResp {
pub addr: String,
}
pub async fn node_start<R: Runtime>(
pub async fn start_nodejs_plugin_runtime<R: Runtime>(
app: &AppHandle<R>,
temp_dir: &PathBuf,
addr: SocketAddr,
kill_rx: &Receiver<bool>,
) -> StartResp {
let port_file_path = temp_dir.join(Alphanumeric.sample_string(&mut rand::thread_rng(), 10));
let plugins_dir = app
.path()
.resolve("plugins", BaseDirectory::Resource)
.expect("failed to resolve plugin directory resource");
) -> Result<()> {
let plugin_runtime_main = app
.path()
.resolve("plugin-runtime", BaseDirectory::Resource)
.expect("failed to resolve plugin runtime resource")
.resolve("plugin-runtime", BaseDirectory::Resource)?
.join("index.cjs");
// HACK: Remove UNC prefix for Windows paths to pass to sidecar
let plugins_dir = dunce::simplified(plugins_dir.as_path())
.to_string_lossy()
.to_string();
let plugin_runtime_main = dunce::simplified(plugin_runtime_main.as_path())
.to_string_lossy()
.to_string();
info!(
"Starting plugin runtime\n → port_file={}\n → plugins_dir={}\n → runtime_dir={}",
port_file_path.to_string_lossy(),
plugins_dir,
plugin_runtime_main,
);
info!("Starting plugin runtime main={}", plugin_runtime_main);
let cmd = app
.shell()
.sidecar("yaaknode")
.expect("yaaknode not found")
.env("YAAK_GRPC_PORT_FILE_PATH", port_file_path.clone())
.env("YAAK_PLUGINS_DIR", plugins_dir)
.sidecar("yaaknode")?
.env("PORT", addr.port().to_string())
.args(&[plugin_runtime_main]);
println!("Waiting on plugin runtime");
let (_, child) = cmd.spawn().expect("yaaknode failed to start");
let (mut child_rx, child) = cmd.spawn()?;
println!("Spawned plugin runtime");
let mut kill_rx = kill_rx.clone();
tokio::spawn(async move {
while let Some(event) = child_rx.recv().await {
match event {
CommandEvent::Stderr(line) => {
print!("{}", String::from_utf8(line).unwrap());
}
CommandEvent::Stdout(line) => {
print!("{}", String::from_utf8(line).unwrap());
}
_ => {}
}
}
});
// Check on child
tokio::spawn(async move {
kill_rx
@@ -77,26 +66,7 @@ pub async fn node_start<R: Runtime>(
info!("Killing plugin runtime");
child.kill().expect("Failed to kill plugin runtime");
info!("Killed plugin runtime");
return;
});
let start = std::time::Instant::now();
let port_file_contents = loop {
if start.elapsed().as_millis() > 30000 {
panic!("Failed to read port file in time");
}
match fs::read_to_string(port_file_path.clone()).await {
Ok(s) => break s,
Err(_) => {
tokio::time::sleep(Duration::from_millis(500)).await;
}
}
};
let port_file: PortFile = serde_json::from_str(port_file_contents.as_str()).unwrap();
info!("Started plugin runtime on :{}", port_file.port);
let addr = format!("http://localhost:{}", port_file.port);
StartResp { addr }
Ok(())
}

View File

@@ -0,0 +1,112 @@
use std::net::SocketAddr;
use std::path::PathBuf;
use std::process::exit;
use std::time::Duration;
use log::info;
use tauri::path::BaseDirectory;
use tauri::plugin::{Builder, TauriPlugin};
use tauri::{Manager, RunEvent, Runtime, State};
use tokio::fs::read_dir;
use tokio::net::TcpListener;
use tokio::sync::Mutex;
use tonic::codegen::tokio_stream;
use tonic::transport::Server;
use crate::error::Result;
use crate::events::{InternalEvent, InternalEventPayload};
use crate::manager::PluginManager;
use crate::server::plugin_runtime::plugin_runtime_server::PluginRuntimeServer;
use crate::server::PluginRuntimeGrpcServer;
pub fn init<R: Runtime>() -> TauriPlugin<R> {
Builder::new("yaak_plugin_runtime")
.setup(|app, _| {
let plugins_dir = app
.path()
.resolve("plugins", BaseDirectory::Resource)
.expect("failed to resolve plugin directory resource");
tauri::async_runtime::block_on(async move {
let plugin_dirs = read_plugins_dir(&plugins_dir)
.await
.expect("Failed to read plugins dir");
let manager = PluginManager::new(&app, plugin_dirs).await;
let manager_state = Mutex::new(manager);
app.manage(manager_state);
Ok(())
})
})
.on_event(|app, e| match e {
// TODO: Also exit when app is force-quit (eg. cmd+r in IntelliJ runner)
RunEvent::ExitRequested { api, .. } => {
api.prevent_exit();
tauri::async_runtime::block_on(async move {
info!("Exiting plugin runtime due to app exit");
let manager: State<Mutex<PluginManager>> = app.state();
manager.lock().await.cleanup().await;
exit(0);
});
}
_ => {}
})
.build()
}
pub async fn start_server(
plugin_dirs: Vec<String>,
) -> Result<(PluginRuntimeGrpcServer, SocketAddr)> {
println!("Starting plugin server with {plugin_dirs:?}");
let server = PluginRuntimeGrpcServer::new(plugin_dirs);
let svc = PluginRuntimeServer::new(server.clone());
let listen_addr = match option_env!("PORT") {
None => "localhost:0".to_string(),
Some(port) => format!("localhost:{port}"),
};
{
let server = server.clone();
tokio::spawn(async move {
let (rx_id, mut rx) = server.subscribe().await;
while let Some(event) = rx.recv().await {
match event.clone() {
InternalEvent {
payload: InternalEventPayload::BootResponse(resp),
plugin_ref_id,
..
} => {
server.boot_plugin(plugin_ref_id.as_str(), &resp).await;
}
_ => {}
};
}
server.unsubscribe(rx_id).await;
});
};
let listener = TcpListener::bind(listen_addr).await?;
let addr = listener.local_addr()?;
println!("Starting gRPC plugin server on {addr}");
tokio::spawn(async move {
Server::builder()
.timeout(Duration::from_secs(10))
.add_service(svc)
.serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener))
.await
.expect("grpc plugin runtime server failed to start");
});
Ok((server, addr))
}
async fn read_plugins_dir(dir: &PathBuf) -> Result<Vec<String>> {
let mut result = read_dir(dir).await?;
let mut dirs: Vec<String> = vec![];
while let Ok(Some(entry)) = result.next_entry().await {
if entry.path().is_dir() {
dirs.push(entry.path().to_string_lossy().to_string())
}
}
Ok(dirs)
}

View File

@@ -0,0 +1,448 @@
use log::info;
use rand::distributions::{Alphanumeric, DistString};
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::mpsc::Receiver;
use tokio::sync::{mpsc, Mutex};
use tonic::codegen::tokio_stream::wrappers::ReceiverStream;
use tonic::codegen::tokio_stream::{Stream, StreamExt};
use tonic::{Request, Response, Status, Streaming};
use crate::error::Error::{NoPluginsErr, PluginNotFoundErr};
use crate::error::Result;
use crate::events::{BootRequest, BootResponse, InternalEvent, InternalEventPayload};
use crate::server::plugin_runtime::plugin_runtime_server::PluginRuntime;
use plugin_runtime::EventStreamEvent;
use yaak_models::queries::generate_id;
pub mod plugin_runtime {
tonic::include_proto!("yaak.plugins.runtime");
}
type ResponseStream =
Pin<Box<dyn Stream<Item = std::result::Result<EventStreamEvent, Status>> + Send>>;
#[derive(Clone)]
pub struct PluginHandle {
dir: String,
to_plugin_tx: Arc<Mutex<mpsc::Sender<tonic::Result<EventStreamEvent>>>>,
ref_id: String,
boot_resp: Arc<Mutex<Option<BootResponse>>>,
}
impl PluginHandle {
pub async fn name(&self) -> String {
match &*self.boot_resp.lock().await {
None => "__NOT_BOOTED__".to_string(),
Some(r) => r.name.to_owned(),
}
}
pub fn build_event_to_send(
&self,
payload: &InternalEventPayload,
reply_id: Option<String>,
) -> InternalEvent {
InternalEvent {
id: gen_id(),
plugin_ref_id: self.ref_id.clone(),
reply_id,
payload: payload.clone(),
}
}
pub async fn send(&self, event: &InternalEvent) -> Result<()> {
info!("Sending event {} {:?}", event.id, self.name().await);
self.to_plugin_tx
.lock()
.await
.send(Ok(EventStreamEvent {
event: serde_json::to_string(&event)?,
}))
.await?;
Ok(())
}
pub async fn boot(&self, resp: &BootResponse) {
let mut boot_resp = self.boot_resp.lock().await;
*boot_resp = Some(resp.clone());
}
}
#[derive(Clone)]
pub struct PluginRuntimeGrpcServer {
plugin_ref_to_plugin: Arc<Mutex<HashMap<String, PluginHandle>>>,
callback_to_plugin_ref: Arc<Mutex<HashMap<String, String>>>,
subscribers: Arc<Mutex<HashMap<String, mpsc::Sender<InternalEvent>>>>,
plugin_dirs: Vec<String>,
}
impl PluginRuntimeGrpcServer {
pub fn new(plugin_dirs: Vec<String>) -> Self {
PluginRuntimeGrpcServer {
plugin_ref_to_plugin: Arc::new(Mutex::new(HashMap::new())),
callback_to_plugin_ref: Arc::new(Mutex::new(HashMap::new())),
subscribers: Arc::new(Mutex::new(HashMap::new())),
plugin_dirs,
}
}
pub async fn subscribe(&self) -> (String, Receiver<InternalEvent>) {
let (tx, rx) = mpsc::channel(128);
let id = generate_id();
self.subscribers.lock().await.insert(id.clone(), tx);
(id, rx)
}
pub async fn unsubscribe(&self, rx_id: String) {
self.subscribers.lock().await.remove(rx_id.as_str());
}
pub async fn remove_plugins(&self, plugin_ids: Vec<String>) {
for plugin_id in plugin_ids {
self.remove_plugin(plugin_id.as_str()).await;
}
}
pub async fn remove_plugin(&self, id: &str) {
match self.plugin_ref_to_plugin.lock().await.remove(id) {
None => {
println!("Tried to remove non-existing plugin {}", id);
}
Some(plugin) => {
println!("Removed plugin {} {}", id, plugin.name().await);
}
};
}
pub async fn boot_plugin(&self, id: &str, resp: &BootResponse) {
match self.plugin_ref_to_plugin.lock().await.get(id) {
None => {
println!("Tried booting non-existing plugin {}", id);
}
Some(plugin) => {
plugin.clone().boot(resp).await;
}
}
}
pub async fn add_plugin(
&self,
dir: &str,
tx: mpsc::Sender<tonic::Result<EventStreamEvent>>,
) -> PluginHandle {
let ref_id = gen_id();
let plugin_handle = PluginHandle {
ref_id: ref_id.clone(),
dir: dir.to_string(),
to_plugin_tx: Arc::new(Mutex::new(tx)),
boot_resp: Arc::new(Mutex::new(None)),
};
let _ = self
.plugin_ref_to_plugin
.lock()
.await
.insert(ref_id, plugin_handle.clone());
plugin_handle
}
// pub async fn callback(
// &self,
// source_event: InternalEvent,
// payload: InternalEventPayload,
// ) -> Result<InternalEvent> {
// let reply_id = match source_event.clone().reply_id {
// None => {
// let msg = format!("Source event missing reply Id {:?}", source_event.clone());
// return Err(MissingCallbackIdErr(msg));
// }
// Some(id) => id,
// };
//
// let callbacks = self.callbacks.lock().await;
// let plugin_name = match callbacks.get(reply_id.as_str()) {
// None => {
// let msg = format!("Callback not found {:?}", source_event);
// return Err(MissingCallbackErr(msg));
// }
// Some(n) => n,
// };
//
// let plugins = self.plugins.lock().await;
// let plugin = match plugins.get(plugin_name) {
// None => {
// let msg = format!(
// "Plugin not found {plugin_name}. Choices were {:?}",
// plugins.keys()
// );
// return Err(UnknownPluginErr(msg));
// }
// Some(n) => n,
// };
//
// plugin.send(&payload, Some(reply_id)).await
// }
pub async fn plugin_by_ref_id(&self, ref_id: &str) -> Result<PluginHandle> {
let plugins = self.plugin_ref_to_plugin.lock().await;
if plugins.is_empty() {
return Err(NoPluginsErr("Send failed because no plugins exist".into()));
}
match plugins.get(ref_id) {
None => {
let msg = format!("Failed to find plugin for id {ref_id}");
Err(PluginNotFoundErr(msg))
}
Some(p) => Ok(p.to_owned()),
}
}
pub async fn plugin_by_name(&self, plugin_name: &str) -> Result<PluginHandle> {
let plugins = self.plugin_ref_to_plugin.lock().await;
if plugins.is_empty() {
return Err(NoPluginsErr("Send failed because no plugins exist".into()));
}
for p in plugins.values() {
if p.name().await == plugin_name {
return Ok(p.to_owned());
}
}
let msg = format!("Failed to find plugin for {plugin_name}");
Err(PluginNotFoundErr(msg))
}
pub async fn send_to_plugin(
&self,
plugin_name: &str,
payload: InternalEventPayload,
) -> Result<InternalEvent> {
let plugins = self.plugin_ref_to_plugin.lock().await;
if plugins.is_empty() {
return Err(NoPluginsErr("Send failed because no plugins exist".into()));
}
let mut plugin = None;
for p in plugins.values() {
if p.name().await == plugin_name {
plugin = Some(p);
break;
}
}
match plugin {
Some(plugin) => {
let event = plugin.build_event_to_send(&payload, None);
plugin.send(&event).await?;
Ok(event)
}
None => {
let msg = format!("Failed to find plugin for {plugin_name}");
Err(PluginNotFoundErr(msg))
}
}
}
pub async fn send_to_plugin_and_wait(
&self,
plugin_name: &str,
payload: &InternalEventPayload,
) -> Result<InternalEvent> {
let plugin = self.plugin_by_name(plugin_name).await?;
let events = self.send_to_plugins_and_wait(payload, vec![plugin]).await?;
Ok(events.first().unwrap().to_owned())
}
pub async fn send_and_wait(
&self,
payload: &InternalEventPayload,
) -> Result<Vec<InternalEvent>> {
let plugins = self
.plugin_ref_to_plugin
.lock()
.await
.values()
.cloned()
.collect();
self.send_to_plugins_and_wait(payload, plugins).await
}
async fn send_to_plugins_and_wait(
&self,
payload: &InternalEventPayload,
plugins: Vec<PluginHandle>,
) -> Result<Vec<InternalEvent>> {
// 1. Build the events with IDs and everything
let events_to_send = plugins
.iter()
.map(|p| p.build_event_to_send(payload, None))
.collect::<Vec<InternalEvent>>();
// 2. Spawn thread to subscribe to incoming events and check reply ids
let server = self.clone();
let send_events_fut = {
let events_to_send = events_to_send.clone();
tokio::spawn(async move {
let (rx_id, mut rx) = server.subscribe().await;
let mut found_events = Vec::new();
while let Some(event) = rx.recv().await {
if events_to_send
.iter()
.find(|e| Some(e.id.to_owned()) == event.reply_id)
.is_some()
{
found_events.push(event.clone());
};
if found_events.len() == events_to_send.len() {
break;
}
}
server.unsubscribe(rx_id).await;
found_events
})
};
// 3. Send the events
for event in events_to_send {
let plugin = plugins
.iter()
.find(|p| p.ref_id == event.plugin_ref_id)
.expect("Didn't find plugin in list");
plugin.send(&event).await?
}
// 4. Join on the spawned thread
let events = send_events_fut.await.expect("Thread didn't succeed");
Ok(events)
}
pub async fn send(&self, payload: InternalEventPayload) -> Result<Vec<InternalEvent>> {
let mut events: Vec<InternalEvent> = Vec::new();
let plugins = self.plugin_ref_to_plugin.lock().await;
if plugins.is_empty() {
return Err(NoPluginsErr("Send failed because no plugins exist".into()));
}
for ph in plugins.values() {
let event = ph.build_event_to_send(&payload, None);
self.send_to_plugin_handle(ph, &event).await?;
events.push(event);
}
Ok(events)
}
async fn send_to_plugin_handle(
&self,
plugin: &PluginHandle,
event: &InternalEvent,
) -> Result<()> {
plugin.send(event).await
}
async fn load_plugins(
&self,
to_plugin_tx: mpsc::Sender<tonic::Result<EventStreamEvent>>,
plugin_dirs: Vec<String>,
) -> Vec<String> {
let mut plugin_ids = Vec::new();
for dir in plugin_dirs {
let plugin = self.add_plugin(dir.as_str(), to_plugin_tx.clone()).await;
plugin_ids.push(plugin.clone().ref_id);
let event = plugin.build_event_to_send(
&InternalEventPayload::BootRequest(BootRequest {
dir: dir.to_string(),
}),
None,
);
if let Err(e) = plugin.send(&event).await {
// TODO: Error handling
println!(
"Failed boot plugin {} at {} -> {}",
plugin.ref_id, plugin.dir, e
)
} else {
println!("Loaded plugin {} at {}", plugin.ref_id, plugin.dir)
}
}
plugin_ids
}
}
#[tonic::async_trait]
impl PluginRuntime for PluginRuntimeGrpcServer {
type EventStreamStream = ResponseStream;
async fn event_stream(
&self,
req: Request<Streaming<EventStreamEvent>>,
) -> tonic::Result<Response<Self::EventStreamStream>> {
let mut in_stream = req.into_inner();
let (to_plugin_tx, to_plugin_rx) = mpsc::channel(128);
let plugin_ids = self
.load_plugins(to_plugin_tx, self.plugin_dirs.clone())
.await;
let callbacks = self.callback_to_plugin_ref.clone();
let server = self.clone();
tokio::spawn(async move {
while let Some(result) = in_stream.next().await {
match result {
Ok(v) => {
let event: InternalEvent = match serde_json::from_str(v.event.as_str()) {
Ok(pe) => pe,
Err(e) => {
println!("Failed to deserialize event {e:?} -> {}", v.event);
continue;
}
};
let plugin_ref_id = event.plugin_ref_id.clone();
let reply_id = event.reply_id.clone();
let subscribers = server.subscribers.lock().await;
for tx in subscribers.values() {
// Emit event to the channel for server to handle
if let Err(e) = tx.try_send(event.clone()) {
println!("Failed to send to server channel. Receiver probably isn't listening: {:?}", e);
}
}
// Add to callbacks if there's a reply_id
if let Some(reply_id) = reply_id {
callbacks.lock().await.insert(reply_id, plugin_ref_id);
}
}
Err(err) => {
// TODO: Better error handling
println!("gRPC server error {err}");
break;
}
};
}
server.remove_plugins(plugin_ids).await;
});
// Write the same data that was received
let out_stream = ReceiverStream::new(to_plugin_rx);
Ok(Response::new(
Box::pin(out_stream) as Self::EventStreamStream
))
}
}
fn gen_id() -> String {
Alphanumeric.sample_string(&mut rand::thread_rng(), 5)
}