mirror of
https://github.com/perstarkse/minne.git
synced 2026-04-23 17:28:34 +02:00
beir-rff
This commit is contained in:
@@ -20,7 +20,7 @@ datasets:
|
|||||||
category: "Natural Questions"
|
category: "Natural Questions"
|
||||||
entity_suffix: "Natural Questions"
|
entity_suffix: "Natural Questions"
|
||||||
source_prefix: "nq"
|
source_prefix: "nq"
|
||||||
raw: "data/raw/nq/dev-all.jsonl"
|
raw: "data/raw/nq-dev/dev-all.jsonl"
|
||||||
converted: "data/converted/nq-dev-minne.json"
|
converted: "data/converted/nq-dev-minne.json"
|
||||||
include_unanswerable: true
|
include_unanswerable: true
|
||||||
slices:
|
slices:
|
||||||
@@ -42,7 +42,7 @@ datasets:
|
|||||||
slices:
|
slices:
|
||||||
- id: beir-mix-600
|
- id: beir-mix-600
|
||||||
label: "BEIR mix (600)"
|
label: "BEIR mix (600)"
|
||||||
description: "Balanced slice across FEVER, FiQA, HotpotQA, NFCorpus, Quora, TREC-COVID"
|
description: "Balanced slice across FEVER, FiQA, HotpotQA, NFCorpus, Quora, TREC-COVID, SciFact, NQ-BEIR"
|
||||||
limit: 600
|
limit: 600
|
||||||
corpus_limit: 6000
|
corpus_limit: 6000
|
||||||
seed: 0x5eed2025
|
seed: 0x5eed2025
|
||||||
@@ -136,3 +136,33 @@ datasets:
|
|||||||
limit: 200
|
limit: 200
|
||||||
corpus_limit: 5000
|
corpus_limit: 5000
|
||||||
seed: 0x5eed2025
|
seed: 0x5eed2025
|
||||||
|
- id: scifact
|
||||||
|
label: "SciFact (BEIR)"
|
||||||
|
category: "SciFact"
|
||||||
|
entity_suffix: "SciFact"
|
||||||
|
source_prefix: "scifact"
|
||||||
|
raw: "data/raw/scifact"
|
||||||
|
converted: "data/converted/scifact-minne.json"
|
||||||
|
include_unanswerable: false
|
||||||
|
slices:
|
||||||
|
- id: scifact-test-200
|
||||||
|
label: "SciFact test (200)"
|
||||||
|
description: "200-case slice from BEIR test qrels"
|
||||||
|
limit: 200
|
||||||
|
corpus_limit: 3000
|
||||||
|
seed: 0x5eed2025
|
||||||
|
- id: nq-beir
|
||||||
|
label: "Natural Questions (BEIR)"
|
||||||
|
category: "Natural Questions"
|
||||||
|
entity_suffix: "Natural Questions"
|
||||||
|
source_prefix: "nq-beir"
|
||||||
|
raw: "data/raw/nq"
|
||||||
|
converted: "data/converted/nq-beir-minne.json"
|
||||||
|
include_unanswerable: false
|
||||||
|
slices:
|
||||||
|
- id: nq-beir-test-200
|
||||||
|
label: "NQ (BEIR) test (200)"
|
||||||
|
description: "200-case slice from BEIR test qrels"
|
||||||
|
limit: 200
|
||||||
|
corpus_limit: 5000
|
||||||
|
seed: 0x5eed2025
|
||||||
|
|||||||
@@ -84,6 +84,26 @@ pub struct RetrievalSettings {
|
|||||||
#[arg(long, default_value_t = 5)]
|
#[arg(long, default_value_t = 5)]
|
||||||
pub chunk_result_cap: usize,
|
pub chunk_result_cap: usize,
|
||||||
|
|
||||||
|
/// Reciprocal rank fusion k value for revised chunk merging
|
||||||
|
#[arg(long)]
|
||||||
|
pub chunk_rrf_k: Option<f32>,
|
||||||
|
|
||||||
|
/// Weight for vector ranks in revised RRF
|
||||||
|
#[arg(long)]
|
||||||
|
pub chunk_rrf_vector_weight: Option<f32>,
|
||||||
|
|
||||||
|
/// Weight for chunk FTS ranks in revised RRF
|
||||||
|
#[arg(long)]
|
||||||
|
pub chunk_rrf_fts_weight: Option<f32>,
|
||||||
|
|
||||||
|
/// Include vector ranks in revised RRF (default: true)
|
||||||
|
#[arg(long)]
|
||||||
|
pub chunk_rrf_use_vector: Option<bool>,
|
||||||
|
|
||||||
|
/// Include chunk FTS ranks in revised RRF (default: true)
|
||||||
|
#[arg(long)]
|
||||||
|
pub chunk_rrf_use_fts: Option<bool>,
|
||||||
|
|
||||||
/// Require verified chunks (disable with --llm-mode)
|
/// Require verified chunks (disable with --llm-mode)
|
||||||
#[arg(skip = true)]
|
#[arg(skip = true)]
|
||||||
pub require_verified_chunks: bool,
|
pub require_verified_chunks: bool,
|
||||||
@@ -104,6 +124,11 @@ impl Default for RetrievalSettings {
|
|||||||
rerank_pool_size: 4,
|
rerank_pool_size: 4,
|
||||||
rerank_keep_top: 10,
|
rerank_keep_top: 10,
|
||||||
chunk_result_cap: 5,
|
chunk_result_cap: 5,
|
||||||
|
chunk_rrf_k: None,
|
||||||
|
chunk_rrf_vector_weight: None,
|
||||||
|
chunk_rrf_fts_weight: None,
|
||||||
|
chunk_rrf_use_vector: None,
|
||||||
|
chunk_rrf_use_fts: None,
|
||||||
require_verified_chunks: true,
|
require_verified_chunks: true,
|
||||||
strategy: RetrievalStrategy::Initial,
|
strategy: RetrievalStrategy::Initial,
|
||||||
}
|
}
|
||||||
@@ -376,6 +401,28 @@ impl Config {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(k) = self.retrieval.chunk_rrf_k {
|
||||||
|
if k <= 0.0 || !k.is_finite() {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"--chunk-rrf-k must be a positive, finite number (got {k})"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(weight) = self.retrieval.chunk_rrf_vector_weight {
|
||||||
|
if weight < 0.0 || !weight.is_finite() {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"--chunk-rrf-vector-weight must be a non-negative, finite number (got {weight})"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(weight) = self.retrieval.chunk_rrf_fts_weight {
|
||||||
|
if weight < 0.0 || !weight.is_finite() {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"--chunk-rrf-fts-weight must be a non-negative, finite number (got {weight})"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if self.concurrency == 0 {
|
if self.concurrency == 0 {
|
||||||
return Err(anyhow!("--concurrency must be greater than zero"));
|
return Err(anyhow!("--concurrency must be greater than zero"));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -262,6 +262,10 @@ pub enum DatasetKind {
|
|||||||
Quora,
|
Quora,
|
||||||
#[value(name = "trec-covid", alias = "treccovid", alias = "trec_covid")]
|
#[value(name = "trec-covid", alias = "treccovid", alias = "trec_covid")]
|
||||||
TrecCovid,
|
TrecCovid,
|
||||||
|
#[value(name = "scifact")]
|
||||||
|
Scifact,
|
||||||
|
#[value(name = "nq-beir", alias = "natural-questions-beir")]
|
||||||
|
NqBeir,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DatasetKind {
|
impl DatasetKind {
|
||||||
@@ -276,6 +280,8 @@ impl DatasetKind {
|
|||||||
Self::Nfcorpus => "nfcorpus",
|
Self::Nfcorpus => "nfcorpus",
|
||||||
Self::Quora => "quora",
|
Self::Quora => "quora",
|
||||||
Self::TrecCovid => "trec-covid",
|
Self::TrecCovid => "trec-covid",
|
||||||
|
Self::Scifact => "scifact",
|
||||||
|
Self::NqBeir => "nq-beir",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -290,6 +296,8 @@ impl DatasetKind {
|
|||||||
Self::Nfcorpus => "NFCorpus (BEIR)",
|
Self::Nfcorpus => "NFCorpus (BEIR)",
|
||||||
Self::Quora => "Quora (IR)",
|
Self::Quora => "Quora (IR)",
|
||||||
Self::TrecCovid => "TREC-COVID (BEIR)",
|
Self::TrecCovid => "TREC-COVID (BEIR)",
|
||||||
|
Self::Scifact => "SciFact (BEIR)",
|
||||||
|
Self::NqBeir => "Natural Questions (BEIR)",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -304,6 +312,8 @@ impl DatasetKind {
|
|||||||
Self::Nfcorpus => "NFCorpus",
|
Self::Nfcorpus => "NFCorpus",
|
||||||
Self::Quora => "Quora",
|
Self::Quora => "Quora",
|
||||||
Self::TrecCovid => "TREC-COVID",
|
Self::TrecCovid => "TREC-COVID",
|
||||||
|
Self::Scifact => "SciFact",
|
||||||
|
Self::NqBeir => "Natural Questions",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -318,6 +328,8 @@ impl DatasetKind {
|
|||||||
Self::Nfcorpus => "NFCorpus",
|
Self::Nfcorpus => "NFCorpus",
|
||||||
Self::Quora => "Quora",
|
Self::Quora => "Quora",
|
||||||
Self::TrecCovid => "TREC-COVID",
|
Self::TrecCovid => "TREC-COVID",
|
||||||
|
Self::Scifact => "SciFact",
|
||||||
|
Self::NqBeir => "Natural Questions",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -332,6 +344,8 @@ impl DatasetKind {
|
|||||||
Self::Nfcorpus => "nfcorpus",
|
Self::Nfcorpus => "nfcorpus",
|
||||||
Self::Quora => "quora",
|
Self::Quora => "quora",
|
||||||
Self::TrecCovid => "trec-covid",
|
Self::TrecCovid => "trec-covid",
|
||||||
|
Self::Scifact => "scifact",
|
||||||
|
Self::NqBeir => "nq-beir",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -376,20 +390,24 @@ impl FromStr for DatasetKind {
|
|||||||
"nfcorpus" | "nf-corpus" => Ok(Self::Nfcorpus),
|
"nfcorpus" | "nf-corpus" => Ok(Self::Nfcorpus),
|
||||||
"quora" => Ok(Self::Quora),
|
"quora" => Ok(Self::Quora),
|
||||||
"trec-covid" | "treccovid" | "trec_covid" => Ok(Self::TrecCovid),
|
"trec-covid" | "treccovid" | "trec_covid" => Ok(Self::TrecCovid),
|
||||||
|
"scifact" => Ok(Self::Scifact),
|
||||||
|
"nq-beir" | "natural-questions-beir" => Ok(Self::NqBeir),
|
||||||
other => {
|
other => {
|
||||||
anyhow::bail!("unknown dataset '{other}'. Expected one of: squad, natural-questions, beir, fever, fiqa, hotpotqa, nfcorpus, quora, trec-covid.")
|
anyhow::bail!("unknown dataset '{other}'. Expected one of: squad, natural-questions, beir, fever, fiqa, hotpotqa, nfcorpus, quora, trec-covid, scifact, nq-beir.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const BEIR_DATASETS: [DatasetKind; 6] = [
|
pub const BEIR_DATASETS: [DatasetKind; 8] = [
|
||||||
DatasetKind::Fever,
|
DatasetKind::Fever,
|
||||||
DatasetKind::Fiqa,
|
DatasetKind::Fiqa,
|
||||||
DatasetKind::HotpotQa,
|
DatasetKind::HotpotQa,
|
||||||
DatasetKind::Nfcorpus,
|
DatasetKind::Nfcorpus,
|
||||||
DatasetKind::Quora,
|
DatasetKind::Quora,
|
||||||
DatasetKind::TrecCovid,
|
DatasetKind::TrecCovid,
|
||||||
|
DatasetKind::Scifact,
|
||||||
|
DatasetKind::NqBeir,
|
||||||
];
|
];
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -481,7 +499,9 @@ pub fn convert(
|
|||||||
| DatasetKind::HotpotQa
|
| DatasetKind::HotpotQa
|
||||||
| DatasetKind::Nfcorpus
|
| DatasetKind::Nfcorpus
|
||||||
| DatasetKind::Quora
|
| DatasetKind::Quora
|
||||||
| DatasetKind::TrecCovid => beir::convert_beir(raw_path, dataset)?,
|
| DatasetKind::TrecCovid
|
||||||
|
| DatasetKind::Scifact
|
||||||
|
| DatasetKind::NqBeir => beir::convert_beir(raw_path, dataset)?,
|
||||||
};
|
};
|
||||||
|
|
||||||
let metadata_limit = match dataset {
|
let metadata_limit = match dataset {
|
||||||
@@ -489,13 +509,26 @@ pub fn convert(
|
|||||||
_ => context_token_limit,
|
_ => context_token_limit,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let generated_at = match dataset {
|
||||||
|
DatasetKind::Beir
|
||||||
|
| DatasetKind::Fever
|
||||||
|
| DatasetKind::Fiqa
|
||||||
|
| DatasetKind::HotpotQa
|
||||||
|
| DatasetKind::Nfcorpus
|
||||||
|
| DatasetKind::Quora
|
||||||
|
| DatasetKind::TrecCovid
|
||||||
|
| DatasetKind::Scifact
|
||||||
|
| DatasetKind::NqBeir => base_timestamp(),
|
||||||
|
_ => Utc::now(),
|
||||||
|
};
|
||||||
|
|
||||||
let source_label = match dataset {
|
let source_label = match dataset {
|
||||||
DatasetKind::Beir => "beir-mix".to_string(),
|
DatasetKind::Beir => "beir-mix".to_string(),
|
||||||
_ => raw_path.display().to_string(),
|
_ => raw_path.display().to_string(),
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(ConvertedDataset {
|
Ok(ConvertedDataset {
|
||||||
generated_at: Utc::now(),
|
generated_at,
|
||||||
metadata: DatasetMetadata::for_kind(dataset, include_unanswerable, metadata_limit),
|
metadata: DatasetMetadata::for_kind(dataset, include_unanswerable, metadata_limit),
|
||||||
source: source_label,
|
source: source_label,
|
||||||
paragraphs,
|
paragraphs,
|
||||||
|
|||||||
@@ -63,6 +63,21 @@ pub(crate) async fn run_queries(
|
|||||||
if let Some(value) = config.retrieval.chunk_fts_take {
|
if let Some(value) = config.retrieval.chunk_fts_take {
|
||||||
retrieval_config.tuning.chunk_fts_take = value;
|
retrieval_config.tuning.chunk_fts_take = value;
|
||||||
}
|
}
|
||||||
|
if let Some(value) = config.retrieval.chunk_rrf_k {
|
||||||
|
retrieval_config.tuning.chunk_rrf_k = value;
|
||||||
|
}
|
||||||
|
if let Some(value) = config.retrieval.chunk_rrf_vector_weight {
|
||||||
|
retrieval_config.tuning.chunk_rrf_vector_weight = value;
|
||||||
|
}
|
||||||
|
if let Some(value) = config.retrieval.chunk_rrf_fts_weight {
|
||||||
|
retrieval_config.tuning.chunk_rrf_fts_weight = value;
|
||||||
|
}
|
||||||
|
if let Some(value) = config.retrieval.chunk_rrf_use_vector {
|
||||||
|
retrieval_config.tuning.chunk_rrf_use_vector = value;
|
||||||
|
}
|
||||||
|
if let Some(value) = config.retrieval.chunk_rrf_use_fts {
|
||||||
|
retrieval_config.tuning.chunk_rrf_use_fts = value;
|
||||||
|
}
|
||||||
if let Some(value) = config.retrieval.chunk_avg_chars_per_token {
|
if let Some(value) = config.retrieval.chunk_avg_chars_per_token {
|
||||||
retrieval_config.tuning.avg_chars_per_token = value;
|
retrieval_config.tuning.avg_chars_per_token = value;
|
||||||
}
|
}
|
||||||
@@ -93,6 +108,11 @@ pub(crate) async fn run_queries(
|
|||||||
rerank_keep_top = config.retrieval.rerank_keep_top,
|
rerank_keep_top = config.retrieval.rerank_keep_top,
|
||||||
chunk_vector_take = effective_chunk_vector,
|
chunk_vector_take = effective_chunk_vector,
|
||||||
chunk_fts_take = effective_chunk_fts,
|
chunk_fts_take = effective_chunk_fts,
|
||||||
|
chunk_rrf_k = active_tuning.chunk_rrf_k,
|
||||||
|
chunk_rrf_vector_weight = active_tuning.chunk_rrf_vector_weight,
|
||||||
|
chunk_rrf_fts_weight = active_tuning.chunk_rrf_fts_weight,
|
||||||
|
chunk_rrf_use_vector = active_tuning.chunk_rrf_use_vector,
|
||||||
|
chunk_rrf_use_fts = active_tuning.chunk_rrf_use_fts,
|
||||||
embedding_backend = ctx.embedding_provider().backend_label(),
|
embedding_backend = ctx.embedding_provider().backend_label(),
|
||||||
embedding_model = ctx
|
embedding_model = ctx
|
||||||
.embedding_provider()
|
.embedding_provider()
|
||||||
|
|||||||
@@ -202,6 +202,11 @@ pub(crate) async fn summarize(
|
|||||||
detailed_report: config.detailed_report,
|
detailed_report: config.detailed_report,
|
||||||
retrieval_strategy: config.retrieval.strategy.to_string(),
|
retrieval_strategy: config.retrieval.strategy.to_string(),
|
||||||
chunk_result_cap: config.retrieval.chunk_result_cap,
|
chunk_result_cap: config.retrieval.chunk_result_cap,
|
||||||
|
chunk_rrf_k: active_tuning.chunk_rrf_k,
|
||||||
|
chunk_rrf_vector_weight: active_tuning.chunk_rrf_vector_weight,
|
||||||
|
chunk_rrf_fts_weight: active_tuning.chunk_rrf_fts_weight,
|
||||||
|
chunk_rrf_use_vector: active_tuning.chunk_rrf_use_vector,
|
||||||
|
chunk_rrf_use_fts: active_tuning.chunk_rrf_use_fts,
|
||||||
ingest_chunk_min_tokens: config.ingest_chunk_min_tokens,
|
ingest_chunk_min_tokens: config.ingest_chunk_min_tokens,
|
||||||
ingest_chunk_max_tokens: config.ingest_chunk_max_tokens,
|
ingest_chunk_max_tokens: config.ingest_chunk_max_tokens,
|
||||||
ingest_chunks_only: config.ingest_chunks_only,
|
ingest_chunks_only: config.ingest_chunks_only,
|
||||||
|
|||||||
@@ -70,6 +70,11 @@ pub struct EvaluationSummary {
|
|||||||
pub detailed_report: bool,
|
pub detailed_report: bool,
|
||||||
pub retrieval_strategy: String,
|
pub retrieval_strategy: String,
|
||||||
pub chunk_result_cap: usize,
|
pub chunk_result_cap: usize,
|
||||||
|
pub chunk_rrf_k: f32,
|
||||||
|
pub chunk_rrf_vector_weight: f32,
|
||||||
|
pub chunk_rrf_fts_weight: f32,
|
||||||
|
pub chunk_rrf_use_vector: bool,
|
||||||
|
pub chunk_rrf_use_fts: bool,
|
||||||
pub ingest_chunk_min_tokens: usize,
|
pub ingest_chunk_min_tokens: usize,
|
||||||
pub ingest_chunk_max_tokens: usize,
|
pub ingest_chunk_max_tokens: usize,
|
||||||
pub ingest_chunks_only: bool,
|
pub ingest_chunks_only: bool,
|
||||||
|
|||||||
@@ -373,6 +373,20 @@ pub async fn ensure_corpus(
|
|||||||
let reused_ingestion = ingested_count == 0 && !cache.force_refresh;
|
let reused_ingestion = ingested_count == 0 && !cache.force_refresh;
|
||||||
let reused_embeddings = reused_ingestion && !cache.refresh_embeddings_only;
|
let reused_embeddings = reused_ingestion && !cache.refresh_embeddings_only;
|
||||||
|
|
||||||
|
info!(
|
||||||
|
dataset = %dataset.metadata.id,
|
||||||
|
slice = %slice.manifest.slice_id,
|
||||||
|
fingerprint = %ingestion_fingerprint,
|
||||||
|
reused_ingestion,
|
||||||
|
reused_embeddings,
|
||||||
|
positive_reused = stats.positive_reused,
|
||||||
|
positive_ingested = stats.positive_ingested,
|
||||||
|
negative_reused = stats.negative_reused,
|
||||||
|
negative_ingested = stats.negative_ingested,
|
||||||
|
shard_dir = %base_dir.display(),
|
||||||
|
"Corpus cache outcome"
|
||||||
|
);
|
||||||
|
|
||||||
let handle = CorpusHandle {
|
let handle = CorpusHandle {
|
||||||
manifest,
|
manifest,
|
||||||
path: base_dir,
|
path: base_dir,
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ use common::storage::{
|
|||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use surrealdb::sql::Thing;
|
use surrealdb::sql::Thing;
|
||||||
use tracing::warn;
|
use tracing::{debug, warn};
|
||||||
|
|
||||||
use crate::datasets::{ConvertedParagraph, ConvertedQuestion};
|
use crate::datasets::{ConvertedParagraph, ConvertedQuestion};
|
||||||
|
|
||||||
@@ -440,6 +440,12 @@ impl ParagraphShardStore {
|
|||||||
.with_context(|| format!("parsing shard {}", path.display()))?;
|
.with_context(|| format!("parsing shard {}", path.display()))?;
|
||||||
|
|
||||||
if shard.ingestion_fingerprint != fingerprint {
|
if shard.ingestion_fingerprint != fingerprint {
|
||||||
|
debug!(
|
||||||
|
path = %path.display(),
|
||||||
|
expected = fingerprint,
|
||||||
|
found = shard.ingestion_fingerprint,
|
||||||
|
"Shard fingerprint mismatch; will rebuild"
|
||||||
|
);
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
if shard.version != PARAGRAPH_SHARD_VERSION {
|
if shard.version != PARAGRAPH_SHARD_VERSION {
|
||||||
|
|||||||
@@ -197,6 +197,11 @@ mod tests {
|
|||||||
detailed_report: false,
|
detailed_report: false,
|
||||||
retrieval_strategy: "initial".into(),
|
retrieval_strategy: "initial".into(),
|
||||||
chunk_result_cap: 5,
|
chunk_result_cap: 5,
|
||||||
|
chunk_rrf_k: 60.0,
|
||||||
|
chunk_rrf_vector_weight: 1.0,
|
||||||
|
chunk_rrf_fts_weight: 1.0,
|
||||||
|
chunk_rrf_use_vector: true,
|
||||||
|
chunk_rrf_use_fts: true,
|
||||||
ingest_chunk_min_tokens: 256,
|
ingest_chunk_min_tokens: 256,
|
||||||
ingest_chunk_max_tokens: 512,
|
ingest_chunk_max_tokens: 512,
|
||||||
ingest_chunks_only: false,
|
ingest_chunks_only: false,
|
||||||
|
|||||||
@@ -88,6 +88,16 @@ pub struct RetrievalSection {
|
|||||||
pub rerank_pool_size: Option<usize>,
|
pub rerank_pool_size: Option<usize>,
|
||||||
pub rerank_keep_top: usize,
|
pub rerank_keep_top: usize,
|
||||||
pub chunk_result_cap: usize,
|
pub chunk_result_cap: usize,
|
||||||
|
#[serde(default = "default_chunk_rrf_k")]
|
||||||
|
pub chunk_rrf_k: f32,
|
||||||
|
#[serde(default = "default_chunk_rrf_weight")]
|
||||||
|
pub chunk_rrf_vector_weight: f32,
|
||||||
|
#[serde(default = "default_chunk_rrf_weight")]
|
||||||
|
pub chunk_rrf_fts_weight: f32,
|
||||||
|
#[serde(default = "default_chunk_rrf_use")]
|
||||||
|
pub chunk_rrf_use_vector: bool,
|
||||||
|
#[serde(default = "default_chunk_rrf_use")]
|
||||||
|
pub chunk_rrf_use_fts: bool,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub chunk_vector_take: usize,
|
pub chunk_vector_take: usize,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
@@ -98,6 +108,18 @@ pub struct RetrievalSection {
|
|||||||
pub ingest_chunks_only: bool,
|
pub ingest_chunks_only: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const fn default_chunk_rrf_k() -> f32 {
|
||||||
|
60.0
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_chunk_rrf_weight() -> f32 {
|
||||||
|
1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_chunk_rrf_use() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct LlmSection {
|
pub struct LlmSection {
|
||||||
pub cases: usize,
|
pub cases: usize,
|
||||||
@@ -206,6 +228,11 @@ impl EvaluationReport {
|
|||||||
rerank_pool_size: summary.rerank_pool_size,
|
rerank_pool_size: summary.rerank_pool_size,
|
||||||
rerank_keep_top: summary.rerank_keep_top,
|
rerank_keep_top: summary.rerank_keep_top,
|
||||||
chunk_result_cap: summary.chunk_result_cap,
|
chunk_result_cap: summary.chunk_result_cap,
|
||||||
|
chunk_rrf_k: summary.chunk_rrf_k,
|
||||||
|
chunk_rrf_vector_weight: summary.chunk_rrf_vector_weight,
|
||||||
|
chunk_rrf_fts_weight: summary.chunk_rrf_fts_weight,
|
||||||
|
chunk_rrf_use_vector: summary.chunk_rrf_use_vector,
|
||||||
|
chunk_rrf_use_fts: summary.chunk_rrf_use_fts,
|
||||||
chunk_vector_take: summary.chunk_vector_take,
|
chunk_vector_take: summary.chunk_vector_take,
|
||||||
chunk_fts_take: summary.chunk_fts_take,
|
chunk_fts_take: summary.chunk_fts_take,
|
||||||
ingest_chunk_min_tokens: summary.ingest_chunk_min_tokens,
|
ingest_chunk_min_tokens: summary.ingest_chunk_min_tokens,
|
||||||
@@ -856,6 +883,11 @@ fn convert_legacy_entry(entry: LegacyHistoryEntry) -> EvaluationReport {
|
|||||||
rerank_pool_size: entry.rerank_pool_size,
|
rerank_pool_size: entry.rerank_pool_size,
|
||||||
rerank_keep_top: entry.rerank_keep_top,
|
rerank_keep_top: entry.rerank_keep_top,
|
||||||
chunk_result_cap: entry.chunk_result_cap.unwrap_or(5),
|
chunk_result_cap: entry.chunk_result_cap.unwrap_or(5),
|
||||||
|
chunk_rrf_k: default_chunk_rrf_k(),
|
||||||
|
chunk_rrf_vector_weight: default_chunk_rrf_weight(),
|
||||||
|
chunk_rrf_fts_weight: default_chunk_rrf_weight(),
|
||||||
|
chunk_rrf_use_vector: default_chunk_rrf_use(),
|
||||||
|
chunk_rrf_use_fts: default_chunk_rrf_use(),
|
||||||
chunk_vector_take: 0,
|
chunk_vector_take: 0,
|
||||||
chunk_fts_take: 0,
|
chunk_fts_take: 0,
|
||||||
ingest_chunk_min_tokens: entry.ingest_chunk_min_tokens.unwrap_or(256),
|
ingest_chunk_min_tokens: entry.ingest_chunk_min_tokens.unwrap_or(256),
|
||||||
@@ -1098,6 +1130,11 @@ mod tests {
|
|||||||
detailed_report: true,
|
detailed_report: true,
|
||||||
retrieval_strategy: "initial".into(),
|
retrieval_strategy: "initial".into(),
|
||||||
chunk_result_cap: 5,
|
chunk_result_cap: 5,
|
||||||
|
chunk_rrf_k: 60.0,
|
||||||
|
chunk_rrf_vector_weight: 1.0,
|
||||||
|
chunk_rrf_fts_weight: 1.0,
|
||||||
|
chunk_rrf_use_vector: true,
|
||||||
|
chunk_rrf_use_fts: true,
|
||||||
ingest_chunk_min_tokens: 256,
|
ingest_chunk_min_tokens: 256,
|
||||||
ingest_chunk_max_tokens: 512,
|
ingest_chunk_max_tokens: 512,
|
||||||
ingest_chunk_overlap_tokens: 50,
|
ingest_chunk_overlap_tokens: 50,
|
||||||
|
|||||||
@@ -72,6 +72,21 @@ pub struct RetrievalTuning {
|
|||||||
pub normalize_vector_scores: bool,
|
pub normalize_vector_scores: bool,
|
||||||
/// Normalize FTS (BM25) scores before fusion (default: true)
|
/// Normalize FTS (BM25) scores before fusion (default: true)
|
||||||
pub normalize_fts_scores: bool,
|
pub normalize_fts_scores: bool,
|
||||||
|
/// Reciprocal rank fusion k value for chunk merging in Revised strategy.
|
||||||
|
#[serde(default = "default_chunk_rrf_k")]
|
||||||
|
pub chunk_rrf_k: f32,
|
||||||
|
/// Weight applied to vector ranks in RRF.
|
||||||
|
#[serde(default = "default_chunk_rrf_vector_weight")]
|
||||||
|
pub chunk_rrf_vector_weight: f32,
|
||||||
|
/// Weight applied to chunk FTS ranks in RRF.
|
||||||
|
#[serde(default = "default_chunk_rrf_fts_weight")]
|
||||||
|
pub chunk_rrf_fts_weight: f32,
|
||||||
|
/// Whether to include vector rankings in RRF.
|
||||||
|
#[serde(default = "default_chunk_rrf_use_vector")]
|
||||||
|
pub chunk_rrf_use_vector: bool,
|
||||||
|
/// Whether to include chunk FTS rankings in RRF.
|
||||||
|
#[serde(default = "default_chunk_rrf_use_fts")]
|
||||||
|
pub chunk_rrf_use_fts: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for RetrievalTuning {
|
impl Default for RetrievalTuning {
|
||||||
@@ -102,6 +117,11 @@ impl Default for RetrievalTuning {
|
|||||||
normalize_vector_scores: false,
|
normalize_vector_scores: false,
|
||||||
// FTS scores (BM25) are unbounded, normalization helps more
|
// FTS scores (BM25) are unbounded, normalization helps more
|
||||||
normalize_fts_scores: true,
|
normalize_fts_scores: true,
|
||||||
|
chunk_rrf_k: default_chunk_rrf_k(),
|
||||||
|
chunk_rrf_vector_weight: default_chunk_rrf_vector_weight(),
|
||||||
|
chunk_rrf_fts_weight: default_chunk_rrf_fts_weight(),
|
||||||
|
chunk_rrf_use_vector: default_chunk_rrf_use_vector(),
|
||||||
|
chunk_rrf_use_fts: default_chunk_rrf_use_fts(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -156,3 +176,23 @@ impl Default for RetrievalConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const fn default_chunk_rrf_k() -> f32 {
|
||||||
|
60.0
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_chunk_rrf_vector_weight() -> f32 {
|
||||||
|
1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_chunk_rrf_fts_weight() -> f32 {
|
||||||
|
1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_chunk_rrf_use_vector() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn default_chunk_rrf_use_fts() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|||||||
@@ -21,8 +21,8 @@ use crate::{
|
|||||||
graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids},
|
graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids},
|
||||||
reranking::RerankerLease,
|
reranking::RerankerLease,
|
||||||
scoring::{
|
scoring::{
|
||||||
clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc,
|
clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, reciprocal_rank_fusion,
|
||||||
FusionWeights, Scored,
|
sort_by_fused_desc, FusionWeights, RrfConfig, Scored,
|
||||||
},
|
},
|
||||||
RetrievedChunk, RetrievedEntity,
|
RetrievedChunk, RetrievedEntity,
|
||||||
};
|
};
|
||||||
@@ -593,8 +593,9 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(),
|
|||||||
debug!("Collecting vector chunk candidates for revised strategy");
|
debug!("Collecting vector chunk candidates for revised strategy");
|
||||||
let embedding = ctx.ensure_embedding()?.clone();
|
let embedding = ctx.ensure_embedding()?.clone();
|
||||||
let tuning = &ctx.config.tuning;
|
let tuning = &ctx.config.tuning;
|
||||||
let weights = tuning.fusion_weights.unwrap_or_else(FusionWeights::default);
|
|
||||||
let fts_take = tuning.chunk_fts_take;
|
let fts_take = tuning.chunk_fts_take;
|
||||||
|
let (fts_query, fts_token_count) = normalize_fts_query(&ctx.input_text);
|
||||||
|
let fts_enabled = tuning.chunk_rrf_use_fts && fts_take > 0 && !fts_query.is_empty();
|
||||||
|
|
||||||
let (vector_rows, fts_rows) = tokio::try_join!(
|
let (vector_rows, fts_rows) = tokio::try_join!(
|
||||||
TextChunk::vector_search(
|
TextChunk::vector_search(
|
||||||
@@ -604,35 +605,42 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(),
|
|||||||
&ctx.user_id,
|
&ctx.user_id,
|
||||||
),
|
),
|
||||||
async {
|
async {
|
||||||
if fts_take == 0 {
|
if fts_enabled {
|
||||||
Ok(Vec::new())
|
TextChunk::fts_search(fts_take, &fts_query, ctx.db_client, &ctx.user_id).await
|
||||||
} else {
|
} else {
|
||||||
TextChunk::fts_search(fts_take, &ctx.input_text, ctx.db_client, &ctx.user_id).await
|
Ok(Vec::new())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let mut merged: HashMap<String, Scored<TextChunk>> = HashMap::new();
|
|
||||||
let vector_candidates = vector_rows.len();
|
let vector_candidates = vector_rows.len();
|
||||||
let fts_candidates = fts_rows.len();
|
let fts_candidates = fts_rows.len();
|
||||||
|
|
||||||
// Collect vector results
|
|
||||||
let vector_scored: Vec<Scored<TextChunk>> = vector_rows
|
let vector_scored: Vec<Scored<TextChunk>> = vector_rows
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|row| Scored::new(row.chunk).with_vector_score(row.score))
|
.map(|row| Scored::new(row.chunk).with_vector_score(row.score))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Collect FTS results
|
|
||||||
let fts_scored: Vec<Scored<TextChunk>> = fts_rows
|
let fts_scored: Vec<Scored<TextChunk>> = fts_rows
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|row| Scored::new(row.chunk).with_fts_score(row.score))
|
.map(|row| Scored::new(row.chunk).with_fts_score(row.score))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Merge by ID first (before normalization)
|
let mut fts_weight = tuning.chunk_rrf_fts_weight;
|
||||||
merge_scored_by_id(&mut merged, vector_scored);
|
if fts_enabled && fts_token_count > 0 && fts_token_count <= 3 {
|
||||||
merge_scored_by_id(&mut merged, fts_scored);
|
// For very short keyword queries, lean more on lexical ranking.
|
||||||
|
fts_weight *= 1.5;
|
||||||
|
}
|
||||||
|
|
||||||
let mut vector_chunks: Vec<Scored<TextChunk>> = merged.into_values().collect();
|
let rrf_config = RrfConfig {
|
||||||
|
k: tuning.chunk_rrf_k,
|
||||||
|
vector_weight: tuning.chunk_rrf_vector_weight,
|
||||||
|
fts_weight,
|
||||||
|
use_vector: tuning.chunk_rrf_use_vector,
|
||||||
|
use_fts: tuning.chunk_rrf_use_fts && fts_candidates > 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut vector_chunks = reciprocal_rank_fusion(vector_scored, fts_scored, rrf_config);
|
||||||
|
|
||||||
debug!(
|
debug!(
|
||||||
total_merged = vector_chunks.len(),
|
total_merged = vector_chunks.len(),
|
||||||
@@ -648,58 +656,24 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(),
|
|||||||
.iter()
|
.iter()
|
||||||
.filter(|c| c.scores.vector.is_some() && c.scores.fts.is_some())
|
.filter(|c| c.scores.vector.is_some() && c.scores.fts.is_some())
|
||||||
.count(),
|
.count(),
|
||||||
"Merged chunk candidates before normalization"
|
rrf_k = rrf_config.k,
|
||||||
|
rrf_vector_weight = rrf_config.vector_weight,
|
||||||
|
rrf_fts_weight = rrf_config.fts_weight,
|
||||||
|
"Merged chunk candidates with RRF"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Normalize scores AFTER merging, on the final merged set
|
// let fts_only_count = vector_chunks
|
||||||
// This ensures we normalize all vector scores together and all FTS scores together
|
// .iter()
|
||||||
// for the actual candidates that will be fused
|
// .filter(|c| c.scores.vector.is_none())
|
||||||
if tuning.normalize_vector_scores && !vector_chunks.is_empty() {
|
// .count();
|
||||||
let before_sample: Vec<f32> = vector_chunks
|
// let both_count = vector_chunks
|
||||||
.iter()
|
// .iter()
|
||||||
.filter_map(|c| c.scores.vector)
|
// .filter(|c| c.scores.vector.is_some() && c.scores.fts.is_some())
|
||||||
.take(5)
|
// .count();
|
||||||
.collect();
|
|
||||||
normalize_vector_scores(&mut vector_chunks);
|
|
||||||
let after_sample: Vec<f32> = vector_chunks
|
|
||||||
.iter()
|
|
||||||
.filter_map(|c| c.scores.vector)
|
|
||||||
.take(5)
|
|
||||||
.collect();
|
|
||||||
debug!(
|
|
||||||
vector_before = ?before_sample,
|
|
||||||
vector_after = ?after_sample,
|
|
||||||
"Vector score normalization applied"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if tuning.normalize_fts_scores && !vector_chunks.is_empty() {
|
|
||||||
let before_sample: Vec<f32> = vector_chunks
|
|
||||||
.iter()
|
|
||||||
.filter_map(|c| c.scores.fts)
|
|
||||||
.take(5)
|
|
||||||
.collect();
|
|
||||||
normalize_fts_scores_in_merged(&mut vector_chunks);
|
|
||||||
let after_sample: Vec<f32> = vector_chunks
|
|
||||||
.iter()
|
|
||||||
.filter_map(|c| c.scores.fts)
|
|
||||||
.take(5)
|
|
||||||
.collect();
|
|
||||||
debug!(
|
|
||||||
fts_before = ?before_sample,
|
|
||||||
fts_after = ?after_sample,
|
|
||||||
"FTS score normalization applied"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fuse scores after normalization
|
// If we have very low overlap (few chunks with both signals), drop FTS-only chunks.
|
||||||
for scored in &mut vector_chunks {
|
// These are often noisy on keyword-heavy datasets and dilute strong vector hits.
|
||||||
let fused = fuse_scores(&scored.scores, weights);
|
// Keep vector-only and “golden” (vector+FTS) chunks.
|
||||||
scored.update_fused(fused);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Filter out FTS-only chunks if they're likely to be low quality
|
|
||||||
// (when overlap is low, FTS-only chunks are usually noise)
|
|
||||||
// Always keep chunks with vector scores (vector-only or both signals)
|
|
||||||
let fts_only_count = vector_chunks
|
let fts_only_count = vector_chunks
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|c| c.scores.vector.is_none())
|
.filter(|c| c.scores.vector.is_none())
|
||||||
@@ -708,10 +682,6 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(),
|
|||||||
.iter()
|
.iter()
|
||||||
.filter(|c| c.scores.vector.is_some() && c.scores.fts.is_some())
|
.filter(|c| c.scores.vector.is_some() && c.scores.fts.is_some())
|
||||||
.count();
|
.count();
|
||||||
|
|
||||||
// If we have very low overlap (few chunks with both signals), filter out FTS-only chunks
|
|
||||||
// They're likely diluting the good vector results
|
|
||||||
// This preserves vector-only chunks and golden chunks (both signals)
|
|
||||||
if fts_only_count > 0 && both_count < 3 {
|
if fts_only_count > 0 && both_count < 3 {
|
||||||
let before_filter = vector_chunks.len();
|
let before_filter = vector_chunks.len();
|
||||||
vector_chunks.retain(|c| c.scores.vector.is_some());
|
vector_chunks.retain(|c| c.scores.vector.is_some());
|
||||||
@@ -724,9 +694,8 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(),
|
|||||||
}
|
}
|
||||||
|
|
||||||
debug!(
|
debug!(
|
||||||
fusion_weights = ?weights,
|
|
||||||
top_fused_scores = ?vector_chunks.iter().take(5).map(|c| c.fused).collect::<Vec<_>>(),
|
top_fused_scores = ?vector_chunks.iter().take(5).map(|c| c.fused).collect::<Vec<_>>(),
|
||||||
"Fused scores after normalization"
|
"Fused scores after RRF ordering"
|
||||||
);
|
);
|
||||||
|
|
||||||
if ctx.diagnostics_enabled() {
|
if ctx.diagnostics_enabled() {
|
||||||
@@ -797,11 +766,6 @@ pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
|||||||
.min(ctx.config.tuning.chunk_vector_take.max(1));
|
.min(ctx.config.tuning.chunk_vector_take.max(1));
|
||||||
|
|
||||||
if chunk_values.len() > limit {
|
if chunk_values.len() > limit {
|
||||||
println!(
|
|
||||||
"We removed chunks! we had {:?}, now going for {:?}",
|
|
||||||
chunk_values.len(),
|
|
||||||
limit
|
|
||||||
);
|
|
||||||
chunk_values.truncate(limit);
|
chunk_values.truncate(limit);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -966,87 +930,24 @@ fn normalize_fts_scores<T>(results: &mut [Scored<T>]) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn normalize_vector_scores<T>(results: &mut [Scored<T>]) {
|
fn normalize_fts_query(input: &str) -> (String, usize) {
|
||||||
// Only normalize scores for items that actually have vector scores
|
const STOPWORDS: &[&str] = &["the", "a", "an", "of", "in", "on", "and", "or", "to", "for"];
|
||||||
let items_with_scores: Vec<(usize, f32)> = results
|
let mut cleaned = String::with_capacity(input.len());
|
||||||
.iter()
|
for ch in input.chars() {
|
||||||
.enumerate()
|
if ch.is_alphanumeric() {
|
||||||
.filter_map(|(idx, candidate)| candidate.scores.vector.map(|score| (idx, score)))
|
cleaned.extend(ch.to_lowercase());
|
||||||
.collect();
|
} else if ch.is_whitespace() {
|
||||||
|
cleaned.push(' ');
|
||||||
if items_with_scores.len() < 2 {
|
}
|
||||||
// Don't normalize if we have 0 or 1 scores - nothing to normalize against
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
let mut tokens = Vec::new();
|
||||||
let raw_scores: Vec<f32> = items_with_scores.iter().map(|(_, score)| *score).collect();
|
for token in cleaned.split_whitespace() {
|
||||||
|
if !STOPWORDS.contains(&token) && !token.is_empty() {
|
||||||
// For cosine similarity scores (already in [0,1]), use a gentler normalization
|
tokens.push(token.to_string());
|
||||||
// that preserves more of the original distribution
|
}
|
||||||
// Only normalize if the range is significant (more than 0.1 difference)
|
|
||||||
let min = raw_scores.iter().fold(f32::MAX, |a, &b| a.min(b));
|
|
||||||
let max = raw_scores.iter().fold(f32::MIN, |a, &b| a.max(b));
|
|
||||||
let range = max - min;
|
|
||||||
|
|
||||||
if range < 0.1 {
|
|
||||||
// Scores are too similar, don't normalize (would compress too much)
|
|
||||||
debug!(
|
|
||||||
vector_score_range = range,
|
|
||||||
min = min,
|
|
||||||
max = max,
|
|
||||||
"Skipping vector normalization - scores too similar"
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let normalized = min_max_normalize(&raw_scores);
|
|
||||||
|
|
||||||
for ((idx, _), normalized_score) in items_with_scores.iter().zip(normalized.into_iter()) {
|
|
||||||
results[*idx].scores.vector = Some(normalized_score);
|
|
||||||
results[*idx].update_fused(0.0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn normalize_fts_scores_in_merged<T>(results: &mut [Scored<T>]) {
|
|
||||||
// Only normalize scores for items that actually have FTS scores
|
|
||||||
let items_with_scores: Vec<(usize, f32)> = results
|
|
||||||
.iter()
|
|
||||||
.enumerate()
|
|
||||||
.filter_map(|(idx, candidate)| candidate.scores.fts.map(|score| (idx, score)))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
if items_with_scores.len() < 2 {
|
|
||||||
// Don't normalize if we have 0 or 1 scores - nothing to normalize against
|
|
||||||
// Single FTS score would become 1.0, which doesn't help
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let raw_scores: Vec<f32> = items_with_scores.iter().map(|(_, score)| *score).collect();
|
|
||||||
|
|
||||||
// BM25 scores can be negative or very high, so normalization is more important
|
|
||||||
// But check if we have enough variation to normalize
|
|
||||||
let min = raw_scores.iter().fold(f32::MAX, |a, &b| a.min(b));
|
|
||||||
let max = raw_scores.iter().fold(f32::MIN, |a, &b| a.max(b));
|
|
||||||
let range = max - min;
|
|
||||||
|
|
||||||
// For BM25, even small differences can be meaningful, but if all scores are
|
|
||||||
// very similar, normalization won't help
|
|
||||||
if range < 0.01 {
|
|
||||||
debug!(
|
|
||||||
fts_score_range = range,
|
|
||||||
min = min,
|
|
||||||
max = max,
|
|
||||||
"Skipping FTS normalization - scores too similar"
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let normalized = min_max_normalize(&raw_scores);
|
|
||||||
|
|
||||||
for ((idx, _), normalized_score) in items_with_scores.iter().zip(normalized.into_iter()) {
|
|
||||||
results[*idx].scores.fts = Some(normalized_score);
|
|
||||||
results[*idx].update_fused(0.0);
|
|
||||||
}
|
}
|
||||||
|
let normalized = tokens.join(" ");
|
||||||
|
(normalized, tokens.len())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn apply_fusion<T>(candidates: &mut HashMap<String, Scored<T>>, weights: FusionWeights)
|
fn apply_fusion<T>(candidates: &mut HashMap<String, Scored<T>>, weights: FusionWeights)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use std::cmp::Ordering;
|
use std::{cmp::Ordering, collections::HashMap};
|
||||||
|
|
||||||
use common::storage::types::StoredObject;
|
use common::storage::types::StoredObject;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
@@ -71,6 +71,28 @@ impl Default for FusionWeights {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Configuration for reciprocal rank fusion.
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct RrfConfig {
|
||||||
|
pub k: f32,
|
||||||
|
pub vector_weight: f32,
|
||||||
|
pub fts_weight: f32,
|
||||||
|
pub use_vector: bool,
|
||||||
|
pub use_fts: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for RrfConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
k: 60.0,
|
||||||
|
vector_weight: 1.0,
|
||||||
|
fts_weight: 1.0,
|
||||||
|
use_vector: true,
|
||||||
|
use_fts: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub const fn clamp_unit(value: f32) -> f32 {
|
pub const fn clamp_unit(value: f32) -> f32 {
|
||||||
value.clamp(0.0, 1.0)
|
value.clamp(0.0, 1.0)
|
||||||
}
|
}
|
||||||
@@ -196,3 +218,83 @@ where
|
|||||||
.then_with(|| a.item.get_id().cmp(b.item.get_id()))
|
.then_with(|| a.item.get_id().cmp(b.item.get_id()))
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn reciprocal_rank_fusion<T>(
|
||||||
|
mut vector_ranked: Vec<Scored<T>>,
|
||||||
|
mut fts_ranked: Vec<Scored<T>>,
|
||||||
|
config: RrfConfig,
|
||||||
|
) -> Vec<Scored<T>>
|
||||||
|
where
|
||||||
|
T: StoredObject + Clone,
|
||||||
|
{
|
||||||
|
let mut merged: HashMap<String, Scored<T>> = HashMap::new();
|
||||||
|
let k = if config.k <= 0.0 { 60.0 } else { config.k };
|
||||||
|
let vector_weight = if config.vector_weight.is_finite() {
|
||||||
|
config.vector_weight.max(0.0)
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
let fts_weight = if config.fts_weight.is_finite() {
|
||||||
|
config.fts_weight.max(0.0)
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
|
||||||
|
if config.use_vector && !vector_ranked.is_empty() {
|
||||||
|
vector_ranked.sort_by(|a, b| {
|
||||||
|
let a_score = a.scores.vector.unwrap_or(0.0);
|
||||||
|
let b_score = b.scores.vector.unwrap_or(0.0);
|
||||||
|
b_score
|
||||||
|
.partial_cmp(&a_score)
|
||||||
|
.unwrap_or(Ordering::Equal)
|
||||||
|
.then_with(|| a.item.get_id().cmp(b.item.get_id()))
|
||||||
|
});
|
||||||
|
|
||||||
|
for (rank, candidate) in vector_ranked.into_iter().enumerate() {
|
||||||
|
let id = candidate.item.get_id().to_owned();
|
||||||
|
let entry = merged
|
||||||
|
.entry(id.clone())
|
||||||
|
.or_insert_with(|| Scored::new(candidate.item.clone()));
|
||||||
|
|
||||||
|
if let Some(score) = candidate.scores.vector {
|
||||||
|
let existing = entry.scores.vector.unwrap_or(f32::MIN);
|
||||||
|
if score > existing {
|
||||||
|
entry.scores.vector = Some(score);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
entry.item = candidate.item;
|
||||||
|
entry.fused += vector_weight / (k + rank as f32 + 1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.use_fts && !fts_ranked.is_empty() {
|
||||||
|
fts_ranked.sort_by(|a, b| {
|
||||||
|
let a_score = a.scores.fts.unwrap_or(0.0);
|
||||||
|
let b_score = b.scores.fts.unwrap_or(0.0);
|
||||||
|
b_score
|
||||||
|
.partial_cmp(&a_score)
|
||||||
|
.unwrap_or(Ordering::Equal)
|
||||||
|
.then_with(|| a.item.get_id().cmp(b.item.get_id()))
|
||||||
|
});
|
||||||
|
|
||||||
|
for (rank, candidate) in fts_ranked.into_iter().enumerate() {
|
||||||
|
let id = candidate.item.get_id().to_owned();
|
||||||
|
let entry = merged
|
||||||
|
.entry(id.clone())
|
||||||
|
.or_insert_with(|| Scored::new(candidate.item.clone()));
|
||||||
|
|
||||||
|
if let Some(score) = candidate.scores.fts {
|
||||||
|
let existing = entry.scores.fts.unwrap_or(f32::MIN);
|
||||||
|
if score > existing {
|
||||||
|
entry.scores.fts = Some(score);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
entry.item = candidate.item;
|
||||||
|
entry.fused += fts_weight / (k + rank as f32 + 1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut fused: Vec<Scored<T>> = merged.into_values().collect();
|
||||||
|
sort_by_fused_desc(&mut fused);
|
||||||
|
fused
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user