Merge main into proxy foundation

This commit is contained in:
Gregory Schier
2026-05-07 14:16:35 -07:00
39 changed files with 512 additions and 346 deletions

View File

@@ -37,7 +37,7 @@ pub struct MethodDefinition {
static SERIALIZE_OPTIONS: &'static SerializeOptions =
&SerializeOptions::new().skip_default_fields(false).stringify_64_bit_integers(false);
pub fn serialize_message(msg: &DynamicMessage) -> Result<String, String> {
pub(crate) fn serialize_dynamic_message_json(msg: &DynamicMessage) -> Result<String, String> {
let mut buf = Vec::new();
let mut se = serde_json::Serializer::pretty(&mut buf);
msg.serialize_with_options(&mut se, SERIALIZE_OPTIONS).map_err(|e| e.to_string())?;

View File

@@ -2,7 +2,8 @@ use crate::codec::DynamicCodec;
use crate::error::Error::GenericError;
use crate::error::Result;
use crate::reflection::{
fill_pool_from_files, fill_pool_from_reflection, method_desc_to_path, reflect_types_for_message,
fill_pool_from_files, fill_pool_from_reflection, method_desc_to_path,
reflect_types_for_dynamic_message, reflect_types_for_message,
};
use crate::transport::get_transport;
use crate::{MethodDefinition, ServiceDefinition, json_schema};
@@ -11,8 +12,11 @@ use hyper_util::client::legacy::Client;
use hyper_util::client::legacy::connect::HttpConnector;
use log::{info, warn};
pub use prost_reflect::DynamicMessage;
use prost_reflect::ReflectMessage;
use prost_reflect::prost::Message;
use prost_reflect::{DescriptorPool, MethodDescriptor, ServiceDescriptor};
use serde_json::Deserializer;
use std::borrow::Cow;
use std::collections::BTreeMap;
use std::error::Error;
use std::fmt;
@@ -115,6 +119,38 @@ impl GrpcConnection {
Ok(client.unary(req, path, codec).await?)
}
pub async fn serialize_message(
&self,
message: &DynamicMessage,
metadata: &BTreeMap<String, String>,
client_cert: Option<ClientCertificateConfig>,
) -> Result<String> {
let message = if self.use_reflection {
reflect_types_for_dynamic_message(
self.pool.clone(),
&self.uri,
message,
metadata,
client_cert,
)
.await?;
let message_name = message.descriptor().full_name().to_string();
let message_desc = {
let pool = self.pool.read().await;
pool.get_message_by_name(&message_name)
.ok_or(GenericError(format!("Failed to find message {message_name}")))?
};
let mut message_with_updated_pool = DynamicMessage::new(message_desc);
message_with_updated_pool.merge(message.encode_to_vec().as_slice())?;
Cow::Owned(message_with_updated_pool)
} else {
Cow::Borrowed(message)
};
crate::serialize_dynamic_message_json(message.as_ref()).map_err(GenericError)
}
pub async fn streaming<F>(
&self,
service: &str,

View File

@@ -7,7 +7,7 @@ use anyhow::anyhow;
use async_recursion::async_recursion;
use log::{debug, info, warn};
use prost::Message;
use prost_reflect::{DescriptorPool, MethodDescriptor};
use prost_reflect::{DescriptorPool, DynamicMessage, MethodDescriptor, ReflectMessage, Value};
use prost_types::{FileDescriptorProto, FileDescriptorSet};
use std::collections::{BTreeMap, HashSet};
use std::env::temp_dir;
@@ -233,6 +233,83 @@ pub(crate) async fn reflect_types_for_message(
Ok(())
}
pub(crate) async fn reflect_types_for_dynamic_message(
pool: Arc<RwLock<DescriptorPool>>,
uri: &Uri,
message: &DynamicMessage,
metadata: &BTreeMap<String, String>,
client_cert: Option<ClientCertificateConfig>,
) -> Result<()> {
let mut extra_types = HashSet::new();
collect_any_types_from_dynamic_message(message, &mut extra_types);
if extra_types.is_empty() {
return Ok(());
}
let mut client = AutoReflectionClient::new(uri, false, client_cert)?;
for extra_type in extra_types {
{
let guard = pool.read().await;
if guard.get_message_by_name(&extra_type).is_some() {
continue;
}
}
info!("Adding response file descriptor for {:?} from reflection", extra_type);
let req = MessageRequest::FileContainingSymbol(extra_type.clone().into());
let resp = match client.send_reflection_request(req, metadata).await {
Ok(r) => r,
Err(e) => {
return Err(GenericError(format!(
"Error sending reflection request for response @type \"{extra_type}\": {e:?}",
)));
}
};
let files = match resp {
MessageResponse::FileDescriptorResponse(resp) => resp.file_descriptor_proto,
_ => panic!("Expected a FileDescriptorResponse variant"),
};
{
let mut guard = pool.write().await;
add_file_descriptors_to_pool(files, &mut *guard, &mut client, metadata).await;
}
}
Ok(())
}
fn collect_any_types_from_dynamic_message(message: &DynamicMessage, out: &mut HashSet<String>) {
if message.descriptor().full_name() == "google.protobuf.Any" {
if let Some(Value::String(type_url)) = message.get_field_by_name("type_url").as_deref() {
if let Some(full_name) = type_url.rsplit_once('/').map(|(_, name)| name) {
out.insert(full_name.to_string());
}
}
}
for (_, value) in message.fields() {
collect_any_types_from_value(value, out);
}
}
fn collect_any_types_from_value(value: &Value, out: &mut HashSet<String>) {
match value {
Value::Message(message) => collect_any_types_from_dynamic_message(message, out),
Value::List(values) => {
for value in values {
collect_any_types_from_value(value, out);
}
}
Value::Map(values) => {
for value in values.values() {
collect_any_types_from_value(value, out);
}
}
_ => {}
}
}
#[async_recursion]
pub(crate) async fn add_file_descriptors_to_pool(
fds: Vec<Vec<u8>>,

View File

@@ -1,11 +1,36 @@
use crate::dns::LocalhostResolver;
use crate::error::Result;
use log::{debug, info, warn};
use reqwest::{Client, Proxy, redirect};
use reqwest::{Client, ClientBuilder, Proxy, redirect};
use std::sync::Arc;
use yaak_models::models::DnsOverride;
use yaak_tls::{ClientCertificateConfig, get_tls_config};
pub const HTTP2_MAX_RESPONSE_HEADER_LIST_SIZE: u32 = 1024 * 1024;
fn client_builder() -> ClientBuilder {
Client::builder().http2_max_header_list_size(HTTP2_MAX_RESPONSE_HEADER_LIST_SIZE)
}
#[derive(Clone)]
pub struct ConfiguredClient {
inner: Client,
}
impl ConfiguredClient {
pub(crate) fn build_default() -> Result<Self> {
Ok(Self { inner: client_builder().build()? })
}
pub(crate) fn from_inner(inner: Client) -> Self {
Self { inner }
}
pub(crate) fn inner(&self) -> &Client {
&self.inner
}
}
/// Build a native-tls connector for maximum compatibility when certificate
/// validation is disabled. Unlike rustls, native-tls uses the OS TLS stack
/// (Secure Transport on macOS, SChannel on Windows, OpenSSL on Linux) which
@@ -87,8 +112,8 @@ impl HttpConnectionOptions {
/// Build a reqwest Client and return it along with the DNS resolver.
/// The resolver is returned separately so it can be configured per-request
/// to emit DNS timing events to the appropriate channel.
pub(crate) fn build_client(&self) -> Result<(Client, Arc<LocalhostResolver>)> {
let mut client = Client::builder()
pub(crate) fn build_client(&self) -> Result<(ConfiguredClient, Arc<LocalhostResolver>)> {
let mut client = client_builder()
.connection_verbose(true)
.redirect(redirect::Policy::none())
// Decompression is handled by HttpTransaction, not reqwest
@@ -108,8 +133,7 @@ impl HttpConnectionOptions {
client = client.use_preconfigured_tls(config);
} else {
// Use native TLS for maximum compatibility (supports TLS 1.0+)
let connector =
build_native_tls_connector(self.client_certificate.clone())?;
let connector = build_native_tls_connector(self.client_certificate.clone())?;
client = client.use_preconfigured_tls(connector);
}
@@ -136,7 +160,7 @@ impl HttpConnectionOptions {
self.client_certificate.is_some()
);
Ok((client.build()?, resolver))
Ok((ConfiguredClient::from_inner(client.build()?), resolver))
}
}

View File

@@ -124,6 +124,30 @@ impl CookieStore {
}
}
/// Get a stored cookie value by name, optionally scoped to an exact stored domain.
pub fn get_cookie_value_from_jar(
cookies: impl IntoIterator<Item = Cookie>,
name: &str,
domain: Option<&str>,
) -> Option<String> {
let domain = domain.and_then(normalize_cookie_domain_filter);
cookies.into_iter().find_map(|cookie| {
let (cookie_name, value) = parse_cookie_name_value(&cookie.raw_cookie)?;
if cookie_name != name {
return None;
}
if let Some(domain) = domain.as_deref() {
if !cookie_domain_matches_filter(&cookie.domain, domain) {
return None;
}
}
Some(value)
})
}
/// Parse name=value from a cookie string (raw_cookie format)
fn parse_cookie_name_value(raw_cookie: &str) -> Option<(String, String)> {
// The raw_cookie typically looks like "name=value" or "name=value; attr1; attr2=..."
@@ -135,6 +159,20 @@ fn parse_cookie_name_value(raw_cookie: &str) -> Option<(String, String)> {
if name.is_empty() { None } else { Some((name, value)) }
}
fn normalize_cookie_domain_filter(domain: &str) -> Option<String> {
let domain = domain.trim().trim_start_matches('.').to_lowercase();
if domain.is_empty() { None } else { Some(domain) }
}
fn cookie_domain_matches_filter(cookie_domain: &CookieDomain, domain: &str) -> bool {
match cookie_domain {
CookieDomain::HostOnly(cookie_domain) | CookieDomain::Suffix(cookie_domain) => {
normalize_cookie_domain_filter(cookie_domain).is_some_and(|d| d == domain)
}
CookieDomain::NotPresent | CookieDomain::Empty => false,
}
}
/// Parse a Set-Cookie header into a Cookie
fn parse_set_cookie(header_value: &str, request_url: &Url) -> Option<Cookie> {
let parsed = cookie::Cookie::parse(header_value).ok()?;
@@ -278,6 +316,15 @@ fn is_localhost(domain: &str) -> bool {
mod tests {
use super::*;
fn cookie(raw_cookie: &str, domain: CookieDomain) -> Cookie {
Cookie {
raw_cookie: raw_cookie.to_string(),
domain,
expires: CookieExpires::SessionEnd,
path: ("/".to_string(), false),
}
}
#[test]
fn test_parse_cookie_name_value() {
assert_eq!(
@@ -387,6 +434,52 @@ mod tests {
assert_eq!(store.get_all_cookies().len(), 1);
}
#[test]
fn test_get_cookie_value_preserves_name_only_first_match() {
let cookies = vec![
cookie("co-auth=", CookieDomain::HostOnly("foo.example.com".to_string())),
cookie("co-auth=token", CookieDomain::Suffix("example.com".to_string())),
];
assert_eq!(get_cookie_value_from_jar(cookies, "co-auth", None), Some("".to_string()));
}
#[test]
fn test_get_cookie_value_matches_domain() {
let cookies = vec![
cookie("co-auth=", CookieDomain::HostOnly("foo.example.com".to_string())),
cookie("co-auth=token", CookieDomain::Suffix("example.com".to_string())),
];
assert_eq!(
get_cookie_value_from_jar(cookies, "co-auth", Some("example.com")),
Some("token".to_string())
);
}
#[test]
fn test_get_cookie_value_normalizes_domain_filter() {
let cookies = vec![cookie(
"co-auth=token",
CookieDomain::Suffix("Example.COM".to_string()),
)];
assert_eq!(
get_cookie_value_from_jar(cookies, "co-auth", Some(" .example.com ")),
Some("token".to_string())
);
}
#[test]
fn test_get_cookie_value_requires_exact_stored_domain_match() {
let cookies = vec![cookie(
"co-auth=token",
CookieDomain::HostOnly("foo.example.com".to_string()),
)];
assert_eq!(get_cookie_value_from_jar(cookies, "co-auth", Some("example.com")), None);
}
#[test]
fn test_is_single_component_domain() {
// Single-component domains (TLDs)

View File

@@ -1,7 +1,6 @@
use crate::client::HttpConnectionOptions;
use crate::client::{ConfiguredClient, HttpConnectionOptions};
use crate::dns::LocalhostResolver;
use crate::error::Result;
use reqwest::Client;
use std::collections::BTreeMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
@@ -10,7 +9,7 @@ use tokio::sync::RwLock;
/// A cached HTTP client along with its DNS resolver.
/// The resolver is needed to set the event sender per-request.
pub struct CachedClient {
pub client: Client,
pub client: ConfiguredClient,
pub resolver: Arc<LocalhostResolver>,
}

View File

@@ -5,7 +5,7 @@ use async_trait::async_trait;
use bytes::Bytes;
use futures_util::StreamExt;
use http_body::{Body as HttpBody, Frame, SizeHint};
use reqwest::{Client, Method, Version};
use reqwest::{Method, Version};
use std::fmt::Display;
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -411,18 +411,18 @@ pub trait HttpSender: Send + Sync {
/// Reqwest-based implementation of HttpSender
pub struct ReqwestSender {
client: Client,
client: crate::client::ConfiguredClient,
}
impl ReqwestSender {
/// Create a new ReqwestSender with a default client
pub fn new() -> Result<Self> {
let client = Client::builder().build().map_err(Error::Client)?;
let client = crate::client::ConfiguredClient::build_default()?;
Ok(Self { client })
}
/// Create a new ReqwestSender with a custom client
pub fn with_client(client: Client) -> Self {
/// Create a new ReqwestSender with a configured client
pub fn with_client(client: crate::client::ConfiguredClient) -> Self {
Self { client }
}
}
@@ -444,7 +444,7 @@ impl HttpSender for ReqwestSender {
.map_err(|e| Error::RequestError(format!("Invalid HTTP method: {}", e)))?;
// Build the request
let mut req_builder = self.client.request(method, &request.url);
let mut req_builder = self.client.inner().request(method, &request.url);
// Add headers
for header in request.headers {
@@ -513,7 +513,7 @@ impl HttpSender for ReqwestSender {
send_event(HttpResponseEvent::Info("Sending request to server".to_string()));
// Map some errors to our own, so they look nicer
let response = self.client.execute(sendable_req).await.map_err(|e| {
let response = self.client.inner().execute(sendable_req).await.map_err(|e| {
if reqwest::Error::is_timeout(&e) {
Error::RequestTimeout(
request.options.timeout.unwrap_or(Duration::from_secs(0)).clone(),

View File

@@ -226,10 +226,8 @@ async fn build_body(
let (body, content_type) = match body_type.as_str() {
"binary" => (build_binary_body(&body).await?, None),
"graphql" => (build_graphql_body(&method, &body), Some("application/json".to_string())),
"application/x-www-form-urlencoded" => {
(build_form_body(&body), Some("application/x-www-form-urlencoded".to_string()))
}
"graphql" => (build_graphql_body(&method, &body), None),
"application/x-www-form-urlencoded" => (build_form_body(&body), None),
"multipart/form-data" => build_multipart_body(&body, &headers).await?,
_ if body.contains_key("text") => (build_text_body(&body, body_type), None),
t => {

View File

@@ -144,9 +144,10 @@ export function duplicateModel<M extends AnyModel["model"], T extends ExtractMod
throw new Error("Failed to duplicate null model");
}
// If the model has a name, try to duplicate it with a name that doesn't conflict
let name = "name" in model ? resolvedModelName(model) : undefined;
if (name != null) {
// If the model has an explicit (non-empty) name, try to duplicate it with a name that doesn't conflict.
// When the name is empty, keep it empty so the display falls back to the URL.
let name = "name" in model ? model.name : undefined;
if (name) {
const existingModels = listModels(model.model);
for (let i = 0; i < 100; i++) {
const hasConflict = existingModels.some((m) => {

View File

@@ -396,7 +396,7 @@ description?: string, };
export type GenericCompletionOption = { label: string, detail?: string, info?: string, type?: CompletionOptionType, boost?: number, };
export type GetCookieValueRequest = { name: string, };
export type GetCookieValueRequest = { name: string, domain?: string | null, };
export type GetCookieValueResponse = { value: string | null, };

View File

@@ -307,6 +307,9 @@ pub struct ListCookieNamesResponse {
#[ts(export, export_to = "gen_events.ts")]
pub struct GetCookieValueRequest {
pub name: String,
#[ts(optional = nullable)]
pub domain: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, TS)]

View File

@@ -1,5 +1,5 @@
/* tslint:disable */
/* eslint-disable */
export function unescape_template(template: string): any;
export function escape_template(template: string): any;
export function parse_template(template: string): any;
export function unescape_template(template: string): any;

View File

@@ -161,6 +161,20 @@ function takeFromExternrefTable0(idx) {
wasm.__externref_table_dealloc(idx);
return value;
}
/**
* @param {string} template
* @returns {any}
*/
export function unescape_template(template) {
const ptr0 = passStringToWasm0(template, wasm.__wbindgen_malloc, wasm.__wbindgen_realloc);
const len0 = WASM_VECTOR_LEN;
const ret = wasm.unescape_template(ptr0, len0);
if (ret[2]) {
throw takeFromExternrefTable0(ret[1]);
}
return takeFromExternrefTable0(ret[0]);
}
/**
* @param {string} template
* @returns {any}
@@ -189,20 +203,6 @@ export function parse_template(template) {
return takeFromExternrefTable0(ret[0]);
}
/**
* @param {string} template
* @returns {any}
*/
export function unescape_template(template) {
const ptr0 = passStringToWasm0(template, wasm.__wbindgen_malloc, wasm.__wbindgen_realloc);
const len0 = WASM_VECTOR_LEN;
const ret = wasm.unescape_template(ptr0, len0);
if (ret[2]) {
throw takeFromExternrefTable0(ret[1]);
}
return takeFromExternrefTable0(ret[0]);
}
export function __wbg_new_405e22f390576ce2() {
const ret = new Object();
return ret;

Binary file not shown.