chore: refactor retrieval pipeline to chunk-first RRF with derived entities and slimmer eval surface.

Collapse the multi-strategy entity engine into one benchmarked chunk retrieval path, derive entities from retrieved chunks, and update consumers, docs, and clippy fixes across the workspace.
This commit is contained in:
Per Stark
2026-05-30 22:19:08 +02:00
parent c70141de35
commit 5c2d2e24d3
38 changed files with 1049 additions and 2614 deletions
+7 -18
View File
@@ -5,7 +5,6 @@ use std::{
use anyhow::{anyhow, Context, Result};
use clap::{Args, Parser, ValueEnum};
use retrieval_pipeline::RetrievalStrategy;
use crate::datasets::DatasetKind;
@@ -55,10 +54,6 @@ pub struct RetrievalSettings {
#[arg(long)]
pub chunk_fts_take: Option<usize>,
/// Override average characters per token used for budgeting
#[arg(long)]
pub chunk_avg_chars_per_token: Option<usize>,
/// Override maximum chunks attached per entity
#[arg(long)]
pub max_chunks_per_entity: Option<usize>,
@@ -71,41 +66,37 @@ pub struct RetrievalSettings {
#[arg(long, default_value_t = 4)]
pub rerank_pool_size: usize,
/// Keep top-N entities after reranking
/// Keep top-N chunks after reranking
#[arg(long, default_value_t = 10)]
pub rerank_keep_top: usize,
/// Cap the number of chunks returned by retrieval (revised strategy)
/// Cap the number of chunks returned by retrieval
#[arg(long, default_value_t = 5)]
pub chunk_result_cap: usize,
/// Reciprocal rank fusion k value for revised chunk merging
/// Reciprocal rank fusion k value for chunk merging
#[arg(long)]
pub chunk_rrf_k: Option<f32>,
/// Weight for vector ranks in revised RRF
/// Weight for vector ranks in RRF
#[arg(long)]
pub chunk_rrf_vector_weight: Option<f32>,
/// Weight for chunk FTS ranks in revised RRF
/// Weight for chunk FTS ranks in RRF
#[arg(long)]
pub chunk_rrf_fts_weight: Option<f32>,
/// Include vector ranks in revised RRF (default: true)
/// Include vector ranks in RRF (default: true)
#[arg(long)]
pub chunk_rrf_use_vector: Option<bool>,
/// Include chunk FTS ranks in revised RRF (default: true)
/// Include chunk FTS ranks in RRF (default: true)
#[arg(long)]
pub chunk_rrf_use_fts: Option<bool>,
/// Require verified chunks (disable with --llm-mode)
#[arg(skip = true)]
pub require_verified_chunks: bool,
/// Select the retrieval pipeline strategy
#[arg(long, default_value_t = RetrievalStrategy::Default)]
pub strategy: RetrievalStrategy,
}
impl Default for RetrievalSettings {
@@ -113,7 +104,6 @@ impl Default for RetrievalSettings {
Self {
chunk_vector_take: None,
chunk_fts_take: None,
chunk_avg_chars_per_token: None,
max_chunks_per_entity: None,
rerank: false,
rerank_pool_size: 4,
@@ -125,7 +115,6 @@ impl Default for RetrievalSettings {
chunk_rrf_use_vector: None,
chunk_rrf_use_fts: None,
require_verified_chunks: true,
strategy: RetrievalStrategy::Default,
}
}
}
+18 -20
View File
@@ -51,8 +51,8 @@ pub fn mirror_perf_outputs(
pub fn print_console_summary(record: &EvaluationReport) {
let perf = &record.performance;
println!(
"[perf] retrieval strategy={} | concurrency={} | rerank={} (pool {:?}, keep {})",
record.retrieval.strategy,
"[perf] resolve_entities={} | concurrency={} | rerank={} (pool {:?}, keep {})",
record.retrieval.resolve_entities,
record.retrieval.concurrency,
record.retrieval.rerank_enabled,
record.retrieval.rerank_pool_size,
@@ -63,16 +63,14 @@ pub fn print_console_summary(record: &EvaluationReport) {
perf.ingestion_ms,
format_duration(perf.namespace_seed_ms),
);
let stage = &perf.stage_latency;
println!(
"[perf] stage avg ms → embed {:.1} | collect {:.1} | graph {:.1} | chunk {:.1} | rerank {:.1} | assemble {:.1}",
stage.embed.avg,
stage.collect_candidates.avg,
stage.graph_expansion.avg,
stage.chunk_attach.avg,
stage.rerank.avg,
stage.assemble.avg,
);
let stage_summary = perf
.stage_latency
.stages
.iter()
.map(|s| format!("{} {:.1}", s.stage, s.stats.avg))
.collect::<Vec<_>>()
.join(" | ");
println!("[perf] stage avg ms → {stage_summary}");
let eval = &perf.evaluation_stages_ms;
println!(
"[perf] eval stage ms → slice {} | db {} | corpus {} | namespace {} | queries {} | summarize {} | finalize {}",
@@ -107,12 +105,13 @@ mod tests {
fn sample_stage_latency() -> crate::eval::StageLatencyBreakdown {
crate::eval::StageLatencyBreakdown {
embed: sample_latency(),
collect_candidates: sample_latency(),
graph_expansion: sample_latency(),
chunk_attach: sample_latency(),
rerank: sample_latency(),
assemble: sample_latency(),
stages: ["embed", "search", "rerank", "resolve_entities", "assemble"]
.into_iter()
.map(|stage| crate::eval::StageLatency {
stage: stage.to_string(),
stats: sample_latency(),
})
.collect(),
}
}
@@ -193,7 +192,7 @@ mod tests {
rerank_keep_top: 10,
concurrency: 2,
detailed_report: false,
retrieval_strategy: "initial".into(),
resolve_entities: false,
chunk_result_cap: 5,
chunk_rrf_k: 60.0,
chunk_rrf_vector_weight: 1.0,
@@ -206,7 +205,6 @@ mod tests {
ingest_chunk_overlap_tokens: 50,
chunk_vector_take: 20,
chunk_fts_take: 20,
chunk_avg_chars_per_token: 4,
max_chunks_per_entity: 4,
cases: Vec::new(),
}
+7 -23
View File
@@ -6,7 +6,7 @@ use futures::stream::{self, StreamExt};
use tracing::{debug, info};
use crate::eval::{
adapt_strategy_output, build_case_diagnostics, text_contains_answer, CaseDiagnostics,
adapt_retrieval_output, build_case_diagnostics, text_contains_answer, CaseDiagnostics,
CaseSummary, RetrievedSummary,
};
use retrieval_pipeline::{
@@ -51,14 +51,8 @@ pub(crate) async fn run_queries(
None
};
let mut retrieval_config = RetrievalConfig {
strategy: config.retrieval.strategy,
..Default::default()
};
let mut retrieval_config = RetrievalConfig::default();
retrieval_config.tuning.rerank_keep_top = config.retrieval.rerank_keep_top;
if retrieval_config.tuning.fallback_min_results < config.retrieval.rerank_keep_top {
retrieval_config.tuning.fallback_min_results = config.retrieval.rerank_keep_top;
}
retrieval_config.tuning.chunk_result_cap = config.retrieval.chunk_result_cap.max(1);
if let Some(value) = config.retrieval.chunk_vector_take {
retrieval_config.tuning.chunk_vector_take = value;
@@ -81,9 +75,6 @@ pub(crate) async fn run_queries(
if let Some(value) = config.retrieval.chunk_rrf_use_fts {
retrieval_config.tuning.flags.chunk_rrf_use_fts = value.into();
}
if let Some(value) = config.retrieval.chunk_avg_chars_per_token {
retrieval_config.tuning.avg_chars_per_token = value;
}
if let Some(value) = config.retrieval.max_chunks_per_entity {
retrieval_config.tuning.max_chunks_per_entity = value;
}
@@ -187,7 +178,7 @@ pub(crate) async fn run_queries(
None => None,
};
let params = pipeline::StrategyParams {
let params = pipeline::RetrievalParams {
db_client: &db,
openai_client: &openai_client,
embedding_provider: Some(&embedding_provider),
@@ -196,26 +187,19 @@ pub(crate) async fn run_queries(
config: (*retrieval_config).clone(),
reranker,
};
let (result_output, pipeline_diagnostics, stage_timings) = if diagnostics_enabled {
let outcome = pipeline::run_pipeline_with_embedding_with_diagnostics(
let (result_output, pipeline_diagnostics, stage_timings) = {
let outcome = pipeline::run_with_embedding_instrumented(
params,
query_embedding,
diagnostics_enabled,
)
.await
.with_context(|| format!("running pipeline for question {question_id}"))?;
(outcome.results, outcome.diagnostics, outcome.stage_timings)
} else {
let outcome = pipeline::run_pipeline_with_embedding_with_metrics(
params,
query_embedding,
)
.await
.with_context(|| format!("running pipeline for question {question_id}"))?;
(outcome.results, None, outcome.stage_timings)
};
let query_latency = query_start.elapsed().as_millis();
let candidates = adapt_strategy_output(result_output);
let candidates = adapt_retrieval_output(result_output);
let mut retrieved = Vec::new();
let mut match_rank = None;
let answers_lower: Vec<String> =
+4 -2
View File
@@ -201,7 +201,10 @@ pub(crate) async fn summarize(
rerank_keep_top: config.retrieval.rerank_keep_top,
concurrency: config.concurrency.max(1),
detailed_report: config.detailed_report,
retrieval_strategy: config.retrieval.strategy.to_string(),
resolve_entities: ctx
.retrieval_config
.as_ref()
.is_some_and(|config| config.resolve_entities),
chunk_result_cap: config.retrieval.chunk_result_cap,
chunk_rrf_k: active_tuning.chunk_rrf_k,
chunk_rrf_vector_weight: active_tuning.chunk_rrf_vector_weight,
@@ -214,7 +217,6 @@ pub(crate) async fn summarize(
ingest_chunk_overlap_tokens: config.ingest.ingest_chunk_overlap_tokens,
chunk_vector_take: active_tuning.chunk_vector_take,
chunk_fts_take: active_tuning.chunk_fts_take,
chunk_avg_chars_per_token: active_tuning.avg_chars_per_token,
max_chunks_per_entity: active_tuning.max_chunks_per_entity,
cases: summaries,
});
+54 -49
View File
@@ -85,7 +85,7 @@ pub struct RetrievalSection {
pub average_ndcg: f64,
pub latency: LatencyStats,
pub concurrency: usize,
pub strategy: String,
pub resolve_entities: bool,
pub rerank_enabled: bool,
pub rerank_pool_size: Option<usize>,
pub rerank_keep_top: usize,
@@ -226,7 +226,7 @@ impl EvaluationReport {
average_ndcg: summary.average_ndcg,
latency: summary.latency_ms.clone(),
concurrency: summary.concurrency,
strategy: summary.retrieval_strategy.clone(),
resolve_entities: summary.resolve_entities,
rerank_enabled: summary.rerank_enabled,
rerank_pool_size: summary.rerank_pool_size,
rerank_keep_top: summary.rerank_keep_top,
@@ -463,7 +463,12 @@ fn render_markdown(report: &EvaluationReport) -> String {
write!(md, "| MRR | {:.3} |\\n", report.retrieval.mrr).unwrap();
write!(md, "| NDCG | {:.3} |\\n", report.retrieval.average_ndcg).unwrap();
write!(md, "| Latency Avg / P50 / P95 (ms) | {:.1} / {} / {} |\\n", report.retrieval.latency.avg, report.retrieval.latency.p50, report.retrieval.latency.p95).unwrap();
write!(md, "| Strategy | `{}` |\\n", report.retrieval.strategy).unwrap();
write!(
md,
"| Resolve entities | {} |\\n",
bool_badge(report.retrieval.resolve_entities)
)
.unwrap();
write!(md, "| Concurrency | {} |\\n", report.retrieval.concurrency).unwrap();
if report.retrieval.rerank_enabled {
let pool = report
@@ -510,28 +515,9 @@ fn render_markdown(report: &EvaluationReport) -> String {
md.push_str("\\n## Retrieval Stage Timings\\n\\n");
md.push_str("| Stage | Avg (ms) | P50 (ms) | P95 (ms) |\\n| --- | --- | --- | --- |\\n");
write_stage_row(&mut md, "Embed", &report.performance.stage_latency.embed);
write_stage_row(
&mut md,
"Collect Candidates",
&report.performance.stage_latency.collect_candidates,
);
write_stage_row(
&mut md,
"Graph Expansion",
&report.performance.stage_latency.graph_expansion,
);
write_stage_row(
&mut md,
"Chunk Attach",
&report.performance.stage_latency.chunk_attach,
);
write_stage_row(&mut md, "Rerank", &report.performance.stage_latency.rerank);
write_stage_row(
&mut md,
"Assemble",
&report.performance.stage_latency.assemble,
);
for stage in &report.performance.stage_latency.stages {
write_stage_row(&mut md, &prettify_stage(&stage.stage), &stage.stats);
}
if report.misses.is_empty() {
if report.detailed_report {
@@ -623,6 +609,20 @@ fn write_stage_row(buf: &mut String, label: &str, stats: &LatencyStats) {
.unwrap();
}
/// Turn a stable stage label (e.g. `resolve_entities`) into a display title (`Resolve Entities`).
fn prettify_stage(label: &str) -> String {
label
.split('_')
.map(|word| {
let mut chars = word.chars();
chars.next().map_or_else(String::new, |first| {
first.to_uppercase().collect::<String>() + chars.as_str()
})
})
.collect::<Vec<_>>()
.join(" ")
}
fn bool_badge(value: bool) -> &'static str {
if value {
""
@@ -740,17 +740,6 @@ struct LegacyHistoryDelta {
latency_avg_ms: f64,
}
fn default_stage_latency() -> StageLatencyBreakdown {
StageLatencyBreakdown {
embed: LatencyStats::default(),
collect_candidates: LatencyStats::default(),
graph_expansion: LatencyStats::default(),
chunk_attach: LatencyStats::default(),
rerank: LatencyStats::default(),
assemble: LatencyStats::default(),
}
}
#[allow(clippy::too_many_lines)]
fn convert_legacy_entry(entry: LegacyHistoryEntry) -> EvaluationReport {
let overview = OverviewSection {
@@ -807,7 +796,7 @@ fn convert_legacy_entry(entry: LegacyHistoryEntry) -> EvaluationReport {
average_ndcg: entry.average_ndcg,
latency: entry.latency_ms,
concurrency: 0,
strategy: "unknown".into(),
resolve_entities: false,
rerank_enabled: entry.rerank_enabled,
rerank_pool_size: entry.rerank_pool_size,
rerank_keep_top: entry.rerank_keep_top,
@@ -840,7 +829,7 @@ fn convert_legacy_entry(entry: LegacyHistoryEntry) -> EvaluationReport {
ingestion_ms: entry.ingestion_ms,
namespace_seed_ms: entry.namespace_seed_ms,
evaluation_stages_ms: EvaluationStageTimings::default(),
stage_latency: default_stage_latency(),
stage_latency: StageLatencyBreakdown::default(),
namespace_reused: false,
ingestion_reused: entry.ingestion_reused,
embeddings_reused: entry.ingestion_embeddings_reused,
@@ -915,7 +904,8 @@ fn record_history(report: &EvaluationReport, report_dir: &Path) -> Result<PathBu
mod tests {
use super::*;
use crate::eval::{
EvaluationStageTimings, PerformanceTimings, RetrievedSummary, StageLatencyBreakdown,
EvaluationStageTimings, PerformanceTimings, RetrievedSummary, StageLatency,
StageLatencyBreakdown,
};
use chrono::Utc;
use tempfile::tempdir;
@@ -931,12 +921,28 @@ mod tests {
fn sample_stage_latency() -> StageLatencyBreakdown {
StageLatencyBreakdown {
embed: latency(9.0),
collect_candidates: latency(10.0),
graph_expansion: latency(11.0),
chunk_attach: latency(12.0),
rerank: latency(13.0),
assemble: latency(14.0),
stages: vec![
StageLatency {
stage: "embed".to_string(),
stats: latency(9.0),
},
StageLatency {
stage: "search".to_string(),
stats: latency(10.0),
},
StageLatency {
stage: "rerank".to_string(),
stats: latency(13.0),
},
StageLatency {
stage: "resolve_entities".to_string(),
stats: latency(11.0),
},
StageLatency {
stage: "assemble".to_string(),
stats: latency(14.0),
},
],
}
}
@@ -1058,7 +1064,7 @@ mod tests {
rerank_keep_top: 5,
concurrency: 2,
detailed_report: true,
retrieval_strategy: "initial".into(),
resolve_entities: false,
chunk_result_cap: 5,
chunk_rrf_k: 60.0,
chunk_rrf_vector_weight: 1.0,
@@ -1071,7 +1077,6 @@ mod tests {
ingest_chunks_only: false,
chunk_vector_take: 50,
chunk_fts_take: 50,
chunk_avg_chars_per_token: 4,
max_chunks_per_entity: 4,
cases,
}
@@ -1097,7 +1102,7 @@ mod tests {
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::indexing_slicing)]
#[test]
fn evaluations_history_captures_strategy_and_concurrency() {
fn evaluations_history_captures_resolve_entities_and_concurrency() {
let tmp = tempdir().unwrap();
let summary = sample_summary(false);
@@ -1109,7 +1114,7 @@ mod tests {
assert_eq!(entries.len(), 1);
let stored = &entries[0];
assert_eq!(stored.retrieval.concurrency, summary.concurrency);
assert_eq!(stored.retrieval.strategy, summary.retrieval_strategy);
assert_eq!(stored.retrieval.resolve_entities, summary.resolve_entities);
assert_eq!(
stored.performance.evaluation_stages_ms.run_queries_ms,
summary.perf.evaluation_stage_ms.run_queries_ms
+33 -38
View File
@@ -3,7 +3,7 @@ use std::collections::HashSet;
use chrono::{DateTime, Utc};
use common::storage::types::StoredObject;
use retrieval_pipeline::{
PipelineDiagnostics, PipelineStageTimings, RetrievedChunk, RetrievedEntity, StrategyOutput,
Diagnostics, RetrievalOutput, RetrievedChunk, RetrievedEntity, StageKind, StageTimings,
};
use serde::{Deserialize, Serialize};
use unicode_normalization::UnicodeNormalization;
@@ -69,7 +69,7 @@ pub struct EvaluationSummary {
pub rerank_keep_top: usize,
pub concurrency: usize,
pub detailed_report: bool,
pub retrieval_strategy: String,
pub resolve_entities: bool,
pub chunk_result_cap: usize,
pub chunk_rrf_k: f32,
pub chunk_rrf_vector_weight: f32,
@@ -82,7 +82,6 @@ pub struct EvaluationSummary {
pub ingest_chunk_overlap_tokens: usize,
pub chunk_vector_take: usize,
pub chunk_fts_take: usize,
pub chunk_avg_chars_per_token: usize,
pub max_chunks_per_entity: usize,
pub cases: Vec<CaseSummary>,
}
@@ -129,14 +128,20 @@ impl Default for LatencyStats {
}
}
/// Latency statistics for a single retrieval stage, keyed by the stage's stable label.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StageLatency {
pub stage: String,
pub stats: LatencyStats,
}
/// Per-stage retrieval latency, in canonical pipeline order.
///
/// The set of stages is driven entirely by [`StageKind::ALL`], so adding a retrieval stage
/// surfaces here automatically without changes to the evaluation harness.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct StageLatencyBreakdown {
pub embed: LatencyStats,
pub collect_candidates: LatencyStats,
pub graph_expansion: LatencyStats,
pub chunk_attach: LatencyStats,
pub rerank: LatencyStats,
pub assemble: LatencyStats,
pub stages: Vec<StageLatency>,
}
#[allow(clippy::struct_field_names)]
@@ -232,13 +237,12 @@ fn candidates_from_chunks(chunks: Vec<RetrievedChunk>) -> Vec<EvaluationCandidat
.collect()
}
pub fn adapt_strategy_output(output: StrategyOutput) -> Vec<EvaluationCandidate> {
pub fn adapt_retrieval_output(output: RetrievalOutput) -> Vec<EvaluationCandidate> {
match output {
StrategyOutput::Entities(entities) => candidates_from_entities(entities),
StrategyOutput::Chunks(chunks) => candidates_from_chunks(chunks),
StrategyOutput::Search(search_result) => {
let mut candidates = candidates_from_entities(search_result.entities);
candidates.extend(candidates_from_chunks(search_result.chunks));
RetrievalOutput::Chunks(chunks) => candidates_from_chunks(chunks),
RetrievalOutput::WithEntities { chunks, entities } => {
let mut candidates = candidates_from_entities(entities);
candidates.extend(candidates_from_chunks(chunks));
candidates.sort_by(|a, b| b.score.total_cmp(&a.score));
candidates
}
@@ -262,7 +266,7 @@ pub struct CaseDiagnostics {
pub attached_chunk_ids: Vec<String>,
pub retrieved: Vec<EntityDiagnostics>,
#[serde(skip_serializing_if = "Option::is_none")]
pub pipeline: Option<PipelineDiagnostics>,
pub pipeline: Option<Diagnostics>,
}
#[derive(Debug, Serialize)]
@@ -366,28 +370,19 @@ pub fn compute_latency_stats(latencies: &[u128]) -> LatencyStats {
LatencyStats { avg, p50, p95 }
}
pub fn build_stage_latency_breakdown(samples: &[PipelineStageTimings]) -> StageLatencyBreakdown {
fn collect_stage<F>(samples: &[PipelineStageTimings], selector: F) -> Vec<u128>
where
F: Fn(&PipelineStageTimings) -> u128,
{
samples.iter().map(selector).collect()
}
pub fn build_stage_latency_breakdown(samples: &[StageTimings]) -> StageLatencyBreakdown {
let stages = StageKind::ALL
.iter()
.map(|kind| {
let latencies: Vec<u128> = samples.iter().map(|s| s.stage_ms(*kind)).collect();
StageLatency {
stage: kind.label().to_string(),
stats: compute_latency_stats(&latencies),
}
})
.collect();
StageLatencyBreakdown {
embed: compute_latency_stats(&collect_stage(samples, retrieval_pipeline::StageTimings::embed_ms)),
collect_candidates: compute_latency_stats(&collect_stage(samples, |entry| {
entry.collect_candidates_ms()
})),
graph_expansion: compute_latency_stats(&collect_stage(samples, |entry| {
entry.graph_expansion_ms()
})),
chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| {
entry.chunk_attach_ms()
})),
rerank: compute_latency_stats(&collect_stage(samples, retrieval_pipeline::StageTimings::rerank_ms)),
assemble: compute_latency_stats(&collect_stage(samples, retrieval_pipeline::StageTimings::assemble_ms)),
}
StageLatencyBreakdown { stages }
}
#[allow(
@@ -412,7 +407,7 @@ pub fn build_case_diagnostics(
expected_chunk_ids: &[String],
answers_lower: &[String],
candidates: &[EvaluationCandidate],
pipeline_stats: Option<PipelineDiagnostics>,
pipeline_stats: Option<Diagnostics>,
) -> CaseDiagnostics {
let expected_set: HashSet<&str> = expected_chunk_ids.iter().map(std::string::String::as_str).collect();
let mut seen_chunks: HashSet<String> = HashSet::new();