refactored to clap, mrr and ndcg

This commit is contained in:
Per Stark
2025-11-28 21:26:51 +01:00
parent 67004c9646
commit 08b1612fcb
12 changed files with 434 additions and 618 deletions

View File

@@ -1,6 +1,8 @@
# Changelog
## Unreleased
- Added a shared `benchmarks` crate with deterministic fixtures and Criterion suites for ingestion, retrieval, and HTML handlers, plus documented baseline results for local performance checks.
- Added a benchmarks create for evaluating the retrieval process
- Added fastembed embedding support, enables the use of local CPU generated embeddings
- Embeddings stored on own table
## Version 0.2.6 (2025-10-29)
- Added an opt-in FastEmbed-based reranking stage behind `reranking_enabled`. It improves retrieval accuracy by re-scoring hybrid results.

2
Cargo.lock generated
View File

@@ -2177,6 +2177,7 @@ dependencies = [
"async-openai",
"async-trait",
"chrono",
"clap",
"common",
"criterion",
"fastembed",
@@ -5464,6 +5465,7 @@ dependencies = [
"async-openai",
"async-trait",
"axum",
"clap",
"common",
"fastembed",
"futures",

View File

@@ -28,6 +28,7 @@ once_cell = "1.19"
serde_yaml = "0.9"
criterion = "0.5"
state-machines = { workspace = true }
clap = { version = "4.4", features = ["derive", "env"] }
[dev-dependencies]
tempfile = { workspace = true }

File diff suppressed because it is too large Load Diff

View File

@@ -13,6 +13,7 @@ use chrono::{DateTime, TimeZone, Utc};
use once_cell::sync::OnceCell;
use serde::{Deserialize, Serialize};
use tracing::warn;
use clap::ValueEnum;
const MANIFEST_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/manifest.yaml");
static DATASET_CATALOG: OnceCell<DatasetCatalog> = OnceCell::new();
@@ -243,7 +244,7 @@ fn dataset_entry_for_kind(kind: DatasetKind) -> Result<&'static DatasetEntry> {
catalog.dataset(kind.id())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
pub enum DatasetKind {
SquadV2,
NaturalQuestions,
@@ -298,6 +299,12 @@ impl DatasetKind {
}
}
impl std::fmt::Display for DatasetKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.id())
}
}
impl Default for DatasetKind {
fn default() -> Self {
Self::SquadV2

View File

@@ -272,6 +272,9 @@ pub(crate) async fn run_queries(
}
let overall_match = match_rank.is_some();
let reciprocal_rank = calculate_reciprocal_rank(match_rank);
let ndcg = calculate_ndcg(&retrieved, config.k);
let summary = CaseSummary {
question_id,
question,
@@ -286,6 +289,8 @@ pub(crate) async fn run_queries(
is_impossible,
has_verified_chunks,
match_rank,
reciprocal_rank: Some(reciprocal_rank),
ndcg: Some(ndcg),
latency_ms: query_latency,
retrieved,
};
@@ -359,3 +364,34 @@ pub(crate) async fn run_queries(
.run_queries()
.map_err(|(_, guard)| map_guard_error("run_queries", guard))
}
fn calculate_reciprocal_rank(rank: Option<usize>) -> f64 {
match rank {
Some(r) if r > 0 => 1.0 / (r as f64),
_ => 0.0,
}
}
fn calculate_ndcg(retrieved: &[RetrievedSummary], k: usize) -> f64 {
let mut dcg = 0.0;
for (i, item) in retrieved.iter().enumerate() {
if i >= k {
break;
}
if item.matched {
let rel = 1.0;
dcg += rel / (i as f64 + 2.0).log2();
}
}
// IDCG for a single relevant item at rank 1 is 1.0 / log2(2) = 1.0
let idcg = 1.0;
if dcg == 0.0 {
0.0
} else {
dcg / idcg
}
}

View File

@@ -45,6 +45,8 @@ pub(crate) async fn summarize(
let mut retrieval_cases = 0usize;
let mut llm_cases = 0usize;
let mut llm_answered = 0usize;
let mut sum_reciprocal_rank = 0.0;
let mut sum_ndcg = 0.0;
for summary in &summaries {
if summary.is_impossible {
llm_cases += 1;
@@ -54,6 +56,12 @@ pub(crate) async fn summarize(
continue;
}
retrieval_cases += 1;
if let Some(rr) = summary.reciprocal_rank {
sum_reciprocal_rank += rr;
}
if let Some(ndcg) = summary.ndcg {
sum_ndcg += ndcg;
}
if summary.matched {
correct += 1;
if let Some(rank) = summary.match_rank {
@@ -99,6 +107,16 @@ pub(crate) async fn summarize(
} else {
(correct_at_3 as f64) / (retrieval_cases as f64)
};
let mrr = if retrieval_cases == 0 {
0.0
} else {
sum_reciprocal_rank / (retrieval_cases as f64)
};
let average_ndcg = if retrieval_cases == 0 {
0.0
} else {
sum_ndcg / (retrieval_cases as f64)
};
let active_tuning = ctx
.retrieval_config
@@ -131,6 +149,8 @@ pub(crate) async fn summarize(
precision_at_1,
precision_at_2,
precision_at_3,
mrr,
average_ndcg,
duration_ms,
dataset_id: dataset.metadata.id.clone(),
dataset_label: dataset.metadata.label.clone(),

View File

@@ -23,6 +23,8 @@ pub struct EvaluationSummary {
pub precision_at_1: f64,
pub precision_at_2: f64,
pub precision_at_3: f64,
pub mrr: f64,
pub average_ndcg: f64,
pub duration_ms: u128,
pub dataset_id: String,
pub dataset_label: String,
@@ -90,6 +92,10 @@ pub struct CaseSummary {
pub has_verified_chunks: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub match_rank: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reciprocal_rank: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ndcg: Option<f64>,
pub latency_ms: u128,
pub retrieved: Vec<RetrievedSummary>,
}

View File

@@ -92,10 +92,8 @@ async fn async_main() -> anyhow::Result<()> {
let parsed = args::parse()?;
if parsed.show_help {
args::print_help();
return Ok(());
}
// Clap handles help automatically, so we don't need to check for it manually
if parsed.config.inspect_question.is_some() {
inspection::inspect_question(&parsed.config).await?;

View File

@@ -78,6 +78,8 @@ pub struct RetrievalSection {
pub precision_at_1: f64,
pub precision_at_2: f64,
pub precision_at_3: f64,
pub mrr: f64,
pub average_ndcg: f64,
pub latency: LatencyStats,
pub concurrency: usize,
pub strategy: String,
@@ -178,6 +180,8 @@ impl EvaluationReport {
precision_at_1: summary.precision_at_1,
precision_at_2: summary.precision_at_2,
precision_at_3: summary.precision_at_3,
mrr: summary.mrr,
average_ndcg: summary.average_ndcg,
latency: summary.latency_ms.clone(),
concurrency: summary.concurrency,
strategy: summary.retrieval_strategy.clone(),
@@ -435,6 +439,14 @@ fn render_markdown(report: &EvaluationReport) -> String {
report.retrieval.precision_at_2,
report.retrieval.precision_at_3
));
md.push_str(&format!(
"| MRR | {:.3} |\\n",
report.retrieval.mrr
));
md.push_str(&format!(
"| NDCG | {:.3} |\\n",
report.retrieval.average_ndcg
));
md.push_str(&format!(
"| Latency Avg / P50 / P95 (ms) | {:.1} / {} / {} |\\n",
report.retrieval.latency.avg, report.retrieval.latency.p50, report.retrieval.latency.p95
@@ -687,6 +699,10 @@ struct HistoryEntry {
precision_at_2: f64,
precision_at_3: f64,
#[serde(default)]
mrr: f64,
#[serde(default)]
average_ndcg: f64,
#[serde(default)]
retrieval_cases: usize,
#[serde(default)]
retrieval_precision: f64,
@@ -771,6 +787,8 @@ fn record_history(summary: &EvaluationSummary, report_dir: &Path) -> Result<()>
precision_at_1: summary.precision_at_1,
precision_at_2: summary.precision_at_2,
precision_at_3: summary.precision_at_3,
mrr: summary.mrr,
average_ndcg: summary.average_ndcg,
retrieval_cases: summary.retrieval_cases,
retrieval_precision: summary.retrieval_precision,
llm_cases: summary.llm_cases,

View File

@@ -21,5 +21,6 @@ async-openai = { workspace = true }
async-trait = { workspace = true }
uuid = { workspace = true }
fastembed = { workspace = true }
clap = { version = "4.4", features = ["derive"] }
common = { path = "../common", features = ["test-utils"] }

View File

@@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, clap::ValueEnum)]
#[serde(rename_all = "snake_case")]
pub enum RetrievalStrategy {
Initial,