benchmarks: v1

Benchmarking ingestion, retrieval precision and performance
This commit is contained in:
Per Stark
2025-11-04 11:22:45 +01:00
parent 7f30c8ff6e
commit 6f08429faa
46 changed files with 8407 additions and 144 deletions
+3 -1
View File
@@ -12,6 +12,7 @@ pub struct RetrievalTuning {
pub token_budget_estimate: usize,
pub avg_chars_per_token: usize,
pub max_chunks_per_entity: usize,
pub lexical_match_weight: f32,
pub graph_traversal_seed_limit: usize,
pub graph_neighbor_limit: usize,
pub graph_score_decay: f32,
@@ -31,9 +32,10 @@ impl Default for RetrievalTuning {
chunk_fts_take: 20,
score_threshold: 0.35,
fallback_min_results: 10,
token_budget_estimate: 2800,
token_budget_estimate: 10000,
avg_chars_per_token: 4,
max_chunks_per_entity: 4,
lexical_match_weight: 0.15,
graph_traversal_seed_limit: 5,
graph_neighbor_limit: 6,
graph_score_decay: 0.75,
@@ -0,0 +1,51 @@
use serde::Serialize;
/// Captures instrumentation for each hybrid retrieval stage when diagnostics are enabled.
#[derive(Debug, Clone, Default, Serialize)]
pub struct PipelineDiagnostics {
pub collect_candidates: Option<CollectCandidatesStats>,
pub enrich_chunks_from_entities: Option<ChunkEnrichmentStats>,
pub assemble: Option<AssembleStats>,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct CollectCandidatesStats {
pub vector_entity_candidates: usize,
pub vector_chunk_candidates: usize,
pub fts_entity_candidates: usize,
pub fts_chunk_candidates: usize,
pub vector_chunk_scores: Vec<f32>,
pub fts_chunk_scores: Vec<f32>,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct ChunkEnrichmentStats {
pub filtered_entity_count: usize,
pub fallback_min_results: usize,
pub chunk_sources_considered: usize,
pub chunk_candidates_before_enrichment: usize,
pub chunk_candidates_after_enrichment: usize,
pub top_chunk_scores: Vec<f32>,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct AssembleStats {
pub token_budget_start: usize,
pub token_budget_spent: usize,
pub token_budget_remaining: usize,
pub budget_exhausted: bool,
pub chunks_selected: usize,
pub chunks_skipped_due_budget: usize,
pub entity_count: usize,
pub entity_traces: Vec<EntityAssemblyTrace>,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct EntityAssemblyTrace {
pub entity_id: String,
pub source_id: String,
pub inspected_candidates: usize,
pub selected_chunk_ids: Vec<String>,
pub selected_chunk_scores: Vec<f32>,
pub skipped_due_budget: usize,
}
+125 -19
View File
@@ -1,14 +1,57 @@
mod config;
mod diagnostics;
mod stages;
mod state;
pub use config::{RetrievalConfig, RetrievalTuning};
pub use diagnostics::{
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
PipelineDiagnostics,
};
use crate::{reranking::RerankerLease, RetrievedEntity};
use async_openai::Client;
use common::{error::AppError, storage::db::SurrealDbClient};
use tracing::info;
#[derive(Debug)]
pub struct PipelineRunOutput {
pub results: Vec<RetrievedEntity>,
pub diagnostics: Option<PipelineDiagnostics>,
pub stage_timings: PipelineStageTimings,
}
#[derive(Debug, Clone, Default, serde::Serialize)]
pub struct PipelineStageTimings {
pub collect_candidates_ms: u128,
pub graph_expansion_ms: u128,
pub chunk_attach_ms: u128,
pub rerank_ms: u128,
pub assemble_ms: u128,
}
impl PipelineStageTimings {
fn record_collect_candidates(&mut self, duration: std::time::Duration) {
self.collect_candidates_ms += duration.as_millis() as u128;
}
fn record_graph_expansion(&mut self, duration: std::time::Duration) {
self.graph_expansion_ms += duration.as_millis() as u128;
}
fn record_chunk_attach(&mut self, duration: std::time::Duration) {
self.chunk_attach_ms += duration.as_millis() as u128;
}
fn record_rerank(&mut self, duration: std::time::Duration) {
self.rerank_ms += duration.as_millis() as u128;
}
fn record_assemble(&mut self, duration: std::time::Duration) {
self.assemble_ms += duration.as_millis() as u128;
}
}
/// Drives the retrieval pipeline from embedding through final assembly.
pub async fn run_pipeline(
db_client: &SurrealDbClient,
@@ -18,7 +61,6 @@ pub async fn run_pipeline(
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<Vec<RetrievedEntity>, AppError> {
let machine = state::ready();
let input_chars = input_text.chars().count();
let input_preview: String = input_text.chars().take(120).collect();
let input_preview_clean = input_preview.replace('\n', " ");
@@ -30,7 +72,7 @@ pub async fn run_pipeline(
preview = %input_preview_clean,
"Starting ingestion retrieval pipeline"
);
let mut ctx = stages::PipelineContext::new(
let ctx = stages::PipelineContext::new(
db_client,
openai_client,
input_text.to_owned(),
@@ -38,17 +80,11 @@ pub async fn run_pipeline(
config,
reranker,
);
let machine = stages::embed(machine, &mut ctx).await?;
let machine = stages::collect_candidates(machine, &mut ctx).await?;
let machine = stages::expand_graph(machine, &mut ctx).await?;
let machine = stages::attach_chunks(machine, &mut ctx).await?;
let machine = stages::rerank(machine, &mut ctx).await?;
let results = stages::assemble(machine, &mut ctx)?;
let outcome = run_pipeline_internal(ctx, false).await?;
Ok(results)
Ok(outcome.results)
}
#[cfg(test)]
pub async fn run_pipeline_with_embedding(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
@@ -58,8 +94,7 @@ pub async fn run_pipeline_with_embedding(
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<Vec<RetrievedEntity>, AppError> {
let machine = state::ready();
let mut ctx = stages::PipelineContext::with_embedding(
let ctx = stages::PipelineContext::with_embedding(
db_client,
openai_client,
query_embedding,
@@ -68,14 +103,54 @@ pub async fn run_pipeline_with_embedding(
config,
reranker,
);
let machine = stages::embed(machine, &mut ctx).await?;
let machine = stages::collect_candidates(machine, &mut ctx).await?;
let machine = stages::expand_graph(machine, &mut ctx).await?;
let machine = stages::attach_chunks(machine, &mut ctx).await?;
let machine = stages::rerank(machine, &mut ctx).await?;
let results = stages::assemble(machine, &mut ctx)?;
let outcome = run_pipeline_internal(ctx, false).await?;
Ok(results)
Ok(outcome.results)
}
/// Runs the pipeline with a precomputed embedding and returns stage metrics.
pub async fn run_pipeline_with_embedding_with_metrics(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<PipelineRunOutput, AppError> {
let ctx = stages::PipelineContext::with_embedding(
db_client,
openai_client,
query_embedding,
input_text.to_owned(),
user_id.to_owned(),
config,
reranker,
);
run_pipeline_internal(ctx, false).await
}
pub async fn run_pipeline_with_embedding_with_diagnostics(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<PipelineRunOutput, AppError> {
let ctx = stages::PipelineContext::with_embedding(
db_client,
openai_client,
query_embedding,
input_text.to_owned(),
user_id.to_owned(),
config,
reranker,
);
run_pipeline_internal(ctx, true).await
}
/// Helper exposed for tests to convert retrieved entities into downstream prompt JSON.
@@ -101,6 +176,37 @@ pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::V
.collect::<Vec<_>>())
}
async fn run_pipeline_internal(
mut ctx: stages::PipelineContext<'_>,
capture_diagnostics: bool,
) -> Result<PipelineRunOutput, AppError> {
if capture_diagnostics {
ctx.enable_diagnostics();
}
let results = drive_pipeline(&mut ctx).await?;
let diagnostics = ctx.take_diagnostics();
Ok(PipelineRunOutput {
results,
diagnostics,
stage_timings: ctx.take_stage_timings(),
})
}
async fn drive_pipeline(
ctx: &mut stages::PipelineContext<'_>,
) -> Result<Vec<RetrievedEntity>, AppError> {
let machine = state::ready();
let machine = stages::embed(machine, ctx).await?;
let machine = stages::collect_candidates(machine, ctx).await?;
let machine = stages::expand_graph(machine, ctx).await?;
let machine = stages::attach_chunks(machine, ctx).await?;
let machine = stages::rerank(machine, ctx).await?;
let results = stages::assemble(machine, ctx)?;
Ok(results)
}
fn round_score(value: f32) -> f64 {
(f64::from(value) * 1000.0).round() / 1000.0
}
+326 -85
View File
@@ -10,7 +10,11 @@ use common::{
use fastembed::RerankResult;
use futures::{stream::FuturesUnordered, StreamExt};
use state_machines::core::GuardError;
use std::collections::{HashMap, HashSet};
use std::{
cmp::Ordering,
collections::{HashMap, HashSet},
time::Instant,
};
use tracing::{debug, instrument, warn};
use crate::{
@@ -27,10 +31,15 @@ use crate::{
use super::{
config::RetrievalConfig,
diagnostics::{
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
PipelineDiagnostics,
},
state::{
CandidatesLoaded, ChunksAttached, Embedded, GraphExpanded, HybridRetrievalMachine, Ready,
Reranked,
},
PipelineStageTimings,
};
pub struct PipelineContext<'a> {
@@ -45,6 +54,8 @@ pub struct PipelineContext<'a> {
pub filtered_entities: Vec<Scored<KnowledgeEntity>>,
pub chunk_values: Vec<Scored<TextChunk>>,
pub reranker: Option<RerankerLease>,
pub diagnostics: Option<PipelineDiagnostics>,
stage_timings: PipelineStageTimings,
}
impl<'a> PipelineContext<'a> {
@@ -68,10 +79,11 @@ impl<'a> PipelineContext<'a> {
filtered_entities: Vec::new(),
chunk_values: Vec::new(),
reranker,
diagnostics: None,
stage_timings: PipelineStageTimings::default(),
}
}
#[cfg(test)]
pub fn with_embedding(
db_client: &'a SurrealDbClient,
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
@@ -100,6 +112,62 @@ impl<'a> PipelineContext<'a> {
)
})
}
pub fn enable_diagnostics(&mut self) {
if self.diagnostics.is_none() {
self.diagnostics = Some(PipelineDiagnostics::default());
}
}
pub fn diagnostics_enabled(&self) -> bool {
self.diagnostics.is_some()
}
pub fn record_collect_candidates(&mut self, stats: CollectCandidatesStats) {
if let Some(diag) = self.diagnostics.as_mut() {
diag.collect_candidates = Some(stats);
}
}
pub fn record_chunk_enrichment(&mut self, stats: ChunkEnrichmentStats) {
if let Some(diag) = self.diagnostics.as_mut() {
diag.enrich_chunks_from_entities = Some(stats);
}
}
pub fn record_assemble(&mut self, stats: AssembleStats) {
if let Some(diag) = self.diagnostics.as_mut() {
diag.assemble = Some(stats);
}
}
pub fn take_diagnostics(&mut self) -> Option<PipelineDiagnostics> {
self.diagnostics.take()
}
pub fn record_collect_candidates_timing(&mut self, duration: std::time::Duration) {
self.stage_timings.record_collect_candidates(duration);
}
pub fn record_graph_expansion_timing(&mut self, duration: std::time::Duration) {
self.stage_timings.record_graph_expansion(duration);
}
pub fn record_chunk_attach_timing(&mut self, duration: std::time::Duration) {
self.stage_timings.record_chunk_attach(duration);
}
pub fn record_rerank_timing(&mut self, duration: std::time::Duration) {
self.stage_timings.record_rerank(duration);
}
pub fn record_assemble_timing(&mut self, duration: std::time::Duration) {
self.stage_timings.record_assemble(duration);
}
pub fn take_stage_timings(&mut self) -> PipelineStageTimings {
std::mem::take(&mut self.stage_timings)
}
}
#[instrument(level = "trace", skip_all)]
@@ -127,6 +195,7 @@ pub async fn collect_candidates(
machine: HybridRetrievalMachine<(), Embedded>,
ctx: &mut PipelineContext<'_>,
) -> Result<HybridRetrievalMachine<(), CandidatesLoaded>, AppError> {
let stage_start = Instant::now();
debug!("Collecting initial candidates via vector and FTS search");
let embedding = ctx.ensure_embedding()?.clone();
let tuning = &ctx.config.tuning;
@@ -172,6 +241,19 @@ pub async fn collect_candidates(
"Hybrid retrieval initial candidate counts"
);
if ctx.diagnostics_enabled() {
ctx.record_collect_candidates(CollectCandidatesStats {
vector_entity_candidates: vector_entities.len(),
vector_chunk_candidates: vector_chunks.len(),
fts_entity_candidates: fts_entities.len(),
fts_chunk_candidates: fts_chunks.len(),
vector_chunk_scores: sample_scores(&vector_chunks, |chunk| {
chunk.scores.vector.unwrap_or(0.0)
}),
fts_chunk_scores: sample_scores(&fts_chunks, |chunk| chunk.scores.fts.unwrap_or(0.0)),
});
}
normalize_fts_scores(&mut fts_entities);
normalize_fts_scores(&mut fts_chunks);
@@ -183,9 +265,11 @@ pub async fn collect_candidates(
apply_fusion(&mut ctx.entity_candidates, weights);
apply_fusion(&mut ctx.chunk_candidates, weights);
machine
let next = machine
.collect_candidates()
.map_err(|(_, guard)| map_guard_error("collect_candidates", guard))
.map_err(|(_, guard)| map_guard_error("collect_candidates", guard))?;
ctx.record_collect_candidates_timing(stage_start.elapsed());
Ok(next)
}
#[instrument(level = "trace", skip_all)]
@@ -193,82 +277,84 @@ pub async fn expand_graph(
machine: HybridRetrievalMachine<(), CandidatesLoaded>,
ctx: &mut PipelineContext<'_>,
) -> Result<HybridRetrievalMachine<(), GraphExpanded>, AppError> {
let stage_start = Instant::now();
debug!("Expanding candidates using graph relationships");
let tuning = &ctx.config.tuning;
let weights = FusionWeights::default();
let next = {
let tuning = &ctx.config.tuning;
let weights = FusionWeights::default();
if ctx.entity_candidates.is_empty() {
return machine
.expand_graph()
.map_err(|(_, guard)| map_guard_error("expand_graph", guard));
}
if ctx.entity_candidates.is_empty() {
machine
.expand_graph()
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
} else {
let graph_seeds = seeds_from_candidates(
&ctx.entity_candidates,
tuning.graph_seed_min_score,
tuning.graph_traversal_seed_limit,
);
let graph_seeds = seeds_from_candidates(
&ctx.entity_candidates,
tuning.graph_seed_min_score,
tuning.graph_traversal_seed_limit,
);
if graph_seeds.is_empty() {
machine
.expand_graph()
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
} else {
let mut futures = FuturesUnordered::new();
for seed in graph_seeds {
let db = ctx.db_client;
let user = ctx.user_id.clone();
let limit = tuning.graph_neighbor_limit;
futures.push(async move {
let neighbors =
find_entities_by_relationship_by_id(db, &seed.id, &user, limit).await;
(seed, neighbors)
});
}
if graph_seeds.is_empty() {
return machine
.expand_graph()
.map_err(|(_, guard)| map_guard_error("expand_graph", guard));
}
while let Some((seed, neighbors_result)) = futures.next().await {
let neighbors = neighbors_result.map_err(AppError::from)?;
if neighbors.is_empty() {
continue;
}
let mut futures = FuturesUnordered::new();
for seed in graph_seeds {
let db = ctx.db_client;
let user = ctx.user_id.clone();
futures.push(async move {
let neighbors = find_entities_by_relationship_by_id(
db,
&seed.id,
&user,
tuning.graph_neighbor_limit,
)
.await;
(seed, neighbors)
});
}
for neighbor in neighbors {
if neighbor.id == seed.id {
continue;
}
while let Some((seed, neighbors_result)) = futures.next().await {
let neighbors = neighbors_result.map_err(AppError::from)?;
if neighbors.is_empty() {
continue;
let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay);
let entry = ctx
.entity_candidates
.entry(neighbor.id.clone())
.or_insert_with(|| Scored::new(neighbor.clone()));
entry.item = neighbor;
let inherited_vector =
clamp_unit(graph_score * tuning.graph_vector_inheritance);
let vector_existing = entry.scores.vector.unwrap_or(0.0);
if inherited_vector > vector_existing {
entry.scores.vector = Some(inherited_vector);
}
let existing_graph = entry.scores.graph.unwrap_or(f32::MIN);
if graph_score > existing_graph || entry.scores.graph.is_none() {
entry.scores.graph = Some(graph_score);
}
let fused = fuse_scores(&entry.scores, weights);
entry.update_fused(fused);
}
}
machine
.expand_graph()
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
}
}
for neighbor in neighbors {
if neighbor.id == seed.id {
continue;
}
let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay);
let entry = ctx
.entity_candidates
.entry(neighbor.id.clone())
.or_insert_with(|| Scored::new(neighbor.clone()));
entry.item = neighbor;
let inherited_vector = clamp_unit(graph_score * tuning.graph_vector_inheritance);
let vector_existing = entry.scores.vector.unwrap_or(0.0);
if inherited_vector > vector_existing {
entry.scores.vector = Some(inherited_vector);
}
let existing_graph = entry.scores.graph.unwrap_or(f32::MIN);
if graph_score > existing_graph || entry.scores.graph.is_none() {
entry.scores.graph = Some(graph_score);
}
let fused = fuse_scores(&entry.scores, weights);
entry.update_fused(fused);
}
}
machine
.expand_graph()
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
}?;
ctx.record_graph_expansion_timing(stage_start.elapsed());
Ok(next)
}
#[instrument(level = "trace", skip_all)]
@@ -276,11 +362,14 @@ pub async fn attach_chunks(
machine: HybridRetrievalMachine<(), GraphExpanded>,
ctx: &mut PipelineContext<'_>,
) -> Result<HybridRetrievalMachine<(), ChunksAttached>, AppError> {
let stage_start = Instant::now();
debug!("Attaching chunks to surviving entities");
let tuning = &ctx.config.tuning;
let weights = FusionWeights::default();
let chunk_by_source = group_chunks_by_source(&ctx.chunk_candidates);
let chunk_candidates_before = ctx.chunk_candidates.len();
let chunk_sources_considered = chunk_by_source.len();
backfill_entities_from_chunks(
&mut ctx.entity_candidates,
@@ -312,6 +401,8 @@ pub async fn attach_chunks(
ctx.filtered_entities = filtered_entities;
let query_embedding = ctx.ensure_embedding()?.clone();
let mut chunk_results: Vec<Scored<TextChunk>> =
ctx.chunk_candidates.values().cloned().collect();
sort_by_fused_desc(&mut chunk_results);
@@ -327,17 +418,31 @@ pub async fn attach_chunks(
ctx.db_client,
&ctx.user_id,
weights,
&query_embedding,
)
.await?;
let mut chunk_values: Vec<Scored<TextChunk>> = chunk_by_id.into_values().collect();
sort_by_fused_desc(&mut chunk_values);
if ctx.diagnostics_enabled() {
ctx.record_chunk_enrichment(ChunkEnrichmentStats {
filtered_entity_count: ctx.filtered_entities.len(),
fallback_min_results: tuning.fallback_min_results,
chunk_sources_considered,
chunk_candidates_before_enrichment: chunk_candidates_before,
chunk_candidates_after_enrichment: chunk_values.len(),
top_chunk_scores: sample_scores(&chunk_values, |chunk| chunk.fused),
});
}
ctx.chunk_values = chunk_values;
machine
let next = machine
.attach_chunks()
.map_err(|(_, guard)| map_guard_error("attach_chunks", guard))
.map_err(|(_, guard)| map_guard_error("attach_chunks", guard))?;
ctx.record_chunk_attach_timing(stage_start.elapsed());
Ok(next)
}
#[instrument(level = "trace", skip_all)]
@@ -345,6 +450,7 @@ pub async fn rerank(
machine: HybridRetrievalMachine<(), ChunksAttached>,
ctx: &mut PipelineContext<'_>,
) -> Result<HybridRetrievalMachine<(), Reranked>, AppError> {
let stage_start = Instant::now();
let mut applied = false;
if let Some(reranker) = ctx.reranker.as_ref() {
@@ -384,9 +490,11 @@ pub async fn rerank(
debug!("Applied reranking adjustments to candidate ordering");
}
machine
let next = machine
.rerank()
.map_err(|(_, guard)| map_guard_error("rerank", guard))
.map_err(|(_, guard)| map_guard_error("rerank", guard))?;
ctx.record_rerank_timing(stage_start.elapsed());
Ok(next)
}
#[instrument(level = "trace", skip_all)]
@@ -394,8 +502,11 @@ pub fn assemble(
machine: HybridRetrievalMachine<(), Reranked>,
ctx: &mut PipelineContext<'_>,
) -> Result<Vec<RetrievedEntity>, AppError> {
let stage_start = Instant::now();
debug!("Assembling final retrieved entities");
let tuning = &ctx.config.tuning;
let query_embedding = ctx.ensure_embedding()?.clone();
let question_terms = extract_keywords(&ctx.input_text);
let mut chunk_by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
for chunk in ctx.chunk_values.drain(..) {
@@ -406,39 +517,68 @@ pub fn assemble(
}
for chunk_list in chunk_by_source.values_mut() {
sort_by_fused_desc(chunk_list);
chunk_list.sort_by(|a, b| {
let sim_a = cosine_similarity(&query_embedding, &a.item.embedding);
let sim_b = cosine_similarity(&query_embedding, &b.item.embedding);
sim_b.partial_cmp(&sim_a).unwrap_or(Ordering::Equal)
});
}
let mut token_budget_remaining = tuning.token_budget_estimate;
let mut results = Vec::new();
let diagnostics_enabled = ctx.diagnostics_enabled();
let mut per_entity_traces = Vec::new();
let mut chunks_skipped_due_budget = 0usize;
let mut chunks_selected = 0usize;
let mut tokens_spent = 0usize;
for entity in &ctx.filtered_entities {
let mut selected_chunks = Vec::new();
let mut entity_trace = if diagnostics_enabled {
Some(EntityAssemblyTrace {
entity_id: entity.item.id.clone(),
source_id: entity.item.source_id.clone(),
inspected_candidates: 0,
selected_chunk_ids: Vec::new(),
selected_chunk_scores: Vec::new(),
skipped_due_budget: 0,
})
} else {
None
};
if let Some(candidates) = chunk_by_source.get_mut(&entity.item.source_id) {
rank_chunks_by_combined_score(candidates, &question_terms, tuning.lexical_match_weight);
let mut per_entity_count = 0;
candidates.sort_by(|a, b| {
b.fused
.partial_cmp(&a.fused)
.unwrap_or(std::cmp::Ordering::Equal)
});
for candidate in candidates.iter() {
if let Some(trace) = entity_trace.as_mut() {
trace.inspected_candidates += 1;
}
if per_entity_count >= tuning.max_chunks_per_entity {
break;
}
let estimated_tokens =
estimate_tokens(&candidate.item.chunk, tuning.avg_chars_per_token);
if estimated_tokens > token_budget_remaining {
chunks_skipped_due_budget += 1;
if let Some(trace) = entity_trace.as_mut() {
trace.skipped_due_budget += 1;
}
continue;
}
token_budget_remaining = token_budget_remaining.saturating_sub(estimated_tokens);
tokens_spent += estimated_tokens;
per_entity_count += 1;
chunks_selected += 1;
selected_chunks.push(RetrievedChunk {
chunk: candidate.item.clone(),
score: candidate.fused,
});
if let Some(trace) = entity_trace.as_mut() {
trace.selected_chunk_ids.push(candidate.item.id.clone());
trace.selected_chunk_scores.push(candidate.fused);
}
}
}
@@ -448,17 +588,48 @@ pub fn assemble(
chunks: selected_chunks,
});
if let Some(trace) = entity_trace {
per_entity_traces.push(trace);
}
if token_budget_remaining == 0 {
break;
}
}
if diagnostics_enabled {
ctx.record_assemble(AssembleStats {
token_budget_start: tuning.token_budget_estimate,
token_budget_spent: tokens_spent,
token_budget_remaining,
budget_exhausted: token_budget_remaining == 0,
chunks_selected,
chunks_skipped_due_budget,
entity_count: ctx.filtered_entities.len(),
entity_traces: per_entity_traces,
});
}
machine
.assemble()
.map_err(|(_, guard)| map_guard_error("assemble", guard))?;
ctx.record_assemble_timing(stage_start.elapsed());
Ok(results)
}
const SCORE_SAMPLE_LIMIT: usize = 8;
fn sample_scores<T, F>(items: &[Scored<T>], mut extractor: F) -> Vec<f32>
where
F: FnMut(&Scored<T>) -> f32,
{
items
.iter()
.take(SCORE_SAMPLE_LIMIT)
.map(|item| extractor(item))
.collect()
}
fn map_guard_error(stage: &'static str, err: GuardError) -> AppError {
AppError::InternalError(format!(
"state machine guard '{stage}' failed: guard={}, event={}, kind={:?}",
@@ -582,6 +753,7 @@ async fn enrich_chunks_from_entities(
db_client: &SurrealDbClient,
user_id: &str,
weights: FusionWeights,
query_embedding: &[f32],
) -> Result<(), AppError> {
let mut source_ids: HashSet<String> = HashSet::new();
for entity in entities {
@@ -615,7 +787,16 @@ async fn enrich_chunks_from_entities(
.copied()
.unwrap_or(0.0);
entry.scores.vector = Some(entry.scores.vector.unwrap_or(0.0).max(entity_score * 0.8));
let similarity = cosine_similarity(query_embedding, &chunk.embedding);
entry.scores.vector = Some(
entry
.scores
.vector
.unwrap_or(0.0)
.max(entity_score * 0.8)
.max(similarity),
);
let fused = fuse_scores(&entry.scores, weights);
entry.update_fused(fused);
entry.item = chunk;
@@ -624,6 +805,24 @@ async fn enrich_chunks_from_entities(
Ok(())
}
fn cosine_similarity(query: &[f32], embedding: &[f32]) -> f32 {
if query.is_empty() || embedding.is_empty() || query.len() != embedding.len() {
return 0.0;
}
let mut dot = 0.0f32;
let mut norm_q = 0.0f32;
let mut norm_e = 0.0f32;
for (q, e) in query.iter().zip(embedding.iter()) {
dot += q * e;
norm_q += q * q;
norm_e += e * e;
}
if norm_q == 0.0 || norm_e == 0.0 {
return 0.0;
}
dot / (norm_q.sqrt() * norm_e.sqrt())
}
fn build_rerank_documents(ctx: &PipelineContext<'_>, max_chunks_per_entity: usize) -> Vec<String> {
if ctx.filtered_entities.is_empty() {
return Vec::new();
@@ -736,6 +935,48 @@ fn estimate_tokens(text: &str, avg_chars_per_token: usize) -> usize {
(chars / avg_chars_per_token).max(1)
}
fn rank_chunks_by_combined_score(
candidates: &mut [Scored<TextChunk>],
question_terms: &[String],
lexical_weight: f32,
) {
if lexical_weight > 0.0 && !question_terms.is_empty() {
for candidate in candidates.iter_mut() {
let lexical = lexical_overlap_score(question_terms, &candidate.item.chunk);
let combined = clamp_unit(candidate.fused + lexical_weight * lexical);
candidate.update_fused(combined);
}
}
candidates.sort_by(|a, b| b.fused.partial_cmp(&a.fused).unwrap_or(Ordering::Equal));
}
fn extract_keywords(text: &str) -> Vec<String> {
let mut terms = Vec::new();
for raw in text.split(|c: char| !c.is_alphanumeric()) {
let term = raw.trim().to_ascii_lowercase();
if term.len() >= 3 {
terms.push(term);
}
}
terms.sort();
terms.dedup();
terms
}
fn lexical_overlap_score(terms: &[String], haystack: &str) -> f32 {
if terms.is_empty() {
return 0.0;
}
let lower = haystack.to_ascii_lowercase();
let mut matches = 0usize;
for term in terms {
if lower.contains(term) {
matches += 1;
}
}
(matches as f32) / (terms.len() as f32)
}
#[derive(Clone)]
struct GraphSeed {
id: String,
+1
View File
@@ -68,6 +68,7 @@ where
{
let embedding_literal = serde_json::to_string(&query_embedding)
.map_err(|err| AppError::InternalError(format!("Failed to serialize embedding: {err}")))?;
let closest_query = format!(
"SELECT id, vector::distance::knn() AS distance \
FROM {table} \