mirror of
https://github.com/mountain-loop/yaak.git
synced 2026-03-31 06:23:08 +02:00
Add initial yaak-proxy crate
This commit is contained in:
114
crates/yaak-proxy/src/body.rs
Normal file
114
crates/yaak-proxy/src/body.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
use std::pin::Pin;
|
||||
use std::sync::mpsc as std_mpsc;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Instant;
|
||||
|
||||
use bytes::Bytes;
|
||||
use hyper::body::{Body, Frame};
|
||||
|
||||
use crate::ProxyEvent;
|
||||
|
||||
/// A body wrapper that emits `ResponseBodyChunk` per frame and
|
||||
/// `ResponseBodyComplete` when the stream finishes.
|
||||
pub struct MeasuredBody<B> {
|
||||
inner: B,
|
||||
request_id: u64,
|
||||
bytes_count: u64,
|
||||
chunks: Vec<Bytes>,
|
||||
event_tx: std_mpsc::Sender<ProxyEvent>,
|
||||
start: Instant,
|
||||
finished: bool,
|
||||
}
|
||||
|
||||
impl<B> MeasuredBody<B> {
|
||||
pub fn new(
|
||||
inner: B,
|
||||
request_id: u64,
|
||||
start: Instant,
|
||||
event_tx: std_mpsc::Sender<ProxyEvent>,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
request_id,
|
||||
bytes_count: 0,
|
||||
chunks: Vec::new(),
|
||||
event_tx,
|
||||
start,
|
||||
finished: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn send_complete(&mut self) {
|
||||
if !self.finished {
|
||||
self.finished = true;
|
||||
let body = if self.chunks.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let mut buf = Vec::with_capacity(self.bytes_count as usize);
|
||||
for chunk in self.chunks.drain(..) {
|
||||
buf.extend_from_slice(&chunk);
|
||||
}
|
||||
Some(buf)
|
||||
};
|
||||
let _ = self.event_tx.send(ProxyEvent::ResponseBodyComplete {
|
||||
id: self.request_id,
|
||||
body,
|
||||
size: self.bytes_count,
|
||||
elapsed_ms: self.start.elapsed().as_millis() as u64,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> Body for MeasuredBody<B>
|
||||
where
|
||||
B: Body<Data = Bytes> + Unpin,
|
||||
B::Error: std::error::Error + Send + Sync + 'static,
|
||||
{
|
||||
type Data = Bytes;
|
||||
type Error = B::Error;
|
||||
|
||||
fn poll_frame(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
|
||||
let inner = Pin::new(&mut self.inner);
|
||||
match inner.poll_frame(cx) {
|
||||
Poll::Ready(Some(Ok(frame))) => {
|
||||
if let Some(data) = frame.data_ref() {
|
||||
let len = data.len();
|
||||
self.bytes_count += len as u64;
|
||||
self.chunks.push(data.clone());
|
||||
let _ = self.event_tx.send(ProxyEvent::ResponseBodyChunk {
|
||||
id: self.request_id,
|
||||
bytes: len,
|
||||
});
|
||||
}
|
||||
Poll::Ready(Some(Ok(frame)))
|
||||
}
|
||||
Poll::Ready(Some(Err(e))) => {
|
||||
self.send_complete();
|
||||
Poll::Ready(Some(Err(e)))
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
self.send_complete();
|
||||
Poll::Ready(None)
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_end_stream(&self) -> bool {
|
||||
self.inner.is_end_stream()
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> hyper::body::SizeHint {
|
||||
self.inner.size_hint()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> Drop for MeasuredBody<B> {
|
||||
fn drop(&mut self) {
|
||||
self.send_complete();
|
||||
}
|
||||
}
|
||||
82
crates/yaak-proxy/src/cert.rs
Normal file
82
crates/yaak-proxy/src/cert.rs
Normal file
@@ -0,0 +1,82 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use rcgen::{BasicConstraints, Certificate, CertificateParams, IsCa, KeyPair, KeyUsagePurpose};
|
||||
use rustls::ServerConfig;
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
|
||||
|
||||
pub struct CertificateAuthority {
|
||||
ca_cert: Certificate,
|
||||
ca_cert_der: CertificateDer<'static>,
|
||||
ca_key: KeyPair,
|
||||
cache: Mutex<HashMap<String, Arc<ServerConfig>>>,
|
||||
}
|
||||
|
||||
impl CertificateAuthority {
|
||||
pub fn new() -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let mut params = CertificateParams::default();
|
||||
params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
|
||||
params.key_usages.push(KeyUsagePurpose::KeyCertSign);
|
||||
params.key_usages.push(KeyUsagePurpose::CrlSign);
|
||||
params
|
||||
.distinguished_name
|
||||
.push(rcgen::DnType::CommonName, "Debug Proxy CA");
|
||||
params
|
||||
.distinguished_name
|
||||
.push(rcgen::DnType::OrganizationName, "Debug Proxy");
|
||||
|
||||
let key = KeyPair::generate()?;
|
||||
let ca_cert = params.self_signed(&key)?;
|
||||
let ca_cert_der = ca_cert.der().clone();
|
||||
|
||||
Ok(Self {
|
||||
ca_cert,
|
||||
ca_cert_der,
|
||||
ca_key: key,
|
||||
cache: Mutex::new(HashMap::new()),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn ca_pem(&self) -> String {
|
||||
pem::encode(&pem::Pem::new("CERTIFICATE", self.ca_cert_der.to_vec()))
|
||||
}
|
||||
|
||||
pub fn server_config(
|
||||
&self,
|
||||
domain: &str,
|
||||
) -> Result<Arc<ServerConfig>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
{
|
||||
let cache = self.cache.lock().unwrap();
|
||||
if let Some(config) = cache.get(domain) {
|
||||
return Ok(config.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let mut params = CertificateParams::new(vec![domain.to_string()])?;
|
||||
params
|
||||
.distinguished_name
|
||||
.push(rcgen::DnType::CommonName, domain);
|
||||
|
||||
let leaf_key = KeyPair::generate()?;
|
||||
let leaf_cert = params.signed_by(&leaf_key, &self.ca_cert, &self.ca_key)?;
|
||||
|
||||
let cert_der = leaf_cert.der().clone();
|
||||
let key_der = leaf_key.serialize_der();
|
||||
|
||||
let mut config = ServerConfig::builder_with_provider(Arc::new(rustls::crypto::ring::default_provider()))
|
||||
.with_safe_default_protocol_versions()?
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(
|
||||
vec![cert_der, self.ca_cert_der.clone()],
|
||||
PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_der)),
|
||||
)?;
|
||||
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
||||
|
||||
let config = Arc::new(config);
|
||||
self.cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(domain.to_string(), config.clone());
|
||||
Ok(config)
|
||||
}
|
||||
}
|
||||
32
crates/yaak-proxy/src/connection.rs
Normal file
32
crates/yaak-proxy/src/connection.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use std::sync::mpsc as std_mpsc;
|
||||
use std::sync::Arc;
|
||||
|
||||
use hyper::server::conn::http1;
|
||||
use hyper::service::service_fn;
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
use crate::ProxyEvent;
|
||||
use crate::cert::CertificateAuthority;
|
||||
use crate::request::handle_request;
|
||||
|
||||
pub(crate) async fn handle_connection(
|
||||
stream: TcpStream,
|
||||
event_tx: std_mpsc::Sender<ProxyEvent>,
|
||||
ca: Arc<CertificateAuthority>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let tx = event_tx.clone();
|
||||
http1::Builder::new()
|
||||
.preserve_header_case(true)
|
||||
.title_case_headers(true)
|
||||
.serve_connection(
|
||||
hyper_util::rt::TokioIo::new(stream),
|
||||
service_fn(move |req| {
|
||||
let tx = tx.clone();
|
||||
let ca = ca.clone();
|
||||
async move { handle_request(req, tx, ca).await }
|
||||
}),
|
||||
)
|
||||
.with_upgrades()
|
||||
.await
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
|
||||
}
|
||||
168
crates/yaak-proxy/src/lib.rs
Normal file
168
crates/yaak-proxy/src/lib.rs
Normal file
@@ -0,0 +1,168 @@
|
||||
pub mod body;
|
||||
pub mod cert;
|
||||
mod connection;
|
||||
mod request;
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::mpsc as std_mpsc;
|
||||
use std::sync::Arc;
|
||||
|
||||
use cert::CertificateAuthority;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
use connection::handle_connection;
|
||||
|
||||
static REQUEST_ID: AtomicU64 = AtomicU64::new(1);
|
||||
|
||||
/// Granular events emitted during request/response lifecycle.
|
||||
/// Each event carries a request `id` so consumers can correlate events.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ProxyEvent {
|
||||
/// A new request has been received from the client.
|
||||
RequestStart {
|
||||
id: u64,
|
||||
method: String,
|
||||
url: String,
|
||||
http_version: String,
|
||||
},
|
||||
/// A request header sent to the upstream server.
|
||||
RequestHeader { id: u64, name: String, value: String },
|
||||
/// The full request body (buffered before forwarding).
|
||||
RequestBody { id: u64, body: Vec<u8> },
|
||||
/// Response headers received from upstream.
|
||||
ResponseStart {
|
||||
id: u64,
|
||||
status: u16,
|
||||
http_version: String,
|
||||
elapsed_ms: u64,
|
||||
},
|
||||
/// A response header received from the upstream server.
|
||||
ResponseHeader { id: u64, name: String, value: String },
|
||||
/// A chunk of the response body was received (emitted per-frame).
|
||||
ResponseBodyChunk { id: u64, bytes: usize },
|
||||
/// The response body stream has completed.
|
||||
ResponseBodyComplete {
|
||||
id: u64,
|
||||
body: Option<Vec<u8>>,
|
||||
size: u64,
|
||||
elapsed_ms: u64,
|
||||
},
|
||||
/// The upstream request failed.
|
||||
Error { id: u64, error: String },
|
||||
}
|
||||
|
||||
/// Accumulated view of a proxied request, built from `ProxyEvent`s.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CapturedRequest {
|
||||
pub id: u64,
|
||||
pub method: String,
|
||||
pub url: String,
|
||||
pub status: Option<u16>,
|
||||
pub elapsed_ms: Option<u64>,
|
||||
pub http_version: String,
|
||||
pub remote_http_version: Option<String>,
|
||||
pub request_headers: Vec<(String, String)>,
|
||||
pub request_body: Option<Vec<u8>>,
|
||||
pub response_headers: Vec<(String, String)>,
|
||||
pub response_body: Option<Vec<u8>>,
|
||||
pub response_body_size: u64,
|
||||
pub state: RequestState,
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum RequestState {
|
||||
Sending,
|
||||
Receiving,
|
||||
Complete,
|
||||
Error,
|
||||
}
|
||||
|
||||
pub struct ProxyHandle {
|
||||
shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
|
||||
thread_handle: Option<std::thread::JoinHandle<()>>,
|
||||
pub event_rx: std_mpsc::Receiver<ProxyEvent>,
|
||||
pub port: u16,
|
||||
pub ca_pem: String,
|
||||
}
|
||||
|
||||
impl Drop for ProxyHandle {
|
||||
fn drop(&mut self) {
|
||||
if let Some(tx) = self.shutdown_tx.take() {
|
||||
let _ = tx.send(());
|
||||
}
|
||||
if let Some(handle) = self.thread_handle.take() {
|
||||
let _ = handle.join();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start_proxy(port: u16) -> Result<ProxyHandle, String> {
|
||||
let ca = CertificateAuthority::new().map_err(|e| format!("Failed to create CA: {e}"))?;
|
||||
let ca_pem = ca.ca_pem();
|
||||
let ca = Arc::new(ca);
|
||||
|
||||
let (event_tx, event_rx) = std_mpsc::channel();
|
||||
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
|
||||
let (ready_tx, ready_rx) = std_mpsc::channel();
|
||||
|
||||
let thread_handle = std::thread::spawn(move || {
|
||||
let rt = match tokio::runtime::Runtime::new() {
|
||||
Ok(rt) => rt,
|
||||
Err(e) => {
|
||||
let _ = ready_tx.send(Err(format!("Failed to create runtime: {e}")));
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
rt.block_on(async move {
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], port));
|
||||
let listener = match TcpListener::bind(addr).await {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
let _ = ready_tx.send(Err(format!("Failed to bind: {e}")));
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let bound_port = listener.local_addr().unwrap().port();
|
||||
let _ = ready_tx.send(Ok(bound_port));
|
||||
|
||||
let mut shutdown_rx = shutdown_rx;
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = listener.accept() => {
|
||||
match result {
|
||||
Ok((stream, _addr)) => {
|
||||
let tx = event_tx.clone();
|
||||
let ca = ca.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_connection(stream, tx, ca).await {
|
||||
eprintln!("Connection error: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(e) => eprintln!("Accept error: {e}"),
|
||||
}
|
||||
}
|
||||
_ = &mut shutdown_rx => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
match ready_rx.recv() {
|
||||
Ok(Ok(bound_port)) => Ok(ProxyHandle {
|
||||
shutdown_tx: Some(shutdown_tx),
|
||||
thread_handle: Some(thread_handle),
|
||||
event_rx,
|
||||
port: bound_port,
|
||||
ca_pem,
|
||||
}),
|
||||
Ok(Err(e)) => Err(e),
|
||||
Err(_) => Err("Proxy thread died before binding".into()),
|
||||
}
|
||||
}
|
||||
390
crates/yaak-proxy/src/request.rs
Normal file
390
crates/yaak-proxy/src/request.rs
Normal file
@@ -0,0 +1,390 @@
|
||||
use std::convert::Infallible;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::mpsc as std_mpsc;
|
||||
use std::time::Instant;
|
||||
|
||||
use bytes::Bytes;
|
||||
use http_body_util::{BodyExt, Full};
|
||||
use hyper::body::Incoming;
|
||||
use hyper::header::HeaderMap;
|
||||
use hyper::service::service_fn;
|
||||
use hyper::{Method, Request, Response, StatusCode, Uri};
|
||||
use hyper_util::client::legacy::Client;
|
||||
use hyper_util::rt::TokioExecutor;
|
||||
use hyper_util::server::conn::auto;
|
||||
use rustls::ClientConfig;
|
||||
use rustls::pki_types::ServerName;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
|
||||
use crate::body::MeasuredBody;
|
||||
use crate::cert::CertificateAuthority;
|
||||
use crate::{ProxyEvent, REQUEST_ID};
|
||||
|
||||
type BoxBody = http_body_util::combinators::BoxBody<Bytes, hyper::Error>;
|
||||
|
||||
fn full_body(bytes: Bytes) -> BoxBody {
|
||||
Full::new(bytes).map_err(|never| match never {}).boxed()
|
||||
}
|
||||
|
||||
fn measured_incoming(
|
||||
incoming: Incoming,
|
||||
id: u64,
|
||||
start: Instant,
|
||||
tx: std_mpsc::Sender<ProxyEvent>,
|
||||
) -> BoxBody {
|
||||
MeasuredBody::new(incoming, id, start, tx).boxed()
|
||||
}
|
||||
|
||||
fn version_str(v: hyper::Version) -> String {
|
||||
match v {
|
||||
hyper::Version::HTTP_09 => "HTTP/0.9",
|
||||
hyper::Version::HTTP_10 => "HTTP/1.0",
|
||||
hyper::Version::HTTP_11 => "HTTP/1.1",
|
||||
hyper::Version::HTTP_2 => "HTTP/2",
|
||||
hyper::Version::HTTP_3 => "HTTP/3",
|
||||
_ => "unknown",
|
||||
}
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn emit_request_events(
|
||||
tx: &std_mpsc::Sender<ProxyEvent>,
|
||||
id: u64,
|
||||
headers: &HeaderMap,
|
||||
body: &Option<Vec<u8>>,
|
||||
) {
|
||||
for (name, value) in headers.iter() {
|
||||
let _ = tx.send(ProxyEvent::RequestHeader {
|
||||
id,
|
||||
name: name.to_string(),
|
||||
value: value.to_str().unwrap_or("<binary>").to_string(),
|
||||
});
|
||||
}
|
||||
if let Some(body) = body {
|
||||
let _ = tx.send(ProxyEvent::RequestBody {
|
||||
id,
|
||||
body: body.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn emit_response_events(
|
||||
tx: &std_mpsc::Sender<ProxyEvent>,
|
||||
id: u64,
|
||||
resp: &Response<Incoming>,
|
||||
start: &Instant,
|
||||
) {
|
||||
let _ = tx.send(ProxyEvent::ResponseStart {
|
||||
id,
|
||||
status: resp.status().as_u16(),
|
||||
http_version: version_str(resp.version()),
|
||||
elapsed_ms: start.elapsed().as_millis() as u64,
|
||||
});
|
||||
for (name, value) in resp.headers().iter() {
|
||||
let _ = tx.send(ProxyEvent::ResponseHeader {
|
||||
id,
|
||||
name: name.to_string(),
|
||||
value: value.to_str().unwrap_or("<binary>").to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_request(
|
||||
req: Request<Incoming>,
|
||||
event_tx: std_mpsc::Sender<ProxyEvent>,
|
||||
ca: Arc<CertificateAuthority>,
|
||||
) -> Result<Response<BoxBody>, Infallible> {
|
||||
let result = if req.method() == Method::CONNECT {
|
||||
handle_connect(req, event_tx, ca).await
|
||||
} else {
|
||||
handle_http(req, event_tx).await
|
||||
};
|
||||
match result {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(e) => {
|
||||
eprintln!("Proxy error: {e}");
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::BAD_GATEWAY)
|
||||
.body(full_body(Bytes::from(format!("Proxy error: {e}"))))
|
||||
.unwrap())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_http(
|
||||
req: Request<Incoming>,
|
||||
event_tx: std_mpsc::Sender<ProxyEvent>,
|
||||
) -> Result<Response<BoxBody>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let id = REQUEST_ID.fetch_add(1, Ordering::Relaxed);
|
||||
let method = req.method().to_string();
|
||||
let uri = req.uri().to_string();
|
||||
let http_version = version_str(req.version());
|
||||
let start = Instant::now();
|
||||
|
||||
let _ = event_tx.send(ProxyEvent::RequestStart {
|
||||
id,
|
||||
method,
|
||||
url: uri.clone(),
|
||||
http_version,
|
||||
});
|
||||
|
||||
let client: Client<_, Full<Bytes>> = Client::builder(TokioExecutor::new()).build_http();
|
||||
|
||||
let (parts, body) = req.into_parts();
|
||||
let body_bytes = body.collect().await?.to_bytes();
|
||||
let request_body = if body_bytes.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(body_bytes.to_vec())
|
||||
};
|
||||
emit_request_events(&event_tx, id, &parts.headers, &request_body);
|
||||
|
||||
let outgoing_req = Request::from_parts(parts, Full::new(body_bytes));
|
||||
|
||||
match client.request(outgoing_req).await {
|
||||
Ok(resp) => {
|
||||
emit_response_events(&event_tx, id, &resp, &start);
|
||||
|
||||
let (parts, body) = resp.into_parts();
|
||||
Ok(Response::from_parts(
|
||||
parts,
|
||||
measured_incoming(body, id, start, event_tx),
|
||||
))
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = event_tx.send(ProxyEvent::Error {
|
||||
id,
|
||||
error: e.to_string(),
|
||||
});
|
||||
Err(Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_connect(
|
||||
req: Request<Incoming>,
|
||||
event_tx: std_mpsc::Sender<ProxyEvent>,
|
||||
ca: Arc<CertificateAuthority>,
|
||||
) -> Result<Response<BoxBody>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let authority = req
|
||||
.uri()
|
||||
.authority()
|
||||
.map(|a| a.to_string())
|
||||
.unwrap_or_default();
|
||||
let (host, port) = parse_host_port(&authority);
|
||||
|
||||
let server_config = ca.server_config(&host)?;
|
||||
let acceptor = TlsAcceptor::from(server_config);
|
||||
|
||||
let target_addr = format!("{host}:{port}");
|
||||
|
||||
tokio::spawn(async move {
|
||||
let upgraded = match hyper::upgrade::on(req).await {
|
||||
Ok(u) => u,
|
||||
Err(e) => {
|
||||
eprintln!("CONNECT upgrade failed: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let tls_stream = match acceptor
|
||||
.accept(hyper_util::rt::TokioIo::new(upgraded))
|
||||
.await
|
||||
{
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
eprintln!("TLS accept failed for {host}: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let tx = event_tx.clone();
|
||||
let host_for_requests = host.clone();
|
||||
let mut builder = auto::Builder::new(TokioExecutor::new());
|
||||
builder
|
||||
.http1()
|
||||
.preserve_header_case(true)
|
||||
.title_case_headers(true);
|
||||
if let Err(e) = builder
|
||||
.serve_connection_with_upgrades(
|
||||
hyper_util::rt::TokioIo::new(tls_stream),
|
||||
service_fn(move |req| {
|
||||
let tx = tx.clone();
|
||||
let host = host_for_requests.clone();
|
||||
let target_addr = target_addr.clone();
|
||||
async move { handle_tunneled_request(req, tx, &host, &target_addr).await }
|
||||
}),
|
||||
)
|
||||
.await
|
||||
{
|
||||
eprintln!("MITM connection error for {host}: {e}");
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Response::new(full_body(Bytes::new())))
|
||||
}
|
||||
|
||||
async fn handle_tunneled_request(
|
||||
req: Request<Incoming>,
|
||||
event_tx: std_mpsc::Sender<ProxyEvent>,
|
||||
host: &str,
|
||||
target_addr: &str,
|
||||
) -> Result<Response<BoxBody>, Infallible> {
|
||||
let result = forward_https(req, event_tx, host, target_addr).await;
|
||||
match result {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(e) => {
|
||||
eprintln!("HTTPS forward error: {e:?}");
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::BAD_GATEWAY)
|
||||
.body(full_body(Bytes::from(format!("Proxy error: {e}"))))
|
||||
.unwrap())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum HttpSender {
|
||||
H1(hyper::client::conn::http1::SendRequest<Full<Bytes>>),
|
||||
H2(hyper::client::conn::http2::SendRequest<Full<Bytes>>),
|
||||
}
|
||||
|
||||
impl HttpSender {
|
||||
async fn send_request(
|
||||
&mut self,
|
||||
req: Request<Full<Bytes>>,
|
||||
) -> Result<Response<Incoming>, hyper::Error> {
|
||||
match self {
|
||||
HttpSender::H1(s) => s.send_request(req).await,
|
||||
HttpSender::H2(s) => s.send_request(req).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn forward_https(
|
||||
req: Request<Incoming>,
|
||||
event_tx: std_mpsc::Sender<ProxyEvent>,
|
||||
host: &str,
|
||||
target_addr: &str,
|
||||
) -> Result<Response<BoxBody>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let id = REQUEST_ID.fetch_add(1, Ordering::Relaxed);
|
||||
let method = req.method().to_string();
|
||||
let http_version = version_str(req.version());
|
||||
let path = req
|
||||
.uri()
|
||||
.path_and_query()
|
||||
.map(|pq| pq.to_string())
|
||||
.unwrap_or_else(|| "/".into());
|
||||
let uri_str = format!("https://{host}{path}");
|
||||
let start = Instant::now();
|
||||
|
||||
let _ = event_tx.send(ProxyEvent::RequestStart {
|
||||
id,
|
||||
method,
|
||||
url: uri_str.clone(),
|
||||
http_version,
|
||||
});
|
||||
|
||||
// Connect to upstream with TLS
|
||||
let tcp_stream = TcpStream::connect(target_addr).await?;
|
||||
|
||||
let mut root_store = rustls::RootCertStore::empty();
|
||||
for cert in rustls_native_certs::load_native_certs().certs {
|
||||
let _ = root_store.add(cert);
|
||||
}
|
||||
|
||||
let mut tls_config =
|
||||
ClientConfig::builder_with_provider(Arc::new(rustls::crypto::ring::default_provider()))
|
||||
.with_safe_default_protocol_versions()?
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth();
|
||||
tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
||||
|
||||
let connector = tokio_rustls::TlsConnector::from(Arc::new(tls_config));
|
||||
let server_name = ServerName::try_from(host.to_string())?;
|
||||
let tls_stream = connector.connect(server_name, tcp_stream).await?;
|
||||
|
||||
let negotiated_h2 = tls_stream
|
||||
.get_ref()
|
||||
.1
|
||||
.alpn_protocol()
|
||||
.map_or(false, |p| p == b"h2");
|
||||
|
||||
let io = hyper_util::rt::TokioIo::new(tls_stream);
|
||||
|
||||
let mut sender = if negotiated_h2 {
|
||||
let (sender, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
|
||||
.handshake(io)
|
||||
.await?;
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = conn.await {
|
||||
eprintln!("Upstream h2 connection error: {e}");
|
||||
}
|
||||
});
|
||||
HttpSender::H2(sender)
|
||||
} else {
|
||||
let (sender, conn) = hyper::client::conn::http1::Builder::new()
|
||||
.preserve_header_case(true)
|
||||
.title_case_headers(true)
|
||||
.handshake(io)
|
||||
.await?;
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = conn.await {
|
||||
eprintln!("Upstream h1 connection error: {e}");
|
||||
}
|
||||
});
|
||||
HttpSender::H1(sender)
|
||||
};
|
||||
|
||||
// Capture request metadata
|
||||
let (mut parts, body) = req.into_parts();
|
||||
let body_bytes = body.collect().await?.to_bytes();
|
||||
let request_body = if body_bytes.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(body_bytes.to_vec())
|
||||
};
|
||||
emit_request_events(&event_tx, id, &parts.headers, &request_body);
|
||||
|
||||
if negotiated_h2 {
|
||||
// HTTP/2 requires absolute-form URI with scheme + authority
|
||||
parts.uri = uri_str.parse::<Uri>()?;
|
||||
} else {
|
||||
parts.uri = path.parse::<Uri>()?;
|
||||
}
|
||||
|
||||
if !parts.headers.contains_key(hyper::header::HOST) {
|
||||
parts.headers.insert(hyper::header::HOST, host.parse()?);
|
||||
}
|
||||
|
||||
let outgoing = Request::from_parts(parts, Full::new(body_bytes));
|
||||
|
||||
match sender.send_request(outgoing).await {
|
||||
Ok(resp) => {
|
||||
emit_response_events(&event_tx, id, &resp, &start);
|
||||
|
||||
let (parts, body) = resp.into_parts();
|
||||
Ok(Response::from_parts(
|
||||
parts,
|
||||
measured_incoming(body, id, start, event_tx),
|
||||
))
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = event_tx.send(ProxyEvent::Error {
|
||||
id,
|
||||
error: e.to_string(),
|
||||
});
|
||||
Err(Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_host_port(authority: &str) -> (String, u16) {
|
||||
if let Some((host, port_str)) = authority.rsplit_once(':') {
|
||||
if let Ok(port) = port_str.parse::<u16>() {
|
||||
return (host.to_string(), port);
|
||||
}
|
||||
}
|
||||
(authority.to_string(), 443)
|
||||
}
|
||||
Reference in New Issue
Block a user