From 32f22aad67de5f26babac505e212310dfe2fc3a3 Mon Sep 17 00:00:00 2001 From: Gregory Schier Date: Fri, 6 Mar 2026 06:58:45 -0800 Subject: [PATCH] Add initial yaak-proxy crate --- Cargo.lock | 74 ++++++ Cargo.toml | 2 + crates/yaak-proxy/Cargo.toml | 18 ++ crates/yaak-proxy/src/body.rs | 114 ++++++++ crates/yaak-proxy/src/cert.rs | 82 ++++++ crates/yaak-proxy/src/connection.rs | 32 +++ crates/yaak-proxy/src/lib.rs | 168 ++++++++++++ crates/yaak-proxy/src/request.rs | 390 ++++++++++++++++++++++++++++ tsconfig.json | 5 +- 9 files changed, 882 insertions(+), 3 deletions(-) create mode 100644 crates/yaak-proxy/Cargo.toml create mode 100644 crates/yaak-proxy/src/body.rs create mode 100644 crates/yaak-proxy/src/cert.rs create mode 100644 crates/yaak-proxy/src/connection.rs create mode 100644 crates/yaak-proxy/src/lib.rs create mode 100644 crates/yaak-proxy/src/request.rs diff --git a/Cargo.lock b/Cargo.lock index 5a06b414..455e8cc1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -477,6 +477,28 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "aws-lc-rs" +version = "1.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94bffc006df10ac2a68c83692d734a465f8ee6c5b384d8545a636f81d858f4bf" +dependencies = [ + "aws-lc-sys", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.38.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4321e568ed89bb5a7d291a7f37997c2c0df89809d7b6d12062c81ddb54aa782e" +dependencies = [ + "cc", + "cmake", + "dunce", + "fs_extra", +] + [[package]] name = "axum" version = "0.7.9" @@ -2192,6 +2214,12 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + [[package]] name = "fsevent-sys" version = "4.1.0" @@ -5115,6 +5143,16 @@ dependencies = [ "hmac", ] +[[package]] +name = "pem" +version = "3.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d30c53c26bc5b31a98cd02d20f25a7c8567146caf63ed593a9d87b2775291be" +dependencies = [ + "base64 0.22.1", + "serde_core", +] + [[package]] name = "percent-encoding" version = "2.3.2" @@ -5955,6 +5993,19 @@ dependencies = [ "cipher", ] +[[package]] +name = "rcgen" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75e669e5202259b5314d1ea5397316ad400819437857b90861765f24c4cf80a2" +dependencies = [ + "pem", + "ring", + "rustls-pki-types", + "time", + "yasna", +] + [[package]] name = "redox_syscall" version = "0.5.12" @@ -6688,6 +6739,8 @@ version = "0.23.34" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a9586e9ee2b4f8fab52a0048ca7334d7024eef48e2cb9407e3497bb7cab7fa7" dependencies = [ + "aws-lc-rs", + "log 0.4.29", "once_cell", "ring", "rustls-pki-types", @@ -6760,6 +6813,7 @@ version = "0.103.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e10b3f4191e8a80e6b43eebabfac91e5dcecebb27a71f04e820c47ec41d314bf" dependencies = [ + "aws-lc-rs", "ring", "rustls-pki-types", "untrusted", @@ -10486,6 +10540,23 @@ dependencies = [ "zip-extract", ] +[[package]] +name = "yaak-proxy" +version = "0.1.0" +dependencies = [ + "bytes", + "http", + "http-body-util", + "hyper", + "hyper-util", + "pem", + "rcgen", + "rustls", + "rustls-native-certs", + "tokio", + "tokio-rustls", +] + [[package]] name = "yaak-sse" version = "0.1.0" @@ -10582,6 +10653,9 @@ name = "yasna" version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd" +dependencies = [ + "time", +] [[package]] name = "yoke" diff --git a/Cargo.toml b/Cargo.toml index ac0e884f..e47b8be2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ members = [ "crates/yaak-tls", "crates/yaak-ws", "crates/yaak-api", + "crates/yaak-proxy", # CLI crates "crates-cli/yaak-cli", # Tauri-specific crates @@ -63,6 +64,7 @@ yaak-templates = { path = "crates/yaak-templates" } yaak-tls = { path = "crates/yaak-tls" } yaak-ws = { path = "crates/yaak-ws" } yaak-api = { path = "crates/yaak-api" } +yaak-proxy = { path = "crates/yaak-proxy" } # Internal crates - Tauri-specific yaak-fonts = { path = "crates-tauri/yaak-fonts" } diff --git a/crates/yaak-proxy/Cargo.toml b/crates/yaak-proxy/Cargo.toml new file mode 100644 index 00000000..d297b511 --- /dev/null +++ b/crates/yaak-proxy/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "yaak-proxy" +version = "0.1.0" +edition = "2024" +publish = false + +[dependencies] +hyper = { version = "1", features = ["http1", "http2", "server", "client"] } +hyper-util = { version = "0.1", features = ["tokio", "server-auto", "client-legacy"] } +http-body-util = "0.1" +http = "1" +bytes = "1" +tokio = { workspace = true, features = ["rt-multi-thread", "net", "sync", "macros", "time", "io-util"] } +rcgen = "0.13" +rustls = { workspace = true, features = ["ring"] } +rustls-native-certs = "0.8" +tokio-rustls = "0.26" +pem = "3" diff --git a/crates/yaak-proxy/src/body.rs b/crates/yaak-proxy/src/body.rs new file mode 100644 index 00000000..216206de --- /dev/null +++ b/crates/yaak-proxy/src/body.rs @@ -0,0 +1,114 @@ +use std::pin::Pin; +use std::sync::mpsc as std_mpsc; +use std::task::{Context, Poll}; +use std::time::Instant; + +use bytes::Bytes; +use hyper::body::{Body, Frame}; + +use crate::ProxyEvent; + +/// A body wrapper that emits `ResponseBodyChunk` per frame and +/// `ResponseBodyComplete` when the stream finishes. +pub struct MeasuredBody { + inner: B, + request_id: u64, + bytes_count: u64, + chunks: Vec, + event_tx: std_mpsc::Sender, + start: Instant, + finished: bool, +} + +impl MeasuredBody { + pub fn new( + inner: B, + request_id: u64, + start: Instant, + event_tx: std_mpsc::Sender, + ) -> Self { + Self { + inner, + request_id, + bytes_count: 0, + chunks: Vec::new(), + event_tx, + start, + finished: false, + } + } + + fn send_complete(&mut self) { + if !self.finished { + self.finished = true; + let body = if self.chunks.is_empty() { + None + } else { + let mut buf = Vec::with_capacity(self.bytes_count as usize); + for chunk in self.chunks.drain(..) { + buf.extend_from_slice(&chunk); + } + Some(buf) + }; + let _ = self.event_tx.send(ProxyEvent::ResponseBodyComplete { + id: self.request_id, + body, + size: self.bytes_count, + elapsed_ms: self.start.elapsed().as_millis() as u64, + }); + } + } +} + +impl Body for MeasuredBody +where + B: Body + Unpin, + B::Error: std::error::Error + Send + Sync + 'static, +{ + type Data = Bytes; + type Error = B::Error; + + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + let inner = Pin::new(&mut self.inner); + match inner.poll_frame(cx) { + Poll::Ready(Some(Ok(frame))) => { + if let Some(data) = frame.data_ref() { + let len = data.len(); + self.bytes_count += len as u64; + self.chunks.push(data.clone()); + let _ = self.event_tx.send(ProxyEvent::ResponseBodyChunk { + id: self.request_id, + bytes: len, + }); + } + Poll::Ready(Some(Ok(frame))) + } + Poll::Ready(Some(Err(e))) => { + self.send_complete(); + Poll::Ready(Some(Err(e))) + } + Poll::Ready(None) => { + self.send_complete(); + Poll::Ready(None) + } + Poll::Pending => Poll::Pending, + } + } + + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } + + fn size_hint(&self) -> hyper::body::SizeHint { + self.inner.size_hint() + } +} + +impl Drop for MeasuredBody { + fn drop(&mut self) { + self.send_complete(); + } +} diff --git a/crates/yaak-proxy/src/cert.rs b/crates/yaak-proxy/src/cert.rs new file mode 100644 index 00000000..7636057c --- /dev/null +++ b/crates/yaak-proxy/src/cert.rs @@ -0,0 +1,82 @@ +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; + +use rcgen::{BasicConstraints, Certificate, CertificateParams, IsCa, KeyPair, KeyUsagePurpose}; +use rustls::ServerConfig; +use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; + +pub struct CertificateAuthority { + ca_cert: Certificate, + ca_cert_der: CertificateDer<'static>, + ca_key: KeyPair, + cache: Mutex>>, +} + +impl CertificateAuthority { + pub fn new() -> Result> { + let mut params = CertificateParams::default(); + params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained); + params.key_usages.push(KeyUsagePurpose::KeyCertSign); + params.key_usages.push(KeyUsagePurpose::CrlSign); + params + .distinguished_name + .push(rcgen::DnType::CommonName, "Debug Proxy CA"); + params + .distinguished_name + .push(rcgen::DnType::OrganizationName, "Debug Proxy"); + + let key = KeyPair::generate()?; + let ca_cert = params.self_signed(&key)?; + let ca_cert_der = ca_cert.der().clone(); + + Ok(Self { + ca_cert, + ca_cert_der, + ca_key: key, + cache: Mutex::new(HashMap::new()), + }) + } + + pub fn ca_pem(&self) -> String { + pem::encode(&pem::Pem::new("CERTIFICATE", self.ca_cert_der.to_vec())) + } + + pub fn server_config( + &self, + domain: &str, + ) -> Result, Box> { + { + let cache = self.cache.lock().unwrap(); + if let Some(config) = cache.get(domain) { + return Ok(config.clone()); + } + } + + let mut params = CertificateParams::new(vec![domain.to_string()])?; + params + .distinguished_name + .push(rcgen::DnType::CommonName, domain); + + let leaf_key = KeyPair::generate()?; + let leaf_cert = params.signed_by(&leaf_key, &self.ca_cert, &self.ca_key)?; + + let cert_der = leaf_cert.der().clone(); + let key_der = leaf_key.serialize_der(); + + let mut config = ServerConfig::builder_with_provider(Arc::new(rustls::crypto::ring::default_provider())) + .with_safe_default_protocol_versions()? + .with_no_client_auth() + .with_single_cert( + vec![cert_der, self.ca_cert_der.clone()], + PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_der)), + )?; + config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + + let config = Arc::new(config); + self.cache + .lock() + .unwrap() + .insert(domain.to_string(), config.clone()); + Ok(config) + } +} diff --git a/crates/yaak-proxy/src/connection.rs b/crates/yaak-proxy/src/connection.rs new file mode 100644 index 00000000..77e6ee6d --- /dev/null +++ b/crates/yaak-proxy/src/connection.rs @@ -0,0 +1,32 @@ +use std::sync::mpsc as std_mpsc; +use std::sync::Arc; + +use hyper::server::conn::http1; +use hyper::service::service_fn; +use tokio::net::TcpStream; + +use crate::ProxyEvent; +use crate::cert::CertificateAuthority; +use crate::request::handle_request; + +pub(crate) async fn handle_connection( + stream: TcpStream, + event_tx: std_mpsc::Sender, + ca: Arc, +) -> Result<(), Box> { + let tx = event_tx.clone(); + http1::Builder::new() + .preserve_header_case(true) + .title_case_headers(true) + .serve_connection( + hyper_util::rt::TokioIo::new(stream), + service_fn(move |req| { + let tx = tx.clone(); + let ca = ca.clone(); + async move { handle_request(req, tx, ca).await } + }), + ) + .with_upgrades() + .await + .map_err(|e| Box::new(e) as Box) +} diff --git a/crates/yaak-proxy/src/lib.rs b/crates/yaak-proxy/src/lib.rs new file mode 100644 index 00000000..b16d8043 --- /dev/null +++ b/crates/yaak-proxy/src/lib.rs @@ -0,0 +1,168 @@ +pub mod body; +pub mod cert; +mod connection; +mod request; + +use std::net::SocketAddr; +use std::sync::atomic::AtomicU64; +use std::sync::mpsc as std_mpsc; +use std::sync::Arc; + +use cert::CertificateAuthority; +use tokio::net::TcpListener; + +use connection::handle_connection; + +static REQUEST_ID: AtomicU64 = AtomicU64::new(1); + +/// Granular events emitted during request/response lifecycle. +/// Each event carries a request `id` so consumers can correlate events. +#[derive(Debug, Clone)] +pub enum ProxyEvent { + /// A new request has been received from the client. + RequestStart { + id: u64, + method: String, + url: String, + http_version: String, + }, + /// A request header sent to the upstream server. + RequestHeader { id: u64, name: String, value: String }, + /// The full request body (buffered before forwarding). + RequestBody { id: u64, body: Vec }, + /// Response headers received from upstream. + ResponseStart { + id: u64, + status: u16, + http_version: String, + elapsed_ms: u64, + }, + /// A response header received from the upstream server. + ResponseHeader { id: u64, name: String, value: String }, + /// A chunk of the response body was received (emitted per-frame). + ResponseBodyChunk { id: u64, bytes: usize }, + /// The response body stream has completed. + ResponseBodyComplete { + id: u64, + body: Option>, + size: u64, + elapsed_ms: u64, + }, + /// The upstream request failed. + Error { id: u64, error: String }, +} + +/// Accumulated view of a proxied request, built from `ProxyEvent`s. +#[derive(Debug, Clone)] +pub struct CapturedRequest { + pub id: u64, + pub method: String, + pub url: String, + pub status: Option, + pub elapsed_ms: Option, + pub http_version: String, + pub remote_http_version: Option, + pub request_headers: Vec<(String, String)>, + pub request_body: Option>, + pub response_headers: Vec<(String, String)>, + pub response_body: Option>, + pub response_body_size: u64, + pub state: RequestState, + pub error: Option, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum RequestState { + Sending, + Receiving, + Complete, + Error, +} + +pub struct ProxyHandle { + shutdown_tx: Option>, + thread_handle: Option>, + pub event_rx: std_mpsc::Receiver, + pub port: u16, + pub ca_pem: String, +} + +impl Drop for ProxyHandle { + fn drop(&mut self) { + if let Some(tx) = self.shutdown_tx.take() { + let _ = tx.send(()); + } + if let Some(handle) = self.thread_handle.take() { + let _ = handle.join(); + } + } +} + +pub fn start_proxy(port: u16) -> Result { + let ca = CertificateAuthority::new().map_err(|e| format!("Failed to create CA: {e}"))?; + let ca_pem = ca.ca_pem(); + let ca = Arc::new(ca); + + let (event_tx, event_rx) = std_mpsc::channel(); + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); + let (ready_tx, ready_rx) = std_mpsc::channel(); + + let thread_handle = std::thread::spawn(move || { + let rt = match tokio::runtime::Runtime::new() { + Ok(rt) => rt, + Err(e) => { + let _ = ready_tx.send(Err(format!("Failed to create runtime: {e}"))); + return; + } + }; + + rt.block_on(async move { + let addr = SocketAddr::from(([127, 0, 0, 1], port)); + let listener = match TcpListener::bind(addr).await { + Ok(l) => l, + Err(e) => { + let _ = ready_tx.send(Err(format!("Failed to bind: {e}"))); + return; + } + }; + + let bound_port = listener.local_addr().unwrap().port(); + let _ = ready_tx.send(Ok(bound_port)); + + let mut shutdown_rx = shutdown_rx; + loop { + tokio::select! { + result = listener.accept() => { + match result { + Ok((stream, _addr)) => { + let tx = event_tx.clone(); + let ca = ca.clone(); + tokio::spawn(async move { + if let Err(e) = handle_connection(stream, tx, ca).await { + eprintln!("Connection error: {e}"); + } + }); + } + Err(e) => eprintln!("Accept error: {e}"), + } + } + _ = &mut shutdown_rx => { + break; + } + } + } + }); + }); + + match ready_rx.recv() { + Ok(Ok(bound_port)) => Ok(ProxyHandle { + shutdown_tx: Some(shutdown_tx), + thread_handle: Some(thread_handle), + event_rx, + port: bound_port, + ca_pem, + }), + Ok(Err(e)) => Err(e), + Err(_) => Err("Proxy thread died before binding".into()), + } +} diff --git a/crates/yaak-proxy/src/request.rs b/crates/yaak-proxy/src/request.rs new file mode 100644 index 00000000..7285af5c --- /dev/null +++ b/crates/yaak-proxy/src/request.rs @@ -0,0 +1,390 @@ +use std::convert::Infallible; +use std::sync::Arc; +use std::sync::atomic::Ordering; +use std::sync::mpsc as std_mpsc; +use std::time::Instant; + +use bytes::Bytes; +use http_body_util::{BodyExt, Full}; +use hyper::body::Incoming; +use hyper::header::HeaderMap; +use hyper::service::service_fn; +use hyper::{Method, Request, Response, StatusCode, Uri}; +use hyper_util::client::legacy::Client; +use hyper_util::rt::TokioExecutor; +use hyper_util::server::conn::auto; +use rustls::ClientConfig; +use rustls::pki_types::ServerName; +use tokio::net::TcpStream; +use tokio_rustls::TlsAcceptor; + +use crate::body::MeasuredBody; +use crate::cert::CertificateAuthority; +use crate::{ProxyEvent, REQUEST_ID}; + +type BoxBody = http_body_util::combinators::BoxBody; + +fn full_body(bytes: Bytes) -> BoxBody { + Full::new(bytes).map_err(|never| match never {}).boxed() +} + +fn measured_incoming( + incoming: Incoming, + id: u64, + start: Instant, + tx: std_mpsc::Sender, +) -> BoxBody { + MeasuredBody::new(incoming, id, start, tx).boxed() +} + +fn version_str(v: hyper::Version) -> String { + match v { + hyper::Version::HTTP_09 => "HTTP/0.9", + hyper::Version::HTTP_10 => "HTTP/1.0", + hyper::Version::HTTP_11 => "HTTP/1.1", + hyper::Version::HTTP_2 => "HTTP/2", + hyper::Version::HTTP_3 => "HTTP/3", + _ => "unknown", + } + .to_string() +} + +fn emit_request_events( + tx: &std_mpsc::Sender, + id: u64, + headers: &HeaderMap, + body: &Option>, +) { + for (name, value) in headers.iter() { + let _ = tx.send(ProxyEvent::RequestHeader { + id, + name: name.to_string(), + value: value.to_str().unwrap_or("").to_string(), + }); + } + if let Some(body) = body { + let _ = tx.send(ProxyEvent::RequestBody { + id, + body: body.clone(), + }); + } +} + +fn emit_response_events( + tx: &std_mpsc::Sender, + id: u64, + resp: &Response, + start: &Instant, +) { + let _ = tx.send(ProxyEvent::ResponseStart { + id, + status: resp.status().as_u16(), + http_version: version_str(resp.version()), + elapsed_ms: start.elapsed().as_millis() as u64, + }); + for (name, value) in resp.headers().iter() { + let _ = tx.send(ProxyEvent::ResponseHeader { + id, + name: name.to_string(), + value: value.to_str().unwrap_or("").to_string(), + }); + } +} + +pub(crate) async fn handle_request( + req: Request, + event_tx: std_mpsc::Sender, + ca: Arc, +) -> Result, Infallible> { + let result = if req.method() == Method::CONNECT { + handle_connect(req, event_tx, ca).await + } else { + handle_http(req, event_tx).await + }; + match result { + Ok(resp) => Ok(resp), + Err(e) => { + eprintln!("Proxy error: {e}"); + Ok(Response::builder() + .status(StatusCode::BAD_GATEWAY) + .body(full_body(Bytes::from(format!("Proxy error: {e}")))) + .unwrap()) + } + } +} + +async fn handle_http( + req: Request, + event_tx: std_mpsc::Sender, +) -> Result, Box> { + let id = REQUEST_ID.fetch_add(1, Ordering::Relaxed); + let method = req.method().to_string(); + let uri = req.uri().to_string(); + let http_version = version_str(req.version()); + let start = Instant::now(); + + let _ = event_tx.send(ProxyEvent::RequestStart { + id, + method, + url: uri.clone(), + http_version, + }); + + let client: Client<_, Full> = Client::builder(TokioExecutor::new()).build_http(); + + let (parts, body) = req.into_parts(); + let body_bytes = body.collect().await?.to_bytes(); + let request_body = if body_bytes.is_empty() { + None + } else { + Some(body_bytes.to_vec()) + }; + emit_request_events(&event_tx, id, &parts.headers, &request_body); + + let outgoing_req = Request::from_parts(parts, Full::new(body_bytes)); + + match client.request(outgoing_req).await { + Ok(resp) => { + emit_response_events(&event_tx, id, &resp, &start); + + let (parts, body) = resp.into_parts(); + Ok(Response::from_parts( + parts, + measured_incoming(body, id, start, event_tx), + )) + } + Err(e) => { + let _ = event_tx.send(ProxyEvent::Error { + id, + error: e.to_string(), + }); + Err(Box::new(e) as Box) + } + } +} + +async fn handle_connect( + req: Request, + event_tx: std_mpsc::Sender, + ca: Arc, +) -> Result, Box> { + let authority = req + .uri() + .authority() + .map(|a| a.to_string()) + .unwrap_or_default(); + let (host, port) = parse_host_port(&authority); + + let server_config = ca.server_config(&host)?; + let acceptor = TlsAcceptor::from(server_config); + + let target_addr = format!("{host}:{port}"); + + tokio::spawn(async move { + let upgraded = match hyper::upgrade::on(req).await { + Ok(u) => u, + Err(e) => { + eprintln!("CONNECT upgrade failed: {e}"); + return; + } + }; + + let tls_stream = match acceptor + .accept(hyper_util::rt::TokioIo::new(upgraded)) + .await + { + Ok(s) => s, + Err(e) => { + eprintln!("TLS accept failed for {host}: {e}"); + return; + } + }; + + let tx = event_tx.clone(); + let host_for_requests = host.clone(); + let mut builder = auto::Builder::new(TokioExecutor::new()); + builder + .http1() + .preserve_header_case(true) + .title_case_headers(true); + if let Err(e) = builder + .serve_connection_with_upgrades( + hyper_util::rt::TokioIo::new(tls_stream), + service_fn(move |req| { + let tx = tx.clone(); + let host = host_for_requests.clone(); + let target_addr = target_addr.clone(); + async move { handle_tunneled_request(req, tx, &host, &target_addr).await } + }), + ) + .await + { + eprintln!("MITM connection error for {host}: {e}"); + } + }); + + Ok(Response::new(full_body(Bytes::new()))) +} + +async fn handle_tunneled_request( + req: Request, + event_tx: std_mpsc::Sender, + host: &str, + target_addr: &str, +) -> Result, Infallible> { + let result = forward_https(req, event_tx, host, target_addr).await; + match result { + Ok(resp) => Ok(resp), + Err(e) => { + eprintln!("HTTPS forward error: {e:?}"); + Ok(Response::builder() + .status(StatusCode::BAD_GATEWAY) + .body(full_body(Bytes::from(format!("Proxy error: {e}")))) + .unwrap()) + } + } +} + +enum HttpSender { + H1(hyper::client::conn::http1::SendRequest>), + H2(hyper::client::conn::http2::SendRequest>), +} + +impl HttpSender { + async fn send_request( + &mut self, + req: Request>, + ) -> Result, hyper::Error> { + match self { + HttpSender::H1(s) => s.send_request(req).await, + HttpSender::H2(s) => s.send_request(req).await, + } + } +} + +async fn forward_https( + req: Request, + event_tx: std_mpsc::Sender, + host: &str, + target_addr: &str, +) -> Result, Box> { + let id = REQUEST_ID.fetch_add(1, Ordering::Relaxed); + let method = req.method().to_string(); + let http_version = version_str(req.version()); + let path = req + .uri() + .path_and_query() + .map(|pq| pq.to_string()) + .unwrap_or_else(|| "/".into()); + let uri_str = format!("https://{host}{path}"); + let start = Instant::now(); + + let _ = event_tx.send(ProxyEvent::RequestStart { + id, + method, + url: uri_str.clone(), + http_version, + }); + + // Connect to upstream with TLS + let tcp_stream = TcpStream::connect(target_addr).await?; + + let mut root_store = rustls::RootCertStore::empty(); + for cert in rustls_native_certs::load_native_certs().certs { + let _ = root_store.add(cert); + } + + let mut tls_config = + ClientConfig::builder_with_provider(Arc::new(rustls::crypto::ring::default_provider())) + .with_safe_default_protocol_versions()? + .with_root_certificates(root_store) + .with_no_client_auth(); + tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + + let connector = tokio_rustls::TlsConnector::from(Arc::new(tls_config)); + let server_name = ServerName::try_from(host.to_string())?; + let tls_stream = connector.connect(server_name, tcp_stream).await?; + + let negotiated_h2 = tls_stream + .get_ref() + .1 + .alpn_protocol() + .map_or(false, |p| p == b"h2"); + + let io = hyper_util::rt::TokioIo::new(tls_stream); + + let mut sender = if negotiated_h2 { + let (sender, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor::new()) + .handshake(io) + .await?; + tokio::spawn(async move { + if let Err(e) = conn.await { + eprintln!("Upstream h2 connection error: {e}"); + } + }); + HttpSender::H2(sender) + } else { + let (sender, conn) = hyper::client::conn::http1::Builder::new() + .preserve_header_case(true) + .title_case_headers(true) + .handshake(io) + .await?; + tokio::spawn(async move { + if let Err(e) = conn.await { + eprintln!("Upstream h1 connection error: {e}"); + } + }); + HttpSender::H1(sender) + }; + + // Capture request metadata + let (mut parts, body) = req.into_parts(); + let body_bytes = body.collect().await?.to_bytes(); + let request_body = if body_bytes.is_empty() { + None + } else { + Some(body_bytes.to_vec()) + }; + emit_request_events(&event_tx, id, &parts.headers, &request_body); + + if negotiated_h2 { + // HTTP/2 requires absolute-form URI with scheme + authority + parts.uri = uri_str.parse::()?; + } else { + parts.uri = path.parse::()?; + } + + if !parts.headers.contains_key(hyper::header::HOST) { + parts.headers.insert(hyper::header::HOST, host.parse()?); + } + + let outgoing = Request::from_parts(parts, Full::new(body_bytes)); + + match sender.send_request(outgoing).await { + Ok(resp) => { + emit_response_events(&event_tx, id, &resp, &start); + + let (parts, body) = resp.into_parts(); + Ok(Response::from_parts( + parts, + measured_incoming(body, id, start, event_tx), + )) + } + Err(e) => { + let _ = event_tx.send(ProxyEvent::Error { + id, + error: e.to_string(), + }); + Err(Box::new(e) as Box) + } + } +} + +fn parse_host_port(authority: &str) -> (String, u16) { + if let Some((host, port_str)) = authority.rsplit_once(':') { + if let Ok(port) = port_str.parse::() { + return (host.to_string(), port); + } + } + (authority.to_string(), 443) +} diff --git a/tsconfig.json b/tsconfig.json index f02cc3d3..bef3b63d 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -5,7 +5,6 @@ "useDefineForClassFields": true, "allowJs": false, "skipLibCheck": true, - "esModuleInterop": false, "allowSyntheticDefaultImports": true, "strict": true, "noUncheckedIndexedAccess": true, @@ -15,6 +14,6 @@ "resolveJsonModule": true, "isolatedModules": true, "noEmit": true, - "jsx": "react-jsx" - } + "jsx": "react-jsx", + }, }