benchmarks: fin

This commit is contained in:
Per Stark
2025-12-08 21:57:53 +01:00
parent 0cb1abc6db
commit a8d10f265c
39 changed files with 774 additions and 714 deletions

515
evaluations/src/args.rs Normal file
View File

@@ -0,0 +1,515 @@
use std::{
env,
path::{Path, PathBuf},
};
use anyhow::{anyhow, Context, Result};
use clap::{Args, Parser, ValueEnum};
use retrieval_pipeline::RetrievalStrategy;
use crate::datasets::DatasetKind;
fn workspace_root() -> PathBuf {
let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
manifest_dir.parent().unwrap_or(&manifest_dir).to_path_buf()
}
fn default_report_dir() -> PathBuf {
workspace_root().join("evaluations/reports")
}
fn default_cache_dir() -> PathBuf {
workspace_root().join("evaluations/cache")
}
fn default_ingestion_cache_dir() -> PathBuf {
workspace_root().join("evaluations/cache/ingested")
}
pub const DEFAULT_SLICE_SEED: u64 = 0x5eed_2025;
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
#[value(rename_all = "lowercase")]
pub enum EmbeddingBackend {
Hashed,
FastEmbed,
}
impl Default for EmbeddingBackend {
fn default() -> Self {
Self::FastEmbed
}
}
impl std::fmt::Display for EmbeddingBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Hashed => write!(f, "hashed"),
Self::FastEmbed => write!(f, "fastembed"),
}
}
}
#[derive(Debug, Clone, Args)]
pub struct RetrievalSettings {
/// Override chunk vector candidate cap
#[arg(long)]
pub chunk_vector_take: Option<usize>,
/// Override chunk FTS candidate cap
#[arg(long)]
pub chunk_fts_take: Option<usize>,
/// Override average characters per token used for budgeting
#[arg(long)]
pub chunk_avg_chars_per_token: Option<usize>,
/// Override maximum chunks attached per entity
#[arg(long)]
pub max_chunks_per_entity: Option<usize>,
/// Enable the FastEmbed reranking stage
#[arg(long = "rerank", action = clap::ArgAction::SetTrue, default_value_t = false)]
pub rerank: bool,
/// Reranking engine pool size / parallelism
#[arg(long, default_value_t = 4)]
pub rerank_pool_size: usize,
/// Keep top-N entities after reranking
#[arg(long, default_value_t = 10)]
pub rerank_keep_top: usize,
/// Cap the number of chunks returned by retrieval (revised strategy)
#[arg(long, default_value_t = 5)]
pub chunk_result_cap: usize,
/// Reciprocal rank fusion k value for revised chunk merging
#[arg(long)]
pub chunk_rrf_k: Option<f32>,
/// Weight for vector ranks in revised RRF
#[arg(long)]
pub chunk_rrf_vector_weight: Option<f32>,
/// Weight for chunk FTS ranks in revised RRF
#[arg(long)]
pub chunk_rrf_fts_weight: Option<f32>,
/// Include vector ranks in revised RRF (default: true)
#[arg(long)]
pub chunk_rrf_use_vector: Option<bool>,
/// Include chunk FTS ranks in revised RRF (default: true)
#[arg(long)]
pub chunk_rrf_use_fts: Option<bool>,
/// Require verified chunks (disable with --llm-mode)
#[arg(skip = true)]
pub require_verified_chunks: bool,
/// Select the retrieval pipeline strategy
#[arg(long, default_value_t = RetrievalStrategy::Initial)]
pub strategy: RetrievalStrategy,
}
impl Default for RetrievalSettings {
fn default() -> Self {
Self {
chunk_vector_take: None,
chunk_fts_take: None,
chunk_avg_chars_per_token: None,
max_chunks_per_entity: None,
rerank: false,
rerank_pool_size: 4,
rerank_keep_top: 10,
chunk_result_cap: 5,
chunk_rrf_k: None,
chunk_rrf_vector_weight: None,
chunk_rrf_fts_weight: None,
chunk_rrf_use_vector: None,
chunk_rrf_use_fts: None,
require_verified_chunks: true,
strategy: RetrievalStrategy::Initial,
}
}
}
#[derive(Debug, Clone, Args)]
pub struct IngestConfig {
/// Directory for ingestion corpora caches
#[arg(long, default_value_os_t = default_ingestion_cache_dir())]
pub ingestion_cache_dir: PathBuf,
/// Minimum tokens per chunk for ingestion
#[arg(long, default_value_t = 256)]
pub ingest_chunk_min_tokens: usize,
/// Maximum tokens per chunk for ingestion
#[arg(long, default_value_t = 512)]
pub ingest_chunk_max_tokens: usize,
/// Overlap between chunks during ingestion (tokens)
#[arg(long, default_value_t = 50)]
pub ingest_chunk_overlap_tokens: usize,
/// Run ingestion in chunk-only mode (skip analyzer/graph generation)
#[arg(long)]
pub ingest_chunks_only: bool,
/// Number of paragraphs to ingest concurrently
#[arg(long, default_value_t = 10)]
pub ingestion_batch_size: usize,
/// Maximum retries for ingestion failures per paragraph
#[arg(long, default_value_t = 3)]
pub ingestion_max_retries: usize,
/// Recompute embeddings for cached corpora without re-running ingestion
#[arg(long, alias = "refresh-embeddings")]
pub refresh_embeddings_only: bool,
/// Delete cached paragraph shards before rebuilding the ingestion corpus
#[arg(long)]
pub slice_reset_ingestion: bool,
}
#[derive(Debug, Clone, Args)]
pub struct DatabaseArgs {
/// SurrealDB server endpoint
#[arg(long, default_value = "ws://127.0.0.1:8000", env = "EVAL_DB_ENDPOINT")]
pub db_endpoint: String,
/// SurrealDB root username
#[arg(long, default_value = "root_user", env = "EVAL_DB_USERNAME")]
pub db_username: String,
/// SurrealDB root password
#[arg(long, default_value = "root_password", env = "EVAL_DB_PASSWORD")]
pub db_password: String,
/// Override the namespace used on the SurrealDB server
#[arg(long, env = "EVAL_DB_NAMESPACE")]
pub db_namespace: Option<String>,
/// Override the database used on the SurrealDB server
#[arg(long, env = "EVAL_DB_DATABASE")]
pub db_database: Option<String>,
/// Path to inspect DB state
#[arg(long)]
pub inspect_db_state: Option<PathBuf>,
}
#[derive(Parser, Debug, Clone)]
#[command(author, version, about, long_about = None)]
pub struct Config {
/// Convert the selected dataset and exit
#[arg(long)]
pub convert_only: bool,
/// Regenerate the converted dataset even if it already exists
#[arg(long, alias = "refresh")]
pub force_convert: bool,
/// Dataset to evaluate
#[arg(long, default_value_t = DatasetKind::default())]
pub dataset: DatasetKind,
/// Enable LLM-assisted evaluation features (includes unanswerable cases)
#[arg(long)]
pub llm_mode: bool,
/// Cap the slice corpus size (positives + negatives)
#[arg(long)]
pub corpus_limit: Option<usize>,
/// Path to the raw dataset (defaults per dataset)
#[arg(long)]
pub raw: Option<PathBuf>,
/// Path to write/read the converted dataset (defaults per dataset)
#[arg(long)]
pub converted: Option<PathBuf>,
/// Directory to write evaluation reports
#[arg(long, default_value_os_t = default_report_dir())]
pub report_dir: PathBuf,
/// Precision@k cutoff
#[arg(long, default_value_t = 5)]
pub k: usize,
/// Limit the number of questions evaluated (0 = all)
#[arg(long = "limit", default_value_t = 200)]
pub limit_arg: usize,
/// Number of mismatches to surface in the Markdown summary
#[arg(long, default_value_t = 5)]
pub sample: usize,
/// Disable context cropping when converting datasets (ingest entire documents)
#[arg(long)]
pub full_context: bool,
#[command(flatten)]
pub retrieval: RetrievalSettings,
/// Concurrency level
#[arg(long, default_value_t = 1)]
pub concurrency: usize,
/// Embedding backend
#[arg(long, default_value_t = EmbeddingBackend::FastEmbed)]
pub embedding_backend: EmbeddingBackend,
/// FastEmbed model code
#[arg(long)]
pub embedding_model: Option<String>,
/// Directory for embedding caches
#[arg(long, default_value_os_t = default_cache_dir())]
pub cache_dir: PathBuf,
#[command(flatten)]
pub ingest: IngestConfig,
/// Include entity descriptions and categories in JSON reports
#[arg(long)]
pub detailed_report: bool,
/// Use a cached dataset slice by id or path
#[arg(long)]
pub slice: Option<String>,
/// Ignore cached corpus state and rebuild the slice's SurrealDB corpus
#[arg(long)]
pub reseed_slice: bool,
/// Slice seed
#[arg(skip = DEFAULT_SLICE_SEED)]
pub slice_seed: u64,
/// Grow the slice ledger to contain at least this many answerable cases, then exit
#[arg(long)]
pub slice_grow: Option<usize>,
/// Evaluate questions starting at this offset within the slice
#[arg(long, default_value_t = 0)]
pub slice_offset: usize,
/// Target negative-to-positive paragraph ratio for slice growth
#[arg(long, default_value_t = crate::slice::DEFAULT_NEGATIVE_MULTIPLIER)]
pub negative_multiplier: f32,
/// Annotate the run; label is stored in JSON/Markdown reports
#[arg(long)]
pub label: Option<String>,
/// Write per-query chunk diagnostics JSONL to the provided path
#[arg(long, alias = "chunk-diagnostics")]
pub chunk_diagnostics_path: Option<PathBuf>,
/// Inspect an ingestion cache question and exit
#[arg(long)]
pub inspect_question: Option<String>,
/// Path to an ingestion cache manifest JSON for inspection mode
#[arg(long)]
pub inspect_manifest: Option<PathBuf>,
/// Override the SurrealDB system settings query model
#[arg(long)]
pub query_model: Option<String>,
/// Write structured performance telemetry JSON to the provided path
#[arg(long)]
pub perf_log_json: Option<PathBuf>,
/// Directory that receives timestamped perf JSON copies
#[arg(long)]
pub perf_log_dir: Option<PathBuf>,
/// Print per-stage performance timings to stdout after the run
#[arg(long, alias = "perf-log")]
pub perf_log_console: bool,
#[command(flatten)]
pub database: DatabaseArgs,
// Computed fields (not arguments)
#[arg(skip)]
pub raw_dataset_path: PathBuf,
#[arg(skip)]
pub converted_dataset_path: PathBuf,
#[arg(skip)]
pub limit: Option<usize>,
#[arg(skip)]
pub summary_sample: usize,
}
impl Config {
pub fn context_token_limit(&self) -> Option<usize> {
None
}
pub fn finalize(&mut self) -> Result<()> {
// Handle dataset paths
if let Some(raw) = &self.raw {
self.raw_dataset_path = raw.clone();
} else {
self.raw_dataset_path = self.dataset.default_raw_path();
}
if let Some(converted) = &self.converted {
self.converted_dataset_path = converted.clone();
} else {
self.converted_dataset_path = self.dataset.default_converted_path();
}
// Handle limit
if self.limit_arg == 0 {
self.limit = None;
} else {
self.limit = Some(self.limit_arg);
}
// Handle sample
self.summary_sample = self.sample.max(1);
// Handle retrieval settings
if self.llm_mode {
self.retrieval.require_verified_chunks = false;
} else {
self.retrieval.require_verified_chunks = true;
}
if self.dataset == DatasetKind::Beir {
self.negative_multiplier = 9.0;
}
// Validations
if self.ingest.ingest_chunk_min_tokens == 0
|| self.ingest.ingest_chunk_min_tokens >= self.ingest.ingest_chunk_max_tokens
{
return Err(anyhow!(
"--ingest-chunk-min-tokens must be greater than zero and less than --ingest-chunk-max-tokens (got {} >= {})",
self.ingest.ingest_chunk_min_tokens,
self.ingest.ingest_chunk_max_tokens
));
}
if self.ingest.ingest_chunk_overlap_tokens >= self.ingest.ingest_chunk_min_tokens {
return Err(anyhow!(
"--ingest-chunk-overlap-tokens ({}) must be less than --ingest-chunk-min-tokens ({})",
self.ingest.ingest_chunk_overlap_tokens,
self.ingest.ingest_chunk_min_tokens
));
}
if self.retrieval.rerank && self.retrieval.rerank_pool_size == 0 {
return Err(anyhow!(
"--rerank-pool must be greater than zero when reranking is enabled"
));
}
if let Some(k) = self.retrieval.chunk_rrf_k {
if k <= 0.0 || !k.is_finite() {
return Err(anyhow!(
"--chunk-rrf-k must be a positive, finite number (got {k})"
));
}
}
if let Some(weight) = self.retrieval.chunk_rrf_vector_weight {
if weight < 0.0 || !weight.is_finite() {
return Err(anyhow!(
"--chunk-rrf-vector-weight must be a non-negative, finite number (got {weight})"
));
}
}
if let Some(weight) = self.retrieval.chunk_rrf_fts_weight {
if weight < 0.0 || !weight.is_finite() {
return Err(anyhow!(
"--chunk-rrf-fts-weight must be a non-negative, finite number (got {weight})"
));
}
}
if self.concurrency == 0 {
return Err(anyhow!("--concurrency must be greater than zero"));
}
if self.embedding_backend == EmbeddingBackend::Hashed && self.embedding_model.is_some() {
return Err(anyhow!(
"--embedding-model cannot be used with the 'hashed' embedding backend"
));
}
if let Some(query_model) = &self.query_model {
if query_model.trim().is_empty() {
return Err(anyhow!("--query-model requires a non-empty model name"));
}
}
if let Some(grow) = self.slice_grow {
if grow == 0 {
return Err(anyhow!("--slice-grow must be greater than zero"));
}
}
if self.negative_multiplier <= 0.0 || !self.negative_multiplier.is_finite() {
return Err(anyhow!(
"--negative-multiplier must be a positive finite number"
));
}
// Handle corpus limit logic
if let Some(limit) = self.limit {
if let Some(corpus_limit) = self.corpus_limit {
if corpus_limit < limit {
self.corpus_limit = Some(limit);
}
} else {
let default_multiplier = 10usize;
let mut computed = limit.saturating_mul(default_multiplier);
if computed < limit {
computed = limit;
}
let max_cap = 1_000usize;
if computed > max_cap {
computed = max_cap;
}
self.corpus_limit = Some(computed);
}
}
// Handle perf log dir env var fallback
if self.perf_log_dir.is_none() {
if let Ok(dir) = env::var("EVAL_PERF_LOG_DIR") {
if !dir.trim().is_empty() {
self.perf_log_dir = Some(PathBuf::from(dir));
}
}
}
Ok(())
}
}
pub struct ParsedArgs {
pub config: Config,
}
pub fn parse() -> Result<ParsedArgs> {
let mut config = Config::parse();
config.finalize()?;
Ok(ParsedArgs { config })
}
pub fn ensure_parent(path: &Path) -> Result<()> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)
.with_context(|| format!("creating parent directory for {}", path.display()))?;
}
Ok(())
}

88
evaluations/src/cache.rs Normal file
View File

