mirror of
https://github.com/perstarkse/minne.git
synced 2026-03-14 06:16:19 +01:00
benchmarks: v1
Benchmarking ingestion, retrieval precision and performance
This commit is contained in:
2
.cargo/config.toml
Normal file
2
.cargo/config.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
[alias]
|
||||
eval = "run -p eval --"
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -10,6 +10,9 @@ result
|
||||
data
|
||||
database
|
||||
|
||||
eval/cache/
|
||||
eval/reports/
|
||||
|
||||
# Devenv
|
||||
.devenv*
|
||||
devenv.local.nix
|
||||
@@ -21,3 +24,4 @@ devenv.local.nix
|
||||
.pre-commit-config.yaml
|
||||
# html-router/assets/style.css
|
||||
html-router/node_modules
|
||||
.fastembed_cache/
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Changelog
|
||||
## Unreleased
|
||||
- Added a shared `benchmarks` crate with deterministic fixtures and Criterion suites for ingestion, retrieval, and HTML handlers, plus documented baseline results for local performance checks.
|
||||
|
||||
## Version 0.2.6 (2025-10-29)
|
||||
- Added an opt-in FastEmbed-based reranking stage behind `reranking_enabled`. It improves retrieval accuracy by re-scoring hybrid results.
|
||||
|
||||
163
Cargo.lock
generated
163
Cargo.lock
generated
@@ -184,6 +184,12 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anes"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"
|
||||
|
||||
[[package]]
|
||||
name = "anstream"
|
||||
version = "0.6.18"
|
||||
@@ -1090,6 +1096,12 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cast"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
|
||||
|
||||
[[package]]
|
||||
name = "castaway"
|
||||
version = "0.2.3"
|
||||
@@ -1626,6 +1638,42 @@ dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "criterion"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f"
|
||||
dependencies = [
|
||||
"anes",
|
||||
"cast",
|
||||
"ciborium",
|
||||
"clap",
|
||||
"criterion-plot",
|
||||
"is-terminal",
|
||||
"itertools 0.10.5",
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"oorandom",
|
||||
"plotters",
|
||||
"rayon",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"tinytemplate",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "criterion-plot"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
|
||||
dependencies = [
|
||||
"cast",
|
||||
"itertools 0.10.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-deque"
|
||||
version = "0.8.6"
|
||||
@@ -2140,6 +2188,37 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "eval"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-openai",
|
||||
"async-trait",
|
||||
"chrono",
|
||||
"common",
|
||||
"composite-retrieval",
|
||||
"criterion",
|
||||
"fastembed",
|
||||
"futures",
|
||||
"ingestion-pipeline",
|
||||
"object_store 0.11.2",
|
||||
"once_cell",
|
||||
"rand 0.8.5",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_yaml",
|
||||
"sha2",
|
||||
"state-machines",
|
||||
"surrealdb",
|
||||
"tempfile",
|
||||
"text-splitter",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "event-listener"
|
||||
version = "5.4.0"
|
||||
@@ -2735,6 +2814,12 @@ version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024"
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
|
||||
|
||||
[[package]]
|
||||
name = "hex"
|
||||
version = "0.4.3"
|
||||
@@ -3248,11 +3333,13 @@ dependencies = [
|
||||
name = "ingestion-pipeline"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-openai",
|
||||
"async-trait",
|
||||
"axum",
|
||||
"axum_typed_multipart",
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"common",
|
||||
"composite-retrieval",
|
||||
@@ -3330,6 +3417,17 @@ version = "2.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130"
|
||||
|
||||
[[package]]
|
||||
name = "is-terminal"
|
||||
version = "0.4.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46"
|
||||
dependencies = [
|
||||
"hermit-abi 0.5.2",
|
||||
"libc",
|
||||
"windows-sys 0.60.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "is_terminal_polyfill"
|
||||
version = "1.70.1"
|
||||
@@ -4240,7 +4338,7 @@ version = "1.16.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43"
|
||||
dependencies = [
|
||||
"hermit-abi",
|
||||
"hermit-abi 0.3.9",
|
||||
"libc",
|
||||
]
|
||||
|
||||
@@ -4330,6 +4428,12 @@ dependencies = [
|
||||
"pkg-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "oorandom"
|
||||
version = "11.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
|
||||
|
||||
[[package]]
|
||||
name = "opaque-debug"
|
||||
version = "0.3.1"
|
||||
@@ -4705,6 +4809,34 @@ version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
|
||||
|
||||
[[package]]
|
||||
name = "plotters"
|
||||
version = "0.3.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
"plotters-backend",
|
||||
"plotters-svg",
|
||||
"wasm-bindgen",
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "plotters-backend"
|
||||
version = "0.3.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a"
|
||||
|
||||
[[package]]
|
||||
name = "plotters-svg"
|
||||
version = "0.3.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670"
|
||||
dependencies = [
|
||||
"plotters-backend",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "png"
|
||||
version = "0.18.0"
|
||||
@@ -5920,6 +6052,19 @@ dependencies = [
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_yaml"
|
||||
version = "0.9.34+deprecated"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47"
|
||||
dependencies = [
|
||||
"indexmap 2.9.0",
|
||||
"itoa",
|
||||
"ryu",
|
||||
"serde",
|
||||
"unsafe-libyaml",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "servo_arc"
|
||||
version = "0.4.0"
|
||||
@@ -6735,6 +6880,16 @@ dependencies = [
|
||||
"zerovec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinytemplate"
|
||||
version = "1.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinyvec"
|
||||
version = "1.9.0"
|
||||
@@ -7263,6 +7418,12 @@ dependencies = [
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unsafe-libyaml"
|
||||
version = "0.2.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861"
|
||||
|
||||
[[package]]
|
||||
name = "untrusted"
|
||||
version = "0.9.0"
|
||||
|
||||
@@ -6,7 +6,8 @@ members = [
|
||||
"html-router",
|
||||
"ingestion-pipeline",
|
||||
"composite-retrieval",
|
||||
"json-stream-parser"
|
||||
"json-stream-parser",
|
||||
"eval"
|
||||
]
|
||||
resolver = "2"
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ use include_dir::{include_dir, Dir};
|
||||
use std::{ops::Deref, sync::Arc};
|
||||
use surrealdb::{
|
||||
engine::any::{connect, Any},
|
||||
opt::auth::Root,
|
||||
opt::auth::{Namespace, Root},
|
||||
Error, Notification, Surreal,
|
||||
};
|
||||
use surrealdb_migrations::MigrationRunner;
|
||||
@@ -48,6 +48,24 @@ impl SurrealDbClient {
|
||||
Ok(SurrealDbClient { client: db })
|
||||
}
|
||||
|
||||
pub async fn new_with_namespace_user(
|
||||
address: &str,
|
||||
namespace: &str,
|
||||
username: &str,
|
||||
password: &str,
|
||||
database: &str,
|
||||
) -> Result<Self, Error> {
|
||||
let db = connect(address).await?;
|
||||
db.signin(Namespace {
|
||||
namespace,
|
||||
username,
|
||||
password,
|
||||
})
|
||||
.await?;
|
||||
db.use_ns(namespace).use_db(database).await?;
|
||||
Ok(SurrealDbClient { client: db })
|
||||
}
|
||||
|
||||
pub async fn create_session_store(
|
||||
&self,
|
||||
) -> Result<SessionStore<SessionSurrealPool<Any>>, SessionError> {
|
||||
|
||||
@@ -12,6 +12,7 @@ pub struct RetrievalTuning {
|
||||
pub token_budget_estimate: usize,
|
||||
pub avg_chars_per_token: usize,
|
||||
pub max_chunks_per_entity: usize,
|
||||
pub lexical_match_weight: f32,
|
||||
pub graph_traversal_seed_limit: usize,
|
||||
pub graph_neighbor_limit: usize,
|
||||
pub graph_score_decay: f32,
|
||||
@@ -31,9 +32,10 @@ impl Default for RetrievalTuning {
|
||||
chunk_fts_take: 20,
|
||||
score_threshold: 0.35,
|
||||
fallback_min_results: 10,
|
||||
token_budget_estimate: 2800,
|
||||
token_budget_estimate: 10000,
|
||||
avg_chars_per_token: 4,
|
||||
max_chunks_per_entity: 4,
|
||||
lexical_match_weight: 0.15,
|
||||
graph_traversal_seed_limit: 5,
|
||||
graph_neighbor_limit: 6,
|
||||
graph_score_decay: 0.75,
|
||||
|
||||
51
composite-retrieval/src/pipeline/diagnostics.rs
Normal file
51
composite-retrieval/src/pipeline/diagnostics.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
use serde::Serialize;
|
||||
|
||||
/// Captures instrumentation for each hybrid retrieval stage when diagnostics are enabled.
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct PipelineDiagnostics {
|
||||
pub collect_candidates: Option<CollectCandidatesStats>,
|
||||
pub enrich_chunks_from_entities: Option<ChunkEnrichmentStats>,
|
||||
pub assemble: Option<AssembleStats>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct CollectCandidatesStats {
|
||||
pub vector_entity_candidates: usize,
|
||||
pub vector_chunk_candidates: usize,
|
||||
pub fts_entity_candidates: usize,
|
||||
pub fts_chunk_candidates: usize,
|
||||
pub vector_chunk_scores: Vec<f32>,
|
||||
pub fts_chunk_scores: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct ChunkEnrichmentStats {
|
||||
pub filtered_entity_count: usize,
|
||||
pub fallback_min_results: usize,
|
||||
pub chunk_sources_considered: usize,
|
||||
pub chunk_candidates_before_enrichment: usize,
|
||||
pub chunk_candidates_after_enrichment: usize,
|
||||
pub top_chunk_scores: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct AssembleStats {
|
||||
pub token_budget_start: usize,
|
||||
pub token_budget_spent: usize,
|
||||
pub token_budget_remaining: usize,
|
||||
pub budget_exhausted: bool,
|
||||
pub chunks_selected: usize,
|
||||
pub chunks_skipped_due_budget: usize,
|
||||
pub entity_count: usize,
|
||||
pub entity_traces: Vec<EntityAssemblyTrace>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct EntityAssemblyTrace {
|
||||
pub entity_id: String,
|
||||
pub source_id: String,
|
||||
pub inspected_candidates: usize,
|
||||
pub selected_chunk_ids: Vec<String>,
|
||||
pub selected_chunk_scores: Vec<f32>,
|
||||
pub skipped_due_budget: usize,
|
||||
}
|
||||
@@ -1,14 +1,57 @@
|
||||
mod config;
|
||||
mod diagnostics;
|
||||
mod stages;
|
||||
mod state;
|
||||
|
||||
pub use config::{RetrievalConfig, RetrievalTuning};
|
||||
pub use diagnostics::{
|
||||
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
|
||||
PipelineDiagnostics,
|
||||
};
|
||||
|
||||
use crate::{reranking::RerankerLease, RetrievedEntity};
|
||||
use async_openai::Client;
|
||||
use common::{error::AppError, storage::db::SurrealDbClient};
|
||||
use tracing::info;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PipelineRunOutput {
|
||||
pub results: Vec<RetrievedEntity>,
|
||||
pub diagnostics: Option<PipelineDiagnostics>,
|
||||
pub stage_timings: PipelineStageTimings,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, serde::Serialize)]
|
||||
pub struct PipelineStageTimings {
|
||||
pub collect_candidates_ms: u128,
|
||||
pub graph_expansion_ms: u128,
|
||||
pub chunk_attach_ms: u128,
|
||||
pub rerank_ms: u128,
|
||||
pub assemble_ms: u128,
|
||||
}
|
||||
|
||||
impl PipelineStageTimings {
|
||||
fn record_collect_candidates(&mut self, duration: std::time::Duration) {
|
||||
self.collect_candidates_ms += duration.as_millis() as u128;
|
||||
}
|
||||
|
||||
fn record_graph_expansion(&mut self, duration: std::time::Duration) {
|
||||
self.graph_expansion_ms += duration.as_millis() as u128;
|
||||
}
|
||||
|
||||
fn record_chunk_attach(&mut self, duration: std::time::Duration) {
|
||||
self.chunk_attach_ms += duration.as_millis() as u128;
|
||||
}
|
||||
|
||||
fn record_rerank(&mut self, duration: std::time::Duration) {
|
||||
self.rerank_ms += duration.as_millis() as u128;
|
||||
}
|
||||
|
||||
fn record_assemble(&mut self, duration: std::time::Duration) {
|
||||
self.assemble_ms += duration.as_millis() as u128;
|
||||
}
|
||||
}
|
||||
|
||||
/// Drives the retrieval pipeline from embedding through final assembly.
|
||||
pub async fn run_pipeline(
|
||||
db_client: &SurrealDbClient,
|
||||
@@ -18,7 +61,6 @@ pub async fn run_pipeline(
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
let machine = state::ready();
|
||||
let input_chars = input_text.chars().count();
|
||||
let input_preview: String = input_text.chars().take(120).collect();
|
||||
let input_preview_clean = input_preview.replace('\n', " ");
|
||||
@@ -30,7 +72,7 @@ pub async fn run_pipeline(
|
||||
preview = %input_preview_clean,
|
||||
"Starting ingestion retrieval pipeline"
|
||||
);
|
||||
let mut ctx = stages::PipelineContext::new(
|
||||
let ctx = stages::PipelineContext::new(
|
||||
db_client,
|
||||
openai_client,
|
||||
input_text.to_owned(),
|
||||
@@ -38,17 +80,11 @@ pub async fn run_pipeline(
|
||||
config,
|
||||
reranker,
|
||||
);
|
||||
let machine = stages::embed(machine, &mut ctx).await?;
|
||||
let machine = stages::collect_candidates(machine, &mut ctx).await?;
|
||||
let machine = stages::expand_graph(machine, &mut ctx).await?;
|
||||
let machine = stages::attach_chunks(machine, &mut ctx).await?;
|
||||
let machine = stages::rerank(machine, &mut ctx).await?;
|
||||
let results = stages::assemble(machine, &mut ctx)?;
|
||||
let outcome = run_pipeline_internal(ctx, false).await?;
|
||||
|
||||
Ok(results)
|
||||
Ok(outcome.results)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub async fn run_pipeline_with_embedding(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
@@ -58,8 +94,7 @@ pub async fn run_pipeline_with_embedding(
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
let machine = state::ready();
|
||||
let mut ctx = stages::PipelineContext::with_embedding(
|
||||
let ctx = stages::PipelineContext::with_embedding(
|
||||
db_client,
|
||||
openai_client,
|
||||
query_embedding,
|
||||
@@ -68,14 +103,54 @@ pub async fn run_pipeline_with_embedding(
|
||||
config,
|
||||
reranker,
|
||||
);
|
||||
let machine = stages::embed(machine, &mut ctx).await?;
|
||||
let machine = stages::collect_candidates(machine, &mut ctx).await?;
|
||||
let machine = stages::expand_graph(machine, &mut ctx).await?;
|
||||
let machine = stages::attach_chunks(machine, &mut ctx).await?;
|
||||
let machine = stages::rerank(machine, &mut ctx).await?;
|
||||
let results = stages::assemble(machine, &mut ctx)?;
|
||||
let outcome = run_pipeline_internal(ctx, false).await?;
|
||||
|
||||
Ok(results)
|
||||
Ok(outcome.results)
|
||||
}
|
||||
|
||||
/// Runs the pipeline with a precomputed embedding and returns stage metrics.
|
||||
pub async fn run_pipeline_with_embedding_with_metrics(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<PipelineRunOutput, AppError> {
|
||||
let ctx = stages::PipelineContext::with_embedding(
|
||||
db_client,
|
||||
openai_client,
|
||||
query_embedding,
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
config,
|
||||
reranker,
|
||||
);
|
||||
|
||||
run_pipeline_internal(ctx, false).await
|
||||
}
|
||||
|
||||
pub async fn run_pipeline_with_embedding_with_diagnostics(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<PipelineRunOutput, AppError> {
|
||||
let ctx = stages::PipelineContext::with_embedding(
|
||||
db_client,
|
||||
openai_client,
|
||||
query_embedding,
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
config,
|
||||
reranker,
|
||||
);
|
||||
|
||||
run_pipeline_internal(ctx, true).await
|
||||
}
|
||||
|
||||
/// Helper exposed for tests to convert retrieved entities into downstream prompt JSON.
|
||||
@@ -101,6 +176,37 @@ pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::V
|
||||
.collect::<Vec<_>>())
|
||||
}
|
||||
|
||||
async fn run_pipeline_internal(
|
||||
mut ctx: stages::PipelineContext<'_>,
|
||||
capture_diagnostics: bool,
|
||||
) -> Result<PipelineRunOutput, AppError> {
|
||||
if capture_diagnostics {
|
||||
ctx.enable_diagnostics();
|
||||
}
|
||||
|
||||
let results = drive_pipeline(&mut ctx).await?;
|
||||
let diagnostics = ctx.take_diagnostics();
|
||||
|
||||
Ok(PipelineRunOutput {
|
||||
results,
|
||||
diagnostics,
|
||||
stage_timings: ctx.take_stage_timings(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn drive_pipeline(
|
||||
ctx: &mut stages::PipelineContext<'_>,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
let machine = state::ready();
|
||||
let machine = stages::embed(machine, ctx).await?;
|
||||
let machine = stages::collect_candidates(machine, ctx).await?;
|
||||
let machine = stages::expand_graph(machine, ctx).await?;
|
||||
let machine = stages::attach_chunks(machine, ctx).await?;
|
||||
let machine = stages::rerank(machine, ctx).await?;
|
||||
let results = stages::assemble(machine, ctx)?;
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
fn round_score(value: f32) -> f64 {
|
||||
(f64::from(value) * 1000.0).round() / 1000.0
|
||||
}
|
||||
|
||||
@@ -10,7 +10,11 @@ use common::{
|
||||
use fastembed::RerankResult;
|
||||
use futures::{stream::FuturesUnordered, StreamExt};
|
||||
use state_machines::core::GuardError;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
collections::{HashMap, HashSet},
|
||||
time::Instant,
|
||||
};
|
||||
use tracing::{debug, instrument, warn};
|
||||
|
||||
use crate::{
|
||||
@@ -27,10 +31,15 @@ use crate::{
|
||||
|
||||
use super::{
|
||||
config::RetrievalConfig,
|
||||
diagnostics::{
|
||||
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
|
||||
PipelineDiagnostics,
|
||||
},
|
||||
state::{
|
||||
CandidatesLoaded, ChunksAttached, Embedded, GraphExpanded, HybridRetrievalMachine, Ready,
|
||||
Reranked,
|
||||
},
|
||||
PipelineStageTimings,
|
||||
};
|
||||
|
||||
pub struct PipelineContext<'a> {
|
||||
@@ -45,6 +54,8 @@ pub struct PipelineContext<'a> {
|
||||
pub filtered_entities: Vec<Scored<KnowledgeEntity>>,
|
||||
pub chunk_values: Vec<Scored<TextChunk>>,
|
||||
pub reranker: Option<RerankerLease>,
|
||||
pub diagnostics: Option<PipelineDiagnostics>,
|
||||
stage_timings: PipelineStageTimings,
|
||||
}
|
||||
|
||||
impl<'a> PipelineContext<'a> {
|
||||
@@ -68,10 +79,11 @@ impl<'a> PipelineContext<'a> {
|
||||
filtered_entities: Vec::new(),
|
||||
chunk_values: Vec::new(),
|
||||
reranker,
|
||||
diagnostics: None,
|
||||
stage_timings: PipelineStageTimings::default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn with_embedding(
|
||||
db_client: &'a SurrealDbClient,
|
||||
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
@@ -100,6 +112,62 @@ impl<'a> PipelineContext<'a> {
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn enable_diagnostics(&mut self) {
|
||||
if self.diagnostics.is_none() {
|
||||
self.diagnostics = Some(PipelineDiagnostics::default());
|
||||
}
|
||||
}
|
||||
|
||||
pub fn diagnostics_enabled(&self) -> bool {
|
||||
self.diagnostics.is_some()
|
||||
}
|
||||
|
||||
pub fn record_collect_candidates(&mut self, stats: CollectCandidatesStats) {
|
||||
if let Some(diag) = self.diagnostics.as_mut() {
|
||||
diag.collect_candidates = Some(stats);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn record_chunk_enrichment(&mut self, stats: ChunkEnrichmentStats) {
|
||||
if let Some(diag) = self.diagnostics.as_mut() {
|
||||
diag.enrich_chunks_from_entities = Some(stats);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn record_assemble(&mut self, stats: AssembleStats) {
|
||||
if let Some(diag) = self.diagnostics.as_mut() {
|
||||
diag.assemble = Some(stats);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn take_diagnostics(&mut self) -> Option<PipelineDiagnostics> {
|
||||
self.diagnostics.take()
|
||||
}
|
||||
|
||||
pub fn record_collect_candidates_timing(&mut self, duration: std::time::Duration) {
|
||||
self.stage_timings.record_collect_candidates(duration);
|
||||
}
|
||||
|
||||
pub fn record_graph_expansion_timing(&mut self, duration: std::time::Duration) {
|
||||
self.stage_timings.record_graph_expansion(duration);
|
||||
}
|
||||
|
||||
pub fn record_chunk_attach_timing(&mut self, duration: std::time::Duration) {
|
||||
self.stage_timings.record_chunk_attach(duration);
|
||||
}
|
||||
|
||||
pub fn record_rerank_timing(&mut self, duration: std::time::Duration) {
|
||||
self.stage_timings.record_rerank(duration);
|
||||
}
|
||||
|
||||
pub fn record_assemble_timing(&mut self, duration: std::time::Duration) {
|
||||
self.stage_timings.record_assemble(duration);
|
||||
}
|
||||
|
||||
pub fn take_stage_timings(&mut self) -> PipelineStageTimings {
|
||||
std::mem::take(&mut self.stage_timings)
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
@@ -127,6 +195,7 @@ pub async fn collect_candidates(
|
||||
machine: HybridRetrievalMachine<(), Embedded>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<HybridRetrievalMachine<(), CandidatesLoaded>, AppError> {
|
||||
let stage_start = Instant::now();
|
||||
debug!("Collecting initial candidates via vector and FTS search");
|
||||
let embedding = ctx.ensure_embedding()?.clone();
|
||||
let tuning = &ctx.config.tuning;
|
||||
@@ -172,6 +241,19 @@ pub async fn collect_candidates(
|
||||
"Hybrid retrieval initial candidate counts"
|
||||
);
|
||||
|
||||
if ctx.diagnostics_enabled() {
|
||||
ctx.record_collect_candidates(CollectCandidatesStats {
|
||||
vector_entity_candidates: vector_entities.len(),
|
||||
vector_chunk_candidates: vector_chunks.len(),
|
||||
fts_entity_candidates: fts_entities.len(),
|
||||
fts_chunk_candidates: fts_chunks.len(),
|
||||
vector_chunk_scores: sample_scores(&vector_chunks, |chunk| {
|
||||
chunk.scores.vector.unwrap_or(0.0)
|
||||
}),
|
||||
fts_chunk_scores: sample_scores(&fts_chunks, |chunk| chunk.scores.fts.unwrap_or(0.0)),
|
||||
});
|
||||
}
|
||||
|
||||
normalize_fts_scores(&mut fts_entities);
|
||||
normalize_fts_scores(&mut fts_chunks);
|
||||
|
||||
@@ -183,9 +265,11 @@ pub async fn collect_candidates(
|
||||
apply_fusion(&mut ctx.entity_candidates, weights);
|
||||
apply_fusion(&mut ctx.chunk_candidates, weights);
|
||||
|
||||
machine
|
||||
let next = machine
|
||||
.collect_candidates()
|
||||
.map_err(|(_, guard)| map_guard_error("collect_candidates", guard))
|
||||
.map_err(|(_, guard)| map_guard_error("collect_candidates", guard))?;
|
||||
ctx.record_collect_candidates_timing(stage_start.elapsed());
|
||||
Ok(next)
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
@@ -193,82 +277,84 @@ pub async fn expand_graph(
|
||||
machine: HybridRetrievalMachine<(), CandidatesLoaded>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<HybridRetrievalMachine<(), GraphExpanded>, AppError> {
|
||||
let stage_start = Instant::now();
|
||||
debug!("Expanding candidates using graph relationships");
|
||||
let tuning = &ctx.config.tuning;
|
||||
let weights = FusionWeights::default();
|
||||
let next = {
|
||||
let tuning = &ctx.config.tuning;
|
||||
let weights = FusionWeights::default();
|
||||
|
||||
if ctx.entity_candidates.is_empty() {
|
||||
return machine
|
||||
.expand_graph()
|
||||
.map_err(|(_, guard)| map_guard_error("expand_graph", guard));
|
||||
}
|
||||
if ctx.entity_candidates.is_empty() {
|
||||
machine
|
||||
.expand_graph()
|
||||
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
|
||||
} else {
|
||||
let graph_seeds = seeds_from_candidates(
|
||||
&ctx.entity_candidates,
|
||||
tuning.graph_seed_min_score,
|
||||
tuning.graph_traversal_seed_limit,
|
||||
);
|
||||
|
||||
let graph_seeds = seeds_from_candidates(
|
||||
&ctx.entity_candidates,
|
||||
tuning.graph_seed_min_score,
|
||||
tuning.graph_traversal_seed_limit,
|
||||
);
|
||||
if graph_seeds.is_empty() {
|
||||
machine
|
||||
.expand_graph()
|
||||
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
|
||||
} else {
|
||||
let mut futures = FuturesUnordered::new();
|
||||
for seed in graph_seeds {
|
||||
let db = ctx.db_client;
|
||||
let user = ctx.user_id.clone();
|
||||
let limit = tuning.graph_neighbor_limit;
|
||||
futures.push(async move {
|
||||
let neighbors =
|
||||
find_entities_by_relationship_by_id(db, &seed.id, &user, limit).await;
|
||||
(seed, neighbors)
|
||||
});
|
||||
}
|
||||
|
||||
if graph_seeds.is_empty() {
|
||||
return machine
|
||||
.expand_graph()
|
||||
.map_err(|(_, guard)| map_guard_error("expand_graph", guard));
|
||||
}
|
||||
while let Some((seed, neighbors_result)) = futures.next().await {
|
||||
let neighbors = neighbors_result.map_err(AppError::from)?;
|
||||
if neighbors.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut futures = FuturesUnordered::new();
|
||||
for seed in graph_seeds {
|
||||
let db = ctx.db_client;
|
||||
let user = ctx.user_id.clone();
|
||||
futures.push(async move {
|
||||
let neighbors = find_entities_by_relationship_by_id(
|
||||
db,
|
||||
&seed.id,
|
||||
&user,
|
||||
tuning.graph_neighbor_limit,
|
||||
)
|
||||
.await;
|
||||
(seed, neighbors)
|
||||
});
|
||||
}
|
||||
for neighbor in neighbors {
|
||||
if neighbor.id == seed.id {
|
||||
continue;
|
||||
}
|
||||
|
||||
while let Some((seed, neighbors_result)) = futures.next().await {
|
||||
let neighbors = neighbors_result.map_err(AppError::from)?;
|
||||
if neighbors.is_empty() {
|
||||
continue;
|
||||
let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay);
|
||||
let entry = ctx
|
||||
.entity_candidates
|
||||
.entry(neighbor.id.clone())
|
||||
.or_insert_with(|| Scored::new(neighbor.clone()));
|
||||
|
||||
entry.item = neighbor;
|
||||
|
||||
let inherited_vector =
|
||||
clamp_unit(graph_score * tuning.graph_vector_inheritance);
|
||||
let vector_existing = entry.scores.vector.unwrap_or(0.0);
|
||||
if inherited_vector > vector_existing {
|
||||
entry.scores.vector = Some(inherited_vector);
|
||||
}
|
||||
|
||||
let existing_graph = entry.scores.graph.unwrap_or(f32::MIN);
|
||||
if graph_score > existing_graph || entry.scores.graph.is_none() {
|
||||
entry.scores.graph = Some(graph_score);
|
||||
}
|
||||
|
||||
let fused = fuse_scores(&entry.scores, weights);
|
||||
entry.update_fused(fused);
|
||||
}
|
||||
}
|
||||
|
||||
machine
|
||||
.expand_graph()
|
||||
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
|
||||
}
|
||||
}
|
||||
|
||||
for neighbor in neighbors {
|
||||
if neighbor.id == seed.id {
|
||||
continue;
|
||||
}
|
||||
|
||||
let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay);
|
||||
let entry = ctx
|
||||
.entity_candidates
|
||||
.entry(neighbor.id.clone())
|
||||
.or_insert_with(|| Scored::new(neighbor.clone()));
|
||||
|
||||
entry.item = neighbor;
|
||||
|
||||
let inherited_vector = clamp_unit(graph_score * tuning.graph_vector_inheritance);
|
||||
let vector_existing = entry.scores.vector.unwrap_or(0.0);
|
||||
if inherited_vector > vector_existing {
|
||||
entry.scores.vector = Some(inherited_vector);
|
||||
}
|
||||
|
||||
let existing_graph = entry.scores.graph.unwrap_or(f32::MIN);
|
||||
if graph_score > existing_graph || entry.scores.graph.is_none() {
|
||||
entry.scores.graph = Some(graph_score);
|
||||
}
|
||||
|
||||
let fused = fuse_scores(&entry.scores, weights);
|
||||
entry.update_fused(fused);
|
||||
}
|
||||
}
|
||||
|
||||
machine
|
||||
.expand_graph()
|
||||
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
|
||||
}?;
|
||||
ctx.record_graph_expansion_timing(stage_start.elapsed());
|
||||
Ok(next)
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
@@ -276,11 +362,14 @@ pub async fn attach_chunks(
|
||||
machine: HybridRetrievalMachine<(), GraphExpanded>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<HybridRetrievalMachine<(), ChunksAttached>, AppError> {
|
||||
let stage_start = Instant::now();
|
||||
debug!("Attaching chunks to surviving entities");
|
||||
let tuning = &ctx.config.tuning;
|
||||
let weights = FusionWeights::default();
|
||||
|
||||
let chunk_by_source = group_chunks_by_source(&ctx.chunk_candidates);
|
||||
let chunk_candidates_before = ctx.chunk_candidates.len();
|
||||
let chunk_sources_considered = chunk_by_source.len();
|
||||
|
||||
backfill_entities_from_chunks(
|
||||
&mut ctx.entity_candidates,
|
||||
@@ -312,6 +401,8 @@ pub async fn attach_chunks(
|
||||
|
||||
ctx.filtered_entities = filtered_entities;
|
||||
|
||||
let query_embedding = ctx.ensure_embedding()?.clone();
|
||||
|
||||
let mut chunk_results: Vec<Scored<TextChunk>> =
|
||||
ctx.chunk_candidates.values().cloned().collect();
|
||||
sort_by_fused_desc(&mut chunk_results);
|
||||
@@ -327,17 +418,31 @@ pub async fn attach_chunks(
|
||||
ctx.db_client,
|
||||
&ctx.user_id,
|
||||
weights,
|
||||
&query_embedding,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut chunk_values: Vec<Scored<TextChunk>> = chunk_by_id.into_values().collect();
|
||||
sort_by_fused_desc(&mut chunk_values);
|
||||
|
||||
if ctx.diagnostics_enabled() {
|
||||
ctx.record_chunk_enrichment(ChunkEnrichmentStats {
|
||||
filtered_entity_count: ctx.filtered_entities.len(),
|
||||
fallback_min_results: tuning.fallback_min_results,
|
||||
chunk_sources_considered,
|
||||
chunk_candidates_before_enrichment: chunk_candidates_before,
|
||||
chunk_candidates_after_enrichment: chunk_values.len(),
|
||||
top_chunk_scores: sample_scores(&chunk_values, |chunk| chunk.fused),
|
||||
});
|
||||
}
|
||||
|
||||
ctx.chunk_values = chunk_values;
|
||||
|
||||
machine
|
||||
let next = machine
|
||||
.attach_chunks()
|
||||
.map_err(|(_, guard)| map_guard_error("attach_chunks", guard))
|
||||
.map_err(|(_, guard)| map_guard_error("attach_chunks", guard))?;
|
||||
ctx.record_chunk_attach_timing(stage_start.elapsed());
|
||||
Ok(next)
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
@@ -345,6 +450,7 @@ pub async fn rerank(
|
||||
machine: HybridRetrievalMachine<(), ChunksAttached>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<HybridRetrievalMachine<(), Reranked>, AppError> {
|
||||
let stage_start = Instant::now();
|
||||
let mut applied = false;
|
||||
|
||||
if let Some(reranker) = ctx.reranker.as_ref() {
|
||||
@@ -384,9 +490,11 @@ pub async fn rerank(
|
||||
debug!("Applied reranking adjustments to candidate ordering");
|
||||
}
|
||||
|
||||
machine
|
||||
let next = machine
|
||||
.rerank()
|
||||
.map_err(|(_, guard)| map_guard_error("rerank", guard))
|
||||
.map_err(|(_, guard)| map_guard_error("rerank", guard))?;
|
||||
ctx.record_rerank_timing(stage_start.elapsed());
|
||||
Ok(next)
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
@@ -394,8 +502,11 @@ pub fn assemble(
|
||||
machine: HybridRetrievalMachine<(), Reranked>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
let stage_start = Instant::now();
|
||||
debug!("Assembling final retrieved entities");
|
||||
let tuning = &ctx.config.tuning;
|
||||
let query_embedding = ctx.ensure_embedding()?.clone();
|
||||
let question_terms = extract_keywords(&ctx.input_text);
|
||||
|
||||
let mut chunk_by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
|
||||
for chunk in ctx.chunk_values.drain(..) {
|
||||
@@ -406,39 +517,68 @@ pub fn assemble(
|
||||
}
|
||||
|
||||
for chunk_list in chunk_by_source.values_mut() {
|
||||
sort_by_fused_desc(chunk_list);
|
||||
chunk_list.sort_by(|a, b| {
|
||||
let sim_a = cosine_similarity(&query_embedding, &a.item.embedding);
|
||||
let sim_b = cosine_similarity(&query_embedding, &b.item.embedding);
|
||||
sim_b.partial_cmp(&sim_a).unwrap_or(Ordering::Equal)
|
||||
});
|
||||
}
|
||||
|
||||
let mut token_budget_remaining = tuning.token_budget_estimate;
|
||||
let mut results = Vec::new();
|
||||
let diagnostics_enabled = ctx.diagnostics_enabled();
|
||||
let mut per_entity_traces = Vec::new();
|
||||
let mut chunks_skipped_due_budget = 0usize;
|
||||
let mut chunks_selected = 0usize;
|
||||
let mut tokens_spent = 0usize;
|
||||
|
||||
for entity in &ctx.filtered_entities {
|
||||
let mut selected_chunks = Vec::new();
|
||||
let mut entity_trace = if diagnostics_enabled {
|
||||
Some(EntityAssemblyTrace {
|
||||
entity_id: entity.item.id.clone(),
|
||||
source_id: entity.item.source_id.clone(),
|
||||
inspected_candidates: 0,
|
||||
selected_chunk_ids: Vec::new(),
|
||||
selected_chunk_scores: Vec::new(),
|
||||
skipped_due_budget: 0,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
if let Some(candidates) = chunk_by_source.get_mut(&entity.item.source_id) {
|
||||
rank_chunks_by_combined_score(candidates, &question_terms, tuning.lexical_match_weight);
|
||||
let mut per_entity_count = 0;
|
||||
candidates.sort_by(|a, b| {
|
||||
b.fused
|
||||
.partial_cmp(&a.fused)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
for candidate in candidates.iter() {
|
||||
if let Some(trace) = entity_trace.as_mut() {
|
||||
trace.inspected_candidates += 1;
|
||||
}
|
||||
if per_entity_count >= tuning.max_chunks_per_entity {
|
||||
break;
|
||||
}
|
||||
let estimated_tokens =
|
||||
estimate_tokens(&candidate.item.chunk, tuning.avg_chars_per_token);
|
||||
if estimated_tokens > token_budget_remaining {
|
||||
chunks_skipped_due_budget += 1;
|
||||
if let Some(trace) = entity_trace.as_mut() {
|
||||
trace.skipped_due_budget += 1;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
token_budget_remaining = token_budget_remaining.saturating_sub(estimated_tokens);
|
||||
tokens_spent += estimated_tokens;
|
||||
per_entity_count += 1;
|
||||
chunks_selected += 1;
|
||||
|
||||
selected_chunks.push(RetrievedChunk {
|
||||
chunk: candidate.item.clone(),
|
||||
score: candidate.fused,
|
||||
});
|
||||
if let Some(trace) = entity_trace.as_mut() {
|
||||
trace.selected_chunk_ids.push(candidate.item.id.clone());
|
||||
trace.selected_chunk_scores.push(candidate.fused);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -448,17 +588,48 @@ pub fn assemble(
|
||||
chunks: selected_chunks,
|
||||
});
|
||||
|
||||
if let Some(trace) = entity_trace {
|
||||
per_entity_traces.push(trace);
|
||||
}
|
||||
|
||||
if token_budget_remaining == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if diagnostics_enabled {
|
||||
ctx.record_assemble(AssembleStats {
|
||||
token_budget_start: tuning.token_budget_estimate,
|
||||
token_budget_spent: tokens_spent,
|
||||
token_budget_remaining,
|
||||
budget_exhausted: token_budget_remaining == 0,
|
||||
chunks_selected,
|
||||
chunks_skipped_due_budget,
|
||||
entity_count: ctx.filtered_entities.len(),
|
||||
entity_traces: per_entity_traces,
|
||||
});
|
||||
}
|
||||
|
||||
machine
|
||||
.assemble()
|
||||
.map_err(|(_, guard)| map_guard_error("assemble", guard))?;
|
||||
ctx.record_assemble_timing(stage_start.elapsed());
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
const SCORE_SAMPLE_LIMIT: usize = 8;
|
||||
|
||||
fn sample_scores<T, F>(items: &[Scored<T>], mut extractor: F) -> Vec<f32>
|
||||
where
|
||||
F: FnMut(&Scored<T>) -> f32,
|
||||
{
|
||||
items
|
||||
.iter()
|
||||
.take(SCORE_SAMPLE_LIMIT)
|
||||
.map(|item| extractor(item))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn map_guard_error(stage: &'static str, err: GuardError) -> AppError {
|
||||
AppError::InternalError(format!(
|
||||
"state machine guard '{stage}' failed: guard={}, event={}, kind={:?}",
|
||||
@@ -582,6 +753,7 @@ async fn enrich_chunks_from_entities(
|
||||
db_client: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
weights: FusionWeights,
|
||||
query_embedding: &[f32],
|
||||
) -> Result<(), AppError> {
|
||||
let mut source_ids: HashSet<String> = HashSet::new();
|
||||
for entity in entities {
|
||||
@@ -615,7 +787,16 @@ async fn enrich_chunks_from_entities(
|
||||
.copied()
|
||||
.unwrap_or(0.0);
|
||||
|
||||
entry.scores.vector = Some(entry.scores.vector.unwrap_or(0.0).max(entity_score * 0.8));
|
||||
let similarity = cosine_similarity(query_embedding, &chunk.embedding);
|
||||
|
||||
entry.scores.vector = Some(
|
||||
entry
|
||||
.scores
|
||||
.vector
|
||||
.unwrap_or(0.0)
|
||||
.max(entity_score * 0.8)
|
||||
.max(similarity),
|
||||
);
|
||||
let fused = fuse_scores(&entry.scores, weights);
|
||||
entry.update_fused(fused);
|
||||
entry.item = chunk;
|
||||
@@ -624,6 +805,24 @@ async fn enrich_chunks_from_entities(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cosine_similarity(query: &[f32], embedding: &[f32]) -> f32 {
|
||||
if query.is_empty() || embedding.is_empty() || query.len() != embedding.len() {
|
||||
return 0.0;
|
||||
}
|
||||
let mut dot = 0.0f32;
|
||||
let mut norm_q = 0.0f32;
|
||||
let mut norm_e = 0.0f32;
|
||||
for (q, e) in query.iter().zip(embedding.iter()) {
|
||||
dot += q * e;
|
||||
norm_q += q * q;
|
||||
norm_e += e * e;
|
||||
}
|
||||
if norm_q == 0.0 || norm_e == 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
dot / (norm_q.sqrt() * norm_e.sqrt())
|
||||
}
|
||||
|
||||
fn build_rerank_documents(ctx: &PipelineContext<'_>, max_chunks_per_entity: usize) -> Vec<String> {
|
||||
if ctx.filtered_entities.is_empty() {
|
||||
return Vec::new();
|
||||
@@ -736,6 +935,48 @@ fn estimate_tokens(text: &str, avg_chars_per_token: usize) -> usize {
|
||||
(chars / avg_chars_per_token).max(1)
|
||||
}
|
||||
|
||||
fn rank_chunks_by_combined_score(
|
||||
candidates: &mut [Scored<TextChunk>],
|
||||
question_terms: &[String],
|
||||
lexical_weight: f32,
|
||||
) {
|
||||
if lexical_weight > 0.0 && !question_terms.is_empty() {
|
||||
for candidate in candidates.iter_mut() {
|
||||
let lexical = lexical_overlap_score(question_terms, &candidate.item.chunk);
|
||||
let combined = clamp_unit(candidate.fused + lexical_weight * lexical);
|
||||
candidate.update_fused(combined);
|
||||
}
|
||||
}
|
||||
candidates.sort_by(|a, b| b.fused.partial_cmp(&a.fused).unwrap_or(Ordering::Equal));
|
||||
}
|
||||
|
||||
fn extract_keywords(text: &str) -> Vec<String> {
|
||||
let mut terms = Vec::new();
|
||||
for raw in text.split(|c: char| !c.is_alphanumeric()) {
|
||||
let term = raw.trim().to_ascii_lowercase();
|
||||
if term.len() >= 3 {
|
||||
terms.push(term);
|
||||
}
|
||||
}
|
||||
terms.sort();
|
||||
terms.dedup();
|
||||
terms
|
||||
}
|
||||
|
||||
fn lexical_overlap_score(terms: &[String], haystack: &str) -> f32 {
|
||||
if terms.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
let lower = haystack.to_ascii_lowercase();
|
||||
let mut matches = 0usize;
|
||||
for term in terms {
|
||||
if lower.contains(term) {
|
||||
matches += 1;
|
||||
}
|
||||
}
|
||||
(matches as f32) / (terms.len() as f32)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct GraphSeed {
|
||||
id: String,
|
||||
|
||||
@@ -68,6 +68,7 @@ where
|
||||
{
|
||||
let embedding_literal = serde_json::to_string(&query_embedding)
|
||||
.map_err(|err| AppError::InternalError(format!("Failed to serialize embedding: {err}")))?;
|
||||
|
||||
let closest_query = format!(
|
||||
"SELECT id, vector::distance::knn() AS distance \
|
||||
FROM {table} \
|
||||
|
||||
33
eval/Cargo.toml
Normal file
33
eval/Cargo.toml
Normal file
@@ -0,0 +1,33 @@
|
||||
[package]
|
||||
name = "eval"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
async-openai = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
common = { path = "../common" }
|
||||
composite-retrieval = { path = "../composite-retrieval" }
|
||||
ingestion-pipeline = { path = "../ingestion-pipeline" }
|
||||
futures = { workspace = true }
|
||||
fastembed = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
text-splitter = { workspace = true }
|
||||
rand = "0.8"
|
||||
sha2 = { workspace = true }
|
||||
object_store = { workspace = true }
|
||||
surrealdb = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
once_cell = "1.19"
|
||||
serde_yaml = "0.9"
|
||||
criterion = "0.5"
|
||||
state-machines = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
33
eval/manifest.yaml
Normal file
33
eval/manifest.yaml
Normal file
@@ -0,0 +1,33 @@
|
||||
default_dataset: squad-v2
|
||||
datasets:
|
||||
- id: squad-v2
|
||||
label: "SQuAD v2.0"
|
||||
category: "SQuAD v2.0"
|
||||
entity_suffix: "SQuAD"
|
||||
source_prefix: "squad"
|
||||
raw: "data/raw/squad/dev-v2.0.json"
|
||||
converted: "data/converted/squad-minne.json"
|
||||
include_unanswerable: false
|
||||
slices:
|
||||
- id: squad-dev-200
|
||||
label: "SQuAD dev (200)"
|
||||
description: "Deterministic 200-case slice for local eval"
|
||||
limit: 200
|
||||
corpus_limit: 2000
|
||||
seed: 0x5eed2025
|
||||
- id: natural-questions-dev
|
||||
label: "Natural Questions (dev)"
|
||||
category: "Natural Questions"
|
||||
entity_suffix: "Natural Questions"
|
||||
source_prefix: "nq"
|
||||
raw: "data/raw/nq/dev-all.jsonl"
|
||||
converted: "data/converted/nq-dev-minne.json"
|
||||
include_unanswerable: true
|
||||
slices:
|
||||
- id: nq-dev-200
|
||||
label: "NQ dev (200)"
|
||||
description: "200-case slice of the dev set"
|
||||
limit: 200
|
||||
corpus_limit: 2000
|
||||
include_unanswerable: false
|
||||
seed: 0x5eed2025
|
||||
638
eval/src/args.rs
Normal file
638
eval/src/args.rs
Normal file
@@ -0,0 +1,638 @@
|
||||
use std::{
|
||||
env,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
|
||||
use crate::datasets::DatasetKind;
|
||||
|
||||
pub const DEFAULT_SLICE_SEED: u64 = 0x5eed_2025;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum EmbeddingBackend {
|
||||
Hashed,
|
||||
FastEmbed,
|
||||
}
|
||||
|
||||
impl Default for EmbeddingBackend {
|
||||
fn default() -> Self {
|
||||
Self::FastEmbed
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for EmbeddingBackend {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_ascii_lowercase().as_str() {
|
||||
"hashed" => Ok(Self::Hashed),
|
||||
"fastembed" | "fast-embed" | "fast" => Ok(Self::FastEmbed),
|
||||
other => Err(anyhow!(
|
||||
"unknown embedding backend '{other}'. Expected 'hashed' or 'fastembed'."
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub convert_only: bool,
|
||||
pub force_convert: bool,
|
||||
pub dataset: DatasetKind,
|
||||
pub llm_mode: bool,
|
||||
pub corpus_limit: Option<usize>,
|
||||
pub raw_dataset_path: PathBuf,
|
||||
pub converted_dataset_path: PathBuf,
|
||||
pub report_dir: PathBuf,
|
||||
pub k: usize,
|
||||
pub limit: Option<usize>,
|
||||
pub summary_sample: usize,
|
||||
pub full_context: bool,
|
||||
pub chunk_min_chars: usize,
|
||||
pub chunk_max_chars: usize,
|
||||
pub chunk_vector_take: Option<usize>,
|
||||
pub chunk_fts_take: Option<usize>,
|
||||
pub chunk_token_budget: Option<usize>,
|
||||
pub chunk_avg_chars_per_token: Option<usize>,
|
||||
pub max_chunks_per_entity: Option<usize>,
|
||||
pub rerank: bool,
|
||||
pub rerank_pool_size: usize,
|
||||
pub rerank_keep_top: usize,
|
||||
pub concurrency: usize,
|
||||
pub embedding_backend: EmbeddingBackend,
|
||||
pub embedding_model: Option<String>,
|
||||
pub cache_dir: PathBuf,
|
||||
pub ingestion_cache_dir: PathBuf,
|
||||
pub refresh_embeddings_only: bool,
|
||||
pub detailed_report: bool,
|
||||
pub slice: Option<String>,
|
||||
pub reseed_slice: bool,
|
||||
pub slice_seed: u64,
|
||||
pub slice_grow: Option<usize>,
|
||||
pub slice_offset: usize,
|
||||
pub slice_reset_ingestion: bool,
|
||||
pub negative_multiplier: f32,
|
||||
pub label: Option<String>,
|
||||
pub chunk_diagnostics_path: Option<PathBuf>,
|
||||
pub inspect_question: Option<String>,
|
||||
pub inspect_manifest: Option<PathBuf>,
|
||||
pub query_model: Option<String>,
|
||||
pub perf_log_json: Option<PathBuf>,
|
||||
pub perf_log_dir: Option<PathBuf>,
|
||||
pub perf_log_console: bool,
|
||||
pub db_endpoint: String,
|
||||
pub db_username: String,
|
||||
pub db_password: String,
|
||||
pub db_namespace: Option<String>,
|
||||
pub db_database: Option<String>,
|
||||
pub inspect_db_state: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
let dataset = DatasetKind::default();
|
||||
Self {
|
||||
convert_only: false,
|
||||
force_convert: false,
|
||||
dataset,
|
||||
llm_mode: false,
|
||||
corpus_limit: None,
|
||||
raw_dataset_path: dataset.default_raw_path(),
|
||||
converted_dataset_path: dataset.default_converted_path(),
|
||||
report_dir: PathBuf::from("eval/reports"),
|
||||
k: 5,
|
||||
limit: Some(200),
|
||||
summary_sample: 5,
|
||||
full_context: false,
|
||||
chunk_min_chars: 500,
|
||||
chunk_max_chars: 2_000,
|
||||
chunk_vector_take: None,
|
||||
chunk_fts_take: None,
|
||||
chunk_token_budget: None,
|
||||
chunk_avg_chars_per_token: None,
|
||||
max_chunks_per_entity: None,
|
||||
rerank: true,
|
||||
rerank_pool_size: 16,
|
||||
rerank_keep_top: 10,
|
||||
concurrency: 4,
|
||||
embedding_backend: EmbeddingBackend::FastEmbed,
|
||||
embedding_model: None,
|
||||
cache_dir: PathBuf::from("eval/cache"),
|
||||
ingestion_cache_dir: PathBuf::from("eval/cache/ingested"),
|
||||
refresh_embeddings_only: false,
|
||||
detailed_report: false,
|
||||
slice: None,
|
||||
reseed_slice: false,
|
||||
slice_seed: DEFAULT_SLICE_SEED,
|
||||
slice_grow: None,
|
||||
slice_offset: 0,
|
||||
slice_reset_ingestion: false,
|
||||
negative_multiplier: crate::slices::DEFAULT_NEGATIVE_MULTIPLIER,
|
||||
label: None,
|
||||
chunk_diagnostics_path: None,
|
||||
inspect_question: None,
|
||||
inspect_manifest: None,
|
||||
query_model: None,
|
||||
inspect_db_state: None,
|
||||
perf_log_json: None,
|
||||
perf_log_dir: None,
|
||||
perf_log_console: false,
|
||||
db_endpoint: "ws://127.0.0.1:8000".to_string(),
|
||||
db_username: "root_user".to_string(),
|
||||
db_password: "root_password".to_string(),
|
||||
db_namespace: None,
|
||||
db_database: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn context_token_limit(&self) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ParsedArgs {
|
||||
pub config: Config,
|
||||
pub show_help: bool,
|
||||
}
|
||||
|
||||
pub fn parse() -> Result<ParsedArgs> {
|
||||
let mut config = Config::default();
|
||||
let mut show_help = false;
|
||||
let mut raw_overridden = false;
|
||||
let mut converted_overridden = false;
|
||||
|
||||
let mut args = env::args().skip(1).peekable();
|
||||
while let Some(arg) = args.next() {
|
||||
match arg.as_str() {
|
||||
"-h" | "--help" => {
|
||||
show_help = true;
|
||||
break;
|
||||
}
|
||||
"--convert-only" => config.convert_only = true,
|
||||
"--force" | "--refresh" => config.force_convert = true,
|
||||
"--llm-mode" => {
|
||||
config.llm_mode = true;
|
||||
}
|
||||
"--dataset" => {
|
||||
let value = take_value("--dataset", &mut args)?;
|
||||
let parsed = value.parse::<DatasetKind>()?;
|
||||
config.dataset = parsed;
|
||||
if !raw_overridden {
|
||||
config.raw_dataset_path = parsed.default_raw_path();
|
||||
}
|
||||
if !converted_overridden {
|
||||
config.converted_dataset_path = parsed.default_converted_path();
|
||||
}
|
||||
}
|
||||
"--slice" => {
|
||||
let value = take_value("--slice", &mut args)?;
|
||||
config.slice = Some(value);
|
||||
}
|
||||
"--label" => {
|
||||
let value = take_value("--label", &mut args)?;
|
||||
config.label = Some(value);
|
||||
}
|
||||
"--query-model" => {
|
||||
let value = take_value("--query-model", &mut args)?;
|
||||
if value.trim().is_empty() {
|
||||
return Err(anyhow!("--query-model requires a non-empty model name"));
|
||||
}
|
||||
config.query_model = Some(value.trim().to_string());
|
||||
}
|
||||
"--slice-grow" => {
|
||||
let value = take_value("--slice-grow", &mut args)?;
|
||||
let parsed = value.parse::<usize>().with_context(|| {
|
||||
format!("failed to parse --slice-grow value '{value}' as usize")
|
||||
})?;
|
||||
if parsed == 0 {
|
||||
return Err(anyhow!("--slice-grow must be greater than zero"));
|
||||
}
|
||||
config.slice_grow = Some(parsed);
|
||||
}
|
||||
"--slice-offset" => {
|
||||
let value = take_value("--slice-offset", &mut args)?;
|
||||
let parsed = value.parse::<usize>().with_context(|| {
|
||||
format!("failed to parse --slice-offset value '{value}' as usize")
|
||||
})?;
|
||||
config.slice_offset = parsed;
|
||||
}
|
||||
"--raw" => {
|
||||
let value = take_value("--raw", &mut args)?;
|
||||
config.raw_dataset_path = PathBuf::from(value);
|
||||
raw_overridden = true;
|
||||
}
|
||||
"--converted" => {
|
||||
let value = take_value("--converted", &mut args)?;
|
||||
config.converted_dataset_path = PathBuf::from(value);
|
||||
converted_overridden = true;
|
||||
}
|
||||
"--corpus-limit" => {
|
||||
let value = take_value("--corpus-limit", &mut args)?;
|
||||
let parsed = value.parse::<usize>().with_context(|| {
|
||||
format!("failed to parse --corpus-limit value '{value}' as usize")
|
||||
})?;
|
||||
config.corpus_limit = if parsed == 0 { None } else { Some(parsed) };
|
||||
}
|
||||
"--reseed-slice" => {
|
||||
config.reseed_slice = true;
|
||||
}
|
||||
"--slice-reset-ingestion" => {
|
||||
config.slice_reset_ingestion = true;
|
||||
}
|
||||
"--report-dir" => {
|
||||
let value = take_value("--report-dir", &mut args)?;
|
||||
config.report_dir = PathBuf::from(value);
|
||||
}
|
||||
"--k" => {
|
||||
let value = take_value("--k", &mut args)?;
|
||||
let parsed = value
|
||||
.parse::<usize>()
|
||||
.with_context(|| format!("failed to parse --k value '{value}' as usize"))?;
|
||||
if parsed == 0 {
|
||||
return Err(anyhow!("--k must be greater than zero"));
|
||||
}
|
||||
config.k = parsed;
|
||||
}
|
||||
"--limit" => {
|
||||
let value = take_value("--limit", &mut args)?;
|
||||
let parsed = value
|
||||
.parse::<usize>()
|
||||
.with_context(|| format!("failed to parse --limit value '{value}' as usize"))?;
|
||||
config.limit = if parsed == 0 { None } else { Some(parsed) };
|
||||
}
|
||||
"--sample" => {
|
||||
let value = take_value("--sample", &mut args)?;
|
||||
let parsed = value.parse::<usize>().with_context(|| {
|
||||
format!("failed to parse --sample value '{value}' as usize")
|
||||
})?;
|
||||
config.summary_sample = parsed.max(1);
|
||||
}
|
||||
"--full-context" => {
|
||||
config.full_context = true;
|
||||
}
|
||||
"--chunk-min" => {
|
||||
let value = take_value("--chunk-min", &mut args)?;
|
||||
let parsed = value.parse::<usize>().with_context(|| {
|
||||
format!("failed to parse --chunk-min value '{value}' as usize")
|
||||
})?;
|
||||
config.chunk_min_chars = parsed.max(1);
|
||||
}
|
||||
"--chunk-max" => {
|
||||
let value = take_value("--chunk-max", &mut args)?;
|
||||
let parsed = value.parse::<usize>().with_context(|| {
|
||||
format!("failed to parse --chunk-max value '{value}' as usize")
|
||||
})?;
|
||||
config.chunk_max_chars = parsed.max(1);
|
||||
}
|
||||
"--chunk-vector-take" => {
|
||||
let value = take_value("--chunk-vector-take", &mut args)?;
|
||||
let parsed = value.parse::<usize>().with_context(|| {
|
||||
format!("failed to parse --chunk-vector-take value '{value}' as usize")
|
||||
})?;
|
||||
if parsed == 0 {
|
||||
return Err(anyhow!("--chunk-vector-take must be greater than zero"));
|
||||
}
|
||||
config.chunk_vector_take = Some(parsed);
|
||||
}
|
||||
"--chunk-fts-take" => {
|
||||
let value = take_value("--chunk-fts-take", &mut args)?;
|
||||
let parsed = value.parse::<usize>().with_context(|| {
|
||||
format!("failed to parse --chunk-fts-take value '{value}' as usize")
|
||||
})?;
|
||||
if parsed == 0 {
|
||||
return Err(anyhow!("--chunk-fts-take must be greater than zero"));
|
||||
}
|
||||
config.chunk_fts_take = Some(parsed);
|
||||
}
|
||||
"--chunk-token-budget" => {
|
||||
let value = take_value("--chunk-token-budget", &mut args)?;
|
||||
let parsed = value.parse::<usize>().with_context(|| {
|
||||
format!("failed to parse --chunk-token-budget value '{value}' as usize")
|
||||
})?;
|
||||
if parsed == 0 {
|
||||
return Err(anyhow!("--chunk-token-budget must be greater than zero"));
|
||||
}
|
||||
config.chunk_token_budget = Some(parsed);
|
||||
}
|
||||
"--chunk-token-chars" => {
|
||||
let value = take_value("--chunk-token-chars", &mut args)?;
|
||||
let parsed = value.parse::<usize>().with_context(|| {
|
||||
format!("failed to parse --chunk-token-chars value '{value}' as usize")
|
||||
})?;
|
||||
if parsed == 0 {
|
||||
return Err(anyhow!("--chunk-token-chars must be greater than zero"));
|
||||
}
|
||||
config.chunk_avg_chars_per_token = Some(parsed);
|
||||
}
|
||||
"--max-chunks-per-entity" => {
|
||||
let value = take_value("--max-chunks-per-entity", &mut args)?;
|
||||
let parsed = value.parse::<usize>().with_context(|| {
|
||||
format!("failed to parse --max-chunks-per-entity value '{value}' as usize")
|
||||
})?;
|
||||
if parsed == 0 {
|
||||
return Err(anyhow!("--max-chunks-per-entity must be greater than zero"));
|
||||
}
|
||||
config.max_chunks_per_entity = Some(parsed);
|
||||
}
|
||||
"--embedding" => {
|
||||
let value = take_value("--embedding", &mut args)?;
|
||||
config.embedding_backend = value.parse()?;
|
||||
}
|
||||
"--embedding-model" => {
|
||||
let value = take_value("--embedding-model", &mut args)?;
|
||||
config.embedding_model = Some(value.trim().to_string());
|
||||
}
|
||||
"--cache-dir" => {
|
||||
let value = take_value("--cache-dir", &mut args)?;
|
||||
config.cache_dir = PathBuf::from(value);
|
||||
}
|
||||
"--ingestion-cache-dir" => {
|
||||
let value = take_value("--ingestion-cache-dir", &mut args)?;
|
||||
config.ingestion_cache_dir = PathBuf::from(value);
|
||||
}
|
||||
"--negative-multiplier" => {
|
||||
let value = take_value("--negative-multiplier", &mut args)?;
|
||||
let parsed = value.parse::<f32>().with_context(|| {
|
||||
format!("failed to parse --negative-multiplier value '{value}' as f32")
|
||||
})?;
|
||||
if !(parsed.is_finite() && parsed > 0.0) {
|
||||
return Err(anyhow!(
|
||||
"--negative-multiplier must be a positive finite number"
|
||||
));
|
||||
}
|
||||
config.negative_multiplier = parsed;
|
||||
}
|
||||
"--no-rerank" => {
|
||||
config.rerank = false;
|
||||
}
|
||||
"--rerank-pool" => {
|
||||
let value = take_value("--rerank-pool", &mut args)?;
|
||||
let parsed = value.parse::<usize>().with_context(|| {
|
||||
format!("failed to parse --rerank-pool value '{value}' as usize")
|
||||
})?;
|
||||
config.rerank_pool_size = parsed.max(1);
|
||||
}
|
||||
"--rerank-keep" => {
|
||||
let value = take_value("--rerank-keep", &mut args)?;
|
||||
let parsed = value.parse::<usize>().with_context(|| {
|
||||
format!("failed to parse --rerank-keep value '{value}' as usize")
|
||||
})?;
|
||||
config.rerank_keep_top = parsed.max(1);
|
||||
}
|
||||
"--concurrency" => {
|
||||
let value = take_value("--concurrency", &mut args)?;
|
||||
let parsed = value.parse::<usize>().with_context(|| {
|
||||
format!("failed to parse --concurrency value '{value}' as usize")
|
||||
})?;
|
||||
config.concurrency = parsed.max(1);
|
||||
}
|
||||
"--refresh-embeddings" => {
|
||||
config.refresh_embeddings_only = true;
|
||||
}
|
||||
"--detailed-report" => {
|
||||
config.detailed_report = true;
|
||||
}
|
||||
"--chunk-diagnostics" => {
|
||||
let value = take_value("--chunk-diagnostics", &mut args)?;
|
||||
config.chunk_diagnostics_path = Some(PathBuf::from(value));
|
||||
}
|
||||
"--inspect-question" => {
|
||||
let value = take_value("--inspect-question", &mut args)?;
|
||||
config.inspect_question = Some(value);
|
||||
}
|
||||
"--inspect-manifest" => {
|
||||
let value = take_value("--inspect-manifest", &mut args)?;
|
||||
config.inspect_manifest = Some(PathBuf::from(value));
|
||||
}
|
||||
"--inspect-db-state" => {
|
||||
let value = take_value("--inspect-db-state", &mut args)?;
|
||||
config.inspect_db_state = Some(PathBuf::from(value));
|
||||
}
|
||||
"--perf-log-json" => {
|
||||
let value = take_value("--perf-log-json", &mut args)?;
|
||||
config.perf_log_json = Some(PathBuf::from(value));
|
||||
}
|
||||
"--perf-log-dir" => {
|
||||
let value = take_value("--perf-log-dir", &mut args)?;
|
||||
config.perf_log_dir = Some(PathBuf::from(value));
|
||||
}
|
||||
"--perf-log" => {
|
||||
config.perf_log_console = true;
|
||||
}
|
||||
"--db-endpoint" => {
|
||||
let value = take_value("--db-endpoint", &mut args)?;
|
||||
config.db_endpoint = value;
|
||||
}
|
||||
"--db-user" => {
|
||||
let value = take_value("--db-user", &mut args)?;
|
||||
config.db_username = value;
|
||||
}
|
||||
"--db-pass" => {
|
||||
let value = take_value("--db-pass", &mut args)?;
|
||||
config.db_password = value;
|
||||
}
|
||||
"--db-namespace" => {
|
||||
let value = take_value("--db-namespace", &mut args)?;
|
||||
config.db_namespace = Some(value);
|
||||
}
|
||||
"--db-database" => {
|
||||
let value = take_value("--db-database", &mut args)?;
|
||||
config.db_database = Some(value);
|
||||
}
|
||||
unknown => {
|
||||
return Err(anyhow!(
|
||||
"unknown argument '{unknown}'. Use --help to see available options."
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if config.chunk_min_chars >= config.chunk_max_chars {
|
||||
return Err(anyhow!(
|
||||
"--chunk-min must be less than --chunk-max (got {} >= {})",
|
||||
config.chunk_min_chars,
|
||||
config.chunk_max_chars
|
||||
));
|
||||
}
|
||||
|
||||
if config.rerank && config.rerank_pool_size == 0 {
|
||||
return Err(anyhow!(
|
||||
"--rerank-pool must be greater than zero when reranking is enabled"
|
||||
));
|
||||
}
|
||||
|
||||
if config.concurrency == 0 {
|
||||
return Err(anyhow!("--concurrency must be greater than zero"));
|
||||
}
|
||||
|
||||
if config.embedding_backend == EmbeddingBackend::Hashed && config.embedding_model.is_some() {
|
||||
return Err(anyhow!(
|
||||
"--embedding-model cannot be used with the 'hashed' embedding backend"
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(limit) = config.limit {
|
||||
if let Some(corpus_limit) = config.corpus_limit {
|
||||
if corpus_limit < limit {
|
||||
config.corpus_limit = Some(limit);
|
||||
}
|
||||
} else {
|
||||
let default_multiplier = 10usize;
|
||||
let mut computed = limit.saturating_mul(default_multiplier);
|
||||
if computed < limit {
|
||||
computed = limit;
|
||||
}
|
||||
let max_cap = 1_000usize;
|
||||
if computed > max_cap {
|
||||
computed = max_cap;
|
||||
}
|
||||
config.corpus_limit = Some(computed);
|
||||
}
|
||||
}
|
||||
|
||||
if config.perf_log_dir.is_none() {
|
||||
if let Ok(dir) = env::var("EVAL_PERF_LOG_DIR") {
|
||||
if !dir.trim().is_empty() {
|
||||
config.perf_log_dir = Some(PathBuf::from(dir));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(endpoint) = env::var("EVAL_DB_ENDPOINT") {
|
||||
if !endpoint.trim().is_empty() {
|
||||
config.db_endpoint = endpoint;
|
||||
}
|
||||
}
|
||||
if let Ok(username) = env::var("EVAL_DB_USERNAME") {
|
||||
if !username.trim().is_empty() {
|
||||
config.db_username = username;
|
||||
}
|
||||
}
|
||||
if let Ok(password) = env::var("EVAL_DB_PASSWORD") {
|
||||
if !password.trim().is_empty() {
|
||||
config.db_password = password;
|
||||
}
|
||||
}
|
||||
if let Ok(ns) = env::var("EVAL_DB_NAMESPACE") {
|
||||
if !ns.trim().is_empty() {
|
||||
config.db_namespace = Some(ns);
|
||||
}
|
||||
}
|
||||
if let Ok(db) = env::var("EVAL_DB_DATABASE") {
|
||||
if !db.trim().is_empty() {
|
||||
config.db_database = Some(db);
|
||||
}
|
||||
}
|
||||
Ok(ParsedArgs { config, show_help })
|
||||
}
|
||||
|
||||
fn take_value<'a, I>(flag: &str, iter: &mut std::iter::Peekable<I>) -> Result<String>
|
||||
where
|
||||
I: Iterator<Item = String>,
|
||||
{
|
||||
iter.next().ok_or_else(|| anyhow!("{flag} expects a value"))
|
||||
}
|
||||
|
||||
pub fn print_help() {
|
||||
println!(
|
||||
"\
|
||||
eval — dataset conversion, ingestion, and retrieval evaluation CLI
|
||||
|
||||
USAGE:
|
||||
cargo eval -- [options]
|
||||
# or
|
||||
cargo run -p eval -- [options]
|
||||
|
||||
OPTIONS:
|
||||
--convert-only Convert the selected dataset and exit.
|
||||
--force, --refresh Regenerate the converted dataset even if it already exists.
|
||||
--dataset <name> Dataset to evaluate: 'squad' (default) or 'natural-questions'.
|
||||
--llm-mode Enable LLM-assisted evaluation features (includes unanswerable cases).
|
||||
--slice <id|path> Use a cached dataset slice by id (under eval/cache/slices) or by explicit path.
|
||||
--label <text> Annotate the run; label is stored in JSON/Markdown reports.
|
||||
--query-model <name> Override the SurrealDB system settings query model (e.g., gpt-4o-mini) for this run.
|
||||
--slice-grow <int> Grow the slice ledger to contain at least this many answerable cases, then exit.
|
||||
--slice-offset <int> Evaluate questions starting at this offset within the slice (default: 0).
|
||||
--reseed-slice Ignore cached corpus state and rebuild the slice's SurrealDB corpus.
|
||||
--slice-reset-ingestion
|
||||
Delete cached paragraph shards before rebuilding the ingestion corpus.
|
||||
--corpus-limit <int> Cap the slice corpus size (positives + negatives). Defaults to ~10× --limit, capped at 1000.
|
||||
--raw <path> Path to the raw dataset (defaults per dataset).
|
||||
--converted <path> Path to write/read the converted dataset (defaults per dataset).
|
||||
--report-dir <path> Directory to write evaluation reports (default: eval/reports).
|
||||
--k <int> Precision@k cutoff (default: 5).
|
||||
--limit <int> Limit the number of questions evaluated (default: 200, 0 = all).
|
||||
--sample <int> Number of mismatches to surface in the Markdown summary (default: 5).
|
||||
--full-context Disable context cropping when converting datasets (ingest entire documents).
|
||||
--chunk-min <int> Minimum characters per chunk for text splitting (default: 500).
|
||||
--chunk-max <int> Maximum characters per chunk for text splitting (default: 2000).
|
||||
--chunk-vector-take <int>
|
||||
Override chunk vector candidate cap (default: 20).
|
||||
--chunk-fts-take <int>
|
||||
Override chunk FTS candidate cap (default: 20).
|
||||
--chunk-token-budget <int>
|
||||
Override chunk token budget estimate for assembly (default: 10000).
|
||||
--chunk-token-chars <int>
|
||||
Override average characters per token used for budgeting (default: 4).
|
||||
--max-chunks-per-entity <int>
|
||||
Override maximum chunks attached per entity (default: 4).
|
||||
--embedding <name> Embedding backend: 'fastembed' (default) or 'hashed'.
|
||||
--embedding-model <code>
|
||||
FastEmbed model code (defaults to crate preset when omitted).
|
||||
--cache-dir <path> Directory for embedding caches (default: eval/cache).
|
||||
--ingestion-cache-dir <path>
|
||||
Directory for ingestion corpora caches (default: eval/cache/ingested).
|
||||
--negative-multiplier <float>
|
||||
Target negative-to-positive paragraph ratio for slice growth (default: 4.0).
|
||||
--refresh-embeddings Recompute embeddings for cached corpora without re-running ingestion.
|
||||
--detailed-report Include entity descriptions and categories in JSON reports.
|
||||
--chunk-diagnostics <path>
|
||||
Write per-query chunk diagnostics JSONL to the provided path.
|
||||
--no-rerank Disable the FastEmbed reranking stage (enabled by default).
|
||||
--rerank-pool <int> Reranking engine pool size / parallelism (default: 16).
|
||||
--rerank-keep <int> Keep top-N entities after reranking (default: 10).
|
||||
--inspect-question <id>
|
||||
Inspect an ingestion cache question and exit (requires --inspect-manifest).
|
||||
--inspect-manifest <path>
|
||||
Path to an ingestion cache manifest JSON for inspection mode.
|
||||
--inspect-db-state <path>
|
||||
Optional override for the SurrealDB state.json used during inspection; defaults to the state recorded for the selected dataset slice.
|
||||
--db-endpoint <url> SurrealDB server endpoint (use http:// or https:// to enable SurQL export/import; ws:// endpoints reuse existing namespaces but skip SurQL exports; default: ws://127.0.0.1:8000).
|
||||
--db-user <value> SurrealDB root username (default: root_user).
|
||||
--db-pass <value> SurrealDB root password (default: root_password).
|
||||
--db-namespace <ns> Override the namespace used on the SurrealDB server; state.json tracks this value and the ledger case count so changing it or requesting more cases via --limit triggers a rebuild/import (default: derived from dataset).
|
||||
--db-database <db> Override the database used on the SurrealDB server; recorded alongside namespace in state.json (default: derived from slice).
|
||||
--perf-log Print per-stage performance timings to stdout after the run.
|
||||
--perf-log-json <path>
|
||||
Write structured performance telemetry JSON to the provided path.
|
||||
--perf-log-dir <path>
|
||||
Directory that receives timestamped perf JSON copies (defaults to $EVAL_PERF_LOG_DIR).
|
||||
|
||||
Examples:
|
||||
cargo eval -- --dataset squad --limit 10 --detailed-report
|
||||
cargo eval -- --dataset natural-questions --limit 1 --rerank-pool 1 --detailed-report
|
||||
|
||||
Notes:
|
||||
The latest run's JSON/Markdown reports are saved as eval/reports/latest.json and latest.md, making it easy to script automated checks.
|
||||
-h, --help Show this help text.
|
||||
|
||||
Dataset defaults (from eval/manifest.yaml):
|
||||
squad raw: eval/data/raw/squad/dev-v2.0.json
|
||||
converted: eval/data/converted/squad-minne.json
|
||||
natural-questions raw: eval/data/raw/nq/dev-all.jsonl
|
||||
converted: eval/data/converted/nq-dev-minne.json
|
||||
"
|
||||
);
|
||||
}
|
||||
|
||||
pub fn ensure_parent(path: &Path) -> Result<()> {
|
||||
if let Some(parent) = path.parent() {
|
||||
std::fs::create_dir_all(parent)
|
||||
.with_context(|| format!("creating parent directory for {}", path.display()))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
88
eval/src/cache.rs
Normal file
88
eval/src/cache.rs
Normal file
@@ -0,0 +1,88 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
path::{Path, PathBuf},
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||
struct EmbeddingCacheData {
|
||||
entities: HashMap<String, Vec<f32>>,
|
||||
chunks: HashMap<String, Vec<f32>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EmbeddingCache {
|
||||
path: Arc<PathBuf>,
|
||||
data: Arc<Mutex<EmbeddingCacheData>>,
|
||||
dirty: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl EmbeddingCache {
|
||||
pub async fn load(path: impl AsRef<Path>) -> Result<Self> {
|
||||
let path = path.as_ref().to_path_buf();
|
||||
let data = if path.exists() {
|
||||
let raw = tokio::fs::read(&path)
|
||||
.await
|
||||
.with_context(|| format!("reading embedding cache {}", path.display()))?;
|
||||
serde_json::from_slice(&raw)
|
||||
.with_context(|| format!("parsing embedding cache {}", path.display()))?
|
||||
} else {
|
||||
EmbeddingCacheData::default()
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
path: Arc::new(path),
|
||||
data: Arc::new(Mutex::new(data)),
|
||||
dirty: Arc::new(AtomicBool::new(false)),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_entity(&self, id: &str) -> Option<Vec<f32>> {
|
||||
let guard = self.data.lock().await;
|
||||
guard.entities.get(id).cloned()
|
||||
}
|
||||
|
||||
pub async fn insert_entity(&self, id: String, embedding: Vec<f32>) {
|
||||
let mut guard = self.data.lock().await;
|
||||
guard.entities.insert(id, embedding);
|
||||
self.dirty.store(true, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub async fn get_chunk(&self, id: &str) -> Option<Vec<f32>> {
|
||||
let guard = self.data.lock().await;
|
||||
guard.chunks.get(id).cloned()
|
||||
}
|
||||
|
||||
pub async fn insert_chunk(&self, id: String, embedding: Vec<f32>) {
|
||||
let mut guard = self.data.lock().await;
|
||||
guard.chunks.insert(id, embedding);
|
||||
self.dirty.store(true, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub async fn persist(&self) -> Result<()> {
|
||||
if !self.dirty.load(Ordering::Relaxed) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let guard = self.data.lock().await;
|
||||
let body = serde_json::to_vec_pretty(&*guard).context("serialising embedding cache")?;
|
||||
if let Some(parent) = self.path.parent() {
|
||||
tokio::fs::create_dir_all(parent)
|
||||
.await
|
||||
.with_context(|| format!("creating cache directory {}", parent.display()))?;
|
||||
}
|
||||
tokio::fs::write(&*self.path, body)
|
||||
.await
|
||||
.with_context(|| format!("writing embedding cache {}", self.path.display()))?;
|
||||
self.dirty.store(false, Ordering::Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
1003
eval/src/datasets.rs
Normal file
1003
eval/src/datasets.rs
Normal file
File diff suppressed because it is too large
Load Diff
269
eval/src/db_helpers.rs
Normal file
269
eval/src/db_helpers.rs
Normal file
@@ -0,0 +1,269 @@
|
||||
use anyhow::{Context, Result};
|
||||
use common::storage::db::SurrealDbClient;
|
||||
|
||||
// Remove and recreate HNSW indexes for changing embedding lengths, used at beginning if embedding length differs from default system settings
|
||||
pub async fn change_embedding_length_in_hnsw_indexes(
|
||||
db: &SurrealDbClient,
|
||||
dimension: usize,
|
||||
) -> Result<()> {
|
||||
tracing::info!("Changing embedding length in HNSW indexes");
|
||||
let query = format!(
|
||||
"BEGIN TRANSACTION;
|
||||
REMOVE INDEX IF EXISTS idx_embedding_chunks ON TABLE text_chunk;
|
||||
REMOVE INDEX IF EXISTS idx_embedding_entities ON TABLE knowledge_entity;
|
||||
DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {dim};
|
||||
DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {dim};
|
||||
COMMIT TRANSACTION;",
|
||||
dim = dimension
|
||||
);
|
||||
|
||||
db.client
|
||||
.query(query)
|
||||
.await
|
||||
.context("changing HNSW indexes")?;
|
||||
tracing::info!("HNSW indexes successfully changed");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Helper functions for index management during namespace reseed
|
||||
pub async fn remove_all_indexes(db: &SurrealDbClient) -> Result<()> {
|
||||
tracing::info!("Removing ALL indexes before namespace reseed (aggressive approach)");
|
||||
|
||||
// Remove ALL indexes from ALL tables to ensure no cache access
|
||||
db.client
|
||||
.query(
|
||||
"BEGIN TRANSACTION;
|
||||
-- HNSW indexes
|
||||
REMOVE INDEX IF EXISTS idx_embedding_chunks ON TABLE text_chunk;
|
||||
REMOVE INDEX IF EXISTS idx_embedding_entities ON TABLE knowledge_entity;
|
||||
|
||||
-- FTS indexes on text_content (remove ALL of them)
|
||||
REMOVE INDEX IF EXISTS text_content_fts_idx ON TABLE text_content;
|
||||
REMOVE INDEX IF EXISTS text_content_fts_text_idx ON TABLE text_content;
|
||||
REMOVE INDEX IF EXISTS text_content_fts_category_idx ON TABLE text_content;
|
||||
REMOVE INDEX IF EXISTS text_content_fts_context_idx ON TABLE text_content;
|
||||
REMOVE INDEX IF EXISTS text_content_fts_file_name_idx ON TABLE text_content;
|
||||
REMOVE INDEX IF EXISTS text_content_fts_url_idx ON TABLE text_content;
|
||||
REMOVE INDEX IF EXISTS text_content_fts_url_title_idx ON TABLE text_content;
|
||||
|
||||
-- FTS indexes on knowledge_entity
|
||||
REMOVE INDEX IF EXISTS knowledge_entity_fts_name_idx ON TABLE knowledge_entity;
|
||||
REMOVE INDEX IF EXISTS knowledge_entity_fts_description_idx ON TABLE knowledge_entity;
|
||||
|
||||
-- FTS indexes on text_chunk
|
||||
REMOVE INDEX IF EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk;
|
||||
|
||||
COMMIT TRANSACTION;",
|
||||
)
|
||||
.await
|
||||
.context("removing all indexes before namespace reseed")?;
|
||||
|
||||
tracing::info!("All indexes removed before namespace reseed");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_tokenizer(db: &SurrealDbClient) -> Result<()> {
|
||||
tracing::info!("Creating FTS analyzers for namespace reseed");
|
||||
let res = db
|
||||
.client
|
||||
.query(
|
||||
"BEGIN TRANSACTION;
|
||||
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer
|
||||
TOKENIZERS class
|
||||
FILTERS lowercase, ascii, snowball(english);
|
||||
COMMIT TRANSACTION;",
|
||||
)
|
||||
.await
|
||||
.context("creating FTS analyzers for namespace reseed")?;
|
||||
|
||||
res.check().context("failed to create the tokenizer")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn recreate_indexes(db: &SurrealDbClient, dimension: usize) -> Result<()> {
|
||||
tracing::info!("Recreating ALL indexes after namespace reseed (SEQUENTIAL approach)");
|
||||
let total_start = std::time::Instant::now();
|
||||
|
||||
create_tokenizer(db)
|
||||
.await
|
||||
.context("creating FTS analyzer")?;
|
||||
|
||||
// For now we dont remove these plain indexes, we could if they prove negatively impacting performance
|
||||
// create_regular_indexes_for_snapshot(db)
|
||||
// .await
|
||||
// .context("creating regular indexes for namespace reseed")?;
|
||||
|
||||
let fts_start = std::time::Instant::now();
|
||||
create_fts_indexes_for_snapshot(db)
|
||||
.await
|
||||
.context("creating FTS indexes for namespace reseed")?;
|
||||
tracing::info!(duration = ?fts_start.elapsed(), "FTS indexes created");
|
||||
|
||||
let hnsw_start = std::time::Instant::now();
|
||||
create_hnsw_indexes_for_snapshot(db, dimension)
|
||||
.await
|
||||
.context("creating HNSW indexes for namespace reseed")?;
|
||||
tracing::info!(duration = ?hnsw_start.elapsed(), "HNSW indexes created");
|
||||
|
||||
tracing::info!(duration = ?total_start.elapsed(), "All index groups recreated successfully in sequence");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(dead_code)] // For now we dont do this. We could
|
||||
async fn create_regular_indexes_for_snapshot(db: &SurrealDbClient) -> Result<()> {
|
||||
tracing::info!("Creating regular indexes for namespace reseed (parallel group 1)");
|
||||
let res = db
|
||||
.client
|
||||
.query(
|
||||
"BEGIN TRANSACTION;
|
||||
DEFINE INDEX text_content_user_id_idx ON text_content FIELDS user_id;
|
||||
DEFINE INDEX text_content_created_at_idx ON text_content FIELDS created_at;
|
||||
DEFINE INDEX text_content_category_idx ON text_content FIELDS category;
|
||||
DEFINE INDEX text_chunk_source_id_idx ON text_chunk FIELDS source_id;
|
||||
DEFINE INDEX text_chunk_user_id_idx ON text_chunk FIELDS user_id;
|
||||
DEFINE INDEX knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;
|
||||
DEFINE INDEX knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;
|
||||
DEFINE INDEX knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;
|
||||
DEFINE INDEX knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;
|
||||
COMMIT TRANSACTION;",
|
||||
)
|
||||
.await
|
||||
.context("creating regular indexes for namespace reseed")?;
|
||||
|
||||
res.check().context("one of the regular indexes failed")?;
|
||||
|
||||
tracing::info!("Regular indexes for namespace reseed created");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_fts_indexes_for_snapshot(db: &SurrealDbClient) -> Result<()> {
|
||||
tracing::info!("Creating FTS indexes for namespace reseed (group 2)");
|
||||
let res = db.client
|
||||
.query(
|
||||
"BEGIN TRANSACTION;
|
||||
DEFINE INDEX text_content_fts_idx ON TABLE text_content FIELDS text;
|
||||
DEFINE INDEX knowledge_entity_fts_name_idx ON TABLE knowledge_entity FIELDS name
|
||||
SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
DEFINE INDEX knowledge_entity_fts_description_idx ON TABLE knowledge_entity FIELDS description
|
||||
SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
DEFINE INDEX text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk
|
||||
SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
COMMIT TRANSACTION;",
|
||||
)
|
||||
.await
|
||||
.context("sending FTS index creation query")?;
|
||||
|
||||
// This actually surfaces statement-level errors
|
||||
res.check()
|
||||
.context("one or more FTS index statements failed")?;
|
||||
|
||||
tracing::info!("FTS indexes for namespace reseed created");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_hnsw_indexes_for_snapshot(db: &SurrealDbClient, dimension: usize) -> Result<()> {
|
||||
tracing::info!("Creating HNSW indexes for namespace reseed (group 3)");
|
||||
let query = format!(
|
||||
"BEGIN TRANSACTION;
|
||||
DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {dim};
|
||||
DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {dim};
|
||||
COMMIT TRANSACTION;",
|
||||
dim = dimension
|
||||
);
|
||||
|
||||
let res = db
|
||||
.client
|
||||
.query(query)
|
||||
.await
|
||||
.context("creating HNSW indexes for namespace reseed")?;
|
||||
|
||||
res.check()
|
||||
.context("one or more HNSW index statements failed")?;
|
||||
|
||||
tracing::info!("HNSW indexes for namespace reseed created");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn reset_namespace(db: &SurrealDbClient, namespace: &str, database: &str) -> Result<()> {
|
||||
let query = format!(
|
||||
"REMOVE NAMESPACE {ns};
|
||||
DEFINE NAMESPACE {ns};
|
||||
DEFINE DATABASE {db};",
|
||||
ns = namespace,
|
||||
db = database
|
||||
);
|
||||
db.client
|
||||
.query(query)
|
||||
.await
|
||||
.context("resetting SurrealDB namespace")?;
|
||||
db.client
|
||||
.use_ns(namespace)
|
||||
.use_db(database)
|
||||
.await
|
||||
.context("selecting namespace/database after reset")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde::Deserialize;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct FooRow {
|
||||
label: String,
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reset_namespace_drops_existing_rows() {
|
||||
let namespace = format!("reset_ns_{}", Uuid::new_v4().simple());
|
||||
let database = format!("reset_db_{}", Uuid::new_v4().simple());
|
||||
let db = SurrealDbClient::memory(&namespace, &database)
|
||||
.await
|
||||
.expect("in-memory db");
|
||||
|
||||
db.client
|
||||
.query(
|
||||
"DEFINE TABLE foo SCHEMALESS;
|
||||
CREATE foo:foo SET label = 'before';",
|
||||
)
|
||||
.await
|
||||
.expect("seed namespace")
|
||||
.check()
|
||||
.expect("seed response");
|
||||
|
||||
let mut before = db
|
||||
.client
|
||||
.query("SELECT * FROM foo")
|
||||
.await
|
||||
.expect("select before reset");
|
||||
let existing: Vec<FooRow> = before.take(0).expect("rows before reset");
|
||||
assert_eq!(existing.len(), 1);
|
||||
assert_eq!(existing[0].label, "before");
|
||||
|
||||
reset_namespace(&db, &namespace, &database)
|
||||
.await
|
||||
.expect("namespace reset");
|
||||
|
||||
match db.client.query("SELECT * FROM foo").await {
|
||||
Ok(mut response) => {
|
||||
let rows: Vec<FooRow> = response.take(0).unwrap_or_default();
|
||||
assert!(
|
||||
rows.is_empty(),
|
||||
"reset namespace should drop rows, found {:?}",
|
||||
rows
|
||||
);
|
||||
}
|
||||
Err(error) => {
|
||||
let message = error.to_string();
|
||||
assert!(
|
||||
message.to_ascii_lowercase().contains("table")
|
||||
|| message.to_ascii_lowercase().contains("namespace")
|
||||
|| message.to_ascii_lowercase().contains("foo"),
|
||||
"unexpected error after namespace reset: {message}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
171
eval/src/embedding.rs
Normal file
171
eval/src/embedding.rs
Normal file
@@ -0,0 +1,171 @@
|
||||
use std::{
|
||||
collections::hash_map::DefaultHasher,
|
||||
hash::{Hash, Hasher},
|
||||
str::FromStr,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::args::{Config, EmbeddingBackend};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EmbeddingProvider {
|
||||
inner: EmbeddingInner,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum EmbeddingInner {
|
||||
Hashed {
|
||||
dimension: usize,
|
||||
},
|
||||
FastEmbed {
|
||||
model: Arc<Mutex<TextEmbedding>>,
|
||||
model_name: EmbeddingModel,
|
||||
dimension: usize,
|
||||
},
|
||||
}
|
||||
|
||||
impl EmbeddingProvider {
|
||||
pub fn backend_label(&self) -> &'static str {
|
||||
match self.inner {
|
||||
EmbeddingInner::Hashed { .. } => "hashed",
|
||||
EmbeddingInner::FastEmbed { .. } => "fastembed",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dimension(&self) -> usize {
|
||||
match &self.inner {
|
||||
EmbeddingInner::Hashed { dimension } => *dimension,
|
||||
EmbeddingInner::FastEmbed { dimension, .. } => *dimension,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn model_code(&self) -> Option<String> {
|
||||
match &self.inner {
|
||||
EmbeddingInner::FastEmbed { model_name, .. } => Some(model_name.to_string()),
|
||||
EmbeddingInner::Hashed { .. } => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
|
||||
match &self.inner {
|
||||
EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)),
|
||||
EmbeddingInner::FastEmbed { model, .. } => {
|
||||
let mut guard = model.lock().await;
|
||||
let embeddings = guard
|
||||
.embed(vec![text.to_owned()], None)
|
||||
.context("generating fastembed vector")?;
|
||||
embeddings
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("fastembed returned no embedding for input"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
|
||||
match &self.inner {
|
||||
EmbeddingInner::Hashed { dimension } => Ok(texts
|
||||
.into_iter()
|
||||
.map(|text| hashed_embedding(&text, *dimension))
|
||||
.collect()),
|
||||
EmbeddingInner::FastEmbed { model, .. } => {
|
||||
if texts.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut guard = model.lock().await;
|
||||
guard
|
||||
.embed(texts, None)
|
||||
.context("generating fastembed batch embeddings")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn build_provider(
|
||||
config: &Config,
|
||||
default_dimension: usize,
|
||||
) -> Result<EmbeddingProvider> {
|
||||
match config.embedding_backend {
|
||||
EmbeddingBackend::Hashed => Ok(EmbeddingProvider {
|
||||
inner: EmbeddingInner::Hashed {
|
||||
dimension: default_dimension.max(1),
|
||||
},
|
||||
}),
|
||||
EmbeddingBackend::FastEmbed => {
|
||||
let model_name = if let Some(code) = config.embedding_model.as_deref() {
|
||||
EmbeddingModel::from_str(code).map_err(|err| anyhow!(err))?
|
||||
} else {
|
||||
EmbeddingModel::default()
|
||||
};
|
||||
|
||||
let options =
|
||||
TextInitOptions::new(model_name.clone()).with_show_download_progress(true);
|
||||
let model_name_for_task = model_name.clone();
|
||||
let model_name_code = model_name.to_string();
|
||||
|
||||
let (model, dimension) = tokio::task::spawn_blocking(move || -> Result<_> {
|
||||
let model =
|
||||
TextEmbedding::try_new(options).context("initialising FastEmbed text model")?;
|
||||
let info =
|
||||
EmbeddingModel::get_model_info(&model_name_for_task).ok_or_else(|| {
|
||||
anyhow!("FastEmbed model metadata missing for {model_name_code}")
|
||||
})?;
|
||||
Ok((model, info.dim))
|
||||
})
|
||||
.await
|
||||
.context("joining FastEmbed initialisation task")??;
|
||||
|
||||
Ok(EmbeddingProvider {
|
||||
inner: EmbeddingInner::FastEmbed {
|
||||
model: Arc::new(Mutex::new(model)),
|
||||
model_name,
|
||||
dimension,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn hashed_embedding(text: &str, dimension: usize) -> Vec<f32> {
|
||||
let dim = dimension.max(1);
|
||||
let mut vector = vec![0.0f32; dim];
|
||||
if text.is_empty() {
|
||||
return vector;
|
||||
}
|
||||
|
||||
let mut token_count = 0f32;
|
||||
for token in tokens(text) {
|
||||
token_count += 1.0;
|
||||
let idx = bucket(&token, dim);
|
||||
vector[idx] += 1.0;
|
||||
}
|
||||
|
||||
if token_count == 0.0 {
|
||||
return vector;
|
||||
}
|
||||
|
||||
let norm = vector.iter().map(|v| v * v).sum::<f32>().sqrt();
|
||||
if norm > 0.0 {
|
||||
for value in &mut vector {
|
||||
*value /= norm;
|
||||
}
|
||||
}
|
||||
|
||||
vector
|
||||
}
|
||||
|
||||
fn tokens(text: &str) -> impl Iterator<Item = String> + '_ {
|
||||
text.split(|c: char| !c.is_ascii_alphanumeric())
|
||||
.filter(|token| !token.is_empty())
|
||||
.map(|token| token.to_ascii_lowercase())
|
||||
}
|
||||
|
||||
fn bucket(token: &str, dimension: usize) -> usize {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
token.hash(&mut hasher);
|
||||
(hasher.finish() as usize) % dimension
|
||||
}
|
||||
767
eval/src/eval/mod.rs
Normal file
767
eval/src/eval/mod.rs
Normal file
@@ -0,0 +1,767 @@
|
||||
mod pipeline;
|
||||
|
||||
pub use pipeline::run_evaluation;
|
||||
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
path::Path,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use chrono::{DateTime, SecondsFormat, Utc};
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{system_settings::SystemSettings, user::User},
|
||||
},
|
||||
};
|
||||
use composite_retrieval::pipeline as retrieval_pipeline;
|
||||
use composite_retrieval::pipeline::PipelineStageTimings;
|
||||
use composite_retrieval::pipeline::RetrievalTuning;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::{
|
||||
args::{self, Config},
|
||||
datasets::{self, ConvertedDataset},
|
||||
db_helpers::change_embedding_length_in_hnsw_indexes,
|
||||
ingest,
|
||||
slice::{self},
|
||||
snapshot::{self, DbSnapshotState},
|
||||
};
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct EvaluationSummary {
|
||||
pub generated_at: DateTime<Utc>,
|
||||
pub k: usize,
|
||||
pub limit: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub run_label: Option<String>,
|
||||
pub total_cases: usize,
|
||||
pub correct: usize,
|
||||
pub precision: f64,
|
||||
pub correct_at_1: usize,
|
||||
pub correct_at_2: usize,
|
||||
pub correct_at_3: usize,
|
||||
pub precision_at_1: f64,
|
||||
pub precision_at_2: f64,
|
||||
pub precision_at_3: f64,
|
||||
pub duration_ms: u128,
|
||||
pub dataset_id: String,
|
||||
pub dataset_label: String,
|
||||
pub dataset_includes_unanswerable: bool,
|
||||
pub dataset_source: String,
|
||||
pub slice_id: String,
|
||||
pub slice_seed: u64,
|
||||
pub slice_total_cases: usize,
|
||||
pub slice_window_offset: usize,
|
||||
pub slice_window_length: usize,
|
||||
pub slice_cases: usize,
|
||||
pub slice_positive_paragraphs: usize,
|
||||
pub slice_negative_paragraphs: usize,
|
||||
pub slice_total_paragraphs: usize,
|
||||
pub slice_negative_multiplier: f32,
|
||||
pub namespace_reused: bool,
|
||||
pub corpus_paragraphs: usize,
|
||||
pub ingestion_cache_path: String,
|
||||
pub ingestion_reused: bool,
|
||||
pub ingestion_embeddings_reused: bool,
|
||||
pub ingestion_fingerprint: String,
|
||||
pub positive_paragraphs_reused: usize,
|
||||
pub negative_paragraphs_reused: usize,
|
||||
pub latency_ms: LatencyStats,
|
||||
pub perf: PerformanceTimings,
|
||||
pub embedding_backend: String,
|
||||
pub embedding_model: Option<String>,
|
||||
pub embedding_dimension: usize,
|
||||
pub rerank_enabled: bool,
|
||||
pub rerank_pool_size: Option<usize>,
|
||||
pub rerank_keep_top: usize,
|
||||
pub concurrency: usize,
|
||||
pub detailed_report: bool,
|
||||
pub chunk_vector_take: usize,
|
||||
pub chunk_fts_take: usize,
|
||||
pub chunk_token_budget: usize,
|
||||
pub chunk_avg_chars_per_token: usize,
|
||||
pub max_chunks_per_entity: usize,
|
||||
pub cases: Vec<CaseSummary>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct CaseSummary {
|
||||
pub question_id: String,
|
||||
pub question: String,
|
||||
pub paragraph_id: String,
|
||||
pub paragraph_title: String,
|
||||
pub expected_source: String,
|
||||
pub answers: Vec<String>,
|
||||
pub matched: bool,
|
||||
pub entity_match: bool,
|
||||
pub chunk_text_match: bool,
|
||||
pub chunk_id_match: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub match_rank: Option<usize>,
|
||||
pub latency_ms: u128,
|
||||
pub retrieved: Vec<RetrievedSummary>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LatencyStats {
|
||||
pub avg: f64,
|
||||
pub p50: u128,
|
||||
pub p95: u128,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct StageLatencyBreakdown {
|
||||
pub collect_candidates: LatencyStats,
|
||||
pub graph_expansion: LatencyStats,
|
||||
pub chunk_attach: LatencyStats,
|
||||
pub rerank: LatencyStats,
|
||||
pub assemble: LatencyStats,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, Serialize)]
|
||||
pub struct EvaluationStageTimings {
|
||||
pub prepare_slice_ms: u128,
|
||||
pub prepare_db_ms: u128,
|
||||
pub prepare_corpus_ms: u128,
|
||||
pub prepare_namespace_ms: u128,
|
||||
pub run_queries_ms: u128,
|
||||
pub summarize_ms: u128,
|
||||
pub finalize_ms: u128,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct PerformanceTimings {
|
||||
pub openai_base_url: String,
|
||||
pub ingestion_ms: u128,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub namespace_seed_ms: Option<u128>,
|
||||
pub evaluation_stage_ms: EvaluationStageTimings,
|
||||
pub stage_latency: StageLatencyBreakdown,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct RetrievedSummary {
|
||||
pub rank: usize,
|
||||
pub entity_id: String,
|
||||
pub source_id: String,
|
||||
pub entity_name: String,
|
||||
pub score: f32,
|
||||
pub matched: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub entity_description: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub entity_category: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub chunk_text_match: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub chunk_id_match: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub(crate) struct CaseDiagnostics {
|
||||
question_id: String,
|
||||
question: String,
|
||||
paragraph_id: String,
|
||||
paragraph_title: String,
|
||||
expected_source: String,
|
||||
expected_chunk_ids: Vec<String>,
|
||||
answers: Vec<String>,
|
||||
entity_match: bool,
|
||||
chunk_text_match: bool,
|
||||
chunk_id_match: bool,
|
||||
failure_reasons: Vec<String>,
|
||||
missing_expected_chunk_ids: Vec<String>,
|
||||
attached_chunk_ids: Vec<String>,
|
||||
retrieved: Vec<EntityDiagnostics>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pipeline: Option<retrieval_pipeline::PipelineDiagnostics>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct EntityDiagnostics {
|
||||
rank: usize,
|
||||
entity_id: String,
|
||||
source_id: String,
|
||||
name: String,
|
||||
score: f32,
|
||||
entity_match: bool,
|
||||
chunk_text_match: bool,
|
||||
chunk_id_match: bool,
|
||||
chunks: Vec<ChunkDiagnosticsEntry>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ChunkDiagnosticsEntry {
|
||||
chunk_id: String,
|
||||
score: f32,
|
||||
contains_answer: bool,
|
||||
expected_chunk: bool,
|
||||
snippet: String,
|
||||
}
|
||||
|
||||
pub(crate) struct SeededCase {
|
||||
question_id: String,
|
||||
question: String,
|
||||
expected_source: String,
|
||||
answers: Vec<String>,
|
||||
paragraph_id: String,
|
||||
paragraph_title: String,
|
||||
expected_chunk_ids: Vec<String>,
|
||||
}
|
||||
|
||||
pub(crate) fn cases_from_manifest(manifest: &ingest::CorpusManifest) -> Vec<SeededCase> {
|
||||
let mut title_map = HashMap::new();
|
||||
for paragraph in &manifest.paragraphs {
|
||||
title_map.insert(paragraph.paragraph_id.as_str(), paragraph.title.clone());
|
||||
}
|
||||
|
||||
manifest
|
||||
.questions
|
||||
.iter()
|
||||
.filter(|question| !question.is_impossible)
|
||||
.map(|question| {
|
||||
let title = title_map
|
||||
.get(question.paragraph_id.as_str())
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "Untitled".to_string());
|
||||
SeededCase {
|
||||
question_id: question.question_id.clone(),
|
||||
question: question.question_text.clone(),
|
||||
expected_source: question.text_content_id.clone(),
|
||||
answers: question.answers.clone(),
|
||||
paragraph_id: question.paragraph_id.clone(),
|
||||
paragraph_title: title,
|
||||
expected_chunk_ids: question.matching_chunk_ids.clone(),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub(crate) fn text_contains_answer(text: &str, answers: &[String]) -> bool {
|
||||
if answers.is_empty() {
|
||||
return true;
|
||||
}
|
||||
let haystack = text.to_ascii_lowercase();
|
||||
answers.iter().any(|needle| haystack.contains(needle))
|
||||
}
|
||||
|
||||
pub(crate) fn compute_latency_stats(latencies: &[u128]) -> LatencyStats {
|
||||
if latencies.is_empty() {
|
||||
return LatencyStats {
|
||||
avg: 0.0,
|
||||
p50: 0,
|
||||
p95: 0,
|
||||
};
|
||||
}
|
||||
let mut sorted = latencies.to_vec();
|
||||
sorted.sort_unstable();
|
||||
let sum: u128 = sorted.iter().copied().sum();
|
||||
let avg = sum as f64 / (sorted.len() as f64);
|
||||
let p50 = percentile(&sorted, 0.50);
|
||||
let p95 = percentile(&sorted, 0.95);
|
||||
LatencyStats { avg, p50, p95 }
|
||||
}
|
||||
|
||||
pub(crate) fn build_stage_latency_breakdown(
|
||||
samples: &[PipelineStageTimings],
|
||||
) -> StageLatencyBreakdown {
|
||||
fn collect_stage<F>(samples: &[PipelineStageTimings], selector: F) -> Vec<u128>
|
||||
where
|
||||
F: Fn(&PipelineStageTimings) -> u128,
|
||||
{
|
||||
samples.iter().map(selector).collect()
|
||||
}
|
||||
|
||||
StageLatencyBreakdown {
|
||||
collect_candidates: compute_latency_stats(&collect_stage(samples, |entry| {
|
||||
entry.collect_candidates_ms
|
||||
})),
|
||||
graph_expansion: compute_latency_stats(&collect_stage(samples, |entry| {
|
||||
entry.graph_expansion_ms
|
||||
})),
|
||||
chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| entry.chunk_attach_ms)),
|
||||
rerank: compute_latency_stats(&collect_stage(samples, |entry| entry.rerank_ms)),
|
||||
assemble: compute_latency_stats(&collect_stage(samples, |entry| entry.assemble_ms)),
|
||||
}
|
||||
}
|
||||
|
||||
fn percentile(sorted: &[u128], fraction: f64) -> u128 {
|
||||
if sorted.is_empty() {
|
||||
return 0;
|
||||
}
|
||||
let clamped = fraction.clamp(0.0, 1.0);
|
||||
let idx = (clamped * (sorted.len() as f64 - 1.0)).round() as usize;
|
||||
sorted[idx.min(sorted.len() - 1)]
|
||||
}
|
||||
|
||||
pub async fn grow_slice(dataset: &ConvertedDataset, config: &Config) -> Result<()> {
|
||||
let ledger_limit = ledger_target(config);
|
||||
let slice_settings = slice::slice_config_with_limit(config, ledger_limit);
|
||||
let slice =
|
||||
slice::resolve_slice(dataset, &slice_settings).context("resolving dataset slice")?;
|
||||
info!(
|
||||
slice = slice.manifest.slice_id.as_str(),
|
||||
cases = slice.manifest.case_count,
|
||||
positives = slice.manifest.positive_paragraphs,
|
||||
negatives = slice.manifest.negative_paragraphs,
|
||||
total_paragraphs = slice.manifest.total_paragraphs,
|
||||
"Slice ledger ready"
|
||||
);
|
||||
println!(
|
||||
"Slice `{}` now contains {} questions ({} positives, {} negatives)",
|
||||
slice.manifest.slice_id,
|
||||
slice.manifest.case_count,
|
||||
slice.manifest.positive_paragraphs,
|
||||
slice.manifest.negative_paragraphs
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn ledger_target(config: &Config) -> Option<usize> {
|
||||
match (config.slice_grow, config.limit) {
|
||||
(Some(grow), Some(limit)) => Some(limit.max(grow)),
|
||||
(Some(grow), None) => Some(grow),
|
||||
(None, limit) => limit,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn apply_dataset_tuning_overrides(
|
||||
dataset: &ConvertedDataset,
|
||||
config: &Config,
|
||||
tuning: &mut RetrievalTuning,
|
||||
) {
|
||||
let is_long_form = dataset
|
||||
.metadata
|
||||
.id
|
||||
.to_ascii_lowercase()
|
||||
.contains("natural-questions");
|
||||
if !is_long_form {
|
||||
return;
|
||||
}
|
||||
|
||||
if config.chunk_vector_take.is_none() {
|
||||
tuning.chunk_vector_take = tuning.chunk_vector_take.max(80);
|
||||
}
|
||||
if config.chunk_fts_take.is_none() {
|
||||
tuning.chunk_fts_take = tuning.chunk_fts_take.max(80);
|
||||
}
|
||||
if config.chunk_token_budget.is_none() {
|
||||
tuning.token_budget_estimate = tuning.token_budget_estimate.max(20_000);
|
||||
}
|
||||
if config.max_chunks_per_entity.is_none() {
|
||||
tuning.max_chunks_per_entity = tuning.max_chunks_per_entity.max(12);
|
||||
}
|
||||
if tuning.lexical_match_weight < 0.25 {
|
||||
tuning.lexical_match_weight = 0.3;
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn build_case_diagnostics(
|
||||
summary: &CaseSummary,
|
||||
expected_chunk_ids: &[String],
|
||||
answers_lower: &[String],
|
||||
entities: &[composite_retrieval::RetrievedEntity],
|
||||
pipeline_stats: Option<retrieval_pipeline::PipelineDiagnostics>,
|
||||
) -> CaseDiagnostics {
|
||||
let expected_set: HashSet<&str> = expected_chunk_ids.iter().map(|id| id.as_str()).collect();
|
||||
let mut seen_chunks: HashSet<String> = HashSet::new();
|
||||
let mut attached_chunk_ids = Vec::new();
|
||||
let mut entity_diagnostics = Vec::new();
|
||||
|
||||
for (idx, entity) in entities.iter().enumerate() {
|
||||
let mut chunk_entries = Vec::new();
|
||||
for chunk in &entity.chunks {
|
||||
let contains_answer = text_contains_answer(&chunk.chunk.chunk, answers_lower);
|
||||
let expected_chunk = expected_set.contains(chunk.chunk.id.as_str());
|
||||
seen_chunks.insert(chunk.chunk.id.clone());
|
||||
attached_chunk_ids.push(chunk.chunk.id.clone());
|
||||
chunk_entries.push(ChunkDiagnosticsEntry {
|
||||
chunk_id: chunk.chunk.id.clone(),
|
||||
score: chunk.score,
|
||||
contains_answer,
|
||||
expected_chunk,
|
||||
snippet: chunk_preview(&chunk.chunk.chunk),
|
||||
});
|
||||
}
|
||||
entity_diagnostics.push(EntityDiagnostics {
|
||||
rank: idx + 1,
|
||||
entity_id: entity.entity.id.clone(),
|
||||
source_id: entity.entity.source_id.clone(),
|
||||
name: entity.entity.name.clone(),
|
||||
score: entity.score,
|
||||
entity_match: entity.entity.source_id == summary.expected_source,
|
||||
chunk_text_match: chunk_entries.iter().any(|entry| entry.contains_answer),
|
||||
chunk_id_match: chunk_entries.iter().any(|entry| entry.expected_chunk),
|
||||
chunks: chunk_entries,
|
||||
});
|
||||
}
|
||||
|
||||
let missing_expected_chunk_ids = expected_chunk_ids
|
||||
.iter()
|
||||
.filter(|id| !seen_chunks.contains(id.as_str()))
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut failure_reasons = Vec::new();
|
||||
if !summary.entity_match {
|
||||
failure_reasons.push("entity_miss".to_string());
|
||||
}
|
||||
if !summary.chunk_text_match {
|
||||
failure_reasons.push("chunk_text_missing".to_string());
|
||||
}
|
||||
if !summary.chunk_id_match {
|
||||
failure_reasons.push("chunk_id_missing".to_string());
|
||||
}
|
||||
if !missing_expected_chunk_ids.is_empty() {
|
||||
failure_reasons.push("expected_chunk_absent".to_string());
|
||||
}
|
||||
|
||||
CaseDiagnostics {
|
||||
question_id: summary.question_id.clone(),
|
||||
question: summary.question.clone(),
|
||||
paragraph_id: summary.paragraph_id.clone(),
|
||||
paragraph_title: summary.paragraph_title.clone(),
|
||||
expected_source: summary.expected_source.clone(),
|
||||
expected_chunk_ids: expected_chunk_ids.to_vec(),
|
||||
answers: summary.answers.clone(),
|
||||
entity_match: summary.entity_match,
|
||||
chunk_text_match: summary.chunk_text_match,
|
||||
chunk_id_match: summary.chunk_id_match,
|
||||
failure_reasons,
|
||||
missing_expected_chunk_ids,
|
||||
attached_chunk_ids,
|
||||
retrieved: entity_diagnostics,
|
||||
pipeline: pipeline_stats,
|
||||
}
|
||||
}
|
||||
|
||||
fn chunk_preview(text: &str) -> String {
|
||||
text.chars()
|
||||
.take(200)
|
||||
.collect::<String>()
|
||||
.replace('\n', " ")
|
||||
}
|
||||
|
||||
pub(crate) async fn write_chunk_diagnostics(path: &Path, cases: &[CaseDiagnostics]) -> Result<()> {
|
||||
args::ensure_parent(path)?;
|
||||
let mut file = tokio::fs::File::create(path)
|
||||
.await
|
||||
.with_context(|| format!("creating diagnostics file {}", path.display()))?;
|
||||
for case in cases {
|
||||
let line = serde_json::to_vec(case).context("serialising chunk diagnostics entry")?;
|
||||
file.write_all(&line).await?;
|
||||
file.write_all(b"\n").await?;
|
||||
}
|
||||
file.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn warm_hnsw_cache(db: &SurrealDbClient, dimension: usize) -> Result<()> {
|
||||
// Create a dummy embedding for cache warming
|
||||
let dummy_embedding: Vec<f32> = (0..dimension).map(|i| (i as f32).sin()).collect();
|
||||
|
||||
info!("Warming HNSW caches with sample queries");
|
||||
|
||||
// Warm up chunk index
|
||||
let _ = db
|
||||
.client
|
||||
.query("SELECT * FROM text_chunk WHERE embedding <|1,1|> $embedding LIMIT 5")
|
||||
.bind(("embedding", dummy_embedding.clone()))
|
||||
.await
|
||||
.context("warming text chunk HNSW cache")?;
|
||||
|
||||
// Warm up entity index
|
||||
let _ = db
|
||||
.client
|
||||
.query("SELECT * FROM knowledge_entity WHERE embedding <|1,1|> $embedding LIMIT 5")
|
||||
.bind(("embedding", dummy_embedding))
|
||||
.await
|
||||
.context("warming knowledge entity HNSW cache")?;
|
||||
|
||||
info!("HNSW cache warming completed");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn ensure_eval_user(db: &SurrealDbClient) -> Result<User> {
|
||||
let timestamp = datasets::base_timestamp();
|
||||
let user = User {
|
||||
id: "eval-user".to_string(),
|
||||
created_at: timestamp,
|
||||
updated_at: timestamp,
|
||||
email: "eval-retrieval@minne.dev".to_string(),
|
||||
password: "not-used".to_string(),
|
||||
anonymous: false,
|
||||
api_key: None,
|
||||
admin: false,
|
||||
timezone: "UTC".to_string(),
|
||||
};
|
||||
|
||||
if let Some(existing) = db.get_item::<User>(&user.id).await? {
|
||||
return Ok(existing);
|
||||
}
|
||||
|
||||
db.store_item(user.clone())
|
||||
.await
|
||||
.context("storing evaluation user")?;
|
||||
Ok(user)
|
||||
}
|
||||
|
||||
pub fn format_timestamp(timestamp: &DateTime<Utc>) -> String {
|
||||
timestamp.to_rfc3339_opts(SecondsFormat::Secs, true)
|
||||
}
|
||||
|
||||
pub(crate) fn sanitize_model_code(code: &str) -> String {
|
||||
code.chars()
|
||||
.map(|ch| {
|
||||
if ch.is_ascii_alphanumeric() {
|
||||
ch.to_ascii_lowercase()
|
||||
} else {
|
||||
'_'
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub(crate) async fn connect_eval_db(
|
||||
config: &Config,
|
||||
namespace: &str,
|
||||
database: &str,
|
||||
) -> Result<SurrealDbClient> {
|
||||
match SurrealDbClient::new(
|
||||
&config.db_endpoint,
|
||||
&config.db_username,
|
||||
&config.db_password,
|
||||
namespace,
|
||||
database,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(client) => {
|
||||
info!(
|
||||
endpoint = %config.db_endpoint,
|
||||
namespace,
|
||||
database,
|
||||
auth = "root",
|
||||
"Connected to SurrealDB"
|
||||
);
|
||||
Ok(client)
|
||||
}
|
||||
Err(root_err) => {
|
||||
info!(
|
||||
endpoint = %config.db_endpoint,
|
||||
namespace,
|
||||
database,
|
||||
"Root authentication failed; trying namespace-level auth"
|
||||
);
|
||||
let namespace_client = SurrealDbClient::new_with_namespace_user(
|
||||
&config.db_endpoint,
|
||||
namespace,
|
||||
&config.db_username,
|
||||
&config.db_password,
|
||||
database,
|
||||
)
|
||||
.await
|
||||
.map_err(|ns_err| {
|
||||
anyhow!(
|
||||
"failed to connect to SurrealDB via root ({root_err}) or namespace ({ns_err}) credentials"
|
||||
)
|
||||
})?;
|
||||
info!(
|
||||
endpoint = %config.db_endpoint,
|
||||
namespace,
|
||||
database,
|
||||
auth = "namespace",
|
||||
"Connected to SurrealDB"
|
||||
);
|
||||
Ok(namespace_client)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn namespace_has_corpus(db: &SurrealDbClient) -> Result<bool> {
|
||||
#[derive(Deserialize)]
|
||||
struct CountRow {
|
||||
count: i64,
|
||||
}
|
||||
|
||||
let mut response = db
|
||||
.client
|
||||
.query("SELECT count() AS count FROM text_chunk")
|
||||
.await
|
||||
.context("checking namespace corpus state")?;
|
||||
let rows: Vec<CountRow> = response.take(0).unwrap_or_default();
|
||||
Ok(rows.first().map(|row| row.count).unwrap_or(0) > 0)
|
||||
}
|
||||
|
||||
pub(crate) async fn can_reuse_namespace(
|
||||
db: &SurrealDbClient,
|
||||
descriptor: &snapshot::Descriptor,
|
||||
namespace: &str,
|
||||
database: &str,
|
||||
dataset_id: &str,
|
||||
slice_id: &str,
|
||||
ingestion_fingerprint: &str,
|
||||
slice_case_count: usize,
|
||||
) -> Result<bool> {
|
||||
let state = match descriptor.load_db_state().await? {
|
||||
Some(state) => state,
|
||||
None => {
|
||||
info!("No namespace state recorded; reseeding corpus from cached shards");
|
||||
return Ok(false);
|
||||
}
|
||||
};
|
||||
|
||||
if state.slice_case_count < slice_case_count {
|
||||
info!(
|
||||
requested_cases = slice_case_count,
|
||||
stored_cases = state.slice_case_count,
|
||||
"Skipping live namespace reuse; ledger grew beyond cached state"
|
||||
);
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
if state.dataset_id != dataset_id
|
||||
|| state.slice_id != slice_id
|
||||
|| state.ingestion_fingerprint != ingestion_fingerprint
|
||||
|| state.namespace.as_deref() != Some(namespace)
|
||||
|| state.database.as_deref() != Some(database)
|
||||
{
|
||||
info!(
|
||||
namespace,
|
||||
database, "Cached namespace metadata mismatch; rebuilding corpus from ingestion cache"
|
||||
);
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
if namespace_has_corpus(db).await? {
|
||||
Ok(true)
|
||||
} else {
|
||||
info!(
|
||||
namespace,
|
||||
database,
|
||||
"Namespace metadata matches but tables are empty; reseeding from ingestion cache"
|
||||
);
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
fn sanitize_identifier(input: &str) -> String {
|
||||
let mut cleaned: String = input
|
||||
.chars()
|
||||
.map(|ch| {
|
||||
if ch.is_ascii_alphanumeric() {
|
||||
ch.to_ascii_lowercase()
|
||||
} else {
|
||||
'_'
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
if cleaned.is_empty() {
|
||||
cleaned.push('x');
|
||||
}
|
||||
if cleaned.len() > 64 {
|
||||
cleaned.truncate(64);
|
||||
}
|
||||
cleaned
|
||||
}
|
||||
|
||||
pub(crate) fn default_namespace(dataset_id: &str, limit: Option<usize>) -> String {
|
||||
let dataset_component = sanitize_identifier(dataset_id);
|
||||
let limit_component = match limit {
|
||||
Some(value) if value > 0 => format!("limit{}", value),
|
||||
_ => "all".to_string(),
|
||||
};
|
||||
format!("eval_{}_{}", dataset_component, limit_component)
|
||||
}
|
||||
|
||||
pub(crate) fn default_database() -> String {
|
||||
"retrieval_eval".to_string()
|
||||
}
|
||||
|
||||
pub(crate) async fn record_namespace_state(
|
||||
descriptor: &snapshot::Descriptor,
|
||||
dataset_id: &str,
|
||||
slice_id: &str,
|
||||
ingestion_fingerprint: &str,
|
||||
namespace: &str,
|
||||
database: &str,
|
||||
slice_case_count: usize,
|
||||
) {
|
||||
let state = DbSnapshotState {
|
||||
dataset_id: dataset_id.to_string(),
|
||||
slice_id: slice_id.to_string(),
|
||||
ingestion_fingerprint: ingestion_fingerprint.to_string(),
|
||||
snapshot_hash: descriptor.metadata_hash().to_string(),
|
||||
updated_at: Utc::now(),
|
||||
namespace: Some(namespace.to_string()),
|
||||
database: Some(database.to_string()),
|
||||
slice_case_count,
|
||||
};
|
||||
if let Err(err) = descriptor.store_db_state(&state).await {
|
||||
warn!(error = %err, "Failed to record namespace state");
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn enforce_system_settings(
|
||||
db: &SurrealDbClient,
|
||||
mut settings: SystemSettings,
|
||||
provider_dimension: usize,
|
||||
config: &Config,
|
||||
) -> Result<SystemSettings> {
|
||||
let mut updated_settings = settings.clone();
|
||||
let mut needs_settings_update = false;
|
||||
let mut embedding_dimension_changed = false;
|
||||
|
||||
if provider_dimension != settings.embedding_dimensions as usize {
|
||||
updated_settings.embedding_dimensions = provider_dimension as u32;
|
||||
needs_settings_update = true;
|
||||
embedding_dimension_changed = true;
|
||||
}
|
||||
if let Some(query_override) = config.query_model.as_deref() {
|
||||
if settings.query_model != query_override {
|
||||
info!(
|
||||
model = query_override,
|
||||
"Overriding system query model for this run"
|
||||
);
|
||||
updated_settings.query_model = query_override.to_string();
|
||||
needs_settings_update = true;
|
||||
}
|
||||
}
|
||||
if needs_settings_update {
|
||||
settings = SystemSettings::update(db, updated_settings)
|
||||
.await
|
||||
.context("updating system settings overrides")?;
|
||||
}
|
||||
if embedding_dimension_changed {
|
||||
change_embedding_length_in_hnsw_indexes(db, provider_dimension)
|
||||
.await
|
||||
.context("redefining HNSW indexes for new embedding dimension")?;
|
||||
}
|
||||
Ok(settings)
|
||||
}
|
||||
|
||||
pub(crate) async fn load_or_init_system_settings(
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(SystemSettings, bool)> {
|
||||
match SystemSettings::get_current(db).await {
|
||||
Ok(settings) => Ok((settings, false)),
|
||||
Err(AppError::NotFound(_)) => {
|
||||
info!("System settings missing; applying database migrations for namespace");
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.context("applying database migrations after missing system settings")?;
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
let settings = SystemSettings::get_current(db)
|
||||
.await
|
||||
.context("loading system settings after migrations")?;
|
||||
Ok((settings, true))
|
||||
}
|
||||
Err(err) => Err(err).context("loading system settings"),
|
||||
}
|
||||
}
|
||||
193
eval/src/eval/pipeline/context.rs
Normal file
193
eval/src/eval/pipeline/context.rs
Normal file
@@ -0,0 +1,193 @@
|
||||
use std::{
|
||||
path::PathBuf,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use async_openai::Client;
|
||||
use common::storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{system_settings::SystemSettings, user::User},
|
||||
};
|
||||
use composite_retrieval::{
|
||||
pipeline::{PipelineStageTimings, RetrievalConfig},
|
||||
reranking::RerankerPool,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
args::Config,
|
||||
cache::EmbeddingCache,
|
||||
datasets::ConvertedDataset,
|
||||
embedding::EmbeddingProvider,
|
||||
eval::{CaseDiagnostics, CaseSummary, EvaluationStageTimings, EvaluationSummary, SeededCase},
|
||||
ingest, slice, snapshot,
|
||||
};
|
||||
|
||||
pub(super) struct EvaluationContext<'a> {
|
||||
dataset: &'a ConvertedDataset,
|
||||
config: &'a Config,
|
||||
pub stage_timings: EvaluationStageTimings,
|
||||
pub ledger_limit: Option<usize>,
|
||||
pub slice_settings: Option<slice::SliceConfig<'a>>,
|
||||
pub slice: Option<slice::ResolvedSlice<'a>>,
|
||||
pub window_offset: usize,
|
||||
pub window_length: usize,
|
||||
pub window_total_cases: usize,
|
||||
pub namespace: String,
|
||||
pub database: String,
|
||||
pub db: Option<SurrealDbClient>,
|
||||
pub descriptor: Option<snapshot::Descriptor>,
|
||||
pub settings: Option<SystemSettings>,
|
||||
pub settings_missing: bool,
|
||||
pub must_reapply_settings: bool,
|
||||
pub embedding_provider: Option<EmbeddingProvider>,
|
||||
pub embedding_cache: Option<EmbeddingCache>,
|
||||
pub openai_client: Option<Arc<Client<async_openai::config::OpenAIConfig>>>,
|
||||
pub openai_base_url: Option<String>,
|
||||
pub expected_fingerprint: Option<String>,
|
||||
pub ingestion_duration_ms: u128,
|
||||
pub namespace_seed_ms: Option<u128>,
|
||||
pub namespace_reused: bool,
|
||||
pub evaluation_start: Option<Instant>,
|
||||
pub eval_user: Option<User>,
|
||||
pub corpus_handle: Option<ingest::CorpusHandle>,
|
||||
pub cases: Vec<SeededCase>,
|
||||
pub stage_latency_samples: Vec<PipelineStageTimings>,
|
||||
pub latencies: Vec<u128>,
|
||||
pub diagnostics_output: Vec<CaseDiagnostics>,
|
||||
pub query_summaries: Vec<CaseSummary>,
|
||||
pub rerank_pool: Option<Arc<RerankerPool>>,
|
||||
pub retrieval_config: Option<Arc<RetrievalConfig>>,
|
||||
pub summary: Option<EvaluationSummary>,
|
||||
pub diagnostics_path: Option<PathBuf>,
|
||||
pub diagnostics_enabled: bool,
|
||||
}
|
||||
|
||||
impl<'a> EvaluationContext<'a> {
|
||||
pub fn new(dataset: &'a ConvertedDataset, config: &'a Config) -> Self {
|
||||
Self {
|
||||
dataset,
|
||||
config,
|
||||
stage_timings: EvaluationStageTimings::default(),
|
||||
ledger_limit: None,
|
||||
slice_settings: None,
|
||||
slice: None,
|
||||
window_offset: 0,
|
||||
window_length: 0,
|
||||
window_total_cases: 0,
|
||||
namespace: String::new(),
|
||||
database: String::new(),
|
||||
db: None,
|
||||
descriptor: None,
|
||||
settings: None,
|
||||
settings_missing: false,
|
||||
must_reapply_settings: false,
|
||||
embedding_provider: None,
|
||||
embedding_cache: None,
|
||||
openai_client: None,
|
||||
openai_base_url: None,
|
||||
expected_fingerprint: None,
|
||||
ingestion_duration_ms: 0,
|
||||
namespace_seed_ms: None,
|
||||
namespace_reused: false,
|
||||
evaluation_start: None,
|
||||
eval_user: None,
|
||||
corpus_handle: None,
|
||||
cases: Vec::new(),
|
||||
stage_latency_samples: Vec::new(),
|
||||
latencies: Vec::new(),
|
||||
diagnostics_output: Vec::new(),
|
||||
query_summaries: Vec::new(),
|
||||
rerank_pool: None,
|
||||
retrieval_config: None,
|
||||
summary: None,
|
||||
diagnostics_path: config.chunk_diagnostics_path.clone(),
|
||||
diagnostics_enabled: config.chunk_diagnostics_path.is_some(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dataset(&self) -> &'a ConvertedDataset {
|
||||
self.dataset
|
||||
}
|
||||
|
||||
pub fn config(&self) -> &'a Config {
|
||||
self.config
|
||||
}
|
||||
|
||||
pub fn slice(&self) -> &slice::ResolvedSlice<'a> {
|
||||
self.slice.as_ref().expect("slice has not been prepared")
|
||||
}
|
||||
|
||||
pub fn db(&self) -> &SurrealDbClient {
|
||||
self.db.as_ref().expect("database connection missing")
|
||||
}
|
||||
|
||||
pub fn descriptor(&self) -> &snapshot::Descriptor {
|
||||
self.descriptor
|
||||
.as_ref()
|
||||
.expect("snapshot descriptor unavailable")
|
||||
}
|
||||
|
||||
pub fn embedding_provider(&self) -> &EmbeddingProvider {
|
||||
self.embedding_provider
|
||||
.as_ref()
|
||||
.expect("embedding provider not initialised")
|
||||
}
|
||||
|
||||
pub fn openai_client(&self) -> Arc<Client<async_openai::config::OpenAIConfig>> {
|
||||
self.openai_client
|
||||
.as_ref()
|
||||
.expect("openai client missing")
|
||||
.clone()
|
||||
}
|
||||
|
||||
pub fn corpus_handle(&self) -> &ingest::CorpusHandle {
|
||||
self.corpus_handle.as_ref().expect("corpus handle missing")
|
||||
}
|
||||
|
||||
pub fn evaluation_user(&self) -> &User {
|
||||
self.eval_user.as_ref().expect("evaluation user missing")
|
||||
}
|
||||
|
||||
pub fn record_stage_duration(&mut self, stage: EvalStage, duration: Duration) {
|
||||
let elapsed = duration.as_millis() as u128;
|
||||
match stage {
|
||||
EvalStage::PrepareSlice => self.stage_timings.prepare_slice_ms += elapsed,
|
||||
EvalStage::PrepareDb => self.stage_timings.prepare_db_ms += elapsed,
|
||||
EvalStage::PrepareCorpus => self.stage_timings.prepare_corpus_ms += elapsed,
|
||||
EvalStage::PrepareNamespace => self.stage_timings.prepare_namespace_ms += elapsed,
|
||||
EvalStage::RunQueries => self.stage_timings.run_queries_ms += elapsed,
|
||||
EvalStage::Summarize => self.stage_timings.summarize_ms += elapsed,
|
||||
EvalStage::Finalize => self.stage_timings.finalize_ms += elapsed,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_summary(self) -> EvaluationSummary {
|
||||
self.summary.expect("evaluation summary missing")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub(super) enum EvalStage {
|
||||
PrepareSlice,
|
||||
PrepareDb,
|
||||
PrepareCorpus,
|
||||
PrepareNamespace,
|
||||
RunQueries,
|
||||
Summarize,
|
||||
Finalize,
|
||||
}
|
||||
|
||||
impl EvalStage {
|
||||
pub fn label(&self) -> &'static str {
|
||||
match self {
|
||||
EvalStage::PrepareSlice => "prepare-slice",
|
||||
EvalStage::PrepareDb => "prepare-db",
|
||||
EvalStage::PrepareCorpus => "prepare-corpus",
|
||||
EvalStage::PrepareNamespace => "prepare-namespace",
|
||||
EvalStage::RunQueries => "run-queries",
|
||||
EvalStage::Summarize => "summarize",
|
||||
EvalStage::Finalize => "finalize",
|
||||
}
|
||||
}
|
||||
}
|
||||
29
eval/src/eval/pipeline/mod.rs
Normal file
29
eval/src/eval/pipeline/mod.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
mod context;
|
||||
mod stages;
|
||||
mod state;
|
||||
|
||||
use anyhow::Result;
|
||||
|
||||
use crate::{args::Config, datasets::ConvertedDataset, eval::EvaluationSummary};
|
||||
|
||||
use context::EvaluationContext;
|
||||
|
||||
pub async fn run_evaluation(
|
||||
dataset: &ConvertedDataset,
|
||||
config: &Config,
|
||||
) -> Result<EvaluationSummary> {
|
||||
let mut ctx = EvaluationContext::new(dataset, config);
|
||||
let machine = state::ready();
|
||||
|
||||
let machine = stages::prepare_slice(machine, &mut ctx).await?;
|
||||
let machine = stages::prepare_db(machine, &mut ctx).await?;
|
||||
let machine = stages::prepare_corpus(machine, &mut ctx).await?;
|
||||
let machine = stages::prepare_namespace(machine, &mut ctx).await?;
|
||||
let machine = stages::run_queries(machine, &mut ctx).await?;
|
||||
let machine = stages::summarize(machine, &mut ctx).await?;
|
||||
let machine = stages::finalize(machine, &mut ctx).await?;
|
||||
|
||||
drop(machine);
|
||||
|
||||
Ok(ctx.into_summary())
|
||||
}
|
||||
59
eval/src/eval/pipeline/stages/finalize.rs
Normal file
59
eval/src/eval/pipeline/stages/finalize.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
use std::time::Instant;
|
||||
|
||||
use anyhow::Context;
|
||||
use tracing::info;
|
||||
|
||||
use crate::eval::write_chunk_diagnostics;
|
||||
|
||||
use super::super::{
|
||||
context::{EvalStage, EvaluationContext},
|
||||
state::{Completed, EvaluationMachine, Summarized},
|
||||
};
|
||||
use super::{map_guard_error, StageResult};
|
||||
|
||||
pub(crate) async fn finalize(
|
||||
machine: EvaluationMachine<(), Summarized>,
|
||||
ctx: &mut EvaluationContext<'_>,
|
||||
) -> StageResult<Completed> {
|
||||
let stage = EvalStage::Finalize;
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
"starting evaluation stage"
|
||||
);
|
||||
let started = Instant::now();
|
||||
|
||||
if let Some(cache) = ctx.embedding_cache.as_ref() {
|
||||
cache
|
||||
.persist()
|
||||
.await
|
||||
.context("persisting embedding cache")?;
|
||||
}
|
||||
|
||||
if let Some(path) = ctx.diagnostics_path.as_ref() {
|
||||
if ctx.diagnostics_enabled {
|
||||
write_chunk_diagnostics(path.as_path(), &ctx.diagnostics_output)
|
||||
.await
|
||||
.with_context(|| format!("writing chunk diagnostics to {}", path.display()))?;
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
total_cases = ctx.summary.as_ref().map(|s| s.total_cases).unwrap_or(0),
|
||||
correct = ctx.summary.as_ref().map(|s| s.correct).unwrap_or(0),
|
||||
precision = ctx.summary.as_ref().map(|s| s.precision).unwrap_or(0.0),
|
||||
dataset = ctx.dataset().metadata.id.as_str(),
|
||||
"Evaluation complete"
|
||||
);
|
||||
|
||||
let elapsed = started.elapsed();
|
||||
ctx.record_stage_duration(stage, elapsed);
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
duration_ms = elapsed.as_millis(),
|
||||
"completed evaluation stage"
|
||||
);
|
||||
|
||||
machine
|
||||
.finalize()
|
||||
.map_err(|(_, guard)| map_guard_error("finalize", guard))
|
||||
}
|
||||
26
eval/src/eval/pipeline/stages/mod.rs
Normal file
26
eval/src/eval/pipeline/stages/mod.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
mod finalize;
|
||||
mod prepare_corpus;
|
||||
mod prepare_db;
|
||||
mod prepare_namespace;
|
||||
mod prepare_slice;
|
||||
mod run_queries;
|
||||
mod summarize;
|
||||
|
||||
pub(crate) use finalize::finalize;
|
||||
pub(crate) use prepare_corpus::prepare_corpus;
|
||||
pub(crate) use prepare_db::prepare_db;
|
||||
pub(crate) use prepare_namespace::prepare_namespace;
|
||||
pub(crate) use prepare_slice::prepare_slice;
|
||||
pub(crate) use run_queries::run_queries;
|
||||
pub(crate) use summarize::summarize;
|
||||
|
||||
use anyhow::Result;
|
||||
use state_machines::core::GuardError;
|
||||
|
||||
use super::state::EvaluationMachine;
|
||||
|
||||
fn map_guard_error(event: &str, guard: GuardError) -> anyhow::Error {
|
||||
anyhow::anyhow!("invalid evaluation pipeline transition during {event}: {guard:?}")
|
||||
}
|
||||
|
||||
type StageResult<S> = Result<EvaluationMachine<(), S>>;
|
||||
84
eval/src/eval/pipeline/stages/prepare_corpus.rs
Normal file
84
eval/src/eval/pipeline/stages/prepare_corpus.rs
Normal file
@@ -0,0 +1,84 @@
|
||||
use std::time::Instant;
|
||||
|
||||
use anyhow::Context;
|
||||
use tracing::info;
|
||||
|
||||
use crate::{ingest, slice, snapshot};
|
||||
|
||||
use super::super::{
|
||||
context::{EvalStage, EvaluationContext},
|
||||
state::{CorpusReady, DbReady, EvaluationMachine},
|
||||
};
|
||||
use super::{map_guard_error, StageResult};
|
||||
|
||||
pub(crate) async fn prepare_corpus(
|
||||
machine: EvaluationMachine<(), DbReady>,
|
||||
ctx: &mut EvaluationContext<'_>,
|
||||
) -> StageResult<CorpusReady> {
|
||||
let stage = EvalStage::PrepareCorpus;
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
"starting evaluation stage"
|
||||
);
|
||||
let started = Instant::now();
|
||||
|
||||
let config = ctx.config();
|
||||
let cache_settings = ingest::CorpusCacheConfig::from(config);
|
||||
let embedding_provider = ctx.embedding_provider().clone();
|
||||
let openai_client = ctx.openai_client();
|
||||
|
||||
let eval_user_id = "eval-user".to_string();
|
||||
let ingestion_timer = Instant::now();
|
||||
let corpus_handle = {
|
||||
let slice = ctx.slice();
|
||||
let window = slice::select_window(slice, ctx.config().slice_offset, ctx.config().limit)
|
||||
.context("selecting slice window for corpus preparation")?;
|
||||
ingest::ensure_corpus(
|
||||
ctx.dataset(),
|
||||
slice,
|
||||
&window,
|
||||
&cache_settings,
|
||||
&embedding_provider,
|
||||
openai_client,
|
||||
&eval_user_id,
|
||||
config.converted_dataset_path.as_path(),
|
||||
)
|
||||
.await
|
||||
.context("ensuring ingestion-backed corpus")?
|
||||
};
|
||||
let expected_fingerprint = corpus_handle
|
||||
.manifest
|
||||
.metadata
|
||||
.ingestion_fingerprint
|
||||
.clone();
|
||||
let ingestion_duration_ms = ingestion_timer.elapsed().as_millis() as u128;
|
||||
info!(
|
||||
cache = %corpus_handle.path.display(),
|
||||
reused_ingestion = corpus_handle.reused_ingestion,
|
||||
reused_embeddings = corpus_handle.reused_embeddings,
|
||||
positive_ingested = corpus_handle.positive_ingested,
|
||||
negative_ingested = corpus_handle.negative_ingested,
|
||||
"Ingestion corpus ready"
|
||||
);
|
||||
|
||||
ctx.corpus_handle = Some(corpus_handle);
|
||||
ctx.expected_fingerprint = Some(expected_fingerprint);
|
||||
ctx.ingestion_duration_ms = ingestion_duration_ms;
|
||||
ctx.descriptor = Some(snapshot::Descriptor::new(
|
||||
config,
|
||||
ctx.slice(),
|
||||
ctx.embedding_provider(),
|
||||
));
|
||||
|
||||
let elapsed = started.elapsed();
|
||||
ctx.record_stage_duration(stage, elapsed);
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
duration_ms = elapsed.as_millis(),
|
||||
"completed evaluation stage"
|
||||
);
|
||||
|
||||
machine
|
||||
.prepare_corpus()
|
||||
.map_err(|(_, guard)| map_guard_error("prepare_corpus", guard))
|
||||
}
|
||||
111
eval/src/eval/pipeline/stages/prepare_db.rs
Normal file
111
eval/src/eval/pipeline/stages/prepare_db.rs
Normal file
@@ -0,0 +1,111 @@
|
||||
use std::{sync::Arc, time::Instant};
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
args::EmbeddingBackend,
|
||||
cache::EmbeddingCache,
|
||||
embedding,
|
||||
eval::{
|
||||
connect_eval_db, enforce_system_settings, load_or_init_system_settings, sanitize_model_code,
|
||||
},
|
||||
openai,
|
||||
};
|
||||
|
||||
use super::super::{
|
||||
context::{EvalStage, EvaluationContext},
|
||||
state::{DbReady, EvaluationMachine, SlicePrepared},
|
||||
};
|
||||
use super::{map_guard_error, StageResult};
|
||||
|
||||
pub(crate) async fn prepare_db(
|
||||
machine: EvaluationMachine<(), SlicePrepared>,
|
||||
ctx: &mut EvaluationContext<'_>,
|
||||
) -> StageResult<DbReady> {
|
||||
let stage = EvalStage::PrepareDb;
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
"starting evaluation stage"
|
||||
);
|
||||
let started = Instant::now();
|
||||
|
||||
let namespace = ctx.namespace.clone();
|
||||
let database = ctx.database.clone();
|
||||
let config = ctx.config();
|
||||
|
||||
let db = connect_eval_db(config, &namespace, &database).await?;
|
||||
let (mut settings, settings_missing) = load_or_init_system_settings(&db).await?;
|
||||
|
||||
let embedding_provider =
|
||||
embedding::build_provider(config, settings.embedding_dimensions as usize)
|
||||
.await
|
||||
.context("building embedding provider")?;
|
||||
let (raw_openai_client, openai_base_url) =
|
||||
openai::build_client_from_env().context("building OpenAI client")?;
|
||||
let openai_client = Arc::new(raw_openai_client);
|
||||
let provider_dimension = embedding_provider.dimension();
|
||||
if provider_dimension == 0 {
|
||||
return Err(anyhow!(
|
||||
"embedding provider reported zero dimensions; cannot continue"
|
||||
));
|
||||
}
|
||||
|
||||
info!(
|
||||
backend = embedding_provider.backend_label(),
|
||||
model = embedding_provider
|
||||
.model_code()
|
||||
.as_deref()
|
||||
.unwrap_or("<none>"),
|
||||
dimension = provider_dimension,
|
||||
"Embedding provider initialised"
|
||||
);
|
||||
info!(openai_base_url = %openai_base_url, "OpenAI client configured");
|
||||
|
||||
let embedding_cache = if config.embedding_backend == EmbeddingBackend::FastEmbed {
|
||||
if let Some(model_code) = embedding_provider.model_code() {
|
||||
let sanitized = sanitize_model_code(&model_code);
|
||||
let path = config.cache_dir.join(format!("{sanitized}.json"));
|
||||
if config.force_convert && path.exists() {
|
||||
tokio::fs::remove_file(&path)
|
||||
.await
|
||||
.with_context(|| format!("removing stale cache {}", path.display()))
|
||||
.ok();
|
||||
}
|
||||
let cache = EmbeddingCache::load(&path).await?;
|
||||
info!(path = %path.display(), "Embedding cache ready");
|
||||
Some(cache)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let must_reapply_settings = settings_missing;
|
||||
let defer_initial_enforce = settings_missing && !config.reseed_slice;
|
||||
if !defer_initial_enforce {
|
||||
settings = enforce_system_settings(&db, settings, provider_dimension, config).await?;
|
||||
}
|
||||
|
||||
ctx.db = Some(db);
|
||||
ctx.settings_missing = settings_missing;
|
||||
ctx.must_reapply_settings = must_reapply_settings;
|
||||
ctx.settings = Some(settings);
|
||||
ctx.embedding_provider = Some(embedding_provider);
|
||||
ctx.embedding_cache = embedding_cache;
|
||||
ctx.openai_client = Some(openai_client);
|
||||
ctx.openai_base_url = Some(openai_base_url);
|
||||
|
||||
let elapsed = started.elapsed();
|
||||
ctx.record_stage_duration(stage, elapsed);
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
duration_ms = elapsed.as_millis(),
|
||||
"completed evaluation stage"
|
||||
);
|
||||
|
||||
machine
|
||||
.prepare_db()
|
||||
.map_err(|(_, guard)| map_guard_error("prepare_db", guard))
|
||||
}
|
||||
159
eval/src/eval/pipeline/stages/prepare_namespace.rs
Normal file
159
eval/src/eval/pipeline/stages/prepare_namespace.rs
Normal file
@@ -0,0 +1,159 @@
|
||||
use std::time::Instant;
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use common::storage::types::system_settings::SystemSettings;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::{
|
||||
db_helpers::{recreate_indexes, remove_all_indexes, reset_namespace},
|
||||
eval::{
|
||||
can_reuse_namespace, cases_from_manifest, enforce_system_settings, ensure_eval_user,
|
||||
record_namespace_state, warm_hnsw_cache,
|
||||
},
|
||||
ingest,
|
||||
};
|
||||
|
||||
use super::super::{
|
||||
context::{EvalStage, EvaluationContext},
|
||||
state::{CorpusReady, EvaluationMachine, NamespaceReady},
|
||||
};
|
||||
use super::{map_guard_error, StageResult};
|
||||
|
||||
pub(crate) async fn prepare_namespace(
|
||||
machine: EvaluationMachine<(), CorpusReady>,
|
||||
ctx: &mut EvaluationContext<'_>,
|
||||
) -> StageResult<NamespaceReady> {
|
||||
let stage = EvalStage::PrepareNamespace;
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
"starting evaluation stage"
|
||||
);
|
||||
let started = Instant::now();
|
||||
|
||||
let config = ctx.config();
|
||||
let dataset = ctx.dataset();
|
||||
let expected_fingerprint = ctx
|
||||
.expected_fingerprint
|
||||
.as_deref()
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
let namespace = ctx.namespace.clone();
|
||||
let database = ctx.database.clone();
|
||||
let embedding_provider = ctx.embedding_provider().clone();
|
||||
|
||||
let mut namespace_reused = false;
|
||||
if !config.reseed_slice {
|
||||
namespace_reused = {
|
||||
let slice = ctx.slice();
|
||||
can_reuse_namespace(
|
||||
ctx.db(),
|
||||
ctx.descriptor(),
|
||||
&namespace,
|
||||
&database,
|
||||
dataset.metadata.id.as_str(),
|
||||
slice.manifest.slice_id.as_str(),
|
||||
expected_fingerprint.as_str(),
|
||||
slice.manifest.case_count,
|
||||
)
|
||||
.await?
|
||||
};
|
||||
}
|
||||
|
||||
let mut namespace_seed_ms = None;
|
||||
if !namespace_reused {
|
||||
ctx.must_reapply_settings = true;
|
||||
if let Err(err) = reset_namespace(ctx.db(), &namespace, &database).await {
|
||||
warn!(
|
||||
error = %err,
|
||||
namespace,
|
||||
database = %database,
|
||||
"Failed to reset namespace before reseeding; continuing with existing data"
|
||||
);
|
||||
} else if let Err(err) = ctx.db().apply_migrations().await {
|
||||
warn!(error = %err, "Failed to reapply migrations after namespace reset");
|
||||
}
|
||||
|
||||
{
|
||||
let slice = ctx.slice();
|
||||
info!(
|
||||
slice = slice.manifest.slice_id.as_str(),
|
||||
window_offset = ctx.window_offset,
|
||||
window_length = ctx.window_length,
|
||||
positives = slice.manifest.positive_paragraphs,
|
||||
negatives = slice.manifest.negative_paragraphs,
|
||||
total = slice.manifest.total_paragraphs,
|
||||
"Seeding ingestion corpus into SurrealDB"
|
||||
);
|
||||
}
|
||||
let indexes_disabled = remove_all_indexes(ctx.db()).await.is_ok();
|
||||
let seed_start = Instant::now();
|
||||
ingest::seed_manifest_into_db(ctx.db(), &ctx.corpus_handle().manifest)
|
||||
.await
|
||||
.context("seeding ingestion corpus from manifest")?;
|
||||
namespace_seed_ms = Some(seed_start.elapsed().as_millis() as u128);
|
||||
if indexes_disabled {
|
||||
info!("Recreating indexes after namespace reset");
|
||||
if let Err(err) = recreate_indexes(ctx.db(), embedding_provider.dimension()).await {
|
||||
warn!(error = %err, "failed to restore indexes after namespace reset");
|
||||
} else {
|
||||
warm_hnsw_cache(ctx.db(), embedding_provider.dimension()).await?;
|
||||
}
|
||||
}
|
||||
{
|
||||
let slice = ctx.slice();
|
||||
record_namespace_state(
|
||||
ctx.descriptor(),
|
||||
dataset.metadata.id.as_str(),
|
||||
slice.manifest.slice_id.as_str(),
|
||||
expected_fingerprint.as_str(),
|
||||
&namespace,
|
||||
&database,
|
||||
slice.manifest.case_count,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
if ctx.must_reapply_settings {
|
||||
let mut settings = SystemSettings::get_current(ctx.db())
|
||||
.await
|
||||
.context("reloading system settings after namespace reset")?;
|
||||
settings =
|
||||
enforce_system_settings(ctx.db(), settings, embedding_provider.dimension(), config)
|
||||
.await?;
|
||||
ctx.settings = Some(settings);
|
||||
ctx.must_reapply_settings = false;
|
||||
}
|
||||
|
||||
let user = ensure_eval_user(ctx.db()).await?;
|
||||
ctx.eval_user = Some(user);
|
||||
|
||||
let cases = cases_from_manifest(&ctx.corpus_handle().manifest);
|
||||
if cases.is_empty() {
|
||||
return Err(anyhow!(
|
||||
"no answerable questions found in converted dataset for evaluation"
|
||||
));
|
||||
}
|
||||
ctx.cases = cases;
|
||||
ctx.namespace_reused = namespace_reused;
|
||||
ctx.namespace_seed_ms = namespace_seed_ms;
|
||||
|
||||
info!(
|
||||
cases = ctx.cases.len(),
|
||||
window_offset = ctx.window_offset,
|
||||
namespace_reused = namespace_reused,
|
||||
"Dataset ready"
|
||||
);
|
||||
|
||||
let elapsed = started.elapsed();
|
||||
ctx.record_stage_duration(stage, elapsed);
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
duration_ms = elapsed.as_millis(),
|
||||
"completed evaluation stage"
|
||||
);
|
||||
|
||||
machine
|
||||
.prepare_namespace()
|
||||
.map_err(|(_, guard)| map_guard_error("prepare_namespace", guard))
|
||||
}
|
||||
66
eval/src/eval/pipeline/stages/prepare_slice.rs
Normal file
66
eval/src/eval/pipeline/stages/prepare_slice.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
use std::time::Instant;
|
||||
|
||||
use anyhow::Context;
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
eval::{default_database, default_namespace, ledger_target},
|
||||
slice,
|
||||
};
|
||||
|
||||
use super::super::{
|
||||
context::{EvalStage, EvaluationContext},
|
||||
state::{EvaluationMachine, Ready, SlicePrepared},
|
||||
};
|
||||
use super::{map_guard_error, StageResult};
|
||||
|
||||
pub(crate) async fn prepare_slice(
|
||||
machine: EvaluationMachine<(), Ready>,
|
||||
ctx: &mut EvaluationContext<'_>,
|
||||
) -> StageResult<SlicePrepared> {
|
||||
let stage = EvalStage::PrepareSlice;
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
"starting evaluation stage"
|
||||
);
|
||||
let started = Instant::now();
|
||||
|
||||
let ledger_limit = ledger_target(ctx.config());
|
||||
let slice_settings = slice::slice_config_with_limit(ctx.config(), ledger_limit);
|
||||
let resolved_slice =
|
||||
slice::resolve_slice(ctx.dataset(), &slice_settings).context("resolving dataset slice")?;
|
||||
let window = slice::select_window(
|
||||
&resolved_slice,
|
||||
ctx.config().slice_offset,
|
||||
ctx.config().limit,
|
||||
)
|
||||
.context("selecting slice window (use --slice-grow to extend the ledger first)")?;
|
||||
|
||||
ctx.ledger_limit = ledger_limit;
|
||||
ctx.slice_settings = Some(slice_settings);
|
||||
ctx.slice = Some(resolved_slice.clone());
|
||||
ctx.window_offset = window.offset;
|
||||
ctx.window_length = window.length;
|
||||
ctx.window_total_cases = window.total_cases;
|
||||
|
||||
ctx.namespace = ctx.config().db_namespace.clone().unwrap_or_else(|| {
|
||||
default_namespace(ctx.dataset().metadata.id.as_str(), ctx.config().limit)
|
||||
});
|
||||
ctx.database = ctx
|
||||
.config()
|
||||
.db_database
|
||||
.clone()
|
||||
.unwrap_or_else(default_database);
|
||||
|
||||
let elapsed = started.elapsed();
|
||||
ctx.record_stage_duration(stage, elapsed);
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
duration_ms = elapsed.as_millis(),
|
||||
"completed evaluation stage"
|
||||
);
|
||||
|
||||
machine
|
||||
.prepare_slice()
|
||||
.map_err(|(_, guard)| map_guard_error("prepare_slice", guard))
|
||||
}
|
||||
337
eval/src/eval/pipeline/stages/run_queries.rs
Normal file
337
eval/src/eval/pipeline/stages/run_queries.rs
Normal file
@@ -0,0 +1,337 @@
|
||||
use std::{collections::HashSet, sync::Arc, time::Instant};
|
||||
|
||||
use anyhow::Context;
|
||||
use futures::stream::{self, StreamExt, TryStreamExt};
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::eval::{
|
||||
apply_dataset_tuning_overrides, build_case_diagnostics, text_contains_answer, CaseDiagnostics,
|
||||
CaseSummary, RetrievedSummary,
|
||||
};
|
||||
use composite_retrieval::pipeline::{self, PipelineStageTimings, RetrievalConfig};
|
||||
use composite_retrieval::reranking::RerankerPool;
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
use super::super::{
|
||||
context::{EvalStage, EvaluationContext},
|
||||
state::{EvaluationMachine, NamespaceReady, QueriesFinished},
|
||||
};
|
||||
use super::{map_guard_error, StageResult};
|
||||
|
||||
pub(crate) async fn run_queries(
|
||||
machine: EvaluationMachine<(), NamespaceReady>,
|
||||
ctx: &mut EvaluationContext<'_>,
|
||||
) -> StageResult<QueriesFinished> {
|
||||
let stage = EvalStage::RunQueries;
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
"starting evaluation stage"
|
||||
);
|
||||
let started = Instant::now();
|
||||
|
||||
let config = ctx.config();
|
||||
let dataset = ctx.dataset();
|
||||
let slice_settings = ctx
|
||||
.slice_settings
|
||||
.as_ref()
|
||||
.expect("slice settings missing during query stage");
|
||||
let total_cases = ctx.cases.len();
|
||||
let cases_iter = std::mem::take(&mut ctx.cases).into_iter().enumerate();
|
||||
|
||||
let rerank_pool = if config.rerank {
|
||||
Some(RerankerPool::new(config.rerank_pool_size).context("initialising reranker pool")?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut retrieval_config = RetrievalConfig::default();
|
||||
retrieval_config.tuning.rerank_keep_top = config.rerank_keep_top;
|
||||
if retrieval_config.tuning.fallback_min_results < config.rerank_keep_top {
|
||||
retrieval_config.tuning.fallback_min_results = config.rerank_keep_top;
|
||||
}
|
||||
if let Some(value) = config.chunk_vector_take {
|
||||
retrieval_config.tuning.chunk_vector_take = value;
|
||||
}
|
||||
if let Some(value) = config.chunk_fts_take {
|
||||
retrieval_config.tuning.chunk_fts_take = value;
|
||||
}
|
||||
if let Some(value) = config.chunk_token_budget {
|
||||
retrieval_config.tuning.token_budget_estimate = value;
|
||||
}
|
||||
if let Some(value) = config.chunk_avg_chars_per_token {
|
||||
retrieval_config.tuning.avg_chars_per_token = value;
|
||||
}
|
||||
if let Some(value) = config.max_chunks_per_entity {
|
||||
retrieval_config.tuning.max_chunks_per_entity = value;
|
||||
}
|
||||
|
||||
apply_dataset_tuning_overrides(dataset, config, &mut retrieval_config.tuning);
|
||||
|
||||
let active_tuning = retrieval_config.tuning.clone();
|
||||
let effective_chunk_vector = config
|
||||
.chunk_vector_take
|
||||
.unwrap_or(active_tuning.chunk_vector_take);
|
||||
let effective_chunk_fts = config
|
||||
.chunk_fts_take
|
||||
.unwrap_or(active_tuning.chunk_fts_take);
|
||||
|
||||
info!(
|
||||
dataset = dataset.metadata.id.as_str(),
|
||||
slice_seed = config.slice_seed,
|
||||
slice_offset = config.slice_offset,
|
||||
slice_limit = config
|
||||
.limit
|
||||
.unwrap_or(ctx.window_total_cases),
|
||||
negative_multiplier = %slice_settings.negative_multiplier,
|
||||
rerank_enabled = config.rerank,
|
||||
rerank_pool_size = config.rerank_pool_size,
|
||||
rerank_keep_top = config.rerank_keep_top,
|
||||
chunk_min = config.chunk_min_chars,
|
||||
chunk_max = config.chunk_max_chars,
|
||||
chunk_vector_take = effective_chunk_vector,
|
||||
chunk_fts_take = effective_chunk_fts,
|
||||
chunk_token_budget = active_tuning.token_budget_estimate,
|
||||
embedding_backend = ctx.embedding_provider().backend_label(),
|
||||
embedding_model = ctx
|
||||
.embedding_provider()
|
||||
.model_code()
|
||||
.as_deref()
|
||||
.unwrap_or("<default>"),
|
||||
"Starting evaluation run"
|
||||
);
|
||||
|
||||
let retrieval_config = Arc::new(retrieval_config);
|
||||
ctx.rerank_pool = rerank_pool.clone();
|
||||
ctx.retrieval_config = Some(retrieval_config.clone());
|
||||
|
||||
ctx.evaluation_start = Some(Instant::now());
|
||||
let user_id = ctx.evaluation_user().id.clone();
|
||||
let concurrency = config.concurrency.max(1);
|
||||
let diagnostics_enabled = ctx.diagnostics_enabled;
|
||||
|
||||
let query_semaphore = Arc::new(Semaphore::new(concurrency));
|
||||
|
||||
info!(
|
||||
total_cases = total_cases,
|
||||
max_concurrent_queries = concurrency,
|
||||
"Starting evaluation with staged query execution"
|
||||
);
|
||||
|
||||
let embedding_provider_for_queries = ctx.embedding_provider().clone();
|
||||
let rerank_pool_for_queries = rerank_pool.clone();
|
||||
let db = ctx.db().clone();
|
||||
let openai_client = ctx.openai_client();
|
||||
|
||||
let results: Vec<(
|
||||
usize,
|
||||
CaseSummary,
|
||||
Option<CaseDiagnostics>,
|
||||
PipelineStageTimings,
|
||||
)> = stream::iter(cases_iter)
|
||||
.map(move |(idx, case)| {
|
||||
let db = db.clone();
|
||||
let openai_client = openai_client.clone();
|
||||
let user_id = user_id.clone();
|
||||
let retrieval_config = retrieval_config.clone();
|
||||
let embedding_provider = embedding_provider_for_queries.clone();
|
||||
let rerank_pool = rerank_pool_for_queries.clone();
|
||||
let semaphore = query_semaphore.clone();
|
||||
let diagnostics_enabled = diagnostics_enabled;
|
||||
|
||||
async move {
|
||||
let _permit = semaphore
|
||||
.acquire()
|
||||
.await
|
||||
.context("acquiring query semaphore permit")?;
|
||||
|
||||
let crate::eval::SeededCase {
|
||||
question_id,
|
||||
question,
|
||||
expected_source,
|
||||
answers,
|
||||
paragraph_id,
|
||||
paragraph_title,
|
||||
expected_chunk_ids,
|
||||
} = case;
|
||||
let query_start = Instant::now();
|
||||
|
||||
debug!(question_id = %question_id, "Evaluating query");
|
||||
let query_embedding =
|
||||
embedding_provider.embed(&question).await.with_context(|| {
|
||||
format!("generating embedding for question {}", question_id)
|
||||
})?;
|
||||
let reranker = match &rerank_pool {
|
||||
Some(pool) => Some(pool.checkout().await),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let (results, pipeline_diagnostics, stage_timings) = if diagnostics_enabled {
|
||||
let outcome = pipeline::run_pipeline_with_embedding_with_diagnostics(
|
||||
&db,
|
||||
&openai_client,
|
||||
query_embedding,
|
||||
&question,
|
||||
&user_id,
|
||||
(*retrieval_config).clone(),
|
||||
reranker,
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("running pipeline for question {}", question_id))?;
|
||||
(outcome.results, outcome.diagnostics, outcome.stage_timings)
|
||||
} else {
|
||||
let outcome = pipeline::run_pipeline_with_embedding_with_metrics(
|
||||
&db,
|
||||
&openai_client,
|
||||
query_embedding,
|
||||
&question,
|
||||
&user_id,
|
||||
(*retrieval_config).clone(),
|
||||
reranker,
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("running pipeline for question {}", question_id))?;
|
||||
(outcome.results, None, outcome.stage_timings)
|
||||
};
|
||||
let query_latency = query_start.elapsed().as_millis() as u128;
|
||||
|
||||
let mut retrieved = Vec::new();
|
||||
let mut match_rank = None;
|
||||
let answers_lower: Vec<String> =
|
||||
answers.iter().map(|ans| ans.to_ascii_lowercase()).collect();
|
||||
let expected_chunk_ids_set: HashSet<&str> =
|
||||
expected_chunk_ids.iter().map(|id| id.as_str()).collect();
|
||||
let chunk_id_required = !expected_chunk_ids_set.is_empty();
|
||||
let mut entity_hit = false;
|
||||
let mut chunk_text_hit = false;
|
||||
let mut chunk_id_hit = !chunk_id_required;
|
||||
|
||||
for (idx_entity, entity) in results.iter().enumerate() {
|
||||
if idx_entity >= config.k {
|
||||
break;
|
||||
}
|
||||
let entity_match = entity.entity.source_id == expected_source;
|
||||
if entity_match {
|
||||
entity_hit = true;
|
||||
}
|
||||
let chunk_text_for_entity = entity
|
||||
.chunks
|
||||
.iter()
|
||||
.any(|chunk| text_contains_answer(&chunk.chunk.chunk, &answers_lower));
|
||||
if chunk_text_for_entity {
|
||||
chunk_text_hit = true;
|
||||
}
|
||||
let chunk_id_for_entity = if chunk_id_required {
|
||||
expected_chunk_ids_set.contains(entity.entity.source_id.as_str())
|
||||
|| entity.chunks.iter().any(|chunk| {
|
||||
expected_chunk_ids_set.contains(chunk.chunk.id.as_str())
|
||||
})
|
||||
} else {
|
||||
true
|
||||
};
|
||||
if chunk_id_for_entity {
|
||||
chunk_id_hit = true;
|
||||
}
|
||||
let success = entity_match && chunk_text_for_entity && chunk_id_for_entity;
|
||||
if success && match_rank.is_none() {
|
||||
match_rank = Some(idx_entity + 1);
|
||||
}
|
||||
let detail_fields = if config.detailed_report {
|
||||
(
|
||||
Some(entity.entity.description.clone()),
|
||||
Some(format!("{:?}", entity.entity.entity_type)),
|
||||
Some(chunk_text_for_entity),
|
||||
Some(chunk_id_for_entity),
|
||||
)
|
||||
} else {
|
||||
(None, None, None, None)
|
||||
};
|
||||
retrieved.push(RetrievedSummary {
|
||||
rank: idx_entity + 1,
|
||||
entity_id: entity.entity.id.clone(),
|
||||
source_id: entity.entity.source_id.clone(),
|
||||
entity_name: entity.entity.name.clone(),
|
||||
score: entity.score,
|
||||
matched: success,
|
||||
entity_description: detail_fields.0,
|
||||
entity_category: detail_fields.1,
|
||||
chunk_text_match: detail_fields.2,
|
||||
chunk_id_match: detail_fields.3,
|
||||
});
|
||||
}
|
||||
|
||||
let overall_match = match_rank.is_some();
|
||||
let summary = CaseSummary {
|
||||
question_id,
|
||||
question,
|
||||
paragraph_id,
|
||||
paragraph_title,
|
||||
expected_source,
|
||||
answers,
|
||||
matched: overall_match,
|
||||
entity_match: entity_hit,
|
||||
chunk_text_match: chunk_text_hit,
|
||||
chunk_id_match: chunk_id_hit,
|
||||
match_rank,
|
||||
latency_ms: query_latency,
|
||||
retrieved,
|
||||
};
|
||||
|
||||
let diagnostics = if diagnostics_enabled {
|
||||
Some(build_case_diagnostics(
|
||||
&summary,
|
||||
&expected_chunk_ids,
|
||||
&answers_lower,
|
||||
&results,
|
||||
pipeline_diagnostics,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok::<
|
||||
(
|
||||
usize,
|
||||
CaseSummary,
|
||||
Option<CaseDiagnostics>,
|
||||
PipelineStageTimings,
|
||||
),
|
||||
anyhow::Error,
|
||||
>((idx, summary, diagnostics, stage_timings))
|
||||
}
|
||||
})
|
||||
.buffer_unordered(concurrency)
|
||||
.try_collect()
|
||||
.await?;
|
||||
|
||||
let mut ordered = results;
|
||||
ordered.sort_by_key(|(idx, ..)| *idx);
|
||||
let mut summaries = Vec::with_capacity(ordered.len());
|
||||
let mut latencies = Vec::with_capacity(ordered.len());
|
||||
let mut diagnostics_output = Vec::new();
|
||||
let mut stage_latency_samples = Vec::with_capacity(ordered.len());
|
||||
for (_, summary, diagnostics, stage_timings) in ordered {
|
||||
latencies.push(summary.latency_ms);
|
||||
summaries.push(summary);
|
||||
if let Some(diag) = diagnostics {
|
||||
diagnostics_output.push(diag);
|
||||
}
|
||||
stage_latency_samples.push(stage_timings);
|
||||
}
|
||||
|
||||
ctx.query_summaries = summaries;
|
||||
ctx.latencies = latencies;
|
||||
ctx.diagnostics_output = diagnostics_output;
|
||||
ctx.stage_latency_samples = stage_latency_samples;
|
||||
|
||||
let elapsed = started.elapsed();
|
||||
ctx.record_stage_duration(stage, elapsed);
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
duration_ms = elapsed.as_millis(),
|
||||
"completed evaluation stage"
|
||||
);
|
||||
|
||||
machine
|
||||
.run_queries()
|
||||
.map_err(|(_, guard)| map_guard_error("run_queries", guard))
|
||||
}
|
||||
173
eval/src/eval/pipeline/stages/summarize.rs
Normal file
173
eval/src/eval/pipeline/stages/summarize.rs
Normal file
@@ -0,0 +1,173 @@
|
||||
use std::time::Instant;
|
||||
|
||||
use chrono::Utc;
|
||||
use tracing::info;
|
||||
|
||||
use crate::eval::{
|
||||
build_stage_latency_breakdown, compute_latency_stats, EvaluationSummary, PerformanceTimings,
|
||||
};
|
||||
|
||||
use super::super::{
|
||||
context::{EvalStage, EvaluationContext},
|
||||
state::{EvaluationMachine, QueriesFinished, Summarized},
|
||||
};
|
||||
use super::{map_guard_error, StageResult};
|
||||
|
||||
pub(crate) async fn summarize(
|
||||
machine: EvaluationMachine<(), QueriesFinished>,
|
||||
ctx: &mut EvaluationContext<'_>,
|
||||
) -> StageResult<Summarized> {
|
||||
let stage = EvalStage::Summarize;
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
"starting evaluation stage"
|
||||
);
|
||||
let started = Instant::now();
|
||||
|
||||
let summaries = std::mem::take(&mut ctx.query_summaries);
|
||||
let latencies = std::mem::take(&mut ctx.latencies);
|
||||
let stage_latency_samples = std::mem::take(&mut ctx.stage_latency_samples);
|
||||
let duration_ms = ctx
|
||||
.evaluation_start
|
||||
.take()
|
||||
.map(|start| start.elapsed().as_millis())
|
||||
.unwrap_or_default();
|
||||
let config = ctx.config();
|
||||
let dataset = ctx.dataset();
|
||||
let slice = ctx.slice();
|
||||
let corpus_handle = ctx.corpus_handle();
|
||||
let total_cases = summaries.len();
|
||||
|
||||
let mut correct = 0usize;
|
||||
let mut correct_at_1 = 0usize;
|
||||
let mut correct_at_2 = 0usize;
|
||||
let mut correct_at_3 = 0usize;
|
||||
for summary in &summaries {
|
||||
if summary.matched {
|
||||
correct += 1;
|
||||
if let Some(rank) = summary.match_rank {
|
||||
if rank <= 1 {
|
||||
correct_at_1 += 1;
|
||||
}
|
||||
if rank <= 2 {
|
||||
correct_at_2 += 1;
|
||||
}
|
||||
if rank <= 3 {
|
||||
correct_at_3 += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let latency_stats = compute_latency_stats(&latencies);
|
||||
let stage_latency = build_stage_latency_breakdown(&stage_latency_samples);
|
||||
|
||||
let precision = if total_cases == 0 {
|
||||
0.0
|
||||
} else {
|
||||
(correct as f64) / (total_cases as f64)
|
||||
};
|
||||
let precision_at_1 = if total_cases == 0 {
|
||||
0.0
|
||||
} else {
|
||||
(correct_at_1 as f64) / (total_cases as f64)
|
||||
};
|
||||
let precision_at_2 = if total_cases == 0 {
|
||||
0.0
|
||||
} else {
|
||||
(correct_at_2 as f64) / (total_cases as f64)
|
||||
};
|
||||
let precision_at_3 = if total_cases == 0 {
|
||||
0.0
|
||||
} else {
|
||||
(correct_at_3 as f64) / (total_cases as f64)
|
||||
};
|
||||
|
||||
let active_tuning = ctx
|
||||
.retrieval_config
|
||||
.as_ref()
|
||||
.map(|cfg| cfg.tuning.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
let perf_timings = PerformanceTimings {
|
||||
openai_base_url: ctx
|
||||
.openai_base_url
|
||||
.clone()
|
||||
.unwrap_or_else(|| "<unknown>".to_string()),
|
||||
ingestion_ms: ctx.ingestion_duration_ms,
|
||||
namespace_seed_ms: ctx.namespace_seed_ms,
|
||||
evaluation_stage_ms: ctx.stage_timings.clone(),
|
||||
stage_latency,
|
||||
};
|
||||
|
||||
ctx.summary = Some(EvaluationSummary {
|
||||
generated_at: Utc::now(),
|
||||
k: config.k,
|
||||
limit: config.limit,
|
||||
run_label: config.label.clone(),
|
||||
total_cases,
|
||||
correct,
|
||||
precision,
|
||||
correct_at_1,
|
||||
correct_at_2,
|
||||
correct_at_3,
|
||||
precision_at_1,
|
||||
precision_at_2,
|
||||
precision_at_3,
|
||||
duration_ms,
|
||||
dataset_id: dataset.metadata.id.clone(),
|
||||
dataset_label: dataset.metadata.label.clone(),
|
||||
dataset_includes_unanswerable: dataset.metadata.include_unanswerable,
|
||||
dataset_source: dataset.source.clone(),
|
||||
slice_id: slice.manifest.slice_id.clone(),
|
||||
slice_seed: slice.manifest.seed,
|
||||
slice_total_cases: slice.manifest.case_count,
|
||||
slice_window_offset: ctx.window_offset,
|
||||
slice_window_length: ctx.window_length,
|
||||
slice_cases: total_cases,
|
||||
slice_positive_paragraphs: slice.manifest.positive_paragraphs,
|
||||
slice_negative_paragraphs: slice.manifest.negative_paragraphs,
|
||||
slice_total_paragraphs: slice.manifest.total_paragraphs,
|
||||
slice_negative_multiplier: slice.manifest.negative_multiplier,
|
||||
namespace_reused: ctx.namespace_reused,
|
||||
corpus_paragraphs: ctx.corpus_handle().manifest.metadata.paragraph_count,
|
||||
ingestion_cache_path: corpus_handle.path.display().to_string(),
|
||||
ingestion_reused: corpus_handle.reused_ingestion,
|
||||
ingestion_embeddings_reused: corpus_handle.reused_embeddings,
|
||||
ingestion_fingerprint: corpus_handle
|
||||
.manifest
|
||||
.metadata
|
||||
.ingestion_fingerprint
|
||||
.clone(),
|
||||
positive_paragraphs_reused: corpus_handle.positive_reused,
|
||||
negative_paragraphs_reused: corpus_handle.negative_reused,
|
||||
latency_ms: latency_stats,
|
||||
perf: perf_timings,
|
||||
embedding_backend: ctx.embedding_provider().backend_label().to_string(),
|
||||
embedding_model: ctx.embedding_provider().model_code(),
|
||||
embedding_dimension: ctx.embedding_provider().dimension(),
|
||||
rerank_enabled: config.rerank,
|
||||
rerank_pool_size: ctx.rerank_pool.as_ref().map(|_| config.rerank_pool_size),
|
||||
rerank_keep_top: config.rerank_keep_top,
|
||||
concurrency: config.concurrency.max(1),
|
||||
detailed_report: config.detailed_report,
|
||||
chunk_vector_take: active_tuning.chunk_vector_take,
|
||||
chunk_fts_take: active_tuning.chunk_fts_take,
|
||||
chunk_token_budget: active_tuning.token_budget_estimate,
|
||||
chunk_avg_chars_per_token: active_tuning.avg_chars_per_token,
|
||||
max_chunks_per_entity: active_tuning.max_chunks_per_entity,
|
||||
cases: summaries,
|
||||
});
|
||||
|
||||
let elapsed = started.elapsed();
|
||||
ctx.record_stage_duration(stage, elapsed);
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
duration_ms = elapsed.as_millis(),
|
||||
"completed evaluation stage"
|
||||
);
|
||||
|
||||
machine
|
||||
.summarize()
|
||||
.map_err(|(_, guard)| map_guard_error("summarize", guard))
|
||||
}
|
||||
31
eval/src/eval/pipeline/state.rs
Normal file
31
eval/src/eval/pipeline/state.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
use state_machines::state_machine;
|
||||
|
||||
state_machine! {
|
||||
name: EvaluationMachine,
|
||||
state: EvaluationState,
|
||||
initial: Ready,
|
||||
states: [Ready, SlicePrepared, DbReady, CorpusReady, NamespaceReady, QueriesFinished, Summarized, Completed, Failed],
|
||||
events {
|
||||
prepare_slice { transition: { from: Ready, to: SlicePrepared } }
|
||||
prepare_db { transition: { from: SlicePrepared, to: DbReady } }
|
||||
prepare_corpus { transition: { from: DbReady, to: CorpusReady } }
|
||||
prepare_namespace { transition: { from: CorpusReady, to: NamespaceReady } }
|
||||
run_queries { transition: { from: NamespaceReady, to: QueriesFinished } }
|
||||
summarize { transition: { from: QueriesFinished, to: Summarized } }
|
||||
finalize { transition: { from: Summarized, to: Completed } }
|
||||
abort {
|
||||
transition: { from: Ready, to: Failed }
|
||||
transition: { from: SlicePrepared, to: Failed }
|
||||
transition: { from: DbReady, to: Failed }
|
||||
transition: { from: CorpusReady, to: Failed }
|
||||
transition: { from: NamespaceReady, to: Failed }
|
||||
transition: { from: QueriesFinished, to: Failed }
|
||||
transition: { from: Summarized, to: Failed }
|
||||
transition: { from: Completed, to: Failed }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ready() -> EvaluationMachine<(), Ready> {
|
||||
EvaluationMachine::new(())
|
||||
}
|
||||
1000
eval/src/ingest.rs
Normal file
1000
eval/src/ingest.rs
Normal file
File diff suppressed because it is too large
Load Diff
182
eval/src/inspection.rs
Normal file
182
eval/src/inspection.rs
Normal file
@@ -0,0 +1,182 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
fs,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use common::storage::{db::SurrealDbClient, types::text_chunk::TextChunk};
|
||||
|
||||
use crate::{args::Config, eval::connect_eval_db, ingest, snapshot::DbSnapshotState};
|
||||
|
||||
pub async fn inspect_question(config: &Config) -> Result<()> {
|
||||
let question_id = config
|
||||
.inspect_question
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("--inspect-question is required for inspection mode"))?;
|
||||
let manifest_path = config
|
||||
.inspect_manifest
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("--inspect-manifest must be provided for inspection mode"))?;
|
||||
|
||||
let manifest = load_manifest(manifest_path)?;
|
||||
let chunk_lookup = build_chunk_lookup(&manifest);
|
||||
|
||||
let question = manifest
|
||||
.questions
|
||||
.iter()
|
||||
.find(|q| q.question_id == *question_id)
|
||||
.ok_or_else(|| {
|
||||
anyhow!(
|
||||
"question '{}' not found in manifest {}",
|
||||
question_id,
|
||||
manifest_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
println!("Question: {}", question.question_text);
|
||||
println!("Answers: {:?}", question.answers);
|
||||
println!(
|
||||
"matching_chunk_ids ({}):",
|
||||
question.matching_chunk_ids.len()
|
||||
);
|
||||
|
||||
let mut missing_in_manifest = Vec::new();
|
||||
for chunk_id in &question.matching_chunk_ids {
|
||||
if let Some(entry) = chunk_lookup.get(chunk_id) {
|
||||
println!(
|
||||
" - {} (paragraph: {})\n snippet: {}",
|
||||
chunk_id, entry.paragraph_title, entry.snippet
|
||||
);
|
||||
} else {
|
||||
println!(" - {} (missing from manifest)", chunk_id);
|
||||
missing_in_manifest.push(chunk_id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
if missing_in_manifest.is_empty() {
|
||||
println!("All matching_chunk_ids are present in the ingestion manifest");
|
||||
} else {
|
||||
println!(
|
||||
"Missing chunk IDs in manifest {}: {:?}",
|
||||
manifest_path.display(),
|
||||
missing_in_manifest
|
||||
);
|
||||
}
|
||||
|
||||
let db_state_path = config
|
||||
.inspect_db_state
|
||||
.clone()
|
||||
.unwrap_or_else(|| default_state_path(config, &manifest));
|
||||
if let Some(state) = load_db_state(&db_state_path)? {
|
||||
if let (Some(ns), Some(db_name)) = (state.namespace.as_deref(), state.database.as_deref()) {
|
||||
match connect_eval_db(config, ns, db_name).await {
|
||||
Ok(db) => match verify_chunks_in_db(&db, &question.matching_chunk_ids).await? {
|
||||
MissingChunks::None => println!(
|
||||
"All matching_chunk_ids exist in namespace '{}', database '{}'",
|
||||
ns, db_name
|
||||
),
|
||||
MissingChunks::Missing(list) => println!(
|
||||
"Missing chunks in namespace '{}', database '{}': {:?}",
|
||||
ns, db_name, list
|
||||
),
|
||||
},
|
||||
Err(err) => {
|
||||
println!(
|
||||
"Failed to connect to SurrealDB namespace '{}' / database '{}': {err}",
|
||||
ns, db_name
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
println!(
|
||||
"State file {} is missing namespace/database fields; skipping live DB validation",
|
||||
db_state_path.display()
|
||||
);
|
||||
}
|
||||
} else {
|
||||
println!(
|
||||
"State file {} not found; skipping live DB validation",
|
||||
db_state_path.display()
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct ChunkEntry {
|
||||
paragraph_title: String,
|
||||
snippet: String,
|
||||
}
|
||||
|
||||
fn load_manifest(path: &Path) -> Result<ingest::CorpusManifest> {
|
||||
let bytes =
|
||||
fs::read(path).with_context(|| format!("reading ingestion manifest {}", path.display()))?;
|
||||
serde_json::from_slice(&bytes)
|
||||
.with_context(|| format!("parsing ingestion manifest {}", path.display()))
|
||||
}
|
||||
|
||||
fn build_chunk_lookup(manifest: &ingest::CorpusManifest) -> HashMap<String, ChunkEntry> {
|
||||
let mut lookup = HashMap::new();
|
||||
for paragraph in &manifest.paragraphs {
|
||||
for chunk in ¶graph.chunks {
|
||||
let snippet = chunk
|
||||
.chunk
|
||||
.chars()
|
||||
.take(160)
|
||||
.collect::<String>()
|
||||
.replace('\n', " ");
|
||||
lookup.insert(
|
||||
chunk.id.clone(),
|
||||
ChunkEntry {
|
||||
paragraph_title: paragraph.title.clone(),
|
||||
snippet,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
lookup
|
||||
}
|
||||
|
||||
fn default_state_path(config: &Config, manifest: &ingest::CorpusManifest) -> PathBuf {
|
||||
config
|
||||
.cache_dir
|
||||
.join("snapshots")
|
||||
.join(&manifest.metadata.dataset_id)
|
||||
.join(&manifest.metadata.slice_id)
|
||||
.join("db/state.json")
|
||||
}
|
||||
|
||||
fn load_db_state(path: &Path) -> Result<Option<DbSnapshotState>> {
|
||||
if !path.exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
let bytes = fs::read(path).with_context(|| format!("reading db state {}", path.display()))?;
|
||||
let state = serde_json::from_slice(&bytes)
|
||||
.with_context(|| format!("parsing db state {}", path.display()))?;
|
||||
Ok(Some(state))
|
||||
}
|
||||
|
||||
enum MissingChunks {
|
||||
None,
|
||||
Missing(Vec<String>),
|
||||
}
|
||||
|
||||
async fn verify_chunks_in_db(db: &SurrealDbClient, chunk_ids: &[String]) -> Result<MissingChunks> {
|
||||
let mut missing = Vec::new();
|
||||
for chunk_id in chunk_ids {
|
||||
let exists = db
|
||||
.get_item::<TextChunk>(chunk_id)
|
||||
.await
|
||||
.with_context(|| format!("fetching text_chunk {}", chunk_id))?
|
||||
.is_some();
|
||||
if !exists {
|
||||
missing.push(chunk_id.clone());
|
||||
}
|
||||
}
|
||||
if missing.is_empty() {
|
||||
Ok(MissingChunks::None)
|
||||
} else {
|
||||
Ok(MissingChunks::Missing(missing))
|
||||
}
|
||||
}
|
||||
214
eval/src/main.rs
Normal file
214
eval/src/main.rs
Normal file
@@ -0,0 +1,214 @@
|
||||
mod args;
|
||||
mod cache;
|
||||
mod datasets;
|
||||
mod db_helpers;
|
||||
mod embedding;
|
||||
mod eval;
|
||||
mod ingest;
|
||||
mod inspection;
|
||||
mod openai;
|
||||
mod perf;
|
||||
mod report;
|
||||
mod slice;
|
||||
mod slices;
|
||||
mod snapshot;
|
||||
|
||||
use anyhow::Context;
|
||||
use tokio::runtime::Builder;
|
||||
use tracing::info;
|
||||
use tracing_subscriber::{fmt, EnvFilter};
|
||||
|
||||
/// Configure SurrealDB environment variables for optimal performance
|
||||
fn configure_surrealdb_performance(cpu_count: usize) {
|
||||
// Set environment variables only if they're not already set
|
||||
let indexing_batch_size = std::env::var("SURREAL_INDEXING_BATCH_SIZE")
|
||||
.unwrap_or_else(|_| (cpu_count * 2).to_string());
|
||||
std::env::set_var("SURREAL_INDEXING_BATCH_SIZE", indexing_batch_size);
|
||||
|
||||
let max_order_queue = std::env::var("SURREAL_MAX_ORDER_LIMIT_PRIORITY_QUEUE_SIZE")
|
||||
.unwrap_or_else(|_| (cpu_count * 4).to_string());
|
||||
std::env::set_var(
|
||||
"SURREAL_MAX_ORDER_LIMIT_PRIORITY_QUEUE_SIZE",
|
||||
max_order_queue,
|
||||
);
|
||||
|
||||
let websocket_concurrent = std::env::var("SURREAL_WEBSOCKET_MAX_CONCURRENT_REQUESTS")
|
||||
.unwrap_or_else(|_| cpu_count.to_string());
|
||||
std::env::set_var(
|
||||
"SURREAL_WEBSOCKET_MAX_CONCURRENT_REQUESTS",
|
||||
websocket_concurrent,
|
||||
);
|
||||
|
||||
let websocket_buffer = std::env::var("SURREAL_WEBSOCKET_RESPONSE_BUFFER_SIZE")
|
||||
.unwrap_or_else(|_| (cpu_count * 8).to_string());
|
||||
std::env::set_var("SURREAL_WEBSOCKET_RESPONSE_BUFFER_SIZE", websocket_buffer);
|
||||
|
||||
let transaction_cache = std::env::var("SURREAL_TRANSACTION_CACHE_SIZE")
|
||||
.unwrap_or_else(|_| (cpu_count * 16).to_string());
|
||||
std::env::set_var("SURREAL_TRANSACTION_CACHE_SIZE", transaction_cache);
|
||||
|
||||
info!(
|
||||
indexing_batch_size = %std::env::var("SURREAL_INDEXING_BATCH_SIZE").unwrap(),
|
||||
max_order_queue = %std::env::var("SURREAL_MAX_ORDER_LIMIT_PRIORITY_QUEUE_SIZE").unwrap(),
|
||||
websocket_concurrent = %std::env::var("SURREAL_WEBSOCKET_MAX_CONCURRENT_REQUESTS").unwrap(),
|
||||
websocket_buffer = %std::env::var("SURREAL_WEBSOCKET_RESPONSE_BUFFER_SIZE").unwrap(),
|
||||
transaction_cache = %std::env::var("SURREAL_TRANSACTION_CACHE_SIZE").unwrap(),
|
||||
"Configured SurrealDB performance variables"
|
||||
);
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// Create an explicit multi-threaded runtime with optimized configuration
|
||||
let runtime = Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.worker_threads(std::thread::available_parallelism()?.get())
|
||||
.max_blocking_threads(std::thread::available_parallelism()?.get())
|
||||
.thread_stack_size(10 * 1024 * 1024) // 10MiB stack size
|
||||
.thread_name("eval-retrieval-worker")
|
||||
.build()
|
||||
.context("failed to create tokio runtime")?;
|
||||
|
||||
runtime.block_on(async_main())
|
||||
}
|
||||
|
||||
async fn async_main() -> anyhow::Result<()> {
|
||||
// Log runtime configuration
|
||||
let cpu_count = std::thread::available_parallelism()?.get();
|
||||
info!(
|
||||
cpu_cores = cpu_count,
|
||||
worker_threads = cpu_count,
|
||||
blocking_threads = cpu_count,
|
||||
thread_stack_size = "10MiB",
|
||||
"Started multi-threaded tokio runtime"
|
||||
);
|
||||
|
||||
// Configure SurrealDB environment variables for better performance
|
||||
configure_surrealdb_performance(cpu_count);
|
||||
|
||||
let filter = std::env::var("RUST_LOG").unwrap_or_else(|_| "info".to_string());
|
||||
let _ = fmt()
|
||||
.with_env_filter(EnvFilter::try_new(&filter).unwrap_or_else(|_| EnvFilter::new("info")))
|
||||
.try_init();
|
||||
|
||||
let parsed = args::parse()?;
|
||||
|
||||
if parsed.show_help {
|
||||
args::print_help();
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if parsed.config.inspect_question.is_some() {
|
||||
inspection::inspect_question(&parsed.config).await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let dataset_kind = parsed.config.dataset;
|
||||
|
||||
if parsed.config.convert_only {
|
||||
info!(
|
||||
dataset = dataset_kind.id(),
|
||||
"Starting dataset conversion only run"
|
||||
);
|
||||
let dataset = crate::datasets::convert(
|
||||
parsed.config.raw_dataset_path.as_path(),
|
||||
dataset_kind,
|
||||
parsed.config.llm_mode,
|
||||
parsed.config.context_token_limit(),
|
||||
)
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"converting {} dataset at {}",
|
||||
dataset_kind.label(),
|
||||
parsed.config.raw_dataset_path.display()
|
||||
)
|
||||
})?;
|
||||
crate::datasets::write_converted(&dataset, parsed.config.converted_dataset_path.as_path())
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"writing converted dataset to {}",
|
||||
parsed.config.converted_dataset_path.display()
|
||||
)
|
||||
})?;
|
||||
println!(
|
||||
"Converted dataset written to {}",
|
||||
parsed.config.converted_dataset_path.display()
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
info!(dataset = dataset_kind.id(), "Preparing converted dataset");
|
||||
let dataset = crate::datasets::ensure_converted(
|
||||
dataset_kind,
|
||||
parsed.config.raw_dataset_path.as_path(),
|
||||
parsed.config.converted_dataset_path.as_path(),
|
||||
parsed.config.force_convert,
|
||||
parsed.config.llm_mode,
|
||||
parsed.config.context_token_limit(),
|
||||
)
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"preparing converted dataset at {}",
|
||||
parsed.config.converted_dataset_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
info!(
|
||||
questions = dataset
|
||||
.paragraphs
|
||||
.iter()
|
||||
.map(|p| p.questions.len())
|
||||
.sum::<usize>(),
|
||||
paragraphs = dataset.paragraphs.len(),
|
||||
dataset = dataset.metadata.id.as_str(),
|
||||
"Dataset ready"
|
||||
);
|
||||
|
||||
if parsed.config.slice_grow.is_some() {
|
||||
eval::grow_slice(&dataset, &parsed.config)
|
||||
.await
|
||||
.context("growing slice ledger")?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
info!("Running retrieval evaluation");
|
||||
let summary = eval::run_evaluation(&dataset, &parsed.config)
|
||||
.await
|
||||
.context("running retrieval evaluation")?;
|
||||
|
||||
let report_paths = report::write_reports(
|
||||
&summary,
|
||||
parsed.config.report_dir.as_path(),
|
||||
parsed.config.summary_sample,
|
||||
)
|
||||
.with_context(|| format!("writing reports to {}", parsed.config.report_dir.display()))?;
|
||||
let perf_log_path = perf::write_perf_logs(
|
||||
&summary,
|
||||
parsed.config.report_dir.as_path(),
|
||||
parsed.config.perf_log_json.as_deref(),
|
||||
parsed.config.perf_log_dir.as_deref(),
|
||||
)
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"writing perf logs under {}",
|
||||
parsed.config.report_dir.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
println!(
|
||||
"[{}] Precision@{k}: {precision:.3} ({correct}/{total}) → JSON: {json} | Markdown: {md} | Perf: {perf}",
|
||||
summary.dataset_label,
|
||||
k = summary.k,
|
||||
precision = summary.precision,
|
||||
correct = summary.correct,
|
||||
total = summary.total_cases,
|
||||
json = report_paths.json.display(),
|
||||
md = report_paths.markdown.display(),
|
||||
perf = perf_log_path.display()
|
||||
);
|
||||
|
||||
if parsed.config.perf_log_console {
|
||||
perf::print_console_summary(&summary);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
16
eval/src/openai.rs
Normal file
16
eval/src/openai.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
use anyhow::{Context, Result};
|
||||
use async_openai::{config::OpenAIConfig, Client};
|
||||
|
||||
const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
|
||||
|
||||
pub fn build_client_from_env() -> Result<(Client<OpenAIConfig>, String)> {
|
||||
let api_key = std::env::var("OPENAI_API_KEY")
|
||||
.context("OPENAI_API_KEY must be set to run retrieval evaluations")?;
|
||||
let base_url =
|
||||
std::env::var("OPENAI_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string());
|
||||
|
||||
let config = OpenAIConfig::new()
|
||||
.with_api_key(api_key)
|
||||
.with_api_base(&base_url);
|
||||
Ok((Client::with_config(config), base_url))
|
||||
}
|
||||
313
eval/src/perf.rs
Normal file
313
eval/src/perf.rs
Normal file
@@ -0,0 +1,313 @@
|
||||
use std::{
|
||||
fs::{self, OpenOptions},
|
||||
io::Write,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::{
|
||||
args,
|
||||
eval::{format_timestamp, EvaluationStageTimings, EvaluationSummary},
|
||||
report,
|
||||
};
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct PerformanceLogEntry {
|
||||
generated_at: String,
|
||||
dataset_id: String,
|
||||
dataset_label: String,
|
||||
run_label: Option<String>,
|
||||
slice_id: String,
|
||||
slice_seed: u64,
|
||||
slice_window_offset: usize,
|
||||
slice_window_length: usize,
|
||||
limit: Option<usize>,
|
||||
total_cases: usize,
|
||||
correct: usize,
|
||||
precision: f64,
|
||||
k: usize,
|
||||
openai_base_url: String,
|
||||
ingestion: IngestionPerf,
|
||||
namespace: NamespacePerf,
|
||||
retrieval: RetrievalPerf,
|
||||
evaluation_stages: EvaluationStageTimings,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct IngestionPerf {
|
||||
duration_ms: u128,
|
||||
cache_path: String,
|
||||
reused: bool,
|
||||
embeddings_reused: bool,
|
||||
fingerprint: String,
|
||||
positives_total: usize,
|
||||
negatives_total: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NamespacePerf {
|
||||
reused: bool,
|
||||
seed_ms: Option<u128>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct RetrievalPerf {
|
||||
latency_ms: crate::eval::LatencyStats,
|
||||
stage_latency: crate::eval::StageLatencyBreakdown,
|
||||
concurrency: usize,
|
||||
rerank_enabled: bool,
|
||||
rerank_pool_size: Option<usize>,
|
||||
rerank_keep_top: usize,
|
||||
evaluated_cases: usize,
|
||||
}
|
||||
|
||||
impl PerformanceLogEntry {
|
||||
fn from_summary(summary: &EvaluationSummary) -> Self {
|
||||
let ingestion = IngestionPerf {
|
||||
duration_ms: summary.perf.ingestion_ms,
|
||||
cache_path: summary.ingestion_cache_path.clone(),
|
||||
reused: summary.ingestion_reused,
|
||||
embeddings_reused: summary.ingestion_embeddings_reused,
|
||||
fingerprint: summary.ingestion_fingerprint.clone(),
|
||||
positives_total: summary.slice_positive_paragraphs,
|
||||
negatives_total: summary.slice_negative_paragraphs,
|
||||
};
|
||||
|
||||
let namespace = NamespacePerf {
|
||||
reused: summary.namespace_reused,
|
||||
seed_ms: summary.perf.namespace_seed_ms,
|
||||
};
|
||||
|
||||
let retrieval = RetrievalPerf {
|
||||
latency_ms: summary.latency_ms.clone(),
|
||||
stage_latency: summary.perf.stage_latency.clone(),
|
||||
concurrency: summary.concurrency,
|
||||
rerank_enabled: summary.rerank_enabled,
|
||||
rerank_pool_size: summary.rerank_pool_size,
|
||||
rerank_keep_top: summary.rerank_keep_top,
|
||||
evaluated_cases: summary.total_cases,
|
||||
};
|
||||
|
||||
Self {
|
||||
generated_at: format_timestamp(&summary.generated_at),
|
||||
dataset_id: summary.dataset_id.clone(),
|
||||
dataset_label: summary.dataset_label.clone(),
|
||||
run_label: summary.run_label.clone(),
|
||||
slice_id: summary.slice_id.clone(),
|
||||
slice_seed: summary.slice_seed,
|
||||
slice_window_offset: summary.slice_window_offset,
|
||||
slice_window_length: summary.slice_window_length,
|
||||
limit: summary.limit,
|
||||
total_cases: summary.total_cases,
|
||||
correct: summary.correct,
|
||||
precision: summary.precision,
|
||||
k: summary.k,
|
||||
openai_base_url: summary.perf.openai_base_url.clone(),
|
||||
ingestion,
|
||||
namespace,
|
||||
retrieval,
|
||||
evaluation_stages: summary.perf.evaluation_stage_ms.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write_perf_logs(
|
||||
summary: &EvaluationSummary,
|
||||
report_root: &Path,
|
||||
extra_json: Option<&Path>,
|
||||
extra_dir: Option<&Path>,
|
||||
) -> Result<PathBuf> {
|
||||
let entry = PerformanceLogEntry::from_summary(summary);
|
||||
let dataset_dir = report::dataset_report_dir(report_root, &summary.dataset_id);
|
||||
fs::create_dir_all(&dataset_dir)
|
||||
.with_context(|| format!("creating dataset perf directory {}", dataset_dir.display()))?;
|
||||
|
||||
let log_path = dataset_dir.join("perf-log.jsonl");
|
||||
let mut file = OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(&log_path)
|
||||
.with_context(|| format!("opening perf log {}", log_path.display()))?;
|
||||
let line = serde_json::to_vec(&entry).context("serialising perf log entry")?;
|
||||
file.write_all(&line)?;
|
||||
file.write_all(b"\n")?;
|
||||
file.flush()?;
|
||||
|
||||
if let Some(path) = extra_json {
|
||||
args::ensure_parent(path)?;
|
||||
let blob = serde_json::to_vec_pretty(&entry).context("serialising perf log JSON")?;
|
||||
fs::write(path, blob)
|
||||
.with_context(|| format!("writing perf log copy to {}", path.display()))?;
|
||||
}
|
||||
|
||||
if let Some(dir) = extra_dir {
|
||||
fs::create_dir_all(dir)
|
||||
.with_context(|| format!("creating perf log directory {}", dir.display()))?;
|
||||
let dataset_slug = dataset_dir
|
||||
.file_name()
|
||||
.and_then(|os| os.to_str())
|
||||
.unwrap_or("dataset");
|
||||
let timestamp = summary.generated_at.format("%Y%m%dT%H%M%S").to_string();
|
||||
let filename = format!("perf-{}-{}.json", dataset_slug, timestamp);
|
||||
let path = dir.join(filename);
|
||||
let blob = serde_json::to_vec_pretty(&entry).context("serialising perf log JSON")?;
|
||||
fs::write(&path, blob)
|
||||
.with_context(|| format!("writing perf log mirror {}", path.display()))?;
|
||||
}
|
||||
|
||||
Ok(log_path)
|
||||
}
|
||||
|
||||
pub fn print_console_summary(summary: &EvaluationSummary) {
|
||||
let perf = &summary.perf;
|
||||
println!(
|
||||
"[perf] ingestion={}ms | namespace_seed={}",
|
||||
perf.ingestion_ms,
|
||||
format_duration(perf.namespace_seed_ms),
|
||||
);
|
||||
let stage = &perf.stage_latency;
|
||||
println!(
|
||||
"[perf] stage avg ms → collect {:.1} | graph {:.1} | chunk {:.1} | rerank {:.1} | assemble {:.1}",
|
||||
stage.collect_candidates.avg,
|
||||
stage.graph_expansion.avg,
|
||||
stage.chunk_attach.avg,
|
||||
stage.rerank.avg,
|
||||
stage.assemble.avg,
|
||||
);
|
||||
let eval = &perf.evaluation_stage_ms;
|
||||
println!(
|
||||
"[perf] eval stage ms → slice {} | db {} | corpus {} | namespace {} | queries {} | summarize {} | finalize {}",
|
||||
eval.prepare_slice_ms,
|
||||
eval.prepare_db_ms,
|
||||
eval.prepare_corpus_ms,
|
||||
eval.prepare_namespace_ms,
|
||||
eval.run_queries_ms,
|
||||
eval.summarize_ms,
|
||||
eval.finalize_ms,
|
||||
);
|
||||
}
|
||||
|
||||
fn format_duration(value: Option<u128>) -> String {
|
||||
value
|
||||
.map(|ms| format!("{ms}ms"))
|
||||
.unwrap_or_else(|| "-".to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::eval::{EvaluationStageTimings, PerformanceTimings};
|
||||
use chrono::Utc;
|
||||
use tempfile::tempdir;
|
||||
|
||||
fn sample_latency() -> crate::eval::LatencyStats {
|
||||
crate::eval::LatencyStats {
|
||||
avg: 10.0,
|
||||
p50: 8,
|
||||
p95: 15,
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_stage_latency() -> crate::eval::StageLatencyBreakdown {
|
||||
crate::eval::StageLatencyBreakdown {
|
||||
collect_candidates: sample_latency(),
|
||||
graph_expansion: sample_latency(),
|
||||
chunk_attach: sample_latency(),
|
||||
rerank: sample_latency(),
|
||||
assemble: sample_latency(),
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_eval_stage() -> EvaluationStageTimings {
|
||||
EvaluationStageTimings {
|
||||
prepare_slice_ms: 10,
|
||||
prepare_db_ms: 20,
|
||||
prepare_corpus_ms: 30,
|
||||
prepare_namespace_ms: 40,
|
||||
run_queries_ms: 50,
|
||||
summarize_ms: 60,
|
||||
finalize_ms: 70,
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_summary() -> EvaluationSummary {
|
||||
EvaluationSummary {
|
||||
generated_at: Utc::now(),
|
||||
k: 5,
|
||||
limit: Some(10),
|
||||
run_label: Some("test".into()),
|
||||
total_cases: 2,
|
||||
correct: 1,
|
||||
precision: 0.5,
|
||||
correct_at_1: 1,
|
||||
correct_at_2: 1,
|
||||
correct_at_3: 1,
|
||||
precision_at_1: 0.5,
|
||||
precision_at_2: 0.5,
|
||||
precision_at_3: 0.5,
|
||||
duration_ms: 1234,
|
||||
dataset_id: "squad-v2".into(),
|
||||
dataset_label: "SQuAD v2".into(),
|
||||
dataset_includes_unanswerable: false,
|
||||
dataset_source: "dev".into(),
|
||||
slice_id: "slice123".into(),
|
||||
slice_seed: 42,
|
||||
slice_total_cases: 400,
|
||||
slice_window_offset: 0,
|
||||
slice_window_length: 10,
|
||||
slice_cases: 10,
|
||||
slice_positive_paragraphs: 10,
|
||||
slice_negative_paragraphs: 40,
|
||||
slice_total_paragraphs: 50,
|
||||
slice_negative_multiplier: 4.0,
|
||||
namespace_reused: true,
|
||||
corpus_paragraphs: 50,
|
||||
ingestion_cache_path: "/tmp/cache".into(),
|
||||
ingestion_reused: true,
|
||||
ingestion_embeddings_reused: true,
|
||||
ingestion_fingerprint: "fingerprint".into(),
|
||||
positive_paragraphs_reused: 10,
|
||||
negative_paragraphs_reused: 40,
|
||||
latency_ms: sample_latency(),
|
||||
perf: PerformanceTimings {
|
||||
openai_base_url: "https://example.com".into(),
|
||||
ingestion_ms: 1000,
|
||||
namespace_seed_ms: Some(150),
|
||||
evaluation_stage_ms: sample_eval_stage(),
|
||||
stage_latency: sample_stage_latency(),
|
||||
},
|
||||
embedding_backend: "fastembed".into(),
|
||||
embedding_model: Some("test-model".into()),
|
||||
embedding_dimension: 32,
|
||||
rerank_enabled: true,
|
||||
rerank_pool_size: Some(4),
|
||||
rerank_keep_top: 10,
|
||||
concurrency: 2,
|
||||
detailed_report: false,
|
||||
chunk_vector_take: 20,
|
||||
chunk_fts_take: 20,
|
||||
chunk_token_budget: 10000,
|
||||
chunk_avg_chars_per_token: 4,
|
||||
max_chunks_per_entity: 4,
|
||||
cases: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn writes_perf_log_jsonl() {
|
||||
let tmp = tempdir().unwrap();
|
||||
let report_root = tmp.path().join("reports");
|
||||
let summary = sample_summary();
|
||||
let log_path = write_perf_logs(&summary, &report_root, None, None).expect("perf log write");
|
||||
assert!(log_path.exists());
|
||||
let contents = std::fs::read_to_string(&log_path).expect("reading perf log jsonl");
|
||||
assert!(
|
||||
contents.contains("\"openai_base_url\":\"https://example.com\""),
|
||||
"serialized log should include base URL"
|
||||
);
|
||||
let dataset_dir = report::dataset_report_dir(&report_root, &summary.dataset_id);
|
||||
assert!(dataset_dir.join("perf-log.jsonl").exists());
|
||||
}
|
||||
}
|
||||
456
eval/src/report.rs
Normal file
456
eval/src/report.rs
Normal file
@@ -0,0 +1,456 @@
|
||||
use std::{
|
||||
fs,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::eval::{format_timestamp, CaseSummary, EvaluationSummary, LatencyStats};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ReportPaths {
|
||||
pub json: PathBuf,
|
||||
pub markdown: PathBuf,
|
||||
}
|
||||
|
||||
pub fn write_reports(
|
||||
summary: &EvaluationSummary,
|
||||
report_dir: &Path,
|
||||
sample: usize,
|
||||
) -> Result<ReportPaths> {
|
||||
fs::create_dir_all(report_dir)
|
||||
.with_context(|| format!("creating report directory {}", report_dir.display()))?;
|
||||
let dataset_dir = dataset_report_dir(report_dir, &summary.dataset_id);
|
||||
fs::create_dir_all(&dataset_dir).with_context(|| {
|
||||
format!(
|
||||
"creating dataset report directory {}",
|
||||
dataset_dir.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
let stem = build_report_stem(summary);
|
||||
|
||||
let json_path = dataset_dir.join(format!("{stem}.json"));
|
||||
let json_blob = serde_json::to_string_pretty(summary).context("serialising JSON report")?;
|
||||
fs::write(&json_path, &json_blob)
|
||||
.with_context(|| format!("writing JSON report to {}", json_path.display()))?;
|
||||
|
||||
let md_path = dataset_dir.join(format!("{stem}.md"));
|
||||
let markdown = render_markdown(summary, sample);
|
||||
fs::write(&md_path, &markdown)
|
||||
.with_context(|| format!("writing Markdown report to {}", md_path.display()))?;
|
||||
|
||||
// Keep a latest.json pointer to simplify automation.
|
||||
let latest_json = dataset_dir.join("latest.json");
|
||||
fs::write(&latest_json, json_blob)
|
||||
.with_context(|| format!("writing latest JSON report to {}", latest_json.display()))?;
|
||||
let latest_md = dataset_dir.join("latest.md");
|
||||
fs::write(&latest_md, markdown)
|
||||
.with_context(|| format!("writing latest Markdown report to {}", latest_md.display()))?;
|
||||
|
||||
record_history(summary, &dataset_dir)?;
|
||||
|
||||
Ok(ReportPaths {
|
||||
json: json_path,
|
||||
markdown: md_path,
|
||||
})
|
||||
}
|
||||
|
||||
fn render_markdown(summary: &EvaluationSummary, sample: usize) -> String {
|
||||
let mut md = String::new();
|
||||
|
||||
md.push_str(&format!("# Retrieval Precision@{}\n\n", summary.k));
|
||||
md.push_str("| Metric | Value |\n");
|
||||
md.push_str("| --- | --- |\n");
|
||||
md.push_str(&format!(
|
||||
"| Generated | {} |\n",
|
||||
format_timestamp(&summary.generated_at)
|
||||
));
|
||||
md.push_str(&format!(
|
||||
"| Dataset | {} (`{}`) |\n",
|
||||
summary.dataset_label, summary.dataset_id
|
||||
));
|
||||
md.push_str(&format!(
|
||||
"| Run Label | {} |\n",
|
||||
summary
|
||||
.run_label
|
||||
.as_deref()
|
||||
.filter(|label| !label.is_empty())
|
||||
.unwrap_or("-")
|
||||
));
|
||||
md.push_str(&format!(
|
||||
"| Unanswerable Included | {} |\n",
|
||||
if summary.dataset_includes_unanswerable {
|
||||
"yes"
|
||||
} else {
|
||||
"no"
|
||||
}
|
||||
));
|
||||
md.push_str(&format!(
|
||||
"| Dataset Source | {} |\n",
|
||||
summary.dataset_source
|
||||
));
|
||||
md.push_str(&format!(
|
||||
"| OpenAI Base URL | {} |\n",
|
||||
summary.perf.openai_base_url
|
||||
));
|
||||
md.push_str(&format!("| Slice ID | `{}` |\n", summary.slice_id));
|
||||
md.push_str(&format!("| Slice Seed | {} |\n", summary.slice_seed));
|
||||
md.push_str(&format!(
|
||||
"| Slice Total Questions | {} |\n",
|
||||
summary.slice_total_cases
|
||||
));
|
||||
md.push_str(&format!(
|
||||
"| Slice Window (offset/length) | {}/{} |\n",
|
||||
summary.slice_window_offset, summary.slice_window_length
|
||||
));
|
||||
md.push_str(&format!(
|
||||
"| Slice Window Questions | {} |\n",
|
||||
summary.slice_cases
|
||||
));
|
||||
md.push_str(&format!(
|
||||
"| Slice Negatives | {} |\n",
|
||||
summary.slice_negative_paragraphs
|
||||
));
|
||||
md.push_str(&format!(
|
||||
"| Slice Total Paragraphs | {} |\n",
|
||||
summary.slice_total_paragraphs
|
||||
));
|
||||
md.push_str(&format!(
|
||||
"| Slice Negative Multiplier | {:.2} |\n",
|
||||
summary.slice_negative_multiplier
|
||||
));
|
||||
md.push_str(&format!(
|
||||
"| Namespace State | {} |\n",
|
||||
if summary.namespace_reused {
|
||||
"reused"
|
||||
} else {
|
||||
"seeded"
|
||||
}
|
||||
));
|
||||
md.push_str(&format!(
|
||||
"| Corpus Paragraphs | {} |\n",
|
||||
summary.corpus_paragraphs
|
||||
));
|
||||
md.push_str(&format!(
|
||||
"| Ingestion Duration | {} ms |\n",
|
||||
summary.perf.ingestion_ms
|
||||
));
|
||||
if let Some(seed) = summary.perf.namespace_seed_ms {
|
||||
md.push_str(&format!("| Namespace Seed | {} ms |\n", seed));
|
||||
}
|
||||
if summary.detailed_report {
|
||||
md.push_str(&format!(
|
||||
"| Ingestion Cache | `{}` |\n",
|
||||
summary.ingestion_cache_path
|
||||
));
|
||||
md.push_str(&format!(
|
||||
"| Ingestion Reused | {} |\n",
|
||||
if summary.ingestion_reused {
|
||||
"yes"
|
||||
} else {
|
||||
"no"
|
||||
}
|
||||
));
|
||||
md.push_str(&format!(
|
||||
"| Embeddings Reused | {} |\n",
|
||||
if summary.ingestion_embeddings_reused {
|
||||
"yes"
|
||||
} else {
|
||||
"no"
|
||||
}
|
||||
));
|
||||
}
|
||||
md.push_str(&format!(
|
||||
"| Positives Cached | {} |
|
||||
",
|
||||
summary.positive_paragraphs_reused
|
||||
));
|
||||
md.push_str(&format!(
|
||||
"| Negatives Cached | {} |
|
||||
",
|
||||
summary.negative_paragraphs_reused
|
||||
));
|
||||
let embedding_label = if let Some(model) = summary.embedding_model.as_ref() {
|
||||
format!("{} ({model})", summary.embedding_backend)
|
||||
} else {
|
||||
summary.embedding_backend.clone()
|
||||
};
|
||||
md.push_str(&format!("| Embedding | {} |\n", embedding_label));
|
||||
md.push_str(&format!(
|
||||
"| Embedding Dim | {} |\n",
|
||||
summary.embedding_dimension
|
||||
));
|
||||
if let Some(limit) = summary.limit {
|
||||
md.push_str(&format!(
|
||||
"| Evaluated Queries | {} (limit {}) |\n",
|
||||
summary.total_cases, limit
|
||||
));
|
||||
} else {
|
||||
md.push_str(&format!(
|
||||
"| Evaluated Queries | {} |\n",
|
||||
summary.total_cases
|
||||
));
|
||||
}
|
||||
if summary.rerank_enabled {
|
||||
let pool = summary
|
||||
.rerank_pool_size
|
||||
.map(|size| size.to_string())
|
||||
.unwrap_or_else(|| "?".to_string());
|
||||
md.push_str(&format!(
|
||||
"| Rerank | enabled (pool {pool}, keep top {}) |\n",
|
||||
summary.rerank_keep_top
|
||||
));
|
||||
} else {
|
||||
md.push_str("| Rerank | disabled |\n");
|
||||
}
|
||||
md.push_str(&format!("| Concurrency | {} |\n", summary.concurrency));
|
||||
md.push_str(&format!(
|
||||
"| Correct@{} | {}/{} |\n",
|
||||
summary.k, summary.correct, summary.total_cases
|
||||
));
|
||||
md.push_str(&format!(
|
||||
"| Precision@{} | {:.3} |\n",
|
||||
summary.k, summary.precision
|
||||
));
|
||||
md.push_str(&format!(
|
||||
"| Precision@1 | {:.3} |\n",
|
||||
summary.precision_at_1
|
||||
));
|
||||
md.push_str(&format!(
|
||||
"| Precision@2 | {:.3} |\n",
|
||||
summary.precision_at_2
|
||||
));
|
||||
md.push_str(&format!(
|
||||
"| Precision@3 | {:.3} |\n",
|
||||
summary.precision_at_3
|
||||
));
|
||||
md.push_str(&format!("| Duration | {} ms |\n", summary.duration_ms));
|
||||
md.push_str(&format!(
|
||||
"| Latency Avg (ms) | {:.1} |\n",
|
||||
summary.latency_ms.avg
|
||||
));
|
||||
md.push_str(&format!(
|
||||
"| Latency P50 (ms) | {} |\n",
|
||||
summary.latency_ms.p50
|
||||
));
|
||||
md.push_str(&format!(
|
||||
"| Latency P95 (ms) | {} |\n",
|
||||
summary.latency_ms.p95
|
||||
));
|
||||
|
||||
md.push_str("\n## Retrieval Stage Timings\n\n");
|
||||
md.push_str("| Stage | Avg (ms) | P50 (ms) | P95 (ms) |\n");
|
||||
md.push_str("| --- | --- | --- | --- |\n");
|
||||
write_stage_row(
|
||||
&mut md,
|
||||
"Collect Candidates",
|
||||
&summary.perf.stage_latency.collect_candidates,
|
||||
);
|
||||
write_stage_row(
|
||||
&mut md,
|
||||
"Graph Expansion",
|
||||
&summary.perf.stage_latency.graph_expansion,
|
||||
);
|
||||
write_stage_row(
|
||||
&mut md,
|
||||
"Chunk Attach",
|
||||
&summary.perf.stage_latency.chunk_attach,
|
||||
);
|
||||
write_stage_row(&mut md, "Rerank", &summary.perf.stage_latency.rerank);
|
||||
write_stage_row(&mut md, "Assemble", &summary.perf.stage_latency.assemble);
|
||||
|
||||
let misses: Vec<&CaseSummary> = summary.cases.iter().filter(|case| !case.matched).collect();
|
||||
if !misses.is_empty() {
|
||||
md.push_str("\n## Missed Queries (sample)\n\n");
|
||||
if summary.detailed_report {
|
||||
md.push_str(
|
||||
"| Question ID | Paragraph | Expected Source | Entity Match | Chunk Text | Chunk ID | Top Retrieved |\n",
|
||||
);
|
||||
md.push_str("| --- | --- | --- | --- | --- | --- | --- |\n");
|
||||
} else {
|
||||
md.push_str("| Question ID | Paragraph | Expected Source | Top Retrieved |\n");
|
||||
md.push_str("| --- | --- | --- | --- |\n");
|
||||
}
|
||||
|
||||
for case in misses.iter().take(sample) {
|
||||
let retrieved = case
|
||||
.retrieved
|
||||
.iter()
|
||||
.map(|entry| format!("{} (rank {})", entry.source_id, entry.rank))
|
||||
.take(3)
|
||||
.collect::<Vec<_>>()
|
||||
.join("<br>");
|
||||
if summary.detailed_report {
|
||||
md.push_str(&format!(
|
||||
"| `{}` | {} | `{}` | {} | {} | {} | {} |\n",
|
||||
case.question_id,
|
||||
case.paragraph_title,
|
||||
case.expected_source,
|
||||
bool_badge(case.entity_match),
|
||||
bool_badge(case.chunk_text_match),
|
||||
bool_badge(case.chunk_id_match),
|
||||
retrieved
|
||||
));
|
||||
} else {
|
||||
md.push_str(&format!(
|
||||
"| `{}` | {} | `{}` | {} |\n",
|
||||
case.question_id, case.paragraph_title, case.expected_source, retrieved
|
||||
));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
md.push_str("\n_All evaluated queries matched within the top-k window._\n");
|
||||
if summary.detailed_report {
|
||||
md.push_str(
|
||||
"\nSuccess measures were captured for each query (entity, chunk text, chunk ID).\n",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
md
|
||||
}
|
||||
|
||||
fn write_stage_row(buf: &mut String, label: &str, stats: &LatencyStats) {
|
||||
buf.push_str(&format!(
|
||||
"| {} | {:.1} | {} | {} |\n",
|
||||
label, stats.avg, stats.p50, stats.p95
|
||||
));
|
||||
}
|
||||
|
||||
fn bool_badge(value: bool) -> &'static str {
|
||||
if value {
|
||||
"✅"
|
||||
} else {
|
||||
"⚪"
|
||||
}
|
||||
}
|
||||
|
||||
fn build_report_stem(summary: &EvaluationSummary) -> String {
|
||||
let timestamp = summary.generated_at.format("%Y%m%dT%H%M%S");
|
||||
let backend = sanitize_component(&summary.embedding_backend);
|
||||
let dataset_component = sanitize_component(&summary.dataset_id);
|
||||
let model_component = summary
|
||||
.embedding_model
|
||||
.as_ref()
|
||||
.map(|model| sanitize_component(model));
|
||||
|
||||
match model_component {
|
||||
Some(model) => format!(
|
||||
"precision_at_{}_{}_{}_{}_{}",
|
||||
summary.k, dataset_component, timestamp, backend, model
|
||||
),
|
||||
None => format!(
|
||||
"precision_at_{}_{}_{}_{}",
|
||||
summary.k, dataset_component, timestamp, backend
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn sanitize_component(input: &str) -> String {
|
||||
input
|
||||
.chars()
|
||||
.map(|ch| if ch.is_ascii_alphanumeric() { ch } else { '_' })
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn dataset_report_dir(report_dir: &Path, dataset_id: &str) -> PathBuf {
|
||||
report_dir.join(sanitize_component(dataset_id))
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct HistoryEntry {
|
||||
generated_at: String,
|
||||
run_label: Option<String>,
|
||||
dataset_id: String,
|
||||
dataset_label: String,
|
||||
slice_id: String,
|
||||
slice_seed: u64,
|
||||
slice_window_offset: usize,
|
||||
slice_window_length: usize,
|
||||
slice_cases: usize,
|
||||
slice_total_cases: usize,
|
||||
k: usize,
|
||||
limit: Option<usize>,
|
||||
precision: f64,
|
||||
precision_at_1: f64,
|
||||
precision_at_2: f64,
|
||||
precision_at_3: f64,
|
||||
duration_ms: u128,
|
||||
latency_ms: LatencyStats,
|
||||
embedding_backend: String,
|
||||
embedding_model: Option<String>,
|
||||
ingestion_reused: bool,
|
||||
ingestion_embeddings_reused: bool,
|
||||
rerank_enabled: bool,
|
||||
rerank_keep_top: usize,
|
||||
rerank_pool_size: Option<usize>,
|
||||
delta: Option<HistoryDelta>,
|
||||
openai_base_url: String,
|
||||
ingestion_ms: u128,
|
||||
#[serde(default)]
|
||||
namespace_seed_ms: Option<u128>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct HistoryDelta {
|
||||
precision: f64,
|
||||
precision_at_1: f64,
|
||||
latency_avg_ms: f64,
|
||||
}
|
||||
|
||||
fn record_history(summary: &EvaluationSummary, report_dir: &Path) -> Result<()> {
|
||||
let path = report_dir.join("evaluations.json");
|
||||
let mut entries: Vec<HistoryEntry> = if path.exists() {
|
||||
let contents = fs::read(&path)
|
||||
.with_context(|| format!("reading evaluation log {}", path.display()))?;
|
||||
serde_json::from_slice(&contents).unwrap_or_default()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
let delta = entries.last().map(|prev| HistoryDelta {
|
||||
precision: summary.precision - prev.precision,
|
||||
precision_at_1: summary.precision_at_1 - prev.precision_at_1,
|
||||
latency_avg_ms: summary.latency_ms.avg - prev.latency_ms.avg,
|
||||
});
|
||||
|
||||
let entry = HistoryEntry {
|
||||
generated_at: format_timestamp(&summary.generated_at),
|
||||
run_label: summary.run_label.clone(),
|
||||
dataset_id: summary.dataset_id.clone(),
|
||||
dataset_label: summary.dataset_label.clone(),
|
||||
slice_id: summary.slice_id.clone(),
|
||||
slice_seed: summary.slice_seed,
|
||||
slice_window_offset: summary.slice_window_offset,
|
||||
slice_window_length: summary.slice_window_length,
|
||||
slice_cases: summary.slice_cases,
|
||||
slice_total_cases: summary.slice_total_cases,
|
||||
k: summary.k,
|
||||
limit: summary.limit,
|
||||
precision: summary.precision,
|
||||
precision_at_1: summary.precision_at_1,
|
||||
precision_at_2: summary.precision_at_2,
|
||||
precision_at_3: summary.precision_at_3,
|
||||
duration_ms: summary.duration_ms,
|
||||
latency_ms: summary.latency_ms.clone(),
|
||||
embedding_backend: summary.embedding_backend.clone(),
|
||||
embedding_model: summary.embedding_model.clone(),
|
||||
ingestion_reused: summary.ingestion_reused,
|
||||
ingestion_embeddings_reused: summary.ingestion_embeddings_reused,
|
||||
rerank_enabled: summary.rerank_enabled,
|
||||
rerank_keep_top: summary.rerank_keep_top,
|
||||
rerank_pool_size: summary.rerank_pool_size,
|
||||
delta,
|
||||
openai_base_url: summary.perf.openai_base_url.clone(),
|
||||
ingestion_ms: summary.perf.ingestion_ms,
|
||||
namespace_seed_ms: summary.perf.namespace_seed_ms,
|
||||
};
|
||||
|
||||
entries.push(entry);
|
||||
|
||||
let blob = serde_json::to_vec_pretty(&entries).context("serialising evaluation log")?;
|
||||
fs::write(&path, blob).with_context(|| format!("writing evaluation log {}", path.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
27
eval/src/slice.rs
Normal file
27
eval/src/slice.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
use crate::slices::SliceConfig as CoreSliceConfig;
|
||||
|
||||
pub use crate::slices::*;
|
||||
|
||||
use crate::args::Config;
|
||||
|
||||
impl<'a> From<&'a Config> for CoreSliceConfig<'a> {
|
||||
fn from(config: &'a Config) -> Self {
|
||||
slice_config_with_limit(config, None)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn slice_config_with_limit<'a>(
|
||||
config: &'a Config,
|
||||
limit_override: Option<usize>,
|
||||
) -> CoreSliceConfig<'a> {
|
||||
CoreSliceConfig {
|
||||
cache_dir: config.cache_dir.as_path(),
|
||||
force_convert: config.force_convert,
|
||||
explicit_slice: config.slice.as_deref(),
|
||||
limit: limit_override.or(config.limit),
|
||||
corpus_limit: config.corpus_limit,
|
||||
slice_seed: config.slice_seed,
|
||||
llm_mode: config.llm_mode,
|
||||
negative_multiplier: config.negative_multiplier,
|
||||
}
|
||||
}
|
||||
941
eval/src/slices.rs
Normal file
941
eval/src/slices.rs
Normal file
@@ -0,0 +1,941 @@
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
fs,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use rand::{rngs::StdRng, seq::SliceRandom, SeedableRng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::datasets::{ConvertedDataset, ConvertedParagraph, ConvertedQuestion};
|
||||
|
||||
const SLICE_VERSION: u32 = 2;
|
||||
pub const DEFAULT_NEGATIVE_MULTIPLIER: f32 = 4.0;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SliceConfig<'a> {
|
||||
pub cache_dir: &'a Path,
|
||||
pub force_convert: bool,
|
||||
pub explicit_slice: Option<&'a str>,
|
||||
pub limit: Option<usize>,
|
||||
pub corpus_limit: Option<usize>,
|
||||
pub slice_seed: u64,
|
||||
pub llm_mode: bool,
|
||||
pub negative_multiplier: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SliceManifest {
|
||||
pub version: u32,
|
||||
pub slice_id: String,
|
||||
pub dataset_id: String,
|
||||
pub dataset_label: String,
|
||||
pub dataset_source: String,
|
||||
pub includes_unanswerable: bool,
|
||||
pub seed: u64,
|
||||
pub requested_limit: Option<usize>,
|
||||
pub requested_corpus: usize,
|
||||
pub generated_at: DateTime<Utc>,
|
||||
pub case_count: usize,
|
||||
pub positive_paragraphs: usize,
|
||||
pub negative_paragraphs: usize,
|
||||
pub total_paragraphs: usize,
|
||||
pub negative_multiplier: f32,
|
||||
pub cases: Vec<SliceCaseEntry>,
|
||||
pub paragraphs: Vec<SliceParagraphEntry>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SliceCaseEntry {
|
||||
pub question_id: String,
|
||||
pub paragraph_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SliceParagraphEntry {
|
||||
pub id: String,
|
||||
pub kind: SliceParagraphKind,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub shard_path: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "kind", rename_all = "snake_case")]
|
||||
pub enum SliceParagraphKind {
|
||||
Positive { question_ids: Vec<String> },
|
||||
Negative,
|
||||
}
|
||||
|
||||
pub(crate) fn default_shard_path(paragraph_id: &str) -> String {
|
||||
let sanitized = sanitize_identifier(paragraph_id);
|
||||
format!("paragraphs/{sanitized}.json")
|
||||
}
|
||||
|
||||
fn sanitize_identifier(input: &str) -> String {
|
||||
let mut sanitized = String::with_capacity(input.len());
|
||||
for ch in input.chars() {
|
||||
if ch.is_ascii_alphanumeric() {
|
||||
sanitized.push(ch);
|
||||
} else {
|
||||
sanitized.push('-');
|
||||
}
|
||||
}
|
||||
let trimmed = sanitized.trim_matches('-').to_string();
|
||||
if trimmed.is_empty() {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(input.as_bytes());
|
||||
let digest = hasher.finalize();
|
||||
digest[..6]
|
||||
.iter()
|
||||
.map(|byte| format!("{byte:02x}"))
|
||||
.collect::<String>()
|
||||
} else {
|
||||
trimmed
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ResolvedSlice<'a> {
|
||||
pub manifest: SliceManifest,
|
||||
pub path: PathBuf,
|
||||
pub paragraphs: Vec<&'a ConvertedParagraph>,
|
||||
pub cases: Vec<CaseRef<'a>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SliceWindow<'a> {
|
||||
pub offset: usize,
|
||||
pub length: usize,
|
||||
pub total_cases: usize,
|
||||
pub cases: Vec<CaseRef<'a>>,
|
||||
positive_paragraph_ids: Vec<String>,
|
||||
}
|
||||
|
||||
impl<'a> SliceWindow<'a> {
|
||||
pub fn positive_ids(&self) -> impl Iterator<Item = &str> {
|
||||
self.positive_paragraph_ids.iter().map(|id| id.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CaseRef<'a> {
|
||||
pub paragraph: &'a ConvertedParagraph,
|
||||
pub question: &'a ConvertedQuestion,
|
||||
}
|
||||
|
||||
struct DatasetIndex {
|
||||
paragraph_by_id: HashMap<String, usize>,
|
||||
question_by_id: HashMap<String, (usize, usize)>,
|
||||
}
|
||||
|
||||
impl DatasetIndex {
|
||||
fn build(dataset: &ConvertedDataset) -> Self {
|
||||
let mut paragraph_by_id = HashMap::new();
|
||||
let mut question_by_id = HashMap::new();
|
||||
|
||||
for (p_idx, paragraph) in dataset.paragraphs.iter().enumerate() {
|
||||
paragraph_by_id.insert(paragraph.id.clone(), p_idx);
|
||||
for (q_idx, question) in paragraph.questions.iter().enumerate() {
|
||||
question_by_id.insert(question.id.clone(), (p_idx, q_idx));
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
paragraph_by_id,
|
||||
question_by_id,
|
||||
}
|
||||
}
|
||||
|
||||
fn paragraph<'a>(
|
||||
&self,
|
||||
dataset: &'a ConvertedDataset,
|
||||
id: &str,
|
||||
) -> Result<&'a ConvertedParagraph> {
|
||||
let idx = self
|
||||
.paragraph_by_id
|
||||
.get(id)
|
||||
.ok_or_else(|| anyhow!("slice references unknown paragraph '{id}'"))?;
|
||||
Ok(&dataset.paragraphs[*idx])
|
||||
}
|
||||
|
||||
fn question<'a>(
|
||||
&self,
|
||||
dataset: &'a ConvertedDataset,
|
||||
question_id: &str,
|
||||
) -> Result<(&'a ConvertedParagraph, &'a ConvertedQuestion)> {
|
||||
let (p_idx, q_idx) = self
|
||||
.question_by_id
|
||||
.get(question_id)
|
||||
.ok_or_else(|| anyhow!("slice references unknown question '{question_id}'"))?;
|
||||
let paragraph = &dataset.paragraphs[*p_idx];
|
||||
let question = paragraph
|
||||
.questions
|
||||
.get(*q_idx)
|
||||
.ok_or_else(|| anyhow!("slice maps question '{question_id}' to missing index"))?;
|
||||
Ok((paragraph, question))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct SliceKey<'a> {
|
||||
dataset_id: &'a str,
|
||||
includes_unanswerable: bool,
|
||||
requested_corpus: usize,
|
||||
seed: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct BuildParams {
|
||||
include_impossible: bool,
|
||||
base_seed: u64,
|
||||
rng_seed: u64,
|
||||
}
|
||||
|
||||
pub fn resolve_slice<'a>(
|
||||
dataset: &'a ConvertedDataset,
|
||||
config: &SliceConfig<'_>,
|
||||
) -> Result<ResolvedSlice<'a>> {
|
||||
let index = DatasetIndex::build(dataset);
|
||||
|
||||
if let Some(slice_arg) = config.explicit_slice {
|
||||
let (path, manifest) = load_explicit_slice(dataset, &index, config, slice_arg)?;
|
||||
let resolved = manifest_to_resolved(dataset, &index, manifest, path)?;
|
||||
info!(
|
||||
slice = %resolved.manifest.slice_id,
|
||||
path = %resolved.path.display(),
|
||||
cases = resolved.manifest.case_count,
|
||||
positives = resolved.manifest.positive_paragraphs,
|
||||
negatives = resolved.manifest.negative_paragraphs,
|
||||
"Using explicitly selected slice"
|
||||
);
|
||||
return Ok(resolved);
|
||||
}
|
||||
|
||||
let requested_corpus = config
|
||||
.corpus_limit
|
||||
.unwrap_or(dataset.paragraphs.len())
|
||||
.min(dataset.paragraphs.len())
|
||||
.max(1);
|
||||
let key = SliceKey {
|
||||
dataset_id: dataset.metadata.id.as_str(),
|
||||
includes_unanswerable: dataset.metadata.include_unanswerable,
|
||||
requested_corpus,
|
||||
seed: config.slice_seed,
|
||||
};
|
||||
let slice_id = compute_slice_id(&key);
|
||||
let base = config
|
||||
.cache_dir
|
||||
.join("slices")
|
||||
.join(dataset.metadata.id.as_str());
|
||||
let path = base.join(format!("{slice_id}.json"));
|
||||
|
||||
let total_questions = dataset
|
||||
.paragraphs
|
||||
.iter()
|
||||
.map(|p| p.questions.len())
|
||||
.sum::<usize>()
|
||||
.max(1);
|
||||
let requested_limit = config
|
||||
.limit
|
||||
.unwrap_or(total_questions)
|
||||
.min(total_questions)
|
||||
.max(1);
|
||||
|
||||
let mut manifest = if !config.force_convert && path.exists() {
|
||||
match read_manifest(&path) {
|
||||
Ok(manifest) if manifest.dataset_id == dataset.metadata.id => {
|
||||
if manifest.includes_unanswerable != dataset.metadata.include_unanswerable {
|
||||
warn!(
|
||||
slice = manifest.slice_id,
|
||||
path = %path.display(),
|
||||
"Slice manifest includes_unanswerable mismatch; regenerating"
|
||||
);
|
||||
None
|
||||
} else {
|
||||
Some(manifest)
|
||||
}
|
||||
}
|
||||
Ok(manifest) => {
|
||||
warn!(
|
||||
slice = manifest.slice_id,
|
||||
path = %path.display(),
|
||||
loaded_dataset = %manifest.dataset_id,
|
||||
expected = %dataset.metadata.id,
|
||||
"Slice manifest targets different dataset; regenerating"
|
||||
);
|
||||
None
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
path = %path.display(),
|
||||
error = %err,
|
||||
"Failed to read cached slice; regenerating"
|
||||
);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let params = BuildParams {
|
||||
include_impossible: config.llm_mode,
|
||||
base_seed: config.slice_seed,
|
||||
rng_seed: mix_seed(dataset.metadata.id.as_str(), config.slice_seed),
|
||||
};
|
||||
|
||||
if manifest
|
||||
.as_ref()
|
||||
.map(|manifest| manifest.version != SLICE_VERSION)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
warn!(
|
||||
slice = manifest
|
||||
.as_ref()
|
||||
.map(|m| m.slice_id.as_str())
|
||||
.unwrap_or("unknown"),
|
||||
found = manifest.as_ref().map(|m| m.version).unwrap_or(0),
|
||||
expected = SLICE_VERSION,
|
||||
"Slice manifest version mismatch; regenerating"
|
||||
);
|
||||
manifest = None;
|
||||
}
|
||||
|
||||
let mut manifest = manifest.unwrap_or_else(|| {
|
||||
empty_manifest(
|
||||
dataset,
|
||||
slice_id.clone(),
|
||||
¶ms,
|
||||
requested_corpus,
|
||||
config.negative_multiplier,
|
||||
config.limit,
|
||||
)
|
||||
});
|
||||
|
||||
manifest.requested_limit = config.limit;
|
||||
manifest.requested_corpus = requested_corpus;
|
||||
manifest.negative_multiplier = config.negative_multiplier;
|
||||
|
||||
let mut changed = ensure_shard_paths(&mut manifest);
|
||||
|
||||
changed |= ensure_case_capacity(dataset, &mut manifest, ¶ms, requested_limit)?;
|
||||
refresh_manifest_stats(&mut manifest);
|
||||
|
||||
let desired_negatives = desired_negative_target(
|
||||
manifest.positive_paragraphs,
|
||||
requested_corpus,
|
||||
dataset.paragraphs.len(),
|
||||
config.negative_multiplier,
|
||||
);
|
||||
changed |= ensure_negative_pool(
|
||||
dataset,
|
||||
&mut manifest,
|
||||
¶ms,
|
||||
desired_negatives,
|
||||
requested_corpus,
|
||||
)?;
|
||||
refresh_manifest_stats(&mut manifest);
|
||||
|
||||
if changed {
|
||||
manifest.generated_at = Utc::now();
|
||||
write_manifest(&path, &manifest)?;
|
||||
info!(
|
||||
slice = %manifest.slice_id,
|
||||
path = %path.display(),
|
||||
cases = manifest.case_count,
|
||||
positives = manifest.positive_paragraphs,
|
||||
negatives = manifest.negative_paragraphs,
|
||||
"Updated dataset slice ledger"
|
||||
);
|
||||
} else {
|
||||
info!(
|
||||
slice = %manifest.slice_id,
|
||||
path = %path.display(),
|
||||
cases = manifest.case_count,
|
||||
positives = manifest.positive_paragraphs,
|
||||
negatives = manifest.negative_paragraphs,
|
||||
"Reusing cached slice ledger"
|
||||
);
|
||||
}
|
||||
|
||||
let resolved = manifest_to_resolved(dataset, &index, manifest.clone(), path.clone())?;
|
||||
|
||||
Ok(resolved)
|
||||
}
|
||||
|
||||
pub fn select_window<'a>(
|
||||
resolved: &'a ResolvedSlice<'a>,
|
||||
offset: usize,
|
||||
limit: Option<usize>,
|
||||
) -> Result<SliceWindow<'a>> {
|
||||
let total = resolved.manifest.case_count;
|
||||
if total == 0 {
|
||||
return Err(anyhow!(
|
||||
"slice '{}' contains no cases",
|
||||
resolved.manifest.slice_id
|
||||
));
|
||||
}
|
||||
if offset >= total {
|
||||
return Err(anyhow!(
|
||||
"slice offset {} exceeds available cases ({})",
|
||||
offset,
|
||||
total
|
||||
));
|
||||
}
|
||||
let available = total - offset;
|
||||
let requested = limit.unwrap_or(available).max(1);
|
||||
let length = requested.min(available);
|
||||
let cases = resolved.cases[offset..offset + length].to_vec();
|
||||
let mut seen = HashSet::new();
|
||||
let mut positive_ids = Vec::new();
|
||||
for case in &cases {
|
||||
if seen.insert(case.paragraph.id.as_str()) {
|
||||
positive_ids.push(case.paragraph.id.clone());
|
||||
}
|
||||
}
|
||||
Ok(SliceWindow {
|
||||
offset,
|
||||
length,
|
||||
total_cases: total,
|
||||
cases,
|
||||
positive_paragraph_ids: positive_ids,
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn full_window<'a>(resolved: &'a ResolvedSlice<'a>) -> Result<SliceWindow<'a>> {
|
||||
select_window(resolved, 0, None)
|
||||
}
|
||||
|
||||
fn load_explicit_slice<'a>(
|
||||
dataset: &'a ConvertedDataset,
|
||||
index: &DatasetIndex,
|
||||
config: &SliceConfig<'_>,
|
||||
slice_arg: &str,
|
||||
) -> Result<(PathBuf, SliceManifest)> {
|
||||
let explicit_path = Path::new(slice_arg);
|
||||
let candidate_path = if explicit_path.exists() {
|
||||
explicit_path.to_path_buf()
|
||||
} else {
|
||||
config
|
||||
.cache_dir
|
||||
.join("slices")
|
||||
.join(dataset.metadata.id.as_str())
|
||||
.join(format!("{slice_arg}.json"))
|
||||
};
|
||||
|
||||
let manifest = read_manifest(&candidate_path)
|
||||
.with_context(|| format!("reading slice manifest at {}", candidate_path.display()))?;
|
||||
|
||||
if manifest.dataset_id != dataset.metadata.id {
|
||||
return Err(anyhow!(
|
||||
"slice '{}' targets dataset '{}', but '{}' is loaded",
|
||||
manifest.slice_id,
|
||||
manifest.dataset_id,
|
||||
dataset.metadata.id
|
||||
));
|
||||
}
|
||||
|
||||
// Validate the manifest before returning.
|
||||
manifest_to_resolved(dataset, index, manifest.clone(), candidate_path.clone())?;
|
||||
|
||||
Ok((candidate_path, manifest))
|
||||
}
|
||||
|
||||
fn empty_manifest(
|
||||
dataset: &ConvertedDataset,
|
||||
slice_id: String,
|
||||
params: &BuildParams,
|
||||
requested_corpus: usize,
|
||||
negative_multiplier: f32,
|
||||
requested_limit: Option<usize>,
|
||||
) -> SliceManifest {
|
||||
SliceManifest {
|
||||
version: SLICE_VERSION,
|
||||
slice_id,
|
||||
dataset_id: dataset.metadata.id.clone(),
|
||||
dataset_label: dataset.metadata.label.clone(),
|
||||
dataset_source: dataset.source.clone(),
|
||||
includes_unanswerable: dataset.metadata.include_unanswerable,
|
||||
seed: params.base_seed,
|
||||
requested_limit,
|
||||
requested_corpus,
|
||||
negative_multiplier,
|
||||
generated_at: Utc::now(),
|
||||
case_count: 0,
|
||||
positive_paragraphs: 0,
|
||||
negative_paragraphs: 0,
|
||||
total_paragraphs: 0,
|
||||
cases: Vec::new(),
|
||||
paragraphs: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn ensure_case_capacity(
|
||||
dataset: &ConvertedDataset,
|
||||
manifest: &mut SliceManifest,
|
||||
params: &BuildParams,
|
||||
target_cases: usize,
|
||||
) -> Result<bool> {
|
||||
if manifest.case_count >= target_cases {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let question_refs = ordered_question_refs(dataset, params)?;
|
||||
let mut existing_questions: HashSet<String> = manifest
|
||||
.cases
|
||||
.iter()
|
||||
.map(|case| case.question_id.clone())
|
||||
.collect();
|
||||
let mut paragraph_positions: HashMap<String, usize> = manifest
|
||||
.paragraphs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, entry)| (entry.id.clone(), idx))
|
||||
.collect();
|
||||
|
||||
let mut changed = false;
|
||||
|
||||
for (p_idx, q_idx) in question_refs {
|
||||
if manifest.case_count >= target_cases {
|
||||
break;
|
||||
}
|
||||
let paragraph = &dataset.paragraphs[p_idx];
|
||||
let question = ¶graph.questions[q_idx];
|
||||
if !existing_questions.insert(question.id.clone()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(idx) = paragraph_positions.get(paragraph.id.as_str()).copied() {
|
||||
match &mut manifest.paragraphs[idx].kind {
|
||||
SliceParagraphKind::Positive { question_ids } => {
|
||||
if !question_ids.contains(&question.id) {
|
||||
question_ids.push(question.id.clone());
|
||||
}
|
||||
}
|
||||
SliceParagraphKind::Negative => {
|
||||
manifest.paragraphs[idx].kind = SliceParagraphKind::Positive {
|
||||
question_ids: vec![question.id.clone()],
|
||||
};
|
||||
}
|
||||
}
|
||||
} else {
|
||||
manifest.paragraphs.push(SliceParagraphEntry {
|
||||
id: paragraph.id.clone(),
|
||||
kind: SliceParagraphKind::Positive {
|
||||
question_ids: vec![question.id.clone()],
|
||||
},
|
||||
shard_path: Some(default_shard_path(¶graph.id)),
|
||||
});
|
||||
let idx = manifest.paragraphs.len() - 1;
|
||||
paragraph_positions.insert(paragraph.id.clone(), idx);
|
||||
}
|
||||
|
||||
manifest.cases.push(SliceCaseEntry {
|
||||
question_id: question.id.clone(),
|
||||
paragraph_id: paragraph.id.clone(),
|
||||
});
|
||||
manifest.case_count += 1;
|
||||
changed = true;
|
||||
}
|
||||
|
||||
if manifest.case_count < target_cases {
|
||||
return Err(anyhow!(
|
||||
"only {}/{} eligible questions available for dataset {}",
|
||||
manifest.case_count,
|
||||
target_cases,
|
||||
dataset.metadata.id
|
||||
));
|
||||
}
|
||||
|
||||
Ok(changed)
|
||||
}
|
||||
|
||||
fn ordered_question_refs(
|
||||
dataset: &ConvertedDataset,
|
||||
params: &BuildParams,
|
||||
) -> Result<Vec<(usize, usize)>> {
|
||||
let mut question_refs = Vec::new();
|
||||
for (p_idx, paragraph) in dataset.paragraphs.iter().enumerate() {
|
||||
for (q_idx, question) in paragraph.questions.iter().enumerate() {
|
||||
let include = if params.include_impossible {
|
||||
true
|
||||
} else {
|
||||
!question.is_impossible && !question.answers.is_empty()
|
||||
};
|
||||
if include {
|
||||
question_refs.push((p_idx, q_idx));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if question_refs.is_empty() {
|
||||
return Err(anyhow!(
|
||||
"no eligible questions found for dataset {}; cannot build slice",
|
||||
dataset.metadata.id
|
||||
));
|
||||
}
|
||||
|
||||
let mut rng = StdRng::seed_from_u64(params.rng_seed);
|
||||
question_refs.shuffle(&mut rng);
|
||||
Ok(question_refs)
|
||||
}
|
||||
|
||||
fn ensure_negative_pool(
|
||||
dataset: &ConvertedDataset,
|
||||
manifest: &mut SliceManifest,
|
||||
params: &BuildParams,
|
||||
target_negatives: usize,
|
||||
requested_corpus: usize,
|
||||
) -> Result<bool> {
|
||||
let current_negatives = manifest
|
||||
.paragraphs
|
||||
.iter()
|
||||
.filter(|entry| matches!(entry.kind, SliceParagraphKind::Negative))
|
||||
.count();
|
||||
if current_negatives >= target_negatives {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let positive_ids: HashSet<String> = manifest
|
||||
.paragraphs
|
||||
.iter()
|
||||
.filter_map(|entry| match entry.kind {
|
||||
SliceParagraphKind::Positive { .. } => Some(entry.id.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
let mut negative_ids: HashSet<String> = manifest
|
||||
.paragraphs
|
||||
.iter()
|
||||
.filter_map(|entry| match entry.kind {
|
||||
SliceParagraphKind::Negative => Some(entry.id.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let negative_seed = mix_seed(
|
||||
&format!("{}::negatives", dataset.metadata.id),
|
||||
params.base_seed,
|
||||
);
|
||||
let candidates = ordered_negative_indices(dataset, &positive_ids, negative_seed);
|
||||
let mut added = false;
|
||||
for idx in candidates {
|
||||
if negative_ids.len() >= target_negatives {
|
||||
break;
|
||||
}
|
||||
let paragraph = &dataset.paragraphs[idx];
|
||||
if negative_ids.contains(paragraph.id.as_str())
|
||||
|| positive_ids.contains(paragraph.id.as_str())
|
||||
{
|
||||
continue;
|
||||
}
|
||||
manifest.paragraphs.push(SliceParagraphEntry {
|
||||
id: paragraph.id.clone(),
|
||||
kind: SliceParagraphKind::Negative,
|
||||
shard_path: Some(default_shard_path(¶graph.id)),
|
||||
});
|
||||
negative_ids.insert(paragraph.id.clone());
|
||||
added = true;
|
||||
}
|
||||
|
||||
if negative_ids.len() < target_negatives {
|
||||
warn!(
|
||||
dataset = %dataset.metadata.id,
|
||||
desired = target_negatives,
|
||||
available = negative_ids.len(),
|
||||
requested_corpus,
|
||||
"Insufficient negative paragraphs to satisfy multiplier"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(added)
|
||||
}
|
||||
|
||||
fn ordered_negative_indices(
|
||||
dataset: &ConvertedDataset,
|
||||
positive_ids: &HashSet<String>,
|
||||
rng_seed: u64,
|
||||
) -> Vec<usize> {
|
||||
let mut candidates: Vec<usize> = dataset
|
||||
.paragraphs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(idx, paragraph)| {
|
||||
if positive_ids.contains(paragraph.id.as_str()) {
|
||||
None
|
||||
} else {
|
||||
Some(idx)
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let mut rng = StdRng::seed_from_u64(rng_seed);
|
||||
candidates.shuffle(&mut rng);
|
||||
candidates
|
||||
}
|
||||
|
||||
fn refresh_manifest_stats(manifest: &mut SliceManifest) {
|
||||
manifest.case_count = manifest.cases.len();
|
||||
manifest.positive_paragraphs = manifest
|
||||
.paragraphs
|
||||
.iter()
|
||||
.filter(|entry| matches!(entry.kind, SliceParagraphKind::Positive { .. }))
|
||||
.count();
|
||||
manifest.negative_paragraphs = manifest
|
||||
.paragraphs
|
||||
.iter()
|
||||
.filter(|entry| matches!(entry.kind, SliceParagraphKind::Negative))
|
||||
.count();
|
||||
manifest.total_paragraphs = manifest.paragraphs.len();
|
||||
}
|
||||
|
||||
fn ensure_shard_paths(manifest: &mut SliceManifest) -> bool {
|
||||
let mut changed = false;
|
||||
for entry in &mut manifest.paragraphs {
|
||||
if entry.shard_path.is_none() {
|
||||
entry.shard_path = Some(default_shard_path(&entry.id));
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
changed
|
||||
}
|
||||
|
||||
fn desired_negative_target(
|
||||
positive_count: usize,
|
||||
requested_corpus: usize,
|
||||
dataset_paragraphs: usize,
|
||||
multiplier: f32,
|
||||
) -> usize {
|
||||
if positive_count == 0 {
|
||||
return 0;
|
||||
}
|
||||
let ratio = multiplier.max(0.0);
|
||||
let mut desired = ((positive_count as f32) * ratio).ceil() as usize;
|
||||
let max_total = requested_corpus.min(dataset_paragraphs).max(positive_count);
|
||||
let max_negatives = max_total.saturating_sub(positive_count);
|
||||
desired = desired.min(max_negatives);
|
||||
desired
|
||||
}
|
||||
|
||||
fn manifest_to_resolved<'a>(
|
||||
dataset: &'a ConvertedDataset,
|
||||
index: &DatasetIndex,
|
||||
manifest: SliceManifest,
|
||||
path: PathBuf,
|
||||
) -> Result<ResolvedSlice<'a>> {
|
||||
if manifest.version != SLICE_VERSION {
|
||||
return Err(anyhow!(
|
||||
"slice version {} does not match expected {}",
|
||||
manifest.version,
|
||||
SLICE_VERSION
|
||||
));
|
||||
}
|
||||
|
||||
let mut paragraphs = Vec::with_capacity(manifest.paragraphs.len());
|
||||
for entry in &manifest.paragraphs {
|
||||
let paragraph = index.paragraph(dataset, &entry.id)?;
|
||||
if let SliceParagraphKind::Positive { question_ids } = &entry.kind {
|
||||
for question_id in question_ids {
|
||||
let (linked_paragraph, _) = index.question(dataset, question_id)?;
|
||||
if linked_paragraph.id != entry.id {
|
||||
return Err(anyhow!(
|
||||
"slice question '{}' expected paragraph '{}', found '{}'",
|
||||
question_id,
|
||||
entry.id,
|
||||
linked_paragraph.id
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
paragraphs.push(paragraph);
|
||||
}
|
||||
|
||||
let mut cases = Vec::with_capacity(manifest.cases.len());
|
||||
for entry in &manifest.cases {
|
||||
let (paragraph, question) = index.question(dataset, &entry.question_id)?;
|
||||
if paragraph.id != entry.paragraph_id {
|
||||
return Err(anyhow!(
|
||||
"slice case '{}' expected paragraph '{}', found '{}'",
|
||||
entry.question_id,
|
||||
entry.paragraph_id,
|
||||
paragraph.id
|
||||
));
|
||||
}
|
||||
cases.push(CaseRef {
|
||||
paragraph,
|
||||
question,
|
||||
});
|
||||
}
|
||||
|
||||
if cases.is_empty() {
|
||||
return Err(anyhow!(
|
||||
"slice '{}' contains no cases after validation",
|
||||
manifest.slice_id
|
||||
));
|
||||
}
|
||||
|
||||
Ok(ResolvedSlice {
|
||||
manifest,
|
||||
path,
|
||||
paragraphs,
|
||||
cases,
|
||||
})
|
||||
}
|
||||
|
||||
fn compute_slice_id(key: &SliceKey<'_>) -> String {
|
||||
let payload = serde_json::to_vec(key).expect("SliceKey serialisation should not fail");
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(payload);
|
||||
let digest = hasher.finalize();
|
||||
digest[..16]
|
||||
.iter()
|
||||
.map(|byte| format!("{byte:02x}"))
|
||||
.collect::<String>()
|
||||
}
|
||||
|
||||
fn mix_seed(dataset_id: &str, seed: u64) -> u64 {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(dataset_id.as_bytes());
|
||||
hasher.update(seed.to_le_bytes());
|
||||
let digest = hasher.finalize();
|
||||
let mut bytes = [0u8; 8];
|
||||
bytes.copy_from_slice(&digest[..8]);
|
||||
u64::from_le_bytes(bytes)
|
||||
}
|
||||
|
||||
fn read_manifest(path: &Path) -> Result<SliceManifest> {
|
||||
let raw = fs::read_to_string(path)
|
||||
.with_context(|| format!("reading slice manifest {}", path.display()))?;
|
||||
let manifest: SliceManifest = serde_json::from_str(&raw)
|
||||
.with_context(|| format!("parsing slice manifest {}", path.display()))?;
|
||||
Ok(manifest)
|
||||
}
|
||||
|
||||
fn write_manifest(path: &Path, manifest: &SliceManifest) -> Result<()> {
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.with_context(|| format!("creating slice directory {}", parent.display()))?;
|
||||
}
|
||||
let json = serde_json::to_vec_pretty(manifest).context("serialising slice manifest to JSON")?;
|
||||
fs::write(path, json).with_context(|| format!("writing slice manifest {}", path.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::datasets::{
|
||||
ConvertedDataset, ConvertedParagraph, ConvertedQuestion, DatasetKind, DatasetMetadata,
|
||||
};
|
||||
use tempfile::tempdir;
|
||||
|
||||
fn sample_dataset() -> ConvertedDataset {
|
||||
let metadata = DatasetMetadata::for_kind(DatasetKind::SquadV2, false, None);
|
||||
ConvertedDataset {
|
||||
generated_at: Utc::now(),
|
||||
metadata,
|
||||
source: "test-source".to_string(),
|
||||
paragraphs: vec![
|
||||
ConvertedParagraph {
|
||||
id: "p1".to_string(),
|
||||
title: "Alpha".to_string(),
|
||||
context: "Alpha context".to_string(),
|
||||
questions: vec![ConvertedQuestion {
|
||||
id: "q1".to_string(),
|
||||
question: "What is alpha?".to_string(),
|
||||
answers: vec!["Alpha".to_string()],
|
||||
is_impossible: false,
|
||||
}],
|
||||
},
|
||||
ConvertedParagraph {
|
||||
id: "p2".to_string(),
|
||||
title: "Beta".to_string(),
|
||||
context: "Beta context".to_string(),
|
||||
questions: vec![ConvertedQuestion {
|
||||
id: "q2".to_string(),
|
||||
question: "What is beta?".to_string(),
|
||||
answers: vec!["Beta".to_string()],
|
||||
is_impossible: false,
|
||||
}],
|
||||
},
|
||||
ConvertedParagraph {
|
||||
id: "p3".to_string(),
|
||||
title: "Gamma".to_string(),
|
||||
context: "Gamma context".to_string(),
|
||||
questions: vec![ConvertedQuestion {
|
||||
id: "q3".to_string(),
|
||||
question: "What is gamma?".to_string(),
|
||||
answers: vec!["Gamma".to_string()],
|
||||
is_impossible: false,
|
||||
}],
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_slice_reuses_cached_manifest() -> Result<()> {
|
||||
let dataset = sample_dataset();
|
||||
let temp = tempdir().context("creating temp directory")?;
|
||||
|
||||
let mut config = SliceConfig {
|
||||
cache_dir: temp.path(),
|
||||
force_convert: false,
|
||||
explicit_slice: None,
|
||||
limit: Some(2),
|
||||
corpus_limit: Some(3),
|
||||
slice_seed: 0x5eed_2025,
|
||||
llm_mode: false,
|
||||
negative_multiplier: DEFAULT_NEGATIVE_MULTIPLIER,
|
||||
};
|
||||
|
||||
let first = resolve_slice(&dataset, &config)?;
|
||||
assert!(first.path.exists());
|
||||
let initial_generated = first.manifest.generated_at;
|
||||
|
||||
let second = resolve_slice(&dataset, &config)?;
|
||||
assert_eq!(first.manifest.slice_id, second.manifest.slice_id);
|
||||
assert_eq!(initial_generated, second.manifest.generated_at);
|
||||
|
||||
config.force_convert = true;
|
||||
let third = resolve_slice(&dataset, &config)?;
|
||||
assert_eq!(first.manifest.slice_id, third.manifest.slice_id);
|
||||
assert_ne!(third.manifest.generated_at, initial_generated);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn select_window_yields_expected_cases() -> Result<()> {
|
||||
let dataset = sample_dataset();
|
||||
let temp = tempdir().context("creating temp directory")?;
|
||||
let config = SliceConfig {
|
||||
cache_dir: temp.path(),
|
||||
force_convert: false,
|
||||
explicit_slice: None,
|
||||
limit: Some(3),
|
||||
corpus_limit: Some(3),
|
||||
slice_seed: 0x5eed_2025,
|
||||
llm_mode: false,
|
||||
negative_multiplier: DEFAULT_NEGATIVE_MULTIPLIER,
|
||||
};
|
||||
let resolved = resolve_slice(&dataset, &config)?;
|
||||
let window = select_window(&resolved, 1, Some(1))?;
|
||||
assert_eq!(window.offset, 1);
|
||||
assert_eq!(window.length, 1);
|
||||
assert_eq!(window.total_cases, resolved.manifest.case_count);
|
||||
assert_eq!(window.cases.len(), 1);
|
||||
let positive_ids: Vec<&str> = window.positive_ids().collect();
|
||||
assert_eq!(positive_ids.len(), 1);
|
||||
assert!(resolved
|
||||
.manifest
|
||||
.paragraphs
|
||||
.iter()
|
||||
.any(|entry| entry.id == positive_ids[0]));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
182
eval/src/snapshot.rs
Normal file
182
eval/src/snapshot.rs
Normal file
@@ -0,0 +1,182 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use tokio::fs;
|
||||
|
||||
use crate::{args::Config, embedding::EmbeddingProvider, slice};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct SnapshotMetadata {
|
||||
pub dataset_id: String,
|
||||
pub slice_id: String,
|
||||
pub embedding_backend: String,
|
||||
pub embedding_model: Option<String>,
|
||||
pub embedding_dimension: usize,
|
||||
pub chunk_min_chars: usize,
|
||||
pub chunk_max_chars: usize,
|
||||
pub rerank_enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DbSnapshotState {
|
||||
pub dataset_id: String,
|
||||
pub slice_id: String,
|
||||
pub ingestion_fingerprint: String,
|
||||
pub snapshot_hash: String,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
#[serde(default)]
|
||||
pub namespace: Option<String>,
|
||||
#[serde(default)]
|
||||
pub database: Option<String>,
|
||||
#[serde(default)]
|
||||
pub slice_case_count: usize,
|
||||
}
|
||||
|
||||
pub struct Descriptor {
|
||||
#[allow(dead_code)]
|
||||
metadata: SnapshotMetadata,
|
||||
dir: PathBuf,
|
||||
metadata_hash: String,
|
||||
}
|
||||
|
||||
impl Descriptor {
|
||||
pub fn new(
|
||||
config: &Config,
|
||||
slice: &slice::ResolvedSlice<'_>,
|
||||
embedding_provider: &EmbeddingProvider,
|
||||
) -> Self {
|
||||
let metadata = SnapshotMetadata {
|
||||
dataset_id: slice.manifest.dataset_id.clone(),
|
||||
slice_id: slice.manifest.slice_id.clone(),
|
||||
embedding_backend: embedding_provider.backend_label().to_string(),
|
||||
embedding_model: embedding_provider.model_code(),
|
||||
embedding_dimension: embedding_provider.dimension(),
|
||||
chunk_min_chars: config.chunk_min_chars,
|
||||
chunk_max_chars: config.chunk_max_chars,
|
||||
rerank_enabled: config.rerank,
|
||||
};
|
||||
|
||||
let dir = config
|
||||
.cache_dir
|
||||
.join("snapshots")
|
||||
.join(&metadata.dataset_id)
|
||||
.join(&metadata.slice_id);
|
||||
let metadata_hash = compute_hash(&metadata);
|
||||
|
||||
Self {
|
||||
metadata,
|
||||
dir,
|
||||
metadata_hash,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn metadata_hash(&self) -> &str {
|
||||
&self.metadata_hash
|
||||
}
|
||||
|
||||
pub async fn load_db_state(&self) -> Result<Option<DbSnapshotState>> {
|
||||
let path = self.db_state_path();
|
||||
if !path.exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
let bytes = fs::read(&path)
|
||||
.await
|
||||
.with_context(|| format!("reading namespace state {}", path.display()))?;
|
||||
let state = serde_json::from_slice(&bytes)
|
||||
.with_context(|| format!("deserialising namespace state {}", path.display()))?;
|
||||
Ok(Some(state))
|
||||
}
|
||||
|
||||
pub async fn store_db_state(&self, state: &DbSnapshotState) -> Result<()> {
|
||||
let path = self.db_state_path();
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent).await.with_context(|| {
|
||||
format!("creating namespace state directory {}", parent.display())
|
||||
})?;
|
||||
}
|
||||
let blob =
|
||||
serde_json::to_vec_pretty(state).context("serialising namespace state payload")?;
|
||||
fs::write(&path, blob)
|
||||
.await
|
||||
.with_context(|| format!("writing namespace state {}", path.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn db_dir(&self) -> PathBuf {
|
||||
self.dir.join("db")
|
||||
}
|
||||
|
||||
fn db_state_path(&self) -> PathBuf {
|
||||
self.db_dir().join("state.json")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn from_parts(metadata: SnapshotMetadata, dir: PathBuf) -> Self {
|
||||
let metadata_hash = compute_hash(&metadata);
|
||||
Self {
|
||||
metadata,
|
||||
dir,
|
||||
metadata_hash,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_hash(metadata: &SnapshotMetadata) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(
|
||||
serde_json::to_vec(metadata).expect("snapshot metadata serialisation should succeed"),
|
||||
);
|
||||
format!("{:x}", hasher.finalize())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn state_round_trip() {
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
let metadata = SnapshotMetadata {
|
||||
dataset_id: "dataset".into(),
|
||||
slice_id: "slice".into(),
|
||||
embedding_backend: "hashed".into(),
|
||||
embedding_model: None,
|
||||
embedding_dimension: 128,
|
||||
chunk_min_chars: 10,
|
||||
chunk_max_chars: 100,
|
||||
rerank_enabled: true,
|
||||
};
|
||||
let descriptor = Descriptor::from_parts(
|
||||
metadata,
|
||||
temp_dir
|
||||
.path()
|
||||
.join("snapshots")
|
||||
.join("dataset")
|
||||
.join("slice"),
|
||||
);
|
||||
|
||||
let state = DbSnapshotState {
|
||||
dataset_id: "dataset".into(),
|
||||
slice_id: "slice".into(),
|
||||
ingestion_fingerprint: "fingerprint".into(),
|
||||
snapshot_hash: descriptor.metadata_hash().to_string(),
|
||||
updated_at: Utc::now(),
|
||||
namespace: Some("ns".into()),
|
||||
database: Some("db".into()),
|
||||
slice_case_count: 42,
|
||||
};
|
||||
descriptor.store_db_state(&state).await.unwrap();
|
||||
|
||||
let loaded = descriptor.load_db_state().await.unwrap().unwrap();
|
||||
assert_eq!(loaded.dataset_id, state.dataset_id);
|
||||
assert_eq!(loaded.slice_id, state.slice_id);
|
||||
assert_eq!(loaded.ingestion_fingerprint, state.ingestion_fingerprint);
|
||||
assert_eq!(loaded.snapshot_hash, state.snapshot_hash);
|
||||
assert_eq!(loaded.namespace, state.namespace);
|
||||
assert_eq!(loaded.database, state.database);
|
||||
assert_eq!(loaded.slice_case_count, state.slice_case_count);
|
||||
}
|
||||
}
|
||||
@@ -7,7 +7,7 @@ use axum::{
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use chat_handlers::{
|
||||
pub use chat_handlers::{
|
||||
delete_conversation, new_chat_user_message, new_user_message, patch_conversation_title,
|
||||
reload_sidebar, show_chat_base, show_conversation_editing_title, show_existing_chat,
|
||||
show_initialized_chat,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
mod handlers;
|
||||
|
||||
use axum::{extract::FromRef, routing::get, Router};
|
||||
use handlers::search_result_handler;
|
||||
pub use handlers::{search_result_handler, SearchParams};
|
||||
|
||||
use crate::html_state::HtmlState;
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ use common::storage::{
|
||||
db::SurrealDbClient,
|
||||
types::ingestion_task::{IngestionTask, DEFAULT_LEASE_SECS},
|
||||
};
|
||||
use pipeline::IngestionPipeline;
|
||||
pub use pipeline::{IngestionConfig, IngestionPipeline, IngestionTuning};
|
||||
use std::sync::Arc;
|
||||
use tokio::time::{sleep, Duration};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
use std::ops::Range;
|
||||
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{ingestion_task::IngestionTask, text_content::TextContent},
|
||||
types::{
|
||||
ingestion_task::IngestionTask, knowledge_entity::KnowledgeEntity,
|
||||
knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk,
|
||||
text_content::TextContent,
|
||||
},
|
||||
},
|
||||
};
|
||||
use composite_retrieval::RetrievedEntity;
|
||||
@@ -24,6 +30,14 @@ pub struct PipelineContext<'a> {
|
||||
pub analysis: Option<LLMEnrichmentResult>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PipelineArtifacts {
|
||||
pub text_content: TextContent,
|
||||
pub entities: Vec<KnowledgeEntity>,
|
||||
pub relationships: Vec<KnowledgeRelationship>,
|
||||
pub chunks: Vec<TextChunk>,
|
||||
}
|
||||
|
||||
impl<'a> PipelineContext<'a> {
|
||||
pub fn new(
|
||||
task: &'a IngestionTask,
|
||||
@@ -73,4 +87,30 @@ impl<'a> PipelineContext<'a> {
|
||||
);
|
||||
err
|
||||
}
|
||||
|
||||
pub async fn build_artifacts(&mut self) -> Result<PipelineArtifacts, AppError> {
|
||||
let content = self.take_text_content()?;
|
||||
let analysis = self.take_analysis()?;
|
||||
|
||||
let (entities, relationships) = self
|
||||
.services
|
||||
.convert_analysis(
|
||||
&content,
|
||||
&analysis,
|
||||
self.pipeline_config.tuning.entity_embedding_concurrency,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let chunk_range: Range<usize> = self.pipeline_config.tuning.chunk_min_chars
|
||||
..self.pipeline_config.tuning.chunk_max_chars;
|
||||
|
||||
let chunks = self.services.prepare_chunks(&content, chunk_range).await?;
|
||||
|
||||
Ok(PipelineArtifacts {
|
||||
text_content: content,
|
||||
entities,
|
||||
relationships,
|
||||
chunks,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ mod stages;
|
||||
mod state;
|
||||
|
||||
pub use config::{IngestionConfig, IngestionTuning};
|
||||
pub use enrichment_result::{LLMEnrichmentResult, LLMKnowledgeEntity, LLMRelationship};
|
||||
pub use services::{DefaultPipelineServices, PipelineServices};
|
||||
|
||||
use std::{
|
||||
@@ -31,7 +32,7 @@ use composite_retrieval::reranking::RerankerPool;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use self::{
|
||||
context::PipelineContext,
|
||||
context::{PipelineArtifacts, PipelineContext},
|
||||
stages::{enrich, persist, prepare_content, retrieve_related},
|
||||
state::ready,
|
||||
};
|
||||
@@ -224,6 +225,33 @@ impl IngestionPipeline {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Runs the ingestion pipeline up to (but excluding) persistence and returns the prepared artifacts.
|
||||
pub async fn produce_artifacts(
|
||||
&self,
|
||||
task: &IngestionTask,
|
||||
) -> Result<PipelineArtifacts, AppError> {
|
||||
let payload = task.content.clone();
|
||||
let mut ctx = PipelineContext::new(
|
||||
task,
|
||||
self.db.as_ref(),
|
||||
&self.pipeline_config,
|
||||
self.services.as_ref(),
|
||||
);
|
||||
|
||||
let machine = ready();
|
||||
let machine = prepare_content(machine, &mut ctx, payload)
|
||||
.await
|
||||
.map_err(|err| ctx.abort(err))?;
|
||||
let machine = retrieve_related(machine, &mut ctx)
|
||||
.await
|
||||
.map_err(|err| ctx.abort(err))?;
|
||||
let _machine = enrich(machine, &mut ctx)
|
||||
.await
|
||||
.map_err(|err| ctx.abort(err))?;
|
||||
|
||||
ctx.build_artifacts().await.map_err(|err| ctx.abort(err))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -29,6 +29,8 @@ use crate::utils::llm_instructions::{
|
||||
get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE,
|
||||
};
|
||||
|
||||
const EMBEDDING_QUERY_CHAR_LIMIT: usize = 12_000;
|
||||
|
||||
#[async_trait]
|
||||
pub trait PipelineServices: Send + Sync {
|
||||
async fn prepare_text_content(
|
||||
@@ -162,9 +164,13 @@ impl PipelineServices for DefaultPipelineServices {
|
||||
&self,
|
||||
content: &TextContent,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
let truncated_body = truncate_for_embedding(&content.text, EMBEDDING_QUERY_CHAR_LIMIT);
|
||||
let input_text = format!(
|
||||
"content: {}, category: {}, user_context: {:?}",
|
||||
content.text, content.category, content.context
|
||||
"content: {}\n[truncated={}], category: {}, user_context: {:?}",
|
||||
truncated_body,
|
||||
truncated_body.len() < content.text.len(),
|
||||
content.category,
|
||||
content.context
|
||||
);
|
||||
|
||||
let rerank_lease = match &self.reranker_pool {
|
||||
@@ -239,3 +245,19 @@ impl PipelineServices for DefaultPipelineServices {
|
||||
Ok(chunks)
|
||||
}
|
||||
}
|
||||
|
||||
fn truncate_for_embedding(text: &str, max_chars: usize) -> String {
|
||||
if text.chars().count() <= max_chars {
|
||||
return text.to_string();
|
||||
}
|
||||
|
||||
let mut truncated = String::with_capacity(max_chars + 3);
|
||||
for (idx, ch) in text.chars().enumerate() {
|
||||
if idx >= max_chars {
|
||||
break;
|
||||
}
|
||||
truncated.push(ch);
|
||||
}
|
||||
truncated.push_str("…");
|
||||
truncated
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ use common::{
|
||||
types::{
|
||||
ingestion_payload::IngestionPayload, knowledge_entity::KnowledgeEntity,
|
||||
knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk,
|
||||
text_content::TextContent,
|
||||
},
|
||||
},
|
||||
};
|
||||
@@ -16,8 +15,7 @@ use tokio::time::{sleep, Duration};
|
||||
use tracing::{debug, instrument, warn};
|
||||
|
||||
use super::{
|
||||
context::PipelineContext,
|
||||
services::PipelineServices,
|
||||
context::{PipelineArtifacts, PipelineContext},
|
||||
state::{ContentPrepared, Enriched, IngestionMachine, Persisted, Ready, Retrieved},
|
||||
};
|
||||
|
||||
@@ -134,37 +132,26 @@ pub async fn persist(
|
||||
machine: IngestionMachine<(), Enriched>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<IngestionMachine<(), Persisted>, AppError> {
|
||||
let content = ctx.take_text_content()?;
|
||||
let analysis = ctx.take_analysis()?;
|
||||
|
||||
let (entities, relationships) = ctx
|
||||
.services
|
||||
.convert_analysis(
|
||||
&content,
|
||||
&analysis,
|
||||
ctx.pipeline_config.tuning.entity_embedding_concurrency,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let PipelineArtifacts {
|
||||
text_content,
|
||||
entities,
|
||||
relationships,
|
||||
chunks,
|
||||
} = ctx.build_artifacts().await?;
|
||||
let entity_count = entities.len();
|
||||
let relationship_count = relationships.len();
|
||||
|
||||
let chunk_range =
|
||||
ctx.pipeline_config.tuning.chunk_min_chars..ctx.pipeline_config.tuning.chunk_max_chars;
|
||||
|
||||
let ((), chunk_count) = tokio::try_join!(
|
||||
store_graph_entities(ctx.db, &ctx.pipeline_config.tuning, entities, relationships),
|
||||
store_vector_chunks(
|
||||
ctx.db,
|
||||
ctx.services,
|
||||
ctx.task_id.as_str(),
|
||||
&content,
|
||||
chunk_range,
|
||||
&chunks,
|
||||
&ctx.pipeline_config.tuning
|
||||
)
|
||||
)?;
|
||||
|
||||
ctx.db.store_item(content).await?;
|
||||
ctx.db.store_item(text_content).await?;
|
||||
ctx.db.rebuild_indexes().await?;
|
||||
|
||||
debug!(
|
||||
@@ -252,17 +239,14 @@ async fn store_graph_entities(
|
||||
|
||||
async fn store_vector_chunks(
|
||||
db: &SurrealDbClient,
|
||||
services: &dyn PipelineServices,
|
||||
task_id: &str,
|
||||
content: &TextContent,
|
||||
chunk_range: std::ops::Range<usize>,
|
||||
chunks: &[TextChunk],
|
||||
tuning: &super::config::IngestionTuning,
|
||||
) -> Result<usize, AppError> {
|
||||
let prepared_chunks = services.prepare_chunks(content, chunk_range).await?;
|
||||
let chunk_count = prepared_chunks.len();
|
||||
let chunk_count = chunks.len();
|
||||
|
||||
let batch_size = tuning.chunk_insert_concurrency.max(1);
|
||||
for chunk in &prepared_chunks {
|
||||
for chunk in chunks {
|
||||
debug!(
|
||||
task_id = %task_id,
|
||||
chunk_id = %chunk.id,
|
||||
@@ -271,7 +255,7 @@ async fn store_vector_chunks(
|
||||
);
|
||||
}
|
||||
|
||||
for batch in prepared_chunks.chunks(batch_size) {
|
||||
for batch in chunks.chunks(batch_size) {
|
||||
store_chunk_batch(db, batch, tuning).await?;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user