benchmarks: v1

Benchmarking ingestion, retrieval precision and performance
This commit is contained in:
Per Stark
2025-11-04 11:22:45 +01:00
parent 112a6965a4
commit 0eda65b07e
46 changed files with 8407 additions and 144 deletions

2
.cargo/config.toml Normal file
View File

@@ -0,0 +1,2 @@
[alias]
eval = "run -p eval --"

4
.gitignore vendored
View File

@@ -10,6 +10,9 @@ result
data
database
eval/cache/
eval/reports/
# Devenv
.devenv*
devenv.local.nix
@@ -21,3 +24,4 @@ devenv.local.nix
.pre-commit-config.yaml
# html-router/assets/style.css
html-router/node_modules
.fastembed_cache/

View File

@@ -1,5 +1,6 @@
# Changelog
## Unreleased
- Added a shared `benchmarks` crate with deterministic fixtures and Criterion suites for ingestion, retrieval, and HTML handlers, plus documented baseline results for local performance checks.
## Version 0.2.6 (2025-10-29)
- Added an opt-in FastEmbed-based reranking stage behind `reranking_enabled`. It improves retrieval accuracy by re-scoring hybrid results.

163
Cargo.lock generated
View File

@@ -184,6 +184,12 @@ dependencies = [
"libc",
]
[[package]]
name = "anes"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"
[[package]]
name = "anstream"
version = "0.6.18"
@@ -1090,6 +1096,12 @@ dependencies = [
"serde",
]
[[package]]
name = "cast"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "castaway"
version = "0.2.3"
@@ -1626,6 +1638,42 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "criterion"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f"
dependencies = [
"anes",
"cast",
"ciborium",
"clap",
"criterion-plot",
"is-terminal",
"itertools 0.10.5",
"num-traits",
"once_cell",
"oorandom",
"plotters",
"rayon",
"regex",
"serde",
"serde_derive",
"serde_json",
"tinytemplate",
"walkdir",
]
[[package]]
name = "criterion-plot"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
dependencies = [
"cast",
"itertools 0.10.5",
]
[[package]]
name = "crossbeam-deque"
version = "0.8.6"
@@ -2140,6 +2188,37 @@ dependencies = [
"num-traits",
]
[[package]]
name = "eval"
version = "0.1.0"
dependencies = [
"anyhow",
"async-openai",
"async-trait",
"chrono",
"common",
"composite-retrieval",
"criterion",
"fastembed",
"futures",
"ingestion-pipeline",
"object_store 0.11.2",
"once_cell",
"rand 0.8.5",
"serde",
"serde_json",
"serde_yaml",
"sha2",
"state-machines",
"surrealdb",
"tempfile",
"text-splitter",
"tokio",
"tracing",
"tracing-subscriber",
"uuid",
]
[[package]]
name = "event-listener"
version = "5.4.0"
@@ -2735,6 +2814,12 @@ version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024"
[[package]]
name = "hermit-abi"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
[[package]]
name = "hex"
version = "0.4.3"
@@ -3248,11 +3333,13 @@ dependencies = [
name = "ingestion-pipeline"
version = "0.1.0"
dependencies = [
"anyhow",
"async-openai",
"async-trait",
"axum",
"axum_typed_multipart",
"base64 0.22.1",
"bytes",
"chrono",
"common",
"composite-retrieval",
@@ -3330,6 +3417,17 @@ version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130"
[[package]]
name = "is-terminal"
version = "0.4.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46"
dependencies = [
"hermit-abi 0.5.2",
"libc",
"windows-sys 0.60.2",
]
[[package]]
name = "is_terminal_polyfill"
version = "1.70.1"
@@ -4240,7 +4338,7 @@ version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43"
dependencies = [
"hermit-abi",
"hermit-abi 0.3.9",
"libc",
]
@@ -4330,6 +4428,12 @@ dependencies = [
"pkg-config",
]
[[package]]
name = "oorandom"
version = "11.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
[[package]]
name = "opaque-debug"
version = "0.3.1"
@@ -4705,6 +4809,34 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
[[package]]
name = "plotters"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747"
dependencies = [
"num-traits",
"plotters-backend",
"plotters-svg",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "plotters-backend"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a"
[[package]]
name = "plotters-svg"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670"
dependencies = [
"plotters-backend",
]
[[package]]
name = "png"
version = "0.18.0"
@@ -5920,6 +6052,19 @@ dependencies = [
"syn 2.0.101",
]
[[package]]
name = "serde_yaml"
version = "0.9.34+deprecated"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47"
dependencies = [
"indexmap 2.9.0",
"itoa",
"ryu",
"serde",
"unsafe-libyaml",
]
[[package]]
name = "servo_arc"
version = "0.4.0"
@@ -6735,6 +6880,16 @@ dependencies = [
"zerovec",
]
[[package]]
name = "tinytemplate"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
dependencies = [
"serde",
"serde_json",
]
[[package]]
name = "tinyvec"
version = "1.9.0"
@@ -7263,6 +7418,12 @@ dependencies = [
"subtle",
]
[[package]]
name = "unsafe-libyaml"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861"
[[package]]
name = "untrusted"
version = "0.9.0"

View File

@@ -6,7 +6,8 @@ members = [
"html-router",
"ingestion-pipeline",
"composite-retrieval",
"json-stream-parser"
"json-stream-parser",
"eval"
]
resolver = "2"

View File

@@ -7,7 +7,7 @@ use include_dir::{include_dir, Dir};
use std::{ops::Deref, sync::Arc};
use surrealdb::{
engine::any::{connect, Any},
opt::auth::Root,
opt::auth::{Namespace, Root},
Error, Notification, Surreal,
};
use surrealdb_migrations::MigrationRunner;
@@ -48,6 +48,24 @@ impl SurrealDbClient {
Ok(SurrealDbClient { client: db })
}
pub async fn new_with_namespace_user(
address: &str,
namespace: &str,
username: &str,
password: &str,
database: &str,
) -> Result<Self, Error> {
let db = connect(address).await?;
db.signin(Namespace {
namespace,
username,
password,
})
.await?;
db.use_ns(namespace).use_db(database).await?;
Ok(SurrealDbClient { client: db })
}
pub async fn create_session_store(
&self,
) -> Result<SessionStore<SessionSurrealPool<Any>>, SessionError> {

View File

@@ -12,6 +12,7 @@ pub struct RetrievalTuning {
pub token_budget_estimate: usize,
pub avg_chars_per_token: usize,
pub max_chunks_per_entity: usize,
pub lexical_match_weight: f32,
pub graph_traversal_seed_limit: usize,
pub graph_neighbor_limit: usize,
pub graph_score_decay: f32,
@@ -31,9 +32,10 @@ impl Default for RetrievalTuning {
chunk_fts_take: 20,
score_threshold: 0.35,
fallback_min_results: 10,
token_budget_estimate: 2800,
token_budget_estimate: 10000,
avg_chars_per_token: 4,
max_chunks_per_entity: 4,
lexical_match_weight: 0.15,
graph_traversal_seed_limit: 5,
graph_neighbor_limit: 6,
graph_score_decay: 0.75,

View File

@@ -0,0 +1,51 @@
use serde::Serialize;
/// Captures instrumentation for each hybrid retrieval stage when diagnostics are enabled.
#[derive(Debug, Clone, Default, Serialize)]
pub struct PipelineDiagnostics {
pub collect_candidates: Option<CollectCandidatesStats>,
pub enrich_chunks_from_entities: Option<ChunkEnrichmentStats>,
pub assemble: Option<AssembleStats>,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct CollectCandidatesStats {
pub vector_entity_candidates: usize,
pub vector_chunk_candidates: usize,
pub fts_entity_candidates: usize,
pub fts_chunk_candidates: usize,
pub vector_chunk_scores: Vec<f32>,
pub fts_chunk_scores: Vec<f32>,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct ChunkEnrichmentStats {
pub filtered_entity_count: usize,
pub fallback_min_results: usize,
pub chunk_sources_considered: usize,
pub chunk_candidates_before_enrichment: usize,
pub chunk_candidates_after_enrichment: usize,
pub top_chunk_scores: Vec<f32>,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct AssembleStats {
pub token_budget_start: usize,
pub token_budget_spent: usize,
pub token_budget_remaining: usize,
pub budget_exhausted: bool,
pub chunks_selected: usize,
pub chunks_skipped_due_budget: usize,
pub entity_count: usize,
pub entity_traces: Vec<EntityAssemblyTrace>,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct EntityAssemblyTrace {
pub entity_id: String,
pub source_id: String,
pub inspected_candidates: usize,
pub selected_chunk_ids: Vec<String>,
pub selected_chunk_scores: Vec<f32>,
pub skipped_due_budget: usize,
}

View File

@@ -1,14 +1,57 @@
mod config;
mod diagnostics;
mod stages;
mod state;
pub use config::{RetrievalConfig, RetrievalTuning};
pub use diagnostics::{
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
PipelineDiagnostics,
};
use crate::{reranking::RerankerLease, RetrievedEntity};
use async_openai::Client;
use common::{error::AppError, storage::db::SurrealDbClient};
use tracing::info;
#[derive(Debug)]
pub struct PipelineRunOutput {
pub results: Vec<RetrievedEntity>,
pub diagnostics: Option<PipelineDiagnostics>,
pub stage_timings: PipelineStageTimings,
}
#[derive(Debug, Clone, Default, serde::Serialize)]
pub struct PipelineStageTimings {
pub collect_candidates_ms: u128,
pub graph_expansion_ms: u128,
pub chunk_attach_ms: u128,
pub rerank_ms: u128,
pub assemble_ms: u128,
}
impl PipelineStageTimings {
fn record_collect_candidates(&mut self, duration: std::time::Duration) {
self.collect_candidates_ms += duration.as_millis() as u128;
}
fn record_graph_expansion(&mut self, duration: std::time::Duration) {
self.graph_expansion_ms += duration.as_millis() as u128;
}
fn record_chunk_attach(&mut self, duration: std::time::Duration) {
self.chunk_attach_ms += duration.as_millis() as u128;
}
fn record_rerank(&mut self, duration: std::time::Duration) {
self.rerank_ms += duration.as_millis() as u128;
}
fn record_assemble(&mut self, duration: std::time::Duration) {
self.assemble_ms += duration.as_millis() as u128;
}
}
/// Drives the retrieval pipeline from embedding through final assembly.
pub async fn run_pipeline(
db_client: &SurrealDbClient,
@@ -18,7 +61,6 @@ pub async fn run_pipeline(
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<Vec<RetrievedEntity>, AppError> {
let machine = state::ready();
let input_chars = input_text.chars().count();
let input_preview: String = input_text.chars().take(120).collect();
let input_preview_clean = input_preview.replace('\n', " ");
@@ -30,7 +72,7 @@ pub async fn run_pipeline(
preview = %input_preview_clean,
"Starting ingestion retrieval pipeline"
);
let mut ctx = stages::PipelineContext::new(
let ctx = stages::PipelineContext::new(
db_client,
openai_client,
input_text.to_owned(),
@@ -38,17 +80,11 @@ pub async fn run_pipeline(
config,
reranker,
);
let machine = stages::embed(machine, &mut ctx).await?;
let machine = stages::collect_candidates(machine, &mut ctx).await?;
let machine = stages::expand_graph(machine, &mut ctx).await?;
let machine = stages::attach_chunks(machine, &mut ctx).await?;
let machine = stages::rerank(machine, &mut ctx).await?;
let results = stages::assemble(machine, &mut ctx)?;
let outcome = run_pipeline_internal(ctx, false).await?;
Ok(results)
Ok(outcome.results)
}
#[cfg(test)]
pub async fn run_pipeline_with_embedding(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
@@ -58,8 +94,7 @@ pub async fn run_pipeline_with_embedding(
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<Vec<RetrievedEntity>, AppError> {
let machine = state::ready();
let mut ctx = stages::PipelineContext::with_embedding(
let ctx = stages::PipelineContext::with_embedding(
db_client,
openai_client,
query_embedding,
@@ -68,14 +103,54 @@ pub async fn run_pipeline_with_embedding(
config,
reranker,
);
let machine = stages::embed(machine, &mut ctx).await?;
let machine = stages::collect_candidates(machine, &mut ctx).await?;
let machine = stages::expand_graph(machine, &mut ctx).await?;
let machine = stages::attach_chunks(machine, &mut ctx).await?;
let machine = stages::rerank(machine, &mut ctx).await?;
let results = stages::assemble(machine, &mut ctx)?;
let outcome = run_pipeline_internal(ctx, false).await?;
Ok(results)
Ok(outcome.results)
}
/// Runs the pipeline with a precomputed embedding and returns stage metrics.
pub async fn run_pipeline_with_embedding_with_metrics(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<PipelineRunOutput, AppError> {
let ctx = stages::PipelineContext::with_embedding(
db_client,
openai_client,
query_embedding,
input_text.to_owned(),
user_id.to_owned(),
config,
reranker,
);
run_pipeline_internal(ctx, false).await
}
pub async fn run_pipeline_with_embedding_with_diagnostics(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<PipelineRunOutput, AppError> {
let ctx = stages::PipelineContext::with_embedding(
db_client,
openai_client,
query_embedding,
input_text.to_owned(),
user_id.to_owned(),
config,
reranker,
);
run_pipeline_internal(ctx, true).await
}
/// Helper exposed for tests to convert retrieved entities into downstream prompt JSON.
@@ -101,6 +176,37 @@ pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::V
.collect::<Vec<_>>())
}
async fn run_pipeline_internal(
mut ctx: stages::PipelineContext<'_>,
capture_diagnostics: bool,
) -> Result<PipelineRunOutput, AppError> {
if capture_diagnostics {
ctx.enable_diagnostics();
}
let results = drive_pipeline(&mut ctx).await?;
let diagnostics = ctx.take_diagnostics();
Ok(PipelineRunOutput {
results,
diagnostics,
stage_timings: ctx.take_stage_timings(),
})
}
async fn drive_pipeline(
ctx: &mut stages::PipelineContext<'_>,
) -> Result<Vec<RetrievedEntity>, AppError> {
let machine = state::ready();
let machine = stages::embed(machine, ctx).await?;
let machine = stages::collect_candidates(machine, ctx).await?;
let machine = stages::expand_graph(machine, ctx).await?;
let machine = stages::attach_chunks(machine, ctx).await?;
let machine = stages::rerank(machine, ctx).await?;
let results = stages::assemble(machine, ctx)?;
Ok(results)
}
fn round_score(value: f32) -> f64 {
(f64::from(value) * 1000.0).round() / 1000.0
}

View File

@@ -10,7 +10,11 @@ use common::{
use fastembed::RerankResult;
use futures::{stream::FuturesUnordered, StreamExt};
use state_machines::core::GuardError;
use std::collections::{HashMap, HashSet};
use std::{
cmp::Ordering,
collections::{HashMap, HashSet},
time::Instant,
};
use tracing::{debug, instrument, warn};
use crate::{
@@ -27,10 +31,15 @@ use crate::{
use super::{
config::RetrievalConfig,
diagnostics::{
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
PipelineDiagnostics,
},
state::{
CandidatesLoaded, ChunksAttached, Embedded, GraphExpanded, HybridRetrievalMachine, Ready,
Reranked,
},
PipelineStageTimings,
};
pub struct PipelineContext<'a> {
@@ -45,6 +54,8 @@ pub struct PipelineContext<'a> {
pub filtered_entities: Vec<Scored<KnowledgeEntity>>,
pub chunk_values: Vec<Scored<TextChunk>>,
pub reranker: Option<RerankerLease>,
pub diagnostics: Option<PipelineDiagnostics>,
stage_timings: PipelineStageTimings,
}
impl<'a> PipelineContext<'a> {
@@ -68,10 +79,11 @@ impl<'a> PipelineContext<'a> {
filtered_entities: Vec::new(),
chunk_values: Vec::new(),
reranker,
diagnostics: None,
stage_timings: PipelineStageTimings::default(),
}
}
#[cfg(test)]
pub fn with_embedding(
db_client: &'a SurrealDbClient,
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
@@ -100,6 +112,62 @@ impl<'a> PipelineContext<'a> {
)
})
}
pub fn enable_diagnostics(&mut self) {
if self.diagnostics.is_none() {
self.diagnostics = Some(PipelineDiagnostics::default());
}
}
pub fn diagnostics_enabled(&self) -> bool {
self.diagnostics.is_some()
}
pub fn record_collect_candidates(&mut self, stats: CollectCandidatesStats) {
if let Some(diag) = self.diagnostics.as_mut() {
diag.collect_candidates = Some(stats);
}
}
pub fn record_chunk_enrichment(&mut self, stats: ChunkEnrichmentStats) {
if let Some(diag) = self.diagnostics.as_mut() {
diag.enrich_chunks_from_entities = Some(stats);
}
}
pub fn record_assemble(&mut self, stats: AssembleStats) {
if let Some(diag) = self.diagnostics.as_mut() {
diag.assemble = Some(stats);
}
}
pub fn take_diagnostics(&mut self) -> Option<PipelineDiagnostics> {
self.diagnostics.take()
}
pub fn record_collect_candidates_timing(&mut self, duration: std::time::Duration) {
self.stage_timings.record_collect_candidates(duration);
}
pub fn record_graph_expansion_timing(&mut self, duration: std::time::Duration) {
self.stage_timings.record_graph_expansion(duration);
}
pub fn record_chunk_attach_timing(&mut self, duration: std::time::Duration) {
self.stage_timings.record_chunk_attach(duration);
}
pub fn record_rerank_timing(&mut self, duration: std::time::Duration) {
self.stage_timings.record_rerank(duration);
}
pub fn record_assemble_timing(&mut self, duration: std::time::Duration) {
self.stage_timings.record_assemble(duration);
}
pub fn take_stage_timings(&mut self) -> PipelineStageTimings {
std::mem::take(&mut self.stage_timings)
}
}
#[instrument(level = "trace", skip_all)]
@@ -127,6 +195,7 @@ pub async fn collect_candidates(
machine: HybridRetrievalMachine<(), Embedded>,
ctx: &mut PipelineContext<'_>,
) -> Result<HybridRetrievalMachine<(), CandidatesLoaded>, AppError> {
let stage_start = Instant::now();
debug!("Collecting initial candidates via vector and FTS search");
let embedding = ctx.ensure_embedding()?.clone();
let tuning = &ctx.config.tuning;
@@ -172,6 +241,19 @@ pub async fn collect_candidates(
"Hybrid retrieval initial candidate counts"
);
if ctx.diagnostics_enabled() {
ctx.record_collect_candidates(CollectCandidatesStats {
vector_entity_candidates: vector_entities.len(),
vector_chunk_candidates: vector_chunks.len(),
fts_entity_candidates: fts_entities.len(),
fts_chunk_candidates: fts_chunks.len(),
vector_chunk_scores: sample_scores(&vector_chunks, |chunk| {
chunk.scores.vector.unwrap_or(0.0)
}),
fts_chunk_scores: sample_scores(&fts_chunks, |chunk| chunk.scores.fts.unwrap_or(0.0)),
});
}
normalize_fts_scores(&mut fts_entities);
normalize_fts_scores(&mut fts_chunks);
@@ -183,9 +265,11 @@ pub async fn collect_candidates(
apply_fusion(&mut ctx.entity_candidates, weights);
apply_fusion(&mut ctx.chunk_candidates, weights);
machine
let next = machine
.collect_candidates()
.map_err(|(_, guard)| map_guard_error("collect_candidates", guard))
.map_err(|(_, guard)| map_guard_error("collect_candidates", guard))?;
ctx.record_collect_candidates_timing(stage_start.elapsed());
Ok(next)
}
#[instrument(level = "trace", skip_all)]
@@ -193,82 +277,84 @@ pub async fn expand_graph(
machine: HybridRetrievalMachine<(), CandidatesLoaded>,
ctx: &mut PipelineContext<'_>,
) -> Result<HybridRetrievalMachine<(), GraphExpanded>, AppError> {
let stage_start = Instant::now();
debug!("Expanding candidates using graph relationships");
let tuning = &ctx.config.tuning;
let weights = FusionWeights::default();
let next = {
let tuning = &ctx.config.tuning;
let weights = FusionWeights::default();
if ctx.entity_candidates.is_empty() {
return machine
.expand_graph()
.map_err(|(_, guard)| map_guard_error("expand_graph", guard));
}
if ctx.entity_candidates.is_empty() {
machine
.expand_graph()
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
} else {
let graph_seeds = seeds_from_candidates(
&ctx.entity_candidates,
tuning.graph_seed_min_score,
tuning.graph_traversal_seed_limit,
);
let graph_seeds = seeds_from_candidates(
&ctx.entity_candidates,
tuning.graph_seed_min_score,
tuning.graph_traversal_seed_limit,
);
if graph_seeds.is_empty() {
machine
.expand_graph()
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
} else {
let mut futures = FuturesUnordered::new();
for seed in graph_seeds {
let db = ctx.db_client;
let user = ctx.user_id.clone();
let limit = tuning.graph_neighbor_limit;
futures.push(async move {
let neighbors =
find_entities_by_relationship_by_id(db, &seed.id, &user, limit).await;
(seed, neighbors)
});
}
if graph_seeds.is_empty() {
return machine
.expand_graph()
.map_err(|(_, guard)| map_guard_error("expand_graph", guard));
}
while let Some((seed, neighbors_result)) = futures.next().await {
let neighbors = neighbors_result.map_err(AppError::from)?;
if neighbors.is_empty() {
continue;
}
let mut futures = FuturesUnordered::new();
for seed in graph_seeds {
let db = ctx.db_client;
let user = ctx.user_id.clone();
futures.push(async move {
let neighbors = find_entities_by_relationship_by_id(
db,
&seed.id,
&user,
tuning.graph_neighbor_limit,
)
.await;
(seed, neighbors)
});
}
for neighbor in neighbors {
if neighbor.id == seed.id {
continue;
}
while let Some((seed, neighbors_result)) = futures.next().await {
let neighbors = neighbors_result.map_err(AppError::from)?;
if neighbors.is_empty() {
continue;
let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay);
let entry = ctx
.entity_candidates
.entry(neighbor.id.clone())
.or_insert_with(|| Scored::new(neighbor.clone()));
entry.item = neighbor;
let inherited_vector =
clamp_unit(graph_score * tuning.graph_vector_inheritance);
let vector_existing = entry.scores.vector.unwrap_or(0.0);
if inherited_vector > vector_existing {
entry.scores.vector = Some(inherited_vector);
}
let existing_graph = entry.scores.graph.unwrap_or(f32::MIN);
if graph_score > existing_graph || entry.scores.graph.is_none() {
entry.scores.graph = Some(graph_score);
}
let fused = fuse_scores(&entry.scores, weights);
entry.update_fused(fused);
}
}
machine
.expand_graph()
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
}
}
for neighbor in neighbors {
if neighbor.id == seed.id {
continue;
}
let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay);
let entry = ctx
.entity_candidates
.entry(neighbor.id.clone())
.or_insert_with(|| Scored::new(neighbor.clone()));
entry.item = neighbor;
let inherited_vector = clamp_unit(graph_score * tuning.graph_vector_inheritance);
let vector_existing = entry.scores.vector.unwrap_or(0.0);
if inherited_vector > vector_existing {
entry.scores.vector = Some(inherited_vector);
}
let existing_graph = entry.scores.graph.unwrap_or(f32::MIN);
if graph_score > existing_graph || entry.scores.graph.is_none() {
entry.scores.graph = Some(graph_score);
}
let fused = fuse_scores(&entry.scores, weights);
entry.update_fused(fused);
}
}
machine
.expand_graph()
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
}?;
ctx.record_graph_expansion_timing(stage_start.elapsed());
Ok(next)
}
#[instrument(level = "trace", skip_all)]
@@ -276,11 +362,14 @@ pub async fn attach_chunks(
machine: HybridRetrievalMachine<(), GraphExpanded>,
ctx: &mut PipelineContext<'_>,
) -> Result<HybridRetrievalMachine<(), ChunksAttached>, AppError> {
let stage_start = Instant::now();
debug!("Attaching chunks to surviving entities");
let tuning = &ctx.config.tuning;
let weights = FusionWeights::default();
let chunk_by_source = group_chunks_by_source(&ctx.chunk_candidates);
let chunk_candidates_before = ctx.chunk_candidates.len();
let chunk_sources_considered = chunk_by_source.len();
backfill_entities_from_chunks(
&mut ctx.entity_candidates,
@@ -312,6 +401,8 @@ pub async fn attach_chunks(
ctx.filtered_entities = filtered_entities;
let query_embedding = ctx.ensure_embedding()?.clone();
let mut chunk_results: Vec<Scored<TextChunk>> =
ctx.chunk_candidates.values().cloned().collect();
sort_by_fused_desc(&mut chunk_results);
@@ -327,17 +418,31 @@ pub async fn attach_chunks(
ctx.db_client,
&ctx.user_id,
weights,
&query_embedding,
)
.await?;
let mut chunk_values: Vec<Scored<TextChunk>> = chunk_by_id.into_values().collect();
sort_by_fused_desc(&mut chunk_values);
if ctx.diagnostics_enabled() {
ctx.record_chunk_enrichment(ChunkEnrichmentStats {
filtered_entity_count: ctx.filtered_entities.len(),
fallback_min_results: tuning.fallback_min_results,
chunk_sources_considered,
chunk_candidates_before_enrichment: chunk_candidates_before,
chunk_candidates_after_enrichment: chunk_values.len(),
top_chunk_scores: sample_scores(&chunk_values, |chunk| chunk.fused),
});
}
ctx.chunk_values = chunk_values;
machine
let next = machine
.attach_chunks()
.map_err(|(_, guard)| map_guard_error("attach_chunks", guard))
.map_err(|(_, guard)| map_guard_error("attach_chunks", guard))?;
ctx.record_chunk_attach_timing(stage_start.elapsed());
Ok(next)
}
#[instrument(level = "trace", skip_all)]
@@ -345,6 +450,7 @@ pub async fn rerank(
machine: HybridRetrievalMachine<(), ChunksAttached>,
ctx: &mut PipelineContext<'_>,
) -> Result<HybridRetrievalMachine<(), Reranked>, AppError> {
let stage_start = Instant::now();
let mut applied = false;
if let Some(reranker) = ctx.reranker.as_ref() {
@@ -384,9 +490,11 @@ pub async fn rerank(
debug!("Applied reranking adjustments to candidate ordering");
}
machine
let next = machine
.rerank()
.map_err(|(_, guard)| map_guard_error("rerank", guard))
.map_err(|(_, guard)| map_guard_error("rerank", guard))?;
ctx.record_rerank_timing(stage_start.elapsed());
Ok(next)
}
#[instrument(level = "trace", skip_all)]
@@ -394,8 +502,11 @@ pub fn assemble(
machine: HybridRetrievalMachine<(), Reranked>,
ctx: &mut PipelineContext<'_>,
) -> Result<Vec<RetrievedEntity>, AppError> {
let stage_start = Instant::now();
debug!("Assembling final retrieved entities");
let tuning = &ctx.config.tuning;
let query_embedding = ctx.ensure_embedding()?.clone();
let question_terms = extract_keywords(&ctx.input_text);
let mut chunk_by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
for chunk in ctx.chunk_values.drain(..) {
@@ -406,39 +517,68 @@ pub fn assemble(
}
for chunk_list in chunk_by_source.values_mut() {
sort_by_fused_desc(chunk_list);
chunk_list.sort_by(|a, b| {
let sim_a = cosine_similarity(&query_embedding, &a.item.embedding);
let sim_b = cosine_similarity(&query_embedding, &b.item.embedding);
sim_b.partial_cmp(&sim_a).unwrap_or(Ordering::Equal)
});
}
let mut token_budget_remaining = tuning.token_budget_estimate;
let mut results = Vec::new();
let diagnostics_enabled = ctx.diagnostics_enabled();
let mut per_entity_traces = Vec::new();
let mut chunks_skipped_due_budget = 0usize;
let mut chunks_selected = 0usize;
let mut tokens_spent = 0usize;
for entity in &ctx.filtered_entities {
let mut selected_chunks = Vec::new();
let mut entity_trace = if diagnostics_enabled {
Some(EntityAssemblyTrace {
entity_id: entity.item.id.clone(),
source_id: entity.item.source_id.clone(),
inspected_candidates: 0,
selected_chunk_ids: Vec::new(),
selected_chunk_scores: Vec::new(),
skipped_due_budget: 0,
})
} else {
None
};
if let Some(candidates) = chunk_by_source.get_mut(&entity.item.source_id) {
rank_chunks_by_combined_score(candidates, &question_terms, tuning.lexical_match_weight);
let mut per_entity_count = 0;
candidates.sort_by(|a, b| {
b.fused
.partial_cmp(&a.fused)
.unwrap_or(std::cmp::Ordering::Equal)
});
for candidate in candidates.iter() {
if let Some(trace) = entity_trace.as_mut() {
trace.inspected_candidates += 1;
}
if per_entity_count >= tuning.max_chunks_per_entity {
break;
}
let estimated_tokens =
estimate_tokens(&candidate.item.chunk, tuning.avg_chars_per_token);
if estimated_tokens > token_budget_remaining {
chunks_skipped_due_budget += 1;
if let Some(trace) = entity_trace.as_mut() {
trace.skipped_due_budget += 1;
}
continue;
}
token_budget_remaining = token_budget_remaining.saturating_sub(estimated_tokens);
tokens_spent += estimated_tokens;
per_entity_count += 1;
chunks_selected += 1;
selected_chunks.push(RetrievedChunk {
chunk: candidate.item.clone(),
score: candidate.fused,
});
if let Some(trace) = entity_trace.as_mut() {
trace.selected_chunk_ids.push(candidate.item.id.clone());
trace.selected_chunk_scores.push(candidate.fused);
}
}
}
@@ -448,17 +588,48 @@ pub fn assemble(
chunks: selected_chunks,
});
if let Some(trace) = entity_trace {
per_entity_traces.push(trace);
}
if token_budget_remaining == 0 {
break;
}
}
if diagnostics_enabled {
ctx.record_assemble(AssembleStats {
token_budget_start: tuning.token_budget_estimate,
token_budget_spent: tokens_spent,
token_budget_remaining,
budget_exhausted: token_budget_remaining == 0,
chunks_selected,
chunks_skipped_due_budget,
entity_count: ctx.filtered_entities.len(),
entity_traces: per_entity_traces,
});
}
machine
.assemble()
.map_err(|(_, guard)| map_guard_error("assemble", guard))?;
ctx.record_assemble_timing(stage_start.elapsed());
Ok(results)
}
const SCORE_SAMPLE_LIMIT: usize = 8;
fn sample_scores<T, F>(items: &[Scored<T>], mut extractor: F) -> Vec<f32>
where
F: FnMut(&Scored<T>) -> f32,
{
items
.iter()
.take(SCORE_SAMPLE_LIMIT)
.map(|item| extractor(item))
.collect()
}
fn map_guard_error(stage: &'static str, err: GuardError) -> AppError {
AppError::InternalError(format!(
"state machine guard '{stage}' failed: guard={}, event={}, kind={:?}",
@@ -582,6 +753,7 @@ async fn enrich_chunks_from_entities(
db_client: &SurrealDbClient,
user_id: &str,
weights: FusionWeights,
query_embedding: &[f32],
) -> Result<(), AppError> {
let mut source_ids: HashSet<String> = HashSet::new();
for entity in entities {
@@ -615,7 +787,16 @@ async fn enrich_chunks_from_entities(
.copied()
.unwrap_or(0.0);
entry.scores.vector = Some(entry.scores.vector.unwrap_or(0.0).max(entity_score * 0.8));
let similarity = cosine_similarity(query_embedding, &chunk.embedding);
entry.scores.vector = Some(
entry
.scores
.vector
.unwrap_or(0.0)
.max(entity_score * 0.8)
.max(similarity),
);
let fused = fuse_scores(&entry.scores, weights);
entry.update_fused(fused);
entry.item = chunk;
@@ -624,6 +805,24 @@ async fn enrich_chunks_from_entities(
Ok(())
}
fn cosine_similarity(query: &[f32], embedding: &[f32]) -> f32 {
if query.is_empty() || embedding.is_empty() || query.len() != embedding.len() {
return 0.0;
}
let mut dot = 0.0f32;
let mut norm_q = 0.0f32;
let mut norm_e = 0.0f32;
for (q, e) in query.iter().zip(embedding.iter()) {
dot += q * e;
norm_q += q * q;
norm_e += e * e;
}
if norm_q == 0.0 || norm_e == 0.0 {
return 0.0;
}
dot / (norm_q.sqrt() * norm_e.sqrt())
}
fn build_rerank_documents(ctx: &PipelineContext<'_>, max_chunks_per_entity: usize) -> Vec<String> {
if ctx.filtered_entities.is_empty() {
return Vec::new();
@@ -736,6 +935,48 @@ fn estimate_tokens(text: &str, avg_chars_per_token: usize) -> usize {
(chars / avg_chars_per_token).max(1)
}
fn rank_chunks_by_combined_score(
candidates: &mut [Scored<TextChunk>],
question_terms: &[String],
lexical_weight: f32,
) {
if lexical_weight > 0.0 && !question_terms.is_empty() {
for candidate in candidates.iter_mut() {
let lexical = lexical_overlap_score(question_terms, &candidate.item.chunk);
let combined = clamp_unit(candidate.fused + lexical_weight * lexical);
candidate.update_fused(combined);
}
}
candidates.sort_by(|a, b| b.fused.partial_cmp(&a.fused).unwrap_or(Ordering::Equal));
}
fn extract_keywords(text: &str) -> Vec<String> {
let mut terms = Vec::new();
for raw in text.split(|c: char| !c.is_alphanumeric()) {
let term = raw.trim().to_ascii_lowercase();
if term.len() >= 3 {
terms.push(term);
}
}
terms.sort();
terms.dedup();
terms
}
fn lexical_overlap_score(terms: &[String], haystack: &str) -> f32 {
if terms.is_empty() {
return 0.0;
}
let lower = haystack.to_ascii_lowercase();
let mut matches = 0usize;
for term in terms {
if lower.contains(term) {
matches += 1;
}
}
(matches as f32) / (terms.len() as f32)
}
#[derive(Clone)]
struct GraphSeed {
id: String,

View File

@@ -68,6 +68,7 @@ where
{
let embedding_literal = serde_json::to_string(&query_embedding)
.map_err(|err| AppError::InternalError(format!("Failed to serialize embedding: {err}")))?;
let closest_query = format!(
"SELECT id, vector::distance::knn() AS distance \
FROM {table} \

33
eval/Cargo.toml Normal file
View File

@@ -0,0 +1,33 @@
[package]
name = "eval"
version = "0.1.0"
edition = "2021"
[dependencies]
anyhow = { workspace = true }
async-openai = { workspace = true }
chrono = { workspace = true }
common = { path = "../common" }
composite-retrieval = { path = "../composite-retrieval" }
ingestion-pipeline = { path = "../ingestion-pipeline" }
futures = { workspace = true }
fastembed = { workspace = true }
serde = { workspace = true, features = ["derive"] }
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
uuid = { workspace = true }
text-splitter = { workspace = true }
rand = "0.8"
sha2 = { workspace = true }
object_store = { workspace = true }
surrealdb = { workspace = true }
serde_json = { workspace = true }
async-trait = { workspace = true }
once_cell = "1.19"
serde_yaml = "0.9"
criterion = "0.5"
state-machines = { workspace = true }
[dev-dependencies]
tempfile = { workspace = true }

33
eval/manifest.yaml Normal file
View File

@@ -0,0 +1,33 @@
default_dataset: squad-v2
datasets:
- id: squad-v2
label: "SQuAD v2.0"
category: "SQuAD v2.0"
entity_suffix: "SQuAD"
source_prefix: "squad"
raw: "data/raw/squad/dev-v2.0.json"
converted: "data/converted/squad-minne.json"
include_unanswerable: false
slices:
- id: squad-dev-200
label: "SQuAD dev (200)"
description: "Deterministic 200-case slice for local eval"
limit: 200
corpus_limit: 2000
seed: 0x5eed2025
- id: natural-questions-dev
label: "Natural Questions (dev)"
category: "Natural Questions"
entity_suffix: "Natural Questions"
source_prefix: "nq"
raw: "data/raw/nq/dev-all.jsonl"
converted: "data/converted/nq-dev-minne.json"
include_unanswerable: true
slices:
- id: nq-dev-200
label: "NQ dev (200)"
description: "200-case slice of the dev set"
limit: 200
corpus_limit: 2000
include_unanswerable: false
seed: 0x5eed2025

638
eval/src/args.rs Normal file
View File

@@ -0,0 +1,638 @@
use std::{
env,
path::{Path, PathBuf},
};
use anyhow::{anyhow, Context, Result};
use crate::datasets::DatasetKind;
pub const DEFAULT_SLICE_SEED: u64 = 0x5eed_2025;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EmbeddingBackend {
Hashed,
FastEmbed,
}
impl Default for EmbeddingBackend {
fn default() -> Self {
Self::FastEmbed
}
}
impl std::str::FromStr for EmbeddingBackend {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_ascii_lowercase().as_str() {
"hashed" => Ok(Self::Hashed),
"fastembed" | "fast-embed" | "fast" => Ok(Self::FastEmbed),
other => Err(anyhow!(
"unknown embedding backend '{other}'. Expected 'hashed' or 'fastembed'."
)),
}
}
}
#[derive(Debug, Clone)]
pub struct Config {
pub convert_only: bool,
pub force_convert: bool,
pub dataset: DatasetKind,
pub llm_mode: bool,
pub corpus_limit: Option<usize>,
pub raw_dataset_path: PathBuf,
pub converted_dataset_path: PathBuf,
pub report_dir: PathBuf,
pub k: usize,
pub limit: Option<usize>,
pub summary_sample: usize,
pub full_context: bool,
pub chunk_min_chars: usize,
pub chunk_max_chars: usize,
pub chunk_vector_take: Option<usize>,
pub chunk_fts_take: Option<usize>,
pub chunk_token_budget: Option<usize>,
pub chunk_avg_chars_per_token: Option<usize>,
pub max_chunks_per_entity: Option<usize>,
pub rerank: bool,
pub rerank_pool_size: usize,
pub rerank_keep_top: usize,
pub concurrency: usize,
pub embedding_backend: EmbeddingBackend,
pub embedding_model: Option<String>,
pub cache_dir: PathBuf,
pub ingestion_cache_dir: PathBuf,
pub refresh_embeddings_only: bool,
pub detailed_report: bool,
pub slice: Option<String>,
pub reseed_slice: bool,
pub slice_seed: u64,
pub slice_grow: Option<usize>,
pub slice_offset: usize,
pub slice_reset_ingestion: bool,
pub negative_multiplier: f32,
pub label: Option<String>,
pub chunk_diagnostics_path: Option<PathBuf>,
pub inspect_question: Option<String>,
pub inspect_manifest: Option<PathBuf>,
pub query_model: Option<String>,
pub perf_log_json: Option<PathBuf>,
pub perf_log_dir: Option<PathBuf>,
pub perf_log_console: bool,
pub db_endpoint: String,
pub db_username: String,
pub db_password: String,
pub db_namespace: Option<String>,
pub db_database: Option<String>,
pub inspect_db_state: Option<PathBuf>,
}
impl Default for Config {
fn default() -> Self {
let dataset = DatasetKind::default();
Self {
convert_only: false,
force_convert: false,
dataset,
llm_mode: false,
corpus_limit: None,
raw_dataset_path: dataset.default_raw_path(),
converted_dataset_path: dataset.default_converted_path(),
report_dir: PathBuf::from("eval/reports"),
k: 5,
limit: Some(200),
summary_sample: 5,
full_context: false,
chunk_min_chars: 500,
chunk_max_chars: 2_000,
chunk_vector_take: None,
chunk_fts_take: None,
chunk_token_budget: None,
chunk_avg_chars_per_token: None,
max_chunks_per_entity: None,
rerank: true,
rerank_pool_size: 16,
rerank_keep_top: 10,
concurrency: 4,
embedding_backend: EmbeddingBackend::FastEmbed,
embedding_model: None,
cache_dir: PathBuf::from("eval/cache"),
ingestion_cache_dir: PathBuf::from("eval/cache/ingested"),
refresh_embeddings_only: false,
detailed_report: false,
slice: None,
reseed_slice: false,
slice_seed: DEFAULT_SLICE_SEED,
slice_grow: None,
slice_offset: 0,
slice_reset_ingestion: false,
negative_multiplier: crate::slices::DEFAULT_NEGATIVE_MULTIPLIER,
label: None,
chunk_diagnostics_path: None,
inspect_question: None,
inspect_manifest: None,
query_model: None,
inspect_db_state: None,
perf_log_json: None,
perf_log_dir: None,
perf_log_console: false,
db_endpoint: "ws://127.0.0.1:8000".to_string(),
db_username: "root_user".to_string(),
db_password: "root_password".to_string(),
db_namespace: None,
db_database: None,
}
}
}
impl Config {
pub fn context_token_limit(&self) -> Option<usize> {
None
}
}
#[derive(Debug)]
pub struct ParsedArgs {
pub config: Config,
pub show_help: bool,
}
pub fn parse() -> Result<ParsedArgs> {
let mut config = Config::default();
let mut show_help = false;
let mut raw_overridden = false;
let mut converted_overridden = false;
let mut args = env::args().skip(1).peekable();
while let Some(arg) = args.next() {
match arg.as_str() {
"-h" | "--help" => {
show_help = true;
break;
}
"--convert-only" => config.convert_only = true,
"--force" | "--refresh" => config.force_convert = true,
"--llm-mode" => {
config.llm_mode = true;
}
"--dataset" => {
let value = take_value("--dataset", &mut args)?;
let parsed = value.parse::<DatasetKind>()?;
config.dataset = parsed;
if !raw_overridden {
config.raw_dataset_path = parsed.default_raw_path();
}
if !converted_overridden {
config.converted_dataset_path = parsed.default_converted_path();
}
}
"--slice" => {
let value = take_value("--slice", &mut args)?;
config.slice = Some(value);
}
"--label" => {
let value = take_value("--label", &mut args)?;
config.label = Some(value);
}
"--query-model" => {
let value = take_value("--query-model", &mut args)?;
if value.trim().is_empty() {
return Err(anyhow!("--query-model requires a non-empty model name"));
}
config.query_model = Some(value.trim().to_string());
}
"--slice-grow" => {
let value = take_value("--slice-grow", &mut args)?;
let parsed = value.parse::<usize>().with_context(|| {
format!("failed to parse --slice-grow value '{value}' as usize")
})?;
if parsed == 0 {
return Err(anyhow!("--slice-grow must be greater than zero"));
}
config.slice_grow = Some(parsed);
}
"--slice-offset" => {
let value = take_value("--slice-offset", &mut args)?;
let parsed = value.parse::<usize>().with_context(|| {
format!("failed to parse --slice-offset value '{value}' as usize")
})?;
config.slice_offset = parsed;
}
"--raw" => {
let value = take_value("--raw", &mut args)?;
config.raw_dataset_path = PathBuf::from(value);
raw_overridden = true;
}
"--converted" => {
let value = take_value("--converted", &mut args)?;
config.converted_dataset_path = PathBuf::from(value);
converted_overridden = true;
}
"--corpus-limit" => {
let value = take_value("--corpus-limit", &mut args)?;
let parsed = value.parse::<usize>().with_context(|| {
format!("failed to parse --corpus-limit value '{value}' as usize")
})?;
config.corpus_limit = if parsed == 0 { None } else { Some(parsed) };
}
"--reseed-slice" => {
config.reseed_slice = true;
}
"--slice-reset-ingestion" => {
config.slice_reset_ingestion = true;
}
"--report-dir" => {
let value = take_value("--report-dir", &mut args)?;
config.report_dir = PathBuf::from(value);
}
"--k" => {
let value = take_value("--k", &mut args)?;
let parsed = value
.parse::<usize>()
.with_context(|| format!("failed to parse --k value '{value}' as usize"))?;
if parsed == 0 {
return Err(anyhow!("--k must be greater than zero"));
}
config.k = parsed;
}
"--limit" => {
let value = take_value("--limit", &mut args)?;
let parsed = value
.parse::<usize>()
.with_context(|| format!("failed to parse --limit value '{value}' as usize"))?;
config.limit = if parsed == 0 { None } else { Some(parsed) };
}
"--sample" => {
let value = take_value("--sample", &mut args)?;
let parsed = value.parse::<usize>().with_context(|| {
format!("failed to parse --sample value '{value}' as usize")
})?;
config.summary_sample = parsed.max(1);
}
"--full-context" => {
config.full_context = true;
}
"--chunk-min" => {
let value = take_value("--chunk-min", &mut args)?;
let parsed = value.parse::<usize>().with_context(|| {
format!("failed to parse --chunk-min value '{value}' as usize")
})?;
config.chunk_min_chars = parsed.max(1);
}
"--chunk-max" => {
let value = take_value("--chunk-max", &mut args)?;
let parsed = value.parse::<usize>().with_context(|| {
format!("failed to parse --chunk-max value '{value}' as usize")
})?;
config.chunk_max_chars = parsed.max(1);
}
"--chunk-vector-take" => {
let value = take_value("--chunk-vector-take", &mut args)?;
let parsed = value.parse::<usize>().with_context(|| {
format!("failed to parse --chunk-vector-take value '{value}' as usize")
})?;
if parsed == 0 {
return Err(anyhow!("--chunk-vector-take must be greater than zero"));
}
config.chunk_vector_take = Some(parsed);
}
"--chunk-fts-take" => {
let value = take_value("--chunk-fts-take", &mut args)?;
let parsed = value.parse::<usize>().with_context(|| {
format!("failed to parse --chunk-fts-take value '{value}' as usize")
})?;
if parsed == 0 {
return Err(anyhow!("--chunk-fts-take must be greater than zero"));
}
config.chunk_fts_take = Some(parsed);
}
"--chunk-token-budget" => {
let value = take_value("--chunk-token-budget", &mut args)?;
let parsed = value.parse::<usize>().with_context(|| {
format!("failed to parse --chunk-token-budget value '{value}' as usize")
})?;
if parsed == 0 {
return Err(anyhow!("--chunk-token-budget must be greater than zero"));
}
config.chunk_token_budget = Some(parsed);
}
"--chunk-token-chars" => {
let value = take_value("--chunk-token-chars", &mut args)?;
let parsed = value.parse::<usize>().with_context(|| {
format!("failed to parse --chunk-token-chars value '{value}' as usize")
})?;
if parsed == 0 {
return Err(anyhow!("--chunk-token-chars must be greater than zero"));
}
config.chunk_avg_chars_per_token = Some(parsed);
}
"--max-chunks-per-entity" => {
let value = take_value("--max-chunks-per-entity", &mut args)?;
let parsed = value.parse::<usize>().with_context(|| {
format!("failed to parse --max-chunks-per-entity value '{value}' as usize")
})?;
if parsed == 0 {
return Err(anyhow!("--max-chunks-per-entity must be greater than zero"));
}
config.max_chunks_per_entity = Some(parsed);
}
"--embedding" => {
let value = take_value("--embedding", &mut args)?;
config.embedding_backend = value.parse()?;
}
"--embedding-model" => {
let value = take_value("--embedding-model", &mut args)?;
config.embedding_model = Some(value.trim().to_string());
}
"--cache-dir" => {
let value = take_value("--cache-dir", &mut args)?;
config.cache_dir = PathBuf::from(value);
}
"--ingestion-cache-dir" => {
let value = take_value("--ingestion-cache-dir", &mut args)?;
config.ingestion_cache_dir = PathBuf::from(value);
}
"--negative-multiplier" => {
let value = take_value("--negative-multiplier", &mut args)?;
let parsed = value.parse::<f32>().with_context(|| {
format!("failed to parse --negative-multiplier value '{value}' as f32")
})?;
if !(parsed.is_finite() && parsed > 0.0) {
return Err(anyhow!(
"--negative-multiplier must be a positive finite number"
));
}
config.negative_multiplier = parsed;
}
"--no-rerank" => {
config.rerank = false;
}
"--rerank-pool" => {
let value = take_value("--rerank-pool", &mut args)?;
let parsed = value.parse::<usize>().with_context(|| {
format!("failed to parse --rerank-pool value '{value}' as usize")
})?;
config.rerank_pool_size = parsed.max(1);
}
"--rerank-keep" => {
let value = take_value("--rerank-keep", &mut args)?;
let parsed = value.parse::<usize>().with_context(|| {
format!("failed to parse --rerank-keep value '{value}' as usize")
})?;
config.rerank_keep_top = parsed.max(1);
}
"--concurrency" => {
let value = take_value("--concurrency", &mut args)?;
let parsed = value.parse::<usize>().with_context(|| {
format!("failed to parse --concurrency value '{value}' as usize")
})?;
config.concurrency = parsed.max(1);
}
"--refresh-embeddings" => {
config.refresh_embeddings_only = true;
}
"--detailed-report" => {
config.detailed_report = true;
}
"--chunk-diagnostics" => {
let value = take_value("--chunk-diagnostics", &mut args)?;
config.chunk_diagnostics_path = Some(PathBuf::from(value));
}
"--inspect-question" => {
let value = take_value("--inspect-question", &mut args)?;
config.inspect_question = Some(value);
}
"--inspect-manifest" => {
let value = take_value("--inspect-manifest", &mut args)?;
config.inspect_manifest = Some(PathBuf::from(value));
}
"--inspect-db-state" => {
let value = take_value("--inspect-db-state", &mut args)?;
config.inspect_db_state = Some(PathBuf::from(value));
}
"--perf-log-json" => {
let value = take_value("--perf-log-json", &mut args)?;
config.perf_log_json = Some(PathBuf::from(value));
}
"--perf-log-dir" => {
let value = take_value("--perf-log-dir", &mut args)?;
config.perf_log_dir = Some(PathBuf::from(value));
}
"--perf-log" => {
config.perf_log_console = true;
}
"--db-endpoint" => {
let value = take_value("--db-endpoint", &mut args)?;
config.db_endpoint = value;
}
"--db-user" => {
let value = take_value("--db-user", &mut args)?;
config.db_username = value;
}
"--db-pass" => {
let value = take_value("--db-pass", &mut args)?;
config.db_password = value;
}
"--db-namespace" => {
let value = take_value("--db-namespace", &mut args)?;
config.db_namespace = Some(value);
}
"--db-database" => {
let value = take_value("--db-database", &mut args)?;
config.db_database = Some(value);
}
unknown => {
return Err(anyhow!(
"unknown argument '{unknown}'. Use --help to see available options."
));
}
}
}
if config.chunk_min_chars >= config.chunk_max_chars {
return Err(anyhow!(
"--chunk-min must be less than --chunk-max (got {} >= {})",
config.chunk_min_chars,
config.chunk_max_chars
));
}
if config.rerank && config.rerank_pool_size == 0 {
return Err(anyhow!(
"--rerank-pool must be greater than zero when reranking is enabled"
));
}
if config.concurrency == 0 {
return Err(anyhow!("--concurrency must be greater than zero"));
}
if config.embedding_backend == EmbeddingBackend::Hashed && config.embedding_model.is_some() {
return Err(anyhow!(
"--embedding-model cannot be used with the 'hashed' embedding backend"
));
}
if let Some(limit) = config.limit {
if let Some(corpus_limit) = config.corpus_limit {
if corpus_limit < limit {
config.corpus_limit = Some(limit);
}
} else {
let default_multiplier = 10usize;
let mut computed = limit.saturating_mul(default_multiplier);
if computed < limit {
computed = limit;
}
let max_cap = 1_000usize;
if computed > max_cap {
computed = max_cap;
}
config.corpus_limit = Some(computed);
}
}
if config.perf_log_dir.is_none() {
if let Ok(dir) = env::var("EVAL_PERF_LOG_DIR") {
if !dir.trim().is_empty() {
config.perf_log_dir = Some(PathBuf::from(dir));
}
}
}
if let Ok(endpoint) = env::var("EVAL_DB_ENDPOINT") {
if !endpoint.trim().is_empty() {
config.db_endpoint = endpoint;
}
}
if let Ok(username) = env::var("EVAL_DB_USERNAME") {
if !username.trim().is_empty() {
config.db_username = username;
}
}
if let Ok(password) = env::var("EVAL_DB_PASSWORD") {
if !password.trim().is_empty() {
config.db_password = password;
}
}
if let Ok(ns) = env::var("EVAL_DB_NAMESPACE") {
if !ns.trim().is_empty() {
config.db_namespace = Some(ns);
}
}
if let Ok(db) = env::var("EVAL_DB_DATABASE") {
if !db.trim().is_empty() {
config.db_database = Some(db);
}
}
Ok(ParsedArgs { config, show_help })
}
fn take_value<'a, I>(flag: &str, iter: &mut std::iter::Peekable<I>) -> Result<String>
where
I: Iterator<Item = String>,
{
iter.next().ok_or_else(|| anyhow!("{flag} expects a value"))
}
pub fn print_help() {
println!(
"\
eval — dataset conversion, ingestion, and retrieval evaluation CLI
USAGE:
cargo eval -- [options]
# or
cargo run -p eval -- [options]
OPTIONS:
--convert-only Convert the selected dataset and exit.
--force, --refresh Regenerate the converted dataset even if it already exists.
--dataset <name> Dataset to evaluate: 'squad' (default) or 'natural-questions'.
--llm-mode Enable LLM-assisted evaluation features (includes unanswerable cases).
--slice <id|path> Use a cached dataset slice by id (under eval/cache/slices) or by explicit path.
--label <text> Annotate the run; label is stored in JSON/Markdown reports.
--query-model <name> Override the SurrealDB system settings query model (e.g., gpt-4o-mini) for this run.
--slice-grow <int> Grow the slice ledger to contain at least this many answerable cases, then exit.
--slice-offset <int> Evaluate questions starting at this offset within the slice (default: 0).
--reseed-slice Ignore cached corpus state and rebuild the slice's SurrealDB corpus.
--slice-reset-ingestion
Delete cached paragraph shards before rebuilding the ingestion corpus.
--corpus-limit <int> Cap the slice corpus size (positives + negatives). Defaults to ~10× --limit, capped at 1000.
--raw <path> Path to the raw dataset (defaults per dataset).
--converted <path> Path to write/read the converted dataset (defaults per dataset).
--report-dir <path> Directory to write evaluation reports (default: eval/reports).
--k <int> Precision@k cutoff (default: 5).
--limit <int> Limit the number of questions evaluated (default: 200, 0 = all).
--sample <int> Number of mismatches to surface in the Markdown summary (default: 5).
--full-context Disable context cropping when converting datasets (ingest entire documents).
--chunk-min <int> Minimum characters per chunk for text splitting (default: 500).
--chunk-max <int> Maximum characters per chunk for text splitting (default: 2000).
--chunk-vector-take <int>
Override chunk vector candidate cap (default: 20).
--chunk-fts-take <int>
Override chunk FTS candidate cap (default: 20).
--chunk-token-budget <int>
Override chunk token budget estimate for assembly (default: 10000).
--chunk-token-chars <int>
Override average characters per token used for budgeting (default: 4).
--max-chunks-per-entity <int>
Override maximum chunks attached per entity (default: 4).
--embedding <name> Embedding backend: 'fastembed' (default) or 'hashed'.
--embedding-model <code>
FastEmbed model code (defaults to crate preset when omitted).
--cache-dir <path> Directory for embedding caches (default: eval/cache).
--ingestion-cache-dir <path>
Directory for ingestion corpora caches (default: eval/cache/ingested).
--negative-multiplier <float>
Target negative-to-positive paragraph ratio for slice growth (default: 4.0).
--refresh-embeddings Recompute embeddings for cached corpora without re-running ingestion.
--detailed-report Include entity descriptions and categories in JSON reports.
--chunk-diagnostics <path>
Write per-query chunk diagnostics JSONL to the provided path.
--no-rerank Disable the FastEmbed reranking stage (enabled by default).
--rerank-pool <int> Reranking engine pool size / parallelism (default: 16).
--rerank-keep <int> Keep top-N entities after reranking (default: 10).
--inspect-question <id>
Inspect an ingestion cache question and exit (requires --inspect-manifest).
--inspect-manifest <path>
Path to an ingestion cache manifest JSON for inspection mode.
--inspect-db-state <path>
Optional override for the SurrealDB state.json used during inspection; defaults to the state recorded for the selected dataset slice.
--db-endpoint <url> SurrealDB server endpoint (use http:// or https:// to enable SurQL export/import; ws:// endpoints reuse existing namespaces but skip SurQL exports; default: ws://127.0.0.1:8000).
--db-user <value> SurrealDB root username (default: root_user).
--db-pass <value> SurrealDB root password (default: root_password).
--db-namespace <ns> Override the namespace used on the SurrealDB server; state.json tracks this value and the ledger case count so changing it or requesting more cases via --limit triggers a rebuild/import (default: derived from dataset).
--db-database <db> Override the database used on the SurrealDB server; recorded alongside namespace in state.json (default: derived from slice).
--perf-log Print per-stage performance timings to stdout after the run.
--perf-log-json <path>
Write structured performance telemetry JSON to the provided path.
--perf-log-dir <path>
Directory that receives timestamped perf JSON copies (defaults to $EVAL_PERF_LOG_DIR).
Examples:
cargo eval -- --dataset squad --limit 10 --detailed-report
cargo eval -- --dataset natural-questions --limit 1 --rerank-pool 1 --detailed-report
Notes:
The latest run's JSON/Markdown reports are saved as eval/reports/latest.json and latest.md, making it easy to script automated checks.
-h, --help Show this help text.
Dataset defaults (from eval/manifest.yaml):
squad raw: eval/data/raw/squad/dev-v2.0.json
converted: eval/data/converted/squad-minne.json
natural-questions raw: eval/data/raw/nq/dev-all.jsonl
converted: eval/data/converted/nq-dev-minne.json
"
);
}
pub fn ensure_parent(path: &Path) -> Result<()> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)
.with_context(|| format!("creating parent directory for {}", path.display()))?;
}
Ok(())
}

88
eval/src/cache.rs Normal file
View File

@@ -0,0 +1,88 @@
use std::{
collections::HashMap,
path::{Path, PathBuf},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
#[derive(Debug, Default, Serialize, Deserialize)]
struct EmbeddingCacheData {
entities: HashMap<String, Vec<f32>>,
chunks: HashMap<String, Vec<f32>>,
}
#[derive(Clone)]
pub struct EmbeddingCache {
path: Arc<PathBuf>,
data: Arc<Mutex<EmbeddingCacheData>>,
dirty: Arc<AtomicBool>,
}
#[allow(dead_code)]
impl EmbeddingCache {
pub async fn load(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref().to_path_buf();
let data = if path.exists() {
let raw = tokio::fs::read(&path)
.await
.with_context(|| format!("reading embedding cache {}", path.display()))?;
serde_json::from_slice(&raw)
.with_context(|| format!("parsing embedding cache {}", path.display()))?
} else {
EmbeddingCacheData::default()
};
Ok(Self {
path: Arc::new(path),
data: Arc::new(Mutex::new(data)),
dirty: Arc::new(AtomicBool::new(false)),
})
}
pub async fn get_entity(&self, id: &str) -> Option<Vec<f32>> {
let guard = self.data.lock().await;
guard.entities.get(id).cloned()
}
pub async fn insert_entity(&self, id: String, embedding: Vec<f32>) {
let mut guard = self.data.lock().await;
guard.entities.insert(id, embedding);
self.dirty.store(true, Ordering::Relaxed);
}
pub async fn get_chunk(&self, id: &str) -> Option<Vec<f32>> {
let guard = self.data.lock().await;
guard.chunks.get(id).cloned()
}
pub async fn insert_chunk(&self, id: String, embedding: Vec<f32>) {
let mut guard = self.data.lock().await;
guard.chunks.insert(id, embedding);
self.dirty.store(true, Ordering::Relaxed);
}
pub async fn persist(&self) -> Result<()> {
if !self.dirty.load(Ordering::Relaxed) {
return Ok(());
}
let guard = self.data.lock().await;
let body = serde_json::to_vec_pretty(&*guard).context("serialising embedding cache")?;
if let Some(parent) = self.path.parent() {
tokio::fs::create_dir_all(parent)
.await
.with_context(|| format!("creating cache directory {}", parent.display()))?;
}
tokio::fs::write(&*self.path, body)
.await
.with_context(|| format!("writing embedding cache {}", self.path.display()))?;
self.dirty.store(false, Ordering::Relaxed);
Ok(())
}
}

1003
eval/src/datasets.rs Normal file

File diff suppressed because it is too large Load Diff

269
eval/src/db_helpers.rs Normal file
View File

@@ -0,0 +1,269 @@
use anyhow::{Context, Result};
use common::storage::db::SurrealDbClient;
// Remove and recreate HNSW indexes for changing embedding lengths, used at beginning if embedding length differs from default system settings
pub async fn change_embedding_length_in_hnsw_indexes(
db: &SurrealDbClient,
dimension: usize,
) -> Result<()> {
tracing::info!("Changing embedding length in HNSW indexes");
let query = format!(
"BEGIN TRANSACTION;
REMOVE INDEX IF EXISTS idx_embedding_chunks ON TABLE text_chunk;
REMOVE INDEX IF EXISTS idx_embedding_entities ON TABLE knowledge_entity;
DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {dim};
DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {dim};
COMMIT TRANSACTION;",
dim = dimension
);
db.client
.query(query)
.await
.context("changing HNSW indexes")?;
tracing::info!("HNSW indexes successfully changed");
Ok(())
}
// Helper functions for index management during namespace reseed
pub async fn remove_all_indexes(db: &SurrealDbClient) -> Result<()> {
tracing::info!("Removing ALL indexes before namespace reseed (aggressive approach)");
// Remove ALL indexes from ALL tables to ensure no cache access
db.client
.query(
"BEGIN TRANSACTION;
-- HNSW indexes
REMOVE INDEX IF EXISTS idx_embedding_chunks ON TABLE text_chunk;
REMOVE INDEX IF EXISTS idx_embedding_entities ON TABLE knowledge_entity;
-- FTS indexes on text_content (remove ALL of them)
REMOVE INDEX IF EXISTS text_content_fts_idx ON TABLE text_content;
REMOVE INDEX IF EXISTS text_content_fts_text_idx ON TABLE text_content;
REMOVE INDEX IF EXISTS text_content_fts_category_idx ON TABLE text_content;
REMOVE INDEX IF EXISTS text_content_fts_context_idx ON TABLE text_content;
REMOVE INDEX IF EXISTS text_content_fts_file_name_idx ON TABLE text_content;
REMOVE INDEX IF EXISTS text_content_fts_url_idx ON TABLE text_content;
REMOVE INDEX IF EXISTS text_content_fts_url_title_idx ON TABLE text_content;
-- FTS indexes on knowledge_entity
REMOVE INDEX IF EXISTS knowledge_entity_fts_name_idx ON TABLE knowledge_entity;
REMOVE INDEX IF EXISTS knowledge_entity_fts_description_idx ON TABLE knowledge_entity;
-- FTS indexes on text_chunk
REMOVE INDEX IF EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk;
COMMIT TRANSACTION;",
)
.await
.context("removing all indexes before namespace reseed")?;
tracing::info!("All indexes removed before namespace reseed");
Ok(())
}
async fn create_tokenizer(db: &SurrealDbClient) -> Result<()> {
tracing::info!("Creating FTS analyzers for namespace reseed");
let res = db
.client
.query(
"BEGIN TRANSACTION;
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer
TOKENIZERS class
FILTERS lowercase, ascii, snowball(english);
COMMIT TRANSACTION;",
)
.await
.context("creating FTS analyzers for namespace reseed")?;
res.check().context("failed to create the tokenizer")?;
Ok(())
}
pub async fn recreate_indexes(db: &SurrealDbClient, dimension: usize) -> Result<()> {
tracing::info!("Recreating ALL indexes after namespace reseed (SEQUENTIAL approach)");
let total_start = std::time::Instant::now();
create_tokenizer(db)
.await
.context("creating FTS analyzer")?;
// For now we dont remove these plain indexes, we could if they prove negatively impacting performance
// create_regular_indexes_for_snapshot(db)
// .await
// .context("creating regular indexes for namespace reseed")?;
let fts_start = std::time::Instant::now();
create_fts_indexes_for_snapshot(db)
.await
.context("creating FTS indexes for namespace reseed")?;
tracing::info!(duration = ?fts_start.elapsed(), "FTS indexes created");
let hnsw_start = std::time::Instant::now();
create_hnsw_indexes_for_snapshot(db, dimension)
.await
.context("creating HNSW indexes for namespace reseed")?;
tracing::info!(duration = ?hnsw_start.elapsed(), "HNSW indexes created");
tracing::info!(duration = ?total_start.elapsed(), "All index groups recreated successfully in sequence");
Ok(())
}
#[allow(dead_code)] // For now we dont do this. We could
async fn create_regular_indexes_for_snapshot(db: &SurrealDbClient) -> Result<()> {
tracing::info!("Creating regular indexes for namespace reseed (parallel group 1)");
let res = db
.client
.query(
"BEGIN TRANSACTION;
DEFINE INDEX text_content_user_id_idx ON text_content FIELDS user_id;
DEFINE INDEX text_content_created_at_idx ON text_content FIELDS created_at;
DEFINE INDEX text_content_category_idx ON text_content FIELDS category;
DEFINE INDEX text_chunk_source_id_idx ON text_chunk FIELDS source_id;
DEFINE INDEX text_chunk_user_id_idx ON text_chunk FIELDS user_id;
DEFINE INDEX knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;
DEFINE INDEX knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;
DEFINE INDEX knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;
DEFINE INDEX knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;
COMMIT TRANSACTION;",
)
.await
.context("creating regular indexes for namespace reseed")?;
res.check().context("one of the regular indexes failed")?;
tracing::info!("Regular indexes for namespace reseed created");
Ok(())
}
async fn create_fts_indexes_for_snapshot(db: &SurrealDbClient) -> Result<()> {
tracing::info!("Creating FTS indexes for namespace reseed (group 2)");
let res = db.client
.query(
"BEGIN TRANSACTION;
DEFINE INDEX text_content_fts_idx ON TABLE text_content FIELDS text;
DEFINE INDEX knowledge_entity_fts_name_idx ON TABLE knowledge_entity FIELDS name
SEARCH ANALYZER app_en_fts_analyzer BM25;
DEFINE INDEX knowledge_entity_fts_description_idx ON TABLE knowledge_entity FIELDS description
SEARCH ANALYZER app_en_fts_analyzer BM25;
DEFINE INDEX text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk
SEARCH ANALYZER app_en_fts_analyzer BM25;
COMMIT TRANSACTION;",
)
.await
.context("sending FTS index creation query")?;
// This actually surfaces statement-level errors
res.check()
.context("one or more FTS index statements failed")?;
tracing::info!("FTS indexes for namespace reseed created");
Ok(())
}
async fn create_hnsw_indexes_for_snapshot(db: &SurrealDbClient, dimension: usize) -> Result<()> {
tracing::info!("Creating HNSW indexes for namespace reseed (group 3)");
let query = format!(
"BEGIN TRANSACTION;
DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {dim};
DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {dim};
COMMIT TRANSACTION;",
dim = dimension
);
let res = db
.client
.query(query)
.await
.context("creating HNSW indexes for namespace reseed")?;
res.check()
.context("one or more HNSW index statements failed")?;
tracing::info!("HNSW indexes for namespace reseed created");
Ok(())
}
pub async fn reset_namespace(db: &SurrealDbClient, namespace: &str, database: &str) -> Result<()> {
let query = format!(
"REMOVE NAMESPACE {ns};
DEFINE NAMESPACE {ns};
DEFINE DATABASE {db};",
ns = namespace,
db = database
);
db.client
.query(query)
.await
.context("resetting SurrealDB namespace")?;
db.client
.use_ns(namespace)
.use_db(database)
.await
.context("selecting namespace/database after reset")?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use serde::Deserialize;
use uuid::Uuid;
#[derive(Debug, Deserialize)]
struct FooRow {
label: String,
}
#[tokio::test]
async fn reset_namespace_drops_existing_rows() {
let namespace = format!("reset_ns_{}", Uuid::new_v4().simple());
let database = format!("reset_db_{}", Uuid::new_v4().simple());
let db = SurrealDbClient::memory(&namespace, &database)
.await
.expect("in-memory db");
db.client
.query(
"DEFINE TABLE foo SCHEMALESS;
CREATE foo:foo SET label = 'before';",
)
.await
.expect("seed namespace")
.check()
.expect("seed response");
let mut before = db
.client
.query("SELECT * FROM foo")
.await
.expect("select before reset");
let existing: Vec<FooRow> = before.take(0).expect("rows before reset");
assert_eq!(existing.len(), 1);
assert_eq!(existing[0].label, "before");
reset_namespace(&db, &namespace, &database)
.await
.expect("namespace reset");
match db.client.query("SELECT * FROM foo").await {
Ok(mut response) => {
let rows: Vec<FooRow> = response.take(0).unwrap_or_default();
assert!(
rows.is_empty(),
"reset namespace should drop rows, found {:?}",
rows
);
}
Err(error) => {
let message = error.to_string();
assert!(
message.to_ascii_lowercase().contains("table")
|| message.to_ascii_lowercase().contains("namespace")
|| message.to_ascii_lowercase().contains("foo"),
"unexpected error after namespace reset: {message}"
);
}
}
}
}

171
eval/src/embedding.rs Normal file
View File

@@ -0,0 +1,171 @@
use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
str::FromStr,
sync::Arc,
};
use anyhow::{anyhow, Context, Result};
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
use tokio::sync::Mutex;
use crate::args::{Config, EmbeddingBackend};
#[derive(Clone)]
pub struct EmbeddingProvider {
inner: EmbeddingInner,
}
#[derive(Clone)]
enum EmbeddingInner {
Hashed {
dimension: usize,
},
FastEmbed {
model: Arc<Mutex<TextEmbedding>>,
model_name: EmbeddingModel,
dimension: usize,
},
}
impl EmbeddingProvider {
pub fn backend_label(&self) -> &'static str {
match self.inner {
EmbeddingInner::Hashed { .. } => "hashed",
EmbeddingInner::FastEmbed { .. } => "fastembed",
}
}
pub fn dimension(&self) -> usize {
match &self.inner {
EmbeddingInner::Hashed { dimension } => *dimension,
EmbeddingInner::FastEmbed { dimension, .. } => *dimension,
}
}
pub fn model_code(&self) -> Option<String> {
match &self.inner {
EmbeddingInner::FastEmbed { model_name, .. } => Some(model_name.to_string()),
EmbeddingInner::Hashed { .. } => None,
}
}
pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
match &self.inner {
EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)),
EmbeddingInner::FastEmbed { model, .. } => {
let mut guard = model.lock().await;
let embeddings = guard
.embed(vec![text.to_owned()], None)
.context("generating fastembed vector")?;
embeddings
.into_iter()
.next()
.ok_or_else(|| anyhow!("fastembed returned no embedding for input"))
}
}
}
pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
match &self.inner {
EmbeddingInner::Hashed { dimension } => Ok(texts
.into_iter()
.map(|text| hashed_embedding(&text, *dimension))
.collect()),
EmbeddingInner::FastEmbed { model, .. } => {
if texts.is_empty() {
return Ok(Vec::new());
}
let mut guard = model.lock().await;
guard
.embed(texts, None)
.context("generating fastembed batch embeddings")
}
}
}
}
pub async fn build_provider(
config: &Config,
default_dimension: usize,
) -> Result<EmbeddingProvider> {
match config.embedding_backend {
EmbeddingBackend::Hashed => Ok(EmbeddingProvider {
inner: EmbeddingInner::Hashed {
dimension: default_dimension.max(1),
},
}),
EmbeddingBackend::FastEmbed => {
let model_name = if let Some(code) = config.embedding_model.as_deref() {
EmbeddingModel::from_str(code).map_err(|err| anyhow!(err))?
} else {
EmbeddingModel::default()
};
let options =
TextInitOptions::new(model_name.clone()).with_show_download_progress(true);
let model_name_for_task = model_name.clone();
let model_name_code = model_name.to_string();
let (model, dimension) = tokio::task::spawn_blocking(move || -> Result<_> {
let model =
TextEmbedding::try_new(options).context("initialising FastEmbed text model")?;
let info =
EmbeddingModel::get_model_info(&model_name_for_task).ok_or_else(|| {
anyhow!("FastEmbed model metadata missing for {model_name_code}")
})?;
Ok((model, info.dim))
})
.await
.context("joining FastEmbed initialisation task")??;
Ok(EmbeddingProvider {
inner: EmbeddingInner::FastEmbed {
model: Arc::new(Mutex::new(model)),
model_name,
dimension,
},
})
}
}
}
fn hashed_embedding(text: &str, dimension: usize) -> Vec<f32> {
let dim = dimension.max(1);
let mut vector = vec![0.0f32; dim];
if text.is_empty() {
return vector;
}
let mut token_count = 0f32;
for token in tokens(text) {
token_count += 1.0;
let idx = bucket(&token, dim);
vector[idx] += 1.0;
}
if token_count == 0.0 {
return vector;
}
let norm = vector.iter().map(|v| v * v).sum::<f32>().sqrt();
if norm > 0.0 {
for value in &mut vector {
*value /= norm;
}
}
vector
}
fn tokens(text: &str) -> impl Iterator<Item = String> + '_ {
text.split(|c: char| !c.is_ascii_alphanumeric())
.filter(|token| !token.is_empty())
.map(|token| token.to_ascii_lowercase())
}
fn bucket(token: &str, dimension: usize) -> usize {
let mut hasher = DefaultHasher::new();
token.hash(&mut hasher);
(hasher.finish() as usize) % dimension
}

767
eval/src/eval/mod.rs Normal file
View File

@@ -0,0 +1,767 @@
mod pipeline;
pub use pipeline::run_evaluation;
use std::{
collections::{HashMap, HashSet},
path::Path,
time::Duration,
};
use anyhow::{anyhow, Context, Result};
use chrono::{DateTime, SecondsFormat, Utc};
use common::{
error::AppError,
storage::{
db::SurrealDbClient,
types::{system_settings::SystemSettings, user::User},
},
};
use composite_retrieval::pipeline as retrieval_pipeline;
use composite_retrieval::pipeline::PipelineStageTimings;
use composite_retrieval::pipeline::RetrievalTuning;
use serde::{Deserialize, Serialize};
use tokio::io::AsyncWriteExt;
use tracing::{info, warn};
use crate::{
args::{self, Config},
datasets::{self, ConvertedDataset},
db_helpers::change_embedding_length_in_hnsw_indexes,
ingest,
slice::{self},
snapshot::{self, DbSnapshotState},
};
#[derive(Debug, Serialize)]
pub struct EvaluationSummary {
pub generated_at: DateTime<Utc>,
pub k: usize,
pub limit: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub run_label: Option<String>,
pub total_cases: usize,
pub correct: usize,
pub precision: f64,
pub correct_at_1: usize,
pub correct_at_2: usize,
pub correct_at_3: usize,
pub precision_at_1: f64,
pub precision_at_2: f64,
pub precision_at_3: f64,
pub duration_ms: u128,
pub dataset_id: String,
pub dataset_label: String,
pub dataset_includes_unanswerable: bool,
pub dataset_source: String,
pub slice_id: String,
pub slice_seed: u64,
pub slice_total_cases: usize,
pub slice_window_offset: usize,
pub slice_window_length: usize,
pub slice_cases: usize,
pub slice_positive_paragraphs: usize,
pub slice_negative_paragraphs: usize,
pub slice_total_paragraphs: usize,
pub slice_negative_multiplier: f32,
pub namespace_reused: bool,
pub corpus_paragraphs: usize,
pub ingestion_cache_path: String,
pub ingestion_reused: bool,
pub ingestion_embeddings_reused: bool,
pub ingestion_fingerprint: String,
pub positive_paragraphs_reused: usize,
pub negative_paragraphs_reused: usize,
pub latency_ms: LatencyStats,
pub perf: PerformanceTimings,
pub embedding_backend: String,
pub embedding_model: Option<String>,
pub embedding_dimension: usize,
pub rerank_enabled: bool,
pub rerank_pool_size: Option<usize>,
pub rerank_keep_top: usize,
pub concurrency: usize,
pub detailed_report: bool,
pub chunk_vector_take: usize,
pub chunk_fts_take: usize,
pub chunk_token_budget: usize,
pub chunk_avg_chars_per_token: usize,
pub max_chunks_per_entity: usize,
pub cases: Vec<CaseSummary>,
}
#[derive(Debug, Serialize)]
pub struct CaseSummary {
pub question_id: String,
pub question: String,
pub paragraph_id: String,
pub paragraph_title: String,
pub expected_source: String,
pub answers: Vec<String>,
pub matched: bool,
pub entity_match: bool,
pub chunk_text_match: bool,
pub chunk_id_match: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub match_rank: Option<usize>,
pub latency_ms: u128,
pub retrieved: Vec<RetrievedSummary>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LatencyStats {
pub avg: f64,
pub p50: u128,
pub p95: u128,
}
#[derive(Debug, Clone, Serialize)]
pub struct StageLatencyBreakdown {
pub collect_candidates: LatencyStats,
pub graph_expansion: LatencyStats,
pub chunk_attach: LatencyStats,
pub rerank: LatencyStats,
pub assemble: LatencyStats,
}
#[derive(Debug, Default, Clone, Serialize)]
pub struct EvaluationStageTimings {
pub prepare_slice_ms: u128,
pub prepare_db_ms: u128,
pub prepare_corpus_ms: u128,
pub prepare_namespace_ms: u128,
pub run_queries_ms: u128,
pub summarize_ms: u128,
pub finalize_ms: u128,
}
#[derive(Debug, Serialize)]
pub struct PerformanceTimings {
pub openai_base_url: String,
pub ingestion_ms: u128,
#[serde(skip_serializing_if = "Option::is_none")]
pub namespace_seed_ms: Option<u128>,
pub evaluation_stage_ms: EvaluationStageTimings,
pub stage_latency: StageLatencyBreakdown,
}
#[derive(Debug, Serialize)]
pub struct RetrievedSummary {
pub rank: usize,
pub entity_id: String,
pub source_id: String,
pub entity_name: String,
pub score: f32,
pub matched: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub entity_description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub entity_category: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub chunk_text_match: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub chunk_id_match: Option<bool>,
}
#[derive(Debug, Serialize)]
pub(crate) struct CaseDiagnostics {
question_id: String,
question: String,
paragraph_id: String,
paragraph_title: String,
expected_source: String,
expected_chunk_ids: Vec<String>,
answers: Vec<String>,
entity_match: bool,
chunk_text_match: bool,
chunk_id_match: bool,
failure_reasons: Vec<String>,
missing_expected_chunk_ids: Vec<String>,
attached_chunk_ids: Vec<String>,
retrieved: Vec<EntityDiagnostics>,
#[serde(skip_serializing_if = "Option::is_none")]
pipeline: Option<retrieval_pipeline::PipelineDiagnostics>,
}
#[derive(Debug, Serialize)]
struct EntityDiagnostics {
rank: usize,
entity_id: String,
source_id: String,
name: String,
score: f32,
entity_match: bool,
chunk_text_match: bool,
chunk_id_match: bool,
chunks: Vec<ChunkDiagnosticsEntry>,
}
#[derive(Debug, Serialize)]
struct ChunkDiagnosticsEntry {
chunk_id: String,
score: f32,
contains_answer: bool,
expected_chunk: bool,
snippet: String,
}
pub(crate) struct SeededCase {
question_id: String,
question: String,
expected_source: String,
answers: Vec<String>,
paragraph_id: String,
paragraph_title: String,
expected_chunk_ids: Vec<String>,
}
pub(crate) fn cases_from_manifest(manifest: &ingest::CorpusManifest) -> Vec<SeededCase> {
let mut title_map = HashMap::new();
for paragraph in &manifest.paragraphs {
title_map.insert(paragraph.paragraph_id.as_str(), paragraph.title.clone());
}
manifest
.questions
.iter()
.filter(|question| !question.is_impossible)
.map(|question| {
let title = title_map
.get(question.paragraph_id.as_str())
.cloned()
.unwrap_or_else(|| "Untitled".to_string());
SeededCase {
question_id: question.question_id.clone(),
question: question.question_text.clone(),
expected_source: question.text_content_id.clone(),
answers: question.answers.clone(),
paragraph_id: question.paragraph_id.clone(),
paragraph_title: title,
expected_chunk_ids: question.matching_chunk_ids.clone(),
}
})
.collect()
}
pub(crate) fn text_contains_answer(text: &str, answers: &[String]) -> bool {
if answers.is_empty() {
return true;
}
let haystack = text.to_ascii_lowercase();
answers.iter().any(|needle| haystack.contains(needle))
}
pub(crate) fn compute_latency_stats(latencies: &[u128]) -> LatencyStats {
if latencies.is_empty() {
return LatencyStats {
avg: 0.0,
p50: 0,
p95: 0,
};
}
let mut sorted = latencies.to_vec();
sorted.sort_unstable();
let sum: u128 = sorted.iter().copied().sum();
let avg = sum as f64 / (sorted.len() as f64);
let p50 = percentile(&sorted, 0.50);
let p95 = percentile(&sorted, 0.95);
LatencyStats { avg, p50, p95 }
}
pub(crate) fn build_stage_latency_breakdown(
samples: &[PipelineStageTimings],
) -> StageLatencyBreakdown {
fn collect_stage<F>(samples: &[PipelineStageTimings], selector: F) -> Vec<u128>
where
F: Fn(&PipelineStageTimings) -> u128,
{
samples.iter().map(selector).collect()
}
StageLatencyBreakdown {
collect_candidates: compute_latency_stats(&collect_stage(samples, |entry| {
entry.collect_candidates_ms
})),
graph_expansion: compute_latency_stats(&collect_stage(samples, |entry| {
entry.graph_expansion_ms
})),
chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| entry.chunk_attach_ms)),
rerank: compute_latency_stats(&collect_stage(samples, |entry| entry.rerank_ms)),
assemble: compute_latency_stats(&collect_stage(samples, |entry| entry.assemble_ms)),
}
}
fn percentile(sorted: &[u128], fraction: f64) -> u128 {
if sorted.is_empty() {
return 0;
}
let clamped = fraction.clamp(0.0, 1.0);
let idx = (clamped * (sorted.len() as f64 - 1.0)).round() as usize;
sorted[idx.min(sorted.len() - 1)]
}
pub async fn grow_slice(dataset: &ConvertedDataset, config: &Config) -> Result<()> {
let ledger_limit = ledger_target(config);
let slice_settings = slice::slice_config_with_limit(config, ledger_limit);
let slice =
slice::resolve_slice(dataset, &slice_settings).context("resolving dataset slice")?;
info!(
slice = slice.manifest.slice_id.as_str(),
cases = slice.manifest.case_count,
positives = slice.manifest.positive_paragraphs,
negatives = slice.manifest.negative_paragraphs,
total_paragraphs = slice.manifest.total_paragraphs,
"Slice ledger ready"
);
println!(
"Slice `{}` now contains {} questions ({} positives, {} negatives)",
slice.manifest.slice_id,
slice.manifest.case_count,
slice.manifest.positive_paragraphs,
slice.manifest.negative_paragraphs
);
Ok(())
}
pub(crate) fn ledger_target(config: &Config) -> Option<usize> {
match (config.slice_grow, config.limit) {
(Some(grow), Some(limit)) => Some(limit.max(grow)),
(Some(grow), None) => Some(grow),
(None, limit) => limit,
}
}
pub(crate) fn apply_dataset_tuning_overrides(
dataset: &ConvertedDataset,
config: &Config,
tuning: &mut RetrievalTuning,
) {
let is_long_form = dataset
.metadata
.id
.to_ascii_lowercase()
.contains("natural-questions");
if !is_long_form {
return;
}
if config.chunk_vector_take.is_none() {
tuning.chunk_vector_take = tuning.chunk_vector_take.max(80);
}
if config.chunk_fts_take.is_none() {
tuning.chunk_fts_take = tuning.chunk_fts_take.max(80);
}
if config.chunk_token_budget.is_none() {
tuning.token_budget_estimate = tuning.token_budget_estimate.max(20_000);
}
if config.max_chunks_per_entity.is_none() {
tuning.max_chunks_per_entity = tuning.max_chunks_per_entity.max(12);
}
if tuning.lexical_match_weight < 0.25 {
tuning.lexical_match_weight = 0.3;
}
}
pub(crate) fn build_case_diagnostics(
summary: &CaseSummary,
expected_chunk_ids: &[String],
answers_lower: &[String],
entities: &[composite_retrieval::RetrievedEntity],
pipeline_stats: Option<retrieval_pipeline::PipelineDiagnostics>,
) -> CaseDiagnostics {
let expected_set: HashSet<&str> = expected_chunk_ids.iter().map(|id| id.as_str()).collect();
let mut seen_chunks: HashSet<String> = HashSet::new();
let mut attached_chunk_ids = Vec::new();
let mut entity_diagnostics = Vec::new();
for (idx, entity) in entities.iter().enumerate() {
let mut chunk_entries = Vec::new();
for chunk in &entity.chunks {
let contains_answer = text_contains_answer(&chunk.chunk.chunk, answers_lower);
let expected_chunk = expected_set.contains(chunk.chunk.id.as_str());
seen_chunks.insert(chunk.chunk.id.clone());
attached_chunk_ids.push(chunk.chunk.id.clone());
chunk_entries.push(ChunkDiagnosticsEntry {
chunk_id: chunk.chunk.id.clone(),
score: chunk.score,
contains_answer,
expected_chunk,
snippet: chunk_preview(&chunk.chunk.chunk),
});
}
entity_diagnostics.push(EntityDiagnostics {
rank: idx + 1,
entity_id: entity.entity.id.clone(),
source_id: entity.entity.source_id.clone(),
name: entity.entity.name.clone(),
score: entity.score,
entity_match: entity.entity.source_id == summary.expected_source,
chunk_text_match: chunk_entries.iter().any(|entry| entry.contains_answer),
chunk_id_match: chunk_entries.iter().any(|entry| entry.expected_chunk),
chunks: chunk_entries,
});
}
let missing_expected_chunk_ids = expected_chunk_ids
.iter()
.filter(|id| !seen_chunks.contains(id.as_str()))
.cloned()
.collect::<Vec<_>>();
let mut failure_reasons = Vec::new();
if !summary.entity_match {
failure_reasons.push("entity_miss".to_string());
}
if !summary.chunk_text_match {
failure_reasons.push("chunk_text_missing".to_string());
}
if !summary.chunk_id_match {
failure_reasons.push("chunk_id_missing".to_string());
}
if !missing_expected_chunk_ids.is_empty() {
failure_reasons.push("expected_chunk_absent".to_string());
}
CaseDiagnostics {
question_id: summary.question_id.clone(),
question: summary.question.clone(),
paragraph_id: summary.paragraph_id.clone(),
paragraph_title: summary.paragraph_title.clone(),
expected_source: summary.expected_source.clone(),
expected_chunk_ids: expected_chunk_ids.to_vec(),
answers: summary.answers.clone(),
entity_match: summary.entity_match,
chunk_text_match: summary.chunk_text_match,
chunk_id_match: summary.chunk_id_match,
failure_reasons,
missing_expected_chunk_ids,
attached_chunk_ids,
retrieved: entity_diagnostics,
pipeline: pipeline_stats,
}
}
fn chunk_preview(text: &str) -> String {
text.chars()
.take(200)
.collect::<String>()
.replace('\n', " ")
}
pub(crate) async fn write_chunk_diagnostics(path: &Path, cases: &[CaseDiagnostics]) -> Result<()> {
args::ensure_parent(path)?;
let mut file = tokio::fs::File::create(path)
.await
.with_context(|| format!("creating diagnostics file {}", path.display()))?;
for case in cases {
let line = serde_json::to_vec(case).context("serialising chunk diagnostics entry")?;
file.write_all(&line).await?;
file.write_all(b"\n").await?;
}
file.flush().await?;
Ok(())
}
pub(crate) async fn warm_hnsw_cache(db: &SurrealDbClient, dimension: usize) -> Result<()> {
// Create a dummy embedding for cache warming
let dummy_embedding: Vec<f32> = (0..dimension).map(|i| (i as f32).sin()).collect();
info!("Warming HNSW caches with sample queries");
// Warm up chunk index
let _ = db
.client
.query("SELECT * FROM text_chunk WHERE embedding <|1,1|> $embedding LIMIT 5")
.bind(("embedding", dummy_embedding.clone()))
.await
.context("warming text chunk HNSW cache")?;
// Warm up entity index
let _ = db
.client
.query("SELECT * FROM knowledge_entity WHERE embedding <|1,1|> $embedding LIMIT 5")
.bind(("embedding", dummy_embedding))
.await
.context("warming knowledge entity HNSW cache")?;
info!("HNSW cache warming completed");
Ok(())
}
pub(crate) async fn ensure_eval_user(db: &SurrealDbClient) -> Result<User> {
let timestamp = datasets::base_timestamp();
let user = User {
id: "eval-user".to_string(),
created_at: timestamp,
updated_at: timestamp,
email: "eval-retrieval@minne.dev".to_string(),
password: "not-used".to_string(),
anonymous: false,
api_key: None,
admin: false,
timezone: "UTC".to_string(),
};
if let Some(existing) = db.get_item::<User>(&user.id).await? {
return Ok(existing);
}
db.store_item(user.clone())
.await
.context("storing evaluation user")?;
Ok(user)
}
pub fn format_timestamp(timestamp: &DateTime<Utc>) -> String {
timestamp.to_rfc3339_opts(SecondsFormat::Secs, true)
}
pub(crate) fn sanitize_model_code(code: &str) -> String {
code.chars()
.map(|ch| {
if ch.is_ascii_alphanumeric() {
ch.to_ascii_lowercase()
} else {
'_'
}
})
.collect()
}
pub(crate) async fn connect_eval_db(
config: &Config,
namespace: &str,
database: &str,
) -> Result<SurrealDbClient> {
match SurrealDbClient::new(
&config.db_endpoint,
&config.db_username,
&config.db_password,
namespace,
database,
)
.await
{
Ok(client) => {
info!(
endpoint = %config.db_endpoint,
namespace,
database,
auth = "root",
"Connected to SurrealDB"
);
Ok(client)
}
Err(root_err) => {
info!(
endpoint = %config.db_endpoint,
namespace,
database,
"Root authentication failed; trying namespace-level auth"
);
let namespace_client = SurrealDbClient::new_with_namespace_user(
&config.db_endpoint,
namespace,
&config.db_username,
&config.db_password,
database,
)
.await
.map_err(|ns_err| {
anyhow!(
"failed to connect to SurrealDB via root ({root_err}) or namespace ({ns_err}) credentials"
)
})?;
info!(
endpoint = %config.db_endpoint,
namespace,
database,
auth = "namespace",
"Connected to SurrealDB"
);
Ok(namespace_client)
}
}
}
pub(crate) async fn namespace_has_corpus(db: &SurrealDbClient) -> Result<bool> {
#[derive(Deserialize)]
struct CountRow {
count: i64,
}
let mut response = db
.client
.query("SELECT count() AS count FROM text_chunk")
.await
.context("checking namespace corpus state")?;
let rows: Vec<CountRow> = response.take(0).unwrap_or_default();
Ok(rows.first().map(|row| row.count).unwrap_or(0) > 0)
}
pub(crate) async fn can_reuse_namespace(
db: &SurrealDbClient,
descriptor: &snapshot::Descriptor,
namespace: &str,
database: &str,
dataset_id: &str,
slice_id: &str,
ingestion_fingerprint: &str,
slice_case_count: usize,
) -> Result<bool> {
let state = match descriptor.load_db_state().await? {
Some(state) => state,
None => {
info!("No namespace state recorded; reseeding corpus from cached shards");
return Ok(false);
}
};
if state.slice_case_count < slice_case_count {
info!(
requested_cases = slice_case_count,
stored_cases = state.slice_case_count,
"Skipping live namespace reuse; ledger grew beyond cached state"
);
return Ok(false);
}
if state.dataset_id != dataset_id
|| state.slice_id != slice_id
|| state.ingestion_fingerprint != ingestion_fingerprint
|| state.namespace.as_deref() != Some(namespace)
|| state.database.as_deref() != Some(database)
{
info!(
namespace,
database, "Cached namespace metadata mismatch; rebuilding corpus from ingestion cache"
);
return Ok(false);
}
if namespace_has_corpus(db).await? {
Ok(true)
} else {
info!(
namespace,
database,
"Namespace metadata matches but tables are empty; reseeding from ingestion cache"
);
Ok(false)
}
}
fn sanitize_identifier(input: &str) -> String {
let mut cleaned: String = input
.chars()
.map(|ch| {
if ch.is_ascii_alphanumeric() {
ch.to_ascii_lowercase()
} else {
'_'
}
})
.collect();
if cleaned.is_empty() {
cleaned.push('x');
}
if cleaned.len() > 64 {
cleaned.truncate(64);
}
cleaned
}
pub(crate) fn default_namespace(dataset_id: &str, limit: Option<usize>) -> String {
let dataset_component = sanitize_identifier(dataset_id);
let limit_component = match limit {
Some(value) if value > 0 => format!("limit{}", value),
_ => "all".to_string(),
};
format!("eval_{}_{}", dataset_component, limit_component)
}
pub(crate) fn default_database() -> String {
"retrieval_eval".to_string()
}
pub(crate) async fn record_namespace_state(
descriptor: &snapshot::Descriptor,
dataset_id: &str,
slice_id: &str,
ingestion_fingerprint: &str,
namespace: &str,
database: &str,
slice_case_count: usize,
) {
let state = DbSnapshotState {
dataset_id: dataset_id.to_string(),
slice_id: slice_id.to_string(),
ingestion_fingerprint: ingestion_fingerprint.to_string(),
snapshot_hash: descriptor.metadata_hash().to_string(),
updated_at: Utc::now(),
namespace: Some(namespace.to_string()),
database: Some(database.to_string()),
slice_case_count,
};
if let Err(err) = descriptor.store_db_state(&state).await {
warn!(error = %err, "Failed to record namespace state");
}
}
pub(crate) async fn enforce_system_settings(
db: &SurrealDbClient,
mut settings: SystemSettings,
provider_dimension: usize,
config: &Config,
) -> Result<SystemSettings> {
let mut updated_settings = settings.clone();
let mut needs_settings_update = false;
let mut embedding_dimension_changed = false;
if provider_dimension != settings.embedding_dimensions as usize {
updated_settings.embedding_dimensions = provider_dimension as u32;
needs_settings_update = true;
embedding_dimension_changed = true;
}
if let Some(query_override) = config.query_model.as_deref() {
if settings.query_model != query_override {
info!(
model = query_override,
"Overriding system query model for this run"
);
updated_settings.query_model = query_override.to_string();
needs_settings_update = true;
}
}
if needs_settings_update {
settings = SystemSettings::update(db, updated_settings)
.await
.context("updating system settings overrides")?;
}
if embedding_dimension_changed {
change_embedding_length_in_hnsw_indexes(db, provider_dimension)
.await
.context("redefining HNSW indexes for new embedding dimension")?;
}
Ok(settings)
}
pub(crate) async fn load_or_init_system_settings(
db: &SurrealDbClient,
) -> Result<(SystemSettings, bool)> {
match SystemSettings::get_current(db).await {
Ok(settings) => Ok((settings, false)),
Err(AppError::NotFound(_)) => {
info!("System settings missing; applying database migrations for namespace");
db.apply_migrations()
.await
.context("applying database migrations after missing system settings")?;
tokio::time::sleep(Duration::from_millis(50)).await;
let settings = SystemSettings::get_current(db)
.await
.context("loading system settings after migrations")?;
Ok((settings, true))
}
Err(err) => Err(err).context("loading system settings"),
}
}

View File

@@ -0,0 +1,193 @@
use std::{
path::PathBuf,
sync::Arc,
time::{Duration, Instant},
};
use async_openai::Client;
use common::storage::{
db::SurrealDbClient,
types::{system_settings::SystemSettings, user::User},
};
use composite_retrieval::{
pipeline::{PipelineStageTimings, RetrievalConfig},
reranking::RerankerPool,
};
use crate::{
args::Config,
cache::EmbeddingCache,
datasets::ConvertedDataset,
embedding::EmbeddingProvider,
eval::{CaseDiagnostics, CaseSummary, EvaluationStageTimings, EvaluationSummary, SeededCase},
ingest, slice, snapshot,
};
pub(super) struct EvaluationContext<'a> {
dataset: &'a ConvertedDataset,
config: &'a Config,
pub stage_timings: EvaluationStageTimings,
pub ledger_limit: Option<usize>,
pub slice_settings: Option<slice::SliceConfig<'a>>,
pub slice: Option<slice::ResolvedSlice<'a>>,
pub window_offset: usize,
pub window_length: usize,
pub window_total_cases: usize,
pub namespace: String,
pub database: String,
pub db: Option<SurrealDbClient>,
pub descriptor: Option<snapshot::Descriptor>,
pub settings: Option<SystemSettings>,
pub settings_missing: bool,
pub must_reapply_settings: bool,
pub embedding_provider: Option<EmbeddingProvider>,
pub embedding_cache: Option<EmbeddingCache>,
pub openai_client: Option<Arc<Client<async_openai::config::OpenAIConfig>>>,
pub openai_base_url: Option<String>,
pub expected_fingerprint: Option<String>,
pub ingestion_duration_ms: u128,
pub namespace_seed_ms: Option<u128>,
pub namespace_reused: bool,
pub evaluation_start: Option<Instant>,
pub eval_user: Option<User>,
pub corpus_handle: Option<ingest::CorpusHandle>,
pub cases: Vec<SeededCase>,
pub stage_latency_samples: Vec<PipelineStageTimings>,
pub latencies: Vec<u128>,
pub diagnostics_output: Vec<CaseDiagnostics>,
pub query_summaries: Vec<CaseSummary>,
pub rerank_pool: Option<Arc<RerankerPool>>,
pub retrieval_config: Option<Arc<RetrievalConfig>>,
pub summary: Option<EvaluationSummary>,
pub diagnostics_path: Option<PathBuf>,
pub diagnostics_enabled: bool,
}
impl<'a> EvaluationContext<'a> {
pub fn new(dataset: &'a ConvertedDataset, config: &'a Config) -> Self {
Self {
dataset,
config,
stage_timings: EvaluationStageTimings::default(),
ledger_limit: None,
slice_settings: None,
slice: None,
window_offset: 0,
window_length: 0,
window_total_cases: 0,
namespace: String::new(),
database: String::new(),
db: None,
descriptor: None,
settings: None,
settings_missing: false,
must_reapply_settings: false,
embedding_provider: None,
embedding_cache: None,
openai_client: None,
openai_base_url: None,
expected_fingerprint: None,
ingestion_duration_ms: 0,
namespace_seed_ms: None,
namespace_reused: false,
evaluation_start: None,
eval_user: None,
corpus_handle: None,
cases: Vec::new(),
stage_latency_samples: Vec::new(),
latencies: Vec::new(),
diagnostics_output: Vec::new(),
query_summaries: Vec::new(),
rerank_pool: None,
retrieval_config: None,
summary: None,
diagnostics_path: config.chunk_diagnostics_path.clone(),
diagnostics_enabled: config.chunk_diagnostics_path.is_some(),
}
}
pub fn dataset(&self) -> &'a ConvertedDataset {
self.dataset
}
pub fn config(&self) -> &'a Config {
self.config
}
pub fn slice(&self) -> &slice::ResolvedSlice<'a> {
self.slice.as_ref().expect("slice has not been prepared")
}
pub fn db(&self) -> &SurrealDbClient {
self.db.as_ref().expect("database connection missing")
}
pub fn descriptor(&self) -> &snapshot::Descriptor {
self.descriptor
.as_ref()
.expect("snapshot descriptor unavailable")
}
pub fn embedding_provider(&self) -> &EmbeddingProvider {
self.embedding_provider
.as_ref()
.expect("embedding provider not initialised")
}
pub fn openai_client(&self) -> Arc<Client<async_openai::config::OpenAIConfig>> {
self.openai_client
.as_ref()
.expect("openai client missing")
.clone()
}
pub fn corpus_handle(&self) -> &ingest::CorpusHandle {
self.corpus_handle.as_ref().expect("corpus handle missing")
}
pub fn evaluation_user(&self) -> &User {
self.eval_user.as_ref().expect("evaluation user missing")
}
pub fn record_stage_duration(&mut self, stage: EvalStage, duration: Duration) {
let elapsed = duration.as_millis() as u128;
match stage {
EvalStage::PrepareSlice => self.stage_timings.prepare_slice_ms += elapsed,
EvalStage::PrepareDb => self.stage_timings.prepare_db_ms += elapsed,
EvalStage::PrepareCorpus => self.stage_timings.prepare_corpus_ms += elapsed,
EvalStage::PrepareNamespace => self.stage_timings.prepare_namespace_ms += elapsed,
EvalStage::RunQueries => self.stage_timings.run_queries_ms += elapsed,
EvalStage::Summarize => self.stage_timings.summarize_ms += elapsed,
EvalStage::Finalize => self.stage_timings.finalize_ms += elapsed,
}
}
pub fn into_summary(self) -> EvaluationSummary {
self.summary.expect("evaluation summary missing")
}
}
#[derive(Copy, Clone)]
pub(super) enum EvalStage {
PrepareSlice,
PrepareDb,
PrepareCorpus,
PrepareNamespace,
RunQueries,
Summarize,
Finalize,
}
impl EvalStage {
pub fn label(&self) -> &'static str {
match self {
EvalStage::PrepareSlice => "prepare-slice",
EvalStage::PrepareDb => "prepare-db",
EvalStage::PrepareCorpus => "prepare-corpus",
EvalStage::PrepareNamespace => "prepare-namespace",
EvalStage::RunQueries => "run-queries",
EvalStage::Summarize => "summarize",
EvalStage::Finalize => "finalize",
}
}
}

View File

@@ -0,0 +1,29 @@
mod context;
mod stages;
mod state;
use anyhow::Result;
use crate::{args::Config, datasets::ConvertedDataset, eval::EvaluationSummary};
use context::EvaluationContext;
pub async fn run_evaluation(
dataset: &ConvertedDataset,
config: &Config,
) -> Result<EvaluationSummary> {
let mut ctx = EvaluationContext::new(dataset, config);
let machine = state::ready();
let machine = stages::prepare_slice(machine, &mut ctx).await?;
let machine = stages::prepare_db(machine, &mut ctx).await?;
let machine = stages::prepare_corpus(machine, &mut ctx).await?;
let machine = stages::prepare_namespace(machine, &mut ctx).await?;
let machine = stages::run_queries(machine, &mut ctx).await?;
let machine = stages::summarize(machine, &mut ctx).await?;
let machine = stages::finalize(machine, &mut ctx).await?;
drop(machine);
Ok(ctx.into_summary())
}

View File

@@ -0,0 +1,59 @@
use std::time::Instant;
use anyhow::Context;
use tracing::info;
use crate::eval::write_chunk_diagnostics;
use super::super::{
context::{EvalStage, EvaluationContext},
state::{Completed, EvaluationMachine, Summarized},
};
use super::{map_guard_error, StageResult};
pub(crate) async fn finalize(
machine: EvaluationMachine<(), Summarized>,
ctx: &mut EvaluationContext<'_>,
) -> StageResult<Completed> {
let stage = EvalStage::Finalize;
info!(
evaluation_stage = stage.label(),
"starting evaluation stage"
);
let started = Instant::now();
if let Some(cache) = ctx.embedding_cache.as_ref() {
cache
.persist()
.await
.context("persisting embedding cache")?;
}
if let Some(path) = ctx.diagnostics_path.as_ref() {
if ctx.diagnostics_enabled {
write_chunk_diagnostics(path.as_path(), &ctx.diagnostics_output)
.await
.with_context(|| format!("writing chunk diagnostics to {}", path.display()))?;
}
}
info!(
total_cases = ctx.summary.as_ref().map(|s| s.total_cases).unwrap_or(0),
correct = ctx.summary.as_ref().map(|s| s.correct).unwrap_or(0),
precision = ctx.summary.as_ref().map(|s| s.precision).unwrap_or(0.0),
dataset = ctx.dataset().metadata.id.as_str(),
"Evaluation complete"
);
let elapsed = started.elapsed();
ctx.record_stage_duration(stage, elapsed);
info!(
evaluation_stage = stage.label(),
duration_ms = elapsed.as_millis(),
"completed evaluation stage"
);
machine
.finalize()
.map_err(|(_, guard)| map_guard_error("finalize", guard))
}

View File

@@ -0,0 +1,26 @@
mod finalize;
mod prepare_corpus;
mod prepare_db;
mod prepare_namespace;
mod prepare_slice;
mod run_queries;
mod summarize;
pub(crate) use finalize::finalize;
pub(crate) use prepare_corpus::prepare_corpus;
pub(crate) use prepare_db::prepare_db;
pub(crate) use prepare_namespace::prepare_namespace;
pub(crate) use prepare_slice::prepare_slice;
pub(crate) use run_queries::run_queries;
pub(crate) use summarize::summarize;
use anyhow::Result;
use state_machines::core::GuardError;
use super::state::EvaluationMachine;
fn map_guard_error(event: &str, guard: GuardError) -> anyhow::Error {
anyhow::anyhow!("invalid evaluation pipeline transition during {event}: {guard:?}")
}
type StageResult<S> = Result<EvaluationMachine<(), S>>;

View File

@@ -0,0 +1,84 @@
use std::time::Instant;
use anyhow::Context;
use tracing::info;
use crate::{ingest, slice, snapshot};
use super::super::{
context::{EvalStage, EvaluationContext},
state::{CorpusReady, DbReady, EvaluationMachine},
};
use super::{map_guard_error, StageResult};
pub(crate) async fn prepare_corpus(
machine: EvaluationMachine<(), DbReady>,
ctx: &mut EvaluationContext<'_>,
) -> StageResult<CorpusReady> {
let stage = EvalStage::PrepareCorpus;
info!(
evaluation_stage = stage.label(),
"starting evaluation stage"
);
let started = Instant::now();
let config = ctx.config();
let cache_settings = ingest::CorpusCacheConfig::from(config);
let embedding_provider = ctx.embedding_provider().clone();
let openai_client = ctx.openai_client();
let eval_user_id = "eval-user".to_string();
let ingestion_timer = Instant::now();
let corpus_handle = {
let slice = ctx.slice();
let window = slice::select_window(slice, ctx.config().slice_offset, ctx.config().limit)
.context("selecting slice window for corpus preparation")?;
ingest::ensure_corpus(
ctx.dataset(),
slice,
&window,
&cache_settings,
&embedding_provider,
openai_client,
&eval_user_id,
config.converted_dataset_path.as_path(),
)
.await
.context("ensuring ingestion-backed corpus")?
};
let expected_fingerprint = corpus_handle
.manifest
.metadata
.ingestion_fingerprint
.clone();
let ingestion_duration_ms = ingestion_timer.elapsed().as_millis() as u128;
info!(
cache = %corpus_handle.path.display(),
reused_ingestion = corpus_handle.reused_ingestion,
reused_embeddings = corpus_handle.reused_embeddings,
positive_ingested = corpus_handle.positive_ingested,
negative_ingested = corpus_handle.negative_ingested,
"Ingestion corpus ready"
);
ctx.corpus_handle = Some(corpus_handle);
ctx.expected_fingerprint = Some(expected_fingerprint);
ctx.ingestion_duration_ms = ingestion_duration_ms;
ctx.descriptor = Some(snapshot::Descriptor::new(
config,
ctx.slice(),
ctx.embedding_provider(),
));
let elapsed = started.elapsed();
ctx.record_stage_duration(stage, elapsed);
info!(
evaluation_stage = stage.label(),
duration_ms = elapsed.as_millis(),
"completed evaluation stage"
);
machine
.prepare_corpus()
.map_err(|(_, guard)| map_guard_error("prepare_corpus", guard))
}

View File

@@ -0,0 +1,111 @@
use std::{sync::Arc, time::Instant};
use anyhow::{anyhow, Context};
use tracing::info;
use crate::{
args::EmbeddingBackend,
cache::EmbeddingCache,
embedding,
eval::{
connect_eval_db, enforce_system_settings, load_or_init_system_settings, sanitize_model_code,
},
openai,
};
use super::super::{
context::{EvalStage, EvaluationContext},
state::{DbReady, EvaluationMachine, SlicePrepared},
};
use super::{map_guard_error, StageResult};
pub(crate) async fn prepare_db(
machine: EvaluationMachine<(), SlicePrepared>,
ctx: &mut EvaluationContext<'_>,
) -> StageResult<DbReady> {
let stage = EvalStage::PrepareDb;
info!(
evaluation_stage = stage.label(),
"starting evaluation stage"
);
let started = Instant::now();
let namespace = ctx.namespace.clone();
let database = ctx.database.clone();
let config = ctx.config();
let db = connect_eval_db(config, &namespace, &database).await?;
let (mut settings, settings_missing) = load_or_init_system_settings(&db).await?;
let embedding_provider =
embedding::build_provider(config, settings.embedding_dimensions as usize)
.await
.context("building embedding provider")?;
let (raw_openai_client, openai_base_url) =
openai::build_client_from_env().context("building OpenAI client")?;
let openai_client = Arc::new(raw_openai_client);
let provider_dimension = embedding_provider.dimension();
if provider_dimension == 0 {
return Err(anyhow!(
"embedding provider reported zero dimensions; cannot continue"
));
}
info!(
backend = embedding_provider.backend_label(),
model = embedding_provider
.model_code()
.as_deref()
.unwrap_or("<none>"),
dimension = provider_dimension,
"Embedding provider initialised"
);
info!(openai_base_url = %openai_base_url, "OpenAI client configured");
let embedding_cache = if config.embedding_backend == EmbeddingBackend::FastEmbed {
if let Some(model_code) = embedding_provider.model_code() {
let sanitized = sanitize_model_code(&model_code);
let path = config.cache_dir.join(format!("{sanitized}.json"));
if config.force_convert && path.exists() {
tokio::fs::remove_file(&path)
.await
.with_context(|| format!("removing stale cache {}", path.display()))
.ok();
}
let cache = EmbeddingCache::load(&path).await?;
info!(path = %path.display(), "Embedding cache ready");
Some(cache)
} else {
None
}
} else {
None
};
let must_reapply_settings = settings_missing;
let defer_initial_enforce = settings_missing && !config.reseed_slice;
if !defer_initial_enforce {
settings = enforce_system_settings(&db, settings, provider_dimension, config).await?;
}
ctx.db = Some(db);
ctx.settings_missing = settings_missing;
ctx.must_reapply_settings = must_reapply_settings;
ctx.settings = Some(settings);
ctx.embedding_provider = Some(embedding_provider);
ctx.embedding_cache = embedding_cache;
ctx.openai_client = Some(openai_client);
ctx.openai_base_url = Some(openai_base_url);
let elapsed = started.elapsed();
ctx.record_stage_duration(stage, elapsed);
info!(
evaluation_stage = stage.label(),
duration_ms = elapsed.as_millis(),
"completed evaluation stage"
);
machine
.prepare_db()
.map_err(|(_, guard)| map_guard_error("prepare_db", guard))
}

View File

@@ -0,0 +1,159 @@
use std::time::Instant;
use anyhow::{anyhow, Context};
use common::storage::types::system_settings::SystemSettings;
use tracing::{info, warn};
use crate::{
db_helpers::{recreate_indexes, remove_all_indexes, reset_namespace},
eval::{
can_reuse_namespace, cases_from_manifest, enforce_system_settings, ensure_eval_user,
record_namespace_state, warm_hnsw_cache,
},
ingest,
};
use super::super::{
context::{EvalStage, EvaluationContext},
state::{CorpusReady, EvaluationMachine, NamespaceReady},
};
use super::{map_guard_error, StageResult};
pub(crate) async fn prepare_namespace(
machine: EvaluationMachine<(), CorpusReady>,
ctx: &mut EvaluationContext<'_>,
) -> StageResult<NamespaceReady> {
let stage = EvalStage::PrepareNamespace;
info!(
evaluation_stage = stage.label(),
"starting evaluation stage"
);
let started = Instant::now();
let config = ctx.config();
let dataset = ctx.dataset();
let expected_fingerprint = ctx
.expected_fingerprint
.as_deref()
.unwrap_or_default()
.to_string();
let namespace = ctx.namespace.clone();
let database = ctx.database.clone();
let embedding_provider = ctx.embedding_provider().clone();
let mut namespace_reused = false;
if !config.reseed_slice {
namespace_reused = {
let slice = ctx.slice();
can_reuse_namespace(
ctx.db(),
ctx.descriptor(),
&namespace,
&database,
dataset.metadata.id.as_str(),
slice.manifest.slice_id.as_str(),
expected_fingerprint.as_str(),
slice.manifest.case_count,
)
.await?
};
}
let mut namespace_seed_ms = None;
if !namespace_reused {
ctx.must_reapply_settings = true;
if let Err(err) = reset_namespace(ctx.db(), &namespace, &database).await {
warn!(
error = %err,
namespace,
database = %database,
"Failed to reset namespace before reseeding; continuing with existing data"
);
} else if let Err(err) = ctx.db().apply_migrations().await {
warn!(error = %err, "Failed to reapply migrations after namespace reset");
}
{
let slice = ctx.slice();
info!(
slice = slice.manifest.slice_id.as_str(),
window_offset = ctx.window_offset,
window_length = ctx.window_length,
positives = slice.manifest.positive_paragraphs,
negatives = slice.manifest.negative_paragraphs,
total = slice.manifest.total_paragraphs,
"Seeding ingestion corpus into SurrealDB"
);
}
let indexes_disabled = remove_all_indexes(ctx.db()).await.is_ok();
let seed_start = Instant::now();
ingest::seed_manifest_into_db(ctx.db(), &ctx.corpus_handle().manifest)
.await
.context("seeding ingestion corpus from manifest")?;
namespace_seed_ms = Some(seed_start.elapsed().as_millis() as u128);
if indexes_disabled {
info!("Recreating indexes after namespace reset");
if let Err(err) = recreate_indexes(ctx.db(), embedding_provider.dimension()).await {
warn!(error = %err, "failed to restore indexes after namespace reset");
} else {
warm_hnsw_cache(ctx.db(), embedding_provider.dimension()).await?;
}
}
{
let slice = ctx.slice();
record_namespace_state(
ctx.descriptor(),
dataset.metadata.id.as_str(),
slice.manifest.slice_id.as_str(),
expected_fingerprint.as_str(),
&namespace,
&database,
slice.manifest.case_count,
)
.await;
}
}
if ctx.must_reapply_settings {
let mut settings = SystemSettings::get_current(ctx.db())
.await
.context("reloading system settings after namespace reset")?;
settings =
enforce_system_settings(ctx.db(), settings, embedding_provider.dimension(), config)
.await?;
ctx.settings = Some(settings);
ctx.must_reapply_settings = false;
}
let user = ensure_eval_user(ctx.db()).await?;
ctx.eval_user = Some(user);
let cases = cases_from_manifest(&ctx.corpus_handle().manifest);
if cases.is_empty() {
return Err(anyhow!(
"no answerable questions found in converted dataset for evaluation"
));
}
ctx.cases = cases;
ctx.namespace_reused = namespace_reused;
ctx.namespace_seed_ms = namespace_seed_ms;
info!(
cases = ctx.cases.len(),
window_offset = ctx.window_offset,
namespace_reused = namespace_reused,
"Dataset ready"
);
let elapsed = started.elapsed();
ctx.record_stage_duration(stage, elapsed);
info!(
evaluation_stage = stage.label(),
duration_ms = elapsed.as_millis(),
"completed evaluation stage"
);
machine
.prepare_namespace()
.map_err(|(_, guard)| map_guard_error("prepare_namespace", guard))
}

View File

@@ -0,0 +1,66 @@
use std::time::Instant;
use anyhow::Context;
use tracing::info;
use crate::{
eval::{default_database, default_namespace, ledger_target},
slice,
};
use super::super::{
context::{EvalStage, EvaluationContext},
state::{EvaluationMachine, Ready, SlicePrepared},
};
use super::{map_guard_error, StageResult};
pub(crate) async fn prepare_slice(
machine: EvaluationMachine<(), Ready>,
ctx: &mut EvaluationContext<'_>,
) -> StageResult<SlicePrepared> {
let stage = EvalStage::PrepareSlice;
info!(
evaluation_stage = stage.label(),
"starting evaluation stage"
);
let started = Instant::now();
let ledger_limit = ledger_target(ctx.config());
let slice_settings = slice::slice_config_with_limit(ctx.config(), ledger_limit);
let resolved_slice =
slice::resolve_slice(ctx.dataset(), &slice_settings).context("resolving dataset slice")?;
let window = slice::select_window(
&resolved_slice,
ctx.config().slice_offset,
ctx.config().limit,
)
.context("selecting slice window (use --slice-grow to extend the ledger first)")?;
ctx.ledger_limit = ledger_limit;
ctx.slice_settings = Some(slice_settings);
ctx.slice = Some(resolved_slice.clone());
ctx.window_offset = window.offset;
ctx.window_length = window.length;
ctx.window_total_cases = window.total_cases;
ctx.namespace = ctx.config().db_namespace.clone().unwrap_or_else(|| {
default_namespace(ctx.dataset().metadata.id.as_str(), ctx.config().limit)
});
ctx.database = ctx
.config()
.db_database
.clone()
.unwrap_or_else(default_database);
let elapsed = started.elapsed();
ctx.record_stage_duration(stage, elapsed);
info!(
evaluation_stage = stage.label(),
duration_ms = elapsed.as_millis(),
"completed evaluation stage"
);
machine
.prepare_slice()
.map_err(|(_, guard)| map_guard_error("prepare_slice", guard))
}

View File

@@ -0,0 +1,337 @@
use std::{collections::HashSet, sync::Arc, time::Instant};
use anyhow::Context;
use futures::stream::{self, StreamExt, TryStreamExt};
use tracing::{debug, info};
use crate::eval::{
apply_dataset_tuning_overrides, build_case_diagnostics, text_contains_answer, CaseDiagnostics,
CaseSummary, RetrievedSummary,
};
use composite_retrieval::pipeline::{self, PipelineStageTimings, RetrievalConfig};
use composite_retrieval::reranking::RerankerPool;
use tokio::sync::Semaphore;
use super::super::{
context::{EvalStage, EvaluationContext},
state::{EvaluationMachine, NamespaceReady, QueriesFinished},
};
use super::{map_guard_error, StageResult};
pub(crate) async fn run_queries(
machine: EvaluationMachine<(), NamespaceReady>,
ctx: &mut EvaluationContext<'_>,
) -> StageResult<QueriesFinished> {
let stage = EvalStage::RunQueries;
info!(
evaluation_stage = stage.label(),
"starting evaluation stage"
);
let started = Instant::now();
let config = ctx.config();
let dataset = ctx.dataset();
let slice_settings = ctx
.slice_settings
.as_ref()
.expect("slice settings missing during query stage");
let total_cases = ctx.cases.len();
let cases_iter = std::mem::take(&mut ctx.cases).into_iter().enumerate();
let rerank_pool = if config.rerank {
Some(RerankerPool::new(config.rerank_pool_size).context("initialising reranker pool")?)
} else {
None
};
let mut retrieval_config = RetrievalConfig::default();
retrieval_config.tuning.rerank_keep_top = config.rerank_keep_top;
if retrieval_config.tuning.fallback_min_results < config.rerank_keep_top {
retrieval_config.tuning.fallback_min_results = config.rerank_keep_top;
}
if let Some(value) = config.chunk_vector_take {
retrieval_config.tuning.chunk_vector_take = value;
}
if let Some(value) = config.chunk_fts_take {
retrieval_config.tuning.chunk_fts_take = value;
}
if let Some(value) = config.chunk_token_budget {
retrieval_config.tuning.token_budget_estimate = value;
}
if let Some(value) = config.chunk_avg_chars_per_token {
retrieval_config.tuning.avg_chars_per_token = value;
}
if let Some(value) = config.max_chunks_per_entity {
retrieval_config.tuning.max_chunks_per_entity = value;
}
apply_dataset_tuning_overrides(dataset, config, &mut retrieval_config.tuning);
let active_tuning = retrieval_config.tuning.clone();
let effective_chunk_vector = config
.chunk_vector_take
.unwrap_or(active_tuning.chunk_vector_take);
let effective_chunk_fts = config
.chunk_fts_take
.unwrap_or(active_tuning.chunk_fts_take);
info!(
dataset = dataset.metadata.id.as_str(),
slice_seed = config.slice_seed,
slice_offset = config.slice_offset,
slice_limit = config
.limit
.unwrap_or(ctx.window_total_cases),
negative_multiplier = %slice_settings.negative_multiplier,
rerank_enabled = config.rerank,
rerank_pool_size = config.rerank_pool_size,
rerank_keep_top = config.rerank_keep_top,
chunk_min = config.chunk_min_chars,
chunk_max = config.chunk_max_chars,
chunk_vector_take = effective_chunk_vector,
chunk_fts_take = effective_chunk_fts,
chunk_token_budget = active_tuning.token_budget_estimate,
embedding_backend = ctx.embedding_provider().backend_label(),
embedding_model = ctx
.embedding_provider()
.model_code()
.as_deref()
.unwrap_or("<default>"),
"Starting evaluation run"
);
let retrieval_config = Arc::new(retrieval_config);
ctx.rerank_pool = rerank_pool.clone();
ctx.retrieval_config = Some(retrieval_config.clone());
ctx.evaluation_start = Some(Instant::now());
let user_id = ctx.evaluation_user().id.clone();
let concurrency = config.concurrency.max(1);
let diagnostics_enabled = ctx.diagnostics_enabled;
let query_semaphore = Arc::new(Semaphore::new(concurrency));
info!(
total_cases = total_cases,
max_concurrent_queries = concurrency,
"Starting evaluation with staged query execution"
);
let embedding_provider_for_queries = ctx.embedding_provider().clone();
let rerank_pool_for_queries = rerank_pool.clone();
let db = ctx.db().clone();
let openai_client = ctx.openai_client();
let results: Vec<(
usize,
CaseSummary,
Option<CaseDiagnostics>,
PipelineStageTimings,
)> = stream::iter(cases_iter)
.map(move |(idx, case)| {
let db = db.clone();
let openai_client = openai_client.clone();
let user_id = user_id.clone();
let retrieval_config = retrieval_config.clone();
let embedding_provider = embedding_provider_for_queries.clone();
let rerank_pool = rerank_pool_for_queries.clone();
let semaphore = query_semaphore.clone();
let diagnostics_enabled = diagnostics_enabled;
async move {
let _permit = semaphore
.acquire()
.await
.context("acquiring query semaphore permit")?;
let crate::eval::SeededCase {
question_id,
question,
expected_source,
answers,
paragraph_id,
paragraph_title,
expected_chunk_ids,
} = case;
let query_start = Instant::now();
debug!(question_id = %question_id, "Evaluating query");
let query_embedding =
embedding_provider.embed(&question).await.with_context(|| {
format!("generating embedding for question {}", question_id)
})?;
let reranker = match &rerank_pool {
Some(pool) => Some(pool.checkout().await),
None => None,
};
let (results, pipeline_diagnostics, stage_timings) = if diagnostics_enabled {
let outcome = pipeline::run_pipeline_with_embedding_with_diagnostics(
&db,
&openai_client,
query_embedding,
&question,
&user_id,
(*retrieval_config).clone(),
reranker,
)
.await
.with_context(|| format!("running pipeline for question {}", question_id))?;
(outcome.results, outcome.diagnostics, outcome.stage_timings)
} else {
let outcome = pipeline::run_pipeline_with_embedding_with_metrics(
&db,
&openai_client,
query_embedding,
&question,
&user_id,
(*retrieval_config).clone(),
reranker,
)
.await
.with_context(|| format!("running pipeline for question {}", question_id))?;
(outcome.results, None, outcome.stage_timings)
};
let query_latency = query_start.elapsed().as_millis() as u128;
let mut retrieved = Vec::new();
let mut match_rank = None;
let answers_lower: Vec<String> =
answers.iter().map(|ans| ans.to_ascii_lowercase()).collect();
let expected_chunk_ids_set: HashSet<&str> =
expected_chunk_ids.iter().map(|id| id.as_str()).collect();
let chunk_id_required = !expected_chunk_ids_set.is_empty();
let mut entity_hit = false;
let mut chunk_text_hit = false;
let mut chunk_id_hit = !chunk_id_required;
for (idx_entity, entity) in results.iter().enumerate() {
if idx_entity >= config.k {
break;
}
let entity_match = entity.entity.source_id == expected_source;
if entity_match {
entity_hit = true;
}
let chunk_text_for_entity = entity
.chunks
.iter()
.any(|chunk| text_contains_answer(&chunk.chunk.chunk, &answers_lower));
if chunk_text_for_entity {
chunk_text_hit = true;
}
let chunk_id_for_entity = if chunk_id_required {
expected_chunk_ids_set.contains(entity.entity.source_id.as_str())
|| entity.chunks.iter().any(|chunk| {
expected_chunk_ids_set.contains(chunk.chunk.id.as_str())
})
} else {
true
};
if chunk_id_for_entity {
chunk_id_hit = true;
}
let success = entity_match && chunk_text_for_entity && chunk_id_for_entity;
if success && match_rank.is_none() {
match_rank = Some(idx_entity + 1);
}
let detail_fields = if config.detailed_report {
(
Some(entity.entity.description.clone()),
Some(format!("{:?}", entity.entity.entity_type)),
Some(chunk_text_for_entity),
Some(chunk_id_for_entity),
)
} else {
(None, None, None, None)
};
retrieved.push(RetrievedSummary {
rank: idx_entity + 1,
entity_id: entity.entity.id.clone(),
source_id: entity.entity.source_id.clone(),
entity_name: entity.entity.name.clone(),
score: entity.score,
matched: success,
entity_description: detail_fields.0,
entity_category: detail_fields.1,
chunk_text_match: detail_fields.2,
chunk_id_match: detail_fields.3,
});
}
let overall_match = match_rank.is_some();
let summary = CaseSummary {
question_id,
question,
paragraph_id,
paragraph_title,
expected_source,
answers,
matched: overall_match,
entity_match: entity_hit,
chunk_text_match: chunk_text_hit,
chunk_id_match: chunk_id_hit,
match_rank,
latency_ms: query_latency,
retrieved,
};
let diagnostics = if diagnostics_enabled {
Some(build_case_diagnostics(
&summary,
&expected_chunk_ids,
&answers_lower,
&results,
pipeline_diagnostics,
))
} else {
None
};
Ok::<
(
usize,
CaseSummary,
Option<CaseDiagnostics>,
PipelineStageTimings,
),
anyhow::Error,
>((idx, summary, diagnostics, stage_timings))
}
})
.buffer_unordered(concurrency)
.try_collect()
.await?;
let mut ordered = results;
ordered.sort_by_key(|(idx, ..)| *idx);
let mut summaries = Vec::with_capacity(ordered.len());
let mut latencies = Vec::with_capacity(ordered.len());
let mut diagnostics_output = Vec::new();
let mut stage_latency_samples = Vec::with_capacity(ordered.len());
for (_, summary, diagnostics, stage_timings) in ordered {
latencies.push(summary.latency_ms);
summaries.push(summary);
if let Some(diag) = diagnostics {
diagnostics_output.push(diag);
}
stage_latency_samples.push(stage_timings);
}
ctx.query_summaries = summaries;
ctx.latencies = latencies;
ctx.diagnostics_output = diagnostics_output;
ctx.stage_latency_samples = stage_latency_samples;
let elapsed = started.elapsed();
ctx.record_stage_duration(stage, elapsed);
info!(
evaluation_stage = stage.label(),
duration_ms = elapsed.as_millis(),
"completed evaluation stage"
);
machine
.run_queries()
.map_err(|(_, guard)| map_guard_error("run_queries", guard))
}

View File

@@ -0,0 +1,173 @@
use std::time::Instant;
use chrono::Utc;
use tracing::info;
use crate::eval::{
build_stage_latency_breakdown, compute_latency_stats, EvaluationSummary, PerformanceTimings,
};
use super::super::{
context::{EvalStage, EvaluationContext},
state::{EvaluationMachine, QueriesFinished, Summarized},
};
use super::{map_guard_error, StageResult};
pub(crate) async fn summarize(
machine: EvaluationMachine<(), QueriesFinished>,
ctx: &mut EvaluationContext<'_>,
) -> StageResult<Summarized> {
let stage = EvalStage::Summarize;
info!(
evaluation_stage = stage.label(),
"starting evaluation stage"
);
let started = Instant::now();
let summaries = std::mem::take(&mut ctx.query_summaries);
let latencies = std::mem::take(&mut ctx.latencies);
let stage_latency_samples = std::mem::take(&mut ctx.stage_latency_samples);
let duration_ms = ctx
.evaluation_start
.take()
.map(|start| start.elapsed().as_millis())
.unwrap_or_default();
let config = ctx.config();
let dataset = ctx.dataset();
let slice = ctx.slice();
let corpus_handle = ctx.corpus_handle();
let total_cases = summaries.len();
let mut correct = 0usize;
let mut correct_at_1 = 0usize;
let mut correct_at_2 = 0usize;
let mut correct_at_3 = 0usize;
for summary in &summaries {
if summary.matched {
correct += 1;
if let Some(rank) = summary.match_rank {
if rank <= 1 {
correct_at_1 += 1;
}
if rank <= 2 {
correct_at_2 += 1;
}
if rank <= 3 {
correct_at_3 += 1;
}
}
}
}
let latency_stats = compute_latency_stats(&latencies);
let stage_latency = build_stage_latency_breakdown(&stage_latency_samples);
let precision = if total_cases == 0 {
0.0
} else {
(correct as f64) / (total_cases as f64)
};
let precision_at_1 = if total_cases == 0 {
0.0
} else {
(correct_at_1 as f64) / (total_cases as f64)
};
let precision_at_2 = if total_cases == 0 {
0.0
} else {
(correct_at_2 as f64) / (total_cases as f64)
};
let precision_at_3 = if total_cases == 0 {
0.0
} else {
(correct_at_3 as f64) / (total_cases as f64)
};
let active_tuning = ctx
.retrieval_config
.as_ref()
.map(|cfg| cfg.tuning.clone())
.unwrap_or_default();
let perf_timings = PerformanceTimings {
openai_base_url: ctx
.openai_base_url
.clone()
.unwrap_or_else(|| "<unknown>".to_string()),
ingestion_ms: ctx.ingestion_duration_ms,
namespace_seed_ms: ctx.namespace_seed_ms,
evaluation_stage_ms: ctx.stage_timings.clone(),
stage_latency,
};
ctx.summary = Some(EvaluationSummary {
generated_at: Utc::now(),
k: config.k,
limit: config.limit,
run_label: config.label.clone(),
total_cases,
correct,
precision,
correct_at_1,
correct_at_2,
correct_at_3,
precision_at_1,
precision_at_2,
precision_at_3,
duration_ms,
dataset_id: dataset.metadata.id.clone(),
dataset_label: dataset.metadata.label.clone(),
dataset_includes_unanswerable: dataset.metadata.include_unanswerable,
dataset_source: dataset.source.clone(),
slice_id: slice.manifest.slice_id.clone(),
slice_seed: slice.manifest.seed,
slice_total_cases: slice.manifest.case_count,
slice_window_offset: ctx.window_offset,
slice_window_length: ctx.window_length,
slice_cases: total_cases,
slice_positive_paragraphs: slice.manifest.positive_paragraphs,
slice_negative_paragraphs: slice.manifest.negative_paragraphs,
slice_total_paragraphs: slice.manifest.total_paragraphs,
slice_negative_multiplier: slice.manifest.negative_multiplier,
namespace_reused: ctx.namespace_reused,
corpus_paragraphs: ctx.corpus_handle().manifest.metadata.paragraph_count,
ingestion_cache_path: corpus_handle.path.display().to_string(),
ingestion_reused: corpus_handle.reused_ingestion,
ingestion_embeddings_reused: corpus_handle.reused_embeddings,
ingestion_fingerprint: corpus_handle
.manifest
.metadata
.ingestion_fingerprint
.clone(),
positive_paragraphs_reused: corpus_handle.positive_reused,
negative_paragraphs_reused: corpus_handle.negative_reused,
latency_ms: latency_stats,
perf: perf_timings,
embedding_backend: ctx.embedding_provider().backend_label().to_string(),
embedding_model: ctx.embedding_provider().model_code(),
embedding_dimension: ctx.embedding_provider().dimension(),
rerank_enabled: config.rerank,
rerank_pool_size: ctx.rerank_pool.as_ref().map(|_| config.rerank_pool_size),
rerank_keep_top: config.rerank_keep_top,
concurrency: config.concurrency.max(1),
detailed_report: config.detailed_report,
chunk_vector_take: active_tuning.chunk_vector_take,
chunk_fts_take: active_tuning.chunk_fts_take,
chunk_token_budget: active_tuning.token_budget_estimate,
chunk_avg_chars_per_token: active_tuning.avg_chars_per_token,
max_chunks_per_entity: active_tuning.max_chunks_per_entity,
cases: summaries,
});
let elapsed = started.elapsed();
ctx.record_stage_duration(stage, elapsed);
info!(
evaluation_stage = stage.label(),
duration_ms = elapsed.as_millis(),
"completed evaluation stage"
);
machine
.summarize()
.map_err(|(_, guard)| map_guard_error("summarize", guard))
}

View File

@@ -0,0 +1,31 @@
use state_machines::state_machine;
state_machine! {
name: EvaluationMachine,
state: EvaluationState,
initial: Ready,
states: [Ready, SlicePrepared, DbReady, CorpusReady, NamespaceReady, QueriesFinished, Summarized, Completed, Failed],
events {
prepare_slice { transition: { from: Ready, to: SlicePrepared } }
prepare_db { transition: { from: SlicePrepared, to: DbReady } }
prepare_corpus { transition: { from: DbReady, to: CorpusReady } }
prepare_namespace { transition: { from: CorpusReady, to: NamespaceReady } }
run_queries { transition: { from: NamespaceReady, to: QueriesFinished } }
summarize { transition: { from: QueriesFinished, to: Summarized } }
finalize { transition: { from: Summarized, to: Completed } }
abort {
transition: { from: Ready, to: Failed }
transition: { from: SlicePrepared, to: Failed }
transition: { from: DbReady, to: Failed }
transition: { from: CorpusReady, to: Failed }
transition: { from: NamespaceReady, to: Failed }
transition: { from: QueriesFinished, to: Failed }
transition: { from: Summarized, to: Failed }
transition: { from: Completed, to: Failed }
}
}
}
pub fn ready() -> EvaluationMachine<(), Ready> {
EvaluationMachine::new(())
}

1000
eval/src/ingest.rs Normal file

File diff suppressed because it is too large Load Diff

182
eval/src/inspection.rs Normal file
View File

@@ -0,0 +1,182 @@
use std::{
collections::HashMap,
fs,
path::{Path, PathBuf},
};
use anyhow::{anyhow, Context, Result};
use common::storage::{db::SurrealDbClient, types::text_chunk::TextChunk};
use crate::{args::Config, eval::connect_eval_db, ingest, snapshot::DbSnapshotState};
pub async fn inspect_question(config: &Config) -> Result<()> {
let question_id = config
.inspect_question
.as_ref()
.ok_or_else(|| anyhow!("--inspect-question is required for inspection mode"))?;
let manifest_path = config
.inspect_manifest
.as_ref()
.ok_or_else(|| anyhow!("--inspect-manifest must be provided for inspection mode"))?;
let manifest = load_manifest(manifest_path)?;
let chunk_lookup = build_chunk_lookup(&manifest);
let question = manifest
.questions
.iter()
.find(|q| q.question_id == *question_id)
.ok_or_else(|| {
anyhow!(
"question '{}' not found in manifest {}",
question_id,
manifest_path.display()
)
})?;
println!("Question: {}", question.question_text);
println!("Answers: {:?}", question.answers);
println!(
"matching_chunk_ids ({}):",
question.matching_chunk_ids.len()
);
let mut missing_in_manifest = Vec::new();
for chunk_id in &question.matching_chunk_ids {
if let Some(entry) = chunk_lookup.get(chunk_id) {
println!(
" - {} (paragraph: {})\n snippet: {}",
chunk_id, entry.paragraph_title, entry.snippet
);
} else {
println!(" - {} (missing from manifest)", chunk_id);
missing_in_manifest.push(chunk_id.clone());
}
}
if missing_in_manifest.is_empty() {
println!("All matching_chunk_ids are present in the ingestion manifest");
} else {
println!(
"Missing chunk IDs in manifest {}: {:?}",
manifest_path.display(),
missing_in_manifest
);
}
let db_state_path = config
.inspect_db_state
.clone()
.unwrap_or_else(|| default_state_path(config, &manifest));
if let Some(state) = load_db_state(&db_state_path)? {
if let (Some(ns), Some(db_name)) = (state.namespace.as_deref(), state.database.as_deref()) {
match connect_eval_db(config, ns, db_name).await {
Ok(db) => match verify_chunks_in_db(&db, &question.matching_chunk_ids).await? {
MissingChunks::None => println!(
"All matching_chunk_ids exist in namespace '{}', database '{}'",
ns, db_name
),
MissingChunks::Missing(list) => println!(
"Missing chunks in namespace '{}', database '{}': {:?}",
ns, db_name, list
),
},
Err(err) => {
println!(
"Failed to connect to SurrealDB namespace '{}' / database '{}': {err}",
ns, db_name
);
}
}
} else {
println!(
"State file {} is missing namespace/database fields; skipping live DB validation",
db_state_path.display()
);
}
} else {
println!(
"State file {} not found; skipping live DB validation",
db_state_path.display()
);
}
Ok(())
}
struct ChunkEntry {
paragraph_title: String,
snippet: String,
}
fn load_manifest(path: &Path) -> Result<ingest::CorpusManifest> {
let bytes =
fs::read(path).with_context(|| format!("reading ingestion manifest {}", path.display()))?;
serde_json::from_slice(&bytes)
.with_context(|| format!("parsing ingestion manifest {}", path.display()))
}
fn build_chunk_lookup(manifest: &ingest::CorpusManifest) -> HashMap<String, ChunkEntry> {
let mut lookup = HashMap::new();
for paragraph in &manifest.paragraphs {
for chunk in &paragraph.chunks {
let snippet = chunk
.chunk
.chars()
.take(160)
.collect::<String>()
.replace('\n', " ");
lookup.insert(
chunk.id.clone(),
ChunkEntry {
paragraph_title: paragraph.title.clone(),
snippet,
},
);
}
}
lookup
}
fn default_state_path(config: &Config, manifest: &ingest::CorpusManifest) -> PathBuf {
config
.cache_dir
.join("snapshots")
.join(&manifest.metadata.dataset_id)
.join(&manifest.metadata.slice_id)
.join("db/state.json")
}
fn load_db_state(path: &Path) -> Result<Option<DbSnapshotState>> {
if !path.exists() {
return Ok(None);
}
let bytes = fs::read(path).with_context(|| format!("reading db state {}", path.display()))?;
let state = serde_json::from_slice(&bytes)
.with_context(|| format!("parsing db state {}", path.display()))?;
Ok(Some(state))
}
enum MissingChunks {
None,
Missing(Vec<String>),
}
async fn verify_chunks_in_db(db: &SurrealDbClient, chunk_ids: &[String]) -> Result<MissingChunks> {
let mut missing = Vec::new();
for chunk_id in chunk_ids {
let exists = db
.get_item::<TextChunk>(chunk_id)
.await
.with_context(|| format!("fetching text_chunk {}", chunk_id))?
.is_some();
if !exists {
missing.push(chunk_id.clone());
}
}
if missing.is_empty() {
Ok(MissingChunks::None)
} else {
Ok(MissingChunks::Missing(missing))
}
}

214
eval/src/main.rs Normal file
View File

@@ -0,0 +1,214 @@
mod args;
mod cache;
mod datasets;
mod db_helpers;
mod embedding;
mod eval;
mod ingest;
mod inspection;
mod openai;
mod perf;
mod report;
mod slice;
mod slices;
mod snapshot;
use anyhow::Context;
use tokio::runtime::Builder;
use tracing::info;
use tracing_subscriber::{fmt, EnvFilter};
/// Configure SurrealDB environment variables for optimal performance
fn configure_surrealdb_performance(cpu_count: usize) {
// Set environment variables only if they're not already set
let indexing_batch_size = std::env::var("SURREAL_INDEXING_BATCH_SIZE")
.unwrap_or_else(|_| (cpu_count * 2).to_string());
std::env::set_var("SURREAL_INDEXING_BATCH_SIZE", indexing_batch_size);
let max_order_queue = std::env::var("SURREAL_MAX_ORDER_LIMIT_PRIORITY_QUEUE_SIZE")
.unwrap_or_else(|_| (cpu_count * 4).to_string());
std::env::set_var(
"SURREAL_MAX_ORDER_LIMIT_PRIORITY_QUEUE_SIZE",
max_order_queue,
);
let websocket_concurrent = std::env::var("SURREAL_WEBSOCKET_MAX_CONCURRENT_REQUESTS")
.unwrap_or_else(|_| cpu_count.to_string());
std::env::set_var(
"SURREAL_WEBSOCKET_MAX_CONCURRENT_REQUESTS",
websocket_concurrent,
);
let websocket_buffer = std::env::var("SURREAL_WEBSOCKET_RESPONSE_BUFFER_SIZE")
.unwrap_or_else(|_| (cpu_count * 8).to_string());
std::env::set_var("SURREAL_WEBSOCKET_RESPONSE_BUFFER_SIZE", websocket_buffer);
let transaction_cache = std::env::var("SURREAL_TRANSACTION_CACHE_SIZE")
.unwrap_or_else(|_| (cpu_count * 16).to_string());
std::env::set_var("SURREAL_TRANSACTION_CACHE_SIZE", transaction_cache);
info!(
indexing_batch_size = %std::env::var("SURREAL_INDEXING_BATCH_SIZE").unwrap(),
max_order_queue = %std::env::var("SURREAL_MAX_ORDER_LIMIT_PRIORITY_QUEUE_SIZE").unwrap(),
websocket_concurrent = %std::env::var("SURREAL_WEBSOCKET_MAX_CONCURRENT_REQUESTS").unwrap(),
websocket_buffer = %std::env::var("SURREAL_WEBSOCKET_RESPONSE_BUFFER_SIZE").unwrap(),
transaction_cache = %std::env::var("SURREAL_TRANSACTION_CACHE_SIZE").unwrap(),
"Configured SurrealDB performance variables"
);
}
fn main() -> anyhow::Result<()> {
// Create an explicit multi-threaded runtime with optimized configuration
let runtime = Builder::new_multi_thread()
.enable_all()
.worker_threads(std::thread::available_parallelism()?.get())
.max_blocking_threads(std::thread::available_parallelism()?.get())
.thread_stack_size(10 * 1024 * 1024) // 10MiB stack size
.thread_name("eval-retrieval-worker")
.build()
.context("failed to create tokio runtime")?;
runtime.block_on(async_main())
}
async fn async_main() -> anyhow::Result<()> {
// Log runtime configuration
let cpu_count = std::thread::available_parallelism()?.get();
info!(
cpu_cores = cpu_count,
worker_threads = cpu_count,
blocking_threads = cpu_count,
thread_stack_size = "10MiB",
"Started multi-threaded tokio runtime"
);
// Configure SurrealDB environment variables for better performance
configure_surrealdb_performance(cpu_count);
let filter = std::env::var("RUST_LOG").unwrap_or_else(|_| "info".to_string());
let _ = fmt()
.with_env_filter(EnvFilter::try_new(&filter).unwrap_or_else(|_| EnvFilter::new("info")))
.try_init();
let parsed = args::parse()?;
if parsed.show_help {
args::print_help();
return Ok(());
}
if parsed.config.inspect_question.is_some() {
inspection::inspect_question(&parsed.config).await?;
return Ok(());
}
let dataset_kind = parsed.config.dataset;
if parsed.config.convert_only {
info!(
dataset = dataset_kind.id(),
"Starting dataset conversion only run"
);
let dataset = crate::datasets::convert(
parsed.config.raw_dataset_path.as_path(),
dataset_kind,
parsed.config.llm_mode,
parsed.config.context_token_limit(),
)
.with_context(|| {
format!(
"converting {} dataset at {}",
dataset_kind.label(),
parsed.config.raw_dataset_path.display()
)
})?;
crate::datasets::write_converted(&dataset, parsed.config.converted_dataset_path.as_path())
.with_context(|| {
format!(
"writing converted dataset to {}",
parsed.config.converted_dataset_path.display()
)
})?;
println!(
"Converted dataset written to {}",
parsed.config.converted_dataset_path.display()
);
return Ok(());
}
info!(dataset = dataset_kind.id(), "Preparing converted dataset");
let dataset = crate::datasets::ensure_converted(
dataset_kind,
parsed.config.raw_dataset_path.as_path(),
parsed.config.converted_dataset_path.as_path(),
parsed.config.force_convert,
parsed.config.llm_mode,
parsed.config.context_token_limit(),
)
.with_context(|| {
format!(
"preparing converted dataset at {}",
parsed.config.converted_dataset_path.display()
)
})?;
info!(
questions = dataset
.paragraphs
.iter()
.map(|p| p.questions.len())
.sum::<usize>(),
paragraphs = dataset.paragraphs.len(),
dataset = dataset.metadata.id.as_str(),
"Dataset ready"
);
if parsed.config.slice_grow.is_some() {
eval::grow_slice(&dataset, &parsed.config)
.await
.context("growing slice ledger")?;
return Ok(());
}
info!("Running retrieval evaluation");
let summary = eval::run_evaluation(&dataset, &parsed.config)
.await
.context("running retrieval evaluation")?;
let report_paths = report::write_reports(
&summary,
parsed.config.report_dir.as_path(),
parsed.config.summary_sample,
)
.with_context(|| format!("writing reports to {}", parsed.config.report_dir.display()))?;
let perf_log_path = perf::write_perf_logs(
&summary,
parsed.config.report_dir.as_path(),
parsed.config.perf_log_json.as_deref(),
parsed.config.perf_log_dir.as_deref(),
)
.with_context(|| {
format!(
"writing perf logs under {}",
parsed.config.report_dir.display()
)
})?;
println!(
"[{}] Precision@{k}: {precision:.3} ({correct}/{total}) → JSON: {json} | Markdown: {md} | Perf: {perf}",
summary.dataset_label,
k = summary.k,
precision = summary.precision,
correct = summary.correct,
total = summary.total_cases,
json = report_paths.json.display(),
md = report_paths.markdown.display(),
perf = perf_log_path.display()
);
if parsed.config.perf_log_console {
perf::print_console_summary(&summary);
}
Ok(())
}

16
eval/src/openai.rs Normal file
View File

@@ -0,0 +1,16 @@
use anyhow::{Context, Result};
use async_openai::{config::OpenAIConfig, Client};
const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
pub fn build_client_from_env() -> Result<(Client<OpenAIConfig>, String)> {
let api_key = std::env::var("OPENAI_API_KEY")
.context("OPENAI_API_KEY must be set to run retrieval evaluations")?;
let base_url =
std::env::var("OPENAI_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string());
let config = OpenAIConfig::new()
.with_api_key(api_key)
.with_api_base(&base_url);
Ok((Client::with_config(config), base_url))
}

313
eval/src/perf.rs Normal file
View File

@@ -0,0 +1,313 @@
use std::{
fs::{self, OpenOptions},
io::Write,
path::{Path, PathBuf},
};
use anyhow::{Context, Result};
use serde::Serialize;
use crate::{
args,
eval::{format_timestamp, EvaluationStageTimings, EvaluationSummary},
report,
};
#[derive(Debug, Serialize)]
struct PerformanceLogEntry {
generated_at: String,
dataset_id: String,
dataset_label: String,
run_label: Option<String>,
slice_id: String,
slice_seed: u64,
slice_window_offset: usize,
slice_window_length: usize,
limit: Option<usize>,
total_cases: usize,
correct: usize,
precision: f64,
k: usize,
openai_base_url: String,
ingestion: IngestionPerf,
namespace: NamespacePerf,
retrieval: RetrievalPerf,
evaluation_stages: EvaluationStageTimings,
}
#[derive(Debug, Serialize)]
struct IngestionPerf {
duration_ms: u128,
cache_path: String,
reused: bool,
embeddings_reused: bool,
fingerprint: String,
positives_total: usize,
negatives_total: usize,
}
#[derive(Debug, Serialize)]
struct NamespacePerf {
reused: bool,
seed_ms: Option<u128>,
}
#[derive(Debug, Serialize)]
struct RetrievalPerf {
latency_ms: crate::eval::LatencyStats,
stage_latency: crate::eval::StageLatencyBreakdown,
concurrency: usize,
rerank_enabled: bool,
rerank_pool_size: Option<usize>,
rerank_keep_top: usize,
evaluated_cases: usize,
}
impl PerformanceLogEntry {
fn from_summary(summary: &EvaluationSummary) -> Self {
let ingestion = IngestionPerf {
duration_ms: summary.perf.ingestion_ms,
cache_path: summary.ingestion_cache_path.clone(),
reused: summary.ingestion_reused,
embeddings_reused: summary.ingestion_embeddings_reused,
fingerprint: summary.ingestion_fingerprint.clone(),
positives_total: summary.slice_positive_paragraphs,
negatives_total: summary.slice_negative_paragraphs,
};
let namespace = NamespacePerf {
reused: summary.namespace_reused,
seed_ms: summary.perf.namespace_seed_ms,
};
let retrieval = RetrievalPerf {
latency_ms: summary.latency_ms.clone(),
stage_latency: summary.perf.stage_latency.clone(),
concurrency: summary.concurrency,
rerank_enabled: summary.rerank_enabled,
rerank_pool_size: summary.rerank_pool_size,
rerank_keep_top: summary.rerank_keep_top,
evaluated_cases: summary.total_cases,
};
Self {
generated_at: format_timestamp(&summary.generated_at),
dataset_id: summary.dataset_id.clone(),
dataset_label: summary.dataset_label.clone(),
run_label: summary.run_label.clone(),
slice_id: summary.slice_id.clone(),
slice_seed: summary.slice_seed,
slice_window_offset: summary.slice_window_offset,
slice_window_length: summary.slice_window_length,
limit: summary.limit,
total_cases: summary.total_cases,
correct: summary.correct,
precision: summary.precision,
k: summary.k,
openai_base_url: summary.perf.openai_base_url.clone(),
ingestion,
namespace,
retrieval,
evaluation_stages: summary.perf.evaluation_stage_ms.clone(),
}
}
}
pub fn write_perf_logs(
summary: &EvaluationSummary,
report_root: &Path,
extra_json: Option<&Path>,
extra_dir: Option<&Path>,
) -> Result<PathBuf> {
let entry = PerformanceLogEntry::from_summary(summary);
let dataset_dir = report::dataset_report_dir(report_root, &summary.dataset_id);
fs::create_dir_all(&dataset_dir)
.with_context(|| format!("creating dataset perf directory {}", dataset_dir.display()))?;
let log_path = dataset_dir.join("perf-log.jsonl");
let mut file = OpenOptions::new()
.create(true)
.append(true)
.open(&log_path)
.with_context(|| format!("opening perf log {}", log_path.display()))?;
let line = serde_json::to_vec(&entry).context("serialising perf log entry")?;
file.write_all(&line)?;
file.write_all(b"\n")?;
file.flush()?;
if let Some(path) = extra_json {
args::ensure_parent(path)?;
let blob = serde_json::to_vec_pretty(&entry).context("serialising perf log JSON")?;
fs::write(path, blob)
.with_context(|| format!("writing perf log copy to {}", path.display()))?;
}
if let Some(dir) = extra_dir {
fs::create_dir_all(dir)
.with_context(|| format!("creating perf log directory {}", dir.display()))?;
let dataset_slug = dataset_dir
.file_name()
.and_then(|os| os.to_str())
.unwrap_or("dataset");
let timestamp = summary.generated_at.format("%Y%m%dT%H%M%S").to_string();
let filename = format!("perf-{}-{}.json", dataset_slug, timestamp);
let path = dir.join(filename);
let blob = serde_json::to_vec_pretty(&entry).context("serialising perf log JSON")?;
fs::write(&path, blob)
.with_context(|| format!("writing perf log mirror {}", path.display()))?;
}
Ok(log_path)
}
pub fn print_console_summary(summary: &EvaluationSummary) {
let perf = &summary.perf;
println!(
"[perf] ingestion={}ms | namespace_seed={}",
perf.ingestion_ms,
format_duration(perf.namespace_seed_ms),
);
let stage = &perf.stage_latency;
println!(
"[perf] stage avg ms → collect {:.1} | graph {:.1} | chunk {:.1} | rerank {:.1} | assemble {:.1}",
stage.collect_candidates.avg,
stage.graph_expansion.avg,
stage.chunk_attach.avg,
stage.rerank.avg,
stage.assemble.avg,
);
let eval = &perf.evaluation_stage_ms;
println!(
"[perf] eval stage ms → slice {} | db {} | corpus {} | namespace {} | queries {} | summarize {} | finalize {}",
eval.prepare_slice_ms,
eval.prepare_db_ms,
eval.prepare_corpus_ms,
eval.prepare_namespace_ms,
eval.run_queries_ms,
eval.summarize_ms,
eval.finalize_ms,
);
}
fn format_duration(value: Option<u128>) -> String {
value
.map(|ms| format!("{ms}ms"))
.unwrap_or_else(|| "-".to_string())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::eval::{EvaluationStageTimings, PerformanceTimings};
use chrono::Utc;
use tempfile::tempdir;
fn sample_latency() -> crate::eval::LatencyStats {
crate::eval::LatencyStats {
avg: 10.0,
p50: 8,
p95: 15,
}
}
fn sample_stage_latency() -> crate::eval::StageLatencyBreakdown {
crate::eval::StageLatencyBreakdown {
collect_candidates: sample_latency(),
graph_expansion: sample_latency(),
chunk_attach: sample_latency(),
rerank: sample_latency(),
assemble: sample_latency(),
}
}
fn sample_eval_stage() -> EvaluationStageTimings {
EvaluationStageTimings {
prepare_slice_ms: 10,
prepare_db_ms: 20,
prepare_corpus_ms: 30,
prepare_namespace_ms: 40,
run_queries_ms: 50,
summarize_ms: 60,
finalize_ms: 70,
}
}
fn sample_summary() -> EvaluationSummary {
EvaluationSummary {
generated_at: Utc::now(),
k: 5,
limit: Some(10),
run_label: Some("test".into()),
total_cases: 2,
correct: 1,
precision: 0.5,
correct_at_1: 1,
correct_at_2: 1,
correct_at_3: 1,
precision_at_1: 0.5,
precision_at_2: 0.5,
precision_at_3: 0.5,
duration_ms: 1234,
dataset_id: "squad-v2".into(),
dataset_label: "SQuAD v2".into(),
dataset_includes_unanswerable: false,
dataset_source: "dev".into(),
slice_id: "slice123".into(),
slice_seed: 42,
slice_total_cases: 400,
slice_window_offset: 0,
slice_window_length: 10,
slice_cases: 10,
slice_positive_paragraphs: 10,
slice_negative_paragraphs: 40,
slice_total_paragraphs: 50,
slice_negative_multiplier: 4.0,
namespace_reused: true,
corpus_paragraphs: 50,
ingestion_cache_path: "/tmp/cache".into(),
ingestion_reused: true,
ingestion_embeddings_reused: true,
ingestion_fingerprint: "fingerprint".into(),
positive_paragraphs_reused: 10,
negative_paragraphs_reused: 40,
latency_ms: sample_latency(),
perf: PerformanceTimings {
openai_base_url: "https://example.com".into(),
ingestion_ms: 1000,
namespace_seed_ms: Some(150),
evaluation_stage_ms: sample_eval_stage(),
stage_latency: sample_stage_latency(),
},
embedding_backend: "fastembed".into(),
embedding_model: Some("test-model".into()),
embedding_dimension: 32,
rerank_enabled: true,
rerank_pool_size: Some(4),
rerank_keep_top: 10,
concurrency: 2,
detailed_report: false,
chunk_vector_take: 20,
chunk_fts_take: 20,
chunk_token_budget: 10000,
chunk_avg_chars_per_token: 4,
max_chunks_per_entity: 4,
cases: Vec::new(),
}
}
#[test]
fn writes_perf_log_jsonl() {
let tmp = tempdir().unwrap();
let report_root = tmp.path().join("reports");
let summary = sample_summary();
let log_path = write_perf_logs(&summary, &report_root, None, None).expect("perf log write");
assert!(log_path.exists());
let contents = std::fs::read_to_string(&log_path).expect("reading perf log jsonl");
assert!(
contents.contains("\"openai_base_url\":\"https://example.com\""),
"serialized log should include base URL"
);
let dataset_dir = report::dataset_report_dir(&report_root, &summary.dataset_id);
assert!(dataset_dir.join("perf-log.jsonl").exists());
}
}

456
eval/src/report.rs Normal file
View File

@@ -0,0 +1,456 @@
use std::{
fs,
path::{Path, PathBuf},
};
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use crate::eval::{format_timestamp, CaseSummary, EvaluationSummary, LatencyStats};
#[derive(Debug)]
pub struct ReportPaths {
pub json: PathBuf,
pub markdown: PathBuf,
}
pub fn write_reports(
summary: &EvaluationSummary,
report_dir: &Path,
sample: usize,
) -> Result<ReportPaths> {
fs::create_dir_all(report_dir)
.with_context(|| format!("creating report directory {}", report_dir.display()))?;
let dataset_dir = dataset_report_dir(report_dir, &summary.dataset_id);
fs::create_dir_all(&dataset_dir).with_context(|| {
format!(
"creating dataset report directory {}",
dataset_dir.display()
)
})?;
let stem = build_report_stem(summary);
let json_path = dataset_dir.join(format!("{stem}.json"));
let json_blob = serde_json::to_string_pretty(summary).context("serialising JSON report")?;
fs::write(&json_path, &json_blob)
.with_context(|| format!("writing JSON report to {}", json_path.display()))?;
let md_path = dataset_dir.join(format!("{stem}.md"));
let markdown = render_markdown(summary, sample);
fs::write(&md_path, &markdown)
.with_context(|| format!("writing Markdown report to {}", md_path.display()))?;
// Keep a latest.json pointer to simplify automation.
let latest_json = dataset_dir.join("latest.json");
fs::write(&latest_json, json_blob)
.with_context(|| format!("writing latest JSON report to {}", latest_json.display()))?;
let latest_md = dataset_dir.join("latest.md");
fs::write(&latest_md, markdown)
.with_context(|| format!("writing latest Markdown report to {}", latest_md.display()))?;
record_history(summary, &dataset_dir)?;
Ok(ReportPaths {
json: json_path,
markdown: md_path,
})
}
fn render_markdown(summary: &EvaluationSummary, sample: usize) -> String {
let mut md = String::new();
md.push_str(&format!("# Retrieval Precision@{}\n\n", summary.k));
md.push_str("| Metric | Value |\n");
md.push_str("| --- | --- |\n");
md.push_str(&format!(
"| Generated | {} |\n",
format_timestamp(&summary.generated_at)
));
md.push_str(&format!(
"| Dataset | {} (`{}`) |\n",
summary.dataset_label, summary.dataset_id
));
md.push_str(&format!(
"| Run Label | {} |\n",
summary
.run_label
.as_deref()
.filter(|label| !label.is_empty())
.unwrap_or("-")
));
md.push_str(&format!(
"| Unanswerable Included | {} |\n",
if summary.dataset_includes_unanswerable {
"yes"
} else {
"no"
}
));
md.push_str(&format!(
"| Dataset Source | {} |\n",
summary.dataset_source
));
md.push_str(&format!(
"| OpenAI Base URL | {} |\n",
summary.perf.openai_base_url
));
md.push_str(&format!("| Slice ID | `{}` |\n", summary.slice_id));
md.push_str(&format!("| Slice Seed | {} |\n", summary.slice_seed));
md.push_str(&format!(
"| Slice Total Questions | {} |\n",
summary.slice_total_cases
));
md.push_str(&format!(
"| Slice Window (offset/length) | {}/{} |\n",
summary.slice_window_offset, summary.slice_window_length
));
md.push_str(&format!(
"| Slice Window Questions | {} |\n",
summary.slice_cases
));
md.push_str(&format!(
"| Slice Negatives | {} |\n",
summary.slice_negative_paragraphs
));
md.push_str(&format!(
"| Slice Total Paragraphs | {} |\n",
summary.slice_total_paragraphs
));
md.push_str(&format!(
"| Slice Negative Multiplier | {:.2} |\n",
summary.slice_negative_multiplier
));
md.push_str(&format!(
"| Namespace State | {} |\n",
if summary.namespace_reused {
"reused"
} else {
"seeded"
}
));
md.push_str(&format!(
"| Corpus Paragraphs | {} |\n",
summary.corpus_paragraphs
));
md.push_str(&format!(
"| Ingestion Duration | {} ms |\n",
summary.perf.ingestion_ms
));
if let Some(seed) = summary.perf.namespace_seed_ms {
md.push_str(&format!("| Namespace Seed | {} ms |\n", seed));
}
if summary.detailed_report {
md.push_str(&format!(
"| Ingestion Cache | `{}` |\n",
summary.ingestion_cache_path
));
md.push_str(&format!(
"| Ingestion Reused | {} |\n",
if summary.ingestion_reused {
"yes"
} else {
"no"
}
));
md.push_str(&format!(
"| Embeddings Reused | {} |\n",
if summary.ingestion_embeddings_reused {
"yes"
} else {
"no"
}
));
}
md.push_str(&format!(
"| Positives Cached | {} |
",
summary.positive_paragraphs_reused
));
md.push_str(&format!(
"| Negatives Cached | {} |
",
summary.negative_paragraphs_reused
));
let embedding_label = if let Some(model) = summary.embedding_model.as_ref() {
format!("{} ({model})", summary.embedding_backend)
} else {
summary.embedding_backend.clone()
};
md.push_str(&format!("| Embedding | {} |\n", embedding_label));
md.push_str(&format!(
"| Embedding Dim | {} |\n",
summary.embedding_dimension
));
if let Some(limit) = summary.limit {
md.push_str(&format!(
"| Evaluated Queries | {} (limit {}) |\n",
summary.total_cases, limit
));
} else {
md.push_str(&format!(
"| Evaluated Queries | {} |\n",
summary.total_cases
));
}
if summary.rerank_enabled {
let pool = summary
.rerank_pool_size
.map(|size| size.to_string())
.unwrap_or_else(|| "?".to_string());
md.push_str(&format!(
"| Rerank | enabled (pool {pool}, keep top {}) |\n",
summary.rerank_keep_top
));
} else {
md.push_str("| Rerank | disabled |\n");
}
md.push_str(&format!("| Concurrency | {} |\n", summary.concurrency));
md.push_str(&format!(
"| Correct@{} | {}/{} |\n",
summary.k, summary.correct, summary.total_cases
));
md.push_str(&format!(
"| Precision@{} | {:.3} |\n",
summary.k, summary.precision
));
md.push_str(&format!(
"| Precision@1 | {:.3} |\n",
summary.precision_at_1
));
md.push_str(&format!(
"| Precision@2 | {:.3} |\n",
summary.precision_at_2
));
md.push_str(&format!(
"| Precision@3 | {:.3} |\n",
summary.precision_at_3
));
md.push_str(&format!("| Duration | {} ms |\n", summary.duration_ms));
md.push_str(&format!(
"| Latency Avg (ms) | {:.1} |\n",
summary.latency_ms.avg
));
md.push_str(&format!(
"| Latency P50 (ms) | {} |\n",
summary.latency_ms.p50
));
md.push_str(&format!(
"| Latency P95 (ms) | {} |\n",
summary.latency_ms.p95
));
md.push_str("\n## Retrieval Stage Timings\n\n");
md.push_str("| Stage | Avg (ms) | P50 (ms) | P95 (ms) |\n");
md.push_str("| --- | --- | --- | --- |\n");
write_stage_row(
&mut md,
"Collect Candidates",
&summary.perf.stage_latency.collect_candidates,
);
write_stage_row(
&mut md,
"Graph Expansion",
&summary.perf.stage_latency.graph_expansion,
);
write_stage_row(
&mut md,
"Chunk Attach",
&summary.perf.stage_latency.chunk_attach,
);
write_stage_row(&mut md, "Rerank", &summary.perf.stage_latency.rerank);
write_stage_row(&mut md, "Assemble", &summary.perf.stage_latency.assemble);
let misses: Vec<&CaseSummary> = summary.cases.iter().filter(|case| !case.matched).collect();
if !misses.is_empty() {
md.push_str("\n## Missed Queries (sample)\n\n");
if summary.detailed_report {
md.push_str(
"| Question ID | Paragraph | Expected Source | Entity Match | Chunk Text | Chunk ID | Top Retrieved |\n",
);
md.push_str("| --- | --- | --- | --- | --- | --- | --- |\n");
} else {
md.push_str("| Question ID | Paragraph | Expected Source | Top Retrieved |\n");
md.push_str("| --- | --- | --- | --- |\n");
}
for case in misses.iter().take(sample) {
let retrieved = case
.retrieved
.iter()
.map(|entry| format!("{} (rank {})", entry.source_id, entry.rank))
.take(3)
.collect::<Vec<_>>()
.join("<br>");
if summary.detailed_report {
md.push_str(&format!(
"| `{}` | {} | `{}` | {} | {} | {} | {} |\n",
case.question_id,
case.paragraph_title,
case.expected_source,
bool_badge(case.entity_match),
bool_badge(case.chunk_text_match),
bool_badge(case.chunk_id_match),
retrieved
));
} else {
md.push_str(&format!(
"| `{}` | {} | `{}` | {} |\n",
case.question_id, case.paragraph_title, case.expected_source, retrieved
));
}
}
} else {
md.push_str("\n_All evaluated queries matched within the top-k window._\n");
if summary.detailed_report {
md.push_str(
"\nSuccess measures were captured for each query (entity, chunk text, chunk ID).\n",
);
}
}
md
}
fn write_stage_row(buf: &mut String, label: &str, stats: &LatencyStats) {
buf.push_str(&format!(
"| {} | {:.1} | {} | {} |\n",
label, stats.avg, stats.p50, stats.p95
));
}
fn bool_badge(value: bool) -> &'static str {
if value {
""
} else {
""
}
}
fn build_report_stem(summary: &EvaluationSummary) -> String {
let timestamp = summary.generated_at.format("%Y%m%dT%H%M%S");
let backend = sanitize_component(&summary.embedding_backend);
let dataset_component = sanitize_component(&summary.dataset_id);
let model_component = summary
.embedding_model
.as_ref()
.map(|model| sanitize_component(model));
match model_component {
Some(model) => format!(
"precision_at_{}_{}_{}_{}_{}",
summary.k, dataset_component, timestamp, backend, model
),
None => format!(
"precision_at_{}_{}_{}_{}",
summary.k, dataset_component, timestamp, backend
),
}
}
fn sanitize_component(input: &str) -> String {
input
.chars()
.map(|ch| if ch.is_ascii_alphanumeric() { ch } else { '_' })
.collect()
}
pub fn dataset_report_dir(report_dir: &Path, dataset_id: &str) -> PathBuf {
report_dir.join(sanitize_component(dataset_id))
}
#[derive(Debug, Serialize, Deserialize)]
struct HistoryEntry {
generated_at: String,
run_label: Option<String>,
dataset_id: String,
dataset_label: String,
slice_id: String,
slice_seed: u64,
slice_window_offset: usize,
slice_window_length: usize,
slice_cases: usize,
slice_total_cases: usize,
k: usize,
limit: Option<usize>,
precision: f64,
precision_at_1: f64,
precision_at_2: f64,
precision_at_3: f64,
duration_ms: u128,
latency_ms: LatencyStats,
embedding_backend: String,
embedding_model: Option<String>,
ingestion_reused: bool,
ingestion_embeddings_reused: bool,
rerank_enabled: bool,
rerank_keep_top: usize,
rerank_pool_size: Option<usize>,
delta: Option<HistoryDelta>,
openai_base_url: String,
ingestion_ms: u128,
#[serde(default)]
namespace_seed_ms: Option<u128>,
}
#[derive(Debug, Serialize, Deserialize)]
struct HistoryDelta {
precision: f64,
precision_at_1: f64,
latency_avg_ms: f64,
}
fn record_history(summary: &EvaluationSummary, report_dir: &Path) -> Result<()> {
let path = report_dir.join("evaluations.json");
let mut entries: Vec<HistoryEntry> = if path.exists() {
let contents = fs::read(&path)
.with_context(|| format!("reading evaluation log {}", path.display()))?;
serde_json::from_slice(&contents).unwrap_or_default()
} else {
Vec::new()
};
let delta = entries.last().map(|prev| HistoryDelta {
precision: summary.precision - prev.precision,
precision_at_1: summary.precision_at_1 - prev.precision_at_1,
latency_avg_ms: summary.latency_ms.avg - prev.latency_ms.avg,
});
let entry = HistoryEntry {
generated_at: format_timestamp(&summary.generated_at),
run_label: summary.run_label.clone(),
dataset_id: summary.dataset_id.clone(),
dataset_label: summary.dataset_label.clone(),
slice_id: summary.slice_id.clone(),
slice_seed: summary.slice_seed,
slice_window_offset: summary.slice_window_offset,
slice_window_length: summary.slice_window_length,
slice_cases: summary.slice_cases,
slice_total_cases: summary.slice_total_cases,
k: summary.k,
limit: summary.limit,
precision: summary.precision,
precision_at_1: summary.precision_at_1,
precision_at_2: summary.precision_at_2,
precision_at_3: summary.precision_at_3,
duration_ms: summary.duration_ms,
latency_ms: summary.latency_ms.clone(),
embedding_backend: summary.embedding_backend.clone(),
embedding_model: summary.embedding_model.clone(),
ingestion_reused: summary.ingestion_reused,
ingestion_embeddings_reused: summary.ingestion_embeddings_reused,
rerank_enabled: summary.rerank_enabled,
rerank_keep_top: summary.rerank_keep_top,
rerank_pool_size: summary.rerank_pool_size,
delta,
openai_base_url: summary.perf.openai_base_url.clone(),
ingestion_ms: summary.perf.ingestion_ms,
namespace_seed_ms: summary.perf.namespace_seed_ms,
};
entries.push(entry);
let blob = serde_json::to_vec_pretty(&entries).context("serialising evaluation log")?;
fs::write(&path, blob).with_context(|| format!("writing evaluation log {}", path.display()))?;
Ok(())
}

27
eval/src/slice.rs Normal file
View File

@@ -0,0 +1,27 @@
use crate::slices::SliceConfig as CoreSliceConfig;
pub use crate::slices::*;
use crate::args::Config;
impl<'a> From<&'a Config> for CoreSliceConfig<'a> {
fn from(config: &'a Config) -> Self {
slice_config_with_limit(config, None)
}
}
pub fn slice_config_with_limit<'a>(
config: &'a Config,
limit_override: Option<usize>,
) -> CoreSliceConfig<'a> {
CoreSliceConfig {
cache_dir: config.cache_dir.as_path(),
force_convert: config.force_convert,
explicit_slice: config.slice.as_deref(),
limit: limit_override.or(config.limit),
corpus_limit: config.corpus_limit,
slice_seed: config.slice_seed,
llm_mode: config.llm_mode,
negative_multiplier: config.negative_multiplier,
}
}

941
eval/src/slices.rs Normal file
View File

@@ -0,0 +1,941 @@
use std::{
collections::{HashMap, HashSet},
fs,
path::{Path, PathBuf},
};
use anyhow::{anyhow, Context, Result};
use chrono::{DateTime, Utc};
use rand::{rngs::StdRng, seq::SliceRandom, SeedableRng};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use tracing::{info, warn};
use crate::datasets::{ConvertedDataset, ConvertedParagraph, ConvertedQuestion};
const SLICE_VERSION: u32 = 2;
pub const DEFAULT_NEGATIVE_MULTIPLIER: f32 = 4.0;
#[derive(Debug, Clone)]
pub struct SliceConfig<'a> {
pub cache_dir: &'a Path,
pub force_convert: bool,
pub explicit_slice: Option<&'a str>,
pub limit: Option<usize>,
pub corpus_limit: Option<usize>,
pub slice_seed: u64,
pub llm_mode: bool,
pub negative_multiplier: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SliceManifest {
pub version: u32,
pub slice_id: String,
pub dataset_id: String,
pub dataset_label: String,
pub dataset_source: String,
pub includes_unanswerable: bool,
pub seed: u64,
pub requested_limit: Option<usize>,
pub requested_corpus: usize,
pub generated_at: DateTime<Utc>,
pub case_count: usize,
pub positive_paragraphs: usize,
pub negative_paragraphs: usize,
pub total_paragraphs: usize,
pub negative_multiplier: f32,
pub cases: Vec<SliceCaseEntry>,
pub paragraphs: Vec<SliceParagraphEntry>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SliceCaseEntry {
pub question_id: String,
pub paragraph_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SliceParagraphEntry {
pub id: String,
pub kind: SliceParagraphKind,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub shard_path: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum SliceParagraphKind {
Positive { question_ids: Vec<String> },
Negative,
}
pub(crate) fn default_shard_path(paragraph_id: &str) -> String {
let sanitized = sanitize_identifier(paragraph_id);
format!("paragraphs/{sanitized}.json")
}
fn sanitize_identifier(input: &str) -> String {
let mut sanitized = String::with_capacity(input.len());
for ch in input.chars() {
if ch.is_ascii_alphanumeric() {
sanitized.push(ch);
} else {
sanitized.push('-');
}
}
let trimmed = sanitized.trim_matches('-').to_string();
if trimmed.is_empty() {
let mut hasher = Sha256::new();
hasher.update(input.as_bytes());
let digest = hasher.finalize();
digest[..6]
.iter()
.map(|byte| format!("{byte:02x}"))
.collect::<String>()
} else {
trimmed
}
}
#[derive(Debug, Clone)]
pub struct ResolvedSlice<'a> {
pub manifest: SliceManifest,
pub path: PathBuf,
pub paragraphs: Vec<&'a ConvertedParagraph>,
pub cases: Vec<CaseRef<'a>>,
}
#[derive(Debug, Clone)]
pub struct SliceWindow<'a> {
pub offset: usize,
pub length: usize,
pub total_cases: usize,
pub cases: Vec<CaseRef<'a>>,
positive_paragraph_ids: Vec<String>,
}
impl<'a> SliceWindow<'a> {
pub fn positive_ids(&self) -> impl Iterator<Item = &str> {
self.positive_paragraph_ids.iter().map(|id| id.as_str())
}
}
#[derive(Debug, Clone)]
pub struct CaseRef<'a> {
pub paragraph: &'a ConvertedParagraph,
pub question: &'a ConvertedQuestion,
}
struct DatasetIndex {
paragraph_by_id: HashMap<String, usize>,
question_by_id: HashMap<String, (usize, usize)>,
}
impl DatasetIndex {
fn build(dataset: &ConvertedDataset) -> Self {
let mut paragraph_by_id = HashMap::new();
let mut question_by_id = HashMap::new();
for (p_idx, paragraph) in dataset.paragraphs.iter().enumerate() {
paragraph_by_id.insert(paragraph.id.clone(), p_idx);
for (q_idx, question) in paragraph.questions.iter().enumerate() {
question_by_id.insert(question.id.clone(), (p_idx, q_idx));
}
}
Self {
paragraph_by_id,
question_by_id,
}
}
fn paragraph<'a>(
&self,
dataset: &'a ConvertedDataset,
id: &str,
) -> Result<&'a ConvertedParagraph> {
let idx = self
.paragraph_by_id
.get(id)
.ok_or_else(|| anyhow!("slice references unknown paragraph '{id}'"))?;
Ok(&dataset.paragraphs[*idx])
}
fn question<'a>(
&self,
dataset: &'a ConvertedDataset,
question_id: &str,
) -> Result<(&'a ConvertedParagraph, &'a ConvertedQuestion)> {
let (p_idx, q_idx) = self
.question_by_id
.get(question_id)
.ok_or_else(|| anyhow!("slice references unknown question '{question_id}'"))?;
let paragraph = &dataset.paragraphs[*p_idx];
let question = paragraph
.questions
.get(*q_idx)
.ok_or_else(|| anyhow!("slice maps question '{question_id}' to missing index"))?;
Ok((paragraph, question))
}
}
#[derive(Debug, Serialize)]
struct SliceKey<'a> {
dataset_id: &'a str,
includes_unanswerable: bool,
requested_corpus: usize,
seed: u64,
}
#[derive(Debug)]
struct BuildParams {
include_impossible: bool,
base_seed: u64,
rng_seed: u64,
}
pub fn resolve_slice<'a>(
dataset: &'a ConvertedDataset,
config: &SliceConfig<'_>,
) -> Result<ResolvedSlice<'a>> {
let index = DatasetIndex::build(dataset);
if let Some(slice_arg) = config.explicit_slice {
let (path, manifest) = load_explicit_slice(dataset, &index, config, slice_arg)?;
let resolved = manifest_to_resolved(dataset, &index, manifest, path)?;
info!(
slice = %resolved.manifest.slice_id,
path = %resolved.path.display(),
cases = resolved.manifest.case_count,
positives = resolved.manifest.positive_paragraphs,
negatives = resolved.manifest.negative_paragraphs,
"Using explicitly selected slice"
);
return Ok(resolved);
}
let requested_corpus = config
.corpus_limit
.unwrap_or(dataset.paragraphs.len())
.min(dataset.paragraphs.len())
.max(1);
let key = SliceKey {
dataset_id: dataset.metadata.id.as_str(),
includes_unanswerable: dataset.metadata.include_unanswerable,
requested_corpus,
seed: config.slice_seed,
};
let slice_id = compute_slice_id(&key);
let base = config
.cache_dir
.join("slices")
.join(dataset.metadata.id.as_str());
let path = base.join(format!("{slice_id}.json"));
let total_questions = dataset
.paragraphs
.iter()
.map(|p| p.questions.len())
.sum::<usize>()
.max(1);
let requested_limit = config
.limit
.unwrap_or(total_questions)
.min(total_questions)
.max(1);
let mut manifest = if !config.force_convert && path.exists() {
match read_manifest(&path) {
Ok(manifest) if manifest.dataset_id == dataset.metadata.id => {
if manifest.includes_unanswerable != dataset.metadata.include_unanswerable {
warn!(
slice = manifest.slice_id,
path = %path.display(),
"Slice manifest includes_unanswerable mismatch; regenerating"
);
None
} else {
Some(manifest)
}
}
Ok(manifest) => {
warn!(
slice = manifest.slice_id,
path = %path.display(),
loaded_dataset = %manifest.dataset_id,
expected = %dataset.metadata.id,
"Slice manifest targets different dataset; regenerating"
);
None
}
Err(err) => {
warn!(
path = %path.display(),
error = %err,
"Failed to read cached slice; regenerating"
);
None
}
}
} else {
None
};
let params = BuildParams {
include_impossible: config.llm_mode,
base_seed: config.slice_seed,
rng_seed: mix_seed(dataset.metadata.id.as_str(), config.slice_seed),
};
if manifest
.as_ref()
.map(|manifest| manifest.version != SLICE_VERSION)
.unwrap_or(false)
{
warn!(
slice = manifest
.as_ref()
.map(|m| m.slice_id.as_str())
.unwrap_or("unknown"),
found = manifest.as_ref().map(|m| m.version).unwrap_or(0),
expected = SLICE_VERSION,
"Slice manifest version mismatch; regenerating"
);
manifest = None;
}
let mut manifest = manifest.unwrap_or_else(|| {
empty_manifest(
dataset,
slice_id.clone(),
&params,
requested_corpus,
config.negative_multiplier,
config.limit,
)
});
manifest.requested_limit = config.limit;
manifest.requested_corpus = requested_corpus;
manifest.negative_multiplier = config.negative_multiplier;
let mut changed = ensure_shard_paths(&mut manifest);
changed |= ensure_case_capacity(dataset, &mut manifest, &params, requested_limit)?;
refresh_manifest_stats(&mut manifest);
let desired_negatives = desired_negative_target(
manifest.positive_paragraphs,
requested_corpus,
dataset.paragraphs.len(),
config.negative_multiplier,
);
changed |= ensure_negative_pool(
dataset,
&mut manifest,
&params,
desired_negatives,
requested_corpus,
)?;
refresh_manifest_stats(&mut manifest);
if changed {
manifest.generated_at = Utc::now();
write_manifest(&path, &manifest)?;
info!(
slice = %manifest.slice_id,
path = %path.display(),
cases = manifest.case_count,
positives = manifest.positive_paragraphs,
negatives = manifest.negative_paragraphs,
"Updated dataset slice ledger"
);
} else {
info!(
slice = %manifest.slice_id,
path = %path.display(),
cases = manifest.case_count,
positives = manifest.positive_paragraphs,
negatives = manifest.negative_paragraphs,
"Reusing cached slice ledger"
);
}
let resolved = manifest_to_resolved(dataset, &index, manifest.clone(), path.clone())?;
Ok(resolved)
}
pub fn select_window<'a>(
resolved: &'a ResolvedSlice<'a>,
offset: usize,
limit: Option<usize>,
) -> Result<SliceWindow<'a>> {
let total = resolved.manifest.case_count;
if total == 0 {
return Err(anyhow!(
"slice '{}' contains no cases",
resolved.manifest.slice_id
));
}
if offset >= total {
return Err(anyhow!(
"slice offset {} exceeds available cases ({})",
offset,
total
));
}
let available = total - offset;
let requested = limit.unwrap_or(available).max(1);
let length = requested.min(available);
let cases = resolved.cases[offset..offset + length].to_vec();
let mut seen = HashSet::new();
let mut positive_ids = Vec::new();
for case in &cases {
if seen.insert(case.paragraph.id.as_str()) {
positive_ids.push(case.paragraph.id.clone());
}
}
Ok(SliceWindow {
offset,
length,
total_cases: total,
cases,
positive_paragraph_ids: positive_ids,
})
}
#[allow(dead_code)]
pub fn full_window<'a>(resolved: &'a ResolvedSlice<'a>) -> Result<SliceWindow<'a>> {
select_window(resolved, 0, None)
}
fn load_explicit_slice<'a>(
dataset: &'a ConvertedDataset,
index: &DatasetIndex,
config: &SliceConfig<'_>,
slice_arg: &str,
) -> Result<(PathBuf, SliceManifest)> {
let explicit_path = Path::new(slice_arg);
let candidate_path = if explicit_path.exists() {
explicit_path.to_path_buf()
} else {
config
.cache_dir
.join("slices")
.join(dataset.metadata.id.as_str())
.join(format!("{slice_arg}.json"))
};
let manifest = read_manifest(&candidate_path)
.with_context(|| format!("reading slice manifest at {}", candidate_path.display()))?;
if manifest.dataset_id != dataset.metadata.id {
return Err(anyhow!(
"slice '{}' targets dataset '{}', but '{}' is loaded",
manifest.slice_id,
manifest.dataset_id,
dataset.metadata.id
));
}
// Validate the manifest before returning.
manifest_to_resolved(dataset, index, manifest.clone(), candidate_path.clone())?;
Ok((candidate_path, manifest))
}
fn empty_manifest(
dataset: &ConvertedDataset,
slice_id: String,
params: &BuildParams,
requested_corpus: usize,
negative_multiplier: f32,
requested_limit: Option<usize>,
) -> SliceManifest {
SliceManifest {
version: SLICE_VERSION,
slice_id,
dataset_id: dataset.metadata.id.clone(),
dataset_label: dataset.metadata.label.clone(),
dataset_source: dataset.source.clone(),
includes_unanswerable: dataset.metadata.include_unanswerable,
seed: params.base_seed,
requested_limit,
requested_corpus,
negative_multiplier,
generated_at: Utc::now(),
case_count: 0,
positive_paragraphs: 0,
negative_paragraphs: 0,
total_paragraphs: 0,
cases: Vec::new(),
paragraphs: Vec::new(),
}
}
fn ensure_case_capacity(
dataset: &ConvertedDataset,
manifest: &mut SliceManifest,
params: &BuildParams,
target_cases: usize,
) -> Result<bool> {
if manifest.case_count >= target_cases {
return Ok(false);
}
let question_refs = ordered_question_refs(dataset, params)?;
let mut existing_questions: HashSet<String> = manifest
.cases
.iter()
.map(|case| case.question_id.clone())
.collect();
let mut paragraph_positions: HashMap<String, usize> = manifest
.paragraphs
.iter()
.enumerate()
.map(|(idx, entry)| (entry.id.clone(), idx))
.collect();
let mut changed = false;
for (p_idx, q_idx) in question_refs {
if manifest.case_count >= target_cases {
break;
}
let paragraph = &dataset.paragraphs[p_idx];
let question = &paragraph.questions[q_idx];
if !existing_questions.insert(question.id.clone()) {
continue;
}
if let Some(idx) = paragraph_positions.get(paragraph.id.as_str()).copied() {
match &mut manifest.paragraphs[idx].kind {
SliceParagraphKind::Positive { question_ids } => {
if !question_ids.contains(&question.id) {
question_ids.push(question.id.clone());
}
}
SliceParagraphKind::Negative => {
manifest.paragraphs[idx].kind = SliceParagraphKind::Positive {
question_ids: vec![question.id.clone()],
};
}
}
} else {
manifest.paragraphs.push(SliceParagraphEntry {
id: paragraph.id.clone(),
kind: SliceParagraphKind::Positive {
question_ids: vec![question.id.clone()],
},
shard_path: Some(default_shard_path(&paragraph.id)),
});
let idx = manifest.paragraphs.len() - 1;
paragraph_positions.insert(paragraph.id.clone(), idx);
}
manifest.cases.push(SliceCaseEntry {
question_id: question.id.clone(),
paragraph_id: paragraph.id.clone(),
});
manifest.case_count += 1;
changed = true;
}
if manifest.case_count < target_cases {
return Err(anyhow!(
"only {}/{} eligible questions available for dataset {}",
manifest.case_count,
target_cases,
dataset.metadata.id
));
}
Ok(changed)
}
fn ordered_question_refs(
dataset: &ConvertedDataset,
params: &BuildParams,
) -> Result<Vec<(usize, usize)>> {
let mut question_refs = Vec::new();
for (p_idx, paragraph) in dataset.paragraphs.iter().enumerate() {
for (q_idx, question) in paragraph.questions.iter().enumerate() {
let include = if params.include_impossible {
true
} else {
!question.is_impossible && !question.answers.is_empty()
};
if include {
question_refs.push((p_idx, q_idx));
}
}
}
if question_refs.is_empty() {
return Err(anyhow!(
"no eligible questions found for dataset {}; cannot build slice",
dataset.metadata.id
));
}
let mut rng = StdRng::seed_from_u64(params.rng_seed);
question_refs.shuffle(&mut rng);
Ok(question_refs)
}
fn ensure_negative_pool(
dataset: &ConvertedDataset,
manifest: &mut SliceManifest,
params: &BuildParams,
target_negatives: usize,
requested_corpus: usize,
) -> Result<bool> {
let current_negatives = manifest
.paragraphs
.iter()
.filter(|entry| matches!(entry.kind, SliceParagraphKind::Negative))
.count();
if current_negatives >= target_negatives {
return Ok(false);
}
let positive_ids: HashSet<String> = manifest
.paragraphs
.iter()
.filter_map(|entry| match entry.kind {
SliceParagraphKind::Positive { .. } => Some(entry.id.clone()),
_ => None,
})
.collect();
let mut negative_ids: HashSet<String> = manifest
.paragraphs
.iter()
.filter_map(|entry| match entry.kind {
SliceParagraphKind::Negative => Some(entry.id.clone()),
_ => None,
})
.collect();
let negative_seed = mix_seed(
&format!("{}::negatives", dataset.metadata.id),
params.base_seed,
);
let candidates = ordered_negative_indices(dataset, &positive_ids, negative_seed);
let mut added = false;
for idx in candidates {
if negative_ids.len() >= target_negatives {
break;
}
let paragraph = &dataset.paragraphs[idx];
if negative_ids.contains(paragraph.id.as_str())
|| positive_ids.contains(paragraph.id.as_str())
{
continue;
}
manifest.paragraphs.push(SliceParagraphEntry {
id: paragraph.id.clone(),
kind: SliceParagraphKind::Negative,
shard_path: Some(default_shard_path(&paragraph.id)),
});
negative_ids.insert(paragraph.id.clone());
added = true;
}
if negative_ids.len() < target_negatives {
warn!(
dataset = %dataset.metadata.id,
desired = target_negatives,
available = negative_ids.len(),
requested_corpus,
"Insufficient negative paragraphs to satisfy multiplier"
);
}
Ok(added)
}
fn ordered_negative_indices(
dataset: &ConvertedDataset,
positive_ids: &HashSet<String>,
rng_seed: u64,
) -> Vec<usize> {
let mut candidates: Vec<usize> = dataset
.paragraphs
.iter()
.enumerate()
.filter_map(|(idx, paragraph)| {
if positive_ids.contains(paragraph.id.as_str()) {
None
} else {
Some(idx)
}
})
.collect();
let mut rng = StdRng::seed_from_u64(rng_seed);
candidates.shuffle(&mut rng);
candidates
}
fn refresh_manifest_stats(manifest: &mut SliceManifest) {
manifest.case_count = manifest.cases.len();
manifest.positive_paragraphs = manifest
.paragraphs
.iter()
.filter(|entry| matches!(entry.kind, SliceParagraphKind::Positive { .. }))
.count();
manifest.negative_paragraphs = manifest
.paragraphs
.iter()
.filter(|entry| matches!(entry.kind, SliceParagraphKind::Negative))
.count();
manifest.total_paragraphs = manifest.paragraphs.len();
}
fn ensure_shard_paths(manifest: &mut SliceManifest) -> bool {
let mut changed = false;
for entry in &mut manifest.paragraphs {
if entry.shard_path.is_none() {
entry.shard_path = Some(default_shard_path(&entry.id));
changed = true;
}
}
changed
}
fn desired_negative_target(
positive_count: usize,
requested_corpus: usize,
dataset_paragraphs: usize,
multiplier: f32,
) -> usize {
if positive_count == 0 {
return 0;
}
let ratio = multiplier.max(0.0);
let mut desired = ((positive_count as f32) * ratio).ceil() as usize;
let max_total = requested_corpus.min(dataset_paragraphs).max(positive_count);
let max_negatives = max_total.saturating_sub(positive_count);
desired = desired.min(max_negatives);
desired
}
fn manifest_to_resolved<'a>(
dataset: &'a ConvertedDataset,
index: &DatasetIndex,
manifest: SliceManifest,
path: PathBuf,
) -> Result<ResolvedSlice<'a>> {
if manifest.version != SLICE_VERSION {
return Err(anyhow!(
"slice version {} does not match expected {}",
manifest.version,
SLICE_VERSION
));
}
let mut paragraphs = Vec::with_capacity(manifest.paragraphs.len());
for entry in &manifest.paragraphs {
let paragraph = index.paragraph(dataset, &entry.id)?;
if let SliceParagraphKind::Positive { question_ids } = &entry.kind {
for question_id in question_ids {
let (linked_paragraph, _) = index.question(dataset, question_id)?;
if linked_paragraph.id != entry.id {
return Err(anyhow!(
"slice question '{}' expected paragraph '{}', found '{}'",
question_id,
entry.id,
linked_paragraph.id
));
}
}
}
paragraphs.push(paragraph);
}
let mut cases = Vec::with_capacity(manifest.cases.len());
for entry in &manifest.cases {
let (paragraph, question) = index.question(dataset, &entry.question_id)?;
if paragraph.id != entry.paragraph_id {
return Err(anyhow!(
"slice case '{}' expected paragraph '{}', found '{}'",
entry.question_id,
entry.paragraph_id,
paragraph.id
));
}
cases.push(CaseRef {
paragraph,
question,
});
}
if cases.is_empty() {
return Err(anyhow!(
"slice '{}' contains no cases after validation",
manifest.slice_id
));
}
Ok(ResolvedSlice {
manifest,
path,
paragraphs,
cases,
})
}
fn compute_slice_id(key: &SliceKey<'_>) -> String {
let payload = serde_json::to_vec(key).expect("SliceKey serialisation should not fail");
let mut hasher = Sha256::new();
hasher.update(payload);
let digest = hasher.finalize();
digest[..16]
.iter()
.map(|byte| format!("{byte:02x}"))
.collect::<String>()
}
fn mix_seed(dataset_id: &str, seed: u64) -> u64 {
let mut hasher = Sha256::new();
hasher.update(dataset_id.as_bytes());
hasher.update(seed.to_le_bytes());
let digest = hasher.finalize();
let mut bytes = [0u8; 8];
bytes.copy_from_slice(&digest[..8]);
u64::from_le_bytes(bytes)
}
fn read_manifest(path: &Path) -> Result<SliceManifest> {
let raw = fs::read_to_string(path)
.with_context(|| format!("reading slice manifest {}", path.display()))?;
let manifest: SliceManifest = serde_json::from_str(&raw)
.with_context(|| format!("parsing slice manifest {}", path.display()))?;
Ok(manifest)
}
fn write_manifest(path: &Path, manifest: &SliceManifest) -> Result<()> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("creating slice directory {}", parent.display()))?;
}
let json = serde_json::to_vec_pretty(manifest).context("serialising slice manifest to JSON")?;
fs::write(path, json).with_context(|| format!("writing slice manifest {}", path.display()))?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::datasets::{
ConvertedDataset, ConvertedParagraph, ConvertedQuestion, DatasetKind, DatasetMetadata,
};
use tempfile::tempdir;
fn sample_dataset() -> ConvertedDataset {
let metadata = DatasetMetadata::for_kind(DatasetKind::SquadV2, false, None);
ConvertedDataset {
generated_at: Utc::now(),
metadata,
source: "test-source".to_string(),
paragraphs: vec![
ConvertedParagraph {
id: "p1".to_string(),
title: "Alpha".to_string(),
context: "Alpha context".to_string(),
questions: vec![ConvertedQuestion {
id: "q1".to_string(),
question: "What is alpha?".to_string(),
answers: vec!["Alpha".to_string()],
is_impossible: false,
}],
},
ConvertedParagraph {
id: "p2".to_string(),
title: "Beta".to_string(),
context: "Beta context".to_string(),
questions: vec![ConvertedQuestion {
id: "q2".to_string(),
question: "What is beta?".to_string(),
answers: vec!["Beta".to_string()],
is_impossible: false,
}],
},
ConvertedParagraph {
id: "p3".to_string(),
title: "Gamma".to_string(),
context: "Gamma context".to_string(),
questions: vec![ConvertedQuestion {
id: "q3".to_string(),
question: "What is gamma?".to_string(),
answers: vec!["Gamma".to_string()],
is_impossible: false,
}],
},
],
}
}
#[test]
fn resolve_slice_reuses_cached_manifest() -> Result<()> {
let dataset = sample_dataset();
let temp = tempdir().context("creating temp directory")?;
let mut config = SliceConfig {
cache_dir: temp.path(),
force_convert: false,
explicit_slice: None,
limit: Some(2),
corpus_limit: Some(3),
slice_seed: 0x5eed_2025,
llm_mode: false,
negative_multiplier: DEFAULT_NEGATIVE_MULTIPLIER,
};
let first = resolve_slice(&dataset, &config)?;
assert!(first.path.exists());
let initial_generated = first.manifest.generated_at;
let second = resolve_slice(&dataset, &config)?;
assert_eq!(first.manifest.slice_id, second.manifest.slice_id);
assert_eq!(initial_generated, second.manifest.generated_at);
config.force_convert = true;
let third = resolve_slice(&dataset, &config)?;
assert_eq!(first.manifest.slice_id, third.manifest.slice_id);
assert_ne!(third.manifest.generated_at, initial_generated);
Ok(())
}
#[test]
fn select_window_yields_expected_cases() -> Result<()> {
let dataset = sample_dataset();
let temp = tempdir().context("creating temp directory")?;
let config = SliceConfig {
cache_dir: temp.path(),
force_convert: false,
explicit_slice: None,
limit: Some(3),
corpus_limit: Some(3),
slice_seed: 0x5eed_2025,
llm_mode: false,
negative_multiplier: DEFAULT_NEGATIVE_MULTIPLIER,
};
let resolved = resolve_slice(&dataset, &config)?;
let window = select_window(&resolved, 1, Some(1))?;
assert_eq!(window.offset, 1);
assert_eq!(window.length, 1);
assert_eq!(window.total_cases, resolved.manifest.case_count);
assert_eq!(window.cases.len(), 1);
let positive_ids: Vec<&str> = window.positive_ids().collect();
assert_eq!(positive_ids.len(), 1);
assert!(resolved
.manifest
.paragraphs
.iter()
.any(|entry| entry.id == positive_ids[0]));
Ok(())
}
}