@@ -0,0 +1,88 @@
use std::{
collections::HashMap,
path::{Path, PathBuf},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
#[derive(Debug, Default, Serialize, Deserialize)]
struct EmbeddingCacheData {
entities: HashMap<String, Vec<f32>>,
chunks: HashMap<String, Vec<f32>>,
}
#[derive(Clone)]
pub struct EmbeddingCache {
path: Arc<PathBuf>,
data: Arc<Mutex<EmbeddingCacheData>>,
dirty: Arc<AtomicBool>,
}
#[allow(dead_code)]
impl EmbeddingCache {
pub async fn load(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref().to_path_buf();
let data = if path.exists() {
let raw = tokio::fs::read(&path)
.await
.with_context(|| format!("reading embedding cache {}", path.display()))?;
serde_json::from_slice(&raw)
.with_context(|| format!("parsing embedding cache {}", path.display()))?
} else {
EmbeddingCacheData::default()
};
Ok(Self {
path: Arc::new(path),
data: Arc::new(Mutex::new(data)),
dirty: Arc::new(AtomicBool::new(false)),
})
}
pub async fn get_entity(&self, id: &str) -> Option<Vec<f32>> {
let guard = self.data.lock().await;
guard.entities.get(id).cloned()
}
pub async fn insert_entity(&self, id: String, embedding: Vec<f32>) {
let mut guard = self.data.lock().await;
guard.entities.insert(id, embedding);
self.dirty.store(true, Ordering::Relaxed);
}
pub async fn get_chunk(&self, id: &str) -> Option<Vec<f32>> {
let guard = self.data.lock().await;
guard.chunks.get(id).cloned()
}
pub async fn insert_chunk(&self, id: String, embedding: Vec<f32>) {
let mut guard = self.data.lock().await;
guard.chunks.insert(id, embedding);
self.dirty.store(true, Ordering::Relaxed);
}
pub async fn persist(&self) -> Result<()> {
if !self.dirty.load(Ordering::Relaxed) {
return Ok(());
}
let guard = self.data.lock().await;
let body = serde_json::to_vec_pretty(&*guard).context("serialising embedding cache")?;
if let Some(parent) = self.path.parent() {
tokio::fs::create_dir_all(parent)
.await
.with_context(|| format!("creating cache directory {}", parent.display()))?;
}
tokio::fs::write(&*self.path, body)
.await
.with_context(|| format!("writing embedding cache {}", self.path.display()))?;
self.dirty.store(false, Ordering::Relaxed);
Ok(())
}
}

187
evaluations/src/cases.rs Normal file
View File

@@ -0,0 +1,187 @@
//! Case generation from corpus manifests.
use std::collections::HashMap;
use crate::corpus;
/// A test case for retrieval evaluation derived from a manifest question.
pub(crate) struct SeededCase {
pub question_id: String,
pub question: String,
pub expected_source: String,
pub answers: Vec<String>,
pub paragraph_id: String,
pub paragraph_title: String,
pub expected_chunk_ids: Vec<String>,
pub is_impossible: bool,
pub has_verified_chunks: bool,
}
/// Convert a corpus manifest into seeded evaluation cases.
pub(crate) fn cases_from_manifest(manifest: &corpus::CorpusManifest) -> Vec<SeededCase> {
let mut title_map = HashMap::new();
for paragraph in &manifest.paragraphs {
title_map.insert(paragraph.paragraph_id.as_str(), paragraph.title.clone());
}
let include_impossible = manifest.metadata.include_unanswerable;
let require_verified_chunks = manifest.metadata.require_verified_chunks;
manifest
.questions
.iter()
.filter(|question| {
should_include_question(question, include_impossible, require_verified_chunks)
})
.map(|question| {
let title = title_map
.get(question.paragraph_id.as_str())
.cloned()
.unwrap_or_else(|| "Untitled".to_string());
SeededCase {
question_id: question.question_id.clone(),
question: question.question_text.clone(),
expected_source: question.text_content_id.clone(),
answers: question.answers.clone(),
paragraph_id: question.paragraph_id.clone(),
paragraph_title: title,
expected_chunk_ids: question.matching_chunk_ids.clone(),
is_impossible: question.is_impossible,
has_verified_chunks: !question.matching_chunk_ids.is_empty(),
}
})
.collect()
}
fn should_include_question(
question: &corpus::CorpusQuestion,
include_impossible: bool,
require_verified_chunks: bool,
) -> bool {
if !include_impossible && question.is_impossible {
return false;
}
if require_verified_chunks && question.matching_chunk_ids.is_empty() {
return false;
}
true
}
#[cfg(test)]
mod tests {
use super::*;
use crate::corpus::store::{CorpusParagraph, EmbeddedKnowledgeEntity, EmbeddedTextChunk};
use crate::corpus::{CorpusManifest, CorpusMetadata, CorpusQuestion, MANIFEST_VERSION};
use chrono::Utc;
use common::storage::types::text_content::TextContent;
fn sample_manifest() -> CorpusManifest {
let paragraphs = vec![
CorpusParagraph {
paragraph_id: "p1".to_string(),
title: "Alpha".to_string(),
text_content: TextContent::new(
"alpha context".to_string(),
None,
"test".to_string(),
None,
None,
"user".to_string(),
),
entities: Vec::<EmbeddedKnowledgeEntity>::new(),
relationships: Vec::new(),
chunks: Vec::<EmbeddedTextChunk>::new(),
},
CorpusParagraph {
paragraph_id: "p2".to_string(),
title: "Beta".to_string(),
text_content: TextContent::new(
"beta context".to_string(),
None,
"test".to_string(),
None,
None,
"user".to_string(),
),
entities: Vec::<EmbeddedKnowledgeEntity>::new(),
relationships: Vec::new(),
chunks: Vec::<EmbeddedTextChunk>::new(),
},
];
let questions = vec![
CorpusQuestion {
question_id: "q1".to_string(),
paragraph_id: "p1".to_string(),
text_content_id: "tc-alpha".to_string(),
question_text: "What is Alpha?".to_string(),
answers: vec!["Alpha".to_string()],
is_impossible: false,
matching_chunk_ids: vec!["chunk-alpha".to_string()],
},
CorpusQuestion {
question_id: "q2".to_string(),
paragraph_id: "p1".to_string(),
text_content_id: "tc-alpha".to_string(),
question_text: "Unanswerable?".to_string(),
answers: Vec::new(),
is_impossible: true,
matching_chunk_ids: Vec::new(),
},
CorpusQuestion {
question_id: "q3".to_string(),
paragraph_id: "p2".to_string(),
text_content_id: "tc-beta".to_string(),
question_text: "Where is Beta?".to_string(),
answers: vec!["Beta".to_string()],
is_impossible: false,
matching_chunk_ids: Vec::new(),
},
];
CorpusManifest {
version: MANIFEST_VERSION,
metadata: CorpusMetadata {
dataset_id: "ds".to_string(),
dataset_label: "Dataset".to_string(),
slice_id: "slice".to_string(),
include_unanswerable: true,
require_verified_chunks: true,
ingestion_fingerprint: "fp".to_string(),
embedding_backend: "test".to_string(),
embedding_model: None,
embedding_dimension: 3,
converted_checksum: "chk".to_string(),
generated_at: Utc::now(),
paragraph_count: paragraphs.len(),
question_count: questions.len(),
chunk_min_tokens: 1,
chunk_max_tokens: 10,
chunk_only: false,
},
paragraphs,
questions,
}
}
#[test]
fn cases_respect_mode_filters() {
let mut manifest = sample_manifest();
manifest.metadata.include_unanswerable = false;
manifest.metadata.require_verified_chunks = true;
let strict_cases = cases_from_manifest(&manifest);
assert_eq!(strict_cases.len(), 1);
assert_eq!(strict_cases[0].question_id, "q1");
assert_eq!(strict_cases[0].paragraph_title, "Alpha");
let mut llm_manifest = manifest.clone();
llm_manifest.metadata.include_unanswerable = true;
llm_manifest.metadata.require_verified_chunks = false;
let llm_cases = cases_from_manifest(&llm_manifest);
let ids: Vec<_> = llm_cases
.iter()
.map(|case| case.question_id.as_str())
.collect();
assert_eq!(ids, vec!["q1", "q2", "q3"]);
}
}

View File

@@ -0,0 +1,42 @@
use std::path::PathBuf;
use crate::args::Config;
#[derive(Debug, Clone)]
pub struct CorpusCacheConfig {
pub ingestion_cache_dir: PathBuf,
pub force_refresh: bool,
pub refresh_embeddings_only: bool,
pub ingestion_batch_size: usize,
pub ingestion_max_retries: usize,
}
impl CorpusCacheConfig {
pub fn new(
ingestion_cache_dir: impl Into<PathBuf>,
force_refresh: bool,
refresh_embeddings_only: bool,
ingestion_batch_size: usize,
ingestion_max_retries: usize,
) -> Self {
Self {
ingestion_cache_dir: ingestion_cache_dir.into(),
force_refresh,
refresh_embeddings_only,
ingestion_batch_size,
ingestion_max_retries,
}
}
}
impl From<&Config> for CorpusCacheConfig {
fn from(config: &Config) -> Self {
CorpusCacheConfig::new(
config.ingest.ingestion_cache_dir.clone(),
config.force_convert || config.ingest.slice_reset_ingestion,
config.ingest.refresh_embeddings_only,
config.ingest.ingestion_batch_size,
config.ingest.ingestion_max_retries,
)
}
}

View File

@@ -0,0 +1,26 @@
mod config;
mod orchestrator;
pub(crate) mod store;
pub use config::CorpusCacheConfig;
pub use orchestrator::{
cached_corpus_dir, compute_ingestion_fingerprint, corpus_handle_from_manifest, ensure_corpus,
load_cached_manifest,
};
pub use store::{
seed_manifest_into_db, window_manifest, CorpusHandle, CorpusManifest, CorpusMetadata,
CorpusQuestion, EmbeddedKnowledgeEntity, EmbeddedTextChunk, ParagraphShard,
ParagraphShardStore, MANIFEST_VERSION,
};
pub fn make_ingestion_config(config: &crate::args::Config) -> ingestion_pipeline::IngestionConfig {
let mut tuning = ingestion_pipeline::IngestionTuning::default();
tuning.chunk_min_tokens = config.ingest.ingest_chunk_min_tokens;
tuning.chunk_max_tokens = config.ingest.ingest_chunk_max_tokens;
tuning.chunk_overlap_tokens = config.ingest.ingest_chunk_overlap_tokens;
ingestion_pipeline::IngestionConfig {
tuning,
chunk_only: config.ingest.ingest_chunks_only,
}
}

View File

@@ -0,0 +1,783 @@
use std::{
collections::{HashMap, HashSet},
fs,
io::Read,
path::{Path, PathBuf},
sync::Arc,
};
use anyhow::{anyhow, Context, Result};
use async_openai::Client;
use chrono::Utc;
use common::{
storage::{
db::SurrealDbClient,
store::{DynStore, StorageManager},
types::{ingestion_payload::IngestionPayload, ingestion_task::IngestionTask, StoredObject},
},
utils::config::{AppConfig, StorageKind},
};
use futures::future::try_join_all;
use ingestion_pipeline::{IngestionConfig, IngestionPipeline};
use object_store::memory::InMemory;
use sha2::{Digest, Sha256};
use tracing::{info, warn};
use uuid::Uuid;
use crate::{
datasets::{ConvertedDataset, ConvertedParagraph, ConvertedQuestion},
slice::{self, ResolvedSlice, SliceParagraphKind},
};
use crate::corpus::{
CorpusCacheConfig, CorpusHandle, CorpusManifest, CorpusMetadata, CorpusQuestion,
EmbeddedKnowledgeEntity, EmbeddedTextChunk, ParagraphShard, ParagraphShardStore,
MANIFEST_VERSION,
};
const INGESTION_SPEC_VERSION: u32 = 2;
type OpenAIClient = Client<async_openai::config::OpenAIConfig>;
#[derive(Clone)]
struct ParagraphShardRecord {
shard: ParagraphShard,
dirty: bool,
needs_reembed: bool,
}
#[derive(Clone)]
struct IngestRequest<'a> {
slot: usize,
paragraph: &'a ConvertedParagraph,
shard_path: String,
question_refs: Vec<&'a ConvertedQuestion>,
}
impl<'a> IngestRequest<'a> {
fn from_entry(
slot: usize,
paragraph: &'a ConvertedParagraph,
entry: &'a slice::SliceParagraphEntry,
) -> Result<Self> {
let shard_path = entry
.shard_path
.clone()
.unwrap_or_else(|| slice::default_shard_path(&entry.id));
let question_refs = match &entry.kind {
SliceParagraphKind::Positive { question_ids } => question_ids
.iter()
.map(|id| {
paragraph
.questions
.iter()
.find(|question| question.id == *id)
.ok_or_else(|| {
anyhow!(
"paragraph '{}' missing question '{}' referenced by slice",
paragraph.id,
id
)
})
})
.collect::<Result<Vec<_>>>()?,
SliceParagraphKind::Negative => Vec::new(),
};
Ok(Self {
slot,
paragraph,
shard_path,
question_refs,
})
}
}
struct ParagraphPlan<'a> {
slot: usize,
entry: &'a slice::SliceParagraphEntry,
paragraph: &'a ConvertedParagraph,
}
#[derive(Default)]
struct IngestionStats {
positive_reused: usize,
positive_ingested: usize,
negative_reused: usize,
negative_ingested: usize,
}
pub async fn ensure_corpus(
dataset: &ConvertedDataset,
slice: &ResolvedSlice<'_>,
window: &slice::SliceWindow<'_>,
cache: &CorpusCacheConfig,
embedding: Arc<common::utils::embedding::EmbeddingProvider>,
openai: Arc<OpenAIClient>,
user_id: &str,
converted_path: &Path,
ingestion_config: IngestionConfig,
) -> Result<CorpusHandle> {
let checksum = compute_file_checksum(converted_path)
.with_context(|| format!("computing checksum for {}", converted_path.display()))?;
let ingestion_fingerprint =
build_ingestion_fingerprint(dataset, slice, &checksum, &ingestion_config);
let base_dir = cached_corpus_dir(
cache,
dataset.metadata.id.as_str(),
slice.manifest.slice_id.as_str(),
);
if cache.force_refresh && !cache.refresh_embeddings_only {
let _ = fs::remove_dir_all(&base_dir);
}
let store = ParagraphShardStore::new(base_dir.clone());
store.ensure_base_dir()?;
let positive_set: HashSet<&str> = window.positive_ids().collect();
let require_verified_chunks = slice.manifest.require_verified_chunks;
let embedding_backend_label = embedding.backend_label().to_string();
let embedding_model_code = embedding.model_code();
let embedding_dimension = embedding.dimension();
if positive_set.is_empty() {
return Err(anyhow!(
"window selection contains zero positive paragraphs for slice '{}'",
slice.manifest.slice_id
));
}
let desired_negatives =
((positive_set.len() as f32) * slice.manifest.negative_multiplier).ceil() as usize;
let mut plan = Vec::new();
let mut negatives_added = 0usize;
for (idx, entry) in slice.manifest.paragraphs.iter().enumerate() {
let include = match &entry.kind {
SliceParagraphKind::Positive { .. } => positive_set.contains(entry.id.as_str()),
SliceParagraphKind::Negative => {
negatives_added < desired_negatives && {
negatives_added += 1;
true
}
}
};
if include {
let paragraph = slice
.paragraphs
.get(idx)
.copied()
.ok_or_else(|| anyhow!("slice missing paragraph index {}", idx))?;
plan.push(ParagraphPlan {
slot: plan.len(),
entry,
paragraph,
});
}
}
if plan.is_empty() {
return Err(anyhow!(
"no paragraphs selected for ingestion (slice '{}')",
slice.manifest.slice_id
));
}
let mut records: Vec<Option<ParagraphShardRecord>> = vec![None; plan.len()];
let mut ingest_requests = Vec::new();
let mut stats = IngestionStats::default();
for plan_entry in &plan {
let shard_path = plan_entry
.entry
.shard_path
.clone()
.unwrap_or_else(|| slice::default_shard_path(&plan_entry.entry.id));
let shard = if cache.force_refresh {
None
} else {
store.load(&shard_path, &ingestion_fingerprint)?
};
if let Some(shard) = shard {
let model_matches = shard.embedding_model.as_deref() == embedding_model_code.as_deref();
let needs_reembed = shard.embedding_backend != embedding_backend_label
|| shard.embedding_dimension != embedding_dimension
|| !model_matches;
match plan_entry.entry.kind {
SliceParagraphKind::Positive { .. } => stats.positive_reused += 1,
SliceParagraphKind::Negative => stats.negative_reused += 1,
}
records[plan_entry.slot] = Some(ParagraphShardRecord {
shard,
dirty: false,
needs_reembed,
});
} else {
match plan_entry.entry.kind {
SliceParagraphKind::Positive { .. } => stats.positive_ingested += 1,
SliceParagraphKind::Negative => stats.negative_ingested += 1,
}
let request =
IngestRequest::from_entry(plan_entry.slot, plan_entry.paragraph, plan_entry.entry)?;
ingest_requests.push(request);
}
}
if cache.refresh_embeddings_only && !ingest_requests.is_empty() {
return Err(anyhow!(
"--refresh-embeddings requested but {} shard(s) missing for dataset '{}' slice '{}'",
ingest_requests.len(),
dataset.metadata.id,
slice.manifest.slice_id
));
}
if !ingest_requests.is_empty() {
let new_shards = ingest_paragraph_batch(
dataset,
&ingest_requests,
embedding.clone(),
openai.clone(),
user_id,
&ingestion_fingerprint,
&embedding_backend_label,
embedding_model_code.clone(),
embedding_dimension,
cache.ingestion_batch_size,
cache.ingestion_max_retries,
ingestion_config.clone(),
)
.await
.context("ingesting missing slice paragraphs")?;
for (request, shard) in ingest_requests.into_iter().zip(new_shards.into_iter()) {
store.persist(&shard)?;
records[request.slot] = Some(ParagraphShardRecord {
shard,
dirty: false,
needs_reembed: false,
});
}
}
for record in &mut records {
let shard_record = record
.as_mut()
.context("shard record missing after ingestion run")?;
if cache.refresh_embeddings_only || shard_record.needs_reembed {
// Embeddings are now generated by the pipeline using FastEmbed - no need to re-embed
shard_record.shard.ingestion_fingerprint = ingestion_fingerprint.clone();
shard_record.shard.ingested_at = Utc::now();
shard_record.shard.embedding_backend = embedding_backend_label.clone();
shard_record.shard.embedding_model = embedding_model_code.clone();
shard_record.shard.embedding_dimension = embedding_dimension;
shard_record.dirty = true;
shard_record.needs_reembed = false;
}
}
let mut record_index = HashMap::new();
for (idx, plan_entry) in plan.iter().enumerate() {
record_index.insert(plan_entry.entry.id.as_str(), idx);
}
let mut corpus_paragraphs = Vec::with_capacity(plan.len());
for record in &records {
let shard = &record.as_ref().expect("record missing").shard;
corpus_paragraphs.push(shard.to_corpus_paragraph());
}
let mut corpus_questions = Vec::with_capacity(window.cases.len());
for case in &window.cases {
let slot = record_index
.get(case.paragraph.id.as_str())
.copied()
.ok_or_else(|| {
anyhow!(
"slice case references paragraph '{}' that is not part of the window",
case.paragraph.id
)
})?;
let record_slot = records
.get_mut(slot)
.context("shard record slot missing for question binding")?;
let record = record_slot
.as_mut()
.context("shard record missing for question binding")?;
let (chunk_ids, updated) = match record.shard.ensure_question_binding(case.question) {
Ok(result) => result,
Err(err) => {
if require_verified_chunks {
return Err(err).context(format!(
"locating answer text for question '{}' in paragraph '{}'",
case.question.id, case.paragraph.id
));
}
warn!(
question_id = %case.question.id,
paragraph_id = %case.paragraph.id,
error = %err,
"Failed to locate answer text in ingested content; recording empty chunk bindings"
);
record
.shard
.question_bindings
.insert(case.question.id.clone(), Vec::new());
record.dirty = true;
(Vec::new(), true)
}
};
if updated {
record.dirty = true;
}
corpus_questions.push(CorpusQuestion {
question_id: case.question.id.clone(),
paragraph_id: case.paragraph.id.clone(),
text_content_id: record.shard.text_content.get_id().to_string(),
question_text: case.question.question.clone(),
answers: case.question.answers.clone(),
is_impossible: case.question.is_impossible,
matching_chunk_ids: chunk_ids,
});
}
for record in &mut records {
if let Some(ref mut entry) = record {
if entry.dirty {
store.persist(&entry.shard)?;
}
}
}
let manifest = CorpusManifest {
version: MANIFEST_VERSION,
metadata: CorpusMetadata {
dataset_id: dataset.metadata.id.clone(),
dataset_label: dataset.metadata.label.clone(),
slice_id: slice.manifest.slice_id.clone(),
include_unanswerable: slice.manifest.includes_unanswerable,
require_verified_chunks: slice.manifest.require_verified_chunks,
ingestion_fingerprint: ingestion_fingerprint.clone(),
embedding_backend: embedding.backend_label().to_string(),
embedding_model: embedding.model_code(),
embedding_dimension: embedding.dimension(),
converted_checksum: checksum,
generated_at: Utc::now(),
paragraph_count: corpus_paragraphs.len(),
question_count: corpus_questions.len(),
chunk_min_tokens: ingestion_config.tuning.chunk_min_tokens,
chunk_max_tokens: ingestion_config.tuning.chunk_max_tokens,
chunk_only: ingestion_config.chunk_only,
},
paragraphs: corpus_paragraphs,
questions: corpus_questions,
};
let ingested_count = stats.positive_ingested + stats.negative_ingested;
let reused_ingestion = ingested_count == 0 && !cache.force_refresh;
let reused_embeddings = reused_ingestion && !cache.refresh_embeddings_only;
info!(
dataset = %dataset.metadata.id,
slice = %slice.manifest.slice_id,
fingerprint = %ingestion_fingerprint,
reused_ingestion,
reused_embeddings,
positive_reused = stats.positive_reused,
positive_ingested = stats.positive_ingested,
negative_reused = stats.negative_reused,
negative_ingested = stats.negative_ingested,
shard_dir = %base_dir.display(),
"Corpus cache outcome"
);
let handle = CorpusHandle {
manifest,
path: base_dir,
reused_ingestion,
reused_embeddings,
positive_reused: stats.positive_reused,
positive_ingested: stats.positive_ingested,
negative_reused: stats.negative_reused,
negative_ingested: stats.negative_ingested,
};
persist_manifest(&handle).context("persisting corpus manifest")?;
Ok(handle)
}
async fn ingest_paragraph_batch(
dataset: &ConvertedDataset,
targets: &[IngestRequest<'_>],
embedding: Arc<common::utils::embedding::EmbeddingProvider>,
openai: Arc<OpenAIClient>,
user_id: &str,
ingestion_fingerprint: &str,
embedding_backend: &str,
embedding_model: Option<String>,
embedding_dimension: usize,
batch_size: usize,
max_retries: usize,
ingestion_config: IngestionConfig,
) -> Result<Vec<ParagraphShard>> {
if targets.is_empty() {
return Ok(Vec::new());
}
let namespace = format!("ingest_eval_{}", Uuid::new_v4());
let db = Arc::new(
SurrealDbClient::memory(&namespace, "corpus")
.await
.context("creating in-memory surrealdb for ingestion")?,
);
db.apply_migrations()
.await
.context("applying migrations for ingestion")?;
let mut app_config = AppConfig::default();
app_config.storage = StorageKind::Memory;
let backend: DynStore = Arc::new(InMemory::new());
let storage = StorageManager::with_backend(backend, StorageKind::Memory);
let pipeline_config = ingestion_config.clone();
let pipeline = IngestionPipeline::new_with_config(
db,
openai.clone(),
app_config,
None::<Arc<retrieval_pipeline::reranking::RerankerPool>>,
storage,
embedding.clone(),
pipeline_config,
)
.await?;
let pipeline = Arc::new(pipeline);
let mut shards = Vec::with_capacity(targets.len());
let category = dataset.metadata.category.clone();
for (batch_index, batch) in targets.chunks(batch_size).enumerate() {
info!(
batch = batch_index,
batch_size = batch.len(),
total_batches = (targets.len() + batch_size - 1) / batch_size,
"Ingesting paragraph batch"
);
let model_clone = embedding_model.clone();
let backend_clone = embedding_backend.to_string();
let pipeline_clone = pipeline.clone();
let category_clone = category.clone();
let tasks = batch.iter().cloned().map(move |request| {
ingest_single_paragraph(
pipeline_clone.clone(),
request,
category_clone.clone(),
user_id,
ingestion_fingerprint,
backend_clone.clone(),
model_clone.clone(),
embedding_dimension,
max_retries,
ingestion_config.tuning.chunk_min_tokens,
ingestion_config.tuning.chunk_max_tokens,
ingestion_config.chunk_only,
)
});
let batch_results: Vec<ParagraphShard> = try_join_all(tasks)
.await
.context("ingesting batch of paragraphs")?;
shards.extend(batch_results);
}
Ok(shards)
}
async fn ingest_single_paragraph(
pipeline: Arc<IngestionPipeline>,
request: IngestRequest<'_>,
category: String,
user_id: &str,
ingestion_fingerprint: &str,
embedding_backend: String,
embedding_model: Option<String>,
embedding_dimension: usize,
max_retries: usize,
chunk_min_tokens: usize,
chunk_max_tokens: usize,
chunk_only: bool,
) -> Result<ParagraphShard> {
let paragraph = request.paragraph;
let mut last_err: Option<anyhow::Error> = None;
for attempt in 1..=max_retries {
let payload = IngestionPayload::Text {
text: paragraph.context.clone(),
context: paragraph.title.clone(),
category: category.clone(),
user_id: user_id.to_string(),
};
let task = IngestionTask::new(payload, user_id.to_string());
match pipeline.produce_artifacts(&task).await {
Ok(artifacts) => {
let entities: Vec<EmbeddedKnowledgeEntity> = artifacts
.entities
.into_iter()
.map(|e| EmbeddedKnowledgeEntity {
entity: e.entity,
embedding: e.embedding,
})
.collect();
let chunks: Vec<EmbeddedTextChunk> = artifacts
.chunks
.into_iter()
.map(|c| EmbeddedTextChunk {
chunk: c.chunk,
embedding: c.embedding,
})
.collect();
// No need to reembed - pipeline now uses FastEmbed internally
let mut shard = ParagraphShard::new(
paragraph,
request.shard_path,
ingestion_fingerprint,
artifacts.text_content,
entities,
artifacts.relationships,
chunks,
&embedding_backend,
embedding_model.clone(),
embedding_dimension,
chunk_min_tokens,
chunk_max_tokens,
chunk_only,
);
for question in &request.question_refs {
if let Err(err) = shard.ensure_question_binding(question) {
warn!(
question_id = %question.id,
paragraph_id = %paragraph.id,
error = %err,
"Failed to locate answer text in ingested content; recording empty chunk bindings"
);
shard
.question_bindings
.insert(question.id.clone(), Vec::new());
}
}
return Ok(shard);
}
Err(err) => {
warn!(
paragraph_id = %paragraph.id,
attempt,
max_attempts = max_retries,
error = ?err,
"ingestion attempt failed for paragraph; retrying"
);
last_err = Some(err.into());
}
}
}
Err(last_err
.unwrap_or_else(|| anyhow!("ingestion failed"))
.context(format!("running ingestion for paragraph {}", paragraph.id)))
}
pub fn cached_corpus_dir(cache: &CorpusCacheConfig, dataset_id: &str, slice_id: &str) -> PathBuf {
cache.ingestion_cache_dir.join(dataset_id).join(slice_id)
}
pub fn build_ingestion_fingerprint(
dataset: &ConvertedDataset,
slice: &ResolvedSlice<'_>,
checksum: &str,
ingestion_config: &IngestionConfig,
) -> String {
let config_repr = format!("{:?}", ingestion_config);
let mut hasher = Sha256::new();
hasher.update(config_repr.as_bytes());
let config_hash = format!("{:x}", hasher.finalize());
format!(
"v{INGESTION_SPEC_VERSION}:{}:{}:{}:{}:{}",
dataset.metadata.id,
slice.manifest.slice_id,
slice.manifest.includes_unanswerable,
checksum,
config_hash
)
}
pub fn compute_ingestion_fingerprint(
dataset: &ConvertedDataset,
slice: &ResolvedSlice<'_>,
converted_path: &Path,
ingestion_config: &IngestionConfig,
) -> Result<String> {
let checksum = compute_file_checksum(converted_path)?;
Ok(build_ingestion_fingerprint(
dataset,
slice,
&checksum,
ingestion_config,
))
}
pub fn load_cached_manifest(base_dir: &Path) -> Result<Option<CorpusManifest>> {
let path = base_dir.join("manifest.json");
if !path.exists() {
return Ok(None);
}
let mut file = fs::File::open(&path)
.with_context(|| format!("opening cached manifest {}", path.display()))?;
let mut buf = Vec::new();
file.read_to_end(&mut buf)
.with_context(|| format!("reading cached manifest {}", path.display()))?;
let manifest: CorpusManifest = serde_json::from_slice(&buf)
.with_context(|| format!("deserialising cached manifest {}", path.display()))?;
Ok(Some(manifest))
}
fn persist_manifest(handle: &CorpusHandle) -> Result<()> {
let path = handle.path.join("manifest.json");
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("creating manifest directory {}", parent.display()))?;
}
let tmp_path = path.with_extension("json.tmp");
let blob =
serde_json::to_vec_pretty(&handle.manifest).context("serialising corpus manifest")?;
fs::write(&tmp_path, &blob)
.with_context(|| format!("writing temporary manifest {}", tmp_path.display()))?;
fs::rename(&tmp_path, &path)
.with_context(|| format!("replacing manifest {}", path.display()))?;
Ok(())
}
pub fn corpus_handle_from_manifest(manifest: CorpusManifest, base_dir: PathBuf) -> CorpusHandle {
CorpusHandle {
manifest,
path: base_dir,
reused_ingestion: true,
reused_embeddings: true,
positive_reused: 0,
positive_ingested: 0,
negative_reused: 0,
negative_ingested: 0,
}
}
fn compute_file_checksum(path: &Path) -> Result<String> {
let mut file = fs::File::open(path)
.with_context(|| format!("opening file {} for checksum", path.display()))?;
let mut hasher = Sha256::new();
let mut buffer = [0u8; 8192];
loop {
let read = file
.read(&mut buffer)
.with_context(|| format!("reading {} for checksum", path.display()))?;
if read == 0 {
break;
}
hasher.update(&buffer[..read]);
}
Ok(format!("{:x}", hasher.finalize()))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
datasets::{ConvertedDataset, ConvertedParagraph, ConvertedQuestion, DatasetKind},
slice::{CaseRef, SliceCaseEntry, SliceManifest, SliceParagraphEntry, SliceParagraphKind},
};
use chrono::Utc;
fn dummy_dataset() -> ConvertedDataset {
let question = ConvertedQuestion {
id: "q1".to_string(),
question: "What?".to_string(),
answers: vec!["A".to_string()],
is_impossible: false,
};
let paragraph = ConvertedParagraph {
id: "p1".to_string(),
title: "title".to_string(),
context: "context".to_string(),
questions: vec![question],
};
ConvertedDataset {
generated_at: Utc::now(),
metadata: crate::datasets::DatasetMetadata::for_kind(
DatasetKind::default(),
false,
None,
),
source: "src".to_string(),
paragraphs: vec![paragraph],
}
}
fn dummy_slice<'a>(dataset: &'a ConvertedDataset) -> ResolvedSlice<'a> {
let paragraph = &dataset.paragraphs[0];
let question = &paragraph.questions[0];
let manifest = SliceManifest {
version: 1,
slice_id: "slice-1".to_string(),
dataset_id: dataset.metadata.id.clone(),
dataset_label: dataset.metadata.label.clone(),
dataset_source: dataset.source.clone(),
includes_unanswerable: false,
require_verified_chunks: false,
seed: 1,
requested_limit: Some(1),
requested_corpus: 1,
generated_at: Utc::now(),
case_count: 1,
positive_paragraphs: 1,
negative_paragraphs: 0,
total_paragraphs: 1,
negative_multiplier: 1.0,
cases: vec![SliceCaseEntry {
question_id: question.id.clone(),
paragraph_id: paragraph.id.clone(),
}],
paragraphs: vec![SliceParagraphEntry {
id: paragraph.id.clone(),
kind: SliceParagraphKind::Positive {
question_ids: vec![question.id.clone()],
},
shard_path: None,
}],
};
ResolvedSlice {
manifest,
path: PathBuf::from("cache"),
paragraphs: dataset.paragraphs.iter().collect(),
cases: vec![CaseRef {
paragraph,
question,
}],
}
}
#[test]
fn fingerprint_changes_with_chunk_settings() {
let dataset = dummy_dataset();
let slice = dummy_slice(&dataset);
let checksum = "deadbeef";
let base_config = IngestionConfig::default();
let fp_base = build_ingestion_fingerprint(&dataset, &slice, checksum, &base_config);
let mut token_config = base_config.clone();
token_config.tuning.chunk_min_tokens += 1;
let fp_token = build_ingestion_fingerprint(&dataset, &slice, checksum, &token_config);
assert_ne!(fp_base, fp_token, "token bounds should affect fingerprint");
let mut chunk_only_config = base_config;
chunk_only_config.chunk_only = true;
let fp_chunk_only =
build_ingestion_fingerprint(&dataset, &slice, checksum, &chunk_only_config);
assert_ne!(
fp_base, fp_chunk_only,
"chunk-only mode should affect fingerprint"
);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,341 @@
use std::{
collections::{BTreeMap, HashMap},
fs::File,
io::{BufRead, BufReader},
path::{Path, PathBuf},
};
use anyhow::{anyhow, Context, Result};
use serde::Deserialize;
use tracing::warn;
use super::{ConvertedParagraph, ConvertedQuestion, DatasetKind};
const ANSWER_SNIPPET_CHARS: usize = 240;
#[derive(Debug, Deserialize)]
struct BeirCorpusRow {
#[serde(rename = "_id")]
id: String,
#[serde(default)]
title: Option<String>,
#[serde(default)]
text: Option<String>,
}
#[derive(Debug, Deserialize)]
struct BeirQueryRow {
#[serde(rename = "_id")]
id: String,
text: String,
}
#[derive(Debug, Clone)]
struct BeirParagraph {
title: String,
context: String,
}
#[derive(Debug, Clone)]
struct BeirQuery {
text: String,
}
#[derive(Debug, Clone)]
struct QrelEntry {
doc_id: String,
score: i32,
}
pub fn convert_beir(raw_dir: &Path, dataset: DatasetKind) -> Result<Vec<ConvertedParagraph>> {
let corpus_path = raw_dir.join("corpus.jsonl");
let queries_path = raw_dir.join("queries.jsonl");
let qrels_path = resolve_qrels_path(raw_dir)?;
let corpus = load_corpus(&corpus_path)?;
let queries = load_queries(&queries_path)?;
let qrels = load_qrels(&qrels_path)?;
let mut paragraphs = Vec::with_capacity(corpus.len());
let mut paragraph_index = HashMap::new();
for (doc_id, entry) in corpus.iter() {
let paragraph_id = format!("{}-{doc_id}", dataset.source_prefix());
let paragraph = ConvertedParagraph {
id: paragraph_id.clone(),
title: entry.title.clone(),
context: entry.context.clone(),
questions: Vec::new(),
};
paragraph_index.insert(doc_id.clone(), paragraphs.len());
paragraphs.push(paragraph);
}
let mut missing_queries = 0usize;
let mut missing_docs = 0usize;
let mut skipped_answers = 0usize;
for (query_id, entries) in qrels {
let query = match queries.get(&query_id) {
Some(query) => query,
None => {
missing_queries += 1;
warn!(query_id = %query_id, "Skipping qrels entry for missing query");
continue;
}
};
let best = match select_best_doc(&entries) {
Some(entry) => entry,
None => continue,
};
let paragraph_slot = match paragraph_index.get(&best.doc_id) {
Some(slot) => *slot,
None => {
missing_docs += 1;
warn!(
query_id = %query_id,
doc_id = %best.doc_id,
"Skipping qrels entry referencing missing corpus document"
);
continue;
}
};
let answer = answer_snippet(&paragraphs[paragraph_slot].context);
let answers = match answer {
Some(snippet) => vec![snippet],
None => {
skipped_answers += 1;
warn!(
query_id = %query_id,
doc_id = %best.doc_id,
"Skipping query because no non-empty answer snippet could be derived"
);
continue;
}
};
let question_id = format!("{}-{query_id}", dataset.source_prefix());
paragraphs[paragraph_slot]
.questions
.push(ConvertedQuestion {
id: question_id,
question: query.text.clone(),
answers,
is_impossible: false,
});
}
if missing_queries + missing_docs + skipped_answers > 0 {
warn!(
missing_queries,
missing_docs, skipped_answers, "Skipped some BEIR qrels entries during conversion"
);
}
Ok(paragraphs)
}
fn resolve_qrels_path(raw_dir: &Path) -> Result<PathBuf> {
let qrels_dir = raw_dir.join("qrels");
let candidates = ["test.tsv", "dev.tsv", "train.tsv"];
for name in candidates {
let candidate = qrels_dir.join(name);
if candidate.exists() {
return Ok(candidate);
}
}
Err(anyhow!(
"No qrels file found under {}; expected one of {:?}",
qrels_dir.display(),
candidates
))
}
fn load_corpus(path: &Path) -> Result<BTreeMap<String, BeirParagraph>> {
let file =
File::open(path).with_context(|| format!("opening BEIR corpus at {}", path.display()))?;
let reader = BufReader::new(file);
let mut corpus = BTreeMap::new();
for (idx, line) in reader.lines().enumerate() {
let raw = line
.with_context(|| format!("reading corpus line {} from {}", idx + 1, path.display()))?;
if raw.trim().is_empty() {
continue;
}
let row: BeirCorpusRow = serde_json::from_str(&raw).with_context(|| {
format!(
"parsing corpus JSON on line {} from {}",
idx + 1,
path.display()
)
})?;
let title = row.title.unwrap_or_else(|| row.id.clone());
let text = row.text.unwrap_or_default();
let context = build_context(&title, &text);
if context.is_empty() {
warn!(doc_id = %row.id, "Skipping empty corpus document");
continue;
}
corpus.insert(row.id, BeirParagraph { title, context });
}
Ok(corpus)
}
fn load_queries(path: &Path) -> Result<BTreeMap<String, BeirQuery>> {
let file = File::open(path)
.with_context(|| format!("opening BEIR queries file at {}", path.display()))?;
let reader = BufReader::new(file);
let mut queries = BTreeMap::new();
for (idx, line) in reader.lines().enumerate() {
let raw = line
.with_context(|| format!("reading query line {} from {}", idx + 1, path.display()))?;
if raw.trim().is_empty() {
continue;
}
let row: BeirQueryRow = serde_json::from_str(&raw).with_context(|| {
format!(
"parsing query JSON on line {} from {}",
idx + 1,
path.display()
)
})?;
queries.insert(
row.id,
BeirQuery {
text: row.text.trim().to_string(),
},
);
}
Ok(queries)
}
fn load_qrels(path: &Path) -> Result<BTreeMap<String, Vec<QrelEntry>>> {
let file =
File::open(path).with_context(|| format!("opening BEIR qrels at {}", path.display()))?;
let reader = BufReader::new(file);
let mut qrels: BTreeMap<String, Vec<QrelEntry>> = BTreeMap::new();
for (idx, line) in reader.lines().enumerate() {
let raw = line
.with_context(|| format!("reading qrels line {} from {}", idx + 1, path.display()))?;
let trimmed = raw.trim();
if trimmed.is_empty() || trimmed.starts_with("query-id") {
continue;
}
let mut parts = trimmed.split_whitespace();
let query_id = parts
.next()
.ok_or_else(|| anyhow!("missing query id on line {}", idx + 1))?;
let doc_id = parts
.next()
.ok_or_else(|| anyhow!("missing document id on line {}", idx + 1))?;
let score_raw = parts
.next()
.ok_or_else(|| anyhow!("missing score on line {}", idx + 1))?;
let score: i32 = score_raw.parse().with_context(|| {
format!(
"parsing qrels score '{}' on line {} from {}",
score_raw,
idx + 1,
path.display()
)
})?;
qrels
.entry(query_id.to_string())
.or_default()
.push(QrelEntry {
doc_id: doc_id.to_string(),
score,
});
}
Ok(qrels)
}
fn select_best_doc(entries: &[QrelEntry]) -> Option<&QrelEntry> {
entries
.iter()
.max_by(|a, b| a.score.cmp(&b.score).then_with(|| b.doc_id.cmp(&a.doc_id)))
}
fn answer_snippet(text: &str) -> Option<String> {
let trimmed = text.trim();
if trimmed.is_empty() {
return None;
}
let snippet: String = trimmed.chars().take(ANSWER_SNIPPET_CHARS).collect();
let snippet = snippet.trim();
if snippet.is_empty() {
None
} else {
Some(snippet.to_string())
}
}
fn build_context(title: &str, text: &str) -> String {
let title = title.trim();
let text = text.trim();
match (title.is_empty(), text.is_empty()) {
(true, true) => String::new(),
(true, false) => text.to_string(),
(false, true) => title.to_string(),
(false, false) => format!("{title}\n\n{text}"),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::tempdir;
#[test]
fn converts_basic_beir_layout() {
let dir = tempdir().unwrap();
let corpus = r#"
{"_id":"d1","title":"Doc 1","text":"Doc one has some text for testing."}
{"_id":"d2","title":"Doc 2","text":"Second document content."}
"#;
let queries = r#"
{"_id":"q1","text":"What is in doc one?"}
"#;
let qrels = "query-id\tcorpus-id\tscore\nq1\td1\t2\n";
fs::write(dir.path().join("corpus.jsonl"), corpus.trim()).unwrap();
fs::write(dir.path().join("queries.jsonl"), queries.trim()).unwrap();
fs::create_dir_all(dir.path().join("qrels")).unwrap();
fs::write(dir.path().join("qrels/test.tsv"), qrels).unwrap();
let paragraphs = convert_beir(dir.path(), DatasetKind::Fever).unwrap();
assert_eq!(paragraphs.len(), 2);
let doc_one = paragraphs
.iter()
.find(|p| p.id == "fever-d1")
.expect("missing paragraph for d1");
assert_eq!(doc_one.questions.len(), 1);
let question = &doc_one.questions[0];
assert_eq!(question.id, "fever-q1");
assert!(!question.answers.is_empty());
assert!(doc_one.context.contains(&question.answers[0]));
let doc_two = paragraphs
.iter()
.find(|p| p.id == "fever-d2")
.expect("missing paragraph for d2");
assert!(doc_two.questions.is_empty());
}
}

View File

@@ -0,0 +1,628 @@
mod beir;
mod nq;
mod squad;
use std::{
collections::{BTreeMap, HashMap},
fs::{self},
path::{Path, PathBuf},
str::FromStr,
};
use anyhow::{anyhow, bail, Context, Result};
use chrono::{DateTime, TimeZone, Utc};
use clap::ValueEnum;
use once_cell::sync::OnceCell;
use serde::{Deserialize, Serialize};
use tracing::warn;
const MANIFEST_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/manifest.yaml");
static DATASET_CATALOG: OnceCell<DatasetCatalog> = OnceCell::new();
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct DatasetCatalog {
datasets: BTreeMap<String, DatasetEntry>,
slices: HashMap<String, SliceLocation>,
default_dataset: String,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct DatasetEntry {
pub metadata: DatasetMetadata,
pub raw_path: PathBuf,
pub converted_path: PathBuf,
pub include_unanswerable: bool,
pub slices: Vec<SliceEntry>,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct SliceEntry {
pub id: String,
pub dataset_id: String,
pub label: String,
pub description: Option<String>,
pub limit: Option<usize>,
pub corpus_limit: Option<usize>,
pub include_unanswerable: Option<bool>,
pub seed: Option<u64>,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct SliceLocation {
dataset_id: String,
slice_index: usize,
}
#[derive(Debug, Deserialize)]
struct ManifestFile {
default_dataset: Option<String>,
datasets: Vec<ManifestDataset>,
}
#[derive(Debug, Deserialize)]
struct ManifestDataset {
id: String,
label: String,
category: String,
#[serde(default)]
entity_suffix: Option<String>,
#[serde(default)]
source_prefix: Option<String>,
raw: String,
converted: String,
#[serde(default)]
include_unanswerable: bool,
#[serde(default)]
slices: Vec<ManifestSlice>,
}
#[derive(Debug, Deserialize)]
struct ManifestSlice {
id: String,
label: String,
#[serde(default)]
description: Option<String>,
#[serde(default)]
limit: Option<usize>,
#[serde(default)]
corpus_limit: Option<usize>,
#[serde(default)]
include_unanswerable: Option<bool>,
#[serde(default)]
seed: Option<u64>,
}
impl DatasetCatalog {
pub fn load() -> Result<Self> {
let manifest_raw = fs::read_to_string(MANIFEST_PATH)
.with_context(|| format!("reading dataset manifest at {}", MANIFEST_PATH))?;
let manifest: ManifestFile = serde_yaml::from_str(&manifest_raw)
.with_context(|| format!("parsing dataset manifest at {}", MANIFEST_PATH))?;
let root = Path::new(env!("CARGO_MANIFEST_DIR"));
let mut datasets = BTreeMap::new();
let mut slices = HashMap::new();
for dataset in manifest.datasets {
let raw_path = resolve_path(root, &dataset.raw);
let converted_path = resolve_path(root, &dataset.converted);
if !raw_path.exists() {
bail!(
"dataset '{}' raw file missing at {}",
dataset.id,
raw_path.display()
);
}
if !converted_path.exists() {
warn!(
"dataset '{}' converted file missing at {}; the next conversion run will regenerate it",
dataset.id,
converted_path.display()
);
}
let metadata = DatasetMetadata {
id: dataset.id.clone(),
label: dataset.label.clone(),
category: dataset.category.clone(),
entity_suffix: dataset
.entity_suffix
.clone()
.unwrap_or_else(|| dataset.label.clone()),
source_prefix: dataset
.source_prefix
.clone()
.unwrap_or_else(|| dataset.id.clone()),
include_unanswerable: dataset.include_unanswerable,
context_token_limit: None,
};
let mut entry_slices = Vec::with_capacity(dataset.slices.len());
for (index, manifest_slice) in dataset.slices.into_iter().enumerate() {
if slices.contains_key(&manifest_slice.id) {
bail!(
"slice '{}' defined multiple times in manifest",
manifest_slice.id
);
}
entry_slices.push(SliceEntry {
id: manifest_slice.id.clone(),
dataset_id: dataset.id.clone(),
label: manifest_slice.label,
description: manifest_slice.description,
limit: manifest_slice.limit,
corpus_limit: manifest_slice.corpus_limit,
include_unanswerable: manifest_slice.include_unanswerable,
seed: manifest_slice.seed,
});
slices.insert(
manifest_slice.id,
SliceLocation {
dataset_id: dataset.id.clone(),
slice_index: index,
},
);
}
datasets.insert(
metadata.id.clone(),
DatasetEntry {
metadata,
raw_path,
converted_path,
include_unanswerable: dataset.include_unanswerable,
slices: entry_slices,
},
);
}
let default_dataset = manifest
.default_dataset
.or_else(|| datasets.keys().next().cloned())
.ok_or_else(|| anyhow!("dataset manifest does not include any datasets"))?;
Ok(Self {
datasets,
slices,
default_dataset,
})
}
pub fn global() -> Result<&'static Self> {
DATASET_CATALOG.get_or_try_init(Self::load)
}
pub fn dataset(&self, id: &str) -> Result<&DatasetEntry> {
self.datasets
.get(id)
.ok_or_else(|| anyhow!("unknown dataset '{id}' in manifest"))
}
#[allow(dead_code)]
pub fn default_dataset(&self) -> Result<&DatasetEntry> {
self.dataset(&self.default_dataset)
}
#[allow(dead_code)]
pub fn slice(&self, slice_id: &str) -> Result<(&DatasetEntry, &SliceEntry)> {
let location = self
.slices
.get(slice_id)
.ok_or_else(|| anyhow!("unknown slice '{slice_id}' in manifest"))?;
let dataset = self
.datasets
.get(&location.dataset_id)
.ok_or_else(|| anyhow!("slice '{slice_id}' references missing dataset"))?;
let slice = dataset
.slices
.get(location.slice_index)
.ok_or_else(|| anyhow!("slice index out of bounds for '{slice_id}'"))?;
Ok((dataset, slice))
}
}
fn resolve_path(root: &Path, value: &str) -> PathBuf {
let path = Path::new(value);
if path.is_absolute() {
path.to_path_buf()
} else {
root.join(path)
}
}
pub fn catalog() -> Result<&'static DatasetCatalog> {
DatasetCatalog::global()
}
fn dataset_entry_for_kind(kind: DatasetKind) -> Result<&'static DatasetEntry> {
let catalog = catalog()?;
catalog.dataset(kind.id())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
pub enum DatasetKind {
SquadV2,
NaturalQuestions,
Beir,
#[value(name = "fever")]
Fever,
#[value(name = "fiqa")]
Fiqa,
#[value(name = "hotpotqa", alias = "hotpot-qa")]
HotpotQa,
#[value(name = "nfcorpus", alias = "nf-corpus")]
Nfcorpus,
#[value(name = "quora")]
Quora,
#[value(name = "trec-covid", alias = "treccovid", alias = "trec_covid")]
TrecCovid,
#[value(name = "scifact")]
Scifact,
#[value(name = "nq-beir", alias = "natural-questions-beir")]
NqBeir,
}
impl DatasetKind {
pub fn id(self) -> &'static str {
match self {
Self::SquadV2 => "squad-v2",
Self::NaturalQuestions => "natural-questions-dev",
Self::Beir => "beir",
Self::Fever => "fever",
Self::Fiqa => "fiqa",
Self::HotpotQa => "hotpotqa",
Self::Nfcorpus => "nfcorpus",
Self::Quora => "quora",
Self::TrecCovid => "trec-covid",
Self::Scifact => "scifact",
Self::NqBeir => "nq-beir",
}
}
pub fn label(self) -> &'static str {
match self {
Self::SquadV2 => "SQuAD v2.0",
Self::NaturalQuestions => "Natural Questions (dev)",
Self::Beir => "BEIR mix",
Self::Fever => "FEVER (BEIR)",
Self::Fiqa => "FiQA-2018 (BEIR)",
Self::HotpotQa => "HotpotQA (BEIR)",
Self::Nfcorpus => "NFCorpus (BEIR)",
Self::Quora => "Quora (IR)",
Self::TrecCovid => "TREC-COVID (BEIR)",
Self::Scifact => "SciFact (BEIR)",
Self::NqBeir => "Natural Questions (BEIR)",
}
}
pub fn category(self) -> &'static str {
match self {
Self::SquadV2 => "SQuAD v2.0",
Self::NaturalQuestions => "Natural Questions",
Self::Beir => "BEIR",
Self::Fever => "FEVER",
Self::Fiqa => "FiQA-2018",
Self::HotpotQa => "HotpotQA",
Self::Nfcorpus => "NFCorpus",
Self::Quora => "Quora",
Self::TrecCovid => "TREC-COVID",
Self::Scifact => "SciFact",
Self::NqBeir => "Natural Questions",
}
}
pub fn entity_suffix(self) -> &'static str {
match self {
Self::SquadV2 => "SQuAD",
Self::NaturalQuestions => "Natural Questions",
Self::Beir => "BEIR",
Self::Fever => "FEVER",
Self::Fiqa => "FiQA",
Self::HotpotQa => "HotpotQA",
Self::Nfcorpus => "NFCorpus",
Self::Quora => "Quora",
Self::TrecCovid => "TREC-COVID",
Self::Scifact => "SciFact",
Self::NqBeir => "Natural Questions",
}
}
pub fn source_prefix(self) -> &'static str {
match self {
Self::SquadV2 => "squad",
Self::NaturalQuestions => "nq",
Self::Beir => "beir",
Self::Fever => "fever",
Self::Fiqa => "fiqa",
Self::HotpotQa => "hotpotqa",
Self::Nfcorpus => "nfcorpus",
Self::Quora => "quora",
Self::TrecCovid => "trec-covid",
Self::Scifact => "scifact",
Self::NqBeir => "nq-beir",
}
}
pub fn default_raw_path(self) -> PathBuf {
dataset_entry_for_kind(self)
.map(|entry| entry.raw_path.clone())
.unwrap_or_else(|err| panic!("dataset manifest missing entry for {:?}: {err}", self))
}
pub fn default_converted_path(self) -> PathBuf {
dataset_entry_for_kind(self)
.map(|entry| entry.converted_path.clone())
.unwrap_or_else(|err| panic!("dataset manifest missing entry for {:?}: {err}", self))
}
}
impl std::fmt::Display for DatasetKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.id())
}
}
impl Default for DatasetKind {
fn default() -> Self {
Self::SquadV2
}
}
impl FromStr for DatasetKind {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_ascii_lowercase().as_str() {
"squad" | "squad-v2" | "squad_v2" => Ok(Self::SquadV2),
"nq" | "natural-questions" | "natural_questions" | "natural-questions-dev" => {
Ok(Self::NaturalQuestions)
}
"beir" => Ok(Self::Beir),
"fever" => Ok(Self::Fever),
"fiqa" | "fiqa-2018" => Ok(Self::Fiqa),
"hotpotqa" | "hotpot-qa" => Ok(Self::HotpotQa),
"nfcorpus" | "nf-corpus" => Ok(Self::Nfcorpus),
"quora" => Ok(Self::Quora),
"trec-covid" | "treccovid" | "trec_covid" => Ok(Self::TrecCovid),
"scifact" => Ok(Self::Scifact),
"nq-beir" | "natural-questions-beir" => Ok(Self::NqBeir),
other => {
anyhow::bail!("unknown dataset '{other}'. Expected one of: squad, natural-questions, beir, fever, fiqa, hotpotqa, nfcorpus, quora, trec-covid, scifact, nq-beir.")
}
}
}
}
pub const BEIR_DATASETS: [DatasetKind; 8] = [
DatasetKind::Fever,
DatasetKind::Fiqa,
DatasetKind::HotpotQa,
DatasetKind::Nfcorpus,
DatasetKind::Quora,
DatasetKind::TrecCovid,
DatasetKind::Scifact,
DatasetKind::NqBeir,
];
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatasetMetadata {
pub id: String,
pub label: String,
pub category: String,
pub entity_suffix: String,
pub source_prefix: String,
#[serde(default)]
pub include_unanswerable: bool,
#[serde(default)]
pub context_token_limit: Option<usize>,
}
impl DatasetMetadata {
pub fn for_kind(
kind: DatasetKind,
include_unanswerable: bool,
context_token_limit: Option<usize>,
) -> Self {
if let Ok(entry) = dataset_entry_for_kind(kind) {
return Self {
id: entry.metadata.id.clone(),
label: entry.metadata.label.clone(),
category: entry.metadata.category.clone(),
entity_suffix: entry.metadata.entity_suffix.clone(),
source_prefix: entry.metadata.source_prefix.clone(),
include_unanswerable,
context_token_limit,
};
}
Self {
id: kind.id().to_string(),
label: kind.label().to_string(),
category: kind.category().to_string(),
entity_suffix: kind.entity_suffix().to_string(),
source_prefix: kind.source_prefix().to_string(),
include_unanswerable,
context_token_limit,
}
}
}
fn default_metadata() -> DatasetMetadata {
DatasetMetadata::for_kind(DatasetKind::default(), false, None)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConvertedDataset {
pub generated_at: DateTime<Utc>,
#[serde(default = "default_metadata")]
pub metadata: DatasetMetadata,
pub source: String,
pub paragraphs: Vec<ConvertedParagraph>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConvertedParagraph {
pub id: String,
pub title: String,
pub context: String,
pub questions: Vec<ConvertedQuestion>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConvertedQuestion {
pub id: String,
pub question: String,
pub answers: Vec<String>,
pub is_impossible: bool,
}
pub fn convert(
raw_path: &Path,
dataset: DatasetKind,
include_unanswerable: bool,
context_token_limit: Option<usize>,
) -> Result<ConvertedDataset> {
let paragraphs = match dataset {
DatasetKind::SquadV2 => squad::convert_squad(raw_path)?,
DatasetKind::NaturalQuestions => {
nq::convert_nq(raw_path, include_unanswerable, context_token_limit)?
}
DatasetKind::Beir => convert_beir_mix(include_unanswerable, context_token_limit)?,
DatasetKind::Fever
| DatasetKind::Fiqa
| DatasetKind::HotpotQa
| DatasetKind::Nfcorpus
| DatasetKind::Quora
| DatasetKind::TrecCovid
| DatasetKind::Scifact
| DatasetKind::NqBeir => beir::convert_beir(raw_path, dataset)?,
};
let metadata_limit = match dataset {
DatasetKind::NaturalQuestions => None,
_ => context_token_limit,
};
let generated_at = match dataset {
DatasetKind::Beir
| DatasetKind::Fever
| DatasetKind::Fiqa
| DatasetKind::HotpotQa
| DatasetKind::Nfcorpus
| DatasetKind::Quora
| DatasetKind::TrecCovid
| DatasetKind::Scifact
| DatasetKind::NqBeir => base_timestamp(),
_ => Utc::now(),
};
let source_label = match dataset {
DatasetKind::Beir => "beir-mix".to_string(),
_ => raw_path.display().to_string(),
};
Ok(ConvertedDataset {
generated_at,
metadata: DatasetMetadata::for_kind(dataset, include_unanswerable, metadata_limit),
source: source_label,
paragraphs,
})
}
fn convert_beir_mix(
include_unanswerable: bool,
_context_token_limit: Option<usize>,
) -> Result<Vec<ConvertedParagraph>> {
if include_unanswerable {
warn!("BEIR mix ignores include_unanswerable flag; all questions are answerable");
}
let mut paragraphs = Vec::new();
for subset in BEIR_DATASETS {
let entry = dataset_entry_for_kind(subset)?;
let subset_paragraphs = beir::convert_beir(&entry.raw_path, subset)?;
paragraphs.extend(subset_paragraphs);
}
Ok(paragraphs)
}
fn ensure_parent(path: &Path) -> Result<()> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("creating parent directory for {}", path.display()))?;
}
Ok(())
}
pub fn write_converted(dataset: &ConvertedDataset, converted_path: &Path) -> Result<()> {
ensure_parent(converted_path)?;
let json =
serde_json::to_string_pretty(dataset).context("serialising converted dataset to JSON")?;
fs::write(converted_path, json)
.with_context(|| format!("writing converted dataset to {}", converted_path.display()))
}
pub fn read_converted(converted_path: &Path) -> Result<ConvertedDataset> {
let raw = fs::read_to_string(converted_path)
.with_context(|| format!("reading converted dataset at {}", converted_path.display()))?;
let mut dataset: ConvertedDataset = serde_json::from_str(&raw)
.with_context(|| format!("parsing converted dataset at {}", converted_path.display()))?;
if dataset.metadata.id.trim().is_empty() {
dataset.metadata = default_metadata();
}
if dataset.source.is_empty() {
dataset.source = converted_path.display().to_string();
}
Ok(dataset)
}
pub fn ensure_converted(
dataset_kind: DatasetKind,
raw_path: &Path,
converted_path: &Path,
force: bool,
include_unanswerable: bool,
context_token_limit: Option<usize>,
) -> Result<ConvertedDataset> {
if force || !converted_path.exists() {
let dataset = convert(
raw_path,
dataset_kind,
include_unanswerable,
context_token_limit,
)?;
write_converted(&dataset, converted_path)?;
return Ok(dataset);
}
match read_converted(converted_path) {
Ok(dataset)
if dataset.metadata.id == dataset_kind.id()
&& dataset.metadata.include_unanswerable == include_unanswerable
&& dataset.metadata.context_token_limit == context_token_limit =>
{
Ok(dataset)
}
_ => {
let dataset = convert(
raw_path,
dataset_kind,
include_unanswerable,
context_token_limit,
)?;
write_converted(&dataset, converted_path)?;
Ok(dataset)
}
}
}
pub fn base_timestamp() -> DateTime<Utc> {
Utc.with_ymd_and_hms(2023, 1, 1, 0, 0, 0).unwrap()
}

