Start extracting DBContext

This commit is contained in:
Gregory Schier
2026-03-08 08:56:08 -07:00
parent cf28229f5f
commit 4c37e62146
45 changed files with 695 additions and 242 deletions

View File

@@ -0,0 +1,18 @@
[package]
name = "yaak-proxy-models"
version = "0.1.0"
edition = "2024"
publish = false
[dependencies]
chrono = { version = "0.4.38", features = ["serde"] }
include_dir = "0.7"
log = { workspace = true }
r2d2 = "0.8.10"
r2d2_sqlite = { version = "0.25.0" }
rusqlite = { version = "0.32.1", features = ["bundled", "chrono"] }
sea-query = { version = "0.32.1", features = ["with-chrono", "attr"] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
thiserror = { workspace = true }
yaak-database = { workspace = true }

View File

@@ -0,0 +1,35 @@
-- Proxy version of http_responses, duplicated from client.
-- No workspace_id/request_id foreign keys — proxy captures raw traffic.
CREATE TABLE proxy_http_responses (
id TEXT NOT NULL PRIMARY KEY,
model TEXT DEFAULT 'proxy_http_response' NOT NULL,
proxy_request_id INTEGER NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL,
elapsed INTEGER NOT NULL DEFAULT 0,
elapsed_headers INTEGER NOT NULL DEFAULT 0,
elapsed_dns INTEGER NOT NULL DEFAULT 0,
status INTEGER NOT NULL DEFAULT 0,
status_reason TEXT,
url TEXT NOT NULL,
headers TEXT NOT NULL DEFAULT '[]',
request_headers TEXT NOT NULL DEFAULT '[]',
error TEXT,
body_path TEXT,
content_length INTEGER,
content_length_compressed INTEGER,
request_content_length INTEGER,
remote_addr TEXT,
version TEXT,
state TEXT DEFAULT 'initialized' NOT NULL
);
CREATE INDEX idx_proxy_http_responses_created_at ON proxy_http_responses (created_at DESC);
-- Inline body storage (proxy keeps everything self-contained in one DB file)
CREATE TABLE proxy_http_response_bodies (
id INTEGER PRIMARY KEY AUTOINCREMENT,
response_id TEXT NOT NULL REFERENCES proxy_http_responses(id) ON DELETE CASCADE,
body_type TEXT NOT NULL,
data BLOB NOT NULL,
UNIQUE(response_id, body_type)
);

View File

@@ -0,0 +1 @@
pub use yaak_database::error::{Error, Result};

View File

@@ -0,0 +1,73 @@
use crate::error::{Error, Result};
use include_dir::{Dir, include_dir};
use log::info;
use r2d2::Pool;
use r2d2_sqlite::SqliteConnectionManager;
use std::fs::create_dir_all;
use std::path::Path;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use yaak_database::{ConnectionOrTx, DbContext};
pub mod error;
pub mod models;
static MIGRATIONS_DIR: Dir = include_dir!("$CARGO_MANIFEST_DIR/migrations");
/// Manages the proxy session database pool.
/// Use `connect()` to get a `DbContext` for running queries.
#[derive(Debug, Clone)]
pub struct ProxyDb {
pool: Arc<Mutex<Pool<SqliteConnectionManager>>>,
}
impl ProxyDb {
pub fn connect(&self) -> DbContext<'_> {
let conn = self
.pool
.lock()
.expect("Failed to gain lock on proxy DB")
.get()
.expect("Failed to get proxy DB connection from pool");
DbContext::new(ConnectionOrTx::Connection(conn))
}
}
pub fn init_standalone(db_path: impl AsRef<Path>) -> Result<ProxyDb> {
let db_path = db_path.as_ref();
if let Some(parent) = db_path.parent() {
create_dir_all(parent)?;
}
info!("Initializing proxy session database {db_path:?}");
let manager = SqliteConnectionManager::file(db_path);
let pool = Pool::builder()
.max_size(100)
.connection_timeout(Duration::from_secs(10))
.build(manager)
.map_err(|e| Error::Database(e.to_string()))?;
pool.get()?.execute_batch(
"PRAGMA journal_mode=WAL;
PRAGMA foreign_keys=ON;",
)?;
yaak_database::run_migrations(&pool, &MIGRATIONS_DIR)?;
Ok(ProxyDb { pool: Arc::new(Mutex::new(pool)) })
}
pub fn init_in_memory() -> Result<ProxyDb> {
let manager = SqliteConnectionManager::memory();
let pool = Pool::builder()
.max_size(1)
.build(manager)
.map_err(|e| Error::Database(e.to_string()))?;
pool.get()?.execute_batch("PRAGMA foreign_keys=ON;")?;
yaak_database::run_migrations(&pool, &MIGRATIONS_DIR)?;
Ok(ProxyDb { pool: Arc::new(Mutex::new(pool)) })
}

View File

