mirror of
https://github.com/perstarkse/minne.git
synced 2026-04-11 03:36:53 +02:00
refactor: implemented state machines for retrieval pipeline, improved tracing
This commit is contained in:
@@ -17,9 +17,9 @@ use common::{
|
||||
},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::{retrieve_entities, RetrievedEntity};
|
||||
use crate::{retrieve_entities, retrieved_entities_to_json};
|
||||
|
||||
use super::answer_retrieval_helper::get_query_response_schema;
|
||||
|
||||
@@ -65,7 +65,7 @@ pub async fn get_answer_with_references(
|
||||
let entities = retrieve_entities(surreal_db_client, openai_client, query, user_id).await?;
|
||||
let settings = SystemSettings::get_current(surreal_db_client).await?;
|
||||
|
||||
let entities_json = format_entities_json(&entities);
|
||||
let entities_json = retrieved_entities_to_json(&entities);
|
||||
let user_message = create_user_message(&entities_json, query);
|
||||
|
||||
let request = create_chat_request(user_message, &settings)?;
|
||||
@@ -83,31 +83,6 @@ pub async fn get_answer_with_references(
|
||||
})
|
||||
}
|
||||
|
||||
pub fn format_entities_json(entities: &[RetrievedEntity]) -> Value {
|
||||
json!(entities
|
||||
.iter()
|
||||
.map(|entry| {
|
||||
json!({
|
||||
"KnowledgeEntity": {
|
||||
"id": entry.entity.id,
|
||||
"name": entry.entity.name,
|
||||
"description": entry.entity.description,
|
||||
"score": round_score(entry.score),
|
||||
"chunks": entry.chunks.iter().map(|chunk| {
|
||||
json!({
|
||||
"score": round_score(chunk.score),
|
||||
"content": chunk.chunk.chunk
|
||||
})
|
||||
}).collect::<Vec<_>>()
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
}
|
||||
|
||||
fn round_score(value: f32) -> f64 {
|
||||
(f64::from(value) * 1000.0).round() / 1000.0
|
||||
}
|
||||
pub fn create_user_message(entities_json: &Value, query: &str) -> String {
|
||||
format!(
|
||||
r"
|
||||
|
||||
@@ -2,44 +2,20 @@ pub mod answer_retrieval;
|
||||
pub mod answer_retrieval_helper;
|
||||
pub mod fts;
|
||||
pub mod graph;
|
||||
pub mod pipeline;
|
||||
pub mod scoring;
|
||||
pub mod vector;
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject},
|
||||
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
|
||||
},
|
||||
utils::embedding::generate_embedding,
|
||||
};
|
||||
use futures::{stream::FuturesUnordered, StreamExt};
|
||||
use graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids};
|
||||
use scoring::{
|
||||
clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc,
|
||||
FusionWeights, Scored,
|
||||
};
|
||||
use tracing::{debug, instrument, trace};
|
||||
use tracing::instrument;
|
||||
|
||||
use crate::{fts::find_items_by_fts, vector::find_items_by_vector_similarity_with_embedding};
|
||||
|
||||
// Tunable knobs controlling first-pass recall, graph expansion, and answer shaping.
|
||||
const ENTITY_VECTOR_TAKE: usize = 15;
|
||||
const CHUNK_VECTOR_TAKE: usize = 20;
|
||||
const ENTITY_FTS_TAKE: usize = 10;
|
||||
const CHUNK_FTS_TAKE: usize = 20;
|
||||
const SCORE_THRESHOLD: f32 = 0.35;
|
||||
const FALLBACK_MIN_RESULTS: usize = 10;
|
||||
const TOKEN_BUDGET_ESTIMATE: usize = 2800;
|
||||
const AVG_CHARS_PER_TOKEN: usize = 4;
|
||||
const MAX_CHUNKS_PER_ENTITY: usize = 4;
|
||||
const GRAPH_TRAVERSAL_SEED_LIMIT: usize = 5;
|
||||
const GRAPH_NEIGHBOR_LIMIT: usize = 6;
|
||||
const GRAPH_SCORE_DECAY: f32 = 0.75;
|
||||
const GRAPH_SEED_MIN_SCORE: f32 = 0.4;
|
||||
const GRAPH_VECTOR_INHERITANCE: f32 = 0.6;
|
||||
pub use pipeline::{retrieved_entities_to_json, RetrievalConfig, RetrievalTuning};
|
||||
|
||||
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -56,435 +32,34 @@ pub struct RetrievedEntity {
|
||||
pub chunks: Vec<RetrievedChunk>,
|
||||
}
|
||||
|
||||
// Primary orchestrator for the process of retrieving KnowledgeEntitities related to a input_text
|
||||
#[instrument(skip_all, fields(user_id))]
|
||||
pub async fn retrieve_entities(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
query: &str,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
trace!("Generating query embedding for hybrid retrieval");
|
||||
let query_embedding = generate_embedding(openai_client, query, db_client).await?;
|
||||
retrieve_entities_with_embedding(db_client, query_embedding, query, user_id).await
|
||||
}
|
||||
|
||||
pub(crate) async fn retrieve_entities_with_embedding(
|
||||
db_client: &SurrealDbClient,
|
||||
query_embedding: Vec<f32>,
|
||||
query: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
// 1) Gather first-pass candidates from vector search and BM25.
|
||||
let weights = FusionWeights::default();
|
||||
|
||||
let (vector_entities, vector_chunks, mut fts_entities, mut fts_chunks) = tokio::try_join!(
|
||||
find_items_by_vector_similarity_with_embedding(
|
||||
ENTITY_VECTOR_TAKE,
|
||||
query_embedding.clone(),
|
||||
db_client,
|
||||
"knowledge_entity",
|
||||
user_id,
|
||||
),
|
||||
find_items_by_vector_similarity_with_embedding(
|
||||
CHUNK_VECTOR_TAKE,
|
||||
query_embedding,
|
||||
db_client,
|
||||
"text_chunk",
|
||||
user_id,
|
||||
),
|
||||
find_items_by_fts(
|
||||
ENTITY_FTS_TAKE,
|
||||
query,
|
||||
db_client,
|
||||
"knowledge_entity",
|
||||
user_id
|
||||
),
|
||||
find_items_by_fts(CHUNK_FTS_TAKE, query, db_client, "text_chunk", user_id),
|
||||
)?;
|
||||
|
||||
debug!(
|
||||
vector_entities = vector_entities.len(),
|
||||
vector_chunks = vector_chunks.len(),
|
||||
fts_entities = fts_entities.len(),
|
||||
fts_chunks = fts_chunks.len(),
|
||||
"Hybrid retrieval initial candidate counts"
|
||||
);
|
||||
|
||||
normalize_fts_scores(&mut fts_entities);
|
||||
normalize_fts_scores(&mut fts_chunks);
|
||||
|
||||
let mut entity_candidates: HashMap<String, Scored<KnowledgeEntity>> = HashMap::new();
|
||||
let mut chunk_candidates: HashMap<String, Scored<TextChunk>> = HashMap::new();
|
||||
|
||||
// Collate raw retrieval results so each ID accumulates all available signals.
|
||||
merge_scored_by_id(&mut entity_candidates, vector_entities);
|
||||
merge_scored_by_id(&mut entity_candidates, fts_entities);
|
||||
merge_scored_by_id(&mut chunk_candidates, vector_chunks);
|
||||
merge_scored_by_id(&mut chunk_candidates, fts_chunks);
|
||||
|
||||
// 2) Normalize scores, fuse them, and allow high-confidence entities to pull neighbors from the graph.
|
||||
apply_fusion(&mut entity_candidates, weights);
|
||||
apply_fusion(&mut chunk_candidates, weights);
|
||||
enrich_entities_from_graph(&mut entity_candidates, db_client, user_id, weights).await?;
|
||||
|
||||
// 3) Track high-signal chunk sources so we can backfill missing entities.
|
||||
let chunk_by_source = group_chunks_by_source(&chunk_candidates);
|
||||
let mut missing_sources = Vec::new();
|
||||
|
||||
for source_id in chunk_by_source.keys() {
|
||||
if !entity_candidates
|
||||
.values()
|
||||
.any(|entity| entity.item.source_id == *source_id)
|
||||
{
|
||||
missing_sources.push(source_id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
if !missing_sources.is_empty() {
|
||||
let related_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
|
||||
missing_sources.clone(),
|
||||
"knowledge_entity",
|
||||
user_id,
|
||||
db_client,
|
||||
)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
for entity in related_entities {
|
||||
if let Some(chunks) = chunk_by_source.get(&entity.source_id) {
|
||||
let best_chunk_score = chunks
|
||||
.iter()
|
||||
.map(|chunk| chunk.fused)
|
||||
.fold(0.0f32, f32::max);
|
||||
|
||||
let mut scored = Scored::new(entity.clone()).with_vector_score(best_chunk_score);
|
||||
let fused = fuse_scores(&scored.scores, weights);
|
||||
scored.update_fused(fused);
|
||||
entity_candidates.insert(entity.id.clone(), scored);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Boost entities with evidence from high scoring chunks.
|
||||
for entity in entity_candidates.values_mut() {
|
||||
if let Some(chunks) = chunk_by_source.get(&entity.item.source_id) {
|
||||
let best_chunk_score = chunks
|
||||
.iter()
|
||||
.map(|chunk| chunk.fused)
|
||||
.fold(0.0f32, f32::max);
|
||||
|
||||
if best_chunk_score > 0.0 {
|
||||
let boosted = entity.scores.vector.unwrap_or(0.0).max(best_chunk_score);
|
||||
entity.scores.vector = Some(boosted);
|
||||
let fused = fuse_scores(&entity.scores, weights);
|
||||
entity.update_fused(fused);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut entity_results: Vec<Scored<KnowledgeEntity>> =
|
||||
entity_candidates.into_values().collect();
|
||||
sort_by_fused_desc(&mut entity_results);
|
||||
|
||||
let mut filtered_entities: Vec<Scored<KnowledgeEntity>> = entity_results
|
||||
.iter()
|
||||
.filter(|candidate| candidate.fused >= SCORE_THRESHOLD)
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
if filtered_entities.len() < FALLBACK_MIN_RESULTS {
|
||||
// Low recall scenarios still benefit from some context; take the top N regardless of score.
|
||||
filtered_entities = entity_results
|
||||
.into_iter()
|
||||
.take(FALLBACK_MIN_RESULTS)
|
||||
.collect();
|
||||
}
|
||||
|
||||
// 4) Re-rank chunks and prepare for attachment to surviving entities.
|
||||
let mut chunk_results: Vec<Scored<TextChunk>> = chunk_candidates.into_values().collect();
|
||||
sort_by_fused_desc(&mut chunk_results);
|
||||
|
||||
let mut chunk_by_id: HashMap<String, Scored<TextChunk>> = HashMap::new();
|
||||
for chunk in chunk_results {
|
||||
chunk_by_id.insert(chunk.item.id.clone(), chunk);
|
||||
}
|
||||
|
||||
enrich_chunks_from_entities(
|
||||
&mut chunk_by_id,
|
||||
&filtered_entities,
|
||||
pipeline::run_pipeline(
|
||||
db_client,
|
||||
openai_client,
|
||||
input_text,
|
||||
user_id,
|
||||
weights,
|
||||
RetrievalConfig::default(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut chunk_values: Vec<Scored<TextChunk>> = chunk_by_id.into_values().collect();
|
||||
sort_by_fused_desc(&mut chunk_values);
|
||||
|
||||
Ok(assemble_results(filtered_entities, chunk_values))
|
||||
}
|
||||
|
||||
// Minimal record used while seeding graph expansion so we can retain the original fused score.
|
||||
#[derive(Clone)]
|
||||
struct GraphSeed {
|
||||
id: String,
|
||||
fused: f32,
|
||||
}
|
||||
|
||||
async fn enrich_entities_from_graph(
|
||||
entity_candidates: &mut HashMap<String, Scored<KnowledgeEntity>>,
|
||||
db_client: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
weights: FusionWeights,
|
||||
) -> Result<(), AppError> {
|
||||
if entity_candidates.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Select a small frontier of high-confidence entities to seed the relationship walk.
|
||||
let mut seeds: Vec<GraphSeed> = entity_candidates
|
||||
.values()
|
||||
.filter(|entity| entity.fused >= GRAPH_SEED_MIN_SCORE)
|
||||
.map(|entity| GraphSeed {
|
||||
id: entity.item.id.clone(),
|
||||
fused: entity.fused,
|
||||
})
|
||||
.collect();
|
||||
|
||||
if seeds.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Prioritise the strongest seeds so we explore the most grounded context first.
|
||||
seeds.sort_by(|a, b| {
|
||||
b.fused
|
||||
.partial_cmp(&a.fused)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
seeds.truncate(GRAPH_TRAVERSAL_SEED_LIMIT);
|
||||
|
||||
let mut futures = FuturesUnordered::new();
|
||||
for seed in seeds.clone() {
|
||||
let user_id = user_id.to_owned();
|
||||
futures.push(async move {
|
||||
// Fetch neighbors concurrently to avoid serial graph round trips.
|
||||
let neighbors = find_entities_by_relationship_by_id(
|
||||
db_client,
|
||||
&seed.id,
|
||||
&user_id,
|
||||
GRAPH_NEIGHBOR_LIMIT,
|
||||
)
|
||||
.await;
|
||||
(seed, neighbors)
|
||||
});
|
||||
}
|
||||
|
||||
while let Some((seed, neighbors_result)) = futures.next().await {
|
||||
let neighbors = neighbors_result.map_err(AppError::from)?;
|
||||
if neighbors.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Fold neighbors back into the candidate map and let them inherit attenuated signal.
|
||||
for neighbor in neighbors {
|
||||
if neighbor.id == seed.id {
|
||||
continue;
|
||||
}
|
||||
|
||||
let graph_score = clamp_unit(seed.fused * GRAPH_SCORE_DECAY);
|
||||
let entry = entity_candidates
|
||||
.entry(neighbor.id.clone())
|
||||
.or_insert_with(|| Scored::new(neighbor.clone()));
|
||||
|
||||
entry.item = neighbor;
|
||||
|
||||
let inherited_vector = clamp_unit(graph_score * GRAPH_VECTOR_INHERITANCE);
|
||||
let vector_existing = entry.scores.vector.unwrap_or(0.0);
|
||||
if inherited_vector > vector_existing {
|
||||
entry.scores.vector = Some(inherited_vector);
|
||||
}
|
||||
|
||||
let existing_graph = entry.scores.graph.unwrap_or(f32::MIN);
|
||||
if graph_score > existing_graph || entry.scores.graph.is_none() {
|
||||
entry.scores.graph = Some(graph_score);
|
||||
}
|
||||
|
||||
let fused = fuse_scores(&entry.scores, weights);
|
||||
entry.update_fused(fused);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn normalize_fts_scores<T>(results: &mut [Scored<T>]) {
|
||||
// Scale BM25 outputs into [0,1] to keep fusion weights predictable.
|
||||
let raw_scores: Vec<f32> = results
|
||||
.iter()
|
||||
.map(|candidate| candidate.scores.fts.unwrap_or(0.0))
|
||||
.collect();
|
||||
|
||||
let normalized = min_max_normalize(&raw_scores);
|
||||
for (candidate, normalized_score) in results.iter_mut().zip(normalized.into_iter()) {
|
||||
candidate.scores.fts = Some(normalized_score);
|
||||
candidate.update_fused(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_fusion<T>(candidates: &mut HashMap<String, Scored<T>>, weights: FusionWeights)
|
||||
where
|
||||
T: StoredObject,
|
||||
{
|
||||
// Collapse individual signals into a single fused score used for ranking.
|
||||
for candidate in candidates.values_mut() {
|
||||
let fused = fuse_scores(&candidate.scores, weights);
|
||||
candidate.update_fused(fused);
|
||||
}
|
||||
}
|
||||
|
||||
fn group_chunks_by_source(
|
||||
chunks: &HashMap<String, Scored<TextChunk>>,
|
||||
) -> HashMap<String, Vec<Scored<TextChunk>>> {
|
||||
// Preserve chunk candidates keyed by their originating source entity.
|
||||
let mut by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
|
||||
|
||||
for chunk in chunks.values() {
|
||||
by_source
|
||||
.entry(chunk.item.source_id.clone())
|
||||
.or_default()
|
||||
.push(chunk.clone());
|
||||
}
|
||||
by_source
|
||||
}
|
||||
|
||||
async fn enrich_chunks_from_entities(
|
||||
chunk_candidates: &mut HashMap<String, Scored<TextChunk>>,
|
||||
entities: &[Scored<KnowledgeEntity>],
|
||||
db_client: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
weights: FusionWeights,
|
||||
) -> Result<(), AppError> {
|
||||
// Fetch additional chunks referenced by entities that survived the fusion stage.
|
||||
let mut source_ids: HashSet<String> = HashSet::new();
|
||||
for entity in entities {
|
||||
source_ids.insert(entity.item.source_id.clone());
|
||||
}
|
||||
|
||||
if source_ids.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let chunks = find_entities_by_source_ids::<TextChunk>(
|
||||
source_ids.into_iter().collect(),
|
||||
"text_chunk",
|
||||
user_id,
|
||||
db_client,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut entity_score_lookup: HashMap<String, f32> = HashMap::new();
|
||||
// Cache fused scores per source so chunks inherit the strength of their parent entity.
|
||||
for entity in entities {
|
||||
entity_score_lookup.insert(entity.item.source_id.clone(), entity.fused);
|
||||
}
|
||||
|
||||
for chunk in chunks {
|
||||
// Ensure each chunk is represented so downstream selection sees the latest content.
|
||||
let entry = chunk_candidates
|
||||
.entry(chunk.id.clone())
|
||||
.or_insert_with(|| Scored::new(chunk.clone()).with_vector_score(0.0));
|
||||
|
||||
let entity_score = entity_score_lookup
|
||||
.get(&chunk.source_id)
|
||||
.copied()
|
||||
.unwrap_or(0.0);
|
||||
|
||||
// Lift chunk score toward the entity score so supporting evidence is prioritised.
|
||||
entry.scores.vector = Some(entry.scores.vector.unwrap_or(0.0).max(entity_score * 0.8));
|
||||
let fused = fuse_scores(&entry.scores, weights);
|
||||
entry.update_fused(fused);
|
||||
entry.item = chunk;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn assemble_results(
|
||||
entities: Vec<Scored<KnowledgeEntity>>,
|
||||
mut chunks: Vec<Scored<TextChunk>>,
|
||||
) -> Vec<RetrievedEntity> {
|
||||
// Re-associate chunk candidates with their parent entity for ranked selection.
|
||||
let mut chunk_by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
|
||||
for chunk in chunks.drain(..) {
|
||||
chunk_by_source
|
||||
.entry(chunk.item.source_id.clone())
|
||||
.or_default()
|
||||
.push(chunk);
|
||||
}
|
||||
|
||||
for chunk_list in chunk_by_source.values_mut() {
|
||||
sort_by_fused_desc(chunk_list);
|
||||
}
|
||||
|
||||
let mut token_budget_remaining = TOKEN_BUDGET_ESTIMATE;
|
||||
let mut results = Vec::new();
|
||||
|
||||
for entity in entities {
|
||||
// Attach best chunks first while respecting per-entity and global token caps.
|
||||
let mut selected_chunks = Vec::new();
|
||||
if let Some(candidates) = chunk_by_source.get_mut(&entity.item.source_id) {
|
||||
let mut per_entity_count = 0;
|
||||
candidates.sort_by(|a, b| {
|
||||
b.fused
|
||||
.partial_cmp(&a.fused)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
for candidate in candidates.iter() {
|
||||
if per_entity_count >= MAX_CHUNKS_PER_ENTITY {
|
||||
break;
|
||||
}
|
||||
let estimated_tokens = estimate_tokens(&candidate.item.chunk);
|
||||
if estimated_tokens > token_budget_remaining {
|
||||
continue;
|
||||
}
|
||||
|
||||
token_budget_remaining = token_budget_remaining.saturating_sub(estimated_tokens);
|
||||
per_entity_count += 1;
|
||||
|
||||
selected_chunks.push(RetrievedChunk {
|
||||
chunk: candidate.item.clone(),
|
||||
score: candidate.fused,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
results.push(RetrievedEntity {
|
||||
entity: entity.item.clone(),
|
||||
score: entity.fused,
|
||||
chunks: selected_chunks,
|
||||
});
|
||||
|
||||
if token_budget_remaining == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
fn estimate_tokens(text: &str) -> usize {
|
||||
// Simple heuristic to avoid calling a tokenizer in hot code paths.
|
||||
let chars = text.chars().count().max(1);
|
||||
(chars / AVG_CHARS_PER_TOKEN).max(1)
|
||||
.await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use async_openai::Client;
|
||||
use common::storage::types::{
|
||||
knowledge_entity::KnowledgeEntityType, knowledge_relationship::KnowledgeRelationship,
|
||||
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
||||
knowledge_relationship::KnowledgeRelationship,
|
||||
text_chunk::TextChunk,
|
||||
};
|
||||
use pipeline::RetrievalConfig;
|
||||
use uuid::Uuid;
|
||||
|
||||
fn test_embedding() -> Vec<f32> {
|
||||
@@ -527,70 +102,46 @@ mod tests {
|
||||
COMMIT TRANSACTION;",
|
||||
)
|
||||
.await
|
||||
.expect("Failed to redefine vector indexes for tests");
|
||||
.expect("Failed to configure indices");
|
||||
|
||||
db
|
||||
}
|
||||
|
||||
async fn seed_test_data(db: &SurrealDbClient, user_id: &str) {
|
||||
let entity_relevant = KnowledgeEntity::new(
|
||||
"source_a".into(),
|
||||
"Rust Concurrency Patterns".into(),
|
||||
"Discussion about async concurrency in Rust.".into(),
|
||||
#[tokio::test]
|
||||
async fn test_retrieve_entities_with_embedding_basic_flow() {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "test_user";
|
||||
let entity = KnowledgeEntity::new(
|
||||
"source_1".into(),
|
||||
"Rust async guide".into(),
|
||||
"Detailed notes about async runtimes".into(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
entity_embedding_high(),
|
||||
user_id.into(),
|
||||
);
|
||||
let entity_irrelevant = KnowledgeEntity::new(
|
||||
"source_b".into(),
|
||||
"Python Tips".into(),
|
||||
"General Python programming tips.".into(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
entity_embedding_low(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(entity_relevant.clone())
|
||||
.await
|
||||
.expect("Failed to store relevant entity");
|
||||
db.store_item(entity_irrelevant.clone())
|
||||
.await
|
||||
.expect("Failed to store irrelevant entity");
|
||||
|
||||
let chunk_primary = TextChunk::new(
|
||||
entity_relevant.source_id.clone(),
|
||||
"Tokio enables async concurrency with lightweight tasks.".into(),
|
||||
let chunk = TextChunk::new(
|
||||
entity.source_id.clone(),
|
||||
"Tokio uses cooperative scheduling for fairness.".into(),
|
||||
chunk_embedding_primary(),
|
||||
user_id.into(),
|
||||
);
|
||||
let chunk_secondary = TextChunk::new(
|
||||
entity_irrelevant.source_id.clone(),
|
||||
"Python focuses on readability and dynamic typing.".into(),
|
||||
chunk_embedding_secondary(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(chunk_primary)
|
||||
db.store_item(entity.clone())
|
||||
.await
|
||||
.expect("Failed to store primary chunk");
|
||||
db.store_item(chunk_secondary)
|
||||
.expect("Failed to store entity");
|
||||
db.store_item(chunk.clone())
|
||||
.await
|
||||
.expect("Failed to store secondary chunk");
|
||||
}
|
||||
.expect("Failed to store chunk");
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hybrid_retrieval_prioritises_relevant_entity() {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "user123";
|
||||
seed_test_data(&db, user_id).await;
|
||||
|
||||
let results = retrieve_entities_with_embedding(
|
||||
let openai_client = Client::new();
|
||||
let results = pipeline::run_pipeline_with_embedding(
|
||||
&db,
|
||||
&openai_client,
|
||||
test_embedding(),
|
||||
"Rust concurrency async tasks",
|
||||
user_id,
|
||||
RetrievalConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("Hybrid retrieval failed");
|
||||
@@ -599,9 +150,7 @@ mod tests {
|
||||
!results.is_empty(),
|
||||
"Expected at least one retrieval result"
|
||||
);
|
||||
|
||||
let top = &results[0];
|
||||
|
||||
assert!(
|
||||
top.entity.name.contains("Rust"),
|
||||
"Expected Rust entity to be ranked first"
|
||||
@@ -610,16 +159,6 @@ mod tests {
|
||||
!top.chunks.is_empty(),
|
||||
"Expected Rust entity to include supporting chunks"
|
||||
);
|
||||
|
||||
let chunk_texts: Vec<&str> = top
|
||||
.chunks
|
||||
.iter()
|
||||
.map(|chunk| chunk.chunk.chunk.as_str())
|
||||
.collect();
|
||||
assert!(
|
||||
chunk_texts.iter().any(|text| text.contains("Tokio")),
|
||||
"Expected chunk discussing Tokio to be included"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -673,6 +212,7 @@ mod tests {
|
||||
.await
|
||||
.expect("Failed to store neighbor chunk");
|
||||
|
||||
let openai_client = Client::new();
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
primary.id.clone(),
|
||||
neighbor.id.clone(),
|
||||
@@ -685,11 +225,13 @@ mod tests {
|
||||
.await
|
||||
.expect("Failed to store relationship");
|
||||
|
||||
let results = retrieve_entities_with_embedding(
|
||||
let results = pipeline::run_pipeline_with_embedding(
|
||||
&db,
|
||||
&openai_client,
|
||||
test_embedding(),
|
||||
"Rust concurrency async tasks",
|
||||
user_id,
|
||||
RetrievalConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("Hybrid retrieval failed");
|
||||
|
||||
61
composite-retrieval/src/pipeline/config.rs
Normal file
61
composite-retrieval/src/pipeline/config.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Tunable parameters that govern each retrieval stage.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RetrievalTuning {
|
||||
pub entity_vector_take: usize,
|
||||
pub chunk_vector_take: usize,
|
||||
pub entity_fts_take: usize,
|
||||
pub chunk_fts_take: usize,
|
||||
pub score_threshold: f32,
|
||||
pub fallback_min_results: usize,
|
||||
pub token_budget_estimate: usize,
|
||||
pub avg_chars_per_token: usize,
|
||||
pub max_chunks_per_entity: usize,
|
||||
pub graph_traversal_seed_limit: usize,
|
||||
pub graph_neighbor_limit: usize,
|
||||
pub graph_score_decay: f32,
|
||||
pub graph_seed_min_score: f32,
|
||||
pub graph_vector_inheritance: f32,
|
||||
}
|
||||
|
||||
impl Default for RetrievalTuning {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
entity_vector_take: 15,
|
||||
chunk_vector_take: 20,
|
||||
entity_fts_take: 10,
|
||||
chunk_fts_take: 20,
|
||||
score_threshold: 0.35,
|
||||
fallback_min_results: 10,
|
||||
token_budget_estimate: 2800,
|
||||
avg_chars_per_token: 4,
|
||||
max_chunks_per_entity: 4,
|
||||
graph_traversal_seed_limit: 5,
|
||||
graph_neighbor_limit: 6,
|
||||
graph_score_decay: 0.75,
|
||||
graph_seed_min_score: 0.4,
|
||||
graph_vector_inheritance: 0.6,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper containing tuning plus future flags for per-request overrides.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetrievalConfig {
|
||||
pub tuning: RetrievalTuning,
|
||||
}
|
||||
|
||||
impl RetrievalConfig {
|
||||
pub fn new(tuning: RetrievalTuning) -> Self {
|
||||
Self { tuning }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RetrievalConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
tuning: RetrievalTuning::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
100
composite-retrieval/src/pipeline/mod.rs
Normal file
100
composite-retrieval/src/pipeline/mod.rs
Normal file
@@ -0,0 +1,100 @@
|
||||
mod config;
|
||||
mod stages;
|
||||
mod state;
|
||||
|
||||
pub use config::{RetrievalConfig, RetrievalTuning};
|
||||
|
||||
use crate::RetrievedEntity;
|
||||
use async_openai::Client;
|
||||
use common::{error::AppError, storage::db::SurrealDbClient};
|
||||
use tracing::info;
|
||||
|
||||
/// Drives the retrieval pipeline from embedding through final assembly.
|
||||
pub async fn run_pipeline(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
let machine = state::ready();
|
||||
let input_chars = input_text.chars().count();
|
||||
let input_preview: String = input_text.chars().take(120).collect();
|
||||
let input_preview_clean = input_preview.replace('\n', " ");
|
||||
let preview_len = input_preview_clean.chars().count();
|
||||
info!(
|
||||
%user_id,
|
||||
input_chars,
|
||||
preview_truncated = input_chars > preview_len,
|
||||
preview = %input_preview_clean,
|
||||
"Starting ingestion retrieval pipeline"
|
||||
);
|
||||
let mut ctx = stages::PipelineContext::new(
|
||||
db_client,
|
||||
openai_client,
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
config,
|
||||
);
|
||||
let machine = stages::embed(machine, &mut ctx).await?;
|
||||
let machine = stages::collect_candidates(machine, &mut ctx).await?;
|
||||
let machine = stages::expand_graph(machine, &mut ctx).await?;
|
||||
let machine = stages::attach_chunks(machine, &mut ctx).await?;
|
||||
let results = stages::assemble(machine, &mut ctx)?;
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub async fn run_pipeline_with_embedding(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
let machine = state::ready();
|
||||
let mut ctx = stages::PipelineContext::with_embedding(
|
||||
db_client,
|
||||
openai_client,
|
||||
query_embedding,
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
config,
|
||||
);
|
||||
let machine = stages::embed(machine, &mut ctx).await?;
|
||||
let machine = stages::collect_candidates(machine, &mut ctx).await?;
|
||||
let machine = stages::expand_graph(machine, &mut ctx).await?;
|
||||
let machine = stages::attach_chunks(machine, &mut ctx).await?;
|
||||
let results = stages::assemble(machine, &mut ctx)?;
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Helper exposed for tests to convert retrieved entities into downstream prompt JSON.
|
||||
pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value {
|
||||
serde_json::json!(entities
|
||||
.iter()
|
||||
.map(|entry| {
|
||||
serde_json::json!({
|
||||
"KnowledgeEntity": {
|
||||
"id": entry.entity.id,
|
||||
"name": entry.entity.name,
|
||||
"description": entry.entity.description,
|
||||
"score": round_score(entry.score),
|
||||
"chunks": entry.chunks.iter().map(|chunk| {
|
||||
serde_json::json!({
|
||||
"score": round_score(chunk.score),
|
||||
"content": chunk.chunk.chunk
|
||||
})
|
||||
}).collect::<Vec<_>>()
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
}
|
||||
|
||||
fn round_score(value: f32) -> f64 {
|
||||
(f64::from(value) * 1000.0).round() / 1000.0
|
||||
}
|
||||
599
composite-retrieval/src/pipeline/stages/mod.rs
Normal file
599
composite-retrieval/src/pipeline/stages/mod.rs
Normal file
@@ -0,0 +1,599 @@
|
||||
use async_openai::Client;
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject},
|
||||
},
|
||||
utils::embedding::generate_embedding,
|
||||
};
|
||||
use futures::{stream::FuturesUnordered, StreamExt};
|
||||
use state_machines::core::GuardError;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use tracing::{debug, instrument, warn};
|
||||
|
||||
use crate::{
|
||||
fts::find_items_by_fts,
|
||||
graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids},
|
||||
scoring::{
|
||||
clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc,
|
||||
FusionWeights, Scored,
|
||||
},
|
||||
vector::find_items_by_vector_similarity_with_embedding,
|
||||
RetrievedChunk, RetrievedEntity,
|
||||
};
|
||||
|
||||
use super::{
|
||||
config::RetrievalConfig,
|
||||
state::{
|
||||
CandidatesLoaded, ChunksAttached, Embedded, GraphExpanded, HybridRetrievalMachine, Ready,
|
||||
},
|
||||
};
|
||||
|
||||
pub struct PipelineContext<'a> {
|
||||
pub db_client: &'a SurrealDbClient,
|
||||
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
pub input_text: String,
|
||||
pub user_id: String,
|
||||
pub config: RetrievalConfig,
|
||||
pub query_embedding: Option<Vec<f32>>,
|
||||
pub entity_candidates: HashMap<String, Scored<KnowledgeEntity>>,
|
||||
pub chunk_candidates: HashMap<String, Scored<TextChunk>>,
|
||||
pub filtered_entities: Vec<Scored<KnowledgeEntity>>,
|
||||
pub chunk_values: Vec<Scored<TextChunk>>,
|
||||
}
|
||||
|
||||
impl<'a> PipelineContext<'a> {
|
||||
pub fn new(
|
||||
db_client: &'a SurrealDbClient,
|
||||
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
input_text: String,
|
||||
user_id: String,
|
||||
config: RetrievalConfig,
|
||||
) -> Self {
|
||||
Self {
|
||||
db_client,
|
||||
openai_client,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
query_embedding: None,
|
||||
entity_candidates: HashMap::new(),
|
||||
chunk_candidates: HashMap::new(),
|
||||
filtered_entities: Vec::new(),
|
||||
chunk_values: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn with_embedding(
|
||||
db_client: &'a SurrealDbClient,
|
||||
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: String,
|
||||
user_id: String,
|
||||
config: RetrievalConfig,
|
||||
) -> Self {
|
||||
let mut ctx = Self::new(db_client, openai_client, input_text, user_id, config);
|
||||
ctx.query_embedding = Some(query_embedding);
|
||||
ctx
|
||||
}
|
||||
|
||||
fn ensure_embedding(&self) -> Result<&Vec<f32>, AppError> {
|
||||
self.query_embedding.as_ref().ok_or_else(|| {
|
||||
AppError::InternalError(
|
||||
"query embedding missing before candidate collection".to_string(),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn embed(
|
||||
machine: HybridRetrievalMachine<(), Ready>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<HybridRetrievalMachine<(), Embedded>, AppError> {
|
||||
let embedding_cached = ctx.query_embedding.is_some();
|
||||
if embedding_cached {
|
||||
debug!("Reusing cached query embedding for hybrid retrieval");
|
||||
} else {
|
||||
debug!("Generating query embedding for hybrid retrieval");
|
||||
let embedding =
|
||||
generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await?;
|
||||
ctx.query_embedding = Some(embedding);
|
||||
}
|
||||
|
||||
machine
|
||||
.embed()
|
||||
.map_err(|(_, guard)| map_guard_error("embed", guard))
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn collect_candidates(
|
||||
machine: HybridRetrievalMachine<(), Embedded>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<HybridRetrievalMachine<(), CandidatesLoaded>, AppError> {
|
||||
debug!("Collecting initial candidates via vector and FTS search");
|
||||
let embedding = ctx.ensure_embedding()?.clone();
|
||||
let tuning = &ctx.config.tuning;
|
||||
|
||||
let weights = FusionWeights::default();
|
||||
|
||||
let (vector_entities, vector_chunks, mut fts_entities, mut fts_chunks) = tokio::try_join!(
|
||||
find_items_by_vector_similarity_with_embedding(
|
||||
tuning.entity_vector_take,
|
||||
embedding.clone(),
|
||||
ctx.db_client,
|
||||
"knowledge_entity",
|
||||
&ctx.user_id,
|
||||
),
|
||||
find_items_by_vector_similarity_with_embedding(
|
||||
tuning.chunk_vector_take,
|
||||
embedding,
|
||||
ctx.db_client,
|
||||
"text_chunk",
|
||||
&ctx.user_id,
|
||||
),
|
||||
find_items_by_fts(
|
||||
tuning.entity_fts_take,
|
||||
&ctx.input_text,
|
||||
ctx.db_client,
|
||||
"knowledge_entity",
|
||||
&ctx.user_id,
|
||||
),
|
||||
find_items_by_fts(
|
||||
tuning.chunk_fts_take,
|
||||
&ctx.input_text,
|
||||
ctx.db_client,
|
||||
"text_chunk",
|
||||
&ctx.user_id
|
||||
),
|
||||
)?;
|
||||
|
||||
debug!(
|
||||
vector_entities = vector_entities.len(),
|
||||
vector_chunks = vector_chunks.len(),
|
||||
fts_entities = fts_entities.len(),
|
||||
fts_chunks = fts_chunks.len(),
|
||||
"Hybrid retrieval initial candidate counts"
|
||||
);
|
||||
|
||||
normalize_fts_scores(&mut fts_entities);
|
||||
normalize_fts_scores(&mut fts_chunks);
|
||||
|
||||
merge_scored_by_id(&mut ctx.entity_candidates, vector_entities);
|
||||
merge_scored_by_id(&mut ctx.entity_candidates, fts_entities);
|
||||
merge_scored_by_id(&mut ctx.chunk_candidates, vector_chunks);
|
||||
merge_scored_by_id(&mut ctx.chunk_candidates, fts_chunks);
|
||||
|
||||
apply_fusion(&mut ctx.entity_candidates, weights);
|
||||
apply_fusion(&mut ctx.chunk_candidates, weights);
|
||||
|
||||
machine
|
||||
.collect_candidates()
|
||||
.map_err(|(_, guard)| map_guard_error("collect_candidates", guard))
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn expand_graph(
|
||||
machine: HybridRetrievalMachine<(), CandidatesLoaded>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<HybridRetrievalMachine<(), GraphExpanded>, AppError> {
|
||||
debug!("Expanding candidates using graph relationships");
|
||||
let tuning = &ctx.config.tuning;
|
||||
let weights = FusionWeights::default();
|
||||
|
||||
if ctx.entity_candidates.is_empty() {
|
||||
return machine
|
||||
.expand_graph()
|
||||
.map_err(|(_, guard)| map_guard_error("expand_graph", guard));
|
||||
}
|
||||
|
||||
let graph_seeds = seeds_from_candidates(
|
||||
&ctx.entity_candidates,
|
||||
tuning.graph_seed_min_score,
|
||||
tuning.graph_traversal_seed_limit,
|
||||
);
|
||||
|
||||
if graph_seeds.is_empty() {
|
||||
return machine
|
||||
.expand_graph()
|
||||
.map_err(|(_, guard)| map_guard_error("expand_graph", guard));
|
||||
}
|
||||
|
||||
let mut futures = FuturesUnordered::new();
|
||||
for seed in graph_seeds {
|
||||
let db = ctx.db_client;
|
||||
let user = ctx.user_id.clone();
|
||||
futures.push(async move {
|
||||
let neighbors = find_entities_by_relationship_by_id(
|
||||
db,
|
||||
&seed.id,
|
||||
&user,
|
||||
tuning.graph_neighbor_limit,
|
||||
)
|
||||
.await;
|
||||
(seed, neighbors)
|
||||
});
|
||||
}
|
||||
|
||||
while let Some((seed, neighbors_result)) = futures.next().await {
|
||||
let neighbors = neighbors_result.map_err(AppError::from)?;
|
||||
if neighbors.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
for neighbor in neighbors {
|
||||
if neighbor.id == seed.id {
|
||||
continue;
|
||||
}
|
||||
|
||||
let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay);
|
||||
let entry = ctx
|
||||
.entity_candidates
|
||||
.entry(neighbor.id.clone())
|
||||
.or_insert_with(|| Scored::new(neighbor.clone()));
|
||||
|
||||
entry.item = neighbor;
|
||||
|
||||
let inherited_vector = clamp_unit(graph_score * tuning.graph_vector_inheritance);
|
||||
let vector_existing = entry.scores.vector.unwrap_or(0.0);
|
||||
if inherited_vector > vector_existing {
|
||||
entry.scores.vector = Some(inherited_vector);
|
||||
}
|
||||
|
||||
let existing_graph = entry.scores.graph.unwrap_or(f32::MIN);
|
||||
if graph_score > existing_graph || entry.scores.graph.is_none() {
|
||||
entry.scores.graph = Some(graph_score);
|
||||
}
|
||||
|
||||
let fused = fuse_scores(&entry.scores, weights);
|
||||
entry.update_fused(fused);
|
||||
}
|
||||
}
|
||||
|
||||
machine
|
||||
.expand_graph()
|
||||
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn attach_chunks(
|
||||
machine: HybridRetrievalMachine<(), GraphExpanded>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<HybridRetrievalMachine<(), ChunksAttached>, AppError> {
|
||||
debug!("Attaching chunks to surviving entities");
|
||||
let tuning = &ctx.config.tuning;
|
||||
let weights = FusionWeights::default();
|
||||
|
||||
let chunk_by_source = group_chunks_by_source(&ctx.chunk_candidates);
|
||||
|
||||
backfill_entities_from_chunks(
|
||||
&mut ctx.entity_candidates,
|
||||
&chunk_by_source,
|
||||
ctx.db_client,
|
||||
&ctx.user_id,
|
||||
weights,
|
||||
)
|
||||
.await?;
|
||||
|
||||
boost_entities_with_chunks(&mut ctx.entity_candidates, &chunk_by_source, weights);
|
||||
|
||||
let mut entity_results: Vec<Scored<KnowledgeEntity>> =
|
||||
ctx.entity_candidates.values().cloned().collect();
|
||||
sort_by_fused_desc(&mut entity_results);
|
||||
|
||||
let mut filtered_entities: Vec<Scored<KnowledgeEntity>> = entity_results
|
||||
.iter()
|
||||
.filter(|candidate| candidate.fused >= tuning.score_threshold)
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
if filtered_entities.len() < tuning.fallback_min_results {
|
||||
filtered_entities = entity_results
|
||||
.into_iter()
|
||||
.take(tuning.fallback_min_results)
|
||||
.collect();
|
||||
}
|
||||
|
||||
ctx.filtered_entities = filtered_entities;
|
||||
|
||||
let mut chunk_results: Vec<Scored<TextChunk>> =
|
||||
ctx.chunk_candidates.values().cloned().collect();
|
||||
sort_by_fused_desc(&mut chunk_results);
|
||||
|
||||
let mut chunk_by_id: HashMap<String, Scored<TextChunk>> = HashMap::new();
|
||||
for chunk in chunk_results {
|
||||
chunk_by_id.insert(chunk.item.id.clone(), chunk);
|
||||
}
|
||||
|
||||
enrich_chunks_from_entities(
|
||||
&mut chunk_by_id,
|
||||
&ctx.filtered_entities,
|
||||
ctx.db_client,
|
||||
&ctx.user_id,
|
||||
weights,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut chunk_values: Vec<Scored<TextChunk>> = chunk_by_id.into_values().collect();
|
||||
sort_by_fused_desc(&mut chunk_values);
|
||||
|
||||
ctx.chunk_values = chunk_values;
|
||||
|
||||
machine
|
||||
.attach_chunks()
|
||||
.map_err(|(_, guard)| map_guard_error("attach_chunks", guard))
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub fn assemble(
|
||||
machine: HybridRetrievalMachine<(), ChunksAttached>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
debug!("Assembling final retrieved entities");
|
||||
let tuning = &ctx.config.tuning;
|
||||
|
||||
let mut chunk_by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
|
||||
for chunk in ctx.chunk_values.drain(..) {
|
||||
chunk_by_source
|
||||
.entry(chunk.item.source_id.clone())
|
||||
.or_default()
|
||||
.push(chunk);
|
||||
}
|
||||
|
||||
for chunk_list in chunk_by_source.values_mut() {
|
||||
sort_by_fused_desc(chunk_list);
|
||||
}
|
||||
|
||||
let mut token_budget_remaining = tuning.token_budget_estimate;
|
||||
let mut results = Vec::new();
|
||||
|
||||
for entity in &ctx.filtered_entities {
|
||||
let mut selected_chunks = Vec::new();
|
||||
if let Some(candidates) = chunk_by_source.get_mut(&entity.item.source_id) {
|
||||
let mut per_entity_count = 0;
|
||||
candidates.sort_by(|a, b| {
|
||||
b.fused
|
||||
.partial_cmp(&a.fused)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
for candidate in candidates.iter() {
|
||||
if per_entity_count >= tuning.max_chunks_per_entity {
|
||||
break;
|
||||
}
|
||||
let estimated_tokens =
|
||||
estimate_tokens(&candidate.item.chunk, tuning.avg_chars_per_token);
|
||||
if estimated_tokens > token_budget_remaining {
|
||||
continue;
|
||||
}
|
||||
|
||||
token_budget_remaining = token_budget_remaining.saturating_sub(estimated_tokens);
|
||||
per_entity_count += 1;
|
||||
|
||||
selected_chunks.push(RetrievedChunk {
|
||||
chunk: candidate.item.clone(),
|
||||
score: candidate.fused,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
results.push(RetrievedEntity {
|
||||
entity: entity.item.clone(),
|
||||
score: entity.fused,
|
||||
chunks: selected_chunks,
|
||||
});
|
||||
|
||||
if token_budget_remaining == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
machine
|
||||
.assemble()
|
||||
.map_err(|(_, guard)| map_guard_error("assemble", guard))?;
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
fn map_guard_error(stage: &'static str, err: GuardError) -> AppError {
|
||||
AppError::InternalError(format!(
|
||||
"state machine guard '{stage}' failed: guard={}, event={}, kind={:?}",
|
||||
err.guard, err.event, err.kind
|
||||
))
|
||||
}
|
||||
fn normalize_fts_scores<T>(results: &mut [Scored<T>]) {
|
||||
let raw_scores: Vec<f32> = results
|
||||
.iter()
|
||||
.map(|candidate| candidate.scores.fts.unwrap_or(0.0))
|
||||
.collect();
|
||||
|
||||
let normalized = min_max_normalize(&raw_scores);
|
||||
for (candidate, normalized_score) in results.iter_mut().zip(normalized.into_iter()) {
|
||||
candidate.scores.fts = Some(normalized_score);
|
||||
candidate.update_fused(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_fusion<T>(candidates: &mut HashMap<String, Scored<T>>, weights: FusionWeights)
|
||||
where
|
||||
T: StoredObject,
|
||||
{
|
||||
for candidate in candidates.values_mut() {
|
||||
let fused = fuse_scores(&candidate.scores, weights);
|
||||
candidate.update_fused(fused);
|
||||
}
|
||||
}
|
||||
|
||||
fn group_chunks_by_source(
|
||||
chunks: &HashMap<String, Scored<TextChunk>>,
|
||||
) -> HashMap<String, Vec<Scored<TextChunk>>> {
|
||||
let mut by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
|
||||
|
||||
for chunk in chunks.values() {
|
||||
by_source
|
||||
.entry(chunk.item.source_id.clone())
|
||||
.or_default()
|
||||
.push(chunk.clone());
|
||||
}
|
||||
by_source
|
||||
}
|
||||
|
||||
async fn backfill_entities_from_chunks(
|
||||
entity_candidates: &mut HashMap<String, Scored<KnowledgeEntity>>,
|
||||
chunk_by_source: &HashMap<String, Vec<Scored<TextChunk>>>,
|
||||
db_client: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
weights: FusionWeights,
|
||||
) -> Result<(), AppError> {
|
||||
let mut missing_sources = Vec::new();
|
||||
|
||||
for source_id in chunk_by_source.keys() {
|
||||
if !entity_candidates
|
||||
.values()
|
||||
.any(|entity| entity.item.source_id == *source_id)
|
||||
{
|
||||
missing_sources.push(source_id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
if missing_sources.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let related_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
|
||||
missing_sources.clone(),
|
||||
"knowledge_entity",
|
||||
user_id,
|
||||
db_client,
|
||||
)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
if related_entities.is_empty() {
|
||||
warn!("expected related entities for missing chunk sources, but none were found");
|
||||
}
|
||||
|
||||
for entity in related_entities {
|
||||
if let Some(chunks) = chunk_by_source.get(&entity.source_id) {
|
||||
let best_chunk_score = chunks
|
||||
.iter()
|
||||
.map(|chunk| chunk.fused)
|
||||
.fold(0.0f32, f32::max);
|
||||
|
||||
let mut scored = Scored::new(entity.clone()).with_vector_score(best_chunk_score);
|
||||
let fused = fuse_scores(&scored.scores, weights);
|
||||
scored.update_fused(fused);
|
||||
entity_candidates.insert(entity.id.clone(), scored);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn boost_entities_with_chunks(
|
||||
entity_candidates: &mut HashMap<String, Scored<KnowledgeEntity>>,
|
||||
chunk_by_source: &HashMap<String, Vec<Scored<TextChunk>>>,
|
||||
weights: FusionWeights,
|
||||
) {
|
||||
for entity in entity_candidates.values_mut() {
|
||||
if let Some(chunks) = chunk_by_source.get(&entity.item.source_id) {
|
||||
let best_chunk_score = chunks
|
||||
.iter()
|
||||
.map(|chunk| chunk.fused)
|
||||
.fold(0.0f32, f32::max);
|
||||
|
||||
if best_chunk_score > 0.0 {
|
||||
let boosted = entity.scores.vector.unwrap_or(0.0).max(best_chunk_score);
|
||||
entity.scores.vector = Some(boosted);
|
||||
let fused = fuse_scores(&entity.scores, weights);
|
||||
entity.update_fused(fused);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn enrich_chunks_from_entities(
|
||||
chunk_candidates: &mut HashMap<String, Scored<TextChunk>>,
|
||||
entities: &[Scored<KnowledgeEntity>],
|
||||
db_client: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
weights: FusionWeights,
|
||||
) -> Result<(), AppError> {
|
||||
let mut source_ids: HashSet<String> = HashSet::new();
|
||||
for entity in entities {
|
||||
source_ids.insert(entity.item.source_id.clone());
|
||||
}
|
||||
|
||||
if source_ids.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let chunks = find_entities_by_source_ids::<TextChunk>(
|
||||
source_ids.into_iter().collect(),
|
||||
"text_chunk",
|
||||
user_id,
|
||||
db_client,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut entity_score_lookup: HashMap<String, f32> = HashMap::new();
|
||||
for entity in entities {
|
||||
entity_score_lookup.insert(entity.item.source_id.clone(), entity.fused);
|
||||
}
|
||||
|
||||
for chunk in chunks {
|
||||
let entry = chunk_candidates
|
||||
.entry(chunk.id.clone())
|
||||
.or_insert_with(|| Scored::new(chunk.clone()).with_vector_score(0.0));
|
||||
|
||||
let entity_score = entity_score_lookup
|
||||
.get(&chunk.source_id)
|
||||
.copied()
|
||||
.unwrap_or(0.0);
|
||||
|
||||
entry.scores.vector = Some(entry.scores.vector.unwrap_or(0.0).max(entity_score * 0.8));
|
||||
let fused = fuse_scores(&entry.scores, weights);
|
||||
entry.update_fused(fused);
|
||||
entry.item = chunk;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn estimate_tokens(text: &str, avg_chars_per_token: usize) -> usize {
|
||||
let chars = text.chars().count().max(1);
|
||||
(chars / avg_chars_per_token).max(1)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct GraphSeed {
|
||||
id: String,
|
||||
fused: f32,
|
||||
}
|
||||
|
||||
fn seeds_from_candidates(
|
||||
entity_candidates: &HashMap<String, Scored<KnowledgeEntity>>,
|
||||
min_score: f32,
|
||||
limit: usize,
|
||||
) -> Vec<GraphSeed> {
|
||||
let mut seeds: Vec<GraphSeed> = entity_candidates
|
||||
.values()
|
||||
.filter(|entity| entity.fused >= min_score)
|
||||
.map(|entity| GraphSeed {
|
||||
id: entity.item.id.clone(),
|
||||
fused: entity.fused,
|
||||
})
|
||||
.collect();
|
||||
|
||||
seeds.sort_by(|a, b| {
|
||||
b.fused
|
||||
.partial_cmp(&a.fused)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
if seeds.len() > limit {
|
||||
seeds.truncate(limit);
|
||||
}
|
||||
|
||||
seeds
|
||||
}
|
||||
25
composite-retrieval/src/pipeline/state.rs
Normal file
25
composite-retrieval/src/pipeline/state.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
use state_machines::state_machine;
|
||||
|
||||
state_machine! {
|
||||
name: HybridRetrievalMachine,
|
||||
state: HybridRetrievalState,
|
||||
initial: Ready,
|
||||
states: [Ready, Embedded, CandidatesLoaded, GraphExpanded, ChunksAttached, Completed, Failed],
|
||||
events {
|
||||
embed { transition: { from: Ready, to: Embedded } }
|
||||
collect_candidates { transition: { from: Embedded, to: CandidatesLoaded } }
|
||||
expand_graph { transition: { from: CandidatesLoaded, to: GraphExpanded } }
|
||||
attach_chunks { transition: { from: GraphExpanded, to: ChunksAttached } }
|
||||
assemble { transition: { from: ChunksAttached, to: Completed } }
|
||||
abort {
|
||||
transition: { from: Ready, to: Failed }
|
||||
transition: { from: CandidatesLoaded, to: Failed }
|
||||
transition: { from: GraphExpanded, to: Failed }
|
||||
transition: { from: ChunksAttached, to: Failed }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ready() -> HybridRetrievalMachine<(), Ready> {
|
||||
HybridRetrievalMachine::new(())
|
||||
}
|
||||
Reference in New Issue
Block a user