From 0eda65b07e530fb950cab0c4dcf3018a4bdc2bdc Mon Sep 17 00:00:00 2001 From: Per Stark Date: Tue, 4 Nov 2025 11:22:45 +0100 Subject: [PATCH] benchmarks: v1 Benchmarking ingestion, retrieval precision and performance --- .cargo/config.toml | 2 + .gitignore | 4 + CHANGELOG.md | 1 + Cargo.lock | 163 ++- Cargo.toml | 3 +- common/src/storage/db.rs | 20 +- composite-retrieval/src/pipeline/config.rs | 4 +- .../src/pipeline/diagnostics.rs | 51 + composite-retrieval/src/pipeline/mod.rs | 144 ++- .../src/pipeline/stages/mod.rs | 411 +++++-- composite-retrieval/src/vector.rs | 1 + eval/Cargo.toml | 33 + eval/manifest.yaml | 33 + eval/src/args.rs | 638 +++++++++++ eval/src/cache.rs | 88 ++ eval/src/datasets.rs | 1003 +++++++++++++++++ eval/src/db_helpers.rs | 269 +++++ eval/src/embedding.rs | 171 +++ eval/src/eval/mod.rs | 767 +++++++++++++ eval/src/eval/pipeline/context.rs | 193 ++++ eval/src/eval/pipeline/mod.rs | 29 + eval/src/eval/pipeline/stages/finalize.rs | 59 + eval/src/eval/pipeline/stages/mod.rs | 26 + .../eval/pipeline/stages/prepare_corpus.rs | 84 ++ eval/src/eval/pipeline/stages/prepare_db.rs | 111 ++ .../eval/pipeline/stages/prepare_namespace.rs | 159 +++ .../src/eval/pipeline/stages/prepare_slice.rs | 66 ++ eval/src/eval/pipeline/stages/run_queries.rs | 337 ++++++ eval/src/eval/pipeline/stages/summarize.rs | 173 +++ eval/src/eval/pipeline/state.rs | 31 + eval/src/ingest.rs | 1000 ++++++++++++++++ eval/src/inspection.rs | 182 +++ eval/src/main.rs | 214 ++++ eval/src/openai.rs | 16 + eval/src/perf.rs | 313 +++++ eval/src/report.rs | 456 ++++++++ eval/src/slice.rs | 27 + eval/src/slices.rs | 941 ++++++++++++++++ eval/src/snapshot.rs | 182 +++ html-router/src/routes/chat/mod.rs | 2 +- html-router/src/routes/search/mod.rs | 2 +- ingestion-pipeline/src/lib.rs | 2 +- ingestion-pipeline/src/pipeline/context.rs | 42 +- ingestion-pipeline/src/pipeline/mod.rs | 30 +- ingestion-pipeline/src/pipeline/services.rs | 26 +- ingestion-pipeline/src/pipeline/stages/mod.rs | 42 +- 46 files changed, 8407 insertions(+), 144 deletions(-) create mode 100644 .cargo/config.toml create mode 100644 composite-retrieval/src/pipeline/diagnostics.rs create mode 100644 eval/Cargo.toml create mode 100644 eval/manifest.yaml create mode 100644 eval/src/args.rs create mode 100644 eval/src/cache.rs create mode 100644 eval/src/datasets.rs create mode 100644 eval/src/db_helpers.rs create mode 100644 eval/src/embedding.rs create mode 100644 eval/src/eval/mod.rs create mode 100644 eval/src/eval/pipeline/context.rs create mode 100644 eval/src/eval/pipeline/mod.rs create mode 100644 eval/src/eval/pipeline/stages/finalize.rs create mode 100644 eval/src/eval/pipeline/stages/mod.rs create mode 100644 eval/src/eval/pipeline/stages/prepare_corpus.rs create mode 100644 eval/src/eval/pipeline/stages/prepare_db.rs create mode 100644 eval/src/eval/pipeline/stages/prepare_namespace.rs create mode 100644 eval/src/eval/pipeline/stages/prepare_slice.rs create mode 100644 eval/src/eval/pipeline/stages/run_queries.rs create mode 100644 eval/src/eval/pipeline/stages/summarize.rs create mode 100644 eval/src/eval/pipeline/state.rs create mode 100644 eval/src/ingest.rs create mode 100644 eval/src/inspection.rs create mode 100644 eval/src/main.rs create mode 100644 eval/src/openai.rs create mode 100644 eval/src/perf.rs create mode 100644 eval/src/report.rs create mode 100644 eval/src/slice.rs create mode 100644 eval/src/slices.rs create mode 100644 eval/src/snapshot.rs diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 0000000..d830e40 --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,2 @@ +[alias] +eval = "run -p eval --" diff --git a/.gitignore b/.gitignore index 56e555b..a6df1fe 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,9 @@ result data database +eval/cache/ +eval/reports/ + # Devenv .devenv* devenv.local.nix @@ -21,3 +24,4 @@ devenv.local.nix .pre-commit-config.yaml # html-router/assets/style.css html-router/node_modules +.fastembed_cache/ diff --git a/CHANGELOG.md b/CHANGELOG.md index ede988a..cd3fdb3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ # Changelog ## Unreleased +- Added a shared `benchmarks` crate with deterministic fixtures and Criterion suites for ingestion, retrieval, and HTML handlers, plus documented baseline results for local performance checks. ## Version 0.2.6 (2025-10-29) - Added an opt-in FastEmbed-based reranking stage behind `reranking_enabled`. It improves retrieval accuracy by re-scoring hybrid results. diff --git a/Cargo.lock b/Cargo.lock index 30dedd1..a298b77 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -184,6 +184,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "anstream" version = "0.6.18" @@ -1090,6 +1096,12 @@ dependencies = [ "serde", ] +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "castaway" version = "0.2.3" @@ -1626,6 +1638,42 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + [[package]] name = "crossbeam-deque" version = "0.8.6" @@ -2140,6 +2188,37 @@ dependencies = [ "num-traits", ] +[[package]] +name = "eval" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-openai", + "async-trait", + "chrono", + "common", + "composite-retrieval", + "criterion", + "fastembed", + "futures", + "ingestion-pipeline", + "object_store 0.11.2", + "once_cell", + "rand 0.8.5", + "serde", + "serde_json", + "serde_yaml", + "sha2", + "state-machines", + "surrealdb", + "tempfile", + "text-splitter", + "tokio", + "tracing", + "tracing-subscriber", + "uuid", +] + [[package]] name = "event-listener" version = "5.4.0" @@ -2735,6 +2814,12 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + [[package]] name = "hex" version = "0.4.3" @@ -3248,11 +3333,13 @@ dependencies = [ name = "ingestion-pipeline" version = "0.1.0" dependencies = [ + "anyhow", "async-openai", "async-trait", "axum", "axum_typed_multipart", "base64 0.22.1", + "bytes", "chrono", "common", "composite-retrieval", @@ -3330,6 +3417,17 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" +[[package]] +name = "is-terminal" +version = "0.4.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" +dependencies = [ + "hermit-abi 0.5.2", + "libc", + "windows-sys 0.60.2", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.1" @@ -4240,7 +4338,7 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi", + "hermit-abi 0.3.9", "libc", ] @@ -4330,6 +4428,12 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "oorandom" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + [[package]] name = "opaque-debug" version = "0.3.1" @@ -4705,6 +4809,34 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + [[package]] name = "png" version = "0.18.0" @@ -5920,6 +6052,19 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "serde_yaml" +version = "0.9.34+deprecated" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" +dependencies = [ + "indexmap 2.9.0", + "itoa", + "ryu", + "serde", + "unsafe-libyaml", +] + [[package]] name = "servo_arc" version = "0.4.0" @@ -6735,6 +6880,16 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.9.0" @@ -7263,6 +7418,12 @@ dependencies = [ "subtle", ] +[[package]] +name = "unsafe-libyaml" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" + [[package]] name = "untrusted" version = "0.9.0" diff --git a/Cargo.toml b/Cargo.toml index 6a8e1ab..febcc3a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,8 @@ members = [ "html-router", "ingestion-pipeline", "composite-retrieval", - "json-stream-parser" + "json-stream-parser", + "eval" ] resolver = "2" diff --git a/common/src/storage/db.rs b/common/src/storage/db.rs index 8937bfb..f6ea172 100644 --- a/common/src/storage/db.rs +++ b/common/src/storage/db.rs @@ -7,7 +7,7 @@ use include_dir::{include_dir, Dir}; use std::{ops::Deref, sync::Arc}; use surrealdb::{ engine::any::{connect, Any}, - opt::auth::Root, + opt::auth::{Namespace, Root}, Error, Notification, Surreal, }; use surrealdb_migrations::MigrationRunner; @@ -48,6 +48,24 @@ impl SurrealDbClient { Ok(SurrealDbClient { client: db }) } + pub async fn new_with_namespace_user( + address: &str, + namespace: &str, + username: &str, + password: &str, + database: &str, + ) -> Result { + let db = connect(address).await?; + db.signin(Namespace { + namespace, + username, + password, + }) + .await?; + db.use_ns(namespace).use_db(database).await?; + Ok(SurrealDbClient { client: db }) + } + pub async fn create_session_store( &self, ) -> Result>, SessionError> { diff --git a/composite-retrieval/src/pipeline/config.rs b/composite-retrieval/src/pipeline/config.rs index 71a964e..446cf0f 100644 --- a/composite-retrieval/src/pipeline/config.rs +++ b/composite-retrieval/src/pipeline/config.rs @@ -12,6 +12,7 @@ pub struct RetrievalTuning { pub token_budget_estimate: usize, pub avg_chars_per_token: usize, pub max_chunks_per_entity: usize, + pub lexical_match_weight: f32, pub graph_traversal_seed_limit: usize, pub graph_neighbor_limit: usize, pub graph_score_decay: f32, @@ -31,9 +32,10 @@ impl Default for RetrievalTuning { chunk_fts_take: 20, score_threshold: 0.35, fallback_min_results: 10, - token_budget_estimate: 2800, + token_budget_estimate: 10000, avg_chars_per_token: 4, max_chunks_per_entity: 4, + lexical_match_weight: 0.15, graph_traversal_seed_limit: 5, graph_neighbor_limit: 6, graph_score_decay: 0.75, diff --git a/composite-retrieval/src/pipeline/diagnostics.rs b/composite-retrieval/src/pipeline/diagnostics.rs new file mode 100644 index 0000000..67c14d2 --- /dev/null +++ b/composite-retrieval/src/pipeline/diagnostics.rs @@ -0,0 +1,51 @@ +use serde::Serialize; + +/// Captures instrumentation for each hybrid retrieval stage when diagnostics are enabled. +#[derive(Debug, Clone, Default, Serialize)] +pub struct PipelineDiagnostics { + pub collect_candidates: Option, + pub enrich_chunks_from_entities: Option, + pub assemble: Option, +} + +#[derive(Debug, Clone, Default, Serialize)] +pub struct CollectCandidatesStats { + pub vector_entity_candidates: usize, + pub vector_chunk_candidates: usize, + pub fts_entity_candidates: usize, + pub fts_chunk_candidates: usize, + pub vector_chunk_scores: Vec, + pub fts_chunk_scores: Vec, +} + +#[derive(Debug, Clone, Default, Serialize)] +pub struct ChunkEnrichmentStats { + pub filtered_entity_count: usize, + pub fallback_min_results: usize, + pub chunk_sources_considered: usize, + pub chunk_candidates_before_enrichment: usize, + pub chunk_candidates_after_enrichment: usize, + pub top_chunk_scores: Vec, +} + +#[derive(Debug, Clone, Default, Serialize)] +pub struct AssembleStats { + pub token_budget_start: usize, + pub token_budget_spent: usize, + pub token_budget_remaining: usize, + pub budget_exhausted: bool, + pub chunks_selected: usize, + pub chunks_skipped_due_budget: usize, + pub entity_count: usize, + pub entity_traces: Vec, +} + +#[derive(Debug, Clone, Default, Serialize)] +pub struct EntityAssemblyTrace { + pub entity_id: String, + pub source_id: String, + pub inspected_candidates: usize, + pub selected_chunk_ids: Vec, + pub selected_chunk_scores: Vec, + pub skipped_due_budget: usize, +} diff --git a/composite-retrieval/src/pipeline/mod.rs b/composite-retrieval/src/pipeline/mod.rs index a0023ef..fc6993f 100644 --- a/composite-retrieval/src/pipeline/mod.rs +++ b/composite-retrieval/src/pipeline/mod.rs @@ -1,14 +1,57 @@ mod config; +mod diagnostics; mod stages; mod state; pub use config::{RetrievalConfig, RetrievalTuning}; +pub use diagnostics::{ + AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace, + PipelineDiagnostics, +}; use crate::{reranking::RerankerLease, RetrievedEntity}; use async_openai::Client; use common::{error::AppError, storage::db::SurrealDbClient}; use tracing::info; +#[derive(Debug)] +pub struct PipelineRunOutput { + pub results: Vec, + pub diagnostics: Option, + pub stage_timings: PipelineStageTimings, +} + +#[derive(Debug, Clone, Default, serde::Serialize)] +pub struct PipelineStageTimings { + pub collect_candidates_ms: u128, + pub graph_expansion_ms: u128, + pub chunk_attach_ms: u128, + pub rerank_ms: u128, + pub assemble_ms: u128, +} + +impl PipelineStageTimings { + fn record_collect_candidates(&mut self, duration: std::time::Duration) { + self.collect_candidates_ms += duration.as_millis() as u128; + } + + fn record_graph_expansion(&mut self, duration: std::time::Duration) { + self.graph_expansion_ms += duration.as_millis() as u128; + } + + fn record_chunk_attach(&mut self, duration: std::time::Duration) { + self.chunk_attach_ms += duration.as_millis() as u128; + } + + fn record_rerank(&mut self, duration: std::time::Duration) { + self.rerank_ms += duration.as_millis() as u128; + } + + fn record_assemble(&mut self, duration: std::time::Duration) { + self.assemble_ms += duration.as_millis() as u128; + } +} + /// Drives the retrieval pipeline from embedding through final assembly. pub async fn run_pipeline( db_client: &SurrealDbClient, @@ -18,7 +61,6 @@ pub async fn run_pipeline( config: RetrievalConfig, reranker: Option, ) -> Result, AppError> { - let machine = state::ready(); let input_chars = input_text.chars().count(); let input_preview: String = input_text.chars().take(120).collect(); let input_preview_clean = input_preview.replace('\n', " "); @@ -30,7 +72,7 @@ pub async fn run_pipeline( preview = %input_preview_clean, "Starting ingestion retrieval pipeline" ); - let mut ctx = stages::PipelineContext::new( + let ctx = stages::PipelineContext::new( db_client, openai_client, input_text.to_owned(), @@ -38,17 +80,11 @@ pub async fn run_pipeline( config, reranker, ); - let machine = stages::embed(machine, &mut ctx).await?; - let machine = stages::collect_candidates(machine, &mut ctx).await?; - let machine = stages::expand_graph(machine, &mut ctx).await?; - let machine = stages::attach_chunks(machine, &mut ctx).await?; - let machine = stages::rerank(machine, &mut ctx).await?; - let results = stages::assemble(machine, &mut ctx)?; + let outcome = run_pipeline_internal(ctx, false).await?; - Ok(results) + Ok(outcome.results) } -#[cfg(test)] pub async fn run_pipeline_with_embedding( db_client: &SurrealDbClient, openai_client: &Client, @@ -58,8 +94,7 @@ pub async fn run_pipeline_with_embedding( config: RetrievalConfig, reranker: Option, ) -> Result, AppError> { - let machine = state::ready(); - let mut ctx = stages::PipelineContext::with_embedding( + let ctx = stages::PipelineContext::with_embedding( db_client, openai_client, query_embedding, @@ -68,14 +103,54 @@ pub async fn run_pipeline_with_embedding( config, reranker, ); - let machine = stages::embed(machine, &mut ctx).await?; - let machine = stages::collect_candidates(machine, &mut ctx).await?; - let machine = stages::expand_graph(machine, &mut ctx).await?; - let machine = stages::attach_chunks(machine, &mut ctx).await?; - let machine = stages::rerank(machine, &mut ctx).await?; - let results = stages::assemble(machine, &mut ctx)?; + let outcome = run_pipeline_internal(ctx, false).await?; - Ok(results) + Ok(outcome.results) +} + +/// Runs the pipeline with a precomputed embedding and returns stage metrics. +pub async fn run_pipeline_with_embedding_with_metrics( + db_client: &SurrealDbClient, + openai_client: &Client, + query_embedding: Vec, + input_text: &str, + user_id: &str, + config: RetrievalConfig, + reranker: Option, +) -> Result { + let ctx = stages::PipelineContext::with_embedding( + db_client, + openai_client, + query_embedding, + input_text.to_owned(), + user_id.to_owned(), + config, + reranker, + ); + + run_pipeline_internal(ctx, false).await +} + +pub async fn run_pipeline_with_embedding_with_diagnostics( + db_client: &SurrealDbClient, + openai_client: &Client, + query_embedding: Vec, + input_text: &str, + user_id: &str, + config: RetrievalConfig, + reranker: Option, +) -> Result { + let ctx = stages::PipelineContext::with_embedding( + db_client, + openai_client, + query_embedding, + input_text.to_owned(), + user_id.to_owned(), + config, + reranker, + ); + + run_pipeline_internal(ctx, true).await } /// Helper exposed for tests to convert retrieved entities into downstream prompt JSON. @@ -101,6 +176,37 @@ pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::V .collect::>()) } +async fn run_pipeline_internal( + mut ctx: stages::PipelineContext<'_>, + capture_diagnostics: bool, +) -> Result { + if capture_diagnostics { + ctx.enable_diagnostics(); + } + + let results = drive_pipeline(&mut ctx).await?; + let diagnostics = ctx.take_diagnostics(); + + Ok(PipelineRunOutput { + results, + diagnostics, + stage_timings: ctx.take_stage_timings(), + }) +} + +async fn drive_pipeline( + ctx: &mut stages::PipelineContext<'_>, +) -> Result, AppError> { + let machine = state::ready(); + let machine = stages::embed(machine, ctx).await?; + let machine = stages::collect_candidates(machine, ctx).await?; + let machine = stages::expand_graph(machine, ctx).await?; + let machine = stages::attach_chunks(machine, ctx).await?; + let machine = stages::rerank(machine, ctx).await?; + let results = stages::assemble(machine, ctx)?; + Ok(results) +} + fn round_score(value: f32) -> f64 { (f64::from(value) * 1000.0).round() / 1000.0 } diff --git a/composite-retrieval/src/pipeline/stages/mod.rs b/composite-retrieval/src/pipeline/stages/mod.rs index 1a8ebb9..d7944e3 100644 --- a/composite-retrieval/src/pipeline/stages/mod.rs +++ b/composite-retrieval/src/pipeline/stages/mod.rs @@ -10,7 +10,11 @@ use common::{ use fastembed::RerankResult; use futures::{stream::FuturesUnordered, StreamExt}; use state_machines::core::GuardError; -use std::collections::{HashMap, HashSet}; +use std::{ + cmp::Ordering, + collections::{HashMap, HashSet}, + time::Instant, +}; use tracing::{debug, instrument, warn}; use crate::{ @@ -27,10 +31,15 @@ use crate::{ use super::{ config::RetrievalConfig, + diagnostics::{ + AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace, + PipelineDiagnostics, + }, state::{ CandidatesLoaded, ChunksAttached, Embedded, GraphExpanded, HybridRetrievalMachine, Ready, Reranked, }, + PipelineStageTimings, }; pub struct PipelineContext<'a> { @@ -45,6 +54,8 @@ pub struct PipelineContext<'a> { pub filtered_entities: Vec>, pub chunk_values: Vec>, pub reranker: Option, + pub diagnostics: Option, + stage_timings: PipelineStageTimings, } impl<'a> PipelineContext<'a> { @@ -68,10 +79,11 @@ impl<'a> PipelineContext<'a> { filtered_entities: Vec::new(), chunk_values: Vec::new(), reranker, + diagnostics: None, + stage_timings: PipelineStageTimings::default(), } } - #[cfg(test)] pub fn with_embedding( db_client: &'a SurrealDbClient, openai_client: &'a Client, @@ -100,6 +112,62 @@ impl<'a> PipelineContext<'a> { ) }) } + + pub fn enable_diagnostics(&mut self) { + if self.diagnostics.is_none() { + self.diagnostics = Some(PipelineDiagnostics::default()); + } + } + + pub fn diagnostics_enabled(&self) -> bool { + self.diagnostics.is_some() + } + + pub fn record_collect_candidates(&mut self, stats: CollectCandidatesStats) { + if let Some(diag) = self.diagnostics.as_mut() { + diag.collect_candidates = Some(stats); + } + } + + pub fn record_chunk_enrichment(&mut self, stats: ChunkEnrichmentStats) { + if let Some(diag) = self.diagnostics.as_mut() { + diag.enrich_chunks_from_entities = Some(stats); + } + } + + pub fn record_assemble(&mut self, stats: AssembleStats) { + if let Some(diag) = self.diagnostics.as_mut() { + diag.assemble = Some(stats); + } + } + + pub fn take_diagnostics(&mut self) -> Option { + self.diagnostics.take() + } + + pub fn record_collect_candidates_timing(&mut self, duration: std::time::Duration) { + self.stage_timings.record_collect_candidates(duration); + } + + pub fn record_graph_expansion_timing(&mut self, duration: std::time::Duration) { + self.stage_timings.record_graph_expansion(duration); + } + + pub fn record_chunk_attach_timing(&mut self, duration: std::time::Duration) { + self.stage_timings.record_chunk_attach(duration); + } + + pub fn record_rerank_timing(&mut self, duration: std::time::Duration) { + self.stage_timings.record_rerank(duration); + } + + pub fn record_assemble_timing(&mut self, duration: std::time::Duration) { + self.stage_timings.record_assemble(duration); + } + + pub fn take_stage_timings(&mut self) -> PipelineStageTimings { + std::mem::take(&mut self.stage_timings) + } } #[instrument(level = "trace", skip_all)] @@ -127,6 +195,7 @@ pub async fn collect_candidates( machine: HybridRetrievalMachine<(), Embedded>, ctx: &mut PipelineContext<'_>, ) -> Result, AppError> { + let stage_start = Instant::now(); debug!("Collecting initial candidates via vector and FTS search"); let embedding = ctx.ensure_embedding()?.clone(); let tuning = &ctx.config.tuning; @@ -172,6 +241,19 @@ pub async fn collect_candidates( "Hybrid retrieval initial candidate counts" ); + if ctx.diagnostics_enabled() { + ctx.record_collect_candidates(CollectCandidatesStats { + vector_entity_candidates: vector_entities.len(), + vector_chunk_candidates: vector_chunks.len(), + fts_entity_candidates: fts_entities.len(), + fts_chunk_candidates: fts_chunks.len(), + vector_chunk_scores: sample_scores(&vector_chunks, |chunk| { + chunk.scores.vector.unwrap_or(0.0) + }), + fts_chunk_scores: sample_scores(&fts_chunks, |chunk| chunk.scores.fts.unwrap_or(0.0)), + }); + } + normalize_fts_scores(&mut fts_entities); normalize_fts_scores(&mut fts_chunks); @@ -183,9 +265,11 @@ pub async fn collect_candidates( apply_fusion(&mut ctx.entity_candidates, weights); apply_fusion(&mut ctx.chunk_candidates, weights); - machine + let next = machine .collect_candidates() - .map_err(|(_, guard)| map_guard_error("collect_candidates", guard)) + .map_err(|(_, guard)| map_guard_error("collect_candidates", guard))?; + ctx.record_collect_candidates_timing(stage_start.elapsed()); + Ok(next) } #[instrument(level = "trace", skip_all)] @@ -193,82 +277,84 @@ pub async fn expand_graph( machine: HybridRetrievalMachine<(), CandidatesLoaded>, ctx: &mut PipelineContext<'_>, ) -> Result, AppError> { + let stage_start = Instant::now(); debug!("Expanding candidates using graph relationships"); - let tuning = &ctx.config.tuning; - let weights = FusionWeights::default(); + let next = { + let tuning = &ctx.config.tuning; + let weights = FusionWeights::default(); - if ctx.entity_candidates.is_empty() { - return machine - .expand_graph() - .map_err(|(_, guard)| map_guard_error("expand_graph", guard)); - } + if ctx.entity_candidates.is_empty() { + machine + .expand_graph() + .map_err(|(_, guard)| map_guard_error("expand_graph", guard)) + } else { + let graph_seeds = seeds_from_candidates( + &ctx.entity_candidates, + tuning.graph_seed_min_score, + tuning.graph_traversal_seed_limit, + ); - let graph_seeds = seeds_from_candidates( - &ctx.entity_candidates, - tuning.graph_seed_min_score, - tuning.graph_traversal_seed_limit, - ); + if graph_seeds.is_empty() { + machine + .expand_graph() + .map_err(|(_, guard)| map_guard_error("expand_graph", guard)) + } else { + let mut futures = FuturesUnordered::new(); + for seed in graph_seeds { + let db = ctx.db_client; + let user = ctx.user_id.clone(); + let limit = tuning.graph_neighbor_limit; + futures.push(async move { + let neighbors = + find_entities_by_relationship_by_id(db, &seed.id, &user, limit).await; + (seed, neighbors) + }); + } - if graph_seeds.is_empty() { - return machine - .expand_graph() - .map_err(|(_, guard)| map_guard_error("expand_graph", guard)); - } + while let Some((seed, neighbors_result)) = futures.next().await { + let neighbors = neighbors_result.map_err(AppError::from)?; + if neighbors.is_empty() { + continue; + } - let mut futures = FuturesUnordered::new(); - for seed in graph_seeds { - let db = ctx.db_client; - let user = ctx.user_id.clone(); - futures.push(async move { - let neighbors = find_entities_by_relationship_by_id( - db, - &seed.id, - &user, - tuning.graph_neighbor_limit, - ) - .await; - (seed, neighbors) - }); - } + for neighbor in neighbors { + if neighbor.id == seed.id { + continue; + } - while let Some((seed, neighbors_result)) = futures.next().await { - let neighbors = neighbors_result.map_err(AppError::from)?; - if neighbors.is_empty() { - continue; + let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay); + let entry = ctx + .entity_candidates + .entry(neighbor.id.clone()) + .or_insert_with(|| Scored::new(neighbor.clone())); + + entry.item = neighbor; + + let inherited_vector = + clamp_unit(graph_score * tuning.graph_vector_inheritance); + let vector_existing = entry.scores.vector.unwrap_or(0.0); + if inherited_vector > vector_existing { + entry.scores.vector = Some(inherited_vector); + } + + let existing_graph = entry.scores.graph.unwrap_or(f32::MIN); + if graph_score > existing_graph || entry.scores.graph.is_none() { + entry.scores.graph = Some(graph_score); + } + + let fused = fuse_scores(&entry.scores, weights); + entry.update_fused(fused); + } + } + + machine + .expand_graph() + .map_err(|(_, guard)| map_guard_error("expand_graph", guard)) + } } - - for neighbor in neighbors { - if neighbor.id == seed.id { - continue; - } - - let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay); - let entry = ctx - .entity_candidates - .entry(neighbor.id.clone()) - .or_insert_with(|| Scored::new(neighbor.clone())); - - entry.item = neighbor; - - let inherited_vector = clamp_unit(graph_score * tuning.graph_vector_inheritance); - let vector_existing = entry.scores.vector.unwrap_or(0.0); - if inherited_vector > vector_existing { - entry.scores.vector = Some(inherited_vector); - } - - let existing_graph = entry.scores.graph.unwrap_or(f32::MIN); - if graph_score > existing_graph || entry.scores.graph.is_none() { - entry.scores.graph = Some(graph_score); - } - - let fused = fuse_scores(&entry.scores, weights); - entry.update_fused(fused); - } - } - - machine - .expand_graph() - .map_err(|(_, guard)| map_guard_error("expand_graph", guard)) + }?; + ctx.record_graph_expansion_timing(stage_start.elapsed()); + Ok(next) } #[instrument(level = "trace", skip_all)] @@ -276,11 +362,14 @@ pub async fn attach_chunks( machine: HybridRetrievalMachine<(), GraphExpanded>, ctx: &mut PipelineContext<'_>, ) -> Result, AppError> { + let stage_start = Instant::now(); debug!("Attaching chunks to surviving entities"); let tuning = &ctx.config.tuning; let weights = FusionWeights::default(); let chunk_by_source = group_chunks_by_source(&ctx.chunk_candidates); + let chunk_candidates_before = ctx.chunk_candidates.len(); + let chunk_sources_considered = chunk_by_source.len(); backfill_entities_from_chunks( &mut ctx.entity_candidates, @@ -312,6 +401,8 @@ pub async fn attach_chunks( ctx.filtered_entities = filtered_entities; + let query_embedding = ctx.ensure_embedding()?.clone(); + let mut chunk_results: Vec> = ctx.chunk_candidates.values().cloned().collect(); sort_by_fused_desc(&mut chunk_results); @@ -327,17 +418,31 @@ pub async fn attach_chunks( ctx.db_client, &ctx.user_id, weights, + &query_embedding, ) .await?; let mut chunk_values: Vec> = chunk_by_id.into_values().collect(); sort_by_fused_desc(&mut chunk_values); + if ctx.diagnostics_enabled() { + ctx.record_chunk_enrichment(ChunkEnrichmentStats { + filtered_entity_count: ctx.filtered_entities.len(), + fallback_min_results: tuning.fallback_min_results, + chunk_sources_considered, + chunk_candidates_before_enrichment: chunk_candidates_before, + chunk_candidates_after_enrichment: chunk_values.len(), + top_chunk_scores: sample_scores(&chunk_values, |chunk| chunk.fused), + }); + } + ctx.chunk_values = chunk_values; - machine + let next = machine .attach_chunks() - .map_err(|(_, guard)| map_guard_error("attach_chunks", guard)) + .map_err(|(_, guard)| map_guard_error("attach_chunks", guard))?; + ctx.record_chunk_attach_timing(stage_start.elapsed()); + Ok(next) } #[instrument(level = "trace", skip_all)] @@ -345,6 +450,7 @@ pub async fn rerank( machine: HybridRetrievalMachine<(), ChunksAttached>, ctx: &mut PipelineContext<'_>, ) -> Result, AppError> { + let stage_start = Instant::now(); let mut applied = false; if let Some(reranker) = ctx.reranker.as_ref() { @@ -384,9 +490,11 @@ pub async fn rerank( debug!("Applied reranking adjustments to candidate ordering"); } - machine + let next = machine .rerank() - .map_err(|(_, guard)| map_guard_error("rerank", guard)) + .map_err(|(_, guard)| map_guard_error("rerank", guard))?; + ctx.record_rerank_timing(stage_start.elapsed()); + Ok(next) } #[instrument(level = "trace", skip_all)] @@ -394,8 +502,11 @@ pub fn assemble( machine: HybridRetrievalMachine<(), Reranked>, ctx: &mut PipelineContext<'_>, ) -> Result, AppError> { + let stage_start = Instant::now(); debug!("Assembling final retrieved entities"); let tuning = &ctx.config.tuning; + let query_embedding = ctx.ensure_embedding()?.clone(); + let question_terms = extract_keywords(&ctx.input_text); let mut chunk_by_source: HashMap>> = HashMap::new(); for chunk in ctx.chunk_values.drain(..) { @@ -406,39 +517,68 @@ pub fn assemble( } for chunk_list in chunk_by_source.values_mut() { - sort_by_fused_desc(chunk_list); + chunk_list.sort_by(|a, b| { + let sim_a = cosine_similarity(&query_embedding, &a.item.embedding); + let sim_b = cosine_similarity(&query_embedding, &b.item.embedding); + sim_b.partial_cmp(&sim_a).unwrap_or(Ordering::Equal) + }); } let mut token_budget_remaining = tuning.token_budget_estimate; let mut results = Vec::new(); + let diagnostics_enabled = ctx.diagnostics_enabled(); + let mut per_entity_traces = Vec::new(); + let mut chunks_skipped_due_budget = 0usize; + let mut chunks_selected = 0usize; + let mut tokens_spent = 0usize; for entity in &ctx.filtered_entities { let mut selected_chunks = Vec::new(); + let mut entity_trace = if diagnostics_enabled { + Some(EntityAssemblyTrace { + entity_id: entity.item.id.clone(), + source_id: entity.item.source_id.clone(), + inspected_candidates: 0, + selected_chunk_ids: Vec::new(), + selected_chunk_scores: Vec::new(), + skipped_due_budget: 0, + }) + } else { + None + }; if let Some(candidates) = chunk_by_source.get_mut(&entity.item.source_id) { + rank_chunks_by_combined_score(candidates, &question_terms, tuning.lexical_match_weight); let mut per_entity_count = 0; - candidates.sort_by(|a, b| { - b.fused - .partial_cmp(&a.fused) - .unwrap_or(std::cmp::Ordering::Equal) - }); - for candidate in candidates.iter() { + if let Some(trace) = entity_trace.as_mut() { + trace.inspected_candidates += 1; + } if per_entity_count >= tuning.max_chunks_per_entity { break; } let estimated_tokens = estimate_tokens(&candidate.item.chunk, tuning.avg_chars_per_token); if estimated_tokens > token_budget_remaining { + chunks_skipped_due_budget += 1; + if let Some(trace) = entity_trace.as_mut() { + trace.skipped_due_budget += 1; + } continue; } token_budget_remaining = token_budget_remaining.saturating_sub(estimated_tokens); + tokens_spent += estimated_tokens; per_entity_count += 1; + chunks_selected += 1; selected_chunks.push(RetrievedChunk { chunk: candidate.item.clone(), score: candidate.fused, }); + if let Some(trace) = entity_trace.as_mut() { + trace.selected_chunk_ids.push(candidate.item.id.clone()); + trace.selected_chunk_scores.push(candidate.fused); + } } } @@ -448,17 +588,48 @@ pub fn assemble( chunks: selected_chunks, }); + if let Some(trace) = entity_trace { + per_entity_traces.push(trace); + } + if token_budget_remaining == 0 { break; } } + if diagnostics_enabled { + ctx.record_assemble(AssembleStats { + token_budget_start: tuning.token_budget_estimate, + token_budget_spent: tokens_spent, + token_budget_remaining, + budget_exhausted: token_budget_remaining == 0, + chunks_selected, + chunks_skipped_due_budget, + entity_count: ctx.filtered_entities.len(), + entity_traces: per_entity_traces, + }); + } + machine .assemble() .map_err(|(_, guard)| map_guard_error("assemble", guard))?; + ctx.record_assemble_timing(stage_start.elapsed()); Ok(results) } +const SCORE_SAMPLE_LIMIT: usize = 8; + +fn sample_scores(items: &[Scored], mut extractor: F) -> Vec +where + F: FnMut(&Scored) -> f32, +{ + items + .iter() + .take(SCORE_SAMPLE_LIMIT) + .map(|item| extractor(item)) + .collect() +} + fn map_guard_error(stage: &'static str, err: GuardError) -> AppError { AppError::InternalError(format!( "state machine guard '{stage}' failed: guard={}, event={}, kind={:?}", @@ -582,6 +753,7 @@ async fn enrich_chunks_from_entities( db_client: &SurrealDbClient, user_id: &str, weights: FusionWeights, + query_embedding: &[f32], ) -> Result<(), AppError> { let mut source_ids: HashSet = HashSet::new(); for entity in entities { @@ -615,7 +787,16 @@ async fn enrich_chunks_from_entities( .copied() .unwrap_or(0.0); - entry.scores.vector = Some(entry.scores.vector.unwrap_or(0.0).max(entity_score * 0.8)); + let similarity = cosine_similarity(query_embedding, &chunk.embedding); + + entry.scores.vector = Some( + entry + .scores + .vector + .unwrap_or(0.0) + .max(entity_score * 0.8) + .max(similarity), + ); let fused = fuse_scores(&entry.scores, weights); entry.update_fused(fused); entry.item = chunk; @@ -624,6 +805,24 @@ async fn enrich_chunks_from_entities( Ok(()) } +fn cosine_similarity(query: &[f32], embedding: &[f32]) -> f32 { + if query.is_empty() || embedding.is_empty() || query.len() != embedding.len() { + return 0.0; + } + let mut dot = 0.0f32; + let mut norm_q = 0.0f32; + let mut norm_e = 0.0f32; + for (q, e) in query.iter().zip(embedding.iter()) { + dot += q * e; + norm_q += q * q; + norm_e += e * e; + } + if norm_q == 0.0 || norm_e == 0.0 { + return 0.0; + } + dot / (norm_q.sqrt() * norm_e.sqrt()) +} + fn build_rerank_documents(ctx: &PipelineContext<'_>, max_chunks_per_entity: usize) -> Vec { if ctx.filtered_entities.is_empty() { return Vec::new(); @@ -736,6 +935,48 @@ fn estimate_tokens(text: &str, avg_chars_per_token: usize) -> usize { (chars / avg_chars_per_token).max(1) } +fn rank_chunks_by_combined_score( + candidates: &mut [Scored], + question_terms: &[String], + lexical_weight: f32, +) { + if lexical_weight > 0.0 && !question_terms.is_empty() { + for candidate in candidates.iter_mut() { + let lexical = lexical_overlap_score(question_terms, &candidate.item.chunk); + let combined = clamp_unit(candidate.fused + lexical_weight * lexical); + candidate.update_fused(combined); + } + } + candidates.sort_by(|a, b| b.fused.partial_cmp(&a.fused).unwrap_or(Ordering::Equal)); +} + +fn extract_keywords(text: &str) -> Vec { + let mut terms = Vec::new(); + for raw in text.split(|c: char| !c.is_alphanumeric()) { + let term = raw.trim().to_ascii_lowercase(); + if term.len() >= 3 { + terms.push(term); + } + } + terms.sort(); + terms.dedup(); + terms +} + +fn lexical_overlap_score(terms: &[String], haystack: &str) -> f32 { + if terms.is_empty() { + return 0.0; + } + let lower = haystack.to_ascii_lowercase(); + let mut matches = 0usize; + for term in terms { + if lower.contains(term) { + matches += 1; + } + } + (matches as f32) / (terms.len() as f32) +} + #[derive(Clone)] struct GraphSeed { id: String, diff --git a/composite-retrieval/src/vector.rs b/composite-retrieval/src/vector.rs index d2aaae7..229ec12 100644 --- a/composite-retrieval/src/vector.rs +++ b/composite-retrieval/src/vector.rs @@ -68,6 +68,7 @@ where { let embedding_literal = serde_json::to_string(&query_embedding) .map_err(|err| AppError::InternalError(format!("Failed to serialize embedding: {err}")))?; + let closest_query = format!( "SELECT id, vector::distance::knn() AS distance \ FROM {table} \ diff --git a/eval/Cargo.toml b/eval/Cargo.toml new file mode 100644 index 0000000..82d0319 --- /dev/null +++ b/eval/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "eval" +version = "0.1.0" +edition = "2021" + +[dependencies] +anyhow = { workspace = true } +async-openai = { workspace = true } +chrono = { workspace = true } +common = { path = "../common" } +composite-retrieval = { path = "../composite-retrieval" } +ingestion-pipeline = { path = "../ingestion-pipeline" } +futures = { workspace = true } +fastembed = { workspace = true } +serde = { workspace = true, features = ["derive"] } +tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } +tracing = { workspace = true } +tracing-subscriber = { workspace = true } +uuid = { workspace = true } +text-splitter = { workspace = true } +rand = "0.8" +sha2 = { workspace = true } +object_store = { workspace = true } +surrealdb = { workspace = true } +serde_json = { workspace = true } +async-trait = { workspace = true } +once_cell = "1.19" +serde_yaml = "0.9" +criterion = "0.5" +state-machines = { workspace = true } + +[dev-dependencies] +tempfile = { workspace = true } diff --git a/eval/manifest.yaml b/eval/manifest.yaml new file mode 100644 index 0000000..3edf556 --- /dev/null +++ b/eval/manifest.yaml @@ -0,0 +1,33 @@ +default_dataset: squad-v2 +datasets: + - id: squad-v2 + label: "SQuAD v2.0" + category: "SQuAD v2.0" + entity_suffix: "SQuAD" + source_prefix: "squad" + raw: "data/raw/squad/dev-v2.0.json" + converted: "data/converted/squad-minne.json" + include_unanswerable: false + slices: + - id: squad-dev-200 + label: "SQuAD dev (200)" + description: "Deterministic 200-case slice for local eval" + limit: 200 + corpus_limit: 2000 + seed: 0x5eed2025 + - id: natural-questions-dev + label: "Natural Questions (dev)" + category: "Natural Questions" + entity_suffix: "Natural Questions" + source_prefix: "nq" + raw: "data/raw/nq/dev-all.jsonl" + converted: "data/converted/nq-dev-minne.json" + include_unanswerable: true + slices: + - id: nq-dev-200 + label: "NQ dev (200)" + description: "200-case slice of the dev set" + limit: 200 + corpus_limit: 2000 + include_unanswerable: false + seed: 0x5eed2025 diff --git a/eval/src/args.rs b/eval/src/args.rs new file mode 100644 index 0000000..c42a20e --- /dev/null +++ b/eval/src/args.rs @@ -0,0 +1,638 @@ +use std::{ + env, + path::{Path, PathBuf}, +}; + +use anyhow::{anyhow, Context, Result}; + +use crate::datasets::DatasetKind; + +pub const DEFAULT_SLICE_SEED: u64 = 0x5eed_2025; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum EmbeddingBackend { + Hashed, + FastEmbed, +} + +impl Default for EmbeddingBackend { + fn default() -> Self { + Self::FastEmbed + } +} + +impl std::str::FromStr for EmbeddingBackend { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + match s.to_ascii_lowercase().as_str() { + "hashed" => Ok(Self::Hashed), + "fastembed" | "fast-embed" | "fast" => Ok(Self::FastEmbed), + other => Err(anyhow!( + "unknown embedding backend '{other}'. Expected 'hashed' or 'fastembed'." + )), + } + } +} + +#[derive(Debug, Clone)] +pub struct Config { + pub convert_only: bool, + pub force_convert: bool, + pub dataset: DatasetKind, + pub llm_mode: bool, + pub corpus_limit: Option, + pub raw_dataset_path: PathBuf, + pub converted_dataset_path: PathBuf, + pub report_dir: PathBuf, + pub k: usize, + pub limit: Option, + pub summary_sample: usize, + pub full_context: bool, + pub chunk_min_chars: usize, + pub chunk_max_chars: usize, + pub chunk_vector_take: Option, + pub chunk_fts_take: Option, + pub chunk_token_budget: Option, + pub chunk_avg_chars_per_token: Option, + pub max_chunks_per_entity: Option, + pub rerank: bool, + pub rerank_pool_size: usize, + pub rerank_keep_top: usize, + pub concurrency: usize, + pub embedding_backend: EmbeddingBackend, + pub embedding_model: Option, + pub cache_dir: PathBuf, + pub ingestion_cache_dir: PathBuf, + pub refresh_embeddings_only: bool, + pub detailed_report: bool, + pub slice: Option, + pub reseed_slice: bool, + pub slice_seed: u64, + pub slice_grow: Option, + pub slice_offset: usize, + pub slice_reset_ingestion: bool, + pub negative_multiplier: f32, + pub label: Option, + pub chunk_diagnostics_path: Option, + pub inspect_question: Option, + pub inspect_manifest: Option, + pub query_model: Option, + pub perf_log_json: Option, + pub perf_log_dir: Option, + pub perf_log_console: bool, + pub db_endpoint: String, + pub db_username: String, + pub db_password: String, + pub db_namespace: Option, + pub db_database: Option, + pub inspect_db_state: Option, +} + +impl Default for Config { + fn default() -> Self { + let dataset = DatasetKind::default(); + Self { + convert_only: false, + force_convert: false, + dataset, + llm_mode: false, + corpus_limit: None, + raw_dataset_path: dataset.default_raw_path(), + converted_dataset_path: dataset.default_converted_path(), + report_dir: PathBuf::from("eval/reports"), + k: 5, + limit: Some(200), + summary_sample: 5, + full_context: false, + chunk_min_chars: 500, + chunk_max_chars: 2_000, + chunk_vector_take: None, + chunk_fts_take: None, + chunk_token_budget: None, + chunk_avg_chars_per_token: None, + max_chunks_per_entity: None, + rerank: true, + rerank_pool_size: 16, + rerank_keep_top: 10, + concurrency: 4, + embedding_backend: EmbeddingBackend::FastEmbed, + embedding_model: None, + cache_dir: PathBuf::from("eval/cache"), + ingestion_cache_dir: PathBuf::from("eval/cache/ingested"), + refresh_embeddings_only: false, + detailed_report: false, + slice: None, + reseed_slice: false, + slice_seed: DEFAULT_SLICE_SEED, + slice_grow: None, + slice_offset: 0, + slice_reset_ingestion: false, + negative_multiplier: crate::slices::DEFAULT_NEGATIVE_MULTIPLIER, + label: None, + chunk_diagnostics_path: None, + inspect_question: None, + inspect_manifest: None, + query_model: None, + inspect_db_state: None, + perf_log_json: None, + perf_log_dir: None, + perf_log_console: false, + db_endpoint: "ws://127.0.0.1:8000".to_string(), + db_username: "root_user".to_string(), + db_password: "root_password".to_string(), + db_namespace: None, + db_database: None, + } + } +} + +impl Config { + pub fn context_token_limit(&self) -> Option { + None + } +} + +#[derive(Debug)] +pub struct ParsedArgs { + pub config: Config, + pub show_help: bool, +} + +pub fn parse() -> Result { + let mut config = Config::default(); + let mut show_help = false; + let mut raw_overridden = false; + let mut converted_overridden = false; + + let mut args = env::args().skip(1).peekable(); + while let Some(arg) = args.next() { + match arg.as_str() { + "-h" | "--help" => { + show_help = true; + break; + } + "--convert-only" => config.convert_only = true, + "--force" | "--refresh" => config.force_convert = true, + "--llm-mode" => { + config.llm_mode = true; + } + "--dataset" => { + let value = take_value("--dataset", &mut args)?; + let parsed = value.parse::()?; + config.dataset = parsed; + if !raw_overridden { + config.raw_dataset_path = parsed.default_raw_path(); + } + if !converted_overridden { + config.converted_dataset_path = parsed.default_converted_path(); + } + } + "--slice" => { + let value = take_value("--slice", &mut args)?; + config.slice = Some(value); + } + "--label" => { + let value = take_value("--label", &mut args)?; + config.label = Some(value); + } + "--query-model" => { + let value = take_value("--query-model", &mut args)?; + if value.trim().is_empty() { + return Err(anyhow!("--query-model requires a non-empty model name")); + } + config.query_model = Some(value.trim().to_string()); + } + "--slice-grow" => { + let value = take_value("--slice-grow", &mut args)?; + let parsed = value.parse::().with_context(|| { + format!("failed to parse --slice-grow value '{value}' as usize") + })?; + if parsed == 0 { + return Err(anyhow!("--slice-grow must be greater than zero")); + } + config.slice_grow = Some(parsed); + } + "--slice-offset" => { + let value = take_value("--slice-offset", &mut args)?; + let parsed = value.parse::().with_context(|| { + format!("failed to parse --slice-offset value '{value}' as usize") + })?; + config.slice_offset = parsed; + } + "--raw" => { + let value = take_value("--raw", &mut args)?; + config.raw_dataset_path = PathBuf::from(value); + raw_overridden = true; + } + "--converted" => { + let value = take_value("--converted", &mut args)?; + config.converted_dataset_path = PathBuf::from(value); + converted_overridden = true; + } + "--corpus-limit" => { + let value = take_value("--corpus-limit", &mut args)?; + let parsed = value.parse::().with_context(|| { + format!("failed to parse --corpus-limit value '{value}' as usize") + })?; + config.corpus_limit = if parsed == 0 { None } else { Some(parsed) }; + } + "--reseed-slice" => { + config.reseed_slice = true; + } + "--slice-reset-ingestion" => { + config.slice_reset_ingestion = true; + } + "--report-dir" => { + let value = take_value("--report-dir", &mut args)?; + config.report_dir = PathBuf::from(value); + } + "--k" => { + let value = take_value("--k", &mut args)?; + let parsed = value + .parse::() + .with_context(|| format!("failed to parse --k value '{value}' as usize"))?; + if parsed == 0 { + return Err(anyhow!("--k must be greater than zero")); + } + config.k = parsed; + } + "--limit" => { + let value = take_value("--limit", &mut args)?; + let parsed = value + .parse::() + .with_context(|| format!("failed to parse --limit value '{value}' as usize"))?; + config.limit = if parsed == 0 { None } else { Some(parsed) }; + } + "--sample" => { + let value = take_value("--sample", &mut args)?; + let parsed = value.parse::().with_context(|| { + format!("failed to parse --sample value '{value}' as usize") + })?; + config.summary_sample = parsed.max(1); + } + "--full-context" => { + config.full_context = true; + } + "--chunk-min" => { + let value = take_value("--chunk-min", &mut args)?; + let parsed = value.parse::().with_context(|| { + format!("failed to parse --chunk-min value '{value}' as usize") + })?; + config.chunk_min_chars = parsed.max(1); + } + "--chunk-max" => { + let value = take_value("--chunk-max", &mut args)?; + let parsed = value.parse::().with_context(|| { + format!("failed to parse --chunk-max value '{value}' as usize") + })?; + config.chunk_max_chars = parsed.max(1); + } + "--chunk-vector-take" => { + let value = take_value("--chunk-vector-take", &mut args)?; + let parsed = value.parse::().with_context(|| { + format!("failed to parse --chunk-vector-take value '{value}' as usize") + })?; + if parsed == 0 { + return Err(anyhow!("--chunk-vector-take must be greater than zero")); + } + config.chunk_vector_take = Some(parsed); + } + "--chunk-fts-take" => { + let value = take_value("--chunk-fts-take", &mut args)?; + let parsed = value.parse::().with_context(|| { + format!("failed to parse --chunk-fts-take value '{value}' as usize") + })?; + if parsed == 0 { + return Err(anyhow!("--chunk-fts-take must be greater than zero")); + } + config.chunk_fts_take = Some(parsed); + } + "--chunk-token-budget" => { + let value = take_value("--chunk-token-budget", &mut args)?; + let parsed = value.parse::().with_context(|| { + format!("failed to parse --chunk-token-budget value '{value}' as usize") + })?; + if parsed == 0 { + return Err(anyhow!("--chunk-token-budget must be greater than zero")); + } + config.chunk_token_budget = Some(parsed); + } + "--chunk-token-chars" => { + let value = take_value("--chunk-token-chars", &mut args)?; + let parsed = value.parse::().with_context(|| { + format!("failed to parse --chunk-token-chars value '{value}' as usize") + })?; + if parsed == 0 { + return Err(anyhow!("--chunk-token-chars must be greater than zero")); + } + config.chunk_avg_chars_per_token = Some(parsed); + } + "--max-chunks-per-entity" => { + let value = take_value("--max-chunks-per-entity", &mut args)?; + let parsed = value.parse::().with_context(|| { + format!("failed to parse --max-chunks-per-entity value '{value}' as usize") + })?; + if parsed == 0 { + return Err(anyhow!("--max-chunks-per-entity must be greater than zero")); + } + config.max_chunks_per_entity = Some(parsed); + } + "--embedding" => { + let value = take_value("--embedding", &mut args)?; + config.embedding_backend = value.parse()?; + } + "--embedding-model" => { + let value = take_value("--embedding-model", &mut args)?; + config.embedding_model = Some(value.trim().to_string()); + } + "--cache-dir" => { + let value = take_value("--cache-dir", &mut args)?; + config.cache_dir = PathBuf::from(value); + } + "--ingestion-cache-dir" => { + let value = take_value("--ingestion-cache-dir", &mut args)?; + config.ingestion_cache_dir = PathBuf::from(value); + } + "--negative-multiplier" => { + let value = take_value("--negative-multiplier", &mut args)?; + let parsed = value.parse::().with_context(|| { + format!("failed to parse --negative-multiplier value '{value}' as f32") + })?; + if !(parsed.is_finite() && parsed > 0.0) { + return Err(anyhow!( + "--negative-multiplier must be a positive finite number" + )); + } + config.negative_multiplier = parsed; + } + "--no-rerank" => { + config.rerank = false; + } + "--rerank-pool" => { + let value = take_value("--rerank-pool", &mut args)?; + let parsed = value.parse::().with_context(|| { + format!("failed to parse --rerank-pool value '{value}' as usize") + })?; + config.rerank_pool_size = parsed.max(1); + } + "--rerank-keep" => { + let value = take_value("--rerank-keep", &mut args)?; + let parsed = value.parse::().with_context(|| { + format!("failed to parse --rerank-keep value '{value}' as usize") + })?; + config.rerank_keep_top = parsed.max(1); + } + "--concurrency" => { + let value = take_value("--concurrency", &mut args)?; + let parsed = value.parse::().with_context(|| { + format!("failed to parse --concurrency value '{value}' as usize") + })?; + config.concurrency = parsed.max(1); + } + "--refresh-embeddings" => { + config.refresh_embeddings_only = true; + } + "--detailed-report" => { + config.detailed_report = true; + } + "--chunk-diagnostics" => { + let value = take_value("--chunk-diagnostics", &mut args)?; + config.chunk_diagnostics_path = Some(PathBuf::from(value)); + } + "--inspect-question" => { + let value = take_value("--inspect-question", &mut args)?; + config.inspect_question = Some(value); + } + "--inspect-manifest" => { + let value = take_value("--inspect-manifest", &mut args)?; + config.inspect_manifest = Some(PathBuf::from(value)); + } + "--inspect-db-state" => { + let value = take_value("--inspect-db-state", &mut args)?; + config.inspect_db_state = Some(PathBuf::from(value)); + } + "--perf-log-json" => { + let value = take_value("--perf-log-json", &mut args)?; + config.perf_log_json = Some(PathBuf::from(value)); + } + "--perf-log-dir" => { + let value = take_value("--perf-log-dir", &mut args)?; + config.perf_log_dir = Some(PathBuf::from(value)); + } + "--perf-log" => { + config.perf_log_console = true; + } + "--db-endpoint" => { + let value = take_value("--db-endpoint", &mut args)?; + config.db_endpoint = value; + } + "--db-user" => { + let value = take_value("--db-user", &mut args)?; + config.db_username = value; + } + "--db-pass" => { + let value = take_value("--db-pass", &mut args)?; + config.db_password = value; + } + "--db-namespace" => { + let value = take_value("--db-namespace", &mut args)?; + config.db_namespace = Some(value); + } + "--db-database" => { + let value = take_value("--db-database", &mut args)?; + config.db_database = Some(value); + } + unknown => { + return Err(anyhow!( + "unknown argument '{unknown}'. Use --help to see available options." + )); + } + } + } + + if config.chunk_min_chars >= config.chunk_max_chars { + return Err(anyhow!( + "--chunk-min must be less than --chunk-max (got {} >= {})", + config.chunk_min_chars, + config.chunk_max_chars + )); + } + + if config.rerank && config.rerank_pool_size == 0 { + return Err(anyhow!( + "--rerank-pool must be greater than zero when reranking is enabled" + )); + } + + if config.concurrency == 0 { + return Err(anyhow!("--concurrency must be greater than zero")); + } + + if config.embedding_backend == EmbeddingBackend::Hashed && config.embedding_model.is_some() { + return Err(anyhow!( + "--embedding-model cannot be used with the 'hashed' embedding backend" + )); + } + + if let Some(limit) = config.limit { + if let Some(corpus_limit) = config.corpus_limit { + if corpus_limit < limit { + config.corpus_limit = Some(limit); + } + } else { + let default_multiplier = 10usize; + let mut computed = limit.saturating_mul(default_multiplier); + if computed < limit { + computed = limit; + } + let max_cap = 1_000usize; + if computed > max_cap { + computed = max_cap; + } + config.corpus_limit = Some(computed); + } + } + + if config.perf_log_dir.is_none() { + if let Ok(dir) = env::var("EVAL_PERF_LOG_DIR") { + if !dir.trim().is_empty() { + config.perf_log_dir = Some(PathBuf::from(dir)); + } + } + } + + if let Ok(endpoint) = env::var("EVAL_DB_ENDPOINT") { + if !endpoint.trim().is_empty() { + config.db_endpoint = endpoint; + } + } + if let Ok(username) = env::var("EVAL_DB_USERNAME") { + if !username.trim().is_empty() { + config.db_username = username; + } + } + if let Ok(password) = env::var("EVAL_DB_PASSWORD") { + if !password.trim().is_empty() { + config.db_password = password; + } + } + if let Ok(ns) = env::var("EVAL_DB_NAMESPACE") { + if !ns.trim().is_empty() { + config.db_namespace = Some(ns); + } + } + if let Ok(db) = env::var("EVAL_DB_DATABASE") { + if !db.trim().is_empty() { + config.db_database = Some(db); + } + } + Ok(ParsedArgs { config, show_help }) +} + +fn take_value<'a, I>(flag: &str, iter: &mut std::iter::Peekable) -> Result +where + I: Iterator, +{ + iter.next().ok_or_else(|| anyhow!("{flag} expects a value")) +} + +pub fn print_help() { + println!( + "\ +eval — dataset conversion, ingestion, and retrieval evaluation CLI + +USAGE: + cargo eval -- [options] + # or + cargo run -p eval -- [options] + +OPTIONS: + --convert-only Convert the selected dataset and exit. + --force, --refresh Regenerate the converted dataset even if it already exists. + --dataset Dataset to evaluate: 'squad' (default) or 'natural-questions'. + --llm-mode Enable LLM-assisted evaluation features (includes unanswerable cases). + --slice Use a cached dataset slice by id (under eval/cache/slices) or by explicit path. + --label Annotate the run; label is stored in JSON/Markdown reports. + --query-model Override the SurrealDB system settings query model (e.g., gpt-4o-mini) for this run. + --slice-grow Grow the slice ledger to contain at least this many answerable cases, then exit. + --slice-offset Evaluate questions starting at this offset within the slice (default: 0). + --reseed-slice Ignore cached corpus state and rebuild the slice's SurrealDB corpus. + --slice-reset-ingestion + Delete cached paragraph shards before rebuilding the ingestion corpus. + --corpus-limit Cap the slice corpus size (positives + negatives). Defaults to ~10× --limit, capped at 1000. + --raw Path to the raw dataset (defaults per dataset). + --converted Path to write/read the converted dataset (defaults per dataset). + --report-dir Directory to write evaluation reports (default: eval/reports). + --k Precision@k cutoff (default: 5). + --limit Limit the number of questions evaluated (default: 200, 0 = all). + --sample Number of mismatches to surface in the Markdown summary (default: 5). + --full-context Disable context cropping when converting datasets (ingest entire documents). + --chunk-min Minimum characters per chunk for text splitting (default: 500). + --chunk-max Maximum characters per chunk for text splitting (default: 2000). + --chunk-vector-take + Override chunk vector candidate cap (default: 20). + --chunk-fts-take + Override chunk FTS candidate cap (default: 20). + --chunk-token-budget + Override chunk token budget estimate for assembly (default: 10000). + --chunk-token-chars + Override average characters per token used for budgeting (default: 4). + --max-chunks-per-entity + Override maximum chunks attached per entity (default: 4). + --embedding Embedding backend: 'fastembed' (default) or 'hashed'. + --embedding-model + FastEmbed model code (defaults to crate preset when omitted). + --cache-dir Directory for embedding caches (default: eval/cache). + --ingestion-cache-dir + Directory for ingestion corpora caches (default: eval/cache/ingested). + --negative-multiplier + Target negative-to-positive paragraph ratio for slice growth (default: 4.0). + --refresh-embeddings Recompute embeddings for cached corpora without re-running ingestion. + --detailed-report Include entity descriptions and categories in JSON reports. + --chunk-diagnostics + Write per-query chunk diagnostics JSONL to the provided path. + --no-rerank Disable the FastEmbed reranking stage (enabled by default). + --rerank-pool Reranking engine pool size / parallelism (default: 16). + --rerank-keep Keep top-N entities after reranking (default: 10). + --inspect-question + Inspect an ingestion cache question and exit (requires --inspect-manifest). + --inspect-manifest + Path to an ingestion cache manifest JSON for inspection mode. + --inspect-db-state + Optional override for the SurrealDB state.json used during inspection; defaults to the state recorded for the selected dataset slice. + --db-endpoint SurrealDB server endpoint (use http:// or https:// to enable SurQL export/import; ws:// endpoints reuse existing namespaces but skip SurQL exports; default: ws://127.0.0.1:8000). + --db-user SurrealDB root username (default: root_user). + --db-pass SurrealDB root password (default: root_password). + --db-namespace Override the namespace used on the SurrealDB server; state.json tracks this value and the ledger case count so changing it or requesting more cases via --limit triggers a rebuild/import (default: derived from dataset). + --db-database Override the database used on the SurrealDB server; recorded alongside namespace in state.json (default: derived from slice). + --perf-log Print per-stage performance timings to stdout after the run. + --perf-log-json + Write structured performance telemetry JSON to the provided path. + --perf-log-dir + Directory that receives timestamped perf JSON copies (defaults to $EVAL_PERF_LOG_DIR). + +Examples: + cargo eval -- --dataset squad --limit 10 --detailed-report + cargo eval -- --dataset natural-questions --limit 1 --rerank-pool 1 --detailed-report + +Notes: + The latest run's JSON/Markdown reports are saved as eval/reports/latest.json and latest.md, making it easy to script automated checks. + -h, --help Show this help text. + +Dataset defaults (from eval/manifest.yaml): + squad raw: eval/data/raw/squad/dev-v2.0.json + converted: eval/data/converted/squad-minne.json + natural-questions raw: eval/data/raw/nq/dev-all.jsonl + converted: eval/data/converted/nq-dev-minne.json +" + ); +} + +pub fn ensure_parent(path: &Path) -> Result<()> { + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent) + .with_context(|| format!("creating parent directory for {}", path.display()))?; + } + Ok(()) +} diff --git a/eval/src/cache.rs b/eval/src/cache.rs new file mode 100644 index 0000000..db31905 --- /dev/null +++ b/eval/src/cache.rs @@ -0,0 +1,88 @@ +use std::{ + collections::HashMap, + path::{Path, PathBuf}, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; + +use anyhow::{Context, Result}; +use serde::{Deserialize, Serialize}; +use tokio::sync::Mutex; + +#[derive(Debug, Default, Serialize, Deserialize)] +struct EmbeddingCacheData { + entities: HashMap>, + chunks: HashMap>, +} + +#[derive(Clone)] +pub struct EmbeddingCache { + path: Arc, + data: Arc>, + dirty: Arc, +} + +#[allow(dead_code)] +impl EmbeddingCache { + pub async fn load(path: impl AsRef) -> Result { + let path = path.as_ref().to_path_buf(); + let data = if path.exists() { + let raw = tokio::fs::read(&path) + .await + .with_context(|| format!("reading embedding cache {}", path.display()))?; + serde_json::from_slice(&raw) + .with_context(|| format!("parsing embedding cache {}", path.display()))? + } else { + EmbeddingCacheData::default() + }; + + Ok(Self { + path: Arc::new(path), + data: Arc::new(Mutex::new(data)), + dirty: Arc::new(AtomicBool::new(false)), + }) + } + + pub async fn get_entity(&self, id: &str) -> Option> { + let guard = self.data.lock().await; + guard.entities.get(id).cloned() + } + + pub async fn insert_entity(&self, id: String, embedding: Vec) { + let mut guard = self.data.lock().await; + guard.entities.insert(id, embedding); + self.dirty.store(true, Ordering::Relaxed); + } + + pub async fn get_chunk(&self, id: &str) -> Option> { + let guard = self.data.lock().await; + guard.chunks.get(id).cloned() + } + + pub async fn insert_chunk(&self, id: String, embedding: Vec) { + let mut guard = self.data.lock().await; + guard.chunks.insert(id, embedding); + self.dirty.store(true, Ordering::Relaxed); + } + + pub async fn persist(&self) -> Result<()> { + if !self.dirty.load(Ordering::Relaxed) { + return Ok(()); + } + + let guard = self.data.lock().await; + let body = serde_json::to_vec_pretty(&*guard).context("serialising embedding cache")?; + if let Some(parent) = self.path.parent() { + tokio::fs::create_dir_all(parent) + .await + .with_context(|| format!("creating cache directory {}", parent.display()))?; + } + tokio::fs::write(&*self.path, body) + .await + .with_context(|| format!("writing embedding cache {}", self.path.display()))?; + self.dirty.store(false, Ordering::Relaxed); + Ok(()) + } +} diff --git a/eval/src/datasets.rs b/eval/src/datasets.rs new file mode 100644 index 0000000..c1f9271 --- /dev/null +++ b/eval/src/datasets.rs @@ -0,0 +1,1003 @@ +use std::{ + collections::{BTreeMap, BTreeSet, HashMap}, + fs::{self, File}, + io::{BufRead, BufReader}, + path::{Path, PathBuf}, + str::FromStr, +}; + +use anyhow::{anyhow, bail, Context, Result}; +use chrono::{TimeZone, Utc}; +use once_cell::sync::OnceCell; +use serde::{Deserialize, Serialize}; +use tracing::warn; + +const MANIFEST_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/manifest.yaml"); +static DATASET_CATALOG: OnceCell = OnceCell::new(); + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct DatasetCatalog { + datasets: BTreeMap, + slices: HashMap, + default_dataset: String, +} + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct DatasetEntry { + pub metadata: DatasetMetadata, + pub raw_path: PathBuf, + pub converted_path: PathBuf, + pub include_unanswerable: bool, + pub slices: Vec, +} + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct SliceEntry { + pub id: String, + pub dataset_id: String, + pub label: String, + pub description: Option, + pub limit: Option, + pub corpus_limit: Option, + pub include_unanswerable: Option, + pub seed: Option, +} + +#[derive(Debug, Clone)] +#[allow(dead_code)] +struct SliceLocation { + dataset_id: String, + slice_index: usize, +} + +#[derive(Debug, Deserialize)] +struct ManifestFile { + default_dataset: Option, + datasets: Vec, +} + +#[derive(Debug, Deserialize)] +struct ManifestDataset { + id: String, + label: String, + category: String, + #[serde(default)] + entity_suffix: Option, + #[serde(default)] + source_prefix: Option, + raw: String, + converted: String, + #[serde(default)] + include_unanswerable: bool, + #[serde(default)] + slices: Vec, +} + +#[derive(Debug, Deserialize)] +struct ManifestSlice { + id: String, + label: String, + #[serde(default)] + description: Option, + #[serde(default)] + limit: Option, + #[serde(default)] + corpus_limit: Option, + #[serde(default)] + include_unanswerable: Option, + #[serde(default)] + seed: Option, +} + +impl DatasetCatalog { + pub fn load() -> Result { + let manifest_raw = fs::read_to_string(MANIFEST_PATH) + .with_context(|| format!("reading dataset manifest at {}", MANIFEST_PATH))?; + let manifest: ManifestFile = serde_yaml::from_str(&manifest_raw) + .with_context(|| format!("parsing dataset manifest at {}", MANIFEST_PATH))?; + + let root = Path::new(env!("CARGO_MANIFEST_DIR")); + let mut datasets = BTreeMap::new(); + let mut slices = HashMap::new(); + + for dataset in manifest.datasets { + let raw_path = resolve_path(root, &dataset.raw); + let converted_path = resolve_path(root, &dataset.converted); + + if !raw_path.exists() { + bail!( + "dataset '{}' raw file missing at {}", + dataset.id, + raw_path.display() + ); + } + if !converted_path.exists() { + warn!( + "dataset '{}' converted file missing at {}; the next conversion run will regenerate it", + dataset.id, + converted_path.display() + ); + } + + let metadata = DatasetMetadata { + id: dataset.id.clone(), + label: dataset.label.clone(), + category: dataset.category.clone(), + entity_suffix: dataset + .entity_suffix + .clone() + .unwrap_or_else(|| dataset.label.clone()), + source_prefix: dataset + .source_prefix + .clone() + .unwrap_or_else(|| dataset.id.clone()), + include_unanswerable: dataset.include_unanswerable, + context_token_limit: None, + }; + + let mut entry_slices = Vec::with_capacity(dataset.slices.len()); + + for (index, manifest_slice) in dataset.slices.into_iter().enumerate() { + if slices.contains_key(&manifest_slice.id) { + bail!( + "slice '{}' defined multiple times in manifest", + manifest_slice.id + ); + } + entry_slices.push(SliceEntry { + id: manifest_slice.id.clone(), + dataset_id: dataset.id.clone(), + label: manifest_slice.label, + description: manifest_slice.description, + limit: manifest_slice.limit, + corpus_limit: manifest_slice.corpus_limit, + include_unanswerable: manifest_slice.include_unanswerable, + seed: manifest_slice.seed, + }); + slices.insert( + manifest_slice.id, + SliceLocation { + dataset_id: dataset.id.clone(), + slice_index: index, + }, + ); + } + + datasets.insert( + metadata.id.clone(), + DatasetEntry { + metadata, + raw_path, + converted_path, + include_unanswerable: dataset.include_unanswerable, + slices: entry_slices, + }, + ); + } + + let default_dataset = manifest + .default_dataset + .or_else(|| datasets.keys().next().cloned()) + .ok_or_else(|| anyhow!("dataset manifest does not include any datasets"))?; + + Ok(Self { + datasets, + slices, + default_dataset, + }) + } + + pub fn global() -> Result<&'static Self> { + DATASET_CATALOG.get_or_try_init(Self::load) + } + + pub fn dataset(&self, id: &str) -> Result<&DatasetEntry> { + self.datasets + .get(id) + .ok_or_else(|| anyhow!("unknown dataset '{id}' in manifest")) + } + + #[allow(dead_code)] + pub fn default_dataset(&self) -> Result<&DatasetEntry> { + self.dataset(&self.default_dataset) + } + + #[allow(dead_code)] + pub fn slice(&self, slice_id: &str) -> Result<(&DatasetEntry, &SliceEntry)> { + let location = self + .slices + .get(slice_id) + .ok_or_else(|| anyhow!("unknown slice '{slice_id}' in manifest"))?; + let dataset = self + .datasets + .get(&location.dataset_id) + .ok_or_else(|| anyhow!("slice '{slice_id}' references missing dataset"))?; + let slice = dataset + .slices + .get(location.slice_index) + .ok_or_else(|| anyhow!("slice index out of bounds for '{slice_id}'"))?; + Ok((dataset, slice)) + } +} + +fn resolve_path(root: &Path, value: &str) -> PathBuf { + let path = Path::new(value); + if path.is_absolute() { + path.to_path_buf() + } else { + root.join(path) + } +} + +pub fn catalog() -> Result<&'static DatasetCatalog> { + DatasetCatalog::global() +} + +fn dataset_entry_for_kind(kind: DatasetKind) -> Result<&'static DatasetEntry> { + let catalog = catalog()?; + catalog.dataset(kind.id()) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DatasetKind { + SquadV2, + NaturalQuestions, +} + +impl DatasetKind { + pub fn id(self) -> &'static str { + match self { + Self::SquadV2 => "squad-v2", + Self::NaturalQuestions => "natural-questions-dev", + } + } + + pub fn label(self) -> &'static str { + match self { + Self::SquadV2 => "SQuAD v2.0", + Self::NaturalQuestions => "Natural Questions (dev)", + } + } + + pub fn category(self) -> &'static str { + match self { + Self::SquadV2 => "SQuAD v2.0", + Self::NaturalQuestions => "Natural Questions", + } + } + + pub fn entity_suffix(self) -> &'static str { + match self { + Self::SquadV2 => "SQuAD", + Self::NaturalQuestions => "Natural Questions", + } + } + + pub fn source_prefix(self) -> &'static str { + match self { + Self::SquadV2 => "squad", + Self::NaturalQuestions => "nq", + } + } + + pub fn default_raw_path(self) -> PathBuf { + dataset_entry_for_kind(self) + .map(|entry| entry.raw_path.clone()) + .unwrap_or_else(|err| panic!("dataset manifest missing entry for {:?}: {err}", self)) + } + + pub fn default_converted_path(self) -> PathBuf { + dataset_entry_for_kind(self) + .map(|entry| entry.converted_path.clone()) + .unwrap_or_else(|err| panic!("dataset manifest missing entry for {:?}: {err}", self)) + } +} + +impl Default for DatasetKind { + fn default() -> Self { + Self::SquadV2 + } +} + +impl FromStr for DatasetKind { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + match s.to_ascii_lowercase().as_str() { + "squad" | "squad-v2" | "squad_v2" => Ok(Self::SquadV2), + "nq" | "natural-questions" | "natural_questions" | "natural-questions-dev" => { + Ok(Self::NaturalQuestions) + } + other => { + anyhow::bail!("unknown dataset '{other}'. Expected 'squad' or 'natural-questions'.") + } + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DatasetMetadata { + pub id: String, + pub label: String, + pub category: String, + pub entity_suffix: String, + pub source_prefix: String, + #[serde(default)] + pub include_unanswerable: bool, + #[serde(default)] + pub context_token_limit: Option, +} + +impl DatasetMetadata { + pub fn for_kind( + kind: DatasetKind, + include_unanswerable: bool, + context_token_limit: Option, + ) -> Self { + if let Ok(entry) = dataset_entry_for_kind(kind) { + return Self { + id: entry.metadata.id.clone(), + label: entry.metadata.label.clone(), + category: entry.metadata.category.clone(), + entity_suffix: entry.metadata.entity_suffix.clone(), + source_prefix: entry.metadata.source_prefix.clone(), + include_unanswerable, + context_token_limit, + }; + } + + Self { + id: kind.id().to_string(), + label: kind.label().to_string(), + category: kind.category().to_string(), + entity_suffix: kind.entity_suffix().to_string(), + source_prefix: kind.source_prefix().to_string(), + include_unanswerable, + context_token_limit, + } + } +} + +fn default_metadata() -> DatasetMetadata { + DatasetMetadata::for_kind(DatasetKind::default(), false, None) +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConvertedDataset { + pub generated_at: chrono::DateTime, + #[serde(default = "default_metadata")] + pub metadata: DatasetMetadata, + pub source: String, + pub paragraphs: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConvertedParagraph { + pub id: String, + pub title: String, + pub context: String, + pub questions: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConvertedQuestion { + pub id: String, + pub question: String, + pub answers: Vec, + pub is_impossible: bool, +} + +pub fn convert( + raw_path: &Path, + dataset: DatasetKind, + include_unanswerable: bool, + context_token_limit: Option, +) -> Result { + let paragraphs = match dataset { + DatasetKind::SquadV2 => convert_squad(raw_path)?, + DatasetKind::NaturalQuestions => { + convert_nq(raw_path, include_unanswerable, context_token_limit)? + } + }; + + let metadata_limit = match dataset { + DatasetKind::NaturalQuestions => None, + _ => context_token_limit, + }; + + Ok(ConvertedDataset { + generated_at: Utc::now(), + metadata: DatasetMetadata::for_kind(dataset, include_unanswerable, metadata_limit), + source: raw_path.display().to_string(), + paragraphs, + }) +} + +fn convert_squad(raw_path: &Path) -> Result> { + #[derive(Debug, Deserialize)] + struct SquadDataset { + data: Vec, + } + + #[derive(Debug, Deserialize)] + struct SquadArticle { + title: String, + paragraphs: Vec, + } + + #[derive(Debug, Deserialize)] + struct SquadParagraph { + context: String, + qas: Vec, + } + + #[derive(Debug, Deserialize)] + struct SquadQuestion { + id: String, + question: String, + answers: Vec, + #[serde(default)] + is_impossible: bool, + } + + #[derive(Debug, Deserialize)] + struct SquadAnswer { + text: String, + } + + let raw = fs::read_to_string(raw_path) + .with_context(|| format!("reading raw SQuAD dataset at {}", raw_path.display()))?; + let parsed: SquadDataset = serde_json::from_str(&raw) + .with_context(|| format!("parsing SQuAD dataset at {}", raw_path.display()))?; + + let mut paragraphs = Vec::new(); + for (article_idx, article) in parsed.data.into_iter().enumerate() { + for (paragraph_idx, paragraph) in article.paragraphs.into_iter().enumerate() { + let mut questions = Vec::new(); + for qa in paragraph.qas { + let answers = dedupe_strings(qa.answers.into_iter().map(|answer| answer.text)); + questions.push(ConvertedQuestion { + id: qa.id, + question: qa.question.trim().to_string(), + answers, + is_impossible: qa.is_impossible, + }); + } + + let paragraph_id = + format!("{}-{}", slugify(&article.title, article_idx), paragraph_idx); + + paragraphs.push(ConvertedParagraph { + id: paragraph_id, + title: article.title.trim().to_string(), + context: paragraph.context.trim().to_string(), + questions, + }); + } + } + + Ok(paragraphs) +} + +#[allow(dead_code)] +pub const DEFAULT_CONTEXT_TOKEN_LIMIT: usize = 1_500; // retained for backwards compatibility (unused) + +fn convert_nq( + raw_path: &Path, + include_unanswerable: bool, + _context_token_limit: Option, +) -> Result> { + #[allow(dead_code)] + #[derive(Debug, Deserialize)] + struct NqExample { + question_text: String, + document_title: String, + example_id: i64, + document_tokens: Vec, + long_answer_candidates: Vec, + annotations: Vec, + } + + #[derive(Debug, Deserialize)] + struct NqToken { + token: String, + #[serde(default)] + html_token: bool, + } + + #[allow(dead_code)] + #[derive(Debug, Deserialize)] + struct NqLongAnswerCandidate { + start_token: i32, + end_token: i32, + } + + #[allow(dead_code)] + #[derive(Debug, Deserialize)] + struct NqAnnotation { + short_answers: Vec, + #[serde(default)] + yes_no_answer: String, + long_answer: NqLongAnswer, + } + + #[derive(Debug, Deserialize)] + struct NqShortAnswer { + start_token: i32, + end_token: i32, + } + + #[allow(dead_code)] + #[derive(Debug, Deserialize)] + struct NqLongAnswer { + candidate_index: i32, + } + + fn join_tokens(tokens: &[NqToken], start: usize, end: usize) -> String { + let mut buffer = String::new(); + let end = end.min(tokens.len()); + for token in tokens.iter().skip(start).take(end.saturating_sub(start)) { + if token.html_token { + continue; + } + let text = token.token.trim(); + if text.is_empty() { + continue; + } + let attach = matches!( + text, + "," | "." | "!" | "?" | ";" | ":" | ")" | "]" | "}" | "%" | "…" | "..." + ) || text.starts_with('\'') + || text == "n't" + || text == "'s" + || text == "'re" + || text == "'ve" + || text == "'d" + || text == "'ll"; + + if buffer.is_empty() || attach { + buffer.push_str(text); + } else { + buffer.push(' '); + buffer.push_str(text); + } + } + + buffer.trim().to_string() + } + + let file = File::open(raw_path).with_context(|| { + format!( + "opening Natural Questions dataset at {}", + raw_path.display() + ) + })?; + let reader = BufReader::new(file); + + let mut paragraphs = Vec::new(); + for (line_idx, line) in reader.lines().enumerate() { + let line = line.with_context(|| { + format!( + "reading Natural Questions line {} from {}", + line_idx + 1, + raw_path.display() + ) + })?; + if line.trim().is_empty() { + continue; + } + let example: NqExample = serde_json::from_str(&line).with_context(|| { + format!( + "parsing Natural Questions JSON (line {}) at {}", + line_idx + 1, + raw_path.display() + ) + })?; + + let mut answer_texts: Vec = Vec::new(); + let mut short_answer_texts: Vec = Vec::new(); + let mut has_short_or_yesno = false; + let mut has_short_answer = false; + for annotation in &example.annotations { + for short in &annotation.short_answers { + if short.start_token < 0 || short.end_token <= short.start_token { + continue; + } + let start = short.start_token as usize; + let end = short.end_token as usize; + if start >= example.document_tokens.len() || end > example.document_tokens.len() { + continue; + } + let text = join_tokens(&example.document_tokens, start, end); + if !text.is_empty() { + answer_texts.push(text.clone()); + short_answer_texts.push(text); + has_short_or_yesno = true; + has_short_answer = true; + } + } + + match annotation + .yes_no_answer + .trim() + .to_ascii_lowercase() + .as_str() + { + "yes" => { + answer_texts.push("yes".to_string()); + has_short_or_yesno = true; + } + "no" => { + answer_texts.push("no".to_string()); + has_short_or_yesno = true; + } + _ => {} + } + } + + let mut answers = dedupe_strings(answer_texts); + let is_unanswerable = !has_short_or_yesno || answers.is_empty(); + if is_unanswerable { + if !include_unanswerable { + continue; + } + answers.clear(); + } + + let paragraph_id = format!("nq-{}", example.example_id); + let question_id = format!("nq-{}", example.example_id); + + let context = join_tokens(&example.document_tokens, 0, example.document_tokens.len()); + if context.is_empty() { + continue; + } + + if has_short_answer && !short_answer_texts.is_empty() { + let normalized_context = context.to_ascii_lowercase(); + let missing_answer = short_answer_texts.iter().any(|answer| { + let needle = answer.trim().to_ascii_lowercase(); + !needle.is_empty() && !normalized_context.contains(&needle) + }); + if missing_answer { + warn!( + question_id = %question_id, + "Skipping Natural Questions example because answers were not found in the assembled context" + ); + continue; + } + } + + if !include_unanswerable && (!has_short_answer || short_answer_texts.is_empty()) { + // yes/no-only questions are excluded by default unless --llm-mode is used + continue; + } + + let question = ConvertedQuestion { + id: question_id, + question: example.question_text.trim().to_string(), + answers, + is_impossible: is_unanswerable, + }; + + paragraphs.push(ConvertedParagraph { + id: paragraph_id, + title: example.document_title.trim().to_string(), + context, + questions: vec![question], + }); + } + + Ok(paragraphs) +} + +fn ensure_parent(path: &Path) -> Result<()> { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent) + .with_context(|| format!("creating parent directory for {}", path.display()))?; + } + Ok(()) +} + +pub fn write_converted(dataset: &ConvertedDataset, converted_path: &Path) -> Result<()> { + ensure_parent(converted_path)?; + let json = + serde_json::to_string_pretty(dataset).context("serialising converted dataset to JSON")?; + fs::write(converted_path, json) + .with_context(|| format!("writing converted dataset to {}", converted_path.display())) +} + +pub fn read_converted(converted_path: &Path) -> Result { + let raw = fs::read_to_string(converted_path) + .with_context(|| format!("reading converted dataset at {}", converted_path.display()))?; + let mut dataset: ConvertedDataset = serde_json::from_str(&raw) + .with_context(|| format!("parsing converted dataset at {}", converted_path.display()))?; + if dataset.metadata.id.trim().is_empty() { + dataset.metadata = default_metadata(); + } + if dataset.source.is_empty() { + dataset.source = converted_path.display().to_string(); + } + Ok(dataset) +} + +pub fn ensure_converted( + dataset_kind: DatasetKind, + raw_path: &Path, + converted_path: &Path, + force: bool, + include_unanswerable: bool, + context_token_limit: Option, +) -> Result { + if force || !converted_path.exists() { + let dataset = convert( + raw_path, + dataset_kind, + include_unanswerable, + context_token_limit, + )?; + write_converted(&dataset, converted_path)?; + return Ok(dataset); + } + + match read_converted(converted_path) { + Ok(dataset) + if dataset.metadata.id == dataset_kind.id() + && dataset.metadata.include_unanswerable == include_unanswerable + && dataset.metadata.context_token_limit == context_token_limit => + { + Ok(dataset) + } + _ => { + let dataset = convert( + raw_path, + dataset_kind, + include_unanswerable, + context_token_limit, + )?; + write_converted(&dataset, converted_path)?; + Ok(dataset) + } + } +} + +fn dedupe_strings(values: I) -> Vec +where + I: IntoIterator, +{ + let mut set = BTreeSet::new(); + for value in values { + let trimmed = value.trim(); + if !trimmed.is_empty() { + set.insert(trimmed.to_string()); + } + } + set.into_iter().collect() +} + +fn slugify(input: &str, fallback_idx: usize) -> String { + let mut slug = String::new(); + let mut last_dash = false; + for ch in input.chars() { + let c = ch.to_ascii_lowercase(); + if c.is_ascii_alphanumeric() { + slug.push(c); + last_dash = false; + } else if !last_dash { + slug.push('-'); + last_dash = true; + } + } + + slug = slug.trim_matches('-').to_string(); + if slug.is_empty() { + slug = format!("article-{fallback_idx}"); + } + slug +} + +pub fn base_timestamp() -> chrono::DateTime { + Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0) + .single() + .expect("valid base timestamp") +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + use std::io::Write; + use tempfile::NamedTempFile; + + #[test] + fn convert_nq_handles_answers_and_skips_unanswerable() { + let mut file = NamedTempFile::new().expect("temp file"); + + let record_with_short_answers = json!({ + "question_text": "What is foo?", + "document_title": "Foo Title", + "example_id": 123, + "document_tokens": [ + {"token": "Foo", "html_token": false}, + {"token": "is", "html_token": false}, + {"token": "bar", "html_token": false}, + {"token": ".", "html_token": false} + ], + "long_answer_candidates": [ + {"start_token": 0, "end_token": 4, "top_level": true} + ], + "annotations": [ + { + "long_answer": {"start_token": 0, "end_token": 4, "candidate_index": 0}, + "short_answers": [ + {"start_token": 2, "end_token": 3}, + {"start_token": 2, "end_token": 3} + ], + "yes_no_answer": "NONE" + } + ] + }); + + let record_with_yes_no = json!({ + "question_text": "Is bar real?", + "document_title": "Bar Title", + "example_id": 456, + "document_tokens": [ + {"token": "Yes", "html_token": false}, + {"token": ",", "html_token": false}, + {"token": "bar", "html_token": false}, + {"token": "is", "html_token": false} + ], + "long_answer_candidates": [ + {"start_token": 0, "end_token": 4, "top_level": true} + ], + "annotations": [ + { + "long_answer": {"start_token": 0, "end_token": 4, "candidate_index": 0}, + "short_answers": [], + "yes_no_answer": "YES" + } + ] + }); + + let unanswerable_record = json!({ + "question_text": "Unknown?", + "document_title": "Unknown Title", + "example_id": 789, + "document_tokens": [ + {"token": "No", "html_token": false}, + {"token": "answer", "html_token": false} + ], + "long_answer_candidates": [ + {"start_token": 0, "end_token": 2, "top_level": true} + ], + "annotations": [ + { + "long_answer": {"start_token": 0, "end_token": 2, "candidate_index": 0}, + "short_answers": [], + "yes_no_answer": "NONE" + } + ] + }); + + writeln!(file, "{}", record_with_short_answers).unwrap(); + writeln!(file, "{}", record_with_yes_no).unwrap(); + writeln!(file, "{}", unanswerable_record).unwrap(); + file.flush().unwrap(); + + let dataset = convert( + file.path(), + DatasetKind::NaturalQuestions, + false, + Some(DEFAULT_CONTEXT_TOKEN_LIMIT), + ) + .expect("convert natural questions"); + + assert_eq!(dataset.metadata.id, DatasetKind::NaturalQuestions.id()); + assert!(!dataset.metadata.include_unanswerable); + assert_eq!(dataset.paragraphs.len(), 2); + + let first = &dataset.paragraphs[0]; + assert_eq!(first.id, "nq-123"); + assert!(first.context.contains("Foo")); + let first_answers = &first.questions.first().expect("question present").answers; + assert_eq!(first_answers, &vec!["bar".to_string()]); + + let second = &dataset.paragraphs[1]; + assert_eq!(second.id, "nq-456"); + let second_answers = &second.questions.first().expect("question present").answers; + assert_eq!(second_answers, &vec!["yes".to_string()]); + + assert!(dataset + .paragraphs + .iter() + .all(|paragraph| paragraph.id != "nq-789")); + } + + #[test] + fn convert_nq_includes_unanswerable_when_flagged() { + let mut file = NamedTempFile::new().expect("temp file"); + + let answerable = json!({ + "question_text": "What is foo?", + "document_title": "Foo Title", + "example_id": 123, + "document_tokens": [ + {"token": "Foo", "html_token": false}, + {"token": "is", "html_token": false}, + {"token": "bar", "html_token": false} + ], + "long_answer_candidates": [ + {"start_token": 0, "end_token": 3, "top_level": true} + ], + "annotations": [ + { + "long_answer": {"start_token": 0, "end_token": 3, "candidate_index": 0}, + "short_answers": [ + {"start_token": 2, "end_token": 3} + ], + "yes_no_answer": "NONE" + } + ] + }); + + let unanswerable = json!({ + "question_text": "Unknown?", + "document_title": "Unknown Title", + "example_id": 456, + "document_tokens": [ + {"token": "No", "html_token": false}, + {"token": "answer", "html_token": false} + ], + "long_answer_candidates": [ + {"start_token": 0, "end_token": 2, "top_level": true} + ], + "annotations": [ + { + "long_answer": {"start_token": 0, "end_token": 2, "candidate_index": -1}, + "short_answers": [], + "yes_no_answer": "NONE" + } + ] + }); + + writeln!(file, "{}", answerable).unwrap(); + writeln!(file, "{}", unanswerable).unwrap(); + file.flush().unwrap(); + + let dataset = convert( + file.path(), + DatasetKind::NaturalQuestions, + true, + Some(DEFAULT_CONTEXT_TOKEN_LIMIT), + ) + .expect("convert natural questions with unanswerable"); + + assert!(dataset.metadata.include_unanswerable); + assert_eq!(dataset.paragraphs.len(), 2); + let impossible = dataset + .paragraphs + .iter() + .find(|p| p.id == "nq-456") + .expect("unanswerable paragraph present"); + let question = impossible.questions.first().expect("question present"); + assert!(question.answers.is_empty()); + assert!(question.is_impossible); + } + + #[test] + fn catalog_lists_datasets_and_slices() { + let catalog = catalog().expect("catalog"); + let squad = catalog.dataset("squad-v2").expect("squad dataset"); + assert!(squad.raw_path.exists()); + assert!(squad.converted_path.exists()); + assert!(!squad.slices.is_empty()); + + let (dataset, slice) = catalog.slice("squad-dev-200").expect("slice"); + assert_eq!(dataset.metadata.id, squad.metadata.id); + assert_eq!(slice.dataset_id, squad.metadata.id); + assert!(slice.limit.is_some()); + } +} diff --git a/eval/src/db_helpers.rs b/eval/src/db_helpers.rs new file mode 100644 index 0000000..01fe151 --- /dev/null +++ b/eval/src/db_helpers.rs @@ -0,0 +1,269 @@ +use anyhow::{Context, Result}; +use common::storage::db::SurrealDbClient; + +// Remove and recreate HNSW indexes for changing embedding lengths, used at beginning if embedding length differs from default system settings +pub async fn change_embedding_length_in_hnsw_indexes( + db: &SurrealDbClient, + dimension: usize, +) -> Result<()> { + tracing::info!("Changing embedding length in HNSW indexes"); + let query = format!( + "BEGIN TRANSACTION; + REMOVE INDEX IF EXISTS idx_embedding_chunks ON TABLE text_chunk; + REMOVE INDEX IF EXISTS idx_embedding_entities ON TABLE knowledge_entity; + DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {dim}; + DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {dim}; + COMMIT TRANSACTION;", + dim = dimension + ); + + db.client + .query(query) + .await + .context("changing HNSW indexes")?; + tracing::info!("HNSW indexes successfully changed"); + Ok(()) +} + +// Helper functions for index management during namespace reseed +pub async fn remove_all_indexes(db: &SurrealDbClient) -> Result<()> { + tracing::info!("Removing ALL indexes before namespace reseed (aggressive approach)"); + + // Remove ALL indexes from ALL tables to ensure no cache access + db.client + .query( + "BEGIN TRANSACTION; + -- HNSW indexes + REMOVE INDEX IF EXISTS idx_embedding_chunks ON TABLE text_chunk; + REMOVE INDEX IF EXISTS idx_embedding_entities ON TABLE knowledge_entity; + + -- FTS indexes on text_content (remove ALL of them) + REMOVE INDEX IF EXISTS text_content_fts_idx ON TABLE text_content; + REMOVE INDEX IF EXISTS text_content_fts_text_idx ON TABLE text_content; + REMOVE INDEX IF EXISTS text_content_fts_category_idx ON TABLE text_content; + REMOVE INDEX IF EXISTS text_content_fts_context_idx ON TABLE text_content; + REMOVE INDEX IF EXISTS text_content_fts_file_name_idx ON TABLE text_content; + REMOVE INDEX IF EXISTS text_content_fts_url_idx ON TABLE text_content; + REMOVE INDEX IF EXISTS text_content_fts_url_title_idx ON TABLE text_content; + + -- FTS indexes on knowledge_entity + REMOVE INDEX IF EXISTS knowledge_entity_fts_name_idx ON TABLE knowledge_entity; + REMOVE INDEX IF EXISTS knowledge_entity_fts_description_idx ON TABLE knowledge_entity; + + -- FTS indexes on text_chunk + REMOVE INDEX IF EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk; + + COMMIT TRANSACTION;", + ) + .await + .context("removing all indexes before namespace reseed")?; + + tracing::info!("All indexes removed before namespace reseed"); + Ok(()) +} + +async fn create_tokenizer(db: &SurrealDbClient) -> Result<()> { + tracing::info!("Creating FTS analyzers for namespace reseed"); + let res = db + .client + .query( + "BEGIN TRANSACTION; + DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer + TOKENIZERS class + FILTERS lowercase, ascii, snowball(english); + COMMIT TRANSACTION;", + ) + .await + .context("creating FTS analyzers for namespace reseed")?; + + res.check().context("failed to create the tokenizer")?; + Ok(()) +} + +pub async fn recreate_indexes(db: &SurrealDbClient, dimension: usize) -> Result<()> { + tracing::info!("Recreating ALL indexes after namespace reseed (SEQUENTIAL approach)"); + let total_start = std::time::Instant::now(); + + create_tokenizer(db) + .await + .context("creating FTS analyzer")?; + + // For now we dont remove these plain indexes, we could if they prove negatively impacting performance + // create_regular_indexes_for_snapshot(db) + // .await + // .context("creating regular indexes for namespace reseed")?; + + let fts_start = std::time::Instant::now(); + create_fts_indexes_for_snapshot(db) + .await + .context("creating FTS indexes for namespace reseed")?; + tracing::info!(duration = ?fts_start.elapsed(), "FTS indexes created"); + + let hnsw_start = std::time::Instant::now(); + create_hnsw_indexes_for_snapshot(db, dimension) + .await + .context("creating HNSW indexes for namespace reseed")?; + tracing::info!(duration = ?hnsw_start.elapsed(), "HNSW indexes created"); + + tracing::info!(duration = ?total_start.elapsed(), "All index groups recreated successfully in sequence"); + Ok(()) +} + +#[allow(dead_code)] // For now we dont do this. We could +async fn create_regular_indexes_for_snapshot(db: &SurrealDbClient) -> Result<()> { + tracing::info!("Creating regular indexes for namespace reseed (parallel group 1)"); + let res = db + .client + .query( + "BEGIN TRANSACTION; + DEFINE INDEX text_content_user_id_idx ON text_content FIELDS user_id; + DEFINE INDEX text_content_created_at_idx ON text_content FIELDS created_at; + DEFINE INDEX text_content_category_idx ON text_content FIELDS category; + DEFINE INDEX text_chunk_source_id_idx ON text_chunk FIELDS source_id; + DEFINE INDEX text_chunk_user_id_idx ON text_chunk FIELDS user_id; + DEFINE INDEX knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id; + DEFINE INDEX knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id; + DEFINE INDEX knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type; + DEFINE INDEX knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at; + COMMIT TRANSACTION;", + ) + .await + .context("creating regular indexes for namespace reseed")?; + + res.check().context("one of the regular indexes failed")?; + + tracing::info!("Regular indexes for namespace reseed created"); + Ok(()) +} + +async fn create_fts_indexes_for_snapshot(db: &SurrealDbClient) -> Result<()> { + tracing::info!("Creating FTS indexes for namespace reseed (group 2)"); + let res = db.client + .query( + "BEGIN TRANSACTION; + DEFINE INDEX text_content_fts_idx ON TABLE text_content FIELDS text; + DEFINE INDEX knowledge_entity_fts_name_idx ON TABLE knowledge_entity FIELDS name + SEARCH ANALYZER app_en_fts_analyzer BM25; + DEFINE INDEX knowledge_entity_fts_description_idx ON TABLE knowledge_entity FIELDS description + SEARCH ANALYZER app_en_fts_analyzer BM25; + DEFINE INDEX text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk + SEARCH ANALYZER app_en_fts_analyzer BM25; + COMMIT TRANSACTION;", + ) + .await + .context("sending FTS index creation query")?; + + // This actually surfaces statement-level errors + res.check() + .context("one or more FTS index statements failed")?; + + tracing::info!("FTS indexes for namespace reseed created"); + Ok(()) +} + +async fn create_hnsw_indexes_for_snapshot(db: &SurrealDbClient, dimension: usize) -> Result<()> { + tracing::info!("Creating HNSW indexes for namespace reseed (group 3)"); + let query = format!( + "BEGIN TRANSACTION; + DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {dim}; + DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {dim}; + COMMIT TRANSACTION;", + dim = dimension + ); + + let res = db + .client + .query(query) + .await + .context("creating HNSW indexes for namespace reseed")?; + + res.check() + .context("one or more HNSW index statements failed")?; + + tracing::info!("HNSW indexes for namespace reseed created"); + Ok(()) +} + +pub async fn reset_namespace(db: &SurrealDbClient, namespace: &str, database: &str) -> Result<()> { + let query = format!( + "REMOVE NAMESPACE {ns}; + DEFINE NAMESPACE {ns}; + DEFINE DATABASE {db};", + ns = namespace, + db = database + ); + db.client + .query(query) + .await + .context("resetting SurrealDB namespace")?; + db.client + .use_ns(namespace) + .use_db(database) + .await + .context("selecting namespace/database after reset")?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use serde::Deserialize; + use uuid::Uuid; + + #[derive(Debug, Deserialize)] + struct FooRow { + label: String, + } + + #[tokio::test] + async fn reset_namespace_drops_existing_rows() { + let namespace = format!("reset_ns_{}", Uuid::new_v4().simple()); + let database = format!("reset_db_{}", Uuid::new_v4().simple()); + let db = SurrealDbClient::memory(&namespace, &database) + .await + .expect("in-memory db"); + + db.client + .query( + "DEFINE TABLE foo SCHEMALESS; + CREATE foo:foo SET label = 'before';", + ) + .await + .expect("seed namespace") + .check() + .expect("seed response"); + + let mut before = db + .client + .query("SELECT * FROM foo") + .await + .expect("select before reset"); + let existing: Vec = before.take(0).expect("rows before reset"); + assert_eq!(existing.len(), 1); + assert_eq!(existing[0].label, "before"); + + reset_namespace(&db, &namespace, &database) + .await + .expect("namespace reset"); + + match db.client.query("SELECT * FROM foo").await { + Ok(mut response) => { + let rows: Vec = response.take(0).unwrap_or_default(); + assert!( + rows.is_empty(), + "reset namespace should drop rows, found {:?}", + rows + ); + } + Err(error) => { + let message = error.to_string(); + assert!( + message.to_ascii_lowercase().contains("table") + || message.to_ascii_lowercase().contains("namespace") + || message.to_ascii_lowercase().contains("foo"), + "unexpected error after namespace reset: {message}" + ); + } + } + } +} diff --git a/eval/src/embedding.rs b/eval/src/embedding.rs new file mode 100644 index 0000000..c17f1fc --- /dev/null +++ b/eval/src/embedding.rs @@ -0,0 +1,171 @@ +use std::{ + collections::hash_map::DefaultHasher, + hash::{Hash, Hasher}, + str::FromStr, + sync::Arc, +}; + +use anyhow::{anyhow, Context, Result}; +use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions}; +use tokio::sync::Mutex; + +use crate::args::{Config, EmbeddingBackend}; + +#[derive(Clone)] +pub struct EmbeddingProvider { + inner: EmbeddingInner, +} + +#[derive(Clone)] +enum EmbeddingInner { + Hashed { + dimension: usize, + }, + FastEmbed { + model: Arc>, + model_name: EmbeddingModel, + dimension: usize, + }, +} + +impl EmbeddingProvider { + pub fn backend_label(&self) -> &'static str { + match self.inner { + EmbeddingInner::Hashed { .. } => "hashed", + EmbeddingInner::FastEmbed { .. } => "fastembed", + } + } + + pub fn dimension(&self) -> usize { + match &self.inner { + EmbeddingInner::Hashed { dimension } => *dimension, + EmbeddingInner::FastEmbed { dimension, .. } => *dimension, + } + } + + pub fn model_code(&self) -> Option { + match &self.inner { + EmbeddingInner::FastEmbed { model_name, .. } => Some(model_name.to_string()), + EmbeddingInner::Hashed { .. } => None, + } + } + + pub async fn embed(&self, text: &str) -> Result> { + match &self.inner { + EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)), + EmbeddingInner::FastEmbed { model, .. } => { + let mut guard = model.lock().await; + let embeddings = guard + .embed(vec![text.to_owned()], None) + .context("generating fastembed vector")?; + embeddings + .into_iter() + .next() + .ok_or_else(|| anyhow!("fastembed returned no embedding for input")) + } + } + } + + pub async fn embed_batch(&self, texts: Vec) -> Result>> { + match &self.inner { + EmbeddingInner::Hashed { dimension } => Ok(texts + .into_iter() + .map(|text| hashed_embedding(&text, *dimension)) + .collect()), + EmbeddingInner::FastEmbed { model, .. } => { + if texts.is_empty() { + return Ok(Vec::new()); + } + let mut guard = model.lock().await; + guard + .embed(texts, None) + .context("generating fastembed batch embeddings") + } + } + } +} + +pub async fn build_provider( + config: &Config, + default_dimension: usize, +) -> Result { + match config.embedding_backend { + EmbeddingBackend::Hashed => Ok(EmbeddingProvider { + inner: EmbeddingInner::Hashed { + dimension: default_dimension.max(1), + }, + }), + EmbeddingBackend::FastEmbed => { + let model_name = if let Some(code) = config.embedding_model.as_deref() { + EmbeddingModel::from_str(code).map_err(|err| anyhow!(err))? + } else { + EmbeddingModel::default() + }; + + let options = + TextInitOptions::new(model_name.clone()).with_show_download_progress(true); + let model_name_for_task = model_name.clone(); + let model_name_code = model_name.to_string(); + + let (model, dimension) = tokio::task::spawn_blocking(move || -> Result<_> { + let model = + TextEmbedding::try_new(options).context("initialising FastEmbed text model")?; + let info = + EmbeddingModel::get_model_info(&model_name_for_task).ok_or_else(|| { + anyhow!("FastEmbed model metadata missing for {model_name_code}") + })?; + Ok((model, info.dim)) + }) + .await + .context("joining FastEmbed initialisation task")??; + + Ok(EmbeddingProvider { + inner: EmbeddingInner::FastEmbed { + model: Arc::new(Mutex::new(model)), + model_name, + dimension, + }, + }) + } + } +} + +fn hashed_embedding(text: &str, dimension: usize) -> Vec { + let dim = dimension.max(1); + let mut vector = vec![0.0f32; dim]; + if text.is_empty() { + return vector; + } + + let mut token_count = 0f32; + for token in tokens(text) { + token_count += 1.0; + let idx = bucket(&token, dim); + vector[idx] += 1.0; + } + + if token_count == 0.0 { + return vector; + } + + let norm = vector.iter().map(|v| v * v).sum::().sqrt(); + if norm > 0.0 { + for value in &mut vector { + *value /= norm; + } + } + + vector +} + +fn tokens(text: &str) -> impl Iterator + '_ { + text.split(|c: char| !c.is_ascii_alphanumeric()) + .filter(|token| !token.is_empty()) + .map(|token| token.to_ascii_lowercase()) +} + +fn bucket(token: &str, dimension: usize) -> usize { + let mut hasher = DefaultHasher::new(); + token.hash(&mut hasher); + (hasher.finish() as usize) % dimension +} diff --git a/eval/src/eval/mod.rs b/eval/src/eval/mod.rs new file mode 100644 index 0000000..b9b0f36 --- /dev/null +++ b/eval/src/eval/mod.rs @@ -0,0 +1,767 @@ +mod pipeline; + +pub use pipeline::run_evaluation; + +use std::{ + collections::{HashMap, HashSet}, + path::Path, + time::Duration, +}; + +use anyhow::{anyhow, Context, Result}; +use chrono::{DateTime, SecondsFormat, Utc}; +use common::{ + error::AppError, + storage::{ + db::SurrealDbClient, + types::{system_settings::SystemSettings, user::User}, + }, +}; +use composite_retrieval::pipeline as retrieval_pipeline; +use composite_retrieval::pipeline::PipelineStageTimings; +use composite_retrieval::pipeline::RetrievalTuning; +use serde::{Deserialize, Serialize}; +use tokio::io::AsyncWriteExt; +use tracing::{info, warn}; + +use crate::{ + args::{self, Config}, + datasets::{self, ConvertedDataset}, + db_helpers::change_embedding_length_in_hnsw_indexes, + ingest, + slice::{self}, + snapshot::{self, DbSnapshotState}, +}; + +#[derive(Debug, Serialize)] +pub struct EvaluationSummary { + pub generated_at: DateTime, + pub k: usize, + pub limit: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub run_label: Option, + pub total_cases: usize, + pub correct: usize, + pub precision: f64, + pub correct_at_1: usize, + pub correct_at_2: usize, + pub correct_at_3: usize, + pub precision_at_1: f64, + pub precision_at_2: f64, + pub precision_at_3: f64, + pub duration_ms: u128, + pub dataset_id: String, + pub dataset_label: String, + pub dataset_includes_unanswerable: bool, + pub dataset_source: String, + pub slice_id: String, + pub slice_seed: u64, + pub slice_total_cases: usize, + pub slice_window_offset: usize, + pub slice_window_length: usize, + pub slice_cases: usize, + pub slice_positive_paragraphs: usize, + pub slice_negative_paragraphs: usize, + pub slice_total_paragraphs: usize, + pub slice_negative_multiplier: f32, + pub namespace_reused: bool, + pub corpus_paragraphs: usize, + pub ingestion_cache_path: String, + pub ingestion_reused: bool, + pub ingestion_embeddings_reused: bool, + pub ingestion_fingerprint: String, + pub positive_paragraphs_reused: usize, + pub negative_paragraphs_reused: usize, + pub latency_ms: LatencyStats, + pub perf: PerformanceTimings, + pub embedding_backend: String, + pub embedding_model: Option, + pub embedding_dimension: usize, + pub rerank_enabled: bool, + pub rerank_pool_size: Option, + pub rerank_keep_top: usize, + pub concurrency: usize, + pub detailed_report: bool, + pub chunk_vector_take: usize, + pub chunk_fts_take: usize, + pub chunk_token_budget: usize, + pub chunk_avg_chars_per_token: usize, + pub max_chunks_per_entity: usize, + pub cases: Vec, +} + +#[derive(Debug, Serialize)] +pub struct CaseSummary { + pub question_id: String, + pub question: String, + pub paragraph_id: String, + pub paragraph_title: String, + pub expected_source: String, + pub answers: Vec, + pub matched: bool, + pub entity_match: bool, + pub chunk_text_match: bool, + pub chunk_id_match: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub match_rank: Option, + pub latency_ms: u128, + pub retrieved: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LatencyStats { + pub avg: f64, + pub p50: u128, + pub p95: u128, +} + +#[derive(Debug, Clone, Serialize)] +pub struct StageLatencyBreakdown { + pub collect_candidates: LatencyStats, + pub graph_expansion: LatencyStats, + pub chunk_attach: LatencyStats, + pub rerank: LatencyStats, + pub assemble: LatencyStats, +} + +#[derive(Debug, Default, Clone, Serialize)] +pub struct EvaluationStageTimings { + pub prepare_slice_ms: u128, + pub prepare_db_ms: u128, + pub prepare_corpus_ms: u128, + pub prepare_namespace_ms: u128, + pub run_queries_ms: u128, + pub summarize_ms: u128, + pub finalize_ms: u128, +} + +#[derive(Debug, Serialize)] +pub struct PerformanceTimings { + pub openai_base_url: String, + pub ingestion_ms: u128, + #[serde(skip_serializing_if = "Option::is_none")] + pub namespace_seed_ms: Option, + pub evaluation_stage_ms: EvaluationStageTimings, + pub stage_latency: StageLatencyBreakdown, +} + +#[derive(Debug, Serialize)] +pub struct RetrievedSummary { + pub rank: usize, + pub entity_id: String, + pub source_id: String, + pub entity_name: String, + pub score: f32, + pub matched: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub entity_description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub entity_category: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub chunk_text_match: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub chunk_id_match: Option, +} + +#[derive(Debug, Serialize)] +pub(crate) struct CaseDiagnostics { + question_id: String, + question: String, + paragraph_id: String, + paragraph_title: String, + expected_source: String, + expected_chunk_ids: Vec, + answers: Vec, + entity_match: bool, + chunk_text_match: bool, + chunk_id_match: bool, + failure_reasons: Vec, + missing_expected_chunk_ids: Vec, + attached_chunk_ids: Vec, + retrieved: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pipeline: Option, +} + +#[derive(Debug, Serialize)] +struct EntityDiagnostics { + rank: usize, + entity_id: String, + source_id: String, + name: String, + score: f32, + entity_match: bool, + chunk_text_match: bool, + chunk_id_match: bool, + chunks: Vec, +} + +#[derive(Debug, Serialize)] +struct ChunkDiagnosticsEntry { + chunk_id: String, + score: f32, + contains_answer: bool, + expected_chunk: bool, + snippet: String, +} + +pub(crate) struct SeededCase { + question_id: String, + question: String, + expected_source: String, + answers: Vec, + paragraph_id: String, + paragraph_title: String, + expected_chunk_ids: Vec, +} + +pub(crate) fn cases_from_manifest(manifest: &ingest::CorpusManifest) -> Vec { + let mut title_map = HashMap::new(); + for paragraph in &manifest.paragraphs { + title_map.insert(paragraph.paragraph_id.as_str(), paragraph.title.clone()); + } + + manifest + .questions + .iter() + .filter(|question| !question.is_impossible) + .map(|question| { + let title = title_map + .get(question.paragraph_id.as_str()) + .cloned() + .unwrap_or_else(|| "Untitled".to_string()); + SeededCase { + question_id: question.question_id.clone(), + question: question.question_text.clone(), + expected_source: question.text_content_id.clone(), + answers: question.answers.clone(), + paragraph_id: question.paragraph_id.clone(), + paragraph_title: title, + expected_chunk_ids: question.matching_chunk_ids.clone(), + } + }) + .collect() +} + +pub(crate) fn text_contains_answer(text: &str, answers: &[String]) -> bool { + if answers.is_empty() { + return true; + } + let haystack = text.to_ascii_lowercase(); + answers.iter().any(|needle| haystack.contains(needle)) +} + +pub(crate) fn compute_latency_stats(latencies: &[u128]) -> LatencyStats { + if latencies.is_empty() { + return LatencyStats { + avg: 0.0, + p50: 0, + p95: 0, + }; + } + let mut sorted = latencies.to_vec(); + sorted.sort_unstable(); + let sum: u128 = sorted.iter().copied().sum(); + let avg = sum as f64 / (sorted.len() as f64); + let p50 = percentile(&sorted, 0.50); + let p95 = percentile(&sorted, 0.95); + LatencyStats { avg, p50, p95 } +} + +pub(crate) fn build_stage_latency_breakdown( + samples: &[PipelineStageTimings], +) -> StageLatencyBreakdown { + fn collect_stage(samples: &[PipelineStageTimings], selector: F) -> Vec + where + F: Fn(&PipelineStageTimings) -> u128, + { + samples.iter().map(selector).collect() + } + + StageLatencyBreakdown { + collect_candidates: compute_latency_stats(&collect_stage(samples, |entry| { + entry.collect_candidates_ms + })), + graph_expansion: compute_latency_stats(&collect_stage(samples, |entry| { + entry.graph_expansion_ms + })), + chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| entry.chunk_attach_ms)), + rerank: compute_latency_stats(&collect_stage(samples, |entry| entry.rerank_ms)), + assemble: compute_latency_stats(&collect_stage(samples, |entry| entry.assemble_ms)), + } +} + +fn percentile(sorted: &[u128], fraction: f64) -> u128 { + if sorted.is_empty() { + return 0; + } + let clamped = fraction.clamp(0.0, 1.0); + let idx = (clamped * (sorted.len() as f64 - 1.0)).round() as usize; + sorted[idx.min(sorted.len() - 1)] +} + +pub async fn grow_slice(dataset: &ConvertedDataset, config: &Config) -> Result<()> { + let ledger_limit = ledger_target(config); + let slice_settings = slice::slice_config_with_limit(config, ledger_limit); + let slice = + slice::resolve_slice(dataset, &slice_settings).context("resolving dataset slice")?; + info!( + slice = slice.manifest.slice_id.as_str(), + cases = slice.manifest.case_count, + positives = slice.manifest.positive_paragraphs, + negatives = slice.manifest.negative_paragraphs, + total_paragraphs = slice.manifest.total_paragraphs, + "Slice ledger ready" + ); + println!( + "Slice `{}` now contains {} questions ({} positives, {} negatives)", + slice.manifest.slice_id, + slice.manifest.case_count, + slice.manifest.positive_paragraphs, + slice.manifest.negative_paragraphs + ); + Ok(()) +} + +pub(crate) fn ledger_target(config: &Config) -> Option { + match (config.slice_grow, config.limit) { + (Some(grow), Some(limit)) => Some(limit.max(grow)), + (Some(grow), None) => Some(grow), + (None, limit) => limit, + } +} + +pub(crate) fn apply_dataset_tuning_overrides( + dataset: &ConvertedDataset, + config: &Config, + tuning: &mut RetrievalTuning, +) { + let is_long_form = dataset + .metadata + .id + .to_ascii_lowercase() + .contains("natural-questions"); + if !is_long_form { + return; + } + + if config.chunk_vector_take.is_none() { + tuning.chunk_vector_take = tuning.chunk_vector_take.max(80); + } + if config.chunk_fts_take.is_none() { + tuning.chunk_fts_take = tuning.chunk_fts_take.max(80); + } + if config.chunk_token_budget.is_none() { + tuning.token_budget_estimate = tuning.token_budget_estimate.max(20_000); + } + if config.max_chunks_per_entity.is_none() { + tuning.max_chunks_per_entity = tuning.max_chunks_per_entity.max(12); + } + if tuning.lexical_match_weight < 0.25 { + tuning.lexical_match_weight = 0.3; + } +} + +pub(crate) fn build_case_diagnostics( + summary: &CaseSummary, + expected_chunk_ids: &[String], + answers_lower: &[String], + entities: &[composite_retrieval::RetrievedEntity], + pipeline_stats: Option, +) -> CaseDiagnostics { + let expected_set: HashSet<&str> = expected_chunk_ids.iter().map(|id| id.as_str()).collect(); + let mut seen_chunks: HashSet = HashSet::new(); + let mut attached_chunk_ids = Vec::new(); + let mut entity_diagnostics = Vec::new(); + + for (idx, entity) in entities.iter().enumerate() { + let mut chunk_entries = Vec::new(); + for chunk in &entity.chunks { + let contains_answer = text_contains_answer(&chunk.chunk.chunk, answers_lower); + let expected_chunk = expected_set.contains(chunk.chunk.id.as_str()); + seen_chunks.insert(chunk.chunk.id.clone()); + attached_chunk_ids.push(chunk.chunk.id.clone()); + chunk_entries.push(ChunkDiagnosticsEntry { + chunk_id: chunk.chunk.id.clone(), + score: chunk.score, + contains_answer, + expected_chunk, + snippet: chunk_preview(&chunk.chunk.chunk), + }); + } + entity_diagnostics.push(EntityDiagnostics { + rank: idx + 1, + entity_id: entity.entity.id.clone(), + source_id: entity.entity.source_id.clone(), + name: entity.entity.name.clone(), + score: entity.score, + entity_match: entity.entity.source_id == summary.expected_source, + chunk_text_match: chunk_entries.iter().any(|entry| entry.contains_answer), + chunk_id_match: chunk_entries.iter().any(|entry| entry.expected_chunk), + chunks: chunk_entries, + }); + } + + let missing_expected_chunk_ids = expected_chunk_ids + .iter() + .filter(|id| !seen_chunks.contains(id.as_str())) + .cloned() + .collect::>(); + + let mut failure_reasons = Vec::new(); + if !summary.entity_match { + failure_reasons.push("entity_miss".to_string()); + } + if !summary.chunk_text_match { + failure_reasons.push("chunk_text_missing".to_string()); + } + if !summary.chunk_id_match { + failure_reasons.push("chunk_id_missing".to_string()); + } + if !missing_expected_chunk_ids.is_empty() { + failure_reasons.push("expected_chunk_absent".to_string()); + } + + CaseDiagnostics { + question_id: summary.question_id.clone(), + question: summary.question.clone(), + paragraph_id: summary.paragraph_id.clone(), + paragraph_title: summary.paragraph_title.clone(), + expected_source: summary.expected_source.clone(), + expected_chunk_ids: expected_chunk_ids.to_vec(), + answers: summary.answers.clone(), + entity_match: summary.entity_match, + chunk_text_match: summary.chunk_text_match, + chunk_id_match: summary.chunk_id_match, + failure_reasons, + missing_expected_chunk_ids, + attached_chunk_ids, + retrieved: entity_diagnostics, + pipeline: pipeline_stats, + } +} + +fn chunk_preview(text: &str) -> String { + text.chars() + .take(200) + .collect::() + .replace('\n', " ") +} + +pub(crate) async fn write_chunk_diagnostics(path: &Path, cases: &[CaseDiagnostics]) -> Result<()> { + args::ensure_parent(path)?; + let mut file = tokio::fs::File::create(path) + .await + .with_context(|| format!("creating diagnostics file {}", path.display()))?; + for case in cases { + let line = serde_json::to_vec(case).context("serialising chunk diagnostics entry")?; + file.write_all(&line).await?; + file.write_all(b"\n").await?; + } + file.flush().await?; + Ok(()) +} + +pub(crate) async fn warm_hnsw_cache(db: &SurrealDbClient, dimension: usize) -> Result<()> { + // Create a dummy embedding for cache warming + let dummy_embedding: Vec = (0..dimension).map(|i| (i as f32).sin()).collect(); + + info!("Warming HNSW caches with sample queries"); + + // Warm up chunk index + let _ = db + .client + .query("SELECT * FROM text_chunk WHERE embedding <|1,1|> $embedding LIMIT 5") + .bind(("embedding", dummy_embedding.clone())) + .await + .context("warming text chunk HNSW cache")?; + + // Warm up entity index + let _ = db + .client + .query("SELECT * FROM knowledge_entity WHERE embedding <|1,1|> $embedding LIMIT 5") + .bind(("embedding", dummy_embedding)) + .await + .context("warming knowledge entity HNSW cache")?; + + info!("HNSW cache warming completed"); + Ok(()) +} + +pub(crate) async fn ensure_eval_user(db: &SurrealDbClient) -> Result { + let timestamp = datasets::base_timestamp(); + let user = User { + id: "eval-user".to_string(), + created_at: timestamp, + updated_at: timestamp, + email: "eval-retrieval@minne.dev".to_string(), + password: "not-used".to_string(), + anonymous: false, + api_key: None, + admin: false, + timezone: "UTC".to_string(), + }; + + if let Some(existing) = db.get_item::(&user.id).await? { + return Ok(existing); + } + + db.store_item(user.clone()) + .await + .context("storing evaluation user")?; + Ok(user) +} + +pub fn format_timestamp(timestamp: &DateTime) -> String { + timestamp.to_rfc3339_opts(SecondsFormat::Secs, true) +} + +pub(crate) fn sanitize_model_code(code: &str) -> String { + code.chars() + .map(|ch| { + if ch.is_ascii_alphanumeric() { + ch.to_ascii_lowercase() + } else { + '_' + } + }) + .collect() +} + +pub(crate) async fn connect_eval_db( + config: &Config, + namespace: &str, + database: &str, +) -> Result { + match SurrealDbClient::new( + &config.db_endpoint, + &config.db_username, + &config.db_password, + namespace, + database, + ) + .await + { + Ok(client) => { + info!( + endpoint = %config.db_endpoint, + namespace, + database, + auth = "root", + "Connected to SurrealDB" + ); + Ok(client) + } + Err(root_err) => { + info!( + endpoint = %config.db_endpoint, + namespace, + database, + "Root authentication failed; trying namespace-level auth" + ); + let namespace_client = SurrealDbClient::new_with_namespace_user( + &config.db_endpoint, + namespace, + &config.db_username, + &config.db_password, + database, + ) + .await + .map_err(|ns_err| { + anyhow!( + "failed to connect to SurrealDB via root ({root_err}) or namespace ({ns_err}) credentials" + ) + })?; + info!( + endpoint = %config.db_endpoint, + namespace, + database, + auth = "namespace", + "Connected to SurrealDB" + ); + Ok(namespace_client) + } + } +} + +pub(crate) async fn namespace_has_corpus(db: &SurrealDbClient) -> Result { + #[derive(Deserialize)] + struct CountRow { + count: i64, + } + + let mut response = db + .client + .query("SELECT count() AS count FROM text_chunk") + .await + .context("checking namespace corpus state")?; + let rows: Vec = response.take(0).unwrap_or_default(); + Ok(rows.first().map(|row| row.count).unwrap_or(0) > 0) +} + +pub(crate) async fn can_reuse_namespace( + db: &SurrealDbClient, + descriptor: &snapshot::Descriptor, + namespace: &str, + database: &str, + dataset_id: &str, + slice_id: &str, + ingestion_fingerprint: &str, + slice_case_count: usize, +) -> Result { + let state = match descriptor.load_db_state().await? { + Some(state) => state, + None => { + info!("No namespace state recorded; reseeding corpus from cached shards"); + return Ok(false); + } + }; + + if state.slice_case_count < slice_case_count { + info!( + requested_cases = slice_case_count, + stored_cases = state.slice_case_count, + "Skipping live namespace reuse; ledger grew beyond cached state" + ); + return Ok(false); + } + + if state.dataset_id != dataset_id + || state.slice_id != slice_id + || state.ingestion_fingerprint != ingestion_fingerprint + || state.namespace.as_deref() != Some(namespace) + || state.database.as_deref() != Some(database) + { + info!( + namespace, + database, "Cached namespace metadata mismatch; rebuilding corpus from ingestion cache" + ); + return Ok(false); + } + + if namespace_has_corpus(db).await? { + Ok(true) + } else { + info!( + namespace, + database, + "Namespace metadata matches but tables are empty; reseeding from ingestion cache" + ); + Ok(false) + } +} + +fn sanitize_identifier(input: &str) -> String { + let mut cleaned: String = input + .chars() + .map(|ch| { + if ch.is_ascii_alphanumeric() { + ch.to_ascii_lowercase() + } else { + '_' + } + }) + .collect(); + if cleaned.is_empty() { + cleaned.push('x'); + } + if cleaned.len() > 64 { + cleaned.truncate(64); + } + cleaned +} + +pub(crate) fn default_namespace(dataset_id: &str, limit: Option) -> String { + let dataset_component = sanitize_identifier(dataset_id); + let limit_component = match limit { + Some(value) if value > 0 => format!("limit{}", value), + _ => "all".to_string(), + }; + format!("eval_{}_{}", dataset_component, limit_component) +} + +pub(crate) fn default_database() -> String { + "retrieval_eval".to_string() +} + +pub(crate) async fn record_namespace_state( + descriptor: &snapshot::Descriptor, + dataset_id: &str, + slice_id: &str, + ingestion_fingerprint: &str, + namespace: &str, + database: &str, + slice_case_count: usize, +) { + let state = DbSnapshotState { + dataset_id: dataset_id.to_string(), + slice_id: slice_id.to_string(), + ingestion_fingerprint: ingestion_fingerprint.to_string(), + snapshot_hash: descriptor.metadata_hash().to_string(), + updated_at: Utc::now(), + namespace: Some(namespace.to_string()), + database: Some(database.to_string()), + slice_case_count, + }; + if let Err(err) = descriptor.store_db_state(&state).await { + warn!(error = %err, "Failed to record namespace state"); + } +} + +pub(crate) async fn enforce_system_settings( + db: &SurrealDbClient, + mut settings: SystemSettings, + provider_dimension: usize, + config: &Config, +) -> Result { + let mut updated_settings = settings.clone(); + let mut needs_settings_update = false; + let mut embedding_dimension_changed = false; + + if provider_dimension != settings.embedding_dimensions as usize { + updated_settings.embedding_dimensions = provider_dimension as u32; + needs_settings_update = true; + embedding_dimension_changed = true; + } + if let Some(query_override) = config.query_model.as_deref() { + if settings.query_model != query_override { + info!( + model = query_override, + "Overriding system query model for this run" + ); + updated_settings.query_model = query_override.to_string(); + needs_settings_update = true; + } + } + if needs_settings_update { + settings = SystemSettings::update(db, updated_settings) + .await + .context("updating system settings overrides")?; + } + if embedding_dimension_changed { + change_embedding_length_in_hnsw_indexes(db, provider_dimension) + .await + .context("redefining HNSW indexes for new embedding dimension")?; + } + Ok(settings) +} + +pub(crate) async fn load_or_init_system_settings( + db: &SurrealDbClient, +) -> Result<(SystemSettings, bool)> { + match SystemSettings::get_current(db).await { + Ok(settings) => Ok((settings, false)), + Err(AppError::NotFound(_)) => { + info!("System settings missing; applying database migrations for namespace"); + db.apply_migrations() + .await + .context("applying database migrations after missing system settings")?; + tokio::time::sleep(Duration::from_millis(50)).await; + let settings = SystemSettings::get_current(db) + .await + .context("loading system settings after migrations")?; + Ok((settings, true)) + } + Err(err) => Err(err).context("loading system settings"), + } +} diff --git a/eval/src/eval/pipeline/context.rs b/eval/src/eval/pipeline/context.rs new file mode 100644 index 0000000..877fde5 --- /dev/null +++ b/eval/src/eval/pipeline/context.rs @@ -0,0 +1,193 @@ +use std::{ + path::PathBuf, + sync::Arc, + time::{Duration, Instant}, +}; + +use async_openai::Client; +use common::storage::{ + db::SurrealDbClient, + types::{system_settings::SystemSettings, user::User}, +}; +use composite_retrieval::{ + pipeline::{PipelineStageTimings, RetrievalConfig}, + reranking::RerankerPool, +}; + +use crate::{ + args::Config, + cache::EmbeddingCache, + datasets::ConvertedDataset, + embedding::EmbeddingProvider, + eval::{CaseDiagnostics, CaseSummary, EvaluationStageTimings, EvaluationSummary, SeededCase}, + ingest, slice, snapshot, +}; + +pub(super) struct EvaluationContext<'a> { + dataset: &'a ConvertedDataset, + config: &'a Config, + pub stage_timings: EvaluationStageTimings, + pub ledger_limit: Option, + pub slice_settings: Option>, + pub slice: Option>, + pub window_offset: usize, + pub window_length: usize, + pub window_total_cases: usize, + pub namespace: String, + pub database: String, + pub db: Option, + pub descriptor: Option, + pub settings: Option, + pub settings_missing: bool, + pub must_reapply_settings: bool, + pub embedding_provider: Option, + pub embedding_cache: Option, + pub openai_client: Option>>, + pub openai_base_url: Option, + pub expected_fingerprint: Option, + pub ingestion_duration_ms: u128, + pub namespace_seed_ms: Option, + pub namespace_reused: bool, + pub evaluation_start: Option, + pub eval_user: Option, + pub corpus_handle: Option, + pub cases: Vec, + pub stage_latency_samples: Vec, + pub latencies: Vec, + pub diagnostics_output: Vec, + pub query_summaries: Vec, + pub rerank_pool: Option>, + pub retrieval_config: Option>, + pub summary: Option, + pub diagnostics_path: Option, + pub diagnostics_enabled: bool, +} + +impl<'a> EvaluationContext<'a> { + pub fn new(dataset: &'a ConvertedDataset, config: &'a Config) -> Self { + Self { + dataset, + config, + stage_timings: EvaluationStageTimings::default(), + ledger_limit: None, + slice_settings: None, + slice: None, + window_offset: 0, + window_length: 0, + window_total_cases: 0, + namespace: String::new(), + database: String::new(), + db: None, + descriptor: None, + settings: None, + settings_missing: false, + must_reapply_settings: false, + embedding_provider: None, + embedding_cache: None, + openai_client: None, + openai_base_url: None, + expected_fingerprint: None, + ingestion_duration_ms: 0, + namespace_seed_ms: None, + namespace_reused: false, + evaluation_start: None, + eval_user: None, + corpus_handle: None, + cases: Vec::new(), + stage_latency_samples: Vec::new(), + latencies: Vec::new(), + diagnostics_output: Vec::new(), + query_summaries: Vec::new(), + rerank_pool: None, + retrieval_config: None, + summary: None, + diagnostics_path: config.chunk_diagnostics_path.clone(), + diagnostics_enabled: config.chunk_diagnostics_path.is_some(), + } + } + + pub fn dataset(&self) -> &'a ConvertedDataset { + self.dataset + } + + pub fn config(&self) -> &'a Config { + self.config + } + + pub fn slice(&self) -> &slice::ResolvedSlice<'a> { + self.slice.as_ref().expect("slice has not been prepared") + } + + pub fn db(&self) -> &SurrealDbClient { + self.db.as_ref().expect("database connection missing") + } + + pub fn descriptor(&self) -> &snapshot::Descriptor { + self.descriptor + .as_ref() + .expect("snapshot descriptor unavailable") + } + + pub fn embedding_provider(&self) -> &EmbeddingProvider { + self.embedding_provider + .as_ref() + .expect("embedding provider not initialised") + } + + pub fn openai_client(&self) -> Arc> { + self.openai_client + .as_ref() + .expect("openai client missing") + .clone() + } + + pub fn corpus_handle(&self) -> &ingest::CorpusHandle { + self.corpus_handle.as_ref().expect("corpus handle missing") + } + + pub fn evaluation_user(&self) -> &User { + self.eval_user.as_ref().expect("evaluation user missing") + } + + pub fn record_stage_duration(&mut self, stage: EvalStage, duration: Duration) { + let elapsed = duration.as_millis() as u128; + match stage { + EvalStage::PrepareSlice => self.stage_timings.prepare_slice_ms += elapsed, + EvalStage::PrepareDb => self.stage_timings.prepare_db_ms += elapsed, + EvalStage::PrepareCorpus => self.stage_timings.prepare_corpus_ms += elapsed, + EvalStage::PrepareNamespace => self.stage_timings.prepare_namespace_ms += elapsed, + EvalStage::RunQueries => self.stage_timings.run_queries_ms += elapsed, + EvalStage::Summarize => self.stage_timings.summarize_ms += elapsed, + EvalStage::Finalize => self.stage_timings.finalize_ms += elapsed, + } + } + + pub fn into_summary(self) -> EvaluationSummary { + self.summary.expect("evaluation summary missing") + } +} + +#[derive(Copy, Clone)] +pub(super) enum EvalStage { + PrepareSlice, + PrepareDb, + PrepareCorpus, + PrepareNamespace, + RunQueries, + Summarize, + Finalize, +} + +impl EvalStage { + pub fn label(&self) -> &'static str { + match self { + EvalStage::PrepareSlice => "prepare-slice", + EvalStage::PrepareDb => "prepare-db", + EvalStage::PrepareCorpus => "prepare-corpus", + EvalStage::PrepareNamespace => "prepare-namespace", + EvalStage::RunQueries => "run-queries", + EvalStage::Summarize => "summarize", + EvalStage::Finalize => "finalize", + } + } +} diff --git a/eval/src/eval/pipeline/mod.rs b/eval/src/eval/pipeline/mod.rs new file mode 100644 index 0000000..2f62a6f --- /dev/null +++ b/eval/src/eval/pipeline/mod.rs @@ -0,0 +1,29 @@ +mod context; +mod stages; +mod state; + +use anyhow::Result; + +use crate::{args::Config, datasets::ConvertedDataset, eval::EvaluationSummary}; + +use context::EvaluationContext; + +pub async fn run_evaluation( + dataset: &ConvertedDataset, + config: &Config, +) -> Result { + let mut ctx = EvaluationContext::new(dataset, config); + let machine = state::ready(); + + let machine = stages::prepare_slice(machine, &mut ctx).await?; + let machine = stages::prepare_db(machine, &mut ctx).await?; + let machine = stages::prepare_corpus(machine, &mut ctx).await?; + let machine = stages::prepare_namespace(machine, &mut ctx).await?; + let machine = stages::run_queries(machine, &mut ctx).await?; + let machine = stages::summarize(machine, &mut ctx).await?; + let machine = stages::finalize(machine, &mut ctx).await?; + + drop(machine); + + Ok(ctx.into_summary()) +} diff --git a/eval/src/eval/pipeline/stages/finalize.rs b/eval/src/eval/pipeline/stages/finalize.rs new file mode 100644 index 0000000..17e7d40 --- /dev/null +++ b/eval/src/eval/pipeline/stages/finalize.rs @@ -0,0 +1,59 @@ +use std::time::Instant; + +use anyhow::Context; +use tracing::info; + +use crate::eval::write_chunk_diagnostics; + +use super::super::{ + context::{EvalStage, EvaluationContext}, + state::{Completed, EvaluationMachine, Summarized}, +}; +use super::{map_guard_error, StageResult}; + +pub(crate) async fn finalize( + machine: EvaluationMachine<(), Summarized>, + ctx: &mut EvaluationContext<'_>, +) -> StageResult { + let stage = EvalStage::Finalize; + info!( + evaluation_stage = stage.label(), + "starting evaluation stage" + ); + let started = Instant::now(); + + if let Some(cache) = ctx.embedding_cache.as_ref() { + cache + .persist() + .await + .context("persisting embedding cache")?; + } + + if let Some(path) = ctx.diagnostics_path.as_ref() { + if ctx.diagnostics_enabled { + write_chunk_diagnostics(path.as_path(), &ctx.diagnostics_output) + .await + .with_context(|| format!("writing chunk diagnostics to {}", path.display()))?; + } + } + + info!( + total_cases = ctx.summary.as_ref().map(|s| s.total_cases).unwrap_or(0), + correct = ctx.summary.as_ref().map(|s| s.correct).unwrap_or(0), + precision = ctx.summary.as_ref().map(|s| s.precision).unwrap_or(0.0), + dataset = ctx.dataset().metadata.id.as_str(), + "Evaluation complete" + ); + + let elapsed = started.elapsed(); + ctx.record_stage_duration(stage, elapsed); + info!( + evaluation_stage = stage.label(), + duration_ms = elapsed.as_millis(), + "completed evaluation stage" + ); + + machine + .finalize() + .map_err(|(_, guard)| map_guard_error("finalize", guard)) +} diff --git a/eval/src/eval/pipeline/stages/mod.rs b/eval/src/eval/pipeline/stages/mod.rs new file mode 100644 index 0000000..2fb0187 --- /dev/null +++ b/eval/src/eval/pipeline/stages/mod.rs @@ -0,0 +1,26 @@ +mod finalize; +mod prepare_corpus; +mod prepare_db; +mod prepare_namespace; +mod prepare_slice; +mod run_queries; +mod summarize; + +pub(crate) use finalize::finalize; +pub(crate) use prepare_corpus::prepare_corpus; +pub(crate) use prepare_db::prepare_db; +pub(crate) use prepare_namespace::prepare_namespace; +pub(crate) use prepare_slice::prepare_slice; +pub(crate) use run_queries::run_queries; +pub(crate) use summarize::summarize; + +use anyhow::Result; +use state_machines::core::GuardError; + +use super::state::EvaluationMachine; + +fn map_guard_error(event: &str, guard: GuardError) -> anyhow::Error { + anyhow::anyhow!("invalid evaluation pipeline transition during {event}: {guard:?}") +} + +type StageResult = Result>; diff --git a/eval/src/eval/pipeline/stages/prepare_corpus.rs b/eval/src/eval/pipeline/stages/prepare_corpus.rs new file mode 100644 index 0000000..16448f1 --- /dev/null +++ b/eval/src/eval/pipeline/stages/prepare_corpus.rs @@ -0,0 +1,84 @@ +use std::time::Instant; + +use anyhow::Context; +use tracing::info; + +use crate::{ingest, slice, snapshot}; + +use super::super::{ + context::{EvalStage, EvaluationContext}, + state::{CorpusReady, DbReady, EvaluationMachine}, +}; +use super::{map_guard_error, StageResult}; + +pub(crate) async fn prepare_corpus( + machine: EvaluationMachine<(), DbReady>, + ctx: &mut EvaluationContext<'_>, +) -> StageResult { + let stage = EvalStage::PrepareCorpus; + info!( + evaluation_stage = stage.label(), + "starting evaluation stage" + ); + let started = Instant::now(); + + let config = ctx.config(); + let cache_settings = ingest::CorpusCacheConfig::from(config); + let embedding_provider = ctx.embedding_provider().clone(); + let openai_client = ctx.openai_client(); + + let eval_user_id = "eval-user".to_string(); + let ingestion_timer = Instant::now(); + let corpus_handle = { + let slice = ctx.slice(); + let window = slice::select_window(slice, ctx.config().slice_offset, ctx.config().limit) + .context("selecting slice window for corpus preparation")?; + ingest::ensure_corpus( + ctx.dataset(), + slice, + &window, + &cache_settings, + &embedding_provider, + openai_client, + &eval_user_id, + config.converted_dataset_path.as_path(), + ) + .await + .context("ensuring ingestion-backed corpus")? + }; + let expected_fingerprint = corpus_handle + .manifest + .metadata + .ingestion_fingerprint + .clone(); + let ingestion_duration_ms = ingestion_timer.elapsed().as_millis() as u128; + info!( + cache = %corpus_handle.path.display(), + reused_ingestion = corpus_handle.reused_ingestion, + reused_embeddings = corpus_handle.reused_embeddings, + positive_ingested = corpus_handle.positive_ingested, + negative_ingested = corpus_handle.negative_ingested, + "Ingestion corpus ready" + ); + + ctx.corpus_handle = Some(corpus_handle); + ctx.expected_fingerprint = Some(expected_fingerprint); + ctx.ingestion_duration_ms = ingestion_duration_ms; + ctx.descriptor = Some(snapshot::Descriptor::new( + config, + ctx.slice(), + ctx.embedding_provider(), + )); + + let elapsed = started.elapsed(); + ctx.record_stage_duration(stage, elapsed); + info!( + evaluation_stage = stage.label(), + duration_ms = elapsed.as_millis(), + "completed evaluation stage" + ); + + machine + .prepare_corpus() + .map_err(|(_, guard)| map_guard_error("prepare_corpus", guard)) +} diff --git a/eval/src/eval/pipeline/stages/prepare_db.rs b/eval/src/eval/pipeline/stages/prepare_db.rs new file mode 100644 index 0000000..94cee4b --- /dev/null +++ b/eval/src/eval/pipeline/stages/prepare_db.rs @@ -0,0 +1,111 @@ +use std::{sync::Arc, time::Instant}; + +use anyhow::{anyhow, Context}; +use tracing::info; + +use crate::{ + args::EmbeddingBackend, + cache::EmbeddingCache, + embedding, + eval::{ + connect_eval_db, enforce_system_settings, load_or_init_system_settings, sanitize_model_code, + }, + openai, +}; + +use super::super::{ + context::{EvalStage, EvaluationContext}, + state::{DbReady, EvaluationMachine, SlicePrepared}, +}; +use super::{map_guard_error, StageResult}; + +pub(crate) async fn prepare_db( + machine: EvaluationMachine<(), SlicePrepared>, + ctx: &mut EvaluationContext<'_>, +) -> StageResult { + let stage = EvalStage::PrepareDb; + info!( + evaluation_stage = stage.label(), + "starting evaluation stage" + ); + let started = Instant::now(); + + let namespace = ctx.namespace.clone(); + let database = ctx.database.clone(); + let config = ctx.config(); + + let db = connect_eval_db(config, &namespace, &database).await?; + let (mut settings, settings_missing) = load_or_init_system_settings(&db).await?; + + let embedding_provider = + embedding::build_provider(config, settings.embedding_dimensions as usize) + .await + .context("building embedding provider")?; + let (raw_openai_client, openai_base_url) = + openai::build_client_from_env().context("building OpenAI client")?; + let openai_client = Arc::new(raw_openai_client); + let provider_dimension = embedding_provider.dimension(); + if provider_dimension == 0 { + return Err(anyhow!( + "embedding provider reported zero dimensions; cannot continue" + )); + } + + info!( + backend = embedding_provider.backend_label(), + model = embedding_provider + .model_code() + .as_deref() + .unwrap_or(""), + dimension = provider_dimension, + "Embedding provider initialised" + ); + info!(openai_base_url = %openai_base_url, "OpenAI client configured"); + + let embedding_cache = if config.embedding_backend == EmbeddingBackend::FastEmbed { + if let Some(model_code) = embedding_provider.model_code() { + let sanitized = sanitize_model_code(&model_code); + let path = config.cache_dir.join(format!("{sanitized}.json")); + if config.force_convert && path.exists() { + tokio::fs::remove_file(&path) + .await + .with_context(|| format!("removing stale cache {}", path.display())) + .ok(); + } + let cache = EmbeddingCache::load(&path).await?; + info!(path = %path.display(), "Embedding cache ready"); + Some(cache) + } else { + None + } + } else { + None + }; + + let must_reapply_settings = settings_missing; + let defer_initial_enforce = settings_missing && !config.reseed_slice; + if !defer_initial_enforce { + settings = enforce_system_settings(&db, settings, provider_dimension, config).await?; + } + + ctx.db = Some(db); + ctx.settings_missing = settings_missing; + ctx.must_reapply_settings = must_reapply_settings; + ctx.settings = Some(settings); + ctx.embedding_provider = Some(embedding_provider); + ctx.embedding_cache = embedding_cache; + ctx.openai_client = Some(openai_client); + ctx.openai_base_url = Some(openai_base_url); + + let elapsed = started.elapsed(); + ctx.record_stage_duration(stage, elapsed); + info!( + evaluation_stage = stage.label(), + duration_ms = elapsed.as_millis(), + "completed evaluation stage" + ); + + machine + .prepare_db() + .map_err(|(_, guard)| map_guard_error("prepare_db", guard)) +} diff --git a/eval/src/eval/pipeline/stages/prepare_namespace.rs b/eval/src/eval/pipeline/stages/prepare_namespace.rs new file mode 100644 index 0000000..03fc282 --- /dev/null +++ b/eval/src/eval/pipeline/stages/prepare_namespace.rs @@ -0,0 +1,159 @@ +use std::time::Instant; + +use anyhow::{anyhow, Context}; +use common::storage::types::system_settings::SystemSettings; +use tracing::{info, warn}; + +use crate::{ + db_helpers::{recreate_indexes, remove_all_indexes, reset_namespace}, + eval::{ + can_reuse_namespace, cases_from_manifest, enforce_system_settings, ensure_eval_user, + record_namespace_state, warm_hnsw_cache, + }, + ingest, +}; + +use super::super::{ + context::{EvalStage, EvaluationContext}, + state::{CorpusReady, EvaluationMachine, NamespaceReady}, +}; +use super::{map_guard_error, StageResult}; + +pub(crate) async fn prepare_namespace( + machine: EvaluationMachine<(), CorpusReady>, + ctx: &mut EvaluationContext<'_>, +) -> StageResult { + let stage = EvalStage::PrepareNamespace; + info!( + evaluation_stage = stage.label(), + "starting evaluation stage" + ); + let started = Instant::now(); + + let config = ctx.config(); + let dataset = ctx.dataset(); + let expected_fingerprint = ctx + .expected_fingerprint + .as_deref() + .unwrap_or_default() + .to_string(); + let namespace = ctx.namespace.clone(); + let database = ctx.database.clone(); + let embedding_provider = ctx.embedding_provider().clone(); + + let mut namespace_reused = false; + if !config.reseed_slice { + namespace_reused = { + let slice = ctx.slice(); + can_reuse_namespace( + ctx.db(), + ctx.descriptor(), + &namespace, + &database, + dataset.metadata.id.as_str(), + slice.manifest.slice_id.as_str(), + expected_fingerprint.as_str(), + slice.manifest.case_count, + ) + .await? + }; + } + + let mut namespace_seed_ms = None; + if !namespace_reused { + ctx.must_reapply_settings = true; + if let Err(err) = reset_namespace(ctx.db(), &namespace, &database).await { + warn!( + error = %err, + namespace, + database = %database, + "Failed to reset namespace before reseeding; continuing with existing data" + ); + } else if let Err(err) = ctx.db().apply_migrations().await { + warn!(error = %err, "Failed to reapply migrations after namespace reset"); + } + + { + let slice = ctx.slice(); + info!( + slice = slice.manifest.slice_id.as_str(), + window_offset = ctx.window_offset, + window_length = ctx.window_length, + positives = slice.manifest.positive_paragraphs, + negatives = slice.manifest.negative_paragraphs, + total = slice.manifest.total_paragraphs, + "Seeding ingestion corpus into SurrealDB" + ); + } + let indexes_disabled = remove_all_indexes(ctx.db()).await.is_ok(); + let seed_start = Instant::now(); + ingest::seed_manifest_into_db(ctx.db(), &ctx.corpus_handle().manifest) + .await + .context("seeding ingestion corpus from manifest")?; + namespace_seed_ms = Some(seed_start.elapsed().as_millis() as u128); + if indexes_disabled { + info!("Recreating indexes after namespace reset"); + if let Err(err) = recreate_indexes(ctx.db(), embedding_provider.dimension()).await { + warn!(error = %err, "failed to restore indexes after namespace reset"); + } else { + warm_hnsw_cache(ctx.db(), embedding_provider.dimension()).await?; + } + } + { + let slice = ctx.slice(); + record_namespace_state( + ctx.descriptor(), + dataset.metadata.id.as_str(), + slice.manifest.slice_id.as_str(), + expected_fingerprint.as_str(), + &namespace, + &database, + slice.manifest.case_count, + ) + .await; + } + } + + if ctx.must_reapply_settings { + let mut settings = SystemSettings::get_current(ctx.db()) + .await + .context("reloading system settings after namespace reset")?; + settings = + enforce_system_settings(ctx.db(), settings, embedding_provider.dimension(), config) + .await?; + ctx.settings = Some(settings); + ctx.must_reapply_settings = false; + } + + let user = ensure_eval_user(ctx.db()).await?; + ctx.eval_user = Some(user); + + let cases = cases_from_manifest(&ctx.corpus_handle().manifest); + if cases.is_empty() { + return Err(anyhow!( + "no answerable questions found in converted dataset for evaluation" + )); + } + ctx.cases = cases; + ctx.namespace_reused = namespace_reused; + ctx.namespace_seed_ms = namespace_seed_ms; + + info!( + cases = ctx.cases.len(), + window_offset = ctx.window_offset, + namespace_reused = namespace_reused, + "Dataset ready" + ); + + let elapsed = started.elapsed(); + ctx.record_stage_duration(stage, elapsed); + info!( + evaluation_stage = stage.label(), + duration_ms = elapsed.as_millis(), + "completed evaluation stage" + ); + + machine + .prepare_namespace() + .map_err(|(_, guard)| map_guard_error("prepare_namespace", guard)) +} diff --git a/eval/src/eval/pipeline/stages/prepare_slice.rs b/eval/src/eval/pipeline/stages/prepare_slice.rs new file mode 100644 index 0000000..9c524e2 --- /dev/null +++ b/eval/src/eval/pipeline/stages/prepare_slice.rs @@ -0,0 +1,66 @@ +use std::time::Instant; + +use anyhow::Context; +use tracing::info; + +use crate::{ + eval::{default_database, default_namespace, ledger_target}, + slice, +}; + +use super::super::{ + context::{EvalStage, EvaluationContext}, + state::{EvaluationMachine, Ready, SlicePrepared}, +}; +use super::{map_guard_error, StageResult}; + +pub(crate) async fn prepare_slice( + machine: EvaluationMachine<(), Ready>, + ctx: &mut EvaluationContext<'_>, +) -> StageResult { + let stage = EvalStage::PrepareSlice; + info!( + evaluation_stage = stage.label(), + "starting evaluation stage" + ); + let started = Instant::now(); + + let ledger_limit = ledger_target(ctx.config()); + let slice_settings = slice::slice_config_with_limit(ctx.config(), ledger_limit); + let resolved_slice = + slice::resolve_slice(ctx.dataset(), &slice_settings).context("resolving dataset slice")?; + let window = slice::select_window( + &resolved_slice, + ctx.config().slice_offset, + ctx.config().limit, + ) + .context("selecting slice window (use --slice-grow to extend the ledger first)")?; + + ctx.ledger_limit = ledger_limit; + ctx.slice_settings = Some(slice_settings); + ctx.slice = Some(resolved_slice.clone()); + ctx.window_offset = window.offset; + ctx.window_length = window.length; + ctx.window_total_cases = window.total_cases; + + ctx.namespace = ctx.config().db_namespace.clone().unwrap_or_else(|| { + default_namespace(ctx.dataset().metadata.id.as_str(), ctx.config().limit) + }); + ctx.database = ctx + .config() + .db_database + .clone() + .unwrap_or_else(default_database); + + let elapsed = started.elapsed(); + ctx.record_stage_duration(stage, elapsed); + info!( + evaluation_stage = stage.label(), + duration_ms = elapsed.as_millis(), + "completed evaluation stage" + ); + + machine + .prepare_slice() + .map_err(|(_, guard)| map_guard_error("prepare_slice", guard)) +} diff --git a/eval/src/eval/pipeline/stages/run_queries.rs b/eval/src/eval/pipeline/stages/run_queries.rs new file mode 100644 index 0000000..1b1cf94 --- /dev/null +++ b/eval/src/eval/pipeline/stages/run_queries.rs @@ -0,0 +1,337 @@ +use std::{collections::HashSet, sync::Arc, time::Instant}; + +use anyhow::Context; +use futures::stream::{self, StreamExt, TryStreamExt}; +use tracing::{debug, info}; + +use crate::eval::{ + apply_dataset_tuning_overrides, build_case_diagnostics, text_contains_answer, CaseDiagnostics, + CaseSummary, RetrievedSummary, +}; +use composite_retrieval::pipeline::{self, PipelineStageTimings, RetrievalConfig}; +use composite_retrieval::reranking::RerankerPool; +use tokio::sync::Semaphore; + +use super::super::{ + context::{EvalStage, EvaluationContext}, + state::{EvaluationMachine, NamespaceReady, QueriesFinished}, +}; +use super::{map_guard_error, StageResult}; + +pub(crate) async fn run_queries( + machine: EvaluationMachine<(), NamespaceReady>, + ctx: &mut EvaluationContext<'_>, +) -> StageResult { + let stage = EvalStage::RunQueries; + info!( + evaluation_stage = stage.label(), + "starting evaluation stage" + ); + let started = Instant::now(); + + let config = ctx.config(); + let dataset = ctx.dataset(); + let slice_settings = ctx + .slice_settings + .as_ref() + .expect("slice settings missing during query stage"); + let total_cases = ctx.cases.len(); + let cases_iter = std::mem::take(&mut ctx.cases).into_iter().enumerate(); + + let rerank_pool = if config.rerank { + Some(RerankerPool::new(config.rerank_pool_size).context("initialising reranker pool")?) + } else { + None + }; + + let mut retrieval_config = RetrievalConfig::default(); + retrieval_config.tuning.rerank_keep_top = config.rerank_keep_top; + if retrieval_config.tuning.fallback_min_results < config.rerank_keep_top { + retrieval_config.tuning.fallback_min_results = config.rerank_keep_top; + } + if let Some(value) = config.chunk_vector_take { + retrieval_config.tuning.chunk_vector_take = value; + } + if let Some(value) = config.chunk_fts_take { + retrieval_config.tuning.chunk_fts_take = value; + } + if let Some(value) = config.chunk_token_budget { + retrieval_config.tuning.token_budget_estimate = value; + } + if let Some(value) = config.chunk_avg_chars_per_token { + retrieval_config.tuning.avg_chars_per_token = value; + } + if let Some(value) = config.max_chunks_per_entity { + retrieval_config.tuning.max_chunks_per_entity = value; + } + + apply_dataset_tuning_overrides(dataset, config, &mut retrieval_config.tuning); + + let active_tuning = retrieval_config.tuning.clone(); + let effective_chunk_vector = config + .chunk_vector_take + .unwrap_or(active_tuning.chunk_vector_take); + let effective_chunk_fts = config + .chunk_fts_take + .unwrap_or(active_tuning.chunk_fts_take); + + info!( + dataset = dataset.metadata.id.as_str(), + slice_seed = config.slice_seed, + slice_offset = config.slice_offset, + slice_limit = config + .limit + .unwrap_or(ctx.window_total_cases), + negative_multiplier = %slice_settings.negative_multiplier, + rerank_enabled = config.rerank, + rerank_pool_size = config.rerank_pool_size, + rerank_keep_top = config.rerank_keep_top, + chunk_min = config.chunk_min_chars, + chunk_max = config.chunk_max_chars, + chunk_vector_take = effective_chunk_vector, + chunk_fts_take = effective_chunk_fts, + chunk_token_budget = active_tuning.token_budget_estimate, + embedding_backend = ctx.embedding_provider().backend_label(), + embedding_model = ctx + .embedding_provider() + .model_code() + .as_deref() + .unwrap_or(""), + "Starting evaluation run" + ); + + let retrieval_config = Arc::new(retrieval_config); + ctx.rerank_pool = rerank_pool.clone(); + ctx.retrieval_config = Some(retrieval_config.clone()); + + ctx.evaluation_start = Some(Instant::now()); + let user_id = ctx.evaluation_user().id.clone(); + let concurrency = config.concurrency.max(1); + let diagnostics_enabled = ctx.diagnostics_enabled; + + let query_semaphore = Arc::new(Semaphore::new(concurrency)); + + info!( + total_cases = total_cases, + max_concurrent_queries = concurrency, + "Starting evaluation with staged query execution" + ); + + let embedding_provider_for_queries = ctx.embedding_provider().clone(); + let rerank_pool_for_queries = rerank_pool.clone(); + let db = ctx.db().clone(); + let openai_client = ctx.openai_client(); + + let results: Vec<( + usize, + CaseSummary, + Option, + PipelineStageTimings, + )> = stream::iter(cases_iter) + .map(move |(idx, case)| { + let db = db.clone(); + let openai_client = openai_client.clone(); + let user_id = user_id.clone(); + let retrieval_config = retrieval_config.clone(); + let embedding_provider = embedding_provider_for_queries.clone(); + let rerank_pool = rerank_pool_for_queries.clone(); + let semaphore = query_semaphore.clone(); + let diagnostics_enabled = diagnostics_enabled; + + async move { + let _permit = semaphore + .acquire() + .await + .context("acquiring query semaphore permit")?; + + let crate::eval::SeededCase { + question_id, + question, + expected_source, + answers, + paragraph_id, + paragraph_title, + expected_chunk_ids, + } = case; + let query_start = Instant::now(); + + debug!(question_id = %question_id, "Evaluating query"); + let query_embedding = + embedding_provider.embed(&question).await.with_context(|| { + format!("generating embedding for question {}", question_id) + })?; + let reranker = match &rerank_pool { + Some(pool) => Some(pool.checkout().await), + None => None, + }; + + let (results, pipeline_diagnostics, stage_timings) = if diagnostics_enabled { + let outcome = pipeline::run_pipeline_with_embedding_with_diagnostics( + &db, + &openai_client, + query_embedding, + &question, + &user_id, + (*retrieval_config).clone(), + reranker, + ) + .await + .with_context(|| format!("running pipeline for question {}", question_id))?; + (outcome.results, outcome.diagnostics, outcome.stage_timings) + } else { + let outcome = pipeline::run_pipeline_with_embedding_with_metrics( + &db, + &openai_client, + query_embedding, + &question, + &user_id, + (*retrieval_config).clone(), + reranker, + ) + .await + .with_context(|| format!("running pipeline for question {}", question_id))?; + (outcome.results, None, outcome.stage_timings) + }; + let query_latency = query_start.elapsed().as_millis() as u128; + + let mut retrieved = Vec::new(); + let mut match_rank = None; + let answers_lower: Vec = + answers.iter().map(|ans| ans.to_ascii_lowercase()).collect(); + let expected_chunk_ids_set: HashSet<&str> = + expected_chunk_ids.iter().map(|id| id.as_str()).collect(); + let chunk_id_required = !expected_chunk_ids_set.is_empty(); + let mut entity_hit = false; + let mut chunk_text_hit = false; + let mut chunk_id_hit = !chunk_id_required; + + for (idx_entity, entity) in results.iter().enumerate() { + if idx_entity >= config.k { + break; + } + let entity_match = entity.entity.source_id == expected_source; + if entity_match { + entity_hit = true; + } + let chunk_text_for_entity = entity + .chunks + .iter() + .any(|chunk| text_contains_answer(&chunk.chunk.chunk, &answers_lower)); + if chunk_text_for_entity { + chunk_text_hit = true; + } + let chunk_id_for_entity = if chunk_id_required { + expected_chunk_ids_set.contains(entity.entity.source_id.as_str()) + || entity.chunks.iter().any(|chunk| { + expected_chunk_ids_set.contains(chunk.chunk.id.as_str()) + }) + } else { + true + }; + if chunk_id_for_entity { + chunk_id_hit = true; + } + let success = entity_match && chunk_text_for_entity && chunk_id_for_entity; + if success && match_rank.is_none() { + match_rank = Some(idx_entity + 1); + } + let detail_fields = if config.detailed_report { + ( + Some(entity.entity.description.clone()), + Some(format!("{:?}", entity.entity.entity_type)), + Some(chunk_text_for_entity), + Some(chunk_id_for_entity), + ) + } else { + (None, None, None, None) + }; + retrieved.push(RetrievedSummary { + rank: idx_entity + 1, + entity_id: entity.entity.id.clone(), + source_id: entity.entity.source_id.clone(), + entity_name: entity.entity.name.clone(), + score: entity.score, + matched: success, + entity_description: detail_fields.0, + entity_category: detail_fields.1, + chunk_text_match: detail_fields.2, + chunk_id_match: detail_fields.3, + }); + } + + let overall_match = match_rank.is_some(); + let summary = CaseSummary { + question_id, + question, + paragraph_id, + paragraph_title, + expected_source, + answers, + matched: overall_match, + entity_match: entity_hit, + chunk_text_match: chunk_text_hit, + chunk_id_match: chunk_id_hit, + match_rank, + latency_ms: query_latency, + retrieved, + }; + + let diagnostics = if diagnostics_enabled { + Some(build_case_diagnostics( + &summary, + &expected_chunk_ids, + &answers_lower, + &results, + pipeline_diagnostics, + )) + } else { + None + }; + + Ok::< + ( + usize, + CaseSummary, + Option, + PipelineStageTimings, + ), + anyhow::Error, + >((idx, summary, diagnostics, stage_timings)) + } + }) + .buffer_unordered(concurrency) + .try_collect() + .await?; + + let mut ordered = results; + ordered.sort_by_key(|(idx, ..)| *idx); + let mut summaries = Vec::with_capacity(ordered.len()); + let mut latencies = Vec::with_capacity(ordered.len()); + let mut diagnostics_output = Vec::new(); + let mut stage_latency_samples = Vec::with_capacity(ordered.len()); + for (_, summary, diagnostics, stage_timings) in ordered { + latencies.push(summary.latency_ms); + summaries.push(summary); + if let Some(diag) = diagnostics { + diagnostics_output.push(diag); + } + stage_latency_samples.push(stage_timings); + } + + ctx.query_summaries = summaries; + ctx.latencies = latencies; + ctx.diagnostics_output = diagnostics_output; + ctx.stage_latency_samples = stage_latency_samples; + + let elapsed = started.elapsed(); + ctx.record_stage_duration(stage, elapsed); + info!( + evaluation_stage = stage.label(), + duration_ms = elapsed.as_millis(), + "completed evaluation stage" + ); + + machine + .run_queries() + .map_err(|(_, guard)| map_guard_error("run_queries", guard)) +} diff --git a/eval/src/eval/pipeline/stages/summarize.rs b/eval/src/eval/pipeline/stages/summarize.rs new file mode 100644 index 0000000..a5508b3 --- /dev/null +++ b/eval/src/eval/pipeline/stages/summarize.rs @@ -0,0 +1,173 @@ +use std::time::Instant; + +use chrono::Utc; +use tracing::info; + +use crate::eval::{ + build_stage_latency_breakdown, compute_latency_stats, EvaluationSummary, PerformanceTimings, +}; + +use super::super::{ + context::{EvalStage, EvaluationContext}, + state::{EvaluationMachine, QueriesFinished, Summarized}, +}; +use super::{map_guard_error, StageResult}; + +pub(crate) async fn summarize( + machine: EvaluationMachine<(), QueriesFinished>, + ctx: &mut EvaluationContext<'_>, +) -> StageResult { + let stage = EvalStage::Summarize; + info!( + evaluation_stage = stage.label(), + "starting evaluation stage" + ); + let started = Instant::now(); + + let summaries = std::mem::take(&mut ctx.query_summaries); + let latencies = std::mem::take(&mut ctx.latencies); + let stage_latency_samples = std::mem::take(&mut ctx.stage_latency_samples); + let duration_ms = ctx + .evaluation_start + .take() + .map(|start| start.elapsed().as_millis()) + .unwrap_or_default(); + let config = ctx.config(); + let dataset = ctx.dataset(); + let slice = ctx.slice(); + let corpus_handle = ctx.corpus_handle(); + let total_cases = summaries.len(); + + let mut correct = 0usize; + let mut correct_at_1 = 0usize; + let mut correct_at_2 = 0usize; + let mut correct_at_3 = 0usize; + for summary in &summaries { + if summary.matched { + correct += 1; + if let Some(rank) = summary.match_rank { + if rank <= 1 { + correct_at_1 += 1; + } + if rank <= 2 { + correct_at_2 += 1; + } + if rank <= 3 { + correct_at_3 += 1; + } + } + } + } + + let latency_stats = compute_latency_stats(&latencies); + let stage_latency = build_stage_latency_breakdown(&stage_latency_samples); + + let precision = if total_cases == 0 { + 0.0 + } else { + (correct as f64) / (total_cases as f64) + }; + let precision_at_1 = if total_cases == 0 { + 0.0 + } else { + (correct_at_1 as f64) / (total_cases as f64) + }; + let precision_at_2 = if total_cases == 0 { + 0.0 + } else { + (correct_at_2 as f64) / (total_cases as f64) + }; + let precision_at_3 = if total_cases == 0 { + 0.0 + } else { + (correct_at_3 as f64) / (total_cases as f64) + }; + + let active_tuning = ctx + .retrieval_config + .as_ref() + .map(|cfg| cfg.tuning.clone()) + .unwrap_or_default(); + + let perf_timings = PerformanceTimings { + openai_base_url: ctx + .openai_base_url + .clone() + .unwrap_or_else(|| "".to_string()), + ingestion_ms: ctx.ingestion_duration_ms, + namespace_seed_ms: ctx.namespace_seed_ms, + evaluation_stage_ms: ctx.stage_timings.clone(), + stage_latency, + }; + + ctx.summary = Some(EvaluationSummary { + generated_at: Utc::now(), + k: config.k, + limit: config.limit, + run_label: config.label.clone(), + total_cases, + correct, + precision, + correct_at_1, + correct_at_2, + correct_at_3, + precision_at_1, + precision_at_2, + precision_at_3, + duration_ms, + dataset_id: dataset.metadata.id.clone(), + dataset_label: dataset.metadata.label.clone(), + dataset_includes_unanswerable: dataset.metadata.include_unanswerable, + dataset_source: dataset.source.clone(), + slice_id: slice.manifest.slice_id.clone(), + slice_seed: slice.manifest.seed, + slice_total_cases: slice.manifest.case_count, + slice_window_offset: ctx.window_offset, + slice_window_length: ctx.window_length, + slice_cases: total_cases, + slice_positive_paragraphs: slice.manifest.positive_paragraphs, + slice_negative_paragraphs: slice.manifest.negative_paragraphs, + slice_total_paragraphs: slice.manifest.total_paragraphs, + slice_negative_multiplier: slice.manifest.negative_multiplier, + namespace_reused: ctx.namespace_reused, + corpus_paragraphs: ctx.corpus_handle().manifest.metadata.paragraph_count, + ingestion_cache_path: corpus_handle.path.display().to_string(), + ingestion_reused: corpus_handle.reused_ingestion, + ingestion_embeddings_reused: corpus_handle.reused_embeddings, + ingestion_fingerprint: corpus_handle + .manifest + .metadata + .ingestion_fingerprint + .clone(), + positive_paragraphs_reused: corpus_handle.positive_reused, + negative_paragraphs_reused: corpus_handle.negative_reused, + latency_ms: latency_stats, + perf: perf_timings, + embedding_backend: ctx.embedding_provider().backend_label().to_string(), + embedding_model: ctx.embedding_provider().model_code(), + embedding_dimension: ctx.embedding_provider().dimension(), + rerank_enabled: config.rerank, + rerank_pool_size: ctx.rerank_pool.as_ref().map(|_| config.rerank_pool_size), + rerank_keep_top: config.rerank_keep_top, + concurrency: config.concurrency.max(1), + detailed_report: config.detailed_report, + chunk_vector_take: active_tuning.chunk_vector_take, + chunk_fts_take: active_tuning.chunk_fts_take, + chunk_token_budget: active_tuning.token_budget_estimate, + chunk_avg_chars_per_token: active_tuning.avg_chars_per_token, + max_chunks_per_entity: active_tuning.max_chunks_per_entity, + cases: summaries, + }); + + let elapsed = started.elapsed(); + ctx.record_stage_duration(stage, elapsed); + info!( + evaluation_stage = stage.label(), + duration_ms = elapsed.as_millis(), + "completed evaluation stage" + ); + + machine + .summarize() + .map_err(|(_, guard)| map_guard_error("summarize", guard)) +} diff --git a/eval/src/eval/pipeline/state.rs b/eval/src/eval/pipeline/state.rs new file mode 100644 index 0000000..aa9e753 --- /dev/null +++ b/eval/src/eval/pipeline/state.rs @@ -0,0 +1,31 @@ +use state_machines::state_machine; + +state_machine! { + name: EvaluationMachine, + state: EvaluationState, + initial: Ready, + states: [Ready, SlicePrepared, DbReady, CorpusReady, NamespaceReady, QueriesFinished, Summarized, Completed, Failed], + events { + prepare_slice { transition: { from: Ready, to: SlicePrepared } } + prepare_db { transition: { from: SlicePrepared, to: DbReady } } + prepare_corpus { transition: { from: DbReady, to: CorpusReady } } + prepare_namespace { transition: { from: CorpusReady, to: NamespaceReady } } + run_queries { transition: { from: NamespaceReady, to: QueriesFinished } } + summarize { transition: { from: QueriesFinished, to: Summarized } } + finalize { transition: { from: Summarized, to: Completed } } + abort { + transition: { from: Ready, to: Failed } + transition: { from: SlicePrepared, to: Failed } + transition: { from: DbReady, to: Failed } + transition: { from: CorpusReady, to: Failed } + transition: { from: NamespaceReady, to: Failed } + transition: { from: QueriesFinished, to: Failed } + transition: { from: Summarized, to: Failed } + transition: { from: Completed, to: Failed } + } + } +} + +pub fn ready() -> EvaluationMachine<(), Ready> { + EvaluationMachine::new(()) +} diff --git a/eval/src/ingest.rs b/eval/src/ingest.rs new file mode 100644 index 0000000..c02fe23 --- /dev/null +++ b/eval/src/ingest.rs @@ -0,0 +1,1000 @@ +use std::{ + collections::{HashMap, HashSet}, + fs, + io::{BufReader, Read}, + path::{Path, PathBuf}, + sync::Arc, +}; + +use anyhow::{anyhow, Context, Result}; +use async_openai::Client; +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use common::{ + storage::{ + db::SurrealDbClient, + store::{DynStore, StorageManager}, + types::{ + ingestion_payload::IngestionPayload, ingestion_task::IngestionTask, + knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship, + text_chunk::TextChunk, text_content::TextContent, + }, + }, + utils::config::{AppConfig, StorageKind}, +}; +use futures::future::try_join_all; +use ingestion_pipeline::{IngestionConfig, IngestionPipeline}; +use object_store::memory::InMemory; +use sha2::{Digest, Sha256}; +use tracing::{info, warn}; +use uuid::Uuid; + +use crate::{ + args::Config, + datasets::{ConvertedDataset, ConvertedParagraph, ConvertedQuestion}, + embedding::EmbeddingProvider, + slices::{self, ResolvedSlice, SliceParagraphKind}, +}; + +const MANIFEST_VERSION: u32 = 1; +const INGESTION_SPEC_VERSION: u32 = 1; +const INGESTION_MAX_RETRIES: usize = 3; +const INGESTION_BATCH_SIZE: usize = 5; +const PARAGRAPH_SHARD_VERSION: u32 = 1; + +#[derive(Debug, Clone)] +pub struct CorpusCacheConfig { + pub ingestion_cache_dir: PathBuf, + pub force_refresh: bool, + pub refresh_embeddings_only: bool, +} + +impl CorpusCacheConfig { + pub fn new( + ingestion_cache_dir: impl Into, + force_refresh: bool, + refresh_embeddings_only: bool, + ) -> Self { + Self { + ingestion_cache_dir: ingestion_cache_dir.into(), + force_refresh, + refresh_embeddings_only, + } + } +} + +#[async_trait] +pub trait CorpusEmbeddingProvider: Send + Sync { + fn backend_label(&self) -> &str; + fn model_code(&self) -> Option; + fn dimension(&self) -> usize; + async fn embed_batch(&self, texts: Vec) -> Result>>; +} + +type OpenAIClient = Client; + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct CorpusManifest { + pub version: u32, + pub metadata: CorpusMetadata, + pub paragraphs: Vec, + pub questions: Vec, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct CorpusMetadata { + pub dataset_id: String, + pub dataset_label: String, + pub slice_id: String, + pub include_unanswerable: bool, + pub ingestion_fingerprint: String, + pub embedding_backend: String, + pub embedding_model: Option, + pub embedding_dimension: usize, + pub converted_checksum: String, + pub generated_at: DateTime, + pub paragraph_count: usize, + pub question_count: usize, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct CorpusParagraph { + pub paragraph_id: String, + pub title: String, + pub text_content: TextContent, + pub entities: Vec, + pub relationships: Vec, + pub chunks: Vec, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct CorpusQuestion { + pub question_id: String, + pub paragraph_id: String, + pub text_content_id: String, + pub question_text: String, + pub answers: Vec, + pub is_impossible: bool, + pub matching_chunk_ids: Vec, +} + +pub struct CorpusHandle { + pub manifest: CorpusManifest, + pub path: PathBuf, + pub reused_ingestion: bool, + pub reused_embeddings: bool, + pub positive_reused: usize, + pub positive_ingested: usize, + pub negative_reused: usize, + pub negative_ingested: usize, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct ParagraphShard { + version: u32, + paragraph_id: String, + shard_path: String, + ingestion_fingerprint: String, + ingested_at: DateTime, + title: String, + text_content: TextContent, + entities: Vec, + relationships: Vec, + chunks: Vec, + #[serde(default)] + question_bindings: HashMap>, + #[serde(default)] + embedding_backend: String, + #[serde(default)] + embedding_model: Option, + #[serde(default)] + embedding_dimension: usize, +} + +struct ParagraphShardStore { + base_dir: PathBuf, +} + +impl ParagraphShardStore { + fn new(base_dir: PathBuf) -> Self { + Self { base_dir } + } + + fn ensure_base_dir(&self) -> Result<()> { + fs::create_dir_all(&self.base_dir) + .with_context(|| format!("creating shard base dir {}", self.base_dir.display())) + } + + fn resolve(&self, relative: &str) -> PathBuf { + self.base_dir.join(relative) + } + + fn load(&self, relative: &str, fingerprint: &str) -> Result> { + let path = self.resolve(relative); + let file = match fs::File::open(&path) { + Ok(file) => file, + Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None), + Err(err) => { + return Err(err).with_context(|| format!("opening shard {}", path.display())) + } + }; + let reader = BufReader::new(file); + let mut shard: ParagraphShard = serde_json::from_reader(reader) + .with_context(|| format!("parsing shard {}", path.display()))?; + if shard.version != PARAGRAPH_SHARD_VERSION { + warn!( + path = %path.display(), + version = shard.version, + expected = PARAGRAPH_SHARD_VERSION, + "Skipping shard due to version mismatch" + ); + return Ok(None); + } + if shard.ingestion_fingerprint != fingerprint { + return Ok(None); + } + shard.shard_path = relative.to_string(); + Ok(Some(shard)) + } + + fn persist(&self, shard: &ParagraphShard) -> Result<()> { + let path = self.resolve(&shard.shard_path); + if let Some(parent) = path.parent() { + fs::create_dir_all(parent) + .with_context(|| format!("creating shard dir {}", parent.display()))?; + } + let tmp_path = path.with_extension("json.tmp"); + let body = serde_json::to_vec_pretty(shard).context("serialising paragraph shard")?; + fs::write(&tmp_path, &body) + .with_context(|| format!("writing shard tmp {}", tmp_path.display()))?; + fs::rename(&tmp_path, &path) + .with_context(|| format!("renaming shard tmp {}", path.display()))?; + Ok(()) + } +} + +#[async_trait] +impl CorpusEmbeddingProvider for EmbeddingProvider { + fn backend_label(&self) -> &str { + EmbeddingProvider::backend_label(self) + } + + fn model_code(&self) -> Option { + EmbeddingProvider::model_code(self) + } + + fn dimension(&self) -> usize { + EmbeddingProvider::dimension(self) + } + + async fn embed_batch(&self, texts: Vec) -> Result>> { + EmbeddingProvider::embed_batch(self, texts).await + } +} + +impl From<&Config> for CorpusCacheConfig { + fn from(config: &Config) -> Self { + CorpusCacheConfig::new( + config.ingestion_cache_dir.clone(), + config.force_convert || config.slice_reset_ingestion, + config.refresh_embeddings_only, + ) + } +} + +impl ParagraphShard { + fn new( + paragraph: &ConvertedParagraph, + shard_path: String, + ingestion_fingerprint: &str, + text_content: TextContent, + entities: Vec, + relationships: Vec, + chunks: Vec, + embedding_backend: &str, + embedding_model: Option, + embedding_dimension: usize, + ) -> Self { + Self { + version: PARAGRAPH_SHARD_VERSION, + paragraph_id: paragraph.id.clone(), + shard_path, + ingestion_fingerprint: ingestion_fingerprint.to_string(), + ingested_at: Utc::now(), + title: paragraph.title.clone(), + text_content, + entities, + relationships, + chunks, + question_bindings: HashMap::new(), + embedding_backend: embedding_backend.to_string(), + embedding_model, + embedding_dimension, + } + } + + fn to_corpus_paragraph(&self) -> CorpusParagraph { + CorpusParagraph { + paragraph_id: self.paragraph_id.clone(), + title: self.title.clone(), + text_content: self.text_content.clone(), + entities: self.entities.clone(), + relationships: self.relationships.clone(), + chunks: self.chunks.clone(), + } + } + + fn ensure_question_binding( + &mut self, + question: &ConvertedQuestion, + ) -> Result<(Vec, bool)> { + if let Some(existing) = self.question_bindings.get(&question.id) { + return Ok((existing.clone(), false)); + } + let chunk_ids = validate_answers(&self.text_content, &self.chunks, question)?; + self.question_bindings + .insert(question.id.clone(), chunk_ids.clone()); + Ok((chunk_ids, true)) + } +} + +#[derive(Clone)] +struct ParagraphShardRecord { + shard: ParagraphShard, + dirty: bool, + needs_reembed: bool, +} + +#[derive(Clone)] +struct IngestRequest<'a> { + slot: usize, + paragraph: &'a ConvertedParagraph, + shard_path: String, + question_refs: Vec<&'a ConvertedQuestion>, +} + +impl<'a> IngestRequest<'a> { + fn from_entry( + slot: usize, + paragraph: &'a ConvertedParagraph, + entry: &'a slices::SliceParagraphEntry, + ) -> Result { + let shard_path = entry + .shard_path + .clone() + .unwrap_or_else(|| slices::default_shard_path(&entry.id)); + let question_refs = match &entry.kind { + SliceParagraphKind::Positive { question_ids } => question_ids + .iter() + .map(|id| { + paragraph + .questions + .iter() + .find(|question| question.id == *id) + .ok_or_else(|| { + anyhow!( + "paragraph '{}' missing question '{}' referenced by slice", + paragraph.id, + id + ) + }) + }) + .collect::>>()?, + SliceParagraphKind::Negative => Vec::new(), + }; + Ok(Self { + slot, + paragraph, + shard_path, + question_refs, + }) + } +} + +struct ParagraphPlan<'a> { + slot: usize, + entry: &'a slices::SliceParagraphEntry, + paragraph: &'a ConvertedParagraph, +} + +#[derive(Default)] +struct IngestionStats { + positive_reused: usize, + positive_ingested: usize, + negative_reused: usize, + negative_ingested: usize, +} + +pub async fn ensure_corpus( + dataset: &ConvertedDataset, + slice: &ResolvedSlice<'_>, + window: &slices::SliceWindow<'_>, + cache: &CorpusCacheConfig, + embedding: &E, + openai: Arc, + user_id: &str, + converted_path: &Path, +) -> Result { + let checksum = compute_file_checksum(converted_path) + .with_context(|| format!("computing checksum for {}", converted_path.display()))?; + let ingestion_fingerprint = build_ingestion_fingerprint(dataset, slice, &checksum); + + let base_dir = cache + .ingestion_cache_dir + .join(dataset.metadata.id.as_str()) + .join(slice.manifest.slice_id.as_str()); + if cache.force_refresh && !cache.refresh_embeddings_only { + let _ = fs::remove_dir_all(&base_dir); + } + let store = ParagraphShardStore::new(base_dir.clone()); + store.ensure_base_dir()?; + + let positive_set: HashSet<&str> = window.positive_ids().collect(); + let embedding_backend_label = embedding.backend_label().to_string(); + let embedding_model_code = embedding.model_code(); + let embedding_dimension = embedding.dimension(); + if positive_set.is_empty() { + return Err(anyhow!( + "window selection contains zero positive paragraphs for slice '{}'", + slice.manifest.slice_id + )); + } + + let mut plan = Vec::new(); + for (idx, entry) in slice.manifest.paragraphs.iter().enumerate() { + let include = match &entry.kind { + SliceParagraphKind::Positive { .. } => positive_set.contains(entry.id.as_str()), + SliceParagraphKind::Negative => true, + }; + if include { + let paragraph = slice + .paragraphs + .get(idx) + .copied() + .ok_or_else(|| anyhow!("slice missing paragraph index {}", idx))?; + plan.push(ParagraphPlan { + slot: plan.len(), + entry, + paragraph, + }); + } + } + + if plan.is_empty() { + return Err(anyhow!( + "no paragraphs selected for ingestion (slice '{}')", + slice.manifest.slice_id + )); + } + + let mut records: Vec> = vec![None; plan.len()]; + let mut ingest_requests = Vec::new(); + let mut stats = IngestionStats::default(); + + for plan_entry in &plan { + let shard_path = plan_entry + .entry + .shard_path + .clone() + .unwrap_or_else(|| slices::default_shard_path(&plan_entry.entry.id)); + let shard = if cache.force_refresh { + None + } else { + store.load(&shard_path, &ingestion_fingerprint)? + }; + if let Some(shard) = shard { + let model_matches = shard.embedding_model.as_deref() == embedding_model_code.as_deref(); + let needs_reembed = shard.embedding_backend != embedding_backend_label + || shard.embedding_dimension != embedding_dimension + || !model_matches; + match plan_entry.entry.kind { + SliceParagraphKind::Positive { .. } => stats.positive_reused += 1, + SliceParagraphKind::Negative => stats.negative_reused += 1, + } + records[plan_entry.slot] = Some(ParagraphShardRecord { + shard, + dirty: false, + needs_reembed, + }); + } else { + match plan_entry.entry.kind { + SliceParagraphKind::Positive { .. } => stats.positive_ingested += 1, + SliceParagraphKind::Negative => stats.negative_ingested += 1, + } + let request = + IngestRequest::from_entry(plan_entry.slot, plan_entry.paragraph, plan_entry.entry)?; + ingest_requests.push(request); + } + } + + if cache.refresh_embeddings_only && !ingest_requests.is_empty() { + return Err(anyhow!( + "--refresh-embeddings requested but {} shard(s) missing for dataset '{}' slice '{}'", + ingest_requests.len(), + dataset.metadata.id, + slice.manifest.slice_id + )); + } + + if !ingest_requests.is_empty() { + let new_shards = ingest_paragraph_batch( + dataset, + &ingest_requests, + embedding, + openai.clone(), + user_id, + &ingestion_fingerprint, + &embedding_backend_label, + embedding_model_code.clone(), + embedding_dimension, + ) + .await + .context("ingesting missing slice paragraphs")?; + for (request, shard) in ingest_requests.into_iter().zip(new_shards.into_iter()) { + store.persist(&shard)?; + records[request.slot] = Some(ParagraphShardRecord { + shard, + dirty: false, + needs_reembed: false, + }); + } + } + + for record in &mut records { + let shard_record = record + .as_mut() + .context("shard record missing after ingestion run")?; + if cache.refresh_embeddings_only || shard_record.needs_reembed { + reembed_entities(&mut shard_record.shard.entities, embedding).await?; + reembed_chunks(&mut shard_record.shard.chunks, embedding).await?; + shard_record.shard.ingestion_fingerprint = ingestion_fingerprint.clone(); + shard_record.shard.ingested_at = Utc::now(); + shard_record.shard.embedding_backend = embedding_backend_label.clone(); + shard_record.shard.embedding_model = embedding_model_code.clone(); + shard_record.shard.embedding_dimension = embedding_dimension; + shard_record.dirty = true; + shard_record.needs_reembed = false; + } + } + + let mut record_index = HashMap::new(); + for (idx, plan_entry) in plan.iter().enumerate() { + record_index.insert(plan_entry.entry.id.as_str(), idx); + } + + let mut corpus_paragraphs = Vec::with_capacity(plan.len()); + for record in &records { + let shard = &record.as_ref().expect("record missing").shard; + corpus_paragraphs.push(shard.to_corpus_paragraph()); + } + + let mut corpus_questions = Vec::with_capacity(window.cases.len()); + for case in &window.cases { + let slot = record_index + .get(case.paragraph.id.as_str()) + .copied() + .ok_or_else(|| { + anyhow!( + "slice case references paragraph '{}' that is not part of the window", + case.paragraph.id + ) + })?; + let record_slot = records + .get_mut(slot) + .context("shard record slot missing for question binding")?; + let record = record_slot + .as_mut() + .context("shard record missing for question binding")?; + let (chunk_ids, updated) = match record.shard.ensure_question_binding(case.question) { + Ok(result) => result, + Err(err) => { + warn!( + question_id = %case.question.id, + paragraph_id = %case.paragraph.id, + error = %err, + "Failed to locate answer text in ingested content; recording empty chunk bindings" + ); + record + .shard + .question_bindings + .insert(case.question.id.clone(), Vec::new()); + record.dirty = true; + (Vec::new(), true) + } + }; + if updated { + record.dirty = true; + } + corpus_questions.push(CorpusQuestion { + question_id: case.question.id.clone(), + paragraph_id: case.paragraph.id.clone(), + text_content_id: record.shard.text_content.id.clone(), + question_text: case.question.question.clone(), + answers: case.question.answers.clone(), + is_impossible: case.question.is_impossible, + matching_chunk_ids: chunk_ids, + }); + } + + for record in &mut records { + if let Some(ref mut entry) = record { + if entry.dirty { + store.persist(&entry.shard)?; + } + } + } + + let manifest = CorpusManifest { + version: MANIFEST_VERSION, + metadata: CorpusMetadata { + dataset_id: dataset.metadata.id.clone(), + dataset_label: dataset.metadata.label.clone(), + slice_id: slice.manifest.slice_id.clone(), + include_unanswerable: slice.manifest.includes_unanswerable, + ingestion_fingerprint: ingestion_fingerprint.clone(), + embedding_backend: embedding.backend_label().to_string(), + embedding_model: embedding.model_code(), + embedding_dimension: embedding.dimension(), + converted_checksum: checksum, + generated_at: Utc::now(), + paragraph_count: corpus_paragraphs.len(), + question_count: corpus_questions.len(), + }, + paragraphs: corpus_paragraphs, + questions: corpus_questions, + }; + + let ingested_count = stats.positive_ingested + stats.negative_ingested; + let reused_ingestion = ingested_count == 0 && !cache.force_refresh; + let reused_embeddings = reused_ingestion && !cache.refresh_embeddings_only; + + Ok(CorpusHandle { + manifest, + path: base_dir, + reused_ingestion, + reused_embeddings, + positive_reused: stats.positive_reused, + positive_ingested: stats.positive_ingested, + negative_reused: stats.negative_reused, + negative_ingested: stats.negative_ingested, + }) +} + +async fn reembed_entities( + entities: &mut [KnowledgeEntity], + embedding: &E, +) -> Result<()> { + if entities.is_empty() { + return Ok(()); + } + let payloads: Vec = entities.iter().map(entity_embedding_text).collect(); + let vectors = embedding.embed_batch(payloads).await?; + if vectors.len() != entities.len() { + return Err(anyhow!( + "entity embedding batch mismatch (expected {}, got {})", + entities.len(), + vectors.len() + )); + } + for (entity, vector) in entities.iter_mut().zip(vectors.into_iter()) { + entity.embedding = vector; + } + Ok(()) +} + +async fn reembed_chunks( + chunks: &mut [TextChunk], + embedding: &E, +) -> Result<()> { + if chunks.is_empty() { + return Ok(()); + } + let payloads: Vec = chunks.iter().map(|chunk| chunk.chunk.clone()).collect(); + let vectors = embedding.embed_batch(payloads).await?; + if vectors.len() != chunks.len() { + return Err(anyhow!( + "chunk embedding batch mismatch (expected {}, got {})", + chunks.len(), + vectors.len() + )); + } + for (chunk, vector) in chunks.iter_mut().zip(vectors.into_iter()) { + chunk.embedding = vector; + } + Ok(()) +} + +fn entity_embedding_text(entity: &KnowledgeEntity) -> String { + format!( + "name: {}\ndescription: {}\ntype: {:?}", + entity.name, entity.description, entity.entity_type + ) +} + +async fn ingest_paragraph_batch( + dataset: &ConvertedDataset, + targets: &[IngestRequest<'_>], + embedding: &E, + openai: Arc, + user_id: &str, + ingestion_fingerprint: &str, + embedding_backend: &str, + embedding_model: Option, + embedding_dimension: usize, +) -> Result> { + if targets.is_empty() { + return Ok(Vec::new()); + } + let namespace = format!("ingest_eval_{}", Uuid::new_v4()); + let db = Arc::new( + SurrealDbClient::memory(&namespace, "corpus") + .await + .context("creating ingestion SurrealDB instance")?, + ); + db.apply_migrations() + .await + .context("applying migrations for ingestion")?; + + let mut app_config = AppConfig::default(); + app_config.storage = StorageKind::Memory; + let backend: DynStore = Arc::new(InMemory::new()); + let storage = StorageManager::with_backend(backend, StorageKind::Memory); + + let pipeline = IngestionPipeline::new( + db, + openai.clone(), + app_config, + None::>, + storage, + ) + .await?; + let pipeline = Arc::new(pipeline); + + let mut shards = Vec::with_capacity(targets.len()); + let category = dataset.metadata.category.clone(); + for (batch_index, batch) in targets.chunks(INGESTION_BATCH_SIZE).enumerate() { + info!( + batch = batch_index, + batch_size = batch.len(), + total_batches = (targets.len() + INGESTION_BATCH_SIZE - 1) / INGESTION_BATCH_SIZE, + "Ingesting paragraph batch" + ); + let model_clone = embedding_model.clone(); + let backend_clone = embedding_backend.to_string(); + let pipeline_clone = pipeline.clone(); + let category_clone = category.clone(); + let tasks = batch.iter().cloned().map(move |request| { + ingest_single_paragraph( + pipeline_clone.clone(), + request, + category_clone.clone(), + embedding, + user_id, + ingestion_fingerprint, + backend_clone.clone(), + model_clone.clone(), + embedding_dimension, + ) + }); + let batch_results: Vec = try_join_all(tasks) + .await + .context("ingesting batch of paragraphs")?; + shards.extend(batch_results); + } + + Ok(shards) +} + +async fn ingest_single_paragraph( + pipeline: Arc, + request: IngestRequest<'_>, + category: String, + embedding: &E, + user_id: &str, + ingestion_fingerprint: &str, + embedding_backend: String, + embedding_model: Option, + embedding_dimension: usize, +) -> Result { + let paragraph = request.paragraph; + let mut last_err: Option = None; + for attempt in 1..=INGESTION_MAX_RETRIES { + let payload = IngestionPayload::Text { + text: paragraph.context.clone(), + context: paragraph.title.clone(), + category: category.clone(), + user_id: user_id.to_string(), + }; + let task = IngestionTask::new(payload, user_id.to_string()); + match pipeline.produce_artifacts(&task).await { + Ok(mut artifacts) => { + reembed_entities(&mut artifacts.entities, embedding).await?; + reembed_chunks(&mut artifacts.chunks, embedding).await?; + let mut shard = ParagraphShard::new( + paragraph, + request.shard_path, + ingestion_fingerprint, + artifacts.text_content, + artifacts.entities, + artifacts.relationships, + artifacts.chunks, + &embedding_backend, + embedding_model.clone(), + embedding_dimension, + ); + for question in &request.question_refs { + if let Err(err) = shard.ensure_question_binding(question) { + warn!( + question_id = %question.id, + paragraph_id = %paragraph.id, + error = %err, + "Failed to locate answer text in ingested content; recording empty chunk bindings" + ); + shard + .question_bindings + .insert(question.id.clone(), Vec::new()); + } + } + return Ok(shard); + } + Err(err) => { + warn!( + paragraph_id = %paragraph.id, + attempt, + max_attempts = INGESTION_MAX_RETRIES, + error = ?err, + "ingestion attempt failed for paragraph; retrying" + ); + last_err = Some(err.into()); + } + } + } + + Err(last_err + .unwrap_or_else(|| anyhow!("ingestion failed")) + .context(format!("running ingestion for paragraph {}", paragraph.id))) +} + +fn validate_answers( + content: &TextContent, + chunks: &[TextChunk], + question: &ConvertedQuestion, +) -> Result> { + if question.is_impossible || question.answers.is_empty() { + return Ok(Vec::new()); + } + + let mut matches = std::collections::BTreeSet::new(); + let mut found_any = false; + let haystack = content.text.to_ascii_lowercase(); + let haystack_norm = normalize_answer_text(&haystack); + for answer in &question.answers { + let needle: String = answer.to_ascii_lowercase(); + let needle_norm = normalize_answer_text(&needle); + let text_match = haystack.contains(&needle) + || (!needle_norm.is_empty() && haystack_norm.contains(&needle_norm)); + if text_match { + found_any = true; + } + for chunk in chunks { + let chunk_text = chunk.chunk.to_ascii_lowercase(); + let chunk_norm = normalize_answer_text(&chunk_text); + if chunk_text.contains(&needle) + || (!needle_norm.is_empty() && chunk_norm.contains(&needle_norm)) + { + matches.insert(chunk.id.clone()); + found_any = true; + } + } + } + + if !found_any { + Err(anyhow!( + "expected answer for question '{}' was not found in ingested content", + question.id + )) + } else { + Ok(matches.into_iter().collect()) + } +} + +fn build_ingestion_fingerprint( + dataset: &ConvertedDataset, + slice: &ResolvedSlice<'_>, + checksum: &str, +) -> String { + let config_repr = format!("{:?}", IngestionConfig::default()); + let mut hasher = Sha256::new(); + hasher.update(config_repr.as_bytes()); + let config_hash = format!("{:x}", hasher.finalize()); + + format!( + "v{INGESTION_SPEC_VERSION}:{}:{}:{}:{}:{}", + dataset.metadata.id, + slice.manifest.slice_id, + slice.manifest.includes_unanswerable, + checksum, + config_hash + ) +} + +fn compute_file_checksum(path: &Path) -> Result { + let mut file = fs::File::open(path) + .with_context(|| format!("opening file {} for checksum", path.display()))?; + let mut hasher = Sha256::new(); + let mut buffer = [0u8; 8192]; + loop { + let read = file + .read(&mut buffer) + .with_context(|| format!("reading {} for checksum", path.display()))?; + if read == 0 { + break; + } + hasher.update(&buffer[..read]); + } + Ok(format!("{:x}", hasher.finalize())) +} + +pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManifest) -> Result<()> { + for paragraph in &manifest.paragraphs { + db.store_item(paragraph.text_content.clone()) + .await + .context("storing text_content from manifest")?; + for entity in ¶graph.entities { + db.store_item(entity.clone()) + .await + .context("storing knowledge_entity from manifest")?; + } + for relationship in ¶graph.relationships { + relationship + .store_relationship(db) + .await + .context("storing knowledge_relationship from manifest")?; + } + for chunk in ¶graph.chunks { + db.store_item(chunk.clone()) + .await + .context("storing text_chunk from manifest")?; + } + } + + Ok(()) +} + +fn normalize_answer_text(text: &str) -> String { + text.chars() + .map(|ch| { + if ch.is_alphanumeric() || ch.is_whitespace() { + ch.to_ascii_lowercase() + } else { + ' ' + } + }) + .collect::() + .split_whitespace() + .collect::>() + .join(" ") +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::datasets::ConvertedQuestion; + + fn mock_text_content() -> TextContent { + TextContent { + id: "tc1".into(), + created_at: Utc::now(), + updated_at: Utc::now(), + text: "alpha beta gamma".into(), + file_info: None, + url_info: None, + context: Some("ctx".into()), + category: "cat".into(), + user_id: "user".into(), + } + } + + fn mock_chunk(id: &str, text: &str) -> TextChunk { + TextChunk { + id: id.into(), + created_at: Utc::now(), + updated_at: Utc::now(), + source_id: "src".into(), + chunk: text.into(), + embedding: vec![], + user_id: "user".into(), + } + } + + #[test] + fn validate_answers_passes_when_present() { + let content = mock_text_content(); + let chunk = mock_chunk("chunk1", "alpha chunk"); + let question = ConvertedQuestion { + id: "q1".into(), + question: "?".into(), + answers: vec!["Alpha".into()], + is_impossible: false, + }; + let matches = validate_answers(&content, &[chunk], &question).expect("answers match"); + assert_eq!(matches, vec!["chunk1".to_string()]); + } + + #[test] + fn validate_answers_fails_when_missing() { + let question = ConvertedQuestion { + id: "q1".into(), + question: "?".into(), + answers: vec!["delta".into()], + is_impossible: false, + }; + let err = validate_answers( + &mock_text_content(), + &[mock_chunk("chunk", "alpha")], + &question, + ) + .expect_err("missing answer should fail"); + assert!(err.to_string().contains("not found")); + } +} diff --git a/eval/src/inspection.rs b/eval/src/inspection.rs new file mode 100644 index 0000000..e2fb353 --- /dev/null +++ b/eval/src/inspection.rs @@ -0,0 +1,182 @@ +use std::{ + collections::HashMap, + fs, + path::{Path, PathBuf}, +}; + +use anyhow::{anyhow, Context, Result}; +use common::storage::{db::SurrealDbClient, types::text_chunk::TextChunk}; + +use crate::{args::Config, eval::connect_eval_db, ingest, snapshot::DbSnapshotState}; + +pub async fn inspect_question(config: &Config) -> Result<()> { + let question_id = config + .inspect_question + .as_ref() + .ok_or_else(|| anyhow!("--inspect-question is required for inspection mode"))?; + let manifest_path = config + .inspect_manifest + .as_ref() + .ok_or_else(|| anyhow!("--inspect-manifest must be provided for inspection mode"))?; + + let manifest = load_manifest(manifest_path)?; + let chunk_lookup = build_chunk_lookup(&manifest); + + let question = manifest + .questions + .iter() + .find(|q| q.question_id == *question_id) + .ok_or_else(|| { + anyhow!( + "question '{}' not found in manifest {}", + question_id, + manifest_path.display() + ) + })?; + + println!("Question: {}", question.question_text); + println!("Answers: {:?}", question.answers); + println!( + "matching_chunk_ids ({}):", + question.matching_chunk_ids.len() + ); + + let mut missing_in_manifest = Vec::new(); + for chunk_id in &question.matching_chunk_ids { + if let Some(entry) = chunk_lookup.get(chunk_id) { + println!( + " - {} (paragraph: {})\n snippet: {}", + chunk_id, entry.paragraph_title, entry.snippet + ); + } else { + println!(" - {} (missing from manifest)", chunk_id); + missing_in_manifest.push(chunk_id.clone()); + } + } + + if missing_in_manifest.is_empty() { + println!("All matching_chunk_ids are present in the ingestion manifest"); + } else { + println!( + "Missing chunk IDs in manifest {}: {:?}", + manifest_path.display(), + missing_in_manifest + ); + } + + let db_state_path = config + .inspect_db_state + .clone() + .unwrap_or_else(|| default_state_path(config, &manifest)); + if let Some(state) = load_db_state(&db_state_path)? { + if let (Some(ns), Some(db_name)) = (state.namespace.as_deref(), state.database.as_deref()) { + match connect_eval_db(config, ns, db_name).await { + Ok(db) => match verify_chunks_in_db(&db, &question.matching_chunk_ids).await? { + MissingChunks::None => println!( + "All matching_chunk_ids exist in namespace '{}', database '{}'", + ns, db_name + ), + MissingChunks::Missing(list) => println!( + "Missing chunks in namespace '{}', database '{}': {:?}", + ns, db_name, list + ), + }, + Err(err) => { + println!( + "Failed to connect to SurrealDB namespace '{}' / database '{}': {err}", + ns, db_name + ); + } + } + } else { + println!( + "State file {} is missing namespace/database fields; skipping live DB validation", + db_state_path.display() + ); + } + } else { + println!( + "State file {} not found; skipping live DB validation", + db_state_path.display() + ); + } + + Ok(()) +} + +struct ChunkEntry { + paragraph_title: String, + snippet: String, +} + +fn load_manifest(path: &Path) -> Result { + let bytes = + fs::read(path).with_context(|| format!("reading ingestion manifest {}", path.display()))?; + serde_json::from_slice(&bytes) + .with_context(|| format!("parsing ingestion manifest {}", path.display())) +} + +fn build_chunk_lookup(manifest: &ingest::CorpusManifest) -> HashMap { + let mut lookup = HashMap::new(); + for paragraph in &manifest.paragraphs { + for chunk in ¶graph.chunks { + let snippet = chunk + .chunk + .chars() + .take(160) + .collect::() + .replace('\n', " "); + lookup.insert( + chunk.id.clone(), + ChunkEntry { + paragraph_title: paragraph.title.clone(), + snippet, + }, + ); + } + } + lookup +} + +fn default_state_path(config: &Config, manifest: &ingest::CorpusManifest) -> PathBuf { + config + .cache_dir + .join("snapshots") + .join(&manifest.metadata.dataset_id) + .join(&manifest.metadata.slice_id) + .join("db/state.json") +} + +fn load_db_state(path: &Path) -> Result> { + if !path.exists() { + return Ok(None); + } + let bytes = fs::read(path).with_context(|| format!("reading db state {}", path.display()))?; + let state = serde_json::from_slice(&bytes) + .with_context(|| format!("parsing db state {}", path.display()))?; + Ok(Some(state)) +} + +enum MissingChunks { + None, + Missing(Vec), +} + +async fn verify_chunks_in_db(db: &SurrealDbClient, chunk_ids: &[String]) -> Result { + let mut missing = Vec::new(); + for chunk_id in chunk_ids { + let exists = db + .get_item::(chunk_id) + .await + .with_context(|| format!("fetching text_chunk {}", chunk_id))? + .is_some(); + if !exists { + missing.push(chunk_id.clone()); + } + } + if missing.is_empty() { + Ok(MissingChunks::None) + } else { + Ok(MissingChunks::Missing(missing)) + } +} diff --git a/eval/src/main.rs b/eval/src/main.rs new file mode 100644 index 0000000..ed35a8f --- /dev/null +++ b/eval/src/main.rs @@ -0,0 +1,214 @@ +mod args; +mod cache; +mod datasets; +mod db_helpers; +mod embedding; +mod eval; +mod ingest; +mod inspection; +mod openai; +mod perf; +mod report; +mod slice; +mod slices; +mod snapshot; + +use anyhow::Context; +use tokio::runtime::Builder; +use tracing::info; +use tracing_subscriber::{fmt, EnvFilter}; + +/// Configure SurrealDB environment variables for optimal performance +fn configure_surrealdb_performance(cpu_count: usize) { + // Set environment variables only if they're not already set + let indexing_batch_size = std::env::var("SURREAL_INDEXING_BATCH_SIZE") + .unwrap_or_else(|_| (cpu_count * 2).to_string()); + std::env::set_var("SURREAL_INDEXING_BATCH_SIZE", indexing_batch_size); + + let max_order_queue = std::env::var("SURREAL_MAX_ORDER_LIMIT_PRIORITY_QUEUE_SIZE") + .unwrap_or_else(|_| (cpu_count * 4).to_string()); + std::env::set_var( + "SURREAL_MAX_ORDER_LIMIT_PRIORITY_QUEUE_SIZE", + max_order_queue, + ); + + let websocket_concurrent = std::env::var("SURREAL_WEBSOCKET_MAX_CONCURRENT_REQUESTS") + .unwrap_or_else(|_| cpu_count.to_string()); + std::env::set_var( + "SURREAL_WEBSOCKET_MAX_CONCURRENT_REQUESTS", + websocket_concurrent, + ); + + let websocket_buffer = std::env::var("SURREAL_WEBSOCKET_RESPONSE_BUFFER_SIZE") + .unwrap_or_else(|_| (cpu_count * 8).to_string()); + std::env::set_var("SURREAL_WEBSOCKET_RESPONSE_BUFFER_SIZE", websocket_buffer); + + let transaction_cache = std::env::var("SURREAL_TRANSACTION_CACHE_SIZE") + .unwrap_or_else(|_| (cpu_count * 16).to_string()); + std::env::set_var("SURREAL_TRANSACTION_CACHE_SIZE", transaction_cache); + + info!( + indexing_batch_size = %std::env::var("SURREAL_INDEXING_BATCH_SIZE").unwrap(), + max_order_queue = %std::env::var("SURREAL_MAX_ORDER_LIMIT_PRIORITY_QUEUE_SIZE").unwrap(), + websocket_concurrent = %std::env::var("SURREAL_WEBSOCKET_MAX_CONCURRENT_REQUESTS").unwrap(), + websocket_buffer = %std::env::var("SURREAL_WEBSOCKET_RESPONSE_BUFFER_SIZE").unwrap(), + transaction_cache = %std::env::var("SURREAL_TRANSACTION_CACHE_SIZE").unwrap(), + "Configured SurrealDB performance variables" + ); +} + +fn main() -> anyhow::Result<()> { + // Create an explicit multi-threaded runtime with optimized configuration + let runtime = Builder::new_multi_thread() + .enable_all() + .worker_threads(std::thread::available_parallelism()?.get()) + .max_blocking_threads(std::thread::available_parallelism()?.get()) + .thread_stack_size(10 * 1024 * 1024) // 10MiB stack size + .thread_name("eval-retrieval-worker") + .build() + .context("failed to create tokio runtime")?; + + runtime.block_on(async_main()) +} + +async fn async_main() -> anyhow::Result<()> { + // Log runtime configuration + let cpu_count = std::thread::available_parallelism()?.get(); + info!( + cpu_cores = cpu_count, + worker_threads = cpu_count, + blocking_threads = cpu_count, + thread_stack_size = "10MiB", + "Started multi-threaded tokio runtime" + ); + + // Configure SurrealDB environment variables for better performance + configure_surrealdb_performance(cpu_count); + + let filter = std::env::var("RUST_LOG").unwrap_or_else(|_| "info".to_string()); + let _ = fmt() + .with_env_filter(EnvFilter::try_new(&filter).unwrap_or_else(|_| EnvFilter::new("info"))) + .try_init(); + + let parsed = args::parse()?; + + if parsed.show_help { + args::print_help(); + return Ok(()); + } + + if parsed.config.inspect_question.is_some() { + inspection::inspect_question(&parsed.config).await?; + return Ok(()); + } + + let dataset_kind = parsed.config.dataset; + + if parsed.config.convert_only { + info!( + dataset = dataset_kind.id(), + "Starting dataset conversion only run" + ); + let dataset = crate::datasets::convert( + parsed.config.raw_dataset_path.as_path(), + dataset_kind, + parsed.config.llm_mode, + parsed.config.context_token_limit(), + ) + .with_context(|| { + format!( + "converting {} dataset at {}", + dataset_kind.label(), + parsed.config.raw_dataset_path.display() + ) + })?; + crate::datasets::write_converted(&dataset, parsed.config.converted_dataset_path.as_path()) + .with_context(|| { + format!( + "writing converted dataset to {}", + parsed.config.converted_dataset_path.display() + ) + })?; + println!( + "Converted dataset written to {}", + parsed.config.converted_dataset_path.display() + ); + return Ok(()); + } + + info!(dataset = dataset_kind.id(), "Preparing converted dataset"); + let dataset = crate::datasets::ensure_converted( + dataset_kind, + parsed.config.raw_dataset_path.as_path(), + parsed.config.converted_dataset_path.as_path(), + parsed.config.force_convert, + parsed.config.llm_mode, + parsed.config.context_token_limit(), + ) + .with_context(|| { + format!( + "preparing converted dataset at {}", + parsed.config.converted_dataset_path.display() + ) + })?; + + info!( + questions = dataset + .paragraphs + .iter() + .map(|p| p.questions.len()) + .sum::(), + paragraphs = dataset.paragraphs.len(), + dataset = dataset.metadata.id.as_str(), + "Dataset ready" + ); + + if parsed.config.slice_grow.is_some() { + eval::grow_slice(&dataset, &parsed.config) + .await + .context("growing slice ledger")?; + return Ok(()); + } + + info!("Running retrieval evaluation"); + let summary = eval::run_evaluation(&dataset, &parsed.config) + .await + .context("running retrieval evaluation")?; + + let report_paths = report::write_reports( + &summary, + parsed.config.report_dir.as_path(), + parsed.config.summary_sample, + ) + .with_context(|| format!("writing reports to {}", parsed.config.report_dir.display()))?; + let perf_log_path = perf::write_perf_logs( + &summary, + parsed.config.report_dir.as_path(), + parsed.config.perf_log_json.as_deref(), + parsed.config.perf_log_dir.as_deref(), + ) + .with_context(|| { + format!( + "writing perf logs under {}", + parsed.config.report_dir.display() + ) + })?; + + println!( + "[{}] Precision@{k}: {precision:.3} ({correct}/{total}) → JSON: {json} | Markdown: {md} | Perf: {perf}", + summary.dataset_label, + k = summary.k, + precision = summary.precision, + correct = summary.correct, + total = summary.total_cases, + json = report_paths.json.display(), + md = report_paths.markdown.display(), + perf = perf_log_path.display() + ); + + if parsed.config.perf_log_console { + perf::print_console_summary(&summary); + } + + Ok(()) +} diff --git a/eval/src/openai.rs b/eval/src/openai.rs new file mode 100644 index 0000000..7c5e644 --- /dev/null +++ b/eval/src/openai.rs @@ -0,0 +1,16 @@ +use anyhow::{Context, Result}; +use async_openai::{config::OpenAIConfig, Client}; + +const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1"; + +pub fn build_client_from_env() -> Result<(Client, String)> { + let api_key = std::env::var("OPENAI_API_KEY") + .context("OPENAI_API_KEY must be set to run retrieval evaluations")?; + let base_url = + std::env::var("OPENAI_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string()); + + let config = OpenAIConfig::new() + .with_api_key(api_key) + .with_api_base(&base_url); + Ok((Client::with_config(config), base_url)) +} diff --git a/eval/src/perf.rs b/eval/src/perf.rs new file mode 100644 index 0000000..7315df2 --- /dev/null +++ b/eval/src/perf.rs @@ -0,0 +1,313 @@ +use std::{ + fs::{self, OpenOptions}, + io::Write, + path::{Path, PathBuf}, +}; + +use anyhow::{Context, Result}; +use serde::Serialize; + +use crate::{ + args, + eval::{format_timestamp, EvaluationStageTimings, EvaluationSummary}, + report, +}; + +#[derive(Debug, Serialize)] +struct PerformanceLogEntry { + generated_at: String, + dataset_id: String, + dataset_label: String, + run_label: Option, + slice_id: String, + slice_seed: u64, + slice_window_offset: usize, + slice_window_length: usize, + limit: Option, + total_cases: usize, + correct: usize, + precision: f64, + k: usize, + openai_base_url: String, + ingestion: IngestionPerf, + namespace: NamespacePerf, + retrieval: RetrievalPerf, + evaluation_stages: EvaluationStageTimings, +} + +#[derive(Debug, Serialize)] +struct IngestionPerf { + duration_ms: u128, + cache_path: String, + reused: bool, + embeddings_reused: bool, + fingerprint: String, + positives_total: usize, + negatives_total: usize, +} + +#[derive(Debug, Serialize)] +struct NamespacePerf { + reused: bool, + seed_ms: Option, +} + +#[derive(Debug, Serialize)] +struct RetrievalPerf { + latency_ms: crate::eval::LatencyStats, + stage_latency: crate::eval::StageLatencyBreakdown, + concurrency: usize, + rerank_enabled: bool, + rerank_pool_size: Option, + rerank_keep_top: usize, + evaluated_cases: usize, +} + +impl PerformanceLogEntry { + fn from_summary(summary: &EvaluationSummary) -> Self { + let ingestion = IngestionPerf { + duration_ms: summary.perf.ingestion_ms, + cache_path: summary.ingestion_cache_path.clone(), + reused: summary.ingestion_reused, + embeddings_reused: summary.ingestion_embeddings_reused, + fingerprint: summary.ingestion_fingerprint.clone(), + positives_total: summary.slice_positive_paragraphs, + negatives_total: summary.slice_negative_paragraphs, + }; + + let namespace = NamespacePerf { + reused: summary.namespace_reused, + seed_ms: summary.perf.namespace_seed_ms, + }; + + let retrieval = RetrievalPerf { + latency_ms: summary.latency_ms.clone(), + stage_latency: summary.perf.stage_latency.clone(), + concurrency: summary.concurrency, + rerank_enabled: summary.rerank_enabled, + rerank_pool_size: summary.rerank_pool_size, + rerank_keep_top: summary.rerank_keep_top, + evaluated_cases: summary.total_cases, + }; + + Self { + generated_at: format_timestamp(&summary.generated_at), + dataset_id: summary.dataset_id.clone(), + dataset_label: summary.dataset_label.clone(), + run_label: summary.run_label.clone(), + slice_id: summary.slice_id.clone(), + slice_seed: summary.slice_seed, + slice_window_offset: summary.slice_window_offset, + slice_window_length: summary.slice_window_length, + limit: summary.limit, + total_cases: summary.total_cases, + correct: summary.correct, + precision: summary.precision, + k: summary.k, + openai_base_url: summary.perf.openai_base_url.clone(), + ingestion, + namespace, + retrieval, + evaluation_stages: summary.perf.evaluation_stage_ms.clone(), + } + } +} + +pub fn write_perf_logs( + summary: &EvaluationSummary, + report_root: &Path, + extra_json: Option<&Path>, + extra_dir: Option<&Path>, +) -> Result { + let entry = PerformanceLogEntry::from_summary(summary); + let dataset_dir = report::dataset_report_dir(report_root, &summary.dataset_id); + fs::create_dir_all(&dataset_dir) + .with_context(|| format!("creating dataset perf directory {}", dataset_dir.display()))?; + + let log_path = dataset_dir.join("perf-log.jsonl"); + let mut file = OpenOptions::new() + .create(true) + .append(true) + .open(&log_path) + .with_context(|| format!("opening perf log {}", log_path.display()))?; + let line = serde_json::to_vec(&entry).context("serialising perf log entry")?; + file.write_all(&line)?; + file.write_all(b"\n")?; + file.flush()?; + + if let Some(path) = extra_json { + args::ensure_parent(path)?; + let blob = serde_json::to_vec_pretty(&entry).context("serialising perf log JSON")?; + fs::write(path, blob) + .with_context(|| format!("writing perf log copy to {}", path.display()))?; + } + + if let Some(dir) = extra_dir { + fs::create_dir_all(dir) + .with_context(|| format!("creating perf log directory {}", dir.display()))?; + let dataset_slug = dataset_dir + .file_name() + .and_then(|os| os.to_str()) + .unwrap_or("dataset"); + let timestamp = summary.generated_at.format("%Y%m%dT%H%M%S").to_string(); + let filename = format!("perf-{}-{}.json", dataset_slug, timestamp); + let path = dir.join(filename); + let blob = serde_json::to_vec_pretty(&entry).context("serialising perf log JSON")?; + fs::write(&path, blob) + .with_context(|| format!("writing perf log mirror {}", path.display()))?; + } + + Ok(log_path) +} + +pub fn print_console_summary(summary: &EvaluationSummary) { + let perf = &summary.perf; + println!( + "[perf] ingestion={}ms | namespace_seed={}", + perf.ingestion_ms, + format_duration(perf.namespace_seed_ms), + ); + let stage = &perf.stage_latency; + println!( + "[perf] stage avg ms → collect {:.1} | graph {:.1} | chunk {:.1} | rerank {:.1} | assemble {:.1}", + stage.collect_candidates.avg, + stage.graph_expansion.avg, + stage.chunk_attach.avg, + stage.rerank.avg, + stage.assemble.avg, + ); + let eval = &perf.evaluation_stage_ms; + println!( + "[perf] eval stage ms → slice {} | db {} | corpus {} | namespace {} | queries {} | summarize {} | finalize {}", + eval.prepare_slice_ms, + eval.prepare_db_ms, + eval.prepare_corpus_ms, + eval.prepare_namespace_ms, + eval.run_queries_ms, + eval.summarize_ms, + eval.finalize_ms, + ); +} + +fn format_duration(value: Option) -> String { + value + .map(|ms| format!("{ms}ms")) + .unwrap_or_else(|| "-".to_string()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::eval::{EvaluationStageTimings, PerformanceTimings}; + use chrono::Utc; + use tempfile::tempdir; + + fn sample_latency() -> crate::eval::LatencyStats { + crate::eval::LatencyStats { + avg: 10.0, + p50: 8, + p95: 15, + } + } + + fn sample_stage_latency() -> crate::eval::StageLatencyBreakdown { + crate::eval::StageLatencyBreakdown { + collect_candidates: sample_latency(), + graph_expansion: sample_latency(), + chunk_attach: sample_latency(), + rerank: sample_latency(), + assemble: sample_latency(), + } + } + + fn sample_eval_stage() -> EvaluationStageTimings { + EvaluationStageTimings { + prepare_slice_ms: 10, + prepare_db_ms: 20, + prepare_corpus_ms: 30, + prepare_namespace_ms: 40, + run_queries_ms: 50, + summarize_ms: 60, + finalize_ms: 70, + } + } + + fn sample_summary() -> EvaluationSummary { + EvaluationSummary { + generated_at: Utc::now(), + k: 5, + limit: Some(10), + run_label: Some("test".into()), + total_cases: 2, + correct: 1, + precision: 0.5, + correct_at_1: 1, + correct_at_2: 1, + correct_at_3: 1, + precision_at_1: 0.5, + precision_at_2: 0.5, + precision_at_3: 0.5, + duration_ms: 1234, + dataset_id: "squad-v2".into(), + dataset_label: "SQuAD v2".into(), + dataset_includes_unanswerable: false, + dataset_source: "dev".into(), + slice_id: "slice123".into(), + slice_seed: 42, + slice_total_cases: 400, + slice_window_offset: 0, + slice_window_length: 10, + slice_cases: 10, + slice_positive_paragraphs: 10, + slice_negative_paragraphs: 40, + slice_total_paragraphs: 50, + slice_negative_multiplier: 4.0, + namespace_reused: true, + corpus_paragraphs: 50, + ingestion_cache_path: "/tmp/cache".into(), + ingestion_reused: true, + ingestion_embeddings_reused: true, + ingestion_fingerprint: "fingerprint".into(), + positive_paragraphs_reused: 10, + negative_paragraphs_reused: 40, + latency_ms: sample_latency(), + perf: PerformanceTimings { + openai_base_url: "https://example.com".into(), + ingestion_ms: 1000, + namespace_seed_ms: Some(150), + evaluation_stage_ms: sample_eval_stage(), + stage_latency: sample_stage_latency(), + }, + embedding_backend: "fastembed".into(), + embedding_model: Some("test-model".into()), + embedding_dimension: 32, + rerank_enabled: true, + rerank_pool_size: Some(4), + rerank_keep_top: 10, + concurrency: 2, + detailed_report: false, + chunk_vector_take: 20, + chunk_fts_take: 20, + chunk_token_budget: 10000, + chunk_avg_chars_per_token: 4, + max_chunks_per_entity: 4, + cases: Vec::new(), + } + } + + #[test] + fn writes_perf_log_jsonl() { + let tmp = tempdir().unwrap(); + let report_root = tmp.path().join("reports"); + let summary = sample_summary(); + let log_path = write_perf_logs(&summary, &report_root, None, None).expect("perf log write"); + assert!(log_path.exists()); + let contents = std::fs::read_to_string(&log_path).expect("reading perf log jsonl"); + assert!( + contents.contains("\"openai_base_url\":\"https://example.com\""), + "serialized log should include base URL" + ); + let dataset_dir = report::dataset_report_dir(&report_root, &summary.dataset_id); + assert!(dataset_dir.join("perf-log.jsonl").exists()); + } +} diff --git a/eval/src/report.rs b/eval/src/report.rs new file mode 100644 index 0000000..b6990af --- /dev/null +++ b/eval/src/report.rs @@ -0,0 +1,456 @@ +use std::{ + fs, + path::{Path, PathBuf}, +}; + +use anyhow::{Context, Result}; +use serde::{Deserialize, Serialize}; + +use crate::eval::{format_timestamp, CaseSummary, EvaluationSummary, LatencyStats}; + +#[derive(Debug)] +pub struct ReportPaths { + pub json: PathBuf, + pub markdown: PathBuf, +} + +pub fn write_reports( + summary: &EvaluationSummary, + report_dir: &Path, + sample: usize, +) -> Result { + fs::create_dir_all(report_dir) + .with_context(|| format!("creating report directory {}", report_dir.display()))?; + let dataset_dir = dataset_report_dir(report_dir, &summary.dataset_id); + fs::create_dir_all(&dataset_dir).with_context(|| { + format!( + "creating dataset report directory {}", + dataset_dir.display() + ) + })?; + + let stem = build_report_stem(summary); + + let json_path = dataset_dir.join(format!("{stem}.json")); + let json_blob = serde_json::to_string_pretty(summary).context("serialising JSON report")?; + fs::write(&json_path, &json_blob) + .with_context(|| format!("writing JSON report to {}", json_path.display()))?; + + let md_path = dataset_dir.join(format!("{stem}.md")); + let markdown = render_markdown(summary, sample); + fs::write(&md_path, &markdown) + .with_context(|| format!("writing Markdown report to {}", md_path.display()))?; + + // Keep a latest.json pointer to simplify automation. + let latest_json = dataset_dir.join("latest.json"); + fs::write(&latest_json, json_blob) + .with_context(|| format!("writing latest JSON report to {}", latest_json.display()))?; + let latest_md = dataset_dir.join("latest.md"); + fs::write(&latest_md, markdown) + .with_context(|| format!("writing latest Markdown report to {}", latest_md.display()))?; + + record_history(summary, &dataset_dir)?; + + Ok(ReportPaths { + json: json_path, + markdown: md_path, + }) +} + +fn render_markdown(summary: &EvaluationSummary, sample: usize) -> String { + let mut md = String::new(); + + md.push_str(&format!("# Retrieval Precision@{}\n\n", summary.k)); + md.push_str("| Metric | Value |\n"); + md.push_str("| --- | --- |\n"); + md.push_str(&format!( + "| Generated | {} |\n", + format_timestamp(&summary.generated_at) + )); + md.push_str(&format!( + "| Dataset | {} (`{}`) |\n", + summary.dataset_label, summary.dataset_id + )); + md.push_str(&format!( + "| Run Label | {} |\n", + summary + .run_label + .as_deref() + .filter(|label| !label.is_empty()) + .unwrap_or("-") + )); + md.push_str(&format!( + "| Unanswerable Included | {} |\n", + if summary.dataset_includes_unanswerable { + "yes" + } else { + "no" + } + )); + md.push_str(&format!( + "| Dataset Source | {} |\n", + summary.dataset_source + )); + md.push_str(&format!( + "| OpenAI Base URL | {} |\n", + summary.perf.openai_base_url + )); + md.push_str(&format!("| Slice ID | `{}` |\n", summary.slice_id)); + md.push_str(&format!("| Slice Seed | {} |\n", summary.slice_seed)); + md.push_str(&format!( + "| Slice Total Questions | {} |\n", + summary.slice_total_cases + )); + md.push_str(&format!( + "| Slice Window (offset/length) | {}/{} |\n", + summary.slice_window_offset, summary.slice_window_length + )); + md.push_str(&format!( + "| Slice Window Questions | {} |\n", + summary.slice_cases + )); + md.push_str(&format!( + "| Slice Negatives | {} |\n", + summary.slice_negative_paragraphs + )); + md.push_str(&format!( + "| Slice Total Paragraphs | {} |\n", + summary.slice_total_paragraphs + )); + md.push_str(&format!( + "| Slice Negative Multiplier | {:.2} |\n", + summary.slice_negative_multiplier + )); + md.push_str(&format!( + "| Namespace State | {} |\n", + if summary.namespace_reused { + "reused" + } else { + "seeded" + } + )); + md.push_str(&format!( + "| Corpus Paragraphs | {} |\n", + summary.corpus_paragraphs + )); + md.push_str(&format!( + "| Ingestion Duration | {} ms |\n", + summary.perf.ingestion_ms + )); + if let Some(seed) = summary.perf.namespace_seed_ms { + md.push_str(&format!("| Namespace Seed | {} ms |\n", seed)); + } + if summary.detailed_report { + md.push_str(&format!( + "| Ingestion Cache | `{}` |\n", + summary.ingestion_cache_path + )); + md.push_str(&format!( + "| Ingestion Reused | {} |\n", + if summary.ingestion_reused { + "yes" + } else { + "no" + } + )); + md.push_str(&format!( + "| Embeddings Reused | {} |\n", + if summary.ingestion_embeddings_reused { + "yes" + } else { + "no" + } + )); + } + md.push_str(&format!( + "| Positives Cached | {} | +", + summary.positive_paragraphs_reused + )); + md.push_str(&format!( + "| Negatives Cached | {} | +", + summary.negative_paragraphs_reused + )); + let embedding_label = if let Some(model) = summary.embedding_model.as_ref() { + format!("{} ({model})", summary.embedding_backend) + } else { + summary.embedding_backend.clone() + }; + md.push_str(&format!("| Embedding | {} |\n", embedding_label)); + md.push_str(&format!( + "| Embedding Dim | {} |\n", + summary.embedding_dimension + )); + if let Some(limit) = summary.limit { + md.push_str(&format!( + "| Evaluated Queries | {} (limit {}) |\n", + summary.total_cases, limit + )); + } else { + md.push_str(&format!( + "| Evaluated Queries | {} |\n", + summary.total_cases + )); + } + if summary.rerank_enabled { + let pool = summary + .rerank_pool_size + .map(|size| size.to_string()) + .unwrap_or_else(|| "?".to_string()); + md.push_str(&format!( + "| Rerank | enabled (pool {pool}, keep top {}) |\n", + summary.rerank_keep_top + )); + } else { + md.push_str("| Rerank | disabled |\n"); + } + md.push_str(&format!("| Concurrency | {} |\n", summary.concurrency)); + md.push_str(&format!( + "| Correct@{} | {}/{} |\n", + summary.k, summary.correct, summary.total_cases + )); + md.push_str(&format!( + "| Precision@{} | {:.3} |\n", + summary.k, summary.precision + )); + md.push_str(&format!( + "| Precision@1 | {:.3} |\n", + summary.precision_at_1 + )); + md.push_str(&format!( + "| Precision@2 | {:.3} |\n", + summary.precision_at_2 + )); + md.push_str(&format!( + "| Precision@3 | {:.3} |\n", + summary.precision_at_3 + )); + md.push_str(&format!("| Duration | {} ms |\n", summary.duration_ms)); + md.push_str(&format!( + "| Latency Avg (ms) | {:.1} |\n", + summary.latency_ms.avg + )); + md.push_str(&format!( + "| Latency P50 (ms) | {} |\n", + summary.latency_ms.p50 + )); + md.push_str(&format!( + "| Latency P95 (ms) | {} |\n", + summary.latency_ms.p95 + )); + + md.push_str("\n## Retrieval Stage Timings\n\n"); + md.push_str("| Stage | Avg (ms) | P50 (ms) | P95 (ms) |\n"); + md.push_str("| --- | --- | --- | --- |\n"); + write_stage_row( + &mut md, + "Collect Candidates", + &summary.perf.stage_latency.collect_candidates, + ); + write_stage_row( + &mut md, + "Graph Expansion", + &summary.perf.stage_latency.graph_expansion, + ); + write_stage_row( + &mut md, + "Chunk Attach", + &summary.perf.stage_latency.chunk_attach, + ); + write_stage_row(&mut md, "Rerank", &summary.perf.stage_latency.rerank); + write_stage_row(&mut md, "Assemble", &summary.perf.stage_latency.assemble); + + let misses: Vec<&CaseSummary> = summary.cases.iter().filter(|case| !case.matched).collect(); + if !misses.is_empty() { + md.push_str("\n## Missed Queries (sample)\n\n"); + if summary.detailed_report { + md.push_str( + "| Question ID | Paragraph | Expected Source | Entity Match | Chunk Text | Chunk ID | Top Retrieved |\n", + ); + md.push_str("| --- | --- | --- | --- | --- | --- | --- |\n"); + } else { + md.push_str("| Question ID | Paragraph | Expected Source | Top Retrieved |\n"); + md.push_str("| --- | --- | --- | --- |\n"); + } + + for case in misses.iter().take(sample) { + let retrieved = case + .retrieved + .iter() + .map(|entry| format!("{} (rank {})", entry.source_id, entry.rank)) + .take(3) + .collect::>() + .join("
"); + if summary.detailed_report { + md.push_str(&format!( + "| `{}` | {} | `{}` | {} | {} | {} | {} |\n", + case.question_id, + case.paragraph_title, + case.expected_source, + bool_badge(case.entity_match), + bool_badge(case.chunk_text_match), + bool_badge(case.chunk_id_match), + retrieved + )); + } else { + md.push_str(&format!( + "| `{}` | {} | `{}` | {} |\n", + case.question_id, case.paragraph_title, case.expected_source, retrieved + )); + } + } + } else { + md.push_str("\n_All evaluated queries matched within the top-k window._\n"); + if summary.detailed_report { + md.push_str( + "\nSuccess measures were captured for each query (entity, chunk text, chunk ID).\n", + ); + } + } + + md +} + +fn write_stage_row(buf: &mut String, label: &str, stats: &LatencyStats) { + buf.push_str(&format!( + "| {} | {:.1} | {} | {} |\n", + label, stats.avg, stats.p50, stats.p95 + )); +} + +fn bool_badge(value: bool) -> &'static str { + if value { + "✅" + } else { + "⚪" + } +} + +fn build_report_stem(summary: &EvaluationSummary) -> String { + let timestamp = summary.generated_at.format("%Y%m%dT%H%M%S"); + let backend = sanitize_component(&summary.embedding_backend); + let dataset_component = sanitize_component(&summary.dataset_id); + let model_component = summary + .embedding_model + .as_ref() + .map(|model| sanitize_component(model)); + + match model_component { + Some(model) => format!( + "precision_at_{}_{}_{}_{}_{}", + summary.k, dataset_component, timestamp, backend, model + ), + None => format!( + "precision_at_{}_{}_{}_{}", + summary.k, dataset_component, timestamp, backend + ), + } +} + +fn sanitize_component(input: &str) -> String { + input + .chars() + .map(|ch| if ch.is_ascii_alphanumeric() { ch } else { '_' }) + .collect() +} + +pub fn dataset_report_dir(report_dir: &Path, dataset_id: &str) -> PathBuf { + report_dir.join(sanitize_component(dataset_id)) +} + +#[derive(Debug, Serialize, Deserialize)] +struct HistoryEntry { + generated_at: String, + run_label: Option, + dataset_id: String, + dataset_label: String, + slice_id: String, + slice_seed: u64, + slice_window_offset: usize, + slice_window_length: usize, + slice_cases: usize, + slice_total_cases: usize, + k: usize, + limit: Option, + precision: f64, + precision_at_1: f64, + precision_at_2: f64, + precision_at_3: f64, + duration_ms: u128, + latency_ms: LatencyStats, + embedding_backend: String, + embedding_model: Option, + ingestion_reused: bool, + ingestion_embeddings_reused: bool, + rerank_enabled: bool, + rerank_keep_top: usize, + rerank_pool_size: Option, + delta: Option, + openai_base_url: String, + ingestion_ms: u128, + #[serde(default)] + namespace_seed_ms: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +struct HistoryDelta { + precision: f64, + precision_at_1: f64, + latency_avg_ms: f64, +} + +fn record_history(summary: &EvaluationSummary, report_dir: &Path) -> Result<()> { + let path = report_dir.join("evaluations.json"); + let mut entries: Vec = if path.exists() { + let contents = fs::read(&path) + .with_context(|| format!("reading evaluation log {}", path.display()))?; + serde_json::from_slice(&contents).unwrap_or_default() + } else { + Vec::new() + }; + + let delta = entries.last().map(|prev| HistoryDelta { + precision: summary.precision - prev.precision, + precision_at_1: summary.precision_at_1 - prev.precision_at_1, + latency_avg_ms: summary.latency_ms.avg - prev.latency_ms.avg, + }); + + let entry = HistoryEntry { + generated_at: format_timestamp(&summary.generated_at), + run_label: summary.run_label.clone(), + dataset_id: summary.dataset_id.clone(), + dataset_label: summary.dataset_label.clone(), + slice_id: summary.slice_id.clone(), + slice_seed: summary.slice_seed, + slice_window_offset: summary.slice_window_offset, + slice_window_length: summary.slice_window_length, + slice_cases: summary.slice_cases, + slice_total_cases: summary.slice_total_cases, + k: summary.k, + limit: summary.limit, + precision: summary.precision, + precision_at_1: summary.precision_at_1, + precision_at_2: summary.precision_at_2, + precision_at_3: summary.precision_at_3, + duration_ms: summary.duration_ms, + latency_ms: summary.latency_ms.clone(), + embedding_backend: summary.embedding_backend.clone(), + embedding_model: summary.embedding_model.clone(), + ingestion_reused: summary.ingestion_reused, + ingestion_embeddings_reused: summary.ingestion_embeddings_reused, + rerank_enabled: summary.rerank_enabled, + rerank_keep_top: summary.rerank_keep_top, + rerank_pool_size: summary.rerank_pool_size, + delta, + openai_base_url: summary.perf.openai_base_url.clone(), + ingestion_ms: summary.perf.ingestion_ms, + namespace_seed_ms: summary.perf.namespace_seed_ms, + }; + + entries.push(entry); + + let blob = serde_json::to_vec_pretty(&entries).context("serialising evaluation log")?; + fs::write(&path, blob).with_context(|| format!("writing evaluation log {}", path.display()))?; + Ok(()) +} diff --git a/eval/src/slice.rs b/eval/src/slice.rs new file mode 100644 index 0000000..3bf456e --- /dev/null +++ b/eval/src/slice.rs @@ -0,0 +1,27 @@ +use crate::slices::SliceConfig as CoreSliceConfig; + +pub use crate::slices::*; + +use crate::args::Config; + +impl<'a> From<&'a Config> for CoreSliceConfig<'a> { + fn from(config: &'a Config) -> Self { + slice_config_with_limit(config, None) + } +} + +pub fn slice_config_with_limit<'a>( + config: &'a Config, + limit_override: Option, +) -> CoreSliceConfig<'a> { + CoreSliceConfig { + cache_dir: config.cache_dir.as_path(), + force_convert: config.force_convert, + explicit_slice: config.slice.as_deref(), + limit: limit_override.or(config.limit), + corpus_limit: config.corpus_limit, + slice_seed: config.slice_seed, + llm_mode: config.llm_mode, + negative_multiplier: config.negative_multiplier, + } +} diff --git a/eval/src/slices.rs b/eval/src/slices.rs new file mode 100644 index 0000000..8c4231c --- /dev/null +++ b/eval/src/slices.rs @@ -0,0 +1,941 @@ +use std::{ + collections::{HashMap, HashSet}, + fs, + path::{Path, PathBuf}, +}; + +use anyhow::{anyhow, Context, Result}; +use chrono::{DateTime, Utc}; +use rand::{rngs::StdRng, seq::SliceRandom, SeedableRng}; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use tracing::{info, warn}; + +use crate::datasets::{ConvertedDataset, ConvertedParagraph, ConvertedQuestion}; + +const SLICE_VERSION: u32 = 2; +pub const DEFAULT_NEGATIVE_MULTIPLIER: f32 = 4.0; + +#[derive(Debug, Clone)] +pub struct SliceConfig<'a> { + pub cache_dir: &'a Path, + pub force_convert: bool, + pub explicit_slice: Option<&'a str>, + pub limit: Option, + pub corpus_limit: Option, + pub slice_seed: u64, + pub llm_mode: bool, + pub negative_multiplier: f32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SliceManifest { + pub version: u32, + pub slice_id: String, + pub dataset_id: String, + pub dataset_label: String, + pub dataset_source: String, + pub includes_unanswerable: bool, + pub seed: u64, + pub requested_limit: Option, + pub requested_corpus: usize, + pub generated_at: DateTime, + pub case_count: usize, + pub positive_paragraphs: usize, + pub negative_paragraphs: usize, + pub total_paragraphs: usize, + pub negative_multiplier: f32, + pub cases: Vec, + pub paragraphs: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SliceCaseEntry { + pub question_id: String, + pub paragraph_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SliceParagraphEntry { + pub id: String, + pub kind: SliceParagraphKind, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub shard_path: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "kind", rename_all = "snake_case")] +pub enum SliceParagraphKind { + Positive { question_ids: Vec }, + Negative, +} + +pub(crate) fn default_shard_path(paragraph_id: &str) -> String { + let sanitized = sanitize_identifier(paragraph_id); + format!("paragraphs/{sanitized}.json") +} + +fn sanitize_identifier(input: &str) -> String { + let mut sanitized = String::with_capacity(input.len()); + for ch in input.chars() { + if ch.is_ascii_alphanumeric() { + sanitized.push(ch); + } else { + sanitized.push('-'); + } + } + let trimmed = sanitized.trim_matches('-').to_string(); + if trimmed.is_empty() { + let mut hasher = Sha256::new(); + hasher.update(input.as_bytes()); + let digest = hasher.finalize(); + digest[..6] + .iter() + .map(|byte| format!("{byte:02x}")) + .collect::() + } else { + trimmed + } +} + +#[derive(Debug, Clone)] +pub struct ResolvedSlice<'a> { + pub manifest: SliceManifest, + pub path: PathBuf, + pub paragraphs: Vec<&'a ConvertedParagraph>, + pub cases: Vec>, +} + +#[derive(Debug, Clone)] +pub struct SliceWindow<'a> { + pub offset: usize, + pub length: usize, + pub total_cases: usize, + pub cases: Vec>, + positive_paragraph_ids: Vec, +} + +impl<'a> SliceWindow<'a> { + pub fn positive_ids(&self) -> impl Iterator { + self.positive_paragraph_ids.iter().map(|id| id.as_str()) + } +} + +#[derive(Debug, Clone)] +pub struct CaseRef<'a> { + pub paragraph: &'a ConvertedParagraph, + pub question: &'a ConvertedQuestion, +} + +struct DatasetIndex { + paragraph_by_id: HashMap, + question_by_id: HashMap, +} + +impl DatasetIndex { + fn build(dataset: &ConvertedDataset) -> Self { + let mut paragraph_by_id = HashMap::new(); + let mut question_by_id = HashMap::new(); + + for (p_idx, paragraph) in dataset.paragraphs.iter().enumerate() { + paragraph_by_id.insert(paragraph.id.clone(), p_idx); + for (q_idx, question) in paragraph.questions.iter().enumerate() { + question_by_id.insert(question.id.clone(), (p_idx, q_idx)); + } + } + + Self { + paragraph_by_id, + question_by_id, + } + } + + fn paragraph<'a>( + &self, + dataset: &'a ConvertedDataset, + id: &str, + ) -> Result<&'a ConvertedParagraph> { + let idx = self + .paragraph_by_id + .get(id) + .ok_or_else(|| anyhow!("slice references unknown paragraph '{id}'"))?; + Ok(&dataset.paragraphs[*idx]) + } + + fn question<'a>( + &self, + dataset: &'a ConvertedDataset, + question_id: &str, + ) -> Result<(&'a ConvertedParagraph, &'a ConvertedQuestion)> { + let (p_idx, q_idx) = self + .question_by_id + .get(question_id) + .ok_or_else(|| anyhow!("slice references unknown question '{question_id}'"))?; + let paragraph = &dataset.paragraphs[*p_idx]; + let question = paragraph + .questions + .get(*q_idx) + .ok_or_else(|| anyhow!("slice maps question '{question_id}' to missing index"))?; + Ok((paragraph, question)) + } +} + +#[derive(Debug, Serialize)] +struct SliceKey<'a> { + dataset_id: &'a str, + includes_unanswerable: bool, + requested_corpus: usize, + seed: u64, +} + +#[derive(Debug)] +struct BuildParams { + include_impossible: bool, + base_seed: u64, + rng_seed: u64, +} + +pub fn resolve_slice<'a>( + dataset: &'a ConvertedDataset, + config: &SliceConfig<'_>, +) -> Result> { + let index = DatasetIndex::build(dataset); + + if let Some(slice_arg) = config.explicit_slice { + let (path, manifest) = load_explicit_slice(dataset, &index, config, slice_arg)?; + let resolved = manifest_to_resolved(dataset, &index, manifest, path)?; + info!( + slice = %resolved.manifest.slice_id, + path = %resolved.path.display(), + cases = resolved.manifest.case_count, + positives = resolved.manifest.positive_paragraphs, + negatives = resolved.manifest.negative_paragraphs, + "Using explicitly selected slice" + ); + return Ok(resolved); + } + + let requested_corpus = config + .corpus_limit + .unwrap_or(dataset.paragraphs.len()) + .min(dataset.paragraphs.len()) + .max(1); + let key = SliceKey { + dataset_id: dataset.metadata.id.as_str(), + includes_unanswerable: dataset.metadata.include_unanswerable, + requested_corpus, + seed: config.slice_seed, + }; + let slice_id = compute_slice_id(&key); + let base = config + .cache_dir + .join("slices") + .join(dataset.metadata.id.as_str()); + let path = base.join(format!("{slice_id}.json")); + + let total_questions = dataset + .paragraphs + .iter() + .map(|p| p.questions.len()) + .sum::() + .max(1); + let requested_limit = config + .limit + .unwrap_or(total_questions) + .min(total_questions) + .max(1); + + let mut manifest = if !config.force_convert && path.exists() { + match read_manifest(&path) { + Ok(manifest) if manifest.dataset_id == dataset.metadata.id => { + if manifest.includes_unanswerable != dataset.metadata.include_unanswerable { + warn!( + slice = manifest.slice_id, + path = %path.display(), + "Slice manifest includes_unanswerable mismatch; regenerating" + ); + None + } else { + Some(manifest) + } + } + Ok(manifest) => { + warn!( + slice = manifest.slice_id, + path = %path.display(), + loaded_dataset = %manifest.dataset_id, + expected = %dataset.metadata.id, + "Slice manifest targets different dataset; regenerating" + ); + None + } + Err(err) => { + warn!( + path = %path.display(), + error = %err, + "Failed to read cached slice; regenerating" + ); + None + } + } + } else { + None + }; + + let params = BuildParams { + include_impossible: config.llm_mode, + base_seed: config.slice_seed, + rng_seed: mix_seed(dataset.metadata.id.as_str(), config.slice_seed), + }; + + if manifest + .as_ref() + .map(|manifest| manifest.version != SLICE_VERSION) + .unwrap_or(false) + { + warn!( + slice = manifest + .as_ref() + .map(|m| m.slice_id.as_str()) + .unwrap_or("unknown"), + found = manifest.as_ref().map(|m| m.version).unwrap_or(0), + expected = SLICE_VERSION, + "Slice manifest version mismatch; regenerating" + ); + manifest = None; + } + + let mut manifest = manifest.unwrap_or_else(|| { + empty_manifest( + dataset, + slice_id.clone(), + ¶ms, + requested_corpus, + config.negative_multiplier, + config.limit, + ) + }); + + manifest.requested_limit = config.limit; + manifest.requested_corpus = requested_corpus; + manifest.negative_multiplier = config.negative_multiplier; + + let mut changed = ensure_shard_paths(&mut manifest); + + changed |= ensure_case_capacity(dataset, &mut manifest, ¶ms, requested_limit)?; + refresh_manifest_stats(&mut manifest); + + let desired_negatives = desired_negative_target( + manifest.positive_paragraphs, + requested_corpus, + dataset.paragraphs.len(), + config.negative_multiplier, + ); + changed |= ensure_negative_pool( + dataset, + &mut manifest, + ¶ms, + desired_negatives, + requested_corpus, + )?; + refresh_manifest_stats(&mut manifest); + + if changed { + manifest.generated_at = Utc::now(); + write_manifest(&path, &manifest)?; + info!( + slice = %manifest.slice_id, + path = %path.display(), + cases = manifest.case_count, + positives = manifest.positive_paragraphs, + negatives = manifest.negative_paragraphs, + "Updated dataset slice ledger" + ); + } else { + info!( + slice = %manifest.slice_id, + path = %path.display(), + cases = manifest.case_count, + positives = manifest.positive_paragraphs, + negatives = manifest.negative_paragraphs, + "Reusing cached slice ledger" + ); + } + + let resolved = manifest_to_resolved(dataset, &index, manifest.clone(), path.clone())?; + + Ok(resolved) +} + +pub fn select_window<'a>( + resolved: &'a ResolvedSlice<'a>, + offset: usize, + limit: Option, +) -> Result> { + let total = resolved.manifest.case_count; + if total == 0 { + return Err(anyhow!( + "slice '{}' contains no cases", + resolved.manifest.slice_id + )); + } + if offset >= total { + return Err(anyhow!( + "slice offset {} exceeds available cases ({})", + offset, + total + )); + } + let available = total - offset; + let requested = limit.unwrap_or(available).max(1); + let length = requested.min(available); + let cases = resolved.cases[offset..offset + length].to_vec(); + let mut seen = HashSet::new(); + let mut positive_ids = Vec::new(); + for case in &cases { + if seen.insert(case.paragraph.id.as_str()) { + positive_ids.push(case.paragraph.id.clone()); + } + } + Ok(SliceWindow { + offset, + length, + total_cases: total, + cases, + positive_paragraph_ids: positive_ids, + }) +} + +#[allow(dead_code)] +pub fn full_window<'a>(resolved: &'a ResolvedSlice<'a>) -> Result> { + select_window(resolved, 0, None) +} + +fn load_explicit_slice<'a>( + dataset: &'a ConvertedDataset, + index: &DatasetIndex, + config: &SliceConfig<'_>, + slice_arg: &str, +) -> Result<(PathBuf, SliceManifest)> { + let explicit_path = Path::new(slice_arg); + let candidate_path = if explicit_path.exists() { + explicit_path.to_path_buf() + } else { + config + .cache_dir + .join("slices") + .join(dataset.metadata.id.as_str()) + .join(format!("{slice_arg}.json")) + }; + + let manifest = read_manifest(&candidate_path) + .with_context(|| format!("reading slice manifest at {}", candidate_path.display()))?; + + if manifest.dataset_id != dataset.metadata.id { + return Err(anyhow!( + "slice '{}' targets dataset '{}', but '{}' is loaded", + manifest.slice_id, + manifest.dataset_id, + dataset.metadata.id + )); + } + + // Validate the manifest before returning. + manifest_to_resolved(dataset, index, manifest.clone(), candidate_path.clone())?; + + Ok((candidate_path, manifest)) +} + +fn empty_manifest( + dataset: &ConvertedDataset, + slice_id: String, + params: &BuildParams, + requested_corpus: usize, + negative_multiplier: f32, + requested_limit: Option, +) -> SliceManifest { + SliceManifest { + version: SLICE_VERSION, + slice_id, + dataset_id: dataset.metadata.id.clone(), + dataset_label: dataset.metadata.label.clone(), + dataset_source: dataset.source.clone(), + includes_unanswerable: dataset.metadata.include_unanswerable, + seed: params.base_seed, + requested_limit, + requested_corpus, + negative_multiplier, + generated_at: Utc::now(), + case_count: 0, + positive_paragraphs: 0, + negative_paragraphs: 0, + total_paragraphs: 0, + cases: Vec::new(), + paragraphs: Vec::new(), + } +} + +fn ensure_case_capacity( + dataset: &ConvertedDataset, + manifest: &mut SliceManifest, + params: &BuildParams, + target_cases: usize, +) -> Result { + if manifest.case_count >= target_cases { + return Ok(false); + } + + let question_refs = ordered_question_refs(dataset, params)?; + let mut existing_questions: HashSet = manifest + .cases + .iter() + .map(|case| case.question_id.clone()) + .collect(); + let mut paragraph_positions: HashMap = manifest + .paragraphs + .iter() + .enumerate() + .map(|(idx, entry)| (entry.id.clone(), idx)) + .collect(); + + let mut changed = false; + + for (p_idx, q_idx) in question_refs { + if manifest.case_count >= target_cases { + break; + } + let paragraph = &dataset.paragraphs[p_idx]; + let question = ¶graph.questions[q_idx]; + if !existing_questions.insert(question.id.clone()) { + continue; + } + + if let Some(idx) = paragraph_positions.get(paragraph.id.as_str()).copied() { + match &mut manifest.paragraphs[idx].kind { + SliceParagraphKind::Positive { question_ids } => { + if !question_ids.contains(&question.id) { + question_ids.push(question.id.clone()); + } + } + SliceParagraphKind::Negative => { + manifest.paragraphs[idx].kind = SliceParagraphKind::Positive { + question_ids: vec![question.id.clone()], + }; + } + } + } else { + manifest.paragraphs.push(SliceParagraphEntry { + id: paragraph.id.clone(), + kind: SliceParagraphKind::Positive { + question_ids: vec![question.id.clone()], + }, + shard_path: Some(default_shard_path(¶graph.id)), + }); + let idx = manifest.paragraphs.len() - 1; + paragraph_positions.insert(paragraph.id.clone(), idx); + } + + manifest.cases.push(SliceCaseEntry { + question_id: question.id.clone(), + paragraph_id: paragraph.id.clone(), + }); + manifest.case_count += 1; + changed = true; + } + + if manifest.case_count < target_cases { + return Err(anyhow!( + "only {}/{} eligible questions available for dataset {}", + manifest.case_count, + target_cases, + dataset.metadata.id + )); + } + + Ok(changed) +} + +fn ordered_question_refs( + dataset: &ConvertedDataset, + params: &BuildParams, +) -> Result> { + let mut question_refs = Vec::new(); + for (p_idx, paragraph) in dataset.paragraphs.iter().enumerate() { + for (q_idx, question) in paragraph.questions.iter().enumerate() { + let include = if params.include_impossible { + true + } else { + !question.is_impossible && !question.answers.is_empty() + }; + if include { + question_refs.push((p_idx, q_idx)); + } + } + } + + if question_refs.is_empty() { + return Err(anyhow!( + "no eligible questions found for dataset {}; cannot build slice", + dataset.metadata.id + )); + } + + let mut rng = StdRng::seed_from_u64(params.rng_seed); + question_refs.shuffle(&mut rng); + Ok(question_refs) +} + +fn ensure_negative_pool( + dataset: &ConvertedDataset, + manifest: &mut SliceManifest, + params: &BuildParams, + target_negatives: usize, + requested_corpus: usize, +) -> Result { + let current_negatives = manifest + .paragraphs + .iter() + .filter(|entry| matches!(entry.kind, SliceParagraphKind::Negative)) + .count(); + if current_negatives >= target_negatives { + return Ok(false); + } + + let positive_ids: HashSet = manifest + .paragraphs + .iter() + .filter_map(|entry| match entry.kind { + SliceParagraphKind::Positive { .. } => Some(entry.id.clone()), + _ => None, + }) + .collect(); + let mut negative_ids: HashSet = manifest + .paragraphs + .iter() + .filter_map(|entry| match entry.kind { + SliceParagraphKind::Negative => Some(entry.id.clone()), + _ => None, + }) + .collect(); + + let negative_seed = mix_seed( + &format!("{}::negatives", dataset.metadata.id), + params.base_seed, + ); + let candidates = ordered_negative_indices(dataset, &positive_ids, negative_seed); + let mut added = false; + for idx in candidates { + if negative_ids.len() >= target_negatives { + break; + } + let paragraph = &dataset.paragraphs[idx]; + if negative_ids.contains(paragraph.id.as_str()) + || positive_ids.contains(paragraph.id.as_str()) + { + continue; + } + manifest.paragraphs.push(SliceParagraphEntry { + id: paragraph.id.clone(), + kind: SliceParagraphKind::Negative, + shard_path: Some(default_shard_path(¶graph.id)), + }); + negative_ids.insert(paragraph.id.clone()); + added = true; + } + + if negative_ids.len() < target_negatives { + warn!( + dataset = %dataset.metadata.id, + desired = target_negatives, + available = negative_ids.len(), + requested_corpus, + "Insufficient negative paragraphs to satisfy multiplier" + ); + } + + Ok(added) +} + +fn ordered_negative_indices( + dataset: &ConvertedDataset, + positive_ids: &HashSet, + rng_seed: u64, +) -> Vec { + let mut candidates: Vec = dataset + .paragraphs + .iter() + .enumerate() + .filter_map(|(idx, paragraph)| { + if positive_ids.contains(paragraph.id.as_str()) { + None + } else { + Some(idx) + } + }) + .collect(); + let mut rng = StdRng::seed_from_u64(rng_seed); + candidates.shuffle(&mut rng); + candidates +} + +fn refresh_manifest_stats(manifest: &mut SliceManifest) { + manifest.case_count = manifest.cases.len(); + manifest.positive_paragraphs = manifest + .paragraphs + .iter() + .filter(|entry| matches!(entry.kind, SliceParagraphKind::Positive { .. })) + .count(); + manifest.negative_paragraphs = manifest + .paragraphs + .iter() + .filter(|entry| matches!(entry.kind, SliceParagraphKind::Negative)) + .count(); + manifest.total_paragraphs = manifest.paragraphs.len(); +} + +fn ensure_shard_paths(manifest: &mut SliceManifest) -> bool { + let mut changed = false; + for entry in &mut manifest.paragraphs { + if entry.shard_path.is_none() { + entry.shard_path = Some(default_shard_path(&entry.id)); + changed = true; + } + } + changed +} + +fn desired_negative_target( + positive_count: usize, + requested_corpus: usize, + dataset_paragraphs: usize, + multiplier: f32, +) -> usize { + if positive_count == 0 { + return 0; + } + let ratio = multiplier.max(0.0); + let mut desired = ((positive_count as f32) * ratio).ceil() as usize; + let max_total = requested_corpus.min(dataset_paragraphs).max(positive_count); + let max_negatives = max_total.saturating_sub(positive_count); + desired = desired.min(max_negatives); + desired +} + +fn manifest_to_resolved<'a>( + dataset: &'a ConvertedDataset, + index: &DatasetIndex, + manifest: SliceManifest, + path: PathBuf, +) -> Result> { + if manifest.version != SLICE_VERSION { + return Err(anyhow!( + "slice version {} does not match expected {}", + manifest.version, + SLICE_VERSION + )); + } + + let mut paragraphs = Vec::with_capacity(manifest.paragraphs.len()); + for entry in &manifest.paragraphs { + let paragraph = index.paragraph(dataset, &entry.id)?; + if let SliceParagraphKind::Positive { question_ids } = &entry.kind { + for question_id in question_ids { + let (linked_paragraph, _) = index.question(dataset, question_id)?; + if linked_paragraph.id != entry.id { + return Err(anyhow!( + "slice question '{}' expected paragraph '{}', found '{}'", + question_id, + entry.id, + linked_paragraph.id + )); + } + } + } + paragraphs.push(paragraph); + } + + let mut cases = Vec::with_capacity(manifest.cases.len()); + for entry in &manifest.cases { + let (paragraph, question) = index.question(dataset, &entry.question_id)?; + if paragraph.id != entry.paragraph_id { + return Err(anyhow!( + "slice case '{}' expected paragraph '{}', found '{}'", + entry.question_id, + entry.paragraph_id, + paragraph.id + )); + } + cases.push(CaseRef { + paragraph, + question, + }); + } + + if cases.is_empty() { + return Err(anyhow!( + "slice '{}' contains no cases after validation", + manifest.slice_id + )); + } + + Ok(ResolvedSlice { + manifest, + path, + paragraphs, + cases, + }) +} + +fn compute_slice_id(key: &SliceKey<'_>) -> String { + let payload = serde_json::to_vec(key).expect("SliceKey serialisation should not fail"); + let mut hasher = Sha256::new(); + hasher.update(payload); + let digest = hasher.finalize(); + digest[..16] + .iter() + .map(|byte| format!("{byte:02x}")) + .collect::() +} + +fn mix_seed(dataset_id: &str, seed: u64) -> u64 { + let mut hasher = Sha256::new(); + hasher.update(dataset_id.as_bytes()); + hasher.update(seed.to_le_bytes()); + let digest = hasher.finalize(); + let mut bytes = [0u8; 8]; + bytes.copy_from_slice(&digest[..8]); + u64::from_le_bytes(bytes) +} + +fn read_manifest(path: &Path) -> Result { + let raw = fs::read_to_string(path) + .with_context(|| format!("reading slice manifest {}", path.display()))?; + let manifest: SliceManifest = serde_json::from_str(&raw) + .with_context(|| format!("parsing slice manifest {}", path.display()))?; + Ok(manifest) +} + +fn write_manifest(path: &Path, manifest: &SliceManifest) -> Result<()> { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent) + .with_context(|| format!("creating slice directory {}", parent.display()))?; + } + let json = serde_json::to_vec_pretty(manifest).context("serialising slice manifest to JSON")?; + fs::write(path, json).with_context(|| format!("writing slice manifest {}", path.display()))?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::datasets::{ + ConvertedDataset, ConvertedParagraph, ConvertedQuestion, DatasetKind, DatasetMetadata, + }; + use tempfile::tempdir; + + fn sample_dataset() -> ConvertedDataset { + let metadata = DatasetMetadata::for_kind(DatasetKind::SquadV2, false, None); + ConvertedDataset { + generated_at: Utc::now(), + metadata, + source: "test-source".to_string(), + paragraphs: vec![ + ConvertedParagraph { + id: "p1".to_string(), + title: "Alpha".to_string(), + context: "Alpha context".to_string(), + questions: vec![ConvertedQuestion { + id: "q1".to_string(), + question: "What is alpha?".to_string(), + answers: vec!["Alpha".to_string()], + is_impossible: false, + }], + }, + ConvertedParagraph { + id: "p2".to_string(), + title: "Beta".to_string(), + context: "Beta context".to_string(), + questions: vec![ConvertedQuestion { + id: "q2".to_string(), + question: "What is beta?".to_string(), + answers: vec!["Beta".to_string()], + is_impossible: false, + }], + }, + ConvertedParagraph { + id: "p3".to_string(), + title: "Gamma".to_string(), + context: "Gamma context".to_string(), + questions: vec![ConvertedQuestion { + id: "q3".to_string(), + question: "What is gamma?".to_string(), + answers: vec!["Gamma".to_string()], + is_impossible: false, + }], + }, + ], + } + } + + #[test] + fn resolve_slice_reuses_cached_manifest() -> Result<()> { + let dataset = sample_dataset(); + let temp = tempdir().context("creating temp directory")?; + + let mut config = SliceConfig { + cache_dir: temp.path(), + force_convert: false, + explicit_slice: None, + limit: Some(2), + corpus_limit: Some(3), + slice_seed: 0x5eed_2025, + llm_mode: false, + negative_multiplier: DEFAULT_NEGATIVE_MULTIPLIER, + }; + + let first = resolve_slice(&dataset, &config)?; + assert!(first.path.exists()); + let initial_generated = first.manifest.generated_at; + + let second = resolve_slice(&dataset, &config)?; + assert_eq!(first.manifest.slice_id, second.manifest.slice_id); + assert_eq!(initial_generated, second.manifest.generated_at); + + config.force_convert = true; + let third = resolve_slice(&dataset, &config)?; + assert_eq!(first.manifest.slice_id, third.manifest.slice_id); + assert_ne!(third.manifest.generated_at, initial_generated); + + Ok(()) + } + + #[test] + fn select_window_yields_expected_cases() -> Result<()> { + let dataset = sample_dataset(); + let temp = tempdir().context("creating temp directory")?; + let config = SliceConfig { + cache_dir: temp.path(), + force_convert: false, + explicit_slice: None, + limit: Some(3), + corpus_limit: Some(3), + slice_seed: 0x5eed_2025, + llm_mode: false, + negative_multiplier: DEFAULT_NEGATIVE_MULTIPLIER, + }; + let resolved = resolve_slice(&dataset, &config)?; + let window = select_window(&resolved, 1, Some(1))?; + assert_eq!(window.offset, 1); + assert_eq!(window.length, 1); + assert_eq!(window.total_cases, resolved.manifest.case_count); + assert_eq!(window.cases.len(), 1); + let positive_ids: Vec<&str> = window.positive_ids().collect(); + assert_eq!(positive_ids.len(), 1); + assert!(resolved + .manifest + .paragraphs + .iter() + .any(|entry| entry.id == positive_ids[0])); + Ok(()) + } +} diff --git a/eval/src/snapshot.rs b/eval/src/snapshot.rs new file mode 100644 index 0000000..6068417 --- /dev/null +++ b/eval/src/snapshot.rs @@ -0,0 +1,182 @@ +use std::path::PathBuf; + +use anyhow::{Context, Result}; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use tokio::fs; + +use crate::{args::Config, embedding::EmbeddingProvider, slice}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct SnapshotMetadata { + pub dataset_id: String, + pub slice_id: String, + pub embedding_backend: String, + pub embedding_model: Option, + pub embedding_dimension: usize, + pub chunk_min_chars: usize, + pub chunk_max_chars: usize, + pub rerank_enabled: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DbSnapshotState { + pub dataset_id: String, + pub slice_id: String, + pub ingestion_fingerprint: String, + pub snapshot_hash: String, + pub updated_at: DateTime, + #[serde(default)] + pub namespace: Option, + #[serde(default)] + pub database: Option, + #[serde(default)] + pub slice_case_count: usize, +} + +pub struct Descriptor { + #[allow(dead_code)] + metadata: SnapshotMetadata, + dir: PathBuf, + metadata_hash: String, +} + +impl Descriptor { + pub fn new( + config: &Config, + slice: &slice::ResolvedSlice<'_>, + embedding_provider: &EmbeddingProvider, + ) -> Self { + let metadata = SnapshotMetadata { + dataset_id: slice.manifest.dataset_id.clone(), + slice_id: slice.manifest.slice_id.clone(), + embedding_backend: embedding_provider.backend_label().to_string(), + embedding_model: embedding_provider.model_code(), + embedding_dimension: embedding_provider.dimension(), + chunk_min_chars: config.chunk_min_chars, + chunk_max_chars: config.chunk_max_chars, + rerank_enabled: config.rerank, + }; + + let dir = config + .cache_dir + .join("snapshots") + .join(&metadata.dataset_id) + .join(&metadata.slice_id); + let metadata_hash = compute_hash(&metadata); + + Self { + metadata, + dir, + metadata_hash, + } + } + + pub fn metadata_hash(&self) -> &str { + &self.metadata_hash + } + + pub async fn load_db_state(&self) -> Result> { + let path = self.db_state_path(); + if !path.exists() { + return Ok(None); + } + let bytes = fs::read(&path) + .await + .with_context(|| format!("reading namespace state {}", path.display()))?; + let state = serde_json::from_slice(&bytes) + .with_context(|| format!("deserialising namespace state {}", path.display()))?; + Ok(Some(state)) + } + + pub async fn store_db_state(&self, state: &DbSnapshotState) -> Result<()> { + let path = self.db_state_path(); + if let Some(parent) = path.parent() { + fs::create_dir_all(parent).await.with_context(|| { + format!("creating namespace state directory {}", parent.display()) + })?; + } + let blob = + serde_json::to_vec_pretty(state).context("serialising namespace state payload")?; + fs::write(&path, blob) + .await + .with_context(|| format!("writing namespace state {}", path.display()))?; + Ok(()) + } + + fn db_dir(&self) -> PathBuf { + self.dir.join("db") + } + + fn db_state_path(&self) -> PathBuf { + self.db_dir().join("state.json") + } + + #[cfg(test)] + pub fn from_parts(metadata: SnapshotMetadata, dir: PathBuf) -> Self { + let metadata_hash = compute_hash(&metadata); + Self { + metadata, + dir, + metadata_hash, + } + } +} + +fn compute_hash(metadata: &SnapshotMetadata) -> String { + let mut hasher = Sha256::new(); + hasher.update( + serde_json::to_vec(metadata).expect("snapshot metadata serialisation should succeed"), + ); + format!("{:x}", hasher.finalize()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn state_round_trip() { + let temp_dir = tempfile::tempdir().unwrap(); + let metadata = SnapshotMetadata { + dataset_id: "dataset".into(), + slice_id: "slice".into(), + embedding_backend: "hashed".into(), + embedding_model: None, + embedding_dimension: 128, + chunk_min_chars: 10, + chunk_max_chars: 100, + rerank_enabled: true, + }; + let descriptor = Descriptor::from_parts( + metadata, + temp_dir + .path() + .join("snapshots") + .join("dataset") + .join("slice"), + ); + + let state = DbSnapshotState { + dataset_id: "dataset".into(), + slice_id: "slice".into(), + ingestion_fingerprint: "fingerprint".into(), + snapshot_hash: descriptor.metadata_hash().to_string(), + updated_at: Utc::now(), + namespace: Some("ns".into()), + database: Some("db".into()), + slice_case_count: 42, + }; + descriptor.store_db_state(&state).await.unwrap(); + + let loaded = descriptor.load_db_state().await.unwrap().unwrap(); + assert_eq!(loaded.dataset_id, state.dataset_id); + assert_eq!(loaded.slice_id, state.slice_id); + assert_eq!(loaded.ingestion_fingerprint, state.ingestion_fingerprint); + assert_eq!(loaded.snapshot_hash, state.snapshot_hash); + assert_eq!(loaded.namespace, state.namespace); + assert_eq!(loaded.database, state.database); + assert_eq!(loaded.slice_case_count, state.slice_case_count); + } +} diff --git a/html-router/src/routes/chat/mod.rs b/html-router/src/routes/chat/mod.rs index b7052d0..0d6bd0a 100644 --- a/html-router/src/routes/chat/mod.rs +++ b/html-router/src/routes/chat/mod.rs @@ -7,7 +7,7 @@ use axum::{ routing::{get, post}, Router, }; -use chat_handlers::{ +pub use chat_handlers::{ delete_conversation, new_chat_user_message, new_user_message, patch_conversation_title, reload_sidebar, show_chat_base, show_conversation_editing_title, show_existing_chat, show_initialized_chat, diff --git a/html-router/src/routes/search/mod.rs b/html-router/src/routes/search/mod.rs index 693f659..e61e2bc 100644 --- a/html-router/src/routes/search/mod.rs +++ b/html-router/src/routes/search/mod.rs @@ -1,7 +1,7 @@ mod handlers; use axum::{extract::FromRef, routing::get, Router}; -use handlers::search_result_handler; +pub use handlers::{search_result_handler, SearchParams}; use crate::html_state::HtmlState; diff --git a/ingestion-pipeline/src/lib.rs b/ingestion-pipeline/src/lib.rs index 4eb7e9f..671f2ae 100644 --- a/ingestion-pipeline/src/lib.rs +++ b/ingestion-pipeline/src/lib.rs @@ -6,7 +6,7 @@ use common::storage::{ db::SurrealDbClient, types::ingestion_task::{IngestionTask, DEFAULT_LEASE_SECS}, }; -use pipeline::IngestionPipeline; +pub use pipeline::{IngestionConfig, IngestionPipeline, IngestionTuning}; use std::sync::Arc; use tokio::time::{sleep, Duration}; use tracing::{error, info, warn}; diff --git a/ingestion-pipeline/src/pipeline/context.rs b/ingestion-pipeline/src/pipeline/context.rs index 74959b1..26ee9a1 100644 --- a/ingestion-pipeline/src/pipeline/context.rs +++ b/ingestion-pipeline/src/pipeline/context.rs @@ -1,8 +1,14 @@ +use std::ops::Range; + use common::{ error::AppError, storage::{ db::SurrealDbClient, - types::{ingestion_task::IngestionTask, text_content::TextContent}, + types::{ + ingestion_task::IngestionTask, knowledge_entity::KnowledgeEntity, + knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk, + text_content::TextContent, + }, }, }; use composite_retrieval::RetrievedEntity; @@ -24,6 +30,14 @@ pub struct PipelineContext<'a> { pub analysis: Option, } +#[derive(Debug)] +pub struct PipelineArtifacts { + pub text_content: TextContent, + pub entities: Vec, + pub relationships: Vec, + pub chunks: Vec, +} + impl<'a> PipelineContext<'a> { pub fn new( task: &'a IngestionTask, @@ -73,4 +87,30 @@ impl<'a> PipelineContext<'a> { ); err } + + pub async fn build_artifacts(&mut self) -> Result { + let content = self.take_text_content()?; + let analysis = self.take_analysis()?; + + let (entities, relationships) = self + .services + .convert_analysis( + &content, + &analysis, + self.pipeline_config.tuning.entity_embedding_concurrency, + ) + .await?; + + let chunk_range: Range = self.pipeline_config.tuning.chunk_min_chars + ..self.pipeline_config.tuning.chunk_max_chars; + + let chunks = self.services.prepare_chunks(&content, chunk_range).await?; + + Ok(PipelineArtifacts { + text_content: content, + entities, + relationships, + chunks, + }) + } } diff --git a/ingestion-pipeline/src/pipeline/mod.rs b/ingestion-pipeline/src/pipeline/mod.rs index 686162b..6355446 100644 --- a/ingestion-pipeline/src/pipeline/mod.rs +++ b/ingestion-pipeline/src/pipeline/mod.rs @@ -7,6 +7,7 @@ mod stages; mod state; pub use config::{IngestionConfig, IngestionTuning}; +pub use enrichment_result::{LLMEnrichmentResult, LLMKnowledgeEntity, LLMRelationship}; pub use services::{DefaultPipelineServices, PipelineServices}; use std::{ @@ -31,7 +32,7 @@ use composite_retrieval::reranking::RerankerPool; use tracing::{debug, info, warn}; use self::{ - context::PipelineContext, + context::{PipelineArtifacts, PipelineContext}, stages::{enrich, persist, prepare_content, retrieve_related}, state::ready, }; @@ -224,6 +225,33 @@ impl IngestionPipeline { Ok(()) } + + /// Runs the ingestion pipeline up to (but excluding) persistence and returns the prepared artifacts. + pub async fn produce_artifacts( + &self, + task: &IngestionTask, + ) -> Result { + let payload = task.content.clone(); + let mut ctx = PipelineContext::new( + task, + self.db.as_ref(), + &self.pipeline_config, + self.services.as_ref(), + ); + + let machine = ready(); + let machine = prepare_content(machine, &mut ctx, payload) + .await + .map_err(|err| ctx.abort(err))?; + let machine = retrieve_related(machine, &mut ctx) + .await + .map_err(|err| ctx.abort(err))?; + let _machine = enrich(machine, &mut ctx) + .await + .map_err(|err| ctx.abort(err))?; + + ctx.build_artifacts().await.map_err(|err| ctx.abort(err)) + } } #[cfg(test)] diff --git a/ingestion-pipeline/src/pipeline/services.rs b/ingestion-pipeline/src/pipeline/services.rs index 3a63203..719e463 100644 --- a/ingestion-pipeline/src/pipeline/services.rs +++ b/ingestion-pipeline/src/pipeline/services.rs @@ -29,6 +29,8 @@ use crate::utils::llm_instructions::{ get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE, }; +const EMBEDDING_QUERY_CHAR_LIMIT: usize = 12_000; + #[async_trait] pub trait PipelineServices: Send + Sync { async fn prepare_text_content( @@ -162,9 +164,13 @@ impl PipelineServices for DefaultPipelineServices { &self, content: &TextContent, ) -> Result, AppError> { + let truncated_body = truncate_for_embedding(&content.text, EMBEDDING_QUERY_CHAR_LIMIT); let input_text = format!( - "content: {}, category: {}, user_context: {:?}", - content.text, content.category, content.context + "content: {}\n[truncated={}], category: {}, user_context: {:?}", + truncated_body, + truncated_body.len() < content.text.len(), + content.category, + content.context ); let rerank_lease = match &self.reranker_pool { @@ -239,3 +245,19 @@ impl PipelineServices for DefaultPipelineServices { Ok(chunks) } } + +fn truncate_for_embedding(text: &str, max_chars: usize) -> String { + if text.chars().count() <= max_chars { + return text.to_string(); + } + + let mut truncated = String::with_capacity(max_chars + 3); + for (idx, ch) in text.chars().enumerate() { + if idx >= max_chars { + break; + } + truncated.push(ch); + } + truncated.push_str("…"); + truncated +} diff --git a/ingestion-pipeline/src/pipeline/stages/mod.rs b/ingestion-pipeline/src/pipeline/stages/mod.rs index 8143dc8..8f3085d 100644 --- a/ingestion-pipeline/src/pipeline/stages/mod.rs +++ b/ingestion-pipeline/src/pipeline/stages/mod.rs @@ -7,7 +7,6 @@ use common::{ types::{ ingestion_payload::IngestionPayload, knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk, - text_content::TextContent, }, }, }; @@ -16,8 +15,7 @@ use tokio::time::{sleep, Duration}; use tracing::{debug, instrument, warn}; use super::{ - context::PipelineContext, - services::PipelineServices, + context::{PipelineArtifacts, PipelineContext}, state::{ContentPrepared, Enriched, IngestionMachine, Persisted, Ready, Retrieved}, }; @@ -134,37 +132,26 @@ pub async fn persist( machine: IngestionMachine<(), Enriched>, ctx: &mut PipelineContext<'_>, ) -> Result, AppError> { - let content = ctx.take_text_content()?; - let analysis = ctx.take_analysis()?; - - let (entities, relationships) = ctx - .services - .convert_analysis( - &content, - &analysis, - ctx.pipeline_config.tuning.entity_embedding_concurrency, - ) - .await?; - + let PipelineArtifacts { + text_content, + entities, + relationships, + chunks, + } = ctx.build_artifacts().await?; let entity_count = entities.len(); let relationship_count = relationships.len(); - let chunk_range = - ctx.pipeline_config.tuning.chunk_min_chars..ctx.pipeline_config.tuning.chunk_max_chars; - let ((), chunk_count) = tokio::try_join!( store_graph_entities(ctx.db, &ctx.pipeline_config.tuning, entities, relationships), store_vector_chunks( ctx.db, - ctx.services, ctx.task_id.as_str(), - &content, - chunk_range, + &chunks, &ctx.pipeline_config.tuning ) )?; - ctx.db.store_item(content).await?; + ctx.db.store_item(text_content).await?; ctx.db.rebuild_indexes().await?; debug!( @@ -252,17 +239,14 @@ async fn store_graph_entities( async fn store_vector_chunks( db: &SurrealDbClient, - services: &dyn PipelineServices, task_id: &str, - content: &TextContent, - chunk_range: std::ops::Range, + chunks: &[TextChunk], tuning: &super::config::IngestionTuning, ) -> Result { - let prepared_chunks = services.prepare_chunks(content, chunk_range).await?; - let chunk_count = prepared_chunks.len(); + let chunk_count = chunks.len(); let batch_size = tuning.chunk_insert_concurrency.max(1); - for chunk in &prepared_chunks { + for chunk in chunks { debug!( task_id = %task_id, chunk_id = %chunk.id, @@ -271,7 +255,7 @@ async fn store_vector_chunks( ); } - for batch in prepared_chunks.chunks(batch_size) { + for batch in chunks.chunks(batch_size) { store_chunk_batch(db, batch, tuning).await?; }