From 0cb1abc6db526fa0eba7c493d93c2e7a6b1256e6 Mon Sep 17 00:00:00 2001 From: Per Stark Date: Mon, 8 Dec 2025 20:39:12 +0100 Subject: [PATCH] beir-rff --- eval/manifest.yaml | 34 ++- eval/src/args.rs | 47 ++++ eval/src/datasets/mod.rs | 41 +++- eval/src/eval/pipeline/stages/run_queries.rs | 20 ++ eval/src/eval/pipeline/stages/summarize.rs | 5 + eval/src/eval/types.rs | 5 + eval/src/ingest/orchestrator.rs | 14 ++ eval/src/ingest/store.rs | 8 +- eval/src/perf.rs | 5 + eval/src/report.rs | 37 ++++ retrieval-pipeline/src/pipeline/config.rs | 40 ++++ retrieval-pipeline/src/pipeline/stages/mod.rs | 205 +++++------------- retrieval-pipeline/src/scoring.rs | 104 ++++++++- 13 files changed, 405 insertions(+), 160 deletions(-) diff --git a/eval/manifest.yaml b/eval/manifest.yaml index c73edba..28c32ed 100644 --- a/eval/manifest.yaml +++ b/eval/manifest.yaml @@ -20,7 +20,7 @@ datasets: category: "Natural Questions" entity_suffix: "Natural Questions" source_prefix: "nq" - raw: "data/raw/nq/dev-all.jsonl" + raw: "data/raw/nq-dev/dev-all.jsonl" converted: "data/converted/nq-dev-minne.json" include_unanswerable: true slices: @@ -42,7 +42,7 @@ datasets: slices: - id: beir-mix-600 label: "BEIR mix (600)" - description: "Balanced slice across FEVER, FiQA, HotpotQA, NFCorpus, Quora, TREC-COVID" + description: "Balanced slice across FEVER, FiQA, HotpotQA, NFCorpus, Quora, TREC-COVID, SciFact, NQ-BEIR" limit: 600 corpus_limit: 6000 seed: 0x5eed2025 @@ -136,3 +136,33 @@ datasets: limit: 200 corpus_limit: 5000 seed: 0x5eed2025 + - id: scifact + label: "SciFact (BEIR)" + category: "SciFact" + entity_suffix: "SciFact" + source_prefix: "scifact" + raw: "data/raw/scifact" + converted: "data/converted/scifact-minne.json" + include_unanswerable: false + slices: + - id: scifact-test-200 + label: "SciFact test (200)" + description: "200-case slice from BEIR test qrels" + limit: 200 + corpus_limit: 3000 + seed: 0x5eed2025 + - id: nq-beir + label: "Natural Questions (BEIR)" + category: "Natural Questions" + entity_suffix: "Natural Questions" + source_prefix: "nq-beir" + raw: "data/raw/nq" + converted: "data/converted/nq-beir-minne.json" + include_unanswerable: false + slices: + - id: nq-beir-test-200 + label: "NQ (BEIR) test (200)" + description: "200-case slice from BEIR test qrels" + limit: 200 + corpus_limit: 5000 + seed: 0x5eed2025 diff --git a/eval/src/args.rs b/eval/src/args.rs index f9b0b5b..2d0d642 100644 --- a/eval/src/args.rs +++ b/eval/src/args.rs @@ -84,6 +84,26 @@ pub struct RetrievalSettings { #[arg(long, default_value_t = 5)] pub chunk_result_cap: usize, + /// Reciprocal rank fusion k value for revised chunk merging + #[arg(long)] + pub chunk_rrf_k: Option, + + /// Weight for vector ranks in revised RRF + #[arg(long)] + pub chunk_rrf_vector_weight: Option, + + /// Weight for chunk FTS ranks in revised RRF + #[arg(long)] + pub chunk_rrf_fts_weight: Option, + + /// Include vector ranks in revised RRF (default: true) + #[arg(long)] + pub chunk_rrf_use_vector: Option, + + /// Include chunk FTS ranks in revised RRF (default: true) + #[arg(long)] + pub chunk_rrf_use_fts: Option, + /// Require verified chunks (disable with --llm-mode) #[arg(skip = true)] pub require_verified_chunks: bool, @@ -104,6 +124,11 @@ impl Default for RetrievalSettings { rerank_pool_size: 4, rerank_keep_top: 10, chunk_result_cap: 5, + chunk_rrf_k: None, + chunk_rrf_vector_weight: None, + chunk_rrf_fts_weight: None, + chunk_rrf_use_vector: None, + chunk_rrf_use_fts: None, require_verified_chunks: true, strategy: RetrievalStrategy::Initial, } @@ -376,6 +401,28 @@ impl Config { )); } + if let Some(k) = self.retrieval.chunk_rrf_k { + if k <= 0.0 || !k.is_finite() { + return Err(anyhow!( + "--chunk-rrf-k must be a positive, finite number (got {k})" + )); + } + } + if let Some(weight) = self.retrieval.chunk_rrf_vector_weight { + if weight < 0.0 || !weight.is_finite() { + return Err(anyhow!( + "--chunk-rrf-vector-weight must be a non-negative, finite number (got {weight})" + )); + } + } + if let Some(weight) = self.retrieval.chunk_rrf_fts_weight { + if weight < 0.0 || !weight.is_finite() { + return Err(anyhow!( + "--chunk-rrf-fts-weight must be a non-negative, finite number (got {weight})" + )); + } + } + if self.concurrency == 0 { return Err(anyhow!("--concurrency must be greater than zero")); } diff --git a/eval/src/datasets/mod.rs b/eval/src/datasets/mod.rs index bef38a7..108c36f 100644 --- a/eval/src/datasets/mod.rs +++ b/eval/src/datasets/mod.rs @@ -262,6 +262,10 @@ pub enum DatasetKind { Quora, #[value(name = "trec-covid", alias = "treccovid", alias = "trec_covid")] TrecCovid, + #[value(name = "scifact")] + Scifact, + #[value(name = "nq-beir", alias = "natural-questions-beir")] + NqBeir, } impl DatasetKind { @@ -276,6 +280,8 @@ impl DatasetKind { Self::Nfcorpus => "nfcorpus", Self::Quora => "quora", Self::TrecCovid => "trec-covid", + Self::Scifact => "scifact", + Self::NqBeir => "nq-beir", } } @@ -290,6 +296,8 @@ impl DatasetKind { Self::Nfcorpus => "NFCorpus (BEIR)", Self::Quora => "Quora (IR)", Self::TrecCovid => "TREC-COVID (BEIR)", + Self::Scifact => "SciFact (BEIR)", + Self::NqBeir => "Natural Questions (BEIR)", } } @@ -304,6 +312,8 @@ impl DatasetKind { Self::Nfcorpus => "NFCorpus", Self::Quora => "Quora", Self::TrecCovid => "TREC-COVID", + Self::Scifact => "SciFact", + Self::NqBeir => "Natural Questions", } } @@ -318,6 +328,8 @@ impl DatasetKind { Self::Nfcorpus => "NFCorpus", Self::Quora => "Quora", Self::TrecCovid => "TREC-COVID", + Self::Scifact => "SciFact", + Self::NqBeir => "Natural Questions", } } @@ -332,6 +344,8 @@ impl DatasetKind { Self::Nfcorpus => "nfcorpus", Self::Quora => "quora", Self::TrecCovid => "trec-covid", + Self::Scifact => "scifact", + Self::NqBeir => "nq-beir", } } @@ -376,20 +390,24 @@ impl FromStr for DatasetKind { "nfcorpus" | "nf-corpus" => Ok(Self::Nfcorpus), "quora" => Ok(Self::Quora), "trec-covid" | "treccovid" | "trec_covid" => Ok(Self::TrecCovid), + "scifact" => Ok(Self::Scifact), + "nq-beir" | "natural-questions-beir" => Ok(Self::NqBeir), other => { - anyhow::bail!("unknown dataset '{other}'. Expected one of: squad, natural-questions, beir, fever, fiqa, hotpotqa, nfcorpus, quora, trec-covid.") + anyhow::bail!("unknown dataset '{other}'. Expected one of: squad, natural-questions, beir, fever, fiqa, hotpotqa, nfcorpus, quora, trec-covid, scifact, nq-beir.") } } } } -pub const BEIR_DATASETS: [DatasetKind; 6] = [ +pub const BEIR_DATASETS: [DatasetKind; 8] = [ DatasetKind::Fever, DatasetKind::Fiqa, DatasetKind::HotpotQa, DatasetKind::Nfcorpus, DatasetKind::Quora, DatasetKind::TrecCovid, + DatasetKind::Scifact, + DatasetKind::NqBeir, ]; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -481,7 +499,9 @@ pub fn convert( | DatasetKind::HotpotQa | DatasetKind::Nfcorpus | DatasetKind::Quora - | DatasetKind::TrecCovid => beir::convert_beir(raw_path, dataset)?, + | DatasetKind::TrecCovid + | DatasetKind::Scifact + | DatasetKind::NqBeir => beir::convert_beir(raw_path, dataset)?, }; let metadata_limit = match dataset { @@ -489,13 +509,26 @@ pub fn convert( _ => context_token_limit, }; + let generated_at = match dataset { + DatasetKind::Beir + | DatasetKind::Fever + | DatasetKind::Fiqa + | DatasetKind::HotpotQa + | DatasetKind::Nfcorpus + | DatasetKind::Quora + | DatasetKind::TrecCovid + | DatasetKind::Scifact + | DatasetKind::NqBeir => base_timestamp(), + _ => Utc::now(), + }; + let source_label = match dataset { DatasetKind::Beir => "beir-mix".to_string(), _ => raw_path.display().to_string(), }; Ok(ConvertedDataset { - generated_at: Utc::now(), + generated_at, metadata: DatasetMetadata::for_kind(dataset, include_unanswerable, metadata_limit), source: source_label, paragraphs, diff --git a/eval/src/eval/pipeline/stages/run_queries.rs b/eval/src/eval/pipeline/stages/run_queries.rs index b2a6f08..37edefc 100644 --- a/eval/src/eval/pipeline/stages/run_queries.rs +++ b/eval/src/eval/pipeline/stages/run_queries.rs @@ -63,6 +63,21 @@ pub(crate) async fn run_queries( if let Some(value) = config.retrieval.chunk_fts_take { retrieval_config.tuning.chunk_fts_take = value; } + if let Some(value) = config.retrieval.chunk_rrf_k { + retrieval_config.tuning.chunk_rrf_k = value; + } + if let Some(value) = config.retrieval.chunk_rrf_vector_weight { + retrieval_config.tuning.chunk_rrf_vector_weight = value; + } + if let Some(value) = config.retrieval.chunk_rrf_fts_weight { + retrieval_config.tuning.chunk_rrf_fts_weight = value; + } + if let Some(value) = config.retrieval.chunk_rrf_use_vector { + retrieval_config.tuning.chunk_rrf_use_vector = value; + } + if let Some(value) = config.retrieval.chunk_rrf_use_fts { + retrieval_config.tuning.chunk_rrf_use_fts = value; + } if let Some(value) = config.retrieval.chunk_avg_chars_per_token { retrieval_config.tuning.avg_chars_per_token = value; } @@ -93,6 +108,11 @@ pub(crate) async fn run_queries( rerank_keep_top = config.retrieval.rerank_keep_top, chunk_vector_take = effective_chunk_vector, chunk_fts_take = effective_chunk_fts, + chunk_rrf_k = active_tuning.chunk_rrf_k, + chunk_rrf_vector_weight = active_tuning.chunk_rrf_vector_weight, + chunk_rrf_fts_weight = active_tuning.chunk_rrf_fts_weight, + chunk_rrf_use_vector = active_tuning.chunk_rrf_use_vector, + chunk_rrf_use_fts = active_tuning.chunk_rrf_use_fts, embedding_backend = ctx.embedding_provider().backend_label(), embedding_model = ctx .embedding_provider() diff --git a/eval/src/eval/pipeline/stages/summarize.rs b/eval/src/eval/pipeline/stages/summarize.rs index 2da9b9d..e4cee7d 100644 --- a/eval/src/eval/pipeline/stages/summarize.rs +++ b/eval/src/eval/pipeline/stages/summarize.rs @@ -202,6 +202,11 @@ pub(crate) async fn summarize( detailed_report: config.detailed_report, retrieval_strategy: config.retrieval.strategy.to_string(), chunk_result_cap: config.retrieval.chunk_result_cap, + chunk_rrf_k: active_tuning.chunk_rrf_k, + chunk_rrf_vector_weight: active_tuning.chunk_rrf_vector_weight, + chunk_rrf_fts_weight: active_tuning.chunk_rrf_fts_weight, + chunk_rrf_use_vector: active_tuning.chunk_rrf_use_vector, + chunk_rrf_use_fts: active_tuning.chunk_rrf_use_fts, ingest_chunk_min_tokens: config.ingest_chunk_min_tokens, ingest_chunk_max_tokens: config.ingest_chunk_max_tokens, ingest_chunks_only: config.ingest_chunks_only, diff --git a/eval/src/eval/types.rs b/eval/src/eval/types.rs index f599b27..0a2d21d 100644 --- a/eval/src/eval/types.rs +++ b/eval/src/eval/types.rs @@ -70,6 +70,11 @@ pub struct EvaluationSummary { pub detailed_report: bool, pub retrieval_strategy: String, pub chunk_result_cap: usize, + pub chunk_rrf_k: f32, + pub chunk_rrf_vector_weight: f32, + pub chunk_rrf_fts_weight: f32, + pub chunk_rrf_use_vector: bool, + pub chunk_rrf_use_fts: bool, pub ingest_chunk_min_tokens: usize, pub ingest_chunk_max_tokens: usize, pub ingest_chunks_only: bool, diff --git a/eval/src/ingest/orchestrator.rs b/eval/src/ingest/orchestrator.rs index 5f26a96..da7ce3b 100644 --- a/eval/src/ingest/orchestrator.rs +++ b/eval/src/ingest/orchestrator.rs @@ -373,6 +373,20 @@ pub async fn ensure_corpus( let reused_ingestion = ingested_count == 0 && !cache.force_refresh; let reused_embeddings = reused_ingestion && !cache.refresh_embeddings_only; + info!( + dataset = %dataset.metadata.id, + slice = %slice.manifest.slice_id, + fingerprint = %ingestion_fingerprint, + reused_ingestion, + reused_embeddings, + positive_reused = stats.positive_reused, + positive_ingested = stats.positive_ingested, + negative_reused = stats.negative_reused, + negative_ingested = stats.negative_ingested, + shard_dir = %base_dir.display(), + "Corpus cache outcome" + ); + let handle = CorpusHandle { manifest, path: base_dir, diff --git a/eval/src/ingest/store.rs b/eval/src/ingest/store.rs index 5256fec..14061b0 100644 --- a/eval/src/ingest/store.rs +++ b/eval/src/ingest/store.rs @@ -22,7 +22,7 @@ use common::storage::{ use serde::Deserialize; use serde::Serialize; use surrealdb::sql::Thing; -use tracing::warn; +use tracing::{debug, warn}; use crate::datasets::{ConvertedParagraph, ConvertedQuestion}; @@ -440,6 +440,12 @@ impl ParagraphShardStore { .with_context(|| format!("parsing shard {}", path.display()))?; if shard.ingestion_fingerprint != fingerprint { + debug!( + path = %path.display(), + expected = fingerprint, + found = shard.ingestion_fingerprint, + "Shard fingerprint mismatch; will rebuild" + ); return Ok(None); } if shard.version != PARAGRAPH_SHARD_VERSION { diff --git a/eval/src/perf.rs b/eval/src/perf.rs index e4e61b5..c5e4573 100644 --- a/eval/src/perf.rs +++ b/eval/src/perf.rs @@ -197,6 +197,11 @@ mod tests { detailed_report: false, retrieval_strategy: "initial".into(), chunk_result_cap: 5, + chunk_rrf_k: 60.0, + chunk_rrf_vector_weight: 1.0, + chunk_rrf_fts_weight: 1.0, + chunk_rrf_use_vector: true, + chunk_rrf_use_fts: true, ingest_chunk_min_tokens: 256, ingest_chunk_max_tokens: 512, ingest_chunks_only: false, diff --git a/eval/src/report.rs b/eval/src/report.rs index 8084c3a..abc7c55 100644 --- a/eval/src/report.rs +++ b/eval/src/report.rs @@ -88,6 +88,16 @@ pub struct RetrievalSection { pub rerank_pool_size: Option, pub rerank_keep_top: usize, pub chunk_result_cap: usize, + #[serde(default = "default_chunk_rrf_k")] + pub chunk_rrf_k: f32, + #[serde(default = "default_chunk_rrf_weight")] + pub chunk_rrf_vector_weight: f32, + #[serde(default = "default_chunk_rrf_weight")] + pub chunk_rrf_fts_weight: f32, + #[serde(default = "default_chunk_rrf_use")] + pub chunk_rrf_use_vector: bool, + #[serde(default = "default_chunk_rrf_use")] + pub chunk_rrf_use_fts: bool, #[serde(default)] pub chunk_vector_take: usize, #[serde(default)] @@ -98,6 +108,18 @@ pub struct RetrievalSection { pub ingest_chunks_only: bool, } +const fn default_chunk_rrf_k() -> f32 { + 60.0 +} + +const fn default_chunk_rrf_weight() -> f32 { + 1.0 +} + +const fn default_chunk_rrf_use() -> bool { + true +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct LlmSection { pub cases: usize, @@ -206,6 +228,11 @@ impl EvaluationReport { rerank_pool_size: summary.rerank_pool_size, rerank_keep_top: summary.rerank_keep_top, chunk_result_cap: summary.chunk_result_cap, + chunk_rrf_k: summary.chunk_rrf_k, + chunk_rrf_vector_weight: summary.chunk_rrf_vector_weight, + chunk_rrf_fts_weight: summary.chunk_rrf_fts_weight, + chunk_rrf_use_vector: summary.chunk_rrf_use_vector, + chunk_rrf_use_fts: summary.chunk_rrf_use_fts, chunk_vector_take: summary.chunk_vector_take, chunk_fts_take: summary.chunk_fts_take, ingest_chunk_min_tokens: summary.ingest_chunk_min_tokens, @@ -856,6 +883,11 @@ fn convert_legacy_entry(entry: LegacyHistoryEntry) -> EvaluationReport { rerank_pool_size: entry.rerank_pool_size, rerank_keep_top: entry.rerank_keep_top, chunk_result_cap: entry.chunk_result_cap.unwrap_or(5), + chunk_rrf_k: default_chunk_rrf_k(), + chunk_rrf_vector_weight: default_chunk_rrf_weight(), + chunk_rrf_fts_weight: default_chunk_rrf_weight(), + chunk_rrf_use_vector: default_chunk_rrf_use(), + chunk_rrf_use_fts: default_chunk_rrf_use(), chunk_vector_take: 0, chunk_fts_take: 0, ingest_chunk_min_tokens: entry.ingest_chunk_min_tokens.unwrap_or(256), @@ -1098,6 +1130,11 @@ mod tests { detailed_report: true, retrieval_strategy: "initial".into(), chunk_result_cap: 5, + chunk_rrf_k: 60.0, + chunk_rrf_vector_weight: 1.0, + chunk_rrf_fts_weight: 1.0, + chunk_rrf_use_vector: true, + chunk_rrf_use_fts: true, ingest_chunk_min_tokens: 256, ingest_chunk_max_tokens: 512, ingest_chunk_overlap_tokens: 50, diff --git a/retrieval-pipeline/src/pipeline/config.rs b/retrieval-pipeline/src/pipeline/config.rs index c12ff02..42f3c50 100644 --- a/retrieval-pipeline/src/pipeline/config.rs +++ b/retrieval-pipeline/src/pipeline/config.rs @@ -72,6 +72,21 @@ pub struct RetrievalTuning { pub normalize_vector_scores: bool, /// Normalize FTS (BM25) scores before fusion (default: true) pub normalize_fts_scores: bool, + /// Reciprocal rank fusion k value for chunk merging in Revised strategy. + #[serde(default = "default_chunk_rrf_k")] + pub chunk_rrf_k: f32, + /// Weight applied to vector ranks in RRF. + #[serde(default = "default_chunk_rrf_vector_weight")] + pub chunk_rrf_vector_weight: f32, + /// Weight applied to chunk FTS ranks in RRF. + #[serde(default = "default_chunk_rrf_fts_weight")] + pub chunk_rrf_fts_weight: f32, + /// Whether to include vector rankings in RRF. + #[serde(default = "default_chunk_rrf_use_vector")] + pub chunk_rrf_use_vector: bool, + /// Whether to include chunk FTS rankings in RRF. + #[serde(default = "default_chunk_rrf_use_fts")] + pub chunk_rrf_use_fts: bool, } impl Default for RetrievalTuning { @@ -102,6 +117,11 @@ impl Default for RetrievalTuning { normalize_vector_scores: false, // FTS scores (BM25) are unbounded, normalization helps more normalize_fts_scores: true, + chunk_rrf_k: default_chunk_rrf_k(), + chunk_rrf_vector_weight: default_chunk_rrf_vector_weight(), + chunk_rrf_fts_weight: default_chunk_rrf_fts_weight(), + chunk_rrf_use_vector: default_chunk_rrf_use_vector(), + chunk_rrf_use_fts: default_chunk_rrf_use_fts(), } } } @@ -156,3 +176,23 @@ impl Default for RetrievalConfig { } } } + +const fn default_chunk_rrf_k() -> f32 { + 60.0 +} + +const fn default_chunk_rrf_vector_weight() -> f32 { + 1.0 +} + +const fn default_chunk_rrf_fts_weight() -> f32 { + 1.0 +} + +const fn default_chunk_rrf_use_vector() -> bool { + true +} + +const fn default_chunk_rrf_use_fts() -> bool { + true +} diff --git a/retrieval-pipeline/src/pipeline/stages/mod.rs b/retrieval-pipeline/src/pipeline/stages/mod.rs index af082a9..587ec52 100644 --- a/retrieval-pipeline/src/pipeline/stages/mod.rs +++ b/retrieval-pipeline/src/pipeline/stages/mod.rs @@ -21,8 +21,8 @@ use crate::{ graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids}, reranking::RerankerLease, scoring::{ - clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc, - FusionWeights, Scored, + clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, reciprocal_rank_fusion, + sort_by_fused_desc, FusionWeights, RrfConfig, Scored, }, RetrievedChunk, RetrievedEntity, }; @@ -593,8 +593,9 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), debug!("Collecting vector chunk candidates for revised strategy"); let embedding = ctx.ensure_embedding()?.clone(); let tuning = &ctx.config.tuning; - let weights = tuning.fusion_weights.unwrap_or_else(FusionWeights::default); let fts_take = tuning.chunk_fts_take; + let (fts_query, fts_token_count) = normalize_fts_query(&ctx.input_text); + let fts_enabled = tuning.chunk_rrf_use_fts && fts_take > 0 && !fts_query.is_empty(); let (vector_rows, fts_rows) = tokio::try_join!( TextChunk::vector_search( @@ -604,35 +605,42 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), &ctx.user_id, ), async { - if fts_take == 0 { - Ok(Vec::new()) + if fts_enabled { + TextChunk::fts_search(fts_take, &fts_query, ctx.db_client, &ctx.user_id).await } else { - TextChunk::fts_search(fts_take, &ctx.input_text, ctx.db_client, &ctx.user_id).await + Ok(Vec::new()) } } )?; - let mut merged: HashMap> = HashMap::new(); let vector_candidates = vector_rows.len(); let fts_candidates = fts_rows.len(); - // Collect vector results let vector_scored: Vec> = vector_rows .into_iter() .map(|row| Scored::new(row.chunk).with_vector_score(row.score)) .collect(); - // Collect FTS results let fts_scored: Vec> = fts_rows .into_iter() .map(|row| Scored::new(row.chunk).with_fts_score(row.score)) .collect(); - // Merge by ID first (before normalization) - merge_scored_by_id(&mut merged, vector_scored); - merge_scored_by_id(&mut merged, fts_scored); + let mut fts_weight = tuning.chunk_rrf_fts_weight; + if fts_enabled && fts_token_count > 0 && fts_token_count <= 3 { + // For very short keyword queries, lean more on lexical ranking. + fts_weight *= 1.5; + } - let mut vector_chunks: Vec> = merged.into_values().collect(); + let rrf_config = RrfConfig { + k: tuning.chunk_rrf_k, + vector_weight: tuning.chunk_rrf_vector_weight, + fts_weight, + use_vector: tuning.chunk_rrf_use_vector, + use_fts: tuning.chunk_rrf_use_fts && fts_candidates > 0, + }; + + let mut vector_chunks = reciprocal_rank_fusion(vector_scored, fts_scored, rrf_config); debug!( total_merged = vector_chunks.len(), @@ -648,58 +656,24 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), .iter() .filter(|c| c.scores.vector.is_some() && c.scores.fts.is_some()) .count(), - "Merged chunk candidates before normalization" + rrf_k = rrf_config.k, + rrf_vector_weight = rrf_config.vector_weight, + rrf_fts_weight = rrf_config.fts_weight, + "Merged chunk candidates with RRF" ); - // Normalize scores AFTER merging, on the final merged set - // This ensures we normalize all vector scores together and all FTS scores together - // for the actual candidates that will be fused - if tuning.normalize_vector_scores && !vector_chunks.is_empty() { - let before_sample: Vec = vector_chunks - .iter() - .filter_map(|c| c.scores.vector) - .take(5) - .collect(); - normalize_vector_scores(&mut vector_chunks); - let after_sample: Vec = vector_chunks - .iter() - .filter_map(|c| c.scores.vector) - .take(5) - .collect(); - debug!( - vector_before = ?before_sample, - vector_after = ?after_sample, - "Vector score normalization applied" - ); - } - if tuning.normalize_fts_scores && !vector_chunks.is_empty() { - let before_sample: Vec = vector_chunks - .iter() - .filter_map(|c| c.scores.fts) - .take(5) - .collect(); - normalize_fts_scores_in_merged(&mut vector_chunks); - let after_sample: Vec = vector_chunks - .iter() - .filter_map(|c| c.scores.fts) - .take(5) - .collect(); - debug!( - fts_before = ?before_sample, - fts_after = ?after_sample, - "FTS score normalization applied" - ); - } + // let fts_only_count = vector_chunks + // .iter() + // .filter(|c| c.scores.vector.is_none()) + // .count(); + // let both_count = vector_chunks + // .iter() + // .filter(|c| c.scores.vector.is_some() && c.scores.fts.is_some()) + // .count(); - // Fuse scores after normalization - for scored in &mut vector_chunks { - let fused = fuse_scores(&scored.scores, weights); - scored.update_fused(fused); - } - - // Filter out FTS-only chunks if they're likely to be low quality - // (when overlap is low, FTS-only chunks are usually noise) - // Always keep chunks with vector scores (vector-only or both signals) + // If we have very low overlap (few chunks with both signals), drop FTS-only chunks. + // These are often noisy on keyword-heavy datasets and dilute strong vector hits. + // Keep vector-only and “golden” (vector+FTS) chunks. let fts_only_count = vector_chunks .iter() .filter(|c| c.scores.vector.is_none()) @@ -708,10 +682,6 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), .iter() .filter(|c| c.scores.vector.is_some() && c.scores.fts.is_some()) .count(); - - // If we have very low overlap (few chunks with both signals), filter out FTS-only chunks - // They're likely diluting the good vector results - // This preserves vector-only chunks and golden chunks (both signals) if fts_only_count > 0 && both_count < 3 { let before_filter = vector_chunks.len(); vector_chunks.retain(|c| c.scores.vector.is_some()); @@ -724,9 +694,8 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), } debug!( - fusion_weights = ?weights, top_fused_scores = ?vector_chunks.iter().take(5).map(|c| c.fused).collect::>(), - "Fused scores after normalization" + "Fused scores after RRF ordering" ); if ctx.diagnostics_enabled() { @@ -797,11 +766,6 @@ pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { .min(ctx.config.tuning.chunk_vector_take.max(1)); if chunk_values.len() > limit { - println!( - "We removed chunks! we had {:?}, now going for {:?}", - chunk_values.len(), - limit - ); chunk_values.truncate(limit); } @@ -966,87 +930,24 @@ fn normalize_fts_scores(results: &mut [Scored]) { } } -fn normalize_vector_scores(results: &mut [Scored]) { - // Only normalize scores for items that actually have vector scores - let items_with_scores: Vec<(usize, f32)> = results - .iter() - .enumerate() - .filter_map(|(idx, candidate)| candidate.scores.vector.map(|score| (idx, score))) - .collect(); - - if items_with_scores.len() < 2 { - // Don't normalize if we have 0 or 1 scores - nothing to normalize against - return; +fn normalize_fts_query(input: &str) -> (String, usize) { + const STOPWORDS: &[&str] = &["the", "a", "an", "of", "in", "on", "and", "or", "to", "for"]; + let mut cleaned = String::with_capacity(input.len()); + for ch in input.chars() { + if ch.is_alphanumeric() { + cleaned.extend(ch.to_lowercase()); + } else if ch.is_whitespace() { + cleaned.push(' '); + } } - - let raw_scores: Vec = items_with_scores.iter().map(|(_, score)| *score).collect(); - - // For cosine similarity scores (already in [0,1]), use a gentler normalization - // that preserves more of the original distribution - // Only normalize if the range is significant (more than 0.1 difference) - let min = raw_scores.iter().fold(f32::MAX, |a, &b| a.min(b)); - let max = raw_scores.iter().fold(f32::MIN, |a, &b| a.max(b)); - let range = max - min; - - if range < 0.1 { - // Scores are too similar, don't normalize (would compress too much) - debug!( - vector_score_range = range, - min = min, - max = max, - "Skipping vector normalization - scores too similar" - ); - return; - } - - let normalized = min_max_normalize(&raw_scores); - - for ((idx, _), normalized_score) in items_with_scores.iter().zip(normalized.into_iter()) { - results[*idx].scores.vector = Some(normalized_score); - results[*idx].update_fused(0.0); - } -} - -fn normalize_fts_scores_in_merged(results: &mut [Scored]) { - // Only normalize scores for items that actually have FTS scores - let items_with_scores: Vec<(usize, f32)> = results - .iter() - .enumerate() - .filter_map(|(idx, candidate)| candidate.scores.fts.map(|score| (idx, score))) - .collect(); - - if items_with_scores.len() < 2 { - // Don't normalize if we have 0 or 1 scores - nothing to normalize against - // Single FTS score would become 1.0, which doesn't help - return; - } - - let raw_scores: Vec = items_with_scores.iter().map(|(_, score)| *score).collect(); - - // BM25 scores can be negative or very high, so normalization is more important - // But check if we have enough variation to normalize - let min = raw_scores.iter().fold(f32::MAX, |a, &b| a.min(b)); - let max = raw_scores.iter().fold(f32::MIN, |a, &b| a.max(b)); - let range = max - min; - - // For BM25, even small differences can be meaningful, but if all scores are - // very similar, normalization won't help - if range < 0.01 { - debug!( - fts_score_range = range, - min = min, - max = max, - "Skipping FTS normalization - scores too similar" - ); - return; - } - - let normalized = min_max_normalize(&raw_scores); - - for ((idx, _), normalized_score) in items_with_scores.iter().zip(normalized.into_iter()) { - results[*idx].scores.fts = Some(normalized_score); - results[*idx].update_fused(0.0); + let mut tokens = Vec::new(); + for token in cleaned.split_whitespace() { + if !STOPWORDS.contains(&token) && !token.is_empty() { + tokens.push(token.to_string()); + } } + let normalized = tokens.join(" "); + (normalized, tokens.len()) } fn apply_fusion(candidates: &mut HashMap>, weights: FusionWeights) diff --git a/retrieval-pipeline/src/scoring.rs b/retrieval-pipeline/src/scoring.rs index 458709d..8fce2e7 100644 --- a/retrieval-pipeline/src/scoring.rs +++ b/retrieval-pipeline/src/scoring.rs @@ -1,4 +1,4 @@ -use std::cmp::Ordering; +use std::{cmp::Ordering, collections::HashMap}; use common::storage::types::StoredObject; use serde::{Deserialize, Serialize}; @@ -71,6 +71,28 @@ impl Default for FusionWeights { } } +/// Configuration for reciprocal rank fusion. +#[derive(Debug, Clone, Copy)] +pub struct RrfConfig { + pub k: f32, + pub vector_weight: f32, + pub fts_weight: f32, + pub use_vector: bool, + pub use_fts: bool, +} + +impl Default for RrfConfig { + fn default() -> Self { + Self { + k: 60.0, + vector_weight: 1.0, + fts_weight: 1.0, + use_vector: true, + use_fts: true, + } + } +} + pub const fn clamp_unit(value: f32) -> f32 { value.clamp(0.0, 1.0) } @@ -196,3 +218,83 @@ where .then_with(|| a.item.get_id().cmp(b.item.get_id())) }); } + +pub fn reciprocal_rank_fusion( + mut vector_ranked: Vec>, + mut fts_ranked: Vec>, + config: RrfConfig, +) -> Vec> +where + T: StoredObject + Clone, +{ + let mut merged: HashMap> = HashMap::new(); + let k = if config.k <= 0.0 { 60.0 } else { config.k }; + let vector_weight = if config.vector_weight.is_finite() { + config.vector_weight.max(0.0) + } else { + 0.0 + }; + let fts_weight = if config.fts_weight.is_finite() { + config.fts_weight.max(0.0) + } else { + 0.0 + }; + + if config.use_vector && !vector_ranked.is_empty() { + vector_ranked.sort_by(|a, b| { + let a_score = a.scores.vector.unwrap_or(0.0); + let b_score = b.scores.vector.unwrap_or(0.0); + b_score + .partial_cmp(&a_score) + .unwrap_or(Ordering::Equal) + .then_with(|| a.item.get_id().cmp(b.item.get_id())) + }); + + for (rank, candidate) in vector_ranked.into_iter().enumerate() { + let id = candidate.item.get_id().to_owned(); + let entry = merged + .entry(id.clone()) + .or_insert_with(|| Scored::new(candidate.item.clone())); + + if let Some(score) = candidate.scores.vector { + let existing = entry.scores.vector.unwrap_or(f32::MIN); + if score > existing { + entry.scores.vector = Some(score); + } + } + entry.item = candidate.item; + entry.fused += vector_weight / (k + rank as f32 + 1.0); + } + } + + if config.use_fts && !fts_ranked.is_empty() { + fts_ranked.sort_by(|a, b| { + let a_score = a.scores.fts.unwrap_or(0.0); + let b_score = b.scores.fts.unwrap_or(0.0); + b_score + .partial_cmp(&a_score) + .unwrap_or(Ordering::Equal) + .then_with(|| a.item.get_id().cmp(b.item.get_id())) + }); + + for (rank, candidate) in fts_ranked.into_iter().enumerate() { + let id = candidate.item.get_id().to_owned(); + let entry = merged + .entry(id.clone()) + .or_insert_with(|| Scored::new(candidate.item.clone())); + + if let Some(score) = candidate.scores.fts { + let existing = entry.scores.fts.unwrap_or(f32::MIN); + if score > existing { + entry.scores.fts = Some(score); + } + } + entry.item = candidate.item; + entry.fused += fts_weight / (k + rank as f32 + 1.0); + } + } + + let mut fused: Vec> = merged.into_values().collect(); + sort_by_fused_desc(&mut fused); + fused +}