Decouple core Yaak logic from Tauri (#354)

This commit is contained in:
Gregory Schier
2026-01-08 20:44:25 -08:00
committed by GitHub
parent 3bcc0b8356
commit ef80216ca1
465 changed files with 3052 additions and 6234 deletions

View File

@@ -0,0 +1,8 @@
[package]
name = "yaak-common"
version = "0.1.0"
edition = "2024"
publish = false
[dependencies]
serde_json = { workspace = true }

View File

@@ -0,0 +1,2 @@
pub mod platform;
pub mod serde;

View File

@@ -0,0 +1,55 @@
use crate::platform::OperatingSystem::{Linux, MacOS, Unknown, Windows};
pub enum OperatingSystem {
Windows,
MacOS,
Linux,
Unknown,
}
pub fn get_os() -> OperatingSystem {
if cfg!(target_os = "windows") {
Windows
} else if cfg!(target_os = "macos") {
MacOS
} else if cfg!(target_os = "linux") {
Linux
} else {
Unknown
}
}
pub fn get_os_str() -> &'static str {
match get_os() {
Windows => "windows",
MacOS => "macos",
Linux => "linux",
Unknown => "unknown",
}
}
pub fn get_ua_platform() -> &'static str {
if cfg!(target_os = "windows") {
"Win"
} else if cfg!(target_os = "macos") {
"Mac"
} else if cfg!(target_os = "linux") {
"Linux"
} else {
"Unknown"
}
}
pub fn get_ua_arch() -> &'static str {
if cfg!(target_arch = "x86_64") {
"x86_64"
} else if cfg!(target_arch = "x86") {
"i386"
} else if cfg!(target_arch = "arm") {
"ARM"
} else if cfg!(target_arch = "aarch64") {
"ARM64"
} else {
"Unknown"
}
}

View File

@@ -0,0 +1,23 @@
use serde_json::Value;
use std::collections::BTreeMap;
pub fn get_bool(v: &Value, key: &str, fallback: bool) -> bool {
match v.get(key) {
None => fallback,
Some(v) => v.as_bool().unwrap_or(fallback),
}
}
pub fn get_str<'a>(v: &'a Value, key: &str) -> &'a str {
match v.get(key) {
None => "",
Some(v) => v.as_str().unwrap_or_default(),
}
}
pub fn get_str_map<'a>(v: &'a BTreeMap<String, Value>, key: &str) -> &'a str {
match v.get(key) {
None => "",
Some(v) => v.as_str().unwrap_or_default(),
}
}

View File

@@ -0,0 +1,9 @@
[package]
name = "yaak-core"
version = "0.0.0"
edition = "2024"
authors = ["Gregory Schier"]
publish = false
[dependencies]
thiserror = { workspace = true }

View File

@@ -0,0 +1,56 @@
use std::path::PathBuf;
/// Context for a workspace operation.
///
/// In Tauri, this is extracted from the WebviewWindow URL.
/// In CLI, this is constructed from command arguments or config.
#[derive(Debug, Clone, Default)]
pub struct WorkspaceContext {
pub workspace_id: Option<String>,
pub environment_id: Option<String>,
pub cookie_jar_id: Option<String>,
pub request_id: Option<String>,
}
impl WorkspaceContext {
pub fn new() -> Self {
Self::default()
}
pub fn with_workspace(mut self, workspace_id: impl Into<String>) -> Self {
self.workspace_id = Some(workspace_id.into());
self
}
pub fn with_environment(mut self, environment_id: impl Into<String>) -> Self {
self.environment_id = Some(environment_id.into());
self
}
pub fn with_cookie_jar(mut self, cookie_jar_id: impl Into<String>) -> Self {
self.cookie_jar_id = Some(cookie_jar_id.into());
self
}
pub fn with_request(mut self, request_id: impl Into<String>) -> Self {
self.request_id = Some(request_id.into());
self
}
}
/// Application context trait for accessing app-level resources.
///
/// This abstracts over Tauri's `AppHandle` for path resolution and app identity.
/// Implemented by Tauri's AppHandle and by CLI's own context struct.
pub trait AppContext: Send + Sync + Clone {
/// Returns the path to the application data directory.
/// This is where the database and other persistent data are stored.
fn app_data_dir(&self) -> PathBuf;
/// Returns the application identifier (e.g., "app.yaak.desktop").
/// Used for keyring access and other platform-specific features.
fn app_identifier(&self) -> &str;
/// Returns true if running in development mode.
fn is_dev(&self) -> bool;
}

View File

@@ -0,0 +1,15 @@
use thiserror::Error;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Error, Debug)]
pub enum Error {
#[error("Missing required context: {0}")]
MissingContext(String),
#[error("Configuration error: {0}")]
Config(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
}

View File

@@ -0,0 +1,10 @@
//! Core abstractions for Yaak that work without Tauri.
//!
//! This crate provides foundational types and traits that allow Yaak's
//! business logic to run in both Tauri (desktop app) and CLI contexts.
mod context;
mod error;
pub use context::{AppContext, WorkspaceContext};
pub use error::{Error, Result};

View File

@@ -0,0 +1,15 @@
[package]
name = "yaak-crypto"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
base32 = "0.5.1" # For encoding human-readable key
base64 = "0.22.1" # For encoding in the database
chacha20poly1305 = "0.10.1"
keyring = { workspace = true, features = ["apple-native", "windows-native", "sync-secret-service"] }
log = { workspace = true }
serde = { workspace = true, features = ["derive"] }
thiserror = { workspace = true }
yaak-models = { workspace = true }

View File

@@ -0,0 +1,13 @@
import { invoke } from '@tauri-apps/api/core';
export function enableEncryption(workspaceId: string) {
return invoke<void>('cmd_enable_encryption', { workspaceId });
}
export function revealWorkspaceKey(workspaceId: string) {
return invoke<string>('cmd_reveal_workspace_key', { workspaceId });
}
export function setWorkspaceKey(args: { workspaceId: string; key: string }) {
return invoke<void>('cmd_set_workspace_key', args);
}

View File

@@ -0,0 +1,6 @@
{
"name": "@yaakapp-internal/crypto",
"private": true,
"version": "1.0.0",
"main": "index.ts"
}

View File

@@ -0,0 +1,98 @@
use crate::error::Error::{DecryptionError, EncryptionError, InvalidEncryptedData};
use crate::error::Result;
use chacha20poly1305::aead::generic_array::typenum::Unsigned;
use chacha20poly1305::aead::{Aead, AeadCore, Key, KeyInit, OsRng};
use chacha20poly1305::XChaCha20Poly1305;
const ENCRYPTION_TAG: &str = "yA4k3nC";
const ENCRYPTION_VERSION: u8 = 1;
pub(crate) fn encrypt_data(data: &[u8], key: &Key<XChaCha20Poly1305>) -> Result<Vec<u8>> {
let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
let cipher = XChaCha20Poly1305::new(&key);
let ciphered_data = cipher.encrypt(&nonce, data).map_err(|_| EncryptionError)?;
let mut data: Vec<u8> = Vec::new();
data.extend_from_slice(ENCRYPTION_TAG.as_bytes()); // Tag
data.push(ENCRYPTION_VERSION); // Version
data.extend_from_slice(&nonce.as_slice()); // Nonce
data.extend_from_slice(&ciphered_data); // Ciphertext
Ok(data)
}
pub(crate) fn decrypt_data(cipher_data: &[u8], key: &Key<XChaCha20Poly1305>) -> Result<Vec<u8>> {
// Yaak Tag + ID + Version + Nonce + ... ciphertext ...
let (tag, rest) =
cipher_data.split_at_checked(ENCRYPTION_TAG.len()).ok_or(InvalidEncryptedData)?;
if tag != ENCRYPTION_TAG.as_bytes() {
return Err(InvalidEncryptedData);
}
let (version, rest) = rest.split_at_checked(1).ok_or(InvalidEncryptedData)?;
if version[0] != ENCRYPTION_VERSION {
return Err(InvalidEncryptedData);
}
let nonce_bytes = <XChaCha20Poly1305 as AeadCore>::NonceSize::to_usize();
let (nonce, ciphered_data) = rest.split_at_checked(nonce_bytes).ok_or(InvalidEncryptedData)?;
let cipher = XChaCha20Poly1305::new(&key);
cipher.decrypt(nonce.into(), ciphered_data).map_err(|_e| DecryptionError)
}
#[cfg(test)]
mod test {
use crate::encryption::{decrypt_data, encrypt_data};
use crate::error::Error::InvalidEncryptedData;
use crate::error::Result;
use chacha20poly1305::aead::OsRng;
use chacha20poly1305::{KeyInit, XChaCha20Poly1305};
#[test]
fn test_encrypt_decrypt() -> Result<()> {
let key = XChaCha20Poly1305::generate_key(OsRng);
let encrypted = encrypt_data("hello world".as_bytes(), &key)?;
let decrypted = decrypt_data(encrypted.as_slice(), &key)?;
assert_eq!(String::from_utf8(decrypted).unwrap(), "hello world");
Ok(())
}
#[test]
fn test_decrypt_empty() -> Result<()> {
let key = XChaCha20Poly1305::generate_key(OsRng);
let encrypted = encrypt_data(&[], &key)?;
assert_eq!(encrypted.len(), 48);
let decrypted = decrypt_data(encrypted.as_slice(), &key)?;
assert_eq!(String::from_utf8(decrypted).unwrap(), "");
Ok(())
}
#[test]
fn test_decrypt_bad_version() -> Result<()> {
let key = XChaCha20Poly1305::generate_key(OsRng);
let mut encrypted = encrypt_data("hello world".as_bytes(), &key)?;
encrypted[7] = 0;
let decrypted = decrypt_data(encrypted.as_slice(), &key);
assert!(matches!(decrypted, Err(InvalidEncryptedData)));
Ok(())
}
#[test]
fn test_decrypt_bad_tag() -> Result<()> {
let key = XChaCha20Poly1305::generate_key(OsRng);
let mut encrypted = encrypt_data("hello world".as_bytes(), &key)?;
encrypted[0] = 2;
let decrypted = decrypt_data(encrypted.as_slice(), &key);
assert!(matches!(decrypted, Err(InvalidEncryptedData)));
Ok(())
}
#[test]
fn test_decrypt_unencrypted_data() -> Result<()> {
let key = XChaCha20Poly1305::generate_key(OsRng);
let decrypted = decrypt_data("123".as_bytes(), &key);
assert!(matches!(decrypted, Err(InvalidEncryptedData)));
Ok(())
}
}

View File

@@ -0,0 +1,50 @@
use serde::{Serialize, Serializer};
use std::io;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum Error {
#[error(transparent)]
DbError(#[from] yaak_models::error::Error),
#[error("Keyring error: {0}")]
KeyringError(#[from] keyring::Error),
#[error("Missing workspace encryption key")]
MissingWorkspaceKey,
#[error("Incorrect workspace key")]
IncorrectWorkspaceKey,
#[error("Failed to decrypt workspace key: {0}")]
WorkspaceKeyDecryptionError(String),
#[error("Crypto IO error: {0}")]
IoError(#[from] io::Error),
#[error("Failed to encrypt data")]
EncryptionError,
#[error("Failed to decrypt data")]
DecryptionError,
#[error("Invalid encrypted data")]
InvalidEncryptedData,
#[error("Invalid key provided")]
InvalidHumanKey,
#[error("Encryption error: {0}")]
GenericError(String),
}
impl Serialize for Error {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.to_string().as_ref())
}
}
pub type Result<T> = std::result::Result<T, Error>;

View File

@@ -0,0 +1,7 @@
extern crate core;
pub mod encryption;
pub mod error;
pub mod manager;
mod master_key;
mod workspace_key;

View File

@@ -0,0 +1,158 @@
use crate::error::Error::{
GenericError, IncorrectWorkspaceKey, MissingWorkspaceKey, WorkspaceKeyDecryptionError,
};
use crate::error::{Error, Result};
use crate::master_key::MasterKey;
use crate::workspace_key::WorkspaceKey;
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use log::{info, warn};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use yaak_models::models::{EncryptedKey, Workspace, WorkspaceMeta};
use yaak_models::query_manager::QueryManager;
use yaak_models::util::{generate_id_of_length, UpdateSource};
const KEY_USER: &str = "encryption-key";
#[derive(Debug, Clone)]
pub struct EncryptionManager {
cached_master_key: Arc<Mutex<Option<MasterKey>>>,
cached_workspace_keys: Arc<Mutex<HashMap<String, WorkspaceKey>>>,
query_manager: QueryManager,
app_id: String,
}
impl EncryptionManager {
pub fn new(query_manager: QueryManager, app_id: impl Into<String>) -> Self {
Self {
cached_master_key: Default::default(),
cached_workspace_keys: Default::default(),
query_manager,
app_id: app_id.into(),
}
}
pub fn encrypt(&self, workspace_id: &str, data: &[u8]) -> Result<Vec<u8>> {
let workspace_secret = self.get_workspace_key(workspace_id)?;
workspace_secret.encrypt(data)
}
pub fn decrypt(&self, workspace_id: &str, data: &[u8]) -> Result<Vec<u8>> {
let workspace_secret = self.get_workspace_key(workspace_id)?;
workspace_secret.decrypt(data)
}
pub fn reveal_workspace_key(&self, workspace_id: &str) -> Result<String> {
let key = self.get_workspace_key(workspace_id)?;
key.to_human()
}
pub fn set_human_key(&self, workspace_id: &str, human_key: &str) -> Result<WorkspaceMeta> {
let wkey = WorkspaceKey::from_human(human_key)?;
let workspace = self.query_manager.connect().get_workspace(workspace_id)?;
let encryption_key_challenge = match workspace.encryption_key_challenge {
None => return self.set_workspace_key(workspace_id, &wkey),
Some(c) => c,
};
let encryption_key_challenge = match BASE64_STANDARD.decode(encryption_key_challenge) {
Ok(c) => c,
Err(_) => return Err(GenericError("Failed to decode workspace challenge".to_string())),
};
if let Err(_) = wkey.decrypt(encryption_key_challenge.as_slice()) {
return Err(IncorrectWorkspaceKey);
};
self.set_workspace_key(workspace_id, &wkey)
}
pub(crate) fn set_workspace_key(
&self,
workspace_id: &str,
wkey: &WorkspaceKey,
) -> Result<WorkspaceMeta> {
info!("Created workspace key for {workspace_id}");
let encrypted_key = BASE64_STANDARD.encode(self.get_master_key()?.encrypt(wkey.raw_key())?);
let encrypted_key = EncryptedKey { encrypted_key };
let encryption_key_challenge = wkey.encrypt(generate_id_of_length(50).as_bytes())?;
let encryption_key_challenge = Some(BASE64_STANDARD.encode(encryption_key_challenge));
let workspace_meta = self.query_manager.with_tx::<WorkspaceMeta, Error>(|tx| {
let workspace = tx.get_workspace(workspace_id)?;
let workspace_meta = tx.get_or_create_workspace_meta(workspace_id)?;
tx.upsert_workspace(
&Workspace { encryption_key_challenge, ..workspace },
&UpdateSource::Background,
)?;
Ok(tx.upsert_workspace_meta(
&WorkspaceMeta { encryption_key: Some(encrypted_key.clone()), ..workspace_meta },
&UpdateSource::Background,
)?)
})?;
let mut cache = self.cached_workspace_keys.lock().unwrap();
cache.insert(workspace_id.to_string(), wkey.clone());
Ok(workspace_meta)
}
pub fn ensure_workspace_key(&self, workspace_id: &str) -> Result<WorkspaceMeta> {
let workspace_meta =
self.query_manager.connect().get_or_create_workspace_meta(workspace_id)?;
// Already exists
if let Some(_) = workspace_meta.encryption_key {
warn!("Tried to create workspace key when one already exists for {workspace_id}");
return Ok(workspace_meta);
}
let wkey = WorkspaceKey::create()?;
self.set_workspace_key(workspace_id, &wkey)
}
fn get_workspace_key(&self, workspace_id: &str) -> Result<WorkspaceKey> {
{
let cache = self.cached_workspace_keys.lock().unwrap();
if let Some(k) = cache.get(workspace_id) {
return Ok(k.clone());
}
};
let db = self.query_manager.connect();
let workspace_meta = db.get_or_create_workspace_meta(workspace_id)?;
let key = match workspace_meta.encryption_key {
None => return Err(MissingWorkspaceKey),
Some(k) => k,
};
let mkey = self.get_master_key()?;
let decoded_key = BASE64_STANDARD
.decode(key.encrypted_key)
.map_err(|e| WorkspaceKeyDecryptionError(e.to_string()))?;
let raw_key = mkey
.decrypt(decoded_key.as_slice())
.map_err(|e| WorkspaceKeyDecryptionError(e.to_string()))?;
let wkey = WorkspaceKey::from_raw_key(raw_key.as_slice());
Ok(wkey)
}
fn get_master_key(&self) -> Result<MasterKey> {
// NOTE: This locks the key for the entire function which seems wrong, but this prevents
// concurrent access from prompting the user for a keychain password multiple times.
let mut master_secret = self.cached_master_key.lock().unwrap();
if let Some(k) = master_secret.as_ref() {
return Ok(k.to_owned());
}
let mkey = MasterKey::get_or_create(&self.app_id, KEY_USER)?;
*master_secret = Some(mkey.clone());
Ok(mkey)
}
}

View File

@@ -0,0 +1,79 @@
use crate::encryption::{decrypt_data, encrypt_data};
use crate::error::Error::GenericError;
use crate::error::Result;
use base32::Alphabet;
use chacha20poly1305::aead::{Key, KeyInit, OsRng};
use chacha20poly1305::XChaCha20Poly1305;
use keyring::{Entry, Error};
use log::info;
const HUMAN_PREFIX: &str = "YKM_";
#[derive(Debug, Clone)]
pub(crate) struct MasterKey {
key: Key<XChaCha20Poly1305>,
}
impl MasterKey {
pub(crate) fn get_or_create(app_id: &str, user: &str) -> Result<Self> {
let id = format!("{app_id}.EncryptionKey");
let entry = Entry::new(&id, user)?;
let key = match entry.get_password() {
Ok(encoded) => {
let without_prefix = encoded.strip_prefix(HUMAN_PREFIX).unwrap_or(&encoded);
let key_bytes = base32::decode(Alphabet::Crockford {}, &without_prefix)
.ok_or(GenericError("Failed to decode master key".to_string()))?;
Key::<XChaCha20Poly1305>::clone_from_slice(key_bytes.as_slice())
}
Err(Error::NoEntry) => {
info!("Creating new master key");
let key = XChaCha20Poly1305::generate_key(OsRng);
let encoded = base32::encode(Alphabet::Crockford {}, key.as_slice());
let with_prefix = format!("{HUMAN_PREFIX}{encoded}");
entry.set_password(&with_prefix)?;
key
}
Err(e) => return Err(GenericError(e.to_string())),
};
Ok(Self { key })
}
pub(crate) fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
encrypt_data(data, &self.key)
}
pub(crate) fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
decrypt_data(data, &self.key)
}
#[cfg(test)]
pub(crate) fn test_key() -> Self {
let key: Key<XChaCha20Poly1305> = Key::<XChaCha20Poly1305>::clone_from_slice(
"00000000000000000000000000000000".as_bytes(),
);
Self { key }
}
}
#[cfg(test)]
mod tests {
use crate::error::Result;
use crate::master_key::MasterKey;
#[test]
fn test_master_key() -> Result<()> {
// Test out the master key
let mkey = MasterKey::test_key();
let encrypted = mkey.encrypt("hello".as_bytes())?;
let decrypted = mkey.decrypt(encrypted.as_slice()).unwrap();
assert_eq!(decrypted, "hello".as_bytes().to_vec());
let mkey = MasterKey::test_key();
let decrypted = mkey.decrypt(encrypted.as_slice()).unwrap();
assert_eq!(decrypted, "hello".as_bytes().to_vec());
Ok(())
}
}

View File

@@ -0,0 +1,114 @@
use crate::encryption::{decrypt_data, encrypt_data};
use crate::error::Error::InvalidHumanKey;
use crate::error::Result;
use base32::Alphabet;
use chacha20poly1305::aead::{Key, KeyInit, OsRng};
use chacha20poly1305::{KeySizeUser, XChaCha20Poly1305};
#[derive(Debug, Clone)]
pub struct WorkspaceKey {
key: Key<XChaCha20Poly1305>,
}
const HUMAN_PREFIX: &str = "YK";
impl WorkspaceKey {
pub(crate) fn to_human(&self) -> Result<String> {
let encoded = base32::encode(Alphabet::Crockford {}, self.key.as_slice());
let with_prefix = format!("{HUMAN_PREFIX}{encoded}");
let with_separators = with_prefix
.chars()
.collect::<Vec<_>>()
.chunks(6)
.map(|chunk| chunk.iter().collect::<String>())
.collect::<Vec<_>>()
.join("-");
Ok(with_separators)
}
#[allow(dead_code)]
pub(crate) fn from_human(human_key: &str) -> Result<Self> {
let without_prefix = human_key.strip_prefix(HUMAN_PREFIX).unwrap_or(human_key);
let without_separators = without_prefix.replace("-", "");
let key =
base32::decode(Alphabet::Crockford {}, &without_separators).ok_or(InvalidHumanKey)?;
if key.len() != XChaCha20Poly1305::key_size() {
return Err(InvalidHumanKey);
}
Ok(Self::from_raw_key(key.as_slice()))
}
pub(crate) fn from_raw_key(key: &[u8]) -> Self {
Self { key: Key::<XChaCha20Poly1305>::clone_from_slice(key) }
}
pub(crate) fn raw_key(&self) -> &[u8] {
self.key.as_slice()
}
pub(crate) fn create() -> Result<Self> {
let key = XChaCha20Poly1305::generate_key(OsRng);
Ok(Self::from_raw_key(key.as_slice()))
}
pub(crate) fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
encrypt_data(data, &self.key)
}
pub(crate) fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
decrypt_data(data, &self.key)
}
#[cfg(test)]
pub(crate) fn test_key() -> Self {
Self::from_raw_key("f1a2d4b3c8e799af1456be3478a4c3f2".as_bytes())
}
}
#[cfg(test)]
mod tests {
use crate::error::Error::InvalidHumanKey;
use crate::error::Result;
use crate::workspace_key::WorkspaceKey;
#[test]
fn test_persisted_key() -> Result<()> {
let key = WorkspaceKey::test_key();
let encrypted = key.encrypt("hello".as_bytes())?;
assert_eq!(key.decrypt(encrypted.as_slice())?, "hello".as_bytes());
Ok(())
}
#[test]
fn test_human_format() -> Result<()> {
let key = WorkspaceKey::test_key();
let encrypted = key.encrypt("hello".as_bytes())?;
assert_eq!(key.decrypt(encrypted.as_slice())?, "hello".as_bytes());
let human = key.to_human()?;
assert_eq!(human, "YKCRRP-2CK46H-H36RSR-CMVKJE-B1CRRK-8D9PC9-JK6D1Q-71GK8R-SKCRS0");
assert_eq!(
WorkspaceKey::from_human(&human)?.decrypt(encrypted.as_slice())?,
"hello".as_bytes()
);
Ok(())
}
#[test]
fn test_from_human_invalid() -> Result<()> {
assert!(matches!(
WorkspaceKey::from_human(
"YKCRRP-2CK46H-H36RSR-CMVKJE-B1CRRK-8D9PC9-JK6D1Q-71GK8R-SKCRS0-H3X38D",
),
Err(InvalidHumanKey)
));
assert!(matches!(WorkspaceKey::from_human("bad-key",), Err(InvalidHumanKey)));
assert!(matches!(WorkspaceKey::from_human("",), Err(InvalidHumanKey)));
Ok(())
}
}

View File

@@ -0,0 +1,18 @@
[package]
name = "yaak-git"
version = "0.1.0"
edition = "2024"
publish = false
[dependencies]
chrono = { workspace = true, features = ["serde"] }
git2 = { version = "0.20.0", features = ["vendored-libgit2", "vendored-openssl"] }
log = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
serde_yaml = "0.9.34"
thiserror = { workspace = true }
ts-rs = { workspace = true, features = ["chrono-impl", "serde-json-impl"] }
url = "2"
yaak-models = { workspace = true }
yaak-sync = { workspace = true }

18
crates/yaak-git/bindings/gen_git.ts generated Normal file
View File

@@ -0,0 +1,18 @@
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
import type { SyncModel } from "./gen_models";
export type GitAuthor = { name: string | null, email: string | null, };
export type GitCommit = { author: GitAuthor, when: string, message: string | null, };
export type GitRemote = { name: string, url: string | null, };
export type GitStatus = "untracked" | "conflict" | "current" | "modified" | "removed" | "renamed" | "type_change";
export type GitStatusEntry = { relaPath: string, status: GitStatus, staged: boolean, prev: SyncModel | null, next: SyncModel | null, };
export type GitStatusSummary = { path: string, headRef: string | null, headRefShorthand: string | null, entries: Array<GitStatusEntry>, origins: Array<string>, localBranches: Array<string>, remoteBranches: Array<string>, };
export type PullResult = { "type": "success", message: string, } | { "type": "up_to_date" } | { "type": "needs_credentials", url: string, error: string | null, };
export type PushResult = { "type": "success", message: string, } | { "type": "up_to_date" } | { "type": "needs_credentials", url: string, error: string | null, };

21
crates/yaak-git/bindings/gen_models.ts generated Normal file
View File

@@ -0,0 +1,21 @@
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
export type Environment = { model: "environment", id: string, workspaceId: string, createdAt: string, updatedAt: string, name: string, public: boolean, parentModel: string, parentId: string | null, variables: Array<EnvironmentVariable>, color: string | null, sortPriority: number, };
export type EnvironmentVariable = { enabled?: boolean, name: string, value: string, id?: string, };
export type Folder = { model: "folder", id: string, createdAt: string, updatedAt: string, workspaceId: string, folderId: string | null, authentication: Record<string, any>, authenticationType: string | null, description: string, headers: Array<HttpRequestHeader>, name: string, sortPriority: number, };
export type GrpcRequest = { model: "grpc_request", id: string, createdAt: string, updatedAt: string, workspaceId: string, folderId: string | null, authenticationType: string | null, authentication: Record<string, any>, description: string, message: string, metadata: Array<HttpRequestHeader>, method: string | null, name: string, service: string | null, sortPriority: number, url: string, };
export type HttpRequest = { model: "http_request", id: string, createdAt: string, updatedAt: string, workspaceId: string, folderId: string | null, authentication: Record<string, any>, authenticationType: string | null, body: Record<string, any>, bodyType: string | null, description: string, headers: Array<HttpRequestHeader>, method: string, name: string, sortPriority: number, url: string, urlParameters: Array<HttpUrlParameter>, };
export type HttpRequestHeader = { enabled?: boolean, name: string, value: string, id?: string, };
export type HttpUrlParameter = { enabled?: boolean, name: string, value: string, id?: string, };
export type SyncModel = { "type": "workspace" } & Workspace | { "type": "environment" } & Environment | { "type": "folder" } & Folder | { "type": "http_request" } & HttpRequest | { "type": "grpc_request" } & GrpcRequest | { "type": "websocket_request" } & WebsocketRequest;
export type WebsocketRequest = { model: "websocket_request", id: string, createdAt: string, updatedAt: string, workspaceId: string, folderId: string | null, authentication: Record<string, any>, authenticationType: string | null, description: string, headers: Array<HttpRequestHeader>, message: string, name: string, sortPriority: number, url: string, urlParameters: Array<HttpUrlParameter>, };
export type Workspace = { model: "workspace", id: string, createdAt: string, updatedAt: string, authentication: Record<string, any>, authenticationType: string | null, description: string, headers: Array<HttpRequestHeader>, name: string, encryptionKeyChallenge: string | null, settingValidateCertificates: boolean, settingFollowRedirects: boolean, settingRequestTimeout: number, };

168
crates/yaak-git/index.ts Normal file
View File

@@ -0,0 +1,168 @@
import { useQuery } from '@tanstack/react-query';
import { invoke } from '@tauri-apps/api/core';
import { createFastMutation } from '@yaakapp/app/hooks/useFastMutation';
import { queryClient } from '@yaakapp/app/lib/queryClient';
import { useMemo } from 'react';
import { GitCommit, GitRemote, GitStatusSummary, PullResult, PushResult } from './bindings/gen_git';
export * from './bindings/gen_git';
export interface GitCredentials {
username: string;
password: string;
}
export interface GitCallbacks {
addRemote: () => Promise<GitRemote | null>;
promptCredentials: (
result: Extract<PushResult, { type: 'needs_credentials' }>,
) => Promise<GitCredentials | null>;
}
const onSuccess = () => queryClient.invalidateQueries({ queryKey: ['git'] });
export function useGit(dir: string, callbacks: GitCallbacks) {
const mutations = useMemo(() => gitMutations(dir, callbacks), [dir, callbacks]);
return [
{
remotes: useQuery<GitRemote[], string>({
queryKey: ['git', 'remotes', dir],
queryFn: () => getRemotes(dir),
}),
log: useQuery<GitCommit[], string>({
queryKey: ['git', 'log', dir],
queryFn: () => invoke('cmd_git_log', { dir }),
}),
status: useQuery<GitStatusSummary, string>({
refetchOnMount: true,
queryKey: ['git', 'status', dir],
queryFn: () => invoke('cmd_git_status', { dir }),
}),
},
mutations,
] as const;
}
export const gitMutations = (dir: string, callbacks: GitCallbacks) => {
const push = async () => {
const remotes = await getRemotes(dir);
if (remotes.length === 0) {
const remote = await callbacks.addRemote();
if (remote == null) throw new Error('No remote found');
}
const result = await invoke<PushResult>('cmd_git_push', { dir });
if (result.type !== 'needs_credentials') return result;
// Needs credentials, prompt for them
const creds = await callbacks.promptCredentials(result);
if (creds == null) throw new Error('Canceled');
await invoke('cmd_git_add_credential', {
dir,
remoteUrl: result.url,
username: creds.username,
password: creds.password,
});
// Push again
return invoke<PushResult>('cmd_git_push', { dir });
};
return {
init: createFastMutation<void, string, void>({
mutationKey: ['git', 'init'],
mutationFn: () => invoke('cmd_git_initialize', { dir }),
onSuccess,
}),
add: createFastMutation<void, string, { relaPaths: string[] }>({
mutationKey: ['git', 'add', dir],
mutationFn: (args) => invoke('cmd_git_add', { dir, ...args }),
onSuccess,
}),
addRemote: createFastMutation<GitRemote, string, GitRemote>({
mutationKey: ['git', 'add-remote'],
mutationFn: (args) => invoke('cmd_git_add_remote', { dir, ...args }),
onSuccess,
}),
rmRemote: createFastMutation<void, string, { name: string }>({
mutationKey: ['git', 'rm-remote', dir],
mutationFn: (args) => invoke('cmd_git_rm_remote', { dir, ...args }),
onSuccess,
}),
branch: createFastMutation<void, string, { branch: string }>({
mutationKey: ['git', 'branch', dir],
mutationFn: (args) => invoke('cmd_git_branch', { dir, ...args }),
onSuccess,
}),
mergeBranch: createFastMutation<void, string, { branch: string; force: boolean }>({
mutationKey: ['git', 'merge', dir],
mutationFn: (args) => invoke('cmd_git_merge_branch', { dir, ...args }),
onSuccess,
}),
deleteBranch: createFastMutation<void, string, { branch: string }>({
mutationKey: ['git', 'delete-branch', dir],
mutationFn: (args) => invoke('cmd_git_delete_branch', { dir, ...args }),
onSuccess,
}),
checkout: createFastMutation<string, string, { branch: string; force: boolean }>({
mutationKey: ['git', 'checkout', dir],
mutationFn: (args) => invoke('cmd_git_checkout', { dir, ...args }),
onSuccess,
}),
commit: createFastMutation<void, string, { message: string }>({
mutationKey: ['git', 'commit', dir],
mutationFn: (args) => invoke('cmd_git_commit', { dir, ...args }),
onSuccess,
}),
commitAndPush: createFastMutation<PushResult, string, { message: string }>({
mutationKey: ['git', 'commit_push', dir],
mutationFn: async (args) => {
await invoke('cmd_git_commit', { dir, ...args });
return push();
},
onSuccess,
}),
fetchAll: createFastMutation<string, string, void>({
mutationKey: ['git', 'checkout', dir],
mutationFn: () => invoke('cmd_git_fetch_all', { dir }),
onSuccess,
}),
push: createFastMutation<PushResult, string, void>({
mutationKey: ['git', 'push', dir],
mutationFn: push,
onSuccess,
}),
pull: createFastMutation<PullResult, string, void>({
mutationKey: ['git', 'pull', dir],
async mutationFn() {
const result = await invoke<PullResult>('cmd_git_pull', { dir });
if (result.type !== 'needs_credentials') return result;
// Needs credentials, prompt for them
const creds = await callbacks.promptCredentials(result);
if (creds == null) throw new Error('Canceled');
await invoke('cmd_git_add_credential', {
dir,
remoteUrl: result.url,
username: creds.username,
password: creds.password,
});
// Pull again
return invoke<PullResult>('cmd_git_pull', { dir });
},
onSuccess,
}),
unstage: createFastMutation<void, string, { relaPaths: string[] }>({
mutationKey: ['git', 'unstage', dir],
mutationFn: (args) => invoke('cmd_git_unstage', { dir, ...args }),
onSuccess,
}),
} as const;
};
async function getRemotes(dir: string) {
return invoke<GitRemote[]>('cmd_git_remotes', { dir });
}

