fix: arc-share retrieved chunks, centralize entity embeddings, and trim hot-path clones.

This commit is contained in:
Per Stark
2026-06-06 23:05:53 +02:00
parent 676fdbc132
commit 4559ee0aa8
41 changed files with 368 additions and 289 deletions
+6 -2
View File
@@ -7,6 +7,8 @@ pub mod scoring;
use std::sync::Arc;
pub use scoring::RetrievalCandidate;
use common::{
error::AppError,
storage::{
@@ -45,7 +47,7 @@ pub(crate) fn round_score(value: f32) -> f64 {
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
#[derive(Debug, Clone)]
pub struct RetrievedChunk {
pub chunk: TextChunk,
pub chunk: Arc<TextChunk>,
pub score: f32,
}
@@ -159,7 +161,9 @@ mod tests {
assert!(!chunks.is_empty(), "Expected at least one retrieval result");
assert!(
chunks.first().is_some_and(|c| c.chunk.chunk.contains("Tokio")),
chunks
.first()
.is_some_and(|c| c.chunk.chunk.contains("Tokio")),
"Expected chunk about Tokio"
);
Ok(())
+2 -2
View File
@@ -11,7 +11,7 @@ use crate::{reranking::RerankerLease, RetrievedChunk, RetrievedEntity};
use super::{
config::RetrievalConfig,
diagnostics::{AssembleStats, Diagnostics, SearchStats},
StageKind, StageTimings, RetrievalParams,
RetrievalParams, StageKind, StageTimings,
};
/// Mutable working state threaded through every retrieval stage.
@@ -22,7 +22,7 @@ pub(crate) struct PipelineContext<'a> {
pub user_id: String,
pub config: RetrievalConfig,
pub query_embedding: Option<Vec<f32>>,
pub chunk_values: Vec<Scored<TextChunk>>,
pub chunk_values: Vec<Scored<std::sync::Arc<TextChunk>>>,
pub reranker: Option<RerankerLease>,
pub diagnostics: Option<Diagnostics>,
pub entity_results: Vec<RetrievedEntity>,
+57 -27
View File
@@ -131,14 +131,14 @@ pub async fn search_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError
let vector_candidates = vector_rows.len();
let fts_candidates = fts_rows.len();
let vector_scored: Vec<Scored<TextChunk>> = vector_rows
let vector_scored: Vec<Scored<Arc<TextChunk>>> = vector_rows
.into_iter()
.map(|row| Scored::new(row.chunk).with_vector_score(row.score))
.map(|row| Scored::new(Arc::new(row.chunk)).with_vector_score(row.score))
.collect();
let fts_scored: Vec<Scored<TextChunk>> = fts_rows
let fts_scored: Vec<Scored<Arc<TextChunk>>> = fts_rows
.into_iter()
.map(|row| Scored::new(row.chunk).with_fts_score(row.score))
.map(|row| Scored::new(Arc::new(row.chunk)).with_fts_score(row.score))
.collect();
let mut fts_weight = tuning.chunk_rrf_fts_weight;
@@ -222,40 +222,63 @@ pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError
/// and the contributing chunks are attached.
#[instrument(level = "trace", skip_all)]
pub async fn resolve_entities(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
if ctx.chunk_values.is_empty() {
let chunk_values = std::mem::take(&mut ctx.chunk_values);
if chunk_values.is_empty() {
return Ok(());
}
let max_chunks = ctx.config.tuning.max_chunks_per_entity.max(1);
struct IndexedChunk {
idx: usize,
score: f32,
}
let mut source_order: Vec<String> = Vec::new();
let mut chunks_by_source: HashMap<String, Vec<RetrievedChunk>> = HashMap::new();
let mut chunks_by_source: HashMap<String, Vec<IndexedChunk>> = HashMap::new();
let mut best_score: HashMap<String, f32> = HashMap::new();
for scored in &ctx.chunk_values {
let source_id = &scored.item.source_id;
let is_new_source = !chunks_by_source.contains_key(source_id);
if is_new_source {
source_order.push(source_id.clone());
for (idx, scored) in chunk_values.iter().enumerate() {
if let Some(attached) = chunks_by_source.get_mut(&scored.item.source_id) {
if attached.len() < max_chunks {
attached.push(IndexedChunk {
idx,
score: scored.fused,
});
}
} else {
let source_id = scored.item.source_id.clone();
best_score.insert(source_id.clone(), scored.fused);
}
let attached = chunks_by_source
.entry(source_id.clone())
.or_default();
if attached.len() < max_chunks {
attached.push(RetrievedChunk {
chunk: scored.item.clone(),
score: scored.fused,
});
source_order.push(source_id.clone());
chunks_by_source.insert(
source_id,
vec![IndexedChunk {
idx,
score: scored.fused,
}],
);
}
}
let chunks_by_source: HashMap<String, Arc<Vec<RetrievedChunk>>> = chunks_by_source
.into_iter()
.map(|(source, chunks)| (source, Arc::new(chunks)))
.map(|(source, indices)| {
let chunks = indices
.into_iter()
.filter_map(|indexed| {
let scored = chunk_values.get(indexed.idx)?;
Some(RetrievedChunk {
chunk: Arc::clone(&scored.item),
score: indexed.score,
})
})
.collect();
(source, Arc::new(chunks))
})
.collect();
ctx.chunk_values = chunk_values;
let entities =
KnowledgeEntity::find_by_source_ids(ctx.db_client, &source_order, &ctx.user_id).await?;
@@ -336,10 +359,17 @@ fn sample_scores<T, F>(items: &[Scored<T>], extractor: F) -> Vec<f32>
where
F: FnMut(&Scored<T>) -> f32,
{
items.iter().take(SCORE_SAMPLE_LIMIT).map(extractor).collect()
items
.iter()
.take(SCORE_SAMPLE_LIMIT)
.map(extractor)
.collect()
}
fn build_chunk_rerank_documents(chunks: &[Scored<TextChunk>], max_chunks: usize) -> Vec<String> {
fn build_chunk_rerank_documents(
chunks: &[Scored<Arc<TextChunk>>],
max_chunks: usize,
) -> Vec<String> {
let take = chunks.len().min(max_chunks);
let mut documents = Vec::with_capacity(take);
let mut buffer = String::with_capacity(512);
@@ -363,7 +393,7 @@ fn build_chunk_rerank_documents(chunks: &[Scored<TextChunk>], max_chunks: usize)
}
fn apply_chunk_rerank_results(
chunks: &mut Vec<Scored<TextChunk>>,
chunks: &mut Vec<Scored<Arc<TextChunk>>>,
tuning: &RetrievalTuning,
results: Vec<RerankResult>,
) {
@@ -371,7 +401,7 @@ fn apply_chunk_rerank_results(
return;
}
let mut remaining: Vec<Option<Scored<TextChunk>>> =
let mut remaining: Vec<Option<Scored<Arc<TextChunk>>>> =
std::mem::take(chunks).into_iter().map(Some).collect();
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
@@ -384,7 +414,7 @@ fn apply_chunk_rerank_results(
clamp_unit(tuning.rerank_blend_weight)
};
let mut reranked: Vec<Scored<TextChunk>> = Vec::with_capacity(remaining.len());
let mut reranked: Vec<Scored<Arc<TextChunk>>> = Vec::with_capacity(remaining.len());
for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) {
if let Some(slot) = remaining.get_mut(result.index) {
if let Some(mut candidate) = slot.take() {
+6 -11
View File
@@ -29,8 +29,7 @@ impl RerankerPool {
/// Build the pool at startup.
/// `pool_size` controls max parallel reranks.
pub fn new(pool_size: usize) -> Result<Arc<Self>, Box<AppError>> {
let init_options =
RerankInitOptions::new(fastembed::RerankerModel::JINARerankerV1TurboEn);
let init_options = RerankInitOptions::new(fastembed::RerankerModel::JINARerankerV1TurboEn);
Self::new_with_options(pool_size, &init_options)
}
@@ -44,8 +43,7 @@ impl RerankerPool {
)));
}
fs::create_dir_all(&init_options.cache_dir)
.map_err(|e| Box::new(AppError::from(e)))?;
fs::create_dir_all(&init_options.cache_dir).map_err(|e| Box::new(AppError::from(e)))?;
let mut engines = Vec::with_capacity(pool_size);
for x in 0..pool_size {
@@ -77,10 +75,7 @@ impl RerankerPool {
/// This returns a lease that can perform `rerank()`.
pub async fn checkout(self: &Arc<Self>) -> Option<RerankerLease> {
// Acquire a permit. This enforces backpressure.
let permit = Arc::clone(&self.semaphore)
.acquire_owned()
.await
.ok()?;
let permit = Arc::clone(&self.semaphore).acquire_owned().await.ok()?;
// Pick an engine.
// This is naive: just pick based on a simple modulo counter.
@@ -165,9 +160,9 @@ impl RerankerLease {
let engine = Arc::clone(&self.engine);
tokio::task::spawn_blocking(move || {
let mut guard = engine.lock().map_err(|_| {
AppError::InternalError("reranker engine mutex poisoned".into())
})?;
let mut guard = engine
.lock()
.map_err(|_| AppError::InternalError("reranker engine mutex poisoned".into()))?;
guard
.rerank(query, documents, false, None)
.map_err(|e| AppError::InternalError(e.to_string()))
+34 -8
View File
@@ -1,9 +1,35 @@
use std::{
cmp::Ordering,
collections::{hash_map::Entry, HashMap},
sync::Arc,
};
use common::storage::types::StoredObject;
use common::storage::types::{
knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject,
};
/// Identifier access for retrieval fusion and sorting.
pub trait RetrievalCandidate {
fn candidate_id(&self) -> &str;
}
impl RetrievalCandidate for TextChunk {
fn candidate_id(&self) -> &str {
self.id()
}
}
impl RetrievalCandidate for Arc<TextChunk> {
fn candidate_id(&self) -> &str {
self.as_ref().id()
}
}
impl RetrievalCandidate for KnowledgeEntity {
fn candidate_id(&self) -> &str {
self.id()
}
}
/// Holds optional subscores gathered from the vector and full-text retrieval signals.
#[derive(Debug, Clone, Copy, Default)]
@@ -102,13 +128,13 @@ pub fn min_max_normalize(scores: &[f32]) -> Vec<f32> {
pub fn sort_by_fused_desc<T>(items: &mut [Scored<T>])
where
T: StoredObject,
T: RetrievalCandidate,
{
items.sort_by(|a, b| {
b.fused
.partial_cmp(&a.fused)
.unwrap_or(Ordering::Equal)
.then_with(|| a.item.id().cmp(b.item.id()))
.then_with(|| a.item.candidate_id().cmp(b.item.candidate_id()))
});
}
@@ -122,7 +148,7 @@ pub fn reciprocal_rank_fusion<T>(
config: RrfConfig,
) -> Vec<Scored<T>>
where
T: StoredObject,
T: RetrievalCandidate,
{
let mut merged: HashMap<String, Scored<T>> = HashMap::new();
let k = if config.k <= 0.0 { 60.0 } else { config.k };
@@ -144,11 +170,11 @@ where
b_score
.partial_cmp(&a_score)
.unwrap_or(Ordering::Equal)
.then_with(|| a.item.id().cmp(b.item.id()))
.then_with(|| a.item.candidate_id().cmp(b.item.candidate_id()))
});
for (rank, candidate) in vector_ranked.into_iter().enumerate() {
let id = candidate.item.id().to_owned();
let id = candidate.item.candidate_id().to_owned();
let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from);
let contribution = vector_weight / (k + rank_f32 + 1.0);
@@ -183,11 +209,11 @@ where
b_score
.partial_cmp(&a_score)
.unwrap_or(Ordering::Equal)
.then_with(|| a.item.id().cmp(b.item.id()))
.then_with(|| a.item.candidate_id().cmp(b.item.candidate_id()))
});
for (rank, candidate) in fts_ranked.into_iter().enumerate() {
let id = candidate.item.id().to_owned();
let id = candidate.item.candidate_id().to_owned();
let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from);
let contribution = fts_weight / (k + rank_f32 + 1.0);