View File

@@ -0,0 +1,234 @@
use std::{
collections::BTreeSet,
fs::File,
io::{BufRead, BufReader},
path::Path,
};
use anyhow::{Context, Result};
use serde::Deserialize;
use tracing::warn;
use super::{ConvertedParagraph, ConvertedQuestion};
pub fn convert_nq(
raw_path: &Path,
include_unanswerable: bool,
_context_token_limit: Option<usize>,
) -> Result<Vec<ConvertedParagraph>> {
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct NqExample {
question_text: String,
document_title: String,
example_id: i64,
document_tokens: Vec<NqToken>,
long_answer_candidates: Vec<NqLongAnswerCandidate>,
annotations: Vec<NqAnnotation>,
}
#[derive(Debug, Deserialize)]
struct NqToken {
token: String,
#[serde(default)]
html_token: bool,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct NqLongAnswerCandidate {
start_token: i32,
end_token: i32,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct NqAnnotation {
short_answers: Vec<NqShortAnswer>,
#[serde(default)]
yes_no_answer: String,
long_answer: NqLongAnswer,
}
#[derive(Debug, Deserialize)]
struct NqShortAnswer {
start_token: i32,
end_token: i32,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct NqLongAnswer {
candidate_index: i32,
}
fn join_tokens(tokens: &[NqToken], start: usize, end: usize) -> String {
let mut buffer = String::new();
let end = end.min(tokens.len());
for token in tokens.iter().skip(start).take(end.saturating_sub(start)) {
if token.html_token {
continue;
}
let text = token.token.trim();
if text.is_empty() {
continue;
}
let attach = matches!(
text,
"," | "." | "!" | "?" | ";" | ":" | ")" | "]" | "}" | "%" | "" | "..."
) || text.starts_with('\'')
|| text == "n't"
|| text == "'s"
|| text == "'re"
|| text == "'ve"
|| text == "'d"
|| text == "'ll";
if buffer.is_empty() || attach {
buffer.push_str(text);
} else {
buffer.push(' ');
buffer.push_str(text);
}
}
buffer.trim().to_string()
}
let file = File::open(raw_path).with_context(|| {
format!(
"opening Natural Questions dataset at {}",
raw_path.display()
)
})?;
let reader = BufReader::new(file);
let mut paragraphs = Vec::new();
for (line_idx, line) in reader.lines().enumerate() {
let line = line.with_context(|| {
format!(
"reading Natural Questions line {} from {}",
line_idx + 1,
raw_path.display()
)
})?;
if line.trim().is_empty() {
continue;
}
let example: NqExample = serde_json::from_str(&line).with_context(|| {
format!(
"parsing Natural Questions JSON (line {}) at {}",
line_idx + 1,
raw_path.display()
)
})?;
let mut answer_texts: Vec<String> = Vec::new();
let mut short_answer_texts: Vec<String> = Vec::new();
let mut has_short_or_yesno = false;
let mut has_short_answer = false;
for annotation in &example.annotations {
for short in &annotation.short_answers {
if short.start_token < 0 || short.end_token <= short.start_token {
continue;
}
let start = short.start_token as usize;
let end = short.end_token as usize;
if start >= example.document_tokens.len() || end > example.document_tokens.len() {
continue;
}
let text = join_tokens(&example.document_tokens, start, end);
if !text.is_empty() {
answer_texts.push(text.clone());
short_answer_texts.push(text);
has_short_or_yesno = true;
has_short_answer = true;
}
}
match annotation
.yes_no_answer
.trim()
.to_ascii_lowercase()
.as_str()
{
"yes" => {
answer_texts.push("yes".to_string());
has_short_or_yesno = true;
}
"no" => {
answer_texts.push("no".to_string());
has_short_or_yesno = true;
}
_ => {}
}
}
let mut answers = dedupe_strings(answer_texts);
let is_unanswerable = !has_short_or_yesno || answers.is_empty();
if is_unanswerable {
if !include_unanswerable {
continue;
}
answers.clear();
}
let paragraph_id = format!("nq-{}", example.example_id);
let question_id = format!("nq-{}", example.example_id);
let context = join_tokens(&example.document_tokens, 0, example.document_tokens.len());
if context.is_empty() {
continue;
}
if has_short_answer && !short_answer_texts.is_empty() {
let normalized_context = context.to_ascii_lowercase();
let missing_answer = short_answer_texts.iter().any(|answer| {
let needle = answer.trim().to_ascii_lowercase();
!needle.is_empty() && !normalized_context.contains(&needle)
});
if missing_answer {
warn!(
question_id = %question_id,
"Skipping Natural Questions example because answers were not found in the assembled context"
);
continue;
}
}
if !include_unanswerable && (!has_short_answer || short_answer_texts.is_empty()) {
// yes/no-only questions are excluded by default unless --llm-mode is used
continue;
}
let question = ConvertedQuestion {
id: question_id,
question: example.question_text.trim().to_string(),
answers,
is_impossible: is_unanswerable,
};
paragraphs.push(ConvertedParagraph {
id: paragraph_id,
title: example.document_title.trim().to_string(),
context,
questions: vec![question],
});
}
Ok(paragraphs)
}
fn dedupe_strings<I>(values: I) -> Vec<String>
where
I: IntoIterator<Item = String>,
{
let mut set = BTreeSet::new();
for value in values {
let trimmed = value.trim();
if !trimmed.is_empty() {
set.insert(trimmed.to_string());
}
}
set.into_iter().collect()
}

View File

@@ -0,0 +1,107 @@
use std::{collections::BTreeSet, fs, path::Path};
use anyhow::{Context, Result};
use serde::Deserialize;
use super::{ConvertedParagraph, ConvertedQuestion};
pub fn convert_squad(raw_path: &Path) -> Result<Vec<ConvertedParagraph>> {
#[derive(Debug, Deserialize)]
struct SquadDataset {
data: Vec<SquadArticle>,
}
#[derive(Debug, Deserialize)]
struct SquadArticle {
title: String,
paragraphs: Vec<SquadParagraph>,
}
#[derive(Debug, Deserialize)]
struct SquadParagraph {
context: String,
qas: Vec<SquadQuestion>,
}
#[derive(Debug, Deserialize)]
struct SquadQuestion {
id: String,
question: String,
answers: Vec<SquadAnswer>,
#[serde(default)]
is_impossible: bool,
}
#[derive(Debug, Deserialize)]
struct SquadAnswer {
text: String,
}
let raw = fs::read_to_string(raw_path)
.with_context(|| format!("reading raw SQuAD dataset at {}", raw_path.display()))?;
let parsed: SquadDataset = serde_json::from_str(&raw)
.with_context(|| format!("parsing SQuAD dataset at {}", raw_path.display()))?;
let mut paragraphs = Vec::new();
for (article_idx, article) in parsed.data.into_iter().enumerate() {
for (paragraph_idx, paragraph) in article.paragraphs.into_iter().enumerate() {
let mut questions = Vec::new();
for qa in paragraph.qas {
let answers = dedupe_strings(qa.answers.into_iter().map(|answer| answer.text));
questions.push(ConvertedQuestion {
id: qa.id,
question: qa.question.trim().to_string(),
answers,
is_impossible: qa.is_impossible,
});
}
let paragraph_id =
format!("{}-{}", slugify(&article.title, article_idx), paragraph_idx);
paragraphs.push(ConvertedParagraph {
id: paragraph_id,
title: article.title.trim().to_string(),
context: paragraph.context.trim().to_string(),
questions,
});
}
}
Ok(paragraphs)
}
fn dedupe_strings<I>(values: I) -> Vec<String>
where
I: IntoIterator<Item = String>,
{
let mut set = BTreeSet::new();
for value in values {
let trimmed = value.trim();
if !trimmed.is_empty() {
set.insert(trimmed.to_string());
}
}
set.into_iter().collect()
}
fn slugify(input: &str, fallback_idx: usize) -> String {
let mut slug = String::new();
let mut last_dash = false;
for ch in input.chars() {
let c = ch.to_ascii_lowercase();
if c.is_ascii_alphanumeric() {
slug.push(c);
last_dash = false;
} else if !last_dash {
slug.push('-');
last_dash = true;
}
}
slug = slug.trim_matches('-').to_string();
if slug.is_empty() {
slug = format!("article-{fallback_idx}");
}
slug
}

View File

@@ -0,0 +1,109 @@
use anyhow::{Context, Result};
use common::storage::{db::SurrealDbClient, indexes::ensure_runtime_indexes};
use tracing::info;
// Helper functions for index management during namespace reseed
pub async fn remove_all_indexes(db: &SurrealDbClient) -> Result<()> {
let _ = db;
info!("Removing ALL indexes before namespace reseed (no-op placeholder)");
Ok(())
}
pub async fn recreate_indexes(db: &SurrealDbClient, dimension: usize) -> Result<()> {
info!("Recreating ALL indexes after namespace reseed via shared runtime helper");
ensure_runtime_indexes(db, dimension)
.await
.context("creating runtime indexes")
}
pub async fn reset_namespace(db: &SurrealDbClient, namespace: &str, database: &str) -> Result<()> {
let query = format!(
"REMOVE NAMESPACE {ns};
DEFINE NAMESPACE {ns};
DEFINE DATABASE {db};",
ns = namespace,
db = database
);
db.client
.query(query)
.await
.context("resetting SurrealDB namespace")?;
db.client
.use_ns(namespace)
.use_db(database)
.await
.context("selecting namespace/database after reset")?;
Ok(())
}
// Test helper to force index dimension change
pub async fn change_embedding_length_in_hnsw_indexes(
db: &SurrealDbClient,
dimension: usize,
) -> Result<()> {
recreate_indexes(db, dimension).await
}
#[cfg(test)]
mod tests {
use super::*;
use serde::Deserialize;
use uuid::Uuid;
#[derive(Debug, Deserialize)]
struct FooRow {
label: String,
}
#[tokio::test]
async fn reset_namespace_drops_existing_rows() {
let namespace = format!("reset_ns_{}", Uuid::new_v4().simple());
let database = format!("reset_db_{}", Uuid::new_v4().simple());
let db = SurrealDbClient::memory(&namespace, &database)
.await
.expect("in-memory db");
db.client
.query(
"DEFINE TABLE foo SCHEMALESS;
CREATE foo:foo SET label = 'before';",
)
.await
.expect("seed namespace")
.check()
.expect("seed response");
let mut before = db
.client
.query("SELECT * FROM foo")
.await
.expect("select before reset");
let existing: Vec<FooRow> = before.take(0).expect("rows before reset");
assert_eq!(existing.len(), 1);
assert_eq!(existing[0].label, "before");
reset_namespace(&db, &namespace, &database)
.await
.expect("namespace reset");
match db.client.query("SELECT * FROM foo").await {
Ok(mut response) => {
let rows: Vec<FooRow> = response.take(0).unwrap_or_default();
assert!(
rows.is_empty(),
"reset namespace should drop rows, found {:?}",
rows
);
}
Err(error) => {
let message = error.to_string();
assert!(
message.to_ascii_lowercase().contains("table")
|| message.to_ascii_lowercase().contains("namespace")
|| message.to_ascii_lowercase().contains("foo"),
"unexpected error after namespace reset: {message}"
);
}
}
}
}

128
evaluations/src/eval.rs Normal file
View File

@@ -0,0 +1,128 @@
//! Evaluation utilities module - re-exports from focused submodules.
// Re-export types from the root types module
pub use crate::types::*;
// Re-export from focused modules at crate root (crate-internal only)
pub(crate) use crate::cases::{cases_from_manifest, SeededCase};
pub(crate) use crate::namespace::{
can_reuse_namespace, connect_eval_db, default_database, default_namespace, ensure_eval_user,
record_namespace_state,
};
pub(crate) use crate::settings::{enforce_system_settings, load_or_init_system_settings};
use std::path::Path;
use anyhow::{Context, Result};
use common::storage::db::SurrealDbClient;
use tokio::io::AsyncWriteExt;
use tracing::info;
use crate::{
args::{self, Config},
datasets::ConvertedDataset,
slice::{self},
};
/// Grow the slice ledger to contain the target number of cases.
pub async fn grow_slice(dataset: &ConvertedDataset, config: &Config) -> Result<()> {
let ledger_limit = ledger_target(config);
let slice_settings = slice::slice_config_with_limit(config, ledger_limit);
let slice =
slice::resolve_slice(dataset, &slice_settings).context("resolving dataset slice")?;
info!(
slice = slice.manifest.slice_id.as_str(),
cases = slice.manifest.case_count,
positives = slice.manifest.positive_paragraphs,
negatives = slice.manifest.negative_paragraphs,
total_paragraphs = slice.manifest.total_paragraphs,
"Slice ledger ready"
);
println!(
"Slice `{}` now contains {} questions ({} positives, {} negatives)",
slice.manifest.slice_id,
slice.manifest.case_count,
slice.manifest.positive_paragraphs,
slice.manifest.negative_paragraphs
);
Ok(())
}
pub(crate) fn ledger_target(config: &Config) -> Option<usize> {
match (config.slice_grow, config.limit) {
(Some(grow), Some(limit)) => Some(limit.max(grow)),
(Some(grow), None) => Some(grow),
(None, limit) => limit,
}
}
pub(crate) async fn write_chunk_diagnostics(path: &Path, cases: &[CaseDiagnostics]) -> Result<()> {
args::ensure_parent(path)?;
let mut file = tokio::fs::File::create(path)
.await
.with_context(|| format!("creating diagnostics file {}", path.display()))?;
for case in cases {
let line = serde_json::to_vec(case).context("serialising chunk diagnostics entry")?;
file.write_all(&line).await?;
file.write_all(b"\n").await?;
}
file.flush().await?;
Ok(())
}
pub(crate) async fn warm_hnsw_cache(db: &SurrealDbClient, dimension: usize) -> Result<()> {
// Create a dummy embedding for cache warming
let dummy_embedding: Vec<f32> = (0..dimension).map(|i| (i as f32).sin()).collect();
info!("Warming HNSW caches with sample queries");
// Warm up chunk embedding index - just query the embedding table to load HNSW index
let _ = db
.client
.query(
r#"SELECT chunk_id
FROM text_chunk_embedding
WHERE embedding <|1,1|> $embedding
LIMIT 5"#,
)
.bind(("embedding", dummy_embedding.clone()))
.await
.context("warming text chunk HNSW cache")?;
// Warm up entity embedding index
let _ = db
.client
.query(
r#"SELECT entity_id
FROM knowledge_entity_embedding
WHERE embedding <|1,1|> $embedding
LIMIT 5"#,
)
.bind(("embedding", dummy_embedding))
.await
.context("warming knowledge entity HNSW cache")?;
info!("HNSW cache warming completed");
Ok(())
}
use chrono::{DateTime, SecondsFormat, Utc};
pub fn format_timestamp(timestamp: &DateTime<Utc>) -> String {
timestamp.to_rfc3339_opts(SecondsFormat::Secs, true)
}
pub(crate) fn sanitize_model_code(code: &str) -> String {
code.chars()
.map(|ch| {
if ch.is_ascii_alphanumeric() {
ch.to_ascii_lowercase()
} else {
'_'
}
})
.collect()
}
// Re-export run_evaluation from the pipeline module at crate root
pub use crate::pipeline::run_evaluation;

View File

@@ -0,0 +1,184 @@
use std::{
collections::HashMap,
fs,
path::{Path, PathBuf},
};
use anyhow::{anyhow, Context, Result};
use common::storage::{db::SurrealDbClient, types::text_chunk::TextChunk};
use crate::{args::Config, eval::connect_eval_db, corpus, snapshot::DbSnapshotState};
pub async fn inspect_question(config: &Config) -> Result<()> {
let question_id = config
.inspect_question
.as_ref()
.ok_or_else(|| anyhow!("--inspect-question is required for inspection mode"))?;
let manifest_path = config
.inspect_manifest
.as_ref()
.ok_or_else(|| anyhow!("--inspect-manifest must be provided for inspection mode"))?;
let manifest = load_manifest(manifest_path)?;
let chunk_lookup = build_chunk_lookup(&manifest);
let question = manifest
.questions
.iter()
.find(|q| q.question_id == *question_id)
.ok_or_else(|| {
anyhow!(
"question '{}' not found in manifest {}",
question_id,
manifest_path.display()
)
})?;
println!("Question: {}", question.question_text);
println!("Answers: {:?}", question.answers);
println!(
"matching_chunk_ids ({}):",
question.matching_chunk_ids.len()
);
let mut missing_in_manifest = Vec::new();
for chunk_id in &question.matching_chunk_ids {
if let Some(entry) = chunk_lookup.get(chunk_id) {
println!(
" - {} (paragraph: {})\n snippet: {}",
chunk_id, entry.paragraph_title, entry.snippet
);
} else {
println!(" - {} (missing from manifest)", chunk_id);
missing_in_manifest.push(chunk_id.clone());
}
}
if missing_in_manifest.is_empty() {
println!("All matching_chunk_ids are present in the ingestion manifest");
} else {
println!(
"Missing chunk IDs in manifest {}: {:?}",
manifest_path.display(),
missing_in_manifest
);
}
let db_state_path = config
.database
.inspect_db_state
.clone()
.unwrap_or_else(|| default_state_path(config, &manifest));
if let Some(state) = load_db_state(&db_state_path)? {
if let (Some(ns), Some(db_name)) = (state.namespace.as_deref(), state.database.as_deref()) {
match connect_eval_db(config, ns, db_name).await {
Ok(db) => match verify_chunks_in_db(&db, &question.matching_chunk_ids).await? {
MissingChunks::None => println!(
"All matching_chunk_ids exist in namespace '{}', database '{}'",
ns, db_name
),
MissingChunks::Missing(list) => println!(
"Missing chunks in namespace '{}', database '{}': {:?}",
ns, db_name, list
),
},
Err(err) => {
println!(
"Failed to connect to SurrealDB namespace '{}' / database '{}': {err}",
ns, db_name
);
}
}
} else {
println!(
"State file {} is missing namespace/database fields; skipping live DB validation",
db_state_path.display()
);
}
} else {
println!(
"State file {} not found; skipping live DB validation",
db_state_path.display()
);
}
Ok(())
}
struct ChunkEntry {
paragraph_title: String,
snippet: String,
}
fn load_manifest(path: &Path) -> Result<corpus::CorpusManifest> {
let bytes =
fs::read(path).with_context(|| format!("reading ingestion manifest {}", path.display()))?;
serde_json::from_slice(&bytes)
.with_context(|| format!("parsing ingestion manifest {}", path.display()))
}
fn build_chunk_lookup(manifest: &corpus::CorpusManifest) -> HashMap<String, ChunkEntry> {
let mut lookup = HashMap::new();
for paragraph in &manifest.paragraphs {
for chunk in &paragraph.chunks {
let snippet = chunk
.chunk
.chunk
.chars()
.take(160)
.collect::<String>()
.replace('\n', " ");
lookup.insert(
chunk.chunk.id.clone(),
ChunkEntry {
paragraph_title: paragraph.title.clone(),
snippet,
},
);
}
}
lookup
}
fn default_state_path(config: &Config, manifest: &corpus::CorpusManifest) -> PathBuf {
config
.cache_dir
.join("snapshots")
.join(&manifest.metadata.dataset_id)
.join(&manifest.metadata.slice_id)
.join("db/state.json")
}
fn load_db_state(path: &Path) -> Result<Option<DbSnapshotState>> {
if !path.exists() {
return Ok(None);
}
let bytes = fs::read(path).with_context(|| format!("reading db state {}", path.display()))?;
let state = serde_json::from_slice(&bytes)
.with_context(|| format!("parsing db state {}", path.display()))?;
Ok(Some(state))
}
enum MissingChunks {
None,
Missing(Vec<String>),
}
async fn verify_chunks_in_db(db: &SurrealDbClient, chunk_ids: &[String]) -> Result<MissingChunks> {
let mut missing = Vec::new();
for chunk_id in chunk_ids {
let exists = db
.get_item::<TextChunk>(chunk_id)
.await
.with_context(|| format!("fetching text_chunk {}", chunk_id))?
.is_some();
if !exists {
missing.push(chunk_id.clone());
}
}
if missing.is_empty() {
Ok(MissingChunks::None)
} else {
Ok(MissingChunks::Missing(missing))
}
}

247
evaluations/src/main.rs Normal file
View File

@@ -0,0 +1,247 @@
mod args;
mod cache;
mod cases;
mod corpus;
mod datasets;
mod db_helpers;
mod eval;
mod inspection;
mod namespace;
mod openai;
mod perf;
mod pipeline;
mod report;
mod settings;
mod slice;
mod snapshot;
mod types;
use anyhow::Context;
use tokio::runtime::Builder;
use tracing::info;
use tracing_subscriber::{fmt, EnvFilter};
/// Configure SurrealDB environment variables for optimal performance
fn configure_surrealdb_performance(cpu_count: usize) {
// Set environment variables only if they're not already set
let indexing_batch_size = std::env::var("SURREAL_INDEXING_BATCH_SIZE")
.unwrap_or_else(|_| (cpu_count * 2).to_string());
std::env::set_var("SURREAL_INDEXING_BATCH_SIZE", indexing_batch_size);
let max_order_queue = std::env::var("SURREAL_MAX_ORDER_LIMIT_PRIORITY_QUEUE_SIZE")
.unwrap_or_else(|_| (cpu_count * 4).to_string());
std::env::set_var(
"SURREAL_MAX_ORDER_LIMIT_PRIORITY_QUEUE_SIZE",
max_order_queue,
);
let websocket_concurrent = std::env::var("SURREAL_WEBSOCKET_MAX_CONCURRENT_REQUESTS")
.unwrap_or_else(|_| cpu_count.to_string());
std::env::set_var(
"SURREAL_WEBSOCKET_MAX_CONCURRENT_REQUESTS",
websocket_concurrent,
);
let websocket_buffer = std::env::var("SURREAL_WEBSOCKET_RESPONSE_BUFFER_SIZE")
.unwrap_or_else(|_| (cpu_count * 8).to_string());
std::env::set_var("SURREAL_WEBSOCKET_RESPONSE_BUFFER_SIZE", websocket_buffer);
let transaction_cache = std::env::var("SURREAL_TRANSACTION_CACHE_SIZE")
.unwrap_or_else(|_| (cpu_count * 16).to_string());
std::env::set_var("SURREAL_TRANSACTION_CACHE_SIZE", transaction_cache);
info!(
indexing_batch_size = %std::env::var("SURREAL_INDEXING_BATCH_SIZE").unwrap(),
max_order_queue = %std::env::var("SURREAL_MAX_ORDER_LIMIT_PRIORITY_QUEUE_SIZE").unwrap(),
websocket_concurrent = %std::env::var("SURREAL_WEBSOCKET_MAX_CONCURRENT_REQUESTS").unwrap(),
websocket_buffer = %std::env::var("SURREAL_WEBSOCKET_RESPONSE_BUFFER_SIZE").unwrap(),
transaction_cache = %std::env::var("SURREAL_TRANSACTION_CACHE_SIZE").unwrap(),
"Configured SurrealDB performance variables"
);
}
fn main() -> anyhow::Result<()> {
// Create an explicit multi-threaded runtime with optimized configuration
let runtime = Builder::new_multi_thread()
.enable_all()
.worker_threads(std::thread::available_parallelism()?.get())
.max_blocking_threads(std::thread::available_parallelism()?.get())
.thread_stack_size(10 * 1024 * 1024) // 10MiB stack size
.thread_name("eval-retrieval-worker")
.build()
.context("failed to create tokio runtime")?;
runtime.block_on(async_main())
}
async fn async_main() -> anyhow::Result<()> {
// Log runtime configuration
let cpu_count = std::thread::available_parallelism()?.get();
info!(
cpu_cores = cpu_count,
worker_threads = cpu_count,
blocking_threads = cpu_count,
thread_stack_size = "10MiB",
"Started multi-threaded tokio runtime"
);
// Configure SurrealDB environment variables for better performance
configure_surrealdb_performance(cpu_count);
let filter = std::env::var("RUST_LOG").unwrap_or_else(|_| "info".to_string());
let _ = fmt()
.with_env_filter(EnvFilter::try_new(&filter).unwrap_or_else(|_| EnvFilter::new("info")))
.try_init();
let parsed = args::parse()?;
// Clap handles help automatically, so we don't need to check for it manually
if parsed.config.inspect_question.is_some() {
inspection::inspect_question(&parsed.config).await?;
return Ok(());
}
let dataset_kind = parsed.config.dataset;
if parsed.config.convert_only {
info!(
dataset = dataset_kind.id(),
"Starting dataset conversion only run"
);
let dataset = crate::datasets::convert(
parsed.config.raw_dataset_path.as_path(),
dataset_kind,
parsed.config.llm_mode,
parsed.config.context_token_limit(),
)
.with_context(|| {
format!(
"converting {} dataset at {}",
dataset_kind.label(),
parsed.config.raw_dataset_path.display()
)
})?;
crate::datasets::write_converted(&dataset, parsed.config.converted_dataset_path.as_path())
.with_context(|| {
format!(
"writing converted dataset to {}",
parsed.config.converted_dataset_path.display()
)
})?;
println!(
"Converted dataset written to {}",
parsed.config.converted_dataset_path.display()
);
return Ok(());
}
info!(dataset = dataset_kind.id(), "Preparing converted dataset");
let dataset = crate::datasets::ensure_converted(
dataset_kind,
parsed.config.raw_dataset_path.as_path(),
parsed.config.converted_dataset_path.as_path(),
parsed.config.force_convert,
parsed.config.llm_mode,
parsed.config.context_token_limit(),
)
.with_context(|| {
format!(
"preparing converted dataset at {}",
parsed.config.converted_dataset_path.display()
)
})?;
info!(
questions = dataset
.paragraphs
.iter()
.map(|p| p.questions.len())
.sum::<usize>(),
paragraphs = dataset.paragraphs.len(),
dataset = dataset.metadata.id.as_str(),
"Dataset ready"
);
if parsed.config.slice_grow.is_some() {
eval::grow_slice(&dataset, &parsed.config)
.await
.context("growing slice ledger")?;
return Ok(());
}
info!("Running retrieval evaluation");
let summary = eval::run_evaluation(&dataset, &parsed.config)
.await
.context("running retrieval evaluation")?;
let report = report::write_reports(
&summary,
parsed.config.report_dir.as_path(),
parsed.config.summary_sample,
)
.with_context(|| format!("writing reports to {}", parsed.config.report_dir.display()))?;
let perf_mirrors = perf::mirror_perf_outputs(
&report.record,
&summary,
parsed.config.report_dir.as_path(),
parsed.config.perf_log_json.as_deref(),
parsed.config.perf_log_dir.as_deref(),
)
.with_context(|| {
format!(
"writing perf mirrors under {}",
parsed.config.report_dir.display()
)
})?;
let perf_note = if perf_mirrors.is_empty() {
String::new()
} else {
format!(
" | Perf mirrors: {}",
perf_mirrors
.iter()
.map(|path| path.display().to_string())
.collect::<Vec<_>>()
.join(", ")
)
};
if summary.llm_cases > 0 {
println!(
"[{}] Retrieval Precision@{k}: {precision:.3} ({correct}/{retrieval_total}) + LLM: {llm_answered}/{llm_total} ({llm_precision:.3}) → JSON: {json} | Markdown: {md} | History: {history}{perf_note}",
summary.dataset_label,
k = summary.k,
precision = summary.precision,
correct = summary.correct,
retrieval_total = summary.retrieval_cases,
llm_answered = summary.llm_answered,
llm_total = summary.llm_cases,
llm_precision = summary.llm_precision,
json = report.paths.json.display(),
md = report.paths.markdown.display(),
history = report.history_path.display(),
perf_note = perf_note,
);
} else {
println!(
"[{}] Retrieval Precision@{k}: {precision:.3} ({correct}/{retrieval_total}) → JSON: {json} | Markdown: {md} | History: {history}{perf_note}",
summary.dataset_label,
k = summary.k,
precision = summary.precision,
correct = summary.correct,
retrieval_total = summary.retrieval_cases,
json = report.paths.json.display(),
md = report.paths.markdown.display(),
history = report.history_path.display(),
perf_note = perf_note,
);
}
if parsed.config.perf_log_console {
perf::print_console_summary(&report.record);
}
Ok(())
}

View File

@@ -0,0 +1,224 @@
//! Database namespace management utilities.
use anyhow::{anyhow, Context, Result};
use chrono::Utc;
use common::storage::{db::SurrealDbClient, types::user::User, types::StoredObject};
use serde::Deserialize;
use tracing::{info, warn};
use crate::{
args::Config,
datasets,
snapshot::{self, DbSnapshotState},
};
/// Connect to the evaluation database with fallback auth strategies.
pub(crate) async fn connect_eval_db(
config: &Config,
namespace: &str,
database: &str,
) -> Result<SurrealDbClient> {
match SurrealDbClient::new(
&config.database.db_endpoint,
&config.database.db_username,
&config.database.db_password,
namespace,
database,
)
.await
{
Ok(client) => {
info!(
endpoint = %config.database.db_endpoint,
namespace,
database,
auth = "root",
"Connected to SurrealDB"
);
Ok(client)
}
Err(root_err) => {
info!(
endpoint = %config.database.db_endpoint,
namespace,
database,
"Root authentication failed; trying namespace-level auth"
);
let namespace_client = SurrealDbClient::new_with_namespace_user(
&config.database.db_endpoint,
namespace,
&config.database.db_username,
&config.database.db_password,
database,
)
.await
.map_err(|ns_err| {
anyhow!(
"failed to connect to SurrealDB via root ({root_err}) or namespace ({ns_err}) credentials"
)
})?;
info!(
endpoint = %config.database.db_endpoint,
namespace,
database,
auth = "namespace",
"Connected to SurrealDB"
);
Ok(namespace_client)
}
}
}
/// Check if the namespace contains any corpus data.
pub(crate) async fn namespace_has_corpus(db: &SurrealDbClient) -> Result<bool> {
#[derive(Deserialize)]
struct CountRow {
count: i64,
}
let mut response = db
.client
.query("SELECT count() AS count FROM text_chunk")
.await
.context("checking namespace corpus state")?;
let rows: Vec<CountRow> = response.take(0).unwrap_or_default();
Ok(rows.first().map(|row| row.count).unwrap_or(0) > 0)
}
/// Determine if we can reuse an existing namespace based on cached state.
pub(crate) async fn can_reuse_namespace(
db: &SurrealDbClient,
descriptor: &snapshot::Descriptor,
namespace: &str,
database: &str,
dataset_id: &str,
slice_id: &str,
ingestion_fingerprint: &str,
slice_case_count: usize,
) -> Result<bool> {
let state = match descriptor.load_db_state().await? {
Some(state) => state,
None => {
info!("No namespace state recorded; reseeding corpus from cached shards");
return Ok(false);
}
};
if state.slice_case_count != slice_case_count {
info!(
requested_cases = slice_case_count,
stored_cases = state.slice_case_count,
"Skipping live namespace reuse; cached state does not match requested window"
);
return Ok(false);
}
if state.dataset_id != dataset_id
|| state.slice_id != slice_id
|| state.ingestion_fingerprint != ingestion_fingerprint
|| state.namespace.as_deref() != Some(namespace)
|| state.database.as_deref() != Some(database)
{
info!(
namespace,
database, "Cached namespace metadata mismatch; rebuilding corpus from ingestion cache"
);
return Ok(false);
}
if namespace_has_corpus(db).await? {
Ok(true)
} else {
info!(
namespace,
database,
"Namespace metadata matches but tables are empty; reseeding from ingestion cache"
);
Ok(false)
}
}
/// Record the current namespace state to allow future reuse checks.
pub(crate) async fn record_namespace_state(
descriptor: &snapshot::Descriptor,
dataset_id: &str,
slice_id: &str,
ingestion_fingerprint: &str,
namespace: &str,
database: &str,
slice_case_count: usize,
) {
let state = DbSnapshotState {
dataset_id: dataset_id.to_string(),
slice_id: slice_id.to_string(),
ingestion_fingerprint: ingestion_fingerprint.to_string(),
snapshot_hash: descriptor.metadata_hash().to_string(),
updated_at: Utc::now(),
namespace: Some(namespace.to_string()),
database: Some(database.to_string()),
slice_case_count,
};
if let Err(err) = descriptor.store_db_state(&state).await {
warn!(error = %err, "Failed to record namespace state");
}
}
fn sanitize_identifier(input: &str) -> String {
let mut cleaned: String = input
.chars()
.map(|ch| {
if ch.is_ascii_alphanumeric() {
ch.to_ascii_lowercase()
} else {
'_'
}
})
.collect();
if cleaned.is_empty() {
cleaned.push('x');
}
if cleaned.len() > 64 {
cleaned.truncate(64);
}
cleaned
}
/// Generate a default namespace name based on dataset and limit.
pub(crate) fn default_namespace(dataset_id: &str, limit: Option<usize>) -> String {
let dataset_component = sanitize_identifier(dataset_id);
let limit_component = match limit {
Some(value) if value > 0 => format!("limit{}", value),
_ => "all".to_string(),
};
format!("eval_{}_{}", dataset_component, limit_component)
}
/// Generate the default database name for evaluations.
pub(crate) fn default_database() -> String {
"retrieval_eval".to_string()
}
/// Ensure the evaluation user exists in the database.
pub(crate) async fn ensure_eval_user(db: &SurrealDbClient) -> Result<User> {
let timestamp = datasets::base_timestamp();
let user = User {
id: "eval-user".to_string(),
created_at: timestamp,
updated_at: timestamp,
email: "eval-retrieval@minne.dev".to_string(),
password: "not-used".to_string(),
anonymous: false,
api_key: None,
admin: false,
timezone: "UTC".to_string(),
};
if let Some(existing) = db.get_item::<User>(&user.get_id()).await? {
return Ok(existing);
}
db.store_item(user.clone())
.await
.context("storing evaluation user")?;
Ok(user)
}

16
evaluations/src/openai.rs Normal file
View File

@@ -0,0 +1,16 @@
use anyhow::{Context, Result};
use async_openai::{config::OpenAIConfig, Client};
const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
pub fn build_client_from_env() -> Result<(Client<OpenAIConfig>, String)> {
let api_key = std::env::var("OPENAI_API_KEY")
.context("OPENAI_API_KEY must be set to run retrieval evaluations")?;
let base_url =
std::env::var("OPENAI_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string());
let config = OpenAIConfig::new()
.with_api_key(api_key)
.with_api_base(&base_url);
Ok((Client::with_config(config), base_url))
}

248
evaluations/src/perf.rs Normal file
View File

@@ -0,0 +1,248 @@
use std::{
fs,
path::{Path, PathBuf},
};
use anyhow::{Context, Result};
use crate::{
args,
eval::EvaluationSummary,
report::{self, EvaluationReport},
};
pub fn mirror_perf_outputs(
record: &EvaluationReport,
summary: &EvaluationSummary,
report_root: &Path,
extra_json: Option<&Path>,
extra_dir: Option<&Path>,
) -> Result<Vec<PathBuf>> {
let mut written = Vec::new();
if let Some(path) = extra_json {
args::ensure_parent(path)?;
let blob = serde_json::to_vec_pretty(record).context("serialising perf log JSON")?;
fs::write(path, blob)
.with_context(|| format!("writing perf log copy to {}", path.display()))?;
written.push(path.to_path_buf());
}
if let Some(dir) = extra_dir {
fs::create_dir_all(dir)
.with_context(|| format!("creating perf log directory {}", dir.display()))?;
let dataset_dir = report::dataset_report_dir(report_root, &summary.dataset_id);
let dataset_slug = dataset_dir
.file_name()
.and_then(|os| os.to_str())
.unwrap_or("dataset");
let timestamp = summary.generated_at.format("%Y%m%dT%H%M%S").to_string();
let filename = format!("perf-{}-{}.json", dataset_slug, timestamp);
let path = dir.join(filename);
let blob = serde_json::to_vec_pretty(record).context("serialising perf log JSON")?;
fs::write(&path, blob)
.with_context(|| format!("writing perf log mirror {}", path.display()))?;
written.push(path);
}
Ok(written)
}
pub fn print_console_summary(record: &EvaluationReport) {
let perf = &record.performance;
println!(
"[perf] retrieval strategy={} | concurrency={} | rerank={} (pool {:?}, keep {})",
record.retrieval.strategy,
record.retrieval.concurrency,
record.retrieval.rerank_enabled,
record.retrieval.rerank_pool_size,
record.retrieval.rerank_keep_top
);
println!(
"[perf] ingestion={}ms | namespace_seed={}",
perf.ingestion_ms,
format_duration(perf.namespace_seed_ms),
);
let stage = &perf.stage_latency;
println!(
"[perf] stage avg ms → embed {:.1} | collect {:.1} | graph {:.1} | chunk {:.1} | rerank {:.1} | assemble {:.1}",
stage.embed.avg,
stage.collect_candidates.avg,
stage.graph_expansion.avg,
stage.chunk_attach.avg,
stage.rerank.avg,
stage.assemble.avg,
);
let eval = &perf.evaluation_stages_ms;
println!(
"[perf] eval stage ms → slice {} | db {} | corpus {} | namespace {} | queries {} | summarize {} | finalize {}",
eval.prepare_slice_ms,
eval.prepare_db_ms,
eval.prepare_corpus_ms,
eval.prepare_namespace_ms,
eval.run_queries_ms,
eval.summarize_ms,
eval.finalize_ms,
);
}
fn format_duration(value: Option<u128>) -> String {
value
.map(|ms| format!("{ms}ms"))
.unwrap_or_else(|| "-".to_string())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::eval::{EvaluationStageTimings, PerformanceTimings};
use chrono::Utc;
use tempfile::tempdir;
fn sample_latency() -> crate::eval::LatencyStats {
crate::eval::LatencyStats {
avg: 10.0,
p50: 8,
p95: 15,
}
}
fn sample_stage_latency() -> crate::eval::StageLatencyBreakdown {
crate::eval::StageLatencyBreakdown {
embed: sample_latency(),
collect_candidates: sample_latency(),
graph_expansion: sample_latency(),
chunk_attach: sample_latency(),
rerank: sample_latency(),
assemble: sample_latency(),
}
}
fn sample_eval_stage() -> EvaluationStageTimings {
EvaluationStageTimings {
prepare_slice_ms: 10,
prepare_db_ms: 20,
prepare_corpus_ms: 30,
prepare_namespace_ms: 40,
run_queries_ms: 50,
summarize_ms: 60,
finalize_ms: 70,
}
}
fn sample_summary() -> EvaluationSummary {
EvaluationSummary {
generated_at: Utc::now(),
k: 5,
limit: Some(10),
run_label: Some("test".into()),
total_cases: 2,
correct: 1,
precision: 0.5,
correct_at_1: 1,
correct_at_2: 1,
correct_at_3: 1,
precision_at_1: 0.5,
precision_at_2: 0.5,
precision_at_3: 0.5,
mrr: 0.0,
average_ndcg: 0.0,
duration_ms: 1234,
dataset_id: "squad-v2".into(),
dataset_label: "SQuAD v2".into(),
dataset_includes_unanswerable: false,
dataset_source: "dev".into(),
includes_impossible_cases: false,
require_verified_chunks: true,
filtered_questions: 0,
retrieval_cases: 2,
retrieval_correct: 1,
retrieval_precision: 0.5,
llm_cases: 0,
llm_answered: 0,
llm_precision: 0.0,
slice_id: "slice123".into(),
slice_seed: 42,
slice_total_cases: 400,
slice_window_offset: 0,
slice_window_length: 10,
slice_cases: 10,
slice_positive_paragraphs: 10,
slice_negative_paragraphs: 40,
slice_total_paragraphs: 50,
slice_negative_multiplier: 4.0,
namespace_reused: true,
corpus_paragraphs: 50,
ingestion_cache_path: "/tmp/cache".into(),
ingestion_reused: true,
ingestion_embeddings_reused: true,
ingestion_fingerprint: "fingerprint".into(),
positive_paragraphs_reused: 10,
negative_paragraphs_reused: 40,
latency_ms: sample_latency(),
perf: PerformanceTimings {
openai_base_url: "https://example.com".into(),
ingestion_ms: 1000,
namespace_seed_ms: Some(150),
evaluation_stage_ms: sample_eval_stage(),
stage_latency: sample_stage_latency(),
},
embedding_backend: "fastembed".into(),
embedding_model: Some("test-model".into()),
embedding_dimension: 32,
rerank_enabled: true,
rerank_pool_size: Some(4),
rerank_keep_top: 10,
concurrency: 2,
detailed_report: false,
retrieval_strategy: "initial".into(),
chunk_result_cap: 5,
chunk_rrf_k: 60.0,
chunk_rrf_vector_weight: 1.0,
chunk_rrf_fts_weight: 1.0,
chunk_rrf_use_vector: true,
chunk_rrf_use_fts: true,
ingest_chunk_min_tokens: 256,
ingest_chunk_max_tokens: 512,
ingest_chunks_only: false,
ingest_chunk_overlap_tokens: 50,
chunk_vector_take: 20,
chunk_fts_take: 20,
chunk_avg_chars_per_token: 4,
max_chunks_per_entity: 4,
cases: Vec::new(),
}
}
#[test]
fn writes_perf_mirrors_from_record() {
let tmp = tempdir().unwrap();
let report_root = tmp.path().join("reports");
let summary = sample_summary();
let record = report::EvaluationReport::from_summary(&summary, 5);
let json_path = tmp.path().join("extra.json");
let dir_path = tmp.path().join("copies");
let outputs = mirror_perf_outputs(
&record,
&summary,
&report_root,
Some(json_path.as_path()),
Some(dir_path.as_path()),
)
.expect("perf mirrors");
assert!(json_path.exists());
let content = std::fs::read_to_string(&json_path).expect("reading mirror json");
assert!(
content.contains("\"evaluation_stages_ms\""),
"perf mirror should include evaluation stage timings"
);
assert_eq!(outputs.len(), 2);
let mirrored = outputs
.into_iter()
.filter(|path| path.starts_with(&dir_path))
.collect::<Vec<_>>();
assert_eq!(mirrored.len(), 1, "expected timestamped mirror in dir");
}
}

View File

@@ -0,0 +1,197 @@
use std::{
path::PathBuf,
sync::Arc,
time::{Duration, Instant},
};
use async_openai::Client;
use common::{
storage::{
db::SurrealDbClient,
types::{system_settings::SystemSettings, user::User},
},
utils::embedding::EmbeddingProvider,
};
use retrieval_pipeline::{
pipeline::{PipelineStageTimings, RetrievalConfig},
reranking::RerankerPool,
};
use crate::{
args::Config,
cache::EmbeddingCache,
datasets::ConvertedDataset,
eval::{CaseDiagnostics, CaseSummary, EvaluationStageTimings, EvaluationSummary, SeededCase},
corpus, slice, snapshot,
};
pub(super) struct EvaluationContext<'a> {
dataset: &'a ConvertedDataset,
config: &'a Config,
pub stage_timings: EvaluationStageTimings,
pub ledger_limit: Option<usize>,
pub slice_settings: Option<slice::SliceConfig<'a>>,
pub slice: Option<slice::ResolvedSlice<'a>>,
pub window_offset: usize,
pub window_length: usize,
pub window_total_cases: usize,
pub namespace: String,
pub database: String,
pub db: Option<SurrealDbClient>,
pub descriptor: Option<snapshot::Descriptor>,
pub settings: Option<SystemSettings>,
pub settings_missing: bool,
pub must_reapply_settings: bool,
pub embedding_provider: Option<EmbeddingProvider>,
pub embedding_cache: Option<EmbeddingCache>,
pub openai_client: Option<Arc<Client<async_openai::config::OpenAIConfig>>>,
pub openai_base_url: Option<String>,
pub expected_fingerprint: Option<String>,
pub ingestion_duration_ms: u128,
pub namespace_seed_ms: Option<u128>,
pub namespace_reused: bool,
pub evaluation_start: Option<Instant>,
pub eval_user: Option<User>,
pub corpus_handle: Option<corpus::CorpusHandle>,
pub cases: Vec<SeededCase>,
pub filtered_questions: usize,
pub stage_latency_samples: Vec<PipelineStageTimings>,
pub latencies: Vec<u128>,
pub diagnostics_output: Vec<CaseDiagnostics>,
pub query_summaries: Vec<CaseSummary>,
pub rerank_pool: Option<Arc<RerankerPool>>,
pub retrieval_config: Option<Arc<RetrievalConfig>>,
pub summary: Option<EvaluationSummary>,
pub diagnostics_path: Option<PathBuf>,
pub diagnostics_enabled: bool,
}
impl<'a> EvaluationContext<'a> {
pub fn new(dataset: &'a ConvertedDataset, config: &'a Config) -> Self {
Self {
dataset,
config,
stage_timings: EvaluationStageTimings::default(),
ledger_limit: None,
slice_settings: None,
slice: None,
window_offset: 0,
window_length: 0,
window_total_cases: 0,
namespace: String::new(),
database: String::new(),
db: None,
descriptor: None,
settings: None,
settings_missing: false,
must_reapply_settings: false,
embedding_provider: None,
embedding_cache: None,
openai_client: None,
openai_base_url: None,
expected_fingerprint: None,
ingestion_duration_ms: 0,
namespace_seed_ms: None,
namespace_reused: false,
evaluation_start: None,
eval_user: None,
corpus_handle: None,
cases: Vec::new(),
filtered_questions: 0,
stage_latency_samples: Vec::new(),
latencies: Vec::new(),
diagnostics_output: Vec::new(),
query_summaries: Vec::new(),
rerank_pool: None,
retrieval_config: None,
summary: None,
diagnostics_path: config.chunk_diagnostics_path.clone(),
diagnostics_enabled: config.chunk_diagnostics_path.is_some(),
}
}
pub fn dataset(&self) -> &'a ConvertedDataset {
self.dataset
}
pub fn config(&self) -> &'a Config {
self.config
}
pub fn slice(&self) -> &slice::ResolvedSlice<'a> {
self.slice.as_ref().expect("slice has not been prepared")
}
pub fn db(&self) -> &SurrealDbClient {
self.db.as_ref().expect("database connection missing")
}
pub fn descriptor(&self) -> &snapshot::Descriptor {
self.descriptor
.as_ref()
.expect("snapshot descriptor unavailable")
}
pub fn embedding_provider(&self) -> &EmbeddingProvider {
self.embedding_provider
.as_ref()
.expect("embedding provider not initialised")
}
pub fn openai_client(&self) -> Arc<Client<async_openai::config::OpenAIConfig>> {
self.openai_client
.as_ref()
.expect("openai client missing")
.clone()
}
pub fn corpus_handle(&self) -> &corpus::CorpusHandle {
self.corpus_handle.as_ref().expect("corpus handle missing")
}
pub fn evaluation_user(&self) -> &User {
self.eval_user.as_ref().expect("evaluation user missing")
}
pub fn record_stage_duration(&mut self, stage: EvalStage, duration: Duration) {
let elapsed = duration.as_millis() as u128;
match stage {
EvalStage::PrepareSlice => self.stage_timings.prepare_slice_ms += elapsed,
EvalStage::PrepareDb => self.stage_timings.prepare_db_ms += elapsed,
EvalStage::PrepareCorpus => self.stage_timings.prepare_corpus_ms += elapsed,
EvalStage::PrepareNamespace => self.stage_timings.prepare_namespace_ms += elapsed,
EvalStage::RunQueries => self.stage_timings.run_queries_ms += elapsed,
EvalStage::Summarize => self.stage_timings.summarize_ms += elapsed,
EvalStage::Finalize => self.stage_timings.finalize_ms += elapsed,
}
}
pub fn into_summary(self) -> EvaluationSummary {
self.summary.expect("evaluation summary missing")
}
}
#[derive(Copy, Clone)]
pub(super) enum EvalStage {
PrepareSlice,
PrepareDb,
PrepareCorpus,
PrepareNamespace,
RunQueries,
Summarize,
Finalize,
}
impl EvalStage {
pub fn label(&self) -> &'static str {
match self {
EvalStage::PrepareSlice => "prepare-slice",
EvalStage::PrepareDb => "prepare-db",
EvalStage::PrepareCorpus => "prepare-corpus",
EvalStage::PrepareNamespace => "prepare-namespace",
EvalStage::RunQueries => "run-queries",
EvalStage::Summarize => "summarize",
EvalStage::Finalize => "finalize",
}
}
}