View File

@@ -0,0 +1,6 @@
{
"name": "@yaakapp-internal/git",
"private": true,
"version": "1.0.0",
"main": "index.ts"
}

View File

@@ -0,0 +1,16 @@
use crate::error::Result;
use crate::repository::open_repo;
use git2::IndexAddOption;
use log::info;
use std::path::Path;
pub fn git_add(dir: &Path, rela_path: &Path) -> Result<()> {
let repo = open_repo(dir)?;
let mut index = repo.index()?;
info!("Staging file {rela_path:?} to {dir:?}");
index.add_all(&[rela_path], IndexAddOption::DEFAULT, None)?;
index.write()?;
Ok(())
}

View File

@@ -0,0 +1,38 @@
use crate::error::Result;
use std::path::Path;
use std::process::{Command, Stdio};
use crate::error::Error::GitNotFound;
#[cfg(target_os = "windows")]
use std::os::windows::process::CommandExt;
#[cfg(target_os = "windows")]
const CREATE_NO_WINDOW: u32 = 0x0800_0000;
pub(crate) fn new_binary_command(dir: &Path) -> Result<Command> {
// 1. Probe that `git` exists and is runnable
let mut probe = Command::new("git");
probe.arg("--version").stdin(Stdio::null()).stdout(Stdio::null()).stderr(Stdio::null());
#[cfg(target_os = "windows")]
{
probe.creation_flags(CREATE_NO_WINDOW);
}
let status = probe.status().map_err(|_| GitNotFound)?;
if !status.success() {
return Err(GitNotFound);
}
// 2. Build the reusable git command
let mut cmd = Command::new("git");
cmd.arg("-C").arg(dir);
#[cfg(target_os = "windows")]
{
cmd.creation_flags(CREATE_NO_WINDOW);
}
Ok(cmd)
}

View File

@@ -0,0 +1,99 @@
use crate::error::Error::GenericError;
use crate::error::Result;
use crate::merge::do_merge;
use crate::repository::open_repo;
use crate::util::{bytes_to_string, get_branch_by_name, get_current_branch};
use git2::BranchType;
use git2::build::CheckoutBuilder;
use log::info;
use std::path::Path;
pub fn git_checkout_branch(dir: &Path, branch_name: &str, force: bool) -> Result<String> {
if branch_name.starts_with("origin/") {
return git_checkout_remote_branch(dir, branch_name, force);
}
let repo = open_repo(dir)?;
let branch = get_branch_by_name(&repo, branch_name)?;
let branch_ref = branch.into_reference();
let branch_tree = branch_ref.peel_to_tree()?;
let mut options = CheckoutBuilder::default();
if force {
options.force();
}
repo.checkout_tree(branch_tree.as_object(), Some(&mut options))?;
repo.set_head(branch_ref.name().unwrap())?;
Ok(branch_name.to_string())
}
pub(crate) fn git_checkout_remote_branch(
dir: &Path,
branch_name: &str,
force: bool,
) -> Result<String> {
let branch_name = branch_name.trim_start_matches("origin/");
let repo = open_repo(dir)?;
let refname = format!("refs/remotes/origin/{}", branch_name);
let remote_ref = repo.find_reference(&refname)?;
let commit = remote_ref.peel_to_commit()?;
let mut new_branch = repo.branch(branch_name, &commit, false)?;
let upstream_name = format!("origin/{}", branch_name);
new_branch.set_upstream(Some(&upstream_name))?;
git_checkout_branch(dir, branch_name, force)
}
pub fn git_create_branch(dir: &Path, name: &str) -> Result<()> {
let repo = open_repo(dir)?;
let head = match repo.head() {
Ok(h) => h,
Err(e) if e.code() == git2::ErrorCode::UnbornBranch => {
let msg = "Cannot create branch when there are no commits";
return Err(GenericError(msg.into()));
}
Err(e) => return Err(e.into()),
};
let head = head.peel_to_commit()?;
repo.branch(name, &head, false)?;
Ok(())
}
pub fn git_delete_branch(dir: &Path, name: &str) -> Result<()> {
let repo = open_repo(dir)?;
let mut branch = get_branch_by_name(&repo, name)?;
if branch.is_head() {
info!("Deleting head branch");
let branches = repo.branches(Some(BranchType::Local))?;
let other_branch = branches.into_iter().filter_map(|b| b.ok()).find(|b| !b.0.is_head());
let other_branch = match other_branch {
None => return Err(GenericError("Cannot delete only branch".into())),
Some(b) => bytes_to_string(b.0.name_bytes()?)?,
};
git_checkout_branch(dir, &other_branch, true)?;
}
branch.delete()?;
Ok(())
}
pub fn git_merge_branch(dir: &Path, name: &str, _force: bool) -> Result<()> {
let repo = open_repo(dir)?;
let local_branch = get_current_branch(&repo)?.unwrap();
let commit_to_merge = get_branch_by_name(&repo, name)?.into_reference();
let commit_to_merge = repo.reference_to_annotated_commit(&commit_to_merge)?;
do_merge(&repo, &local_branch, &commit_to_merge)?;
Ok(())
}

View File

@@ -0,0 +1,20 @@
use crate::binary::new_binary_command;
use crate::error::Error::GenericError;
use log::info;
use std::path::Path;
pub fn git_commit(dir: &Path, message: &str) -> crate::error::Result<()> {
let out = new_binary_command(dir)?.args(["commit", "--message", message]).output()?;
let stdout = String::from_utf8_lossy(&out.stdout);
let stderr = String::from_utf8_lossy(&out.stderr);
let combined = stdout + stderr;
if !out.status.success() {
return Err(GenericError(format!("Failed to commit: {}", combined)));
}
info!("Committed to {dir:?}");
Ok(())
}

View File

@@ -0,0 +1,47 @@
use crate::binary::new_binary_command;
use crate::error::Error::GenericError;
use crate::error::Result;
use std::io::Write;
use std::path::Path;
use std::process::Stdio;
use url::Url;
pub async fn git_add_credential(
dir: &Path,
remote_url: &str,
username: &str,
password: &str,
) -> Result<()> {
let url = Url::parse(remote_url)
.map_err(|e| GenericError(format!("Failed to parse remote url {remote_url}: {e:?}")))?;
let protocol = url.scheme();
let host = url.host_str().unwrap();
let path = Some(url.path());
let mut child = new_binary_command(dir)?
.args(["credential", "approve"])
.stdin(Stdio::piped())
.stdout(Stdio::null())
.spawn()?;
{
let stdin = child.stdin.as_mut().unwrap();
writeln!(stdin, "protocol={}", protocol)?;
writeln!(stdin, "host={}", host)?;
if let Some(path) = path {
if !path.is_empty() {
writeln!(stdin, "path={}", path.trim_start_matches('/'))?;
}
}
writeln!(stdin, "username={}", username)?;
writeln!(stdin, "password={}", password)?;
writeln!(stdin)?; // blank line terminator
}
let status = child.wait()?;
if !status.success() {
return Err(GenericError("Failed to approve git credential".to_string()));
}
Ok(())
}

View File

@@ -0,0 +1,64 @@
use serde::{Serialize, Serializer};
use std::io;
use std::path::PathBuf;
use std::string::FromUtf8Error;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum Error {
#[error("Git repo not found {0}")]
GitRepoNotFound(PathBuf),
#[error("Git error: {0}")]
GitUnknown(#[from] git2::Error),
#[error("Yaml error: {0}")]
YamlParseError(#[from] serde_yaml::Error),
#[error(transparent)]
ModelError(#[from] yaak_models::error::Error),
#[error("Sync error: {0}")]
SyncError(#[from] yaak_sync::error::Error),
#[error("I/o error: {0}")]
IoError(#[from] io::Error),
#[error("JSON error: {0}")]
JsonParseError(#[from] serde_json::Error),
#[error("UTF8 error: {0}")]
Utf8ConversionError(#[from] FromUtf8Error),
#[error("Git error: {0}")]
GenericError(String),
#[error("'git' not found. Please ensure it's installed and available in $PATH")]
GitNotFound,
#[error("Credentials required: {0}")]
CredentialsRequiredError(String),
#[error("No default remote found")]
NoDefaultRemoteFound,
#[error("No remotes found for repo")]
NoRemotesFound,
#[error("Merge failed due to conflicts")]
MergeConflicts,
#[error("No active branch")]
NoActiveBranch,
}
impl Serialize for Error {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.to_string().as_ref())
}
}
pub type Result<T> = std::result::Result<T, Error>;

View File

@@ -0,0 +1,20 @@
use crate::binary::new_binary_command;
use crate::error::Error::GenericError;
use crate::error::Result;
use std::path::Path;
pub fn git_fetch_all(dir: &Path) -> Result<()> {
let out = new_binary_command(dir)?
.args(["fetch", "--all", "--prune", "--tags"])
.output()
.map_err(|e| GenericError(format!("failed to run git pull: {e}")))?;
let stdout = String::from_utf8_lossy(&out.stdout);
let stderr = String::from_utf8_lossy(&out.stderr);
let combined = stdout + stderr;
if !out.status.success() {
return Err(GenericError(format!("Failed to fetch: {}", combined)));
}
Ok(())
}

View File

@@ -0,0 +1,14 @@
use crate::error::Result;
use crate::repository::open_repo;
use log::info;
use std::path::Path;
pub fn git_init(dir: &Path) -> Result<()> {
git2::Repository::init(dir)?;
let repo = open_repo(dir)?;
// Default to main instead of master, to align with
// the official Git and GitHub behavior
repo.set_head("refs/heads/main")?;
info!("Initialized {dir:?}");
Ok(())
}

View File

@@ -0,0 +1,31 @@
mod add;
mod binary;
mod branch;
mod commit;
mod credential;
pub mod error;
mod fetch;
mod init;
mod log;
mod merge;
mod pull;
mod push;
mod remotes;
mod repository;
mod status;
mod unstage;
mod util;
// Re-export all git functions for external use
pub use add::git_add;
pub use branch::{git_checkout_branch, git_create_branch, git_delete_branch, git_merge_branch};
pub use commit::git_commit;
pub use credential::git_add_credential;
pub use fetch::git_fetch_all;
pub use init::git_init;
pub use log::{GitCommit, git_log};
pub use pull::{PullResult, git_pull};
pub use push::{PushResult, git_push};
pub use remotes::{GitRemote, git_add_remote, git_remotes, git_rm_remote};
pub use status::{GitStatusSummary, git_status};
pub use unstage::git_unstage;

View File

@@ -0,0 +1,73 @@
use crate::repository::open_repo;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::path::Path;
use ts_rs::TS;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export, export_to = "gen_git.ts")]
pub struct GitCommit {
pub author: GitAuthor,
pub when: DateTime<Utc>,
pub message: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export, export_to = "gen_git.ts")]
pub struct GitAuthor {
pub name: Option<String>,
pub email: Option<String>,
}
pub fn git_log(dir: &Path) -> crate::error::Result<Vec<GitCommit>> {
let repo = open_repo(dir)?;
// Return empty if empty repo or no head (new repo)
if repo.is_empty()? || repo.head().is_err() {
return Ok(vec![]);
}
let mut revwalk = repo.revwalk()?;
revwalk.push_head()?;
revwalk.set_sorting(git2::Sort::TIME)?;
// Run git log
macro_rules! filter_try {
($e:expr) => {
match $e {
Ok(t) => t,
Err(_) => return None,
}
};
}
let log: Vec<GitCommit> = revwalk
.filter_map(|oid| {
let oid = filter_try!(oid);
let commit = filter_try!(repo.find_commit(oid));
let author = commit.author();
Some(GitCommit {
author: GitAuthor {
name: author.name().map(|s| s.to_string()),
email: author.email().map(|s| s.to_string()),
},
when: convert_git_time_to_date(author.when()),
message: commit.message().map(|m| m.to_string()),
})
})
.collect();
Ok(log)
}
#[cfg(test)]
fn convert_git_time_to_date(_git_time: git2::Time) -> DateTime<Utc> {
DateTime::from_timestamp(0, 0).unwrap()
}
#[cfg(not(test))]
fn convert_git_time_to_date(git_time: git2::Time) -> DateTime<Utc> {
let timestamp = git_time.seconds();
DateTime::from_timestamp(timestamp, 0).unwrap()
}

View File

@@ -0,0 +1,135 @@
use crate::error::Error::MergeConflicts;
use crate::util::bytes_to_string;
use git2::{AnnotatedCommit, Branch, IndexEntry, Reference, Repository};
use log::{debug, info};
pub(crate) fn do_merge(
repo: &Repository,
local_branch: &Branch,
commit_to_merge: &AnnotatedCommit,
) -> crate::error::Result<()> {
debug!("Merging remote branches");
let analysis = repo.merge_analysis(&[&commit_to_merge])?;
if analysis.0.is_fast_forward() {
let refname = bytes_to_string(local_branch.get().name_bytes())?;
match repo.find_reference(&refname) {
Ok(mut r) => {
merge_fast_forward(repo, &mut r, &commit_to_merge)?;
}
Err(_) => {
// The branch doesn't exist, so set the reference to the commit directly. Usually
// this is because you are pulling into an empty repository.
repo.reference(
&refname,
commit_to_merge.id(),
true,
&format!("Setting {} to {}", refname, commit_to_merge.id()),
)?;
repo.set_head(&refname)?;
repo.checkout_head(Some(
git2::build::CheckoutBuilder::default()
.allow_conflicts(true)
.conflict_style_merge(true)
.force(),
))?;
}
};
} else if analysis.0.is_normal() {
let head_commit = repo.reference_to_annotated_commit(&repo.head()?)?;
merge_normal(repo, &head_commit, commit_to_merge)?;
} else {
debug!("Skipping merge. Nothing to do")
}
Ok(())
}
pub(crate) fn merge_fast_forward(
repo: &Repository,
local_reference: &mut Reference,
remote_commit: &AnnotatedCommit,
) -> crate::error::Result<()> {
info!("Performing fast forward");
let name = match local_reference.name() {
Some(s) => s.to_string(),
None => String::from_utf8_lossy(local_reference.name_bytes()).to_string(),
};
let msg = format!("Fast-Forward: Setting {} to id: {}", name, remote_commit.id());
local_reference.set_target(remote_commit.id(), &msg)?;
repo.set_head(&name)?;
repo.checkout_head(Some(
git2::build::CheckoutBuilder::default()
// For some reason, the force is required to make the working directory actually get
// updated I suspect we should be adding some logic to handle dirty working directory
// states, but this is just an example so maybe not.
.force(),
))?;
Ok(())
}
pub(crate) fn merge_normal(
repo: &Repository,
local: &AnnotatedCommit,
remote: &AnnotatedCommit,
) -> crate::error::Result<()> {
info!("Performing normal merge");
let local_tree = repo.find_commit(local.id())?.tree()?;
let remote_tree = repo.find_commit(remote.id())?.tree()?;
let ancestor = repo.find_commit(repo.merge_base(local.id(), remote.id())?)?.tree()?;
let mut idx = repo.merge_trees(&ancestor, &local_tree, &remote_tree, None)?;
if idx.has_conflicts() {
let conflicts = idx.conflicts()?;
for conflict in conflicts {
if let Ok(conflict) = conflict {
print_conflict(&conflict);
}
}
return Err(MergeConflicts);
}
let result_tree = repo.find_tree(idx.write_tree_to(repo)?)?;
// now create the merge commit
let msg = format!("Merge: {} into {}", remote.id(), local.id());
let sig = repo.signature()?;
let local_commit = repo.find_commit(local.id())?;
let remote_commit = repo.find_commit(remote.id())?;
// Do our merge commit and set current branch head to that commit.
let _merge_commit = repo.commit(
Some("HEAD"),
&sig,
&sig,
&msg,
&result_tree,
&[&local_commit, &remote_commit],
)?;
// Set working tree to match head.
repo.checkout_head(None)?;
Ok(())
}
fn print_conflict(conflict: &git2::IndexConflict) {
let ancestor = conflict.ancestor.as_ref().map(path_from_index_entry);
let ours = conflict.our.as_ref().map(path_from_index_entry);
let theirs = conflict.their.as_ref().map(path_from_index_entry);
println!("Conflict detected:");
if let Some(path) = ancestor {
println!(" Common ancestor: {:?}", path);
}
if let Some(path) = ours {
println!(" Ours: {:?}", path);
}
if let Some(path) = theirs {
println!(" Theirs: {:?}", path);
}
}
fn path_from_index_entry(entry: &IndexEntry) -> String {
String::from_utf8_lossy(entry.path.as_slice()).into_owned()
}

View File

@@ -0,0 +1,95 @@
use crate::binary::new_binary_command;
use crate::error::Error::GenericError;
use crate::error::Result;
use crate::repository::open_repo;
use crate::util::{get_current_branch_name, get_default_remote_in_repo};
use log::info;
use serde::{Deserialize, Serialize};
use std::path::Path;
use ts_rs::TS;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TS)]
#[serde(rename_all = "snake_case", tag = "type")]
#[ts(export, export_to = "gen_git.ts")]
pub enum PullResult {
Success { message: String },
UpToDate,
NeedsCredentials { url: String, error: Option<String> },
}
pub fn git_pull(dir: &Path) -> Result<PullResult> {
let repo = open_repo(dir)?;
let branch_name = get_current_branch_name(&repo)?;
let remote = get_default_remote_in_repo(&repo)?;
let remote_name = remote.name().ok_or(GenericError("Failed to get remote name".to_string()))?;
let remote_url = remote.url().ok_or(GenericError("Failed to get remote url".to_string()))?;
let out = new_binary_command(dir)?
.args(["pull", &remote_name, &branch_name])
.env("GIT_TERMINAL_PROMPT", "0")
.output()
.map_err(|e| GenericError(format!("failed to run git pull: {e}")))?;
let stdout = String::from_utf8_lossy(&out.stdout);
let stderr = String::from_utf8_lossy(&out.stderr);
let combined = stdout + stderr;
info!("Pulled status={} {combined}", out.status);
if combined.to_lowercase().contains("could not read") {
return Ok(PullResult::NeedsCredentials { url: remote_url.to_string(), error: None });
}
if combined.to_lowercase().contains("unable to access") {
return Ok(PullResult::NeedsCredentials {
url: remote_url.to_string(),
error: Some(combined.to_string()),
});
}
if !out.status.success() {
return Err(GenericError(format!("Failed to pull {combined}")));
}
if combined.to_lowercase().contains("up to date") {
return Ok(PullResult::UpToDate);
}
Ok(PullResult::Success { message: format!("Pulled from {}/{}", remote_name, branch_name) })
}
// pub(crate) fn git_pull_old(dir: &Path) -> Result<PullResult> {
// let repo = open_repo(dir)?;
//
// let branch = get_current_branch(&repo)?.ok_or(NoActiveBranch)?;
// let branch_ref = branch.get();
// let branch_ref = bytes_to_string(branch_ref.name_bytes())?;
//
// let remote_name = repo.branch_upstream_remote(&branch_ref)?;
// let remote_name = bytes_to_string(&remote_name)?;
// debug!("Pulling from {remote_name}");
//
// let mut remote = repo.find_remote(&remote_name)?;
//
// let mut options = FetchOptions::new();
// let callbacks = default_callbacks();
// options.remote_callbacks(callbacks);
//
// let mut proxy = ProxyOptions::new();
// proxy.auto();
// options.proxy_options(proxy);
//
// remote.fetch(&[&branch_ref], Some(&mut options), None)?;
//
// let stats = remote.stats();
//
// let fetch_head = repo.find_reference("FETCH_HEAD")?;
// let fetch_commit = repo.reference_to_annotated_commit(&fetch_head)?;
// do_merge(&repo, &branch, &fetch_commit)?;
//
// Ok(PullResult::Success {
// message: "Hello".to_string(),
// // received_bytes: stats.received_bytes(),
// // received_objects: stats.received_objects(),
// })
// }

View File

@@ -0,0 +1,81 @@
use crate::binary::new_binary_command;
use crate::error::Error::GenericError;
use crate::error::Result;
use crate::repository::open_repo;
use crate::util::{get_current_branch_name, get_default_remote_for_push_in_repo};
use log::info;
use serde::{Deserialize, Serialize};
use std::path::Path;
use ts_rs::TS;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TS)]
#[serde(rename_all = "snake_case", tag = "type")]
#[ts(export, export_to = "gen_git.ts")]
pub enum PushResult {
Success { message: String },
UpToDate,
NeedsCredentials { url: String, error: Option<String> },
}
pub fn git_push(dir: &Path) -> Result<PushResult> {
let repo = open_repo(dir)?;
let branch_name = get_current_branch_name(&repo)?;
let remote = get_default_remote_for_push_in_repo(&repo)?;
let remote_name = remote.name().ok_or(GenericError("Failed to get remote name".to_string()))?;
let remote_url = remote.url().ok_or(GenericError("Failed to get remote url".to_string()))?;
let out = new_binary_command(dir)?
.args(["push", &remote_name, &branch_name])
.env("GIT_TERMINAL_PROMPT", "0")
.output()
.map_err(|e| GenericError(format!("failed to run git push: {e}")))?;
let stdout = String::from_utf8_lossy(&out.stdout);
let stderr = String::from_utf8_lossy(&out.stderr);
let combined = stdout + stderr;
let combined_lower = combined.to_lowercase();
info!("Pushed to repo status={} {combined}", out.status);
// Helper to check if this is a credentials error
let is_credentials_error = || {
combined_lower.contains("could not read")
|| combined_lower.contains("unable to access")
|| combined_lower.contains("authentication failed")
};
// Check for explicit rejection indicators first (e.g., protected branch rejections)
// These can occur even if some git servers don't properly set exit codes
if combined_lower.contains("rejected") || combined_lower.contains("failed to push") {
if is_credentials_error() {
return Ok(PushResult::NeedsCredentials {
url: remote_url.to_string(),
error: Some(combined.to_string()),
});
}
return Err(GenericError(format!("Failed to push: {combined}")));
}
// Check exit status for any other failures
if !out.status.success() {
if combined_lower.contains("could not read") {
return Ok(PushResult::NeedsCredentials { url: remote_url.to_string(), error: None });
}
if combined_lower.contains("unable to access")
|| combined_lower.contains("authentication failed")
{
return Ok(PushResult::NeedsCredentials {
url: remote_url.to_string(),
error: Some(combined.to_string()),
});
}
return Err(GenericError(format!("Failed to push: {combined}")));
}
// Success cases (exit code 0 and no rejection indicators)
if combined_lower.contains("up-to-date") {
return Ok(PushResult::UpToDate);
}
Ok(PushResult::Success { message: format!("Pushed to {}/{}", remote_name, branch_name) })
}

View File

@@ -0,0 +1,47 @@
use crate::error::Result;
use crate::repository::open_repo;
use log::warn;
use serde::{Deserialize, Serialize};
use std::path::Path;
use ts_rs::TS;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TS)]
#[ts(export, export_to = "gen_git.ts")]
pub struct GitRemote {
name: String,
url: Option<String>,
}
pub fn git_remotes(dir: &Path) -> Result<Vec<GitRemote>> {
let repo = open_repo(dir)?;
let mut remotes = Vec::new();
for remote in repo.remotes()?.into_iter() {
let name = match remote {
None => continue,
Some(name) => name,
};
let r = match repo.find_remote(name) {
Ok(r) => r,
Err(e) => {
warn!("Failed to get remote {name}: {e:?}");
continue;
}
};
remotes.push(GitRemote { name: name.to_string(), url: r.url().map(|u| u.to_string()) });
}
Ok(remotes)
}
pub fn git_add_remote(dir: &Path, name: &str, url: &str) -> Result<GitRemote> {
let repo = open_repo(dir)?;
repo.remote(name, url)?;
Ok(GitRemote { name: name.to_string(), url: Some(url.to_string()) })
}
pub fn git_rm_remote(dir: &Path, name: &str) -> Result<()> {
let repo = open_repo(dir)?;
repo.remote_delete(name)?;
Ok(())
}

View File

@@ -0,0 +1,10 @@
use crate::error::Error::{GitRepoNotFound, GitUnknown};
use std::path::Path;
pub(crate) fn open_repo(dir: &Path) -> crate::error::Result<git2::Repository> {
match git2::Repository::discover(dir) {
Ok(r) => Ok(r),
Err(e) if e.code() == git2::ErrorCode::NotFound => Err(GitRepoNotFound(dir.to_path_buf())),
Err(e) => Err(GitUnknown(e)),
}
}

View File

@@ -0,0 +1,172 @@
use crate::repository::open_repo;
use crate::util::{local_branch_names, remote_branch_names};
use log::warn;
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::Path;
use ts_rs::TS;
use yaak_sync::models::SyncModel;
#[derive(Debug, Clone, Serialize, Deserialize, TS, PartialEq)]
#[serde(rename_all = "camelCase")]
#[ts(export, export_to = "gen_git.ts")]
pub struct GitStatusSummary {
pub path: String,
pub head_ref: Option<String>,
pub head_ref_shorthand: Option<String>,
pub entries: Vec<GitStatusEntry>,
pub origins: Vec<String>,
pub local_branches: Vec<String>,
pub remote_branches: Vec<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export, export_to = "gen_git.ts")]
pub struct GitStatusEntry {
pub rela_path: String,
pub status: GitStatus,
pub staged: bool,
pub prev: Option<SyncModel>,
pub next: Option<SyncModel>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TS)]
#[serde(rename_all = "snake_case")]
#[ts(export, export_to = "gen_git.ts")]
pub enum GitStatus {
Untracked,
Conflict,
Current,
Modified,
Removed,
Renamed,
TypeChange,
}
pub fn git_status(dir: &Path) -> crate::error::Result<GitStatusSummary> {
let repo = open_repo(dir)?;
let (head_tree, head_ref, head_ref_shorthand) = match repo.head() {
Ok(head) => {
let tree = head.peel_to_tree().ok();
let head_ref_shorthand = head.shorthand().map(|s| s.to_string());
let head_ref = head.name().map(|s| s.to_string());
(tree, head_ref, head_ref_shorthand)
}
Err(_) => {
// For "unborn" repos, reading from HEAD is the only way to get the branch name
// See https://github.com/starship/starship/pull/1336
let head_path = repo.path().join("HEAD");
let head_ref = fs::read_to_string(&head_path)
.ok()
.unwrap_or_default()
.lines()
.next()
.map(|s| s.trim_start_matches("ref:").trim().to_string());
let head_ref_shorthand =
head_ref.clone().map(|r| r.split('/').last().unwrap_or("unknown").to_string());
(None, head_ref, head_ref_shorthand)
}
};
let mut opts = git2::StatusOptions::new();
opts.include_ignored(false)
.include_untracked(true) // Include untracked
.recurse_untracked_dirs(true) // Show all untracked
.include_unmodified(true); // Include unchanged
// TODO: Support renames
let mut entries: Vec<GitStatusEntry> = Vec::new();
for entry in repo.statuses(Some(&mut opts))?.into_iter() {
let rela_path = entry.path().unwrap().to_string();
let status = entry.status();
let index_status = match status {
// Note: order matters here, since we're checking a bitmap!
s if s.contains(git2::Status::CONFLICTED) => GitStatus::Conflict,
s if s.contains(git2::Status::INDEX_NEW) => GitStatus::Untracked,
s if s.contains(git2::Status::INDEX_MODIFIED) => GitStatus::Modified,
s if s.contains(git2::Status::INDEX_DELETED) => GitStatus::Removed,
s if s.contains(git2::Status::INDEX_RENAMED) => GitStatus::Renamed,
s if s.contains(git2::Status::INDEX_TYPECHANGE) => GitStatus::TypeChange,
s if s.contains(git2::Status::CURRENT) => GitStatus::Current,
s => {
warn!("Unknown index status {s:?}");
continue;
}
};
let worktree_status = match status {
// Note: order matters here, since we're checking a bitmap!
s if s.contains(git2::Status::CONFLICTED) => GitStatus::Conflict,
s if s.contains(git2::Status::WT_NEW) => GitStatus::Untracked,
s if s.contains(git2::Status::WT_MODIFIED) => GitStatus::Modified,
s if s.contains(git2::Status::WT_DELETED) => GitStatus::Removed,
s if s.contains(git2::Status::WT_RENAMED) => GitStatus::Renamed,
s if s.contains(git2::Status::WT_TYPECHANGE) => GitStatus::TypeChange,
s if s.contains(git2::Status::CURRENT) => GitStatus::Current,
s => {
warn!("Unknown worktree status {s:?}");
continue;
}
};
let status = if index_status == GitStatus::Current {
worktree_status.clone()
} else {
index_status.clone()
};
let staged = if index_status == GitStatus::Current && worktree_status == GitStatus::Current
{
// No change, so can't be added
false
} else if index_status != GitStatus::Current {
true
} else {
false
};
// Get previous content from Git, if it's in there
let prev = match head_tree.clone() {
None => None,
Some(t) => match t.get_path(&Path::new(&rela_path)) {
Ok(entry) => {
let obj = entry.to_object(&repo)?;
let content = obj.as_blob().unwrap().content();
let name = Path::new(entry.name().unwrap_or_default());
SyncModel::from_bytes(content.into(), name)?.map(|m| m.0)
}
Err(_) => None,
},
};
let next = {
let full_path = repo.workdir().unwrap().join(rela_path.clone());
SyncModel::from_file(full_path.as_path())?.map(|m| m.0)
};
entries.push(GitStatusEntry {
status,
staged,
rela_path,
prev: prev.clone(),
next: next.clone(),
})
}
let origins = repo.remotes()?.into_iter().filter_map(|o| Some(o?.to_string())).collect();
let local_branches = local_branch_names(&repo)?;
let remote_branches = remote_branch_names(&repo)?;
Ok(GitStatusSummary {
entries,
origins,
path: dir.to_string_lossy().to_string(),
head_ref,
head_ref_shorthand,
local_branches,
remote_branches,
})
}

View File

@@ -0,0 +1,27 @@
use crate::repository::open_repo;
use log::info;
use std::path::Path;
pub fn git_unstage(dir: &Path, rela_path: &Path) -> crate::error::Result<()> {
let repo = open_repo(dir)?;
let head = match repo.head() {
Ok(h) => h,
Err(e) if e.code() == git2::ErrorCode::UnbornBranch => {
info!("Unstaging file in empty branch {rela_path:?} to {dir:?}");
// Repo has no commits, so "unstage" means remove from index
let mut index = repo.index()?;
index.remove_path(rela_path)?;
index.write()?;
return Ok(());
}
Err(e) => return Err(e.into()),
};
// If repo has commits, update the index entry back to HEAD
info!("Unstaging file {rela_path:?} to {dir:?}");
let commit = head.peel_to_commit()?;
repo.reset_default(Some(commit.as_object()), &[rela_path])?;
Ok(())
}

125
crates/yaak-git/src/util.rs Normal file
View File

