mirror of
https://github.com/perstarkse/minne.git
synced 2026-06-30 18:11:34 +02:00
fix: arc-share retrieved chunks, centralize entity embeddings, and trim hot-path clones.
This commit is contained in:
@@ -7,6 +7,8 @@ pub mod scoring;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
pub use scoring::RetrievalCandidate;
|
||||
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
@@ -45,7 +47,7 @@ pub(crate) fn round_score(value: f32) -> f64 {
|
||||
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetrievedChunk {
|
||||
pub chunk: TextChunk,
|
||||
pub chunk: Arc<TextChunk>,
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
@@ -159,7 +161,9 @@ mod tests {
|
||||
|
||||
assert!(!chunks.is_empty(), "Expected at least one retrieval result");
|
||||
assert!(
|
||||
chunks.first().is_some_and(|c| c.chunk.chunk.contains("Tokio")),
|
||||
chunks
|
||||
.first()
|
||||
.is_some_and(|c| c.chunk.chunk.contains("Tokio")),
|
||||
"Expected chunk about Tokio"
|
||||
);
|
||||
Ok(())
|
||||
|
||||
@@ -11,7 +11,7 @@ use crate::{reranking::RerankerLease, RetrievedChunk, RetrievedEntity};
|
||||
use super::{
|
||||
config::RetrievalConfig,
|
||||
diagnostics::{AssembleStats, Diagnostics, SearchStats},
|
||||
StageKind, StageTimings, RetrievalParams,
|
||||
RetrievalParams, StageKind, StageTimings,
|
||||
};
|
||||
|
||||
/// Mutable working state threaded through every retrieval stage.
|
||||
@@ -22,7 +22,7 @@ pub(crate) struct PipelineContext<'a> {
|
||||
pub user_id: String,
|
||||
pub config: RetrievalConfig,
|
||||
pub query_embedding: Option<Vec<f32>>,
|
||||
pub chunk_values: Vec<Scored<TextChunk>>,
|
||||
pub chunk_values: Vec<Scored<std::sync::Arc<TextChunk>>>,
|
||||
pub reranker: Option<RerankerLease>,
|
||||
pub diagnostics: Option<Diagnostics>,
|
||||
pub entity_results: Vec<RetrievedEntity>,
|
||||
|
||||
@@ -131,14 +131,14 @@ pub async fn search_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError
|
||||
let vector_candidates = vector_rows.len();
|
||||
let fts_candidates = fts_rows.len();
|
||||
|
||||
let vector_scored: Vec<Scored<TextChunk>> = vector_rows
|
||||
let vector_scored: Vec<Scored<Arc<TextChunk>>> = vector_rows
|
||||
.into_iter()
|
||||
.map(|row| Scored::new(row.chunk).with_vector_score(row.score))
|
||||
.map(|row| Scored::new(Arc::new(row.chunk)).with_vector_score(row.score))
|
||||
.collect();
|
||||
|
||||
let fts_scored: Vec<Scored<TextChunk>> = fts_rows
|
||||
let fts_scored: Vec<Scored<Arc<TextChunk>>> = fts_rows
|
||||
.into_iter()
|
||||
.map(|row| Scored::new(row.chunk).with_fts_score(row.score))
|
||||
.map(|row| Scored::new(Arc::new(row.chunk)).with_fts_score(row.score))
|
||||
.collect();
|
||||
|
||||
let mut fts_weight = tuning.chunk_rrf_fts_weight;
|
||||
@@ -222,40 +222,63 @@ pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError
|
||||
/// and the contributing chunks are attached.
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn resolve_entities(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
if ctx.chunk_values.is_empty() {
|
||||
let chunk_values = std::mem::take(&mut ctx.chunk_values);
|
||||
if chunk_values.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let max_chunks = ctx.config.tuning.max_chunks_per_entity.max(1);
|
||||
|
||||
struct IndexedChunk {
|
||||
idx: usize,
|
||||
score: f32,
|
||||
}
|
||||
|
||||
let mut source_order: Vec<String> = Vec::new();
|
||||
let mut chunks_by_source: HashMap<String, Vec<RetrievedChunk>> = HashMap::new();
|
||||
let mut chunks_by_source: HashMap<String, Vec<IndexedChunk>> = HashMap::new();
|
||||
let mut best_score: HashMap<String, f32> = HashMap::new();
|
||||
|
||||
for scored in &ctx.chunk_values {
|
||||
let source_id = &scored.item.source_id;
|
||||
let is_new_source = !chunks_by_source.contains_key(source_id);
|
||||
if is_new_source {
|
||||
source_order.push(source_id.clone());
|
||||
for (idx, scored) in chunk_values.iter().enumerate() {
|
||||
if let Some(attached) = chunks_by_source.get_mut(&scored.item.source_id) {
|
||||
if attached.len() < max_chunks {
|
||||
attached.push(IndexedChunk {
|
||||
idx,
|
||||
score: scored.fused,
|
||||
});
|
||||
}
|
||||
} else {
|
||||
let source_id = scored.item.source_id.clone();
|
||||
best_score.insert(source_id.clone(), scored.fused);
|
||||
}
|
||||
|
||||
let attached = chunks_by_source
|
||||
.entry(source_id.clone())
|
||||
.or_default();
|
||||
if attached.len() < max_chunks {
|
||||
attached.push(RetrievedChunk {
|
||||
chunk: scored.item.clone(),
|
||||
score: scored.fused,
|
||||
});
|
||||
source_order.push(source_id.clone());
|
||||
chunks_by_source.insert(
|
||||
source_id,
|
||||
vec![IndexedChunk {
|
||||
idx,
|
||||
score: scored.fused,
|
||||
}],
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let chunks_by_source: HashMap<String, Arc<Vec<RetrievedChunk>>> = chunks_by_source
|
||||
.into_iter()
|
||||
.map(|(source, chunks)| (source, Arc::new(chunks)))
|
||||
.map(|(source, indices)| {
|
||||
let chunks = indices
|
||||
.into_iter()
|
||||
.filter_map(|indexed| {
|
||||
let scored = chunk_values.get(indexed.idx)?;
|
||||
Some(RetrievedChunk {
|
||||
chunk: Arc::clone(&scored.item),
|
||||
score: indexed.score,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
(source, Arc::new(chunks))
|
||||
})
|
||||
.collect();
|
||||
|
||||
ctx.chunk_values = chunk_values;
|
||||
|
||||
let entities =
|
||||
KnowledgeEntity::find_by_source_ids(ctx.db_client, &source_order, &ctx.user_id).await?;
|
||||
|
||||
@@ -336,10 +359,17 @@ fn sample_scores<T, F>(items: &[Scored<T>], extractor: F) -> Vec<f32>
|
||||
where
|
||||
F: FnMut(&Scored<T>) -> f32,
|
||||
{
|
||||
items.iter().take(SCORE_SAMPLE_LIMIT).map(extractor).collect()
|
||||
items
|
||||
.iter()
|
||||
.take(SCORE_SAMPLE_LIMIT)
|
||||
.map(extractor)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn build_chunk_rerank_documents(chunks: &[Scored<TextChunk>], max_chunks: usize) -> Vec<String> {
|
||||
fn build_chunk_rerank_documents(
|
||||
chunks: &[Scored<Arc<TextChunk>>],
|
||||
max_chunks: usize,
|
||||
) -> Vec<String> {
|
||||
let take = chunks.len().min(max_chunks);
|
||||
let mut documents = Vec::with_capacity(take);
|
||||
let mut buffer = String::with_capacity(512);
|
||||
@@ -363,7 +393,7 @@ fn build_chunk_rerank_documents(chunks: &[Scored<TextChunk>], max_chunks: usize)
|
||||
}
|
||||
|
||||
fn apply_chunk_rerank_results(
|
||||
chunks: &mut Vec<Scored<TextChunk>>,
|
||||
chunks: &mut Vec<Scored<Arc<TextChunk>>>,
|
||||
tuning: &RetrievalTuning,
|
||||
results: Vec<RerankResult>,
|
||||
) {
|
||||
@@ -371,7 +401,7 @@ fn apply_chunk_rerank_results(
|
||||
return;
|
||||
}
|
||||
|
||||
let mut remaining: Vec<Option<Scored<TextChunk>>> =
|
||||
let mut remaining: Vec<Option<Scored<Arc<TextChunk>>>> =
|
||||
std::mem::take(chunks).into_iter().map(Some).collect();
|
||||
|
||||
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
|
||||
@@ -384,7 +414,7 @@ fn apply_chunk_rerank_results(
|
||||
clamp_unit(tuning.rerank_blend_weight)
|
||||
};
|
||||
|
||||
let mut reranked: Vec<Scored<TextChunk>> = Vec::with_capacity(remaining.len());
|
||||
let mut reranked: Vec<Scored<Arc<TextChunk>>> = Vec::with_capacity(remaining.len());
|
||||
for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) {
|
||||
if let Some(slot) = remaining.get_mut(result.index) {
|
||||
if let Some(mut candidate) = slot.take() {
|
||||
|
||||
@@ -29,8 +29,7 @@ impl RerankerPool {
|
||||
/// Build the pool at startup.
|
||||
/// `pool_size` controls max parallel reranks.
|
||||
pub fn new(pool_size: usize) -> Result<Arc<Self>, Box<AppError>> {
|
||||
let init_options =
|
||||
RerankInitOptions::new(fastembed::RerankerModel::JINARerankerV1TurboEn);
|
||||
let init_options = RerankInitOptions::new(fastembed::RerankerModel::JINARerankerV1TurboEn);
|
||||
Self::new_with_options(pool_size, &init_options)
|
||||
}
|
||||
|
||||
@@ -44,8 +43,7 @@ impl RerankerPool {
|
||||
)));
|
||||
}
|
||||
|
||||
fs::create_dir_all(&init_options.cache_dir)
|
||||
.map_err(|e| Box::new(AppError::from(e)))?;
|
||||
fs::create_dir_all(&init_options.cache_dir).map_err(|e| Box::new(AppError::from(e)))?;
|
||||
|
||||
let mut engines = Vec::with_capacity(pool_size);
|
||||
for x in 0..pool_size {
|
||||
@@ -77,10 +75,7 @@ impl RerankerPool {
|
||||
/// This returns a lease that can perform `rerank()`.
|
||||
pub async fn checkout(self: &Arc<Self>) -> Option<RerankerLease> {
|
||||
// Acquire a permit. This enforces backpressure.
|
||||
let permit = Arc::clone(&self.semaphore)
|
||||
.acquire_owned()
|
||||
.await
|
||||
.ok()?;
|
||||
let permit = Arc::clone(&self.semaphore).acquire_owned().await.ok()?;
|
||||
|
||||
// Pick an engine.
|
||||
// This is naive: just pick based on a simple modulo counter.
|
||||
@@ -165,9 +160,9 @@ impl RerankerLease {
|
||||
let engine = Arc::clone(&self.engine);
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let mut guard = engine.lock().map_err(|_| {
|
||||
AppError::InternalError("reranker engine mutex poisoned".into())
|
||||
})?;
|
||||
let mut guard = engine
|
||||
.lock()
|
||||
.map_err(|_| AppError::InternalError("reranker engine mutex poisoned".into()))?;
|
||||
guard
|
||||
.rerank(query, documents, false, None)
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))
|
||||
|
||||
@@ -1,9 +1,35 @@
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
collections::{hash_map::Entry, HashMap},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use common::storage::types::StoredObject;
|
||||
use common::storage::types::{
|
||||
knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject,
|
||||
};
|
||||
|
||||
/// Identifier access for retrieval fusion and sorting.
|
||||
pub trait RetrievalCandidate {
|
||||
fn candidate_id(&self) -> &str;
|
||||
}
|
||||
|
||||
impl RetrievalCandidate for TextChunk {
|
||||
fn candidate_id(&self) -> &str {
|
||||
self.id()
|
||||
}
|
||||
}
|
||||
|
||||
impl RetrievalCandidate for Arc<TextChunk> {
|
||||
fn candidate_id(&self) -> &str {
|
||||
self.as_ref().id()
|
||||
}
|
||||
}
|
||||
|
||||
impl RetrievalCandidate for KnowledgeEntity {
|
||||
fn candidate_id(&self) -> &str {
|
||||
self.id()
|
||||
}
|
||||
}
|
||||
|
||||
/// Holds optional subscores gathered from the vector and full-text retrieval signals.
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
@@ -102,13 +128,13 @@ pub fn min_max_normalize(scores: &[f32]) -> Vec<f32> {
|
||||
|
||||
pub fn sort_by_fused_desc<T>(items: &mut [Scored<T>])
|
||||
where
|
||||
T: StoredObject,
|
||||
T: RetrievalCandidate,
|
||||
{
|
||||
items.sort_by(|a, b| {
|
||||
b.fused
|
||||
.partial_cmp(&a.fused)
|
||||
.unwrap_or(Ordering::Equal)
|
||||
.then_with(|| a.item.id().cmp(b.item.id()))
|
||||
.then_with(|| a.item.candidate_id().cmp(b.item.candidate_id()))
|
||||
});
|
||||
}
|
||||
|
||||
@@ -122,7 +148,7 @@ pub fn reciprocal_rank_fusion<T>(
|
||||
config: RrfConfig,
|
||||
) -> Vec<Scored<T>>
|
||||
where
|
||||
T: StoredObject,
|
||||
T: RetrievalCandidate,
|
||||
{
|
||||
let mut merged: HashMap<String, Scored<T>> = HashMap::new();
|
||||
let k = if config.k <= 0.0 { 60.0 } else { config.k };
|
||||
@@ -144,11 +170,11 @@ where
|
||||
b_score
|
||||
.partial_cmp(&a_score)
|
||||
.unwrap_or(Ordering::Equal)
|
||||
.then_with(|| a.item.id().cmp(b.item.id()))
|
||||
.then_with(|| a.item.candidate_id().cmp(b.item.candidate_id()))
|
||||
});
|
||||
|
||||
for (rank, candidate) in vector_ranked.into_iter().enumerate() {
|
||||
let id = candidate.item.id().to_owned();
|
||||
let id = candidate.item.candidate_id().to_owned();
|
||||
let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from);
|
||||
let contribution = vector_weight / (k + rank_f32 + 1.0);
|
||||
|
||||
@@ -183,11 +209,11 @@ where
|
||||
b_score
|
||||
.partial_cmp(&a_score)
|
||||
.unwrap_or(Ordering::Equal)
|
||||
.then_with(|| a.item.id().cmp(b.item.id()))
|
||||
.then_with(|| a.item.candidate_id().cmp(b.item.candidate_id()))
|
||||
});
|
||||
|
||||
for (rank, candidate) in fts_ranked.into_iter().enumerate() {
|
||||
let id = candidate.item.id().to_owned();
|
||||
let id = candidate.item.candidate_id().to_owned();
|
||||
let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from);
|
||||
let contribution = fts_weight / (k + rank_f32 + 1.0);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user