182
eval/src/snapshot.rs Normal file
View File

@@ -0,0 +1,182 @@
use std::path::PathBuf;
use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use tokio::fs;
use crate::{args::Config, embedding::EmbeddingProvider, slice};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SnapshotMetadata {
pub dataset_id: String,
pub slice_id: String,
pub embedding_backend: String,
pub embedding_model: Option<String>,
pub embedding_dimension: usize,
pub chunk_min_chars: usize,
pub chunk_max_chars: usize,
pub rerank_enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DbSnapshotState {
pub dataset_id: String,
pub slice_id: String,
pub ingestion_fingerprint: String,
pub snapshot_hash: String,
pub updated_at: DateTime<Utc>,
#[serde(default)]
pub namespace: Option<String>,
#[serde(default)]
pub database: Option<String>,
#[serde(default)]
pub slice_case_count: usize,
}
pub struct Descriptor {
#[allow(dead_code)]
metadata: SnapshotMetadata,
dir: PathBuf,
metadata_hash: String,
}
impl Descriptor {
pub fn new(
config: &Config,
slice: &slice::ResolvedSlice<'_>,
embedding_provider: &EmbeddingProvider,
) -> Self {
let metadata = SnapshotMetadata {
dataset_id: slice.manifest.dataset_id.clone(),
slice_id: slice.manifest.slice_id.clone(),
embedding_backend: embedding_provider.backend_label().to_string(),
embedding_model: embedding_provider.model_code(),
embedding_dimension: embedding_provider.dimension(),
chunk_min_chars: config.chunk_min_chars,
chunk_max_chars: config.chunk_max_chars,
rerank_enabled: config.rerank,
};
let dir = config
.cache_dir
.join("snapshots")
.join(&metadata.dataset_id)
.join(&metadata.slice_id);
let metadata_hash = compute_hash(&metadata);
Self {
metadata,
dir,
metadata_hash,
}
}
pub fn metadata_hash(&self) -> &str {
&self.metadata_hash
}
pub async fn load_db_state(&self) -> Result<Option<DbSnapshotState>> {
let path = self.db_state_path();
if !path.exists() {
return Ok(None);
}
let bytes = fs::read(&path)
.await
.with_context(|| format!("reading namespace state {}", path.display()))?;
let state = serde_json::from_slice(&bytes)
.with_context(|| format!("deserialising namespace state {}", path.display()))?;
Ok(Some(state))
}
pub async fn store_db_state(&self, state: &DbSnapshotState) -> Result<()> {
let path = self.db_state_path();
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).await.with_context(|| {
format!("creating namespace state directory {}", parent.display())
})?;
}
let blob =
serde_json::to_vec_pretty(state).context("serialising namespace state payload")?;
fs::write(&path, blob)
.await
.with_context(|| format!("writing namespace state {}", path.display()))?;
Ok(())
}
fn db_dir(&self) -> PathBuf {
self.dir.join("db")
}
fn db_state_path(&self) -> PathBuf {
self.db_dir().join("state.json")
}
#[cfg(test)]
pub fn from_parts(metadata: SnapshotMetadata, dir: PathBuf) -> Self {
let metadata_hash = compute_hash(&metadata);
Self {
metadata,
dir,
metadata_hash,
}
}
}
fn compute_hash(metadata: &SnapshotMetadata) -> String {
let mut hasher = Sha256::new();
hasher.update(
serde_json::to_vec(metadata).expect("snapshot metadata serialisation should succeed"),
);
format!("{:x}", hasher.finalize())
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn state_round_trip() {
let temp_dir = tempfile::tempdir().unwrap();
let metadata = SnapshotMetadata {
dataset_id: "dataset".into(),
slice_id: "slice".into(),
embedding_backend: "hashed".into(),
embedding_model: None,
embedding_dimension: 128,
chunk_min_chars: 10,
chunk_max_chars: 100,
rerank_enabled: true,
};
let descriptor = Descriptor::from_parts(
metadata,
temp_dir
.path()
.join("snapshots")
.join("dataset")
.join("slice"),
);
let state = DbSnapshotState {
dataset_id: "dataset".into(),
slice_id: "slice".into(),
ingestion_fingerprint: "fingerprint".into(),
snapshot_hash: descriptor.metadata_hash().to_string(),
updated_at: Utc::now(),
namespace: Some("ns".into()),
database: Some("db".into()),
slice_case_count: 42,
};
descriptor.store_db_state(&state).await.unwrap();
let loaded = descriptor.load_db_state().await.unwrap().unwrap();
assert_eq!(loaded.dataset_id, state.dataset_id);
assert_eq!(loaded.slice_id, state.slice_id);
assert_eq!(loaded.ingestion_fingerprint, state.ingestion_fingerprint);
assert_eq!(loaded.snapshot_hash, state.snapshot_hash);
assert_eq!(loaded.namespace, state.namespace);
assert_eq!(loaded.database, state.database);
assert_eq!(loaded.slice_case_count, state.slice_case_count);
}
}

