mirror of
https://github.com/perstarkse/minne.git
synced 2026-06-30 10:01:40 +02:00
benchmarks: v1
Benchmarking ingestion, retrieval precision and performance
This commit is contained in:
@@ -12,6 +12,7 @@ pub struct RetrievalTuning {
|
||||
pub token_budget_estimate: usize,
|
||||
pub avg_chars_per_token: usize,
|
||||
pub max_chunks_per_entity: usize,
|
||||
pub lexical_match_weight: f32,
|
||||
pub graph_traversal_seed_limit: usize,
|
||||
pub graph_neighbor_limit: usize,
|
||||
pub graph_score_decay: f32,
|
||||
@@ -31,9 +32,10 @@ impl Default for RetrievalTuning {
|
||||
chunk_fts_take: 20,
|
||||
score_threshold: 0.35,
|
||||
fallback_min_results: 10,
|
||||
token_budget_estimate: 2800,
|
||||
token_budget_estimate: 10000,
|
||||
avg_chars_per_token: 4,
|
||||
max_chunks_per_entity: 4,
|
||||
lexical_match_weight: 0.15,
|
||||
graph_traversal_seed_limit: 5,
|
||||
graph_neighbor_limit: 6,
|
||||
graph_score_decay: 0.75,
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
use serde::Serialize;
|
||||
|
||||
/// Captures instrumentation for each hybrid retrieval stage when diagnostics are enabled.
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct PipelineDiagnostics {
|
||||
pub collect_candidates: Option<CollectCandidatesStats>,
|
||||
pub enrich_chunks_from_entities: Option<ChunkEnrichmentStats>,
|
||||
pub assemble: Option<AssembleStats>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct CollectCandidatesStats {
|
||||
pub vector_entity_candidates: usize,
|
||||
pub vector_chunk_candidates: usize,
|
||||
pub fts_entity_candidates: usize,
|
||||
pub fts_chunk_candidates: usize,
|
||||
pub vector_chunk_scores: Vec<f32>,
|
||||
pub fts_chunk_scores: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct ChunkEnrichmentStats {
|
||||
pub filtered_entity_count: usize,
|
||||
pub fallback_min_results: usize,
|
||||
pub chunk_sources_considered: usize,
|
||||
pub chunk_candidates_before_enrichment: usize,
|
||||
pub chunk_candidates_after_enrichment: usize,
|
||||
pub top_chunk_scores: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct AssembleStats {
|
||||
pub token_budget_start: usize,
|
||||
pub token_budget_spent: usize,
|
||||
pub token_budget_remaining: usize,
|
||||
pub budget_exhausted: bool,
|
||||
pub chunks_selected: usize,
|
||||
pub chunks_skipped_due_budget: usize,
|
||||
pub entity_count: usize,
|
||||
pub entity_traces: Vec<EntityAssemblyTrace>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct EntityAssemblyTrace {
|
||||
pub entity_id: String,
|
||||
pub source_id: String,
|
||||
pub inspected_candidates: usize,
|
||||
pub selected_chunk_ids: Vec<String>,
|
||||
pub selected_chunk_scores: Vec<f32>,
|
||||
pub skipped_due_budget: usize,
|
||||
}
|
||||
@@ -1,14 +1,57 @@
|
||||
mod config;
|
||||
mod diagnostics;
|
||||
mod stages;
|
||||
mod state;
|
||||
|
||||
pub use config::{RetrievalConfig, RetrievalTuning};
|
||||
pub use diagnostics::{
|
||||
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
|
||||
PipelineDiagnostics,
|
||||
};
|
||||
|
||||
use crate::{reranking::RerankerLease, RetrievedEntity};
|
||||
use async_openai::Client;
|
||||
use common::{error::AppError, storage::db::SurrealDbClient};
|
||||
use tracing::info;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PipelineRunOutput {
|
||||
pub results: Vec<RetrievedEntity>,
|
||||
pub diagnostics: Option<PipelineDiagnostics>,
|
||||
pub stage_timings: PipelineStageTimings,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, serde::Serialize)]
|
||||
pub struct PipelineStageTimings {
|
||||
pub collect_candidates_ms: u128,
|
||||
pub graph_expansion_ms: u128,
|
||||
pub chunk_attach_ms: u128,
|
||||
pub rerank_ms: u128,
|
||||
pub assemble_ms: u128,
|
||||
}
|
||||
|
||||
impl PipelineStageTimings {
|
||||
fn record_collect_candidates(&mut self, duration: std::time::Duration) {
|
||||
self.collect_candidates_ms += duration.as_millis() as u128;
|
||||
}
|
||||
|
||||
fn record_graph_expansion(&mut self, duration: std::time::Duration) {
|
||||
self.graph_expansion_ms += duration.as_millis() as u128;
|
||||
}
|
||||
|
||||
fn record_chunk_attach(&mut self, duration: std::time::Duration) {
|
||||
self.chunk_attach_ms += duration.as_millis() as u128;
|
||||
}
|
||||
|
||||
fn record_rerank(&mut self, duration: std::time::Duration) {
|
||||
self.rerank_ms += duration.as_millis() as u128;
|
||||
}
|
||||
|
||||
fn record_assemble(&mut self, duration: std::time::Duration) {
|
||||
self.assemble_ms += duration.as_millis() as u128;
|
||||
}
|
||||
}
|
||||
|
||||
/// Drives the retrieval pipeline from embedding through final assembly.
|
||||
pub async fn run_pipeline(
|
||||
db_client: &SurrealDbClient,
|
||||
@@ -18,7 +61,6 @@ pub async fn run_pipeline(
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
let machine = state::ready();
|
||||
let input_chars = input_text.chars().count();
|
||||
let input_preview: String = input_text.chars().take(120).collect();
|
||||
let input_preview_clean = input_preview.replace('\n', " ");
|
||||
@@ -30,7 +72,7 @@ pub async fn run_pipeline(
|
||||
preview = %input_preview_clean,
|
||||
"Starting ingestion retrieval pipeline"
|
||||
);
|
||||
let mut ctx = stages::PipelineContext::new(
|
||||
let ctx = stages::PipelineContext::new(
|
||||
db_client,
|
||||
openai_client,
|
||||
input_text.to_owned(),
|
||||
@@ -38,17 +80,11 @@ pub async fn run_pipeline(
|
||||
config,
|
||||
reranker,
|
||||
);
|
||||
let machine = stages::embed(machine, &mut ctx).await?;
|
||||
let machine = stages::collect_candidates(machine, &mut ctx).await?;
|
||||
let machine = stages::expand_graph(machine, &mut ctx).await?;
|
||||
let machine = stages::attach_chunks(machine, &mut ctx).await?;
|
||||
let machine = stages::rerank(machine, &mut ctx).await?;
|
||||
let results = stages::assemble(machine, &mut ctx)?;
|
||||
let outcome = run_pipeline_internal(ctx, false).await?;
|
||||
|
||||
Ok(results)
|
||||
Ok(outcome.results)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub async fn run_pipeline_with_embedding(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
@@ -58,8 +94,7 @@ pub async fn run_pipeline_with_embedding(
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
let machine = state::ready();
|
||||
let mut ctx = stages::PipelineContext::with_embedding(
|
||||
let ctx = stages::PipelineContext::with_embedding(
|
||||
db_client,
|
||||
openai_client,
|
||||
query_embedding,
|
||||
@@ -68,14 +103,54 @@ pub async fn run_pipeline_with_embedding(
|
||||
config,
|
||||
reranker,
|
||||
);
|
||||
let machine = stages::embed(machine, &mut ctx).await?;
|
||||
let machine = stages::collect_candidates(machine, &mut ctx).await?;
|
||||
let machine = stages::expand_graph(machine, &mut ctx).await?;
|
||||
let machine = stages::attach_chunks(machine, &mut ctx).await?;
|
||||
let machine = stages::rerank(machine, &mut ctx).await?;
|
||||
let results = stages::assemble(machine, &mut ctx)?;
|
||||
let outcome = run_pipeline_internal(ctx, false).await?;
|
||||
|
||||
Ok(results)
|
||||
Ok(outcome.results)
|
||||
}
|
||||
|
||||
/// Runs the pipeline with a precomputed embedding and returns stage metrics.
|
||||
pub async fn run_pipeline_with_embedding_with_metrics(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<PipelineRunOutput, AppError> {
|
||||
let ctx = stages::PipelineContext::with_embedding(
|
||||
db_client,
|
||||
openai_client,
|
||||
query_embedding,
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
config,
|
||||
reranker,
|
||||
);
|
||||
|
||||
run_pipeline_internal(ctx, false).await
|
||||
}
|
||||
|
||||
pub async fn run_pipeline_with_embedding_with_diagnostics(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<PipelineRunOutput, AppError> {
|
||||
let ctx = stages::PipelineContext::with_embedding(
|
||||
db_client,
|
||||
openai_client,
|
||||
query_embedding,
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
config,
|
||||
reranker,
|
||||
);
|
||||
|
||||
run_pipeline_internal(ctx, true).await
|
||||
}
|
||||
|
||||
/// Helper exposed for tests to convert retrieved entities into downstream prompt JSON.
|
||||
@@ -101,6 +176,37 @@ pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::V
|
||||
.collect::<Vec<_>>())
|
||||
}
|
||||
|
||||
async fn run_pipeline_internal(
|
||||
mut ctx: stages::PipelineContext<'_>,
|
||||
capture_diagnostics: bool,
|
||||
) -> Result<PipelineRunOutput, AppError> {
|
||||
if capture_diagnostics {
|
||||
ctx.enable_diagnostics();
|
||||
}
|
||||
|
||||
let results = drive_pipeline(&mut ctx).await?;
|
||||
let diagnostics = ctx.take_diagnostics();
|
||||
|
||||
Ok(PipelineRunOutput {
|
||||
results,
|
||||
diagnostics,
|
||||
stage_timings: ctx.take_stage_timings(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn drive_pipeline(
|
||||
ctx: &mut stages::PipelineContext<'_>,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
let machine = state::ready();
|
||||
let machine = stages::embed(machine, ctx).await?;
|
||||
let machine = stages::collect_candidates(machine, ctx).await?;
|
||||
let machine = stages::expand_graph(machine, ctx).await?;
|
||||
let machine = stages::attach_chunks(machine, ctx).await?;
|
||||
let machine = stages::rerank(machine, ctx).await?;
|
||||
let results = stages::assemble(machine, ctx)?;
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
fn round_score(value: f32) -> f64 {
|
||||
(f64::from(value) * 1000.0).round() / 1000.0
|
||||
}
|
||||
|
||||
@@ -10,7 +10,11 @@ use common::{
|
||||
use fastembed::RerankResult;
|
||||
use futures::{stream::FuturesUnordered, StreamExt};
|
||||
use state_machines::core::GuardError;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
collections::{HashMap, HashSet},
|
||||
time::Instant,
|
||||
};
|
||||
use tracing::{debug, instrument, warn};
|
||||
|
||||
use crate::{
|
||||
@@ -27,10 +31,15 @@ use crate::{
|
||||
|
||||
use super::{
|
||||
config::RetrievalConfig,
|
||||
diagnostics::{
|
||||
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
|
||||
PipelineDiagnostics,
|
||||
},
|
||||
state::{
|
||||
CandidatesLoaded, ChunksAttached, Embedded, GraphExpanded, HybridRetrievalMachine, Ready,
|
||||
Reranked,
|
||||
},
|
||||
PipelineStageTimings,
|
||||
};
|
||||
|
||||
pub struct PipelineContext<'a> {
|
||||
@@ -45,6 +54,8 @@ pub struct PipelineContext<'a> {
|
||||
pub filtered_entities: Vec<Scored<KnowledgeEntity>>,
|
||||
pub chunk_values: Vec<Scored<TextChunk>>,
|
||||
pub reranker: Option<RerankerLease>,
|
||||
pub diagnostics: Option<PipelineDiagnostics>,
|
||||
stage_timings: PipelineStageTimings,
|
||||
}
|
||||
|
||||
impl<'a> PipelineContext<'a> {
|
||||
@@ -68,10 +79,11 @@ impl<'a> PipelineContext<'a> {
|
||||
filtered_entities: Vec::new(),
|
||||
chunk_values: Vec::new(),
|
||||
reranker,
|
||||
diagnostics: None,
|
||||
stage_timings: PipelineStageTimings::default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn with_embedding(
|
||||
db_client: &'a SurrealDbClient,
|
||||
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
@@ -100,6 +112,62 @@ impl<'a> PipelineContext<'a> {
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn enable_diagnostics(&mut self) {
|
||||
if self.diagnostics.is_none() {
|
||||
self.diagnostics = Some(PipelineDiagnostics::default());
|
||||
}
|
||||
}
|
||||
|
||||
pub fn diagnostics_enabled(&self) -> bool {
|
||||
self.diagnostics.is_some()
|
||||
}
|
||||
|
||||
pub fn record_collect_candidates(&mut self, stats: CollectCandidatesStats) {
|
||||
if let Some(diag) = self.diagnostics.as_mut() {
|
||||
diag.collect_candidates = Some(stats);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn record_chunk_enrichment(&mut self, stats: ChunkEnrichmentStats) {
|
||||
if let Some(diag) = self.diagnostics.as_mut() {
|
||||
diag.enrich_chunks_from_entities = Some(stats);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn record_assemble(&mut self, stats: AssembleStats) {
|
||||
if let Some(diag) = self.diagnostics.as_mut() {
|
||||
diag.assemble = Some(stats);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn take_diagnostics(&mut self) -> Option<PipelineDiagnostics> {
|
||||
self.diagnostics.take()
|
||||
}
|
||||
|
||||
pub fn record_collect_candidates_timing(&mut self, duration: std::time::Duration) {
|
||||
self.stage_timings.record_collect_candidates(duration);
|
||||
}
|
||||
|
||||
pub fn record_graph_expansion_timing(&mut self, duration: std::time::Duration) {
|
||||
self.stage_timings.record_graph_expansion(duration);
|
||||
}
|
||||
|
||||
pub fn record_chunk_attach_timing(&mut self, duration: std::time::Duration) {
|
||||
self.stage_timings.record_chunk_attach(duration);
|
||||
}
|
||||
|
||||
pub fn record_rerank_timing(&mut self, duration: std::time::Duration) {
|
||||
self.stage_timings.record_rerank(duration);
|
||||
}
|
||||
|
||||
pub fn record_assemble_timing(&mut self, duration: std::time::Duration) {
|
||||
self.stage_timings.record_assemble(duration);
|
||||
}
|
||||
|
||||
pub fn take_stage_timings(&mut self) -> PipelineStageTimings {
|
||||
std::mem::take(&mut self.stage_timings)
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
@@ -127,6 +195,7 @@ pub async fn collect_candidates(
|
||||
machine: HybridRetrievalMachine<(), Embedded>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<HybridRetrievalMachine<(), CandidatesLoaded>, AppError> {
|
||||
let stage_start = Instant::now();
|
||||
debug!("Collecting initial candidates via vector and FTS search");
|
||||
let embedding = ctx.ensure_embedding()?.clone();
|
||||
let tuning = &ctx.config.tuning;
|
||||
@@ -172,6 +241,19 @@ pub async fn collect_candidates(
|
||||
"Hybrid retrieval initial candidate counts"
|
||||
);
|
||||
|
||||
if ctx.diagnostics_enabled() {
|
||||
ctx.record_collect_candidates(CollectCandidatesStats {
|
||||
vector_entity_candidates: vector_entities.len(),
|
||||
vector_chunk_candidates: vector_chunks.len(),
|
||||
fts_entity_candidates: fts_entities.len(),
|
||||
fts_chunk_candidates: fts_chunks.len(),
|
||||
vector_chunk_scores: sample_scores(&vector_chunks, |chunk| {
|
||||
chunk.scores.vector.unwrap_or(0.0)
|
||||
}),
|
||||
fts_chunk_scores: sample_scores(&fts_chunks, |chunk| chunk.scores.fts.unwrap_or(0.0)),
|
||||
});
|
||||
}
|
||||
|
||||
normalize_fts_scores(&mut fts_entities);
|
||||
normalize_fts_scores(&mut fts_chunks);
|
||||
|
||||
@@ -183,9 +265,11 @@ pub async fn collect_candidates(
|
||||
apply_fusion(&mut ctx.entity_candidates, weights);
|
||||
apply_fusion(&mut ctx.chunk_candidates, weights);
|
||||
|
||||
machine
|
||||
let next = machine
|
||||
.collect_candidates()
|
||||
.map_err(|(_, guard)| map_guard_error("collect_candidates", guard))
|
||||
.map_err(|(_, guard)| map_guard_error("collect_candidates", guard))?;
|
||||
ctx.record_collect_candidates_timing(stage_start.elapsed());
|
||||
Ok(next)
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
@@ -193,82 +277,84 @@ pub async fn expand_graph(
|
||||
machine: HybridRetrievalMachine<(), CandidatesLoaded>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<HybridRetrievalMachine<(), GraphExpanded>, AppError> {
|
||||
let stage_start = Instant::now();
|
||||
debug!("Expanding candidates using graph relationships");
|
||||
let tuning = &ctx.config.tuning;
|
||||
let weights = FusionWeights::default();
|
||||
let next = {
|
||||
let tuning = &ctx.config.tuning;
|
||||
let weights = FusionWeights::default();
|
||||
|
||||
if ctx.entity_candidates.is_empty() {
|
||||
return machine
|
||||
.expand_graph()
|
||||
.map_err(|(_, guard)| map_guard_error("expand_graph", guard));
|
||||
}
|
||||
if ctx.entity_candidates.is_empty() {
|
||||
machine
|
||||
.expand_graph()
|
||||
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
|
||||
} else {
|
||||
let graph_seeds = seeds_from_candidates(
|
||||
&ctx.entity_candidates,
|
||||
tuning.graph_seed_min_score,
|
||||
tuning.graph_traversal_seed_limit,
|
||||
);
|
||||
|
||||
let graph_seeds = seeds_from_candidates(
|
||||
&ctx.entity_candidates,
|
||||
tuning.graph_seed_min_score,
|
||||
tuning.graph_traversal_seed_limit,
|
||||
);
|
||||
if graph_seeds.is_empty() {
|
||||
machine
|
||||
.expand_graph()
|
||||
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
|
||||
} else {
|
||||
let mut futures = FuturesUnordered::new();
|
||||
for seed in graph_seeds {
|
||||
let db = ctx.db_client;
|
||||
let user = ctx.user_id.clone();
|
||||
let limit = tuning.graph_neighbor_limit;
|
||||
futures.push(async move {
|
||||
let neighbors =
|
||||
find_entities_by_relationship_by_id(db, &seed.id, &user, limit).await;
|
||||
(seed, neighbors)
|
||||
});
|
||||
}
|
||||
|
||||
if graph_seeds.is_empty() {
|
||||
return machine
|
||||
.expand_graph()
|
||||
.map_err(|(_, guard)| map_guard_error("expand_graph", guard));
|
||||
}
|
||||
while let Some((seed, neighbors_result)) = futures.next().await {
|
||||
let neighbors = neighbors_result.map_err(AppError::from)?;
|
||||
if neighbors.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut futures = FuturesUnordered::new();
|
||||
for seed in graph_seeds {
|
||||
let db = ctx.db_client;
|
||||
let user = ctx.user_id.clone();
|
||||
futures.push(async move {
|
||||
let neighbors = find_entities_by_relationship_by_id(
|
||||
db,
|
||||
&seed.id,
|
||||
&user,
|
||||
tuning.graph_neighbor_limit,
|
||||
)
|
||||
.await;
|
||||
(seed, neighbors)
|
||||
});
|
||||
}
|
||||
for neighbor in neighbors {
|
||||
if neighbor.id == seed.id {
|
||||
continue;
|
||||
}
|
||||
|
||||
while let Some((seed, neighbors_result)) = futures.next().await {
|
||||
let neighbors = neighbors_result.map_err(AppError::from)?;
|
||||
if neighbors.is_empty() {
|
||||
continue;
|
||||
let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay);
|
||||
let entry = ctx
|
||||
.entity_candidates
|
||||
.entry(neighbor.id.clone())
|
||||
.or_insert_with(|| Scored::new(neighbor.clone()));
|
||||
|
||||
entry.item = neighbor;
|
||||
|
||||
let inherited_vector =
|
||||
clamp_unit(graph_score * tuning.graph_vector_inheritance);
|
||||
let vector_existing = entry.scores.vector.unwrap_or(0.0);
|
||||
if inherited_vector > vector_existing {
|
||||
entry.scores.vector = Some(inherited_vector);
|
||||
}
|
||||
|
||||
let existing_graph = entry.scores.graph.unwrap_or(f32::MIN);
|
||||
if graph_score > existing_graph || entry.scores.graph.is_none() {
|
||||
entry.scores.graph = Some(graph_score);
|
||||
}
|
||||
|
||||
let fused = fuse_scores(&entry.scores, weights);
|
||||
entry.update_fused(fused);
|
||||
}
|
||||
}
|
||||
|
||||
machine
|
||||
.expand_graph()
|
||||
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
|
||||
}
|
||||
}
|
||||
|
||||
for neighbor in neighbors {
|
||||
if neighbor.id == seed.id {
|
||||
continue;
|
||||
}
|
||||
|
||||
let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay);
|
||||
let entry = ctx
|
||||
.entity_candidates
|
||||
.entry(neighbor.id.clone())
|
||||
.or_insert_with(|| Scored::new(neighbor.clone()));
|
||||
|
||||
entry.item = neighbor;
|
||||
|
||||
let inherited_vector = clamp_unit(graph_score * tuning.graph_vector_inheritance);
|
||||
let vector_existing = entry.scores.vector.unwrap_or(0.0);
|
||||
if inherited_vector > vector_existing {
|
||||
entry.scores.vector = Some(inherited_vector);
|
||||
}
|
||||
|
||||
let existing_graph = entry.scores.graph.unwrap_or(f32::MIN);
|
||||
if graph_score > existing_graph || entry.scores.graph.is_none() {
|
||||
entry.scores.graph = Some(graph_score);
|
||||
}
|
||||
|
||||
let fused = fuse_scores(&entry.scores, weights);
|
||||
entry.update_fused(fused);
|
||||
}
|
||||
}
|
||||
|
||||
machine
|
||||
.expand_graph()
|
||||
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
|
||||
}?;
|
||||
ctx.record_graph_expansion_timing(stage_start.elapsed());
|
||||
Ok(next)
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
@@ -276,11 +362,14 @@ pub async fn attach_chunks(
|
||||
machine: HybridRetrievalMachine<(), GraphExpanded>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<HybridRetrievalMachine<(), ChunksAttached>, AppError> {
|
||||
let stage_start = Instant::now();
|
||||
debug!("Attaching chunks to surviving entities");
|
||||
let tuning = &ctx.config.tuning;
|
||||
let weights = FusionWeights::default();
|
||||
|
||||
let chunk_by_source = group_chunks_by_source(&ctx.chunk_candidates);
|
||||
let chunk_candidates_before = ctx.chunk_candidates.len();
|
||||
let chunk_sources_considered = chunk_by_source.len();
|
||||
|
||||
backfill_entities_from_chunks(
|
||||
&mut ctx.entity_candidates,
|
||||
@@ -312,6 +401,8 @@ pub async fn attach_chunks(
|
||||
|
||||
ctx.filtered_entities = filtered_entities;
|
||||
|
||||
let query_embedding = ctx.ensure_embedding()?.clone();
|
||||
|
||||
let mut chunk_results: Vec<Scored<TextChunk>> =
|
||||
ctx.chunk_candidates.values().cloned().collect();
|
||||
sort_by_fused_desc(&mut chunk_results);
|
||||
@@ -327,17 +418,31 @@ pub async fn attach_chunks(
|
||||
ctx.db_client,
|
||||
&ctx.user_id,
|
||||
weights,
|
||||
&query_embedding,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut chunk_values: Vec<Scored<TextChunk>> = chunk_by_id.into_values().collect();
|
||||
sort_by_fused_desc(&mut chunk_values);
|
||||
|
||||
if ctx.diagnostics_enabled() {
|
||||
ctx.record_chunk_enrichment(ChunkEnrichmentStats {
|
||||
filtered_entity_count: ctx.filtered_entities.len(),
|
||||
fallback_min_results: tuning.fallback_min_results,
|
||||
chunk_sources_considered,
|
||||
chunk_candidates_before_enrichment: chunk_candidates_before,
|
||||
chunk_candidates_after_enrichment: chunk_values.len(),
|
||||
top_chunk_scores: sample_scores(&chunk_values, |chunk| chunk.fused),
|
||||
});
|
||||
}
|
||||
|
||||
ctx.chunk_values = chunk_values;
|
||||
|
||||
machine
|
||||
let next = machine
|
||||
.attach_chunks()
|
||||
.map_err(|(_, guard)| map_guard_error("attach_chunks", guard))
|
||||
.map_err(|(_, guard)| map_guard_error("attach_chunks", guard))?;
|
||||
ctx.record_chunk_attach_timing(stage_start.elapsed());
|
||||
Ok(next)
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
@@ -345,6 +450,7 @@ pub async fn rerank(
|
||||
machine: HybridRetrievalMachine<(), ChunksAttached>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<HybridRetrievalMachine<(), Reranked>, AppError> {
|
||||
let stage_start = Instant::now();
|
||||
let mut applied = false;
|
||||
|
||||
if let Some(reranker) = ctx.reranker.as_ref() {
|
||||
@@ -384,9 +490,11 @@ pub async fn rerank(
|
||||
debug!("Applied reranking adjustments to candidate ordering");
|
||||
}
|
||||
|
||||
machine
|
||||
let next = machine
|
||||
.rerank()
|
||||
.map_err(|(_, guard)| map_guard_error("rerank", guard))
|
||||
.map_err(|(_, guard)| map_guard_error("rerank", guard))?;
|
||||
ctx.record_rerank_timing(stage_start.elapsed());
|
||||
Ok(next)
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
@@ -394,8 +502,11 @@ pub fn assemble(
|
||||
machine: HybridRetrievalMachine<(), Reranked>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
let stage_start = Instant::now();
|
||||
debug!("Assembling final retrieved entities");
|
||||
let tuning = &ctx.config.tuning;
|
||||
let query_embedding = ctx.ensure_embedding()?.clone();
|
||||
let question_terms = extract_keywords(&ctx.input_text);
|
||||
|
||||
let mut chunk_by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
|
||||
for chunk in ctx.chunk_values.drain(..) {
|
||||
@@ -406,39 +517,68 @@ pub fn assemble(
|
||||
}
|
||||
|
||||
for chunk_list in chunk_by_source.values_mut() {
|
||||
sort_by_fused_desc(chunk_list);
|
||||
chunk_list.sort_by(|a, b| {
|
||||
let sim_a = cosine_similarity(&query_embedding, &a.item.embedding);
|
||||
let sim_b = cosine_similarity(&query_embedding, &b.item.embedding);
|
||||
sim_b.partial_cmp(&sim_a).unwrap_or(Ordering::Equal)
|
||||
});
|
||||
}
|
||||
|
||||
let mut token_budget_remaining = tuning.token_budget_estimate;
|
||||
let mut results = Vec::new();
|
||||
let diagnostics_enabled = ctx.diagnostics_enabled();
|
||||
let mut per_entity_traces = Vec::new();
|
||||
let mut chunks_skipped_due_budget = 0usize;
|
||||
let mut chunks_selected = 0usize;
|
||||
let mut tokens_spent = 0usize;
|
||||
|
||||
for entity in &ctx.filtered_entities {
|
||||
let mut selected_chunks = Vec::new();
|
||||
let mut entity_trace = if diagnostics_enabled {
|
||||
Some(EntityAssemblyTrace {
|
||||
entity_id: entity.item.id.clone(),
|
||||
source_id: entity.item.source_id.clone(),
|
||||
inspected_candidates: 0,
|
||||
selected_chunk_ids: Vec::new(),
|
||||
selected_chunk_scores: Vec::new(),
|
||||
skipped_due_budget: 0,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
if let Some(candidates) = chunk_by_source.get_mut(&entity.item.source_id) {
|
||||
rank_chunks_by_combined_score(candidates, &question_terms, tuning.lexical_match_weight);
|
||||
let mut per_entity_count = 0;
|
||||
candidates.sort_by(|a, b| {
|
||||
b.fused
|
||||
.partial_cmp(&a.fused)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
for candidate in candidates.iter() {
|
||||
if let Some(trace) = entity_trace.as_mut() {
|
||||
trace.inspected_candidates += 1;
|
||||
}
|
||||
if per_entity_count >= tuning.max_chunks_per_entity {
|
||||
break;
|
||||
}
|
||||
let estimated_tokens =
|
||||
estimate_tokens(&candidate.item.chunk, tuning.avg_chars_per_token);
|
||||
if estimated_tokens > token_budget_remaining {
|
||||
chunks_skipped_due_budget += 1;
|
||||
if let Some(trace) = entity_trace.as_mut() {
|
||||
trace.skipped_due_budget += 1;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
token_budget_remaining = token_budget_remaining.saturating_sub(estimated_tokens);
|
||||
tokens_spent += estimated_tokens;
|
||||
per_entity_count += 1;
|
||||
chunks_selected += 1;
|
||||
|
||||
selected_chunks.push(RetrievedChunk {
|
||||
chunk: candidate.item.clone(),
|
||||
score: candidate.fused,
|
||||
});
|
||||
if let Some(trace) = entity_trace.as_mut() {
|
||||
trace.selected_chunk_ids.push(candidate.item.id.clone());
|
||||
trace.selected_chunk_scores.push(candidate.fused);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -448,17 +588,48 @@ pub fn assemble(
|
||||
chunks: selected_chunks,
|
||||
});
|
||||
|
||||
if let Some(trace) = entity_trace {
|
||||
per_entity_traces.push(trace);
|
||||
}
|
||||
|
||||
if token_budget_remaining == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if diagnostics_enabled {
|
||||
ctx.record_assemble(AssembleStats {
|
||||
token_budget_start: tuning.token_budget_estimate,
|
||||
token_budget_spent: tokens_spent,
|
||||
token_budget_remaining,
|
||||
budget_exhausted: token_budget_remaining == 0,
|
||||
chunks_selected,
|
||||
chunks_skipped_due_budget,
|
||||
entity_count: ctx.filtered_entities.len(),
|
||||
entity_traces: per_entity_traces,
|
||||
});
|
||||
}
|
||||
|
||||
machine
|
||||
.assemble()
|
||||
.map_err(|(_, guard)| map_guard_error("assemble", guard))?;
|
||||
ctx.record_assemble_timing(stage_start.elapsed());
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
const SCORE_SAMPLE_LIMIT: usize = 8;
|
||||
|
||||
fn sample_scores<T, F>(items: &[Scored<T>], mut extractor: F) -> Vec<f32>
|
||||
where
|
||||
F: FnMut(&Scored<T>) -> f32,
|
||||
{
|
||||
items
|
||||
.iter()
|
||||
.take(SCORE_SAMPLE_LIMIT)
|
||||
.map(|item| extractor(item))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn map_guard_error(stage: &'static str, err: GuardError) -> AppError {
|
||||
AppError::InternalError(format!(
|
||||
"state machine guard '{stage}' failed: guard={}, event={}, kind={:?}",
|
||||
@@ -582,6 +753,7 @@ async fn enrich_chunks_from_entities(
|
||||
db_client: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
weights: FusionWeights,
|
||||
query_embedding: &[f32],
|
||||
) -> Result<(), AppError> {
|
||||
let mut source_ids: HashSet<String> = HashSet::new();
|
||||
for entity in entities {
|
||||
@@ -615,7 +787,16 @@ async fn enrich_chunks_from_entities(
|
||||
.copied()
|
||||
.unwrap_or(0.0);
|
||||
|
||||
entry.scores.vector = Some(entry.scores.vector.unwrap_or(0.0).max(entity_score * 0.8));
|
||||
let similarity = cosine_similarity(query_embedding, &chunk.embedding);
|
||||
|
||||
entry.scores.vector = Some(
|
||||
entry
|
||||
.scores
|
||||
.vector
|
||||
.unwrap_or(0.0)
|
||||
.max(entity_score * 0.8)
|
||||
.max(similarity),
|
||||
);
|
||||
let fused = fuse_scores(&entry.scores, weights);
|
||||
entry.update_fused(fused);
|
||||
entry.item = chunk;
|
||||
@@ -624,6 +805,24 @@ async fn enrich_chunks_from_entities(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cosine_similarity(query: &[f32], embedding: &[f32]) -> f32 {
|
||||
if query.is_empty() || embedding.is_empty() || query.len() != embedding.len() {
|
||||
return 0.0;
|
||||
}
|
||||
let mut dot = 0.0f32;
|
||||
let mut norm_q = 0.0f32;
|
||||
let mut norm_e = 0.0f32;
|
||||
for (q, e) in query.iter().zip(embedding.iter()) {
|
||||
dot += q * e;
|
||||
norm_q += q * q;
|
||||
norm_e += e * e;
|
||||
}
|
||||
if norm_q == 0.0 || norm_e == 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
dot / (norm_q.sqrt() * norm_e.sqrt())
|
||||
}
|
||||
|
||||
fn build_rerank_documents(ctx: &PipelineContext<'_>, max_chunks_per_entity: usize) -> Vec<String> {
|
||||
if ctx.filtered_entities.is_empty() {
|
||||
return Vec::new();
|
||||
@@ -736,6 +935,48 @@ fn estimate_tokens(text: &str, avg_chars_per_token: usize) -> usize {
|
||||
(chars / avg_chars_per_token).max(1)
|
||||
}
|
||||
|
||||
fn rank_chunks_by_combined_score(
|
||||
candidates: &mut [Scored<TextChunk>],
|
||||
question_terms: &[String],
|
||||
lexical_weight: f32,
|
||||
) {
|
||||
if lexical_weight > 0.0 && !question_terms.is_empty() {
|
||||
for candidate in candidates.iter_mut() {
|
||||
let lexical = lexical_overlap_score(question_terms, &candidate.item.chunk);
|
||||
let combined = clamp_unit(candidate.fused + lexical_weight * lexical);
|
||||
candidate.update_fused(combined);
|
||||
}
|
||||
}
|
||||
candidates.sort_by(|a, b| b.fused.partial_cmp(&a.fused).unwrap_or(Ordering::Equal));
|
||||
}
|
||||
|
||||
fn extract_keywords(text: &str) -> Vec<String> {
|
||||
let mut terms = Vec::new();
|
||||
for raw in text.split(|c: char| !c.is_alphanumeric()) {
|
||||
let term = raw.trim().to_ascii_lowercase();
|
||||
if term.len() >= 3 {
|
||||
terms.push(term);
|
||||
}
|
||||
}
|
||||
terms.sort();
|
||||
terms.dedup();
|
||||
terms
|
||||
}
|
||||
|
||||
fn lexical_overlap_score(terms: &[String], haystack: &str) -> f32 {
|
||||
if terms.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
let lower = haystack.to_ascii_lowercase();
|
||||
let mut matches = 0usize;
|
||||
for term in terms {
|
||||
if lower.contains(term) {
|
||||
matches += 1;
|
||||
}
|
||||
}
|
||||
(matches as f32) / (terms.len() as f32)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct GraphSeed {
|
||||
id: String,
|
||||
|
||||
@@ -68,6 +68,7 @@ where
|
||||
{
|
||||
let embedding_literal = serde_json::to_string(&query_embedding)
|
||||
.map_err(|err| AppError::InternalError(format!("Failed to serialize embedding: {err}")))?;
|
||||
|
||||
let closest_query = format!(
|
||||
"SELECT id, vector::distance::knn() AS distance \
|
||||
FROM {table} \
|
||||
|
||||
Reference in New Issue
Block a user