@@ -0,0 +1,125 @@
use crate::error::Error::{GenericError, NoDefaultRemoteFound};
use crate::error::Result;
use git2::{Branch, BranchType, Remote, Repository};
const DEFAULT_REMOTE_NAME: &str = "origin";
pub(crate) fn get_current_branch(repo: &Repository) -> Result<Option<Branch<'_>>> {
for b in repo.branches(None)? {
let branch = b?.0;
if branch.is_head() {
return Ok(Some(branch));
}
}
Ok(None)
}
pub(crate) fn get_current_branch_name(repo: &Repository) -> Result<String> {
Ok(get_current_branch(&repo)?
.ok_or(GenericError("Failed to get current branch".to_string()))?
.name()?
.ok_or(GenericError("Failed to get current branch name".to_string()))?
.to_string())
}
pub(crate) fn local_branch_names(repo: &Repository) -> Result<Vec<String>> {
let mut branches = Vec::new();
for branch in repo.branches(Some(BranchType::Local))? {
let (branch, _) = branch?;
let name = branch.name_bytes()?;
let name = bytes_to_string(name)?;
branches.push(name);
}
Ok(branches)
}
pub(crate) fn remote_branch_names(repo: &Repository) -> Result<Vec<String>> {
let mut branches = Vec::new();
for branch in repo.branches(Some(BranchType::Remote))? {
let (branch, _) = branch?;
let name = branch.name_bytes()?;
let name = bytes_to_string(name)?;
if name.ends_with("/HEAD") {
continue;
}
branches.push(name);
}
Ok(branches)
}
pub(crate) fn get_branch_by_name<'s>(repo: &'s Repository, name: &str) -> Result<Branch<'s>> {
Ok(repo.find_branch(name, BranchType::Local)?)
}
pub(crate) fn bytes_to_string(bytes: &[u8]) -> Result<String> {
Ok(String::from_utf8(bytes.to_vec())?)
}
pub(crate) fn get_default_remote_for_push_in_repo(repo: &'_ Repository) -> Result<Remote<'_>> {
let name = get_default_remote_name_for_push_in_repo(repo)?;
let remote = repo.find_remote(&name)?;
Ok(remote)
}
pub(crate) fn get_default_remote_name_for_push_in_repo(repo: &Repository) -> Result<String> {
let config = repo.config()?;
let branch = get_current_branch(repo)?;
if let Some(branch) = branch {
let remote_name = bytes_to_string(branch.name_bytes()?)?;
let entry_name = format!("branch.{}.pushRemote", &remote_name);
if let Ok(entry) = config.get_entry(&entry_name) {
return bytes_to_string(entry.value_bytes());
}
if let Ok(entry) = config.get_entry("remote.pushDefault") {
return bytes_to_string(entry.value_bytes());
}
let entry_name = format!("branch.{}.remote", &remote_name);
if let Ok(entry) = config.get_entry(&entry_name) {
return bytes_to_string(entry.value_bytes());
}
}
get_default_remote_name_in_repo(repo)
}
pub(crate) fn get_default_remote_in_repo(repo: &'_ Repository) -> Result<Remote<'_>> {
let name = get_default_remote_name_in_repo(repo)?;
let remote = repo.find_remote(&name)?;
Ok(remote)
}
pub(crate) fn get_default_remote_name_in_repo(repo: &Repository) -> Result<String> {
let remotes = repo.remotes()?;
if remotes.is_empty() {
return Err(NoDefaultRemoteFound);
}
// if `origin` exists return that
let found_origin = remotes.iter().any(|r| r.is_some_and(|r| r == DEFAULT_REMOTE_NAME));
if found_origin {
return Ok(DEFAULT_REMOTE_NAME.into());
}
// if only one remote exists, pick that
if remotes.len() == 1 {
let first_remote = remotes
.iter()
.next()
.flatten()
.map(String::from)
.ok_or_else(|| GenericError("no remote found".into()))?;
return Ok(first_remote);
}
// inconclusive
Err(NoDefaultRemoteFound)
}

View File

@@ -0,0 +1,26 @@
[package]
name = "yaak-grpc"
version = "0.1.0"
edition = "2024"
publish = false
[dependencies]
anyhow = "1.0.97"
async-recursion = "1.1.1"
dunce = "1.0.4"
hyper-rustls = { version = "0.27.7", default-features = false, features = ["http2"] }
hyper-util = { version = "0.1.13", default-features = false, features = ["client-legacy"] }
log = { workspace = true }
md5 = "0.7.0"
prost = "0.13.4"
prost-reflect = { version = "0.14.4", default-features = false, features = ["serde", "derive"] }
prost-types = "0.13.4"
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "fs", "process"] }
tokio-stream = "0.1.14"
tonic = { version = "0.12.3", default-features = false, features = ["transport"] }
tonic-reflection = "0.12.3"
uuid = { version = "1.7.0", features = ["v4"] }
yaak-tls = { workspace = true }
thiserror = "2.0.17"

View File

@@ -0,0 +1,60 @@
use log::error;
pub(crate) fn collect_any_types(json: &str, out: &mut Vec<String>) {
let value = match serde_json::from_str(json).map_err(|e| e.to_string()) {
Ok(v) => v,
Err(e) => {
error!("Failed to parse gRPC message JSON: {e:?}");
return;
}
};
collect_any_types_value(&value, out);
}
fn collect_any_types_value(json: &serde_json::Value, out: &mut Vec<String>) {
match json {
serde_json::Value::Object(map) => {
if let Some(t) = map.get("@type").and_then(|v| v.as_str()) {
if let Some(full_name) = t.rsplit_once('/').map(|(_, n)| n) {
out.push(full_name.to_string());
}
}
for v in map.values() {
collect_any_types_value(v, out);
}
}
serde_json::Value::Array(arr) => {
for v in arr {
collect_any_types_value(v, out);
}
}
_ => {}
}
}
// Write tests for this
#[cfg(test)]
mod tests {
#[test]
fn test_collect_any_types() {
let json = r#"{
"mounts": [
{
"mountSource": {
"@type": "type.googleapis.com/mount_source.MountSourceRBDVolume",
"volumeID": "volumes/rbd"
}
}
],
"foo": {
"@type": "type.googleapis.com/foo.bar",
"foo": "fooo"
}
}"#;
let mut out = Vec::new();
super::collect_any_types(json, &mut out);
assert_eq!(out, vec!["foo.bar", "mount_source.MountSourceRBDVolume"]);
}
}

View File

@@ -0,0 +1,182 @@
use crate::error::Error::GenericError;
use crate::error::Result;
use crate::manager::decorate_req;
use crate::transport::get_transport;
use async_recursion::async_recursion;
use hyper_rustls::HttpsConnector;
use hyper_util::client::legacy::Client;
use hyper_util::client::legacy::connect::HttpConnector;
use log::debug;
use std::collections::BTreeMap;
use tokio_stream::StreamExt;
use tonic::Request;
use tonic::body::BoxBody;
use tonic::transport::Uri;
use tonic_reflection::pb::v1::server_reflection_request::MessageRequest;
use tonic_reflection::pb::v1::server_reflection_response::MessageResponse;
use tonic_reflection::pb::v1::{
ErrorResponse, ExtensionNumberResponse, ListServiceResponse, ServerReflectionRequest,
ServiceResponse,
};
use tonic_reflection::pb::v1::{ExtensionRequest, FileDescriptorResponse};
use tonic_reflection::pb::{v1, v1alpha};
use yaak_tls::ClientCertificateConfig;
pub struct AutoReflectionClient<T = Client<HttpsConnector<HttpConnector>, BoxBody>> {
use_v1alpha: bool,
client_v1: v1::server_reflection_client::ServerReflectionClient<T>,
client_v1alpha: v1alpha::server_reflection_client::ServerReflectionClient<T>,
}
impl AutoReflectionClient {
pub fn new(
uri: &Uri,
validate_certificates: bool,
client_cert: Option<ClientCertificateConfig>,
) -> Result<Self> {
let client_v1 = v1::server_reflection_client::ServerReflectionClient::with_origin(
get_transport(validate_certificates, client_cert.clone())?,
uri.clone(),
);
let client_v1alpha = v1alpha::server_reflection_client::ServerReflectionClient::with_origin(
get_transport(validate_certificates, client_cert.clone())?,
uri.clone(),
);
Ok(AutoReflectionClient { use_v1alpha: false, client_v1, client_v1alpha })
}
#[async_recursion]
pub async fn send_reflection_request(
&mut self,
message: MessageRequest,
metadata: &BTreeMap<String, String>,
) -> Result<MessageResponse> {
let reflection_request = ServerReflectionRequest {
host: "".into(), // Doesn't matter
message_request: Some(message.clone()),
};
if self.use_v1alpha {
let mut request =
Request::new(tokio_stream::once(to_v1alpha_request(reflection_request)));
decorate_req(metadata, &mut request)?;
self.client_v1alpha
.server_reflection_info(request)
.await
.map_err(|e| match e.code() {
tonic::Code::Unavailable => {
GenericError("Failed to connect to endpoint".to_string())
}
tonic::Code::Unauthenticated => {
GenericError("Authentication failed".to_string())
}
tonic::Code::DeadlineExceeded => GenericError("Deadline exceeded".to_string()),
_ => GenericError(e.to_string()),
})?
.into_inner()
.next()
.await
.ok_or(GenericError("Missing reflection message".to_string()))??
.message_response
.ok_or(GenericError("No reflection response".to_string()))
.map(|resp| to_v1_msg_response(resp))
} else {
let mut request = Request::new(tokio_stream::once(reflection_request));
decorate_req(metadata, &mut request)?;
let resp = self.client_v1.server_reflection_info(request).await;
match resp {
Ok(r) => Ok(r),
Err(e) => match e.code().clone() {
tonic::Code::Unimplemented => {
// If v1 fails, change to v1alpha and try again
debug!("gRPC schema reflection falling back to v1alpha");
self.use_v1alpha = true;
return self.send_reflection_request(message, metadata).await;
}
_ => Err(e),
},
}
.map_err(|e| match e.code() {
tonic::Code::Unavailable => {
GenericError("Failed to connect to endpoint".to_string())
}
tonic::Code::Unauthenticated => GenericError("Authentication failed".to_string()),
tonic::Code::DeadlineExceeded => GenericError("Deadline exceeded".to_string()),
_ => GenericError(e.to_string()),
})?
.into_inner()
.next()
.await
.ok_or(GenericError("Missing reflection message".to_string()))??
.message_response
.ok_or(GenericError("No reflection response".to_string()))
}
}
}
fn to_v1_msg_response(
response: v1alpha::server_reflection_response::MessageResponse,
) -> MessageResponse {
match response {
v1alpha::server_reflection_response::MessageResponse::FileDescriptorResponse(v) => {
MessageResponse::FileDescriptorResponse(FileDescriptorResponse {
file_descriptor_proto: v.file_descriptor_proto,
})
}
v1alpha::server_reflection_response::MessageResponse::AllExtensionNumbersResponse(v) => {
MessageResponse::AllExtensionNumbersResponse(ExtensionNumberResponse {
extension_number: v.extension_number,
base_type_name: v.base_type_name,
})
}
v1alpha::server_reflection_response::MessageResponse::ListServicesResponse(v) => {
MessageResponse::ListServicesResponse(ListServiceResponse {
service: v
.service
.iter()
.map(|s| ServiceResponse { name: s.name.clone() })
.collect(),
})
}
v1alpha::server_reflection_response::MessageResponse::ErrorResponse(v) => {
MessageResponse::ErrorResponse(ErrorResponse {
error_code: v.error_code,
error_message: v.error_message,
})
}
}
}
fn to_v1alpha_request(request: ServerReflectionRequest) -> v1alpha::ServerReflectionRequest {
v1alpha::ServerReflectionRequest {
host: request.host,
message_request: request.message_request.map(|m| to_v1alpha_msg_request(m)),
}
}
fn to_v1alpha_msg_request(
message: MessageRequest,
) -> v1alpha::server_reflection_request::MessageRequest {
match message {
MessageRequest::FileByFilename(v) => {
v1alpha::server_reflection_request::MessageRequest::FileByFilename(v)
}
MessageRequest::FileContainingSymbol(v) => {
v1alpha::server_reflection_request::MessageRequest::FileContainingSymbol(v)
}
MessageRequest::FileContainingExtension(ExtensionRequest {
extension_number,
containing_type,
}) => v1alpha::server_reflection_request::MessageRequest::FileContainingExtension(
v1alpha::ExtensionRequest { extension_number, containing_type },
),
MessageRequest::AllExtensionNumbersOfType(v) => {
v1alpha::server_reflection_request::MessageRequest::AllExtensionNumbersOfType(v)
}
MessageRequest::ListServices(v) => {
v1alpha::server_reflection_request::MessageRequest::ListServices(v)
}
}
}

View File

@@ -0,0 +1,50 @@
use prost_reflect::prost::Message;
use prost_reflect::{DynamicMessage, MethodDescriptor};
use tonic::Status;
use tonic::codec::{Codec, DecodeBuf, Decoder, EncodeBuf, Encoder};
#[derive(Clone)]
pub struct DynamicCodec(MethodDescriptor);
impl DynamicCodec {
#[allow(dead_code)]
pub fn new(md: MethodDescriptor) -> Self {
Self(md)
}
}
impl Codec for DynamicCodec {
type Encode = DynamicMessage;
type Decode = DynamicMessage;
type Encoder = Self;
type Decoder = Self;
fn encoder(&mut self) -> Self::Encoder {
self.clone()
}
fn decoder(&mut self) -> Self::Decoder {
self.clone()
}
}
impl Encoder for DynamicCodec {
type Item = DynamicMessage;
type Error = Status;
fn encode(&mut self, item: Self::Item, dst: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
item.encode(dst).expect("buffer is too small to decode this message");
Ok(())
}
}
impl Decoder for DynamicCodec {
type Item = DynamicMessage;
type Error = Status;
fn decode(&mut self, src: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
let mut msg = DynamicMessage::new(self.0.output());
msg.merge(src).map_err(|err| Status::internal(err.to_string()))?;
Ok(Some(msg))
}
}

View File

