From bd519ab2699a108792f826a14c74625c5b93cfb2 Mon Sep 17 00:00:00 2001 From: Per Stark Date: Tue, 18 Nov 2025 22:51:06 +0100 Subject: [PATCH] benchmarks: v2 Minor refactor --- eval/Cargo.toml | 2 +- eval/src/args.rs | 125 +- eval/src/datasets.rs | 1003 ----------------- eval/src/datasets/mod.rs | 493 ++++++++ eval/src/datasets/nq.rs | 234 ++++ eval/src/datasets/squad.rs | 107 ++ eval/src/eval/mod.rs | 469 +++----- eval/src/eval/pipeline/context.rs | 4 +- .../eval/pipeline/stages/prepare_namespace.rs | 19 +- eval/src/eval/pipeline/stages/run_queries.rs | 104 +- eval/src/eval/pipeline/stages/summarize.rs | 52 +- eval/src/eval/types.rs | 396 +++++++ eval/src/ingest/config.rs | 72 ++ eval/src/ingest/mod.rs | 10 + .../src/{ingest.rs => ingest/orchestrator.rs} | 446 +------- eval/src/ingest/store.rs | 299 +++++ eval/src/main.rs | 39 +- eval/src/perf.rs | 33 +- eval/src/report.rs | 866 +++++++++++--- eval/src/slice.rs | 1 + eval/src/slices.rs | 49 +- eval/src/snapshot.rs | 6 +- 22 files changed, 2794 insertions(+), 2035 deletions(-) delete mode 100644 eval/src/datasets.rs create mode 100644 eval/src/datasets/mod.rs create mode 100644 eval/src/datasets/nq.rs create mode 100644 eval/src/datasets/squad.rs create mode 100644 eval/src/eval/types.rs create mode 100644 eval/src/ingest/config.rs create mode 100644 eval/src/ingest/mod.rs rename eval/src/{ingest.rs => ingest/orchestrator.rs} (59%) create mode 100644 eval/src/ingest/store.rs diff --git a/eval/Cargo.toml b/eval/Cargo.toml index 82d0319..ced47b2 100644 --- a/eval/Cargo.toml +++ b/eval/Cargo.toml @@ -8,7 +8,7 @@ anyhow = { workspace = true } async-openai = { workspace = true } chrono = { workspace = true } common = { path = "../common" } -composite-retrieval = { path = "../composite-retrieval" } +retrieval-pipeline = { path = "../retrieval-pipeline" } ingestion-pipeline = { path = "../ingestion-pipeline" } futures = { workspace = true } fastembed = { workspace = true } diff --git a/eval/src/args.rs b/eval/src/args.rs index c42a20e..d1d132c 100644 --- a/eval/src/args.rs +++ b/eval/src/args.rs @@ -4,6 +4,7 @@ use std::{ }; use anyhow::{anyhow, Context, Result}; +use retrieval_pipeline::RetrievalStrategy; use crate::datasets::DatasetKind; @@ -35,6 +36,41 @@ impl std::str::FromStr for EmbeddingBackend { } } +#[derive(Debug, Clone)] +pub struct RetrievalSettings { + pub chunk_min_chars: usize, + pub chunk_max_chars: usize, + pub chunk_vector_take: Option, + pub chunk_fts_take: Option, + pub chunk_token_budget: Option, + pub chunk_avg_chars_per_token: Option, + pub max_chunks_per_entity: Option, + pub rerank: bool, + pub rerank_pool_size: usize, + pub rerank_keep_top: usize, + pub require_verified_chunks: bool, + pub strategy: RetrievalStrategy, +} + +impl Default for RetrievalSettings { + fn default() -> Self { + Self { + chunk_min_chars: 500, + chunk_max_chars: 2_000, + chunk_vector_take: None, + chunk_fts_take: None, + chunk_token_budget: None, + chunk_avg_chars_per_token: None, + max_chunks_per_entity: None, + rerank: true, + rerank_pool_size: 16, + rerank_keep_top: 10, + require_verified_chunks: true, + strategy: RetrievalStrategy::Initial, + } + } +} + #[derive(Debug, Clone)] pub struct Config { pub convert_only: bool, @@ -49,21 +85,14 @@ pub struct Config { pub limit: Option, pub summary_sample: usize, pub full_context: bool, - pub chunk_min_chars: usize, - pub chunk_max_chars: usize, - pub chunk_vector_take: Option, - pub chunk_fts_take: Option, - pub chunk_token_budget: Option, - pub chunk_avg_chars_per_token: Option, - pub max_chunks_per_entity: Option, - pub rerank: bool, - pub rerank_pool_size: usize, - pub rerank_keep_top: usize, + pub retrieval: RetrievalSettings, pub concurrency: usize, pub embedding_backend: EmbeddingBackend, pub embedding_model: Option, pub cache_dir: PathBuf, pub ingestion_cache_dir: PathBuf, + pub ingestion_batch_size: usize, + pub ingestion_max_retries: usize, pub refresh_embeddings_only: bool, pub detailed_report: bool, pub slice: Option, @@ -105,21 +134,14 @@ impl Default for Config { limit: Some(200), summary_sample: 5, full_context: false, - chunk_min_chars: 500, - chunk_max_chars: 2_000, - chunk_vector_take: None, - chunk_fts_take: None, - chunk_token_budget: None, - chunk_avg_chars_per_token: None, - max_chunks_per_entity: None, - rerank: true, - rerank_pool_size: 16, - rerank_keep_top: 10, + retrieval: RetrievalSettings::default(), concurrency: 4, embedding_backend: EmbeddingBackend::FastEmbed, embedding_model: None, cache_dir: PathBuf::from("eval/cache"), ingestion_cache_dir: PathBuf::from("eval/cache/ingested"), + ingestion_batch_size: 5, + ingestion_max_retries: 3, refresh_embeddings_only: false, detailed_report: false, slice: None, @@ -176,6 +198,7 @@ pub fn parse() -> Result { "--force" | "--refresh" => config.force_convert = true, "--llm-mode" => { config.llm_mode = true; + config.retrieval.require_verified_chunks = false; } "--dataset" => { let value = take_value("--dataset", &mut args)?; @@ -279,14 +302,14 @@ pub fn parse() -> Result { let parsed = value.parse::().with_context(|| { format!("failed to parse --chunk-min value '{value}' as usize") })?; - config.chunk_min_chars = parsed.max(1); + config.retrieval.chunk_min_chars = parsed.max(1); } "--chunk-max" => { let value = take_value("--chunk-max", &mut args)?; let parsed = value.parse::().with_context(|| { format!("failed to parse --chunk-max value '{value}' as usize") })?; - config.chunk_max_chars = parsed.max(1); + config.retrieval.chunk_max_chars = parsed.max(1); } "--chunk-vector-take" => { let value = take_value("--chunk-vector-take", &mut args)?; @@ -296,7 +319,7 @@ pub fn parse() -> Result { if parsed == 0 { return Err(anyhow!("--chunk-vector-take must be greater than zero")); } - config.chunk_vector_take = Some(parsed); + config.retrieval.chunk_vector_take = Some(parsed); } "--chunk-fts-take" => { let value = take_value("--chunk-fts-take", &mut args)?; @@ -306,7 +329,7 @@ pub fn parse() -> Result { if parsed == 0 { return Err(anyhow!("--chunk-fts-take must be greater than zero")); } - config.chunk_fts_take = Some(parsed); + config.retrieval.chunk_fts_take = Some(parsed); } "--chunk-token-budget" => { let value = take_value("--chunk-token-budget", &mut args)?; @@ -316,7 +339,7 @@ pub fn parse() -> Result { if parsed == 0 { return Err(anyhow!("--chunk-token-budget must be greater than zero")); } - config.chunk_token_budget = Some(parsed); + config.retrieval.chunk_token_budget = Some(parsed); } "--chunk-token-chars" => { let value = take_value("--chunk-token-chars", &mut args)?; @@ -326,7 +349,14 @@ pub fn parse() -> Result { if parsed == 0 { return Err(anyhow!("--chunk-token-chars must be greater than zero")); } - config.chunk_avg_chars_per_token = Some(parsed); + config.retrieval.chunk_avg_chars_per_token = Some(parsed); + } + "--retrieval-strategy" => { + let value = take_value("--retrieval-strategy", &mut args)?; + let parsed = value.parse::().map_err(|err| { + anyhow!("failed to parse --retrieval-strategy value '{value}': {err}") + })?; + config.retrieval.strategy = parsed; } "--max-chunks-per-entity" => { let value = take_value("--max-chunks-per-entity", &mut args)?; @@ -336,7 +366,7 @@ pub fn parse() -> Result { if parsed == 0 { return Err(anyhow!("--max-chunks-per-entity must be greater than zero")); } - config.max_chunks_per_entity = Some(parsed); + config.retrieval.max_chunks_per_entity = Some(parsed); } "--embedding" => { let value = take_value("--embedding", &mut args)?; @@ -354,6 +384,23 @@ pub fn parse() -> Result { let value = take_value("--ingestion-cache-dir", &mut args)?; config.ingestion_cache_dir = PathBuf::from(value); } + "--ingestion-batch-size" => { + let value = take_value("--ingestion-batch-size", &mut args)?; + let parsed = value.parse::().with_context(|| { + format!("failed to parse --ingestion-batch-size value '{value}' as usize") + })?; + if parsed == 0 { + return Err(anyhow!("--ingestion-batch-size must be greater than zero")); + } + config.ingestion_batch_size = parsed; + } + "--ingestion-max-retries" => { + let value = take_value("--ingestion-max-retries", &mut args)?; + let parsed = value.parse::().with_context(|| { + format!("failed to parse --ingestion-max-retries value '{value}' as usize") + })?; + config.ingestion_max_retries = parsed; + } "--negative-multiplier" => { let value = take_value("--negative-multiplier", &mut args)?; let parsed = value.parse::().with_context(|| { @@ -367,21 +414,21 @@ pub fn parse() -> Result { config.negative_multiplier = parsed; } "--no-rerank" => { - config.rerank = false; + config.retrieval.rerank = false; } "--rerank-pool" => { let value = take_value("--rerank-pool", &mut args)?; let parsed = value.parse::().with_context(|| { format!("failed to parse --rerank-pool value '{value}' as usize") })?; - config.rerank_pool_size = parsed.max(1); + config.retrieval.rerank_pool_size = parsed.max(1); } "--rerank-keep" => { let value = take_value("--rerank-keep", &mut args)?; let parsed = value.parse::().with_context(|| { format!("failed to parse --rerank-keep value '{value}' as usize") })?; - config.rerank_keep_top = parsed.max(1); + config.retrieval.rerank_keep_top = parsed.max(1); } "--concurrency" => { let value = take_value("--concurrency", &mut args)?; @@ -451,15 +498,15 @@ pub fn parse() -> Result { } } - if config.chunk_min_chars >= config.chunk_max_chars { + if config.retrieval.chunk_min_chars >= config.retrieval.chunk_max_chars { return Err(anyhow!( "--chunk-min must be less than --chunk-max (got {} >= {})", - config.chunk_min_chars, - config.chunk_max_chars + config.retrieval.chunk_min_chars, + config.retrieval.chunk_max_chars )); } - if config.rerank && config.rerank_pool_size == 0 { + if config.retrieval.rerank && config.retrieval.rerank_pool_size == 0 { return Err(anyhow!( "--rerank-pool must be greater than zero when reranking is enabled" )); @@ -578,14 +625,20 @@ OPTIONS: Override chunk token budget estimate for assembly (default: 10000). --chunk-token-chars Override average characters per token used for budgeting (default: 4). + --retrieval-strategy + Select the retrieval pipeline strategy (default: initial). --max-chunks-per-entity Override maximum chunks attached per entity (default: 4). --embedding Embedding backend: 'fastembed' (default) or 'hashed'. --embedding-model FastEmbed model code (defaults to crate preset when omitted). --cache-dir Directory for embedding caches (default: eval/cache). - --ingestion-cache-dir - Directory for ingestion corpora caches (default: eval/cache/ingested). + --ingestion-cache-dir + Directory for ingestion corpora caches (default: eval/cache/ingested). + --ingestion-batch-size + Number of paragraphs to ingest concurrently (default: 5). + --ingestion-max-retries + Maximum retries for ingestion failures per paragraph (default: 3). --negative-multiplier Target negative-to-positive paragraph ratio for slice growth (default: 4.0). --refresh-embeddings Recompute embeddings for cached corpora without re-running ingestion. diff --git a/eval/src/datasets.rs b/eval/src/datasets.rs deleted file mode 100644 index c1f9271..0000000 --- a/eval/src/datasets.rs +++ /dev/null @@ -1,1003 +0,0 @@ -use std::{ - collections::{BTreeMap, BTreeSet, HashMap}, - fs::{self, File}, - io::{BufRead, BufReader}, - path::{Path, PathBuf}, - str::FromStr, -}; - -use anyhow::{anyhow, bail, Context, Result}; -use chrono::{TimeZone, Utc}; -use once_cell::sync::OnceCell; -use serde::{Deserialize, Serialize}; -use tracing::warn; - -const MANIFEST_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/manifest.yaml"); -static DATASET_CATALOG: OnceCell = OnceCell::new(); - -#[derive(Debug, Clone)] -#[allow(dead_code)] -pub struct DatasetCatalog { - datasets: BTreeMap, - slices: HashMap, - default_dataset: String, -} - -#[derive(Debug, Clone)] -#[allow(dead_code)] -pub struct DatasetEntry { - pub metadata: DatasetMetadata, - pub raw_path: PathBuf, - pub converted_path: PathBuf, - pub include_unanswerable: bool, - pub slices: Vec, -} - -#[derive(Debug, Clone)] -#[allow(dead_code)] -pub struct SliceEntry { - pub id: String, - pub dataset_id: String, - pub label: String, - pub description: Option, - pub limit: Option, - pub corpus_limit: Option, - pub include_unanswerable: Option, - pub seed: Option, -} - -#[derive(Debug, Clone)] -#[allow(dead_code)] -struct SliceLocation { - dataset_id: String, - slice_index: usize, -} - -#[derive(Debug, Deserialize)] -struct ManifestFile { - default_dataset: Option, - datasets: Vec, -} - -#[derive(Debug, Deserialize)] -struct ManifestDataset { - id: String, - label: String, - category: String, - #[serde(default)] - entity_suffix: Option, - #[serde(default)] - source_prefix: Option, - raw: String, - converted: String, - #[serde(default)] - include_unanswerable: bool, - #[serde(default)] - slices: Vec, -} - -#[derive(Debug, Deserialize)] -struct ManifestSlice { - id: String, - label: String, - #[serde(default)] - description: Option, - #[serde(default)] - limit: Option, - #[serde(default)] - corpus_limit: Option, - #[serde(default)] - include_unanswerable: Option, - #[serde(default)] - seed: Option, -} - -impl DatasetCatalog { - pub fn load() -> Result { - let manifest_raw = fs::read_to_string(MANIFEST_PATH) - .with_context(|| format!("reading dataset manifest at {}", MANIFEST_PATH))?; - let manifest: ManifestFile = serde_yaml::from_str(&manifest_raw) - .with_context(|| format!("parsing dataset manifest at {}", MANIFEST_PATH))?; - - let root = Path::new(env!("CARGO_MANIFEST_DIR")); - let mut datasets = BTreeMap::new(); - let mut slices = HashMap::new(); - - for dataset in manifest.datasets { - let raw_path = resolve_path(root, &dataset.raw); - let converted_path = resolve_path(root, &dataset.converted); - - if !raw_path.exists() { - bail!( - "dataset '{}' raw file missing at {}", - dataset.id, - raw_path.display() - ); - } - if !converted_path.exists() { - warn!( - "dataset '{}' converted file missing at {}; the next conversion run will regenerate it", - dataset.id, - converted_path.display() - ); - } - - let metadata = DatasetMetadata { - id: dataset.id.clone(), - label: dataset.label.clone(), - category: dataset.category.clone(), - entity_suffix: dataset - .entity_suffix - .clone() - .unwrap_or_else(|| dataset.label.clone()), - source_prefix: dataset - .source_prefix - .clone() - .unwrap_or_else(|| dataset.id.clone()), - include_unanswerable: dataset.include_unanswerable, - context_token_limit: None, - }; - - let mut entry_slices = Vec::with_capacity(dataset.slices.len()); - - for (index, manifest_slice) in dataset.slices.into_iter().enumerate() { - if slices.contains_key(&manifest_slice.id) { - bail!( - "slice '{}' defined multiple times in manifest", - manifest_slice.id - ); - } - entry_slices.push(SliceEntry { - id: manifest_slice.id.clone(), - dataset_id: dataset.id.clone(), - label: manifest_slice.label, - description: manifest_slice.description, - limit: manifest_slice.limit, - corpus_limit: manifest_slice.corpus_limit, - include_unanswerable: manifest_slice.include_unanswerable, - seed: manifest_slice.seed, - }); - slices.insert( - manifest_slice.id, - SliceLocation { - dataset_id: dataset.id.clone(), - slice_index: index, - }, - ); - } - - datasets.insert( - metadata.id.clone(), - DatasetEntry { - metadata, - raw_path, - converted_path, - include_unanswerable: dataset.include_unanswerable, - slices: entry_slices, - }, - ); - } - - let default_dataset = manifest - .default_dataset - .or_else(|| datasets.keys().next().cloned()) - .ok_or_else(|| anyhow!("dataset manifest does not include any datasets"))?; - - Ok(Self { - datasets, - slices, - default_dataset, - }) - } - - pub fn global() -> Result<&'static Self> { - DATASET_CATALOG.get_or_try_init(Self::load) - } - - pub fn dataset(&self, id: &str) -> Result<&DatasetEntry> { - self.datasets - .get(id) - .ok_or_else(|| anyhow!("unknown dataset '{id}' in manifest")) - } - - #[allow(dead_code)] - pub fn default_dataset(&self) -> Result<&DatasetEntry> { - self.dataset(&self.default_dataset) - } - - #[allow(dead_code)] - pub fn slice(&self, slice_id: &str) -> Result<(&DatasetEntry, &SliceEntry)> { - let location = self - .slices - .get(slice_id) - .ok_or_else(|| anyhow!("unknown slice '{slice_id}' in manifest"))?; - let dataset = self - .datasets - .get(&location.dataset_id) - .ok_or_else(|| anyhow!("slice '{slice_id}' references missing dataset"))?; - let slice = dataset - .slices - .get(location.slice_index) - .ok_or_else(|| anyhow!("slice index out of bounds for '{slice_id}'"))?; - Ok((dataset, slice)) - } -} - -fn resolve_path(root: &Path, value: &str) -> PathBuf { - let path = Path::new(value); - if path.is_absolute() { - path.to_path_buf() - } else { - root.join(path) - } -} - -pub fn catalog() -> Result<&'static DatasetCatalog> { - DatasetCatalog::global() -} - -fn dataset_entry_for_kind(kind: DatasetKind) -> Result<&'static DatasetEntry> { - let catalog = catalog()?; - catalog.dataset(kind.id()) -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum DatasetKind { - SquadV2, - NaturalQuestions, -} - -impl DatasetKind { - pub fn id(self) -> &'static str { - match self { - Self::SquadV2 => "squad-v2", - Self::NaturalQuestions => "natural-questions-dev", - } - } - - pub fn label(self) -> &'static str { - match self { - Self::SquadV2 => "SQuAD v2.0", - Self::NaturalQuestions => "Natural Questions (dev)", - } - } - - pub fn category(self) -> &'static str { - match self { - Self::SquadV2 => "SQuAD v2.0", - Self::NaturalQuestions => "Natural Questions", - } - } - - pub fn entity_suffix(self) -> &'static str { - match self { - Self::SquadV2 => "SQuAD", - Self::NaturalQuestions => "Natural Questions", - } - } - - pub fn source_prefix(self) -> &'static str { - match self { - Self::SquadV2 => "squad", - Self::NaturalQuestions => "nq", - } - } - - pub fn default_raw_path(self) -> PathBuf { - dataset_entry_for_kind(self) - .map(|entry| entry.raw_path.clone()) - .unwrap_or_else(|err| panic!("dataset manifest missing entry for {:?}: {err}", self)) - } - - pub fn default_converted_path(self) -> PathBuf { - dataset_entry_for_kind(self) - .map(|entry| entry.converted_path.clone()) - .unwrap_or_else(|err| panic!("dataset manifest missing entry for {:?}: {err}", self)) - } -} - -impl Default for DatasetKind { - fn default() -> Self { - Self::SquadV2 - } -} - -impl FromStr for DatasetKind { - type Err = anyhow::Error; - - fn from_str(s: &str) -> Result { - match s.to_ascii_lowercase().as_str() { - "squad" | "squad-v2" | "squad_v2" => Ok(Self::SquadV2), - "nq" | "natural-questions" | "natural_questions" | "natural-questions-dev" => { - Ok(Self::NaturalQuestions) - } - other => { - anyhow::bail!("unknown dataset '{other}'. Expected 'squad' or 'natural-questions'.") - } - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DatasetMetadata { - pub id: String, - pub label: String, - pub category: String, - pub entity_suffix: String, - pub source_prefix: String, - #[serde(default)] - pub include_unanswerable: bool, - #[serde(default)] - pub context_token_limit: Option, -} - -impl DatasetMetadata { - pub fn for_kind( - kind: DatasetKind, - include_unanswerable: bool, - context_token_limit: Option, - ) -> Self { - if let Ok(entry) = dataset_entry_for_kind(kind) { - return Self { - id: entry.metadata.id.clone(), - label: entry.metadata.label.clone(), - category: entry.metadata.category.clone(), - entity_suffix: entry.metadata.entity_suffix.clone(), - source_prefix: entry.metadata.source_prefix.clone(), - include_unanswerable, - context_token_limit, - }; - } - - Self { - id: kind.id().to_string(), - label: kind.label().to_string(), - category: kind.category().to_string(), - entity_suffix: kind.entity_suffix().to_string(), - source_prefix: kind.source_prefix().to_string(), - include_unanswerable, - context_token_limit, - } - } -} - -fn default_metadata() -> DatasetMetadata { - DatasetMetadata::for_kind(DatasetKind::default(), false, None) -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ConvertedDataset { - pub generated_at: chrono::DateTime, - #[serde(default = "default_metadata")] - pub metadata: DatasetMetadata, - pub source: String, - pub paragraphs: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ConvertedParagraph { - pub id: String, - pub title: String, - pub context: String, - pub questions: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ConvertedQuestion { - pub id: String, - pub question: String, - pub answers: Vec, - pub is_impossible: bool, -} - -pub fn convert( - raw_path: &Path, - dataset: DatasetKind, - include_unanswerable: bool, - context_token_limit: Option, -) -> Result { - let paragraphs = match dataset { - DatasetKind::SquadV2 => convert_squad(raw_path)?, - DatasetKind::NaturalQuestions => { - convert_nq(raw_path, include_unanswerable, context_token_limit)? - } - }; - - let metadata_limit = match dataset { - DatasetKind::NaturalQuestions => None, - _ => context_token_limit, - }; - - Ok(ConvertedDataset { - generated_at: Utc::now(), - metadata: DatasetMetadata::for_kind(dataset, include_unanswerable, metadata_limit), - source: raw_path.display().to_string(), - paragraphs, - }) -} - -fn convert_squad(raw_path: &Path) -> Result> { - #[derive(Debug, Deserialize)] - struct SquadDataset { - data: Vec, - } - - #[derive(Debug, Deserialize)] - struct SquadArticle { - title: String, - paragraphs: Vec, - } - - #[derive(Debug, Deserialize)] - struct SquadParagraph { - context: String, - qas: Vec, - } - - #[derive(Debug, Deserialize)] - struct SquadQuestion { - id: String, - question: String, - answers: Vec, - #[serde(default)] - is_impossible: bool, - } - - #[derive(Debug, Deserialize)] - struct SquadAnswer { - text: String, - } - - let raw = fs::read_to_string(raw_path) - .with_context(|| format!("reading raw SQuAD dataset at {}", raw_path.display()))?; - let parsed: SquadDataset = serde_json::from_str(&raw) - .with_context(|| format!("parsing SQuAD dataset at {}", raw_path.display()))?; - - let mut paragraphs = Vec::new(); - for (article_idx, article) in parsed.data.into_iter().enumerate() { - for (paragraph_idx, paragraph) in article.paragraphs.into_iter().enumerate() { - let mut questions = Vec::new(); - for qa in paragraph.qas { - let answers = dedupe_strings(qa.answers.into_iter().map(|answer| answer.text)); - questions.push(ConvertedQuestion { - id: qa.id, - question: qa.question.trim().to_string(), - answers, - is_impossible: qa.is_impossible, - }); - } - - let paragraph_id = - format!("{}-{}", slugify(&article.title, article_idx), paragraph_idx); - - paragraphs.push(ConvertedParagraph { - id: paragraph_id, - title: article.title.trim().to_string(), - context: paragraph.context.trim().to_string(), - questions, - }); - } - } - - Ok(paragraphs) -} - -#[allow(dead_code)] -pub const DEFAULT_CONTEXT_TOKEN_LIMIT: usize = 1_500; // retained for backwards compatibility (unused) - -fn convert_nq( - raw_path: &Path, - include_unanswerable: bool, - _context_token_limit: Option, -) -> Result> { - #[allow(dead_code)] - #[derive(Debug, Deserialize)] - struct NqExample { - question_text: String, - document_title: String, - example_id: i64, - document_tokens: Vec, - long_answer_candidates: Vec, - annotations: Vec, - } - - #[derive(Debug, Deserialize)] - struct NqToken { - token: String, - #[serde(default)] - html_token: bool, - } - - #[allow(dead_code)] - #[derive(Debug, Deserialize)] - struct NqLongAnswerCandidate { - start_token: i32, - end_token: i32, - } - - #[allow(dead_code)] - #[derive(Debug, Deserialize)] - struct NqAnnotation { - short_answers: Vec, - #[serde(default)] - yes_no_answer: String, - long_answer: NqLongAnswer, - } - - #[derive(Debug, Deserialize)] - struct NqShortAnswer { - start_token: i32, - end_token: i32, - } - - #[allow(dead_code)] - #[derive(Debug, Deserialize)] - struct NqLongAnswer { - candidate_index: i32, - } - - fn join_tokens(tokens: &[NqToken], start: usize, end: usize) -> String { - let mut buffer = String::new(); - let end = end.min(tokens.len()); - for token in tokens.iter().skip(start).take(end.saturating_sub(start)) { - if token.html_token { - continue; - } - let text = token.token.trim(); - if text.is_empty() { - continue; - } - let attach = matches!( - text, - "," | "." | "!" | "?" | ";" | ":" | ")" | "]" | "}" | "%" | "…" | "..." - ) || text.starts_with('\'') - || text == "n't" - || text == "'s" - || text == "'re" - || text == "'ve" - || text == "'d" - || text == "'ll"; - - if buffer.is_empty() || attach { - buffer.push_str(text); - } else { - buffer.push(' '); - buffer.push_str(text); - } - } - - buffer.trim().to_string() - } - - let file = File::open(raw_path).with_context(|| { - format!( - "opening Natural Questions dataset at {}", - raw_path.display() - ) - })?; - let reader = BufReader::new(file); - - let mut paragraphs = Vec::new(); - for (line_idx, line) in reader.lines().enumerate() { - let line = line.with_context(|| { - format!( - "reading Natural Questions line {} from {}", - line_idx + 1, - raw_path.display() - ) - })?; - if line.trim().is_empty() { - continue; - } - let example: NqExample = serde_json::from_str(&line).with_context(|| { - format!( - "parsing Natural Questions JSON (line {}) at {}", - line_idx + 1, - raw_path.display() - ) - })?; - - let mut answer_texts: Vec = Vec::new(); - let mut short_answer_texts: Vec = Vec::new(); - let mut has_short_or_yesno = false; - let mut has_short_answer = false; - for annotation in &example.annotations { - for short in &annotation.short_answers { - if short.start_token < 0 || short.end_token <= short.start_token { - continue; - } - let start = short.start_token as usize; - let end = short.end_token as usize; - if start >= example.document_tokens.len() || end > example.document_tokens.len() { - continue; - } - let text = join_tokens(&example.document_tokens, start, end); - if !text.is_empty() { - answer_texts.push(text.clone()); - short_answer_texts.push(text); - has_short_or_yesno = true; - has_short_answer = true; - } - } - - match annotation - .yes_no_answer - .trim() - .to_ascii_lowercase() - .as_str() - { - "yes" => { - answer_texts.push("yes".to_string()); - has_short_or_yesno = true; - } - "no" => { - answer_texts.push("no".to_string()); - has_short_or_yesno = true; - } - _ => {} - } - } - - let mut answers = dedupe_strings(answer_texts); - let is_unanswerable = !has_short_or_yesno || answers.is_empty(); - if is_unanswerable { - if !include_unanswerable { - continue; - } - answers.clear(); - } - - let paragraph_id = format!("nq-{}", example.example_id); - let question_id = format!("nq-{}", example.example_id); - - let context = join_tokens(&example.document_tokens, 0, example.document_tokens.len()); - if context.is_empty() { - continue; - } - - if has_short_answer && !short_answer_texts.is_empty() { - let normalized_context = context.to_ascii_lowercase(); - let missing_answer = short_answer_texts.iter().any(|answer| { - let needle = answer.trim().to_ascii_lowercase(); - !needle.is_empty() && !normalized_context.contains(&needle) - }); - if missing_answer { - warn!( - question_id = %question_id, - "Skipping Natural Questions example because answers were not found in the assembled context" - ); - continue; - } - } - - if !include_unanswerable && (!has_short_answer || short_answer_texts.is_empty()) { - // yes/no-only questions are excluded by default unless --llm-mode is used - continue; - } - - let question = ConvertedQuestion { - id: question_id, - question: example.question_text.trim().to_string(), - answers, - is_impossible: is_unanswerable, - }; - - paragraphs.push(ConvertedParagraph { - id: paragraph_id, - title: example.document_title.trim().to_string(), - context, - questions: vec![question], - }); - } - - Ok(paragraphs) -} - -fn ensure_parent(path: &Path) -> Result<()> { - if let Some(parent) = path.parent() { - fs::create_dir_all(parent) - .with_context(|| format!("creating parent directory for {}", path.display()))?; - } - Ok(()) -} - -pub fn write_converted(dataset: &ConvertedDataset, converted_path: &Path) -> Result<()> { - ensure_parent(converted_path)?; - let json = - serde_json::to_string_pretty(dataset).context("serialising converted dataset to JSON")?; - fs::write(converted_path, json) - .with_context(|| format!("writing converted dataset to {}", converted_path.display())) -} - -pub fn read_converted(converted_path: &Path) -> Result { - let raw = fs::read_to_string(converted_path) - .with_context(|| format!("reading converted dataset at {}", converted_path.display()))?; - let mut dataset: ConvertedDataset = serde_json::from_str(&raw) - .with_context(|| format!("parsing converted dataset at {}", converted_path.display()))?; - if dataset.metadata.id.trim().is_empty() { - dataset.metadata = default_metadata(); - } - if dataset.source.is_empty() { - dataset.source = converted_path.display().to_string(); - } - Ok(dataset) -} - -pub fn ensure_converted( - dataset_kind: DatasetKind, - raw_path: &Path, - converted_path: &Path, - force: bool, - include_unanswerable: bool, - context_token_limit: Option, -) -> Result { - if force || !converted_path.exists() { - let dataset = convert( - raw_path, - dataset_kind, - include_unanswerable, - context_token_limit, - )?; - write_converted(&dataset, converted_path)?; - return Ok(dataset); - } - - match read_converted(converted_path) { - Ok(dataset) - if dataset.metadata.id == dataset_kind.id() - && dataset.metadata.include_unanswerable == include_unanswerable - && dataset.metadata.context_token_limit == context_token_limit => - { - Ok(dataset) - } - _ => { - let dataset = convert( - raw_path, - dataset_kind, - include_unanswerable, - context_token_limit, - )?; - write_converted(&dataset, converted_path)?; - Ok(dataset) - } - } -} - -fn dedupe_strings(values: I) -> Vec -where - I: IntoIterator, -{ - let mut set = BTreeSet::new(); - for value in values { - let trimmed = value.trim(); - if !trimmed.is_empty() { - set.insert(trimmed.to_string()); - } - } - set.into_iter().collect() -} - -fn slugify(input: &str, fallback_idx: usize) -> String { - let mut slug = String::new(); - let mut last_dash = false; - for ch in input.chars() { - let c = ch.to_ascii_lowercase(); - if c.is_ascii_alphanumeric() { - slug.push(c); - last_dash = false; - } else if !last_dash { - slug.push('-'); - last_dash = true; - } - } - - slug = slug.trim_matches('-').to_string(); - if slug.is_empty() { - slug = format!("article-{fallback_idx}"); - } - slug -} - -pub fn base_timestamp() -> chrono::DateTime { - Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0) - .single() - .expect("valid base timestamp") -} - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - use std::io::Write; - use tempfile::NamedTempFile; - - #[test] - fn convert_nq_handles_answers_and_skips_unanswerable() { - let mut file = NamedTempFile::new().expect("temp file"); - - let record_with_short_answers = json!({ - "question_text": "What is foo?", - "document_title": "Foo Title", - "example_id": 123, - "document_tokens": [ - {"token": "Foo", "html_token": false}, - {"token": "is", "html_token": false}, - {"token": "bar", "html_token": false}, - {"token": ".", "html_token": false} - ], - "long_answer_candidates": [ - {"start_token": 0, "end_token": 4, "top_level": true} - ], - "annotations": [ - { - "long_answer": {"start_token": 0, "end_token": 4, "candidate_index": 0}, - "short_answers": [ - {"start_token": 2, "end_token": 3}, - {"start_token": 2, "end_token": 3} - ], - "yes_no_answer": "NONE" - } - ] - }); - - let record_with_yes_no = json!({ - "question_text": "Is bar real?", - "document_title": "Bar Title", - "example_id": 456, - "document_tokens": [ - {"token": "Yes", "html_token": false}, - {"token": ",", "html_token": false}, - {"token": "bar", "html_token": false}, - {"token": "is", "html_token": false} - ], - "long_answer_candidates": [ - {"start_token": 0, "end_token": 4, "top_level": true} - ], - "annotations": [ - { - "long_answer": {"start_token": 0, "end_token": 4, "candidate_index": 0}, - "short_answers": [], - "yes_no_answer": "YES" - } - ] - }); - - let unanswerable_record = json!({ - "question_text": "Unknown?", - "document_title": "Unknown Title", - "example_id": 789, - "document_tokens": [ - {"token": "No", "html_token": false}, - {"token": "answer", "html_token": false} - ], - "long_answer_candidates": [ - {"start_token": 0, "end_token": 2, "top_level": true} - ], - "annotations": [ - { - "long_answer": {"start_token": 0, "end_token": 2, "candidate_index": 0}, - "short_answers": [], - "yes_no_answer": "NONE" - } - ] - }); - - writeln!(file, "{}", record_with_short_answers).unwrap(); - writeln!(file, "{}", record_with_yes_no).unwrap(); - writeln!(file, "{}", unanswerable_record).unwrap(); - file.flush().unwrap(); - - let dataset = convert( - file.path(), - DatasetKind::NaturalQuestions, - false, - Some(DEFAULT_CONTEXT_TOKEN_LIMIT), - ) - .expect("convert natural questions"); - - assert_eq!(dataset.metadata.id, DatasetKind::NaturalQuestions.id()); - assert!(!dataset.metadata.include_unanswerable); - assert_eq!(dataset.paragraphs.len(), 2); - - let first = &dataset.paragraphs[0]; - assert_eq!(first.id, "nq-123"); - assert!(first.context.contains("Foo")); - let first_answers = &first.questions.first().expect("question present").answers; - assert_eq!(first_answers, &vec!["bar".to_string()]); - - let second = &dataset.paragraphs[1]; - assert_eq!(second.id, "nq-456"); - let second_answers = &second.questions.first().expect("question present").answers; - assert_eq!(second_answers, &vec!["yes".to_string()]); - - assert!(dataset - .paragraphs - .iter() - .all(|paragraph| paragraph.id != "nq-789")); - } - - #[test] - fn convert_nq_includes_unanswerable_when_flagged() { - let mut file = NamedTempFile::new().expect("temp file"); - - let answerable = json!({ - "question_text": "What is foo?", - "document_title": "Foo Title", - "example_id": 123, - "document_tokens": [ - {"token": "Foo", "html_token": false}, - {"token": "is", "html_token": false}, - {"token": "bar", "html_token": false} - ], - "long_answer_candidates": [ - {"start_token": 0, "end_token": 3, "top_level": true} - ], - "annotations": [ - { - "long_answer": {"start_token": 0, "end_token": 3, "candidate_index": 0}, - "short_answers": [ - {"start_token": 2, "end_token": 3} - ], - "yes_no_answer": "NONE" - } - ] - }); - - let unanswerable = json!({ - "question_text": "Unknown?", - "document_title": "Unknown Title", - "example_id": 456, - "document_tokens": [ - {"token": "No", "html_token": false}, - {"token": "answer", "html_token": false} - ], - "long_answer_candidates": [ - {"start_token": 0, "end_token": 2, "top_level": true} - ], - "annotations": [ - { - "long_answer": {"start_token": 0, "end_token": 2, "candidate_index": -1}, - "short_answers": [], - "yes_no_answer": "NONE" - } - ] - }); - - writeln!(file, "{}", answerable).unwrap(); - writeln!(file, "{}", unanswerable).unwrap(); - file.flush().unwrap(); - - let dataset = convert( - file.path(), - DatasetKind::NaturalQuestions, - true, - Some(DEFAULT_CONTEXT_TOKEN_LIMIT), - ) - .expect("convert natural questions with unanswerable"); - - assert!(dataset.metadata.include_unanswerable); - assert_eq!(dataset.paragraphs.len(), 2); - let impossible = dataset - .paragraphs - .iter() - .find(|p| p.id == "nq-456") - .expect("unanswerable paragraph present"); - let question = impossible.questions.first().expect("question present"); - assert!(question.answers.is_empty()); - assert!(question.is_impossible); - } - - #[test] - fn catalog_lists_datasets_and_slices() { - let catalog = catalog().expect("catalog"); - let squad = catalog.dataset("squad-v2").expect("squad dataset"); - assert!(squad.raw_path.exists()); - assert!(squad.converted_path.exists()); - assert!(!squad.slices.is_empty()); - - let (dataset, slice) = catalog.slice("squad-dev-200").expect("slice"); - assert_eq!(dataset.metadata.id, squad.metadata.id); - assert_eq!(slice.dataset_id, squad.metadata.id); - assert!(slice.limit.is_some()); - } -} diff --git a/eval/src/datasets/mod.rs b/eval/src/datasets/mod.rs new file mode 100644 index 0000000..6d106cb --- /dev/null +++ b/eval/src/datasets/mod.rs @@ -0,0 +1,493 @@ +mod nq; +mod squad; + +use std::{ + collections::{BTreeMap, HashMap}, + fs::{self}, + path::{Path, PathBuf}, + str::FromStr, +}; + +use anyhow::{anyhow, bail, Context, Result}; +use chrono::{DateTime, TimeZone, Utc}; +use once_cell::sync::OnceCell; +use serde::{Deserialize, Serialize}; +use tracing::warn; + +const MANIFEST_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/manifest.yaml"); +static DATASET_CATALOG: OnceCell = OnceCell::new(); + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct DatasetCatalog { + datasets: BTreeMap, + slices: HashMap, + default_dataset: String, +} + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct DatasetEntry { + pub metadata: DatasetMetadata, + pub raw_path: PathBuf, + pub converted_path: PathBuf, + pub include_unanswerable: bool, + pub slices: Vec, +} + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct SliceEntry { + pub id: String, + pub dataset_id: String, + pub label: String, + pub description: Option, + pub limit: Option, + pub corpus_limit: Option, + pub include_unanswerable: Option, + pub seed: Option, +} + +#[derive(Debug, Clone)] +#[allow(dead_code)] +struct SliceLocation { + dataset_id: String, + slice_index: usize, +} + +#[derive(Debug, Deserialize)] +struct ManifestFile { + default_dataset: Option, + datasets: Vec, +} + +#[derive(Debug, Deserialize)] +struct ManifestDataset { + id: String, + label: String, + category: String, + #[serde(default)] + entity_suffix: Option, + #[serde(default)] + source_prefix: Option, + raw: String, + converted: String, + #[serde(default)] + include_unanswerable: bool, + #[serde(default)] + slices: Vec, +} + +#[derive(Debug, Deserialize)] +struct ManifestSlice { + id: String, + label: String, + #[serde(default)] + description: Option, + #[serde(default)] + limit: Option, + #[serde(default)] + corpus_limit: Option, + #[serde(default)] + include_unanswerable: Option, + #[serde(default)] + seed: Option, +} + +impl DatasetCatalog { + pub fn load() -> Result { + let manifest_raw = fs::read_to_string(MANIFEST_PATH) + .with_context(|| format!("reading dataset manifest at {}", MANIFEST_PATH))?; + let manifest: ManifestFile = serde_yaml::from_str(&manifest_raw) + .with_context(|| format!("parsing dataset manifest at {}", MANIFEST_PATH))?; + + let root = Path::new(env!("CARGO_MANIFEST_DIR")); + let mut datasets = BTreeMap::new(); + let mut slices = HashMap::new(); + + for dataset in manifest.datasets { + let raw_path = resolve_path(root, &dataset.raw); + let converted_path = resolve_path(root, &dataset.converted); + + if !raw_path.exists() { + bail!( + "dataset '{}' raw file missing at {}", + dataset.id, + raw_path.display() + ); + } + if !converted_path.exists() { + warn!( + "dataset '{}' converted file missing at {}; the next conversion run will regenerate it", + dataset.id, + converted_path.display() + ); + } + + let metadata = DatasetMetadata { + id: dataset.id.clone(), + label: dataset.label.clone(), + category: dataset.category.clone(), + entity_suffix: dataset + .entity_suffix + .clone() + .unwrap_or_else(|| dataset.label.clone()), + source_prefix: dataset + .source_prefix + .clone() + .unwrap_or_else(|| dataset.id.clone()), + include_unanswerable: dataset.include_unanswerable, + context_token_limit: None, + }; + + let mut entry_slices = Vec::with_capacity(dataset.slices.len()); + + for (index, manifest_slice) in dataset.slices.into_iter().enumerate() { + if slices.contains_key(&manifest_slice.id) { + bail!( + "slice '{}' defined multiple times in manifest", + manifest_slice.id + ); + } + entry_slices.push(SliceEntry { + id: manifest_slice.id.clone(), + dataset_id: dataset.id.clone(), + label: manifest_slice.label, + description: manifest_slice.description, + limit: manifest_slice.limit, + corpus_limit: manifest_slice.corpus_limit, + include_unanswerable: manifest_slice.include_unanswerable, + seed: manifest_slice.seed, + }); + slices.insert( + manifest_slice.id, + SliceLocation { + dataset_id: dataset.id.clone(), + slice_index: index, + }, + ); + } + + datasets.insert( + metadata.id.clone(), + DatasetEntry { + metadata, + raw_path, + converted_path, + include_unanswerable: dataset.include_unanswerable, + slices: entry_slices, + }, + ); + } + + let default_dataset = manifest + .default_dataset + .or_else(|| datasets.keys().next().cloned()) + .ok_or_else(|| anyhow!("dataset manifest does not include any datasets"))?; + + Ok(Self { + datasets, + slices, + default_dataset, + }) + } + + pub fn global() -> Result<&'static Self> { + DATASET_CATALOG.get_or_try_init(Self::load) + } + + pub fn dataset(&self, id: &str) -> Result<&DatasetEntry> { + self.datasets + .get(id) + .ok_or_else(|| anyhow!("unknown dataset '{id}' in manifest")) + } + + #[allow(dead_code)] + pub fn default_dataset(&self) -> Result<&DatasetEntry> { + self.dataset(&self.default_dataset) + } + + #[allow(dead_code)] + pub fn slice(&self, slice_id: &str) -> Result<(&DatasetEntry, &SliceEntry)> { + let location = self + .slices + .get(slice_id) + .ok_or_else(|| anyhow!("unknown slice '{slice_id}' in manifest"))?; + let dataset = self + .datasets + .get(&location.dataset_id) + .ok_or_else(|| anyhow!("slice '{slice_id}' references missing dataset"))?; + let slice = dataset + .slices + .get(location.slice_index) + .ok_or_else(|| anyhow!("slice index out of bounds for '{slice_id}'"))?; + Ok((dataset, slice)) + } +} + +fn resolve_path(root: &Path, value: &str) -> PathBuf { + let path = Path::new(value); + if path.is_absolute() { + path.to_path_buf() + } else { + root.join(path) + } +} + +pub fn catalog() -> Result<&'static DatasetCatalog> { + DatasetCatalog::global() +} + +fn dataset_entry_for_kind(kind: DatasetKind) -> Result<&'static DatasetEntry> { + let catalog = catalog()?; + catalog.dataset(kind.id()) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DatasetKind { + SquadV2, + NaturalQuestions, +} + +impl DatasetKind { + pub fn id(self) -> &'static str { + match self { + Self::SquadV2 => "squad-v2", + Self::NaturalQuestions => "natural-questions-dev", + } + } + + pub fn label(self) -> &'static str { + match self { + Self::SquadV2 => "SQuAD v2.0", + Self::NaturalQuestions => "Natural Questions (dev)", + } + } + + pub fn category(self) -> &'static str { + match self { + Self::SquadV2 => "SQuAD v2.0", + Self::NaturalQuestions => "Natural Questions", + } + } + + pub fn entity_suffix(self) -> &'static str { + match self { + Self::SquadV2 => "SQuAD", + Self::NaturalQuestions => "Natural Questions", + } + } + + pub fn source_prefix(self) -> &'static str { + match self { + Self::SquadV2 => "squad", + Self::NaturalQuestions => "nq", + } + } + + pub fn default_raw_path(self) -> PathBuf { + dataset_entry_for_kind(self) + .map(|entry| entry.raw_path.clone()) + .unwrap_or_else(|err| panic!("dataset manifest missing entry for {:?}: {err}", self)) + } + + pub fn default_converted_path(self) -> PathBuf { + dataset_entry_for_kind(self) + .map(|entry| entry.converted_path.clone()) + .unwrap_or_else(|err| panic!("dataset manifest missing entry for {:?}: {err}", self)) + } +} + +impl Default for DatasetKind { + fn default() -> Self { + Self::SquadV2 + } +} + +impl FromStr for DatasetKind { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + match s.to_ascii_lowercase().as_str() { + "squad" | "squad-v2" | "squad_v2" => Ok(Self::SquadV2), + "nq" | "natural-questions" | "natural_questions" | "natural-questions-dev" => { + Ok(Self::NaturalQuestions) + } + other => { + anyhow::bail!("unknown dataset '{other}'. Expected 'squad' or 'natural-questions'.") + } + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DatasetMetadata { + pub id: String, + pub label: String, + pub category: String, + pub entity_suffix: String, + pub source_prefix: String, + #[serde(default)] + pub include_unanswerable: bool, + #[serde(default)] + pub context_token_limit: Option, +} + +impl DatasetMetadata { + pub fn for_kind( + kind: DatasetKind, + include_unanswerable: bool, + context_token_limit: Option, + ) -> Self { + if let Ok(entry) = dataset_entry_for_kind(kind) { + return Self { + id: entry.metadata.id.clone(), + label: entry.metadata.label.clone(), + category: entry.metadata.category.clone(), + entity_suffix: entry.metadata.entity_suffix.clone(), + source_prefix: entry.metadata.source_prefix.clone(), + include_unanswerable, + context_token_limit, + }; + } + + Self { + id: kind.id().to_string(), + label: kind.label().to_string(), + category: kind.category().to_string(), + entity_suffix: kind.entity_suffix().to_string(), + source_prefix: kind.source_prefix().to_string(), + include_unanswerable, + context_token_limit, + } + } +} + +fn default_metadata() -> DatasetMetadata { + DatasetMetadata::for_kind(DatasetKind::default(), false, None) +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConvertedDataset { + pub generated_at: DateTime, + #[serde(default = "default_metadata")] + pub metadata: DatasetMetadata, + pub source: String, + pub paragraphs: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConvertedParagraph { + pub id: String, + pub title: String, + pub context: String, + pub questions: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConvertedQuestion { + pub id: String, + pub question: String, + pub answers: Vec, + pub is_impossible: bool, +} + +pub fn convert( + raw_path: &Path, + dataset: DatasetKind, + include_unanswerable: bool, + context_token_limit: Option, +) -> Result { + let paragraphs = match dataset { + DatasetKind::SquadV2 => squad::convert_squad(raw_path)?, + DatasetKind::NaturalQuestions => { + nq::convert_nq(raw_path, include_unanswerable, context_token_limit)? + } + }; + + let metadata_limit = match dataset { + DatasetKind::NaturalQuestions => None, + _ => context_token_limit, + }; + + Ok(ConvertedDataset { + generated_at: Utc::now(), + metadata: DatasetMetadata::for_kind(dataset, include_unanswerable, metadata_limit), + source: raw_path.display().to_string(), + paragraphs, + }) +} + +fn ensure_parent(path: &Path) -> Result<()> { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent) + .with_context(|| format!("creating parent directory for {}", path.display()))?; + } + Ok(()) +} + +pub fn write_converted(dataset: &ConvertedDataset, converted_path: &Path) -> Result<()> { + ensure_parent(converted_path)?; + let json = + serde_json::to_string_pretty(dataset).context("serialising converted dataset to JSON")?; + fs::write(converted_path, json) + .with_context(|| format!("writing converted dataset to {}", converted_path.display())) +} + +pub fn read_converted(converted_path: &Path) -> Result { + let raw = fs::read_to_string(converted_path) + .with_context(|| format!("reading converted dataset at {}", converted_path.display()))?; + let mut dataset: ConvertedDataset = serde_json::from_str(&raw) + .with_context(|| format!("parsing converted dataset at {}", converted_path.display()))?; + if dataset.metadata.id.trim().is_empty() { + dataset.metadata = default_metadata(); + } + if dataset.source.is_empty() { + dataset.source = converted_path.display().to_string(); + } + Ok(dataset) +} + +pub fn ensure_converted( + dataset_kind: DatasetKind, + raw_path: &Path, + converted_path: &Path, + force: bool, + include_unanswerable: bool, + context_token_limit: Option, +) -> Result { + if force || !converted_path.exists() { + let dataset = convert( + raw_path, + dataset_kind, + include_unanswerable, + context_token_limit, + )?; + write_converted(&dataset, converted_path)?; + return Ok(dataset); + } + + match read_converted(converted_path) { + Ok(dataset) + if dataset.metadata.id == dataset_kind.id() + && dataset.metadata.include_unanswerable == include_unanswerable + && dataset.metadata.context_token_limit == context_token_limit => + { + Ok(dataset) + } + _ => { + let dataset = convert( + raw_path, + dataset_kind, + include_unanswerable, + context_token_limit, + )?; + write_converted(&dataset, converted_path)?; + Ok(dataset) + } + } +} + +pub fn base_timestamp() -> DateTime { + Utc.with_ymd_and_hms(2023, 1, 1, 0, 0, 0).unwrap() +} diff --git a/eval/src/datasets/nq.rs b/eval/src/datasets/nq.rs new file mode 100644 index 0000000..dde9f2b --- /dev/null +++ b/eval/src/datasets/nq.rs @@ -0,0 +1,234 @@ +use std::{ + collections::BTreeSet, + fs::File, + io::{BufRead, BufReader}, + path::Path, +}; + +use anyhow::{Context, Result}; +use serde::Deserialize; +use tracing::warn; + +use super::{ConvertedParagraph, ConvertedQuestion}; + +pub fn convert_nq( + raw_path: &Path, + include_unanswerable: bool, + _context_token_limit: Option, +) -> Result> { + #[allow(dead_code)] + #[derive(Debug, Deserialize)] + struct NqExample { + question_text: String, + document_title: String, + example_id: i64, + document_tokens: Vec, + long_answer_candidates: Vec, + annotations: Vec, + } + + #[derive(Debug, Deserialize)] + struct NqToken { + token: String, + #[serde(default)] + html_token: bool, + } + + #[allow(dead_code)] + #[derive(Debug, Deserialize)] + struct NqLongAnswerCandidate { + start_token: i32, + end_token: i32, + } + + #[allow(dead_code)] + #[derive(Debug, Deserialize)] + struct NqAnnotation { + short_answers: Vec, + #[serde(default)] + yes_no_answer: String, + long_answer: NqLongAnswer, + } + + #[derive(Debug, Deserialize)] + struct NqShortAnswer { + start_token: i32, + end_token: i32, + } + + #[allow(dead_code)] + #[derive(Debug, Deserialize)] + struct NqLongAnswer { + candidate_index: i32, + } + + fn join_tokens(tokens: &[NqToken], start: usize, end: usize) -> String { + let mut buffer = String::new(); + let end = end.min(tokens.len()); + for token in tokens.iter().skip(start).take(end.saturating_sub(start)) { + if token.html_token { + continue; + } + let text = token.token.trim(); + if text.is_empty() { + continue; + } + let attach = matches!( + text, + "," | "." | "!" | "?" | ";" | ":" | ")" | "]" | "}" | "%" | "…" | "..." + ) || text.starts_with('\'') + || text == "n't" + || text == "'s" + || text == "'re" + || text == "'ve" + || text == "'d" + || text == "'ll"; + + if buffer.is_empty() || attach { + buffer.push_str(text); + } else { + buffer.push(' '); + buffer.push_str(text); + } + } + + buffer.trim().to_string() + } + + let file = File::open(raw_path).with_context(|| { + format!( + "opening Natural Questions dataset at {}", + raw_path.display() + ) + })?; + let reader = BufReader::new(file); + + let mut paragraphs = Vec::new(); + for (line_idx, line) in reader.lines().enumerate() { + let line = line.with_context(|| { + format!( + "reading Natural Questions line {} from {}", + line_idx + 1, + raw_path.display() + ) + })?; + if line.trim().is_empty() { + continue; + } + let example: NqExample = serde_json::from_str(&line).with_context(|| { + format!( + "parsing Natural Questions JSON (line {}) at {}", + line_idx + 1, + raw_path.display() + ) + })?; + + let mut answer_texts: Vec = Vec::new(); + let mut short_answer_texts: Vec = Vec::new(); + let mut has_short_or_yesno = false; + let mut has_short_answer = false; + for annotation in &example.annotations { + for short in &annotation.short_answers { + if short.start_token < 0 || short.end_token <= short.start_token { + continue; + } + let start = short.start_token as usize; + let end = short.end_token as usize; + if start >= example.document_tokens.len() || end > example.document_tokens.len() { + continue; + } + let text = join_tokens(&example.document_tokens, start, end); + if !text.is_empty() { + answer_texts.push(text.clone()); + short_answer_texts.push(text); + has_short_or_yesno = true; + has_short_answer = true; + } + } + + match annotation + .yes_no_answer + .trim() + .to_ascii_lowercase() + .as_str() + { + "yes" => { + answer_texts.push("yes".to_string()); + has_short_or_yesno = true; + } + "no" => { + answer_texts.push("no".to_string()); + has_short_or_yesno = true; + } + _ => {} + } + } + + let mut answers = dedupe_strings(answer_texts); + let is_unanswerable = !has_short_or_yesno || answers.is_empty(); + if is_unanswerable { + if !include_unanswerable { + continue; + } + answers.clear(); + } + + let paragraph_id = format!("nq-{}", example.example_id); + let question_id = format!("nq-{}", example.example_id); + + let context = join_tokens(&example.document_tokens, 0, example.document_tokens.len()); + if context.is_empty() { + continue; + } + + if has_short_answer && !short_answer_texts.is_empty() { + let normalized_context = context.to_ascii_lowercase(); + let missing_answer = short_answer_texts.iter().any(|answer| { + let needle = answer.trim().to_ascii_lowercase(); + !needle.is_empty() && !normalized_context.contains(&needle) + }); + if missing_answer { + warn!( + question_id = %question_id, + "Skipping Natural Questions example because answers were not found in the assembled context" + ); + continue; + } + } + + if !include_unanswerable && (!has_short_answer || short_answer_texts.is_empty()) { + // yes/no-only questions are excluded by default unless --llm-mode is used + continue; + } + + let question = ConvertedQuestion { + id: question_id, + question: example.question_text.trim().to_string(), + answers, + is_impossible: is_unanswerable, + }; + + paragraphs.push(ConvertedParagraph { + id: paragraph_id, + title: example.document_title.trim().to_string(), + context, + questions: vec![question], + }); + } + + Ok(paragraphs) +} + +fn dedupe_strings(values: I) -> Vec +where + I: IntoIterator, +{ + let mut set = BTreeSet::new(); + for value in values { + let trimmed = value.trim(); + if !trimmed.is_empty() { + set.insert(trimmed.to_string()); + } + } + set.into_iter().collect() +} diff --git a/eval/src/datasets/squad.rs b/eval/src/datasets/squad.rs new file mode 100644 index 0000000..fa454f3 --- /dev/null +++ b/eval/src/datasets/squad.rs @@ -0,0 +1,107 @@ +use std::{collections::BTreeSet, fs, path::Path}; + +use anyhow::{Context, Result}; +use serde::Deserialize; + +use super::{ConvertedParagraph, ConvertedQuestion}; + +pub fn convert_squad(raw_path: &Path) -> Result> { + #[derive(Debug, Deserialize)] + struct SquadDataset { + data: Vec, + } + + #[derive(Debug, Deserialize)] + struct SquadArticle { + title: String, + paragraphs: Vec, + } + + #[derive(Debug, Deserialize)] + struct SquadParagraph { + context: String, + qas: Vec, + } + + #[derive(Debug, Deserialize)] + struct SquadQuestion { + id: String, + question: String, + answers: Vec, + #[serde(default)] + is_impossible: bool, + } + + #[derive(Debug, Deserialize)] + struct SquadAnswer { + text: String, + } + + let raw = fs::read_to_string(raw_path) + .with_context(|| format!("reading raw SQuAD dataset at {}", raw_path.display()))?; + let parsed: SquadDataset = serde_json::from_str(&raw) + .with_context(|| format!("parsing SQuAD dataset at {}", raw_path.display()))?; + + let mut paragraphs = Vec::new(); + for (article_idx, article) in parsed.data.into_iter().enumerate() { + for (paragraph_idx, paragraph) in article.paragraphs.into_iter().enumerate() { + let mut questions = Vec::new(); + for qa in paragraph.qas { + let answers = dedupe_strings(qa.answers.into_iter().map(|answer| answer.text)); + questions.push(ConvertedQuestion { + id: qa.id, + question: qa.question.trim().to_string(), + answers, + is_impossible: qa.is_impossible, + }); + } + + let paragraph_id = + format!("{}-{}", slugify(&article.title, article_idx), paragraph_idx); + + paragraphs.push(ConvertedParagraph { + id: paragraph_id, + title: article.title.trim().to_string(), + context: paragraph.context.trim().to_string(), + questions, + }); + } + } + + Ok(paragraphs) +} + +fn dedupe_strings(values: I) -> Vec +where + I: IntoIterator, +{ + let mut set = BTreeSet::new(); + for value in values { + let trimmed = value.trim(); + if !trimmed.is_empty() { + set.insert(trimmed.to_string()); + } + } + set.into_iter().collect() +} + +fn slugify(input: &str, fallback_idx: usize) -> String { + let mut slug = String::new(); + let mut last_dash = false; + for ch in input.chars() { + let c = ch.to_ascii_lowercase(); + if c.is_ascii_alphanumeric() { + slug.push(c); + last_dash = false; + } else if !last_dash { + slug.push('-'); + last_dash = true; + } + } + + slug = slug.trim_matches('-').to_string(); + if slug.is_empty() { + slug = format!("article-{fallback_idx}"); + } + slug +} diff --git a/eval/src/eval/mod.rs b/eval/src/eval/mod.rs index b9b0f36..44556a6 100644 --- a/eval/src/eval/mod.rs +++ b/eval/src/eval/mod.rs @@ -1,12 +1,10 @@ mod pipeline; +mod types; pub use pipeline::run_evaluation; +pub use types::*; -use std::{ - collections::{HashMap, HashSet}, - path::Path, - time::Duration, -}; +use std::{collections::HashMap, path::Path, time::Duration}; use anyhow::{anyhow, Context, Result}; use chrono::{DateTime, SecondsFormat, Utc}; @@ -17,10 +15,8 @@ use common::{ types::{system_settings::SystemSettings, user::User}, }, }; -use composite_retrieval::pipeline as retrieval_pipeline; -use composite_retrieval::pipeline::PipelineStageTimings; -use composite_retrieval::pipeline::RetrievalTuning; -use serde::{Deserialize, Serialize}; +use retrieval_pipeline::RetrievalTuning; +use serde::Deserialize; use tokio::io::AsyncWriteExt; use tracing::{info, warn}; @@ -33,178 +29,6 @@ use crate::{ snapshot::{self, DbSnapshotState}, }; -#[derive(Debug, Serialize)] -pub struct EvaluationSummary { - pub generated_at: DateTime, - pub k: usize, - pub limit: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub run_label: Option, - pub total_cases: usize, - pub correct: usize, - pub precision: f64, - pub correct_at_1: usize, - pub correct_at_2: usize, - pub correct_at_3: usize, - pub precision_at_1: f64, - pub precision_at_2: f64, - pub precision_at_3: f64, - pub duration_ms: u128, - pub dataset_id: String, - pub dataset_label: String, - pub dataset_includes_unanswerable: bool, - pub dataset_source: String, - pub slice_id: String, - pub slice_seed: u64, - pub slice_total_cases: usize, - pub slice_window_offset: usize, - pub slice_window_length: usize, - pub slice_cases: usize, - pub slice_positive_paragraphs: usize, - pub slice_negative_paragraphs: usize, - pub slice_total_paragraphs: usize, - pub slice_negative_multiplier: f32, - pub namespace_reused: bool, - pub corpus_paragraphs: usize, - pub ingestion_cache_path: String, - pub ingestion_reused: bool, - pub ingestion_embeddings_reused: bool, - pub ingestion_fingerprint: String, - pub positive_paragraphs_reused: usize, - pub negative_paragraphs_reused: usize, - pub latency_ms: LatencyStats, - pub perf: PerformanceTimings, - pub embedding_backend: String, - pub embedding_model: Option, - pub embedding_dimension: usize, - pub rerank_enabled: bool, - pub rerank_pool_size: Option, - pub rerank_keep_top: usize, - pub concurrency: usize, - pub detailed_report: bool, - pub chunk_vector_take: usize, - pub chunk_fts_take: usize, - pub chunk_token_budget: usize, - pub chunk_avg_chars_per_token: usize, - pub max_chunks_per_entity: usize, - pub cases: Vec, -} - -#[derive(Debug, Serialize)] -pub struct CaseSummary { - pub question_id: String, - pub question: String, - pub paragraph_id: String, - pub paragraph_title: String, - pub expected_source: String, - pub answers: Vec, - pub matched: bool, - pub entity_match: bool, - pub chunk_text_match: bool, - pub chunk_id_match: bool, - #[serde(skip_serializing_if = "Option::is_none")] - pub match_rank: Option, - pub latency_ms: u128, - pub retrieved: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct LatencyStats { - pub avg: f64, - pub p50: u128, - pub p95: u128, -} - -#[derive(Debug, Clone, Serialize)] -pub struct StageLatencyBreakdown { - pub collect_candidates: LatencyStats, - pub graph_expansion: LatencyStats, - pub chunk_attach: LatencyStats, - pub rerank: LatencyStats, - pub assemble: LatencyStats, -} - -#[derive(Debug, Default, Clone, Serialize)] -pub struct EvaluationStageTimings { - pub prepare_slice_ms: u128, - pub prepare_db_ms: u128, - pub prepare_corpus_ms: u128, - pub prepare_namespace_ms: u128, - pub run_queries_ms: u128, - pub summarize_ms: u128, - pub finalize_ms: u128, -} - -#[derive(Debug, Serialize)] -pub struct PerformanceTimings { - pub openai_base_url: String, - pub ingestion_ms: u128, - #[serde(skip_serializing_if = "Option::is_none")] - pub namespace_seed_ms: Option, - pub evaluation_stage_ms: EvaluationStageTimings, - pub stage_latency: StageLatencyBreakdown, -} - -#[derive(Debug, Serialize)] -pub struct RetrievedSummary { - pub rank: usize, - pub entity_id: String, - pub source_id: String, - pub entity_name: String, - pub score: f32, - pub matched: bool, - #[serde(skip_serializing_if = "Option::is_none")] - pub entity_description: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub entity_category: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub chunk_text_match: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub chunk_id_match: Option, -} - -#[derive(Debug, Serialize)] -pub(crate) struct CaseDiagnostics { - question_id: String, - question: String, - paragraph_id: String, - paragraph_title: String, - expected_source: String, - expected_chunk_ids: Vec, - answers: Vec, - entity_match: bool, - chunk_text_match: bool, - chunk_id_match: bool, - failure_reasons: Vec, - missing_expected_chunk_ids: Vec, - attached_chunk_ids: Vec, - retrieved: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pipeline: Option, -} - -#[derive(Debug, Serialize)] -struct EntityDiagnostics { - rank: usize, - entity_id: String, - source_id: String, - name: String, - score: f32, - entity_match: bool, - chunk_text_match: bool, - chunk_id_match: bool, - chunks: Vec, -} - -#[derive(Debug, Serialize)] -struct ChunkDiagnosticsEntry { - chunk_id: String, - score: f32, - contains_answer: bool, - expected_chunk: bool, - snippet: String, -} - pub(crate) struct SeededCase { question_id: String, question: String, @@ -213,6 +37,8 @@ pub(crate) struct SeededCase { paragraph_id: String, paragraph_title: String, expected_chunk_ids: Vec, + is_impossible: bool, + has_verified_chunks: bool, } pub(crate) fn cases_from_manifest(manifest: &ingest::CorpusManifest) -> Vec { @@ -221,10 +47,15 @@ pub(crate) fn cases_from_manifest(manifest: &ingest::CorpusManifest) -> Vec Vec bool { - if answers.is_empty() { - return true; +fn should_include_question( + question: &ingest::CorpusQuestion, + include_impossible: bool, + require_verified_chunks: bool, +) -> bool { + if !include_impossible && question.is_impossible { + return false; } - let haystack = text.to_ascii_lowercase(); - answers.iter().any(|needle| haystack.contains(needle)) -} - -pub(crate) fn compute_latency_stats(latencies: &[u128]) -> LatencyStats { - if latencies.is_empty() { - return LatencyStats { - avg: 0.0, - p50: 0, - p95: 0, - }; + if require_verified_chunks && question.matching_chunk_ids.is_empty() { + return false; } - let mut sorted = latencies.to_vec(); - sorted.sort_unstable(); - let sum: u128 = sorted.iter().copied().sum(); - let avg = sum as f64 / (sorted.len() as f64); - let p50 = percentile(&sorted, 0.50); - let p95 = percentile(&sorted, 0.95); - LatencyStats { avg, p50, p95 } -} - -pub(crate) fn build_stage_latency_breakdown( - samples: &[PipelineStageTimings], -) -> StageLatencyBreakdown { - fn collect_stage(samples: &[PipelineStageTimings], selector: F) -> Vec - where - F: Fn(&PipelineStageTimings) -> u128, - { - samples.iter().map(selector).collect() - } - - StageLatencyBreakdown { - collect_candidates: compute_latency_stats(&collect_stage(samples, |entry| { - entry.collect_candidates_ms - })), - graph_expansion: compute_latency_stats(&collect_stage(samples, |entry| { - entry.graph_expansion_ms - })), - chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| entry.chunk_attach_ms)), - rerank: compute_latency_stats(&collect_stage(samples, |entry| entry.rerank_ms)), - assemble: compute_latency_stats(&collect_stage(samples, |entry| entry.assemble_ms)), - } -} - -fn percentile(sorted: &[u128], fraction: f64) -> u128 { - if sorted.is_empty() { - return 0; - } - let clamped = fraction.clamp(0.0, 1.0); - let idx = (clamped * (sorted.len() as f64 - 1.0)).round() as usize; - sorted[idx.min(sorted.len() - 1)] + true } pub async fn grow_slice(dataset: &ConvertedDataset, config: &Config) -> Result<()> { @@ -345,16 +135,16 @@ pub(crate) fn apply_dataset_tuning_overrides( return; } - if config.chunk_vector_take.is_none() { + if config.retrieval.chunk_vector_take.is_none() { tuning.chunk_vector_take = tuning.chunk_vector_take.max(80); } - if config.chunk_fts_take.is_none() { + if config.retrieval.chunk_fts_take.is_none() { tuning.chunk_fts_take = tuning.chunk_fts_take.max(80); } - if config.chunk_token_budget.is_none() { + if config.retrieval.chunk_token_budget.is_none() { tuning.token_budget_estimate = tuning.token_budget_estimate.max(20_000); } - if config.max_chunks_per_entity.is_none() { + if config.retrieval.max_chunks_per_entity.is_none() { tuning.max_chunks_per_entity = tuning.max_chunks_per_entity.max(12); } if tuning.lexical_match_weight < 0.25 { @@ -362,92 +152,6 @@ pub(crate) fn apply_dataset_tuning_overrides( } } -pub(crate) fn build_case_diagnostics( - summary: &CaseSummary, - expected_chunk_ids: &[String], - answers_lower: &[String], - entities: &[composite_retrieval::RetrievedEntity], - pipeline_stats: Option, -) -> CaseDiagnostics { - let expected_set: HashSet<&str> = expected_chunk_ids.iter().map(|id| id.as_str()).collect(); - let mut seen_chunks: HashSet = HashSet::new(); - let mut attached_chunk_ids = Vec::new(); - let mut entity_diagnostics = Vec::new(); - - for (idx, entity) in entities.iter().enumerate() { - let mut chunk_entries = Vec::new(); - for chunk in &entity.chunks { - let contains_answer = text_contains_answer(&chunk.chunk.chunk, answers_lower); - let expected_chunk = expected_set.contains(chunk.chunk.id.as_str()); - seen_chunks.insert(chunk.chunk.id.clone()); - attached_chunk_ids.push(chunk.chunk.id.clone()); - chunk_entries.push(ChunkDiagnosticsEntry { - chunk_id: chunk.chunk.id.clone(), - score: chunk.score, - contains_answer, - expected_chunk, - snippet: chunk_preview(&chunk.chunk.chunk), - }); - } - entity_diagnostics.push(EntityDiagnostics { - rank: idx + 1, - entity_id: entity.entity.id.clone(), - source_id: entity.entity.source_id.clone(), - name: entity.entity.name.clone(), - score: entity.score, - entity_match: entity.entity.source_id == summary.expected_source, - chunk_text_match: chunk_entries.iter().any(|entry| entry.contains_answer), - chunk_id_match: chunk_entries.iter().any(|entry| entry.expected_chunk), - chunks: chunk_entries, - }); - } - - let missing_expected_chunk_ids = expected_chunk_ids - .iter() - .filter(|id| !seen_chunks.contains(id.as_str())) - .cloned() - .collect::>(); - - let mut failure_reasons = Vec::new(); - if !summary.entity_match { - failure_reasons.push("entity_miss".to_string()); - } - if !summary.chunk_text_match { - failure_reasons.push("chunk_text_missing".to_string()); - } - if !summary.chunk_id_match { - failure_reasons.push("chunk_id_missing".to_string()); - } - if !missing_expected_chunk_ids.is_empty() { - failure_reasons.push("expected_chunk_absent".to_string()); - } - - CaseDiagnostics { - question_id: summary.question_id.clone(), - question: summary.question.clone(), - paragraph_id: summary.paragraph_id.clone(), - paragraph_title: summary.paragraph_title.clone(), - expected_source: summary.expected_source.clone(), - expected_chunk_ids: expected_chunk_ids.to_vec(), - answers: summary.answers.clone(), - entity_match: summary.entity_match, - chunk_text_match: summary.chunk_text_match, - chunk_id_match: summary.chunk_id_match, - failure_reasons, - missing_expected_chunk_ids, - attached_chunk_ids, - retrieved: entity_diagnostics, - pipeline: pipeline_stats, - } -} - -fn chunk_preview(text: &str) -> String { - text.chars() - .take(200) - .collect::() - .replace('\n', " ") -} - pub(crate) async fn write_chunk_diagnostics(path: &Path, cases: &[CaseDiagnostics]) -> Result<()> { args::ensure_parent(path)?; let mut file = tokio::fs::File::create(path) @@ -765,3 +469,118 @@ pub(crate) async fn load_or_init_system_settings( Err(err) => Err(err).context("loading system settings"), } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::ingest::{CorpusManifest, CorpusMetadata, CorpusParagraph, CorpusQuestion}; + use chrono::Utc; + use common::storage::types::text_content::TextContent; + + fn sample_manifest() -> CorpusManifest { + let paragraphs = vec![ + CorpusParagraph { + paragraph_id: "p1".to_string(), + title: "Alpha".to_string(), + text_content: TextContent::new( + "alpha context".to_string(), + None, + "test".to_string(), + None, + None, + "user".to_string(), + ), + entities: Vec::new(), + relationships: Vec::new(), + chunks: Vec::new(), + }, + CorpusParagraph { + paragraph_id: "p2".to_string(), + title: "Beta".to_string(), + text_content: TextContent::new( + "beta context".to_string(), + None, + "test".to_string(), + None, + None, + "user".to_string(), + ), + entities: Vec::new(), + relationships: Vec::new(), + chunks: Vec::new(), + }, + ]; + let questions = vec![ + CorpusQuestion { + question_id: "q1".to_string(), + paragraph_id: "p1".to_string(), + text_content_id: "tc-alpha".to_string(), + question_text: "What is Alpha?".to_string(), + answers: vec!["Alpha".to_string()], + is_impossible: false, + matching_chunk_ids: vec!["chunk-alpha".to_string()], + }, + CorpusQuestion { + question_id: "q2".to_string(), + paragraph_id: "p1".to_string(), + text_content_id: "tc-alpha".to_string(), + question_text: "Unanswerable?".to_string(), + answers: Vec::new(), + is_impossible: true, + matching_chunk_ids: Vec::new(), + }, + CorpusQuestion { + question_id: "q3".to_string(), + paragraph_id: "p2".to_string(), + text_content_id: "tc-beta".to_string(), + question_text: "Where is Beta?".to_string(), + answers: vec!["Beta".to_string()], + is_impossible: false, + matching_chunk_ids: Vec::new(), + }, + ]; + CorpusManifest { + version: 1, + metadata: CorpusMetadata { + dataset_id: "ds".to_string(), + dataset_label: "Dataset".to_string(), + slice_id: "slice".to_string(), + include_unanswerable: true, + require_verified_chunks: true, + ingestion_fingerprint: "fp".to_string(), + embedding_backend: "test".to_string(), + embedding_model: None, + embedding_dimension: 3, + converted_checksum: "chk".to_string(), + generated_at: Utc::now(), + paragraph_count: paragraphs.len(), + question_count: questions.len(), + }, + paragraphs, + questions, + } + } + + #[test] + fn cases_respect_mode_filters() { + let mut manifest = sample_manifest(); + manifest.metadata.include_unanswerable = false; + manifest.metadata.require_verified_chunks = true; + + let strict_cases = cases_from_manifest(&manifest); + assert_eq!(strict_cases.len(), 1); + assert_eq!(strict_cases[0].question_id, "q1"); + assert_eq!(strict_cases[0].paragraph_title, "Alpha"); + + let mut llm_manifest = manifest.clone(); + llm_manifest.metadata.include_unanswerable = true; + llm_manifest.metadata.require_verified_chunks = false; + + let llm_cases = cases_from_manifest(&llm_manifest); + let ids: Vec<_> = llm_cases + .iter() + .map(|case| case.question_id.as_str()) + .collect(); + assert_eq!(ids, vec!["q1", "q2", "q3"]); + } +} diff --git a/eval/src/eval/pipeline/context.rs b/eval/src/eval/pipeline/context.rs index 877fde5..4a8834d 100644 --- a/eval/src/eval/pipeline/context.rs +++ b/eval/src/eval/pipeline/context.rs @@ -9,7 +9,7 @@ use common::storage::{ db::SurrealDbClient, types::{system_settings::SystemSettings, user::User}, }; -use composite_retrieval::{ +use retrieval_pipeline::{ pipeline::{PipelineStageTimings, RetrievalConfig}, reranking::RerankerPool, }; @@ -52,6 +52,7 @@ pub(super) struct EvaluationContext<'a> { pub eval_user: Option, pub corpus_handle: Option, pub cases: Vec, + pub filtered_questions: usize, pub stage_latency_samples: Vec, pub latencies: Vec, pub diagnostics_output: Vec, @@ -94,6 +95,7 @@ impl<'a> EvaluationContext<'a> { eval_user: None, corpus_handle: None, cases: Vec::new(), + filtered_questions: 0, stage_latency_samples: Vec::new(), latencies: Vec::new(), diagnostics_output: Vec::new(), diff --git a/eval/src/eval/pipeline/stages/prepare_namespace.rs b/eval/src/eval/pipeline/stages/prepare_namespace.rs index 03fc282..25e336d 100644 --- a/eval/src/eval/pipeline/stages/prepare_namespace.rs +++ b/eval/src/eval/pipeline/stages/prepare_namespace.rs @@ -128,13 +128,28 @@ pub(crate) async fn prepare_namespace( let user = ensure_eval_user(ctx.db()).await?; ctx.eval_user = Some(user); - let cases = cases_from_manifest(&ctx.corpus_handle().manifest); + let corpus_handle = ctx.corpus_handle(); + let total_manifest_questions = corpus_handle.manifest.questions.len(); + let cases = cases_from_manifest(&corpus_handle.manifest); + let include_impossible = corpus_handle.manifest.metadata.include_unanswerable; + let require_verified_chunks = corpus_handle.manifest.metadata.require_verified_chunks; + let filtered = total_manifest_questions.saturating_sub(cases.len()); + if filtered > 0 { + info!( + filtered_questions = filtered, + total_questions = total_manifest_questions, + includes_impossible = include_impossible, + require_verified_chunks = require_verified_chunks, + "Filtered questions not eligible for this evaluation mode (impossible or unverifiable)" + ); + } if cases.is_empty() { return Err(anyhow!( - "no answerable questions found in converted dataset for evaluation" + "no eligible questions found in converted dataset for evaluation (consider --llm-mode or refreshing ingestion data)" )); } ctx.cases = cases; + ctx.filtered_questions = filtered; ctx.namespace_reused = namespace_reused; ctx.namespace_seed_ms = namespace_seed_ms; diff --git a/eval/src/eval/pipeline/stages/run_queries.rs b/eval/src/eval/pipeline/stages/run_queries.rs index 1b1cf94..d6913a3 100644 --- a/eval/src/eval/pipeline/stages/run_queries.rs +++ b/eval/src/eval/pipeline/stages/run_queries.rs @@ -1,15 +1,17 @@ use std::{collections::HashSet, sync::Arc, time::Instant}; use anyhow::Context; -use futures::stream::{self, StreamExt, TryStreamExt}; +use futures::stream::{self, StreamExt}; use tracing::{debug, info}; use crate::eval::{ - apply_dataset_tuning_overrides, build_case_diagnostics, text_contains_answer, CaseDiagnostics, - CaseSummary, RetrievedSummary, + adapt_strategy_output, apply_dataset_tuning_overrides, build_case_diagnostics, + text_contains_answer, CaseDiagnostics, CaseSummary, RetrievedSummary, +}; +use retrieval_pipeline::{ + pipeline::{self, PipelineStageTimings, RetrievalConfig}, + reranking::RerankerPool, }; -use composite_retrieval::pipeline::{self, PipelineStageTimings, RetrievalConfig}; -use composite_retrieval::reranking::RerankerPool; use tokio::sync::Semaphore; use super::super::{ @@ -38,30 +40,34 @@ pub(crate) async fn run_queries( let total_cases = ctx.cases.len(); let cases_iter = std::mem::take(&mut ctx.cases).into_iter().enumerate(); - let rerank_pool = if config.rerank { - Some(RerankerPool::new(config.rerank_pool_size).context("initialising reranker pool")?) + let rerank_pool = if config.retrieval.rerank { + Some( + RerankerPool::new(config.retrieval.rerank_pool_size) + .context("initialising reranker pool")?, + ) } else { None }; let mut retrieval_config = RetrievalConfig::default(); - retrieval_config.tuning.rerank_keep_top = config.rerank_keep_top; - if retrieval_config.tuning.fallback_min_results < config.rerank_keep_top { - retrieval_config.tuning.fallback_min_results = config.rerank_keep_top; + retrieval_config.strategy = config.retrieval.strategy; + retrieval_config.tuning.rerank_keep_top = config.retrieval.rerank_keep_top; + if retrieval_config.tuning.fallback_min_results < config.retrieval.rerank_keep_top { + retrieval_config.tuning.fallback_min_results = config.retrieval.rerank_keep_top; } - if let Some(value) = config.chunk_vector_take { + if let Some(value) = config.retrieval.chunk_vector_take { retrieval_config.tuning.chunk_vector_take = value; } - if let Some(value) = config.chunk_fts_take { + if let Some(value) = config.retrieval.chunk_fts_take { retrieval_config.tuning.chunk_fts_take = value; } - if let Some(value) = config.chunk_token_budget { + if let Some(value) = config.retrieval.chunk_token_budget { retrieval_config.tuning.token_budget_estimate = value; } - if let Some(value) = config.chunk_avg_chars_per_token { + if let Some(value) = config.retrieval.chunk_avg_chars_per_token { retrieval_config.tuning.avg_chars_per_token = value; } - if let Some(value) = config.max_chunks_per_entity { + if let Some(value) = config.retrieval.max_chunks_per_entity { retrieval_config.tuning.max_chunks_per_entity = value; } @@ -69,9 +75,11 @@ pub(crate) async fn run_queries( let active_tuning = retrieval_config.tuning.clone(); let effective_chunk_vector = config + .retrieval .chunk_vector_take .unwrap_or(active_tuning.chunk_vector_take); let effective_chunk_fts = config + .retrieval .chunk_fts_take .unwrap_or(active_tuning.chunk_fts_take); @@ -83,11 +91,11 @@ pub(crate) async fn run_queries( .limit .unwrap_or(ctx.window_total_cases), negative_multiplier = %slice_settings.negative_multiplier, - rerank_enabled = config.rerank, - rerank_pool_size = config.rerank_pool_size, - rerank_keep_top = config.rerank_keep_top, - chunk_min = config.chunk_min_chars, - chunk_max = config.chunk_max_chars, + rerank_enabled = config.retrieval.rerank, + rerank_pool_size = config.retrieval.rerank_pool_size, + rerank_keep_top = config.retrieval.rerank_keep_top, + chunk_min = config.retrieval.chunk_min_chars, + chunk_max = config.retrieval.chunk_max_chars, chunk_vector_take = effective_chunk_vector, chunk_fts_take = effective_chunk_fts, chunk_token_budget = active_tuning.token_budget_estimate, @@ -122,12 +130,7 @@ pub(crate) async fn run_queries( let db = ctx.db().clone(); let openai_client = ctx.openai_client(); - let results: Vec<( - usize, - CaseSummary, - Option, - PipelineStageTimings, - )> = stream::iter(cases_iter) + let raw_results = stream::iter(cases_iter) .map(move |(idx, case)| { let db = db.clone(); let openai_client = openai_client.clone(); @@ -152,6 +155,8 @@ pub(crate) async fn run_queries( paragraph_id, paragraph_title, expected_chunk_ids, + is_impossible, + has_verified_chunks, } = case; let query_start = Instant::now(); @@ -165,7 +170,7 @@ pub(crate) async fn run_queries( None => None, }; - let (results, pipeline_diagnostics, stage_timings) = if diagnostics_enabled { + let (result_output, pipeline_diagnostics, stage_timings) = if diagnostics_enabled { let outcome = pipeline::run_pipeline_with_embedding_with_diagnostics( &db, &openai_client, @@ -194,26 +199,27 @@ pub(crate) async fn run_queries( }; let query_latency = query_start.elapsed().as_millis() as u128; + let candidates = adapt_strategy_output(result_output); let mut retrieved = Vec::new(); let mut match_rank = None; let answers_lower: Vec = answers.iter().map(|ans| ans.to_ascii_lowercase()).collect(); let expected_chunk_ids_set: HashSet<&str> = expected_chunk_ids.iter().map(|id| id.as_str()).collect(); - let chunk_id_required = !expected_chunk_ids_set.is_empty(); + let chunk_id_required = has_verified_chunks; let mut entity_hit = false; let mut chunk_text_hit = false; let mut chunk_id_hit = !chunk_id_required; - for (idx_entity, entity) in results.iter().enumerate() { + for (idx_entity, candidate) in candidates.iter().enumerate() { if idx_entity >= config.k { break; } - let entity_match = entity.entity.source_id == expected_source; + let entity_match = candidate.source_id == expected_source; if entity_match { entity_hit = true; } - let chunk_text_for_entity = entity + let chunk_text_for_entity = candidate .chunks .iter() .any(|chunk| text_contains_answer(&chunk.chunk.chunk, &answers_lower)); @@ -221,8 +227,8 @@ pub(crate) async fn run_queries( chunk_text_hit = true; } let chunk_id_for_entity = if chunk_id_required { - expected_chunk_ids_set.contains(entity.entity.source_id.as_str()) - || entity.chunks.iter().any(|chunk| { + expected_chunk_ids_set.contains(candidate.source_id.as_str()) + || candidate.chunks.iter().any(|chunk| { expected_chunk_ids_set.contains(chunk.chunk.id.as_str()) }) } else { @@ -236,9 +242,11 @@ pub(crate) async fn run_queries( match_rank = Some(idx_entity + 1); } let detail_fields = if config.detailed_report { + let description = candidate.entity_description.clone(); + let category = candidate.entity_category.clone(); ( - Some(entity.entity.description.clone()), - Some(format!("{:?}", entity.entity.entity_type)), + description, + category, Some(chunk_text_for_entity), Some(chunk_id_for_entity), ) @@ -247,10 +255,10 @@ pub(crate) async fn run_queries( }; retrieved.push(RetrievedSummary { rank: idx_entity + 1, - entity_id: entity.entity.id.clone(), - source_id: entity.entity.source_id.clone(), - entity_name: entity.entity.name.clone(), - score: entity.score, + entity_id: candidate.entity_id.clone(), + source_id: candidate.source_id.clone(), + entity_name: candidate.entity_name.clone(), + score: candidate.score, matched: success, entity_description: detail_fields.0, entity_category: detail_fields.1, @@ -271,6 +279,8 @@ pub(crate) async fn run_queries( entity_match: entity_hit, chunk_text_match: chunk_text_hit, chunk_id_match: chunk_id_hit, + is_impossible, + has_verified_chunks, match_rank, latency_ms: query_latency, retrieved, @@ -281,7 +291,7 @@ pub(crate) async fn run_queries( &summary, &expected_chunk_ids, &answers_lower, - &results, + &candidates, pipeline_diagnostics, )) } else { @@ -300,8 +310,18 @@ pub(crate) async fn run_queries( } }) .buffer_unordered(concurrency) - .try_collect() - .await?; + .collect::>() + .await; + + let mut results = Vec::with_capacity(raw_results.len()); + for result in raw_results { + match result { + Ok(val) => results.push(val), + Err(err) => { + tracing::error!(error = ?err, "Query execution failed"); + } + } + } let mut ordered = results; ordered.sort_by_key(|(idx, ..)| *idx); diff --git a/eval/src/eval/pipeline/stages/summarize.rs b/eval/src/eval/pipeline/stages/summarize.rs index a5508b3..341b84b 100644 --- a/eval/src/eval/pipeline/stages/summarize.rs +++ b/eval/src/eval/pipeline/stages/summarize.rs @@ -42,7 +42,18 @@ pub(crate) async fn summarize( let mut correct_at_1 = 0usize; let mut correct_at_2 = 0usize; let mut correct_at_3 = 0usize; + let mut retrieval_cases = 0usize; + let mut llm_cases = 0usize; + let mut llm_answered = 0usize; for summary in &summaries { + if summary.is_impossible { + llm_cases += 1; + if summary.matched { + llm_answered += 1; + } + continue; + } + retrieval_cases += 1; if summary.matched { correct += 1; if let Some(rank) = summary.match_rank { @@ -62,25 +73,31 @@ pub(crate) async fn summarize( let latency_stats = compute_latency_stats(&latencies); let stage_latency = build_stage_latency_breakdown(&stage_latency_samples); - let precision = if total_cases == 0 { + let retrieval_precision = if retrieval_cases == 0 { 0.0 } else { - (correct as f64) / (total_cases as f64) + (correct as f64) / (retrieval_cases as f64) }; - let precision_at_1 = if total_cases == 0 { + let llm_precision = if llm_cases == 0 { 0.0 } else { - (correct_at_1 as f64) / (total_cases as f64) + (llm_answered as f64) / (llm_cases as f64) }; - let precision_at_2 = if total_cases == 0 { + let precision = retrieval_precision; + let precision_at_1 = if retrieval_cases == 0 { 0.0 } else { - (correct_at_2 as f64) / (total_cases as f64) + (correct_at_1 as f64) / (retrieval_cases as f64) }; - let precision_at_3 = if total_cases == 0 { + let precision_at_2 = if retrieval_cases == 0 { 0.0 } else { - (correct_at_3 as f64) / (total_cases as f64) + (correct_at_2 as f64) / (retrieval_cases as f64) + }; + let precision_at_3 = if retrieval_cases == 0 { + 0.0 + } else { + (correct_at_3 as f64) / (retrieval_cases as f64) }; let active_tuning = ctx @@ -119,6 +136,15 @@ pub(crate) async fn summarize( dataset_label: dataset.metadata.label.clone(), dataset_includes_unanswerable: dataset.metadata.include_unanswerable, dataset_source: dataset.source.clone(), + includes_impossible_cases: slice.manifest.includes_unanswerable, + require_verified_chunks: slice.manifest.require_verified_chunks, + filtered_questions: ctx.filtered_questions, + retrieval_cases, + retrieval_correct: correct, + retrieval_precision, + llm_cases, + llm_answered, + llm_precision, slice_id: slice.manifest.slice_id.clone(), slice_seed: slice.manifest.seed, slice_total_cases: slice.manifest.case_count, @@ -146,11 +172,15 @@ pub(crate) async fn summarize( embedding_backend: ctx.embedding_provider().backend_label().to_string(), embedding_model: ctx.embedding_provider().model_code(), embedding_dimension: ctx.embedding_provider().dimension(), - rerank_enabled: config.rerank, - rerank_pool_size: ctx.rerank_pool.as_ref().map(|_| config.rerank_pool_size), - rerank_keep_top: config.rerank_keep_top, + rerank_enabled: config.retrieval.rerank, + rerank_pool_size: ctx + .rerank_pool + .as_ref() + .map(|_| config.retrieval.rerank_pool_size), + rerank_keep_top: config.retrieval.rerank_keep_top, concurrency: config.concurrency.max(1), detailed_report: config.detailed_report, + retrieval_strategy: config.retrieval.strategy.to_string(), chunk_vector_take: active_tuning.chunk_vector_take, chunk_fts_take: active_tuning.chunk_fts_take, chunk_token_budget: active_tuning.token_budget_estimate, diff --git a/eval/src/eval/types.rs b/eval/src/eval/types.rs new file mode 100644 index 0000000..5e567d5 --- /dev/null +++ b/eval/src/eval/types.rs @@ -0,0 +1,396 @@ +use std::collections::HashSet; + +use chrono::{DateTime, Utc}; +use retrieval_pipeline::{ + PipelineDiagnostics, PipelineStageTimings, RetrievedChunk, RetrievedEntity, StrategyOutput, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize)] +pub struct EvaluationSummary { + pub generated_at: DateTime, + pub k: usize, + pub limit: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub run_label: Option, + pub total_cases: usize, + pub correct: usize, + pub precision: f64, + pub correct_at_1: usize, + pub correct_at_2: usize, + pub correct_at_3: usize, + pub precision_at_1: f64, + pub precision_at_2: f64, + pub precision_at_3: f64, + pub duration_ms: u128, + pub dataset_id: String, + pub dataset_label: String, + pub dataset_includes_unanswerable: bool, + pub dataset_source: String, + pub includes_impossible_cases: bool, + pub require_verified_chunks: bool, + pub filtered_questions: usize, + pub retrieval_cases: usize, + pub retrieval_correct: usize, + pub retrieval_precision: f64, + pub llm_cases: usize, + pub llm_answered: usize, + pub llm_precision: f64, + pub slice_id: String, + pub slice_seed: u64, + pub slice_total_cases: usize, + pub slice_window_offset: usize, + pub slice_window_length: usize, + pub slice_cases: usize, + pub slice_positive_paragraphs: usize, + pub slice_negative_paragraphs: usize, + pub slice_total_paragraphs: usize, + pub slice_negative_multiplier: f32, + pub namespace_reused: bool, + pub corpus_paragraphs: usize, + pub ingestion_cache_path: String, + pub ingestion_reused: bool, + pub ingestion_embeddings_reused: bool, + pub ingestion_fingerprint: String, + pub positive_paragraphs_reused: usize, + pub negative_paragraphs_reused: usize, + pub latency_ms: LatencyStats, + pub perf: PerformanceTimings, + pub embedding_backend: String, + pub embedding_model: Option, + pub embedding_dimension: usize, + pub rerank_enabled: bool, + pub rerank_pool_size: Option, + pub rerank_keep_top: usize, + pub concurrency: usize, + pub detailed_report: bool, + pub retrieval_strategy: String, + pub chunk_vector_take: usize, + pub chunk_fts_take: usize, + pub chunk_token_budget: usize, + pub chunk_avg_chars_per_token: usize, + pub max_chunks_per_entity: usize, + pub cases: Vec, +} + +#[derive(Debug, Serialize)] +pub struct CaseSummary { + pub question_id: String, + pub question: String, + pub paragraph_id: String, + pub paragraph_title: String, + pub expected_source: String, + pub answers: Vec, + pub matched: bool, + pub entity_match: bool, + pub chunk_text_match: bool, + pub chunk_id_match: bool, + pub is_impossible: bool, + pub has_verified_chunks: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub match_rank: Option, + pub latency_ms: u128, + pub retrieved: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LatencyStats { + pub avg: f64, + pub p50: u128, + pub p95: u128, +} + +#[derive(Debug, Clone, Serialize)] +pub struct StageLatencyBreakdown { + pub embed: LatencyStats, + pub collect_candidates: LatencyStats, + pub graph_expansion: LatencyStats, + pub chunk_attach: LatencyStats, + pub rerank: LatencyStats, + pub assemble: LatencyStats, +} + +#[derive(Debug, Default, Clone, Serialize)] +pub struct EvaluationStageTimings { + pub prepare_slice_ms: u128, + pub prepare_db_ms: u128, + pub prepare_corpus_ms: u128, + pub prepare_namespace_ms: u128, + pub run_queries_ms: u128, + pub summarize_ms: u128, + pub finalize_ms: u128, +} + +#[derive(Debug, Serialize)] +pub struct PerformanceTimings { + pub openai_base_url: String, + pub ingestion_ms: u128, + #[serde(skip_serializing_if = "Option::is_none")] + pub namespace_seed_ms: Option, + pub evaluation_stage_ms: EvaluationStageTimings, + pub stage_latency: StageLatencyBreakdown, +} + +#[derive(Debug, Serialize)] +pub struct RetrievedSummary { + pub rank: usize, + pub entity_id: String, + pub source_id: String, + pub entity_name: String, + pub score: f32, + pub matched: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub entity_description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub entity_category: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub chunk_text_match: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub chunk_id_match: Option, +} + +#[derive(Debug, Clone)] +pub struct EvaluationCandidate { + pub entity_id: String, + pub source_id: String, + pub entity_name: String, + pub entity_description: Option, + pub entity_category: Option, + pub score: f32, + pub chunks: Vec, +} + +impl EvaluationCandidate { + fn from_entity(entity: RetrievedEntity) -> Self { + let entity_category = Some(format!("{:?}", entity.entity.entity_type)); + Self { + entity_id: entity.entity.id.clone(), + source_id: entity.entity.source_id.clone(), + entity_name: entity.entity.name.clone(), + entity_description: Some(entity.entity.description.clone()), + entity_category, + score: entity.score, + chunks: entity.chunks, + } + } + + fn from_chunk(chunk: RetrievedChunk) -> Self { + let snippet = chunk_snippet(&chunk.chunk.chunk); + Self { + entity_id: chunk.chunk.id.clone(), + source_id: chunk.chunk.source_id.clone(), + entity_name: chunk.chunk.source_id.clone(), + entity_description: Some(snippet), + entity_category: Some("Chunk".to_string()), + score: chunk.score, + chunks: vec![chunk], + } + } +} + +pub fn adapt_strategy_output(output: StrategyOutput) -> Vec { + match output { + StrategyOutput::Entities(entities) => entities + .into_iter() + .map(EvaluationCandidate::from_entity) + .collect(), + StrategyOutput::Chunks(chunks) => chunks + .into_iter() + .map(EvaluationCandidate::from_chunk) + .collect(), + } +} + +#[derive(Debug, Serialize)] +pub struct CaseDiagnostics { + pub question_id: String, + pub question: String, + pub paragraph_id: String, + pub paragraph_title: String, + pub expected_source: String, + pub expected_chunk_ids: Vec, + pub answers: Vec, + pub entity_match: bool, + pub chunk_text_match: bool, + pub chunk_id_match: bool, + pub failure_reasons: Vec, + pub missing_expected_chunk_ids: Vec, + pub attached_chunk_ids: Vec, + pub retrieved: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub pipeline: Option, +} + +#[derive(Debug, Serialize)] +pub struct EntityDiagnostics { + pub rank: usize, + pub entity_id: String, + pub source_id: String, + pub name: String, + pub score: f32, + pub entity_match: bool, + pub chunk_text_match: bool, + pub chunk_id_match: bool, + pub chunks: Vec, +} + +#[derive(Debug, Serialize)] +pub struct ChunkDiagnosticsEntry { + pub chunk_id: String, + pub score: f32, + pub contains_answer: bool, + pub expected_chunk: bool, + pub snippet: String, +} + +pub fn text_contains_answer(text: &str, answers: &[String]) -> bool { + if answers.is_empty() { + return true; + } + let haystack = text.to_ascii_lowercase(); + answers.iter().any(|needle| haystack.contains(needle)) +} + +fn chunk_snippet(text: &str) -> String { + const MAX_CHARS: usize = 160; + let trimmed = text.trim(); + if trimmed.chars().count() <= MAX_CHARS { + return trimmed.to_string(); + } + let mut acc = String::with_capacity(MAX_CHARS + 3); + for (idx, ch) in trimmed.chars().enumerate() { + if idx >= MAX_CHARS { + break; + } + acc.push(ch); + } + acc.push_str("..."); + acc +} + +pub fn compute_latency_stats(latencies: &[u128]) -> LatencyStats { + if latencies.is_empty() { + return LatencyStats { + avg: 0.0, + p50: 0, + p95: 0, + }; + } + let mut sorted = latencies.to_vec(); + sorted.sort_unstable(); + let sum: u128 = sorted.iter().copied().sum(); + let avg = sum as f64 / (sorted.len() as f64); + let p50 = percentile(&sorted, 0.50); + let p95 = percentile(&sorted, 0.95); + LatencyStats { avg, p50, p95 } +} + +pub fn build_stage_latency_breakdown(samples: &[PipelineStageTimings]) -> StageLatencyBreakdown { + fn collect_stage(samples: &[PipelineStageTimings], selector: F) -> Vec + where + F: Fn(&PipelineStageTimings) -> u128, + { + samples.iter().map(selector).collect() + } + + StageLatencyBreakdown { + embed: compute_latency_stats(&collect_stage(samples, |entry| entry.embed_ms)), + collect_candidates: compute_latency_stats(&collect_stage(samples, |entry| { + entry.collect_candidates_ms + })), + graph_expansion: compute_latency_stats(&collect_stage(samples, |entry| { + entry.graph_expansion_ms + })), + chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| entry.chunk_attach_ms)), + rerank: compute_latency_stats(&collect_stage(samples, |entry| entry.rerank_ms)), + assemble: compute_latency_stats(&collect_stage(samples, |entry| entry.assemble_ms)), + } +} + +fn percentile(sorted: &[u128], fraction: f64) -> u128 { + if sorted.is_empty() { + return 0; + } + let clamped = fraction.clamp(0.0, 1.0); + let idx = (clamped * (sorted.len() as f64 - 1.0)).round() as usize; + sorted[idx.min(sorted.len() - 1)] +} + +pub fn build_case_diagnostics( + summary: &CaseSummary, + expected_chunk_ids: &[String], + answers_lower: &[String], + candidates: &[EvaluationCandidate], + pipeline_stats: Option, +) -> CaseDiagnostics { + let expected_set: HashSet<&str> = expected_chunk_ids.iter().map(|id| id.as_str()).collect(); + let mut seen_chunks: HashSet = HashSet::new(); + let mut attached_chunk_ids = Vec::new(); + let mut entity_diagnostics = Vec::new(); + + for (idx, candidate) in candidates.iter().enumerate() { + let mut chunk_entries = Vec::new(); + for chunk in &candidate.chunks { + let contains_answer = text_contains_answer(&chunk.chunk.chunk, answers_lower); + let expected_chunk = expected_set.contains(chunk.chunk.id.as_str()); + seen_chunks.insert(chunk.chunk.id.clone()); + attached_chunk_ids.push(chunk.chunk.id.clone()); + chunk_entries.push(ChunkDiagnosticsEntry { + chunk_id: chunk.chunk.id.clone(), + score: chunk.score, + contains_answer, + expected_chunk, + snippet: chunk_snippet(&chunk.chunk.chunk), + }); + } + entity_diagnostics.push(EntityDiagnostics { + rank: idx + 1, + entity_id: candidate.entity_id.clone(), + source_id: candidate.source_id.clone(), + name: candidate.entity_name.clone(), + score: candidate.score, + entity_match: candidate.source_id == summary.expected_source, + chunk_text_match: chunk_entries.iter().any(|entry| entry.contains_answer), + chunk_id_match: chunk_entries.iter().any(|entry| entry.expected_chunk), + chunks: chunk_entries, + }); + } + + let missing_expected_chunk_ids = expected_chunk_ids + .iter() + .filter(|id| !seen_chunks.contains(id.as_str())) + .cloned() + .collect::>(); + + let mut failure_reasons = Vec::new(); + if !summary.entity_match { + failure_reasons.push("entity_miss".to_string()); + } + if !summary.chunk_text_match { + failure_reasons.push("chunk_text_missing".to_string()); + } + if !summary.chunk_id_match { + failure_reasons.push("chunk_id_missing".to_string()); + } + if !missing_expected_chunk_ids.is_empty() { + failure_reasons.push("expected_chunk_absent".to_string()); + } + + CaseDiagnostics { + question_id: summary.question_id.clone(), + question: summary.question.clone(), + paragraph_id: summary.paragraph_id.clone(), + paragraph_title: summary.paragraph_title.clone(), + expected_source: summary.expected_source.clone(), + expected_chunk_ids: expected_chunk_ids.to_vec(), + answers: summary.answers.clone(), + entity_match: summary.entity_match, + chunk_text_match: summary.chunk_text_match, + chunk_id_match: summary.chunk_id_match, + failure_reasons, + missing_expected_chunk_ids, + attached_chunk_ids, + retrieved: entity_diagnostics, + pipeline: pipeline_stats, + } +} diff --git a/eval/src/ingest/config.rs b/eval/src/ingest/config.rs new file mode 100644 index 0000000..fd2eed1 --- /dev/null +++ b/eval/src/ingest/config.rs @@ -0,0 +1,72 @@ +use std::path::PathBuf; + +use anyhow::Result; +use async_trait::async_trait; + +use crate::{args::Config, embedding::EmbeddingProvider}; + +#[derive(Debug, Clone)] +pub struct CorpusCacheConfig { + pub ingestion_cache_dir: PathBuf, + pub force_refresh: bool, + pub refresh_embeddings_only: bool, + pub ingestion_batch_size: usize, + pub ingestion_max_retries: usize, +} + +impl CorpusCacheConfig { + pub fn new( + ingestion_cache_dir: impl Into, + force_refresh: bool, + refresh_embeddings_only: bool, + ingestion_batch_size: usize, + ingestion_max_retries: usize, + ) -> Self { + Self { + ingestion_cache_dir: ingestion_cache_dir.into(), + force_refresh, + refresh_embeddings_only, + ingestion_batch_size, + ingestion_max_retries, + } + } +} + +#[async_trait] +pub trait CorpusEmbeddingProvider: Send + Sync { + fn backend_label(&self) -> &str; + fn model_code(&self) -> Option; + fn dimension(&self) -> usize; + async fn embed_batch(&self, texts: Vec) -> Result>>; +} + +#[async_trait] +impl CorpusEmbeddingProvider for EmbeddingProvider { + fn backend_label(&self) -> &str { + EmbeddingProvider::backend_label(self) + } + + fn model_code(&self) -> Option { + EmbeddingProvider::model_code(self) + } + + fn dimension(&self) -> usize { + EmbeddingProvider::dimension(self) + } + + async fn embed_batch(&self, texts: Vec) -> Result>> { + EmbeddingProvider::embed_batch(self, texts).await + } +} + +impl From<&Config> for CorpusCacheConfig { + fn from(config: &Config) -> Self { + CorpusCacheConfig::new( + config.ingestion_cache_dir.clone(), + config.force_convert || config.slice_reset_ingestion, + config.refresh_embeddings_only, + config.ingestion_batch_size, + config.ingestion_max_retries, + ) + } +} diff --git a/eval/src/ingest/mod.rs b/eval/src/ingest/mod.rs new file mode 100644 index 0000000..cf589d4 --- /dev/null +++ b/eval/src/ingest/mod.rs @@ -0,0 +1,10 @@ +mod config; +mod orchestrator; +mod store; + +pub use config::{CorpusCacheConfig, CorpusEmbeddingProvider}; +pub use orchestrator::ensure_corpus; +pub use store::{ + seed_manifest_into_db, CorpusHandle, CorpusManifest, CorpusMetadata, CorpusQuestion, + ParagraphShard, ParagraphShardStore, MANIFEST_VERSION, +}; diff --git a/eval/src/ingest.rs b/eval/src/ingest/orchestrator.rs similarity index 59% rename from eval/src/ingest.rs rename to eval/src/ingest/orchestrator.rs index c02fe23..e6ea9ad 100644 --- a/eval/src/ingest.rs +++ b/eval/src/ingest/orchestrator.rs @@ -1,23 +1,21 @@ use std::{ collections::{HashMap, HashSet}, fs, - io::{BufReader, Read}, - path::{Path, PathBuf}, + io::Read, + path::Path, sync::Arc, }; use anyhow::{anyhow, Context, Result}; use async_openai::Client; -use async_trait::async_trait; -use chrono::{DateTime, Utc}; +use chrono::Utc; use common::{ storage::{ db::SurrealDbClient, store::{DynStore, StorageManager}, types::{ ingestion_payload::IngestionPayload, ingestion_task::IngestionTask, - knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship, - text_chunk::TextChunk, text_content::TextContent, + knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, }, }, utils::config::{AppConfig, StorageKind}, @@ -30,274 +28,19 @@ use tracing::{info, warn}; use uuid::Uuid; use crate::{ - args::Config, datasets::{ConvertedDataset, ConvertedParagraph, ConvertedQuestion}, - embedding::EmbeddingProvider, slices::{self, ResolvedSlice, SliceParagraphKind}, }; -const MANIFEST_VERSION: u32 = 1; +use crate::ingest::{ + CorpusCacheConfig, CorpusEmbeddingProvider, CorpusHandle, CorpusManifest, CorpusMetadata, + CorpusQuestion, ParagraphShard, ParagraphShardStore, MANIFEST_VERSION, +}; + const INGESTION_SPEC_VERSION: u32 = 1; -const INGESTION_MAX_RETRIES: usize = 3; -const INGESTION_BATCH_SIZE: usize = 5; -const PARAGRAPH_SHARD_VERSION: u32 = 1; - -#[derive(Debug, Clone)] -pub struct CorpusCacheConfig { - pub ingestion_cache_dir: PathBuf, - pub force_refresh: bool, - pub refresh_embeddings_only: bool, -} - -impl CorpusCacheConfig { - pub fn new( - ingestion_cache_dir: impl Into, - force_refresh: bool, - refresh_embeddings_only: bool, - ) -> Self { - Self { - ingestion_cache_dir: ingestion_cache_dir.into(), - force_refresh, - refresh_embeddings_only, - } - } -} - -#[async_trait] -pub trait CorpusEmbeddingProvider: Send + Sync { - fn backend_label(&self) -> &str; - fn model_code(&self) -> Option; - fn dimension(&self) -> usize; - async fn embed_batch(&self, texts: Vec) -> Result>>; -} type OpenAIClient = Client; -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -pub struct CorpusManifest { - pub version: u32, - pub metadata: CorpusMetadata, - pub paragraphs: Vec, - pub questions: Vec, -} - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -pub struct CorpusMetadata { - pub dataset_id: String, - pub dataset_label: String, - pub slice_id: String, - pub include_unanswerable: bool, - pub ingestion_fingerprint: String, - pub embedding_backend: String, - pub embedding_model: Option, - pub embedding_dimension: usize, - pub converted_checksum: String, - pub generated_at: DateTime, - pub paragraph_count: usize, - pub question_count: usize, -} - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -pub struct CorpusParagraph { - pub paragraph_id: String, - pub title: String, - pub text_content: TextContent, - pub entities: Vec, - pub relationships: Vec, - pub chunks: Vec, -} - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -pub struct CorpusQuestion { - pub question_id: String, - pub paragraph_id: String, - pub text_content_id: String, - pub question_text: String, - pub answers: Vec, - pub is_impossible: bool, - pub matching_chunk_ids: Vec, -} - -pub struct CorpusHandle { - pub manifest: CorpusManifest, - pub path: PathBuf, - pub reused_ingestion: bool, - pub reused_embeddings: bool, - pub positive_reused: usize, - pub positive_ingested: usize, - pub negative_reused: usize, - pub negative_ingested: usize, -} - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -struct ParagraphShard { - version: u32, - paragraph_id: String, - shard_path: String, - ingestion_fingerprint: String, - ingested_at: DateTime, - title: String, - text_content: TextContent, - entities: Vec, - relationships: Vec, - chunks: Vec, - #[serde(default)] - question_bindings: HashMap>, - #[serde(default)] - embedding_backend: String, - #[serde(default)] - embedding_model: Option, - #[serde(default)] - embedding_dimension: usize, -} - -struct ParagraphShardStore { - base_dir: PathBuf, -} - -impl ParagraphShardStore { - fn new(base_dir: PathBuf) -> Self { - Self { base_dir } - } - - fn ensure_base_dir(&self) -> Result<()> { - fs::create_dir_all(&self.base_dir) - .with_context(|| format!("creating shard base dir {}", self.base_dir.display())) - } - - fn resolve(&self, relative: &str) -> PathBuf { - self.base_dir.join(relative) - } - - fn load(&self, relative: &str, fingerprint: &str) -> Result> { - let path = self.resolve(relative); - let file = match fs::File::open(&path) { - Ok(file) => file, - Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None), - Err(err) => { - return Err(err).with_context(|| format!("opening shard {}", path.display())) - } - }; - let reader = BufReader::new(file); - let mut shard: ParagraphShard = serde_json::from_reader(reader) - .with_context(|| format!("parsing shard {}", path.display()))?; - if shard.version != PARAGRAPH_SHARD_VERSION { - warn!( - path = %path.display(), - version = shard.version, - expected = PARAGRAPH_SHARD_VERSION, - "Skipping shard due to version mismatch" - ); - return Ok(None); - } - if shard.ingestion_fingerprint != fingerprint { - return Ok(None); - } - shard.shard_path = relative.to_string(); - Ok(Some(shard)) - } - - fn persist(&self, shard: &ParagraphShard) -> Result<()> { - let path = self.resolve(&shard.shard_path); - if let Some(parent) = path.parent() { - fs::create_dir_all(parent) - .with_context(|| format!("creating shard dir {}", parent.display()))?; - } - let tmp_path = path.with_extension("json.tmp"); - let body = serde_json::to_vec_pretty(shard).context("serialising paragraph shard")?; - fs::write(&tmp_path, &body) - .with_context(|| format!("writing shard tmp {}", tmp_path.display()))?; - fs::rename(&tmp_path, &path) - .with_context(|| format!("renaming shard tmp {}", path.display()))?; - Ok(()) - } -} - -#[async_trait] -impl CorpusEmbeddingProvider for EmbeddingProvider { - fn backend_label(&self) -> &str { - EmbeddingProvider::backend_label(self) - } - - fn model_code(&self) -> Option { - EmbeddingProvider::model_code(self) - } - - fn dimension(&self) -> usize { - EmbeddingProvider::dimension(self) - } - - async fn embed_batch(&self, texts: Vec) -> Result>> { - EmbeddingProvider::embed_batch(self, texts).await - } -} - -impl From<&Config> for CorpusCacheConfig { - fn from(config: &Config) -> Self { - CorpusCacheConfig::new( - config.ingestion_cache_dir.clone(), - config.force_convert || config.slice_reset_ingestion, - config.refresh_embeddings_only, - ) - } -} - -impl ParagraphShard { - fn new( - paragraph: &ConvertedParagraph, - shard_path: String, - ingestion_fingerprint: &str, - text_content: TextContent, - entities: Vec, - relationships: Vec, - chunks: Vec, - embedding_backend: &str, - embedding_model: Option, - embedding_dimension: usize, - ) -> Self { - Self { - version: PARAGRAPH_SHARD_VERSION, - paragraph_id: paragraph.id.clone(), - shard_path, - ingestion_fingerprint: ingestion_fingerprint.to_string(), - ingested_at: Utc::now(), - title: paragraph.title.clone(), - text_content, - entities, - relationships, - chunks, - question_bindings: HashMap::new(), - embedding_backend: embedding_backend.to_string(), - embedding_model, - embedding_dimension, - } - } - - fn to_corpus_paragraph(&self) -> CorpusParagraph { - CorpusParagraph { - paragraph_id: self.paragraph_id.clone(), - title: self.title.clone(), - text_content: self.text_content.clone(), - entities: self.entities.clone(), - relationships: self.relationships.clone(), - chunks: self.chunks.clone(), - } - } - - fn ensure_question_binding( - &mut self, - question: &ConvertedQuestion, - ) -> Result<(Vec, bool)> { - if let Some(existing) = self.question_bindings.get(&question.id) { - return Ok((existing.clone(), false)); - } - let chunk_ids = validate_answers(&self.text_content, &self.chunks, question)?; - self.question_bindings - .insert(question.id.clone(), chunk_ids.clone()); - Ok((chunk_ids, true)) - } -} - #[derive(Clone)] struct ParagraphShardRecord { shard: ParagraphShard, @@ -390,6 +133,7 @@ pub async fn ensure_corpus( store.ensure_base_dir()?; let positive_set: HashSet<&str> = window.positive_ids().collect(); + let require_verified_chunks = slice.manifest.require_verified_chunks; let embedding_backend_label = embedding.backend_label().to_string(); let embedding_model_code = embedding.model_code(); let embedding_dimension = embedding.dimension(); @@ -487,6 +231,8 @@ pub async fn ensure_corpus( &embedding_backend_label, embedding_model_code.clone(), embedding_dimension, + cache.ingestion_batch_size, + cache.ingestion_max_retries, ) .await .context("ingesting missing slice paragraphs")?; @@ -548,6 +294,12 @@ pub async fn ensure_corpus( let (chunk_ids, updated) = match record.shard.ensure_question_binding(case.question) { Ok(result) => result, Err(err) => { + if require_verified_chunks { + return Err(err).context(format!( + "locating answer text for question '{}' in paragraph '{}'", + case.question.id, case.paragraph.id + )); + } warn!( question_id = %case.question.id, paragraph_id = %case.paragraph.id, @@ -591,6 +343,7 @@ pub async fn ensure_corpus( dataset_label: dataset.metadata.label.clone(), slice_id: slice.manifest.slice_id.clone(), include_unanswerable: slice.manifest.includes_unanswerable, + require_verified_chunks: slice.manifest.require_verified_chunks, ingestion_fingerprint: ingestion_fingerprint.clone(), embedding_backend: embedding.backend_label().to_string(), embedding_model: embedding.model_code(), @@ -681,6 +434,8 @@ async fn ingest_paragraph_batch( embedding_backend: &str, embedding_model: Option, embedding_dimension: usize, + batch_size: usize, + max_retries: usize, ) -> Result> { if targets.is_empty() { return Ok(Vec::new()); @@ -704,7 +459,7 @@ async fn ingest_paragraph_batch( db, openai.clone(), app_config, - None::>, + None::>, storage, ) .await?; @@ -712,11 +467,11 @@ async fn ingest_paragraph_batch( let mut shards = Vec::with_capacity(targets.len()); let category = dataset.metadata.category.clone(); - for (batch_index, batch) in targets.chunks(INGESTION_BATCH_SIZE).enumerate() { + for (batch_index, batch) in targets.chunks(batch_size).enumerate() { info!( batch = batch_index, batch_size = batch.len(), - total_batches = (targets.len() + INGESTION_BATCH_SIZE - 1) / INGESTION_BATCH_SIZE, + total_batches = (targets.len() + batch_size - 1) / batch_size, "Ingesting paragraph batch" ); let model_clone = embedding_model.clone(); @@ -734,6 +489,7 @@ async fn ingest_paragraph_batch( backend_clone.clone(), model_clone.clone(), embedding_dimension, + max_retries, ) }); let batch_results: Vec = try_join_all(tasks) @@ -755,10 +511,11 @@ async fn ingest_single_paragraph( embedding_backend: String, embedding_model: Option, embedding_dimension: usize, + max_retries: usize, ) -> Result { let paragraph = request.paragraph; let mut last_err: Option = None; - for attempt in 1..=INGESTION_MAX_RETRIES { + for attempt in 1..=max_retries { let payload = IngestionPayload::Text { text: paragraph.context.clone(), context: paragraph.title.clone(), @@ -801,7 +558,7 @@ async fn ingest_single_paragraph( warn!( paragraph_id = %paragraph.id, attempt, - max_attempts = INGESTION_MAX_RETRIES, + max_attempts = max_retries, error = ?err, "ingestion attempt failed for paragraph; retrying" ); @@ -815,49 +572,6 @@ async fn ingest_single_paragraph( .context(format!("running ingestion for paragraph {}", paragraph.id))) } -fn validate_answers( - content: &TextContent, - chunks: &[TextChunk], - question: &ConvertedQuestion, -) -> Result> { - if question.is_impossible || question.answers.is_empty() { - return Ok(Vec::new()); - } - - let mut matches = std::collections::BTreeSet::new(); - let mut found_any = false; - let haystack = content.text.to_ascii_lowercase(); - let haystack_norm = normalize_answer_text(&haystack); - for answer in &question.answers { - let needle: String = answer.to_ascii_lowercase(); - let needle_norm = normalize_answer_text(&needle); - let text_match = haystack.contains(&needle) - || (!needle_norm.is_empty() && haystack_norm.contains(&needle_norm)); - if text_match { - found_any = true; - } - for chunk in chunks { - let chunk_text = chunk.chunk.to_ascii_lowercase(); - let chunk_norm = normalize_answer_text(&chunk_text); - if chunk_text.contains(&needle) - || (!needle_norm.is_empty() && chunk_norm.contains(&needle_norm)) - { - matches.insert(chunk.id.clone()); - found_any = true; - } - } - } - - if !found_any { - Err(anyhow!( - "expected answer for question '{}' was not found in ingested content", - question.id - )) - } else { - Ok(matches.into_iter().collect()) - } -} - fn build_ingestion_fingerprint( dataset: &ConvertedDataset, slice: &ResolvedSlice<'_>, @@ -894,107 +608,3 @@ fn compute_file_checksum(path: &Path) -> Result { } Ok(format!("{:x}", hasher.finalize())) } - -pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManifest) -> Result<()> { - for paragraph in &manifest.paragraphs { - db.store_item(paragraph.text_content.clone()) - .await - .context("storing text_content from manifest")?; - for entity in ¶graph.entities { - db.store_item(entity.clone()) - .await - .context("storing knowledge_entity from manifest")?; - } - for relationship in ¶graph.relationships { - relationship - .store_relationship(db) - .await - .context("storing knowledge_relationship from manifest")?; - } - for chunk in ¶graph.chunks { - db.store_item(chunk.clone()) - .await - .context("storing text_chunk from manifest")?; - } - } - - Ok(()) -} - -fn normalize_answer_text(text: &str) -> String { - text.chars() - .map(|ch| { - if ch.is_alphanumeric() || ch.is_whitespace() { - ch.to_ascii_lowercase() - } else { - ' ' - } - }) - .collect::() - .split_whitespace() - .collect::>() - .join(" ") -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::datasets::ConvertedQuestion; - - fn mock_text_content() -> TextContent { - TextContent { - id: "tc1".into(), - created_at: Utc::now(), - updated_at: Utc::now(), - text: "alpha beta gamma".into(), - file_info: None, - url_info: None, - context: Some("ctx".into()), - category: "cat".into(), - user_id: "user".into(), - } - } - - fn mock_chunk(id: &str, text: &str) -> TextChunk { - TextChunk { - id: id.into(), - created_at: Utc::now(), - updated_at: Utc::now(), - source_id: "src".into(), - chunk: text.into(), - embedding: vec![], - user_id: "user".into(), - } - } - - #[test] - fn validate_answers_passes_when_present() { - let content = mock_text_content(); - let chunk = mock_chunk("chunk1", "alpha chunk"); - let question = ConvertedQuestion { - id: "q1".into(), - question: "?".into(), - answers: vec!["Alpha".into()], - is_impossible: false, - }; - let matches = validate_answers(&content, &[chunk], &question).expect("answers match"); - assert_eq!(matches, vec!["chunk1".to_string()]); - } - - #[test] - fn validate_answers_fails_when_missing() { - let question = ConvertedQuestion { - id: "q1".into(), - question: "?".into(), - answers: vec!["delta".into()], - is_impossible: false, - }; - let err = validate_answers( - &mock_text_content(), - &[mock_chunk("chunk", "alpha")], - &question, - ) - .expect_err("missing answer should fail"); - assert!(err.to_string().contains("not found")); - } -} diff --git a/eval/src/ingest/store.rs b/eval/src/ingest/store.rs new file mode 100644 index 0000000..be786dd --- /dev/null +++ b/eval/src/ingest/store.rs @@ -0,0 +1,299 @@ +use std::{collections::HashMap, fs, io::BufReader, path::PathBuf}; + +use anyhow::{anyhow, Context, Result}; +use chrono::{DateTime, Utc}; +use common::storage::{ + db::SurrealDbClient, + types::{ + knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship, + text_chunk::TextChunk, text_content::TextContent, + }, +}; +use tracing::warn; + +use crate::datasets::{ConvertedParagraph, ConvertedQuestion}; + +pub const MANIFEST_VERSION: u32 = 1; +pub const PARAGRAPH_SHARD_VERSION: u32 = 1; + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct CorpusManifest { + pub version: u32, + pub metadata: CorpusMetadata, + pub paragraphs: Vec, + pub questions: Vec, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct CorpusMetadata { + pub dataset_id: String, + pub dataset_label: String, + pub slice_id: String, + pub include_unanswerable: bool, + #[serde(default)] + pub require_verified_chunks: bool, + pub ingestion_fingerprint: String, + pub embedding_backend: String, + pub embedding_model: Option, + pub embedding_dimension: usize, + pub converted_checksum: String, + pub generated_at: DateTime, + pub paragraph_count: usize, + pub question_count: usize, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct CorpusParagraph { + pub paragraph_id: String, + pub title: String, + pub text_content: TextContent, + pub entities: Vec, + pub relationships: Vec, + pub chunks: Vec, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct CorpusQuestion { + pub question_id: String, + pub paragraph_id: String, + pub text_content_id: String, + pub question_text: String, + pub answers: Vec, + pub is_impossible: bool, + pub matching_chunk_ids: Vec, +} + +pub struct CorpusHandle { + pub manifest: CorpusManifest, + pub path: PathBuf, + pub reused_ingestion: bool, + pub reused_embeddings: bool, + pub positive_reused: usize, + pub positive_ingested: usize, + pub negative_reused: usize, + pub negative_ingested: usize, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct ParagraphShard { + pub version: u32, + pub paragraph_id: String, + pub shard_path: String, + pub ingestion_fingerprint: String, + pub ingested_at: DateTime, + pub title: String, + pub text_content: TextContent, + pub entities: Vec, + pub relationships: Vec, + pub chunks: Vec, + #[serde(default)] + pub question_bindings: HashMap>, + #[serde(default)] + pub embedding_backend: String, + #[serde(default)] + pub embedding_model: Option, + #[serde(default)] + pub embedding_dimension: usize, +} + +pub struct ParagraphShardStore { + base_dir: PathBuf, +} + +impl ParagraphShardStore { + pub fn new(base_dir: PathBuf) -> Self { + Self { base_dir } + } + + pub fn ensure_base_dir(&self) -> Result<()> { + fs::create_dir_all(&self.base_dir) + .with_context(|| format!("creating shard base dir {}", self.base_dir.display())) + } + + fn resolve(&self, relative: &str) -> PathBuf { + self.base_dir.join(relative) + } + + pub fn load(&self, relative: &str, fingerprint: &str) -> Result> { + let path = self.resolve(relative); + let file = match fs::File::open(&path) { + Ok(file) => file, + Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None), + Err(err) => { + return Err(err).with_context(|| format!("opening shard {}", path.display())) + } + }; + let reader = BufReader::new(file); + let mut shard: ParagraphShard = serde_json::from_reader(reader) + .with_context(|| format!("parsing shard {}", path.display()))?; + if shard.version != PARAGRAPH_SHARD_VERSION { + warn!( + path = %path.display(), + version = shard.version, + expected = PARAGRAPH_SHARD_VERSION, + "Skipping shard due to version mismatch" + ); + return Ok(None); + } + if shard.ingestion_fingerprint != fingerprint { + return Ok(None); + } + shard.shard_path = relative.to_string(); + Ok(Some(shard)) + } + + pub fn persist(&self, shard: &ParagraphShard) -> Result<()> { + let path = self.resolve(&shard.shard_path); + if let Some(parent) = path.parent() { + fs::create_dir_all(parent) + .with_context(|| format!("creating shard dir {}", parent.display()))?; + } + let tmp_path = path.with_extension("json.tmp"); + let body = serde_json::to_vec_pretty(shard).context("serialising paragraph shard")?; + fs::write(&tmp_path, &body) + .with_context(|| format!("writing shard tmp {}", tmp_path.display()))?; + fs::rename(&tmp_path, &path) + .with_context(|| format!("renaming shard tmp {}", path.display()))?; + Ok(()) + } +} + +impl ParagraphShard { + pub fn new( + paragraph: &ConvertedParagraph, + shard_path: String, + ingestion_fingerprint: &str, + text_content: TextContent, + entities: Vec, + relationships: Vec, + chunks: Vec, + embedding_backend: &str, + embedding_model: Option, + embedding_dimension: usize, + ) -> Self { + Self { + version: PARAGRAPH_SHARD_VERSION, + paragraph_id: paragraph.id.clone(), + shard_path, + ingestion_fingerprint: ingestion_fingerprint.to_string(), + ingested_at: Utc::now(), + title: paragraph.title.clone(), + text_content, + entities, + relationships, + chunks, + question_bindings: HashMap::new(), + embedding_backend: embedding_backend.to_string(), + embedding_model, + embedding_dimension, + } + } + + pub fn to_corpus_paragraph(&self) -> CorpusParagraph { + CorpusParagraph { + paragraph_id: self.paragraph_id.clone(), + title: self.title.clone(), + text_content: self.text_content.clone(), + entities: self.entities.clone(), + relationships: self.relationships.clone(), + chunks: self.chunks.clone(), + } + } + + pub fn ensure_question_binding( + &mut self, + question: &ConvertedQuestion, + ) -> Result<(Vec, bool)> { + if let Some(existing) = self.question_bindings.get(&question.id) { + return Ok((existing.clone(), false)); + } + let chunk_ids = validate_answers(&self.text_content, &self.chunks, question)?; + self.question_bindings + .insert(question.id.clone(), chunk_ids.clone()); + Ok((chunk_ids, true)) + } +} + +fn validate_answers( + content: &TextContent, + chunks: &[TextChunk], + question: &ConvertedQuestion, +) -> Result> { + if question.is_impossible || question.answers.is_empty() { + return Ok(Vec::new()); + } + + let mut matches = std::collections::BTreeSet::new(); + let mut found_any = false; + let haystack = content.text.to_ascii_lowercase(); + let haystack_norm = normalize_answer_text(&haystack); + for answer in &question.answers { + let needle: String = answer.to_ascii_lowercase(); + let needle_norm = normalize_answer_text(&needle); + let text_match = haystack.contains(&needle) + || (!needle_norm.is_empty() && haystack_norm.contains(&needle_norm)); + if text_match { + found_any = true; + } + for chunk in chunks { + let chunk_text = chunk.chunk.to_ascii_lowercase(); + let chunk_norm = normalize_answer_text(&chunk_text); + if chunk_text.contains(&needle) + || (!needle_norm.is_empty() && chunk_norm.contains(&needle_norm)) + { + matches.insert(chunk.id.clone()); + found_any = true; + } + } + } + + if !found_any { + Err(anyhow!( + "expected answer for question '{}' was not found in ingested content", + question.id + )) + } else { + Ok(matches.into_iter().collect()) + } +} + +fn normalize_answer_text(text: &str) -> String { + text.chars() + .map(|ch| { + if ch.is_alphanumeric() || ch.is_whitespace() { + ch.to_ascii_lowercase() + } else { + ' ' + } + }) + .collect::() + .split_whitespace() + .collect::>() + .join(" ") +} + +pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManifest) -> Result<()> { + for paragraph in &manifest.paragraphs { + db.upsert_item(paragraph.text_content.clone()) + .await + .context("storing text_content from manifest")?; + for entity in ¶graph.entities { + db.upsert_item(entity.clone()) + .await + .context("storing knowledge_entity from manifest")?; + } + for relationship in ¶graph.relationships { + relationship + .store_relationship(db) + .await + .context("storing knowledge_relationship from manifest")?; + } + for chunk in ¶graph.chunks { + db.upsert_item(chunk.clone()) + .await + .context("storing text_chunk from manifest")?; + } + } + + Ok(()) +} diff --git a/eval/src/main.rs b/eval/src/main.rs index ed35a8f..653dacb 100644 --- a/eval/src/main.rs +++ b/eval/src/main.rs @@ -194,17 +194,34 @@ async fn async_main() -> anyhow::Result<()> { ) })?; - println!( - "[{}] Precision@{k}: {precision:.3} ({correct}/{total}) → JSON: {json} | Markdown: {md} | Perf: {perf}", - summary.dataset_label, - k = summary.k, - precision = summary.precision, - correct = summary.correct, - total = summary.total_cases, - json = report_paths.json.display(), - md = report_paths.markdown.display(), - perf = perf_log_path.display() - ); + if summary.llm_cases > 0 { + println!( + "[{}] Retrieval Precision@{k}: {precision:.3} ({correct}/{retrieval_total}) + LLM: {llm_answered}/{llm_total} ({llm_precision:.3}) → JSON: {json} | Markdown: {md} | Perf: {perf}", + summary.dataset_label, + k = summary.k, + precision = summary.precision, + correct = summary.correct, + retrieval_total = summary.retrieval_cases, + llm_answered = summary.llm_answered, + llm_total = summary.llm_cases, + llm_precision = summary.llm_precision, + json = report_paths.json.display(), + md = report_paths.markdown.display(), + perf = perf_log_path.display() + ); + } else { + println!( + "[{}] Retrieval Precision@{k}: {precision:.3} ({correct}/{retrieval_total}) → JSON: {json} | Markdown: {md} | Perf: {perf}", + summary.dataset_label, + k = summary.k, + precision = summary.precision, + correct = summary.correct, + retrieval_total = summary.retrieval_cases, + json = report_paths.json.display(), + md = report_paths.markdown.display(), + perf = perf_log_path.display() + ); + } if parsed.config.perf_log_console { perf::print_console_summary(&summary); diff --git a/eval/src/perf.rs b/eval/src/perf.rs index 7315df2..bea16ec 100644 --- a/eval/src/perf.rs +++ b/eval/src/perf.rs @@ -19,6 +19,7 @@ struct PerformanceLogEntry { dataset_id: String, dataset_label: String, run_label: Option, + retrieval_strategy: String, slice_id: String, slice_seed: u64, slice_window_offset: usize, @@ -27,6 +28,10 @@ struct PerformanceLogEntry { total_cases: usize, correct: usize, precision: f64, + retrieval_cases: usize, + llm_cases: usize, + llm_answered: usize, + llm_precision: f64, k: usize, openai_base_url: String, ingestion: IngestionPerf, @@ -87,7 +92,7 @@ impl PerformanceLogEntry { rerank_enabled: summary.rerank_enabled, rerank_pool_size: summary.rerank_pool_size, rerank_keep_top: summary.rerank_keep_top, - evaluated_cases: summary.total_cases, + evaluated_cases: summary.retrieval_cases, }; Self { @@ -95,6 +100,7 @@ impl PerformanceLogEntry { dataset_id: summary.dataset_id.clone(), dataset_label: summary.dataset_label.clone(), run_label: summary.run_label.clone(), + retrieval_strategy: summary.retrieval_strategy.clone(), slice_id: summary.slice_id.clone(), slice_seed: summary.slice_seed, slice_window_offset: summary.slice_window_offset, @@ -103,6 +109,10 @@ impl PerformanceLogEntry { total_cases: summary.total_cases, correct: summary.correct, precision: summary.precision, + retrieval_cases: summary.retrieval_cases, + llm_cases: summary.llm_cases, + llm_answered: summary.llm_answered, + llm_precision: summary.llm_precision, k: summary.k, openai_base_url: summary.perf.openai_base_url.clone(), ingestion, @@ -162,6 +172,13 @@ pub fn write_perf_logs( pub fn print_console_summary(summary: &EvaluationSummary) { let perf = &summary.perf; + println!( + "[perf] retrieval strategy={} | rerank={} (pool {:?}, keep {})", + summary.retrieval_strategy, + summary.rerank_enabled, + summary.rerank_pool_size, + summary.rerank_keep_top + ); println!( "[perf] ingestion={}ms | namespace_seed={}", perf.ingestion_ms, @@ -169,7 +186,8 @@ pub fn print_console_summary(summary: &EvaluationSummary) { ); let stage = &perf.stage_latency; println!( - "[perf] stage avg ms → collect {:.1} | graph {:.1} | chunk {:.1} | rerank {:.1} | assemble {:.1}", + "[perf] stage avg ms → embed {:.1} | collect {:.1} | graph {:.1} | chunk {:.1} | rerank {:.1} | assemble {:.1}", + stage.embed.avg, stage.collect_candidates.avg, stage.graph_expansion.avg, stage.chunk_attach.avg, @@ -212,6 +230,7 @@ mod tests { fn sample_stage_latency() -> crate::eval::StageLatencyBreakdown { crate::eval::StageLatencyBreakdown { + embed: sample_latency(), collect_candidates: sample_latency(), graph_expansion: sample_latency(), chunk_attach: sample_latency(), @@ -252,6 +271,15 @@ mod tests { dataset_label: "SQuAD v2".into(), dataset_includes_unanswerable: false, dataset_source: "dev".into(), + includes_impossible_cases: false, + require_verified_chunks: true, + filtered_questions: 0, + retrieval_cases: 2, + retrieval_correct: 1, + retrieval_precision: 0.5, + llm_cases: 0, + llm_answered: 0, + llm_precision: 0.0, slice_id: "slice123".into(), slice_seed: 42, slice_total_cases: 400, @@ -285,6 +313,7 @@ mod tests { rerank_pool_size: Some(4), rerank_keep_top: 10, concurrency: 2, + retrieval_strategy: "initial".into(), detailed_report: false, chunk_vector_take: 20, chunk_fts_take: 20, diff --git a/eval/src/report.rs b/eval/src/report.rs index b6990af..9c6e995 100644 --- a/eval/src/report.rs +++ b/eval/src/report.rs @@ -6,7 +6,12 @@ use std::{ use anyhow::{Context, Result}; use serde::{Deserialize, Serialize}; -use crate::eval::{format_timestamp, CaseSummary, EvaluationSummary, LatencyStats}; +use crate::eval::{ + format_timestamp, CaseSummary, EvaluationStageTimings, EvaluationSummary, LatencyStats, + StageLatencyBreakdown, +}; +use chrono::Utc; +use tracing::warn; #[derive(Debug)] pub struct ReportPaths { @@ -14,6 +19,278 @@ pub struct ReportPaths { pub markdown: PathBuf, } +#[derive(Debug, Serialize)] +pub struct EvaluationReport { + pub overview: OverviewSection, + pub dataset: DatasetSection, + pub slice: SliceSection, + pub retrieval: RetrievalSection, + #[serde(skip_serializing_if = "Option::is_none")] + pub llm: Option, + pub performance: PerformanceSection, + #[serde(skip_serializing_if = "Vec::is_empty")] + pub misses: Vec, + #[serde(skip_serializing_if = "Vec::is_empty")] + pub llm_cases: Vec, + pub detailed_report: bool, +} + +#[derive(Debug, Serialize)] +pub struct OverviewSection { + pub generated_at: String, + pub run_label: Option, + pub total_cases: usize, + pub filtered_questions: usize, +} + +#[derive(Debug, Serialize)] +pub struct DatasetSection { + pub id: String, + pub label: String, + pub source: String, + pub includes_unanswerable: bool, + pub require_verified_chunks: bool, + pub embedding_backend: String, + pub embedding_model: Option, + pub embedding_dimension: usize, +} + +#[derive(Debug, Serialize)] +pub struct SliceSection { + pub id: String, + pub seed: u64, + pub window_offset: usize, + pub window_length: usize, + pub slice_cases: usize, + pub ledger_total_cases: usize, + pub positives: usize, + pub negatives: usize, + pub total_paragraphs: usize, + pub negative_multiplier: f32, +} + +#[derive(Debug, Serialize)] +pub struct RetrievalSection { + pub k: usize, + pub cases: usize, + pub correct: usize, + pub precision: f64, + pub precision_at_1: f64, + pub precision_at_2: f64, + pub precision_at_3: f64, + pub latency: LatencyStats, + pub concurrency: usize, + pub strategy: String, + pub rerank_enabled: bool, + pub rerank_pool_size: Option, + pub rerank_keep_top: usize, +} + +#[derive(Debug, Serialize)] +pub struct LlmSection { + pub cases: usize, + pub answered: usize, + pub precision: f64, +} + +#[derive(Debug, Serialize)] +pub struct PerformanceSection { + pub openai_base_url: String, + pub ingestion_ms: u128, + pub namespace_seed_ms: Option, + pub evaluation_stages_ms: EvaluationStageTimings, + pub stage_latency: StageLatencyBreakdown, + pub namespace_reused: bool, + pub ingestion_reused: bool, + pub embeddings_reused: bool, + pub ingestion_cache_path: String, + pub corpus_paragraphs: usize, + pub positive_paragraphs_reused: usize, + pub negative_paragraphs_reused: usize, +} + +#[derive(Debug, Serialize)] +pub struct MissEntry { + pub question_id: String, + pub paragraph_title: String, + pub expected_source: String, + pub entity_match: bool, + pub chunk_text_match: bool, + pub chunk_id_match: bool, + pub retrieved: Vec, +} + +#[derive(Debug, Serialize)] +pub struct LlmCaseEntry { + pub question_id: String, + pub answered: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub match_rank: Option, + pub retrieved: Vec, +} + +#[derive(Debug, Serialize)] +pub struct RetrievedSnippet { + pub rank: usize, + pub source_id: String, + pub entity_name: String, + pub matched: bool, +} + +impl EvaluationReport { + pub fn from_summary(summary: &EvaluationSummary, sample: usize) -> Self { + let overview = OverviewSection { + generated_at: format_timestamp(&summary.generated_at), + run_label: summary.run_label.clone(), + total_cases: summary.total_cases, + filtered_questions: summary.filtered_questions, + }; + + let dataset = DatasetSection { + id: summary.dataset_id.clone(), + label: summary.dataset_label.clone(), + source: summary.dataset_source.clone(), + includes_unanswerable: summary.includes_impossible_cases, + require_verified_chunks: summary.require_verified_chunks, + embedding_backend: summary.embedding_backend.clone(), + embedding_model: summary.embedding_model.clone(), + embedding_dimension: summary.embedding_dimension, + }; + + let slice = SliceSection { + id: summary.slice_id.clone(), + seed: summary.slice_seed, + window_offset: summary.slice_window_offset, + window_length: summary.slice_window_length, + slice_cases: summary.slice_cases, + ledger_total_cases: summary.slice_total_cases, + positives: summary.slice_positive_paragraphs, + negatives: summary.slice_negative_paragraphs, + total_paragraphs: summary.slice_total_paragraphs, + negative_multiplier: summary.slice_negative_multiplier, + }; + + let retrieval = RetrievalSection { + k: summary.k, + cases: summary.retrieval_cases, + correct: summary.retrieval_correct, + precision: summary.retrieval_precision, + precision_at_1: summary.precision_at_1, + precision_at_2: summary.precision_at_2, + precision_at_3: summary.precision_at_3, + latency: summary.latency_ms.clone(), + concurrency: summary.concurrency, + strategy: summary.retrieval_strategy.clone(), + rerank_enabled: summary.rerank_enabled, + rerank_pool_size: summary.rerank_pool_size, + rerank_keep_top: summary.rerank_keep_top, + }; + + let llm = if summary.llm_cases > 0 { + Some(LlmSection { + cases: summary.llm_cases, + answered: summary.llm_answered, + precision: summary.llm_precision, + }) + } else { + None + }; + + let performance = PerformanceSection { + openai_base_url: summary.perf.openai_base_url.clone(), + ingestion_ms: summary.perf.ingestion_ms, + namespace_seed_ms: summary.perf.namespace_seed_ms, + evaluation_stages_ms: summary.perf.evaluation_stage_ms.clone(), + stage_latency: summary.perf.stage_latency.clone(), + namespace_reused: summary.namespace_reused, + ingestion_reused: summary.ingestion_reused, + embeddings_reused: summary.ingestion_embeddings_reused, + ingestion_cache_path: summary.ingestion_cache_path.clone(), + corpus_paragraphs: summary.corpus_paragraphs, + positive_paragraphs_reused: summary.positive_paragraphs_reused, + negative_paragraphs_reused: summary.negative_paragraphs_reused, + }; + + let misses = summary + .cases + .iter() + .filter(|case| !case.matched && !case.is_impossible) + .take(sample) + .map(MissEntry::from_case) + .collect(); + + let llm_cases = if llm.is_some() { + summary + .cases + .iter() + .filter(|case| case.is_impossible) + .take(sample) + .map(LlmCaseEntry::from_case) + .collect() + } else { + Vec::new() + }; + + Self { + overview, + dataset, + slice, + retrieval, + llm, + performance, + misses, + llm_cases, + detailed_report: summary.detailed_report, + } + } +} + +impl MissEntry { + fn from_case(case: &CaseSummary) -> Self { + Self { + question_id: case.question_id.clone(), + paragraph_title: case.paragraph_title.clone(), + expected_source: case.expected_source.clone(), + entity_match: case.entity_match, + chunk_text_match: case.chunk_text_match, + chunk_id_match: case.chunk_id_match, + retrieved: case + .retrieved + .iter() + .take(3) + .map(RetrievedSnippet::from_summary) + .collect(), + } + } +} + +impl LlmCaseEntry { + fn from_case(case: &CaseSummary) -> Self { + Self { + question_id: case.question_id.clone(), + answered: case.matched, + match_rank: case.match_rank, + retrieved: case + .retrieved + .iter() + .take(3) + .map(RetrievedSnippet::from_summary) + .collect(), + } + } +} + +impl RetrievedSnippet { + fn from_summary(entry: &crate::eval::RetrievedSummary) -> Self { + Self { + rank: entry.rank, + source_id: entry.source_id.clone(), + entity_name: entry.entity_name.clone(), + matched: entry.matched, + } + } +} + pub fn write_reports( summary: &EvaluationSummary, report_dir: &Path, @@ -30,14 +307,15 @@ pub fn write_reports( })?; let stem = build_report_stem(summary); + let report = EvaluationReport::from_summary(summary, sample); let json_path = dataset_dir.join(format!("{stem}.json")); - let json_blob = serde_json::to_string_pretty(summary).context("serialising JSON report")?; + let json_blob = serde_json::to_string_pretty(&report).context("serialising JSON report")?; fs::write(&json_path, &json_blob) .with_context(|| format!("writing JSON report to {}", json_path.display()))?; let md_path = dataset_dir.join(format!("{stem}.md")); - let markdown = render_markdown(summary, sample); + let markdown = render_markdown(&report); fs::write(&md_path, &markdown) .with_context(|| format!("writing Markdown report to {}", md_path.display()))?; @@ -57,234 +335,236 @@ pub fn write_reports( }) } -fn render_markdown(summary: &EvaluationSummary, sample: usize) -> String { +fn render_markdown(report: &EvaluationReport) -> String { let mut md = String::new(); - md.push_str(&format!("# Retrieval Precision@{}\n\n", summary.k)); - md.push_str("| Metric | Value |\n"); - md.push_str("| --- | --- |\n"); md.push_str(&format!( - "| Generated | {} |\n", - format_timestamp(&summary.generated_at) + "# Retrieval Evaluation (k={})\\n\\n", + report.retrieval.k + )); + + md.push_str("## Overview\\n\\n"); + md.push_str("| Metric | Value |\\n| --- | --- |\\n"); + md.push_str(&format!( + "| Generated | {} |\\n", + report.overview.generated_at )); md.push_str(&format!( - "| Dataset | {} (`{}`) |\n", - summary.dataset_label, summary.dataset_id - )); - md.push_str(&format!( - "| Run Label | {} |\n", - summary + "| Run Label | {} |\\n", + report + .overview .run_label .as_deref() .filter(|label| !label.is_empty()) .unwrap_or("-") )); md.push_str(&format!( - "| Unanswerable Included | {} |\n", - if summary.dataset_includes_unanswerable { - "yes" - } else { - "no" - } + "| Total Cases | {} |\\n", + report.overview.total_cases )); md.push_str(&format!( - "| Dataset Source | {} |\n", - summary.dataset_source + "| Filtered Questions | {} |\\n", + report.overview.filtered_questions + )); + + md.push_str("\\n## Dataset & Slice\\n\\n"); + md.push_str("| Metric | Value |\\n| --- | --- |\\n"); + md.push_str(&format!( + "| Dataset | {} (`{}`) |\\n", + report.dataset.label, report.dataset.id )); md.push_str(&format!( - "| OpenAI Base URL | {} |\n", - summary.perf.openai_base_url - )); - md.push_str(&format!("| Slice ID | `{}` |\n", summary.slice_id)); - md.push_str(&format!("| Slice Seed | {} |\n", summary.slice_seed)); - md.push_str(&format!( - "| Slice Total Questions | {} |\n", - summary.slice_total_cases + "| Dataset Source | {} |\\n", + report.dataset.source )); md.push_str(&format!( - "| Slice Window (offset/length) | {}/{} |\n", - summary.slice_window_offset, summary.slice_window_length + "| Includes Unanswerable | {} |\\n", + bool_badge(report.dataset.includes_unanswerable) )); md.push_str(&format!( - "| Slice Window Questions | {} |\n", - summary.slice_cases + "| Require Verified Chunks | {} |\\n", + bool_badge(report.dataset.require_verified_chunks) + )); + let embedding_label = if let Some(model) = report.dataset.embedding_model.as_ref() { + format!("{} ({model})", report.dataset.embedding_backend) + } else { + report.dataset.embedding_backend.clone() + }; + md.push_str(&format!("| Embedding | {} |\\n", embedding_label)); + md.push_str(&format!( + "| Embedding Dim | {} |\\n", + report.dataset.embedding_dimension + )); + md.push_str(&format!("| Slice ID | `{}` |\\n", report.slice.id)); + md.push_str(&format!("| Slice Seed | {} |\\n", report.slice.seed)); + md.push_str(&format!( + "| Slice Window (offset/length) | {}/{} |\\n", + report.slice.window_offset, report.slice.window_length )); md.push_str(&format!( - "| Slice Negatives | {} |\n", - summary.slice_negative_paragraphs + "| Slice Questions (window/ledger) | {}/{} |\\n", + report.slice.slice_cases, report.slice.ledger_total_cases )); md.push_str(&format!( - "| Slice Total Paragraphs | {} |\n", - summary.slice_total_paragraphs + "| Slice Positives / Negatives | {}/{} |\\n", + report.slice.positives, report.slice.negatives )); md.push_str(&format!( - "| Slice Negative Multiplier | {:.2} |\n", - summary.slice_negative_multiplier + "| Slice Paragraphs | {} |\\n", + report.slice.total_paragraphs )); md.push_str(&format!( - "| Namespace State | {} |\n", - if summary.namespace_reused { + "| Negative Multiplier | {:.2} |\\n", + report.slice.negative_multiplier + )); + + md.push_str("\\n## Retrieval Metrics\\n\\n"); + md.push_str("| Metric | Value |\\n| --- | --- |\\n"); + md.push_str(&format!("| Cases | {} |\\n", report.retrieval.cases)); + md.push_str(&format!( + "| Correct@{} | {}/{} |\\n", + report.retrieval.k, report.retrieval.correct, report.retrieval.cases + )); + md.push_str(&format!( + "| Precision@{} | {:.3} |\\n", + report.retrieval.k, report.retrieval.precision + )); + md.push_str(&format!( + "| Precision@1/2/3 | {:.3} / {:.3} / {:.3} |\\n", + report.retrieval.precision_at_1, + report.retrieval.precision_at_2, + report.retrieval.precision_at_3 + )); + md.push_str(&format!( + "| Latency Avg / P50 / P95 (ms) | {:.1} / {} / {} |\\n", + report.retrieval.latency.avg, report.retrieval.latency.p50, report.retrieval.latency.p95 + )); + md.push_str(&format!( + "| Strategy | `{}` |\\n", + report.retrieval.strategy + )); + md.push_str(&format!( + "| Concurrency | {} |\\n", + report.retrieval.concurrency + )); + if report.retrieval.rerank_enabled { + let pool = report + .retrieval + .rerank_pool_size + .map(|size| size.to_string()) + .unwrap_or_else(|| "?".into()); + md.push_str(&format!( + "| Rerank | enabled (pool {pool}, keep top {}) |\\n", + report.retrieval.rerank_keep_top + )); + } else { + md.push_str("| Rerank | disabled |\\n"); + } + + if let Some(llm) = &report.llm { + md.push_str("\\n## LLM Mode Metrics\\n\\n"); + md.push_str("| Metric | Value |\\n| --- | --- |\\n"); + md.push_str(&format!("| Cases | {} |\\n", llm.cases)); + md.push_str(&format!("| Answered | {} |\\n", llm.answered)); + md.push_str(&format!("| Precision | {:.3} |\\n", llm.precision)); + } + + md.push_str("\\n## Performance\\n\\n"); + md.push_str("| Metric | Value |\\n| --- | --- |\\n"); + md.push_str(&format!( + "| OpenAI Base URL | {} |\\n", + report.performance.openai_base_url + )); + md.push_str(&format!( + "| Ingestion Duration | {} ms |\\n", + report.performance.ingestion_ms + )); + if let Some(seed) = report.performance.namespace_seed_ms { + md.push_str(&format!("| Namespace Seed | {} ms |\\n", seed)); + } + md.push_str(&format!( + "| Namespace State | {} |\\n", + if report.performance.namespace_reused { "reused" } else { "seeded" } )); md.push_str(&format!( - "| Corpus Paragraphs | {} |\n", - summary.corpus_paragraphs + "| Corpus Paragraphs | {} |\\n", + report.performance.corpus_paragraphs )); - md.push_str(&format!( - "| Ingestion Duration | {} ms |\n", - summary.perf.ingestion_ms - )); - if let Some(seed) = summary.perf.namespace_seed_ms { - md.push_str(&format!("| Namespace Seed | {} ms |\n", seed)); - } - if summary.detailed_report { + if report.detailed_report { md.push_str(&format!( - "| Ingestion Cache | `{}` |\n", - summary.ingestion_cache_path + "| Ingestion Cache | `{}` |\\n", + report.performance.ingestion_cache_path )); md.push_str(&format!( - "| Ingestion Reused | {} |\n", - if summary.ingestion_reused { - "yes" - } else { - "no" - } + "| Ingestion Reused | {} |\\n", + bool_badge(report.performance.ingestion_reused) )); md.push_str(&format!( - "| Embeddings Reused | {} |\n", - if summary.ingestion_embeddings_reused { - "yes" - } else { - "no" - } + "| Embeddings Reused | {} |\\n", + bool_badge(report.performance.embeddings_reused) )); } md.push_str(&format!( - "| Positives Cached | {} | -", - summary.positive_paragraphs_reused + "| Positives Cached | {} |\\n", + report.performance.positive_paragraphs_reused )); md.push_str(&format!( - "| Negatives Cached | {} | -", - summary.negative_paragraphs_reused - )); - let embedding_label = if let Some(model) = summary.embedding_model.as_ref() { - format!("{} ({model})", summary.embedding_backend) - } else { - summary.embedding_backend.clone() - }; - md.push_str(&format!("| Embedding | {} |\n", embedding_label)); - md.push_str(&format!( - "| Embedding Dim | {} |\n", - summary.embedding_dimension - )); - if let Some(limit) = summary.limit { - md.push_str(&format!( - "| Evaluated Queries | {} (limit {}) |\n", - summary.total_cases, limit - )); - } else { - md.push_str(&format!( - "| Evaluated Queries | {} |\n", - summary.total_cases - )); - } - if summary.rerank_enabled { - let pool = summary - .rerank_pool_size - .map(|size| size.to_string()) - .unwrap_or_else(|| "?".to_string()); - md.push_str(&format!( - "| Rerank | enabled (pool {pool}, keep top {}) |\n", - summary.rerank_keep_top - )); - } else { - md.push_str("| Rerank | disabled |\n"); - } - md.push_str(&format!("| Concurrency | {} |\n", summary.concurrency)); - md.push_str(&format!( - "| Correct@{} | {}/{} |\n", - summary.k, summary.correct, summary.total_cases - )); - md.push_str(&format!( - "| Precision@{} | {:.3} |\n", - summary.k, summary.precision - )); - md.push_str(&format!( - "| Precision@1 | {:.3} |\n", - summary.precision_at_1 - )); - md.push_str(&format!( - "| Precision@2 | {:.3} |\n", - summary.precision_at_2 - )); - md.push_str(&format!( - "| Precision@3 | {:.3} |\n", - summary.precision_at_3 - )); - md.push_str(&format!("| Duration | {} ms |\n", summary.duration_ms)); - md.push_str(&format!( - "| Latency Avg (ms) | {:.1} |\n", - summary.latency_ms.avg - )); - md.push_str(&format!( - "| Latency P50 (ms) | {} |\n", - summary.latency_ms.p50 - )); - md.push_str(&format!( - "| Latency P95 (ms) | {} |\n", - summary.latency_ms.p95 + "| Negatives Cached | {} |\\n", + report.performance.negative_paragraphs_reused )); - md.push_str("\n## Retrieval Stage Timings\n\n"); - md.push_str("| Stage | Avg (ms) | P50 (ms) | P95 (ms) |\n"); - md.push_str("| --- | --- | --- | --- |\n"); + md.push_str("\\n## Retrieval Stage Timings\\n\\n"); + md.push_str("| Stage | Avg (ms) | P50 (ms) | P95 (ms) |\\n| --- | --- | --- | --- |\\n"); + write_stage_row(&mut md, "Embed", &report.performance.stage_latency.embed); write_stage_row( &mut md, "Collect Candidates", - &summary.perf.stage_latency.collect_candidates, + &report.performance.stage_latency.collect_candidates, ); write_stage_row( &mut md, "Graph Expansion", - &summary.perf.stage_latency.graph_expansion, + &report.performance.stage_latency.graph_expansion, ); write_stage_row( &mut md, "Chunk Attach", - &summary.perf.stage_latency.chunk_attach, + &report.performance.stage_latency.chunk_attach, + ); + write_stage_row(&mut md, "Rerank", &report.performance.stage_latency.rerank); + write_stage_row( + &mut md, + "Assemble", + &report.performance.stage_latency.assemble, ); - write_stage_row(&mut md, "Rerank", &summary.perf.stage_latency.rerank); - write_stage_row(&mut md, "Assemble", &summary.perf.stage_latency.assemble); - let misses: Vec<&CaseSummary> = summary.cases.iter().filter(|case| !case.matched).collect(); - if !misses.is_empty() { - md.push_str("\n## Missed Queries (sample)\n\n"); - if summary.detailed_report { + if report.misses.is_empty() { + md.push_str("\\n_All evaluated retrieval queries matched within the top-k window._\\n"); + if report.detailed_report { md.push_str( - "| Question ID | Paragraph | Expected Source | Entity Match | Chunk Text | Chunk ID | Top Retrieved |\n", + "\\nSuccess measures were captured for each query (entity, chunk text, chunk ID).\\n", ); - md.push_str("| --- | --- | --- | --- | --- | --- | --- |\n"); - } else { - md.push_str("| Question ID | Paragraph | Expected Source | Top Retrieved |\n"); - md.push_str("| --- | --- | --- | --- |\n"); } - - for case in misses.iter().take(sample) { - let retrieved = case - .retrieved - .iter() - .map(|entry| format!("{} (rank {})", entry.source_id, entry.rank)) - .take(3) - .collect::>() - .join("
"); - if summary.detailed_report { + } else { + md.push_str("\\n## Missed Retrieval Queries (sample)\\n\\n"); + if report.detailed_report { + md.push_str( + "| Question ID | Paragraph | Expected Source | Entity Match | Chunk Text | Chunk ID | Top Retrieved |\\n", + ); + md.push_str("| --- | --- | --- | --- | --- | --- | --- |\\n"); + } else { + md.push_str("| Question ID | Paragraph | Expected Source | Top Retrieved |\\n"); + md.push_str("| --- | --- | --- | --- |\\n"); + } + for case in &report.misses { + let retrieved = render_retrieved(&case.retrieved); + if report.detailed_report { md.push_str(&format!( - "| `{}` | {} | `{}` | {} | {} | {} | {} |\n", + "| `{}` | {} | `{}` | {} | {} | {} | {} |\\n", case.question_id, case.paragraph_title, case.expected_source, @@ -295,23 +575,39 @@ fn render_markdown(summary: &EvaluationSummary, sample: usize) -> String { )); } else { md.push_str(&format!( - "| `{}` | {} | `{}` | {} |\n", + "| `{}` | {} | `{}` | {} |\\n", case.question_id, case.paragraph_title, case.expected_source, retrieved )); } } - } else { - md.push_str("\n_All evaluated queries matched within the top-k window._\n"); - if summary.detailed_report { - md.push_str( - "\nSuccess measures were captured for each query (entity, chunk text, chunk ID).\n", - ); + } + + if report.llm.is_some() { + md.push_str("\\n## LLM-Only Cases (sample)\\n\\n"); + if report.llm_cases.is_empty() { + md.push_str("All LLM-only cases matched within the evaluation window.\\n"); + } else { + md.push_str("| Question ID | Answered | Match Rank | Top Retrieved |\\n"); + md.push_str("| --- | --- | --- | --- |\\n"); + for case in &report.llm_cases { + let retrieved = render_retrieved(&case.retrieved); + let rank = case + .match_rank + .map(|rank| rank.to_string()) + .unwrap_or_else(|| "-".into()); + md.push_str(&format!( + "| `{}` | {} | {} | {} |\\n", + case.question_id, + bool_badge(case.answered), + rank, + retrieved + )); + } } } md } - fn write_stage_row(buf: &mut String, label: &str, stats: &LatencyStats) { buf.push_str(&format!( "| {} | {:.1} | {} | {} |\n", @@ -327,6 +623,19 @@ fn bool_badge(value: bool) -> &'static str { } } +fn render_retrieved(entries: &[RetrievedSnippet]) -> String { + if entries.is_empty() { + "-".to_string() + } else { + entries + .iter() + .map(|entry| format!("{} (rank {})", entry.source_id, entry.rank)) + .take(3) + .collect::>() + .join("
") + } +} + fn build_report_stem(summary: &EvaluationSummary) -> String { let timestamp = summary.generated_at.format("%Y%m%dT%H%M%S"); let backend = sanitize_component(&summary.embedding_backend); @@ -377,6 +686,14 @@ struct HistoryEntry { precision_at_1: f64, precision_at_2: f64, precision_at_3: f64, + #[serde(default)] + retrieval_cases: usize, + #[serde(default)] + retrieval_precision: f64, + #[serde(default)] + llm_cases: usize, + #[serde(default)] + llm_precision: f64, duration_ms: u128, latency_ms: LatencyStats, embedding_backend: String, @@ -405,7 +722,28 @@ fn record_history(summary: &EvaluationSummary, report_dir: &Path) -> Result<()> let mut entries: Vec = if path.exists() { let contents = fs::read(&path) .with_context(|| format!("reading evaluation log {}", path.display()))?; - serde_json::from_slice(&contents).unwrap_or_default() + match serde_json::from_slice(&contents) { + Ok(entries) => entries, + Err(err) => { + let timestamp = Utc::now().format("%Y%m%dT%H%M%S"); + let backup_path = + report_dir.join(format!("evaluations.json.corrupted.{}", timestamp)); + warn!( + path = %path.display(), + backup = %backup_path.display(), + error = %err, + "Evaluation history file is corrupted; backing up and starting fresh" + ); + if let Err(e) = fs::rename(&path, &backup_path) { + warn!( + path = %path.display(), + error = %e, + "Failed to backup corrupted evaluation history" + ); + } + Vec::new() + } + } } else { Vec::new() }; @@ -433,6 +771,10 @@ fn record_history(summary: &EvaluationSummary, report_dir: &Path) -> Result<()> precision_at_1: summary.precision_at_1, precision_at_2: summary.precision_at_2, precision_at_3: summary.precision_at_3, + retrieval_cases: summary.retrieval_cases, + retrieval_precision: summary.retrieval_precision, + llm_cases: summary.llm_cases, + llm_precision: summary.llm_precision, duration_ms: summary.duration_ms, latency_ms: summary.latency_ms.clone(), embedding_backend: summary.embedding_backend.clone(), @@ -454,3 +796,173 @@ fn record_history(summary: &EvaluationSummary, report_dir: &Path) -> Result<()> fs::write(&path, blob).with_context(|| format!("writing evaluation log {}", path.display()))?; Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::eval::{ + EvaluationStageTimings, PerformanceTimings, RetrievedSummary, StageLatencyBreakdown, + }; + use chrono::Utc; + + fn latency(ms: f64) -> LatencyStats { + LatencyStats { + avg: ms, + p50: ms as u128, + p95: ms as u128, + } + } + + fn sample_stage_latency() -> StageLatencyBreakdown { + StageLatencyBreakdown { + embed: latency(9.0), + collect_candidates: latency(10.0), + graph_expansion: latency(11.0), + chunk_attach: latency(12.0), + rerank: latency(13.0), + assemble: latency(14.0), + } + } + + fn sample_eval_stage() -> EvaluationStageTimings { + EvaluationStageTimings { + prepare_slice_ms: 1, + prepare_db_ms: 2, + prepare_corpus_ms: 3, + prepare_namespace_ms: 4, + run_queries_ms: 5, + summarize_ms: 6, + finalize_ms: 7, + } + } + + fn sample_case(is_impossible: bool, matched: bool) -> CaseSummary { + CaseSummary { + question_id: if is_impossible { + "llm-q".into() + } else { + "retrieval-q".into() + }, + question: "Who is the hero?".into(), + paragraph_id: "p1".into(), + paragraph_title: "Hero".into(), + expected_source: "src1".into(), + answers: vec!["answer".into()], + matched, + entity_match: matched, + chunk_text_match: matched, + chunk_id_match: matched, + is_impossible, + has_verified_chunks: !is_impossible, + match_rank: if matched { Some(1) } else { None }, + latency_ms: 42, + retrieved: vec![RetrievedSummary { + rank: 1, + entity_id: "entity1".into(), + source_id: "src1".into(), + entity_name: "Entity".into(), + score: 1.0, + matched, + entity_description: None, + entity_category: None, + chunk_text_match: Some(matched), + chunk_id_match: Some(matched), + }], + } + } + + fn sample_summary(include_llm: bool) -> EvaluationSummary { + let mut cases = vec![sample_case(false, true)]; + if include_llm { + cases.push(sample_case(true, false)); + } + EvaluationSummary { + generated_at: Utc::now(), + k: 5, + limit: Some(10), + run_label: Some("test".into()), + total_cases: cases.len(), + correct: 1, + precision: 1.0, + correct_at_1: 1, + correct_at_2: 1, + correct_at_3: 1, + precision_at_1: 1.0, + precision_at_2: 1.0, + precision_at_3: 1.0, + duration_ms: 100, + dataset_id: "ds".into(), + dataset_label: "Dataset".into(), + dataset_includes_unanswerable: include_llm, + dataset_source: "dev".into(), + includes_impossible_cases: include_llm, + require_verified_chunks: !include_llm, + filtered_questions: 0, + retrieval_cases: 1, + retrieval_correct: 1, + retrieval_precision: 1.0, + llm_cases: if include_llm { 1 } else { 0 }, + llm_answered: 0, + llm_precision: 0.0, + slice_id: "slice".into(), + slice_seed: 1, + slice_total_cases: cases.len(), + slice_window_offset: 0, + slice_window_length: cases.len(), + slice_cases: cases.len(), + slice_positive_paragraphs: 1, + slice_negative_paragraphs: 0, + slice_total_paragraphs: 1, + slice_negative_multiplier: 1.0, + namespace_reused: true, + corpus_paragraphs: 1, + ingestion_cache_path: "/cache".into(), + ingestion_reused: true, + ingestion_embeddings_reused: true, + ingestion_fingerprint: "fp".into(), + positive_paragraphs_reused: 1, + negative_paragraphs_reused: 0, + latency_ms: latency(10.0), + perf: PerformanceTimings { + openai_base_url: "https://example.com".into(), + ingestion_ms: 100, + namespace_seed_ms: Some(50), + evaluation_stage_ms: sample_eval_stage(), + stage_latency: sample_stage_latency(), + }, + embedding_backend: "fastembed".into(), + embedding_model: Some("model".into()), + embedding_dimension: 32, + rerank_enabled: true, + rerank_pool_size: Some(4), + rerank_keep_top: 5, + concurrency: 2, + detailed_report: true, + retrieval_strategy: "initial".into(), + chunk_vector_take: 50, + chunk_fts_take: 50, + chunk_token_budget: 10_000, + chunk_avg_chars_per_token: 4, + max_chunks_per_entity: 4, + cases, + } + } + + #[test] + fn markdown_includes_llm_section() { + let summary = sample_summary(true); + let report = EvaluationReport::from_summary(&summary, 5); + let md = render_markdown(&report); + assert!(md.contains("LLM Mode Metrics")); + assert!(md.contains("LLM-Only Cases (sample)")); + } + + #[test] + fn markdown_hides_llm_section_when_not_present() { + let summary = sample_summary(false); + let report = EvaluationReport::from_summary(&summary, 5); + let md = render_markdown(&report); + assert!(!md.contains("LLM Mode Metrics")); + assert!(!md.contains("LLM-Only Cases")); + } +} diff --git a/eval/src/slice.rs b/eval/src/slice.rs index 3bf456e..dda9d80 100644 --- a/eval/src/slice.rs +++ b/eval/src/slice.rs @@ -23,5 +23,6 @@ pub fn slice_config_with_limit<'a>( slice_seed: config.slice_seed, llm_mode: config.llm_mode, negative_multiplier: config.negative_multiplier, + require_verified_chunks: config.retrieval.require_verified_chunks, } } diff --git a/eval/src/slices.rs b/eval/src/slices.rs index 8c4231c..79523eb 100644 --- a/eval/src/slices.rs +++ b/eval/src/slices.rs @@ -26,6 +26,7 @@ pub struct SliceConfig<'a> { pub slice_seed: u64, pub llm_mode: bool, pub negative_multiplier: f32, + pub require_verified_chunks: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -36,6 +37,8 @@ pub struct SliceManifest { pub dataset_label: String, pub dataset_source: String, pub includes_unanswerable: bool, + #[serde(default = "default_require_verified_chunks")] + pub require_verified_chunks: bool, pub seed: u64, pub requested_limit: Option, pub requested_corpus: usize, @@ -49,6 +52,10 @@ pub struct SliceManifest { pub paragraphs: Vec, } +fn default_require_verified_chunks() -> bool { + false +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SliceCaseEntry { pub question_id: String, @@ -184,6 +191,7 @@ impl DatasetIndex { struct SliceKey<'a> { dataset_id: &'a str, includes_unanswerable: bool, + require_verified_chunks: bool, requested_corpus: usize, seed: u64, } @@ -222,7 +230,8 @@ pub fn resolve_slice<'a>( .max(1); let key = SliceKey { dataset_id: dataset.metadata.id.as_str(), - includes_unanswerable: dataset.metadata.include_unanswerable, + includes_unanswerable: config.llm_mode, + require_verified_chunks: config.require_verified_chunks, requested_corpus, seed: config.slice_seed, }; @@ -248,13 +257,24 @@ pub fn resolve_slice<'a>( let mut manifest = if !config.force_convert && path.exists() { match read_manifest(&path) { Ok(manifest) if manifest.dataset_id == dataset.metadata.id => { - if manifest.includes_unanswerable != dataset.metadata.include_unanswerable { + if manifest.includes_unanswerable != config.llm_mode { warn!( slice = manifest.slice_id, path = %path.display(), + expected = config.llm_mode, + found = manifest.includes_unanswerable, "Slice manifest includes_unanswerable mismatch; regenerating" ); None + } else if manifest.require_verified_chunks != config.require_verified_chunks { + warn!( + slice = manifest.slice_id, + path = %path.display(), + expected = config.require_verified_chunks, + found = manifest.require_verified_chunks, + "Slice manifest verified-chunk requirement mismatch; regenerating" + ); + None } else { Some(manifest) } @@ -312,6 +332,7 @@ pub fn resolve_slice<'a>( ¶ms, requested_corpus, config.negative_multiplier, + config.require_verified_chunks, config.limit, ) }); @@ -319,6 +340,8 @@ pub fn resolve_slice<'a>( manifest.requested_limit = config.limit; manifest.requested_corpus = requested_corpus; manifest.negative_multiplier = config.negative_multiplier; + manifest.includes_unanswerable = config.llm_mode; + manifest.require_verified_chunks = config.require_verified_chunks; let mut changed = ensure_shard_paths(&mut manifest); @@ -439,6 +462,22 @@ fn load_explicit_slice<'a>( dataset.metadata.id )); } + if manifest.includes_unanswerable != config.llm_mode { + return Err(anyhow!( + "slice '{}' includes_unanswerable mismatch (expected {}, found {})", + manifest.slice_id, + config.llm_mode, + manifest.includes_unanswerable + )); + } + if manifest.require_verified_chunks != config.require_verified_chunks { + return Err(anyhow!( + "slice '{}' verified-chunk requirement mismatch (expected {}, found {})", + manifest.slice_id, + config.require_verified_chunks, + manifest.require_verified_chunks + )); + } // Validate the manifest before returning. manifest_to_resolved(dataset, index, manifest.clone(), candidate_path.clone())?; @@ -452,6 +491,7 @@ fn empty_manifest( params: &BuildParams, requested_corpus: usize, negative_multiplier: f32, + require_verified_chunks: bool, requested_limit: Option, ) -> SliceManifest { SliceManifest { @@ -460,7 +500,8 @@ fn empty_manifest( dataset_id: dataset.metadata.id.clone(), dataset_label: dataset.metadata.label.clone(), dataset_source: dataset.source.clone(), - includes_unanswerable: dataset.metadata.include_unanswerable, + includes_unanswerable: params.include_impossible, + require_verified_chunks, seed: params.base_seed, requested_limit, requested_corpus, @@ -891,6 +932,7 @@ mod tests { slice_seed: 0x5eed_2025, llm_mode: false, negative_multiplier: DEFAULT_NEGATIVE_MULTIPLIER, + require_verified_chunks: true, }; let first = resolve_slice(&dataset, &config)?; @@ -922,6 +964,7 @@ mod tests { slice_seed: 0x5eed_2025, llm_mode: false, negative_multiplier: DEFAULT_NEGATIVE_MULTIPLIER, + require_verified_chunks: true, }; let resolved = resolve_slice(&dataset, &config)?; let window = select_window(&resolved, 1, Some(1))?; diff --git a/eval/src/snapshot.rs b/eval/src/snapshot.rs index 6068417..18440b9 100644 --- a/eval/src/snapshot.rs +++ b/eval/src/snapshot.rs @@ -54,9 +54,9 @@ impl Descriptor { embedding_backend: embedding_provider.backend_label().to_string(), embedding_model: embedding_provider.model_code(), embedding_dimension: embedding_provider.dimension(), - chunk_min_chars: config.chunk_min_chars, - chunk_max_chars: config.chunk_max_chars, - rerank_enabled: config.rerank, + chunk_min_chars: config.retrieval.chunk_min_chars, + chunk_max_chars: config.retrieval.chunk_max_chars, + rerank_enabled: config.retrieval.rerank, }; let dir = config