View File

@@ -0,0 +1,29 @@
mod context;
mod stages;
mod state;
use anyhow::Result;
use crate::{args::Config, datasets::ConvertedDataset, types::EvaluationSummary};
use context::EvaluationContext;
pub async fn run_evaluation(
dataset: &ConvertedDataset,
config: &Config,
) -> Result<EvaluationSummary> {
let mut ctx = EvaluationContext::new(dataset, config);
let machine = state::ready();
let machine = stages::prepare_slice(machine, &mut ctx).await?;
let machine = stages::prepare_db(machine, &mut ctx).await?;
let machine = stages::prepare_corpus(machine, &mut ctx).await?;
let machine = stages::prepare_namespace(machine, &mut ctx).await?;
let machine = stages::run_queries(machine, &mut ctx).await?;
let machine = stages::summarize(machine, &mut ctx).await?;
let machine = stages::finalize(machine, &mut ctx).await?;
drop(machine);
Ok(ctx.into_summary())
}

View File

@@ -0,0 +1,59 @@
use std::time::Instant;
use anyhow::Context;
use tracing::info;
use crate::eval::write_chunk_diagnostics;
use super::super::{
context::{EvalStage, EvaluationContext},
state::{Completed, EvaluationMachine, Summarized},
};
use super::{map_guard_error, StageResult};
pub(crate) async fn finalize(
machine: EvaluationMachine<(), Summarized>,
ctx: &mut EvaluationContext<'_>,
) -> StageResult<Completed> {
let stage = EvalStage::Finalize;
info!(
evaluation_stage = stage.label(),
"starting evaluation stage"
);
let started = Instant::now();
if let Some(cache) = ctx.embedding_cache.as_ref() {
cache
.persist()
.await
.context("persisting embedding cache")?;
}
if let Some(path) = ctx.diagnostics_path.as_ref() {
if ctx.diagnostics_enabled {
write_chunk_diagnostics(path.as_path(), &ctx.diagnostics_output)
.await
.with_context(|| format!("writing chunk diagnostics to {}", path.display()))?;
}
}
info!(
total_cases = ctx.summary.as_ref().map(|s| s.total_cases).unwrap_or(0),
correct = ctx.summary.as_ref().map(|s| s.correct).unwrap_or(0),
precision = ctx.summary.as_ref().map(|s| s.precision).unwrap_or(0.0),
dataset = ctx.dataset().metadata.id.as_str(),
"Evaluation complete"
);
let elapsed = started.elapsed();
ctx.record_stage_duration(stage, elapsed);
info!(
evaluation_stage = stage.label(),
duration_ms = elapsed.as_millis(),
"completed evaluation stage"
);
machine
.finalize()
.map_err(|(_, guard)| map_guard_error("finalize", guard))
}