View File

@@ -7,7 +7,7 @@ use axum::{
routing::{get, post},
Router,
};
use chat_handlers::{
pub use chat_handlers::{
delete_conversation, new_chat_user_message, new_user_message, patch_conversation_title,
reload_sidebar, show_chat_base, show_conversation_editing_title, show_existing_chat,
show_initialized_chat,

View File

@@ -1,7 +1,7 @@
mod handlers;
use axum::{extract::FromRef, routing::get, Router};
use handlers::search_result_handler;
pub use handlers::{search_result_handler, SearchParams};
use crate::html_state::HtmlState;

View File

@@ -6,7 +6,7 @@ use common::storage::{
db::SurrealDbClient,
types::ingestion_task::{IngestionTask, DEFAULT_LEASE_SECS},
};
use pipeline::IngestionPipeline;
pub use pipeline::{IngestionConfig, IngestionPipeline, IngestionTuning};
use std::sync::Arc;
use tokio::time::{sleep, Duration};
use tracing::{error, info, warn};

View File

@@ -1,8 +1,14 @@
use std::ops::Range;
use common::{
error::AppError,
storage::{
db::SurrealDbClient,
types::{ingestion_task::IngestionTask, text_content::TextContent},
types::{
ingestion_task::IngestionTask, knowledge_entity::KnowledgeEntity,
knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk,
text_content::TextContent,
},
},
};
use composite_retrieval::RetrievedEntity;
@@ -24,6 +30,14 @@ pub struct PipelineContext<'a> {
pub analysis: Option<LLMEnrichmentResult>,
}
#[derive(Debug)]
pub struct PipelineArtifacts {
pub text_content: TextContent,
pub entities: Vec<KnowledgeEntity>,
pub relationships: Vec<KnowledgeRelationship>,
pub chunks: Vec<TextChunk>,
}
impl<'a> PipelineContext<'a> {
pub fn new(
task: &'a IngestionTask,
@@ -73,4 +87,30 @@ impl<'a> PipelineContext<'a> {
);
err
}
pub async fn build_artifacts(&mut self) -> Result<PipelineArtifacts, AppError> {
let content = self.take_text_content()?;
let analysis = self.take_analysis()?;
let (entities, relationships) = self
.services
.convert_analysis(
&content,
&analysis,
self.pipeline_config.tuning.entity_embedding_concurrency,
)
.await?;
let chunk_range: Range<usize> = self.pipeline_config.tuning.chunk_min_chars
..self.pipeline_config.tuning.chunk_max_chars;
let chunks = self.services.prepare_chunks(&content, chunk_range).await?;
Ok(PipelineArtifacts {
text_content: content,
entities,
relationships,
chunks,
})
}
}

View File

@@ -7,6 +7,7 @@ mod stages;
mod state;
pub use config::{IngestionConfig, IngestionTuning};
pub use enrichment_result::{LLMEnrichmentResult, LLMKnowledgeEntity, LLMRelationship};
pub use services::{DefaultPipelineServices, PipelineServices};
use std::{
@@ -31,7 +32,7 @@ use composite_retrieval::reranking::RerankerPool;
use tracing::{debug, info, warn};
use self::{
context::PipelineContext,
context::{PipelineArtifacts, PipelineContext},
stages::{enrich, persist, prepare_content, retrieve_related},
state::ready,
};
@@ -224,6 +225,33 @@ impl IngestionPipeline {
Ok(())
}
/// Runs the ingestion pipeline up to (but excluding) persistence and returns the prepared artifacts.
pub async fn produce_artifacts(
&self,
task: &IngestionTask,
) -> Result<PipelineArtifacts, AppError> {
let payload = task.content.clone();
let mut ctx = PipelineContext::new(
task,
self.db.as_ref(),
&self.pipeline_config,
self.services.as_ref(),
);
let machine = ready();
let machine = prepare_content(machine, &mut ctx, payload)
.await
.map_err(|err| ctx.abort(err))?;
let machine = retrieve_related(machine, &mut ctx)
.await
.map_err(|err| ctx.abort(err))?;
let _machine = enrich(machine, &mut ctx)
.await
.map_err(|err| ctx.abort(err))?;
ctx.build_artifacts().await.map_err(|err| ctx.abort(err))
}
}
#[cfg(test)]

View File

@@ -29,6 +29,8 @@ use crate::utils::llm_instructions::{
get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE,
};
const EMBEDDING_QUERY_CHAR_LIMIT: usize = 12_000;
#[async_trait]
pub trait PipelineServices: Send + Sync {
async fn prepare_text_content(
@@ -162,9 +164,13 @@ impl PipelineServices for DefaultPipelineServices {
&self,
content: &TextContent,
) -> Result<Vec<RetrievedEntity>, AppError> {
let truncated_body = truncate_for_embedding(&content.text, EMBEDDING_QUERY_CHAR_LIMIT);
let input_text = format!(
"content: {}, category: {}, user_context: {:?}",
content.text, content.category, content.context
"content: {}\n[truncated={}], category: {}, user_context: {:?}",
truncated_body,
truncated_body.len() < content.text.len(),
content.category,
content.context
);
let rerank_lease = match &self.reranker_pool {
@@ -239,3 +245,19 @@ impl PipelineServices for DefaultPipelineServices {
Ok(chunks)
}
}
fn truncate_for_embedding(text: &str, max_chars: usize) -> String {
if text.chars().count() <= max_chars {
return text.to_string();
}
let mut truncated = String::with_capacity(max_chars + 3);
for (idx, ch) in text.chars().enumerate() {
if idx >= max_chars {
break;
}
truncated.push(ch);
}
truncated.push_str("");
truncated
}

View File

@@ -7,7 +7,6 @@ use common::{
types::{
ingestion_payload::IngestionPayload, knowledge_entity::KnowledgeEntity,
knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk,
text_content::TextContent,
},
},
};
@@ -16,8 +15,7 @@ use tokio::time::{sleep, Duration};
use tracing::{debug, instrument, warn};
use super::{
context::PipelineContext,
services::PipelineServices,
context::{PipelineArtifacts, PipelineContext},
state::{ContentPrepared, Enriched, IngestionMachine, Persisted, Ready, Retrieved},
};
@@ -134,37 +132,26 @@ pub async fn persist(
machine: IngestionMachine<(), Enriched>,
ctx: &mut PipelineContext<'_>,
) -> Result<IngestionMachine<(), Persisted>, AppError> {
let content = ctx.take_text_content()?;
let analysis = ctx.take_analysis()?;
let (entities, relationships) = ctx
.services
.convert_analysis(
&content,
&analysis,
ctx.pipeline_config.tuning.entity_embedding_concurrency,
)
.await?;
let PipelineArtifacts {
text_content,
entities,
relationships,
chunks,
} = ctx.build_artifacts().await?;
let entity_count = entities.len();
let relationship_count = relationships.len();
let chunk_range =
ctx.pipeline_config.tuning.chunk_min_chars..ctx.pipeline_config.tuning.chunk_max_chars;
let ((), chunk_count) = tokio::try_join!(
store_graph_entities(ctx.db, &ctx.pipeline_config.tuning, entities, relationships),
store_vector_chunks(
ctx.db,
ctx.services,
ctx.task_id.as_str(),
&content,
chunk_range,
&chunks,
&ctx.pipeline_config.tuning
)
)?;
ctx.db.store_item(content).await?;
ctx.db.store_item(text_content).await?;
ctx.db.rebuild_indexes().await?;
debug!(
@@ -252,17 +239,14 @@ async fn store_graph_entities(
async fn store_vector_chunks(
db: &SurrealDbClient,
services: &dyn PipelineServices,
task_id: &str,
content: &TextContent,
chunk_range: std::ops::Range<usize>,
chunks: &[TextChunk],
tuning: &super::config::IngestionTuning,
) -> Result<usize, AppError> {
let prepared_chunks = services.prepare_chunks(content, chunk_range).await?;
let chunk_count = prepared_chunks.len();
let chunk_count = chunks.len();
let batch_size = tuning.chunk_insert_concurrency.max(1);
for chunk in &prepared_chunks {
for chunk in chunks {
debug!(
task_id = %task_id,
chunk_id = %chunk.id,
@@ -271,7 +255,7 @@ async fn store_vector_chunks(
);
}
for batch in prepared_chunks.chunks(batch_size) {
for batch in chunks.chunks(batch_size) {
store_chunk_batch(db, batch, tuning).await?;
}