Implement cancel

This commit is contained in:
Gregory Schier
2024-02-02 00:18:37 -08:00
parent b526ea506b
commit 4e781b752d
7 changed files with 121 additions and 133 deletions

View File

@@ -1,5 +1,5 @@
use std::fs::{create_dir_all, File};
use std::fs;
use std::fs::{create_dir_all, File};
use std::io::Write;
use std::path::PathBuf;
use std::str::FromStr;
@@ -7,19 +7,19 @@ use std::sync::Arc;
use std::time::Duration;
use base64::Engine;
use http::{HeaderMap, HeaderName, HeaderValue, Method};
use http::header::{ACCEPT, USER_AGENT};
use http::{HeaderMap, HeaderName, HeaderValue, Method};
use log::{error, info, warn};
use reqwest::{multipart, Url};
use reqwest::redirect::Policy;
use sqlx::{Pool, Sqlite};
use reqwest::{multipart, Url};
use sqlx::types::{Json, JsonValue};
use sqlx::{Pool, Sqlite};
use tauri::{AppHandle, Wry};
use crate::{emit_side_effect, models, render, response_err};
pub async fn send_http_request(
app_handle: &AppHandle<Wry>,
app_handle: AppHandle<Wry>,
db: &Pool<Sqlite>,
request: models::HttpRequest,
response: &models::HttpResponse,
@@ -88,7 +88,7 @@ pub async fn send_http_request(
let url = match Url::from_str(url_string.as_str()) {
Ok(u) => u,
Err(e) => {
return response_err(response, e.to_string(), app_handle, db).await;
return response_err(response, e.to_string(), app_handle.clone(), db).await;
}
};
@@ -293,7 +293,7 @@ pub async fn send_http_request(
let sendable_req = match request_builder.build() {
Ok(r) => r,
Err(e) => {
return response_err(response, e.to_string(), app_handle, db).await;
return response_err(response, e.to_string(), app_handle.clone(), db).await;
}
};
@@ -366,7 +366,7 @@ pub async fn send_http_request(
.await
.expect("Failed to update response");
if !request.id.is_empty() {
emit_side_effect(app_handle, "updated_model", &response);
emit_side_effect(app_handle.clone(), "updated_model", &response);
}
// Copy response to download path, if specified

View File

@@ -95,14 +95,6 @@ async fn cmd_grpc_reflect(endpoint: &str) -> Result<Vec<ServiceDefinition>, Stri
Ok(grpc::callable(&uri).await)
}
async fn cmd_grpc_cancel(
id: &str,
grpc_handle: State<'_, Mutex<GrpcManager>>,
) -> Result<(), String> {
// grpc_handle.lock().await.cancel(id).await.unwrap()
Ok(())
}
#[tauri::command]
async fn cmd_grpc_call_unary(
endpoint: &str,
@@ -147,6 +139,8 @@ async fn cmd_grpc_server_streaming(
app_handle: AppHandle<Wry>,
grpc_handle: State<'_, Mutex<GrpcManager>>,
) -> Result<String, String> {
let (cancelled_tx, mut cancelled_rx) = tokio::sync::watch::channel(false);
let uri = safe_uri(endpoint).map_err(|e| e.to_string())?;
let conn_id = generate_id(Some("grpc"));
@@ -157,44 +151,76 @@ async fn cmd_grpc_server_streaming(
.await
.unwrap();
loop {
match stream.message().await {
Ok(Some(item)) => {
let item = serde_json::to_string_pretty(&item).unwrap();
println!("GOT MESSAGE {:?}", item);
println!("Sending message {}", item);
emit_side_effect(&app_handle, "grpc_message", item);
}
Ok(None) => {
// sleep for a bit
println!("NO MESSAGE YET");
sleep(std::time::Duration::from_millis(100)).await;
}
Err(e) => {
return Err(e.to_string());
}
}
#[derive(serde::Deserialize)]
enum GrpcMessage {
Message(String),
Commit,
Cancel,
}
// while let Some(item) = stream.message() {
// println!("GOT MESSAGE");
// // if grpc_handle.lock().await.is_cancelled(&conn_id) {
// // break;
// // }
//
// match item {
// Ok(item) => {
// let item = serde_json::to_string_pretty(&item).unwrap();
// println!("Sending message {}", item);
// emit_side_effect(&app_handle, "grpc_message", item);
// }
// Err(e) => println!("\terror: {}", e),
// }
// // let foo = stream.trailers().await.unwrap();
// break;
// }
let cb = {
let cancelled_rx = cancelled_rx.clone();
move |ev: tauri::Event| {
if *cancelled_rx.borrow() {
// Stream is cancelled
return;
}
match serde_json::from_str::<GrpcMessage>(ev.payload().unwrap()) {
Ok(GrpcMessage::Message(msg)) => {
println!("Received message: {}", msg);
}
Ok(GrpcMessage::Commit) => {
println!("Received commit");
// TODO: Commit client streaming stream
}
Ok(GrpcMessage::Cancel) => {
println!("Received cancel");
cancelled_tx.send_replace(true);
}
Err(e) => {
error!("Failed to parse gRPC message: {:?}", e);
}
}
}
};
let event_handler = app_handle.listen_global("grpc_message_in", cb);
let app_handle2 = app_handle.clone();
let grpc_listen = async move {
loop {
match stream.next().await {
Some(Ok(item)) => {
let item = serde_json::to_string_pretty(&item).unwrap();
app_handle2
.emit_all("grpc_message", item)
.expect("Failed to emit");
}
Some(Err(e)) => {
error!("gRPC stream error: {:?}", e);
// TODO: Handle error
}
None => {
info!("gRPC stream closed by sender");
break;
}
}
}
};
tauri::async_runtime::spawn(async move {
tokio::select! {
_ = grpc_listen => {
debug!("gRPC listen finished");
},
_ = cancelled_rx.changed() => {
debug!("gRPC connection cancelled");
},
}
app_handle.unlisten(event_handler);
});
println!("DONE");
Ok(conn_id)
}
@@ -228,7 +254,7 @@ async fn cmd_send_ephemeral_request(
// let cookie_jar_id2 = cookie_jar_id.unwrap_or("").to_string();
send_http_request(
&app_handle,
app_handle,
db,
request,
&response,
@@ -381,7 +407,7 @@ async fn cmd_export_data(
#[tauri::command]
async fn cmd_send_request(
window: Window<Wry>,
app_handle: AppHandle<Wry>,
db_state: State<'_, Mutex<Pool<Sqlite>>>,
request_id: &str,
environment_id: Option<&str>,
@@ -389,7 +415,6 @@ async fn cmd_send_request(
download_dir: Option<&str>,
) -> Result<HttpResponse, String> {
let db = &*db_state.lock().await;
let app_handle = window.app_handle();
let request = get_request(db, request_id)
.await
@@ -436,10 +461,10 @@ async fn cmd_send_request(
None
};
emit_side_effect(&app_handle, "created_model", response.clone());
emit_side_effect(app_handle.clone(), "created_model", response.clone());
send_http_request(
&app_handle,
app_handle,
db,
request.clone(),
&response,
@@ -453,7 +478,7 @@ async fn cmd_send_request(
async fn response_err(
response: &HttpResponse,
error: String,
app_handle: &AppHandle<Wry>,
app_handle: AppHandle<Wry>,
db: &Pool<Sqlite>,
) -> Result<HttpResponse, String> {
let mut response = response.clone();
@@ -1000,24 +1025,25 @@ async fn cmd_check_for_updates(
fn main() {
tauri::Builder::default()
.plugin(tauri_plugin_window_state::Builder::default().build())
.plugin(
tauri_plugin_log::Builder::default()
.targets([LogTarget::LogDir, LogTarget::Stdout, LogTarget::Webview])
.level_for("tao", log::LevelFilter::Info)
.level_for("sqlx", log::LevelFilter::Warn)
.level_for("hyper", log::LevelFilter::Info)
.level_for("tracing", log::LevelFilter::Info)
.level_for("reqwest", log::LevelFilter::Info)
.level_for("tokio_util", log::LevelFilter::Info)
.level_for("cookie_store", log::LevelFilter::Info)
.level_for("h2", log::LevelFilter::Info)
.level_for("tower", log::LevelFilter::Info)
.level_for("hyper", log::LevelFilter::Info)
.level_for("hyper_rustls", log::LevelFilter::Info)
.level_for("reqwest", log::LevelFilter::Info)
.level_for("sqlx", log::LevelFilter::Warn)
.level_for("tao", log::LevelFilter::Info)
.level_for("tokio_util", log::LevelFilter::Info)
.level_for("tonic", log::LevelFilter::Info)
.level_for("tower", log::LevelFilter::Info)
.level_for("tracing", log::LevelFilter::Info)
.with_colors(ColoredLevelConfig::default())
.level(log::LevelFilter::Trace)
.build(),
)
.plugin(tauri_plugin_window_state::Builder::default().build())
.setup(|app| {
let app_data_dir = app.path_resolver().app_data_dir().unwrap();
let app_config_dir = app.path_resolver().app_config_dir().unwrap();
@@ -1292,7 +1318,7 @@ fn emit_and_return<S: Serialize + Clone, E>(
}
/// Emit an event to all windows, used for side-effects where there is no source window to attribute. This
fn emit_side_effect<S: Serialize + Clone>(app_handle: &AppHandle<Wry>, event: &str, payload: S) {
fn emit_side_effect<S: Serialize + Clone>(app_handle: AppHandle<Wry>, event: &str, payload: S) {
app_handle.emit_all(event, &payload).unwrap();
}