mirror of
https://github.com/perstarkse/minne.git
synced 2026-04-26 10:48:37 +02:00
beir-rff
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
use std::cmp::Ordering;
|
||||
use std::{cmp::Ordering, collections::HashMap};
|
||||
|
||||
use common::storage::types::StoredObject;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -71,6 +71,28 @@ impl Default for FusionWeights {
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for reciprocal rank fusion.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct RrfConfig {
|
||||
pub k: f32,
|
||||
pub vector_weight: f32,
|
||||
pub fts_weight: f32,
|
||||
pub use_vector: bool,
|
||||
pub use_fts: bool,
|
||||
}
|
||||
|
||||
impl Default for RrfConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
k: 60.0,
|
||||
vector_weight: 1.0,
|
||||
fts_weight: 1.0,
|
||||
use_vector: true,
|
||||
use_fts: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn clamp_unit(value: f32) -> f32 {
|
||||
value.clamp(0.0, 1.0)
|
||||
}
|
||||
@@ -196,3 +218,83 @@ where
|
||||
.then_with(|| a.item.get_id().cmp(b.item.get_id()))
|
||||
});
|
||||
}
|
||||
|
||||
pub fn reciprocal_rank_fusion<T>(
|
||||
mut vector_ranked: Vec<Scored<T>>,
|
||||
mut fts_ranked: Vec<Scored<T>>,
|
||||
config: RrfConfig,
|
||||
) -> Vec<Scored<T>>
|
||||
where
|
||||
T: StoredObject + Clone,
|
||||
{
|
||||
let mut merged: HashMap<String, Scored<T>> = HashMap::new();
|
||||
let k = if config.k <= 0.0 { 60.0 } else { config.k };
|
||||
let vector_weight = if config.vector_weight.is_finite() {
|
||||
config.vector_weight.max(0.0)
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
let fts_weight = if config.fts_weight.is_finite() {
|
||||
config.fts_weight.max(0.0)
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
if config.use_vector && !vector_ranked.is_empty() {
|
||||
vector_ranked.sort_by(|a, b| {
|
||||
let a_score = a.scores.vector.unwrap_or(0.0);
|
||||
let b_score = b.scores.vector.unwrap_or(0.0);
|
||||
b_score
|
||||
.partial_cmp(&a_score)
|
||||
.unwrap_or(Ordering::Equal)
|
||||
.then_with(|| a.item.get_id().cmp(b.item.get_id()))
|
||||
});
|
||||
|
||||
for (rank, candidate) in vector_ranked.into_iter().enumerate() {
|
||||
let id = candidate.item.get_id().to_owned();
|
||||
let entry = merged
|
||||
.entry(id.clone())
|
||||
.or_insert_with(|| Scored::new(candidate.item.clone()));
|
||||
|
||||
if let Some(score) = candidate.scores.vector {
|
||||
let existing = entry.scores.vector.unwrap_or(f32::MIN);
|
||||
if score > existing {
|
||||
entry.scores.vector = Some(score);
|
||||
}
|
||||
}
|
||||
entry.item = candidate.item;
|
||||
entry.fused += vector_weight / (k + rank as f32 + 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
if config.use_fts && !fts_ranked.is_empty() {
|
||||
fts_ranked.sort_by(|a, b| {
|
||||
let a_score = a.scores.fts.unwrap_or(0.0);
|
||||
let b_score = b.scores.fts.unwrap_or(0.0);
|
||||
b_score
|
||||
.partial_cmp(&a_score)
|
||||
.unwrap_or(Ordering::Equal)
|
||||
.then_with(|| a.item.get_id().cmp(b.item.get_id()))
|
||||
});
|
||||
|
||||
for (rank, candidate) in fts_ranked.into_iter().enumerate() {
|
||||
let id = candidate.item.get_id().to_owned();
|
||||
let entry = merged
|
||||
.entry(id.clone())
|
||||
.or_insert_with(|| Scored::new(candidate.item.clone()));
|
||||
|
||||
if let Some(score) = candidate.scores.fts {
|
||||
let existing = entry.scores.fts.unwrap_or(f32::MIN);
|
||||
if score > existing {
|
||||
entry.scores.fts = Some(score);
|
||||
}
|
||||
}
|
||||
entry.item = candidate.item;
|
||||
entry.fused += fts_weight / (k + rank as f32 + 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
let mut fused: Vec<Scored<T>> = merged.into_values().collect();
|
||||
sort_by_fused_desc(&mut fused);
|
||||
fused
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user