From 08b1612fcbaf83996e72f70e4a69dee2411c17b8 Mon Sep 17 00:00:00 2001 From: Per Stark Date: Fri, 28 Nov 2025 21:26:51 +0100 Subject: [PATCH] refactored to clap, mrr and ndcg --- CHANGELOG.md | 4 +- Cargo.lock | 2 + eval/Cargo.toml | 1 + eval/src/args.rs | 947 +++++++------------ eval/src/datasets/mod.rs | 9 +- eval/src/eval/pipeline/stages/run_queries.rs | 36 + eval/src/eval/pipeline/stages/summarize.rs | 20 + eval/src/eval/types.rs | 6 + eval/src/main.rs | 6 +- eval/src/report.rs | 18 + retrieval-pipeline/Cargo.toml | 1 + retrieval-pipeline/src/pipeline/config.rs | 2 +- 12 files changed, 434 insertions(+), 618 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cd3fdb3..5d1321a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,8 @@ # Changelog ## Unreleased -- Added a shared `benchmarks` crate with deterministic fixtures and Criterion suites for ingestion, retrieval, and HTML handlers, plus documented baseline results for local performance checks. +- Added a benchmarks create for evaluating the retrieval process +- Added fastembed embedding support, enables the use of local CPU generated embeddings +- Embeddings stored on own table ## Version 0.2.6 (2025-10-29) - Added an opt-in FastEmbed-based reranking stage behind `reranking_enabled`. It improves retrieval accuracy by re-scoring hybrid results. diff --git a/Cargo.lock b/Cargo.lock index 65389a4..c7588be 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2177,6 +2177,7 @@ dependencies = [ "async-openai", "async-trait", "chrono", + "clap", "common", "criterion", "fastembed", @@ -5464,6 +5465,7 @@ dependencies = [ "async-openai", "async-trait", "axum", + "clap", "common", "fastembed", "futures", diff --git a/eval/Cargo.toml b/eval/Cargo.toml index ced47b2..f5d8ca9 100644 --- a/eval/Cargo.toml +++ b/eval/Cargo.toml @@ -28,6 +28,7 @@ once_cell = "1.19" serde_yaml = "0.9" criterion = "0.5" state-machines = { workspace = true } +clap = { version = "4.4", features = ["derive", "env"] } [dev-dependencies] tempfile = { workspace = true } diff --git a/eval/src/args.rs b/eval/src/args.rs index 3106c12..6debf08 100644 --- a/eval/src/args.rs +++ b/eval/src/args.rs @@ -4,6 +4,7 @@ use std::{ }; use anyhow::{anyhow, Context, Result}; +use clap::{Args, Parser, ValueEnum}; use retrieval_pipeline::RetrievalStrategy; use crate::datasets::DatasetKind; @@ -27,7 +28,8 @@ fn default_ingestion_cache_dir() -> PathBuf { pub const DEFAULT_SLICE_SEED: u64 = 0x5eed_2025; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)] +#[value(rename_all = "lowercase")] pub enum EmbeddingBackend { Hashed, FastEmbed, @@ -39,33 +41,63 @@ impl Default for EmbeddingBackend { } } -impl std::str::FromStr for EmbeddingBackend { - type Err = anyhow::Error; - - fn from_str(s: &str) -> Result { - match s.to_ascii_lowercase().as_str() { - "hashed" => Ok(Self::Hashed), - "fastembed" | "fast-embed" | "fast" => Ok(Self::FastEmbed), - other => Err(anyhow!( - "unknown embedding backend '{other}'. Expected 'hashed' or 'fastembed'." - )), +impl std::fmt::Display for EmbeddingBackend { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Hashed => write!(f, "hashed"), + Self::FastEmbed => write!(f, "fastembed"), } } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Args)] pub struct RetrievalSettings { + /// Minimum characters per chunk for text splitting + #[arg(long, default_value_t = 500)] pub chunk_min_chars: usize, + + /// Maximum characters per chunk for text splitting + #[arg(long, default_value_t = 2000)] pub chunk_max_chars: usize, + + /// Override chunk vector candidate cap + #[arg(long)] pub chunk_vector_take: Option, + + /// Override chunk FTS candidate cap + #[arg(long)] pub chunk_fts_take: Option, + + /// Override chunk token budget estimate for assembly + #[arg(long)] pub chunk_token_budget: Option, + + /// Override average characters per token used for budgeting + #[arg(long)] pub chunk_avg_chars_per_token: Option, + + /// Override maximum chunks attached per entity + #[arg(long)] pub max_chunks_per_entity: Option, + + /// Disable the FastEmbed reranking stage + #[arg(long = "no-rerank", action = clap::ArgAction::SetFalse)] pub rerank: bool, + + /// Reranking engine pool size / parallelism + #[arg(long, default_value_t = 16)] pub rerank_pool_size: usize, + + /// Keep top-N entities after reranking + #[arg(long, default_value_t = 10)] pub rerank_keep_top: usize, + + /// Require verified chunks (disable with --llm-mode) + #[arg(skip = true)] pub require_verified_chunks: bool, + + /// Select the retrieval pipeline strategy + #[arg(long, default_value_t = RetrievalStrategy::Initial)] pub strategy: RetrievalStrategy, } @@ -88,622 +120,315 @@ impl Default for RetrievalSettings { } } -#[derive(Debug, Clone)] +#[derive(Parser, Debug, Clone)] +#[command(author, version, about, long_about = None)] pub struct Config { + /// Convert the selected dataset and exit + #[arg(long)] pub convert_only: bool, - pub force_convert: bool, - pub dataset: DatasetKind, - pub llm_mode: bool, - pub corpus_limit: Option, - pub raw_dataset_path: PathBuf, - pub converted_dataset_path: PathBuf, - pub report_dir: PathBuf, - pub k: usize, - pub limit: Option, - pub summary_sample: usize, - pub full_context: bool, - pub retrieval: RetrievalSettings, - pub concurrency: usize, - pub embedding_backend: EmbeddingBackend, - pub embedding_model: Option, - pub cache_dir: PathBuf, - pub ingestion_cache_dir: PathBuf, - pub ingestion_batch_size: usize, - pub ingestion_max_retries: usize, - pub refresh_embeddings_only: bool, - pub detailed_report: bool, - pub slice: Option, - pub reseed_slice: bool, - pub slice_seed: u64, - pub slice_grow: Option, - pub slice_offset: usize, - pub slice_reset_ingestion: bool, - pub negative_multiplier: f32, - pub label: Option, - pub chunk_diagnostics_path: Option, - pub inspect_question: Option, - pub inspect_manifest: Option, - pub query_model: Option, - pub perf_log_json: Option, - pub perf_log_dir: Option, - pub perf_log_console: bool, - pub db_endpoint: String, - pub db_username: String, - pub db_password: String, - pub db_namespace: Option, - pub db_database: Option, - pub inspect_db_state: Option, -} -impl Default for Config { - fn default() -> Self { - let dataset = DatasetKind::default(); - Self { - convert_only: false, - force_convert: false, - dataset, - llm_mode: false, - corpus_limit: None, - raw_dataset_path: dataset.default_raw_path(), - converted_dataset_path: dataset.default_converted_path(), - report_dir: default_report_dir(), - k: 5, - limit: Some(200), - summary_sample: 5, - full_context: false, - retrieval: RetrievalSettings::default(), - concurrency: 4, - embedding_backend: EmbeddingBackend::FastEmbed, - embedding_model: None, - cache_dir: default_cache_dir(), - ingestion_cache_dir: default_ingestion_cache_dir(), - ingestion_batch_size: 5, - ingestion_max_retries: 3, - refresh_embeddings_only: false, - detailed_report: false, - slice: None, - reseed_slice: false, - slice_seed: DEFAULT_SLICE_SEED, - slice_grow: None, - slice_offset: 0, - slice_reset_ingestion: false, - negative_multiplier: crate::slices::DEFAULT_NEGATIVE_MULTIPLIER, - label: None, - chunk_diagnostics_path: None, - inspect_question: None, - inspect_manifest: None, - query_model: None, - inspect_db_state: None, - perf_log_json: None, - perf_log_dir: None, - perf_log_console: false, - db_endpoint: "ws://127.0.0.1:8000".to_string(), - db_username: "root_user".to_string(), - db_password: "root_password".to_string(), - db_namespace: None, - db_database: None, - } - } + /// Regenerate the converted dataset even if it already exists + #[arg(long, alias = "refresh")] + pub force_convert: bool, + + /// Dataset to evaluate + #[arg(long, default_value_t = DatasetKind::default())] + pub dataset: DatasetKind, + + /// Enable LLM-assisted evaluation features (includes unanswerable cases) + #[arg(long)] + pub llm_mode: bool, + + /// Cap the slice corpus size (positives + negatives) + #[arg(long)] + pub corpus_limit: Option, + + /// Path to the raw dataset (defaults per dataset) + #[arg(long)] + pub raw: Option, + + /// Path to write/read the converted dataset (defaults per dataset) + #[arg(long)] + pub converted: Option, + + /// Directory to write evaluation reports + #[arg(long, default_value_os_t = default_report_dir())] + pub report_dir: PathBuf, + + /// Precision@k cutoff + #[arg(long, default_value_t = 5)] + pub k: usize, + + /// Limit the number of questions evaluated (0 = all) + #[arg(long = "limit", default_value_t = 200)] + pub limit_arg: usize, + + /// Number of mismatches to surface in the Markdown summary + #[arg(long, default_value_t = 5)] + pub sample: usize, + + /// Disable context cropping when converting datasets (ingest entire documents) + #[arg(long)] + pub full_context: bool, + + #[command(flatten)] + pub retrieval: RetrievalSettings, + + /// Concurrency level + #[arg(long, default_value_t = 4)] + pub concurrency: usize, + + /// Embedding backend + #[arg(long, default_value_t = EmbeddingBackend::FastEmbed)] + pub embedding_backend: EmbeddingBackend, + + /// FastEmbed model code + #[arg(long)] + pub embedding_model: Option, + + /// Directory for embedding caches + #[arg(long, default_value_os_t = default_cache_dir())] + pub cache_dir: PathBuf, + + /// Directory for ingestion corpora caches + #[arg(long, default_value_os_t = default_ingestion_cache_dir())] + pub ingestion_cache_dir: PathBuf, + + /// Number of paragraphs to ingest concurrently + #[arg(long, default_value_t = 5)] + pub ingestion_batch_size: usize, + + /// Maximum retries for ingestion failures per paragraph + #[arg(long, default_value_t = 3)] + pub ingestion_max_retries: usize, + + /// Recompute embeddings for cached corpora without re-running ingestion + #[arg(long, alias = "refresh-embeddings")] + pub refresh_embeddings_only: bool, + + /// Include entity descriptions and categories in JSON reports + #[arg(long)] + pub detailed_report: bool, + + /// Use a cached dataset slice by id or path + #[arg(long)] + pub slice: Option, + + /// Ignore cached corpus state and rebuild the slice's SurrealDB corpus + #[arg(long)] + pub reseed_slice: bool, + + /// Slice seed + #[arg(skip = DEFAULT_SLICE_SEED)] + pub slice_seed: u64, + + /// Grow the slice ledger to contain at least this many answerable cases, then exit + #[arg(long)] + pub slice_grow: Option, + + /// Evaluate questions starting at this offset within the slice + #[arg(long, default_value_t = 0)] + pub slice_offset: usize, + + /// Delete cached paragraph shards before rebuilding the ingestion corpus + #[arg(long)] + pub slice_reset_ingestion: bool, + + /// Target negative-to-positive paragraph ratio for slice growth + #[arg(long, default_value_t = crate::slices::DEFAULT_NEGATIVE_MULTIPLIER)] + pub negative_multiplier: f32, + + /// Annotate the run; label is stored in JSON/Markdown reports + #[arg(long)] + pub label: Option, + + /// Write per-query chunk diagnostics JSONL to the provided path + #[arg(long, alias = "chunk-diagnostics")] + pub chunk_diagnostics_path: Option, + + /// Inspect an ingestion cache question and exit + #[arg(long)] + pub inspect_question: Option, + + /// Path to an ingestion cache manifest JSON for inspection mode + #[arg(long)] + pub inspect_manifest: Option, + + /// Override the SurrealDB system settings query model + #[arg(long)] + pub query_model: Option, + + /// Write structured performance telemetry JSON to the provided path + #[arg(long)] + pub perf_log_json: Option, + + /// Directory that receives timestamped perf JSON copies + #[arg(long)] + pub perf_log_dir: Option, + + /// Print per-stage performance timings to stdout after the run + #[arg(long, alias = "perf-log")] + pub perf_log_console: bool, + + /// SurrealDB server endpoint + #[arg(long, default_value = "ws://127.0.0.1:8000", env = "EVAL_DB_ENDPOINT")] + pub db_endpoint: String, + + /// SurrealDB root username + #[arg(long, default_value = "root_user", env = "EVAL_DB_USERNAME")] + pub db_username: String, + + /// SurrealDB root password + #[arg(long, default_value = "root_password", env = "EVAL_DB_PASSWORD")] + pub db_password: String, + + /// Override the namespace used on the SurrealDB server + #[arg(long, env = "EVAL_DB_NAMESPACE")] + pub db_namespace: Option, + + /// Override the database used on the SurrealDB server + #[arg(long, env = "EVAL_DB_DATABASE")] + pub db_database: Option, + + /// Path to inspect DB state + #[arg(long)] + pub inspect_db_state: Option, + + // Computed fields (not arguments) + #[arg(skip)] + pub raw_dataset_path: PathBuf, + #[arg(skip)] + pub converted_dataset_path: PathBuf, + #[arg(skip)] + pub limit: Option, + #[arg(skip)] + pub summary_sample: usize, } impl Config { pub fn context_token_limit(&self) -> Option { None } + + pub fn finalize(&mut self) -> Result<()> { + // Handle dataset paths + if let Some(raw) = &self.raw { + self.raw_dataset_path = raw.clone(); + } else { + self.raw_dataset_path = self.dataset.default_raw_path(); + } + + if let Some(converted) = &self.converted { + self.converted_dataset_path = converted.clone(); + } else { + self.converted_dataset_path = self.dataset.default_converted_path(); + } + + // Handle limit + if self.limit_arg == 0 { + self.limit = None; + } else { + self.limit = Some(self.limit_arg); + } + + // Handle sample + self.summary_sample = self.sample.max(1); + + // Handle retrieval settings + if self.llm_mode { + self.retrieval.require_verified_chunks = false; + } else { + self.retrieval.require_verified_chunks = true; + } + + // Validations + if self.retrieval.chunk_min_chars >= self.retrieval.chunk_max_chars { + return Err(anyhow!( + "--chunk-min must be less than --chunk-max (got {} >= {})", + self.retrieval.chunk_min_chars, + self.retrieval.chunk_max_chars + )); + } + + if self.retrieval.rerank && self.retrieval.rerank_pool_size == 0 { + return Err(anyhow!( + "--rerank-pool must be greater than zero when reranking is enabled" + )); + } + + if self.concurrency == 0 { + return Err(anyhow!("--concurrency must be greater than zero")); + } + + if self.embedding_backend == EmbeddingBackend::Hashed && self.embedding_model.is_some() { + return Err(anyhow!( + "--embedding-model cannot be used with the 'hashed' embedding backend" + )); + } + + if let Some(query_model) = &self.query_model { + if query_model.trim().is_empty() { + return Err(anyhow!("--query-model requires a non-empty model name")); + } + } + + if let Some(grow) = self.slice_grow { + if grow == 0 { + return Err(anyhow!("--slice-grow must be greater than zero")); + } + } + + if self.negative_multiplier <= 0.0 || !self.negative_multiplier.is_finite() { + return Err(anyhow!( + "--negative-multiplier must be a positive finite number" + )); + } + + // Handle corpus limit logic + if let Some(limit) = self.limit { + if let Some(corpus_limit) = self.corpus_limit { + if corpus_limit < limit { + self.corpus_limit = Some(limit); + } + } else { + let default_multiplier = 10usize; + let mut computed = limit.saturating_mul(default_multiplier); + if computed < limit { + computed = limit; + } + let max_cap = 1_000usize; + if computed > max_cap { + computed = max_cap; + } + self.corpus_limit = Some(computed); + } + } + + // Handle perf log dir env var fallback + if self.perf_log_dir.is_none() { + if let Ok(dir) = env::var("EVAL_PERF_LOG_DIR") { + if !dir.trim().is_empty() { + self.perf_log_dir = Some(PathBuf::from(dir)); + } + } + } + + Ok(()) + } } -#[derive(Debug)] pub struct ParsedArgs { pub config: Config, pub show_help: bool, } pub fn parse() -> Result { - let mut config = Config::default(); - let mut show_help = false; - let mut raw_overridden = false; - let mut converted_overridden = false; - - let mut args = env::args().skip(1).peekable(); - while let Some(arg) = args.next() { - match arg.as_str() { - "-h" | "--help" => { - show_help = true; - break; - } - "--convert-only" => config.convert_only = true, - "--force" | "--refresh" => config.force_convert = true, - "--llm-mode" => { - config.llm_mode = true; - config.retrieval.require_verified_chunks = false; - } - "--dataset" => { - let value = take_value("--dataset", &mut args)?; - let parsed = value.parse::()?; - config.dataset = parsed; - if !raw_overridden { - config.raw_dataset_path = parsed.default_raw_path(); - } - if !converted_overridden { - config.converted_dataset_path = parsed.default_converted_path(); - } - } - "--slice" => { - let value = take_value("--slice", &mut args)?; - config.slice = Some(value); - } - "--label" => { - let value = take_value("--label", &mut args)?; - config.label = Some(value); - } - "--query-model" => { - let value = take_value("--query-model", &mut args)?; - if value.trim().is_empty() { - return Err(anyhow!("--query-model requires a non-empty model name")); - } - config.query_model = Some(value.trim().to_string()); - } - "--slice-grow" => { - let value = take_value("--slice-grow", &mut args)?; - let parsed = value.parse::().with_context(|| { - format!("failed to parse --slice-grow value '{value}' as usize") - })?; - if parsed == 0 { - return Err(anyhow!("--slice-grow must be greater than zero")); - } - config.slice_grow = Some(parsed); - } - "--slice-offset" => { - let value = take_value("--slice-offset", &mut args)?; - let parsed = value.parse::().with_context(|| { - format!("failed to parse --slice-offset value '{value}' as usize") - })?; - config.slice_offset = parsed; - } - "--raw" => { - let value = take_value("--raw", &mut args)?; - config.raw_dataset_path = PathBuf::from(value); - raw_overridden = true; - } - "--converted" => { - let value = take_value("--converted", &mut args)?; - config.converted_dataset_path = PathBuf::from(value); - converted_overridden = true; - } - "--corpus-limit" => { - let value = take_value("--corpus-limit", &mut args)?; - let parsed = value.parse::().with_context(|| { - format!("failed to parse --corpus-limit value '{value}' as usize") - })?; - config.corpus_limit = if parsed == 0 { None } else { Some(parsed) }; - } - "--reseed-slice" => { - config.reseed_slice = true; - } - "--slice-reset-ingestion" => { - config.slice_reset_ingestion = true; - } - "--report-dir" => { - let value = take_value("--report-dir", &mut args)?; - config.report_dir = PathBuf::from(value); - } - "--k" => { - let value = take_value("--k", &mut args)?; - let parsed = value - .parse::() - .with_context(|| format!("failed to parse --k value '{value}' as usize"))?; - if parsed == 0 { - return Err(anyhow!("--k must be greater than zero")); - } - config.k = parsed; - } - "--limit" => { - let value = take_value("--limit", &mut args)?; - let parsed = value - .parse::() - .with_context(|| format!("failed to parse --limit value '{value}' as usize"))?; - config.limit = if parsed == 0 { None } else { Some(parsed) }; - } - "--sample" => { - let value = take_value("--sample", &mut args)?; - let parsed = value.parse::().with_context(|| { - format!("failed to parse --sample value '{value}' as usize") - })?; - config.summary_sample = parsed.max(1); - } - "--full-context" => { - config.full_context = true; - } - "--chunk-min" => { - let value = take_value("--chunk-min", &mut args)?; - let parsed = value.parse::().with_context(|| { - format!("failed to parse --chunk-min value '{value}' as usize") - })?; - config.retrieval.chunk_min_chars = parsed.max(1); - } - "--chunk-max" => { - let value = take_value("--chunk-max", &mut args)?; - let parsed = value.parse::().with_context(|| { - format!("failed to parse --chunk-max value '{value}' as usize") - })?; - config.retrieval.chunk_max_chars = parsed.max(1); - } - "--chunk-vector-take" => { - let value = take_value("--chunk-vector-take", &mut args)?; - let parsed = value.parse::().with_context(|| { - format!("failed to parse --chunk-vector-take value '{value}' as usize") - })?; - if parsed == 0 { - return Err(anyhow!("--chunk-vector-take must be greater than zero")); - } - config.retrieval.chunk_vector_take = Some(parsed); - } - "--chunk-fts-take" => { - let value = take_value("--chunk-fts-take", &mut args)?; - let parsed = value.parse::().with_context(|| { - format!("failed to parse --chunk-fts-take value '{value}' as usize") - })?; - if parsed == 0 { - return Err(anyhow!("--chunk-fts-take must be greater than zero")); - } - config.retrieval.chunk_fts_take = Some(parsed); - } - "--chunk-token-budget" => { - let value = take_value("--chunk-token-budget", &mut args)?; - let parsed = value.parse::().with_context(|| { - format!("failed to parse --chunk-token-budget value '{value}' as usize") - })?; - if parsed == 0 { - return Err(anyhow!("--chunk-token-budget must be greater than zero")); - } - config.retrieval.chunk_token_budget = Some(parsed); - } - "--chunk-token-chars" => { - let value = take_value("--chunk-token-chars", &mut args)?; - let parsed = value.parse::().with_context(|| { - format!("failed to parse --chunk-token-chars value '{value}' as usize") - })?; - if parsed == 0 { - return Err(anyhow!("--chunk-token-chars must be greater than zero")); - } - config.retrieval.chunk_avg_chars_per_token = Some(parsed); - } - "--retrieval-strategy" => { - let value = take_value("--retrieval-strategy", &mut args)?; - let parsed = value.parse::().map_err(|err| { - anyhow!("failed to parse --retrieval-strategy value '{value}': {err}") - })?; - config.retrieval.strategy = parsed; - } - "--max-chunks-per-entity" => { - let value = take_value("--max-chunks-per-entity", &mut args)?; - let parsed = value.parse::().with_context(|| { - format!("failed to parse --max-chunks-per-entity value '{value}' as usize") - })?; - if parsed == 0 { - return Err(anyhow!("--max-chunks-per-entity must be greater than zero")); - } - config.retrieval.max_chunks_per_entity = Some(parsed); - } - "--embedding" => { - let value = take_value("--embedding", &mut args)?; - config.embedding_backend = value.parse()?; - } - "--embedding-model" => { - let value = take_value("--embedding-model", &mut args)?; - config.embedding_model = Some(value.trim().to_string()); - } - "--cache-dir" => { - let value = take_value("--cache-dir", &mut args)?; - config.cache_dir = PathBuf::from(value); - } - "--ingestion-cache-dir" => { - let value = take_value("--ingestion-cache-dir", &mut args)?; - config.ingestion_cache_dir = PathBuf::from(value); - } - "--ingestion-batch-size" => { - let value = take_value("--ingestion-batch-size", &mut args)?; - let parsed = value.parse::().with_context(|| { - format!("failed to parse --ingestion-batch-size value '{value}' as usize") - })?; - if parsed == 0 { - return Err(anyhow!("--ingestion-batch-size must be greater than zero")); - } - config.ingestion_batch_size = parsed; - } - "--ingestion-max-retries" => { - let value = take_value("--ingestion-max-retries", &mut args)?; - let parsed = value.parse::().with_context(|| { - format!("failed to parse --ingestion-max-retries value '{value}' as usize") - })?; - config.ingestion_max_retries = parsed; - } - "--negative-multiplier" => { - let value = take_value("--negative-multiplier", &mut args)?; - let parsed = value.parse::().with_context(|| { - format!("failed to parse --negative-multiplier value '{value}' as f32") - })?; - if !(parsed.is_finite() && parsed > 0.0) { - return Err(anyhow!( - "--negative-multiplier must be a positive finite number" - )); - } - config.negative_multiplier = parsed; - } - "--no-rerank" => { - config.retrieval.rerank = false; - } - "--rerank-pool" => { - let value = take_value("--rerank-pool", &mut args)?; - let parsed = value.parse::().with_context(|| { - format!("failed to parse --rerank-pool value '{value}' as usize") - })?; - config.retrieval.rerank_pool_size = parsed.max(1); - } - "--rerank-keep" => { - let value = take_value("--rerank-keep", &mut args)?; - let parsed = value.parse::().with_context(|| { - format!("failed to parse --rerank-keep value '{value}' as usize") - })?; - config.retrieval.rerank_keep_top = parsed.max(1); - } - "--concurrency" => { - let value = take_value("--concurrency", &mut args)?; - let parsed = value.parse::().with_context(|| { - format!("failed to parse --concurrency value '{value}' as usize") - })?; - config.concurrency = parsed.max(1); - } - "--refresh-embeddings" => { - config.refresh_embeddings_only = true; - } - "--detailed-report" => { - config.detailed_report = true; - } - "--chunk-diagnostics" => { - let value = take_value("--chunk-diagnostics", &mut args)?; - config.chunk_diagnostics_path = Some(PathBuf::from(value)); - } - "--inspect-question" => { - let value = take_value("--inspect-question", &mut args)?; - config.inspect_question = Some(value); - } - "--inspect-manifest" => { - let value = take_value("--inspect-manifest", &mut args)?; - config.inspect_manifest = Some(PathBuf::from(value)); - } - "--inspect-db-state" => { - let value = take_value("--inspect-db-state", &mut args)?; - config.inspect_db_state = Some(PathBuf::from(value)); - } - "--perf-log-json" => { - let value = take_value("--perf-log-json", &mut args)?; - config.perf_log_json = Some(PathBuf::from(value)); - } - "--perf-log-dir" => { - let value = take_value("--perf-log-dir", &mut args)?; - config.perf_log_dir = Some(PathBuf::from(value)); - } - "--perf-log" => { - config.perf_log_console = true; - } - "--db-endpoint" => { - let value = take_value("--db-endpoint", &mut args)?; - config.db_endpoint = value; - } - "--db-user" => { - let value = take_value("--db-user", &mut args)?; - config.db_username = value; - } - "--db-pass" => { - let value = take_value("--db-pass", &mut args)?; - config.db_password = value; - } - "--db-namespace" => { - let value = take_value("--db-namespace", &mut args)?; - config.db_namespace = Some(value); - } - "--db-database" => { - let value = take_value("--db-database", &mut args)?; - config.db_database = Some(value); - } - unknown => { - return Err(anyhow!( - "unknown argument '{unknown}'. Use --help to see available options." - )); - } - } - } - - if config.retrieval.chunk_min_chars >= config.retrieval.chunk_max_chars { - return Err(anyhow!( - "--chunk-min must be less than --chunk-max (got {} >= {})", - config.retrieval.chunk_min_chars, - config.retrieval.chunk_max_chars - )); - } - - if config.retrieval.rerank && config.retrieval.rerank_pool_size == 0 { - return Err(anyhow!( - "--rerank-pool must be greater than zero when reranking is enabled" - )); - } - - if config.concurrency == 0 { - return Err(anyhow!("--concurrency must be greater than zero")); - } - - if config.embedding_backend == EmbeddingBackend::Hashed && config.embedding_model.is_some() { - return Err(anyhow!( - "--embedding-model cannot be used with the 'hashed' embedding backend" - )); - } - - if let Some(limit) = config.limit { - if let Some(corpus_limit) = config.corpus_limit { - if corpus_limit < limit { - config.corpus_limit = Some(limit); - } - } else { - let default_multiplier = 10usize; - let mut computed = limit.saturating_mul(default_multiplier); - if computed < limit { - computed = limit; - } - let max_cap = 1_000usize; - if computed > max_cap { - computed = max_cap; - } - config.corpus_limit = Some(computed); - } - } - - if config.perf_log_dir.is_none() { - if let Ok(dir) = env::var("EVAL_PERF_LOG_DIR") { - if !dir.trim().is_empty() { - config.perf_log_dir = Some(PathBuf::from(dir)); - } - } - } - - if let Ok(endpoint) = env::var("EVAL_DB_ENDPOINT") { - if !endpoint.trim().is_empty() { - config.db_endpoint = endpoint; - } - } - if let Ok(username) = env::var("EVAL_DB_USERNAME") { - if !username.trim().is_empty() { - config.db_username = username; - } - } - if let Ok(password) = env::var("EVAL_DB_PASSWORD") { - if !password.trim().is_empty() { - config.db_password = password; - } - } - if let Ok(ns) = env::var("EVAL_DB_NAMESPACE") { - if !ns.trim().is_empty() { - config.db_namespace = Some(ns); - } - } - if let Ok(db) = env::var("EVAL_DB_DATABASE") { - if !db.trim().is_empty() { - config.db_database = Some(db); - } - } - Ok(ParsedArgs { config, show_help }) -} - -fn take_value<'a, I>(flag: &str, iter: &mut std::iter::Peekable) -> Result -where - I: Iterator, -{ - iter.next().ok_or_else(|| anyhow!("{flag} expects a value")) -} - -pub fn print_help() { - let report_default = default_report_dir(); - let cache_default = default_cache_dir(); - let ingestion_cache_default = default_ingestion_cache_dir(); - let report_default_display = report_default.display(); - let cache_default_display = cache_default.display(); - let ingestion_cache_default_display = ingestion_cache_default.display(); - - println!( - "\ -eval — dataset conversion, ingestion, and retrieval evaluation CLI - -USAGE: - cargo eval -- [options] - # or - cargo run -p eval -- [options] - -OPTIONS: - --convert-only Convert the selected dataset and exit. - --force, --refresh Regenerate the converted dataset even if it already exists. - --dataset Dataset to evaluate: 'squad' (default) or 'natural-questions'. - --llm-mode Enable LLM-assisted evaluation features (includes unanswerable cases). - --slice Use a cached dataset slice by id (under eval/cache/slices) or by explicit path. - --label Annotate the run; label is stored in JSON/Markdown reports. - --query-model Override the SurrealDB system settings query model (e.g., gpt-4o-mini) for this run. - --slice-grow Grow the slice ledger to contain at least this many answerable cases, then exit. - --slice-offset Evaluate questions starting at this offset within the slice (default: 0). - --reseed-slice Ignore cached corpus state and rebuild the slice's SurrealDB corpus. - --slice-reset-ingestion - Delete cached paragraph shards before rebuilding the ingestion corpus. - --corpus-limit Cap the slice corpus size (positives + negatives). Defaults to ~10× --limit, capped at 1000. - --raw Path to the raw dataset (defaults per dataset). - --converted Path to write/read the converted dataset (defaults per dataset). - --report-dir Directory to write evaluation reports (default: {report_default_display}). - --k Precision@k cutoff (default: 5). - --limit Limit the number of questions evaluated (default: 200, 0 = all). - --sample Number of mismatches to surface in the Markdown summary (default: 5). - --full-context Disable context cropping when converting datasets (ingest entire documents). - --chunk-min Minimum characters per chunk for text splitting (default: 500). - --chunk-max Maximum characters per chunk for text splitting (default: 2000). - --chunk-vector-take - Override chunk vector candidate cap (default: 20). - --chunk-fts-take - Override chunk FTS candidate cap (default: 20). - --chunk-token-budget - Override chunk token budget estimate for assembly (default: 10000). - --chunk-token-chars - Override average characters per token used for budgeting (default: 4). - --retrieval-strategy - Select the retrieval pipeline strategy (default: initial). - --max-chunks-per-entity - Override maximum chunks attached per entity (default: 4). - --embedding Embedding backend: 'fastembed' (default) or 'hashed'. - --embedding-model - FastEmbed model code (defaults to crate preset when omitted). - --cache-dir Directory for embedding caches (default: {cache_default_display}). - --ingestion-cache-dir - Directory for ingestion corpora caches (default: {ingestion_cache_default_display}). - --ingestion-batch-size - Number of paragraphs to ingest concurrently (default: 5). - --ingestion-max-retries - Maximum retries for ingestion failures per paragraph (default: 3). - --negative-multiplier - Target negative-to-positive paragraph ratio for slice growth (default: 4.0). - --refresh-embeddings Recompute embeddings for cached corpora without re-running ingestion. - --detailed-report Include entity descriptions and categories in JSON reports. - --chunk-diagnostics - Write per-query chunk diagnostics JSONL to the provided path. - --no-rerank Disable the FastEmbed reranking stage (enabled by default). - --rerank-pool Reranking engine pool size / parallelism (default: 16). - --rerank-keep Keep top-N entities after reranking (default: 10). - --inspect-question - Inspect an ingestion cache question and exit (requires --inspect-manifest). - --inspect-manifest - Path to an ingestion cache manifest JSON for inspection mode. - --inspect-db-state - Optional override for the SurrealDB state.json used during inspection; defaults to the state recorded for the selected dataset slice. - --db-endpoint SurrealDB server endpoint (use http:// or https:// to enable SurQL export/import; ws:// endpoints reuse existing namespaces but skip SurQL exports; default: ws://127.0.0.1:8000). - --db-user SurrealDB root username (default: root_user). - --db-pass SurrealDB root password (default: root_password). - --db-namespace Override the namespace used on the SurrealDB server; state.json tracks this value and the ledger case count so changing it or requesting more cases via --limit triggers a rebuild/import (default: derived from dataset). - --db-database Override the database used on the SurrealDB server; recorded alongside namespace in state.json (default: derived from slice). - --perf-log Print per-stage performance timings to stdout after the run. - --perf-log-json - Write structured performance telemetry JSON to the provided path. - --perf-log-dir - Directory that receives timestamped perf JSON copies (defaults to $EVAL_PERF_LOG_DIR). - -Examples: - cargo eval -- --dataset squad --limit 10 --detailed-report - cargo eval -- --dataset natural-questions --limit 1 --rerank-pool 1 --detailed-report - -Notes: - The latest run's JSON/Markdown reports are saved as eval/reports/latest.json and latest.md, making it easy to script automated checks. - -h, --help Show this help text. - -Dataset defaults (from eval/manifest.yaml): - squad raw: eval/data/raw/squad/dev-v2.0.json - converted: eval/data/converted/squad-minne.json - natural-questions raw: eval/data/raw/nq/dev-all.jsonl - converted: eval/data/converted/nq-dev-minne.json -" - ); + let mut config = Config::parse(); + config.finalize()?; + Ok(ParsedArgs { + config, + show_help: false, // Clap handles help automatically + }) } pub fn ensure_parent(path: &Path) -> Result<()> { diff --git a/eval/src/datasets/mod.rs b/eval/src/datasets/mod.rs index 6d106cb..84991e3 100644 --- a/eval/src/datasets/mod.rs +++ b/eval/src/datasets/mod.rs @@ -13,6 +13,7 @@ use chrono::{DateTime, TimeZone, Utc}; use once_cell::sync::OnceCell; use serde::{Deserialize, Serialize}; use tracing::warn; +use clap::ValueEnum; const MANIFEST_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/manifest.yaml"); static DATASET_CATALOG: OnceCell = OnceCell::new(); @@ -243,7 +244,7 @@ fn dataset_entry_for_kind(kind: DatasetKind) -> Result<&'static DatasetEntry> { catalog.dataset(kind.id()) } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)] pub enum DatasetKind { SquadV2, NaturalQuestions, @@ -298,6 +299,12 @@ impl DatasetKind { } } +impl std::fmt::Display for DatasetKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.id()) + } +} + impl Default for DatasetKind { fn default() -> Self { Self::SquadV2 diff --git a/eval/src/eval/pipeline/stages/run_queries.rs b/eval/src/eval/pipeline/stages/run_queries.rs index d530a10..4e12b8b 100644 --- a/eval/src/eval/pipeline/stages/run_queries.rs +++ b/eval/src/eval/pipeline/stages/run_queries.rs @@ -272,6 +272,9 @@ pub(crate) async fn run_queries( } let overall_match = match_rank.is_some(); + let reciprocal_rank = calculate_reciprocal_rank(match_rank); + let ndcg = calculate_ndcg(&retrieved, config.k); + let summary = CaseSummary { question_id, question, @@ -286,6 +289,8 @@ pub(crate) async fn run_queries( is_impossible, has_verified_chunks, match_rank, + reciprocal_rank: Some(reciprocal_rank), + ndcg: Some(ndcg), latency_ms: query_latency, retrieved, }; @@ -359,3 +364,34 @@ pub(crate) async fn run_queries( .run_queries() .map_err(|(_, guard)| map_guard_error("run_queries", guard)) } + +fn calculate_reciprocal_rank(rank: Option) -> f64 { + match rank { + Some(r) if r > 0 => 1.0 / (r as f64), + _ => 0.0, + } +} + +fn calculate_ndcg(retrieved: &[RetrievedSummary], k: usize) -> f64 { + let mut dcg = 0.0; + + for (i, item) in retrieved.iter().enumerate() { + if i >= k { + break; + } + if item.matched { + let rel = 1.0; + dcg += rel / (i as f64 + 2.0).log2(); + } + } + + // IDCG for a single relevant item at rank 1 is 1.0 / log2(2) = 1.0 + let idcg = 1.0; + + if dcg == 0.0 { + 0.0 + } else { + dcg / idcg + } +} + diff --git a/eval/src/eval/pipeline/stages/summarize.rs b/eval/src/eval/pipeline/stages/summarize.rs index 341b84b..3d5c851 100644 --- a/eval/src/eval/pipeline/stages/summarize.rs +++ b/eval/src/eval/pipeline/stages/summarize.rs @@ -45,6 +45,8 @@ pub(crate) async fn summarize( let mut retrieval_cases = 0usize; let mut llm_cases = 0usize; let mut llm_answered = 0usize; + let mut sum_reciprocal_rank = 0.0; + let mut sum_ndcg = 0.0; for summary in &summaries { if summary.is_impossible { llm_cases += 1; @@ -54,6 +56,12 @@ pub(crate) async fn summarize( continue; } retrieval_cases += 1; + if let Some(rr) = summary.reciprocal_rank { + sum_reciprocal_rank += rr; + } + if let Some(ndcg) = summary.ndcg { + sum_ndcg += ndcg; + } if summary.matched { correct += 1; if let Some(rank) = summary.match_rank { @@ -99,6 +107,16 @@ pub(crate) async fn summarize( } else { (correct_at_3 as f64) / (retrieval_cases as f64) }; + let mrr = if retrieval_cases == 0 { + 0.0 + } else { + sum_reciprocal_rank / (retrieval_cases as f64) + }; + let average_ndcg = if retrieval_cases == 0 { + 0.0 + } else { + sum_ndcg / (retrieval_cases as f64) + }; let active_tuning = ctx .retrieval_config @@ -131,6 +149,8 @@ pub(crate) async fn summarize( precision_at_1, precision_at_2, precision_at_3, + mrr, + average_ndcg, duration_ms, dataset_id: dataset.metadata.id.clone(), dataset_label: dataset.metadata.label.clone(), diff --git a/eval/src/eval/types.rs b/eval/src/eval/types.rs index 57a6d6b..0395992 100644 --- a/eval/src/eval/types.rs +++ b/eval/src/eval/types.rs @@ -23,6 +23,8 @@ pub struct EvaluationSummary { pub precision_at_1: f64, pub precision_at_2: f64, pub precision_at_3: f64, + pub mrr: f64, + pub average_ndcg: f64, pub duration_ms: u128, pub dataset_id: String, pub dataset_label: String, @@ -90,6 +92,10 @@ pub struct CaseSummary { pub has_verified_chunks: bool, #[serde(skip_serializing_if = "Option::is_none")] pub match_rank: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub reciprocal_rank: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub ndcg: Option, pub latency_ms: u128, pub retrieved: Vec, } diff --git a/eval/src/main.rs b/eval/src/main.rs index 653dacb..f312acf 100644 --- a/eval/src/main.rs +++ b/eval/src/main.rs @@ -92,10 +92,8 @@ async fn async_main() -> anyhow::Result<()> { let parsed = args::parse()?; - if parsed.show_help { - args::print_help(); - return Ok(()); - } + // Clap handles help automatically, so we don't need to check for it manually + if parsed.config.inspect_question.is_some() { inspection::inspect_question(&parsed.config).await?; diff --git a/eval/src/report.rs b/eval/src/report.rs index 9c6e995..92792cd 100644 --- a/eval/src/report.rs +++ b/eval/src/report.rs @@ -78,6 +78,8 @@ pub struct RetrievalSection { pub precision_at_1: f64, pub precision_at_2: f64, pub precision_at_3: f64, + pub mrr: f64, + pub average_ndcg: f64, pub latency: LatencyStats, pub concurrency: usize, pub strategy: String, @@ -178,6 +180,8 @@ impl EvaluationReport { precision_at_1: summary.precision_at_1, precision_at_2: summary.precision_at_2, precision_at_3: summary.precision_at_3, + mrr: summary.mrr, + average_ndcg: summary.average_ndcg, latency: summary.latency_ms.clone(), concurrency: summary.concurrency, strategy: summary.retrieval_strategy.clone(), @@ -435,6 +439,14 @@ fn render_markdown(report: &EvaluationReport) -> String { report.retrieval.precision_at_2, report.retrieval.precision_at_3 )); + md.push_str(&format!( + "| MRR | {:.3} |\\n", + report.retrieval.mrr + )); + md.push_str(&format!( + "| NDCG | {:.3} |\\n", + report.retrieval.average_ndcg + )); md.push_str(&format!( "| Latency Avg / P50 / P95 (ms) | {:.1} / {} / {} |\\n", report.retrieval.latency.avg, report.retrieval.latency.p50, report.retrieval.latency.p95 @@ -687,6 +699,10 @@ struct HistoryEntry { precision_at_2: f64, precision_at_3: f64, #[serde(default)] + mrr: f64, + #[serde(default)] + average_ndcg: f64, + #[serde(default)] retrieval_cases: usize, #[serde(default)] retrieval_precision: f64, @@ -771,6 +787,8 @@ fn record_history(summary: &EvaluationSummary, report_dir: &Path) -> Result<()> precision_at_1: summary.precision_at_1, precision_at_2: summary.precision_at_2, precision_at_3: summary.precision_at_3, + mrr: summary.mrr, + average_ndcg: summary.average_ndcg, retrieval_cases: summary.retrieval_cases, retrieval_precision: summary.retrieval_precision, llm_cases: summary.llm_cases, diff --git a/retrieval-pipeline/Cargo.toml b/retrieval-pipeline/Cargo.toml index ad792ce..8b65f85 100644 --- a/retrieval-pipeline/Cargo.toml +++ b/retrieval-pipeline/Cargo.toml @@ -21,5 +21,6 @@ async-openai = { workspace = true } async-trait = { workspace = true } uuid = { workspace = true } fastembed = { workspace = true } +clap = { version = "4.4", features = ["derive"] } common = { path = "../common", features = ["test-utils"] } diff --git a/retrieval-pipeline/src/pipeline/config.rs b/retrieval-pipeline/src/pipeline/config.rs index f5bdb1a..40d2650 100644 --- a/retrieval-pipeline/src/pipeline/config.rs +++ b/retrieval-pipeline/src/pipeline/config.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; use std::fmt; -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, clap::ValueEnum)] #[serde(rename_all = "snake_case")] pub enum RetrievalStrategy { Initial,