This commit is contained in:
Per Stark
2025-12-08 20:39:12 +01:00
parent d1a6d9abdf
commit 0cb1abc6db
13 changed files with 405 additions and 160 deletions

View File

@@ -1,4 +1,4 @@
use std::cmp::Ordering;
use std::{cmp::Ordering, collections::HashMap};
use common::storage::types::StoredObject;
use serde::{Deserialize, Serialize};
@@ -71,6 +71,28 @@ impl Default for FusionWeights {
}
}
/// Configuration for reciprocal rank fusion.
#[derive(Debug, Clone, Copy)]
pub struct RrfConfig {
pub k: f32,
pub vector_weight: f32,
pub fts_weight: f32,
pub use_vector: bool,
pub use_fts: bool,
}
impl Default for RrfConfig {
fn default() -> Self {
Self {
k: 60.0,
vector_weight: 1.0,
fts_weight: 1.0,
use_vector: true,
use_fts: true,
}
}
}
pub const fn clamp_unit(value: f32) -> f32 {
value.clamp(0.0, 1.0)
}
@@ -196,3 +218,83 @@ where
.then_with(|| a.item.get_id().cmp(b.item.get_id()))
});
}
pub fn reciprocal_rank_fusion<T>(
mut vector_ranked: Vec<Scored<T>>,
mut fts_ranked: Vec<Scored<T>>,
config: RrfConfig,
) -> Vec<Scored<T>>
where
T: StoredObject + Clone,
{
let mut merged: HashMap<String, Scored<T>> = HashMap::new();
let k = if config.k <= 0.0 { 60.0 } else { config.k };
let vector_weight = if config.vector_weight.is_finite() {
config.vector_weight.max(0.0)
} else {
0.0
};
let fts_weight = if config.fts_weight.is_finite() {
config.fts_weight.max(0.0)
} else {
0.0
};
if config.use_vector && !vector_ranked.is_empty() {
vector_ranked.sort_by(|a, b| {
let a_score = a.scores.vector.unwrap_or(0.0);
let b_score = b.scores.vector.unwrap_or(0.0);
b_score
.partial_cmp(&a_score)
.unwrap_or(Ordering::Equal)
.then_with(|| a.item.get_id().cmp(b.item.get_id()))
});
for (rank, candidate) in vector_ranked.into_iter().enumerate() {
let id = candidate.item.get_id().to_owned();
let entry = merged
.entry(id.clone())
.or_insert_with(|| Scored::new(candidate.item.clone()));
if let Some(score) = candidate.scores.vector {
let existing = entry.scores.vector.unwrap_or(f32::MIN);
if score > existing {
entry.scores.vector = Some(score);
}
}
entry.item = candidate.item;
entry.fused += vector_weight / (k + rank as f32 + 1.0);
}
}
if config.use_fts && !fts_ranked.is_empty() {
fts_ranked.sort_by(|a, b| {
let a_score = a.scores.fts.unwrap_or(0.0);
let b_score = b.scores.fts.unwrap_or(0.0);
b_score
.partial_cmp(&a_score)
.unwrap_or(Ordering::Equal)
.then_with(|| a.item.get_id().cmp(b.item.get_id()))
});
for (rank, candidate) in fts_ranked.into_iter().enumerate() {
let id = candidate.item.get_id().to_owned();
let entry = merged
.entry(id.clone())
.or_insert_with(|| Scored::new(candidate.item.clone()));
if let Some(score) = candidate.scores.fts {
let existing = entry.scores.fts.unwrap_or(f32::MIN);
if score > existing {
entry.scores.fts = Some(score);
}
}
entry.item = candidate.item;
entry.fused += fts_weight / (k + rank as f32 + 1.0);
}
}
let mut fused: Vec<Scored<T>> = merged.into_values().collect();
sort_by_fused_desc(&mut fused);
fused
}