@@ -0,0 +1,51 @@
use crate::manager::GrpcStreamError;
use prost::DecodeError;
use serde::{Serialize, Serializer};
use serde_json::Error as SerdeJsonError;
use std::io;
use thiserror::Error;
use tonic::Status;
#[derive(Error, Debug)]
pub enum Error {
#[error(transparent)]
TlsError(#[from] yaak_tls::error::Error),
#[error(transparent)]
TonicError(#[from] Status),
#[error("Prost reflect error: {0:?}")]
ProstReflectError(#[from] prost_reflect::DescriptorError),
#[error(transparent)]
DeserializerError(#[from] SerdeJsonError),
#[error(transparent)]
GrpcStreamError(#[from] GrpcStreamError),
#[error(transparent)]
GrpcDecodeError(#[from] DecodeError),
#[error(transparent)]
GrpcInvalidMetadataKeyError(#[from] tonic::metadata::errors::InvalidMetadataKey),
#[error(transparent)]
GrpcInvalidMetadataValueError(#[from] tonic::metadata::errors::InvalidMetadataValue),
#[error(transparent)]
IOError(#[from] io::Error),
#[error("GRPC error: {0}")]
GenericError(String),
}
impl Serialize for Error {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.to_string().as_ref())
}
}
pub type Result<T> = std::result::Result<T, Error>;

View File

@@ -0,0 +1,382 @@
use prost_reflect::{DescriptorPool, FieldDescriptor, MessageDescriptor};
use std::collections::{HashMap, HashSet, VecDeque};
pub fn message_to_json_schema(_: &DescriptorPool, root_msg: MessageDescriptor) -> JsonSchemaEntry {
JsonSchemaGenerator::generate_json_schema(root_msg)
}
struct JsonSchemaGenerator {
msg_mapping: HashMap<String, JsonSchemaEntry>,
}
impl JsonSchemaGenerator {
pub fn new() -> Self {
JsonSchemaGenerator { msg_mapping: HashMap::new() }
}
pub fn generate_json_schema(msg: MessageDescriptor) -> JsonSchemaEntry {
let generator = JsonSchemaGenerator::new();
generator.scan_root(msg)
}
fn add_message(&mut self, msg: &MessageDescriptor) {
let name = msg.full_name().to_string();
if self.msg_mapping.contains_key(&name) {
return;
}
self.msg_mapping.insert(name.clone(), JsonSchemaEntry::object());
}
pub fn scan_root(mut self, root_msg: MessageDescriptor) -> JsonSchemaEntry {
self.init_structure(root_msg.clone());
self.fill_properties(root_msg.clone());
let mut root = self.msg_mapping.remove(root_msg.full_name()).unwrap();
if self.msg_mapping.len() > 0 {
root.defs = Some(self.msg_mapping);
}
root
}
fn fill_properties(&mut self, root_msg: MessageDescriptor) {
let root_name = root_msg.full_name().to_string();
let mut visited = HashSet::new();
let mut msg_queue = VecDeque::new();
msg_queue.push_back(root_msg);
while !msg_queue.is_empty() {
let msg = msg_queue.pop_front().unwrap();
let msg_name = msg.full_name();
if visited.contains(msg_name) {
continue;
}
visited.insert(msg_name.to_string());
let entry = self.msg_mapping.get_mut(msg_name).unwrap();
for field in msg.fields() {
let field_name = field.name().to_string();
if matches!(field.cardinality(), prost_reflect::Cardinality::Required) {
entry.add_required(field_name.clone());
}
if let Some(oneof) = field.containing_oneof() {
for oneof_field in oneof.fields() {
if let Some(fm) = is_message_field(&oneof_field) {
msg_queue.push_back(fm);
}
entry.add_property(
oneof_field.name().to_string(),
field_to_type_or_ref(&root_name, oneof_field),
);
}
continue;
}
let (field_type, nest_msg) = {
if let Some(fm) = is_message_field(&field) {
if field.is_list() {
// repeated message type
(
JsonSchemaEntry::array(field_to_type_or_ref(&root_name, field)),
Some(fm),
)
} else if field.is_map() {
let value_field = fm.get_field_by_name("value").unwrap();
if let Some(fm) = is_message_field(&value_field) {
(
JsonSchemaEntry::map(field_to_type_or_ref(
&root_name,
value_field,
)),
Some(fm),
)
} else {
(
JsonSchemaEntry::map(field_to_type_or_ref(
&root_name,
value_field,
)),
None,
)
}
} else {
(field_to_type_or_ref(&root_name, field), Some(fm))
}
} else {
if field.is_list() {
// repeated scalar type
(JsonSchemaEntry::array(field_to_type_or_ref(&root_name, field)), None)
} else {
(field_to_type_or_ref(&root_name, field), None)
}
}
};
if let Some(fm) = nest_msg {
msg_queue.push_back(fm);
}
entry.add_property(field_name, field_type);
}
}
}
fn init_structure(&mut self, root_msg: MessageDescriptor) {
let mut visited = HashSet::new();
let mut msg_queue = VecDeque::new();
msg_queue.push_back(root_msg.clone());
// level traversal, to make sure all message type is defined before used
while !msg_queue.is_empty() {
let msg = msg_queue.pop_front().unwrap();
let name = msg.full_name();
if visited.contains(name) {
continue;
}
visited.insert(name.to_string());
self.add_message(&msg);
for child in msg.child_messages() {
if child.is_map_entry() {
// for field with map<key, value> type, there will be a child message type *Entry generated
// just skip it
continue;
}
self.add_message(&child);
msg_queue.push_back(child);
}
for field in msg.fields() {
if let Some(oneof) = field.containing_oneof() {
for oneof_field in oneof.fields() {
if let Some(fm) = is_message_field(&oneof_field) {
self.add_message(&fm);
msg_queue.push_back(fm);
}
}
continue;
}
if field.is_map() {
// key is always scalar type, so no need to process
// value can be any type, so need to unpack value type
let map_field_msg = is_message_field(&field).unwrap();
let map_value_field = map_field_msg.get_field_by_name("value").unwrap();
if let Some(value_fm) = is_message_field(&map_value_field) {
self.add_message(&value_fm);
msg_queue.push_back(value_fm);
}
continue;
}
if let Some(fm) = is_message_field(&field) {
self.add_message(&fm);
msg_queue.push_back(fm);
}
}
}
}
}
fn field_to_type_or_ref(root_name: &str, field: FieldDescriptor) -> JsonSchemaEntry {
match field.kind() {
prost_reflect::Kind::Bool => JsonSchemaEntry::boolean(),
prost_reflect::Kind::Double => JsonSchemaEntry::number("double"),
prost_reflect::Kind::Float => JsonSchemaEntry::number("float"),
prost_reflect::Kind::Int32 => JsonSchemaEntry::number("int32"),
prost_reflect::Kind::Int64 => JsonSchemaEntry::string_with_format("int64"),
prost_reflect::Kind::Uint32 => JsonSchemaEntry::number("int64"),
prost_reflect::Kind::Uint64 => JsonSchemaEntry::string_with_format("uint64"),
prost_reflect::Kind::Sint32 => JsonSchemaEntry::number("sint32"),
prost_reflect::Kind::Sint64 => JsonSchemaEntry::string_with_format("sint64"),
prost_reflect::Kind::Fixed32 => JsonSchemaEntry::number("int64"),
prost_reflect::Kind::Fixed64 => JsonSchemaEntry::string_with_format("fixed64"),
prost_reflect::Kind::Sfixed32 => JsonSchemaEntry::number("sfixed32"),
prost_reflect::Kind::Sfixed64 => JsonSchemaEntry::string_with_format("sfixed64"),
prost_reflect::Kind::String => JsonSchemaEntry::string(),
prost_reflect::Kind::Bytes => JsonSchemaEntry::string_with_format("byte"),
prost_reflect::Kind::Enum(enums) => {
let values = enums.values().map(|v| v.name().to_string()).collect::<Vec<_>>();
JsonSchemaEntry::enums(values)
}
prost_reflect::Kind::Message(fm) => {
let field_type_full_name = fm.full_name();
match field_type_full_name {
// [Protocol Buffers Well-Known Types]: https://protobuf.dev/reference/protobuf/google.protobuf/
"google.protobuf.FieldMask" => JsonSchemaEntry::string(),
"google.protobuf.Timestamp" => JsonSchemaEntry::string_with_format("date-time"),
"google.protobuf.Duration" => JsonSchemaEntry::string(),
"google.protobuf.StringValue" => JsonSchemaEntry::string(),
"google.protobuf.BytesValue" => JsonSchemaEntry::string_with_format("byte"),
"google.protobuf.Int32Value" => JsonSchemaEntry::number("int32"),
"google.protobuf.UInt32Value" => JsonSchemaEntry::string_with_format("int64"),
"google.protobuf.Int64Value" => JsonSchemaEntry::string_with_format("int64"),
"google.protobuf.UInt64Value" => JsonSchemaEntry::string_with_format("uint64"),
"google.protobuf.FloatValue" => JsonSchemaEntry::number("float"),
"google.protobuf.DoubleValue" => JsonSchemaEntry::number("double"),
"google.protobuf.BoolValue" => JsonSchemaEntry::boolean(),
"google.protobuf.Empty" => JsonSchemaEntry::default(),
"google.protobuf.Struct" => JsonSchemaEntry::object(),
"google.protobuf.ListValue" => JsonSchemaEntry::array(JsonSchemaEntry::default()),
"google.protobuf.NullValue" => JsonSchemaEntry::null(),
name @ _ if name == root_name => JsonSchemaEntry::root_reference(),
_ => JsonSchemaEntry::reference(fm.full_name()),
}
}
}
}
fn is_message_field(field: &FieldDescriptor) -> Option<MessageDescriptor> {
match field.kind() {
prost_reflect::Kind::Message(m) => Some(m),
_ => None,
}
}
#[derive(Default, serde::Serialize)]
#[serde(default, rename_all = "camelCase")]
pub struct JsonSchemaEntry {
#[serde(skip_serializing_if = "Option::is_none")]
title: Option<String>,
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
type_: Option<JsonType>,
#[serde(skip_serializing_if = "Option::is_none")]
format: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
properties: Option<HashMap<String, JsonSchemaEntry>>,
#[serde(rename = "enum", skip_serializing_if = "Option::is_none")]
enum_: Option<Vec<String>>,
// for map type
#[serde(skip_serializing_if = "Option::is_none")]
additional_properties: Option<Box<JsonSchemaEntry>>,
// Set all properties to required
#[serde(skip_serializing_if = "Option::is_none")]
required: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
items: Option<Box<JsonSchemaEntry>>,
#[serde(skip_serializing_if = "Option::is_none", rename = "$defs")]
defs: Option<HashMap<String, JsonSchemaEntry>>,
#[serde(skip_serializing_if = "Option::is_none", rename = "$ref")]
ref_: Option<String>,
}
impl JsonSchemaEntry {
pub fn add_property(&mut self, name: String, entry: JsonSchemaEntry) {
if self.properties.is_none() {
self.properties = Some(HashMap::new());
}
self.properties.as_mut().unwrap().insert(name, entry);
}
pub fn add_required(&mut self, name: String) {
if self.required.is_none() {
self.required = Some(Vec::new());
}
self.required.as_mut().unwrap().push(name);
}
}
impl JsonSchemaEntry {
pub fn object() -> Self {
JsonSchemaEntry { type_: Some(JsonType::Object), ..Default::default() }
}
pub fn boolean() -> Self {
JsonSchemaEntry { type_: Some(JsonType::Boolean), ..Default::default() }
}
pub fn number<S: Into<String>>(format: S) -> Self {
JsonSchemaEntry {
type_: Some(JsonType::Number),
format: Some(format.into()),
..Default::default()
}
}
pub fn string() -> Self {
JsonSchemaEntry { type_: Some(JsonType::String), ..Default::default() }
}
pub fn string_with_format<S: Into<String>>(format: S) -> Self {
JsonSchemaEntry {
type_: Some(JsonType::String),
format: Some(format.into()),
..Default::default()
}
}
pub fn reference<S: AsRef<str>>(ref_: S) -> Self {
JsonSchemaEntry { ref_: Some(format!("#/$defs/{}", ref_.as_ref())), ..Default::default() }
}
pub fn root_reference() -> Self {
JsonSchemaEntry { ref_: Some("#".to_string()), ..Default::default() }
}
pub fn array(item: JsonSchemaEntry) -> Self {
JsonSchemaEntry {
type_: Some(JsonType::Array),
items: Some(Box::new(item)),
..Default::default()
}
}
pub fn enums(enums: Vec<String>) -> Self {
JsonSchemaEntry { type_: Some(JsonType::String), enum_: Some(enums), ..Default::default() }
}
pub fn map(value_type: JsonSchemaEntry) -> Self {
JsonSchemaEntry {
type_: Some(JsonType::Object),
additional_properties: Some(Box::new(value_type)),
..Default::default()
}
}
pub fn null() -> Self {
JsonSchemaEntry { type_: Some(JsonType::Null), ..Default::default() }
}
}
enum JsonType {
String,
Number,
Object,
Array,
Boolean,
Null,
_UNKNOWN,
}
impl Default for JsonType {
fn default() -> Self {
JsonType::_UNKNOWN
}
}
impl serde::Serialize for JsonType {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
JsonType::String => serializer.serialize_str("string"),
JsonType::Number => serializer.serialize_str("number"),
JsonType::Object => serializer.serialize_str("object"),
JsonType::Array => serializer.serialize_str("array"),
JsonType::Boolean => serializer.serialize_str("boolean"),
JsonType::Null => serializer.serialize_str("null"),
JsonType::_UNKNOWN => serializer.serialize_str("unknown"),
}
}
}

View File

@@ -0,0 +1,54 @@
use prost_reflect::{DynamicMessage, MethodDescriptor, SerializeOptions};
use serde::{Deserialize, Serialize};
use serde_json::Deserializer;
mod any;
mod client;
mod codec;
pub mod error;
mod json_schema;
pub mod manager;
mod reflection;
mod transport;
pub use tonic::Code;
pub use tonic::metadata::*;
pub fn serialize_options() -> SerializeOptions {
SerializeOptions::new().skip_default_fields(false)
}
#[derive(Serialize, Deserialize, Debug, Default)]
#[serde(default, rename_all = "camelCase")]
pub struct ServiceDefinition {
pub name: String,
pub methods: Vec<MethodDefinition>,
}
#[derive(Serialize, Deserialize, Debug, Default)]
#[serde(default, rename_all = "camelCase")]
pub struct MethodDefinition {
pub name: String,
pub schema: String,
pub client_streaming: bool,
pub server_streaming: bool,
}
static SERIALIZE_OPTIONS: &'static SerializeOptions =
&SerializeOptions::new().skip_default_fields(false).stringify_64_bit_integers(false);
pub fn serialize_message(msg: &DynamicMessage) -> Result<String, String> {
let mut buf = Vec::new();
let mut se = serde_json::Serializer::pretty(&mut buf);
msg.serialize_with_options(&mut se, SERIALIZE_OPTIONS).map_err(|e| e.to_string())?;
let s = String::from_utf8(buf).expect("serde_json to emit valid utf8");
Ok(s)
}
pub fn deserialize_message(msg: &str, method: MethodDescriptor) -> Result<DynamicMessage, String> {
let mut deserializer = Deserializer::from_str(&msg);
let req_message = DynamicMessage::deserialize(method.input(), &mut deserializer)
.map_err(|e| e.to_string())?;
deserializer.end().map_err(|e| e.to_string())?;
Ok(req_message)
}

View File

@@ -0,0 +1,426 @@
use crate::codec::DynamicCodec;
use crate::error::Error::GenericError;
use crate::error::Result;
use crate::reflection::{
fill_pool_from_files, fill_pool_from_reflection, method_desc_to_path, reflect_types_for_message,
};
use crate::transport::get_transport;
use crate::{MethodDefinition, ServiceDefinition, json_schema};
use hyper_rustls::HttpsConnector;
use hyper_util::client::legacy::Client;
use hyper_util::client::legacy::connect::HttpConnector;
use log::{info, warn};
pub use prost_reflect::DynamicMessage;
use prost_reflect::{DescriptorPool, MethodDescriptor, ServiceDescriptor};
use serde_json::Deserializer;
use std::collections::BTreeMap;
use std::error::Error;
use std::fmt;
use std::fmt::Display;
use std::path::PathBuf;
use std::str::FromStr;
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio_stream::StreamExt;
use tokio_stream::wrappers::ReceiverStream;
use tonic::body::BoxBody;
use tonic::metadata::{MetadataKey, MetadataValue};
use tonic::transport::Uri;
use tonic::{IntoRequest, IntoStreamingRequest, Request, Response, Status, Streaming};
use yaak_tls::ClientCertificateConfig;
#[derive(Clone)]
pub struct GrpcConnection {
pool: Arc<RwLock<DescriptorPool>>,
conn: Client<HttpsConnector<HttpConnector>, BoxBody>,
pub uri: Uri,
use_reflection: bool,
}
#[derive(Default, Debug)]
pub struct GrpcStreamError {
pub message: String,
pub status: Option<Status>,
}
impl Error for GrpcStreamError {}
impl Display for GrpcStreamError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.status {
Some(status) => write!(f, "[{}] {}", status, self.message),
None => write!(f, "{}", self.message),
}
}
}
impl From<String> for GrpcStreamError {
fn from(value: String) -> Self {
GrpcStreamError { message: value.to_string(), status: None }
}
}
impl From<Status> for GrpcStreamError {
fn from(s: Status) -> Self {
GrpcStreamError { message: s.message().to_string(), status: Some(s) }
}
}
impl GrpcConnection {
pub async fn method(&self, service: &str, method: &str) -> Result<MethodDescriptor> {
let service = self.service(service).await?;
let method = service
.methods()
.find(|m| m.name() == method)
.ok_or(GenericError("Failed to find method".to_string()))?;
Ok(method)
}
async fn service(&self, service: &str) -> Result<ServiceDescriptor> {
let pool = self.pool.read().await;
let service = pool
.get_service_by_name(service)
.ok_or(GenericError("Failed to find service".to_string()))?;
Ok(service)
}
pub async fn unary(
&self,
service: &str,
method: &str,
message: &str,
metadata: &BTreeMap<String, String>,
client_cert: Option<ClientCertificateConfig>,
) -> Result<Response<DynamicMessage>> {
if self.use_reflection {
reflect_types_for_message(self.pool.clone(), &self.uri, message, metadata, client_cert)
.await?;
}
let method = &self.method(&service, &method).await?;
let input_message = method.input();
let mut deserializer = Deserializer::from_str(message);
let req_message = DynamicMessage::deserialize(input_message, &mut deserializer)?;
deserializer.end()?;
let mut client = tonic::client::Grpc::with_origin(self.conn.clone(), self.uri.clone());
let mut req = req_message.into_request();
decorate_req(metadata, &mut req)?;
let path = method_desc_to_path(method);
let codec = DynamicCodec::new(method.clone());
client.ready().await.map_err(|e| GenericError(format!("Failed to connect: {}", e)))?;
Ok(client.unary(req, path, codec).await?)
}
pub async fn streaming(
&self,
service: &str,
method: &str,
stream: ReceiverStream<String>,
metadata: &BTreeMap<String, String>,
client_cert: Option<ClientCertificateConfig>,
) -> Result<Response<Streaming<DynamicMessage>>> {
let method = &self.method(&service, &method).await?;
let mapped_stream = {
let input_message = method.input();
let pool = self.pool.clone();
let uri = self.uri.clone();
let md = metadata.clone();
let use_reflection = self.use_reflection.clone();
let client_cert = client_cert.clone();
stream.filter_map(move |json| {
let pool = pool.clone();
let uri = uri.clone();
let input_message = input_message.clone();
let md = md.clone();
let use_reflection = use_reflection.clone();
let client_cert = client_cert.clone();
tokio::runtime::Handle::current().block_on(async move {
if use_reflection {
if let Err(e) =
reflect_types_for_message(pool, &uri, &json, &md, client_cert).await
{
warn!("Failed to resolve Any types: {e}");
}
}
let mut de = Deserializer::from_str(&json);
match DynamicMessage::deserialize(input_message, &mut de) {
Ok(m) => Some(m),
Err(e) => {
warn!("Failed to deserialize message: {e}");
None
}
}
})
})
};
let mut client = tonic::client::Grpc::with_origin(self.conn.clone(), self.uri.clone());
let path = method_desc_to_path(method);
let codec = DynamicCodec::new(method.clone());
let mut req = mapped_stream.into_streaming_request();
decorate_req(metadata, &mut req)?;
client.ready().await.map_err(|e| GenericError(format!("Failed to connect: {}", e)))?;
Ok(client.streaming(req, path, codec).await?)
}
pub async fn client_streaming(
&self,
service: &str,
method: &str,
stream: ReceiverStream<String>,
metadata: &BTreeMap<String, String>,
client_cert: Option<ClientCertificateConfig>,
) -> Result<Response<DynamicMessage>> {
let method = &self.method(&service, &method).await?;
let mapped_stream = {
let input_message = method.input();
let pool = self.pool.clone();
let uri = self.uri.clone();
let md = metadata.clone();
let use_reflection = self.use_reflection.clone();
let client_cert = client_cert.clone();
stream.filter_map(move |json| {
let pool = pool.clone();
let uri = uri.clone();
let input_message = input_message.clone();
let md = md.clone();
let use_reflection = use_reflection.clone();
let client_cert = client_cert.clone();
tokio::runtime::Handle::current().block_on(async move {
if use_reflection {
if let Err(e) =
reflect_types_for_message(pool, &uri, &json, &md, client_cert).await
{
warn!("Failed to resolve Any types: {e}");
}
}
let mut de = Deserializer::from_str(&json);
match DynamicMessage::deserialize(input_message, &mut de) {
Ok(m) => Some(m),
Err(e) => {
warn!("Failed to deserialize message: {e}");
None
}
}
})
})
};
let mut client = tonic::client::Grpc::with_origin(self.conn.clone(), self.uri.clone());
let path = method_desc_to_path(method);
let codec = DynamicCodec::new(method.clone());
let mut req = mapped_stream.into_streaming_request();
decorate_req(metadata, &mut req)?;
client.ready().await.map_err(|e| GenericError(format!("Failed to connect: {}", e)))?;
Ok(client
.client_streaming(req, path, codec)
.await
.map_err(|e| GrpcStreamError { message: e.message().to_string(), status: Some(e) })?)
}
pub async fn server_streaming(
&self,
service: &str,
method: &str,
message: &str,
metadata: &BTreeMap<String, String>,
) -> Result<Response<Streaming<DynamicMessage>>> {
let method = &self.method(&service, &method).await?;
let input_message = method.input();
let mut deserializer = Deserializer::from_str(message);
let req_message = DynamicMessage::deserialize(input_message, &mut deserializer)?;
deserializer.end()?;
let mut client = tonic::client::Grpc::with_origin(self.conn.clone(), self.uri.clone());
let mut req = req_message.into_request();
decorate_req(metadata, &mut req)?;
let path = method_desc_to_path(method);
let codec = DynamicCodec::new(method.clone());
client.ready().await.map_err(|e| GenericError(format!("Failed to connect: {}", e)))?;
Ok(client.server_streaming(req, path, codec).await?)
}
}
/// Configuration for GrpcHandle to compile proto files
#[derive(Clone)]
pub struct GrpcConfig {
/// Path to the protoc include directory (vendored/protoc/include)
pub protoc_include_dir: PathBuf,
/// Path to the yaakprotoc sidecar binary
pub protoc_bin_path: PathBuf,
}
pub struct GrpcHandle {
config: GrpcConfig,
pools: BTreeMap<String, DescriptorPool>,
}
impl GrpcHandle {
pub fn new(config: GrpcConfig) -> Self {
let pools = BTreeMap::new();
Self { pools, config }
}
}
impl GrpcHandle {
/// Remove cached descriptor pool for the given key, if present.
pub fn invalidate_pool(&mut self, id: &str, uri: &str, proto_files: &Vec<PathBuf>) {
let key = make_pool_key(id, uri, proto_files);
self.pools.remove(&key);
}
pub async fn reflect(
&mut self,
id: &str,
uri: &str,
proto_files: &Vec<PathBuf>,
metadata: &BTreeMap<String, String>,
validate_certificates: bool,
client_cert: Option<ClientCertificateConfig>,
) -> Result<bool> {
let server_reflection = proto_files.is_empty();
let key = make_pool_key(id, uri, proto_files);
// If we already have a pool for this key, reuse it and avoid re-reflection
if self.pools.contains_key(&key) {
return Ok(server_reflection);
}
let pool = if server_reflection {
let full_uri = uri_from_str(uri)?;
fill_pool_from_reflection(&full_uri, metadata, validate_certificates, client_cert).await
} else {
fill_pool_from_files(&self.config, proto_files).await
}?;
self.pools.insert(key, pool.clone());
Ok(server_reflection)
}
pub async fn services(
&mut self,
id: &str,
uri: &str,
proto_files: &Vec<PathBuf>,
metadata: &BTreeMap<String, String>,
validate_certificates: bool,
client_cert: Option<ClientCertificateConfig>,
skip_cache: bool,
) -> Result<Vec<ServiceDefinition>> {
// Ensure we have a pool; reflect only if missing
if skip_cache || self.get_pool(id, uri, proto_files).is_none() {
info!("Reflecting gRPC services for {} at {}", id, uri);
self.reflect(id, uri, proto_files, metadata, validate_certificates, client_cert)
.await?;
}
let pool = self
.get_pool(id, uri, proto_files)
.ok_or(GenericError("Failed to get pool".to_string()))?;
Ok(self.services_from_pool(&pool))
}
fn services_from_pool(&self, pool: &DescriptorPool) -> Vec<ServiceDefinition> {
pool.services()
.map(|s| {
let mut def =
ServiceDefinition { name: s.full_name().to_string(), methods: vec![] };
for method in s.methods() {
let input_message = method.input();
def.methods.push(MethodDefinition {
name: method.name().to_string(),
server_streaming: method.is_server_streaming(),
client_streaming: method.is_client_streaming(),
schema: serde_json::to_string_pretty(&json_schema::message_to_json_schema(
&pool,
input_message,
))
.expect("Failed to serialize JSON schema"),
})
}
def
})
.collect::<Vec<_>>()
}
pub async fn connect(
&mut self,
id: &str,
uri: &str,
proto_files: &Vec<PathBuf>,
metadata: &BTreeMap<String, String>,
validate_certificates: bool,
client_cert: Option<ClientCertificateConfig>,
) -> Result<GrpcConnection> {
let use_reflection = proto_files.is_empty();
if self.get_pool(id, uri, proto_files).is_none() {
self.reflect(
id,
uri,
proto_files,
metadata,
validate_certificates,
client_cert.clone(),
)
.await?;
}
let pool = self
.get_pool(id, uri, proto_files)
.ok_or(GenericError("Failed to get pool".to_string()))?
.clone();
let uri = uri_from_str(uri)?;
let conn = get_transport(validate_certificates, client_cert.clone())?;
Ok(GrpcConnection { pool: Arc::new(RwLock::new(pool)), use_reflection, conn, uri })
}
fn get_pool(&self, id: &str, uri: &str, proto_files: &Vec<PathBuf>) -> Option<&DescriptorPool> {
self.pools.get(make_pool_key(id, uri, proto_files).as_str())
}
}
pub(crate) fn decorate_req<T>(
metadata: &BTreeMap<String, String>,
req: &mut Request<T>,
) -> Result<()> {
for (k, v) in metadata {
req.metadata_mut()
.insert(MetadataKey::from_str(k.as_str())?, MetadataValue::from_str(v.as_str())?);
}
Ok(())
}
fn uri_from_str(uri_str: &str) -> Result<Uri> {
match Uri::from_str(uri_str) {
Ok(uri) => Ok(uri),
Err(err) => {
// Uri::from_str basically only returns "invalid format" so we add more context here
Err(GenericError(format!("Failed to parse URL, {}", err.to_string())))
}
}
}
fn make_pool_key(id: &str, uri: &str, proto_files: &Vec<PathBuf>) -> String {
let pool_key = format!(
"{}::{}::{}",
id,
uri,
proto_files
.iter()
.map(|p| p.to_string_lossy().to_string())
.collect::<Vec<String>>()
.join(":")
);
format!("{:x}", md5::compute(pool_key))
}

View File

@@ -0,0 +1,452 @@
use crate::any::collect_any_types;
use crate::client::AutoReflectionClient;
use crate::error::Error::GenericError;
use crate::error::Result;
use crate::manager::GrpcConfig;
use anyhow::anyhow;
use async_recursion::async_recursion;
use log::{debug, info, warn};
use prost::Message;
use prost_reflect::{DescriptorPool, MethodDescriptor};
use prost_types::{FileDescriptorProto, FileDescriptorSet};
use std::collections::{BTreeMap, HashSet};
use std::env::temp_dir;
use std::ops::Deref;
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::sync::Arc;
use tokio::fs;
use tokio::process::Command;
use tokio::sync::RwLock;
use tonic::codegen::http::uri::PathAndQuery;
use tonic::transport::Uri;
use tonic_reflection::pb::v1::server_reflection_request::MessageRequest;
use tonic_reflection::pb::v1::server_reflection_response::MessageResponse;
use yaak_tls::ClientCertificateConfig;
pub async fn fill_pool_from_files(
config: &GrpcConfig,
paths: &Vec<PathBuf>,
) -> Result<DescriptorPool> {
let mut pool = DescriptorPool::new();
let random_file_name = format!("{}.desc", uuid::Uuid::new_v4());
let desc_path = temp_dir().join(random_file_name);
// HACK: Remove UNC prefix for Windows paths
let global_import_dir =
dunce::simplified(config.protoc_include_dir.as_path()).to_string_lossy().to_string();
let desc_path = dunce::simplified(desc_path.as_path());
let mut args = vec![
"--include_imports".to_string(),
"--include_source_info".to_string(),
"-I".to_string(),
global_import_dir,
"-o".to_string(),
desc_path.to_string_lossy().to_string(),
];
let mut include_dirs = HashSet::new();
let mut include_protos = HashSet::new();
for p in paths {
if !p.exists() {
continue;
}
// Dirs are added as includes
if p.is_dir() {
include_dirs.insert(p.to_string_lossy().to_string());
continue;
}
let parent = p.as_path().parent();
if let Some(parent_path) = parent {
match find_parent_proto_dir(parent_path) {
None => {
// Add parent/grandparent as fallback
include_dirs.insert(parent_path.to_string_lossy().to_string());
if let Some(grandparent_path) = parent_path.parent() {
include_dirs.insert(grandparent_path.to_string_lossy().to_string());
}
}
Some(p) => {
include_dirs.insert(p.to_string_lossy().to_string());
}
};
} else {
debug!("ignoring {:?} since it does not exist.", parent)
}
include_protos.insert(p.to_string_lossy().to_string());
}
for d in include_dirs.clone() {
args.push("-I".to_string());
args.push(d);
}
for p in include_protos.clone() {
args.push(p);
}
info!("Invoking protoc with {}", args.join(" "));
let out = Command::new(&config.protoc_bin_path)
.args(&args)
.output()
.await
.map_err(|e| GenericError(format!("Failed to run protoc: {}", e)))?;
if !out.status.success() {
return Err(GenericError(format!(
"protoc failed with status {}: {}",
out.status.code().unwrap_or(-1),
String::from_utf8_lossy(out.stderr.as_slice())
)));
}
let bytes = fs::read(desc_path).await?;
let fdp = FileDescriptorSet::decode(bytes.deref())?;
pool.add_file_descriptor_set(fdp)?;
fs::remove_file(desc_path).await?;
Ok(pool)
}
pub async fn fill_pool_from_reflection(
uri: &Uri,
metadata: &BTreeMap<String, String>,
validate_certificates: bool,
client_cert: Option<ClientCertificateConfig>,
) -> Result<DescriptorPool> {
let mut pool = DescriptorPool::new();
let mut client = AutoReflectionClient::new(uri, validate_certificates, client_cert)?;
for service in list_services(&mut client, metadata).await? {
if service == "grpc.reflection.v1alpha.ServerReflection" {
continue;
}
if service == "grpc.reflection.v1.ServerReflection" {
continue;
}
debug!("Fetching descriptors for {}", service);
file_descriptor_set_from_service_name(&service, &mut pool, &mut client, metadata).await;
}
Ok(pool)
}
async fn list_services(
client: &mut AutoReflectionClient,
metadata: &BTreeMap<String, String>,
) -> Result<Vec<String>> {
let response =
client.send_reflection_request(MessageRequest::ListServices("".into()), metadata).await?;
let list_services_response = match response {
MessageResponse::ListServicesResponse(resp) => resp,
_ => panic!("Expected a ListServicesResponse variant"),
};
Ok(list_services_response.service.iter().map(|s| s.name.clone()).collect::<Vec<_>>())
}
async fn file_descriptor_set_from_service_name(
service_name: &str,
pool: &mut DescriptorPool,
client: &mut AutoReflectionClient,
metadata: &BTreeMap<String, String>,
) {
let response = match client
.send_reflection_request(
MessageRequest::FileContainingSymbol(service_name.into()),
metadata,
)
.await
{
Ok(resp) => resp,
Err(e) => {
warn!("Error fetching file descriptor for service {}: {:?}", service_name, e);
return;
}
};
let file_descriptor_response = match response {
MessageResponse::FileDescriptorResponse(resp) => resp,
_ => panic!("Expected a FileDescriptorResponse variant"),
};
add_file_descriptors_to_pool(
file_descriptor_response.file_descriptor_proto,
pool,
client,
metadata,
)
.await;
}
pub(crate) async fn reflect_types_for_message(
pool: Arc<RwLock<DescriptorPool>>,
uri: &Uri,
json: &str,
metadata: &BTreeMap<String, String>,
client_cert: Option<ClientCertificateConfig>,
) -> Result<()> {
// 1. Collect all Any types in the JSON
let mut extra_types = Vec::new();
collect_any_types(json, &mut extra_types);
if extra_types.is_empty() {
return Ok(()); // nothing to do
}
let mut client = AutoReflectionClient::new(uri, false, client_cert)?;
for extra_type in extra_types {
{
let guard = pool.read().await;
if guard.get_message_by_name(&extra_type).is_some() {
continue;
}
}
info!("Adding file descriptor for {:?} from reflection", extra_type);
let req = MessageRequest::FileContainingSymbol(extra_type.clone().into());
let resp = match client.send_reflection_request(req, metadata).await {
Ok(r) => r,
Err(e) => {
return Err(GenericError(format!(
"Error sending reflection request for @type \"{extra_type}\": {e:?}",
)));
}
};
let files = match resp {
MessageResponse::FileDescriptorResponse(resp) => resp.file_descriptor_proto,
_ => panic!("Expected a FileDescriptorResponse variant"),
};
{
let mut guard = pool.write().await;
add_file_descriptors_to_pool(files, &mut *guard, &mut client, metadata).await;
}
}
Ok(())
}
#[async_recursion]
pub(crate) async fn add_file_descriptors_to_pool(
fds: Vec<Vec<u8>>,
pool: &mut DescriptorPool,
client: &mut AutoReflectionClient,
metadata: &BTreeMap<String, String>,
) {
let mut topo_sort = topology::SimpleTopoSort::new();
let mut fd_mapping = std::collections::HashMap::with_capacity(fds.len());
for fd in fds {
let fdp = FileDescriptorProto::decode(fd.deref()).unwrap();
topo_sort.insert(fdp.name().to_string(), fdp.dependency.clone());
fd_mapping.insert(fdp.name().to_string(), fdp);
}
for node in topo_sort {
match node {
Ok(node) => {
if let Some(fdp) = fd_mapping.remove(&node) {
pool.add_file_descriptor_proto(fdp).expect("add file descriptor proto");
} else {
file_descriptor_set_by_filename(node.as_str(), pool, client, metadata).await;
}
}
Err(_) => panic!("proto file got cycle!"),
}
}
}
async fn file_descriptor_set_by_filename(
filename: &str,
pool: &mut DescriptorPool,
client: &mut AutoReflectionClient,
metadata: &BTreeMap<String, String>,
) {
// We already fetched this file
if let Some(_) = pool.get_file_by_name(filename) {
return;
}
let msg = MessageRequest::FileByFilename(filename.into());
let response = client.send_reflection_request(msg, metadata).await;
let file_descriptor_response = match response {
Ok(MessageResponse::FileDescriptorResponse(resp)) => resp,
Ok(_) => {
panic!("Expected a FileDescriptorResponse variant")
}
Err(e) => {
warn!("Error fetching file descriptor for {}: {:?}", filename, e);
return;
}
};
add_file_descriptors_to_pool(
file_descriptor_response.file_descriptor_proto,
pool,
client,
metadata,
)
.await;
}
pub fn method_desc_to_path(md: &MethodDescriptor) -> PathAndQuery {
let full_name = md.full_name();
let (namespace, method_name) = full_name
.rsplit_once('.')
.ok_or_else(|| anyhow!("invalid method path"))
.expect("invalid method path");
PathAndQuery::from_str(&format!("/{}/{}", namespace, method_name)).expect("invalid method path")
}
mod topology {
use std::collections::{HashMap, HashSet};
pub struct SimpleTopoSort<T> {
out_graph: HashMap<T, HashSet<T>>,
in_graph: HashMap<T, HashSet<T>>,
}
impl<T> SimpleTopoSort<T>
where
T: Eq + std::hash::Hash + Clone,
{
pub fn new() -> Self {
SimpleTopoSort { out_graph: HashMap::new(), in_graph: HashMap::new() }
}
pub fn insert<I: IntoIterator<Item = T>>(&mut self, node: T, deps: I) {
self.out_graph.entry(node.clone()).or_insert(HashSet::new());
for dep in deps {
self.out_graph.entry(node.clone()).or_insert(HashSet::new()).insert(dep.clone());
self.in_graph.entry(dep.clone()).or_insert(HashSet::new()).insert(node.clone());
}
}
}
impl<T> IntoIterator for SimpleTopoSort<T>
where
T: Eq + std::hash::Hash + Clone,
{
type Item = <SimpleTopoSortIter<T> as Iterator>::Item;
type IntoIter = SimpleTopoSortIter<T>;
fn into_iter(self) -> Self::IntoIter {
SimpleTopoSortIter::new(self)
}
}
pub struct SimpleTopoSortIter<T> {
data: SimpleTopoSort<T>,
zero_indegree: Vec<T>,
}
impl<T> SimpleTopoSortIter<T>
where
T: Eq + std::hash::Hash + Clone,
{
pub fn new(data: SimpleTopoSort<T>) -> Self {
let mut zero_indegree = Vec::new();
for (node, _) in data.in_graph.iter() {
if !data.out_graph.contains_key(node) {
zero_indegree.push(node.clone());
}
}
for (node, deps) in data.out_graph.iter() {
if deps.is_empty() {
zero_indegree.push(node.clone());
}
}
SimpleTopoSortIter { data, zero_indegree }
}
}
impl<T> Iterator for SimpleTopoSortIter<T>
where
T: Eq + std::hash::Hash + Clone,
{
type Item = Result<T, &'static str>;
fn next(&mut self) -> Option<Self::Item> {
if self.zero_indegree.is_empty() {
if self.data.out_graph.is_empty() {
return None;
}
return Some(Err("Cycle detected"));
}
let node = self.zero_indegree.pop().unwrap();
if let Some(parents) = self.data.in_graph.get(&node) {
for parent in parents.iter() {
let deps = self.data.out_graph.get_mut(parent).unwrap();
deps.remove(&node);
if deps.is_empty() {
self.zero_indegree.push(parent.clone());
}
}
}
self.data.out_graph.remove(&node);
Some(Ok(node))
}
}
#[test]
fn test_sort() {
{
let mut topo_sort = SimpleTopoSort::new();
topo_sort.insert("a", []);
for node in topo_sort {
match node {
Ok(n) => assert_eq!(n, "a"),
Err(e) => panic!("err {}", e),
}
}
}
{
let mut topo_sort = SimpleTopoSort::new();
topo_sort.insert("a", ["b"]);
topo_sort.insert("b", []);
let mut iter = topo_sort.into_iter();
match iter.next() {
Some(Ok(n)) => assert_eq!(n, "b"),
_ => panic!("err"),
}
match iter.next() {
Some(Ok(n)) => assert_eq!(n, "a"),
_ => panic!("err"),
}
assert_eq!(iter.next(), None);
}
}
}
fn find_parent_proto_dir(start_path: impl AsRef<Path>) -> Option<PathBuf> {
let mut dir = start_path.as_ref().canonicalize().ok()?;
loop {
if let Some(name) = dir.file_name().and_then(|n| n.to_str()) {
if name == "proto" {
return Some(dir);
}
}
let parent = dir.parent()?;
if parent == dir {
return None; // Reached root
}
dir = parent.to_path_buf();
}
}

View File

@@ -0,0 +1,40 @@
use crate::error::Result;
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
use hyper_util::client::legacy::Client;
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::rt::TokioExecutor;
use log::info;
use tonic::body::BoxBody;
use yaak_tls::{ClientCertificateConfig, get_tls_config};
// I think ALPN breaks this because we're specifying http2_only
const WITH_ALPN: bool = false;
pub(crate) fn get_transport(
validate_certificates: bool,
client_cert: Option<ClientCertificateConfig>,
) -> Result<Client<HttpsConnector<HttpConnector>, BoxBody>> {
let tls_config = get_tls_config(validate_certificates, WITH_ALPN, client_cert.clone())?;
let mut http = HttpConnector::new();
http.enforce_http(false);
let connector = HttpsConnectorBuilder::new()
.with_tls_config(tls_config)
.https_or_http()
.enable_http2()
.build();
let client = Client::builder(TokioExecutor::new())
.pool_max_idle_per_host(0)
.http2_only(true)
.build(connector);
info!(
"Created gRPC client validate_certs={} client_cert={}",
validate_certificates,
client_cert.is_some()
);
Ok(client)
}

View File

@@ -0,0 +1,31 @@
[package]
name = "yaak-http"
version = "0.1.0"
edition = "2024"
publish = false
[dependencies]
async-compression = { version = "0.4", features = ["tokio", "gzip", "deflate", "brotli", "zstd"] }
async-trait = "0.1"
brotli = "7"
bytes = "1.5.0"
cookie = "0.18.1"
flate2 = "1"
futures-util = "0.3"
url = "2"
zstd = "0.13"
hyper-util = { version = "0.1.17", default-features = false, features = ["client-legacy"] }
log = { workspace = true }
mime_guess = "2.0.5"
regex = "1.11.1"
reqwest = { workspace = true, features = ["rustls-tls-manual-roots-no-provider", "socks", "http2", "stream"] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt", "fs", "io-util"] }
tokio-util = { version = "0.7", features = ["codec", "io", "io-util"] }
tower-service = "0.3.3"
urlencoding = "2.1.3"
yaak-common = { workspace = true }
yaak-models = { workspace = true }
yaak-tls = { workspace = true }

View File

@@ -0,0 +1,78 @@
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, ReadBuf};
/// A stream that chains multiple AsyncRead sources together
pub(crate) struct ChainedReader {
readers: Vec<ReaderType>,
current_index: usize,
current_reader: Option<Box<dyn AsyncRead + Send + Unpin + 'static>>,
}
#[derive(Clone)]
pub(crate) enum ReaderType {
Bytes(Vec<u8>),
FilePath(String),
}
impl ChainedReader {
pub(crate) fn new(readers: Vec<ReaderType>) -> Self {
Self { readers, current_index: 0, current_reader: None }
}
}
impl AsyncRead for ChainedReader {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
loop {
// Try to read from current reader if we have one
if let Some(ref mut reader) = self.current_reader {
let before_len = buf.filled().len();
return match Pin::new(reader).poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
if buf.filled().len() == before_len && buf.remaining() > 0 {
// Current reader is exhausted, move to next
self.current_reader = None;
continue;
}
Poll::Ready(Ok(()))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
};
}
// We need to get the next reader
if self.current_index >= self.readers.len() {
// No more readers
return Poll::Ready(Ok(()));
}
// Get the next reader
let reader_type = self.readers[self.current_index].clone();
self.current_index += 1;
match reader_type {
ReaderType::Bytes(bytes) => {
self.current_reader = Some(Box::new(io::Cursor::new(bytes)));
}
ReaderType::FilePath(path) => {
// We need to handle file opening synchronously in poll_read
// This is a limitation - we'll use blocking file open
match std::fs::File::open(&path) {
Ok(file) => {
// Convert std File to tokio File
let tokio_file = tokio::fs::File::from_std(file);
self.current_reader = Some(Box::new(tokio_file));
}
Err(e) => return Poll::Ready(Err(e)),
}
}
}
}
}
}

View File

@@ -0,0 +1,117 @@
use crate::dns::LocalhostResolver;
use crate::error::Result;
use log::{debug, info, warn};
use reqwest::{Client, Proxy, redirect};
use yaak_tls::{ClientCertificateConfig, get_tls_config};
#[derive(Clone)]
pub struct HttpConnectionProxySettingAuth {
pub user: String,
pub password: String,
}
#[derive(Clone)]
pub enum HttpConnectionProxySetting {
Disabled,
System,
Enabled {
http: String,
https: String,
auth: Option<HttpConnectionProxySettingAuth>,
bypass: String,
},
}
#[derive(Clone)]
pub struct HttpConnectionOptions {
pub id: String,
pub validate_certificates: bool,
pub proxy: HttpConnectionProxySetting,
pub client_certificate: Option<ClientCertificateConfig>,
}
impl HttpConnectionOptions {
pub(crate) fn build_client(&self) -> Result<Client> {
let mut client = Client::builder()
.connection_verbose(true)
.redirect(redirect::Policy::none())
// Decompression is handled by HttpTransaction, not reqwest
.no_gzip()
.no_brotli()
.no_deflate()
.referer(false)
.tls_info(true);
// Configure TLS with optional client certificate
let config =
get_tls_config(self.validate_certificates, true, self.client_certificate.clone())?;
client = client.use_preconfigured_tls(config);
// Configure DNS resolver
client = client.dns_resolver(LocalhostResolver::new());
// Configure proxy
match self.proxy.clone() {
HttpConnectionProxySetting::System => { /* Default */ }
HttpConnectionProxySetting::Disabled => {
client = client.no_proxy();
}
HttpConnectionProxySetting::Enabled { http, https, auth, bypass } => {
for p in build_enabled_proxy(http, https, auth, bypass) {
client = client.proxy(p)
}
}
}
info!(
"Building new HTTP client validate_certificates={} client_cert={}",
self.validate_certificates,
self.client_certificate.is_some()
);
Ok(client.build()?)
}
}
fn build_enabled_proxy(
http: String,
https: String,
auth: Option<HttpConnectionProxySettingAuth>,
bypass: String,
) -> Vec<Proxy> {
debug!("Using proxy http={http} https={https} bypass={bypass}");
let mut proxies = Vec::new();
if !http.is_empty() {
match Proxy::http(http) {
Ok(mut proxy) => {
if let Some(HttpConnectionProxySettingAuth { user, password }) = auth.clone() {
debug!("Using http proxy auth");
proxy = proxy.basic_auth(user.as_str(), password.as_str());
}
proxies.push(proxy.no_proxy(reqwest::NoProxy::from_string(&bypass)));
}
Err(e) => {
warn!("Failed to apply http proxy {e:?}");
}
};
}
if !https.is_empty() {
match Proxy::https(https) {
Ok(mut proxy) => {
if let Some(HttpConnectionProxySettingAuth { user, password }) = auth {
debug!("Using https proxy auth");
proxy = proxy.basic_auth(user.as_str(), password.as_str());
}
proxies.push(proxy.no_proxy(reqwest::NoProxy::from_string(&bypass)));
}
Err(e) => {
warn!("Failed to apply https proxy {e:?}");
}
};
}
proxies
}

View File

@@ -0,0 +1,484 @@
//! Custom cookie handling for HTTP requests
//!
//! This module provides cookie storage and matching functionality that was previously
//! delegated to reqwest. It implements RFC 6265 cookie domain and path matching.
use log::debug;
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use url::Url;
use yaak_models::models::{Cookie, CookieDomain, CookieExpires};
/// A thread-safe cookie store that can be shared across requests
#[derive(Debug, Clone)]
pub struct CookieStore {
cookies: Arc<Mutex<Vec<Cookie>>>,
}
impl Default for CookieStore {
fn default() -> Self {
Self::new()
}
}
impl CookieStore {
/// Create a new empty cookie store
pub fn new() -> Self {
Self { cookies: Arc::new(Mutex::new(Vec::new())) }
}
/// Create a cookie store from existing cookies
pub fn from_cookies(cookies: Vec<Cookie>) -> Self {
Self { cookies: Arc::new(Mutex::new(cookies)) }
}
/// Get all cookies (for persistence)
pub fn get_all_cookies(&self) -> Vec<Cookie> {
self.cookies.lock().unwrap().clone()
}
/// Get the Cookie header value for the given URL
pub fn get_cookie_header(&self, url: &Url) -> Option<String> {
let cookies = self.cookies.lock().unwrap();
let now = SystemTime::now();
let matching_cookies: Vec<_> = cookies
.iter()
.filter(|cookie| self.cookie_matches(cookie, url, &now))
.filter_map(|cookie| {
// Parse the raw cookie to get name=value
parse_cookie_name_value(&cookie.raw_cookie)
})
.collect();
if matching_cookies.is_empty() {
None
} else {
Some(
matching_cookies
.into_iter()
.map(|(name, value)| format!("{}={}", name, value))
.collect::<Vec<_>>()
.join("; "),
)
}
}
/// Parse Set-Cookie headers and add cookies to the store
pub fn store_cookies_from_response(&self, url: &Url, set_cookie_headers: &[String]) {
let mut cookies = self.cookies.lock().unwrap();
for header_value in set_cookie_headers {
if let Some(cookie) = parse_set_cookie(header_value, url) {
// Remove any existing cookie with the same name and domain
cookies.retain(|existing| !cookies_match(existing, &cookie));
debug!(
"Storing cookie: {} for domain {:?}",
parse_cookie_name_value(&cookie.raw_cookie)
.map(|(n, _)| n)
.unwrap_or_else(|| "unknown".to_string()),
cookie.domain
);
cookies.push(cookie);
}
}
}
/// Check if a cookie matches the given URL
fn cookie_matches(&self, cookie: &Cookie, url: &Url, now: &SystemTime) -> bool {
// Check expiration
if let CookieExpires::AtUtc(expiry_str) = &cookie.expires {
if let Ok(expiry) = parse_cookie_date(expiry_str) {
if expiry < *now {
return false;
}
}
}
// Check domain
let url_host = match url.host_str() {
Some(h) => h.to_lowercase(),
None => return false,
};
let domain_matches = match &cookie.domain {
CookieDomain::HostOnly(domain) => url_host == domain.to_lowercase(),
CookieDomain::Suffix(domain) => {
let domain_lower = domain.to_lowercase();
url_host == domain_lower || url_host.ends_with(&format!(".{}", domain_lower))
}
// NotPresent and Empty should never occur in practice since we always set domain
// when parsing Set-Cookie headers. Treat as non-matching to be safe.
CookieDomain::NotPresent | CookieDomain::Empty => false,
};
if !domain_matches {
return false;
}
// Check path
let (cookie_path, _) = &cookie.path;
let url_path = url.path();
path_matches(url_path, cookie_path)
}
}
/// Parse name=value from a cookie string (raw_cookie format)
fn parse_cookie_name_value(raw_cookie: &str) -> Option<(String, String)> {
// The raw_cookie typically looks like "name=value" or "name=value; attr1; attr2=..."
let first_part = raw_cookie.split(';').next()?;
let mut parts = first_part.splitn(2, '=');
let name = parts.next()?.trim().to_string();
let value = parts.next().unwrap_or("").trim().to_string();
if name.is_empty() { None } else { Some((name, value)) }
}
/// Parse a Set-Cookie header into a Cookie
fn parse_set_cookie(header_value: &str, request_url: &Url) -> Option<Cookie> {
let parsed = cookie::Cookie::parse(header_value).ok()?;
let raw_cookie = format!("{}={}", parsed.name(), parsed.value());
// Determine domain
let domain = if let Some(domain_attr) = parsed.domain() {
// Domain attribute present - this is a suffix match
let domain = domain_attr.trim_start_matches('.').to_lowercase();
// Reject single-component domains (TLDs) except localhost
if is_single_component_domain(&domain) && !is_localhost(&domain) {
debug!("Rejecting cookie with single-component domain: {}", domain);
return None;
}
CookieDomain::Suffix(domain)
} else {
// No domain attribute - host-only cookie
CookieDomain::HostOnly(request_url.host_str().unwrap_or("").to_lowercase())
};
// Determine expiration
let expires = if let Some(max_age) = parsed.max_age() {
let duration = Duration::from_secs(max_age.whole_seconds().max(0) as u64);
let expiry = SystemTime::now() + duration;
let expiry_secs = expiry.duration_since(UNIX_EPOCH).unwrap_or_default().as_secs();
CookieExpires::AtUtc(format!("{}", expiry_secs))
} else if let Some(expires_time) = parsed.expires() {
match expires_time {
cookie::Expiration::DateTime(dt) => {
let timestamp = dt.unix_timestamp();
CookieExpires::AtUtc(format!("{}", timestamp))
}
cookie::Expiration::Session => CookieExpires::SessionEnd,
}
} else {
CookieExpires::SessionEnd
};
// Determine path
let path = if let Some(path_attr) = parsed.path() {
(path_attr.to_string(), true)
} else {
// Default path is the directory of the request URI
let default_path = default_cookie_path(request_url.path());
(default_path, false)
};
Some(Cookie { raw_cookie, domain, expires, path })
}
/// Get the default cookie path from a request path (RFC 6265 Section 5.1.4)
fn default_cookie_path(request_path: &str) -> String {
if request_path.is_empty() || !request_path.starts_with('/') {
return "/".to_string();
}
// Find the last slash
if let Some(last_slash) = request_path.rfind('/') {
if last_slash == 0 { "/".to_string() } else { request_path[..last_slash].to_string() }
} else {
"/".to_string()
}
}
/// Check if a request path matches a cookie path (RFC 6265 Section 5.1.4)
fn path_matches(request_path: &str, cookie_path: &str) -> bool {
if request_path == cookie_path {
return true;
}
if request_path.starts_with(cookie_path) {
// Cookie path must end with / or the char after cookie_path in request_path must be /
if cookie_path.ends_with('/') {
return true;
}
if request_path.chars().nth(cookie_path.len()) == Some('/') {
return true;
}
}
false
}
/// Check if two cookies match (same name and domain)
fn cookies_match(a: &Cookie, b: &Cookie) -> bool {
let name_a = parse_cookie_name_value(&a.raw_cookie).map(|(n, _)| n);
let name_b = parse_cookie_name_value(&b.raw_cookie).map(|(n, _)| n);
if name_a != name_b {
return false;
}
// Check domain match
match (&a.domain, &b.domain) {
(CookieDomain::HostOnly(d1), CookieDomain::HostOnly(d2)) => {
d1.to_lowercase() == d2.to_lowercase()
}
(CookieDomain::Suffix(d1), CookieDomain::Suffix(d2)) => {
d1.to_lowercase() == d2.to_lowercase()
}
_ => false,
}
}
/// Parse a cookie date string (Unix timestamp in our format)
fn parse_cookie_date(date_str: &str) -> Result<SystemTime, ()> {
let timestamp: i64 = date_str.parse().map_err(|_| ())?;
let duration = Duration::from_secs(timestamp.max(0) as u64);
Ok(UNIX_EPOCH + duration)
}
/// Check if a domain is a single-component domain (TLD)
/// e.g., "com", "org", "net" - domains without any dots
fn is_single_component_domain(domain: &str) -> bool {
// Empty or only dots
let trimmed = domain.trim_matches('.');
if trimmed.is_empty() {
return true;
}
// IPv6 addresses use colons, not dots - don't consider them single-component
if domain.contains(':') {
return false;
}
!trimmed.contains('.')
}
/// Check if a domain is localhost or a localhost variant
fn is_localhost(domain: &str) -> bool {
let lower = domain.to_lowercase();
lower == "localhost"
|| lower.ends_with(".localhost")
|| lower == "127.0.0.1"
|| lower == "::1"
|| lower == "[::1]"
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_cookie_name_value() {
assert_eq!(
parse_cookie_name_value("session=abc123"),
Some(("session".to_string(), "abc123".to_string()))
);
assert_eq!(
parse_cookie_name_value("name=value; Path=/; HttpOnly"),
Some(("name".to_string(), "value".to_string()))
);
assert_eq!(parse_cookie_name_value("empty="), Some(("empty".to_string(), "".to_string())));
assert_eq!(parse_cookie_name_value(""), None);
}
#[test]
fn test_path_matches() {
assert!(path_matches("/", "/"));
assert!(path_matches("/foo", "/"));
assert!(path_matches("/foo/bar", "/foo"));
assert!(path_matches("/foo/bar", "/foo/"));
assert!(!path_matches("/foobar", "/foo"));
assert!(!path_matches("/foo", "/foo/bar"));
}
#[test]
fn test_default_cookie_path() {
assert_eq!(default_cookie_path("/"), "/");
assert_eq!(default_cookie_path("/foo"), "/");
assert_eq!(default_cookie_path("/foo/bar"), "/foo");
assert_eq!(default_cookie_path("/foo/bar/baz"), "/foo/bar");
assert_eq!(default_cookie_path(""), "/");
}
#[test]
fn test_cookie_store_basic() {
let store = CookieStore::new();
let url = Url::parse("https://example.com/path").unwrap();
// Initially empty
assert!(store.get_cookie_header(&url).is_none());
// Add a cookie
store.store_cookies_from_response(&url, &["session=abc123".to_string()]);
// Should now have the cookie
let header = store.get_cookie_header(&url);
assert_eq!(header, Some("session=abc123".to_string()));
}
#[test]
fn test_cookie_domain_matching() {
let store = CookieStore::new();
let url = Url::parse("https://example.com/").unwrap();
// Cookie with domain attribute (suffix match)
store.store_cookies_from_response(
&url,
&["domain_cookie=value; Domain=example.com".to_string()],
);
// Should match example.com
assert!(store.get_cookie_header(&url).is_some());
// Should match subdomain
let subdomain_url = Url::parse("https://sub.example.com/").unwrap();
assert!(store.get_cookie_header(&subdomain_url).is_some());
// Should not match different domain
let other_url = Url::parse("https://other.com/").unwrap();
assert!(store.get_cookie_header(&other_url).is_none());
}
#[test]
fn test_cookie_path_matching() {
let store = CookieStore::new();
let url = Url::parse("https://example.com/api/v1").unwrap();
// Cookie with path
store.store_cookies_from_response(&url, &["api_cookie=value; Path=/api".to_string()]);
// Should match /api/v1
assert!(store.get_cookie_header(&url).is_some());
// Should match /api
let api_url = Url::parse("https://example.com/api").unwrap();
assert!(store.get_cookie_header(&api_url).is_some());
// Should not match /other
let other_url = Url::parse("https://example.com/other").unwrap();
assert!(store.get_cookie_header(&other_url).is_none());
}
#[test]
fn test_cookie_replacement() {
let store = CookieStore::new();
let url = Url::parse("https://example.com/").unwrap();
// Add a cookie
store.store_cookies_from_response(&url, &["session=old".to_string()]);
assert_eq!(store.get_cookie_header(&url), Some("session=old".to_string()));
// Replace with new value
store.store_cookies_from_response(&url, &["session=new".to_string()]);
assert_eq!(store.get_cookie_header(&url), Some("session=new".to_string()));
// Should only have one cookie
assert_eq!(store.get_all_cookies().len(), 1);
}
#[test]
fn test_is_single_component_domain() {
// Single-component domains (TLDs)
assert!(is_single_component_domain("com"));
assert!(is_single_component_domain("org"));
assert!(is_single_component_domain("net"));
assert!(is_single_component_domain("localhost")); // Still single-component, but allowed separately
// Multi-component domains
assert!(!is_single_component_domain("example.com"));
assert!(!is_single_component_domain("sub.example.com"));
assert!(!is_single_component_domain("co.uk"));
// Edge cases
assert!(is_single_component_domain("")); // Empty is treated as single-component
assert!(is_single_component_domain(".")); // Only dots
assert!(is_single_component_domain("..")); // Only dots
// IPv6 addresses (have colons, not dots)
assert!(!is_single_component_domain("::1")); // IPv6 localhost
assert!(!is_single_component_domain("[::1]")); // Bracketed IPv6
assert!(!is_single_component_domain("2001:db8::1")); // IPv6 address
}
#[test]
fn test_is_localhost() {
// Localhost variants
assert!(is_localhost("localhost"));
assert!(is_localhost("LOCALHOST")); // Case-insensitive
assert!(is_localhost("sub.localhost"));
assert!(is_localhost("app.sub.localhost"));
// IP localhost
assert!(is_localhost("127.0.0.1"));
assert!(is_localhost("::1"));
assert!(is_localhost("[::1]"));
// Not localhost
assert!(!is_localhost("example.com"));
assert!(!is_localhost("localhost.com")); // .com domain, not localhost
assert!(!is_localhost("notlocalhost"));
}
#[test]
fn test_reject_tld_cookies() {
let store = CookieStore::new();
let url = Url::parse("https://example.com/").unwrap();
// Try to set a cookie with Domain=com (TLD)
store.store_cookies_from_response(&url, &["bad=cookie; Domain=com".to_string()]);
// Should be rejected - no cookies stored
assert_eq!(store.get_all_cookies().len(), 0);
assert!(store.get_cookie_header(&url).is_none());
}
#[test]
fn test_allow_localhost_cookies() {
let store = CookieStore::new();
let url = Url::parse("http://localhost:3000/").unwrap();
// Cookie with Domain=localhost should be allowed
store.store_cookies_from_response(&url, &["session=abc; Domain=localhost".to_string()]);
// Should be accepted
assert_eq!(store.get_all_cookies().len(), 1);
assert!(store.get_cookie_header(&url).is_some());
}
#[test]
fn test_allow_127_0_0_1_cookies() {
let store = CookieStore::new();
let url = Url::parse("http://127.0.0.1:8080/").unwrap();
// Cookie without Domain attribute (host-only) should work
store.store_cookies_from_response(&url, &["session=xyz".to_string()]);
// Should be accepted
assert_eq!(store.get_all_cookies().len(), 1);
assert!(store.get_cookie_header(&url).is_some());
}
#[test]
fn test_allow_normal_domain_cookies() {
let store = CookieStore::new();
let url = Url::parse("https://example.com/").unwrap();
// Cookie with valid domain should be allowed
store.store_cookies_from_response(&url, &["session=abc; Domain=example.com".to_string()]);
// Should be accepted
assert_eq!(store.get_all_cookies().len(), 1);
assert!(store.get_cookie_header(&url).is_some());
}
}

View File

@@ -0,0 +1,188 @@
use crate::error::{Error, Result};
use async_compression::tokio::bufread::{
BrotliDecoder, DeflateDecoder as AsyncDeflateDecoder, GzipDecoder,
ZstdDecoder as AsyncZstdDecoder,
};
use flate2::read::{DeflateDecoder, GzDecoder};
use std::io::Read;
use tokio::io::{AsyncBufRead, AsyncRead};
/// Supported compression encodings
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ContentEncoding {
Gzip,
Deflate,
Brotli,
Zstd,
Identity,
}
impl ContentEncoding {
/// Parse a Content-Encoding header value into an encoding type.
/// Returns Identity for unknown or missing encodings.
pub fn from_header(value: Option<&str>) -> Self {
match value.map(|s| s.trim().to_lowercase()).as_deref() {
Some("gzip") | Some("x-gzip") => ContentEncoding::Gzip,
Some("deflate") => ContentEncoding::Deflate,
Some("br") => ContentEncoding::Brotli,
Some("zstd") => ContentEncoding::Zstd,
_ => ContentEncoding::Identity,
}
}
}
/// Result of decompression, containing both the decompressed data and size info
#[derive(Debug)]
pub struct DecompressResult {
pub data: Vec<u8>,
pub compressed_size: u64,
pub decompressed_size: u64,
}
/// Decompress data based on the Content-Encoding.
/// Returns the original data unchanged if encoding is Identity or unknown.
pub fn decompress(data: Vec<u8>, encoding: ContentEncoding) -> Result<DecompressResult> {
let compressed_size = data.len() as u64;
let decompressed = match encoding {
ContentEncoding::Identity => data,
ContentEncoding::Gzip => decompress_gzip(&data)?,
ContentEncoding::Deflate => decompress_deflate(&data)?,
ContentEncoding::Brotli => decompress_brotli(&data)?,
ContentEncoding::Zstd => decompress_zstd(&data)?,
};
let decompressed_size = decompressed.len() as u64;
Ok(DecompressResult { data: decompressed, compressed_size, decompressed_size })
}
fn decompress_gzip(data: &[u8]) -> Result<Vec<u8>> {
let mut decoder = GzDecoder::new(data);
let mut decompressed = Vec::new();
decoder
.read_to_end(&mut decompressed)
.map_err(|e| Error::DecompressionError(format!("gzip decompression failed: {}", e)))?;
Ok(decompressed)
}
fn decompress_deflate(data: &[u8]) -> Result<Vec<u8>> {
let mut decoder = DeflateDecoder::new(data);
let mut decompressed = Vec::new();
decoder
.read_to_end(&mut decompressed)
.map_err(|e| Error::DecompressionError(format!("deflate decompression failed: {}", e)))?;
Ok(decompressed)
}
fn decompress_brotli(data: &[u8]) -> Result<Vec<u8>> {
let mut decompressed = Vec::new();
brotli::BrotliDecompress(&mut std::io::Cursor::new(data), &mut decompressed)
.map_err(|e| Error::DecompressionError(format!("brotli decompression failed: {}", e)))?;
Ok(decompressed)
}
fn decompress_zstd(data: &[u8]) -> Result<Vec<u8>> {
zstd::stream::decode_all(std::io::Cursor::new(data))
.map_err(|e| Error::DecompressionError(format!("zstd decompression failed: {}", e)))
}
/// Create a streaming decompressor that wraps an async reader.
/// Returns an AsyncRead that decompresses data on-the-fly.
pub fn streaming_decoder<R: AsyncBufRead + Unpin + Send + 'static>(
reader: R,
encoding: ContentEncoding,
) -> Box<dyn AsyncRead + Unpin + Send> {
match encoding {
ContentEncoding::Identity => Box::new(reader),
ContentEncoding::Gzip => Box::new(GzipDecoder::new(reader)),
ContentEncoding::Deflate => Box::new(AsyncDeflateDecoder::new(reader)),
ContentEncoding::Brotli => Box::new(BrotliDecoder::new(reader)),
ContentEncoding::Zstd => Box::new(AsyncZstdDecoder::new(reader)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use flate2::Compression;
use flate2::write::GzEncoder;
use std::io::Write;
#[test]
fn test_content_encoding_from_header() {
assert_eq!(ContentEncoding::from_header(Some("gzip")), ContentEncoding::Gzip);
assert_eq!(ContentEncoding::from_header(Some("x-gzip")), ContentEncoding::Gzip);
assert_eq!(ContentEncoding::from_header(Some("GZIP")), ContentEncoding::Gzip);
assert_eq!(ContentEncoding::from_header(Some("deflate")), ContentEncoding::Deflate);
assert_eq!(ContentEncoding::from_header(Some("br")), ContentEncoding::Brotli);
assert_eq!(ContentEncoding::from_header(Some("zstd")), ContentEncoding::Zstd);
assert_eq!(ContentEncoding::from_header(Some("identity")), ContentEncoding::Identity);
assert_eq!(ContentEncoding::from_header(Some("unknown")), ContentEncoding::Identity);
assert_eq!(ContentEncoding::from_header(None), ContentEncoding::Identity);
}
#[test]
fn test_decompress_identity() {
let data = b"hello world".to_vec();
let result = decompress(data.clone(), ContentEncoding::Identity).unwrap();
assert_eq!(result.data, data);
assert_eq!(result.compressed_size, 11);
assert_eq!(result.decompressed_size, 11);
}
#[test]
fn test_decompress_gzip() {
// Compress some data with gzip
let original = b"hello world, this is a test of gzip compression";
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(original).unwrap();
let compressed = encoder.finish().unwrap();
let result = decompress(compressed.clone(), ContentEncoding::Gzip).unwrap();
assert_eq!(result.data, original);
assert_eq!(result.compressed_size, compressed.len() as u64);
assert_eq!(result.decompressed_size, original.len() as u64);
}
#[test]
fn test_decompress_deflate() {
// Compress some data with deflate
let original = b"hello world, this is a test of deflate compression";
let mut encoder = flate2::write::DeflateEncoder::new(Vec::new(), Compression::default());
encoder.write_all(original).unwrap();
let compressed = encoder.finish().unwrap();
let result = decompress(compressed.clone(), ContentEncoding::Deflate).unwrap();
assert_eq!(result.data, original);
assert_eq!(result.compressed_size, compressed.len() as u64);
assert_eq!(result.decompressed_size, original.len() as u64);
}
#[test]
fn test_decompress_brotli() {
// Compress some data with brotli
let original = b"hello world, this is a test of brotli compression";
let mut compressed = Vec::new();
let mut writer = brotli::CompressorWriter::new(&mut compressed, 4096, 4, 22);
writer.write_all(original).unwrap();
drop(writer);
let result = decompress(compressed.clone(), ContentEncoding::Brotli).unwrap();
assert_eq!(result.data, original);
assert_eq!(result.compressed_size, compressed.len() as u64);
assert_eq!(result.decompressed_size, original.len() as u64);
}
#[test]
fn test_decompress_zstd() {
// Compress some data with zstd
let original = b"hello world, this is a test of zstd compression";
let compressed = zstd::stream::encode_all(std::io::Cursor::new(original), 3).unwrap();
let result = decompress(compressed.clone(), ContentEncoding::Zstd).unwrap();
assert_eq!(result.data, original);
assert_eq!(result.compressed_size, compressed.len() as u64);
assert_eq!(result.decompressed_size, original.len() as u64);
}
}

View File

@@ -0,0 +1,54 @@
use hyper_util::client::legacy::connect::dns::{
GaiResolver as HyperGaiResolver, Name as HyperName,
};
use reqwest::dns::{Addrs, Name, Resolve, Resolving};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::str::FromStr;
use std::sync::Arc;
use tower_service::Service;
#[derive(Clone)]
pub struct LocalhostResolver {
fallback: HyperGaiResolver,
}
impl LocalhostResolver {
pub fn new() -> Arc<Self> {
let resolver = HyperGaiResolver::new();
Arc::new(Self { fallback: resolver })
}
}
impl Resolve for LocalhostResolver {
fn resolve(&self, name: Name) -> Resolving {
let host = name.as_str().to_lowercase();
let is_localhost = host.ends_with(".localhost");
if is_localhost {
// Port 0 is fine; reqwest replaces it with the URL's explicit
// port or the schemes default (80/443, etc.).
// (See docs note below.)
let addrs: Vec<SocketAddr> = vec![
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 0),
];
return Box::pin(async move {
Ok::<Addrs, Box<dyn std::error::Error + Send + Sync>>(Box::new(addrs.into_iter()))
});
}
let mut fallback = self.fallback.clone();
let name_str = name.as_str().to_string();
Box::pin(async move {
match HyperName::from_str(&name_str) {
Ok(n) => fallback
.call(n)
.await
.map(|addrs| Box::new(addrs) as Addrs)
.map_err(|err| Box::new(err) as Box<dyn std::error::Error + Send + Sync>),
Err(e) => Err(Box::new(e) as Box<dyn std::error::Error + Send + Sync>),
}
})
}
}

View File

@@ -0,0 +1,37 @@
use serde::{Serialize, Serializer};
use thiserror::Error;
#[derive(Error, Debug)]
pub enum Error {
#[error("Client error: {0:?}")]
Client(#[from] reqwest::Error),
#[error(transparent)]
TlsError(#[from] yaak_tls::error::Error),
#[error("Request failed with {0:?}")]
RequestError(String),
#[error("Request canceled")]
RequestCanceledError,
#[error("Timeout of {0:?} reached")]
RequestTimeout(std::time::Duration),
#[error("Decompression error: {0}")]
DecompressionError(String),
#[error("Failed to read response body: {0}")]
BodyReadError(String),
}
impl Serialize for Error {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.to_string().as_ref())
}
}
pub type Result<T> = std::result::Result<T, Error>;

View File

@@ -0,0 +1,13 @@
mod chained_reader;
pub mod client;
pub mod cookies;
pub mod decompress;
pub mod dns;
pub mod error;
pub mod manager;
pub mod path_placeholders;
mod proto;
pub mod sender;
pub mod tee_reader;
pub mod transaction;
pub mod types;

View File

@@ -0,0 +1,40 @@
use crate::client::HttpConnectionOptions;
use crate::error::Result;
use log::info;
use reqwest::Client;
use std::collections::BTreeMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
pub struct HttpConnectionManager {
connections: Arc<RwLock<BTreeMap<String, (Client, Instant)>>>,
ttl: Duration,
}
impl HttpConnectionManager {
pub fn new() -> Self {
Self {
connections: Arc::new(RwLock::new(BTreeMap::new())),
ttl: Duration::from_secs(10 * 60),
}
}
pub async fn get_client(&self, opt: &HttpConnectionOptions) -> Result<Client> {
let mut connections = self.connections.write().await;
let id = opt.id.clone();
// Clean old connections
connections.retain(|_, (_, last_used)| last_used.elapsed() <= self.ttl);
if let Some((c, last_used)) = connections.get_mut(&id) {
info!("Re-using HTTP client {id}");
*last_used = Instant::now();
return Ok(c.clone());
}
let c = opt.build_client()?;
connections.insert(id.into(), (c.clone(), Instant::now()));
Ok(c)
}
}

View File

@@ -0,0 +1,167 @@
use yaak_models::models::HttpUrlParameter;
pub fn apply_path_placeholders(
url: &str,
parameters: &Vec<HttpUrlParameter>,
) -> (String, Vec<HttpUrlParameter>) {
let mut new_parameters = Vec::new();
let mut url = url.to_string();
for p in parameters {
if !p.enabled || p.name.is_empty() {
continue;
}
// Replace path parameters with values from URL parameters
let old_url_string = url.clone();
url = replace_path_placeholder(&p, url.as_str());
// Remove as param if it modified the URL
if old_url_string == *url {
new_parameters.push(p.to_owned());
}
}
(url, new_parameters)
}
fn replace_path_placeholder(p: &HttpUrlParameter, url: &str) -> String {
if !p.enabled {
return url.to_string();
}
if !p.name.starts_with(":") {
return url.to_string();
}
let re = regex::Regex::new(format!("(/){}([/?#]|$)", p.name).as_str()).unwrap();
let result = re
.replace_all(url, |cap: &regex::Captures| {
format!(
"{}{}{}",
cap[1].to_string(),
urlencoding::encode(p.value.as_str()),
cap[2].to_string()
)
})
.into_owned();
result
}
#[cfg(test)]
mod placeholder_tests {
use crate::path_placeholders::{apply_path_placeholders, replace_path_placeholder};
use yaak_models::models::{HttpRequest, HttpUrlParameter};
#[test]
fn placeholder_middle() {
let p =
HttpUrlParameter { name: ":foo".into(), value: "xxx".into(), enabled: true, id: None };
assert_eq!(
replace_path_placeholder(&p, "https://example.com/:foo/bar"),
"https://example.com/xxx/bar",
);
}
#[test]
fn placeholder_end() {
let p =
HttpUrlParameter { name: ":foo".into(), value: "xxx".into(), enabled: true, id: None };
assert_eq!(
replace_path_placeholder(&p, "https://example.com/:foo"),
"https://example.com/xxx",
);
}
#[test]
fn placeholder_query() {
let p =
HttpUrlParameter { name: ":foo".into(), value: "xxx".into(), enabled: true, id: None };
assert_eq!(
replace_path_placeholder(&p, "https://example.com/:foo?:foo"),
"https://example.com/xxx?:foo",
);
}
#[test]
fn placeholder_missing() {
let p = HttpUrlParameter {
enabled: true,
name: "".to_string(),
value: "".to_string(),
id: None,
};
assert_eq!(
replace_path_placeholder(&p, "https://example.com/:missing"),
"https://example.com/:missing",
);
}
#[test]
fn placeholder_disabled() {
let p = HttpUrlParameter {
enabled: false,
name: ":foo".to_string(),
value: "xxx".to_string(),
id: None,
};
assert_eq!(
replace_path_placeholder(&p, "https://example.com/:foo"),
"https://example.com/:foo",
);
}
#[test]
fn placeholder_prefix() {
let p =
HttpUrlParameter { name: ":foo".into(), value: "xxx".into(), enabled: true, id: None };
assert_eq!(
replace_path_placeholder(&p, "https://example.com/:foooo"),
"https://example.com/:foooo",
);
}
#[test]
fn placeholder_encode() {
let p = HttpUrlParameter {
name: ":foo".into(),
value: "Hello World".into(),
enabled: true,
id: None,
};
assert_eq!(
replace_path_placeholder(&p, "https://example.com/:foo"),
"https://example.com/Hello%20World",
);
}
#[test]
fn apply_placeholder() {
let req = HttpRequest {
url: "example.com/:a/bar".to_string(),
url_parameters: vec![
HttpUrlParameter {
name: "b".to_string(),
value: "bbb".to_string(),
enabled: true,
id: None,
},
HttpUrlParameter {
name: ":a".to_string(),
value: "aaa".to_string(),
enabled: true,
id: None,
},
],
..Default::default()
};
let (url, url_parameters) = apply_path_placeholders(&req.url, &req.url_parameters);
// Pattern match back to access it
assert_eq!(url, "example.com/aaa/bar");
assert_eq!(url_parameters.len(), 1);
assert_eq!(url_parameters[0].name, "b");
assert_eq!(url_parameters[0].value, "bbb");
}
}

View File

@@ -0,0 +1,29 @@
use reqwest::Url;
use std::str::FromStr;
pub(crate) fn ensure_proto(url_str: &str) -> String {
if url_str.is_empty() {
return "".to_string();
}
if url_str.starts_with("http://") || url_str.starts_with("https://") {
return url_str.to_string();
}
// Url::from_str will fail without a proto, so add one
let parseable_url = format!("http://{}", url_str);
if let Ok(u) = Url::from_str(parseable_url.as_str()) {
match u.host() {
Some(host) => {
let h = host.to_string();
// These TLDs force HTTPS
if h.ends_with(".app") || h.ends_with(".dev") || h.ends_with(".page") {
return format!("https://{url_str}");
}
}
None => {}
}
}
format!("http://{url_str}")
}

View File

@@ -0,0 +1,482 @@
use crate::decompress::{ContentEncoding, streaming_decoder};
use crate::error::{Error, Result};
use crate::types::{SendableBody, SendableHttpRequest};
use async_trait::async_trait;
use futures_util::StreamExt;
use reqwest::{Client, Method, Version};
use std::fmt::Display;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, BufReader, ReadBuf};
use tokio::sync::mpsc;
use tokio_util::io::StreamReader;
#[derive(Debug, Clone)]
pub enum RedirectBehavior {
/// 307/308: Method and body are preserved
Preserve,
/// 303 or 301/302 with POST: Method changed to GET, body dropped
DropBody,
}
#[derive(Debug, Clone)]
pub enum HttpResponseEvent {
Setting(String, String),
Info(String),
Redirect {
url: String,
status: u16,
behavior: RedirectBehavior,
},
SendUrl {
method: String,
path: String,
},
ReceiveUrl {
version: Version,
status: String,
},
HeaderUp(String, String),
HeaderDown(String, String),
ChunkSent {
bytes: usize,
},
ChunkReceived {
bytes: usize,
},
}
impl Display for HttpResponseEvent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
HttpResponseEvent::Setting(name, value) => write!(f, "* Setting {}={}", name, value),
HttpResponseEvent::Info(s) => write!(f, "* {}", s),
HttpResponseEvent::Redirect { url, status, behavior } => {
let behavior_str = match behavior {
RedirectBehavior::Preserve => "preserve",
RedirectBehavior::DropBody => "drop body",
};
write!(f, "* Redirect {} -> {} ({})", status, url, behavior_str)
}
HttpResponseEvent::SendUrl { method, path } => write!(f, "> {} {}", method, path),
HttpResponseEvent::ReceiveUrl { version, status } => {
write!(f, "< {} {}", version_to_str(version), status)
}
HttpResponseEvent::HeaderUp(name, value) => write!(f, "> {}: {}", name, value),
HttpResponseEvent::HeaderDown(name, value) => write!(f, "< {}: {}", name, value),
HttpResponseEvent::ChunkSent { bytes } => write!(f, "> [{} bytes sent]", bytes),
HttpResponseEvent::ChunkReceived { bytes } => write!(f, "< [{} bytes received]", bytes),
}
}
}
impl From<HttpResponseEvent> for yaak_models::models::HttpResponseEventData {
fn from(event: HttpResponseEvent) -> Self {
use yaak_models::models::HttpResponseEventData as D;
match event {
HttpResponseEvent::Setting(name, value) => D::Setting { name, value },
HttpResponseEvent::Info(message) => D::Info { message },
HttpResponseEvent::Redirect { url, status, behavior } => D::Redirect {
url,
status,
behavior: match behavior {
RedirectBehavior::Preserve => "preserve".to_string(),
RedirectBehavior::DropBody => "drop_body".to_string(),
},
},
HttpResponseEvent::SendUrl { method, path } => D::SendUrl { method, path },
HttpResponseEvent::ReceiveUrl { version, status } => {
D::ReceiveUrl { version: format!("{:?}", version), status }
}
HttpResponseEvent::HeaderUp(name, value) => D::HeaderUp { name, value },
HttpResponseEvent::HeaderDown(name, value) => D::HeaderDown { name, value },
HttpResponseEvent::ChunkSent { bytes } => D::ChunkSent { bytes },
HttpResponseEvent::ChunkReceived { bytes } => D::ChunkReceived { bytes },
}
}
}
/// Statistics about the body after consumption
#[derive(Debug, Default, Clone)]
pub struct BodyStats {
/// Size of the body as received over the wire (before decompression)
pub size_compressed: u64,
/// Size of the body after decompression
pub size_decompressed: u64,
}
/// An AsyncRead wrapper that sends chunk events as data is read
pub struct TrackingRead<R> {
inner: R,
event_tx: mpsc::Sender<HttpResponseEvent>,
ended: bool,
}
impl<R> TrackingRead<R> {
pub fn new(inner: R, event_tx: mpsc::Sender<HttpResponseEvent>) -> Self {
Self { inner, event_tx, ended: false }
}
}
impl<R: AsyncRead + Unpin> AsyncRead for TrackingRead<R> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let before = buf.filled().len();
let result = Pin::new(&mut self.inner).poll_read(cx, buf);
if let Poll::Ready(Ok(())) = &result {
let bytes_read = buf.filled().len() - before;
if bytes_read > 0 {
// Ignore send errors - receiver may have been dropped or channel is full
let _ =
self.event_tx.try_send(HttpResponseEvent::ChunkReceived { bytes: bytes_read });
} else if !self.ended {
self.ended = true;
}
}
result
}
}
/// Type alias for the body stream
type BodyStream = Pin<Box<dyn AsyncRead + Send>>;
/// HTTP response with deferred body consumption.
/// Headers are available immediately after send(), body can be consumed in different ways.
/// Note: Debug is manually implemented since BodyStream doesn't implement Debug.
pub struct HttpResponse {
/// HTTP status code
pub status: u16,
/// HTTP status reason phrase (e.g., "OK", "Not Found")
pub status_reason: Option<String>,
/// Response headers (Vec to support multiple headers with same name, e.g., Set-Cookie)
pub headers: Vec<(String, String)>,
/// Request headers (Vec to support multiple headers with same name)
pub request_headers: Vec<(String, String)>,
/// Content-Length from headers (may differ from actual body size)
pub content_length: Option<u64>,
/// Final URL (after redirects)
pub url: String,
/// Remote address of the server
pub remote_addr: Option<String>,
/// HTTP version (e.g., "HTTP/1.1", "HTTP/2")
pub version: Option<String>,
/// The body stream (consumed when calling bytes(), text(), write_to_file(), or drain())
body_stream: Option<BodyStream>,
/// Content-Encoding for decompression
encoding: ContentEncoding,
}
impl std::fmt::Debug for HttpResponse {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HttpResponse")
.field("status", &self.status)
.field("status_reason", &self.status_reason)
.field("headers", &self.headers)
.field("content_length", &self.content_length)
.field("url", &self.url)
.field("remote_addr", &self.remote_addr)
.field("version", &self.version)
.field("body_stream", &"<stream>")
.field("encoding", &self.encoding)
.finish()
}
}
impl HttpResponse {
/// Create a new HttpResponse with an unconsumed body stream
#[allow(clippy::too_many_arguments)]
pub fn new(
status: u16,
status_reason: Option<String>,
headers: Vec<(String, String)>,
request_headers: Vec<(String, String)>,
content_length: Option<u64>,
url: String,
remote_addr: Option<String>,
version: Option<String>,
body_stream: BodyStream,
encoding: ContentEncoding,
) -> Self {
Self {
status,
status_reason,
headers,
request_headers,
content_length,
url,
remote_addr,
version,
body_stream: Some(body_stream),
encoding,
}
}
/// Consume the body and return it as bytes (loads entire body into memory).
/// Also decompresses the body if Content-Encoding is set.
pub async fn bytes(mut self) -> Result<(Vec<u8>, BodyStats)> {
let stream = self.body_stream.take().ok_or_else(|| {
Error::RequestError("Response body has already been consumed".to_string())
})?;
let buf_reader = BufReader::new(stream);
let mut decoder = streaming_decoder(buf_reader, self.encoding);
let mut decompressed = Vec::new();
let mut bytes_read = 0u64;
// Read through the decoder in chunks to track compressed size
let mut buf = [0u8; 8192];
loop {
match decoder.read(&mut buf).await {
Ok(0) => break,
Ok(n) => {
decompressed.extend_from_slice(&buf[..n]);
bytes_read += n as u64;
}
Err(e) => {
return Err(Error::BodyReadError(e.to_string()));
}
}
}
let stats = BodyStats {
// For now, we can't easily track compressed size when streaming through decoder
// Use content_length as an approximation, or decompressed size if identity encoding
size_compressed: self.content_length.unwrap_or(bytes_read),
size_decompressed: decompressed.len() as u64,
};
Ok((decompressed, stats))
}
/// Consume the body and return it as a UTF-8 string.
pub async fn text(self) -> Result<(String, BodyStats)> {
let (bytes, stats) = self.bytes().await?;
let text = String::from_utf8(bytes)
.map_err(|e| Error::RequestError(format!("Response is not valid UTF-8: {}", e)))?;
Ok((text, stats))
}
/// Take the body stream for manual consumption.
/// Returns an AsyncRead that decompresses on-the-fly if Content-Encoding is set.
/// The caller is responsible for reading and processing the stream.
pub fn into_body_stream(&mut self) -> Result<Box<dyn AsyncRead + Unpin + Send>> {
let stream = self.body_stream.take().ok_or_else(|| {
Error::RequestError("Response body has already been consumed".to_string())
})?;
let buf_reader = BufReader::new(stream);
let decoder = streaming_decoder(buf_reader, self.encoding);
Ok(decoder)
}
/// Discard the body without reading it (useful for redirects).
pub async fn drain(mut self) -> Result<()> {
let stream = self.body_stream.take().ok_or_else(|| {
Error::RequestError("Response body has already been consumed".to_string())
})?;
// Just read and discard all bytes
let mut reader = stream;
let mut buf = [0u8; 8192];
loop {
match reader.read(&mut buf).await {
Ok(0) => break,
Ok(_) => continue,
Err(e) => {
return Err(Error::RequestError(format!(
"Failed to drain response body: {}",
e
)));
}
}
}
Ok(())
}
}
/// Trait for sending HTTP requests
#[async_trait]
pub trait HttpSender: Send + Sync {
/// Send an HTTP request and return the response with headers.
/// The body is not consumed until you call bytes(), text(), write_to_file(), or drain().
/// Events are sent through the provided channel.
async fn send(
&self,
request: SendableHttpRequest,
event_tx: mpsc::Sender<HttpResponseEvent>,
) -> Result<HttpResponse>;
}
/// Reqwest-based implementation of HttpSender
pub struct ReqwestSender {
client: Client,
}
impl ReqwestSender {
/// Create a new ReqwestSender with a default client
pub fn new() -> Result<Self> {
let client = Client::builder().build().map_err(Error::Client)?;
Ok(Self { client })
}
/// Create a new ReqwestSender with a custom client
pub fn with_client(client: Client) -> Self {
Self { client }
}
}
#[async_trait]
impl HttpSender for ReqwestSender {
async fn send(
&self,
request: SendableHttpRequest,
event_tx: mpsc::Sender<HttpResponseEvent>,
) -> Result<HttpResponse> {
// Helper to send events (ignores errors if receiver is dropped or channel is full)
let send_event = |event: HttpResponseEvent| {
let _ = event_tx.try_send(event);
};
// Parse the HTTP method
let method = Method::from_bytes(request.method.as_bytes())
.map_err(|e| Error::RequestError(format!("Invalid HTTP method: {}", e)))?;
// Build the request
let mut req_builder = self.client.request(method, &request.url);
// Add headers
for header in request.headers {
req_builder = req_builder.header(&header.0, &header.1);
}
// Configure timeout
if let Some(d) = request.options.timeout
&& !d.is_zero()
{
req_builder = req_builder.timeout(d);
}
// Add body
match request.body {
None => {}
Some(SendableBody::Bytes(bytes)) => {
req_builder = req_builder.body(bytes);
}
Some(SendableBody::Stream(stream)) => {
// Convert AsyncRead stream to reqwest Body
let stream = tokio_util::io::ReaderStream::new(stream);
let body = reqwest::Body::wrap_stream(stream);
req_builder = req_builder.body(body);
}
}
// Send the request
let sendable_req = req_builder.build()?;
send_event(HttpResponseEvent::Setting(
"timeout".to_string(),
if request.options.timeout.unwrap_or_default().is_zero() {
"Infinity".to_string()
} else {
format!("{:?}", request.options.timeout)
},
));
send_event(HttpResponseEvent::SendUrl {
path: sendable_req.url().path().to_string(),
method: sendable_req.method().to_string(),
});
let mut request_headers = Vec::new();
for (name, value) in sendable_req.headers() {
let v = value.to_str().unwrap_or_default().to_string();
request_headers.push((name.to_string(), v.clone()));
send_event(HttpResponseEvent::HeaderUp(name.to_string(), v));
}
send_event(HttpResponseEvent::Info("Sending request to server".to_string()));
// Map some errors to our own, so they look nicer
let response = self.client.execute(sendable_req).await.map_err(|e| {
if reqwest::Error::is_timeout(&e) {
Error::RequestTimeout(
request.options.timeout.unwrap_or(Duration::from_secs(0)).clone(),
)
} else {
Error::Client(e)
}
})?;
let status = response.status().as_u16();
let status_reason = response.status().canonical_reason().map(|s| s.to_string());
let url = response.url().to_string();
let remote_addr = response.remote_addr().map(|a| a.to_string());
let version = Some(version_to_str(&response.version()));
let content_length = response.content_length();
send_event(HttpResponseEvent::ReceiveUrl {
version: response.version(),
status: response.status().to_string(),
});
// Extract headers (use Vec to preserve duplicates like Set-Cookie)
let mut headers = Vec::new();
for (key, value) in response.headers() {
if let Ok(v) = value.to_str() {
send_event(HttpResponseEvent::HeaderDown(key.to_string(), v.to_string()));
headers.push((key.to_string(), v.to_string()));
}
}
// Determine content encoding for decompression
// HTTP headers are case-insensitive, so we need to search for any casing
let encoding = ContentEncoding::from_header(
headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("content-encoding"))
.map(|(_, v)| v.as_str()),
);
// Get the byte stream instead of loading into memory
let byte_stream = response.bytes_stream();
// Convert the stream to an AsyncRead
let stream_reader = StreamReader::new(
byte_stream.map(|result| result.map_err(|e| std::io::Error::other(e))),
);
// Wrap the stream with tracking to emit chunk received events via the same channel
let tracking_reader = TrackingRead::new(stream_reader, event_tx);
let body_stream: BodyStream = Box::pin(tracking_reader);
Ok(HttpResponse::new(
status,
status_reason,
headers,
request_headers,
content_length,
url,
remote_addr,
version,
body_stream,
encoding,
))
}
}
fn version_to_str(version: &Version) -> String {
match *version {
Version::HTTP_09 => "HTTP/0.9".to_string(),
Version::HTTP_10 => "HTTP/1.0".to_string(),
Version::HTTP_11 => "HTTP/1.1".to_string(),
Version::HTTP_2 => "HTTP/2".to_string(),
Version::HTTP_3 => "HTTP/3".to_string(),
_ => "unknown".to_string(),
}
}

