diff --git a/crates/yaak/src/send.rs b/crates/yaak/src/send.rs index 9082c2c5..01b7239d 100644 --- a/crates/yaak/src/send.rs +++ b/crates/yaak/src/send.rs @@ -3,8 +3,11 @@ use async_trait::async_trait; use log::warn; use std::path::{Path, PathBuf}; use std::sync::Arc; +use std::sync::atomic::{AtomicI32, Ordering}; use std::time::Instant; use thiserror::Error; +use tokio::fs::File; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::sync::mpsc; use tokio::sync::watch; use yaak_crypto::manager::EncryptionManager; @@ -14,11 +17,12 @@ use yaak_http::client::{ use yaak_http::cookies::CookieStore; use yaak_http::manager::HttpConnectionManager; use yaak_http::sender::{HttpResponseEvent as SenderHttpResponseEvent, ReqwestSender}; +use yaak_http::tee_reader::TeeReader; use yaak_http::transaction::HttpTransaction; use yaak_http::types::{ SendableBody, SendableHttpRequest, SendableHttpRequestOptions, append_query_params, }; -use yaak_models::blob_manager::BlobManager; +use yaak_models::blob_manager::{BlobManager, BodyChunk}; use yaak_models::models::{ ClientCertificate, CookieJar, DnsOverride, Environment, HttpRequest, HttpResponse, HttpResponseEvent, HttpResponseHeader, HttpResponseState, ProxySetting, ProxySettingAuth, @@ -34,6 +38,8 @@ use yaak_templates::{RenderOptions, TemplateCallback}; use yaak_tls::find_client_certificate; const HTTP_EVENT_CHANNEL_CAPACITY: usize = 100; +const REQUEST_BODY_CHUNK_SIZE: usize = 1024 * 1024; +const RESPONSE_PROGRESS_UPDATE_INTERVAL_MS: u128 = 100; #[derive(Debug, Error)] pub enum SendHttpRequestError { @@ -233,6 +239,7 @@ pub struct SendHttpRequestByIdParams<'a, T: TemplateCallback> { pub cookie_jar_id: Option, pub response_dir: &'a Path, pub emit_events_to: Option>, + pub cancelled_rx: Option>, pub prepare_sendable_request: Option<&'a dyn PrepareSendableRequest>, pub executor: Option<&'a dyn SendRequestExecutor>, } @@ -248,6 +255,7 @@ pub struct SendHttpRequestParams<'a, T: TemplateCallback> { pub cookie_jar_id: Option, pub response_dir: &'a Path, pub emit_events_to: Option>, + pub cancelled_rx: Option>, pub auth_context_id: Option, pub existing_response: Option, pub prepare_sendable_request: Option<&'a dyn PrepareSendableRequest>, @@ -389,6 +397,7 @@ pub async fn send_http_request_with_plugins( cookie_jar_id: params.cookie_jar_id, response_dir: params.response_dir, emit_events_to: params.emit_events_to, + cancelled_rx: params.cancelled_rx, auth_context_id: None, existing_response: params.existing_response, prepare_sendable_request: Some(&auth_hook), @@ -418,6 +427,7 @@ pub async fn send_http_request_by_id( cookie_jar_id: params.cookie_jar_id, response_dir: params.response_dir, emit_events_to: params.emit_events_to, + cancelled_rx: params.cancelled_rx, existing_response: None, prepare_sendable_request: params.prepare_sendable_request, executor: params.executor, @@ -499,6 +509,35 @@ pub async fn send_http_request( response.id = generate_prefixed_id("rs"); } + let request_body_id = format!("{}.request", response.id); + let mut request_body_capture_task = None; + let mut request_body_capture_error = None; + if persist_response { + match sendable_request.body.as_mut() { + Some(SendableBody::Bytes(bytes)) => { + if let Err(err) = persist_request_body_bytes( + params.blob_manager, + &request_body_id, + bytes.as_ref(), + ) { + request_body_capture_error = Some(err); + } + } + Some(SendableBody::Stream { data, .. }) => { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel::>(); + let inner = std::mem::replace(data, Box::pin(tokio::io::empty())); + let tee_reader = TeeReader::new(inner, tx); + *data = Box::pin(tee_reader); + let blob_manager = params.blob_manager.clone(); + let body_id = request_body_id.clone(); + request_body_capture_task = Some(tokio::spawn(async move { + persist_request_body_stream(blob_manager, body_id, rx).await + })); + } + None => {} + } + } + let (event_tx, mut event_rx) = mpsc::channel::(HTTP_EVENT_CHANNEL_CAPACITY); let event_query_manager = params.query_manager.clone(); @@ -506,8 +545,14 @@ pub async fn send_http_request( let event_workspace_id = params.request.workspace_id.clone(); let event_update_source = params.update_source.clone(); let emit_events_to = params.emit_events_to.clone(); + let dns_elapsed = Arc::new(AtomicI32::new(0)); + let event_dns_elapsed = dns_elapsed.clone(); let event_handle = tokio::spawn(async move { while let Some(event) = event_rx.recv().await { + if let SenderHttpResponseEvent::DnsResolved { duration, .. } = &event { + event_dns_elapsed.store(u64_to_i32(*duration), Ordering::Relaxed); + } + if persist_response { let db_event = HttpResponseEvent::new( &event_response_id, @@ -533,7 +578,9 @@ pub async fn send_http_request( let started_at = Instant::now(); let request_started_url = sendable_request.url.clone(); - let http_response = match executor.send(sendable_request, event_tx, cookie_store.clone()).await + let mut http_response = match executor + .send(sendable_request, event_tx, cookie_store.clone()) + .await { Ok(response) => response, Err(err) => { @@ -552,6 +599,9 @@ pub async fn send_http_request( if let Err(join_err) = event_handle.await { warn!("Failed to join response event task: {}", join_err); } + if let Some(task) = request_body_capture_task.take() { + let _ = task.await; + } return Err(SendHttpRequestError::SendRequest(err)); } }; @@ -565,6 +615,7 @@ pub async fn send_http_request( url: http_response.url.clone(), remote_addr: http_response.remote_addr.clone(), version: http_response.version.clone(), + elapsed_dns: dns_elapsed.load(Ordering::Relaxed), headers: http_response .headers .iter() @@ -581,19 +632,12 @@ pub async fn send_http_request( response = params .query_manager .connect() - .upsert_http_response( - &connected_response, - ¶ms.update_source, - params.blob_manager, - ) + .upsert_http_response(&connected_response, ¶ms.update_source, params.blob_manager) .map_err(SendHttpRequestError::PersistResponse)?; } else { response = connected_response; } - let (response_body, body_stats) = - http_response.bytes().await.map_err(SendHttpRequestError::ReadResponseBody)?; - std::fs::create_dir_all(params.response_dir).map_err(|source| { SendHttpRequestError::CreateResponseDirectory { path: params.response_dir.to_path_buf(), @@ -602,16 +646,137 @@ pub async fn send_http_request( })?; let body_path = params.response_dir.join(&response.id); - std::fs::write(&body_path, &response_body).map_err(|source| { - SendHttpRequestError::WriteResponseBody { path: body_path.clone(), source } - })?; + let mut file = + File::options().create(true).truncate(true).write(true).open(&body_path).await.map_err( + |source| SendHttpRequestError::WriteResponseBody { path: body_path.clone(), source }, + )?; + let mut body_stream = + http_response.into_body_stream().map_err(SendHttpRequestError::ReadResponseBody)?; + let mut response_body = Vec::new(); + let mut body_read_error = None; + let mut written_bytes: usize = 0; + let mut last_progress_update = started_at; + let mut cancelled_rx = params.cancelled_rx.clone(); + loop { + let read_result = if let Some(cancelled_rx) = cancelled_rx.as_mut() { + if *cancelled_rx.borrow() { + break; + } + + tokio::select! { + biased; + _ = cancelled_rx.changed() => { + None + } + result = body_stream.read_buf(&mut response_body) => { + Some(result) + } + } + } else { + Some(body_stream.read_buf(&mut response_body).await) + }; + + let Some(read_result) = read_result else { + break; + }; + + match read_result { + Ok(0) => break, + Ok(n) => { + written_bytes += n; + let start_idx = response_body.len() - n; + file.write_all(&response_body[start_idx..]).await.map_err(|source| { + SendHttpRequestError::WriteResponseBody { path: body_path.clone(), source } + })?; + + let now = Instant::now(); + let should_update = now.duration_since(last_progress_update).as_millis() + >= RESPONSE_PROGRESS_UPDATE_INTERVAL_MS; + if should_update { + let elapsed = duration_to_i32(started_at.elapsed()); + let progress_response = HttpResponse { + elapsed, + content_length: Some(usize_to_i32(written_bytes)), + elapsed_dns: dns_elapsed.load(Ordering::Relaxed), + ..response.clone() + }; + if persist_response { + response = params + .query_manager + .connect() + .upsert_http_response( + &progress_response, + ¶ms.update_source, + params.blob_manager, + ) + .map_err(SendHttpRequestError::PersistResponse)?; + } else { + response = progress_response; + } + last_progress_update = now; + } + } + Err(err) => { + body_read_error = Some(SendHttpRequestError::ReadResponseBody( + yaak_http::error::Error::BodyReadError(err.to_string()), + )); + break; + } + } + } + + file.flush().await.map_err(|source| SendHttpRequestError::WriteResponseBody { + path: body_path.clone(), + source, + })?; + drop(body_stream); + + if let Some(task) = request_body_capture_task.take() { + match task.await { + Ok(Ok(total)) => { + response.request_content_length = Some(usize_to_i32(total)); + } + Ok(Err(err)) => request_body_capture_error = Some(err), + Err(err) => request_body_capture_error = Some(err.to_string()), + } + } + + if let Some(err) = request_body_capture_error.take() { + response.error = Some(append_error_message( + response.error.take(), + format!("Request succeeded but failed to store request body: {err}"), + )); + } + + if let Err(join_err) = event_handle.await { + warn!("Failed to join response event task: {}", join_err); + } + + if let Some(err) = body_read_error { + if persist_response { + let _ = persist_response_error( + params.query_manager, + params.blob_manager, + ¶ms.update_source, + &response, + started_at, + err.to_string(), + request_started_url, + ); + } + persist_cookie_jar(params.query_manager, cookie_jar.as_mut(), cookie_store.as_ref())?; + return Err(err); + } + + let compressed_length = http_response.content_length.unwrap_or(written_bytes as u64); let final_response = HttpResponse { body_path: Some(body_path.to_string_lossy().to_string()), - content_length: Some(usize_to_i32(response_body.len())), - content_length_compressed: Some(u64_to_i32(body_stats.size_compressed)), + content_length: Some(usize_to_i32(written_bytes)), + content_length_compressed: Some(u64_to_i32(compressed_length)), elapsed: duration_to_i32(started_at.elapsed()), elapsed_headers: headers_elapsed, + elapsed_dns: dns_elapsed.load(Ordering::Relaxed), state: HttpResponseState::Closed, ..response }; @@ -625,14 +790,60 @@ pub async fn send_http_request( response = final_response; } - if let Err(join_err) = event_handle.await { - warn!("Failed to join response event task: {}", join_err); - } persist_cookie_jar(params.query_manager, cookie_jar.as_mut(), cookie_store.as_ref())?; Ok(SendHttpRequestResult { rendered_request, response, response_body }) } +fn persist_request_body_bytes( + blob_manager: &BlobManager, + body_id: &str, + bytes: &[u8], +) -> std::result::Result<(), String> { + if bytes.is_empty() { + return Ok(()); + } + + let blob_ctx = blob_manager.connect(); + let mut offset = 0; + let mut chunk_index: i32 = 0; + while offset < bytes.len() { + let end = std::cmp::min(offset + REQUEST_BODY_CHUNK_SIZE, bytes.len()); + let chunk = BodyChunk::new(body_id, chunk_index, bytes[offset..end].to_vec()); + blob_ctx.insert_chunk(&chunk).map_err(|e| e.to_string())?; + chunk_index += 1; + offset = end; + } + Ok(()) +} + +async fn persist_request_body_stream( + blob_manager: BlobManager, + body_id: String, + mut rx: tokio::sync::mpsc::UnboundedReceiver>, +) -> std::result::Result { + let mut chunk_index: i32 = 0; + let mut total_bytes = 0usize; + while let Some(data) = rx.recv().await { + total_bytes += data.len(); + if data.is_empty() { + continue; + } + let chunk = BodyChunk::new(&body_id, chunk_index, data); + blob_manager.connect().insert_chunk(&chunk).map_err(|e| e.to_string())?; + chunk_index += 1; + } + + Ok(total_bytes) +} + +fn append_error_message(existing_error: Option, message: String) -> String { + match existing_error { + Some(existing) => format!("{existing}; {message}"), + None => message, + } +} + fn resolve_environment_chain( query_manager: &QueryManager, request: &HttpRequest,