refactor: implemented state machines for retrieval pipeline, improved tracing

This commit is contained in:
Per Stark
2025-10-18 17:43:10 +02:00
parent 21e4ab1f42
commit 83d39afad4
15 changed files with 899 additions and 566 deletions

View File

@@ -17,9 +17,9 @@ use common::{
},
};
use serde::Deserialize;
use serde_json::{json, Value};
use serde_json::Value;
use crate::{retrieve_entities, RetrievedEntity};
use crate::{retrieve_entities, retrieved_entities_to_json};
use super::answer_retrieval_helper::get_query_response_schema;
@@ -65,7 +65,7 @@ pub async fn get_answer_with_references(
let entities = retrieve_entities(surreal_db_client, openai_client, query, user_id).await?;
let settings = SystemSettings::get_current(surreal_db_client).await?;
let entities_json = format_entities_json(&entities);
let entities_json = retrieved_entities_to_json(&entities);
let user_message = create_user_message(&entities_json, query);
let request = create_chat_request(user_message, &settings)?;
@@ -83,31 +83,6 @@ pub async fn get_answer_with_references(
})
}
pub fn format_entities_json(entities: &[RetrievedEntity]) -> Value {
json!(entities
.iter()
.map(|entry| {
json!({
"KnowledgeEntity": {
"id": entry.entity.id,
"name": entry.entity.name,
"description": entry.entity.description,
"score": round_score(entry.score),
"chunks": entry.chunks.iter().map(|chunk| {
json!({
"score": round_score(chunk.score),
"content": chunk.chunk.chunk
})
}).collect::<Vec<_>>()
}
})
})
.collect::<Vec<_>>())
}
fn round_score(value: f32) -> f64 {
(f64::from(value) * 1000.0).round() / 1000.0
}
pub fn create_user_message(entities_json: &Value, query: &str) -> String {
format!(
r"

View File

@@ -2,44 +2,20 @@ pub mod answer_retrieval;
pub mod answer_retrieval_helper;
pub mod fts;
pub mod graph;
pub mod pipeline;
pub mod scoring;
pub mod vector;
use std::collections::{HashMap, HashSet};
use common::{
error::AppError,
storage::{
db::SurrealDbClient,
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject},
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
},
utils::embedding::generate_embedding,
};
use futures::{stream::FuturesUnordered, StreamExt};
use graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids};
use scoring::{
clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc,
FusionWeights, Scored,
};
use tracing::{debug, instrument, trace};
use tracing::instrument;
use crate::{fts::find_items_by_fts, vector::find_items_by_vector_similarity_with_embedding};
// Tunable knobs controlling first-pass recall, graph expansion, and answer shaping.
const ENTITY_VECTOR_TAKE: usize = 15;
const CHUNK_VECTOR_TAKE: usize = 20;
const ENTITY_FTS_TAKE: usize = 10;
const CHUNK_FTS_TAKE: usize = 20;
const SCORE_THRESHOLD: f32 = 0.35;
const FALLBACK_MIN_RESULTS: usize = 10;
const TOKEN_BUDGET_ESTIMATE: usize = 2800;
const AVG_CHARS_PER_TOKEN: usize = 4;
const MAX_CHUNKS_PER_ENTITY: usize = 4;
const GRAPH_TRAVERSAL_SEED_LIMIT: usize = 5;
const GRAPH_NEIGHBOR_LIMIT: usize = 6;
const GRAPH_SCORE_DECAY: f32 = 0.75;
const GRAPH_SEED_MIN_SCORE: f32 = 0.4;
const GRAPH_VECTOR_INHERITANCE: f32 = 0.6;
pub use pipeline::{retrieved_entities_to_json, RetrievalConfig, RetrievalTuning};
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
#[derive(Debug, Clone)]
@@ -56,435 +32,34 @@ pub struct RetrievedEntity {
pub chunks: Vec<RetrievedChunk>,
}
// Primary orchestrator for the process of retrieving KnowledgeEntitities related to a input_text
#[instrument(skip_all, fields(user_id))]
pub async fn retrieve_entities(
db_client: &SurrealDbClient,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
query: &str,
input_text: &str,
user_id: &str,
) -> Result<Vec<RetrievedEntity>, AppError> {
trace!("Generating query embedding for hybrid retrieval");
let query_embedding = generate_embedding(openai_client, query, db_client).await?;
retrieve_entities_with_embedding(db_client, query_embedding, query, user_id).await
}
pub(crate) async fn retrieve_entities_with_embedding(
db_client: &SurrealDbClient,
query_embedding: Vec<f32>,
query: &str,
user_id: &str,
) -> Result<Vec<RetrievedEntity>, AppError> {
// 1) Gather first-pass candidates from vector search and BM25.
let weights = FusionWeights::default();
let (vector_entities, vector_chunks, mut fts_entities, mut fts_chunks) = tokio::try_join!(
find_items_by_vector_similarity_with_embedding(
ENTITY_VECTOR_TAKE,
query_embedding.clone(),
db_client,
"knowledge_entity",
user_id,
),
find_items_by_vector_similarity_with_embedding(
CHUNK_VECTOR_TAKE,
query_embedding,
db_client,
"text_chunk",
user_id,
),
find_items_by_fts(
ENTITY_FTS_TAKE,
query,
db_client,
"knowledge_entity",
user_id
),
find_items_by_fts(CHUNK_FTS_TAKE, query, db_client, "text_chunk", user_id),
)?;
debug!(
vector_entities = vector_entities.len(),
vector_chunks = vector_chunks.len(),
fts_entities = fts_entities.len(),
fts_chunks = fts_chunks.len(),
"Hybrid retrieval initial candidate counts"
);
normalize_fts_scores(&mut fts_entities);
normalize_fts_scores(&mut fts_chunks);
let mut entity_candidates: HashMap<String, Scored<KnowledgeEntity>> = HashMap::new();
let mut chunk_candidates: HashMap<String, Scored<TextChunk>> = HashMap::new();
// Collate raw retrieval results so each ID accumulates all available signals.
merge_scored_by_id(&mut entity_candidates, vector_entities);
merge_scored_by_id(&mut entity_candidates, fts_entities);
merge_scored_by_id(&mut chunk_candidates, vector_chunks);
merge_scored_by_id(&mut chunk_candidates, fts_chunks);
// 2) Normalize scores, fuse them, and allow high-confidence entities to pull neighbors from the graph.
apply_fusion(&mut entity_candidates, weights);
apply_fusion(&mut chunk_candidates, weights);
enrich_entities_from_graph(&mut entity_candidates, db_client, user_id, weights).await?;
// 3) Track high-signal chunk sources so we can backfill missing entities.
let chunk_by_source = group_chunks_by_source(&chunk_candidates);
let mut missing_sources = Vec::new();
for source_id in chunk_by_source.keys() {
if !entity_candidates
.values()
.any(|entity| entity.item.source_id == *source_id)
{
missing_sources.push(source_id.clone());
}
}
if !missing_sources.is_empty() {
let related_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
missing_sources.clone(),
"knowledge_entity",
user_id,
db_client,
)
.await
.unwrap_or_default();
for entity in related_entities {
if let Some(chunks) = chunk_by_source.get(&entity.source_id) {
let best_chunk_score = chunks
.iter()
.map(|chunk| chunk.fused)
.fold(0.0f32, f32::max);
let mut scored = Scored::new(entity.clone()).with_vector_score(best_chunk_score);
let fused = fuse_scores(&scored.scores, weights);
scored.update_fused(fused);
entity_candidates.insert(entity.id.clone(), scored);
}
}
}
// Boost entities with evidence from high scoring chunks.
for entity in entity_candidates.values_mut() {
if let Some(chunks) = chunk_by_source.get(&entity.item.source_id) {
let best_chunk_score = chunks
.iter()
.map(|chunk| chunk.fused)
.fold(0.0f32, f32::max);
if best_chunk_score > 0.0 {
let boosted = entity.scores.vector.unwrap_or(0.0).max(best_chunk_score);
entity.scores.vector = Some(boosted);
let fused = fuse_scores(&entity.scores, weights);
entity.update_fused(fused);
}
}
}
let mut entity_results: Vec<Scored<KnowledgeEntity>> =
entity_candidates.into_values().collect();
sort_by_fused_desc(&mut entity_results);
let mut filtered_entities: Vec<Scored<KnowledgeEntity>> = entity_results
.iter()
.filter(|candidate| candidate.fused >= SCORE_THRESHOLD)
.cloned()
.collect();
if filtered_entities.len() < FALLBACK_MIN_RESULTS {
// Low recall scenarios still benefit from some context; take the top N regardless of score.
filtered_entities = entity_results
.into_iter()
.take(FALLBACK_MIN_RESULTS)
.collect();
}
// 4) Re-rank chunks and prepare for attachment to surviving entities.
let mut chunk_results: Vec<Scored<TextChunk>> = chunk_candidates.into_values().collect();
sort_by_fused_desc(&mut chunk_results);
let mut chunk_by_id: HashMap<String, Scored<TextChunk>> = HashMap::new();
for chunk in chunk_results {
chunk_by_id.insert(chunk.item.id.clone(), chunk);
}
enrich_chunks_from_entities(
&mut chunk_by_id,
&filtered_entities,
pipeline::run_pipeline(
db_client,
openai_client,
input_text,
user_id,
weights,
RetrievalConfig::default(),
)
.await?;
let mut chunk_values: Vec<Scored<TextChunk>> = chunk_by_id.into_values().collect();
sort_by_fused_desc(&mut chunk_values);
Ok(assemble_results(filtered_entities, chunk_values))
}
// Minimal record used while seeding graph expansion so we can retain the original fused score.
#[derive(Clone)]
struct GraphSeed {
id: String,
fused: f32,
}
async fn enrich_entities_from_graph(
entity_candidates: &mut HashMap<String, Scored<KnowledgeEntity>>,
db_client: &SurrealDbClient,
user_id: &str,
weights: FusionWeights,
) -> Result<(), AppError> {
if entity_candidates.is_empty() {
return Ok(());
}
// Select a small frontier of high-confidence entities to seed the relationship walk.
let mut seeds: Vec<GraphSeed> = entity_candidates
.values()
.filter(|entity| entity.fused >= GRAPH_SEED_MIN_SCORE)
.map(|entity| GraphSeed {
id: entity.item.id.clone(),
fused: entity.fused,
})
.collect();
if seeds.is_empty() {
return Ok(());
}
// Prioritise the strongest seeds so we explore the most grounded context first.
seeds.sort_by(|a, b| {
b.fused
.partial_cmp(&a.fused)
.unwrap_or(std::cmp::Ordering::Equal)
});
seeds.truncate(GRAPH_TRAVERSAL_SEED_LIMIT);
let mut futures = FuturesUnordered::new();
for seed in seeds.clone() {
let user_id = user_id.to_owned();
futures.push(async move {
// Fetch neighbors concurrently to avoid serial graph round trips.
let neighbors = find_entities_by_relationship_by_id(
db_client,
&seed.id,
&user_id,
GRAPH_NEIGHBOR_LIMIT,
)
.await;
(seed, neighbors)
});
}
while let Some((seed, neighbors_result)) = futures.next().await {
let neighbors = neighbors_result.map_err(AppError::from)?;
if neighbors.is_empty() {
continue;
}
// Fold neighbors back into the candidate map and let them inherit attenuated signal.
for neighbor in neighbors {
if neighbor.id == seed.id {
continue;
}
let graph_score = clamp_unit(seed.fused * GRAPH_SCORE_DECAY);
let entry = entity_candidates
.entry(neighbor.id.clone())
.or_insert_with(|| Scored::new(neighbor.clone()));
entry.item = neighbor;
let inherited_vector = clamp_unit(graph_score * GRAPH_VECTOR_INHERITANCE);
let vector_existing = entry.scores.vector.unwrap_or(0.0);
if inherited_vector > vector_existing {
entry.scores.vector = Some(inherited_vector);
}
let existing_graph = entry.scores.graph.unwrap_or(f32::MIN);
if graph_score > existing_graph || entry.scores.graph.is_none() {
entry.scores.graph = Some(graph_score);
}
let fused = fuse_scores(&entry.scores, weights);
entry.update_fused(fused);
}
}
Ok(())
}
fn normalize_fts_scores<T>(results: &mut [Scored<T>]) {
// Scale BM25 outputs into [0,1] to keep fusion weights predictable.
let raw_scores: Vec<f32> = results
.iter()
.map(|candidate| candidate.scores.fts.unwrap_or(0.0))
.collect();
let normalized = min_max_normalize(&raw_scores);
for (candidate, normalized_score) in results.iter_mut().zip(normalized.into_iter()) {
candidate.scores.fts = Some(normalized_score);
candidate.update_fused(0.0);
}
}
fn apply_fusion<T>(candidates: &mut HashMap<String, Scored<T>>, weights: FusionWeights)
where
T: StoredObject,
{
// Collapse individual signals into a single fused score used for ranking.
for candidate in candidates.values_mut() {
let fused = fuse_scores(&candidate.scores, weights);
candidate.update_fused(fused);
}
}
fn group_chunks_by_source(
chunks: &HashMap<String, Scored<TextChunk>>,
) -> HashMap<String, Vec<Scored<TextChunk>>> {
// Preserve chunk candidates keyed by their originating source entity.
let mut by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
for chunk in chunks.values() {
by_source
.entry(chunk.item.source_id.clone())
.or_default()
.push(chunk.clone());
}
by_source
}
async fn enrich_chunks_from_entities(
chunk_candidates: &mut HashMap<String, Scored<TextChunk>>,
entities: &[Scored<KnowledgeEntity>],
db_client: &SurrealDbClient,
user_id: &str,
weights: FusionWeights,
) -> Result<(), AppError> {
// Fetch additional chunks referenced by entities that survived the fusion stage.
let mut source_ids: HashSet<String> = HashSet::new();
for entity in entities {
source_ids.insert(entity.item.source_id.clone());
}
if source_ids.is_empty() {
return Ok(());
}
let chunks = find_entities_by_source_ids::<TextChunk>(
source_ids.into_iter().collect(),
"text_chunk",
user_id,
db_client,
)
.await?;
let mut entity_score_lookup: HashMap<String, f32> = HashMap::new();
// Cache fused scores per source so chunks inherit the strength of their parent entity.
for entity in entities {
entity_score_lookup.insert(entity.item.source_id.clone(), entity.fused);
}
for chunk in chunks {
// Ensure each chunk is represented so downstream selection sees the latest content.
let entry = chunk_candidates
.entry(chunk.id.clone())
.or_insert_with(|| Scored::new(chunk.clone()).with_vector_score(0.0));
let entity_score = entity_score_lookup
.get(&chunk.source_id)
.copied()
.unwrap_or(0.0);
// Lift chunk score toward the entity score so supporting evidence is prioritised.
entry.scores.vector = Some(entry.scores.vector.unwrap_or(0.0).max(entity_score * 0.8));
let fused = fuse_scores(&entry.scores, weights);
entry.update_fused(fused);
entry.item = chunk;
}
Ok(())
}
fn assemble_results(
entities: Vec<Scored<KnowledgeEntity>>,
mut chunks: Vec<Scored<TextChunk>>,
) -> Vec<RetrievedEntity> {
// Re-associate chunk candidates with their parent entity for ranked selection.
let mut chunk_by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
for chunk in chunks.drain(..) {
chunk_by_source
.entry(chunk.item.source_id.clone())
.or_default()
.push(chunk);
}
for chunk_list in chunk_by_source.values_mut() {
sort_by_fused_desc(chunk_list);
}
let mut token_budget_remaining = TOKEN_BUDGET_ESTIMATE;
let mut results = Vec::new();
for entity in entities {
// Attach best chunks first while respecting per-entity and global token caps.
let mut selected_chunks = Vec::new();
if let Some(candidates) = chunk_by_source.get_mut(&entity.item.source_id) {
let mut per_entity_count = 0;
candidates.sort_by(|a, b| {
b.fused
.partial_cmp(&a.fused)
.unwrap_or(std::cmp::Ordering::Equal)
});
for candidate in candidates.iter() {
if per_entity_count >= MAX_CHUNKS_PER_ENTITY {
break;
}
let estimated_tokens = estimate_tokens(&candidate.item.chunk);
if estimated_tokens > token_budget_remaining {
continue;
}
token_budget_remaining = token_budget_remaining.saturating_sub(estimated_tokens);
per_entity_count += 1;
selected_chunks.push(RetrievedChunk {
chunk: candidate.item.clone(),
score: candidate.fused,
});
}
}
results.push(RetrievedEntity {
entity: entity.item.clone(),
score: entity.fused,
chunks: selected_chunks,
});
if token_budget_remaining == 0 {
break;
}
}
results
}
fn estimate_tokens(text: &str) -> usize {
// Simple heuristic to avoid calling a tokenizer in hot code paths.
let chars = text.chars().count().max(1);
(chars / AVG_CHARS_PER_TOKEN).max(1)
.await
}
#[cfg(test)]
mod tests {
use super::*;
use async_openai::Client;
use common::storage::types::{
knowledge_entity::KnowledgeEntityType, knowledge_relationship::KnowledgeRelationship,
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
knowledge_relationship::KnowledgeRelationship,
text_chunk::TextChunk,
};
use pipeline::RetrievalConfig;
use uuid::Uuid;
fn test_embedding() -> Vec<f32> {
@@ -527,70 +102,46 @@ mod tests {
COMMIT TRANSACTION;",
)
.await
.expect("Failed to redefine vector indexes for tests");
.expect("Failed to configure indices");
db
}
async fn seed_test_data(db: &SurrealDbClient, user_id: &str) {
let entity_relevant = KnowledgeEntity::new(
"source_a".into(),
"Rust Concurrency Patterns".into(),
"Discussion about async concurrency in Rust.".into(),
#[tokio::test]
async fn test_retrieve_entities_with_embedding_basic_flow() {
let db = setup_test_db().await;
let user_id = "test_user";
let entity = KnowledgeEntity::new(
"source_1".into(),
"Rust async guide".into(),
"Detailed notes about async runtimes".into(),
KnowledgeEntityType::Document,
None,
entity_embedding_high(),
user_id.into(),
);
let entity_irrelevant = KnowledgeEntity::new(
"source_b".into(),
"Python Tips".into(),
"General Python programming tips.".into(),
KnowledgeEntityType::Document,
None,
entity_embedding_low(),
user_id.into(),
);
db.store_item(entity_relevant.clone())
.await
.expect("Failed to store relevant entity");
db.store_item(entity_irrelevant.clone())
.await
.expect("Failed to store irrelevant entity");
let chunk_primary = TextChunk::new(
entity_relevant.source_id.clone(),
"Tokio enables async concurrency with lightweight tasks.".into(),
let chunk = TextChunk::new(
entity.source_id.clone(),
"Tokio uses cooperative scheduling for fairness.".into(),
chunk_embedding_primary(),
user_id.into(),
);
let chunk_secondary = TextChunk::new(
entity_irrelevant.source_id.clone(),
"Python focuses on readability and dynamic typing.".into(),
chunk_embedding_secondary(),
user_id.into(),
);
db.store_item(chunk_primary)
db.store_item(entity.clone())
.await
.expect("Failed to store primary chunk");
db.store_item(chunk_secondary)
.expect("Failed to store entity");
db.store_item(chunk.clone())
.await
.expect("Failed to store secondary chunk");
}
.expect("Failed to store chunk");
#[tokio::test]
async fn test_hybrid_retrieval_prioritises_relevant_entity() {
let db = setup_test_db().await;
let user_id = "user123";
seed_test_data(&db, user_id).await;
let results = retrieve_entities_with_embedding(
let openai_client = Client::new();
let results = pipeline::run_pipeline_with_embedding(
&db,
&openai_client,
test_embedding(),
"Rust concurrency async tasks",
user_id,
RetrievalConfig::default(),
)
.await
.expect("Hybrid retrieval failed");
@@ -599,9 +150,7 @@ mod tests {
!results.is_empty(),
"Expected at least one retrieval result"
);
let top = &results[0];
assert!(
top.entity.name.contains("Rust"),
"Expected Rust entity to be ranked first"
@@ -610,16 +159,6 @@ mod tests {
!top.chunks.is_empty(),
"Expected Rust entity to include supporting chunks"
);
let chunk_texts: Vec<&str> = top
.chunks
.iter()
.map(|chunk| chunk.chunk.chunk.as_str())
.collect();
assert!(
chunk_texts.iter().any(|text| text.contains("Tokio")),
"Expected chunk discussing Tokio to be included"
);
}
#[tokio::test]
@@ -673,6 +212,7 @@ mod tests {
.await
.expect("Failed to store neighbor chunk");
let openai_client = Client::new();
let relationship = KnowledgeRelationship::new(
primary.id.clone(),
neighbor.id.clone(),
@@ -685,11 +225,13 @@ mod tests {
.await
.expect("Failed to store relationship");
let results = retrieve_entities_with_embedding(
let results = pipeline::run_pipeline_with_embedding(
&db,
&openai_client,
test_embedding(),
"Rust concurrency async tasks",
user_id,
RetrievalConfig::default(),
)
.await
.expect("Hybrid retrieval failed");

View File

@@ -0,0 +1,61 @@
use serde::{Deserialize, Serialize};
/// Tunable parameters that govern each retrieval stage.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetrievalTuning {
pub entity_vector_take: usize,
pub chunk_vector_take: usize,
pub entity_fts_take: usize,
pub chunk_fts_take: usize,
pub score_threshold: f32,
pub fallback_min_results: usize,
pub token_budget_estimate: usize,
pub avg_chars_per_token: usize,
pub max_chunks_per_entity: usize,
pub graph_traversal_seed_limit: usize,
pub graph_neighbor_limit: usize,
pub graph_score_decay: f32,
pub graph_seed_min_score: f32,
pub graph_vector_inheritance: f32,
}
impl Default for RetrievalTuning {
fn default() -> Self {
Self {
entity_vector_take: 15,
chunk_vector_take: 20,
entity_fts_take: 10,
chunk_fts_take: 20,
score_threshold: 0.35,
fallback_min_results: 10,
token_budget_estimate: 2800,
avg_chars_per_token: 4,
max_chunks_per_entity: 4,
graph_traversal_seed_limit: 5,
graph_neighbor_limit: 6,
graph_score_decay: 0.75,
graph_seed_min_score: 0.4,
graph_vector_inheritance: 0.6,
}
}
}
/// Wrapper containing tuning plus future flags for per-request overrides.
#[derive(Debug, Clone)]
pub struct RetrievalConfig {
pub tuning: RetrievalTuning,
}
impl RetrievalConfig {
pub fn new(tuning: RetrievalTuning) -> Self {
Self { tuning }
}
}
impl Default for RetrievalConfig {
fn default() -> Self {
Self {
tuning: RetrievalTuning::default(),
}
}
}

View File

@@ -0,0 +1,100 @@
mod config;
mod stages;
mod state;
pub use config::{RetrievalConfig, RetrievalTuning};
use crate::RetrievedEntity;
use async_openai::Client;
use common::{error::AppError, storage::db::SurrealDbClient};
use tracing::info;
/// Drives the retrieval pipeline from embedding through final assembly.
pub async fn run_pipeline(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
) -> Result<Vec<RetrievedEntity>, AppError> {
let machine = state::ready();
let input_chars = input_text.chars().count();
let input_preview: String = input_text.chars().take(120).collect();
let input_preview_clean = input_preview.replace('\n', " ");
let preview_len = input_preview_clean.chars().count();
info!(
%user_id,
input_chars,
preview_truncated = input_chars > preview_len,
preview = %input_preview_clean,
"Starting ingestion retrieval pipeline"
);
let mut ctx = stages::PipelineContext::new(
db_client,
openai_client,
input_text.to_owned(),
user_id.to_owned(),
config,
);
let machine = stages::embed(machine, &mut ctx).await?;
let machine = stages::collect_candidates(machine, &mut ctx).await?;
let machine = stages::expand_graph(machine, &mut ctx).await?;
let machine = stages::attach_chunks(machine, &mut ctx).await?;
let results = stages::assemble(machine, &mut ctx)?;
Ok(results)
}
#[cfg(test)]
pub async fn run_pipeline_with_embedding(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
) -> Result<Vec<RetrievedEntity>, AppError> {
let machine = state::ready();
let mut ctx = stages::PipelineContext::with_embedding(
db_client,
openai_client,
query_embedding,
input_text.to_owned(),
user_id.to_owned(),
config,
);
let machine = stages::embed(machine, &mut ctx).await?;
let machine = stages::collect_candidates(machine, &mut ctx).await?;
let machine = stages::expand_graph(machine, &mut ctx).await?;
let machine = stages::attach_chunks(machine, &mut ctx).await?;
let results = stages::assemble(machine, &mut ctx)?;
Ok(results)
}
/// Helper exposed for tests to convert retrieved entities into downstream prompt JSON.
pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value {
serde_json::json!(entities
.iter()
.map(|entry| {
serde_json::json!({
"KnowledgeEntity": {
"id": entry.entity.id,
"name": entry.entity.name,
"description": entry.entity.description,
"score": round_score(entry.score),
"chunks": entry.chunks.iter().map(|chunk| {
serde_json::json!({
"score": round_score(chunk.score),
"content": chunk.chunk.chunk
})
}).collect::<Vec<_>>()
}
})
})
.collect::<Vec<_>>())
}
fn round_score(value: f32) -> f64 {
(f64::from(value) * 1000.0).round() / 1000.0
}

View File

@@ -0,0 +1,599 @@
use async_openai::Client;
use common::{
error::AppError,
storage::{
db::SurrealDbClient,
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject},
},
utils::embedding::generate_embedding,
};
use futures::{stream::FuturesUnordered, StreamExt};
use state_machines::core::GuardError;
use std::collections::{HashMap, HashSet};
use tracing::{debug, instrument, warn};
use crate::{
fts::find_items_by_fts,
graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids},
scoring::{
clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc,
FusionWeights, Scored,
},
vector::find_items_by_vector_similarity_with_embedding,
RetrievedChunk, RetrievedEntity,
};
use super::{
config::RetrievalConfig,
state::{
CandidatesLoaded, ChunksAttached, Embedded, GraphExpanded, HybridRetrievalMachine, Ready,
},
};
pub struct PipelineContext<'a> {
pub db_client: &'a SurrealDbClient,
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
pub input_text: String,
pub user_id: String,
pub config: RetrievalConfig,
pub query_embedding: Option<Vec<f32>>,
pub entity_candidates: HashMap<String, Scored<KnowledgeEntity>>,
pub chunk_candidates: HashMap<String, Scored<TextChunk>>,
pub filtered_entities: Vec<Scored<KnowledgeEntity>>,
pub chunk_values: Vec<Scored<TextChunk>>,
}
impl<'a> PipelineContext<'a> {
pub fn new(
db_client: &'a SurrealDbClient,
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
input_text: String,
user_id: String,
config: RetrievalConfig,
) -> Self {
Self {
db_client,
openai_client,
input_text,
user_id,
config,
query_embedding: None,
entity_candidates: HashMap::new(),
chunk_candidates: HashMap::new(),
filtered_entities: Vec::new(),
chunk_values: Vec::new(),
}
}
#[cfg(test)]
pub fn with_embedding(
db_client: &'a SurrealDbClient,
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
query_embedding: Vec<f32>,
input_text: String,
user_id: String,
config: RetrievalConfig,
) -> Self {
let mut ctx = Self::new(db_client, openai_client, input_text, user_id, config);
ctx.query_embedding = Some(query_embedding);
ctx
}
fn ensure_embedding(&self) -> Result<&Vec<f32>, AppError> {
self.query_embedding.as_ref().ok_or_else(|| {
AppError::InternalError(
"query embedding missing before candidate collection".to_string(),
)
})
}
}
#[instrument(level = "trace", skip_all)]
pub async fn embed(
machine: HybridRetrievalMachine<(), Ready>,
ctx: &mut PipelineContext<'_>,
) -> Result<HybridRetrievalMachine<(), Embedded>, AppError> {
let embedding_cached = ctx.query_embedding.is_some();
if embedding_cached {
debug!("Reusing cached query embedding for hybrid retrieval");
} else {
debug!("Generating query embedding for hybrid retrieval");
let embedding =
generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await?;
ctx.query_embedding = Some(embedding);
}
machine
.embed()
.map_err(|(_, guard)| map_guard_error("embed", guard))
}
#[instrument(level = "trace", skip_all)]
pub async fn collect_candidates(
machine: HybridRetrievalMachine<(), Embedded>,
ctx: &mut PipelineContext<'_>,
) -> Result<HybridRetrievalMachine<(), CandidatesLoaded>, AppError> {
debug!("Collecting initial candidates via vector and FTS search");
let embedding = ctx.ensure_embedding()?.clone();
let tuning = &ctx.config.tuning;
let weights = FusionWeights::default();
let (vector_entities, vector_chunks, mut fts_entities, mut fts_chunks) = tokio::try_join!(
find_items_by_vector_similarity_with_embedding(
tuning.entity_vector_take,
embedding.clone(),
ctx.db_client,
"knowledge_entity",
&ctx.user_id,
),
find_items_by_vector_similarity_with_embedding(
tuning.chunk_vector_take,
embedding,
ctx.db_client,
"text_chunk",
&ctx.user_id,
),
find_items_by_fts(
tuning.entity_fts_take,
&ctx.input_text,
ctx.db_client,
"knowledge_entity",
&ctx.user_id,
),
find_items_by_fts(
tuning.chunk_fts_take,
&ctx.input_text,
ctx.db_client,
"text_chunk",
&ctx.user_id
),
)?;
debug!(
vector_entities = vector_entities.len(),
vector_chunks = vector_chunks.len(),
fts_entities = fts_entities.len(),
fts_chunks = fts_chunks.len(),
"Hybrid retrieval initial candidate counts"
);
normalize_fts_scores(&mut fts_entities);
normalize_fts_scores(&mut fts_chunks);
merge_scored_by_id(&mut ctx.entity_candidates, vector_entities);
merge_scored_by_id(&mut ctx.entity_candidates, fts_entities);
merge_scored_by_id(&mut ctx.chunk_candidates, vector_chunks);
merge_scored_by_id(&mut ctx.chunk_candidates, fts_chunks);
apply_fusion(&mut ctx.entity_candidates, weights);
apply_fusion(&mut ctx.chunk_candidates, weights);
machine
.collect_candidates()
.map_err(|(_, guard)| map_guard_error("collect_candidates", guard))
}
#[instrument(level = "trace", skip_all)]
pub async fn expand_graph(
machine: HybridRetrievalMachine<(), CandidatesLoaded>,
ctx: &mut PipelineContext<'_>,
) -> Result<HybridRetrievalMachine<(), GraphExpanded>, AppError> {
debug!("Expanding candidates using graph relationships");
let tuning = &ctx.config.tuning;
let weights = FusionWeights::default();
if ctx.entity_candidates.is_empty() {
return machine
.expand_graph()
.map_err(|(_, guard)| map_guard_error("expand_graph", guard));
}
let graph_seeds = seeds_from_candidates(
&ctx.entity_candidates,
tuning.graph_seed_min_score,
tuning.graph_traversal_seed_limit,
);
if graph_seeds.is_empty() {
return machine
.expand_graph()
.map_err(|(_, guard)| map_guard_error("expand_graph", guard));
}
let mut futures = FuturesUnordered::new();
for seed in graph_seeds {
let db = ctx.db_client;
let user = ctx.user_id.clone();
futures.push(async move {
let neighbors = find_entities_by_relationship_by_id(
db,
&seed.id,
&user,
tuning.graph_neighbor_limit,
)
.await;
(seed, neighbors)
});
}
while let Some((seed, neighbors_result)) = futures.next().await {
let neighbors = neighbors_result.map_err(AppError::from)?;
if neighbors.is_empty() {
continue;
}
for neighbor in neighbors {
if neighbor.id == seed.id {
continue;
}
let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay);
let entry = ctx
.entity_candidates
.entry(neighbor.id.clone())
.or_insert_with(|| Scored::new(neighbor.clone()));
entry.item = neighbor;
let inherited_vector = clamp_unit(graph_score * tuning.graph_vector_inheritance);
let vector_existing = entry.scores.vector.unwrap_or(0.0);
if inherited_vector > vector_existing {
entry.scores.vector = Some(inherited_vector);
}
let existing_graph = entry.scores.graph.unwrap_or(f32::MIN);
if graph_score > existing_graph || entry.scores.graph.is_none() {
entry.scores.graph = Some(graph_score);
}
let fused = fuse_scores(&entry.scores, weights);
entry.update_fused(fused);
}
}
machine
.expand_graph()
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
}
#[instrument(level = "trace", skip_all)]
pub async fn attach_chunks(
machine: HybridRetrievalMachine<(), GraphExpanded>,
ctx: &mut PipelineContext<'_>,
) -> Result<HybridRetrievalMachine<(), ChunksAttached>, AppError> {
debug!("Attaching chunks to surviving entities");
let tuning = &ctx.config.tuning;
let weights = FusionWeights::default();
let chunk_by_source = group_chunks_by_source(&ctx.chunk_candidates);
backfill_entities_from_chunks(
&mut ctx.entity_candidates,
&chunk_by_source,
ctx.db_client,
&ctx.user_id,
weights,
)
.await?;
boost_entities_with_chunks(&mut ctx.entity_candidates, &chunk_by_source, weights);
let mut entity_results: Vec<Scored<KnowledgeEntity>> =
ctx.entity_candidates.values().cloned().collect();
sort_by_fused_desc(&mut entity_results);
let mut filtered_entities: Vec<Scored<KnowledgeEntity>> = entity_results
.iter()
.filter(|candidate| candidate.fused >= tuning.score_threshold)
.cloned()
.collect();
if filtered_entities.len() < tuning.fallback_min_results {
filtered_entities = entity_results
.into_iter()
.take(tuning.fallback_min_results)
.collect();
}
ctx.filtered_entities = filtered_entities;
let mut chunk_results: Vec<Scored<TextChunk>> =
ctx.chunk_candidates.values().cloned().collect();
sort_by_fused_desc(&mut chunk_results);
let mut chunk_by_id: HashMap<String, Scored<TextChunk>> = HashMap::new();
for chunk in chunk_results {
chunk_by_id.insert(chunk.item.id.clone(), chunk);
}
enrich_chunks_from_entities(
&mut chunk_by_id,
&ctx.filtered_entities,
ctx.db_client,
&ctx.user_id,
weights,
)
.await?;
let mut chunk_values: Vec<Scored<TextChunk>> = chunk_by_id.into_values().collect();
sort_by_fused_desc(&mut chunk_values);
ctx.chunk_values = chunk_values;
machine
.attach_chunks()
.map_err(|(_, guard)| map_guard_error("attach_chunks", guard))
}
#[instrument(level = "trace", skip_all)]
pub fn assemble(
machine: HybridRetrievalMachine<(), ChunksAttached>,
ctx: &mut PipelineContext<'_>,
) -> Result<Vec<RetrievedEntity>, AppError> {
debug!("Assembling final retrieved entities");
let tuning = &ctx.config.tuning;
let mut chunk_by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
for chunk in ctx.chunk_values.drain(..) {
chunk_by_source
.entry(chunk.item.source_id.clone())
.or_default()
.push(chunk);
}
for chunk_list in chunk_by_source.values_mut() {
sort_by_fused_desc(chunk_list);
}
let mut token_budget_remaining = tuning.token_budget_estimate;
let mut results = Vec::new();
for entity in &ctx.filtered_entities {
let mut selected_chunks = Vec::new();
if let Some(candidates) = chunk_by_source.get_mut(&entity.item.source_id) {
let mut per_entity_count = 0;
candidates.sort_by(|a, b| {
b.fused
.partial_cmp(&a.fused)
.unwrap_or(std::cmp::Ordering::Equal)
});
for candidate in candidates.iter() {
if per_entity_count >= tuning.max_chunks_per_entity {
break;
}
let estimated_tokens =
estimate_tokens(&candidate.item.chunk, tuning.avg_chars_per_token);
if estimated_tokens > token_budget_remaining {
continue;
}
token_budget_remaining = token_budget_remaining.saturating_sub(estimated_tokens);
per_entity_count += 1;
selected_chunks.push(RetrievedChunk {
chunk: candidate.item.clone(),
score: candidate.fused,
});
}
}
results.push(RetrievedEntity {
entity: entity.item.clone(),
score: entity.fused,
chunks: selected_chunks,
});
if token_budget_remaining == 0 {
break;
}
}
machine
.assemble()
.map_err(|(_, guard)| map_guard_error("assemble", guard))?;
Ok(results)
}
fn map_guard_error(stage: &'static str, err: GuardError) -> AppError {
AppError::InternalError(format!(
"state machine guard '{stage}' failed: guard={}, event={}, kind={:?}",
err.guard, err.event, err.kind
))
}
fn normalize_fts_scores<T>(results: &mut [Scored<T>]) {
let raw_scores: Vec<f32> = results
.iter()
.map(|candidate| candidate.scores.fts.unwrap_or(0.0))
.collect();
let normalized = min_max_normalize(&raw_scores);
for (candidate, normalized_score) in results.iter_mut().zip(normalized.into_iter()) {
candidate.scores.fts = Some(normalized_score);
candidate.update_fused(0.0);
}
}
fn apply_fusion<T>(candidates: &mut HashMap<String, Scored<T>>, weights: FusionWeights)
where
T: StoredObject,
{
for candidate in candidates.values_mut() {
let fused = fuse_scores(&candidate.scores, weights);
candidate.update_fused(fused);
}
}
fn group_chunks_by_source(
chunks: &HashMap<String, Scored<TextChunk>>,
) -> HashMap<String, Vec<Scored<TextChunk>>> {
let mut by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
for chunk in chunks.values() {
by_source
.entry(chunk.item.source_id.clone())
.or_default()
.push(chunk.clone());
}
by_source
}
async fn backfill_entities_from_chunks(
entity_candidates: &mut HashMap<String, Scored<KnowledgeEntity>>,
chunk_by_source: &HashMap<String, Vec<Scored<TextChunk>>>,
db_client: &SurrealDbClient,
user_id: &str,
weights: FusionWeights,
) -> Result<(), AppError> {
let mut missing_sources = Vec::new();
for source_id in chunk_by_source.keys() {
if !entity_candidates
.values()
.any(|entity| entity.item.source_id == *source_id)
{
missing_sources.push(source_id.clone());
}
}
if missing_sources.is_empty() {
return Ok(());
}
let related_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
missing_sources.clone(),
"knowledge_entity",
user_id,
db_client,
)
.await
.unwrap_or_default();
if related_entities.is_empty() {
warn!("expected related entities for missing chunk sources, but none were found");
}
for entity in related_entities {
if let Some(chunks) = chunk_by_source.get(&entity.source_id) {
let best_chunk_score = chunks
.iter()
.map(|chunk| chunk.fused)
.fold(0.0f32, f32::max);
let mut scored = Scored::new(entity.clone()).with_vector_score(best_chunk_score);
let fused = fuse_scores(&scored.scores, weights);
scored.update_fused(fused);
entity_candidates.insert(entity.id.clone(), scored);
}
}
Ok(())
}
fn boost_entities_with_chunks(
entity_candidates: &mut HashMap<String, Scored<KnowledgeEntity>>,
chunk_by_source: &HashMap<String, Vec<Scored<TextChunk>>>,
weights: FusionWeights,
) {
for entity in entity_candidates.values_mut() {
if let Some(chunks) = chunk_by_source.get(&entity.item.source_id) {
let best_chunk_score = chunks
.iter()
.map(|chunk| chunk.fused)
.fold(0.0f32, f32::max);
if best_chunk_score > 0.0 {
let boosted = entity.scores.vector.unwrap_or(0.0).max(best_chunk_score);
entity.scores.vector = Some(boosted);
let fused = fuse_scores(&entity.scores, weights);
entity.update_fused(fused);
}
}
}
}
async fn enrich_chunks_from_entities(
chunk_candidates: &mut HashMap<String, Scored<TextChunk>>,
entities: &[Scored<KnowledgeEntity>],
db_client: &SurrealDbClient,
user_id: &str,
weights: FusionWeights,
) -> Result<(), AppError> {
let mut source_ids: HashSet<String> = HashSet::new();
for entity in entities {
source_ids.insert(entity.item.source_id.clone());
}
if source_ids.is_empty() {
return Ok(());
}
let chunks = find_entities_by_source_ids::<TextChunk>(
source_ids.into_iter().collect(),
"text_chunk",
user_id,
db_client,
)
.await?;
let mut entity_score_lookup: HashMap<String, f32> = HashMap::new();
for entity in entities {
entity_score_lookup.insert(entity.item.source_id.clone(), entity.fused);
}
for chunk in chunks {
let entry = chunk_candidates
.entry(chunk.id.clone())
.or_insert_with(|| Scored::new(chunk.clone()).with_vector_score(0.0));
let entity_score = entity_score_lookup
.get(&chunk.source_id)
.copied()
.unwrap_or(0.0);
entry.scores.vector = Some(entry.scores.vector.unwrap_or(0.0).max(entity_score * 0.8));
let fused = fuse_scores(&entry.scores, weights);
entry.update_fused(fused);
entry.item = chunk;
}
Ok(())
}
fn estimate_tokens(text: &str, avg_chars_per_token: usize) -> usize {
let chars = text.chars().count().max(1);
(chars / avg_chars_per_token).max(1)
}
#[derive(Clone)]
struct GraphSeed {
id: String,
fused: f32,
}
fn seeds_from_candidates(
entity_candidates: &HashMap<String, Scored<KnowledgeEntity>>,
min_score: f32,
limit: usize,
) -> Vec<GraphSeed> {
let mut seeds: Vec<GraphSeed> = entity_candidates
.values()
.filter(|entity| entity.fused >= min_score)
.map(|entity| GraphSeed {
id: entity.item.id.clone(),
fused: entity.fused,
})
.collect();
seeds.sort_by(|a, b| {
b.fused
.partial_cmp(&a.fused)
.unwrap_or(std::cmp::Ordering::Equal)
});
if seeds.len() > limit {
seeds.truncate(limit);
}
seeds
}

View File

@@ -0,0 +1,25 @@
use state_machines::state_machine;
state_machine! {
name: HybridRetrievalMachine,
state: HybridRetrievalState,
initial: Ready,
states: [Ready, Embedded, CandidatesLoaded, GraphExpanded, ChunksAttached, Completed, Failed],
events {
embed { transition: { from: Ready, to: Embedded } }
collect_candidates { transition: { from: Embedded, to: CandidatesLoaded } }
expand_graph { transition: { from: CandidatesLoaded, to: GraphExpanded } }
attach_chunks { transition: { from: GraphExpanded, to: ChunksAttached } }
assemble { transition: { from: ChunksAttached, to: Completed } }
abort {
transition: { from: Ready, to: Failed }
transition: { from: CandidatesLoaded, to: Failed }
transition: { from: GraphExpanded, to: Failed }
transition: { from: ChunksAttached, to: Failed }
}
}
}
pub fn ready() -> HybridRetrievalMachine<(), Ready> {
HybridRetrievalMachine::new(())
}