This commit is contained in:
Per Stark
2025-12-08 20:39:12 +01:00
parent d1a6d9abdf
commit 0cb1abc6db
13 changed files with 405 additions and 160 deletions

View File

@@ -72,6 +72,21 @@ pub struct RetrievalTuning {
pub normalize_vector_scores: bool,
/// Normalize FTS (BM25) scores before fusion (default: true)
pub normalize_fts_scores: bool,
/// Reciprocal rank fusion k value for chunk merging in Revised strategy.
#[serde(default = "default_chunk_rrf_k")]
pub chunk_rrf_k: f32,
/// Weight applied to vector ranks in RRF.
#[serde(default = "default_chunk_rrf_vector_weight")]
pub chunk_rrf_vector_weight: f32,
/// Weight applied to chunk FTS ranks in RRF.
#[serde(default = "default_chunk_rrf_fts_weight")]
pub chunk_rrf_fts_weight: f32,
/// Whether to include vector rankings in RRF.
#[serde(default = "default_chunk_rrf_use_vector")]
pub chunk_rrf_use_vector: bool,
/// Whether to include chunk FTS rankings in RRF.
#[serde(default = "default_chunk_rrf_use_fts")]
pub chunk_rrf_use_fts: bool,
}
impl Default for RetrievalTuning {
@@ -102,6 +117,11 @@ impl Default for RetrievalTuning {
normalize_vector_scores: false,
// FTS scores (BM25) are unbounded, normalization helps more
normalize_fts_scores: true,
chunk_rrf_k: default_chunk_rrf_k(),
chunk_rrf_vector_weight: default_chunk_rrf_vector_weight(),
chunk_rrf_fts_weight: default_chunk_rrf_fts_weight(),
chunk_rrf_use_vector: default_chunk_rrf_use_vector(),
chunk_rrf_use_fts: default_chunk_rrf_use_fts(),
}
}
}
@@ -156,3 +176,23 @@ impl Default for RetrievalConfig {
}
}
}
const fn default_chunk_rrf_k() -> f32 {
60.0
}
const fn default_chunk_rrf_vector_weight() -> f32 {
1.0
}
const fn default_chunk_rrf_fts_weight() -> f32 {
1.0
}
const fn default_chunk_rrf_use_vector() -> bool {
true
}
const fn default_chunk_rrf_use_fts() -> bool {
true
}

View File