View File

@@ -0,0 +1,26 @@
mod finalize;
mod prepare_corpus;
mod prepare_db;
mod prepare_namespace;
mod prepare_slice;
mod run_queries;
mod summarize;
pub(crate) use finalize::finalize;
pub(crate) use prepare_corpus::prepare_corpus;
pub(crate) use prepare_db::prepare_db;
pub(crate) use prepare_namespace::prepare_namespace;
pub(crate) use prepare_slice::prepare_slice;
pub(crate) use run_queries::run_queries;
pub(crate) use summarize::summarize;
use anyhow::Result;
use state_machines::core::GuardError;
use super::state::EvaluationMachine;
fn map_guard_error(event: &str, guard: GuardError) -> anyhow::Error {
anyhow::anyhow!("invalid evaluation pipeline transition during {event}: {guard:?}")
}
type StageResult<S> = Result<EvaluationMachine<(), S>>;

View File

@@ -0,0 +1,142 @@
use std::time::Instant;
use anyhow::Context;
use tracing::info;
use crate::{eval::can_reuse_namespace, corpus, slice, snapshot};
use super::super::{
context::{EvalStage, EvaluationContext},
state::{CorpusReady, DbReady, EvaluationMachine},
};
use super::{map_guard_error, StageResult};
pub(crate) async fn prepare_corpus(
machine: EvaluationMachine<(), DbReady>,
ctx: &mut EvaluationContext<'_>,
) -> StageResult<CorpusReady> {
let stage = EvalStage::PrepareCorpus;
info!(
evaluation_stage = stage.label(),
"starting evaluation stage"
);
let started = Instant::now();
let config = ctx.config();
let cache_settings = corpus::CorpusCacheConfig::from(config);
let embedding_provider = ctx.embedding_provider().clone();
let openai_client = ctx.openai_client();
let slice = ctx.slice();
let window = slice::select_window(slice, ctx.config().slice_offset, ctx.config().limit)
.context("selecting slice window for corpus preparation")?;
let descriptor = snapshot::Descriptor::new(config, slice, ctx.embedding_provider());
let ingestion_config = corpus::make_ingestion_config(config);
let expected_fingerprint = corpus::compute_ingestion_fingerprint(
ctx.dataset(),
slice,
config.converted_dataset_path.as_path(),
&ingestion_config,
)?;
let base_dir = corpus::cached_corpus_dir(
&cache_settings,
ctx.dataset().metadata.id.as_str(),
slice.manifest.slice_id.as_str(),
);
if !config.reseed_slice {
let requested_cases = window.cases.len();
if can_reuse_namespace(
ctx.db(),
&descriptor,
&ctx.namespace,
&ctx.database,
ctx.dataset().metadata.id.as_str(),
slice.manifest.slice_id.as_str(),
expected_fingerprint.as_str(),
requested_cases,
)
.await?
{
if let Some(manifest) = corpus::load_cached_manifest(&base_dir)? {
info!(
cache = %base_dir.display(),
namespace = ctx.namespace.as_str(),
database = ctx.database.as_str(),
"Namespace already seeded; reusing cached corpus manifest"
);
let corpus_handle = corpus::corpus_handle_from_manifest(manifest, base_dir);
ctx.corpus_handle = Some(corpus_handle);
ctx.expected_fingerprint = Some(expected_fingerprint);
ctx.ingestion_duration_ms = 0;
ctx.descriptor = Some(descriptor);
let elapsed = started.elapsed();
ctx.record_stage_duration(stage, elapsed);
info!(
evaluation_stage = stage.label(),
duration_ms = elapsed.as_millis(),
"completed evaluation stage"
);
return machine
.prepare_corpus()
.map_err(|(_, guard)| map_guard_error("prepare_corpus", guard));
} else {
info!(
cache = %base_dir.display(),
"Namespace reusable but cached manifest missing; regenerating corpus"
);
}
}
}
let eval_user_id = "eval-user".to_string();
let ingestion_timer = Instant::now();
let corpus_handle = {
corpus::ensure_corpus(
ctx.dataset(),
slice,
&window,
&cache_settings,
embedding_provider.clone().into(),
openai_client,
&eval_user_id,
config.converted_dataset_path.as_path(),
ingestion_config.clone(),
)
.await
.context("ensuring ingestion-backed corpus")?
};
let expected_fingerprint = corpus_handle
.manifest
.metadata
.ingestion_fingerprint
.clone();
let ingestion_duration_ms = ingestion_timer.elapsed().as_millis() as u128;
info!(
cache = %corpus_handle.path.display(),
reused_ingestion = corpus_handle.reused_ingestion,
reused_embeddings = corpus_handle.reused_embeddings,
positive_ingested = corpus_handle.positive_ingested,
negative_ingested = corpus_handle.negative_ingested,
"Ingestion corpus ready"
);
ctx.corpus_handle = Some(corpus_handle);
ctx.expected_fingerprint = Some(expected_fingerprint);
ctx.ingestion_duration_ms = ingestion_duration_ms;
ctx.descriptor = Some(descriptor);
let elapsed = started.elapsed();
ctx.record_stage_duration(stage, elapsed);
info!(
evaluation_stage = stage.label(),
duration_ms = elapsed.as_millis(),
"completed evaluation stage"
);
machine
.prepare_corpus()
.map_err(|(_, guard)| map_guard_error("prepare_corpus", guard))
}

