mirror of
https://github.com/mountain-loop/yaak.git
synced 2026-04-19 23:31:21 +02:00
Restore send parity in shared HTTP pipeline (#400)
This commit is contained in:
@@ -3,8 +3,11 @@ use async_trait::async_trait;
|
||||
use log::warn;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicI32, Ordering};
|
||||
use std::time::Instant;
|
||||
use thiserror::Error;
|
||||
use tokio::fs::File;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::watch;
|
||||
use yaak_crypto::manager::EncryptionManager;
|
||||
@@ -14,11 +17,12 @@ use yaak_http::client::{
|
||||
use yaak_http::cookies::CookieStore;
|
||||
use yaak_http::manager::HttpConnectionManager;
|
||||
use yaak_http::sender::{HttpResponseEvent as SenderHttpResponseEvent, ReqwestSender};
|
||||
use yaak_http::tee_reader::TeeReader;
|
||||
use yaak_http::transaction::HttpTransaction;
|
||||
use yaak_http::types::{
|
||||
SendableBody, SendableHttpRequest, SendableHttpRequestOptions, append_query_params,
|
||||
};
|
||||
use yaak_models::blob_manager::BlobManager;
|
||||
use yaak_models::blob_manager::{BlobManager, BodyChunk};
|
||||
use yaak_models::models::{
|
||||
ClientCertificate, CookieJar, DnsOverride, Environment, HttpRequest, HttpResponse,
|
||||
HttpResponseEvent, HttpResponseHeader, HttpResponseState, ProxySetting, ProxySettingAuth,
|
||||
@@ -34,6 +38,8 @@ use yaak_templates::{RenderOptions, TemplateCallback};
|
||||
use yaak_tls::find_client_certificate;
|
||||
|
||||
const HTTP_EVENT_CHANNEL_CAPACITY: usize = 100;
|
||||
const REQUEST_BODY_CHUNK_SIZE: usize = 1024 * 1024;
|
||||
const RESPONSE_PROGRESS_UPDATE_INTERVAL_MS: u128 = 100;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum SendHttpRequestError {
|
||||
@@ -233,6 +239,7 @@ pub struct SendHttpRequestByIdParams<'a, T: TemplateCallback> {
|
||||
pub cookie_jar_id: Option<String>,
|
||||
pub response_dir: &'a Path,
|
||||
pub emit_events_to: Option<mpsc::Sender<SenderHttpResponseEvent>>,
|
||||
pub cancelled_rx: Option<watch::Receiver<bool>>,
|
||||
pub prepare_sendable_request: Option<&'a dyn PrepareSendableRequest>,
|
||||
pub executor: Option<&'a dyn SendRequestExecutor>,
|
||||
}
|
||||
@@ -248,6 +255,7 @@ pub struct SendHttpRequestParams<'a, T: TemplateCallback> {
|
||||
pub cookie_jar_id: Option<String>,
|
||||
pub response_dir: &'a Path,
|
||||
pub emit_events_to: Option<mpsc::Sender<SenderHttpResponseEvent>>,
|
||||
pub cancelled_rx: Option<watch::Receiver<bool>>,
|
||||
pub auth_context_id: Option<String>,
|
||||
pub existing_response: Option<HttpResponse>,
|
||||
pub prepare_sendable_request: Option<&'a dyn PrepareSendableRequest>,
|
||||
@@ -389,6 +397,7 @@ pub async fn send_http_request_with_plugins(
|
||||
cookie_jar_id: params.cookie_jar_id,
|
||||
response_dir: params.response_dir,
|
||||
emit_events_to: params.emit_events_to,
|
||||
cancelled_rx: params.cancelled_rx,
|
||||
auth_context_id: None,
|
||||
existing_response: params.existing_response,
|
||||
prepare_sendable_request: Some(&auth_hook),
|
||||
@@ -418,6 +427,7 @@ pub async fn send_http_request_by_id<T: TemplateCallback>(
|
||||
cookie_jar_id: params.cookie_jar_id,
|
||||
response_dir: params.response_dir,
|
||||
emit_events_to: params.emit_events_to,
|
||||
cancelled_rx: params.cancelled_rx,
|
||||
existing_response: None,
|
||||
prepare_sendable_request: params.prepare_sendable_request,
|
||||
executor: params.executor,
|
||||
@@ -499,6 +509,35 @@ pub async fn send_http_request<T: TemplateCallback>(
|
||||
response.id = generate_prefixed_id("rs");
|
||||
}
|
||||
|
||||
let request_body_id = format!("{}.request", response.id);
|
||||
let mut request_body_capture_task = None;
|
||||
let mut request_body_capture_error = None;
|
||||
if persist_response {
|
||||
match sendable_request.body.as_mut() {
|
||||
Some(SendableBody::Bytes(bytes)) => {
|
||||
if let Err(err) = persist_request_body_bytes(
|
||||
params.blob_manager,
|
||||
&request_body_id,
|
||||
bytes.as_ref(),
|
||||
) {
|
||||
request_body_capture_error = Some(err);
|
||||
}
|
||||
}
|
||||
Some(SendableBody::Stream { data, .. }) => {
|
||||
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<Vec<u8>>();
|
||||
let inner = std::mem::replace(data, Box::pin(tokio::io::empty()));
|
||||
let tee_reader = TeeReader::new(inner, tx);
|
||||
*data = Box::pin(tee_reader);
|
||||
let blob_manager = params.blob_manager.clone();
|
||||
let body_id = request_body_id.clone();
|
||||
request_body_capture_task = Some(tokio::spawn(async move {
|
||||
persist_request_body_stream(blob_manager, body_id, rx).await
|
||||
}));
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
}
|
||||
|
||||
let (event_tx, mut event_rx) =
|
||||
mpsc::channel::<SenderHttpResponseEvent>(HTTP_EVENT_CHANNEL_CAPACITY);
|
||||
let event_query_manager = params.query_manager.clone();
|
||||
@@ -506,8 +545,14 @@ pub async fn send_http_request<T: TemplateCallback>(
|
||||
let event_workspace_id = params.request.workspace_id.clone();
|
||||
let event_update_source = params.update_source.clone();
|
||||
let emit_events_to = params.emit_events_to.clone();
|
||||
let dns_elapsed = Arc::new(AtomicI32::new(0));
|
||||
let event_dns_elapsed = dns_elapsed.clone();
|
||||
let event_handle = tokio::spawn(async move {
|
||||
while let Some(event) = event_rx.recv().await {
|
||||
if let SenderHttpResponseEvent::DnsResolved { duration, .. } = &event {
|
||||
event_dns_elapsed.store(u64_to_i32(*duration), Ordering::Relaxed);
|
||||
}
|
||||
|
||||
if persist_response {
|
||||
let db_event = HttpResponseEvent::new(
|
||||
&event_response_id,
|
||||
@@ -533,7 +578,9 @@ pub async fn send_http_request<T: TemplateCallback>(
|
||||
let started_at = Instant::now();
|
||||
let request_started_url = sendable_request.url.clone();
|
||||
|
||||
let http_response = match executor.send(sendable_request, event_tx, cookie_store.clone()).await
|
||||
let mut http_response = match executor
|
||||
.send(sendable_request, event_tx, cookie_store.clone())
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(err) => {
|
||||
@@ -552,6 +599,9 @@ pub async fn send_http_request<T: TemplateCallback>(
|
||||
if let Err(join_err) = event_handle.await {
|
||||
warn!("Failed to join response event task: {}", join_err);
|
||||
}
|
||||
if let Some(task) = request_body_capture_task.take() {
|
||||
let _ = task.await;
|
||||
}
|
||||
return Err(SendHttpRequestError::SendRequest(err));
|
||||
}
|
||||
};
|
||||
@@ -565,6 +615,7 @@ pub async fn send_http_request<T: TemplateCallback>(
|
||||
url: http_response.url.clone(),
|
||||
remote_addr: http_response.remote_addr.clone(),
|
||||
version: http_response.version.clone(),
|
||||
elapsed_dns: dns_elapsed.load(Ordering::Relaxed),
|
||||
headers: http_response
|
||||
.headers
|
||||
.iter()
|
||||
@@ -581,19 +632,12 @@ pub async fn send_http_request<T: TemplateCallback>(
|
||||
response = params
|
||||
.query_manager
|
||||
.connect()
|
||||
.upsert_http_response(
|
||||
&connected_response,
|
||||
¶ms.update_source,
|
||||
params.blob_manager,
|
||||
)
|
||||
.upsert_http_response(&connected_response, ¶ms.update_source, params.blob_manager)
|
||||
.map_err(SendHttpRequestError::PersistResponse)?;
|
||||
} else {
|
||||
response = connected_response;
|
||||
}
|
||||
|
||||
let (response_body, body_stats) =
|
||||
http_response.bytes().await.map_err(SendHttpRequestError::ReadResponseBody)?;
|
||||
|
||||
std::fs::create_dir_all(params.response_dir).map_err(|source| {
|
||||
SendHttpRequestError::CreateResponseDirectory {
|
||||
path: params.response_dir.to_path_buf(),
|
||||
@@ -602,16 +646,137 @@ pub async fn send_http_request<T: TemplateCallback>(
|
||||
})?;
|
||||
|
||||
let body_path = params.response_dir.join(&response.id);
|
||||
std::fs::write(&body_path, &response_body).map_err(|source| {
|
||||
SendHttpRequestError::WriteResponseBody { path: body_path.clone(), source }
|
||||
})?;
|
||||
let mut file =
|
||||
File::options().create(true).truncate(true).write(true).open(&body_path).await.map_err(
|
||||
|source| SendHttpRequestError::WriteResponseBody { path: body_path.clone(), source },
|
||||
)?;
|
||||
let mut body_stream =
|
||||
http_response.into_body_stream().map_err(SendHttpRequestError::ReadResponseBody)?;
|
||||
let mut response_body = Vec::new();
|
||||
let mut body_read_error = None;
|
||||
let mut written_bytes: usize = 0;
|
||||
let mut last_progress_update = started_at;
|
||||
let mut cancelled_rx = params.cancelled_rx.clone();
|
||||
|
||||
loop {
|
||||
let read_result = if let Some(cancelled_rx) = cancelled_rx.as_mut() {
|
||||
if *cancelled_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
tokio::select! {
|
||||
biased;
|
||||
_ = cancelled_rx.changed() => {
|
||||
None
|
||||
}
|
||||
result = body_stream.read_buf(&mut response_body) => {
|
||||
Some(result)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Some(body_stream.read_buf(&mut response_body).await)
|
||||
};
|
||||
|
||||
let Some(read_result) = read_result else {
|
||||
break;
|
||||
};
|
||||
|
||||
match read_result {
|
||||
Ok(0) => break,
|
||||
Ok(n) => {
|
||||
written_bytes += n;
|
||||
let start_idx = response_body.len() - n;
|
||||
file.write_all(&response_body[start_idx..]).await.map_err(|source| {
|
||||
SendHttpRequestError::WriteResponseBody { path: body_path.clone(), source }
|
||||
})?;
|
||||
|
||||
let now = Instant::now();
|
||||
let should_update = now.duration_since(last_progress_update).as_millis()
|
||||
>= RESPONSE_PROGRESS_UPDATE_INTERVAL_MS;
|
||||
if should_update {
|
||||
let elapsed = duration_to_i32(started_at.elapsed());
|
||||
let progress_response = HttpResponse {
|
||||
elapsed,
|
||||
content_length: Some(usize_to_i32(written_bytes)),
|
||||
elapsed_dns: dns_elapsed.load(Ordering::Relaxed),
|
||||
..response.clone()
|
||||
};
|
||||
if persist_response {
|
||||
response = params
|
||||
.query_manager
|
||||
.connect()
|
||||
.upsert_http_response(
|
||||
&progress_response,
|
||||
¶ms.update_source,
|
||||
params.blob_manager,
|
||||
)
|
||||
.map_err(SendHttpRequestError::PersistResponse)?;
|
||||
} else {
|
||||
response = progress_response;
|
||||
}
|
||||
last_progress_update = now;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
body_read_error = Some(SendHttpRequestError::ReadResponseBody(
|
||||
yaak_http::error::Error::BodyReadError(err.to_string()),
|
||||
));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
file.flush().await.map_err(|source| SendHttpRequestError::WriteResponseBody {
|
||||
path: body_path.clone(),
|
||||
source,
|
||||
})?;
|
||||
drop(body_stream);
|
||||
|
||||
if let Some(task) = request_body_capture_task.take() {
|
||||
match task.await {
|
||||
Ok(Ok(total)) => {
|
||||
response.request_content_length = Some(usize_to_i32(total));
|
||||
}
|
||||
Ok(Err(err)) => request_body_capture_error = Some(err),
|
||||
Err(err) => request_body_capture_error = Some(err.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(err) = request_body_capture_error.take() {
|
||||
response.error = Some(append_error_message(
|
||||
response.error.take(),
|
||||
format!("Request succeeded but failed to store request body: {err}"),
|
||||
));
|
||||
}
|
||||
|
||||
if let Err(join_err) = event_handle.await {
|
||||
warn!("Failed to join response event task: {}", join_err);
|
||||
}
|
||||
|
||||
if let Some(err) = body_read_error {
|
||||
if persist_response {
|
||||
let _ = persist_response_error(
|
||||
params.query_manager,
|
||||
params.blob_manager,
|
||||
¶ms.update_source,
|
||||
&response,
|
||||
started_at,
|
||||
err.to_string(),
|
||||
request_started_url,
|
||||
);
|
||||
}
|
||||
persist_cookie_jar(params.query_manager, cookie_jar.as_mut(), cookie_store.as_ref())?;
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
let compressed_length = http_response.content_length.unwrap_or(written_bytes as u64);
|
||||
let final_response = HttpResponse {
|
||||
body_path: Some(body_path.to_string_lossy().to_string()),
|
||||
content_length: Some(usize_to_i32(response_body.len())),
|
||||
content_length_compressed: Some(u64_to_i32(body_stats.size_compressed)),
|
||||
content_length: Some(usize_to_i32(written_bytes)),
|
||||
content_length_compressed: Some(u64_to_i32(compressed_length)),
|
||||
elapsed: duration_to_i32(started_at.elapsed()),
|
||||
elapsed_headers: headers_elapsed,
|
||||
elapsed_dns: dns_elapsed.load(Ordering::Relaxed),
|
||||
state: HttpResponseState::Closed,
|
||||
..response
|
||||
};
|
||||
@@ -625,14 +790,60 @@ pub async fn send_http_request<T: TemplateCallback>(
|
||||
response = final_response;
|
||||
}
|
||||
|
||||
if let Err(join_err) = event_handle.await {
|
||||
warn!("Failed to join response event task: {}", join_err);
|
||||
}
|
||||
persist_cookie_jar(params.query_manager, cookie_jar.as_mut(), cookie_store.as_ref())?;
|
||||
|
||||
Ok(SendHttpRequestResult { rendered_request, response, response_body })
|
||||
}
|
||||
|
||||
fn persist_request_body_bytes(
|
||||
blob_manager: &BlobManager,
|
||||
body_id: &str,
|
||||
bytes: &[u8],
|
||||
) -> std::result::Result<(), String> {
|
||||
if bytes.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let blob_ctx = blob_manager.connect();
|
||||
let mut offset = 0;
|
||||
let mut chunk_index: i32 = 0;
|
||||
while offset < bytes.len() {
|
||||
let end = std::cmp::min(offset + REQUEST_BODY_CHUNK_SIZE, bytes.len());
|
||||
let chunk = BodyChunk::new(body_id, chunk_index, bytes[offset..end].to_vec());
|
||||
blob_ctx.insert_chunk(&chunk).map_err(|e| e.to_string())?;
|
||||
chunk_index += 1;
|
||||
offset = end;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn persist_request_body_stream(
|
||||
blob_manager: BlobManager,
|
||||
body_id: String,
|
||||
mut rx: tokio::sync::mpsc::UnboundedReceiver<Vec<u8>>,
|
||||
) -> std::result::Result<usize, String> {
|
||||
let mut chunk_index: i32 = 0;
|
||||
let mut total_bytes = 0usize;
|
||||
while let Some(data) = rx.recv().await {
|
||||
total_bytes += data.len();
|
||||
if data.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let chunk = BodyChunk::new(&body_id, chunk_index, data);
|
||||
blob_manager.connect().insert_chunk(&chunk).map_err(|e| e.to_string())?;
|
||||
chunk_index += 1;
|
||||
}
|
||||
|
||||
Ok(total_bytes)
|
||||
}
|
||||
|
||||
fn append_error_message(existing_error: Option<String>, message: String) -> String {
|
||||
match existing_error {
|
||||
Some(existing) => format!("{existing}; {message}"),
|
||||
None => message,
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_environment_chain(
|
||||
query_manager: &QueryManager,
|
||||
request: &HttpRequest,
|
||||
|
||||
Reference in New Issue
Block a user