@@ -21,8 +21,8 @@ use crate::{
graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids},
reranking::RerankerLease,
scoring::{
clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc,
FusionWeights, Scored,
clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, reciprocal_rank_fusion,
sort_by_fused_desc, FusionWeights, RrfConfig, Scored,
},
RetrievedChunk, RetrievedEntity,
};
@@ -593,8 +593,9 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(),
debug!("Collecting vector chunk candidates for revised strategy");
let embedding = ctx.ensure_embedding()?.clone();
let tuning = &ctx.config.tuning;
let weights = tuning.fusion_weights.unwrap_or_else(FusionWeights::default);
let fts_take = tuning.chunk_fts_take;
let (fts_query, fts_token_count) = normalize_fts_query(&ctx.input_text);
let fts_enabled = tuning.chunk_rrf_use_fts && fts_take > 0 && !fts_query.is_empty();
let (vector_rows, fts_rows) = tokio::try_join!(
TextChunk::vector_search(
@@ -604,35 +605,42 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(),
&ctx.user_id,
),
async {
if fts_take == 0 {
Ok(Vec::new())
if fts_enabled {
TextChunk::fts_search(fts_take, &fts_query, ctx.db_client, &ctx.user_id).await
} else {
TextChunk::fts_search(fts_take, &ctx.input_text, ctx.db_client, &ctx.user_id).await
Ok(Vec::new())
}
}
)?;
let mut merged: HashMap<String, Scored<TextChunk>> = HashMap::new();
let vector_candidates = vector_rows.len();
let fts_candidates = fts_rows.len();
// Collect vector results
let vector_scored: Vec<Scored<TextChunk>> = vector_rows
.into_iter()
.map(|row| Scored::new(row.chunk).with_vector_score(row.score))
.collect();
// Collect FTS results
let fts_scored: Vec<Scored<TextChunk>> = fts_rows
.into_iter()
.map(|row| Scored::new(row.chunk).with_fts_score(row.score))
.collect();
// Merge by ID first (before normalization)
merge_scored_by_id(&mut merged, vector_scored);
merge_scored_by_id(&mut merged, fts_scored);
let mut fts_weight = tuning.chunk_rrf_fts_weight;
if fts_enabled && fts_token_count > 0 && fts_token_count <= 3 {
// For very short keyword queries, lean more on lexical ranking.
fts_weight *= 1.5;
}
let mut vector_chunks: Vec<Scored<TextChunk>> = merged.into_values().collect();
let rrf_config = RrfConfig {
k: tuning.chunk_rrf_k,
vector_weight: tuning.chunk_rrf_vector_weight,
fts_weight,
use_vector: tuning.chunk_rrf_use_vector,
use_fts: tuning.chunk_rrf_use_fts && fts_candidates > 0,
};
let mut vector_chunks = reciprocal_rank_fusion(vector_scored, fts_scored, rrf_config);
debug!(
total_merged = vector_chunks.len(),
@@ -648,58 +656,24 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(),
.iter()
.filter(|c| c.scores.vector.is_some() && c.scores.fts.is_some())
.count(),
"Merged chunk candidates before normalization"
rrf_k = rrf_config.k,
rrf_vector_weight = rrf_config.vector_weight,
rrf_fts_weight = rrf_config.fts_weight,
"Merged chunk candidates with RRF"
);
// Normalize scores AFTER merging, on the final merged set
// This ensures we normalize all vector scores together and all FTS scores together
// for the actual candidates that will be fused
if tuning.normalize_vector_scores && !vector_chunks.is_empty() {
let before_sample: Vec<f32> = vector_chunks
.iter()
.filter_map(|c| c.scores.vector)
.take(5)
.collect();
normalize_vector_scores(&mut vector_chunks);
let after_sample: Vec<f32> = vector_chunks
.iter()
.filter_map(|c| c.scores.vector)
.take(5)
.collect();
debug!(
vector_before = ?before_sample,
vector_after = ?after_sample,
"Vector score normalization applied"
);
}
if tuning.normalize_fts_scores && !vector_chunks.is_empty() {
let before_sample: Vec<f32> = vector_chunks
.iter()
.filter_map(|c| c.scores.fts)
.take(5)
.collect();
normalize_fts_scores_in_merged(&mut vector_chunks);
let after_sample: Vec<f32> = vector_chunks
.iter()
.filter_map(|c| c.scores.fts)
.take(5)
.collect();
debug!(
fts_before = ?before_sample,
fts_after = ?after_sample,
"FTS score normalization applied"
);
}
// let fts_only_count = vector_chunks
// .iter()
// .filter(|c| c.scores.vector.is_none())
// .count();
// let both_count = vector_chunks
// .iter()
// .filter(|c| c.scores.vector.is_some() && c.scores.fts.is_some())
// .count();
// Fuse scores after normalization
for scored in &mut vector_chunks {
let fused = fuse_scores(&scored.scores, weights);
scored.update_fused(fused);
}
// Filter out FTS-only chunks if they're likely to be low quality
// (when overlap is low, FTS-only chunks are usually noise)
// Always keep chunks with vector scores (vector-only or both signals)
// If we have very low overlap (few chunks with both signals), drop FTS-only chunks.
// These are often noisy on keyword-heavy datasets and dilute strong vector hits.
// Keep vector-only and “golden” (vector+FTS) chunks.
let fts_only_count = vector_chunks
.iter()
.filter(|c| c.scores.vector.is_none())
@@ -708,10 +682,6 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(),
.iter()
.filter(|c| c.scores.vector.is_some() && c.scores.fts.is_some())
.count();
// If we have very low overlap (few chunks with both signals), filter out FTS-only chunks
// They're likely diluting the good vector results
// This preserves vector-only chunks and golden chunks (both signals)
if fts_only_count > 0 && both_count < 3 {
let before_filter = vector_chunks.len();
vector_chunks.retain(|c| c.scores.vector.is_some());
@@ -724,9 +694,8 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(),
}
debug!(
fusion_weights = ?weights,
top_fused_scores = ?vector_chunks.iter().take(5).map(|c| c.fused).collect::<Vec<_>>(),
"Fused scores after normalization"
"Fused scores after RRF ordering"
);
if ctx.diagnostics_enabled() {
@@ -797,11 +766,6 @@ pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
.min(ctx.config.tuning.chunk_vector_take.max(1));
if chunk_values.len() > limit {
println!(
"We removed chunks! we had {:?}, now going for {:?}",
chunk_values.len(),
limit
);
chunk_values.truncate(limit);
}
@@ -966,87 +930,24 @@ fn normalize_fts_scores<T>(results: &mut [Scored<T>]) {
}
}
fn normalize_vector_scores<T>(results: &mut [Scored<T>]) {
// Only normalize scores for items that actually have vector scores
let items_with_scores: Vec<(usize, f32)> = results
.iter()
.enumerate()
.filter_map(|(idx, candidate)| candidate.scores.vector.map(|score| (idx, score)))
.collect();
if items_with_scores.len() < 2 {
// Don't normalize if we have 0 or 1 scores - nothing to normalize against
return;
fn normalize_fts_query(input: &str) -> (String, usize) {
const STOPWORDS: &[&str] = &["the", "a", "an", "of", "in", "on", "and", "or", "to", "for"];
let mut cleaned = String::with_capacity(input.len());
for ch in input.chars() {
if ch.is_alphanumeric() {
cleaned.extend(ch.to_lowercase());
} else if ch.is_whitespace() {
cleaned.push(' ');
}
}
let raw_scores: Vec<f32> = items_with_scores.iter().map(|(_, score)| *score).collect();
// For cosine similarity scores (already in [0,1]), use a gentler normalization
// that preserves more of the original distribution
// Only normalize if the range is significant (more than 0.1 difference)
let min = raw_scores.iter().fold(f32::MAX, |a, &b| a.min(b));
let max = raw_scores.iter().fold(f32::MIN, |a, &b| a.max(b));
let range = max - min;
if range < 0.1 {
// Scores are too similar, don't normalize (would compress too much)
debug!(
vector_score_range = range,
min = min,
max = max,
"Skipping vector normalization - scores too similar"
);
return;
}
let normalized = min_max_normalize(&raw_scores);
for ((idx, _), normalized_score) in items_with_scores.iter().zip(normalized.into_iter()) {
results[*idx].scores.vector = Some(normalized_score);
results[*idx].update_fused(0.0);
}
}
fn normalize_fts_scores_in_merged<T>(results: &mut [Scored<T>]) {
// Only normalize scores for items that actually have FTS scores
let items_with_scores: Vec<(usize, f32)> = results
.iter()
.enumerate()
.filter_map(|(idx, candidate)| candidate.scores.fts.map(|score| (idx, score)))
.collect();
if items_with_scores.len() < 2 {
// Don't normalize if we have 0 or 1 scores - nothing to normalize against
// Single FTS score would become 1.0, which doesn't help
return;
}
let raw_scores: Vec<f32> = items_with_scores.iter().map(|(_, score)| *score).collect();
// BM25 scores can be negative or very high, so normalization is more important
// But check if we have enough variation to normalize
let min = raw_scores.iter().fold(f32::MAX, |a, &b| a.min(b));
let max = raw_scores.iter().fold(f32::MIN, |a, &b| a.max(b));
let range = max - min;
// For BM25, even small differences can be meaningful, but if all scores are
// very similar, normalization won't help
if range < 0.01 {
debug!(
fts_score_range = range,
min = min,
max = max,
"Skipping FTS normalization - scores too similar"
);
return;
}
let normalized = min_max_normalize(&raw_scores);
for ((idx, _), normalized_score) in items_with_scores.iter().zip(normalized.into_iter()) {
results[*idx].scores.fts = Some(normalized_score);
results[*idx].update_fused(0.0);
let mut tokens = Vec::new();
for token in cleaned.split_whitespace() {
if !STOPWORDS.contains(&token) && !token.is_empty() {
tokens.push(token.to_string());
}
}
let normalized = tokens.join(" ");
(normalized, tokens.len())
}
fn apply_fusion<T>(candidates: &mut HashMap<String, Scored<T>>, weights: FusionWeights)

View File

@@ -1,4 +1,4 @@
use std::cmp::Ordering;
use std::{cmp::Ordering, collections::HashMap};
use common::storage::types::StoredObject;
use serde::{Deserialize, Serialize};
@@ -71,6 +71,28 @@ impl Default for FusionWeights {
}
}
/// Configuration for reciprocal rank fusion.
#[derive(Debug, Clone, Copy)]
pub struct RrfConfig {
pub k: f32,
pub vector_weight: f32,
pub fts_weight: f32,
pub use_vector: bool,
pub use_fts: bool,
}
impl Default for RrfConfig {
fn default() -> Self {
Self {
k: 60.0,
vector_weight: 1.0,
fts_weight: 1.0,
use_vector: true,
use_fts: true,
}
}
}
pub const fn clamp_unit(value: f32) -> f32 {
value.clamp(0.0, 1.0)
}
@@ -196,3 +218,83 @@ where
.then_with(|| a.item.get_id().cmp(b.item.get_id()))
});
}
pub fn reciprocal_rank_fusion<T>(
mut vector_ranked: Vec<Scored<T>>,
mut fts_ranked: Vec<Scored<T>>,
config: RrfConfig,
) -> Vec<Scored<T>>
where
T: StoredObject + Clone,
{
let mut merged: HashMap<String, Scored<T>> = HashMap::new();
let k = if config.k <= 0.0 { 60.0 } else { config.k };
let vector_weight = if config.vector_weight.is_finite() {
config.vector_weight.max(0.0)
} else {
0.0
};
let fts_weight = if config.fts_weight.is_finite() {
config.fts_weight.max(0.0)
} else {
0.0
};
if config.use_vector && !vector_ranked.is_empty() {
vector_ranked.sort_by(|a, b| {
let a_score = a.scores.vector.unwrap_or(0.0);
let b_score = b.scores.vector.unwrap_or(0.0);
b_score
.partial_cmp(&a_score)
.unwrap_or(Ordering::Equal)
.then_with(|| a.item.get_id().cmp(b.item.get_id()))
});
for (rank, candidate) in vector_ranked.into_iter().enumerate() {
let id = candidate.item.get_id().to_owned();
let entry = merged
.entry(id.clone())
.or_insert_with(|| Scored::new(candidate.item.clone()));
if let Some(score) = candidate.scores.vector {
let existing = entry.scores.vector.unwrap_or(f32::MIN);
if score > existing {
entry.scores.vector = Some(score);
}
}
entry.item = candidate.item;
entry.fused += vector_weight / (k + rank as f32 + 1.0);
}
}
if config.use_fts && !fts_ranked.is_empty() {
fts_ranked.sort_by(|a, b| {
let a_score = a.scores.fts.unwrap_or(0.0);
let b_score = b.scores.fts.unwrap_or(0.0);
b_score
.partial_cmp(&a_score)
.unwrap_or(Ordering::Equal)
.then_with(|| a.item.get_id().cmp(b.item.get_id()))
});
for (rank, candidate) in fts_ranked.into_iter().enumerate() {
let id = candidate.item.get_id().to_owned();
let entry = merged
.entry(id.clone())
.or_insert_with(|| Scored::new(candidate.item.clone()));
if let Some(score) = candidate.scores.fts {
let existing = entry.scores.fts.unwrap_or(f32::MIN);
if score > existing {
entry.scores.fts = Some(score);
}
}
entry.item = candidate.item;
entry.fused += fts_weight / (k + rank as f32 + 1.0);
}
}
let mut fused: Vec<Scored<T>> = merged.into_values().collect();
sort_by_fused_desc(&mut fused);
fused
}