View File

@@ -0,0 +1,121 @@
use std::{sync::Arc, time::Instant};
use anyhow::{anyhow, Context};
use tracing::info;
use crate::{
args::EmbeddingBackend,
cache::EmbeddingCache,
eval::{
connect_eval_db, enforce_system_settings, load_or_init_system_settings, sanitize_model_code,
},
openai,
};
use common::utils::embedding::EmbeddingProvider;
use super::super::{
context::{EvalStage, EvaluationContext},
state::{DbReady, EvaluationMachine, SlicePrepared},
};
use super::{map_guard_error, StageResult};
pub(crate) async fn prepare_db(
machine: EvaluationMachine<(), SlicePrepared>,
ctx: &mut EvaluationContext<'_>,
) -> StageResult<DbReady> {
let stage = EvalStage::PrepareDb;
info!(
evaluation_stage = stage.label(),
"starting evaluation stage"
);
let started = Instant::now();
let namespace = ctx.namespace.clone();
let database = ctx.database.clone();
let config = ctx.config();
let db = connect_eval_db(config, &namespace, &database).await?;
let (raw_openai_client, openai_base_url) =
openai::build_client_from_env().context("building OpenAI client")?;
let openai_client = Arc::new(raw_openai_client);
// Create embedding provider directly from config (eval only supports FastEmbed and Hashed)
let embedding_provider = match config.embedding_backend {
crate::args::EmbeddingBackend::FastEmbed => {
EmbeddingProvider::new_fastembed(config.embedding_model.clone())
.await
.context("creating FastEmbed provider")?
}
crate::args::EmbeddingBackend::Hashed => {
EmbeddingProvider::new_hashed(1536).context("creating Hashed provider")?
}
};
let provider_dimension = embedding_provider.dimension();
if provider_dimension == 0 {
return Err(anyhow!(
"embedding provider reported zero dimensions; cannot continue"
));
}
info!(
backend = embedding_provider.backend_label(),
model = embedding_provider
.model_code()
.as_deref()
.unwrap_or("<none>"),
dimension = provider_dimension,
"Embedding provider initialised"
);
info!(openai_base_url = %openai_base_url, "OpenAI client configured");
let (mut settings, settings_missing) =
load_or_init_system_settings(&db, provider_dimension).await?;
let embedding_cache = if config.embedding_backend == EmbeddingBackend::FastEmbed {
if let Some(model_code) = embedding_provider.model_code() {
let sanitized = sanitize_model_code(&model_code);
let path = config.cache_dir.join(format!("{sanitized}.json"));
if config.force_convert && path.exists() {
tokio::fs::remove_file(&path)
.await
.with_context(|| format!("removing stale cache {}", path.display()))
.ok();
}
let cache = EmbeddingCache::load(&path).await?;
info!(path = %path.display(), "Embedding cache ready");
Some(cache)
} else {
None
}
} else {
None
};
let must_reapply_settings = settings_missing;
let defer_initial_enforce = settings_missing && !config.reseed_slice;
if !defer_initial_enforce {
settings = enforce_system_settings(&db, settings, provider_dimension, config).await?;
}
ctx.db = Some(db);
ctx.settings_missing = settings_missing;
ctx.must_reapply_settings = must_reapply_settings;
ctx.settings = Some(settings);
ctx.embedding_provider = Some(embedding_provider);
ctx.embedding_cache = embedding_cache;
ctx.openai_client = Some(openai_client);
ctx.openai_base_url = Some(openai_base_url);
let elapsed = started.elapsed();
ctx.record_stage_duration(stage, elapsed);
info!(
evaluation_stage = stage.label(),
duration_ms = elapsed.as_millis(),
"completed evaluation stage"
);
machine
.prepare_db()
.map_err(|(_, guard)| map_guard_error("prepare_db", guard))
}

View File

@@ -0,0 +1,203 @@
use std::time::Instant;
use anyhow::{anyhow, Context};
use common::storage::types::system_settings::SystemSettings;
use tracing::{info, warn};
use crate::{
db_helpers::{recreate_indexes, remove_all_indexes, reset_namespace},
eval::{
can_reuse_namespace, cases_from_manifest, enforce_system_settings, ensure_eval_user,
record_namespace_state, warm_hnsw_cache,
},
corpus,
};
use super::super::{
context::{EvalStage, EvaluationContext},
state::{CorpusReady, EvaluationMachine, NamespaceReady},
};
use super::{map_guard_error, StageResult};
pub(crate) async fn prepare_namespace(
machine: EvaluationMachine<(), CorpusReady>,
ctx: &mut EvaluationContext<'_>,
) -> StageResult<NamespaceReady> {
let stage = EvalStage::PrepareNamespace;
info!(
evaluation_stage = stage.label(),
"starting evaluation stage"
);
let started = Instant::now();
let config = ctx.config();
let dataset = ctx.dataset();
let expected_fingerprint = ctx
.expected_fingerprint
.as_deref()
.unwrap_or_default()
.to_string();
let namespace = ctx.namespace.clone();
let database = ctx.database.clone();
let embedding_provider = ctx.embedding_provider().clone();
let corpus_handle = ctx.corpus_handle();
let base_manifest = &corpus_handle.manifest;
let manifest_for_seed =
if ctx.window_offset == 0 && ctx.window_length >= base_manifest.questions.len() {
base_manifest.clone()
} else {
corpus::window_manifest(
base_manifest,
ctx.window_offset,
ctx.window_length,
ctx.config().negative_multiplier,
)
.context("selecting manifest window for seeding")?
};
let requested_cases = manifest_for_seed.questions.len();
let mut namespace_reused = false;
if !config.reseed_slice {
namespace_reused = {
let slice = ctx.slice();
can_reuse_namespace(
ctx.db(),
ctx.descriptor(),
&namespace,
&database,
dataset.metadata.id.as_str(),
slice.manifest.slice_id.as_str(),
expected_fingerprint.as_str(),
requested_cases,
)
.await?
};
}
let mut namespace_seed_ms = None;
if !namespace_reused {
ctx.must_reapply_settings = true;
if let Err(err) = reset_namespace(ctx.db(), &namespace, &database).await {
warn!(
error = %err,
namespace,
database = %database,
"Failed to reset namespace before reseeding; continuing with existing data"
);
} else if let Err(err) = ctx.db().apply_migrations().await {
warn!(error = %err, "Failed to reapply migrations after namespace reset");
}
{
let slice = ctx.slice();
info!(
slice = slice.manifest.slice_id.as_str(),
window_offset = ctx.window_offset,
window_length = ctx.window_length,
positives = manifest_for_seed
.questions
.iter()
.map(|q| q.paragraph_id.as_str())
.collect::<std::collections::HashSet<_>>()
.len(),
negatives = manifest_for_seed.paragraphs.len().saturating_sub(
manifest_for_seed
.questions
.iter()
.map(|q| q.paragraph_id.as_str())
.collect::<std::collections::HashSet<_>>()
.len(),
),
total = manifest_for_seed.paragraphs.len(),
"Seeding ingestion corpus into SurrealDB"
);
}
let indexes_disabled = remove_all_indexes(ctx.db()).await.is_ok();
let seed_start = Instant::now();
corpus::seed_manifest_into_db(ctx.db(), &manifest_for_seed)
.await
.context("seeding ingestion corpus from manifest")?;
namespace_seed_ms = Some(seed_start.elapsed().as_millis() as u128);
// Recreate indexes AFTER data is loaded (correct bulk loading pattern)
if indexes_disabled {
info!("Recreating indexes after seeding data");
recreate_indexes(ctx.db(), embedding_provider.dimension())
.await
.context("recreating indexes with correct dimension")?;
warm_hnsw_cache(ctx.db(), embedding_provider.dimension()).await?;
}
{
let slice = ctx.slice();
record_namespace_state(
ctx.descriptor(),
dataset.metadata.id.as_str(),
slice.manifest.slice_id.as_str(),
expected_fingerprint.as_str(),
&namespace,
&database,
requested_cases,
)
.await;
}
}
if ctx.must_reapply_settings {
let mut settings = SystemSettings::get_current(ctx.db())
.await
.context("reloading system settings after namespace reset")?;
settings =
enforce_system_settings(ctx.db(), settings, embedding_provider.dimension(), config)
.await?;
ctx.settings = Some(settings);
ctx.must_reapply_settings = false;
}
let user = ensure_eval_user(ctx.db()).await?;
ctx.eval_user = Some(user);
let total_manifest_questions = manifest_for_seed.questions.len();
let cases = cases_from_manifest(&manifest_for_seed);
let include_impossible = manifest_for_seed.metadata.include_unanswerable;
let require_verified_chunks = manifest_for_seed.metadata.require_verified_chunks;
let filtered = total_manifest_questions.saturating_sub(cases.len());
if filtered > 0 {
info!(
filtered_questions = filtered,
total_questions = total_manifest_questions,
includes_impossible = include_impossible,
require_verified_chunks = require_verified_chunks,
"Filtered questions not eligible for this evaluation mode (impossible or unverifiable)"
);
}
if cases.is_empty() {
return Err(anyhow!(
"no eligible questions found in converted dataset for evaluation (consider --llm-mode or refreshing ingestion data)"
));
}
ctx.cases = cases;
ctx.filtered_questions = filtered;
ctx.namespace_reused = namespace_reused;
ctx.namespace_seed_ms = namespace_seed_ms;
info!(
cases = ctx.cases.len(),
window_offset = ctx.window_offset,
namespace_reused = namespace_reused,
"Dataset ready"
);
let elapsed = started.elapsed();
ctx.record_stage_duration(stage, elapsed);
info!(
evaluation_stage = stage.label(),
duration_ms = elapsed.as_millis(),
"completed evaluation stage"
);
machine
.prepare_namespace()
.map_err(|(_, guard)| map_guard_error("prepare_namespace", guard))
}