View File

@@ -0,0 +1,159 @@
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, ReadBuf};
use tokio::sync::mpsc;
/// A reader that forwards all read data to a channel while also returning it to the caller.
/// This allows capturing request body data as it's being sent.
/// Uses an unbounded channel to ensure all data is captured without blocking the request.
pub struct TeeReader<R> {
inner: R,
tx: mpsc::UnboundedSender<Vec<u8>>,
}
impl<R> TeeReader<R> {
pub fn new(inner: R, tx: mpsc::UnboundedSender<Vec<u8>>) -> Self {
Self { inner, tx }
}
}
impl<R: AsyncRead + Unpin> AsyncRead for TeeReader<R> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let before_len = buf.filled().len();
match Pin::new(&mut self.inner).poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
let after_len = buf.filled().len();
if after_len > before_len {
// Data was read, send a copy to the channel
let data = buf.filled()[before_len..after_len].to_vec();
// Send to unbounded channel - this never blocks
// Ignore error if receiver is closed
let _ = self.tx.send(data);
}
Poll::Ready(Ok(()))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
use tokio::io::AsyncReadExt;
#[tokio::test]
async fn test_tee_reader_captures_all_data() {
let data = b"Hello, World!";
let cursor = Cursor::new(data.to_vec());
let (tx, mut rx) = mpsc::unbounded_channel();
let mut tee = TeeReader::new(cursor, tx);
let mut output = Vec::new();
tee.read_to_end(&mut output).await.unwrap();
// Verify the reader returns the correct data
assert_eq!(output, data);
// Verify the channel received the data
let mut captured = Vec::new();
while let Ok(chunk) = rx.try_recv() {
captured.extend(chunk);
}
assert_eq!(captured, data);
}
#[tokio::test]
async fn test_tee_reader_with_chunked_reads() {
let data = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ";
let cursor = Cursor::new(data.to_vec());
let (tx, mut rx) = mpsc::unbounded_channel();
let mut tee = TeeReader::new(cursor, tx);
// Read in small chunks
let mut buf = [0u8; 5];
let mut output = Vec::new();
loop {
let n = tee.read(&mut buf).await.unwrap();
if n == 0 {
break;
}
output.extend_from_slice(&buf[..n]);
}
// Verify the reader returns the correct data
assert_eq!(output, data);
// Verify the channel received all chunks
let mut captured = Vec::new();
while let Ok(chunk) = rx.try_recv() {
captured.extend(chunk);
}
assert_eq!(captured, data);
}
#[tokio::test]
async fn test_tee_reader_empty_data() {
let data: Vec<u8> = vec![];
let cursor = Cursor::new(data.clone());
let (tx, mut rx) = mpsc::unbounded_channel();
let mut tee = TeeReader::new(cursor, tx);
let mut output = Vec::new();
tee.read_to_end(&mut output).await.unwrap();
// Verify empty output
assert!(output.is_empty());
// Verify no data was sent to channel
assert!(rx.try_recv().is_err());
}
#[tokio::test]
async fn test_tee_reader_works_when_receiver_dropped() {
let data = b"Hello, World!";
let cursor = Cursor::new(data.to_vec());
let (tx, rx) = mpsc::unbounded_channel();
// Drop the receiver before reading
drop(rx);
let mut tee = TeeReader::new(cursor, tx);
let mut output = Vec::new();
// Should still work even though receiver is dropped
tee.read_to_end(&mut output).await.unwrap();
assert_eq!(output, data);
}
#[tokio::test]
async fn test_tee_reader_large_data() {
// Test with 1MB of data
let data: Vec<u8> = (0..1024 * 1024).map(|i| (i % 256) as u8).collect();
let cursor = Cursor::new(data.clone());
let (tx, mut rx) = mpsc::unbounded_channel();
let mut tee = TeeReader::new(cursor, tx);
let mut output = Vec::new();
tee.read_to_end(&mut output).await.unwrap();
// Verify the reader returns the correct data
assert_eq!(output, data);
// Verify the channel received all data
let mut captured = Vec::new();
while let Ok(chunk) = rx.try_recv() {
captured.extend(chunk);
}
assert_eq!(captured, data);
}
}

