mirror of
https://github.com/perstarkse/minne.git
synced 2026-03-19 16:21:30 +01:00
refactored to clap, mrr and ndcg
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
# Changelog
|
||||
## Unreleased
|
||||
- Added a shared `benchmarks` crate with deterministic fixtures and Criterion suites for ingestion, retrieval, and HTML handlers, plus documented baseline results for local performance checks.
|
||||
- Added a benchmarks create for evaluating the retrieval process
|
||||
- Added fastembed embedding support, enables the use of local CPU generated embeddings
|
||||
- Embeddings stored on own table
|
||||
|
||||
## Version 0.2.6 (2025-10-29)
|
||||
- Added an opt-in FastEmbed-based reranking stage behind `reranking_enabled`. It improves retrieval accuracy by re-scoring hybrid results.
|
||||
|
||||
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -2177,6 +2177,7 @@ dependencies = [
|
||||
"async-openai",
|
||||
"async-trait",
|
||||
"chrono",
|
||||
"clap",
|
||||
"common",
|
||||
"criterion",
|
||||
"fastembed",
|
||||
@@ -5464,6 +5465,7 @@ dependencies = [
|
||||
"async-openai",
|
||||
"async-trait",
|
||||
"axum",
|
||||
"clap",
|
||||
"common",
|
||||
"fastembed",
|
||||
"futures",
|
||||
|
||||
@@ -28,6 +28,7 @@ once_cell = "1.19"
|
||||
serde_yaml = "0.9"
|
||||
criterion = "0.5"
|
||||
state-machines = { workspace = true }
|
||||
clap = { version = "4.4", features = ["derive", "env"] }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
|
||||
947
eval/src/args.rs
947
eval/src/args.rs
File diff suppressed because it is too large
Load Diff
@@ -13,6 +13,7 @@ use chrono::{DateTime, TimeZone, Utc};
|
||||
use once_cell::sync::OnceCell;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::warn;
|
||||
use clap::ValueEnum;
|
||||
|
||||
const MANIFEST_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/manifest.yaml");
|
||||
static DATASET_CATALOG: OnceCell<DatasetCatalog> = OnceCell::new();
|
||||
@@ -243,7 +244,7 @@ fn dataset_entry_for_kind(kind: DatasetKind) -> Result<&'static DatasetEntry> {
|
||||
catalog.dataset(kind.id())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
|
||||
pub enum DatasetKind {
|
||||
SquadV2,
|
||||
NaturalQuestions,
|
||||
@@ -298,6 +299,12 @@ impl DatasetKind {
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for DatasetKind {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.id())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DatasetKind {
|
||||
fn default() -> Self {
|
||||
Self::SquadV2
|
||||
|
||||
@@ -272,6 +272,9 @@ pub(crate) async fn run_queries(
|
||||
}
|
||||
|
||||
let overall_match = match_rank.is_some();
|
||||
let reciprocal_rank = calculate_reciprocal_rank(match_rank);
|
||||
let ndcg = calculate_ndcg(&retrieved, config.k);
|
||||
|
||||
let summary = CaseSummary {
|
||||
question_id,
|
||||
question,
|
||||
@@ -286,6 +289,8 @@ pub(crate) async fn run_queries(
|
||||
is_impossible,
|
||||
has_verified_chunks,
|
||||
match_rank,
|
||||
reciprocal_rank: Some(reciprocal_rank),
|
||||
ndcg: Some(ndcg),
|
||||
latency_ms: query_latency,
|
||||
retrieved,
|
||||
};
|
||||
@@ -359,3 +364,34 @@ pub(crate) async fn run_queries(
|
||||
.run_queries()
|
||||
.map_err(|(_, guard)| map_guard_error("run_queries", guard))
|
||||
}
|
||||
|
||||
fn calculate_reciprocal_rank(rank: Option<usize>) -> f64 {
|
||||
match rank {
|
||||
Some(r) if r > 0 => 1.0 / (r as f64),
|
||||
_ => 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
fn calculate_ndcg(retrieved: &[RetrievedSummary], k: usize) -> f64 {
|
||||
let mut dcg = 0.0;
|
||||
|
||||
for (i, item) in retrieved.iter().enumerate() {
|
||||
if i >= k {
|
||||
break;
|
||||
}
|
||||
if item.matched {
|
||||
let rel = 1.0;
|
||||
dcg += rel / (i as f64 + 2.0).log2();
|
||||
}
|
||||
}
|
||||
|
||||
// IDCG for a single relevant item at rank 1 is 1.0 / log2(2) = 1.0
|
||||
let idcg = 1.0;
|
||||
|
||||
if dcg == 0.0 {
|
||||
0.0
|
||||
} else {
|
||||
dcg / idcg
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -45,6 +45,8 @@ pub(crate) async fn summarize(
|
||||
let mut retrieval_cases = 0usize;
|
||||
let mut llm_cases = 0usize;
|
||||
let mut llm_answered = 0usize;
|
||||
let mut sum_reciprocal_rank = 0.0;
|
||||
let mut sum_ndcg = 0.0;
|
||||
for summary in &summaries {
|
||||
if summary.is_impossible {
|
||||
llm_cases += 1;
|
||||
@@ -54,6 +56,12 @@ pub(crate) async fn summarize(
|
||||
continue;
|
||||
}
|
||||
retrieval_cases += 1;
|
||||
if let Some(rr) = summary.reciprocal_rank {
|
||||
sum_reciprocal_rank += rr;
|
||||
}
|
||||
if let Some(ndcg) = summary.ndcg {
|
||||
sum_ndcg += ndcg;
|
||||
}
|
||||
if summary.matched {
|
||||
correct += 1;
|
||||
if let Some(rank) = summary.match_rank {
|
||||
@@ -99,6 +107,16 @@ pub(crate) async fn summarize(
|
||||
} else {
|
||||
(correct_at_3 as f64) / (retrieval_cases as f64)
|
||||
};
|
||||
let mrr = if retrieval_cases == 0 {
|
||||
0.0
|
||||
} else {
|
||||
sum_reciprocal_rank / (retrieval_cases as f64)
|
||||
};
|
||||
let average_ndcg = if retrieval_cases == 0 {
|
||||
0.0
|
||||
} else {
|
||||
sum_ndcg / (retrieval_cases as f64)
|
||||
};
|
||||
|
||||
let active_tuning = ctx
|
||||
.retrieval_config
|
||||
@@ -131,6 +149,8 @@ pub(crate) async fn summarize(
|
||||
precision_at_1,
|
||||
precision_at_2,
|
||||
precision_at_3,
|
||||
mrr,
|
||||
average_ndcg,
|
||||
duration_ms,
|
||||
dataset_id: dataset.metadata.id.clone(),
|
||||
dataset_label: dataset.metadata.label.clone(),
|
||||
|
||||
@@ -23,6 +23,8 @@ pub struct EvaluationSummary {
|
||||
pub precision_at_1: f64,
|
||||
pub precision_at_2: f64,
|
||||
pub precision_at_3: f64,
|
||||
pub mrr: f64,
|
||||
pub average_ndcg: f64,
|
||||
pub duration_ms: u128,
|
||||
pub dataset_id: String,
|
||||
pub dataset_label: String,
|
||||
@@ -90,6 +92,10 @@ pub struct CaseSummary {
|
||||
pub has_verified_chunks: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub match_rank: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reciprocal_rank: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ndcg: Option<f64>,
|
||||
pub latency_ms: u128,
|
||||
pub retrieved: Vec<RetrievedSummary>,
|
||||
}
|
||||
|
||||
@@ -92,10 +92,8 @@ async fn async_main() -> anyhow::Result<()> {
|
||||
|
||||
let parsed = args::parse()?;
|
||||
|
||||
if parsed.show_help {
|
||||
args::print_help();
|
||||
return Ok(());
|
||||
}
|
||||
// Clap handles help automatically, so we don't need to check for it manually
|
||||
|
||||
|
||||
if parsed.config.inspect_question.is_some() {
|
||||
inspection::inspect_question(&parsed.config).await?;
|
||||
|
||||
@@ -78,6 +78,8 @@ pub struct RetrievalSection {
|
||||
pub precision_at_1: f64,
|
||||
pub precision_at_2: f64,
|
||||
pub precision_at_3: f64,
|
||||
pub mrr: f64,
|
||||
pub average_ndcg: f64,
|
||||
pub latency: LatencyStats,
|
||||
pub concurrency: usize,
|
||||
pub strategy: String,
|
||||
@@ -178,6 +180,8 @@ impl EvaluationReport {
|
||||
precision_at_1: summary.precision_at_1,
|
||||
precision_at_2: summary.precision_at_2,
|
||||
precision_at_3: summary.precision_at_3,
|
||||
mrr: summary.mrr,
|
||||
average_ndcg: summary.average_ndcg,
|
||||
latency: summary.latency_ms.clone(),
|
||||
concurrency: summary.concurrency,
|
||||
strategy: summary.retrieval_strategy.clone(),
|
||||
@@ -435,6 +439,14 @@ fn render_markdown(report: &EvaluationReport) -> String {
|
||||
report.retrieval.precision_at_2,
|
||||
report.retrieval.precision_at_3
|
||||
));
|
||||
md.push_str(&format!(
|
||||
"| MRR | {:.3} |\\n",
|
||||
report.retrieval.mrr
|
||||
));
|
||||
md.push_str(&format!(
|
||||
"| NDCG | {:.3} |\\n",
|
||||
report.retrieval.average_ndcg
|
||||
));
|
||||
md.push_str(&format!(
|
||||
"| Latency Avg / P50 / P95 (ms) | {:.1} / {} / {} |\\n",
|
||||
report.retrieval.latency.avg, report.retrieval.latency.p50, report.retrieval.latency.p95
|
||||
@@ -687,6 +699,10 @@ struct HistoryEntry {
|
||||
precision_at_2: f64,
|
||||
precision_at_3: f64,
|
||||
#[serde(default)]
|
||||
mrr: f64,
|
||||
#[serde(default)]
|
||||
average_ndcg: f64,
|
||||
#[serde(default)]
|
||||
retrieval_cases: usize,
|
||||
#[serde(default)]
|
||||
retrieval_precision: f64,
|
||||
@@ -771,6 +787,8 @@ fn record_history(summary: &EvaluationSummary, report_dir: &Path) -> Result<()>
|
||||
precision_at_1: summary.precision_at_1,
|
||||
precision_at_2: summary.precision_at_2,
|
||||
precision_at_3: summary.precision_at_3,
|
||||
mrr: summary.mrr,
|
||||
average_ndcg: summary.average_ndcg,
|
||||
retrieval_cases: summary.retrieval_cases,
|
||||
retrieval_precision: summary.retrieval_precision,
|
||||
llm_cases: summary.llm_cases,
|
||||
|
||||
@@ -21,5 +21,6 @@ async-openai = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
fastembed = { workspace = true }
|
||||
clap = { version = "4.4", features = ["derive"] }
|
||||
|
||||
common = { path = "../common", features = ["test-utils"] }
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, clap::ValueEnum)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RetrievalStrategy {
|
||||
Initial,
|
||||
|
||||
Reference in New Issue
Block a user