View File

@@ -0,0 +1,70 @@
use std::time::Instant;
use anyhow::Context;
use tracing::info;
use crate::{
eval::{default_database, default_namespace, ledger_target},
slice,
};
use super::super::{
context::{EvalStage, EvaluationContext},
state::{EvaluationMachine, Ready, SlicePrepared},
};
use super::{map_guard_error, StageResult};
pub(crate) async fn prepare_slice(
machine: EvaluationMachine<(), Ready>,
ctx: &mut EvaluationContext<'_>,
) -> StageResult<SlicePrepared> {
let stage = EvalStage::PrepareSlice;
info!(
evaluation_stage = stage.label(),
"starting evaluation stage"
);
let started = Instant::now();
let ledger_limit = ledger_target(ctx.config());
let slice_settings = slice::slice_config_with_limit(ctx.config(), ledger_limit);
let resolved_slice =
slice::resolve_slice(ctx.dataset(), &slice_settings).context("resolving dataset slice")?;
let window = slice::select_window(
&resolved_slice,
ctx.config().slice_offset,
ctx.config().limit,
)
.context("selecting slice window (use --slice-grow to extend the ledger first)")?;
ctx.ledger_limit = ledger_limit;
ctx.slice_settings = Some(slice_settings);
ctx.slice = Some(resolved_slice.clone());
ctx.window_offset = window.offset;
ctx.window_length = window.length;
ctx.window_total_cases = window.total_cases;
ctx.namespace = ctx
.config()
.database
.db_namespace
.clone()
.unwrap_or_else(|| default_namespace(ctx.dataset().metadata.id.as_str(), ctx.config().limit));
ctx.database = ctx
.config()
.database
.db_database
.clone()
.unwrap_or_else(default_database);
let elapsed = started.elapsed();
ctx.record_stage_duration(stage, elapsed);
info!(
evaluation_stage = stage.label(),
duration_ms = elapsed.as_millis(),
"completed evaluation stage"
);
machine
.prepare_slice()
.map_err(|(_, guard)| map_guard_error("prepare_slice", guard))
}

View File

@@ -0,0 +1,420 @@
use std::{collections::HashSet, sync::Arc, time::Instant};
use anyhow::Context;
use common::storage::types::StoredObject;
use futures::stream::{self, StreamExt};
use tracing::{debug, info};
use crate::eval::{
adapt_strategy_output, build_case_diagnostics, text_contains_answer, CaseDiagnostics,
CaseSummary, RetrievedSummary,
};
use retrieval_pipeline::{
pipeline::{self, PipelineStageTimings, RetrievalConfig},
reranking::RerankerPool,
};
use tokio::sync::Semaphore;
use super::super::{
context::{EvalStage, EvaluationContext},
state::{EvaluationMachine, NamespaceReady, QueriesFinished},
};
use super::{map_guard_error, StageResult};
pub(crate) async fn run_queries(
machine: EvaluationMachine<(), NamespaceReady>,
ctx: &mut EvaluationContext<'_>,
) -> StageResult<QueriesFinished> {
let stage = EvalStage::RunQueries;
info!(
evaluation_stage = stage.label(),
"starting evaluation stage"
);
let started = Instant::now();
let config = ctx.config();
let dataset = ctx.dataset();
let slice_settings = ctx
.slice_settings
.as_ref()
.expect("slice settings missing during query stage");
let total_cases = ctx.cases.len();
let cases_iter = std::mem::take(&mut ctx.cases).into_iter().enumerate();
let rerank_pool = if config.retrieval.rerank {
Some(
RerankerPool::new(config.retrieval.rerank_pool_size)
.context("initialising reranker pool")?,
)
} else {
None
};
let mut retrieval_config = RetrievalConfig::default();
retrieval_config.strategy = config.retrieval.strategy;
retrieval_config.tuning.rerank_keep_top = config.retrieval.rerank_keep_top;
if retrieval_config.tuning.fallback_min_results < config.retrieval.rerank_keep_top {
retrieval_config.tuning.fallback_min_results = config.retrieval.rerank_keep_top;
}
retrieval_config.tuning.chunk_result_cap = config.retrieval.chunk_result_cap.max(1);
if let Some(value) = config.retrieval.chunk_vector_take {
retrieval_config.tuning.chunk_vector_take = value;
}
if let Some(value) = config.retrieval.chunk_fts_take {
retrieval_config.tuning.chunk_fts_take = value;
}
if let Some(value) = config.retrieval.chunk_rrf_k {
retrieval_config.tuning.chunk_rrf_k = value;
}
if let Some(value) = config.retrieval.chunk_rrf_vector_weight {
retrieval_config.tuning.chunk_rrf_vector_weight = value;
}
if let Some(value) = config.retrieval.chunk_rrf_fts_weight {
retrieval_config.tuning.chunk_rrf_fts_weight = value;
}
if let Some(value) = config.retrieval.chunk_rrf_use_vector {
retrieval_config.tuning.chunk_rrf_use_vector = value;
}
if let Some(value) = config.retrieval.chunk_rrf_use_fts {
retrieval_config.tuning.chunk_rrf_use_fts = value;
}
if let Some(value) = config.retrieval.chunk_avg_chars_per_token {
retrieval_config.tuning.avg_chars_per_token = value;
}
if let Some(value) = config.retrieval.max_chunks_per_entity {
retrieval_config.tuning.max_chunks_per_entity = value;
}
let active_tuning = retrieval_config.tuning.clone();
let effective_chunk_vector = config
.retrieval
.chunk_vector_take
.unwrap_or(active_tuning.chunk_vector_take);
let effective_chunk_fts = config
.retrieval
.chunk_fts_take
.unwrap_or(active_tuning.chunk_fts_take);
info!(
dataset = dataset.metadata.id.as_str(),
slice_seed = config.slice_seed,
slice_offset = config.slice_offset,
slice_limit = config
.limit
.unwrap_or(ctx.window_total_cases),
negative_multiplier = %slice_settings.negative_multiplier,
rerank_enabled = config.retrieval.rerank,
rerank_pool_size = config.retrieval.rerank_pool_size,
rerank_keep_top = config.retrieval.rerank_keep_top,
chunk_vector_take = effective_chunk_vector,
chunk_fts_take = effective_chunk_fts,
chunk_rrf_k = active_tuning.chunk_rrf_k,
chunk_rrf_vector_weight = active_tuning.chunk_rrf_vector_weight,
chunk_rrf_fts_weight = active_tuning.chunk_rrf_fts_weight,
chunk_rrf_use_vector = active_tuning.chunk_rrf_use_vector,
chunk_rrf_use_fts = active_tuning.chunk_rrf_use_fts,
embedding_backend = ctx.embedding_provider().backend_label(),
embedding_model = ctx
.embedding_provider()
.model_code()
.as_deref()
.unwrap_or("<default>"),
"Starting evaluation run"
);
let retrieval_config = Arc::new(retrieval_config);
ctx.rerank_pool = rerank_pool.clone();
ctx.retrieval_config = Some(retrieval_config.clone());
ctx.evaluation_start = Some(Instant::now());
let user_id = ctx.evaluation_user().id.clone();
let concurrency = config.concurrency.max(1);
let diagnostics_enabled = ctx.diagnostics_enabled;
let query_semaphore = Arc::new(Semaphore::new(concurrency));
info!(
total_cases = total_cases,
max_concurrent_queries = concurrency,
"Starting evaluation with staged query execution"
);
let embedding_provider_for_queries = ctx.embedding_provider().clone();
let rerank_pool_for_queries = rerank_pool.clone();
let db = ctx.db().clone();
let openai_client = ctx.openai_client();
let raw_results = stream::iter(cases_iter)
.map(move |(idx, case)| {
let db = db.clone();
let openai_client = openai_client.clone();
let user_id = user_id.clone();
let retrieval_config = retrieval_config.clone();
let embedding_provider = embedding_provider_for_queries.clone();
let rerank_pool = rerank_pool_for_queries.clone();
let semaphore = query_semaphore.clone();
let diagnostics_enabled = diagnostics_enabled;
async move {
let _permit = semaphore
.acquire()
.await
.context("acquiring query semaphore permit")?;
let crate::eval::SeededCase {
question_id,
question,
expected_source,
answers,
paragraph_id,
paragraph_title,
expected_chunk_ids,
is_impossible,
has_verified_chunks,
} = case;
let query_start = Instant::now();
debug!(question_id = %question_id, "Evaluating query");
let query_embedding =
embedding_provider.embed(&question).await.with_context(|| {
format!("generating embedding for question {}", question_id)
})?;
let reranker = match &rerank_pool {
Some(pool) => Some(pool.checkout().await),
None => None,
};
let (result_output, pipeline_diagnostics, stage_timings) = if diagnostics_enabled {
let outcome = pipeline::run_pipeline_with_embedding_with_diagnostics(
&db,
&openai_client,
Some(&embedding_provider),
query_embedding,
&question,
&user_id,
(*retrieval_config).clone(),
reranker,
)
.await
.with_context(|| format!("running pipeline for question {}", question_id))?;
(outcome.results, outcome.diagnostics, outcome.stage_timings)
} else {
let outcome = pipeline::run_pipeline_with_embedding_with_metrics(
&db,
&openai_client,
Some(&embedding_provider),
query_embedding,
&question,
&user_id,
(*retrieval_config).clone(),
reranker,
)
.await
.with_context(|| format!("running pipeline for question {}", question_id))?;
(outcome.results, None, outcome.stage_timings)
};
let query_latency = query_start.elapsed().as_millis() as u128;
let candidates = adapt_strategy_output(result_output);
let mut retrieved = Vec::new();
let mut match_rank = None;
let answers_lower: Vec<String> =
answers.iter().map(|ans| ans.to_ascii_lowercase()).collect();
let expected_chunk_ids_set: HashSet<&str> =
expected_chunk_ids.iter().map(|id| id.as_str()).collect();
let chunk_id_required = has_verified_chunks;
let mut entity_hit = false;
let mut chunk_text_hit = false;
let mut chunk_id_hit = !chunk_id_required;
for (idx_entity, candidate) in candidates.iter().enumerate() {
if idx_entity >= config.k {
break;
}
let entity_match = candidate.source_id == expected_source;
if entity_match {
entity_hit = true;
}
let chunk_text_for_entity = candidate
.chunks
.iter()
.any(|chunk| text_contains_answer(&chunk.chunk.chunk, &answers_lower));
if chunk_text_for_entity {
chunk_text_hit = true;
}
let chunk_id_for_entity = if chunk_id_required {
expected_chunk_ids_set.contains(candidate.source_id.as_str())
|| candidate
.chunks
.iter()
.any(|chunk| expected_chunk_ids_set.contains(&chunk.chunk.get_id()))
} else {
true
};
if chunk_id_for_entity {
chunk_id_hit = true;
}
let success = entity_match && chunk_text_for_entity && chunk_id_for_entity;
if success && match_rank.is_none() {
match_rank = Some(idx_entity + 1);
}
let detail_fields = if config.detailed_report {
let description = candidate.entity_description.clone();
let category = candidate.entity_category.clone();
(
description,
category,
Some(chunk_text_for_entity),
Some(chunk_id_for_entity),
)
} else {
(None, None, None, None)
};
retrieved.push(RetrievedSummary {
rank: idx_entity + 1,
entity_id: candidate.entity_id.clone(),
source_id: candidate.source_id.clone(),
entity_name: candidate.entity_name.clone(),
score: candidate.score,
matched: success,
entity_description: detail_fields.0,
entity_category: detail_fields.1,
chunk_text_match: detail_fields.2,
chunk_id_match: detail_fields.3,
});
}
let overall_match = match_rank.is_some();
let reciprocal_rank = calculate_reciprocal_rank(match_rank);
let ndcg = calculate_ndcg(&retrieved, config.k);
let summary = CaseSummary {
question_id,
question,
paragraph_id,
paragraph_title,
expected_source,
answers,
matched: overall_match,
entity_match: entity_hit,
chunk_text_match: chunk_text_hit,
chunk_id_match: chunk_id_hit,
is_impossible,
has_verified_chunks,
match_rank,
reciprocal_rank: Some(reciprocal_rank),
ndcg: Some(ndcg),
latency_ms: query_latency,
retrieved,
};
let diagnostics = if diagnostics_enabled {
Some(build_case_diagnostics(
&summary,
&expected_chunk_ids,
&answers_lower,
&candidates,
pipeline_diagnostics,
))
} else {
None
};
Ok::<
(
usize,
CaseSummary,
Option<CaseDiagnostics>,
PipelineStageTimings,
),
anyhow::Error,
>((idx, summary, diagnostics, stage_timings))
}
})
.buffer_unordered(concurrency)
.collect::<Vec<_>>()
.await;
let mut results = Vec::with_capacity(raw_results.len());
for result in raw_results {
match result {
Ok(val) => results.push(val),
Err(err) => {
tracing::error!(error = ?err, "Query execution failed");
}
}
}
let mut ordered = results;
ordered.sort_by_key(|(idx, ..)| *idx);
let mut summaries = Vec::with_capacity(ordered.len());
let mut latencies = Vec::with_capacity(ordered.len());
let mut diagnostics_output = Vec::new();
let mut stage_latency_samples = Vec::with_capacity(ordered.len());
for (_, summary, diagnostics, stage_timings) in ordered {
latencies.push(summary.latency_ms);
summaries.push(summary);
if let Some(diag) = diagnostics {
diagnostics_output.push(diag);
}
stage_latency_samples.push(stage_timings);
}
ctx.query_summaries = summaries;
ctx.latencies = latencies;
ctx.diagnostics_output = diagnostics_output;
ctx.stage_latency_samples = stage_latency_samples;
let elapsed = started.elapsed();
ctx.record_stage_duration(stage, elapsed);
info!(
evaluation_stage = stage.label(),
duration_ms = elapsed.as_millis(),
"completed evaluation stage"
);
machine
.run_queries()
.map_err(|(_, guard)| map_guard_error("run_queries", guard))
}
fn calculate_reciprocal_rank(rank: Option<usize>) -> f64 {
match rank {
Some(r) if r > 0 => 1.0 / (r as f64),
_ => 0.0,
}
}
fn calculate_ndcg(retrieved: &[RetrievedSummary], k: usize) -> f64 {
let mut dcg = 0.0;
let mut relevant_count = 0;
for (i, item) in retrieved.iter().enumerate() {
if i >= k {
break;
}
if item.matched {
let rel = 1.0;
dcg += rel / (i as f64 + 2.0).log2();
relevant_count += 1;
}
}
if dcg == 0.0 {
return 0.0;
}
// Calculate IDCG based on the number of relevant items found
// We assume ideal ordering would place all 'relevant_count' items at the top
let mut idcg = 0.0;
for i in 0..relevant_count {
let rel = 1.0;
idcg += rel / (i as f64 + 2.0).log2();
}
if idcg == 0.0 {
0.0
} else {
dcg / idcg
}
}

View File

@@ -0,0 +1,232 @@
use std::time::Instant;
use chrono::Utc;
use tracing::info;
use crate::eval::{
build_stage_latency_breakdown, compute_latency_stats, EvaluationSummary, PerformanceTimings,
};
use super::super::{
context::{EvalStage, EvaluationContext},
state::{EvaluationMachine, QueriesFinished, Summarized},
};
use super::{map_guard_error, StageResult};
pub(crate) async fn summarize(
machine: EvaluationMachine<(), QueriesFinished>,
ctx: &mut EvaluationContext<'_>,
) -> StageResult<Summarized> {
let stage = EvalStage::Summarize;
info!(
evaluation_stage = stage.label(),
"starting evaluation stage"
);
let started = Instant::now();
let summaries = std::mem::take(&mut ctx.query_summaries);
let latencies = std::mem::take(&mut ctx.latencies);
let stage_latency_samples = std::mem::take(&mut ctx.stage_latency_samples);
let duration_ms = ctx
.evaluation_start
.take()
.map(|start| start.elapsed().as_millis())
.unwrap_or_default();
let config = ctx.config();
let dataset = ctx.dataset();
let slice = ctx.slice();
let corpus_handle = ctx.corpus_handle();
let total_cases = summaries.len();
let mut correct = 0usize;
let mut correct_at_1 = 0usize;
let mut correct_at_2 = 0usize;
let mut correct_at_3 = 0usize;
let mut retrieval_cases = 0usize;
let mut llm_cases = 0usize;
let mut llm_answered = 0usize;
let mut sum_reciprocal_rank = 0.0;
let mut sum_ndcg = 0.0;
for summary in &summaries {
if summary.is_impossible {
llm_cases += 1;
if summary.matched {
llm_answered += 1;
}
continue;
}
retrieval_cases += 1;
if let Some(rr) = summary.reciprocal_rank {
sum_reciprocal_rank += rr;
}
if let Some(ndcg) = summary.ndcg {
sum_ndcg += ndcg;
}
if summary.matched {
correct += 1;
if let Some(rank) = summary.match_rank {
if rank <= 1 {
correct_at_1 += 1;
}
if rank <= 2 {
correct_at_2 += 1;
}
if rank <= 3 {
correct_at_3 += 1;
}
}
}
}
let latency_stats = compute_latency_stats(&latencies);
let stage_latency = build_stage_latency_breakdown(&stage_latency_samples);
let retrieval_precision = if retrieval_cases == 0 {
0.0
} else {
(correct as f64) / (retrieval_cases as f64)
};
let llm_precision = if llm_cases == 0 {
0.0
} else {
(llm_answered as f64) / (llm_cases as f64)
};
let precision = retrieval_precision;
let precision_at_1 = if retrieval_cases == 0 {
0.0
} else {
(correct_at_1 as f64) / (retrieval_cases as f64)
};
let precision_at_2 = if retrieval_cases == 0 {
0.0
} else {
(correct_at_2 as f64) / (retrieval_cases as f64)
};
let precision_at_3 = if retrieval_cases == 0 {
0.0
} else {
(correct_at_3 as f64) / (retrieval_cases as f64)
};
let mrr = if retrieval_cases == 0 {
0.0
} else {
sum_reciprocal_rank / (retrieval_cases as f64)
};
let average_ndcg = if retrieval_cases == 0 {
0.0
} else {
sum_ndcg / (retrieval_cases as f64)
};
let active_tuning = ctx
.retrieval_config
.as_ref()
.map(|cfg| cfg.tuning.clone())
.unwrap_or_default();
let perf_timings = PerformanceTimings {
openai_base_url: ctx
.openai_base_url
.clone()
.unwrap_or_else(|| "<unknown>".to_string()),
ingestion_ms: ctx.ingestion_duration_ms,
namespace_seed_ms: ctx.namespace_seed_ms,
evaluation_stage_ms: ctx.stage_timings.clone(),
stage_latency,
};
ctx.summary = Some(EvaluationSummary {
generated_at: Utc::now(),
k: config.k,
limit: config.limit,
run_label: config.label.clone(),
total_cases,
correct,
precision,
correct_at_1,
correct_at_2,
correct_at_3,
precision_at_1,
precision_at_2,
precision_at_3,
mrr,
average_ndcg,
duration_ms,
dataset_id: dataset.metadata.id.clone(),
dataset_label: dataset.metadata.label.clone(),
dataset_includes_unanswerable: dataset.metadata.include_unanswerable,
dataset_source: dataset.source.clone(),
includes_impossible_cases: slice.manifest.includes_unanswerable,
require_verified_chunks: slice.manifest.require_verified_chunks,
filtered_questions: ctx.filtered_questions,
retrieval_cases,
retrieval_correct: correct,
retrieval_precision,
llm_cases,
llm_answered,
llm_precision,
slice_id: slice.manifest.slice_id.clone(),
slice_seed: slice.manifest.seed,
slice_total_cases: slice.manifest.case_count,
slice_window_offset: ctx.window_offset,
slice_window_length: ctx.window_length,
slice_cases: total_cases,
slice_positive_paragraphs: slice.manifest.positive_paragraphs,
slice_negative_paragraphs: slice.manifest.negative_paragraphs,
slice_total_paragraphs: slice.manifest.total_paragraphs,
slice_negative_multiplier: slice.manifest.negative_multiplier,
namespace_reused: ctx.namespace_reused,
corpus_paragraphs: ctx.corpus_handle().manifest.metadata.paragraph_count,
ingestion_cache_path: corpus_handle.path.display().to_string(),
ingestion_reused: corpus_handle.reused_ingestion,
ingestion_embeddings_reused: corpus_handle.reused_embeddings,
ingestion_fingerprint: corpus_handle
.manifest
.metadata
.ingestion_fingerprint
.clone(),
positive_paragraphs_reused: corpus_handle.positive_reused,
negative_paragraphs_reused: corpus_handle.negative_reused,
latency_ms: latency_stats,
perf: perf_timings,
embedding_backend: ctx.embedding_provider().backend_label().to_string(),
embedding_model: ctx.embedding_provider().model_code(),
embedding_dimension: ctx.embedding_provider().dimension(),
rerank_enabled: config.retrieval.rerank,
rerank_pool_size: ctx
.rerank_pool
.as_ref()
.map(|_| config.retrieval.rerank_pool_size),
rerank_keep_top: config.retrieval.rerank_keep_top,
concurrency: config.concurrency.max(1),
detailed_report: config.detailed_report,
retrieval_strategy: config.retrieval.strategy.to_string(),
chunk_result_cap: config.retrieval.chunk_result_cap,
chunk_rrf_k: active_tuning.chunk_rrf_k,
chunk_rrf_vector_weight: active_tuning.chunk_rrf_vector_weight,
chunk_rrf_fts_weight: active_tuning.chunk_rrf_fts_weight,
chunk_rrf_use_vector: active_tuning.chunk_rrf_use_vector,
chunk_rrf_use_fts: active_tuning.chunk_rrf_use_fts,
ingest_chunk_min_tokens: config.ingest.ingest_chunk_min_tokens,
ingest_chunk_max_tokens: config.ingest.ingest_chunk_max_tokens,
ingest_chunks_only: config.ingest.ingest_chunks_only,
ingest_chunk_overlap_tokens: config.ingest.ingest_chunk_overlap_tokens,
chunk_vector_take: active_tuning.chunk_vector_take,
chunk_fts_take: active_tuning.chunk_fts_take,
chunk_avg_chars_per_token: active_tuning.avg_chars_per_token,
max_chunks_per_entity: active_tuning.max_chunks_per_entity,
cases: summaries,
});
let elapsed = started.elapsed();
ctx.record_stage_duration(stage, elapsed);
info!(
evaluation_stage = stage.label(),
duration_ms = elapsed.as_millis(),
"completed evaluation stage"
);
machine
.summarize()
.map_err(|(_, guard)| map_guard_error("summarize", guard))
}

