Files
yaak-mountain-loop/src-tauri/yaak-http/src/transaction.rs
2025-12-28 08:07:42 -08:00

392 lines
15 KiB
Rust

use crate::error::Result;
use crate::sender::{HttpResponse, HttpResponseEvent, HttpSender, RedirectBehavior};
use crate::types::SendableHttpRequest;
use tokio::sync::mpsc;
use tokio::sync::watch::Receiver;
/// HTTP Transaction that manages the lifecycle of a request, including redirect handling
pub struct HttpTransaction<S: HttpSender> {
sender: S,
max_redirects: usize,
}
impl<S: HttpSender> HttpTransaction<S> {
/// Create a new transaction with default settings
pub fn new(sender: S) -> Self {
Self { sender, max_redirects: 10 }
}
/// Create a new transaction with custom max redirects
pub fn with_max_redirects(sender: S, max_redirects: usize) -> Self {
Self { sender, max_redirects }
}
/// Execute the request with cancellation support.
/// Returns an HttpResponse with unconsumed body - caller decides how to consume it.
/// Events are sent through the provided channel.
pub async fn execute_with_cancellation(
&self,
request: SendableHttpRequest,
mut cancelled_rx: Receiver<bool>,
event_tx: mpsc::Sender<HttpResponseEvent>,
) -> Result<HttpResponse> {
let mut redirect_count = 0;
let mut current_url = request.url;
let mut current_method = request.method;
let mut current_headers = request.headers;
let mut current_body = request.body;
// Helper to send events (ignores errors if receiver is dropped or channel is full)
let send_event = |event: HttpResponseEvent| {
let _ = event_tx.try_send(event);
};
loop {
// Check for cancellation before each request
if *cancelled_rx.borrow() {
return Err(crate::error::Error::RequestCanceledError);
}
// Build request for this iteration
let req = SendableHttpRequest {
url: current_url.clone(),
method: current_method.clone(),
headers: current_headers.clone(),
body: current_body,
options: request.options.clone(),
};
// Send the request
send_event(HttpResponseEvent::Setting(
"redirects".to_string(),
request.options.follow_redirects.to_string(),
));
// Execute with cancellation support
let response = tokio::select! {
result = self.sender.send(req, event_tx.clone()) => result?,
_ = cancelled_rx.changed() => {
return Err(crate::error::Error::RequestCanceledError);
}
};
if !Self::is_redirect(response.status) {
// Not a redirect - return the response for caller to consume body
return Ok(response);
}
if !request.options.follow_redirects {
// Redirects disabled - return the redirect response as-is
return Ok(response);
}
// Check if we've exceeded max redirects
if redirect_count >= self.max_redirects {
// Drain the response before returning error
let _ = response.drain().await;
return Err(crate::error::Error::RequestError(format!(
"Maximum redirect limit ({}) exceeded",
self.max_redirects
)));
}
// Extract Location header before draining (headers are available immediately)
// HTTP headers are case-insensitive, so we need to search for any casing
let location = response
.headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("location"))
.map(|(_, v)| v.clone())
.ok_or_else(|| {
crate::error::Error::RequestError(
"Redirect response missing Location header".to_string(),
)
})?;
// Also get status before draining
let status = response.status;
send_event(HttpResponseEvent::Info("Ignoring the response body".to_string()));
// Drain the redirect response body before following
response.drain().await?;
// Update the request URL
current_url = if location.starts_with("http://") || location.starts_with("https://") {
// Absolute URL
location
} else if location.starts_with('/') {
// Absolute path - need to extract base URL from current request
let base_url = Self::extract_base_url(&current_url)?;
format!("{}{}", base_url, location)
} else {
// Relative path - need to resolve relative to current path
let base_path = Self::extract_base_path(&current_url)?;
format!("{}/{}", base_path, location)
};
// Determine redirect behavior based on status code and method
let behavior = if status == 303 {
// 303 See Other always changes to GET
RedirectBehavior::DropBody
} else if (status == 301 || status == 302) && current_method == "POST" {
// For 301/302, change POST to GET (common browser behavior)
RedirectBehavior::DropBody
} else {
// For 307 and 308, the method and body are preserved
// Also for 301/302 with non-POST methods
RedirectBehavior::Preserve
};
send_event(HttpResponseEvent::Redirect {
url: current_url.clone(),
status,
behavior: behavior.clone(),
});
// Handle method changes for certain redirect codes
if matches!(behavior, RedirectBehavior::DropBody) {
if current_method != "GET" {
current_method = "GET".to_string();
}
// Remove content-related headers
current_headers.retain(|h| {
let name_lower = h.0.to_lowercase();
!name_lower.starts_with("content-") && name_lower != "transfer-encoding"
});
}
// Reset body for next iteration (since it was moved in the send call)
// For redirects that change method to GET or for all redirects since body was consumed
current_body = None;
redirect_count += 1;
}
}
/// Check if a status code indicates a redirect
fn is_redirect(status: u16) -> bool {
matches!(status, 301 | 302 | 303 | 307 | 308)
}
/// Extract the base URL (scheme + host) from a full URL
fn extract_base_url(url: &str) -> Result<String> {
// Find the position after "://"
let scheme_end = url.find("://").ok_or_else(|| {
crate::error::Error::RequestError(format!("Invalid URL format: {}", url))
})?;
// Find the first '/' after the scheme
let path_start = url[scheme_end + 3..].find('/');
if let Some(idx) = path_start {
Ok(url[..scheme_end + 3 + idx].to_string())
} else {
// No path, return entire URL
Ok(url.to_string())
}
}
/// Extract the base path (everything except the last segment) from a URL
fn extract_base_path(url: &str) -> Result<String> {
if let Some(last_slash) = url.rfind('/') {
// Don't include the trailing slash if it's part of the host
if url[..last_slash].ends_with("://") || url[..last_slash].ends_with(':') {
Ok(url.to_string())
} else {
Ok(url[..last_slash].to_string())
}
} else {
Ok(url.to_string())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::decompress::ContentEncoding;
use crate::sender::{HttpResponseEvent, HttpSender};
use async_trait::async_trait;
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use tokio::io::AsyncRead;
use tokio::sync::Mutex;
/// Mock sender for testing
struct MockSender {
responses: Arc<Mutex<Vec<MockResponse>>>,
}
struct MockResponse {
status: u16,
headers: HashMap<String, String>,
body: Vec<u8>,
}
impl MockSender {
fn new(responses: Vec<MockResponse>) -> Self {
Self { responses: Arc::new(Mutex::new(responses)) }
}
}
#[async_trait]
impl HttpSender for MockSender {
async fn send(
&self,
_request: SendableHttpRequest,
_event_tx: mpsc::Sender<HttpResponseEvent>,
) -> Result<HttpResponse> {
let mut responses = self.responses.lock().await;
if responses.is_empty() {
Err(crate::error::Error::RequestError("No more mock responses".to_string()))
} else {
let mock = responses.remove(0);
// Create a simple in-memory stream from the body
let body_stream: Pin<Box<dyn AsyncRead + Send>> =
Box::pin(std::io::Cursor::new(mock.body));
Ok(HttpResponse::new(
mock.status,
None, // status_reason
mock.headers,
HashMap::new(),
None, // content_length
"https://example.com".to_string(), // url
None, // remote_addr
Some("HTTP/1.1".to_string()), // version
body_stream,
ContentEncoding::Identity,
))
}
}
}
#[tokio::test]
async fn test_transaction_no_redirect() {
let response = MockResponse { status: 200, headers: HashMap::new(), body: b"OK".to_vec() };
let sender = MockSender::new(vec![response]);
let transaction = HttpTransaction::new(sender);
let request = SendableHttpRequest {
url: "https://example.com".to_string(),
method: "GET".to_string(),
headers: vec![],
..Default::default()
};
let (_tx, rx) = tokio::sync::watch::channel(false);
let (event_tx, _event_rx) = mpsc::channel(100);
let result = transaction.execute_with_cancellation(request, rx, event_tx).await.unwrap();
assert_eq!(result.status, 200);
// Consume the body to verify it
let (body, _) = result.bytes().await.unwrap();
assert_eq!(body, b"OK");
}
#[tokio::test]
async fn test_transaction_single_redirect() {
let mut redirect_headers = HashMap::new();
redirect_headers.insert("Location".to_string(), "https://example.com/new".to_string());
let responses = vec![
MockResponse { status: 302, headers: redirect_headers, body: vec![] },
MockResponse { status: 200, headers: HashMap::new(), body: b"Final".to_vec() },
];
let sender = MockSender::new(responses);
let transaction = HttpTransaction::new(sender);
let request = SendableHttpRequest {
url: "https://example.com/old".to_string(),
method: "GET".to_string(),
options: crate::types::SendableHttpRequestOptions {
follow_redirects: true,
..Default::default()
},
..Default::default()
};
let (_tx, rx) = tokio::sync::watch::channel(false);
let (event_tx, _event_rx) = mpsc::channel(100);
let result = transaction.execute_with_cancellation(request, rx, event_tx).await.unwrap();
assert_eq!(result.status, 200);
let (body, _) = result.bytes().await.unwrap();
assert_eq!(body, b"Final");
}
#[tokio::test]
async fn test_transaction_max_redirects_exceeded() {
let mut redirect_headers = HashMap::new();
redirect_headers.insert("Location".to_string(), "https://example.com/loop".to_string());
// Create more redirects than allowed
let responses: Vec<MockResponse> = (0..12)
.map(|_| MockResponse { status: 302, headers: redirect_headers.clone(), body: vec![] })
.collect();
let sender = MockSender::new(responses);
let transaction = HttpTransaction::with_max_redirects(sender, 10);
let request = SendableHttpRequest {
url: "https://example.com/start".to_string(),
method: "GET".to_string(),
options: crate::types::SendableHttpRequestOptions {
follow_redirects: true,
..Default::default()
},
..Default::default()
};
let (_tx, rx) = tokio::sync::watch::channel(false);
let (event_tx, _event_rx) = mpsc::channel(100);
let result = transaction.execute_with_cancellation(request, rx, event_tx).await;
if let Err(crate::error::Error::RequestError(msg)) = result {
assert!(msg.contains("Maximum redirect limit"));
} else {
panic!("Expected RequestError with max redirect message. Got {result:?}");
}
}
#[test]
fn test_is_redirect() {
assert!(HttpTransaction::<MockSender>::is_redirect(301));
assert!(HttpTransaction::<MockSender>::is_redirect(302));
assert!(HttpTransaction::<MockSender>::is_redirect(303));
assert!(HttpTransaction::<MockSender>::is_redirect(307));
assert!(HttpTransaction::<MockSender>::is_redirect(308));
assert!(!HttpTransaction::<MockSender>::is_redirect(200));
assert!(!HttpTransaction::<MockSender>::is_redirect(404));
assert!(!HttpTransaction::<MockSender>::is_redirect(500));
}
#[test]
fn test_extract_base_url() {
let result =
HttpTransaction::<MockSender>::extract_base_url("https://example.com/path/to/resource");
assert_eq!(result.unwrap(), "https://example.com");
let result = HttpTransaction::<MockSender>::extract_base_url("http://localhost:8080/api");
assert_eq!(result.unwrap(), "http://localhost:8080");
let result = HttpTransaction::<MockSender>::extract_base_url("invalid-url");
assert!(result.is_err());
}
#[test]
fn test_extract_base_path() {
let result = HttpTransaction::<MockSender>::extract_base_path(
"https://example.com/path/to/resource",
);
assert_eq!(result.unwrap(), "https://example.com/path/to");
let result = HttpTransaction::<MockSender>::extract_base_path("https://example.com/single");
assert_eq!(result.unwrap(), "https://example.com");
let result = HttpTransaction::<MockSender>::extract_base_path("https://example.com/");
assert_eq!(result.unwrap(), "https://example.com");
}
}