mirror of
https://github.com/mountain-loop/yaak.git
synced 2026-03-26 11:21:16 +01:00
Split up HTTP sending logic (#320)
This commit is contained in:
78
src-tauri/yaak-http/src/chained_reader.rs
Normal file
78
src-tauri/yaak-http/src/chained_reader.rs
Normal file
@@ -0,0 +1,78 @@
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, ReadBuf};
|
||||
|
||||
/// A stream that chains multiple AsyncRead sources together
|
||||
pub(crate) struct ChainedReader {
|
||||
readers: Vec<ReaderType>,
|
||||
current_index: usize,
|
||||
current_reader: Option<Box<dyn AsyncRead + Send + Unpin + 'static>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) enum ReaderType {
|
||||
Bytes(Vec<u8>),
|
||||
FilePath(String),
|
||||
}
|
||||
|
||||
impl ChainedReader {
|
||||
pub(crate) fn new(readers: Vec<ReaderType>) -> Self {
|
||||
Self { readers, current_index: 0, current_reader: None }
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for ChainedReader {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
loop {
|
||||
// Try to read from current reader if we have one
|
||||
if let Some(ref mut reader) = self.current_reader {
|
||||
let before_len = buf.filled().len();
|
||||
return match Pin::new(reader).poll_read(cx, buf) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
if buf.filled().len() == before_len && buf.remaining() > 0 {
|
||||
// Current reader is exhausted, move to next
|
||||
self.current_reader = None;
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
|
||||
Poll::Pending => Poll::Pending,
|
||||
};
|
||||
}
|
||||
|
||||
// We need to get the next reader
|
||||
if self.current_index >= self.readers.len() {
|
||||
// No more readers
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
// Get the next reader
|
||||
let reader_type = self.readers[self.current_index].clone();
|
||||
self.current_index += 1;
|
||||
|
||||
match reader_type {
|
||||
ReaderType::Bytes(bytes) => {
|
||||
self.current_reader = Some(Box::new(io::Cursor::new(bytes)));
|
||||
}
|
||||
ReaderType::FilePath(path) => {
|
||||
// We need to handle file opening synchronously in poll_read
|
||||
// This is a limitation - we'll use blocking file open
|
||||
match std::fs::File::open(&path) {
|
||||
Ok(file) => {
|
||||
// Convert std File to tokio File
|
||||
let tokio_file = tokio::fs::File::from_std(file);
|
||||
self.current_reader = Some(Box::new(tokio_file));
|
||||
}
|
||||
Err(e) => return Poll::Ready(Err(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,9 @@
|
||||
use crate::dns::LocalhostResolver;
|
||||
use crate::error::Result;
|
||||
use log::{debug, info, warn};
|
||||
use reqwest::redirect::Policy;
|
||||
use reqwest::{Client, Proxy};
|
||||
use reqwest::{Client, Proxy, redirect};
|
||||
use reqwest_cookie_store::CookieStoreMutex;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use yaak_tls::{ClientCertificateConfig, get_tls_config};
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -29,11 +27,9 @@ pub enum HttpConnectionProxySetting {
|
||||
#[derive(Clone)]
|
||||
pub struct HttpConnectionOptions {
|
||||
pub id: String,
|
||||
pub follow_redirects: bool,
|
||||
pub validate_certificates: bool,
|
||||
pub proxy: HttpConnectionProxySetting,
|
||||
pub cookie_provider: Option<Arc<CookieStoreMutex>>,
|
||||
pub timeout: Option<Duration>,
|
||||
pub client_certificate: Option<ClientCertificateConfig>,
|
||||
}
|
||||
|
||||
@@ -41,9 +37,11 @@ impl HttpConnectionOptions {
|
||||
pub(crate) fn build_client(&self) -> Result<Client> {
|
||||
let mut client = Client::builder()
|
||||
.connection_verbose(true)
|
||||
.gzip(true)
|
||||
.brotli(true)
|
||||
.deflate(true)
|
||||
.redirect(redirect::Policy::none())
|
||||
// Decompression is handled by HttpTransaction, not reqwest
|
||||
.no_gzip()
|
||||
.no_brotli()
|
||||
.no_deflate()
|
||||
.referer(false)
|
||||
.tls_info(true);
|
||||
|
||||
@@ -55,12 +53,6 @@ impl HttpConnectionOptions {
|
||||
// Configure DNS resolver
|
||||
client = client.dns_resolver(LocalhostResolver::new());
|
||||
|
||||
// Configure redirects
|
||||
client = client.redirect(match self.follow_redirects {
|
||||
true => Policy::limited(10), // TODO: Handle redirects natively
|
||||
false => Policy::none(),
|
||||
});
|
||||
|
||||
// Configure cookie provider
|
||||
if let Some(p) = &self.cookie_provider {
|
||||
client = client.cookie_provider(Arc::clone(&p));
|
||||
@@ -79,11 +71,6 @@ impl HttpConnectionOptions {
|
||||
}
|
||||
}
|
||||
|
||||
// Configure timeout
|
||||
if let Some(d) = self.timeout {
|
||||
client = client.timeout(d);
|
||||
}
|
||||
|
||||
info!(
|
||||
"Building new HTTP client validate_certificates={} client_cert={}",
|
||||
self.validate_certificates,
|
||||
|
||||
188
src-tauri/yaak-http/src/decompress.rs
Normal file
188
src-tauri/yaak-http/src/decompress.rs
Normal file
@@ -0,0 +1,188 @@
|
||||
use crate::error::{Error, Result};
|
||||
use async_compression::tokio::bufread::{
|
||||
BrotliDecoder, DeflateDecoder as AsyncDeflateDecoder, GzipDecoder,
|
||||
ZstdDecoder as AsyncZstdDecoder,
|
||||
};
|
||||
use flate2::read::{DeflateDecoder, GzDecoder};
|
||||
use std::io::Read;
|
||||
use tokio::io::{AsyncBufRead, AsyncRead};
|
||||
|
||||
/// Supported compression encodings
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ContentEncoding {
|
||||
Gzip,
|
||||
Deflate,
|
||||
Brotli,
|
||||
Zstd,
|
||||
Identity,
|
||||
}
|
||||
|
||||
impl ContentEncoding {
|
||||
/// Parse a Content-Encoding header value into an encoding type.
|
||||
/// Returns Identity for unknown or missing encodings.
|
||||
pub fn from_header(value: Option<&str>) -> Self {
|
||||
match value.map(|s| s.trim().to_lowercase()).as_deref() {
|
||||
Some("gzip") | Some("x-gzip") => ContentEncoding::Gzip,
|
||||
Some("deflate") => ContentEncoding::Deflate,
|
||||
Some("br") => ContentEncoding::Brotli,
|
||||
Some("zstd") => ContentEncoding::Zstd,
|
||||
_ => ContentEncoding::Identity,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of decompression, containing both the decompressed data and size info
|
||||
#[derive(Debug)]
|
||||
pub struct DecompressResult {
|
||||
pub data: Vec<u8>,
|
||||
pub compressed_size: u64,
|
||||
pub decompressed_size: u64,
|
||||
}
|
||||
|
||||
/// Decompress data based on the Content-Encoding.
|
||||
/// Returns the original data unchanged if encoding is Identity or unknown.
|
||||
pub fn decompress(data: Vec<u8>, encoding: ContentEncoding) -> Result<DecompressResult> {
|
||||
let compressed_size = data.len() as u64;
|
||||
|
||||
let decompressed = match encoding {
|
||||
ContentEncoding::Identity => data,
|
||||
ContentEncoding::Gzip => decompress_gzip(&data)?,
|
||||
ContentEncoding::Deflate => decompress_deflate(&data)?,
|
||||
ContentEncoding::Brotli => decompress_brotli(&data)?,
|
||||
ContentEncoding::Zstd => decompress_zstd(&data)?,
|
||||
};
|
||||
|
||||
let decompressed_size = decompressed.len() as u64;
|
||||
|
||||
Ok(DecompressResult { data: decompressed, compressed_size, decompressed_size })
|
||||
}
|
||||
|
||||
fn decompress_gzip(data: &[u8]) -> Result<Vec<u8>> {
|
||||
let mut decoder = GzDecoder::new(data);
|
||||
let mut decompressed = Vec::new();
|
||||
decoder
|
||||
.read_to_end(&mut decompressed)
|
||||
.map_err(|e| Error::DecompressionError(format!("gzip decompression failed: {}", e)))?;
|
||||
Ok(decompressed)
|
||||
}
|
||||
|
||||
fn decompress_deflate(data: &[u8]) -> Result<Vec<u8>> {
|
||||
let mut decoder = DeflateDecoder::new(data);
|
||||
let mut decompressed = Vec::new();
|
||||
decoder
|
||||
.read_to_end(&mut decompressed)
|
||||
.map_err(|e| Error::DecompressionError(format!("deflate decompression failed: {}", e)))?;
|
||||
Ok(decompressed)
|
||||
}
|
||||
|
||||
fn decompress_brotli(data: &[u8]) -> Result<Vec<u8>> {
|
||||
let mut decompressed = Vec::new();
|
||||
brotli::BrotliDecompress(&mut std::io::Cursor::new(data), &mut decompressed)
|
||||
.map_err(|e| Error::DecompressionError(format!("brotli decompression failed: {}", e)))?;
|
||||
Ok(decompressed)
|
||||
}
|
||||
|
||||
fn decompress_zstd(data: &[u8]) -> Result<Vec<u8>> {
|
||||
zstd::stream::decode_all(std::io::Cursor::new(data))
|
||||
.map_err(|e| Error::DecompressionError(format!("zstd decompression failed: {}", e)))
|
||||
}
|
||||
|
||||
/// Create a streaming decompressor that wraps an async reader.
|
||||
/// Returns an AsyncRead that decompresses data on-the-fly.
|
||||
pub fn streaming_decoder<R: AsyncBufRead + Unpin + Send + 'static>(
|
||||
reader: R,
|
||||
encoding: ContentEncoding,
|
||||
) -> Box<dyn AsyncRead + Unpin + Send> {
|
||||
match encoding {
|
||||
ContentEncoding::Identity => Box::new(reader),
|
||||
ContentEncoding::Gzip => Box::new(GzipDecoder::new(reader)),
|
||||
ContentEncoding::Deflate => Box::new(AsyncDeflateDecoder::new(reader)),
|
||||
ContentEncoding::Brotli => Box::new(BrotliDecoder::new(reader)),
|
||||
ContentEncoding::Zstd => Box::new(AsyncZstdDecoder::new(reader)),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use flate2::Compression;
|
||||
use flate2::write::GzEncoder;
|
||||
use std::io::Write;
|
||||
|
||||
#[test]
|
||||
fn test_content_encoding_from_header() {
|
||||
assert_eq!(ContentEncoding::from_header(Some("gzip")), ContentEncoding::Gzip);
|
||||
assert_eq!(ContentEncoding::from_header(Some("x-gzip")), ContentEncoding::Gzip);
|
||||
assert_eq!(ContentEncoding::from_header(Some("GZIP")), ContentEncoding::Gzip);
|
||||
assert_eq!(ContentEncoding::from_header(Some("deflate")), ContentEncoding::Deflate);
|
||||
assert_eq!(ContentEncoding::from_header(Some("br")), ContentEncoding::Brotli);
|
||||
assert_eq!(ContentEncoding::from_header(Some("zstd")), ContentEncoding::Zstd);
|
||||
assert_eq!(ContentEncoding::from_header(Some("identity")), ContentEncoding::Identity);
|
||||
assert_eq!(ContentEncoding::from_header(Some("unknown")), ContentEncoding::Identity);
|
||||
assert_eq!(ContentEncoding::from_header(None), ContentEncoding::Identity);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decompress_identity() {
|
||||
let data = b"hello world".to_vec();
|
||||
let result = decompress(data.clone(), ContentEncoding::Identity).unwrap();
|
||||
assert_eq!(result.data, data);
|
||||
assert_eq!(result.compressed_size, 11);
|
||||
assert_eq!(result.decompressed_size, 11);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decompress_gzip() {
|
||||
// Compress some data with gzip
|
||||
let original = b"hello world, this is a test of gzip compression";
|
||||
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
|
||||
encoder.write_all(original).unwrap();
|
||||
let compressed = encoder.finish().unwrap();
|
||||
|
||||
let result = decompress(compressed.clone(), ContentEncoding::Gzip).unwrap();
|
||||
assert_eq!(result.data, original);
|
||||
assert_eq!(result.compressed_size, compressed.len() as u64);
|
||||
assert_eq!(result.decompressed_size, original.len() as u64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decompress_deflate() {
|
||||
// Compress some data with deflate
|
||||
let original = b"hello world, this is a test of deflate compression";
|
||||
let mut encoder = flate2::write::DeflateEncoder::new(Vec::new(), Compression::default());
|
||||
encoder.write_all(original).unwrap();
|
||||
let compressed = encoder.finish().unwrap();
|
||||
|
||||
let result = decompress(compressed.clone(), ContentEncoding::Deflate).unwrap();
|
||||
assert_eq!(result.data, original);
|
||||
assert_eq!(result.compressed_size, compressed.len() as u64);
|
||||
assert_eq!(result.decompressed_size, original.len() as u64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decompress_brotli() {
|
||||
// Compress some data with brotli
|
||||
let original = b"hello world, this is a test of brotli compression";
|
||||
let mut compressed = Vec::new();
|
||||
let mut writer = brotli::CompressorWriter::new(&mut compressed, 4096, 4, 22);
|
||||
writer.write_all(original).unwrap();
|
||||
drop(writer);
|
||||
|
||||
let result = decompress(compressed.clone(), ContentEncoding::Brotli).unwrap();
|
||||
assert_eq!(result.data, original);
|
||||
assert_eq!(result.compressed_size, compressed.len() as u64);
|
||||
assert_eq!(result.decompressed_size, original.len() as u64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decompress_zstd() {
|
||||
// Compress some data with zstd
|
||||
let original = b"hello world, this is a test of zstd compression";
|
||||
let compressed = zstd::stream::encode_all(std::io::Cursor::new(original), 3).unwrap();
|
||||
|
||||
let result = decompress(compressed.clone(), ContentEncoding::Zstd).unwrap();
|
||||
assert_eq!(result.data, original);
|
||||
assert_eq!(result.compressed_size, compressed.len() as u64);
|
||||
assert_eq!(result.decompressed_size, original.len() as u64);
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,21 @@ pub enum Error {
|
||||
|
||||
#[error(transparent)]
|
||||
TlsError(#[from] yaak_tls::error::Error),
|
||||
|
||||
#[error("Request failed with {0:?}")]
|
||||
RequestError(String),
|
||||
|
||||
#[error("Request canceled")]
|
||||
RequestCanceledError,
|
||||
|
||||
#[error("Timeout of {0:?} reached")]
|
||||
RequestTimeout(std::time::Duration),
|
||||
|
||||
#[error("Decompression error: {0}")]
|
||||
DecompressionError(String),
|
||||
|
||||
#[error("Failed to read response body: {0}")]
|
||||
BodyReadError(String),
|
||||
}
|
||||
|
||||
impl Serialize for Error {
|
||||
|
||||
@@ -2,11 +2,17 @@ use crate::manager::HttpConnectionManager;
|
||||
use tauri::plugin::{Builder, TauriPlugin};
|
||||
use tauri::{Manager, Runtime};
|
||||
|
||||
mod chained_reader;
|
||||
pub mod client;
|
||||
pub mod decompress;
|
||||
pub mod dns;
|
||||
pub mod error;
|
||||
pub mod manager;
|
||||
pub mod path_placeholders;
|
||||
mod proto;
|
||||
pub mod sender;
|
||||
pub mod transaction;
|
||||
pub mod types;
|
||||
|
||||
pub fn init<R: Runtime>() -> TauriPlugin<R> {
|
||||
Builder::new("yaak-http")
|
||||
|
||||
@@ -2,7 +2,7 @@ use yaak_models::models::HttpUrlParameter;
|
||||
|
||||
pub fn apply_path_placeholders(
|
||||
url: &str,
|
||||
parameters: Vec<HttpUrlParameter>,
|
||||
parameters: &Vec<HttpUrlParameter>,
|
||||
) -> (String, Vec<HttpUrlParameter>) {
|
||||
let mut new_parameters = Vec::new();
|
||||
|
||||
@@ -18,7 +18,7 @@ pub fn apply_path_placeholders(
|
||||
|
||||
// Remove as param if it modified the URL
|
||||
if old_url_string == *url {
|
||||
new_parameters.push(p);
|
||||
new_parameters.push(p.to_owned());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -156,7 +156,7 @@ mod placeholder_tests {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let (url, url_parameters) = apply_path_placeholders(&req.url, req.url_parameters);
|
||||
let (url, url_parameters) = apply_path_placeholders(&req.url, &req.url_parameters);
|
||||
|
||||
// Pattern match back to access it
|
||||
assert_eq!(url, "example.com/aaa/bar");
|
||||
|
||||
29
src-tauri/yaak-http/src/proto.rs
Normal file
29
src-tauri/yaak-http/src/proto.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use reqwest::Url;
|
||||
use std::str::FromStr;
|
||||
|
||||
pub(crate) fn ensure_proto(url_str: &str) -> String {
|
||||
if url_str.is_empty() {
|
||||
return "".to_string();
|
||||
}
|
||||
|
||||
if url_str.starts_with("http://") || url_str.starts_with("https://") {
|
||||
return url_str.to_string();
|
||||
}
|
||||
|
||||
// Url::from_str will fail without a proto, so add one
|
||||
let parseable_url = format!("http://{}", url_str);
|
||||
if let Ok(u) = Url::from_str(parseable_url.as_str()) {
|
||||
match u.host() {
|
||||
Some(host) => {
|
||||
let h = host.to_string();
|
||||
// These TLDs force HTTPS
|
||||
if h.ends_with(".app") || h.ends_with(".dev") || h.ends_with(".page") {
|
||||
return format!("https://{url_str}");
|
||||
}
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
}
|
||||
|
||||
format!("http://{url_str}")
|
||||
}
|
||||
409
src-tauri/yaak-http/src/sender.rs
Normal file
409
src-tauri/yaak-http/src/sender.rs
Normal file
@@ -0,0 +1,409 @@
|
||||
use crate::decompress::{ContentEncoding, streaming_decoder};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::types::{SendableBody, SendableHttpRequest};
|
||||
use async_trait::async_trait;
|
||||
use futures_util::StreamExt;
|
||||
use reqwest::{Client, Method, Version};
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Display;
|
||||
use std::pin::Pin;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, BufReader};
|
||||
use tokio_util::io::StreamReader;
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct HttpResponseTiming {
|
||||
pub headers: Duration,
|
||||
pub body: Duration,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum HttpResponseEvent {
|
||||
Setting(String, String),
|
||||
Info(String),
|
||||
SendUrl { method: String, path: String },
|
||||
ReceiveUrl { version: Version, status: String },
|
||||
HeaderUp(String, String),
|
||||
HeaderDown(String, String),
|
||||
HeaderUpDone,
|
||||
HeaderDownDone,
|
||||
}
|
||||
|
||||
impl Display for HttpResponseEvent {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
HttpResponseEvent::Setting(name, value) => write!(f, "* Setting {}={}", name, value),
|
||||
HttpResponseEvent::Info(s) => write!(f, "* {}", s),
|
||||
HttpResponseEvent::SendUrl { method, path } => write!(f, "> {} {}", method, path),
|
||||
HttpResponseEvent::ReceiveUrl { version, status } => {
|
||||
write!(f, "< {} {}", version_to_str(version), status)
|
||||
}
|
||||
HttpResponseEvent::HeaderUp(name, value) => write!(f, "> {}: {}", name, value),
|
||||
HttpResponseEvent::HeaderUpDone => write!(f, ">"),
|
||||
HttpResponseEvent::HeaderDown(name, value) => write!(f, "< {}: {}", name, value),
|
||||
HttpResponseEvent::HeaderDownDone => write!(f, "<"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Statistics about the body after consumption
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct BodyStats {
|
||||
/// Size of the body as received over the wire (before decompression)
|
||||
pub size_compressed: u64,
|
||||
/// Size of the body after decompression
|
||||
pub size_decompressed: u64,
|
||||
}
|
||||
|
||||
/// Type alias for the body stream
|
||||
type BodyStream = Pin<Box<dyn AsyncRead + Send>>;
|
||||
|
||||
/// HTTP response with deferred body consumption.
|
||||
/// Headers are available immediately after send(), body can be consumed in different ways.
|
||||
/// Note: Debug is manually implemented since BodyStream doesn't implement Debug.
|
||||
pub struct HttpResponse {
|
||||
/// HTTP status code
|
||||
pub status: u16,
|
||||
/// HTTP status reason phrase (e.g., "OK", "Not Found")
|
||||
pub status_reason: Option<String>,
|
||||
/// Response headers
|
||||
pub headers: HashMap<String, String>,
|
||||
/// Content-Length from headers (may differ from actual body size)
|
||||
pub content_length: Option<u64>,
|
||||
/// Final URL (after redirects)
|
||||
pub url: String,
|
||||
/// Remote address of the server
|
||||
pub remote_addr: Option<String>,
|
||||
/// HTTP version (e.g., "HTTP/1.1", "HTTP/2")
|
||||
pub version: Option<String>,
|
||||
/// Timing information
|
||||
pub timing: HttpResponseTiming,
|
||||
|
||||
/// The body stream (consumed when calling bytes(), text(), write_to_file(), or drain())
|
||||
body_stream: Option<BodyStream>,
|
||||
/// Content-Encoding for decompression
|
||||
encoding: ContentEncoding,
|
||||
/// Start time for timing the body read
|
||||
start_time: Instant,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for HttpResponse {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("HttpResponse")
|
||||
.field("status", &self.status)
|
||||
.field("status_reason", &self.status_reason)
|
||||
.field("headers", &self.headers)
|
||||
.field("content_length", &self.content_length)
|
||||
.field("url", &self.url)
|
||||
.field("remote_addr", &self.remote_addr)
|
||||
.field("version", &self.version)
|
||||
.field("timing", &self.timing)
|
||||
.field("body_stream", &"<stream>")
|
||||
.field("encoding", &self.encoding)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl HttpResponse {
|
||||
/// Create a new HttpResponse with an unconsumed body stream
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
status: u16,
|
||||
status_reason: Option<String>,
|
||||
headers: HashMap<String, String>,
|
||||
content_length: Option<u64>,
|
||||
url: String,
|
||||
remote_addr: Option<String>,
|
||||
version: Option<String>,
|
||||
timing: HttpResponseTiming,
|
||||
body_stream: BodyStream,
|
||||
encoding: ContentEncoding,
|
||||
start_time: Instant,
|
||||
) -> Self {
|
||||
Self {
|
||||
status,
|
||||
status_reason,
|
||||
headers,
|
||||
content_length,
|
||||
url,
|
||||
remote_addr,
|
||||
version,
|
||||
timing,
|
||||
body_stream: Some(body_stream),
|
||||
encoding,
|
||||
start_time,
|
||||
}
|
||||
}
|
||||
|
||||
/// Consume the body and return it as bytes (loads entire body into memory).
|
||||
/// Also decompresses the body if Content-Encoding is set.
|
||||
pub async fn bytes(mut self) -> Result<(Vec<u8>, BodyStats, HttpResponseTiming)> {
|
||||
let stream = self.body_stream.take().ok_or_else(|| {
|
||||
Error::RequestError("Response body has already been consumed".to_string())
|
||||
})?;
|
||||
|
||||
let buf_reader = BufReader::new(stream);
|
||||
let mut decoder = streaming_decoder(buf_reader, self.encoding);
|
||||
|
||||
let mut decompressed = Vec::new();
|
||||
let mut bytes_read = 0u64;
|
||||
|
||||
// Read through the decoder in chunks to track compressed size
|
||||
let mut buf = [0u8; 8192];
|
||||
loop {
|
||||
match decoder.read(&mut buf).await {
|
||||
Ok(0) => break,
|
||||
Ok(n) => {
|
||||
decompressed.extend_from_slice(&buf[..n]);
|
||||
bytes_read += n as u64;
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(Error::BodyReadError(e.to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut timing = self.timing.clone();
|
||||
timing.body = self.start_time.elapsed();
|
||||
|
||||
let stats = BodyStats {
|
||||
// For now, we can't easily track compressed size when streaming through decoder
|
||||
// Use content_length as an approximation, or decompressed size if identity encoding
|
||||
size_compressed: self.content_length.unwrap_or(bytes_read),
|
||||
size_decompressed: decompressed.len() as u64,
|
||||
};
|
||||
|
||||
Ok((decompressed, stats, timing))
|
||||
}
|
||||
|
||||
/// Consume the body and return it as a UTF-8 string.
|
||||
pub async fn text(self) -> Result<(String, BodyStats, HttpResponseTiming)> {
|
||||
let (bytes, stats, timing) = self.bytes().await?;
|
||||
let text = String::from_utf8(bytes)
|
||||
.map_err(|e| Error::RequestError(format!("Response is not valid UTF-8: {}", e)))?;
|
||||
Ok((text, stats, timing))
|
||||
}
|
||||
|
||||
/// Take the body stream for manual consumption.
|
||||
/// Returns an AsyncRead that decompresses on-the-fly if Content-Encoding is set.
|
||||
/// The caller is responsible for reading and processing the stream.
|
||||
pub fn into_body_stream(mut self) -> Result<Box<dyn AsyncRead + Unpin + Send>> {
|
||||
let stream = self.body_stream.take().ok_or_else(|| {
|
||||
Error::RequestError("Response body has already been consumed".to_string())
|
||||
})?;
|
||||
|
||||
let buf_reader = BufReader::new(stream);
|
||||
let decoder = streaming_decoder(buf_reader, self.encoding);
|
||||
|
||||
Ok(decoder)
|
||||
}
|
||||
|
||||
/// Discard the body without reading it (useful for redirects).
|
||||
pub async fn drain(mut self) -> Result<HttpResponseTiming> {
|
||||
let stream = self.body_stream.take().ok_or_else(|| {
|
||||
Error::RequestError("Response body has already been consumed".to_string())
|
||||
})?;
|
||||
|
||||
// Just read and discard all bytes
|
||||
let mut reader = stream;
|
||||
let mut buf = [0u8; 8192];
|
||||
loop {
|
||||
match reader.read(&mut buf).await {
|
||||
Ok(0) => break,
|
||||
Ok(_) => continue,
|
||||
Err(e) => {
|
||||
return Err(Error::RequestError(format!(
|
||||
"Failed to drain response body: {}",
|
||||
e
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut timing = self.timing.clone();
|
||||
timing.body = self.start_time.elapsed();
|
||||
|
||||
Ok(timing)
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for sending HTTP requests
|
||||
#[async_trait]
|
||||
pub trait HttpSender: Send + Sync {
|
||||
/// Send an HTTP request and return the response with headers.
|
||||
/// The body is not consumed until you call bytes(), text(), write_to_file(), or drain().
|
||||
async fn send(
|
||||
&self,
|
||||
request: SendableHttpRequest,
|
||||
events: &mut Vec<HttpResponseEvent>,
|
||||
) -> Result<HttpResponse>;
|
||||
}
|
||||
|
||||
/// Reqwest-based implementation of HttpSender
|
||||
pub struct ReqwestSender {
|
||||
client: Client,
|
||||
}
|
||||
|
||||
impl ReqwestSender {
|
||||
/// Create a new ReqwestSender with a default client
|
||||
pub fn new() -> Result<Self> {
|
||||
let client = Client::builder().build().map_err(Error::Client)?;
|
||||
Ok(Self { client })
|
||||
}
|
||||
|
||||
/// Create a new ReqwestSender with a custom client
|
||||
pub fn with_client(client: Client) -> Self {
|
||||
Self { client }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl HttpSender for ReqwestSender {
|
||||
async fn send(
|
||||
&self,
|
||||
request: SendableHttpRequest,
|
||||
events: &mut Vec<HttpResponseEvent>,
|
||||
) -> Result<HttpResponse> {
|
||||
// Parse the HTTP method
|
||||
let method = Method::from_bytes(request.method.as_bytes())
|
||||
.map_err(|e| Error::RequestError(format!("Invalid HTTP method: {}", e)))?;
|
||||
|
||||
// Build the request
|
||||
let mut req_builder = self.client.request(method, &request.url);
|
||||
|
||||
// Add headers
|
||||
for header in request.headers {
|
||||
req_builder = req_builder.header(&header.0, &header.1);
|
||||
}
|
||||
|
||||
// Configure timeout
|
||||
if let Some(d) = request.options.timeout
|
||||
&& !d.is_zero()
|
||||
{
|
||||
req_builder = req_builder.timeout(d);
|
||||
}
|
||||
|
||||
// Add body
|
||||
match request.body {
|
||||
None => {}
|
||||
Some(SendableBody::Bytes(bytes)) => {
|
||||
req_builder = req_builder.body(bytes);
|
||||
}
|
||||
Some(SendableBody::Stream(stream)) => {
|
||||
// Convert AsyncRead stream to reqwest Body
|
||||
let stream = tokio_util::io::ReaderStream::new(stream);
|
||||
let body = reqwest::Body::wrap_stream(stream);
|
||||
req_builder = req_builder.body(body);
|
||||
}
|
||||
}
|
||||
|
||||
let start = Instant::now();
|
||||
let mut timing = HttpResponseTiming::default();
|
||||
|
||||
// Send the request
|
||||
let sendable_req = req_builder.build()?;
|
||||
events.push(HttpResponseEvent::Setting(
|
||||
"timeout".to_string(),
|
||||
if request.options.timeout.unwrap_or_default().is_zero() {
|
||||
"Infinity".to_string()
|
||||
} else {
|
||||
format!("{:?}", request.options.timeout)
|
||||
},
|
||||
));
|
||||
|
||||
events.push(HttpResponseEvent::SendUrl {
|
||||
path: sendable_req.url().path().to_string(),
|
||||
method: sendable_req.method().to_string(),
|
||||
});
|
||||
|
||||
for (name, value) in sendable_req.headers() {
|
||||
events.push(HttpResponseEvent::HeaderUp(
|
||||
name.to_string(),
|
||||
value.to_str().unwrap_or_default().to_string(),
|
||||
));
|
||||
}
|
||||
events.push(HttpResponseEvent::HeaderUpDone);
|
||||
events.push(HttpResponseEvent::Info("Sending request to server".to_string()));
|
||||
|
||||
// Map some errors to our own, so they look nicer
|
||||
let response = self.client.execute(sendable_req).await.map_err(|e| {
|
||||
if reqwest::Error::is_timeout(&e) {
|
||||
Error::RequestTimeout(
|
||||
request.options.timeout.unwrap_or(Duration::from_secs(0)).clone(),
|
||||
)
|
||||
} else {
|
||||
Error::Client(e)
|
||||
}
|
||||
})?;
|
||||
|
||||
let status = response.status().as_u16();
|
||||
let status_reason = response.status().canonical_reason().map(|s| s.to_string());
|
||||
let url = response.url().to_string();
|
||||
let remote_addr = response.remote_addr().map(|a| a.to_string());
|
||||
let version = Some(version_to_str(&response.version()));
|
||||
|
||||
events.push(HttpResponseEvent::ReceiveUrl {
|
||||
version: response.version(),
|
||||
status: response.status().to_string(),
|
||||
});
|
||||
|
||||
timing.headers = start.elapsed();
|
||||
|
||||
// Extract content length
|
||||
let content_length = response.content_length();
|
||||
|
||||
// Extract headers
|
||||
let mut headers = HashMap::new();
|
||||
for (key, value) in response.headers() {
|
||||
if let Ok(v) = value.to_str() {
|
||||
events.push(HttpResponseEvent::HeaderDown(key.to_string(), v.to_string()));
|
||||
headers.insert(key.to_string(), v.to_string());
|
||||
}
|
||||
}
|
||||
events.push(HttpResponseEvent::HeaderDownDone);
|
||||
|
||||
// Determine content encoding for decompression
|
||||
// HTTP headers are case-insensitive, so we need to search for any casing
|
||||
let encoding = ContentEncoding::from_header(
|
||||
headers
|
||||
.iter()
|
||||
.find(|(k, _)| k.eq_ignore_ascii_case("content-encoding"))
|
||||
.map(|(_, v)| v.as_str()),
|
||||
);
|
||||
|
||||
// Get the byte stream instead of loading into memory
|
||||
let byte_stream = response.bytes_stream();
|
||||
|
||||
// Convert the stream to an AsyncRead
|
||||
let stream_reader = StreamReader::new(
|
||||
byte_stream.map(|result| result.map_err(|e| std::io::Error::other(e))),
|
||||
);
|
||||
|
||||
let body_stream: BodyStream = Box::pin(stream_reader);
|
||||
|
||||
Ok(HttpResponse::new(
|
||||
status,
|
||||
status_reason,
|
||||
headers,
|
||||
content_length,
|
||||
url,
|
||||
remote_addr,
|
||||
version,
|
||||
timing,
|
||||
body_stream,
|
||||
encoding,
|
||||
start,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn version_to_str(version: &Version) -> String {
|
||||
match *version {
|
||||
Version::HTTP_09 => "HTTP/0.9".to_string(),
|
||||
Version::HTTP_10 => "HTTP/1.0".to_string(),
|
||||
Version::HTTP_11 => "HTTP/1.1".to_string(),
|
||||
Version::HTTP_2 => "HTTP/2".to_string(),
|
||||
Version::HTTP_3 => "HTTP/3".to_string(),
|
||||
_ => "unknown".to_string(),
|
||||
}
|
||||
}
|
||||
385
src-tauri/yaak-http/src/transaction.rs
Normal file
385
src-tauri/yaak-http/src/transaction.rs
Normal file
@@ -0,0 +1,385 @@
|
||||
use crate::error::Result;
|
||||
use crate::sender::{HttpResponse, HttpResponseEvent, HttpSender};
|
||||
use crate::types::SendableHttpRequest;
|
||||
use tokio::sync::watch::Receiver;
|
||||
|
||||
/// HTTP Transaction that manages the lifecycle of a request, including redirect handling
|
||||
pub struct HttpTransaction<S: HttpSender> {
|
||||
sender: S,
|
||||
max_redirects: usize,
|
||||
}
|
||||
|
||||
impl<S: HttpSender> HttpTransaction<S> {
|
||||
/// Create a new transaction with default settings
|
||||
pub fn new(sender: S) -> Self {
|
||||
Self { sender, max_redirects: 10 }
|
||||
}
|
||||
|
||||
/// Create a new transaction with custom max redirects
|
||||
pub fn with_max_redirects(sender: S, max_redirects: usize) -> Self {
|
||||
Self { sender, max_redirects }
|
||||
}
|
||||
|
||||
/// Execute the request with cancellation support.
|
||||
/// Returns an HttpResponse with unconsumed body - caller decides how to consume it.
|
||||
pub async fn execute_with_cancellation(
|
||||
&self,
|
||||
request: SendableHttpRequest,
|
||||
mut cancelled_rx: Receiver<bool>,
|
||||
) -> Result<(HttpResponse, Vec<HttpResponseEvent>)> {
|
||||
let mut redirect_count = 0;
|
||||
let mut current_url = request.url;
|
||||
let mut current_method = request.method;
|
||||
let mut current_headers = request.headers;
|
||||
let mut current_body = request.body;
|
||||
let mut events = Vec::new();
|
||||
|
||||
loop {
|
||||
// Check for cancellation before each request
|
||||
if *cancelled_rx.borrow() {
|
||||
return Err(crate::error::Error::RequestCanceledError);
|
||||
}
|
||||
|
||||
// Build request for this iteration
|
||||
let req = SendableHttpRequest {
|
||||
url: current_url.clone(),
|
||||
method: current_method.clone(),
|
||||
headers: current_headers.clone(),
|
||||
body: current_body,
|
||||
options: request.options.clone(),
|
||||
};
|
||||
|
||||
// Send the request
|
||||
events.push(HttpResponseEvent::Setting(
|
||||
"redirects".to_string(),
|
||||
request.options.follow_redirects.to_string(),
|
||||
));
|
||||
|
||||
// Execute with cancellation support
|
||||
let response = tokio::select! {
|
||||
result = self.sender.send(req, &mut events) => result?,
|
||||
_ = cancelled_rx.changed() => {
|
||||
return Err(crate::error::Error::RequestCanceledError);
|
||||
}
|
||||
};
|
||||
|
||||
if !Self::is_redirect(response.status) {
|
||||
// Not a redirect - return the response for caller to consume body
|
||||
return Ok((response, events));
|
||||
}
|
||||
|
||||
if !request.options.follow_redirects {
|
||||
// Redirects disabled - return the redirect response as-is
|
||||
return Ok((response, events));
|
||||
}
|
||||
|
||||
// Check if we've exceeded max redirects
|
||||
if redirect_count >= self.max_redirects {
|
||||
// Drain the response before returning error
|
||||
let _ = response.drain().await;
|
||||
return Err(crate::error::Error::RequestError(format!(
|
||||
"Maximum redirect limit ({}) exceeded",
|
||||
self.max_redirects
|
||||
)));
|
||||
}
|
||||
|
||||
// Extract Location header before draining (headers are available immediately)
|
||||
// HTTP headers are case-insensitive, so we need to search for any casing
|
||||
let location = response
|
||||
.headers
|
||||
.iter()
|
||||
.find(|(k, _)| k.eq_ignore_ascii_case("location"))
|
||||
.map(|(_, v)| v.clone())
|
||||
.ok_or_else(|| {
|
||||
crate::error::Error::RequestError(
|
||||
"Redirect response missing Location header".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Also get status before draining
|
||||
let status = response.status;
|
||||
|
||||
events.push(HttpResponseEvent::Info("Ignoring the response body".to_string()));
|
||||
|
||||
// Drain the redirect response body before following
|
||||
response.drain().await?;
|
||||
|
||||
// Update the request URL
|
||||
current_url = if location.starts_with("http://") || location.starts_with("https://") {
|
||||
// Absolute URL
|
||||
location
|
||||
} else if location.starts_with('/') {
|
||||
// Absolute path - need to extract base URL from current request
|
||||
let base_url = Self::extract_base_url(¤t_url)?;
|
||||
format!("{}{}", base_url, location)
|
||||
} else {
|
||||
// Relative path - need to resolve relative to current path
|
||||
let base_path = Self::extract_base_path(¤t_url)?;
|
||||
format!("{}/{}", base_path, location)
|
||||
};
|
||||
|
||||
events.push(HttpResponseEvent::Info(format!(
|
||||
"Issuing redirect {} to: {}",
|
||||
redirect_count + 1,
|
||||
current_url
|
||||
)));
|
||||
|
||||
// Handle method changes for certain redirect codes
|
||||
if status == 303 {
|
||||
// 303 See Other always changes to GET
|
||||
if current_method != "GET" {
|
||||
current_method = "GET".to_string();
|
||||
events.push(HttpResponseEvent::Info("Changing method to GET".to_string()));
|
||||
}
|
||||
// Remove content-related headers
|
||||
current_headers.retain(|h| {
|
||||
let name_lower = h.0.to_lowercase();
|
||||
!name_lower.starts_with("content-") && name_lower != "transfer-encoding"
|
||||
});
|
||||
} else if status == 301 || status == 302 {
|
||||
// For 301/302, change POST to GET (common browser behavior)
|
||||
// but keep other methods as-is
|
||||
if current_method == "POST" {
|
||||
events.push(HttpResponseEvent::Info("Changing method to GET".to_string()));
|
||||
current_method = "GET".to_string();
|
||||
// Remove content-related headers
|
||||
current_headers.retain(|h| {
|
||||
let name_lower = h.0.to_lowercase();
|
||||
!name_lower.starts_with("content-") && name_lower != "transfer-encoding"
|
||||
});
|
||||
}
|
||||
}
|
||||
// For 307 and 308, the method and body are preserved
|
||||
|
||||
// Reset body for next iteration (since it was moved in the send call)
|
||||
// For redirects that change method to GET or for all redirects since body was consumed
|
||||
current_body = None;
|
||||
|
||||
redirect_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a status code indicates a redirect
|
||||
fn is_redirect(status: u16) -> bool {
|
||||
matches!(status, 301 | 302 | 303 | 307 | 308)
|
||||
}
|
||||
|
||||
/// Extract the base URL (scheme + host) from a full URL
|
||||
fn extract_base_url(url: &str) -> Result<String> {
|
||||
// Find the position after "://"
|
||||
let scheme_end = url.find("://").ok_or_else(|| {
|
||||
crate::error::Error::RequestError(format!("Invalid URL format: {}", url))
|
||||
})?;
|
||||
|
||||
// Find the first '/' after the scheme
|
||||
let path_start = url[scheme_end + 3..].find('/');
|
||||
|
||||
if let Some(idx) = path_start {
|
||||
Ok(url[..scheme_end + 3 + idx].to_string())
|
||||
} else {
|
||||
// No path, return entire URL
|
||||
Ok(url.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the base path (everything except the last segment) from a URL
|
||||
fn extract_base_path(url: &str) -> Result<String> {
|
||||
if let Some(last_slash) = url.rfind('/') {
|
||||
// Don't include the trailing slash if it's part of the host
|
||||
if url[..last_slash].ends_with("://") || url[..last_slash].ends_with(':') {
|
||||
Ok(url.to_string())
|
||||
} else {
|
||||
Ok(url[..last_slash].to_string())
|
||||
}
|
||||
} else {
|
||||
Ok(url.to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::decompress::ContentEncoding;
|
||||
use crate::sender::{HttpResponseEvent, HttpResponseTiming, HttpSender};
|
||||
use async_trait::async_trait;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::io::AsyncRead;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
/// Mock sender for testing
|
||||
struct MockSender {
|
||||
responses: Arc<Mutex<Vec<MockResponse>>>,
|
||||
}
|
||||
|
||||
struct MockResponse {
|
||||
status: u16,
|
||||
headers: HashMap<String, String>,
|
||||
body: Vec<u8>,
|
||||
}
|
||||
|
||||
impl MockSender {
|
||||
fn new(responses: Vec<MockResponse>) -> Self {
|
||||
Self { responses: Arc::new(Mutex::new(responses)) }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl HttpSender for MockSender {
|
||||
async fn send(
|
||||
&self,
|
||||
_request: SendableHttpRequest,
|
||||
_events: &mut Vec<HttpResponseEvent>,
|
||||
) -> Result<HttpResponse> {
|
||||
let mut responses = self.responses.lock().await;
|
||||
if responses.is_empty() {
|
||||
Err(crate::error::Error::RequestError("No more mock responses".to_string()))
|
||||
} else {
|
||||
let mock = responses.remove(0);
|
||||
// Create a simple in-memory stream from the body
|
||||
let body_stream: Pin<Box<dyn AsyncRead + Send>> =
|
||||
Box::pin(std::io::Cursor::new(mock.body));
|
||||
Ok(HttpResponse::new(
|
||||
mock.status,
|
||||
None, // status_reason
|
||||
mock.headers,
|
||||
None, // content_length
|
||||
"https://example.com".to_string(), // url
|
||||
None, // remote_addr
|
||||
Some("HTTP/1.1".to_string()), // version
|
||||
HttpResponseTiming::default(),
|
||||
body_stream,
|
||||
ContentEncoding::Identity,
|
||||
Instant::now(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_transaction_no_redirect() {
|
||||
let response = MockResponse { status: 200, headers: HashMap::new(), body: b"OK".to_vec() };
|
||||
let sender = MockSender::new(vec![response]);
|
||||
let transaction = HttpTransaction::new(sender);
|
||||
|
||||
let request = SendableHttpRequest {
|
||||
url: "https://example.com".to_string(),
|
||||
method: "GET".to_string(),
|
||||
headers: vec![],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let (_tx, rx) = tokio::sync::watch::channel(false);
|
||||
let (result, _) = transaction.execute_with_cancellation(request, rx).await.unwrap();
|
||||
assert_eq!(result.status, 200);
|
||||
|
||||
// Consume the body to verify it
|
||||
let (body, _, _) = result.bytes().await.unwrap();
|
||||
assert_eq!(body, b"OK");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_transaction_single_redirect() {
|
||||
let mut redirect_headers = HashMap::new();
|
||||
redirect_headers.insert("Location".to_string(), "https://example.com/new".to_string());
|
||||
|
||||
let responses = vec![
|
||||
MockResponse { status: 302, headers: redirect_headers, body: vec![] },
|
||||
MockResponse { status: 200, headers: HashMap::new(), body: b"Final".to_vec() },
|
||||
];
|
||||
|
||||
let sender = MockSender::new(responses);
|
||||
let transaction = HttpTransaction::new(sender);
|
||||
|
||||
let request = SendableHttpRequest {
|
||||
url: "https://example.com/old".to_string(),
|
||||
method: "GET".to_string(),
|
||||
options: crate::types::SendableHttpRequestOptions {
|
||||
follow_redirects: true,
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let (_tx, rx) = tokio::sync::watch::channel(false);
|
||||
let (result, _) = transaction.execute_with_cancellation(request, rx).await.unwrap();
|
||||
assert_eq!(result.status, 200);
|
||||
|
||||
let (body, _, _) = result.bytes().await.unwrap();
|
||||
assert_eq!(body, b"Final");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_transaction_max_redirects_exceeded() {
|
||||
let mut redirect_headers = HashMap::new();
|
||||
redirect_headers.insert("Location".to_string(), "https://example.com/loop".to_string());
|
||||
|
||||
// Create more redirects than allowed
|
||||
let responses: Vec<MockResponse> = (0..12)
|
||||
.map(|_| MockResponse { status: 302, headers: redirect_headers.clone(), body: vec![] })
|
||||
.collect();
|
||||
|
||||
let sender = MockSender::new(responses);
|
||||
let transaction = HttpTransaction::with_max_redirects(sender, 10);
|
||||
|
||||
let request = SendableHttpRequest {
|
||||
url: "https://example.com/start".to_string(),
|
||||
method: "GET".to_string(),
|
||||
options: crate::types::SendableHttpRequestOptions {
|
||||
follow_redirects: true,
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let (_tx, rx) = tokio::sync::watch::channel(false);
|
||||
let result = transaction.execute_with_cancellation(request, rx).await;
|
||||
if let Err(crate::error::Error::RequestError(msg)) = result {
|
||||
assert!(msg.contains("Maximum redirect limit"));
|
||||
} else {
|
||||
panic!("Expected RequestError with max redirect message. Got {result:?}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_redirect() {
|
||||
assert!(HttpTransaction::<MockSender>::is_redirect(301));
|
||||
assert!(HttpTransaction::<MockSender>::is_redirect(302));
|
||||
assert!(HttpTransaction::<MockSender>::is_redirect(303));
|
||||
assert!(HttpTransaction::<MockSender>::is_redirect(307));
|
||||
assert!(HttpTransaction::<MockSender>::is_redirect(308));
|
||||
assert!(!HttpTransaction::<MockSender>::is_redirect(200));
|
||||
assert!(!HttpTransaction::<MockSender>::is_redirect(404));
|
||||
assert!(!HttpTransaction::<MockSender>::is_redirect(500));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_base_url() {
|
||||
let result =
|
||||
HttpTransaction::<MockSender>::extract_base_url("https://example.com/path/to/resource");
|
||||
assert_eq!(result.unwrap(), "https://example.com");
|
||||
|
||||
let result = HttpTransaction::<MockSender>::extract_base_url("http://localhost:8080/api");
|
||||
assert_eq!(result.unwrap(), "http://localhost:8080");
|
||||
|
||||
let result = HttpTransaction::<MockSender>::extract_base_url("invalid-url");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_base_path() {
|
||||
let result = HttpTransaction::<MockSender>::extract_base_path(
|
||||
"https://example.com/path/to/resource",
|
||||
);
|
||||
assert_eq!(result.unwrap(), "https://example.com/path/to");
|
||||
|
||||
let result = HttpTransaction::<MockSender>::extract_base_path("https://example.com/single");
|
||||
assert_eq!(result.unwrap(), "https://example.com");
|
||||
|
||||
let result = HttpTransaction::<MockSender>::extract_base_path("https://example.com/");
|
||||
assert_eq!(result.unwrap(), "https://example.com");
|
||||
}
|
||||
}
|
||||
975
src-tauri/yaak-http/src/types.rs
Normal file
975
src-tauri/yaak-http/src/types.rs
Normal file
@@ -0,0 +1,975 @@
|
||||
use crate::chained_reader::{ChainedReader, ReaderType};
|
||||
use crate::error::Error::RequestError;
|
||||
use crate::error::Result;
|
||||
use crate::path_placeholders::apply_path_placeholders;
|
||||
use crate::proto::ensure_proto;
|
||||
use bytes::Bytes;
|
||||
use log::warn;
|
||||
use std::collections::BTreeMap;
|
||||
use std::pin::Pin;
|
||||
use std::time::Duration;
|
||||
use tokio::io::AsyncRead;
|
||||
use yaak_common::serde::{get_bool, get_str, get_str_map};
|
||||
use yaak_models::models::HttpRequest;
|
||||
|
||||
pub(crate) const MULTIPART_BOUNDARY: &str = "------YaakFormBoundary";
|
||||
|
||||
pub enum SendableBody {
|
||||
Bytes(Bytes),
|
||||
Stream(Pin<Box<dyn AsyncRead + Send + 'static>>),
|
||||
}
|
||||
|
||||
enum SendableBodyWithMeta {
|
||||
Bytes(Bytes),
|
||||
Stream {
|
||||
data: Pin<Box<dyn AsyncRead + Send + 'static>>,
|
||||
content_length: Option<usize>,
|
||||
},
|
||||
}
|
||||
|
||||
impl From<SendableBodyWithMeta> for SendableBody {
|
||||
fn from(value: SendableBodyWithMeta) -> Self {
|
||||
match value {
|
||||
SendableBodyWithMeta::Bytes(b) => SendableBody::Bytes(b),
|
||||
SendableBodyWithMeta::Stream { data, .. } => SendableBody::Stream(data),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct SendableHttpRequest {
|
||||
pub url: String,
|
||||
pub method: String,
|
||||
pub headers: Vec<(String, String)>,
|
||||
pub body: Option<SendableBody>,
|
||||
pub options: SendableHttpRequestOptions,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
pub struct SendableHttpRequestOptions {
|
||||
pub timeout: Option<Duration>,
|
||||
pub follow_redirects: bool,
|
||||
}
|
||||
|
||||
impl SendableHttpRequest {
|
||||
pub async fn from_http_request(
|
||||
r: &HttpRequest,
|
||||
options: SendableHttpRequestOptions,
|
||||
) -> Result<Self> {
|
||||
let initial_headers = build_headers(r);
|
||||
let (body, headers) = build_body(&r.method, &r.body_type, &r.body, initial_headers).await?;
|
||||
|
||||
Ok(Self {
|
||||
url: build_url(r),
|
||||
method: r.method.to_uppercase(),
|
||||
headers,
|
||||
body: body.into(),
|
||||
options,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn insert_header(&mut self, header: (String, String)) {
|
||||
if let Some(existing) =
|
||||
self.headers.iter_mut().find(|h| h.0.to_lowercase() == header.0.to_lowercase())
|
||||
{
|
||||
existing.1 = header.1;
|
||||
} else {
|
||||
self.headers.push(header);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn append_query_params(url: &str, params: Vec<(String, String)>) -> String {
|
||||
let url_string = url.to_string();
|
||||
if params.is_empty() {
|
||||
return url.to_string();
|
||||
}
|
||||
|
||||
// Build query string
|
||||
let query_string = params
|
||||
.iter()
|
||||
.map(|(name, value)| {
|
||||
format!("{}={}", urlencoding::encode(name), urlencoding::encode(value))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("&");
|
||||
|
||||
// Split URL into parts: base URL, query, and fragment
|
||||
let (base_and_query, fragment) = if let Some(hash_pos) = url_string.find('#') {
|
||||
let (before_hash, after_hash) = url_string.split_at(hash_pos);
|
||||
(before_hash.to_string(), Some(after_hash.to_string()))
|
||||
} else {
|
||||
(url_string, None)
|
||||
};
|
||||
|
||||
// Now handle query parameters on the base URL (without fragment)
|
||||
let mut result = if base_and_query.contains('?') {
|
||||
// Check if there's already a query string after the '?'
|
||||
let parts: Vec<&str> = base_and_query.splitn(2, '?').collect();
|
||||
if parts.len() == 2 && !parts[1].trim().is_empty() {
|
||||
// Append with & if there are existing parameters
|
||||
format!("{}&{}", base_and_query, query_string)
|
||||
} else {
|
||||
// Just append the new parameters directly (URL ends with '?')
|
||||
format!("{}{}", base_and_query, query_string)
|
||||
}
|
||||
} else {
|
||||
// No existing query parameters, add with '?'
|
||||
format!("{}?{}", base_and_query, query_string)
|
||||
};
|
||||
|
||||
// Re-append the fragment if it exists
|
||||
if let Some(fragment) = fragment {
|
||||
result.push_str(&fragment);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn build_url(r: &HttpRequest) -> String {
|
||||
let (url_string, params) = apply_path_placeholders(&ensure_proto(&r.url), &r.url_parameters);
|
||||
append_query_params(
|
||||
&url_string,
|
||||
params
|
||||
.iter()
|
||||
.filter(|p| p.enabled && !p.name.is_empty())
|
||||
.map(|p| (p.name.clone(), p.value.clone()))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
fn build_headers(r: &HttpRequest) -> Vec<(String, String)> {
|
||||
r.headers
|
||||
.iter()
|
||||
.filter_map(|h| {
|
||||
if h.enabled && !h.name.is_empty() {
|
||||
Some((h.name.clone(), h.value.clone()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn build_body(
|
||||
method: &str,
|
||||
body_type: &Option<String>,
|
||||
body: &BTreeMap<String, serde_json::Value>,
|
||||
headers: Vec<(String, String)>,
|
||||
) -> Result<(Option<SendableBody>, Vec<(String, String)>)> {
|
||||
let body_type = match &body_type {
|
||||
None => return Ok((None, headers)),
|
||||
Some(t) => t,
|
||||
};
|
||||
|
||||
let (body, content_type) = match body_type.as_str() {
|
||||
"binary" => (build_binary_body(&body).await?, None),
|
||||
"graphql" => (build_graphql_body(&method, &body), Some("application/json".to_string())),
|
||||
"application/x-www-form-urlencoded" => {
|
||||
(build_form_body(&body), Some("application/x-www-form-urlencoded".to_string()))
|
||||
}
|
||||
"multipart/form-data" => build_multipart_body(&body, &headers).await?,
|
||||
_ if body.contains_key("text") => (build_text_body(&body), None),
|
||||
t => {
|
||||
warn!("Unsupported body type: {}", t);
|
||||
(None, None)
|
||||
}
|
||||
};
|
||||
|
||||
// Add or update the Content-Type header
|
||||
let mut headers = headers;
|
||||
if let Some(ct) = content_type {
|
||||
if let Some(existing) = headers.iter_mut().find(|h| h.0.to_lowercase() == "content-type") {
|
||||
existing.1 = ct;
|
||||
} else {
|
||||
headers.push(("Content-Type".to_string(), ct));
|
||||
}
|
||||
}
|
||||
|
||||
// Check if Transfer-Encoding: chunked is already set
|
||||
let has_chunked_encoding = headers.iter().any(|h| {
|
||||
h.0.to_lowercase() == "transfer-encoding" && h.1.to_lowercase().contains("chunked")
|
||||
});
|
||||
|
||||
// Add a Content-Length header only if chunked encoding is not being used
|
||||
if !has_chunked_encoding {
|
||||
let content_length = match body {
|
||||
Some(SendableBodyWithMeta::Bytes(ref bytes)) => Some(bytes.len()),
|
||||
Some(SendableBodyWithMeta::Stream { content_length, .. }) => content_length,
|
||||
None => None,
|
||||
};
|
||||
|
||||
if let Some(cl) = content_length {
|
||||
headers.push(("Content-Length".to_string(), cl.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
Ok((body.map(|b| b.into()), headers))
|
||||
}
|
||||
|
||||
fn build_form_body(body: &BTreeMap<String, serde_json::Value>) -> Option<SendableBodyWithMeta> {
|
||||
let form_params = match body.get("form").map(|f| f.as_array()) {
|
||||
Some(Some(f)) => f,
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
let mut body = String::new();
|
||||
for p in form_params {
|
||||
let enabled = get_bool(p, "enabled", true);
|
||||
let name = get_str(p, "name");
|
||||
if !enabled || name.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let value = get_str(p, "value");
|
||||
if !body.is_empty() {
|
||||
body.push('&');
|
||||
}
|
||||
body.push_str(&urlencoding::encode(&name));
|
||||
body.push('=');
|
||||
body.push_str(&urlencoding::encode(&value));
|
||||
}
|
||||
|
||||
if body.is_empty() { None } else { Some(SendableBodyWithMeta::Bytes(Bytes::from(body))) }
|
||||
}
|
||||
|
||||
async fn build_binary_body(
|
||||
body: &BTreeMap<String, serde_json::Value>,
|
||||
) -> Result<Option<SendableBodyWithMeta>> {
|
||||
let file_path = match body.get("filePath").map(|f| f.as_str()) {
|
||||
Some(Some(f)) => f,
|
||||
_ => return Ok(None),
|
||||
};
|
||||
|
||||
// Open a file for streaming
|
||||
let content_length = tokio::fs::metadata(file_path)
|
||||
.await
|
||||
.map_err(|e| RequestError(format!("Failed to get file metadata: {}", e)))?
|
||||
.len();
|
||||
|
||||
let file = tokio::fs::File::open(file_path)
|
||||
.await
|
||||
.map_err(|e| RequestError(format!("Failed to open file: {}", e)))?;
|
||||
|
||||
Ok(Some(SendableBodyWithMeta::Stream {
|
||||
data: Box::pin(file),
|
||||
content_length: Some(content_length as usize),
|
||||
}))
|
||||
}
|
||||
|
||||
fn build_text_body(body: &BTreeMap<String, serde_json::Value>) -> Option<SendableBodyWithMeta> {
|
||||
let text = get_str_map(body, "text");
|
||||
if text.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(SendableBodyWithMeta::Bytes(Bytes::from(text.to_string())))
|
||||
}
|
||||
}
|
||||
|
||||
fn build_graphql_body(
|
||||
method: &str,
|
||||
body: &BTreeMap<String, serde_json::Value>,
|
||||
) -> Option<SendableBodyWithMeta> {
|
||||
let query = get_str_map(body, "query");
|
||||
let variables = get_str_map(body, "variables");
|
||||
|
||||
if method.to_lowercase() == "get" {
|
||||
// GraphQL GET requests use query parameters, not a body
|
||||
return None;
|
||||
}
|
||||
|
||||
let body = if variables.trim().is_empty() {
|
||||
format!(r#"{{"query":{}}}"#, serde_json::to_string(&query).unwrap_or_default())
|
||||
} else {
|
||||
format!(
|
||||
r#"{{"query":{},"variables":{}}}"#,
|
||||
serde_json::to_string(&query).unwrap_or_default(),
|
||||
variables
|
||||
)
|
||||
};
|
||||
|
||||
Some(SendableBodyWithMeta::Bytes(Bytes::from(body)))
|
||||
}
|
||||
|
||||
async fn build_multipart_body(
|
||||
body: &BTreeMap<String, serde_json::Value>,
|
||||
headers: &Vec<(String, String)>,
|
||||
) -> Result<(Option<SendableBodyWithMeta>, Option<String>)> {
|
||||
let boundary = extract_boundary_from_headers(headers);
|
||||
|
||||
let form_params = match body.get("form").map(|f| f.as_array()) {
|
||||
Some(Some(f)) => f,
|
||||
_ => return Ok((None, None)),
|
||||
};
|
||||
|
||||
// Build a list of readers for streaming and calculate total content length
|
||||
let mut readers: Vec<ReaderType> = Vec::new();
|
||||
let mut has_content = false;
|
||||
let mut total_size: usize = 0;
|
||||
|
||||
for p in form_params {
|
||||
let enabled = get_bool(p, "enabled", true);
|
||||
let name = get_str(p, "name");
|
||||
if !enabled || name.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
has_content = true;
|
||||
|
||||
// Add boundary delimiter
|
||||
let boundary_bytes = format!("--{}\r\n", boundary).into_bytes();
|
||||
total_size += boundary_bytes.len();
|
||||
readers.push(ReaderType::Bytes(boundary_bytes));
|
||||
|
||||
let file_path = get_str(p, "file");
|
||||
let value = get_str(p, "value");
|
||||
let content_type = get_str(p, "contentType");
|
||||
|
||||
if file_path.is_empty() {
|
||||
// Text field
|
||||
let header =
|
||||
format!("Content-Disposition: form-data; name=\"{}\"\r\n\r\n{}", name, value);
|
||||
let header_bytes = header.into_bytes();
|
||||
total_size += header_bytes.len();
|
||||
readers.push(ReaderType::Bytes(header_bytes));
|
||||
} else {
|
||||
// File field - validate that file exists first
|
||||
if !tokio::fs::try_exists(file_path).await.unwrap_or(false) {
|
||||
return Err(RequestError(format!("File not found: {}", file_path)));
|
||||
}
|
||||
|
||||
// Get file size for content length calculation
|
||||
let file_metadata = tokio::fs::metadata(file_path)
|
||||
.await
|
||||
.map_err(|e| RequestError(format!("Failed to get file metadata: {}", e)))?;
|
||||
let file_size = file_metadata.len() as usize;
|
||||
|
||||
let filename = get_str(p, "filename");
|
||||
let filename = if filename.is_empty() {
|
||||
std::path::Path::new(file_path)
|
||||
.file_name()
|
||||
.and_then(|n| n.to_str())
|
||||
.unwrap_or("file")
|
||||
} else {
|
||||
filename
|
||||
};
|
||||
|
||||
// Add content type
|
||||
let mime_type = if !content_type.is_empty() {
|
||||
content_type.to_string()
|
||||
} else {
|
||||
// Guess mime type from file extension
|
||||
mime_guess::from_path(file_path).first_or_octet_stream().to_string()
|
||||
};
|
||||
|
||||
let header = format!(
|
||||
"Content-Disposition: form-data; name=\"{}\"; filename=\"{}\"\r\nContent-Type: {}\r\n\r\n",
|
||||
name, filename, mime_type
|
||||
);
|
||||
let header_bytes = header.into_bytes();
|
||||
total_size += header_bytes.len();
|
||||
total_size += file_size;
|
||||
readers.push(ReaderType::Bytes(header_bytes));
|
||||
|
||||
// Add a file path for streaming
|
||||
readers.push(ReaderType::FilePath(file_path.to_string()));
|
||||
}
|
||||
|
||||
let line_ending = b"\r\n".to_vec();
|
||||
total_size += line_ending.len();
|
||||
readers.push(ReaderType::Bytes(line_ending));
|
||||
}
|
||||
|
||||
if has_content {
|
||||
// Add the final boundary
|
||||
let final_boundary = format!("--{}--\r\n", boundary).into_bytes();
|
||||
total_size += final_boundary.len();
|
||||
readers.push(ReaderType::Bytes(final_boundary));
|
||||
|
||||
let content_type = format!("multipart/form-data; boundary={}", boundary);
|
||||
let stream = ChainedReader::new(readers);
|
||||
Ok((
|
||||
Some(SendableBodyWithMeta::Stream {
|
||||
data: Box::pin(stream),
|
||||
content_length: Some(total_size),
|
||||
}),
|
||||
Some(content_type),
|
||||
))
|
||||
} else {
|
||||
Ok((None, None))
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_boundary_from_headers(headers: &Vec<(String, String)>) -> String {
|
||||
headers
|
||||
.iter()
|
||||
.find(|h| h.0.to_lowercase() == "content-type")
|
||||
.and_then(|h| {
|
||||
// Extract boundary from the Content-Type header (e.g., "multipart/form-data; boundary=xyz")
|
||||
h.1.split(';')
|
||||
.find(|part| part.trim().starts_with("boundary="))
|
||||
.and_then(|boundary_part| boundary_part.split('=').nth(1))
|
||||
.map(|b| b.trim().to_string())
|
||||
})
|
||||
.unwrap_or_else(|| MULTIPART_BOUNDARY.to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use bytes::Bytes;
|
||||
use serde_json::json;
|
||||
use std::collections::BTreeMap;
|
||||
use yaak_models::models::{HttpRequest, HttpUrlParameter};
|
||||
|
||||
#[test]
|
||||
fn test_build_url_no_params() {
|
||||
let r = HttpRequest {
|
||||
url: "https://example.com/api".to_string(),
|
||||
url_parameters: vec![],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = build_url(&r);
|
||||
assert_eq!(result, "https://example.com/api");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_url_with_params() {
|
||||
let r = HttpRequest {
|
||||
url: "https://example.com/api".to_string(),
|
||||
url_parameters: vec![
|
||||
HttpUrlParameter {
|
||||
enabled: true,
|
||||
name: "foo".to_string(),
|
||||
value: "bar".to_string(),
|
||||
id: None,
|
||||
},
|
||||
HttpUrlParameter {
|
||||
enabled: true,
|
||||
name: "baz".to_string(),
|
||||
value: "qux".to_string(),
|
||||
id: None,
|
||||
},
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = build_url(&r);
|
||||
assert_eq!(result, "https://example.com/api?foo=bar&baz=qux");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_url_with_disabled_params() {
|
||||
let r = HttpRequest {
|
||||
url: "https://example.com/api".to_string(),
|
||||
url_parameters: vec![
|
||||
HttpUrlParameter {
|
||||
enabled: false,
|
||||
name: "disabled".to_string(),
|
||||
value: "value".to_string(),
|
||||
id: None,
|
||||
},
|
||||
HttpUrlParameter {
|
||||
enabled: true,
|
||||
name: "enabled".to_string(),
|
||||
value: "value".to_string(),
|
||||
id: None,
|
||||
},
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = build_url(&r);
|
||||
assert_eq!(result, "https://example.com/api?enabled=value");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_url_with_existing_query() {
|
||||
let r = HttpRequest {
|
||||
url: "https://example.com/api?existing=param".to_string(),
|
||||
url_parameters: vec![HttpUrlParameter {
|
||||
enabled: true,
|
||||
name: "new".to_string(),
|
||||
value: "value".to_string(),
|
||||
id: None,
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = build_url(&r);
|
||||
assert_eq!(result, "https://example.com/api?existing=param&new=value");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_url_with_empty_existing_query() {
|
||||
let r = HttpRequest {
|
||||
url: "https://example.com/api?".to_string(),
|
||||
url_parameters: vec![HttpUrlParameter {
|
||||
enabled: true,
|
||||
name: "new".to_string(),
|
||||
value: "value".to_string(),
|
||||
id: None,
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = build_url(&r);
|
||||
assert_eq!(result, "https://example.com/api?new=value");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_url_with_special_chars() {
|
||||
let r = HttpRequest {
|
||||
url: "https://example.com/api".to_string(),
|
||||
url_parameters: vec![HttpUrlParameter {
|
||||
enabled: true,
|
||||
name: "special chars!@#".to_string(),
|
||||
value: "value with spaces & symbols".to_string(),
|
||||
id: None,
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = build_url(&r);
|
||||
assert_eq!(
|
||||
result,
|
||||
"https://example.com/api?special%20chars%21%40%23=value%20with%20spaces%20%26%20symbols"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_url_adds_protocol() {
|
||||
let r = HttpRequest {
|
||||
url: "example.com/api".to_string(),
|
||||
url_parameters: vec![HttpUrlParameter {
|
||||
enabled: true,
|
||||
name: "foo".to_string(),
|
||||
value: "bar".to_string(),
|
||||
id: None,
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = build_url(&r);
|
||||
// ensure_proto defaults to http:// for regular domains
|
||||
assert_eq!(result, "http://example.com/api?foo=bar");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_url_adds_https_for_dev_domain() {
|
||||
let r = HttpRequest {
|
||||
url: "example.dev/api".to_string(),
|
||||
url_parameters: vec![HttpUrlParameter {
|
||||
enabled: true,
|
||||
name: "foo".to_string(),
|
||||
value: "bar".to_string(),
|
||||
id: None,
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = build_url(&r);
|
||||
// .dev domains force https
|
||||
assert_eq!(result, "https://example.dev/api?foo=bar");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_url_with_fragment() {
|
||||
let r = HttpRequest {
|
||||
url: "https://example.com/api#section".to_string(),
|
||||
url_parameters: vec![HttpUrlParameter {
|
||||
enabled: true,
|
||||
name: "foo".to_string(),
|
||||
value: "bar".to_string(),
|
||||
id: None,
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = build_url(&r);
|
||||
assert_eq!(result, "https://example.com/api?foo=bar#section");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_url_with_existing_query_and_fragment() {
|
||||
let r = HttpRequest {
|
||||
url: "https://yaak.app?foo=bar#some-hash".to_string(),
|
||||
url_parameters: vec![HttpUrlParameter {
|
||||
enabled: true,
|
||||
name: "baz".to_string(),
|
||||
value: "qux".to_string(),
|
||||
id: None,
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = build_url(&r);
|
||||
assert_eq!(result, "https://yaak.app?foo=bar&baz=qux#some-hash");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_url_with_empty_query_and_fragment() {
|
||||
let r = HttpRequest {
|
||||
url: "https://example.com/api?#section".to_string(),
|
||||
url_parameters: vec![HttpUrlParameter {
|
||||
enabled: true,
|
||||
name: "foo".to_string(),
|
||||
value: "bar".to_string(),
|
||||
id: None,
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = build_url(&r);
|
||||
assert_eq!(result, "https://example.com/api?foo=bar#section");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_url_with_fragment_containing_special_chars() {
|
||||
let r = HttpRequest {
|
||||
url: "https://example.com#section/with/slashes?and=fake&query".to_string(),
|
||||
url_parameters: vec![HttpUrlParameter {
|
||||
enabled: true,
|
||||
name: "real".to_string(),
|
||||
value: "param".to_string(),
|
||||
id: None,
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = build_url(&r);
|
||||
assert_eq!(result, "https://example.com?real=param#section/with/slashes?and=fake&query");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_url_preserves_empty_fragment() {
|
||||
let r = HttpRequest {
|
||||
url: "https://example.com/api#".to_string(),
|
||||
url_parameters: vec![HttpUrlParameter {
|
||||
enabled: true,
|
||||
name: "foo".to_string(),
|
||||
value: "bar".to_string(),
|
||||
id: None,
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = build_url(&r);
|
||||
assert_eq!(result, "https://example.com/api?foo=bar#");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_url_with_multiple_fragments() {
|
||||
// Testing edge case where the URL has multiple # characters (though technically invalid)
|
||||
let r = HttpRequest {
|
||||
url: "https://example.com#section#subsection".to_string(),
|
||||
url_parameters: vec![HttpUrlParameter {
|
||||
enabled: true,
|
||||
name: "foo".to_string(),
|
||||
value: "bar".to_string(),
|
||||
id: None,
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = build_url(&r);
|
||||
// Should treat everything after first # as fragment
|
||||
assert_eq!(result, "https://example.com?foo=bar#section#subsection");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_body() {
|
||||
let mut body = BTreeMap::new();
|
||||
body.insert("text".to_string(), json!("Hello, World!"));
|
||||
|
||||
let result = build_text_body(&body);
|
||||
match result {
|
||||
Some(SendableBodyWithMeta::Bytes(bytes)) => {
|
||||
assert_eq!(bytes, Bytes::from("Hello, World!"))
|
||||
}
|
||||
_ => panic!("Expected Some(SendableBody::Bytes)"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_body_empty() {
|
||||
let mut body = BTreeMap::new();
|
||||
body.insert("text".to_string(), json!(""));
|
||||
|
||||
let result = build_text_body(&body);
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_body_missing() {
|
||||
let body = BTreeMap::new();
|
||||
|
||||
let result = build_text_body(&body);
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_form_urlencoded_body() -> Result<()> {
|
||||
let mut body = BTreeMap::new();
|
||||
body.insert(
|
||||
"form".to_string(),
|
||||
json!([
|
||||
{ "enabled": true, "name": "basic", "value": "aaa"},
|
||||
{ "enabled": true, "name": "fUnkey Stuff!$*#(", "value": "*)%&#$)@ *$#)@&"},
|
||||
{ "enabled": false, "name": "disabled", "value": "won't show"},
|
||||
]),
|
||||
);
|
||||
|
||||
let result = build_form_body(&body);
|
||||
match result {
|
||||
Some(SendableBodyWithMeta::Bytes(bytes)) => {
|
||||
let expected = "basic=aaa&fUnkey%20Stuff%21%24%2A%23%28=%2A%29%25%26%23%24%29%40%20%2A%24%23%29%40%26";
|
||||
assert_eq!(bytes, Bytes::from(expected));
|
||||
}
|
||||
_ => panic!("Expected Some(SendableBody::Bytes)"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_form_urlencoded_body_missing_form() {
|
||||
let body = BTreeMap::new();
|
||||
let result = build_form_body(&body);
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_binary_body() -> Result<()> {
|
||||
let mut body = BTreeMap::new();
|
||||
body.insert("filePath".to_string(), json!("./tests/test.txt"));
|
||||
|
||||
let result = build_binary_body(&body).await?;
|
||||
assert!(matches!(result, Some(SendableBodyWithMeta::Stream { .. })));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_binary_body_file_not_found() {
|
||||
let mut body = BTreeMap::new();
|
||||
body.insert("filePath".to_string(), json!("./nonexistent/file.txt"));
|
||||
|
||||
let result = build_binary_body(&body).await;
|
||||
assert!(result.is_err());
|
||||
if let Err(e) = result {
|
||||
assert!(matches!(e, RequestError(_)));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_graphql_body_with_variables() {
|
||||
let mut body = BTreeMap::new();
|
||||
body.insert("query".to_string(), json!("{ user(id: $id) { name } }"));
|
||||
body.insert("variables".to_string(), json!(r#"{"id": "123"}"#));
|
||||
|
||||
let result = build_graphql_body("POST", &body);
|
||||
match result {
|
||||
Some(SendableBodyWithMeta::Bytes(bytes)) => {
|
||||
let expected =
|
||||
r#"{"query":"{ user(id: $id) { name } }","variables":{"id": "123"}}"#;
|
||||
assert_eq!(bytes, Bytes::from(expected));
|
||||
}
|
||||
_ => panic!("Expected Some(SendableBody::Bytes)"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_graphql_body_without_variables() {
|
||||
let mut body = BTreeMap::new();
|
||||
body.insert("query".to_string(), json!("{ users { name } }"));
|
||||
body.insert("variables".to_string(), json!(""));
|
||||
|
||||
let result = build_graphql_body("POST", &body);
|
||||
match result {
|
||||
Some(SendableBodyWithMeta::Bytes(bytes)) => {
|
||||
let expected = r#"{"query":"{ users { name } }"}"#;
|
||||
assert_eq!(bytes, Bytes::from(expected));
|
||||
}
|
||||
_ => panic!("Expected Some(SendableBody::Bytes)"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_graphql_body_get_method() {
|
||||
let mut body = BTreeMap::new();
|
||||
body.insert("query".to_string(), json!("{ users { name } }"));
|
||||
|
||||
let result = build_graphql_body("GET", &body);
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multipart_body_text_fields() -> Result<()> {
|
||||
let mut body = BTreeMap::new();
|
||||
body.insert(
|
||||
"form".to_string(),
|
||||
json!([
|
||||
{ "enabled": true, "name": "field1", "value": "value1", "file": "" },
|
||||
{ "enabled": true, "name": "field2", "value": "value2", "file": "" },
|
||||
{ "enabled": false, "name": "disabled", "value": "won't show", "file": "" },
|
||||
]),
|
||||
);
|
||||
|
||||
let (result, content_type) = build_multipart_body(&body, &vec![]).await?;
|
||||
assert!(content_type.is_some());
|
||||
|
||||
match result {
|
||||
Some(SendableBodyWithMeta::Stream { data: mut stream, content_length }) => {
|
||||
// Read the entire stream to verify content
|
||||
let mut buf = Vec::new();
|
||||
use tokio::io::AsyncReadExt;
|
||||
stream.read_to_end(&mut buf).await.expect("Failed to read stream");
|
||||
let body_str = String::from_utf8_lossy(&buf);
|
||||
assert_eq!(
|
||||
body_str,
|
||||
"--------YaakFormBoundary\r\nContent-Disposition: form-data; name=\"field1\"\r\n\r\nvalue1\r\n--------YaakFormBoundary\r\nContent-Disposition: form-data; name=\"field2\"\r\n\r\nvalue2\r\n--------YaakFormBoundary--\r\n",
|
||||
);
|
||||
assert_eq!(content_length, Some(body_str.len()));
|
||||
}
|
||||
_ => panic!("Expected Some(SendableBody::Stream)"),
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
content_type.unwrap(),
|
||||
format!("multipart/form-data; boundary={}", MULTIPART_BOUNDARY)
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multipart_body_with_file() -> Result<()> {
|
||||
let mut body = BTreeMap::new();
|
||||
body.insert(
|
||||
"form".to_string(),
|
||||
json!([
|
||||
{ "enabled": true, "name": "file_field", "file": "./tests/test.txt", "filename": "custom.txt", "contentType": "text/plain" },
|
||||
]),
|
||||
);
|
||||
|
||||
let (result, content_type) = build_multipart_body(&body, &vec![]).await?;
|
||||
assert!(content_type.is_some());
|
||||
|
||||
match result {
|
||||
Some(SendableBodyWithMeta::Stream { data: mut stream, content_length }) => {
|
||||
// Read the entire stream to verify content
|
||||
let mut buf = Vec::new();
|
||||
use tokio::io::AsyncReadExt;
|
||||
stream.read_to_end(&mut buf).await.expect("Failed to read stream");
|
||||
let body_str = String::from_utf8_lossy(&buf);
|
||||
assert_eq!(
|
||||
body_str,
|
||||
"--------YaakFormBoundary\r\nContent-Disposition: form-data; name=\"file_field\"; filename=\"custom.txt\"\r\nContent-Type: text/plain\r\n\r\nThis is a test file!\n\r\n--------YaakFormBoundary--\r\n"
|
||||
);
|
||||
assert_eq!(content_length, Some(body_str.len()));
|
||||
}
|
||||
_ => panic!("Expected Some(SendableBody::Stream)"),
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
content_type.unwrap(),
|
||||
format!("multipart/form-data; boundary={}", MULTIPART_BOUNDARY)
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multipart_body_empty() -> Result<()> {
|
||||
let body = BTreeMap::new();
|
||||
let (result, content_type) = build_multipart_body(&body, &vec![]).await?;
|
||||
assert!(result.is_none());
|
||||
assert_eq!(content_type, None);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_boundary_from_headers_with_custom_boundary() {
|
||||
let headers = vec![(
|
||||
"Content-Type".to_string(),
|
||||
"multipart/form-data; boundary=customBoundary123".to_string(),
|
||||
)];
|
||||
let boundary = extract_boundary_from_headers(&headers);
|
||||
assert_eq!(boundary, "customBoundary123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_boundary_from_headers_default() {
|
||||
let headers = vec![("Accept".to_string(), "*/*".to_string())];
|
||||
let boundary = extract_boundary_from_headers(&headers);
|
||||
assert_eq!(boundary, MULTIPART_BOUNDARY);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_boundary_from_headers_no_boundary_in_content_type() {
|
||||
let headers = vec![("Content-Type".to_string(), "multipart/form-data".to_string())];
|
||||
let boundary = extract_boundary_from_headers(&headers);
|
||||
assert_eq!(boundary, MULTIPART_BOUNDARY);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_boundary_case_insensitive() {
|
||||
let headers = vec![(
|
||||
"Content-Type".to_string(),
|
||||
"multipart/form-data; boundary=myBoundary".to_string(),
|
||||
)];
|
||||
let boundary = extract_boundary_from_headers(&headers);
|
||||
assert_eq!(boundary, "myBoundary");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_no_content_length_with_chunked_encoding() -> Result<()> {
|
||||
let mut body = BTreeMap::new();
|
||||
body.insert("text".to_string(), json!("Hello, World!"));
|
||||
|
||||
// Headers with Transfer-Encoding: chunked
|
||||
let headers = vec![("Transfer-Encoding".to_string(), "chunked".to_string())];
|
||||
|
||||
let (_, result_headers) =
|
||||
build_body("POST", &Some("text/plain".to_string()), &body, headers).await?;
|
||||
|
||||
// Verify that Content-Length is NOT present when Transfer-Encoding: chunked is set
|
||||
let has_content_length =
|
||||
result_headers.iter().any(|h| h.0.to_lowercase() == "content-length");
|
||||
assert!(!has_content_length, "Content-Length should not be present with chunked encoding");
|
||||
|
||||
// Verify that the Transfer-Encoding header is still present
|
||||
let has_chunked = result_headers.iter().any(|h| {
|
||||
h.0.to_lowercase() == "transfer-encoding" && h.1.to_lowercase().contains("chunked")
|
||||
});
|
||||
assert!(has_chunked, "Transfer-Encoding: chunked should be preserved");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_content_length_without_chunked_encoding() -> Result<()> {
|
||||
let mut body = BTreeMap::new();
|
||||
body.insert("text".to_string(), json!("Hello, World!"));
|
||||
|
||||
// Headers without Transfer-Encoding: chunked
|
||||
let headers = vec![];
|
||||
|
||||
let (_, result_headers) =
|
||||
build_body("POST", &Some("text/plain".to_string()), &body, headers).await?;
|
||||
|
||||
// Verify that Content-Length IS present when Transfer-Encoding: chunked is NOT set
|
||||
let content_length_header =
|
||||
result_headers.iter().find(|h| h.0.to_lowercase() == "content-length");
|
||||
assert!(
|
||||
content_length_header.is_some(),
|
||||
"Content-Length should be present without chunked encoding"
|
||||
);
|
||||
assert_eq!(
|
||||
content_length_header.unwrap().1,
|
||||
"13",
|
||||
"Content-Length should match the body size"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user