View File

@@ -0,0 +1,31 @@
use state_machines::state_machine;
state_machine! {
name: EvaluationMachine,
state: EvaluationState,
initial: Ready,
states: [Ready, SlicePrepared, DbReady, CorpusReady, NamespaceReady, QueriesFinished, Summarized, Completed, Failed],
events {
prepare_slice { transition: { from: Ready, to: SlicePrepared } }
prepare_db { transition: { from: SlicePrepared, to: DbReady } }
prepare_corpus { transition: { from: DbReady, to: CorpusReady } }
prepare_namespace { transition: { from: CorpusReady, to: NamespaceReady } }
run_queries { transition: { from: NamespaceReady, to: QueriesFinished } }
summarize { transition: { from: QueriesFinished, to: Summarized } }
finalize { transition: { from: Summarized, to: Completed } }
abort {
transition: { from: Ready, to: Failed }
transition: { from: SlicePrepared, to: Failed }
transition: { from: DbReady, to: Failed }
transition: { from: CorpusReady, to: Failed }
transition: { from: NamespaceReady, to: Failed }
transition: { from: QueriesFinished, to: Failed }
transition: { from: Summarized, to: Failed }
transition: { from: Completed, to: Failed }
}
}
}
pub fn ready() -> EvaluationMachine<(), Ready> {
EvaluationMachine::new(())
}

1187
evaluations/src/report.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,63 @@
//! System settings enforcement for evaluations.
use anyhow::{Context, Result};
use common::{
error::AppError,
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
};
use tracing::info;
use crate::args::Config;
/// Enforce evaluation-specific system settings overrides.
pub(crate) async fn enforce_system_settings(
db: &SurrealDbClient,
mut settings: SystemSettings,
provider_dimension: usize,
config: &Config,
) -> Result<SystemSettings> {
let mut updated_settings = settings.clone();
let mut needs_settings_update = false;
if provider_dimension != settings.embedding_dimensions as usize {
updated_settings.embedding_dimensions = provider_dimension as u32;
needs_settings_update = true;
}
if let Some(query_override) = config.query_model.as_deref() {
if settings.query_model != query_override {
info!(
model = query_override,
"Overriding system query model for this run"
);
updated_settings.query_model = query_override.to_string();
needs_settings_update = true;
}
}
if needs_settings_update {
settings = SystemSettings::update(db, updated_settings)
.await
.context("updating system settings overrides")?;
}
Ok(settings)
}
/// Load existing system settings or initialize them via migrations.
pub(crate) async fn load_or_init_system_settings(
db: &SurrealDbClient,
_dimension: usize,
) -> Result<(SystemSettings, bool)> {
match SystemSettings::get_current(db).await {
Ok(settings) => Ok((settings, false)),
Err(AppError::NotFound(_)) => {
info!("System settings missing; applying database migrations for namespace");
db.apply_migrations()
.await
.context("applying database migrations after missing system settings")?;
let settings = SystemSettings::get_current(db)
.await
.context("loading system settings after migrations")?;
Ok((settings, true))
}
Err(err) => Err(err).context("loading system settings"),
}
}

1243
evaluations/src/slice.rs Normal file

File diff suppressed because it is too large Load Diff

177
evaluations/src/snapshot.rs Normal file
View File

@@ -0,0 +1,177 @@
use std::path::PathBuf;
use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use tokio::fs;
use crate::{args::Config, slice};
use common::utils::embedding::EmbeddingProvider;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SnapshotMetadata {
pub dataset_id: String,
pub slice_id: String,
pub embedding_backend: String,
pub embedding_model: Option<String>,
pub embedding_dimension: usize,
pub rerank_enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DbSnapshotState {
pub dataset_id: String,
pub slice_id: String,
pub ingestion_fingerprint: String,
pub snapshot_hash: String,
pub updated_at: DateTime<Utc>,
#[serde(default)]
pub namespace: Option<String>,
#[serde(default)]
pub database: Option<String>,
#[serde(default)]
pub slice_case_count: usize,
}
pub struct Descriptor {
#[allow(dead_code)]
metadata: SnapshotMetadata,
dir: PathBuf,
metadata_hash: String,
}
impl Descriptor {
pub fn new(
config: &Config,
slice: &slice::ResolvedSlice<'_>,
embedding_provider: &EmbeddingProvider,
) -> Self {
let metadata = SnapshotMetadata {
dataset_id: slice.manifest.dataset_id.clone(),
slice_id: slice.manifest.slice_id.clone(),
embedding_backend: embedding_provider.backend_label().to_string(),
embedding_model: embedding_provider.model_code(),
embedding_dimension: embedding_provider.dimension(),
rerank_enabled: config.retrieval.rerank,
};
let dir = config
.cache_dir
.join("snapshots")
.join(&metadata.dataset_id)
.join(&metadata.slice_id);
let metadata_hash = compute_hash(&metadata);
Self {
metadata,
dir,
metadata_hash,
}
}
pub fn metadata_hash(&self) -> &str {
&self.metadata_hash
}
pub async fn load_db_state(&self) -> Result<Option<DbSnapshotState>> {
let path = self.db_state_path();
if !path.exists() {
return Ok(None);
}
let bytes = fs::read(&path)
.await
.with_context(|| format!("reading namespace state {}", path.display()))?;
let state = serde_json::from_slice(&bytes)
.with_context(|| format!("deserialising namespace state {}", path.display()))?;
Ok(Some(state))
}
pub async fn store_db_state(&self, state: &DbSnapshotState) -> Result<()> {
let path = self.db_state_path();
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).await.with_context(|| {
format!("creating namespace state directory {}", parent.display())
})?;
}
let blob =
serde_json::to_vec_pretty(state).context("serialising namespace state payload")?;
fs::write(&path, blob)
.await
.with_context(|| format!("writing namespace state {}", path.display()))?;
Ok(())
}
fn db_dir(&self) -> PathBuf {
self.dir.join("db")
}
fn db_state_path(&self) -> PathBuf {
self.db_dir().join("state.json")
}
#[cfg(test)]
pub fn from_parts(metadata: SnapshotMetadata, dir: PathBuf) -> Self {
let metadata_hash = compute_hash(&metadata);
Self {
metadata,
dir,
metadata_hash,
}
}
}
fn compute_hash(metadata: &SnapshotMetadata) -> String {
let mut hasher = Sha256::new();
hasher.update(
serde_json::to_vec(metadata).expect("snapshot metadata serialisation should succeed"),
);
format!("{:x}", hasher.finalize())
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn state_round_trip() {
let temp_dir = tempfile::tempdir().unwrap();
let metadata = SnapshotMetadata {
dataset_id: "dataset".into(),
slice_id: "slice".into(),
embedding_backend: "hashed".into(),
embedding_model: None,
embedding_dimension: 128,
rerank_enabled: true,
};
let descriptor = Descriptor::from_parts(
metadata,
temp_dir
.path()
.join("snapshots")
.join("dataset")
.join("slice"),
);
let state = DbSnapshotState {
dataset_id: "dataset".into(),
slice_id: "slice".into(),
ingestion_fingerprint: "fingerprint".into(),
snapshot_hash: descriptor.metadata_hash().to_string(),
updated_at: Utc::now(),
namespace: Some("ns".into()),
database: Some("db".into()),
slice_case_count: 42,
};
descriptor.store_db_state(&state).await.unwrap();
let loaded = descriptor.load_db_state().await.unwrap().unwrap();
assert_eq!(loaded.dataset_id, state.dataset_id);
assert_eq!(loaded.slice_id, state.slice_id);
assert_eq!(loaded.ingestion_fingerprint, state.ingestion_fingerprint);
assert_eq!(loaded.snapshot_hash, state.snapshot_hash);
assert_eq!(loaded.namespace, state.namespace);
assert_eq!(loaded.database, state.database);
assert_eq!(loaded.slice_case_count, state.slice_case_count);
}
}

461
evaluations/src/types.rs Normal file
View File

@@ -0,0 +1,461 @@
use std::collections::HashSet;
use chrono::{DateTime, Utc};
use common::storage::types::StoredObject;
use retrieval_pipeline::{
PipelineDiagnostics, PipelineStageTimings, RetrievedChunk, RetrievedEntity, StrategyOutput,
};
use serde::{Deserialize, Serialize};
use unicode_normalization::UnicodeNormalization;
#[derive(Debug, Serialize)]
pub struct EvaluationSummary {
pub generated_at: DateTime<Utc>,
pub k: usize,
pub limit: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub run_label: Option<String>,
pub total_cases: usize,
pub correct: usize,
pub precision: f64,
pub correct_at_1: usize,
pub correct_at_2: usize,
pub correct_at_3: usize,
pub precision_at_1: f64,
pub precision_at_2: f64,
pub precision_at_3: f64,
pub mrr: f64,
pub average_ndcg: f64,
pub duration_ms: u128,
pub dataset_id: String,
pub dataset_label: String,
pub dataset_includes_unanswerable: bool,
pub dataset_source: String,
pub includes_impossible_cases: bool,
pub require_verified_chunks: bool,
pub filtered_questions: usize,
pub retrieval_cases: usize,
pub retrieval_correct: usize,
pub retrieval_precision: f64,
pub llm_cases: usize,
pub llm_answered: usize,
pub llm_precision: f64,
pub slice_id: String,
pub slice_seed: u64,
pub slice_total_cases: usize,
pub slice_window_offset: usize,
pub slice_window_length: usize,
pub slice_cases: usize,
pub slice_positive_paragraphs: usize,
pub slice_negative_paragraphs: usize,
pub slice_total_paragraphs: usize,
pub slice_negative_multiplier: f32,
pub namespace_reused: bool,
pub corpus_paragraphs: usize,
pub ingestion_cache_path: String,
pub ingestion_reused: bool,
pub ingestion_embeddings_reused: bool,
pub ingestion_fingerprint: String,
pub positive_paragraphs_reused: usize,
pub negative_paragraphs_reused: usize,
pub latency_ms: LatencyStats,
pub perf: PerformanceTimings,
pub embedding_backend: String,
pub embedding_model: Option<String>,
pub embedding_dimension: usize,
pub rerank_enabled: bool,
pub rerank_pool_size: Option<usize>,
pub rerank_keep_top: usize,
pub concurrency: usize,
pub detailed_report: bool,
pub retrieval_strategy: String,
pub chunk_result_cap: usize,
pub chunk_rrf_k: f32,
pub chunk_rrf_vector_weight: f32,
pub chunk_rrf_fts_weight: f32,
pub chunk_rrf_use_vector: bool,
pub chunk_rrf_use_fts: bool,
pub ingest_chunk_min_tokens: usize,
pub ingest_chunk_max_tokens: usize,
pub ingest_chunks_only: bool,
pub ingest_chunk_overlap_tokens: usize,
pub chunk_vector_take: usize,
pub chunk_fts_take: usize,
pub chunk_avg_chars_per_token: usize,
pub max_chunks_per_entity: usize,
pub cases: Vec<CaseSummary>,
}
#[derive(Debug, Serialize)]
pub struct CaseSummary {
pub question_id: String,
pub question: String,
pub paragraph_id: String,
pub paragraph_title: String,
pub expected_source: String,
pub answers: Vec<String>,
pub matched: bool,
pub entity_match: bool,
pub chunk_text_match: bool,
pub chunk_id_match: bool,
pub is_impossible: bool,
pub has_verified_chunks: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub match_rank: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reciprocal_rank: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ndcg: Option<f64>,
pub latency_ms: u128,
pub retrieved: Vec<RetrievedSummary>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LatencyStats {
pub avg: f64,
pub p50: u128,
pub p95: u128,
}
impl Default for LatencyStats {
fn default() -> Self {
Self {
avg: 0.0,
p50: 0,
p95: 0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct StageLatencyBreakdown {
pub embed: LatencyStats,
pub collect_candidates: LatencyStats,
pub graph_expansion: LatencyStats,
pub chunk_attach: LatencyStats,
pub rerank: LatencyStats,
pub assemble: LatencyStats,
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct EvaluationStageTimings {
pub prepare_slice_ms: u128,
pub prepare_db_ms: u128,
pub prepare_corpus_ms: u128,
pub prepare_namespace_ms: u128,
pub run_queries_ms: u128,
pub summarize_ms: u128,
pub finalize_ms: u128,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct PerformanceTimings {
pub openai_base_url: String,
pub ingestion_ms: u128,
#[serde(skip_serializing_if = "Option::is_none")]
pub namespace_seed_ms: Option<u128>,
pub evaluation_stage_ms: EvaluationStageTimings,
pub stage_latency: StageLatencyBreakdown,
}
#[derive(Debug, Serialize)]
pub struct RetrievedSummary {
pub rank: usize,
pub entity_id: String,
pub source_id: String,
pub entity_name: String,
pub score: f32,
pub matched: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub entity_description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub entity_category: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub chunk_text_match: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub chunk_id_match: Option<bool>,
}
#[derive(Debug, Clone)]
pub struct EvaluationCandidate {
pub entity_id: String,
pub source_id: String,
pub entity_name: String,
pub entity_description: Option<String>,
pub entity_category: Option<String>,
pub score: f32,
pub chunks: Vec<RetrievedChunk>,
}
impl EvaluationCandidate {
fn from_entity(entity: RetrievedEntity) -> Self {
let entity_category = Some(format!("{:?}", entity.entity.entity_type));
Self {
entity_id: entity.entity.get_id().to_string(),
source_id: entity.entity.source_id.clone(),
entity_name: entity.entity.name.clone(),
entity_description: Some(entity.entity.description.clone()),
entity_category,
score: entity.score,
chunks: entity.chunks,
}
}
fn from_chunk(chunk: RetrievedChunk) -> Self {
let snippet = chunk_snippet(&chunk.chunk.chunk);
Self {
entity_id: chunk.chunk.get_id().to_string(),
source_id: chunk.chunk.source_id.clone(),
entity_name: chunk.chunk.source_id.clone(),
entity_description: Some(snippet),
entity_category: Some("Chunk".to_string()),
score: chunk.score,
chunks: vec![chunk],
}
}
}
pub fn adapt_strategy_output(output: StrategyOutput) -> Vec<EvaluationCandidate> {
match output {
StrategyOutput::Entities(entities) => entities
.into_iter()
.map(EvaluationCandidate::from_entity)
.collect(),
StrategyOutput::Chunks(chunks) => chunks
.into_iter()
.map(EvaluationCandidate::from_chunk)
.collect(),
}
}
#[derive(Debug, Serialize)]
pub struct CaseDiagnostics {
pub question_id: String,
pub question: String,
pub paragraph_id: String,
pub paragraph_title: String,
pub expected_source: String,
pub expected_chunk_ids: Vec<String>,
pub answers: Vec<String>,
pub entity_match: bool,
pub chunk_text_match: bool,
pub chunk_id_match: bool,
pub failure_reasons: Vec<String>,
pub missing_expected_chunk_ids: Vec<String>,
pub attached_chunk_ids: Vec<String>,
pub retrieved: Vec<EntityDiagnostics>,
#[serde(skip_serializing_if = "Option::is_none")]
pub pipeline: Option<PipelineDiagnostics>,
}
#[derive(Debug, Serialize)]
pub struct EntityDiagnostics {
pub rank: usize,
pub entity_id: String,
pub source_id: String,
pub name: String,
pub score: f32,
pub entity_match: bool,
pub chunk_text_match: bool,
pub chunk_id_match: bool,
pub chunks: Vec<ChunkDiagnosticsEntry>,
}
#[derive(Debug, Serialize)]
pub struct ChunkDiagnosticsEntry {
pub chunk_id: String,
pub score: f32,
pub contains_answer: bool,
pub expected_chunk: bool,
pub snippet: String,
}
pub fn text_contains_answer(text: &str, answers: &[String]) -> bool {
if answers.is_empty() {
return true;
}
let haystack = normalize_for_match(text);
answers
.iter()
.map(|needle| normalize_for_match(needle))
.any(|needle| !needle.is_empty() && haystack.contains(&needle))
}
fn normalize_for_match(input: &str) -> String {
// NFKC normalize, lowercase, and collapse whitespace/punctuation to a single space
// to reduce false negatives from formatting or punctuation differences.
let mut out = String::with_capacity(input.len());
let mut last_space = false;
for ch in input.nfkc().flat_map(|c| c.to_lowercase()) {
let is_space = ch.is_whitespace();
let is_punct = ch.is_ascii_punctuation()
|| matches!(
ch,
'“' | '”' | '' | '' | '«' | '»' | '' | '—' | '…' | '·' | '•'
);
if is_space || is_punct {
if !last_space {
out.push(' ');
last_space = true;
}
} else {
out.push(ch);
last_space = false;
}
}
let trimmed = out.trim();
if trimmed.is_empty() {
return String::new();
}
trimmed
.trim_matches(|c: char| c.is_ascii_punctuation() || c.is_whitespace())
.to_string()
}
fn chunk_snippet(text: &str) -> String {
const MAX_CHARS: usize = 160;
let trimmed = text.trim();
if trimmed.chars().count() <= MAX_CHARS {
return trimmed.to_string();
}
let mut acc = String::with_capacity(MAX_CHARS + 3);
for (idx, ch) in trimmed.chars().enumerate() {
if idx >= MAX_CHARS {
break;
}
acc.push(ch);
}
acc.push_str("...");
acc
}
pub fn compute_latency_stats(latencies: &[u128]) -> LatencyStats {
if latencies.is_empty() {
return LatencyStats {
avg: 0.0,
p50: 0,
p95: 0,
};
}
let mut sorted = latencies.to_vec();
sorted.sort_unstable();
let sum: u128 = sorted.iter().copied().sum();
let avg = sum as f64 / (sorted.len() as f64);
let p50 = percentile(&sorted, 0.50);
let p95 = percentile(&sorted, 0.95);
LatencyStats { avg, p50, p95 }
}
pub fn build_stage_latency_breakdown(samples: &[PipelineStageTimings]) -> StageLatencyBreakdown {
fn collect_stage<F>(samples: &[PipelineStageTimings], selector: F) -> Vec<u128>
where
F: Fn(&PipelineStageTimings) -> u128,
{
samples.iter().map(selector).collect()
}
StageLatencyBreakdown {
embed: compute_latency_stats(&collect_stage(samples, |entry| entry.embed_ms())),
collect_candidates: compute_latency_stats(&collect_stage(samples, |entry| {
entry.collect_candidates_ms()
})),
graph_expansion: compute_latency_stats(&collect_stage(samples, |entry| {
entry.graph_expansion_ms()
})),
chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| {
entry.chunk_attach_ms()
})),
rerank: compute_latency_stats(&collect_stage(samples, |entry| entry.rerank_ms())),
assemble: compute_latency_stats(&collect_stage(samples, |entry| entry.assemble_ms())),
}
}
fn percentile(sorted: &[u128], fraction: f64) -> u128 {
if sorted.is_empty() {
return 0;
}
let clamped = fraction.clamp(0.0, 1.0);
let idx = (clamped * (sorted.len() as f64 - 1.0)).round() as usize;
sorted[idx.min(sorted.len() - 1)]
}
pub fn build_case_diagnostics(
summary: &CaseSummary,
expected_chunk_ids: &[String],
answers_lower: &[String],
candidates: &[EvaluationCandidate],
pipeline_stats: Option<PipelineDiagnostics>,
) -> CaseDiagnostics {
let expected_set: HashSet<&str> = expected_chunk_ids.iter().map(|id| id.as_str()).collect();
let mut seen_chunks: HashSet<String> = HashSet::new();
let mut attached_chunk_ids = Vec::new();
let mut entity_diagnostics = Vec::new();
for (idx, candidate) in candidates.iter().enumerate() {
let mut chunk_entries = Vec::new();
for chunk in &candidate.chunks {
let contains_answer = text_contains_answer(&chunk.chunk.chunk, answers_lower);
let expected_chunk = expected_set.contains(chunk.chunk.get_id());
seen_chunks.insert(chunk.chunk.get_id().to_string());
attached_chunk_ids.push(chunk.chunk.get_id().to_string());
chunk_entries.push(ChunkDiagnosticsEntry {
chunk_id: chunk.chunk.get_id().to_string(),
score: chunk.score,
contains_answer,
expected_chunk,
snippet: chunk_snippet(&chunk.chunk.chunk),
});
}
entity_diagnostics.push(EntityDiagnostics {
rank: idx + 1,
entity_id: candidate.entity_id.clone(),
source_id: candidate.source_id.clone(),
name: candidate.entity_name.clone(),
score: candidate.score,
entity_match: candidate.source_id == summary.expected_source,
chunk_text_match: chunk_entries.iter().any(|entry| entry.contains_answer),
chunk_id_match: chunk_entries.iter().any(|entry| entry.expected_chunk),
chunks: chunk_entries,
});
}
let missing_expected_chunk_ids = expected_chunk_ids
.iter()
.filter(|id| !seen_chunks.contains(id.as_str()))
.cloned()
.collect::<Vec<_>>();
let mut failure_reasons = Vec::new();
if !summary.entity_match {
failure_reasons.push("entity_miss".to_string());
}
if !summary.chunk_text_match {
failure_reasons.push("chunk_text_missing".to_string());
}
if !summary.chunk_id_match {
failure_reasons.push("chunk_id_missing".to_string());
}
if !missing_expected_chunk_ids.is_empty() {
failure_reasons.push("expected_chunk_absent".to_string());
}
CaseDiagnostics {
question_id: summary.question_id.clone(),
question: summary.question.clone(),
paragraph_id: summary.paragraph_id.clone(),
paragraph_title: summary.paragraph_title.clone(),
expected_source: summary.expected_source.clone(),
expected_chunk_ids: expected_chunk_ids.to_vec(),
answers: summary.answers.clone(),
entity_match: summary.entity_match,
chunk_text_match: summary.chunk_text_match,
chunk_id_match: summary.chunk_id_match,
failure_reasons,
missing_expected_chunk_ids,
attached_chunk_ids,
retrieved: entity_diagnostics,
pipeline: pipeline_stats,
}
}