mirror of
https://github.com/perstarkse/minne.git
synced 2026-04-17 22:49:43 +02:00
retrieval: hybrid search, linear fusion
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
use crate::scoring::FusionWeights;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, clap::ValueEnum)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RetrievalStrategy {
|
||||
@@ -64,6 +66,12 @@ pub struct RetrievalTuning {
|
||||
pub rerank_scores_only: bool,
|
||||
pub rerank_keep_top: usize,
|
||||
pub chunk_result_cap: usize,
|
||||
/// Optional fusion weights for hybrid search. If None, uses default weights.
|
||||
pub fusion_weights: Option<FusionWeights>,
|
||||
/// Normalize vector similarity scores before fusion (default: true)
|
||||
pub normalize_vector_scores: bool,
|
||||
/// Normalize FTS (BM25) scores before fusion (default: true)
|
||||
pub normalize_fts_scores: bool,
|
||||
}
|
||||
|
||||
impl Default for RetrievalTuning {
|
||||
@@ -88,6 +96,12 @@ impl Default for RetrievalTuning {
|
||||
rerank_scores_only: false,
|
||||
rerank_keep_top: 8,
|
||||
chunk_result_cap: 5,
|
||||
fusion_weights: None,
|
||||
// Vector scores (cosine similarity) are already in [0,1] range
|
||||
// Normalization only helps when there's significant variation
|
||||
normalize_vector_scores: false,
|
||||
// FTS scores (BM25) are unbounded, normalization helps more
|
||||
normalize_fts_scores: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -593,38 +593,158 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(),
|
||||
debug!("Collecting vector chunk candidates for revised strategy");
|
||||
let embedding = ctx.ensure_embedding()?.clone();
|
||||
let tuning = &ctx.config.tuning;
|
||||
let weights = FusionWeights::default();
|
||||
let weights = tuning.fusion_weights.unwrap_or_else(FusionWeights::default);
|
||||
let fts_take = tuning.chunk_fts_take;
|
||||
|
||||
let mut vector_chunks: Vec<Scored<TextChunk>> = TextChunk::vector_search(
|
||||
tuning.chunk_vector_take,
|
||||
embedding,
|
||||
ctx.db_client,
|
||||
&ctx.user_id,
|
||||
)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|row| {
|
||||
let mut scored = Scored::new(row.chunk).with_vector_score(row.score);
|
||||
let (vector_rows, fts_rows) = tokio::try_join!(
|
||||
TextChunk::vector_search(
|
||||
tuning.chunk_vector_take,
|
||||
embedding,
|
||||
ctx.db_client,
|
||||
&ctx.user_id,
|
||||
),
|
||||
async {
|
||||
if fts_take == 0 {
|
||||
Ok(Vec::new())
|
||||
} else {
|
||||
TextChunk::fts_search(fts_take, &ctx.input_text, ctx.db_client, &ctx.user_id).await
|
||||
}
|
||||
}
|
||||
)?;
|
||||
|
||||
let mut merged: HashMap<String, Scored<TextChunk>> = HashMap::new();
|
||||
let vector_candidates = vector_rows.len();
|
||||
let fts_candidates = fts_rows.len();
|
||||
|
||||
// Collect vector results
|
||||
let vector_scored: Vec<Scored<TextChunk>> = vector_rows
|
||||
.into_iter()
|
||||
.map(|row| Scored::new(row.chunk).with_vector_score(row.score))
|
||||
.collect();
|
||||
|
||||
// Collect FTS results
|
||||
let fts_scored: Vec<Scored<TextChunk>> = fts_rows
|
||||
.into_iter()
|
||||
.map(|row| Scored::new(row.chunk).with_fts_score(row.score))
|
||||
.collect();
|
||||
|
||||
// Merge by ID first (before normalization)
|
||||
merge_scored_by_id(&mut merged, vector_scored);
|
||||
merge_scored_by_id(&mut merged, fts_scored);
|
||||
|
||||
let mut vector_chunks: Vec<Scored<TextChunk>> = merged.into_values().collect();
|
||||
|
||||
debug!(
|
||||
total_merged = vector_chunks.len(),
|
||||
vector_only = vector_chunks
|
||||
.iter()
|
||||
.filter(|c| c.scores.fts.is_none())
|
||||
.count(),
|
||||
fts_only = vector_chunks
|
||||
.iter()
|
||||
.filter(|c| c.scores.vector.is_none())
|
||||
.count(),
|
||||
both_signals = vector_chunks
|
||||
.iter()
|
||||
.filter(|c| c.scores.vector.is_some() && c.scores.fts.is_some())
|
||||
.count(),
|
||||
"Merged chunk candidates before normalization"
|
||||
);
|
||||
|
||||
// Normalize scores AFTER merging, on the final merged set
|
||||
// This ensures we normalize all vector scores together and all FTS scores together
|
||||
// for the actual candidates that will be fused
|
||||
if tuning.normalize_vector_scores && !vector_chunks.is_empty() {
|
||||
let before_sample: Vec<f32> = vector_chunks
|
||||
.iter()
|
||||
.filter_map(|c| c.scores.vector)
|
||||
.take(5)
|
||||
.collect();
|
||||
normalize_vector_scores(&mut vector_chunks);
|
||||
let after_sample: Vec<f32> = vector_chunks
|
||||
.iter()
|
||||
.filter_map(|c| c.scores.vector)
|
||||
.take(5)
|
||||
.collect();
|
||||
debug!(
|
||||
vector_before = ?before_sample,
|
||||
vector_after = ?after_sample,
|
||||
"Vector score normalization applied"
|
||||
);
|
||||
}
|
||||
if tuning.normalize_fts_scores && !vector_chunks.is_empty() {
|
||||
let before_sample: Vec<f32> = vector_chunks
|
||||
.iter()
|
||||
.filter_map(|c| c.scores.fts)
|
||||
.take(5)
|
||||
.collect();
|
||||
normalize_fts_scores_in_merged(&mut vector_chunks);
|
||||
let after_sample: Vec<f32> = vector_chunks
|
||||
.iter()
|
||||
.filter_map(|c| c.scores.fts)
|
||||
.take(5)
|
||||
.collect();
|
||||
debug!(
|
||||
fts_before = ?before_sample,
|
||||
fts_after = ?after_sample,
|
||||
"FTS score normalization applied"
|
||||
);
|
||||
}
|
||||
|
||||
// Fuse scores after normalization
|
||||
for scored in &mut vector_chunks {
|
||||
let fused = fuse_scores(&scored.scores, weights);
|
||||
scored.update_fused(fused);
|
||||
scored
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
|
||||
// Filter out FTS-only chunks if they're likely to be low quality
|
||||
// (when overlap is low, FTS-only chunks are usually noise)
|
||||
// Always keep chunks with vector scores (vector-only or both signals)
|
||||
let fts_only_count = vector_chunks
|
||||
.iter()
|
||||
.filter(|c| c.scores.vector.is_none())
|
||||
.count();
|
||||
let both_count = vector_chunks
|
||||
.iter()
|
||||
.filter(|c| c.scores.vector.is_some() && c.scores.fts.is_some())
|
||||
.count();
|
||||
|
||||
// If we have very low overlap (few chunks with both signals), filter out FTS-only chunks
|
||||
// They're likely diluting the good vector results
|
||||
// This preserves vector-only chunks and golden chunks (both signals)
|
||||
if fts_only_count > 0 && both_count < 3 {
|
||||
let before_filter = vector_chunks.len();
|
||||
vector_chunks.retain(|c| c.scores.vector.is_some());
|
||||
let after_filter = vector_chunks.len();
|
||||
debug!(
|
||||
fts_only_filtered = before_filter - after_filter,
|
||||
both_signals_preserved = both_count,
|
||||
"Filtered out FTS-only chunks due to low overlap, preserved golden chunks"
|
||||
);
|
||||
}
|
||||
|
||||
debug!(
|
||||
fusion_weights = ?weights,
|
||||
top_fused_scores = ?vector_chunks.iter().take(5).map(|c| c.fused).collect::<Vec<_>>(),
|
||||
"Fused scores after normalization"
|
||||
);
|
||||
|
||||
if ctx.diagnostics_enabled() {
|
||||
ctx.record_collect_candidates(CollectCandidatesStats {
|
||||
vector_entity_candidates: 0,
|
||||
vector_chunk_candidates: vector_chunks.len(),
|
||||
vector_chunk_candidates: vector_candidates,
|
||||
fts_entity_candidates: 0,
|
||||
fts_chunk_candidates: 0,
|
||||
fts_chunk_candidates: fts_candidates,
|
||||
vector_chunk_scores: sample_scores(&vector_chunks, |chunk| {
|
||||
chunk.scores.vector.unwrap_or(0.0)
|
||||
}),
|
||||
fts_chunk_scores: Vec::new(),
|
||||
fts_chunk_scores: sample_scores(&vector_chunks, |chunk| {
|
||||
chunk.scores.fts.unwrap_or(0.0)
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
vector_chunks.sort_by(|a, b| b.fused.partial_cmp(&a.fused).unwrap_or(Ordering::Equal));
|
||||
sort_by_fused_desc(&mut vector_chunks);
|
||||
ctx.revised_chunk_values = vector_chunks;
|
||||
|
||||
Ok(())
|
||||
@@ -668,13 +788,6 @@ pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError
|
||||
pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
debug!("Assembling chunk-only retrieval results");
|
||||
let mut chunk_values = std::mem::take(&mut ctx.revised_chunk_values);
|
||||
let question_terms = extract_keywords(&ctx.input_text);
|
||||
rank_chunks_by_combined_score(
|
||||
&mut chunk_values,
|
||||
&question_terms,
|
||||
ctx.config.tuning.lexical_match_weight,
|
||||
);
|
||||
|
||||
// Limit how many chunks we return to keep context size reasonable.
|
||||
let limit = ctx
|
||||
.config
|
||||
@@ -682,7 +795,13 @@ pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
.chunk_result_cap
|
||||
.max(1)
|
||||
.min(ctx.config.tuning.chunk_vector_take.max(1));
|
||||
|
||||
if chunk_values.len() > limit {
|
||||
println!(
|
||||
"We removed chunks! we had {:?}, now going for {:?}",
|
||||
chunk_values.len(),
|
||||
limit
|
||||
);
|
||||
chunk_values.truncate(limit);
|
||||
}
|
||||
|
||||
@@ -847,6 +966,89 @@ fn normalize_fts_scores<T>(results: &mut [Scored<T>]) {
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_vector_scores<T>(results: &mut [Scored<T>]) {
|
||||
// Only normalize scores for items that actually have vector scores
|
||||
let items_with_scores: Vec<(usize, f32)> = results
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(idx, candidate)| candidate.scores.vector.map(|score| (idx, score)))
|
||||
.collect();
|
||||
|
||||
if items_with_scores.len() < 2 {
|
||||
// Don't normalize if we have 0 or 1 scores - nothing to normalize against
|
||||
return;
|
||||
}
|
||||
|
||||
let raw_scores: Vec<f32> = items_with_scores.iter().map(|(_, score)| *score).collect();
|
||||
|
||||
// For cosine similarity scores (already in [0,1]), use a gentler normalization
|
||||
// that preserves more of the original distribution
|
||||
// Only normalize if the range is significant (more than 0.1 difference)
|
||||
let min = raw_scores.iter().fold(f32::MAX, |a, &b| a.min(b));
|
||||
let max = raw_scores.iter().fold(f32::MIN, |a, &b| a.max(b));
|
||||
let range = max - min;
|
||||
|
||||
if range < 0.1 {
|
||||
// Scores are too similar, don't normalize (would compress too much)
|
||||
debug!(
|
||||
vector_score_range = range,
|
||||
min = min,
|
||||
max = max,
|
||||
"Skipping vector normalization - scores too similar"
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let normalized = min_max_normalize(&raw_scores);
|
||||
|
||||
for ((idx, _), normalized_score) in items_with_scores.iter().zip(normalized.into_iter()) {
|
||||
results[*idx].scores.vector = Some(normalized_score);
|
||||
results[*idx].update_fused(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_fts_scores_in_merged<T>(results: &mut [Scored<T>]) {
|
||||
// Only normalize scores for items that actually have FTS scores
|
||||
let items_with_scores: Vec<(usize, f32)> = results
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(idx, candidate)| candidate.scores.fts.map(|score| (idx, score)))
|
||||
.collect();
|
||||
|
||||
if items_with_scores.len() < 2 {
|
||||
// Don't normalize if we have 0 or 1 scores - nothing to normalize against
|
||||
// Single FTS score would become 1.0, which doesn't help
|
||||
return;
|
||||
}
|
||||
|
||||
let raw_scores: Vec<f32> = items_with_scores.iter().map(|(_, score)| *score).collect();
|
||||
|
||||
// BM25 scores can be negative or very high, so normalization is more important
|
||||
// But check if we have enough variation to normalize
|
||||
let min = raw_scores.iter().fold(f32::MAX, |a, &b| a.min(b));
|
||||
let max = raw_scores.iter().fold(f32::MIN, |a, &b| a.max(b));
|
||||
let range = max - min;
|
||||
|
||||
// For BM25, even small differences can be meaningful, but if all scores are
|
||||
// very similar, normalization won't help
|
||||
if range < 0.01 {
|
||||
debug!(
|
||||
fts_score_range = range,
|
||||
min = min,
|
||||
max = max,
|
||||
"Skipping FTS normalization - scores too similar"
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let normalized = min_max_normalize(&raw_scores);
|
||||
|
||||
for ((idx, _), normalized_score) in items_with_scores.iter().zip(normalized.into_iter()) {
|
||||
results[*idx].scores.fts = Some(normalized_score);
|
||||
results[*idx].update_fused(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_fusion<T>(candidates: &mut HashMap<String, Scored<T>>, weights: FusionWeights)
|
||||
where
|
||||
T: StoredObject,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use common::storage::types::StoredObject;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Holds optional subscores gathered from different retrieval signals.
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
@@ -48,7 +49,7 @@ impl<T> Scored<T> {
|
||||
}
|
||||
|
||||
/// Weights used for linear score fusion.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct FusionWeights {
|
||||
pub vector: f32,
|
||||
pub fts: f32,
|
||||
@@ -58,11 +59,14 @@ pub struct FusionWeights {
|
||||
|
||||
impl Default for FusionWeights {
|
||||
fn default() -> Self {
|
||||
// Default weights favor vector search, which typically performs better
|
||||
// FTS is used as a complement when there's good overlap
|
||||
// Higher multi_bonus to heavily favor chunks with both signals (the "golden chunk")
|
||||
Self {
|
||||
vector: 0.5,
|
||||
fts: 0.3,
|
||||
vector: 0.8,
|
||||
fts: 0.2,
|
||||
graph: 0.2,
|
||||
multi_bonus: 0.02,
|
||||
multi_bonus: 0.3, // Increased to boost chunks with both signals
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -134,8 +138,19 @@ pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 {
|
||||
.chain(scores.fts.iter())
|
||||
.chain(scores.graph.iter())
|
||||
.count();
|
||||
|
||||
// Boost chunks with multiple signals (especially vector + FTS, the "golden chunk")
|
||||
if signals_present >= 2 {
|
||||
fused += weights.multi_bonus;
|
||||
// For chunks with both vector and FTS, give a significant boost
|
||||
// This helps identify the "golden chunk" that appears in both searches
|
||||
if scores.vector.is_some() && scores.fts.is_some() {
|
||||
// Multiplicative boost: multiply by (1 + bonus) to scale with the base score
|
||||
// This ensures high-scoring golden chunks get boosted more than low-scoring ones
|
||||
fused = fused * (1.0 + weights.multi_bonus);
|
||||
} else {
|
||||
// For other multi-signal combinations (e.g., vector + graph), use additive bonus
|
||||
fused += weights.multi_bonus;
|
||||
}
|
||||
}
|
||||
|
||||
clamp_unit(fused)
|
||||
|
||||
Reference in New Issue
Block a user