View File

@@ -0,0 +1,723 @@
use crate::cookies::CookieStore;
use crate::error::Result;
use crate::sender::{HttpResponse, HttpResponseEvent, HttpSender, RedirectBehavior};
use crate::types::SendableHttpRequest;
use log::debug;
use tokio::sync::mpsc;
use tokio::sync::watch::Receiver;
use url::Url;
/// HTTP Transaction that manages the lifecycle of a request, including redirect handling
pub struct HttpTransaction<S: HttpSender> {
sender: S,
max_redirects: usize,
cookie_store: Option<CookieStore>,
}
impl<S: HttpSender> HttpTransaction<S> {
/// Create a new transaction with default settings
pub fn new(sender: S) -> Self {
Self { sender, max_redirects: 10, cookie_store: None }
}
/// Create a new transaction with custom max redirects
pub fn with_max_redirects(sender: S, max_redirects: usize) -> Self {
Self { sender, max_redirects, cookie_store: None }
}
/// Create a new transaction with a cookie store
pub fn with_cookie_store(sender: S, cookie_store: CookieStore) -> Self {
Self { sender, max_redirects: 10, cookie_store: Some(cookie_store) }
}
/// Create a new transaction with custom max redirects and a cookie store
pub fn with_options(
sender: S,
max_redirects: usize,
cookie_store: Option<CookieStore>,
) -> Self {
Self { sender, max_redirects, cookie_store }
}
/// Execute the request with cancellation support.
/// Returns an HttpResponse with unconsumed body - caller decides how to consume it.
/// Events are sent through the provided channel.
pub async fn execute_with_cancellation(
&self,
request: SendableHttpRequest,
mut cancelled_rx: Receiver<bool>,
event_tx: mpsc::Sender<HttpResponseEvent>,
) -> Result<HttpResponse> {
let mut redirect_count = 0;
let mut current_url = request.url;
let mut current_method = request.method;
let mut current_headers = request.headers;
let mut current_body = request.body;
// Helper to send events (ignores errors if receiver is dropped or channel is full)
let send_event = |event: HttpResponseEvent| {
let _ = event_tx.try_send(event);
};
loop {
// Check for cancellation before each request
if *cancelled_rx.borrow() {
return Err(crate::error::Error::RequestCanceledError);
}
// Inject cookies into headers if we have a cookie store
let headers_with_cookies = if let Some(cookie_store) = &self.cookie_store {
let mut headers = current_headers.clone();
if let Ok(url) = Url::parse(&current_url) {
if let Some(cookie_header) = cookie_store.get_cookie_header(&url) {
debug!("Injecting Cookie header: {}", cookie_header);
// Check if there's already a Cookie header and merge if so
if let Some(existing) =
headers.iter_mut().find(|h| h.0.eq_ignore_ascii_case("cookie"))
{
existing.1 = format!("{}; {}", existing.1, cookie_header);
} else {
headers.push(("Cookie".to_string(), cookie_header));
}
}
}
headers
} else {
current_headers.clone()
};
// Build request for this iteration
let req = SendableHttpRequest {
url: current_url.clone(),
method: current_method.clone(),
headers: headers_with_cookies,
body: current_body,
options: request.options.clone(),
};
// Send the request
send_event(HttpResponseEvent::Setting(
"redirects".to_string(),
request.options.follow_redirects.to_string(),
));
// Execute with cancellation support
let response = tokio::select! {
result = self.sender.send(req, event_tx.clone()) => result?,
_ = cancelled_rx.changed() => {
return Err(crate::error::Error::RequestCanceledError);
}
};
// Parse Set-Cookie headers and store cookies
if let Some(cookie_store) = &self.cookie_store {
if let Ok(url) = Url::parse(&current_url) {
let set_cookie_headers: Vec<String> = response
.headers
.iter()
.filter(|(k, _)| k.eq_ignore_ascii_case("set-cookie"))
.map(|(_, v)| v.clone())
.collect();
if !set_cookie_headers.is_empty() {
debug!("Storing {} cookies from response", set_cookie_headers.len());
cookie_store.store_cookies_from_response(&url, &set_cookie_headers);
}
}
}
if !Self::is_redirect(response.status) {
// Not a redirect - return the response for caller to consume body
return Ok(response);
}
if !request.options.follow_redirects {
// Redirects disabled - return the redirect response as-is
return Ok(response);
}
// Check if we've exceeded max redirects
if redirect_count >= self.max_redirects {
// Drain the response before returning error
let _ = response.drain().await;
return Err(crate::error::Error::RequestError(format!(
"Maximum redirect limit ({}) exceeded",
self.max_redirects
)));
}
// Extract Location header before draining (headers are available immediately)
// HTTP headers are case-insensitive, so we need to search for any casing
let location = response
.headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("location"))
.map(|(_, v)| v.clone())
.ok_or_else(|| {
crate::error::Error::RequestError(
"Redirect response missing Location header".to_string(),
)
})?;
// Also get status before draining
let status = response.status;
send_event(HttpResponseEvent::Info("Ignoring the response body".to_string()));
// Drain the redirect response body before following
response.drain().await?;
// Update the request URL
current_url = if location.starts_with("http://") || location.starts_with("https://") {
// Absolute URL
location
} else if location.starts_with('/') {
// Absolute path - need to extract base URL from current request
let base_url = Self::extract_base_url(&current_url)?;
format!("{}{}", base_url, location)
} else {
// Relative path - need to resolve relative to current path
let base_path = Self::extract_base_path(&current_url)?;
format!("{}/{}", base_path, location)
};
// Determine redirect behavior based on status code and method
let behavior = if status == 303 {
// 303 See Other always changes to GET
RedirectBehavior::DropBody
} else if (status == 301 || status == 302) && current_method == "POST" {
// For 301/302, change POST to GET (common browser behavior)
RedirectBehavior::DropBody
} else {
// For 307 and 308, the method and body are preserved
// Also for 301/302 with non-POST methods
RedirectBehavior::Preserve
};
send_event(HttpResponseEvent::Redirect {
url: current_url.clone(),
status,
behavior: behavior.clone(),
});
// Handle method changes for certain redirect codes
if matches!(behavior, RedirectBehavior::DropBody) {
if current_method != "GET" {
current_method = "GET".to_string();
}
// Remove content-related headers
current_headers.retain(|h| {
let name_lower = h.0.to_lowercase();
!name_lower.starts_with("content-") && name_lower != "transfer-encoding"
});
}
// Reset body for next iteration (since it was moved in the send call)
// For redirects that change method to GET or for all redirects since body was consumed
current_body = None;
redirect_count += 1;
}
}
/// Check if a status code indicates a redirect
fn is_redirect(status: u16) -> bool {
matches!(status, 301 | 302 | 303 | 307 | 308)
}
/// Extract the base URL (scheme + host) from a full URL
fn extract_base_url(url: &str) -> Result<String> {
// Find the position after "://"
let scheme_end = url.find("://").ok_or_else(|| {
crate::error::Error::RequestError(format!("Invalid URL format: {}", url))
})?;
// Find the first '/' after the scheme
let path_start = url[scheme_end + 3..].find('/');
if let Some(idx) = path_start {
Ok(url[..scheme_end + 3 + idx].to_string())
} else {
// No path, return entire URL
Ok(url.to_string())
}
}
/// Extract the base path (everything except the last segment) from a URL
fn extract_base_path(url: &str) -> Result<String> {
if let Some(last_slash) = url.rfind('/') {
// Don't include the trailing slash if it's part of the host
if url[..last_slash].ends_with("://") || url[..last_slash].ends_with(':') {
Ok(url.to_string())
} else {
Ok(url[..last_slash].to_string())
}
} else {
Ok(url.to_string())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::decompress::ContentEncoding;
use crate::sender::{HttpResponseEvent, HttpSender};
use async_trait::async_trait;
use std::pin::Pin;
use std::sync::Arc;
use tokio::io::AsyncRead;
use tokio::sync::Mutex;
/// Mock sender for testing
struct MockSender {
responses: Arc<Mutex<Vec<MockResponse>>>,
}
struct MockResponse {
status: u16,
headers: Vec<(String, String)>,
body: Vec<u8>,
}
impl MockSender {
fn new(responses: Vec<MockResponse>) -> Self {
Self { responses: Arc::new(Mutex::new(responses)) }
}
}
#[async_trait]
impl HttpSender for MockSender {
async fn send(
&self,
_request: SendableHttpRequest,
_event_tx: mpsc::Sender<HttpResponseEvent>,
) -> Result<HttpResponse> {
let mut responses = self.responses.lock().await;
if responses.is_empty() {
Err(crate::error::Error::RequestError("No more mock responses".to_string()))
} else {
let mock = responses.remove(0);
// Create a simple in-memory stream from the body
let body_stream: Pin<Box<dyn AsyncRead + Send>> =
Box::pin(std::io::Cursor::new(mock.body));
Ok(HttpResponse::new(
mock.status,
None, // status_reason
mock.headers,
Vec::new(),
None, // content_length
"https://example.com".to_string(), // url
None, // remote_addr
Some("HTTP/1.1".to_string()), // version
body_stream,
ContentEncoding::Identity,
))
}
}
}
#[tokio::test]
async fn test_transaction_no_redirect() {
let response = MockResponse { status: 200, headers: Vec::new(), body: b"OK".to_vec() };
let sender = MockSender::new(vec![response]);
let transaction = HttpTransaction::new(sender);
let request = SendableHttpRequest {
url: "https://example.com".to_string(),
method: "GET".to_string(),
headers: vec![],
..Default::default()
};
let (_tx, rx) = tokio::sync::watch::channel(false);
let (event_tx, _event_rx) = mpsc::channel(100);
let result = transaction.execute_with_cancellation(request, rx, event_tx).await.unwrap();
assert_eq!(result.status, 200);
// Consume the body to verify it
let (body, _) = result.bytes().await.unwrap();
assert_eq!(body, b"OK");
}
#[tokio::test]
async fn test_transaction_single_redirect() {
let redirect_headers = vec![("Location".to_string(), "https://example.com/new".to_string())];
let responses = vec![
MockResponse { status: 302, headers: redirect_headers, body: vec![] },
MockResponse { status: 200, headers: Vec::new(), body: b"Final".to_vec() },
];
let sender = MockSender::new(responses);
let transaction = HttpTransaction::new(sender);
let request = SendableHttpRequest {
url: "https://example.com/old".to_string(),
method: "GET".to_string(),
options: crate::types::SendableHttpRequestOptions {
follow_redirects: true,
..Default::default()
},
..Default::default()
};
let (_tx, rx) = tokio::sync::watch::channel(false);
let (event_tx, _event_rx) = mpsc::channel(100);
let result = transaction.execute_with_cancellation(request, rx, event_tx).await.unwrap();
assert_eq!(result.status, 200);
let (body, _) = result.bytes().await.unwrap();
assert_eq!(body, b"Final");
}
#[tokio::test]
async fn test_transaction_max_redirects_exceeded() {
let redirect_headers = vec![("Location".to_string(), "https://example.com/loop".to_string())];
// Create more redirects than allowed
let responses: Vec<MockResponse> = (0..12)
.map(|_| MockResponse { status: 302, headers: redirect_headers.clone(), body: vec![] })
.collect();
let sender = MockSender::new(responses);
let transaction = HttpTransaction::with_max_redirects(sender, 10);
let request = SendableHttpRequest {
url: "https://example.com/start".to_string(),
method: "GET".to_string(),
options: crate::types::SendableHttpRequestOptions {
follow_redirects: true,
..Default::default()
},
..Default::default()
};
let (_tx, rx) = tokio::sync::watch::channel(false);
let (event_tx, _event_rx) = mpsc::channel(100);
let result = transaction.execute_with_cancellation(request, rx, event_tx).await;
if let Err(crate::error::Error::RequestError(msg)) = result {
assert!(msg.contains("Maximum redirect limit"));
} else {
panic!("Expected RequestError with max redirect message. Got {result:?}");
}
}
#[test]
fn test_is_redirect() {
assert!(HttpTransaction::<MockSender>::is_redirect(301));
assert!(HttpTransaction::<MockSender>::is_redirect(302));
assert!(HttpTransaction::<MockSender>::is_redirect(303));
assert!(HttpTransaction::<MockSender>::is_redirect(307));
assert!(HttpTransaction::<MockSender>::is_redirect(308));
assert!(!HttpTransaction::<MockSender>::is_redirect(200));
assert!(!HttpTransaction::<MockSender>::is_redirect(404));
assert!(!HttpTransaction::<MockSender>::is_redirect(500));
}
#[test]
fn test_extract_base_url() {
let result =
HttpTransaction::<MockSender>::extract_base_url("https://example.com/path/to/resource");
assert_eq!(result.unwrap(), "https://example.com");
let result = HttpTransaction::<MockSender>::extract_base_url("http://localhost:8080/api");
assert_eq!(result.unwrap(), "http://localhost:8080");
let result = HttpTransaction::<MockSender>::extract_base_url("invalid-url");
assert!(result.is_err());
}
#[test]
fn test_extract_base_path() {
let result = HttpTransaction::<MockSender>::extract_base_path(
"https://example.com/path/to/resource",
);
assert_eq!(result.unwrap(), "https://example.com/path/to");
let result = HttpTransaction::<MockSender>::extract_base_path("https://example.com/single");
assert_eq!(result.unwrap(), "https://example.com");
let result = HttpTransaction::<MockSender>::extract_base_path("https://example.com/");
assert_eq!(result.unwrap(), "https://example.com");
}
#[tokio::test]
async fn test_cookie_injection() {
// Create a mock sender that verifies the Cookie header was injected
struct CookieVerifyingSender {
expected_cookie: String,
}
#[async_trait]
impl HttpSender for CookieVerifyingSender {
async fn send(
&self,
request: SendableHttpRequest,
_event_tx: mpsc::Sender<HttpResponseEvent>,
) -> Result<HttpResponse> {
// Verify the Cookie header was injected
let cookie_header =
request.headers.iter().find(|(k, _)| k.eq_ignore_ascii_case("cookie"));
assert!(cookie_header.is_some(), "Cookie header should be present");
assert!(
cookie_header.unwrap().1.contains(&self.expected_cookie),
"Cookie header should contain expected value"
);
let body_stream: Pin<Box<dyn AsyncRead + Send>> =
Box::pin(std::io::Cursor::new(vec![]));
Ok(HttpResponse::new(
200,
None,
Vec::new(),
Vec::new(),
None,
"https://example.com".to_string(),
None,
Some("HTTP/1.1".to_string()),
body_stream,
ContentEncoding::Identity,
))
}
}
use yaak_models::models::{Cookie, CookieDomain, CookieExpires};
// Create a cookie store with a test cookie
let cookie = Cookie {
raw_cookie: "session=abc123".to_string(),
domain: CookieDomain::HostOnly("example.com".to_string()),
expires: CookieExpires::SessionEnd,
path: ("/".to_string(), false),
};
let cookie_store = CookieStore::from_cookies(vec![cookie]);
let sender = CookieVerifyingSender { expected_cookie: "session=abc123".to_string() };
let transaction = HttpTransaction::with_cookie_store(sender, cookie_store);
let request = SendableHttpRequest {
url: "https://example.com/api".to_string(),
method: "GET".to_string(),
headers: vec![],
..Default::default()
};
let (_tx, rx) = tokio::sync::watch::channel(false);
let (event_tx, _event_rx) = mpsc::channel(100);
let result = transaction.execute_with_cancellation(request, rx, event_tx).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_set_cookie_parsing() {
// Create a cookie store
let cookie_store = CookieStore::new();
// Mock sender that returns a Set-Cookie header
struct SetCookieSender;
#[async_trait]
impl HttpSender for SetCookieSender {
async fn send(
&self,
_request: SendableHttpRequest,
_event_tx: mpsc::Sender<HttpResponseEvent>,
) -> Result<HttpResponse> {
let headers = vec![("set-cookie".to_string(), "session=xyz789; Path=/".to_string())];
let body_stream: Pin<Box<dyn AsyncRead + Send>> =
Box::pin(std::io::Cursor::new(vec![]));
Ok(HttpResponse::new(
200,
None,
headers,
Vec::new(),
None,
"https://example.com".to_string(),
None,
Some("HTTP/1.1".to_string()),
body_stream,
ContentEncoding::Identity,
))
}
}
let sender = SetCookieSender;
let transaction = HttpTransaction::with_cookie_store(sender, cookie_store.clone());
let request = SendableHttpRequest {
url: "https://example.com/login".to_string(),
method: "POST".to_string(),
headers: vec![],
..Default::default()
};
let (_tx, rx) = tokio::sync::watch::channel(false);
let (event_tx, _event_rx) = mpsc::channel(100);
let result = transaction.execute_with_cancellation(request, rx, event_tx).await;
assert!(result.is_ok());
// Verify the cookie was stored
let cookies = cookie_store.get_all_cookies();
assert_eq!(cookies.len(), 1);
assert!(cookies[0].raw_cookie.contains("session=xyz789"));
}
#[tokio::test]
async fn test_multiple_set_cookie_headers() {
// Create a cookie store
let cookie_store = CookieStore::new();
// Mock sender that returns multiple Set-Cookie headers
struct MultiSetCookieSender;
#[async_trait]
impl HttpSender for MultiSetCookieSender {
async fn send(
&self,
_request: SendableHttpRequest,
_event_tx: mpsc::Sender<HttpResponseEvent>,
) -> Result<HttpResponse> {
// Multiple Set-Cookie headers (this is standard HTTP behavior)
let headers = vec![
("set-cookie".to_string(), "session=abc123; Path=/".to_string()),
("set-cookie".to_string(), "user_id=42; Path=/".to_string()),
("set-cookie".to_string(), "preferences=dark; Path=/; Max-Age=86400".to_string()),
];
let body_stream: Pin<Box<dyn AsyncRead + Send>> =
Box::pin(std::io::Cursor::new(vec![]));
Ok(HttpResponse::new(
200,
None,
headers,
Vec::new(),
None,
"https://example.com".to_string(),
None,
Some("HTTP/1.1".to_string()),
body_stream,
ContentEncoding::Identity,
))
}
}
let sender = MultiSetCookieSender;
let transaction = HttpTransaction::with_cookie_store(sender, cookie_store.clone());
let request = SendableHttpRequest {
url: "https://example.com/login".to_string(),
method: "POST".to_string(),
headers: vec![],
..Default::default()
};
let (_tx, rx) = tokio::sync::watch::channel(false);
let (event_tx, _event_rx) = mpsc::channel(100);
let result = transaction.execute_with_cancellation(request, rx, event_tx).await;
assert!(result.is_ok());
// Verify all three cookies were stored
let cookies = cookie_store.get_all_cookies();
assert_eq!(cookies.len(), 3, "All three Set-Cookie headers should be parsed and stored");
let cookie_values: Vec<&str> = cookies.iter().map(|c| c.raw_cookie.as_str()).collect();
assert!(
cookie_values.iter().any(|c| c.contains("session=abc123")),
"session cookie should be stored"
);
assert!(
cookie_values.iter().any(|c| c.contains("user_id=42")),
"user_id cookie should be stored"
);
assert!(
cookie_values.iter().any(|c| c.contains("preferences=dark")),
"preferences cookie should be stored"
);
}
#[tokio::test]
async fn test_cookies_across_redirects() {
use std::sync::atomic::{AtomicUsize, Ordering};
// Create a cookie store
let cookie_store = CookieStore::new();
// Track request count
let request_count = Arc::new(AtomicUsize::new(0));
let request_count_clone = request_count.clone();
struct RedirectWithCookiesSender {
request_count: Arc<AtomicUsize>,
}
#[async_trait]
impl HttpSender for RedirectWithCookiesSender {
async fn send(
&self,
request: SendableHttpRequest,
_event_tx: mpsc::Sender<HttpResponseEvent>,
) -> Result<HttpResponse> {
let count = self.request_count.fetch_add(1, Ordering::SeqCst);
let (status, headers) = if count == 0 {
// First request: return redirect with Set-Cookie
let h = vec![
("location".to_string(), "https://example.com/final".to_string()),
("set-cookie".to_string(), "redirect_cookie=value1".to_string()),
];
(302, h)
} else {
// Second request: verify cookie was sent
let cookie_header =
request.headers.iter().find(|(k, _)| k.eq_ignore_ascii_case("cookie"));
assert!(cookie_header.is_some(), "Cookie header should be present on redirect");
assert!(
cookie_header.unwrap().1.contains("redirect_cookie=value1"),
"Redirect cookie should be included"
);
(200, Vec::new())
};
let body_stream: Pin<Box<dyn AsyncRead + Send>> =
Box::pin(std::io::Cursor::new(vec![]));
Ok(HttpResponse::new(
status,
None,
headers,
Vec::new(),
None,
"https://example.com".to_string(),
None,
Some("HTTP/1.1".to_string()),
body_stream,
ContentEncoding::Identity,
))
}
}
let sender = RedirectWithCookiesSender { request_count: request_count_clone };
let transaction = HttpTransaction::with_cookie_store(sender, cookie_store);
let request = SendableHttpRequest {
url: "https://example.com/start".to_string(),
method: "GET".to_string(),
headers: vec![],
options: crate::types::SendableHttpRequestOptions {
follow_redirects: true,
..Default::default()
},
..Default::default()
};
let (_tx, rx) = tokio::sync::watch::channel(false);
let (event_tx, _event_rx) = mpsc::channel(100);
let result = transaction.execute_with_cancellation(request, rx, event_tx).await;
assert!(result.is_ok());
assert_eq!(request_count.load(Ordering::SeqCst), 2);
}
}

View File

@@ -0,0 +1,981 @@
use crate::chained_reader::{ChainedReader, ReaderType};
use crate::error::Error::RequestError;
use crate::error::Result;
use crate::path_placeholders::apply_path_placeholders;
use crate::proto::ensure_proto;
use bytes::Bytes;
use log::warn;
use std::collections::BTreeMap;
use std::pin::Pin;
use std::time::Duration;
use tokio::io::AsyncRead;
use yaak_common::serde::{get_bool, get_str, get_str_map};
use yaak_models::models::HttpRequest;
pub(crate) const MULTIPART_BOUNDARY: &str = "------YaakFormBoundary";
pub enum SendableBody {
Bytes(Bytes),
Stream(Pin<Box<dyn AsyncRead + Send + 'static>>),
}
enum SendableBodyWithMeta {
Bytes(Bytes),
Stream {
data: Pin<Box<dyn AsyncRead + Send + 'static>>,
content_length: Option<usize>,
},
}
impl From<SendableBodyWithMeta> for SendableBody {
fn from(value: SendableBodyWithMeta) -> Self {
match value {
SendableBodyWithMeta::Bytes(b) => SendableBody::Bytes(b),
SendableBodyWithMeta::Stream { data, .. } => SendableBody::Stream(data),
}
}
}
#[derive(Default)]
pub struct SendableHttpRequest {
pub url: String,
pub method: String,
pub headers: Vec<(String, String)>,
pub body: Option<SendableBody>,
pub options: SendableHttpRequestOptions,
}
#[derive(Default, Clone)]
pub struct SendableHttpRequestOptions {
pub timeout: Option<Duration>,
pub follow_redirects: bool,
}
impl SendableHttpRequest {
pub async fn from_http_request(
r: &HttpRequest,
options: SendableHttpRequestOptions,
) -> Result<Self> {
let initial_headers = build_headers(r);
let (body, headers) = build_body(&r.method, &r.body_type, &r.body, initial_headers).await?;
Ok(Self {
url: build_url(r),
method: r.method.to_uppercase(),
headers,
body: body.into(),
options,
})
}
pub fn insert_header(&mut self, header: (String, String)) {
if let Some(existing) =
self.headers.iter_mut().find(|h| h.0.to_lowercase() == header.0.to_lowercase())
{
existing.1 = header.1;
} else {
self.headers.push(header);
}
}
}
pub fn append_query_params(url: &str, params: Vec<(String, String)>) -> String {
let url_string = url.to_string();
if params.is_empty() {
return url.to_string();
}
// Build query string
let query_string = params
.iter()
.map(|(name, value)| {
format!("{}={}", urlencoding::encode(name), urlencoding::encode(value))
})
.collect::<Vec<_>>()
.join("&");
// Split URL into parts: base URL, query, and fragment
let (base_and_query, fragment) = if let Some(hash_pos) = url_string.find('#') {
let (before_hash, after_hash) = url_string.split_at(hash_pos);
(before_hash.to_string(), Some(after_hash.to_string()))
} else {
(url_string, None)
};
// Now handle query parameters on the base URL (without fragment)
let mut result = if base_and_query.contains('?') {
// Check if there's already a query string after the '?'
let parts: Vec<&str> = base_and_query.splitn(2, '?').collect();
if parts.len() == 2 && !parts[1].trim().is_empty() {
// Append with & if there are existing parameters
format!("{}&{}", base_and_query, query_string)
} else {
// Just append the new parameters directly (URL ends with '?')
format!("{}{}", base_and_query, query_string)
}
} else {
// No existing query parameters, add with '?'
format!("{}?{}", base_and_query, query_string)
};
// Re-append the fragment if it exists
if let Some(fragment) = fragment {
result.push_str(&fragment);
}
result
}
fn build_url(r: &HttpRequest) -> String {
let (url_string, params) = apply_path_placeholders(&ensure_proto(&r.url), &r.url_parameters);
append_query_params(
&url_string,
params
.iter()
.filter(|p| p.enabled && !p.name.is_empty())
.map(|p| (p.name.clone(), p.value.clone()))
.collect(),
)
}
fn build_headers(r: &HttpRequest) -> Vec<(String, String)> {
r.headers
.iter()
.filter_map(|h| {
if h.enabled && !h.name.is_empty() {
Some((h.name.clone(), h.value.clone()))
} else {
None
}
})
.collect()
}
async fn build_body(
method: &str,
body_type: &Option<String>,
body: &BTreeMap<String, serde_json::Value>,
headers: Vec<(String, String)>,
) -> Result<(Option<SendableBody>, Vec<(String, String)>)> {
let body_type = match &body_type {
None => return Ok((None, headers)),
Some(t) => t,
};
let (body, content_type) = match body_type.as_str() {
"binary" => (build_binary_body(&body).await?, None),
"graphql" => (build_graphql_body(&method, &body), Some("application/json".to_string())),
"application/x-www-form-urlencoded" => {
(build_form_body(&body), Some("application/x-www-form-urlencoded".to_string()))
}
"multipart/form-data" => build_multipart_body(&body, &headers).await?,
_ if body.contains_key("text") => (build_text_body(&body), None),
t => {
warn!("Unsupported body type: {}", t);
(None, None)
}
};
// Add or update the Content-Type header
let mut headers = headers;
if let Some(ct) = content_type {
if let Some(existing) = headers.iter_mut().find(|h| h.0.to_lowercase() == "content-type") {
existing.1 = ct;
} else {
headers.push(("Content-Type".to_string(), ct));
}
}
// Check if Transfer-Encoding: chunked is already set
let has_chunked_encoding = headers.iter().any(|h| {
h.0.to_lowercase() == "transfer-encoding" && h.1.to_lowercase().contains("chunked")
});
// Add a Content-Length header only if chunked encoding is not being used
if !has_chunked_encoding {
let content_length = match body {
Some(SendableBodyWithMeta::Bytes(ref bytes)) => Some(bytes.len()),
Some(SendableBodyWithMeta::Stream { content_length, .. }) => content_length,
None => None,
};
if let Some(cl) = content_length {
headers.push(("Content-Length".to_string(), cl.to_string()));
}
}
Ok((body.map(|b| b.into()), headers))
}
fn build_form_body(body: &BTreeMap<String, serde_json::Value>) -> Option<SendableBodyWithMeta> {
let form_params = match body.get("form").map(|f| f.as_array()) {
Some(Some(f)) => f,
_ => return None,
};
let mut body = String::new();
for p in form_params {
let enabled = get_bool(p, "enabled", true);
let name = get_str(p, "name");
if !enabled || name.is_empty() {
continue;
}
let value = get_str(p, "value");
if !body.is_empty() {
body.push('&');
}
body.push_str(&urlencoding::encode(&name));
body.push('=');
body.push_str(&urlencoding::encode(&value));
}
if body.is_empty() { None } else { Some(SendableBodyWithMeta::Bytes(Bytes::from(body))) }
}
async fn build_binary_body(
body: &BTreeMap<String, serde_json::Value>,
) -> Result<Option<SendableBodyWithMeta>> {
let file_path = match body.get("filePath").map(|f| f.as_str()) {
Some(Some(f)) => f,
_ => return Ok(None),
};
// Open a file for streaming
let content_length = tokio::fs::metadata(file_path)
.await
.map_err(|e| RequestError(format!("Failed to get file metadata: {}", e)))?
.len();
let file = tokio::fs::File::open(file_path)
.await
.map_err(|e| RequestError(format!("Failed to open file: {}", e)))?;
Ok(Some(SendableBodyWithMeta::Stream {
data: Box::pin(file),
content_length: Some(content_length as usize),
}))
}
fn build_text_body(body: &BTreeMap<String, serde_json::Value>) -> Option<SendableBodyWithMeta> {
let text = get_str_map(body, "text");
if text.is_empty() {
None
} else {
Some(SendableBodyWithMeta::Bytes(Bytes::from(text.to_string())))
}
}
fn build_graphql_body(
method: &str,
body: &BTreeMap<String, serde_json::Value>,
) -> Option<SendableBodyWithMeta> {
let query = get_str_map(body, "query");
let variables = get_str_map(body, "variables");
if method.to_lowercase() == "get" {
// GraphQL GET requests use query parameters, not a body
return None;
}
let body = if variables.trim().is_empty() {
format!(r#"{{"query":{}}}"#, serde_json::to_string(&query).unwrap_or_default())
} else {
format!(
r#"{{"query":{},"variables":{}}}"#,
serde_json::to_string(&query).unwrap_or_default(),
variables
)
};
Some(SendableBodyWithMeta::Bytes(Bytes::from(body)))
}
async fn build_multipart_body(
body: &BTreeMap<String, serde_json::Value>,
headers: &Vec<(String, String)>,
) -> Result<(Option<SendableBodyWithMeta>, Option<String>)> {
let boundary = extract_boundary_from_headers(headers);
let form_params = match body.get("form").map(|f| f.as_array()) {
Some(Some(f)) => f,
_ => return Ok((None, None)),
};
// Build a list of readers for streaming and calculate total content length
let mut readers: Vec<ReaderType> = Vec::new();
let mut has_content = false;
let mut total_size: usize = 0;
for p in form_params {
let enabled = get_bool(p, "enabled", true);
let name = get_str(p, "name");
if !enabled || name.is_empty() {
continue;
}
has_content = true;
// Add boundary delimiter
let boundary_bytes = format!("--{}\r\n", boundary).into_bytes();
total_size += boundary_bytes.len();
readers.push(ReaderType::Bytes(boundary_bytes));
let file_path = get_str(p, "file");
let value = get_str(p, "value");
let content_type = get_str(p, "contentType");
if file_path.is_empty() {
// Text field
let header = if !content_type.is_empty() {
format!(
"Content-Disposition: form-data; name=\"{}\"\r\nContent-Type: {}\r\n\r\n{}",
name, content_type, value
)
} else {
format!("Content-Disposition: form-data; name=\"{}\"\r\n\r\n{}", name, value)
};
let header_bytes = header.into_bytes();
total_size += header_bytes.len();
readers.push(ReaderType::Bytes(header_bytes));
} else {
// File field - validate that file exists first
if !tokio::fs::try_exists(file_path).await.unwrap_or(false) {
return Err(RequestError(format!("File not found: {}", file_path)));
}
// Get file size for content length calculation
let file_metadata = tokio::fs::metadata(file_path)
.await
.map_err(|e| RequestError(format!("Failed to get file metadata: {}", e)))?;
let file_size = file_metadata.len() as usize;
let filename = get_str(p, "filename");
let filename = if filename.is_empty() {
std::path::Path::new(file_path)
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("file")
} else {
filename
};
// Add content type
let mime_type = if !content_type.is_empty() {
content_type.to_string()
} else {
// Guess mime type from file extension
mime_guess::from_path(file_path).first_or_octet_stream().to_string()
};
let header = format!(
"Content-Disposition: form-data; name=\"{}\"; filename=\"{}\"\r\nContent-Type: {}\r\n\r\n",
name, filename, mime_type
);
let header_bytes = header.into_bytes();
total_size += header_bytes.len();
total_size += file_size;
readers.push(ReaderType::Bytes(header_bytes));
// Add a file path for streaming
readers.push(ReaderType::FilePath(file_path.to_string()));
}
let line_ending = b"\r\n".to_vec();
total_size += line_ending.len();
readers.push(ReaderType::Bytes(line_ending));
}
if has_content {
// Add the final boundary
let final_boundary = format!("--{}--\r\n", boundary).into_bytes();
total_size += final_boundary.len();
readers.push(ReaderType::Bytes(final_boundary));
let content_type = format!("multipart/form-data; boundary={}", boundary);
let stream = ChainedReader::new(readers);
Ok((
Some(SendableBodyWithMeta::Stream {
data: Box::pin(stream),
content_length: Some(total_size),
}),
Some(content_type),
))
} else {
Ok((None, None))
}
}
fn extract_boundary_from_headers(headers: &Vec<(String, String)>) -> String {
headers
.iter()
.find(|h| h.0.to_lowercase() == "content-type")
.and_then(|h| {
// Extract boundary from the Content-Type header (e.g., "multipart/form-data; boundary=xyz")
h.1.split(';')
.find(|part| part.trim().starts_with("boundary="))
.and_then(|boundary_part| boundary_part.split('=').nth(1))
.map(|b| b.trim().to_string())
})
.unwrap_or_else(|| MULTIPART_BOUNDARY.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use serde_json::json;
use std::collections::BTreeMap;
use yaak_models::models::{HttpRequest, HttpUrlParameter};
#[test]
fn test_build_url_no_params() {
let r = HttpRequest {
url: "https://example.com/api".to_string(),
url_parameters: vec![],
..Default::default()
};
let result = build_url(&r);
assert_eq!(result, "https://example.com/api");
}
#[test]
fn test_build_url_with_params() {
let r = HttpRequest {
url: "https://example.com/api".to_string(),
url_parameters: vec![
HttpUrlParameter {
enabled: true,
name: "foo".to_string(),
value: "bar".to_string(),
id: None,
},
HttpUrlParameter {
enabled: true,
name: "baz".to_string(),
value: "qux".to_string(),
id: None,
},
],
..Default::default()
};
let result = build_url(&r);
assert_eq!(result, "https://example.com/api?foo=bar&baz=qux");
}
#[test]
fn test_build_url_with_disabled_params() {
let r = HttpRequest {
url: "https://example.com/api".to_string(),
url_parameters: vec![
HttpUrlParameter {
enabled: false,
name: "disabled".to_string(),
value: "value".to_string(),
id: None,
},
HttpUrlParameter {
enabled: true,
name: "enabled".to_string(),
value: "value".to_string(),
id: None,
},
],
..Default::default()
};
let result = build_url(&r);
assert_eq!(result, "https://example.com/api?enabled=value");
}
#[test]
fn test_build_url_with_existing_query() {
let r = HttpRequest {
url: "https://example.com/api?existing=param".to_string(),
url_parameters: vec![HttpUrlParameter {
enabled: true,
name: "new".to_string(),
value: "value".to_string(),
id: None,
}],
..Default::default()
};
let result = build_url(&r);
assert_eq!(result, "https://example.com/api?existing=param&new=value");
}
#[test]
fn test_build_url_with_empty_existing_query() {
let r = HttpRequest {
url: "https://example.com/api?".to_string(),
url_parameters: vec![HttpUrlParameter {
enabled: true,
name: "new".to_string(),
value: "value".to_string(),
id: None,
}],
..Default::default()
};
let result = build_url(&r);
assert_eq!(result, "https://example.com/api?new=value");
}
#[test]
fn test_build_url_with_special_chars() {
let r = HttpRequest {
url: "https://example.com/api".to_string(),
url_parameters: vec![HttpUrlParameter {
enabled: true,
name: "special chars!@#".to_string(),
value: "value with spaces & symbols".to_string(),
id: None,
}],
..Default::default()
};
let result = build_url(&r);
assert_eq!(
result,
"https://example.com/api?special%20chars%21%40%23=value%20with%20spaces%20%26%20symbols"
);
}
#[test]
fn test_build_url_adds_protocol() {
let r = HttpRequest {
url: "example.com/api".to_string(),
url_parameters: vec![HttpUrlParameter {
enabled: true,
name: "foo".to_string(),
value: "bar".to_string(),
id: None,
}],
..Default::default()
};
let result = build_url(&r);
// ensure_proto defaults to http:// for regular domains
assert_eq!(result, "http://example.com/api?foo=bar");
}
#[test]
fn test_build_url_adds_https_for_dev_domain() {
let r = HttpRequest {
url: "example.dev/api".to_string(),
url_parameters: vec![HttpUrlParameter {
enabled: true,
name: "foo".to_string(),
value: "bar".to_string(),
id: None,
}],
..Default::default()
};
let result = build_url(&r);
// .dev domains force https
assert_eq!(result, "https://example.dev/api?foo=bar");
}
#[test]
fn test_build_url_with_fragment() {
let r = HttpRequest {
url: "https://example.com/api#section".to_string(),
url_parameters: vec![HttpUrlParameter {
enabled: true,
name: "foo".to_string(),
value: "bar".to_string(),
id: None,
}],
..Default::default()
};
let result = build_url(&r);
assert_eq!(result, "https://example.com/api?foo=bar#section");
}
#[test]
fn test_build_url_with_existing_query_and_fragment() {
let r = HttpRequest {
url: "https://yaak.app?foo=bar#some-hash".to_string(),
url_parameters: vec![HttpUrlParameter {
enabled: true,
name: "baz".to_string(),
value: "qux".to_string(),
id: None,
}],
..Default::default()
};
let result = build_url(&r);
assert_eq!(result, "https://yaak.app?foo=bar&baz=qux#some-hash");
}
#[test]
fn test_build_url_with_empty_query_and_fragment() {
let r = HttpRequest {
url: "https://example.com/api?#section".to_string(),
url_parameters: vec![HttpUrlParameter {
enabled: true,
name: "foo".to_string(),
value: "bar".to_string(),
id: None,
}],
..Default::default()
};
let result = build_url(&r);
assert_eq!(result, "https://example.com/api?foo=bar#section");
}
#[test]
fn test_build_url_with_fragment_containing_special_chars() {
let r = HttpRequest {
url: "https://example.com#section/with/slashes?and=fake&query".to_string(),
url_parameters: vec![HttpUrlParameter {
enabled: true,
name: "real".to_string(),
value: "param".to_string(),
id: None,
}],
..Default::default()
};
let result = build_url(&r);
assert_eq!(result, "https://example.com?real=param#section/with/slashes?and=fake&query");
}
#[test]
fn test_build_url_preserves_empty_fragment() {
let r = HttpRequest {
url: "https://example.com/api#".to_string(),
url_parameters: vec![HttpUrlParameter {
enabled: true,
name: "foo".to_string(),
value: "bar".to_string(),
id: None,
}],
..Default::default()
};
let result = build_url(&r);
assert_eq!(result, "https://example.com/api?foo=bar#");
}
#[test]
fn test_build_url_with_multiple_fragments() {
// Testing edge case where the URL has multiple # characters (though technically invalid)
let r = HttpRequest {
url: "https://example.com#section#subsection".to_string(),
url_parameters: vec![HttpUrlParameter {
enabled: true,
name: "foo".to_string(),
value: "bar".to_string(),
id: None,
}],
..Default::default()
};
let result = build_url(&r);
// Should treat everything after first # as fragment
assert_eq!(result, "https://example.com?foo=bar#section#subsection");
}
#[tokio::test]
async fn test_text_body() {
let mut body = BTreeMap::new();
body.insert("text".to_string(), json!("Hello, World!"));
let result = build_text_body(&body);
match result {
Some(SendableBodyWithMeta::Bytes(bytes)) => {
assert_eq!(bytes, Bytes::from("Hello, World!"))
}
_ => panic!("Expected Some(SendableBody::Bytes)"),
}
}
#[tokio::test]
async fn test_text_body_empty() {
let mut body = BTreeMap::new();
body.insert("text".to_string(), json!(""));
let result = build_text_body(&body);
assert!(result.is_none());
}
#[tokio::test]
async fn test_text_body_missing() {
let body = BTreeMap::new();
let result = build_text_body(&body);
assert!(result.is_none());
}
#[tokio::test]
async fn test_form_urlencoded_body() -> Result<()> {
let mut body = BTreeMap::new();
body.insert(
"form".to_string(),
json!([
{ "enabled": true, "name": "basic", "value": "aaa"},
{ "enabled": true, "name": "fUnkey Stuff!$*#(", "value": "*)%&#$)@ *$#)@&"},
{ "enabled": false, "name": "disabled", "value": "won't show"},
]),
);
let result = build_form_body(&body);
match result {
Some(SendableBodyWithMeta::Bytes(bytes)) => {
let expected = "basic=aaa&fUnkey%20Stuff%21%24%2A%23%28=%2A%29%25%26%23%24%29%40%20%2A%24%23%29%40%26";
assert_eq!(bytes, Bytes::from(expected));
}
_ => panic!("Expected Some(SendableBody::Bytes)"),
}
Ok(())
}
#[tokio::test]
async fn test_form_urlencoded_body_missing_form() {
let body = BTreeMap::new();
let result = build_form_body(&body);
assert!(result.is_none());
}
#[tokio::test]
async fn test_binary_body() -> Result<()> {
let mut body = BTreeMap::new();
body.insert("filePath".to_string(), json!("./tests/test.txt"));
let result = build_binary_body(&body).await?;
assert!(matches!(result, Some(SendableBodyWithMeta::Stream { .. })));
Ok(())
}
#[tokio::test]
async fn test_binary_body_file_not_found() {
let mut body = BTreeMap::new();
body.insert("filePath".to_string(), json!("./nonexistent/file.txt"));
let result = build_binary_body(&body).await;
assert!(result.is_err());
if let Err(e) = result {
assert!(matches!(e, RequestError(_)));
}
}
#[tokio::test]
async fn test_graphql_body_with_variables() {
let mut body = BTreeMap::new();
body.insert("query".to_string(), json!("{ user(id: $id) { name } }"));
body.insert("variables".to_string(), json!(r#"{"id": "123"}"#));
let result = build_graphql_body("POST", &body);
match result {
Some(SendableBodyWithMeta::Bytes(bytes)) => {
let expected =
r#"{"query":"{ user(id: $id) { name } }","variables":{"id": "123"}}"#;
assert_eq!(bytes, Bytes::from(expected));
}
_ => panic!("Expected Some(SendableBody::Bytes)"),
}
}
#[tokio::test]
async fn test_graphql_body_without_variables() {
let mut body = BTreeMap::new();
body.insert("query".to_string(), json!("{ users { name } }"));
body.insert("variables".to_string(), json!(""));
let result = build_graphql_body("POST", &body);
match result {
Some(SendableBodyWithMeta::Bytes(bytes)) => {
let expected = r#"{"query":"{ users { name } }"}"#;
assert_eq!(bytes, Bytes::from(expected));
}
_ => panic!("Expected Some(SendableBody::Bytes)"),
}
}
#[tokio::test]
async fn test_graphql_body_get_method() {
let mut body = BTreeMap::new();
body.insert("query".to_string(), json!("{ users { name } }"));
let result = build_graphql_body("GET", &body);
assert!(result.is_none());
}
#[tokio::test]
async fn test_multipart_body_text_fields() -> Result<()> {
let mut body = BTreeMap::new();
body.insert(
"form".to_string(),
json!([
{ "enabled": true, "name": "field1", "value": "value1", "file": "" },
{ "enabled": true, "name": "field2", "value": "value2", "file": "" },
{ "enabled": false, "name": "disabled", "value": "won't show", "file": "" },
]),
);
let (result, content_type) = build_multipart_body(&body, &vec![]).await?;
assert!(content_type.is_some());
match result {
Some(SendableBodyWithMeta::Stream { data: mut stream, content_length }) => {
// Read the entire stream to verify content
let mut buf = Vec::new();
use tokio::io::AsyncReadExt;
stream.read_to_end(&mut buf).await.expect("Failed to read stream");
let body_str = String::from_utf8_lossy(&buf);
assert_eq!(
body_str,
"--------YaakFormBoundary\r\nContent-Disposition: form-data; name=\"field1\"\r\n\r\nvalue1\r\n--------YaakFormBoundary\r\nContent-Disposition: form-data; name=\"field2\"\r\n\r\nvalue2\r\n--------YaakFormBoundary--\r\n",
);
assert_eq!(content_length, Some(body_str.len()));
}
_ => panic!("Expected Some(SendableBody::Stream)"),
}
assert_eq!(
content_type.unwrap(),
format!("multipart/form-data; boundary={}", MULTIPART_BOUNDARY)
);
Ok(())
}
#[tokio::test]
async fn test_multipart_body_with_file() -> Result<()> {
let mut body = BTreeMap::new();
body.insert(
"form".to_string(),
json!([
{ "enabled": true, "name": "file_field", "file": "./tests/test.txt", "filename": "custom.txt", "contentType": "text/plain" },
]),
);
let (result, content_type) = build_multipart_body(&body, &vec![]).await?;
assert!(content_type.is_some());
match result {
Some(SendableBodyWithMeta::Stream { data: mut stream, content_length }) => {
// Read the entire stream to verify content
let mut buf = Vec::new();
use tokio::io::AsyncReadExt;
stream.read_to_end(&mut buf).await.expect("Failed to read stream");
let body_str = String::from_utf8_lossy(&buf);
assert_eq!(
body_str,
"--------YaakFormBoundary\r\nContent-Disposition: form-data; name=\"file_field\"; filename=\"custom.txt\"\r\nContent-Type: text/plain\r\n\r\nThis is a test file!\n\r\n--------YaakFormBoundary--\r\n"
);
assert_eq!(content_length, Some(body_str.len()));
}
_ => panic!("Expected Some(SendableBody::Stream)"),
}
assert_eq!(
content_type.unwrap(),
format!("multipart/form-data; boundary={}", MULTIPART_BOUNDARY)
);
Ok(())
}
#[tokio::test]
async fn test_multipart_body_empty() -> Result<()> {
let body = BTreeMap::new();
let (result, content_type) = build_multipart_body(&body, &vec![]).await?;
assert!(result.is_none());
assert_eq!(content_type, None);
Ok(())
}
#[test]
fn test_extract_boundary_from_headers_with_custom_boundary() {
let headers = vec![(
"Content-Type".to_string(),
"multipart/form-data; boundary=customBoundary123".to_string(),
)];
let boundary = extract_boundary_from_headers(&headers);
assert_eq!(boundary, "customBoundary123");
}
#[test]
fn test_extract_boundary_from_headers_default() {
let headers = vec![("Accept".to_string(), "*/*".to_string())];
let boundary = extract_boundary_from_headers(&headers);
assert_eq!(boundary, MULTIPART_BOUNDARY);
}
#[test]
fn test_extract_boundary_from_headers_no_boundary_in_content_type() {
let headers = vec![("Content-Type".to_string(), "multipart/form-data".to_string())];
let boundary = extract_boundary_from_headers(&headers);
assert_eq!(boundary, MULTIPART_BOUNDARY);
}
#[test]
fn test_extract_boundary_case_insensitive() {
let headers = vec![(
"Content-Type".to_string(),
"multipart/form-data; boundary=myBoundary".to_string(),
)];
let boundary = extract_boundary_from_headers(&headers);
assert_eq!(boundary, "myBoundary");
}
#[tokio::test]
async fn test_no_content_length_with_chunked_encoding() -> Result<()> {
let mut body = BTreeMap::new();
body.insert("text".to_string(), json!("Hello, World!"));
// Headers with Transfer-Encoding: chunked
let headers = vec![("Transfer-Encoding".to_string(), "chunked".to_string())];
let (_, result_headers) =
build_body("POST", &Some("text/plain".to_string()), &body, headers).await?;
// Verify that Content-Length is NOT present when Transfer-Encoding: chunked is set
let has_content_length =
result_headers.iter().any(|h| h.0.to_lowercase() == "content-length");
assert!(!has_content_length, "Content-Length should not be present with chunked encoding");
// Verify that the Transfer-Encoding header is still present
let has_chunked = result_headers.iter().any(|h| {
h.0.to_lowercase() == "transfer-encoding" && h.1.to_lowercase().contains("chunked")
});
assert!(has_chunked, "Transfer-Encoding: chunked should be preserved");
Ok(())
}
#[tokio::test]
async fn test_content_length_without_chunked_encoding() -> Result<()> {
let mut body = BTreeMap::new();
body.insert("text".to_string(), json!("Hello, World!"));
// Headers without Transfer-Encoding: chunked
let headers = vec![];
let (_, result_headers) =
build_body("POST", &Some("text/plain".to_string()), &body, headers).await?;
// Verify that Content-Length IS present when Transfer-Encoding: chunked is NOT set
let content_length_header =
result_headers.iter().find(|h| h.0.to_lowercase() == "content-length");
assert!(
content_length_header.is_some(),
"Content-Length should be present without chunked encoding"
);
assert_eq!(
content_length_header.unwrap().1,
"13",
"Content-Length should match the body size"
);
Ok(())
}
}

View File

@@ -0,0 +1 @@
This is a test file!

View File

@@ -0,0 +1,23 @@
[package]
name = "yaak-models"
version = "0.1.0"
edition = "2024"
publish = false
[dependencies]
chrono = { version = "0.4.38", features = ["serde"] }
hex = { workspace = true }
include_dir = "0.7"
log = { workspace = true }
nanoid = "0.4.0"
r2d2 = "0.8.10"
r2d2_sqlite = { version = "0.25.0" }
rusqlite = { version = "0.32.1", features = ["bundled", "chrono"] }
sea-query = { version = "0.32.1", features = ["with-chrono", "attr"] }
sea-query-rusqlite = { version = "0.7.0", features = ["with-chrono"] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
sha2 = { workspace = true }
thiserror = { workspace = true }
ts-rs = { workspace = true, features = ["chrono-impl", "serde-json-impl"] }
yaak-core = { workspace = true }

View File

@@ -0,0 +1,96 @@
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
export type AnyModel = CookieJar | Environment | Folder | GraphQlIntrospection | GrpcConnection | GrpcEvent | GrpcRequest | HttpRequest | HttpResponse | HttpResponseEvent | KeyValue | Plugin | Settings | SyncState | WebsocketConnection | WebsocketEvent | WebsocketRequest | Workspace | WorkspaceMeta;
export type ClientCertificate = { host: string, port: number | null, crtFile: string | null, keyFile: string | null, pfxFile: string | null, passphrase: string | null, enabled?: boolean, };
export type Cookie = { raw_cookie: string, domain: CookieDomain, expires: CookieExpires, path: [string, boolean], };
export type CookieDomain = { "HostOnly": string } | { "Suffix": string } | "NotPresent" | "Empty";
export type CookieExpires = { "AtUtc": string } | "SessionEnd";
export type CookieJar = { model: "cookie_jar", id: string, createdAt: string, updatedAt: string, workspaceId: string, cookies: Array<Cookie>, name: string, };
export type EditorKeymap = "default" | "vim" | "vscode" | "emacs";
export type EncryptedKey = { encryptedKey: string, };
export type Environment = { model: "environment", id: string, workspaceId: string, createdAt: string, updatedAt: string, name: string, public: boolean, parentModel: string, parentId: string | null, variables: Array<EnvironmentVariable>, color: string | null, sortPriority: number, };
export type EnvironmentVariable = { enabled?: boolean, name: string, value: string, id?: string, };
export type Folder = { model: "folder", id: string, createdAt: string, updatedAt: string, workspaceId: string, folderId: string | null, authentication: Record<string, any>, authenticationType: string | null, description: string, headers: Array<HttpRequestHeader>, name: string, sortPriority: number, };
export type GraphQlIntrospection = { model: "graphql_introspection", id: string, createdAt: string, updatedAt: string, workspaceId: string, requestId: string, content: string | null, };
export type GrpcConnection = { model: "grpc_connection", id: string, createdAt: string, updatedAt: string, workspaceId: string, requestId: string, elapsed: number, error: string | null, method: string, service: string, status: number, state: GrpcConnectionState, trailers: { [key in string]?: string }, url: string, };
export type GrpcConnectionState = "initialized" | "connected" | "closed";
export type GrpcEvent = { model: "grpc_event", id: string, createdAt: string, updatedAt: string, workspaceId: string, requestId: string, connectionId: string, content: string, error: string | null, eventType: GrpcEventType, metadata: { [key in string]?: string }, status: number | null, };
export type GrpcEventType = "info" | "error" | "client_message" | "server_message" | "connection_start" | "connection_end";
export type GrpcRequest = { model: "grpc_request", id: string, createdAt: string, updatedAt: string, workspaceId: string, folderId: string | null, authenticationType: string | null, authentication: Record<string, any>, description: string, message: string, metadata: Array<HttpRequestHeader>, method: string | null, name: string, service: string | null, sortPriority: number, url: string, };
export type HttpRequest = { model: "http_request", id: string, createdAt: string, updatedAt: string, workspaceId: string, folderId: string | null, authentication: Record<string, any>, authenticationType: string | null, body: Record<string, any>, bodyType: string | null, description: string, headers: Array<HttpRequestHeader>, method: string, name: string, sortPriority: number, url: string, urlParameters: Array<HttpUrlParameter>, };
export type HttpRequestHeader = { enabled?: boolean, name: string, value: string, id?: string, };
export type HttpResponse = { model: "http_response", id: string, createdAt: string, updatedAt: string, workspaceId: string, requestId: string, bodyPath: string | null, contentLength: number | null, contentLengthCompressed: number | null, elapsed: number, elapsedHeaders: number, error: string | null, headers: Array<HttpResponseHeader>, remoteAddr: string | null, requestContentLength: number | null, requestHeaders: Array<HttpResponseHeader>, status: number, statusReason: string | null, state: HttpResponseState, url: string, version: string | null, };
export type HttpResponseEvent = { model: "http_response_event", id: string, createdAt: string, updatedAt: string, workspaceId: string, responseId: string, event: HttpResponseEventData, };
/**
* Serializable representation of HTTP response events for DB storage.
* This mirrors `yaak_http::sender::HttpResponseEvent` but with serde support.
* The `From` impl is in yaak-http to avoid circular dependencies.
*/
export type HttpResponseEventData = { "type": "setting", name: string, value: string, } | { "type": "info", message: string, } | { "type": "redirect", url: string, status: number, behavior: string, } | { "type": "send_url", method: string, path: string, } | { "type": "receive_url", version: string, status: string, } | { "type": "header_up", name: string, value: string, } | { "type": "header_down", name: string, value: string, } | { "type": "chunk_sent", bytes: number, } | { "type": "chunk_received", bytes: number, };
export type HttpResponseHeader = { name: string, value: string, };
export type HttpResponseState = "initialized" | "connected" | "closed";
export type HttpUrlParameter = { enabled?: boolean, name: string, value: string, id?: string, };
export type KeyValue = { model: "key_value", id: string, createdAt: string, updatedAt: string, key: string, namespace: string, value: string, };
export type ModelChangeEvent = { "type": "upsert", created: boolean, } | { "type": "delete" };
export type ModelPayload = { model: AnyModel, updateSource: UpdateSource, change: ModelChangeEvent, };
export type ParentAuthentication = { authentication: Record<string, any>, authenticationType: string | null, };
export type ParentHeaders = { headers: Array<HttpRequestHeader>, };
export type Plugin = { model: "plugin", id: string, createdAt: string, updatedAt: string, checkedAt: string | null, directory: string, enabled: boolean, url: string | null, };
export type PluginKeyValue = { model: "plugin_key_value", createdAt: string, updatedAt: string, pluginName: string, key: string, value: string, };
export type ProxySetting = { "type": "enabled", http: string, https: string, auth: ProxySettingAuth | null, bypass: string, disabled: boolean, } | { "type": "disabled" };
export type ProxySettingAuth = { user: string, password: string, };
export type Settings = { model: "settings", id: string, createdAt: string, updatedAt: string, appearance: string, clientCertificates: Array<ClientCertificate>, coloredMethods: boolean, editorFont: string | null, editorFontSize: number, editorKeymap: EditorKeymap, editorSoftWrap: boolean, hideWindowControls: boolean, useNativeTitlebar: boolean, interfaceFont: string | null, interfaceFontSize: number, interfaceScale: number, openWorkspaceNewWindow: boolean | null, proxy: ProxySetting | null, themeDark: string, themeLight: string, updateChannel: string, hideLicenseBadge: boolean, autoupdate: boolean, autoDownloadUpdates: boolean, checkNotifications: boolean, hotkeys: { [key in string]?: Array<string> }, };
export type SyncState = { model: "sync_state", id: string, workspaceId: string, createdAt: string, updatedAt: string, flushedAt: string, modelId: string, checksum: string, relPath: string, syncDir: string, };
export type UpdateSource = { "type": "background" } | { "type": "import" } | { "type": "plugin" } | { "type": "sync" } | { "type": "window", label: string, };
export type WebsocketConnection = { model: "websocket_connection", id: string, createdAt: string, updatedAt: string, workspaceId: string, requestId: string, elapsed: number, error: string | null, headers: Array<HttpResponseHeader>, state: WebsocketConnectionState, status: number, url: string, };
export type WebsocketConnectionState = "initialized" | "connected" | "closing" | "closed";
export type WebsocketEvent = { model: "websocket_event", id: string, createdAt: string, updatedAt: string, workspaceId: string, requestId: string, connectionId: string, isServer: boolean, message: Array<number>, messageType: WebsocketEventType, };
export type WebsocketEventType = "binary" | "close" | "frame" | "open" | "ping" | "pong" | "text";
export type WebsocketMessageType = "text" | "binary";
export type WebsocketRequest = { model: "websocket_request", id: string, createdAt: string, updatedAt: string, workspaceId: string, folderId: string | null, authentication: Record<string, any>, authenticationType: string | null, description: string, headers: Array<HttpRequestHeader>, message: string, name: string, sortPriority: number, url: string, urlParameters: Array<HttpUrlParameter>, };
export type Workspace = { model: "workspace", id: string, createdAt: string, updatedAt: string, authentication: Record<string, any>, authenticationType: string | null, description: string, headers: Array<HttpRequestHeader>, name: string, encryptionKeyChallenge: string | null, settingValidateCertificates: boolean, settingFollowRedirects: boolean, settingRequestTimeout: number, };
export type WorkspaceMeta = { model: "workspace_meta", id: string, workspaceId: string, createdAt: string, updatedAt: string, encryptionKey: EncryptedKey | null, settingSyncDir: string | null, };

4
crates/yaak-models/bindings/gen_util.ts generated Normal file
View File

@@ -0,0 +1,4 @@
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
import type { Environment, Folder, GrpcRequest, HttpRequest, WebsocketRequest, Workspace } from "./gen_models";
export type BatchUpsertResult = { workspaces: Array<Workspace>, environments: Array<Environment>, folders: Array<Folder>, httpRequests: Array<HttpRequest>, grpcRequests: Array<GrpcRequest>, websocketRequests: Array<WebsocketRequest>, };

View File

@@ -0,0 +1,12 @@
CREATE TABLE body_chunks
(
id TEXT PRIMARY KEY,
body_id TEXT NOT NULL,
chunk_index INTEGER NOT NULL,
data BLOB NOT NULL,
created_at DATETIME DEFAULT (STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) NOT NULL,
UNIQUE (body_id, chunk_index)
);
CREATE INDEX idx_body_chunks_body_id ON body_chunks (body_id, chunk_index);

View File

@@ -0,0 +1,81 @@
import { atom } from 'jotai';
import { selectAtom } from 'jotai/utils';
import type { AnyModel } from '../bindings/gen_models';
import { ExtractModel } from './types';
import { newStoreData } from './util';
export const modelStoreDataAtom = atom(newStoreData());
export const cookieJarsAtom = createOrderedModelAtom('cookie_jar', 'name', 'asc');
export const environmentsAtom = createOrderedModelAtom('environment', 'sortPriority', 'asc');
export const foldersAtom = createModelAtom('folder');
export const grpcConnectionsAtom = createOrderedModelAtom('grpc_connection', 'createdAt', 'desc');
export const grpcEventsAtom = createOrderedModelAtom('grpc_event', 'createdAt', 'asc');
export const grpcRequestsAtom = createModelAtom('grpc_request');
export const httpRequestsAtom = createModelAtom('http_request');
export const httpResponsesAtom = createOrderedModelAtom('http_response', 'createdAt', 'desc');
export const httpResponseEventsAtom = createOrderedModelAtom('http_response_event', 'createdAt', 'asc');
export const keyValuesAtom = createModelAtom('key_value');
export const pluginsAtom = createModelAtom('plugin');
export const settingsAtom = createSingularModelAtom('settings');
export const websocketRequestsAtom = createModelAtom('websocket_request');
export const websocketEventsAtom = createOrderedModelAtom('websocket_event', 'createdAt', 'asc');
export const websocketConnectionsAtom = createOrderedModelAtom(
'websocket_connection',
'createdAt',
'desc',
);
export const workspaceMetasAtom = createModelAtom('workspace_meta');
export const workspacesAtom = createOrderedModelAtom('workspace', 'name', 'asc');
export function createModelAtom<M extends AnyModel['model']>(modelType: M) {
return selectAtom(
modelStoreDataAtom,
(data) => Object.values(data[modelType] ?? {}),
shallowEqual,
);
}
export function createSingularModelAtom<M extends AnyModel['model']>(modelType: M) {
return selectAtom(modelStoreDataAtom, (data) => {
const modelData = Object.values(data[modelType] ?? {});
const item = modelData[0];
if (item == null) throw new Error('Failed creating singular model with no data: ' + modelType);
return item;
});
}
export function createOrderedModelAtom<M extends AnyModel['model']>(
modelType: M,
field: keyof ExtractModel<AnyModel, M>,
order: 'asc' | 'desc',
) {
return selectAtom(
modelStoreDataAtom,
(data) => {
const modelData = data[modelType] ?? {};
return Object.values(modelData).sort(
(a: ExtractModel<AnyModel, M>, b: ExtractModel<AnyModel, M>) => {
const n = a[field] > b[field] ? 1 : -1;
return order === 'desc' ? n * -1 : n;
},
);
},
shallowEqual,
);
}
function shallowEqual<T>(a: T[], b: T[]): boolean {
if (a.length !== b.length) {
return false;
}
for (let i = 0; i < a.length; i++) {
if (a[i] !== b[i]) {
return false;
}
}
return true;
}

View File

@@ -0,0 +1,11 @@
import { AnyModel } from '../bindings/gen_models';
export * from '../bindings/gen_models';
export * from '../bindings/gen_util';
export * from './store';
export * from './atoms';
export function modelTypeLabel(m: AnyModel): string {
const capitalize = (str: string) => str.charAt(0).toUpperCase() + str.slice(1);
return m.model.split('_').map(capitalize).join(' ');
}

View File

@@ -0,0 +1,230 @@
import { invoke } from '@tauri-apps/api/core';
import { getCurrentWebviewWindow } from '@tauri-apps/api/webviewWindow';
import { resolvedModelName } from '@yaakapp/app/lib/resolvedModelName';
import { AnyModel, ModelPayload } from '../bindings/gen_models';
import { modelStoreDataAtom } from './atoms';
import { ExtractModel, JotaiStore, ModelStoreData } from './types';
import { newStoreData } from './util';
let _store: JotaiStore | null = null;
export function initModelStore(store: JotaiStore) {
_store = store;
getCurrentWebviewWindow()
.listen<ModelPayload>('model_write', ({ payload }) => {
if (shouldIgnoreModel(payload)) return;
mustStore().set(modelStoreDataAtom, (prev: ModelStoreData) => {
if (payload.change.type === 'upsert') {
return {
...prev,
[payload.model.model]: {
...prev[payload.model.model],
[payload.model.id]: payload.model,
},
};
} else {
const modelData = { ...prev[payload.model.model] };
delete modelData[payload.model.id];
return { ...prev, [payload.model.model]: modelData };
}
});
})
.catch(console.error);
}
function mustStore(): JotaiStore {
if (_store == null) {
throw new Error('Model store was not initialized');
}
return _store;
}
let _activeWorkspaceId: string | null = null;
export async function changeModelStoreWorkspace(workspaceId: string | null) {
console.log('Syncing models with new workspace', workspaceId);
const workspaceModelsStr = await invoke<string>('models_workspace_models', {
workspaceId, // NOTE: if no workspace id provided, it will just fetch global models
});
const workspaceModels = JSON.parse(workspaceModelsStr) as AnyModel[];
const data = newStoreData();
for (const model of workspaceModels) {
data[model.model][model.id] = model;
}
mustStore().set(modelStoreDataAtom, data);
console.log('Synced model store with workspace', workspaceId, data);
_activeWorkspaceId = workspaceId;
}
export function listModels<M extends AnyModel['model'], T extends ExtractModel<AnyModel, M>>(
modelType: M | ReadonlyArray<M>,
): T[] {
let data = mustStore().get(modelStoreDataAtom);
const types: ReadonlyArray<M> = Array.isArray(modelType) ? modelType : [modelType];
return types.flatMap((t) => Object.values(data[t]) as T[]);
}
export function getModel<M extends AnyModel['model'], T extends ExtractModel<AnyModel, M>>(
modelType: M | ReadonlyArray<M>,
id: string,
): T | null {
let data = mustStore().get(modelStoreDataAtom);
const types: ReadonlyArray<M> = Array.isArray(modelType) ? modelType : [modelType];
for (const t of types) {
let v = data[t][id];
if (v?.model === t) return v as T;
}
return null;
}
export function getAnyModel(
id: string,
): AnyModel | null {
let data = mustStore().get(modelStoreDataAtom);
for (const t of Object.keys(data)) {
let v = (data as any)[t]?.[id];
if (v?.model === t) return v;
}
return null;
}
export function patchModelById<M extends AnyModel['model'], T extends ExtractModel<AnyModel, M>>(
model: M,
id: string,
patch: Partial<T> | ((prev: T) => T),
): Promise<string> {
let prev = getModel<M, T>(model, id);
if (prev == null) {
throw new Error(`Failed to get model to patch id=${id} model=${model}`);
}
const newModel = typeof patch === 'function' ? patch(prev) : { ...prev, ...patch };
return updateModel(newModel);
}
export async function patchModel<M extends AnyModel['model'], T extends ExtractModel<AnyModel, M>>(
base: Pick<T, 'id' | 'model'>,
patch: Partial<T>,
): Promise<string> {
return patchModelById<M, T>(base.model, base.id, patch);
}
export async function updateModel<M extends AnyModel['model'], T extends ExtractModel<AnyModel, M>>(
model: T,
): Promise<string> {
return invoke<string>('models_upsert', { model });
}
export async function deleteModelById<
M extends AnyModel['model'],
T extends ExtractModel<AnyModel, M>,
>(modelType: M | M[], id: string) {
let model = getModel<M, T>(modelType, id);
await deleteModel(model);
}
export async function deleteModel<M extends AnyModel['model'], T extends ExtractModel<AnyModel, M>>(
model: T | null,
) {
if (model == null) {
throw new Error('Failed to delete null model');
}
await invoke<string>('models_delete', { model });
}
export function duplicateModel<M extends AnyModel['model'], T extends ExtractModel<AnyModel, M>>(
model: T | null,
) {
if (model == null) {
throw new Error('Failed to duplicate null model');
}
// If the model has a name, try to duplicate it with a name that doesn't conflict
let name = 'name' in model ? resolvedModelName(model) : undefined;
if (name != null) {
const existingModels = listModels(model.model);
for (let i = 0; i < 100; i++) {
const hasConflict = existingModels.some((m) => {
if ('folderId' in m && 'folderId' in model && model.folderId !== m.folderId) {
return false;
} else if (resolvedModelName(m) !== name) {
return false;
}
return true;
});
if (!hasConflict) {
break;
}
// Name conflict. Try another one
const m: RegExpMatchArray | null = name.match(/ Copy( (?<n>\d+))?$/);
if (m != null && m.groups?.n == null) {
name = name.substring(0, m.index) + ' Copy 2';
} else if (m != null && m.groups?.n != null) {
name = name.substring(0, m.index) + ` Copy ${parseInt(m.groups.n) + 1}`;
} else {
name = `${name} Copy`;
}
}
}
return invoke<string>('models_duplicate', { model: { ...model, name } });
}
export async function createGlobalModel<T extends Exclude<AnyModel, { workspaceId: string }>>(
patch: Partial<T> & Pick<T, 'model'>,
): Promise<string> {
return invoke<string>('models_upsert', { model: patch });
}
export async function createWorkspaceModel<T extends Extract<AnyModel, { workspaceId: string }>>(
patch: Partial<T> & Pick<T, 'model' | 'workspaceId'>,
): Promise<string> {
return invoke<string>('models_upsert', { model: patch });
}
export function replaceModelsInStore<
M extends AnyModel['model'],
T extends Extract<AnyModel, { model: M }>,
>(model: M, models: T[]) {
const newModels: Record<string, T> = {};
for (const model of models) {
newModels[model.id] = model;
}
mustStore().set(modelStoreDataAtom, (prev: ModelStoreData) => {
return {
...prev,
[model]: newModels,
};
});
}
function shouldIgnoreModel({ model, updateSource }: ModelPayload) {
// Never ignore updates from non-user sources
if (updateSource.type !== 'window') {
return false;
}
// Never ignore same-window updates
if (updateSource.label === getCurrentWebviewWindow().label) {
return false;
}
// Only sync models that belong to this workspace, if a workspace ID is present
if ('workspaceId' in model && model.workspaceId !== _activeWorkspaceId) {
return true;
}
if (model.model === 'key_value' && model.namespace === 'no_sync') {
return true;
}
return false;
}

View File

@@ -0,0 +1,8 @@
import { createStore } from 'jotai';
import { AnyModel } from '../bindings/gen_models';
export type ExtractModel<T, M> = T extends { model: M } ? T : never;
export type ModelStoreData<T extends AnyModel = AnyModel> = {
[M in T['model']]: Record<string, Extract<T, { model: M }>>;
};
export type JotaiStore = ReturnType<typeof createStore>;

View File

@@ -0,0 +1,25 @@
import { ModelStoreData } from './types';
export function newStoreData(): ModelStoreData {
return {
cookie_jar: {},
environment: {},
folder: {},
graphql_introspection: {},
grpc_connection: {},
grpc_event: {},
grpc_request: {},
http_request: {},
http_response: {},
http_response_event: {},
key_value: {},
plugin: {},
settings: {},
sync_state: {},
websocket_connection: {},
websocket_event: {},
websocket_request: {},
workspace: {},
workspace_meta: {},
};
}

View File

@@ -0,0 +1,65 @@
CREATE TABLE key_values
(
model TEXT DEFAULT 'key_value' NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL,
deleted_at DATETIME,
namespace TEXT NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
PRIMARY KEY (namespace, key)
);
CREATE TABLE workspaces
(
id TEXT NOT NULL
PRIMARY KEY,
model TEXT DEFAULT 'workspace' NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL,
deleted_at DATETIME,
name TEXT NOT NULL,
description TEXT NOT NULL
);
CREATE TABLE http_requests
(
id TEXT NOT NULL
PRIMARY KEY,
model TEXT DEFAULT 'http_request' NOT NULL,
workspace_id TEXT NOT NULL
REFERENCES workspaces
ON DELETE CASCADE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL,
deleted_at DATETIME,
name TEXT NOT NULL,
url TEXT NOT NULL,
method TEXT NOT NULL,
headers TEXT NOT NULL,
body TEXT,
body_type TEXT
);
CREATE TABLE http_responses
(
id TEXT NOT NULL
PRIMARY KEY,
model TEXT DEFAULT 'http_response' NOT NULL,
request_id TEXT NOT NULL
REFERENCES http_requests
ON DELETE CASCADE,
workspace_id TEXT NOT NULL
REFERENCES workspaces
ON DELETE CASCADE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL,
deleted_at DATETIME,
elapsed INTEGER NOT NULL,
status INTEGER NOT NULL,
status_reason TEXT,
url TEXT NOT NULL,
body TEXT NOT NULL,
headers TEXT NOT NULL,
error TEXT
);

View File

@@ -0,0 +1 @@
ALTER TABLE main.http_requests ADD COLUMN sort_priority REAL NOT NULL DEFAULT 0;

View File

@@ -0,0 +1,2 @@
ALTER TABLE http_requests ADD COLUMN authentication TEXT NOT NULL DEFAULT '{}';
ALTER TABLE http_requests ADD COLUMN authentication_type TEXT;

View File

@@ -0,0 +1,5 @@
DELETE FROM main.http_responses;
ALTER TABLE http_responses DROP COLUMN body;
ALTER TABLE http_responses ADD COLUMN body BLOB;
ALTER TABLE http_responses ADD COLUMN body_path TEXT;
ALTER TABLE http_responses ADD COLUMN content_length INTEGER;

View File

@@ -0,0 +1,15 @@
CREATE TABLE environments
(
id TEXT NOT NULL
PRIMARY KEY,
model TEXT DEFAULT 'workspace' NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL,
deleted_at DATETIME,
workspace_id TEXT NOT NULL
REFERENCES workspaces
ON DELETE CASCADE,
name TEXT NOT NULL,
data TEXT NOT NULL
DEFAULT '{}'
);

View File

@@ -0,0 +1,2 @@
ALTER TABLE environments DROP COLUMN data;
ALTER TABLE environments ADD COLUMN variables DEFAULT '[]' NOT NULL;

View File

@@ -0,0 +1 @@
ALTER TABLE workspaces ADD COLUMN variables DEFAULT '[]' NOT NULL;

View File

@@ -0,0 +1,19 @@
CREATE TABLE folders
(
id TEXT NOT NULL
PRIMARY KEY,
model TEXT DEFAULT 'folder' NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL,
deleted_at DATETIME,
workspace_id TEXT NOT NULL
REFERENCES workspaces
ON DELETE CASCADE,
folder_id TEXT NULL
REFERENCES folders
ON DELETE CASCADE,
name TEXT NOT NULL,
sort_priority REAL DEFAULT 0 NOT NULL
);
ALTER TABLE http_requests ADD COLUMN folder_id TEXT REFERENCES folders(id) ON DELETE CASCADE;

View File

@@ -0,0 +1,16 @@
-- Rename old column to backup name
ALTER TABLE http_requests
RENAME COLUMN body TO body_old;
-- Create desired new body column
ALTER TABLE http_requests
ADD COLUMN body TEXT NOT NULL DEFAULT '{}';
-- Copy data from old to new body, in new JSON format
UPDATE http_requests
SET body = CASE WHEN body_old IS NULL THEN '{}' ELSE JSON_OBJECT('text', body_old) END
WHERE TRUE;
-- Drop old column
ALTER TABLE http_requests
DROP COLUMN body_old;

View File

@@ -0,0 +1,2 @@
ALTER TABLE http_requests
ADD COLUMN url_parameters TEXT NOT NULL DEFAULT '[]';

View File

@@ -0,0 +1 @@
ALTER TABLE http_responses DROP COLUMN body;

View File

@@ -0,0 +1,13 @@
CREATE TABLE settings
(
id TEXT NOT NULL
PRIMARY KEY,
model TEXT DEFAULT 'settings' NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL,
follow_redirects BOOLEAN DEFAULT TRUE NOT NULL,
validate_certificates BOOLEAN DEFAULT TRUE NOT NULL,
request_timeout INTEGER DEFAULT 0 NOT NULL,
theme TEXT DEFAULT 'default' NOT NULL,
appearance TEXT DEFAULT 'system' NOT NULL
);

View File

@@ -0,0 +1,9 @@
-- Add existing request-related settings to workspace
ALTER TABLE workspaces ADD COLUMN setting_request_timeout INTEGER DEFAULT '0' NOT NULL;
ALTER TABLE workspaces ADD COLUMN setting_validate_certificates BOOLEAN DEFAULT TRUE NOT NULL;
ALTER TABLE workspaces ADD COLUMN setting_follow_redirects BOOLEAN DEFAULT TRUE NOT NULL;
-- Remove old settings that used to be global
ALTER TABLE settings DROP COLUMN request_timeout;
ALTER TABLE settings DROP COLUMN follow_redirects;
ALTER TABLE settings DROP COLUMN validate_certificates;

View File

@@ -0,0 +1 @@
ALTER TABLE settings ADD COLUMN update_channel TEXT DEFAULT 'stable' NOT NULL;

View File

@@ -0,0 +1,10 @@
CREATE TABLE cookie_jars
(
id TEXT NOT NULL PRIMARY KEY,
model TEXT DEFAULT 'cookie_jar' NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL,
name TEXT NOT NULL,
cookies TEXT DEFAULT '[]' NOT NULL,
workspace_id TEXT NOT NULL
);

View File

@@ -0,0 +1,3 @@
ALTER TABLE http_responses ADD COLUMN elapsed_headers INTEGER NOT NULL DEFAULT 0;
ALTER TABLE http_responses ADD COLUMN remote_addr TEXT;
ALTER TABLE http_responses ADD COLUMN version TEXT;

View File

@@ -0,0 +1,68 @@
CREATE TABLE grpc_requests
(
id TEXT NOT NULL
PRIMARY KEY,
model TEXT DEFAULT 'grpc_request' NOT NULL,
workspace_id TEXT NOT NULL
REFERENCES workspaces
ON DELETE CASCADE,
folder_id TEXT NULL
REFERENCES folders
ON DELETE CASCADE,
created_at DATETIME DEFAULT (STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) NOT NULL,
updated_at DATETIME DEFAULT (STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) NOT NULL,
name TEXT NOT NULL,
sort_priority REAL NOT NULL,
url TEXT NOT NULL,
service TEXT NULL,
method TEXT NULL,
message TEXT NOT NULL,
authentication TEXT DEFAULT '{}' NOT NULL,
authentication_type TEXT NULL,
metadata TEXT DEFAULT '[]' NOT NULL
);
CREATE TABLE grpc_connections
(
id TEXT NOT NULL
PRIMARY KEY,
model TEXT DEFAULT 'grpc_connection' NOT NULL,
workspace_id TEXT NOT NULL
REFERENCES workspaces
ON DELETE CASCADE,
request_id TEXT NOT NULL
REFERENCES grpc_requests
ON DELETE CASCADE,
created_at DATETIME DEFAULT (STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) NOT NULL,
updated_at DATETIME DEFAULT (STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) NOT NULL,
url TEXT NOT NULL,
service TEXT NOT NULL,
method TEXT NOT NULL,
status INTEGER DEFAULT -1 NOT NULL,
error TEXT NULL,
elapsed INTEGER DEFAULT 0 NOT NULL,
trailers TEXT DEFAULT '{}' NOT NULL
);
CREATE TABLE grpc_events
(
id TEXT NOT NULL
PRIMARY KEY,
model TEXT DEFAULT 'grpc_event' NOT NULL,
workspace_id TEXT NOT NULL
REFERENCES workspaces
ON DELETE CASCADE,
request_id TEXT NOT NULL
REFERENCES grpc_requests
ON DELETE CASCADE,
connection_id TEXT NOT NULL
REFERENCES grpc_connections
ON DELETE CASCADE,
created_at DATETIME DEFAULT (STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) NOT NULL,
updated_at DATETIME DEFAULT (STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) NOT NULL,
metadata TEXT DEFAULT '{}' NOT NULL,
event_type TEXT NOT NULL,
status INTEGER NULL,
error TEXT NULL,
content TEXT NOT NULL
);

View File

@@ -0,0 +1,4 @@
ALTER TABLE settings
ADD COLUMN theme_dark TEXT DEFAULT 'yaak-dark' NOT NULL;
ALTER TABLE settings
ADD COLUMN theme_light TEXT DEFAULT 'yaak-light' NOT NULL;

View File

@@ -0,0 +1,4 @@
ALTER TABLE settings ADD COLUMN interface_font_size INTEGER DEFAULT 15 NOT NULL;
ALTER TABLE settings ADD COLUMN interface_scale INTEGER DEFAULT 1 NOT NULL;
ALTER TABLE settings ADD COLUMN editor_font_size INTEGER DEFAULT 13 NOT NULL;
ALTER TABLE settings ADD COLUMN editor_soft_wrap BOOLEAN DEFAULT 1 NOT NULL;

View File

@@ -0,0 +1 @@
ALTER TABLE settings ADD COLUMN open_workspace_new_window BOOLEAN NULL DEFAULT NULL;

View File

@@ -0,0 +1,2 @@
ALTER TABLE environments DROP COLUMN model;
ALTER TABLE environments ADD COLUMN model TEXT DEFAULT 'environment';

View File

@@ -0,0 +1 @@
ALTER TABLE settings ADD COLUMN telemetry BOOLEAN DEFAULT TRUE;

View File

@@ -0,0 +1,12 @@
CREATE TABLE plugins
(
id TEXT NOT NULL
PRIMARY KEY,
model TEXT DEFAULT 'plugin' NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL,
checked_at DATETIME NULL,
enabled BOOLEAN NOT NULL,
directory TEXT NULL NOT NULL,
url TEXT NULL
);

View File

@@ -0,0 +1,5 @@
ALTER TABLE http_responses
ADD COLUMN state TEXT DEFAULT 'closed' NOT NULL;
ALTER TABLE grpc_connections
ADD COLUMN state TEXT DEFAULT 'closed' NOT NULL;

View File

@@ -0,0 +1 @@
ALTER TABLE settings ADD COLUMN proxy TEXT;

Some files were not shown because too many files have changed in this diff Show More