@@ -0,0 +1,160 @@
use chrono::NaiveDateTime;
use rusqlite::Row;
use sea_query::Order::Desc;
use sea_query::{IntoColumnRef, IntoIden, IntoTableRef, Order, SimpleExpr, enum_def};
use serde::{Deserialize, Serialize};
use yaak_database::{
UpsertModelInfo, UpdateSource, Result as DbResult,
generate_prefixed_id, upsert_date,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum HttpResponseState {
Initialized,
Connected,
Closed,
}
impl Default for HttpResponseState {
fn default() -> Self {
Self::Initialized
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HttpResponseHeader {
pub name: String,
pub value: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(default, rename_all = "camelCase")]
#[enum_def(table_name = "proxy_http_responses")]
pub struct HttpResponse {
pub model: String,
pub id: String,
pub proxy_request_id: i64,
pub created_at: NaiveDateTime,
pub updated_at: NaiveDateTime,
pub elapsed: i32,
pub elapsed_headers: i32,
pub elapsed_dns: i32,
pub status: i32,
pub status_reason: Option<String>,
pub url: String,
pub headers: Vec<HttpResponseHeader>,
pub request_headers: Vec<HttpResponseHeader>,
pub error: Option<String>,
pub body_path: Option<String>,
pub content_length: Option<i32>,
pub content_length_compressed: Option<i32>,
pub request_content_length: Option<i32>,
pub remote_addr: Option<String>,
pub version: Option<String>,
pub state: HttpResponseState,
}
impl UpsertModelInfo for HttpResponse {
fn table_name() -> impl IntoTableRef + IntoIden {
HttpResponseIden::Table
}
fn id_column() -> impl IntoIden + Eq + Clone {
HttpResponseIden::Id
}
fn generate_id() -> String {
generate_prefixed_id("rs")
}
fn order_by() -> (impl IntoColumnRef, Order) {
(HttpResponseIden::CreatedAt, Desc)
}
fn get_id(&self) -> String {
self.id.clone()
}
fn insert_values(
self,
source: &UpdateSource,
) -> DbResult<Vec<(impl IntoIden + Eq, impl Into<SimpleExpr>)>> {
use HttpResponseIden::*;
Ok(vec![
(CreatedAt, upsert_date(source, self.created_at)),
(UpdatedAt, upsert_date(source, self.updated_at)),
(ProxyRequestId, self.proxy_request_id.into()),
(BodyPath, self.body_path.into()),
(ContentLength, self.content_length.into()),
(ContentLengthCompressed, self.content_length_compressed.into()),
(Elapsed, self.elapsed.into()),
(ElapsedHeaders, self.elapsed_headers.into()),
(ElapsedDns, self.elapsed_dns.into()),
(Error, self.error.into()),
(Headers, serde_json::to_string(&self.headers)?.into()),
(RemoteAddr, self.remote_addr.into()),
(RequestContentLength, self.request_content_length.into()),
(RequestHeaders, serde_json::to_string(&self.request_headers)?.into()),
(State, serde_json::to_value(&self.state)?.as_str().into()),
(Status, self.status.into()),
(StatusReason, self.status_reason.into()),
(Url, self.url.into()),
(Version, self.version.into()),
])
}
fn update_columns() -> Vec<impl IntoIden> {
vec![
HttpResponseIden::UpdatedAt,
HttpResponseIden::BodyPath,
HttpResponseIden::ContentLength,
HttpResponseIden::ContentLengthCompressed,
HttpResponseIden::Elapsed,
HttpResponseIden::ElapsedHeaders,
HttpResponseIden::ElapsedDns,
HttpResponseIden::Error,
HttpResponseIden::Headers,
HttpResponseIden::RemoteAddr,
HttpResponseIden::RequestContentLength,
HttpResponseIden::RequestHeaders,
HttpResponseIden::State,
HttpResponseIden::Status,
HttpResponseIden::StatusReason,
HttpResponseIden::Url,
HttpResponseIden::Version,
]
}
fn from_row(r: &Row) -> rusqlite::Result<Self>
where
Self: Sized,
{
let headers: String = r.get("headers")?;
let request_headers: String = r.get("request_headers")?;
let state: String = r.get("state")?;
Ok(Self {
id: r.get("id")?,
model: r.get("model")?,
proxy_request_id: r.get("proxy_request_id")?,
created_at: r.get("created_at")?,
updated_at: r.get("updated_at")?,
error: r.get("error")?,
url: r.get("url")?,
content_length: r.get("content_length")?,
content_length_compressed: r.get("content_length_compressed").unwrap_or_default(),
version: r.get("version")?,
elapsed: r.get("elapsed")?,
elapsed_headers: r.get("elapsed_headers")?,
elapsed_dns: r.get("elapsed_dns").unwrap_or_default(),
remote_addr: r.get("remote_addr")?,
status: r.get("status")?,
status_reason: r.get("status_reason")?,
state: serde_json::from_str(format!(r#""{state}""#).as_str()).unwrap_or_default(),
body_path: r.get("body_path")?,
headers: serde_json::from_str(headers.as_str()).unwrap_or_default(),
request_content_length: r.get("request_content_length").unwrap_or_default(),
request_headers: serde_json::from_str(request_headers.as_str()).unwrap_or_default(),
})
}
}