mirror of
https://github.com/perstarkse/minne.git
synced 2026-04-23 17:28:34 +02:00
refactor: implemented state machines for retrieval pipeline, improved tracing
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -1345,6 +1345,7 @@ dependencies = [
|
|||||||
"futures",
|
"futures",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"state-machines",
|
||||||
"surrealdb",
|
"surrealdb",
|
||||||
"thiserror 1.0.69",
|
"thiserror 1.0.69",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
|||||||
@@ -21,3 +21,4 @@ async-openai = { workspace = true }
|
|||||||
uuid = { workspace = true }
|
uuid = { workspace = true }
|
||||||
|
|
||||||
common = { path = "../common", features = ["test-utils"] }
|
common = { path = "../common", features = ["test-utils"] }
|
||||||
|
state-machines = { workspace = true }
|
||||||
|
|||||||
@@ -17,9 +17,9 @@ use common::{
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_json::{json, Value};
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::{retrieve_entities, RetrievedEntity};
|
use crate::{retrieve_entities, retrieved_entities_to_json};
|
||||||
|
|
||||||
use super::answer_retrieval_helper::get_query_response_schema;
|
use super::answer_retrieval_helper::get_query_response_schema;
|
||||||
|
|
||||||
@@ -65,7 +65,7 @@ pub async fn get_answer_with_references(
|
|||||||
let entities = retrieve_entities(surreal_db_client, openai_client, query, user_id).await?;
|
let entities = retrieve_entities(surreal_db_client, openai_client, query, user_id).await?;
|
||||||
let settings = SystemSettings::get_current(surreal_db_client).await?;
|
let settings = SystemSettings::get_current(surreal_db_client).await?;
|
||||||
|
|
||||||
let entities_json = format_entities_json(&entities);
|
let entities_json = retrieved_entities_to_json(&entities);
|
||||||
let user_message = create_user_message(&entities_json, query);
|
let user_message = create_user_message(&entities_json, query);
|
||||||
|
|
||||||
let request = create_chat_request(user_message, &settings)?;
|
let request = create_chat_request(user_message, &settings)?;
|
||||||
@@ -83,31 +83,6 @@ pub async fn get_answer_with_references(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn format_entities_json(entities: &[RetrievedEntity]) -> Value {
|
|
||||||
json!(entities
|
|
||||||
.iter()
|
|
||||||
.map(|entry| {
|
|
||||||
json!({
|
|
||||||
"KnowledgeEntity": {
|
|
||||||
"id": entry.entity.id,
|
|
||||||
"name": entry.entity.name,
|
|
||||||
"description": entry.entity.description,
|
|
||||||
"score": round_score(entry.score),
|
|
||||||
"chunks": entry.chunks.iter().map(|chunk| {
|
|
||||||
json!({
|
|
||||||
"score": round_score(chunk.score),
|
|
||||||
"content": chunk.chunk.chunk
|
|
||||||
})
|
|
||||||
}).collect::<Vec<_>>()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn round_score(value: f32) -> f64 {
|
|
||||||
(f64::from(value) * 1000.0).round() / 1000.0
|
|
||||||
}
|
|
||||||
pub fn create_user_message(entities_json: &Value, query: &str) -> String {
|
pub fn create_user_message(entities_json: &Value, query: &str) -> String {
|
||||||
format!(
|
format!(
|
||||||
r"
|
r"
|
||||||
|
|||||||
@@ -2,44 +2,20 @@ pub mod answer_retrieval;
|
|||||||
pub mod answer_retrieval_helper;
|
pub mod answer_retrieval_helper;
|
||||||
pub mod fts;
|
pub mod fts;
|
||||||
pub mod graph;
|
pub mod graph;
|
||||||
|
pub mod pipeline;
|
||||||
pub mod scoring;
|
pub mod scoring;
|
||||||
pub mod vector;
|
pub mod vector;
|
||||||
|
|
||||||
use std::collections::{HashMap, HashSet};
|
|
||||||
|
|
||||||
use common::{
|
use common::{
|
||||||
error::AppError,
|
error::AppError,
|
||||||
storage::{
|
storage::{
|
||||||
db::SurrealDbClient,
|
db::SurrealDbClient,
|
||||||
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject},
|
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
|
||||||
},
|
},
|
||||||
utils::embedding::generate_embedding,
|
|
||||||
};
|
};
|
||||||
use futures::{stream::FuturesUnordered, StreamExt};
|
use tracing::instrument;
|
||||||
use graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids};
|
|
||||||
use scoring::{
|
|
||||||
clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc,
|
|
||||||
FusionWeights, Scored,
|
|
||||||
};
|
|
||||||
use tracing::{debug, instrument, trace};
|
|
||||||
|
|
||||||
use crate::{fts::find_items_by_fts, vector::find_items_by_vector_similarity_with_embedding};
|
pub use pipeline::{retrieved_entities_to_json, RetrievalConfig, RetrievalTuning};
|
||||||
|
|
||||||
// Tunable knobs controlling first-pass recall, graph expansion, and answer shaping.
|
|
||||||
const ENTITY_VECTOR_TAKE: usize = 15;
|
|
||||||
const CHUNK_VECTOR_TAKE: usize = 20;
|
|
||||||
const ENTITY_FTS_TAKE: usize = 10;
|
|
||||||
const CHUNK_FTS_TAKE: usize = 20;
|
|
||||||
const SCORE_THRESHOLD: f32 = 0.35;
|
|
||||||
const FALLBACK_MIN_RESULTS: usize = 10;
|
|
||||||
const TOKEN_BUDGET_ESTIMATE: usize = 2800;
|
|
||||||
const AVG_CHARS_PER_TOKEN: usize = 4;
|
|
||||||
const MAX_CHUNKS_PER_ENTITY: usize = 4;
|
|
||||||
const GRAPH_TRAVERSAL_SEED_LIMIT: usize = 5;
|
|
||||||
const GRAPH_NEIGHBOR_LIMIT: usize = 6;
|
|
||||||
const GRAPH_SCORE_DECAY: f32 = 0.75;
|
|
||||||
const GRAPH_SEED_MIN_SCORE: f32 = 0.4;
|
|
||||||
const GRAPH_VECTOR_INHERITANCE: f32 = 0.6;
|
|
||||||
|
|
||||||
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
|
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@@ -56,435 +32,34 @@ pub struct RetrievedEntity {
|
|||||||
pub chunks: Vec<RetrievedChunk>,
|
pub chunks: Vec<RetrievedChunk>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Primary orchestrator for the process of retrieving KnowledgeEntitities related to a input_text
|
||||||
#[instrument(skip_all, fields(user_id))]
|
#[instrument(skip_all, fields(user_id))]
|
||||||
pub async fn retrieve_entities(
|
pub async fn retrieve_entities(
|
||||||
db_client: &SurrealDbClient,
|
db_client: &SurrealDbClient,
|
||||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||||
query: &str,
|
input_text: &str,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||||
trace!("Generating query embedding for hybrid retrieval");
|
pipeline::run_pipeline(
|
||||||
let query_embedding = generate_embedding(openai_client, query, db_client).await?;
|
|
||||||
retrieve_entities_with_embedding(db_client, query_embedding, query, user_id).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) async fn retrieve_entities_with_embedding(
|
|
||||||
db_client: &SurrealDbClient,
|
|
||||||
query_embedding: Vec<f32>,
|
|
||||||
query: &str,
|
|
||||||
user_id: &str,
|
|
||||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
|
||||||
// 1) Gather first-pass candidates from vector search and BM25.
|
|
||||||
let weights = FusionWeights::default();
|
|
||||||
|
|
||||||
let (vector_entities, vector_chunks, mut fts_entities, mut fts_chunks) = tokio::try_join!(
|
|
||||||
find_items_by_vector_similarity_with_embedding(
|
|
||||||
ENTITY_VECTOR_TAKE,
|
|
||||||
query_embedding.clone(),
|
|
||||||
db_client,
|
|
||||||
"knowledge_entity",
|
|
||||||
user_id,
|
|
||||||
),
|
|
||||||
find_items_by_vector_similarity_with_embedding(
|
|
||||||
CHUNK_VECTOR_TAKE,
|
|
||||||
query_embedding,
|
|
||||||
db_client,
|
|
||||||
"text_chunk",
|
|
||||||
user_id,
|
|
||||||
),
|
|
||||||
find_items_by_fts(
|
|
||||||
ENTITY_FTS_TAKE,
|
|
||||||
query,
|
|
||||||
db_client,
|
|
||||||
"knowledge_entity",
|
|
||||||
user_id
|
|
||||||
),
|
|
||||||
find_items_by_fts(CHUNK_FTS_TAKE, query, db_client, "text_chunk", user_id),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
debug!(
|
|
||||||
vector_entities = vector_entities.len(),
|
|
||||||
vector_chunks = vector_chunks.len(),
|
|
||||||
fts_entities = fts_entities.len(),
|
|
||||||
fts_chunks = fts_chunks.len(),
|
|
||||||
"Hybrid retrieval initial candidate counts"
|
|
||||||
);
|
|
||||||
|
|
||||||
normalize_fts_scores(&mut fts_entities);
|
|
||||||
normalize_fts_scores(&mut fts_chunks);
|
|
||||||
|
|
||||||
let mut entity_candidates: HashMap<String, Scored<KnowledgeEntity>> = HashMap::new();
|
|
||||||
let mut chunk_candidates: HashMap<String, Scored<TextChunk>> = HashMap::new();
|
|
||||||
|
|
||||||
// Collate raw retrieval results so each ID accumulates all available signals.
|
|
||||||
merge_scored_by_id(&mut entity_candidates, vector_entities);
|
|
||||||
merge_scored_by_id(&mut entity_candidates, fts_entities);
|
|
||||||
merge_scored_by_id(&mut chunk_candidates, vector_chunks);
|
|
||||||
merge_scored_by_id(&mut chunk_candidates, fts_chunks);
|
|
||||||
|
|
||||||
// 2) Normalize scores, fuse them, and allow high-confidence entities to pull neighbors from the graph.
|
|
||||||
apply_fusion(&mut entity_candidates, weights);
|
|
||||||
apply_fusion(&mut chunk_candidates, weights);
|
|
||||||
enrich_entities_from_graph(&mut entity_candidates, db_client, user_id, weights).await?;
|
|
||||||
|
|
||||||
// 3) Track high-signal chunk sources so we can backfill missing entities.
|
|
||||||
let chunk_by_source = group_chunks_by_source(&chunk_candidates);
|
|
||||||
let mut missing_sources = Vec::new();
|
|
||||||
|
|
||||||
for source_id in chunk_by_source.keys() {
|
|
||||||
if !entity_candidates
|
|
||||||
.values()
|
|
||||||
.any(|entity| entity.item.source_id == *source_id)
|
|
||||||
{
|
|
||||||
missing_sources.push(source_id.clone());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !missing_sources.is_empty() {
|
|
||||||
let related_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
|
|
||||||
missing_sources.clone(),
|
|
||||||
"knowledge_entity",
|
|
||||||
user_id,
|
|
||||||
db_client,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
for entity in related_entities {
|
|
||||||
if let Some(chunks) = chunk_by_source.get(&entity.source_id) {
|
|
||||||
let best_chunk_score = chunks
|
|
||||||
.iter()
|
|
||||||
.map(|chunk| chunk.fused)
|
|
||||||
.fold(0.0f32, f32::max);
|
|
||||||
|
|
||||||
let mut scored = Scored::new(entity.clone()).with_vector_score(best_chunk_score);
|
|
||||||
let fused = fuse_scores(&scored.scores, weights);
|
|
||||||
scored.update_fused(fused);
|
|
||||||
entity_candidates.insert(entity.id.clone(), scored);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Boost entities with evidence from high scoring chunks.
|
|
||||||
for entity in entity_candidates.values_mut() {
|
|
||||||
if let Some(chunks) = chunk_by_source.get(&entity.item.source_id) {
|
|
||||||
let best_chunk_score = chunks
|
|
||||||
.iter()
|
|
||||||
.map(|chunk| chunk.fused)
|
|
||||||
.fold(0.0f32, f32::max);
|
|
||||||
|
|
||||||
if best_chunk_score > 0.0 {
|
|
||||||
let boosted = entity.scores.vector.unwrap_or(0.0).max(best_chunk_score);
|
|
||||||
entity.scores.vector = Some(boosted);
|
|
||||||
let fused = fuse_scores(&entity.scores, weights);
|
|
||||||
entity.update_fused(fused);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut entity_results: Vec<Scored<KnowledgeEntity>> =
|
|
||||||
entity_candidates.into_values().collect();
|
|
||||||
sort_by_fused_desc(&mut entity_results);
|
|
||||||
|
|
||||||
let mut filtered_entities: Vec<Scored<KnowledgeEntity>> = entity_results
|
|
||||||
.iter()
|
|
||||||
.filter(|candidate| candidate.fused >= SCORE_THRESHOLD)
|
|
||||||
.cloned()
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
if filtered_entities.len() < FALLBACK_MIN_RESULTS {
|
|
||||||
// Low recall scenarios still benefit from some context; take the top N regardless of score.
|
|
||||||
filtered_entities = entity_results
|
|
||||||
.into_iter()
|
|
||||||
.take(FALLBACK_MIN_RESULTS)
|
|
||||||
.collect();
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4) Re-rank chunks and prepare for attachment to surviving entities.
|
|
||||||
let mut chunk_results: Vec<Scored<TextChunk>> = chunk_candidates.into_values().collect();
|
|
||||||
sort_by_fused_desc(&mut chunk_results);
|
|
||||||
|
|
||||||
let mut chunk_by_id: HashMap<String, Scored<TextChunk>> = HashMap::new();
|
|
||||||
for chunk in chunk_results {
|
|
||||||
chunk_by_id.insert(chunk.item.id.clone(), chunk);
|
|
||||||
}
|
|
||||||
|
|
||||||
enrich_chunks_from_entities(
|
|
||||||
&mut chunk_by_id,
|
|
||||||
&filtered_entities,
|
|
||||||
db_client,
|
db_client,
|
||||||
|
openai_client,
|
||||||
|
input_text,
|
||||||
user_id,
|
user_id,
|
||||||
weights,
|
RetrievalConfig::default(),
|
||||||
)
|
)
|
||||||
.await?;
|
.await
|
||||||
|
|
||||||
let mut chunk_values: Vec<Scored<TextChunk>> = chunk_by_id.into_values().collect();
|
|
||||||
sort_by_fused_desc(&mut chunk_values);
|
|
||||||
|
|
||||||
Ok(assemble_results(filtered_entities, chunk_values))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Minimal record used while seeding graph expansion so we can retain the original fused score.
|
|
||||||
#[derive(Clone)]
|
|
||||||
struct GraphSeed {
|
|
||||||
id: String,
|
|
||||||
fused: f32,
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn enrich_entities_from_graph(
|
|
||||||
entity_candidates: &mut HashMap<String, Scored<KnowledgeEntity>>,
|
|
||||||
db_client: &SurrealDbClient,
|
|
||||||
user_id: &str,
|
|
||||||
weights: FusionWeights,
|
|
||||||
) -> Result<(), AppError> {
|
|
||||||
if entity_candidates.is_empty() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Select a small frontier of high-confidence entities to seed the relationship walk.
|
|
||||||
let mut seeds: Vec<GraphSeed> = entity_candidates
|
|
||||||
.values()
|
|
||||||
.filter(|entity| entity.fused >= GRAPH_SEED_MIN_SCORE)
|
|
||||||
.map(|entity| GraphSeed {
|
|
||||||
id: entity.item.id.clone(),
|
|
||||||
fused: entity.fused,
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
if seeds.is_empty() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prioritise the strongest seeds so we explore the most grounded context first.
|
|
||||||
seeds.sort_by(|a, b| {
|
|
||||||
b.fused
|
|
||||||
.partial_cmp(&a.fused)
|
|
||||||
.unwrap_or(std::cmp::Ordering::Equal)
|
|
||||||
});
|
|
||||||
seeds.truncate(GRAPH_TRAVERSAL_SEED_LIMIT);
|
|
||||||
|
|
||||||
let mut futures = FuturesUnordered::new();
|
|
||||||
for seed in seeds.clone() {
|
|
||||||
let user_id = user_id.to_owned();
|
|
||||||
futures.push(async move {
|
|
||||||
// Fetch neighbors concurrently to avoid serial graph round trips.
|
|
||||||
let neighbors = find_entities_by_relationship_by_id(
|
|
||||||
db_client,
|
|
||||||
&seed.id,
|
|
||||||
&user_id,
|
|
||||||
GRAPH_NEIGHBOR_LIMIT,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
(seed, neighbors)
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
while let Some((seed, neighbors_result)) = futures.next().await {
|
|
||||||
let neighbors = neighbors_result.map_err(AppError::from)?;
|
|
||||||
if neighbors.is_empty() {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fold neighbors back into the candidate map and let them inherit attenuated signal.
|
|
||||||
for neighbor in neighbors {
|
|
||||||
if neighbor.id == seed.id {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let graph_score = clamp_unit(seed.fused * GRAPH_SCORE_DECAY);
|
|
||||||
let entry = entity_candidates
|
|
||||||
.entry(neighbor.id.clone())
|
|
||||||
.or_insert_with(|| Scored::new(neighbor.clone()));
|
|
||||||
|
|
||||||
entry.item = neighbor;
|
|
||||||
|
|
||||||
let inherited_vector = clamp_unit(graph_score * GRAPH_VECTOR_INHERITANCE);
|
|
||||||
let vector_existing = entry.scores.vector.unwrap_or(0.0);
|
|
||||||
if inherited_vector > vector_existing {
|
|
||||||
entry.scores.vector = Some(inherited_vector);
|
|
||||||
}
|
|
||||||
|
|
||||||
let existing_graph = entry.scores.graph.unwrap_or(f32::MIN);
|
|
||||||
if graph_score > existing_graph || entry.scores.graph.is_none() {
|
|
||||||
entry.scores.graph = Some(graph_score);
|
|
||||||
}
|
|
||||||
|
|
||||||
let fused = fuse_scores(&entry.scores, weights);
|
|
||||||
entry.update_fused(fused);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn normalize_fts_scores<T>(results: &mut [Scored<T>]) {
|
|
||||||
// Scale BM25 outputs into [0,1] to keep fusion weights predictable.
|
|
||||||
let raw_scores: Vec<f32> = results
|
|
||||||
.iter()
|
|
||||||
.map(|candidate| candidate.scores.fts.unwrap_or(0.0))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let normalized = min_max_normalize(&raw_scores);
|
|
||||||
for (candidate, normalized_score) in results.iter_mut().zip(normalized.into_iter()) {
|
|
||||||
candidate.scores.fts = Some(normalized_score);
|
|
||||||
candidate.update_fused(0.0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn apply_fusion<T>(candidates: &mut HashMap<String, Scored<T>>, weights: FusionWeights)
|
|
||||||
where
|
|
||||||
T: StoredObject,
|
|
||||||
{
|
|
||||||
// Collapse individual signals into a single fused score used for ranking.
|
|
||||||
for candidate in candidates.values_mut() {
|
|
||||||
let fused = fuse_scores(&candidate.scores, weights);
|
|
||||||
candidate.update_fused(fused);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn group_chunks_by_source(
|
|
||||||
chunks: &HashMap<String, Scored<TextChunk>>,
|
|
||||||
) -> HashMap<String, Vec<Scored<TextChunk>>> {
|
|
||||||
// Preserve chunk candidates keyed by their originating source entity.
|
|
||||||
let mut by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
|
|
||||||
|
|
||||||
for chunk in chunks.values() {
|
|
||||||
by_source
|
|
||||||
.entry(chunk.item.source_id.clone())
|
|
||||||
.or_default()
|
|
||||||
.push(chunk.clone());
|
|
||||||
}
|
|
||||||
by_source
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn enrich_chunks_from_entities(
|
|
||||||
chunk_candidates: &mut HashMap<String, Scored<TextChunk>>,
|
|
||||||
entities: &[Scored<KnowledgeEntity>],
|
|
||||||
db_client: &SurrealDbClient,
|
|
||||||
user_id: &str,
|
|
||||||
weights: FusionWeights,
|
|
||||||
) -> Result<(), AppError> {
|
|
||||||
// Fetch additional chunks referenced by entities that survived the fusion stage.
|
|
||||||
let mut source_ids: HashSet<String> = HashSet::new();
|
|
||||||
for entity in entities {
|
|
||||||
source_ids.insert(entity.item.source_id.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
if source_ids.is_empty() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
let chunks = find_entities_by_source_ids::<TextChunk>(
|
|
||||||
source_ids.into_iter().collect(),
|
|
||||||
"text_chunk",
|
|
||||||
user_id,
|
|
||||||
db_client,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let mut entity_score_lookup: HashMap<String, f32> = HashMap::new();
|
|
||||||
// Cache fused scores per source so chunks inherit the strength of their parent entity.
|
|
||||||
for entity in entities {
|
|
||||||
entity_score_lookup.insert(entity.item.source_id.clone(), entity.fused);
|
|
||||||
}
|
|
||||||
|
|
||||||
for chunk in chunks {
|
|
||||||
// Ensure each chunk is represented so downstream selection sees the latest content.
|
|
||||||
let entry = chunk_candidates
|
|
||||||
.entry(chunk.id.clone())
|
|
||||||
.or_insert_with(|| Scored::new(chunk.clone()).with_vector_score(0.0));
|
|
||||||
|
|
||||||
let entity_score = entity_score_lookup
|
|
||||||
.get(&chunk.source_id)
|
|
||||||
.copied()
|
|
||||||
.unwrap_or(0.0);
|
|
||||||
|
|
||||||
// Lift chunk score toward the entity score so supporting evidence is prioritised.
|
|
||||||
entry.scores.vector = Some(entry.scores.vector.unwrap_or(0.0).max(entity_score * 0.8));
|
|
||||||
let fused = fuse_scores(&entry.scores, weights);
|
|
||||||
entry.update_fused(fused);
|
|
||||||
entry.item = chunk;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn assemble_results(
|
|
||||||
entities: Vec<Scored<KnowledgeEntity>>,
|
|
||||||
mut chunks: Vec<Scored<TextChunk>>,
|
|
||||||
) -> Vec<RetrievedEntity> {
|
|
||||||
// Re-associate chunk candidates with their parent entity for ranked selection.
|
|
||||||
let mut chunk_by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
|
|
||||||
for chunk in chunks.drain(..) {
|
|
||||||
chunk_by_source
|
|
||||||
.entry(chunk.item.source_id.clone())
|
|
||||||
.or_default()
|
|
||||||
.push(chunk);
|
|
||||||
}
|
|
||||||
|
|
||||||
for chunk_list in chunk_by_source.values_mut() {
|
|
||||||
sort_by_fused_desc(chunk_list);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut token_budget_remaining = TOKEN_BUDGET_ESTIMATE;
|
|
||||||
let mut results = Vec::new();
|
|
||||||
|
|
||||||
for entity in entities {
|
|
||||||
// Attach best chunks first while respecting per-entity and global token caps.
|
|
||||||
let mut selected_chunks = Vec::new();
|
|
||||||
if let Some(candidates) = chunk_by_source.get_mut(&entity.item.source_id) {
|
|
||||||
let mut per_entity_count = 0;
|
|
||||||
candidates.sort_by(|a, b| {
|
|
||||||
b.fused
|
|
||||||
.partial_cmp(&a.fused)
|
|
||||||
.unwrap_or(std::cmp::Ordering::Equal)
|
|
||||||
});
|
|
||||||
|
|
||||||
for candidate in candidates.iter() {
|
|
||||||
if per_entity_count >= MAX_CHUNKS_PER_ENTITY {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
let estimated_tokens = estimate_tokens(&candidate.item.chunk);
|
|
||||||
if estimated_tokens > token_budget_remaining {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
token_budget_remaining = token_budget_remaining.saturating_sub(estimated_tokens);
|
|
||||||
per_entity_count += 1;
|
|
||||||
|
|
||||||
selected_chunks.push(RetrievedChunk {
|
|
||||||
chunk: candidate.item.clone(),
|
|
||||||
score: candidate.fused,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
results.push(RetrievedEntity {
|
|
||||||
entity: entity.item.clone(),
|
|
||||||
score: entity.fused,
|
|
||||||
chunks: selected_chunks,
|
|
||||||
});
|
|
||||||
|
|
||||||
if token_budget_remaining == 0 {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
results
|
|
||||||
}
|
|
||||||
|
|
||||||
fn estimate_tokens(text: &str) -> usize {
|
|
||||||
// Simple heuristic to avoid calling a tokenizer in hot code paths.
|
|
||||||
let chars = text.chars().count().max(1);
|
|
||||||
(chars / AVG_CHARS_PER_TOKEN).max(1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use async_openai::Client;
|
||||||
use common::storage::types::{
|
use common::storage::types::{
|
||||||
knowledge_entity::KnowledgeEntityType, knowledge_relationship::KnowledgeRelationship,
|
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
||||||
|
knowledge_relationship::KnowledgeRelationship,
|
||||||
|
text_chunk::TextChunk,
|
||||||
};
|
};
|
||||||
|
use pipeline::RetrievalConfig;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
fn test_embedding() -> Vec<f32> {
|
fn test_embedding() -> Vec<f32> {
|
||||||
@@ -527,70 +102,46 @@ mod tests {
|
|||||||
COMMIT TRANSACTION;",
|
COMMIT TRANSACTION;",
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to redefine vector indexes for tests");
|
.expect("Failed to configure indices");
|
||||||
|
|
||||||
db
|
db
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn seed_test_data(db: &SurrealDbClient, user_id: &str) {
|
#[tokio::test]
|
||||||
let entity_relevant = KnowledgeEntity::new(
|
async fn test_retrieve_entities_with_embedding_basic_flow() {
|
||||||
"source_a".into(),
|
let db = setup_test_db().await;
|
||||||
"Rust Concurrency Patterns".into(),
|
let user_id = "test_user";
|
||||||
"Discussion about async concurrency in Rust.".into(),
|
let entity = KnowledgeEntity::new(
|
||||||
|
"source_1".into(),
|
||||||
|
"Rust async guide".into(),
|
||||||
|
"Detailed notes about async runtimes".into(),
|
||||||
KnowledgeEntityType::Document,
|
KnowledgeEntityType::Document,
|
||||||
None,
|
None,
|
||||||
entity_embedding_high(),
|
entity_embedding_high(),
|
||||||
user_id.into(),
|
user_id.into(),
|
||||||
);
|
);
|
||||||
let entity_irrelevant = KnowledgeEntity::new(
|
let chunk = TextChunk::new(
|
||||||
"source_b".into(),
|
entity.source_id.clone(),
|
||||||
"Python Tips".into(),
|
"Tokio uses cooperative scheduling for fairness.".into(),
|
||||||
"General Python programming tips.".into(),
|
|
||||||
KnowledgeEntityType::Document,
|
|
||||||
None,
|
|
||||||
entity_embedding_low(),
|
|
||||||
user_id.into(),
|
|
||||||
);
|
|
||||||
|
|
||||||
db.store_item(entity_relevant.clone())
|
|
||||||
.await
|
|
||||||
.expect("Failed to store relevant entity");
|
|
||||||
db.store_item(entity_irrelevant.clone())
|
|
||||||
.await
|
|
||||||
.expect("Failed to store irrelevant entity");
|
|
||||||
|
|
||||||
let chunk_primary = TextChunk::new(
|
|
||||||
entity_relevant.source_id.clone(),
|
|
||||||
"Tokio enables async concurrency with lightweight tasks.".into(),
|
|
||||||
chunk_embedding_primary(),
|
chunk_embedding_primary(),
|
||||||
user_id.into(),
|
user_id.into(),
|
||||||
);
|
);
|
||||||
let chunk_secondary = TextChunk::new(
|
|
||||||
entity_irrelevant.source_id.clone(),
|
|
||||||
"Python focuses on readability and dynamic typing.".into(),
|
|
||||||
chunk_embedding_secondary(),
|
|
||||||
user_id.into(),
|
|
||||||
);
|
|
||||||
|
|
||||||
db.store_item(chunk_primary)
|
db.store_item(entity.clone())
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store primary chunk");
|
.expect("Failed to store entity");
|
||||||
db.store_item(chunk_secondary)
|
db.store_item(chunk.clone())
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store secondary chunk");
|
.expect("Failed to store chunk");
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
let openai_client = Client::new();
|
||||||
async fn test_hybrid_retrieval_prioritises_relevant_entity() {
|
let results = pipeline::run_pipeline_with_embedding(
|
||||||
let db = setup_test_db().await;
|
|
||||||
let user_id = "user123";
|
|
||||||
seed_test_data(&db, user_id).await;
|
|
||||||
|
|
||||||
let results = retrieve_entities_with_embedding(
|
|
||||||
&db,
|
&db,
|
||||||
|
&openai_client,
|
||||||
test_embedding(),
|
test_embedding(),
|
||||||
"Rust concurrency async tasks",
|
"Rust concurrency async tasks",
|
||||||
user_id,
|
user_id,
|
||||||
|
RetrievalConfig::default(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.expect("Hybrid retrieval failed");
|
.expect("Hybrid retrieval failed");
|
||||||
@@ -599,9 +150,7 @@ mod tests {
|
|||||||
!results.is_empty(),
|
!results.is_empty(),
|
||||||
"Expected at least one retrieval result"
|
"Expected at least one retrieval result"
|
||||||
);
|
);
|
||||||
|
|
||||||
let top = &results[0];
|
let top = &results[0];
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
top.entity.name.contains("Rust"),
|
top.entity.name.contains("Rust"),
|
||||||
"Expected Rust entity to be ranked first"
|
"Expected Rust entity to be ranked first"
|
||||||
@@ -610,16 +159,6 @@ mod tests {
|
|||||||
!top.chunks.is_empty(),
|
!top.chunks.is_empty(),
|
||||||
"Expected Rust entity to include supporting chunks"
|
"Expected Rust entity to include supporting chunks"
|
||||||
);
|
);
|
||||||
|
|
||||||
let chunk_texts: Vec<&str> = top
|
|
||||||
.chunks
|
|
||||||
.iter()
|
|
||||||
.map(|chunk| chunk.chunk.chunk.as_str())
|
|
||||||
.collect();
|
|
||||||
assert!(
|
|
||||||
chunk_texts.iter().any(|text| text.contains("Tokio")),
|
|
||||||
"Expected chunk discussing Tokio to be included"
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -673,6 +212,7 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.expect("Failed to store neighbor chunk");
|
.expect("Failed to store neighbor chunk");
|
||||||
|
|
||||||
|
let openai_client = Client::new();
|
||||||
let relationship = KnowledgeRelationship::new(
|
let relationship = KnowledgeRelationship::new(
|
||||||
primary.id.clone(),
|
primary.id.clone(),
|
||||||
neighbor.id.clone(),
|
neighbor.id.clone(),
|
||||||
@@ -685,11 +225,13 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.expect("Failed to store relationship");
|
.expect("Failed to store relationship");
|
||||||
|
|
||||||
let results = retrieve_entities_with_embedding(
|
let results = pipeline::run_pipeline_with_embedding(
|
||||||
&db,
|
&db,
|
||||||
|
&openai_client,
|
||||||
test_embedding(),
|
test_embedding(),
|
||||||
"Rust concurrency async tasks",
|
"Rust concurrency async tasks",
|
||||||
user_id,
|
user_id,
|
||||||
|
RetrievalConfig::default(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.expect("Hybrid retrieval failed");
|
.expect("Hybrid retrieval failed");
|
||||||
|
|||||||
61
composite-retrieval/src/pipeline/config.rs
Normal file
61
composite-retrieval/src/pipeline/config.rs
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// Tunable parameters that govern each retrieval stage.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct RetrievalTuning {
|
||||||
|
pub entity_vector_take: usize,
|
||||||
|
pub chunk_vector_take: usize,
|
||||||
|
pub entity_fts_take: usize,
|
||||||
|
pub chunk_fts_take: usize,
|
||||||
|
pub score_threshold: f32,
|
||||||
|
pub fallback_min_results: usize,
|
||||||
|
pub token_budget_estimate: usize,
|
||||||
|
pub avg_chars_per_token: usize,
|
||||||
|
pub max_chunks_per_entity: usize,
|
||||||
|
pub graph_traversal_seed_limit: usize,
|
||||||
|
pub graph_neighbor_limit: usize,
|
||||||
|
pub graph_score_decay: f32,
|
||||||
|
pub graph_seed_min_score: f32,
|
||||||
|
pub graph_vector_inheritance: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for RetrievalTuning {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
entity_vector_take: 15,
|
||||||
|
chunk_vector_take: 20,
|
||||||
|
entity_fts_take: 10,
|
||||||
|
chunk_fts_take: 20,
|
||||||
|
score_threshold: 0.35,
|
||||||
|
fallback_min_results: 10,
|
||||||
|
token_budget_estimate: 2800,
|
||||||
|
avg_chars_per_token: 4,
|
||||||
|
max_chunks_per_entity: 4,
|
||||||
|
graph_traversal_seed_limit: 5,
|
||||||
|
graph_neighbor_limit: 6,
|
||||||
|
graph_score_decay: 0.75,
|
||||||
|
graph_seed_min_score: 0.4,
|
||||||
|
graph_vector_inheritance: 0.6,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wrapper containing tuning plus future flags for per-request overrides.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct RetrievalConfig {
|
||||||
|
pub tuning: RetrievalTuning,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RetrievalConfig {
|
||||||
|
pub fn new(tuning: RetrievalTuning) -> Self {
|
||||||
|
Self { tuning }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for RetrievalConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
tuning: RetrievalTuning::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
100
composite-retrieval/src/pipeline/mod.rs
Normal file
100
composite-retrieval/src/pipeline/mod.rs
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
mod config;
|
||||||
|
mod stages;
|
||||||
|
mod state;
|
||||||
|
|
||||||
|
pub use config::{RetrievalConfig, RetrievalTuning};
|
||||||
|
|
||||||
|
use crate::RetrievedEntity;
|
||||||
|
use async_openai::Client;
|
||||||
|
use common::{error::AppError, storage::db::SurrealDbClient};
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
|
/// Drives the retrieval pipeline from embedding through final assembly.
|
||||||
|
pub async fn run_pipeline(
|
||||||
|
db_client: &SurrealDbClient,
|
||||||
|
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||||
|
input_text: &str,
|
||||||
|
user_id: &str,
|
||||||
|
config: RetrievalConfig,
|
||||||
|
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||||
|
let machine = state::ready();
|
||||||
|
let input_chars = input_text.chars().count();
|
||||||
|
let input_preview: String = input_text.chars().take(120).collect();
|
||||||
|
let input_preview_clean = input_preview.replace('\n', " ");
|
||||||
|
let preview_len = input_preview_clean.chars().count();
|
||||||
|
info!(
|
||||||
|
%user_id,
|
||||||
|
input_chars,
|
||||||
|
preview_truncated = input_chars > preview_len,
|
||||||
|
preview = %input_preview_clean,
|
||||||
|
"Starting ingestion retrieval pipeline"
|
||||||
|
);
|
||||||
|
let mut ctx = stages::PipelineContext::new(
|
||||||
|
db_client,
|
||||||
|
openai_client,
|
||||||
|
input_text.to_owned(),
|
||||||
|
user_id.to_owned(),
|
||||||
|
config,
|
||||||
|
);
|
||||||
|
let machine = stages::embed(machine, &mut ctx).await?;
|
||||||
|
let machine = stages::collect_candidates(machine, &mut ctx).await?;
|
||||||
|
let machine = stages::expand_graph(machine, &mut ctx).await?;
|
||||||
|
let machine = stages::attach_chunks(machine, &mut ctx).await?;
|
||||||
|
let results = stages::assemble(machine, &mut ctx)?;
|
||||||
|
|
||||||
|
Ok(results)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub async fn run_pipeline_with_embedding(
|
||||||
|
db_client: &SurrealDbClient,
|
||||||
|
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||||
|
query_embedding: Vec<f32>,
|
||||||
|
input_text: &str,
|
||||||
|
user_id: &str,
|
||||||
|
config: RetrievalConfig,
|
||||||
|
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||||
|
let machine = state::ready();
|
||||||
|
let mut ctx = stages::PipelineContext::with_embedding(
|
||||||
|
db_client,
|
||||||
|
openai_client,
|
||||||
|
query_embedding,
|
||||||
|
input_text.to_owned(),
|
||||||
|
user_id.to_owned(),
|
||||||
|
config,
|
||||||
|
);
|
||||||
|
let machine = stages::embed(machine, &mut ctx).await?;
|
||||||
|
let machine = stages::collect_candidates(machine, &mut ctx).await?;
|
||||||
|
let machine = stages::expand_graph(machine, &mut ctx).await?;
|
||||||
|
let machine = stages::attach_chunks(machine, &mut ctx).await?;
|
||||||
|
let results = stages::assemble(machine, &mut ctx)?;
|
||||||
|
|
||||||
|
Ok(results)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper exposed for tests to convert retrieved entities into downstream prompt JSON.
|
||||||
|
pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value {
|
||||||
|
serde_json::json!(entities
|
||||||
|
.iter()
|
||||||
|
.map(|entry| {
|
||||||
|
serde_json::json!({
|
||||||
|
"KnowledgeEntity": {
|
||||||
|
"id": entry.entity.id,
|
||||||
|
"name": entry.entity.name,
|
||||||
|
"description": entry.entity.description,
|
||||||
|
"score": round_score(entry.score),
|
||||||
|
"chunks": entry.chunks.iter().map(|chunk| {
|
||||||
|
serde_json::json!({
|
||||||
|
"score": round_score(chunk.score),
|
||||||
|
"content": chunk.chunk.chunk
|
||||||
|
})
|
||||||
|
}).collect::<Vec<_>>()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn round_score(value: f32) -> f64 {
|
||||||
|
(f64::from(value) * 1000.0).round() / 1000.0
|
||||||
|
}
|
||||||
599
composite-retrieval/src/pipeline/stages/mod.rs
Normal file
599
composite-retrieval/src/pipeline/stages/mod.rs
Normal file
@@ -0,0 +1,599 @@
|
|||||||
|
use async_openai::Client;
|
||||||
|
use common::{
|
||||||
|
error::AppError,
|
||||||
|
storage::{
|
||||||
|
db::SurrealDbClient,
|
||||||
|
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject},
|
||||||
|
},
|
||||||
|
utils::embedding::generate_embedding,
|
||||||
|
};
|
||||||
|
use futures::{stream::FuturesUnordered, StreamExt};
|
||||||
|
use state_machines::core::GuardError;
|
||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
use tracing::{debug, instrument, warn};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
fts::find_items_by_fts,
|
||||||
|
graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids},
|
||||||
|
scoring::{
|
||||||
|
clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc,
|
||||||
|
FusionWeights, Scored,
|
||||||
|
},
|
||||||
|
vector::find_items_by_vector_similarity_with_embedding,
|
||||||
|
RetrievedChunk, RetrievedEntity,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
config::RetrievalConfig,
|
||||||
|
state::{
|
||||||
|
CandidatesLoaded, ChunksAttached, Embedded, GraphExpanded, HybridRetrievalMachine, Ready,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
pub struct PipelineContext<'a> {
|
||||||
|
pub db_client: &'a SurrealDbClient,
|
||||||
|
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||||
|
pub input_text: String,
|
||||||
|
pub user_id: String,
|
||||||
|
pub config: RetrievalConfig,
|
||||||
|
pub query_embedding: Option<Vec<f32>>,
|
||||||
|
pub entity_candidates: HashMap<String, Scored<KnowledgeEntity>>,
|
||||||
|
pub chunk_candidates: HashMap<String, Scored<TextChunk>>,
|
||||||
|
pub filtered_entities: Vec<Scored<KnowledgeEntity>>,
|
||||||
|
pub chunk_values: Vec<Scored<TextChunk>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> PipelineContext<'a> {
|
||||||
|
pub fn new(
|
||||||
|
db_client: &'a SurrealDbClient,
|
||||||
|
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||||
|
input_text: String,
|
||||||
|
user_id: String,
|
||||||
|
config: RetrievalConfig,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
db_client,
|
||||||
|
openai_client,
|
||||||
|
input_text,
|
||||||
|
user_id,
|
||||||
|
config,
|
||||||
|
query_embedding: None,
|
||||||
|
entity_candidates: HashMap::new(),
|
||||||
|
chunk_candidates: HashMap::new(),
|
||||||
|
filtered_entities: Vec::new(),
|
||||||
|
chunk_values: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub fn with_embedding(
|
||||||
|
db_client: &'a SurrealDbClient,
|
||||||
|
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||||
|
query_embedding: Vec<f32>,
|
||||||
|
input_text: String,
|
||||||
|
user_id: String,
|
||||||
|
config: RetrievalConfig,
|
||||||
|
) -> Self {
|
||||||
|
let mut ctx = Self::new(db_client, openai_client, input_text, user_id, config);
|
||||||
|
ctx.query_embedding = Some(query_embedding);
|
||||||
|
ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ensure_embedding(&self) -> Result<&Vec<f32>, AppError> {
|
||||||
|
self.query_embedding.as_ref().ok_or_else(|| {
|
||||||
|
AppError::InternalError(
|
||||||
|
"query embedding missing before candidate collection".to_string(),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(level = "trace", skip_all)]
|
||||||
|
pub async fn embed(
|
||||||
|
machine: HybridRetrievalMachine<(), Ready>,
|
||||||
|
ctx: &mut PipelineContext<'_>,
|
||||||
|
) -> Result<HybridRetrievalMachine<(), Embedded>, AppError> {
|
||||||
|
let embedding_cached = ctx.query_embedding.is_some();
|
||||||
|
if embedding_cached {
|
||||||
|
debug!("Reusing cached query embedding for hybrid retrieval");
|
||||||
|
} else {
|
||||||
|
debug!("Generating query embedding for hybrid retrieval");
|
||||||
|
let embedding =
|
||||||
|
generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await?;
|
||||||
|
ctx.query_embedding = Some(embedding);
|
||||||
|
}
|
||||||
|
|
||||||
|
machine
|
||||||
|
.embed()
|
||||||
|
.map_err(|(_, guard)| map_guard_error("embed", guard))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(level = "trace", skip_all)]
|
||||||
|
pub async fn collect_candidates(
|
||||||
|
machine: HybridRetrievalMachine<(), Embedded>,
|
||||||
|
ctx: &mut PipelineContext<'_>,
|
||||||
|
) -> Result<HybridRetrievalMachine<(), CandidatesLoaded>, AppError> {
|
||||||
|
debug!("Collecting initial candidates via vector and FTS search");
|
||||||
|
let embedding = ctx.ensure_embedding()?.clone();
|
||||||
|
let tuning = &ctx.config.tuning;
|
||||||
|
|
||||||
|
let weights = FusionWeights::default();
|
||||||
|
|
||||||
|
let (vector_entities, vector_chunks, mut fts_entities, mut fts_chunks) = tokio::try_join!(
|
||||||
|
find_items_by_vector_similarity_with_embedding(
|
||||||
|
tuning.entity_vector_take,
|
||||||
|
embedding.clone(),
|
||||||
|
ctx.db_client,
|
||||||
|
"knowledge_entity",
|
||||||
|
&ctx.user_id,
|
||||||
|
),
|
||||||
|
find_items_by_vector_similarity_with_embedding(
|
||||||
|
tuning.chunk_vector_take,
|
||||||
|
embedding,
|
||||||
|
ctx.db_client,
|
||||||
|
"text_chunk",
|
||||||
|
&ctx.user_id,
|
||||||
|
),
|
||||||
|
find_items_by_fts(
|
||||||
|
tuning.entity_fts_take,
|
||||||
|
&ctx.input_text,
|
||||||
|
ctx.db_client,
|
||||||
|
"knowledge_entity",
|
||||||
|
&ctx.user_id,
|
||||||
|
),
|
||||||
|
find_items_by_fts(
|
||||||
|
tuning.chunk_fts_take,
|
||||||
|
&ctx.input_text,
|
||||||
|
ctx.db_client,
|
||||||
|
"text_chunk",
|
||||||
|
&ctx.user_id
|
||||||
|
),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
vector_entities = vector_entities.len(),
|
||||||
|
vector_chunks = vector_chunks.len(),
|
||||||
|
fts_entities = fts_entities.len(),
|
||||||
|
fts_chunks = fts_chunks.len(),
|
||||||
|
"Hybrid retrieval initial candidate counts"
|
||||||
|
);
|
||||||
|
|
||||||
|
normalize_fts_scores(&mut fts_entities);
|
||||||
|
normalize_fts_scores(&mut fts_chunks);
|
||||||
|
|
||||||
|
merge_scored_by_id(&mut ctx.entity_candidates, vector_entities);
|
||||||
|
merge_scored_by_id(&mut ctx.entity_candidates, fts_entities);
|
||||||
|
merge_scored_by_id(&mut ctx.chunk_candidates, vector_chunks);
|
||||||
|
merge_scored_by_id(&mut ctx.chunk_candidates, fts_chunks);
|
||||||
|
|
||||||
|
apply_fusion(&mut ctx.entity_candidates, weights);
|
||||||
|
apply_fusion(&mut ctx.chunk_candidates, weights);
|
||||||
|
|
||||||
|
machine
|
||||||
|
.collect_candidates()
|
||||||
|
.map_err(|(_, guard)| map_guard_error("collect_candidates", guard))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(level = "trace", skip_all)]
|
||||||
|
pub async fn expand_graph(
|
||||||
|
machine: HybridRetrievalMachine<(), CandidatesLoaded>,
|
||||||
|
ctx: &mut PipelineContext<'_>,
|
||||||
|
) -> Result<HybridRetrievalMachine<(), GraphExpanded>, AppError> {
|
||||||
|
debug!("Expanding candidates using graph relationships");
|
||||||
|
let tuning = &ctx.config.tuning;
|
||||||
|
let weights = FusionWeights::default();
|
||||||
|
|
||||||
|
if ctx.entity_candidates.is_empty() {
|
||||||
|
return machine
|
||||||
|
.expand_graph()
|
||||||
|
.map_err(|(_, guard)| map_guard_error("expand_graph", guard));
|
||||||
|
}
|
||||||
|
|
||||||
|
let graph_seeds = seeds_from_candidates(
|
||||||
|
&ctx.entity_candidates,
|
||||||
|
tuning.graph_seed_min_score,
|
||||||
|
tuning.graph_traversal_seed_limit,
|
||||||
|
);
|
||||||
|
|
||||||
|
if graph_seeds.is_empty() {
|
||||||
|
return machine
|
||||||
|
.expand_graph()
|
||||||
|
.map_err(|(_, guard)| map_guard_error("expand_graph", guard));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut futures = FuturesUnordered::new();
|
||||||
|
for seed in graph_seeds {
|
||||||
|
let db = ctx.db_client;
|
||||||
|
let user = ctx.user_id.clone();
|
||||||
|
futures.push(async move {
|
||||||
|
let neighbors = find_entities_by_relationship_by_id(
|
||||||
|
db,
|
||||||
|
&seed.id,
|
||||||
|
&user,
|
||||||
|
tuning.graph_neighbor_limit,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
(seed, neighbors)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
while let Some((seed, neighbors_result)) = futures.next().await {
|
||||||
|
let neighbors = neighbors_result.map_err(AppError::from)?;
|
||||||
|
if neighbors.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
for neighbor in neighbors {
|
||||||
|
if neighbor.id == seed.id {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay);
|
||||||
|
let entry = ctx
|
||||||
|
.entity_candidates
|
||||||
|
.entry(neighbor.id.clone())
|
||||||
|
.or_insert_with(|| Scored::new(neighbor.clone()));
|
||||||
|
|
||||||
|
entry.item = neighbor;
|
||||||
|
|
||||||
|
let inherited_vector = clamp_unit(graph_score * tuning.graph_vector_inheritance);
|
||||||
|
let vector_existing = entry.scores.vector.unwrap_or(0.0);
|
||||||
|
if inherited_vector > vector_existing {
|
||||||
|
entry.scores.vector = Some(inherited_vector);
|
||||||
|
}
|
||||||
|
|
||||||
|
let existing_graph = entry.scores.graph.unwrap_or(f32::MIN);
|
||||||
|
if graph_score > existing_graph || entry.scores.graph.is_none() {
|
||||||
|
entry.scores.graph = Some(graph_score);
|
||||||
|
}
|
||||||
|
|
||||||
|
let fused = fuse_scores(&entry.scores, weights);
|
||||||
|
entry.update_fused(fused);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
machine
|
||||||
|
.expand_graph()
|
||||||
|
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(level = "trace", skip_all)]
|
||||||
|
pub async fn attach_chunks(
|
||||||
|
machine: HybridRetrievalMachine<(), GraphExpanded>,
|
||||||
|
ctx: &mut PipelineContext<'_>,
|
||||||
|
) -> Result<HybridRetrievalMachine<(), ChunksAttached>, AppError> {
|
||||||
|
debug!("Attaching chunks to surviving entities");
|
||||||
|
let tuning = &ctx.config.tuning;
|
||||||
|
let weights = FusionWeights::default();
|
||||||
|
|
||||||
|
let chunk_by_source = group_chunks_by_source(&ctx.chunk_candidates);
|
||||||
|
|
||||||
|
backfill_entities_from_chunks(
|
||||||
|
&mut ctx.entity_candidates,
|
||||||
|
&chunk_by_source,
|
||||||
|
ctx.db_client,
|
||||||
|
&ctx.user_id,
|
||||||
|
weights,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
boost_entities_with_chunks(&mut ctx.entity_candidates, &chunk_by_source, weights);
|
||||||
|
|
||||||
|
let mut entity_results: Vec<Scored<KnowledgeEntity>> =
|
||||||
|
ctx.entity_candidates.values().cloned().collect();
|
||||||
|
sort_by_fused_desc(&mut entity_results);
|
||||||
|
|
||||||
|
let mut filtered_entities: Vec<Scored<KnowledgeEntity>> = entity_results
|
||||||
|
.iter()
|
||||||
|
.filter(|candidate| candidate.fused >= tuning.score_threshold)
|
||||||
|
.cloned()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if filtered_entities.len() < tuning.fallback_min_results {
|
||||||
|
filtered_entities = entity_results
|
||||||
|
.into_iter()
|
||||||
|
.take(tuning.fallback_min_results)
|
||||||
|
.collect();
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.filtered_entities = filtered_entities;
|
||||||
|
|
||||||
|
let mut chunk_results: Vec<Scored<TextChunk>> =
|
||||||
|
ctx.chunk_candidates.values().cloned().collect();
|
||||||
|
sort_by_fused_desc(&mut chunk_results);
|
||||||
|
|
||||||
|
let mut chunk_by_id: HashMap<String, Scored<TextChunk>> = HashMap::new();
|
||||||
|
for chunk in chunk_results {
|
||||||
|
chunk_by_id.insert(chunk.item.id.clone(), chunk);
|
||||||
|
}
|
||||||
|
|
||||||
|
enrich_chunks_from_entities(
|
||||||
|
&mut chunk_by_id,
|
||||||
|
&ctx.filtered_entities,
|
||||||
|
ctx.db_client,
|
||||||
|
&ctx.user_id,
|
||||||
|
weights,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let mut chunk_values: Vec<Scored<TextChunk>> = chunk_by_id.into_values().collect();
|
||||||
|
sort_by_fused_desc(&mut chunk_values);
|
||||||
|
|
||||||
|
ctx.chunk_values = chunk_values;
|
||||||
|
|
||||||
|
machine
|
||||||
|
.attach_chunks()
|
||||||
|
.map_err(|(_, guard)| map_guard_error("attach_chunks", guard))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(level = "trace", skip_all)]
|
||||||
|
pub fn assemble(
|
||||||
|
machine: HybridRetrievalMachine<(), ChunksAttached>,
|
||||||
|
ctx: &mut PipelineContext<'_>,
|
||||||
|
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||||
|
debug!("Assembling final retrieved entities");
|
||||||
|
let tuning = &ctx.config.tuning;
|
||||||
|
|
||||||
|
let mut chunk_by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
|
||||||
|
for chunk in ctx.chunk_values.drain(..) {
|
||||||
|
chunk_by_source
|
||||||
|
.entry(chunk.item.source_id.clone())
|
||||||
|
.or_default()
|
||||||
|
.push(chunk);
|
||||||
|
}
|
||||||
|
|
||||||
|
for chunk_list in chunk_by_source.values_mut() {
|
||||||
|
sort_by_fused_desc(chunk_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut token_budget_remaining = tuning.token_budget_estimate;
|
||||||
|
let mut results = Vec::new();
|
||||||
|
|
||||||
|
for entity in &ctx.filtered_entities {
|
||||||
|
let mut selected_chunks = Vec::new();
|
||||||
|
if let Some(candidates) = chunk_by_source.get_mut(&entity.item.source_id) {
|
||||||
|
let mut per_entity_count = 0;
|
||||||
|
candidates.sort_by(|a, b| {
|
||||||
|
b.fused
|
||||||
|
.partial_cmp(&a.fused)
|
||||||
|
.unwrap_or(std::cmp::Ordering::Equal)
|
||||||
|
});
|
||||||
|
|
||||||
|
for candidate in candidates.iter() {
|
||||||
|
if per_entity_count >= tuning.max_chunks_per_entity {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let estimated_tokens =
|
||||||
|
estimate_tokens(&candidate.item.chunk, tuning.avg_chars_per_token);
|
||||||
|
if estimated_tokens > token_budget_remaining {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
token_budget_remaining = token_budget_remaining.saturating_sub(estimated_tokens);
|
||||||
|
per_entity_count += 1;
|
||||||
|
|
||||||
|
selected_chunks.push(RetrievedChunk {
|
||||||
|
chunk: candidate.item.clone(),
|
||||||
|
score: candidate.fused,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
results.push(RetrievedEntity {
|
||||||
|
entity: entity.item.clone(),
|
||||||
|
score: entity.fused,
|
||||||
|
chunks: selected_chunks,
|
||||||
|
});
|
||||||
|
|
||||||
|
if token_budget_remaining == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
machine
|
||||||
|
.assemble()
|
||||||
|
.map_err(|(_, guard)| map_guard_error("assemble", guard))?;
|
||||||
|
Ok(results)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn map_guard_error(stage: &'static str, err: GuardError) -> AppError {
|
||||||
|
AppError::InternalError(format!(
|
||||||
|
"state machine guard '{stage}' failed: guard={}, event={}, kind={:?}",
|
||||||
|
err.guard, err.event, err.kind
|
||||||
|
))
|
||||||
|
}
|
||||||
|
fn normalize_fts_scores<T>(results: &mut [Scored<T>]) {
|
||||||
|
let raw_scores: Vec<f32> = results
|
||||||
|
.iter()
|
||||||
|
.map(|candidate| candidate.scores.fts.unwrap_or(0.0))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let normalized = min_max_normalize(&raw_scores);
|
||||||
|
for (candidate, normalized_score) in results.iter_mut().zip(normalized.into_iter()) {
|
||||||
|
candidate.scores.fts = Some(normalized_score);
|
||||||
|
candidate.update_fused(0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_fusion<T>(candidates: &mut HashMap<String, Scored<T>>, weights: FusionWeights)
|
||||||
|
where
|
||||||
|
T: StoredObject,
|
||||||
|
{
|
||||||
|
for candidate in candidates.values_mut() {
|
||||||
|
let fused = fuse_scores(&candidate.scores, weights);
|
||||||
|
candidate.update_fused(fused);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn group_chunks_by_source(
|
||||||
|
chunks: &HashMap<String, Scored<TextChunk>>,
|
||||||
|
) -> HashMap<String, Vec<Scored<TextChunk>>> {
|
||||||
|
let mut by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
|
||||||
|
|
||||||
|
for chunk in chunks.values() {
|
||||||
|
by_source
|
||||||
|
.entry(chunk.item.source_id.clone())
|
||||||
|
.or_default()
|
||||||
|
.push(chunk.clone());
|
||||||
|
}
|
||||||
|
by_source
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn backfill_entities_from_chunks(
|
||||||
|
entity_candidates: &mut HashMap<String, Scored<KnowledgeEntity>>,
|
||||||
|
chunk_by_source: &HashMap<String, Vec<Scored<TextChunk>>>,
|
||||||
|
db_client: &SurrealDbClient,
|
||||||
|
user_id: &str,
|
||||||
|
weights: FusionWeights,
|
||||||
|
) -> Result<(), AppError> {
|
||||||
|
let mut missing_sources = Vec::new();
|
||||||
|
|
||||||
|
for source_id in chunk_by_source.keys() {
|
||||||
|
if !entity_candidates
|
||||||
|
.values()
|
||||||
|
.any(|entity| entity.item.source_id == *source_id)
|
||||||
|
{
|
||||||
|
missing_sources.push(source_id.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if missing_sources.is_empty() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let related_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
|
||||||
|
missing_sources.clone(),
|
||||||
|
"knowledge_entity",
|
||||||
|
user_id,
|
||||||
|
db_client,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
if related_entities.is_empty() {
|
||||||
|
warn!("expected related entities for missing chunk sources, but none were found");
|
||||||
|
}
|
||||||
|
|
||||||
|
for entity in related_entities {
|
||||||
|
if let Some(chunks) = chunk_by_source.get(&entity.source_id) {
|
||||||
|
let best_chunk_score = chunks
|
||||||
|
.iter()
|
||||||
|
.map(|chunk| chunk.fused)
|
||||||
|
.fold(0.0f32, f32::max);
|
||||||
|
|
||||||
|
let mut scored = Scored::new(entity.clone()).with_vector_score(best_chunk_score);
|
||||||
|
let fused = fuse_scores(&scored.scores, weights);
|
||||||
|
scored.update_fused(fused);
|
||||||
|
entity_candidates.insert(entity.id.clone(), scored);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn boost_entities_with_chunks(
|
||||||
|
entity_candidates: &mut HashMap<String, Scored<KnowledgeEntity>>,
|
||||||
|
chunk_by_source: &HashMap<String, Vec<Scored<TextChunk>>>,
|
||||||
|
weights: FusionWeights,
|
||||||
|
) {
|
||||||
|
for entity in entity_candidates.values_mut() {
|
||||||
|
if let Some(chunks) = chunk_by_source.get(&entity.item.source_id) {
|
||||||
|
let best_chunk_score = chunks
|
||||||
|
.iter()
|
||||||
|
.map(|chunk| chunk.fused)
|
||||||
|
.fold(0.0f32, f32::max);
|
||||||
|
|
||||||
|
if best_chunk_score > 0.0 {
|
||||||
|
let boosted = entity.scores.vector.unwrap_or(0.0).max(best_chunk_score);
|
||||||
|
entity.scores.vector = Some(boosted);
|
||||||
|
let fused = fuse_scores(&entity.scores, weights);
|
||||||
|
entity.update_fused(fused);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn enrich_chunks_from_entities(
|
||||||
|
chunk_candidates: &mut HashMap<String, Scored<TextChunk>>,
|
||||||
|
entities: &[Scored<KnowledgeEntity>],
|
||||||
|
db_client: &SurrealDbClient,
|
||||||
|
user_id: &str,
|
||||||
|
weights: FusionWeights,
|
||||||
|
) -> Result<(), AppError> {
|
||||||
|
let mut source_ids: HashSet<String> = HashSet::new();
|
||||||
|
for entity in entities {
|
||||||
|
source_ids.insert(entity.item.source_id.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
if source_ids.is_empty() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let chunks = find_entities_by_source_ids::<TextChunk>(
|
||||||
|
source_ids.into_iter().collect(),
|
||||||
|
"text_chunk",
|
||||||
|
user_id,
|
||||||
|
db_client,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let mut entity_score_lookup: HashMap<String, f32> = HashMap::new();
|
||||||
|
for entity in entities {
|
||||||
|
entity_score_lookup.insert(entity.item.source_id.clone(), entity.fused);
|
||||||
|
}
|
||||||
|
|
||||||
|
for chunk in chunks {
|
||||||
|
let entry = chunk_candidates
|
||||||
|
.entry(chunk.id.clone())
|
||||||
|
.or_insert_with(|| Scored::new(chunk.clone()).with_vector_score(0.0));
|
||||||
|
|
||||||
|
let entity_score = entity_score_lookup
|
||||||
|
.get(&chunk.source_id)
|
||||||
|
.copied()
|
||||||
|
.unwrap_or(0.0);
|
||||||
|
|
||||||
|
entry.scores.vector = Some(entry.scores.vector.unwrap_or(0.0).max(entity_score * 0.8));
|
||||||
|
let fused = fuse_scores(&entry.scores, weights);
|
||||||
|
entry.update_fused(fused);
|
||||||
|
entry.item = chunk;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn estimate_tokens(text: &str, avg_chars_per_token: usize) -> usize {
|
||||||
|
let chars = text.chars().count().max(1);
|
||||||
|
(chars / avg_chars_per_token).max(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct GraphSeed {
|
||||||
|
id: String,
|
||||||
|
fused: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn seeds_from_candidates(
|
||||||
|
entity_candidates: &HashMap<String, Scored<KnowledgeEntity>>,
|
||||||
|
min_score: f32,
|
||||||
|
limit: usize,
|
||||||
|
) -> Vec<GraphSeed> {
|
||||||
|
let mut seeds: Vec<GraphSeed> = entity_candidates
|
||||||
|
.values()
|
||||||
|
.filter(|entity| entity.fused >= min_score)
|
||||||
|
.map(|entity| GraphSeed {
|
||||||
|
id: entity.item.id.clone(),
|
||||||
|
fused: entity.fused,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
seeds.sort_by(|a, b| {
|
||||||
|
b.fused
|
||||||
|
.partial_cmp(&a.fused)
|
||||||
|
.unwrap_or(std::cmp::Ordering::Equal)
|
||||||
|
});
|
||||||
|
if seeds.len() > limit {
|
||||||
|
seeds.truncate(limit);
|
||||||
|
}
|
||||||
|
|
||||||
|
seeds
|
||||||
|
}
|
||||||
25
composite-retrieval/src/pipeline/state.rs
Normal file
25
composite-retrieval/src/pipeline/state.rs
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
use state_machines::state_machine;
|
||||||
|
|
||||||
|
state_machine! {
|
||||||
|
name: HybridRetrievalMachine,
|
||||||
|
state: HybridRetrievalState,
|
||||||
|
initial: Ready,
|
||||||
|
states: [Ready, Embedded, CandidatesLoaded, GraphExpanded, ChunksAttached, Completed, Failed],
|
||||||
|
events {
|
||||||
|
embed { transition: { from: Ready, to: Embedded } }
|
||||||
|
collect_candidates { transition: { from: Embedded, to: CandidatesLoaded } }
|
||||||
|
expand_graph { transition: { from: CandidatesLoaded, to: GraphExpanded } }
|
||||||
|
attach_chunks { transition: { from: GraphExpanded, to: ChunksAttached } }
|
||||||
|
assemble { transition: { from: ChunksAttached, to: Completed } }
|
||||||
|
abort {
|
||||||
|
transition: { from: Ready, to: Failed }
|
||||||
|
transition: { from: CandidatesLoaded, to: Failed }
|
||||||
|
transition: { from: GraphExpanded, to: Failed }
|
||||||
|
transition: { from: ChunksAttached, to: Failed }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ready() -> HybridRetrievalMachine<(), Ready> {
|
||||||
|
HybridRetrievalMachine::new(())
|
||||||
|
}
|
||||||
@@ -24,7 +24,10 @@ pub async fn show_account_page(
|
|||||||
RequireUser(user): RequireUser,
|
RequireUser(user): RequireUser,
|
||||||
State(state): State<HtmlState>,
|
State(state): State<HtmlState>,
|
||||||
) -> Result<impl IntoResponse, HtmlError> {
|
) -> Result<impl IntoResponse, HtmlError> {
|
||||||
let timezones = TZ_VARIANTS.iter().map(std::string::ToString::to_string).collect();
|
let timezones = TZ_VARIANTS
|
||||||
|
.iter()
|
||||||
|
.map(std::string::ToString::to_string)
|
||||||
|
.collect();
|
||||||
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
|
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
|
||||||
|
|
||||||
Ok(TemplateResponse::new_template(
|
Ok(TemplateResponse::new_template(
|
||||||
@@ -102,7 +105,10 @@ pub async fn update_timezone(
|
|||||||
..user.clone()
|
..user.clone()
|
||||||
};
|
};
|
||||||
|
|
||||||
let timezones = TZ_VARIANTS.iter().map(std::string::ToString::to_string).collect();
|
let timezones = TZ_VARIANTS
|
||||||
|
.iter()
|
||||||
|
.map(std::string::ToString::to_string)
|
||||||
|
.collect();
|
||||||
|
|
||||||
// Render the API key section block
|
// Render the API key section block
|
||||||
Ok(TemplateResponse::new_partial(
|
Ok(TemplateResponse::new_partial(
|
||||||
|
|||||||
@@ -27,11 +27,15 @@ pub async fn show_signin_form(
|
|||||||
if auth.is_authenticated() {
|
if auth.is_authenticated() {
|
||||||
return Ok(TemplateResponse::redirect("/"));
|
return Ok(TemplateResponse::redirect("/"));
|
||||||
}
|
}
|
||||||
if boosted { Ok(TemplateResponse::new_partial(
|
if boosted {
|
||||||
"auth/signin_base.html",
|
Ok(TemplateResponse::new_partial(
|
||||||
"body",
|
"auth/signin_base.html",
|
||||||
(),
|
"body",
|
||||||
)) } else { Ok(TemplateResponse::new_template("auth/signin_base.html", ())) }
|
(),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
Ok(TemplateResponse::new_template("auth/signin_base.html", ()))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn authenticate_user(
|
pub async fn authenticate_user(
|
||||||
|
|||||||
@@ -29,11 +29,15 @@ pub async fn show_signup_form(
|
|||||||
return Ok(TemplateResponse::redirect("/"));
|
return Ok(TemplateResponse::redirect("/"));
|
||||||
}
|
}
|
||||||
|
|
||||||
if boosted { Ok(TemplateResponse::new_partial(
|
if boosted {
|
||||||
"auth/signup_form.html",
|
Ok(TemplateResponse::new_partial(
|
||||||
"body",
|
"auth/signup_form.html",
|
||||||
(),
|
"body",
|
||||||
)) } else { Ok(TemplateResponse::new_template("auth/signup_form.html", ())) }
|
(),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
Ok(TemplateResponse::new_template("auth/signup_form.html", ()))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn process_signup_and_show_verification(
|
pub async fn process_signup_and_show_verification(
|
||||||
|
|||||||
@@ -9,11 +9,8 @@ use axum::{
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
use composite_retrieval::{
|
use composite_retrieval::{
|
||||||
answer_retrieval::{
|
answer_retrieval::{create_chat_request, create_user_message_with_history, LLMResponseFormat},
|
||||||
create_chat_request, create_user_message_with_history, format_entities_json,
|
retrieve_entities, retrieved_entities_to_json,
|
||||||
LLMResponseFormat,
|
|
||||||
},
|
|
||||||
retrieve_entities,
|
|
||||||
};
|
};
|
||||||
use futures::{
|
use futures::{
|
||||||
stream::{self, once},
|
stream::{self, once},
|
||||||
@@ -136,7 +133,7 @@ pub async fn get_response_stream(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// 3. Create the OpenAI request
|
// 3. Create the OpenAI request
|
||||||
let entities_json = format_entities_json(&entities);
|
let entities_json = retrieved_entities_to_json(&entities);
|
||||||
let formatted_user_message =
|
let formatted_user_message =
|
||||||
create_user_message_with_history(&entities_json, &history, &user_message.content);
|
create_user_message_with_history(&entities_json, &history, &user_message.content);
|
||||||
let settings = match SystemSettings::get_current(&state.db).await {
|
let settings = match SystemSettings::get_current(&state.db).await {
|
||||||
@@ -260,7 +257,11 @@ pub async fn get_response_stream(
|
|||||||
.chain(stream::once(async move {
|
.chain(stream::once(async move {
|
||||||
if let Some(message) = rx_final.recv().await {
|
if let Some(message) = rx_final.recv().await {
|
||||||
// Don't send any event if references is empty
|
// Don't send any event if references is empty
|
||||||
if message.references.as_ref().is_some_and(std::vec::Vec::is_empty) {
|
if message
|
||||||
|
.references
|
||||||
|
.as_ref()
|
||||||
|
.is_some_and(std::vec::Vec::is_empty)
|
||||||
|
{
|
||||||
return Ok(Event::default().event("empty")); // This event won't be sent
|
return Ok(Event::default().event("empty")); // This event won't be sent
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,10 +9,7 @@ use common::{
|
|||||||
error::AppError,
|
error::AppError,
|
||||||
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
|
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
|
||||||
};
|
};
|
||||||
use composite_retrieval::{
|
use composite_retrieval::{retrieve_entities, retrieved_entities_to_json, RetrievedEntity};
|
||||||
answer_retrieval::format_entities_json, retrieve_entities, RetrievedEntity,
|
|
||||||
};
|
|
||||||
use tracing::{debug, info};
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
types::llm_enrichment_result::LLMEnrichmentResult,
|
types::llm_enrichment_result::LLMEnrichmentResult,
|
||||||
@@ -42,11 +39,9 @@ impl IngestionEnricher {
|
|||||||
text: &str,
|
text: &str,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
) -> Result<LLMEnrichmentResult, AppError> {
|
) -> Result<LLMEnrichmentResult, AppError> {
|
||||||
info!("getting similar entitities");
|
|
||||||
let similar_entities = self
|
let similar_entities = self
|
||||||
.find_similar_entities(category, context, text, user_id)
|
.find_similar_entities(category, context, text, user_id)
|
||||||
.await?;
|
.await?;
|
||||||
info!("got similar entitities");
|
|
||||||
let llm_request = self
|
let llm_request = self
|
||||||
.prepare_llm_request(category, context, text, &similar_entities)
|
.prepare_llm_request(category, context, text, &similar_entities)
|
||||||
.await?;
|
.await?;
|
||||||
@@ -60,9 +55,8 @@ impl IngestionEnricher {
|
|||||||
text: &str,
|
text: &str,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||||
let input_text = format!(
|
let input_text =
|
||||||
"content: {text}, category: {category}, user_context: {context:?}"
|
format!("content: {text}, category: {category}, user_context: {context:?}");
|
||||||
);
|
|
||||||
|
|
||||||
retrieve_entities(&self.db_client, &self.openai_client, &input_text, user_id).await
|
retrieve_entities(&self.db_client, &self.openai_client, &input_text, user_id).await
|
||||||
}
|
}
|
||||||
@@ -76,14 +70,12 @@ impl IngestionEnricher {
|
|||||||
) -> Result<CreateChatCompletionRequest, AppError> {
|
) -> Result<CreateChatCompletionRequest, AppError> {
|
||||||
let settings = SystemSettings::get_current(&self.db_client).await?;
|
let settings = SystemSettings::get_current(&self.db_client).await?;
|
||||||
|
|
||||||
let entities_json = format_entities_json(similar_entities);
|
let entities_json = retrieved_entities_to_json(similar_entities);
|
||||||
|
|
||||||
let user_message = format!(
|
let user_message = format!(
|
||||||
"Category:\n{category}\ncontext:\n{context:?}\nContent:\n{text}\nExisting KnowledgeEntities in database:\n{entities_json}"
|
"Category:\n{category}\ncontext:\n{context:?}\nContent:\n{text}\nExisting KnowledgeEntities in database:\n{entities_json}"
|
||||||
);
|
);
|
||||||
|
|
||||||
debug!("Prepared LLM request message: {}", user_message);
|
|
||||||
|
|
||||||
let response_format = ResponseFormat::JsonSchema {
|
let response_format = ResponseFormat::JsonSchema {
|
||||||
json_schema: ResponseFormatJsonSchema {
|
json_schema: ResponseFormatJsonSchema {
|
||||||
description: Some("Structured analysis of the submitted content".into()),
|
description: Some("Structured analysis of the submitted content".into()),
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ use std::{sync::Arc, time::Instant};
|
|||||||
|
|
||||||
use text_splitter::TextSplitter;
|
use text_splitter::TextSplitter;
|
||||||
use tokio::time::{sleep, Duration};
|
use tokio::time::{sleep, Duration};
|
||||||
use tracing::{info, info_span, warn};
|
use tracing::{debug, info, info_span, warn};
|
||||||
|
|
||||||
use common::{
|
use common::{
|
||||||
error::AppError,
|
error::AppError,
|
||||||
@@ -67,6 +67,34 @@ impl IngestionPipeline {
|
|||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
let text_len = text_content.text.chars().count();
|
||||||
|
let preview: String = text_content.text.chars().take(120).collect();
|
||||||
|
let preview_clean = preview.replace("\n", " ");
|
||||||
|
let preview_len = preview_clean.chars().count();
|
||||||
|
let truncated = text_len > preview_len;
|
||||||
|
let context_len = text_content
|
||||||
|
.context
|
||||||
|
.as_ref()
|
||||||
|
.map(|c| c.chars().count())
|
||||||
|
.unwrap_or(0);
|
||||||
|
info!(
|
||||||
|
%task_id,
|
||||||
|
attempt,
|
||||||
|
user_id = %text_content.user_id,
|
||||||
|
category = %text_content.category,
|
||||||
|
text_chars = text_len,
|
||||||
|
context_chars = context_len,
|
||||||
|
attachments = text_content.file_info.is_some(),
|
||||||
|
"ingestion task input ready"
|
||||||
|
);
|
||||||
|
debug!(
|
||||||
|
%task_id,
|
||||||
|
attempt,
|
||||||
|
preview = %preview_clean,
|
||||||
|
preview_truncated = truncated,
|
||||||
|
"ingestion task input preview"
|
||||||
|
);
|
||||||
|
|
||||||
match self.process(&text_content).await {
|
match self.process(&text_content).await {
|
||||||
Ok(()) => {
|
Ok(()) => {
|
||||||
processing_task.mark_succeeded(&self.db).await?;
|
processing_task.mark_succeeded(&self.db).await?;
|
||||||
|
|||||||
@@ -132,9 +132,7 @@ async fn render_pdf_pages(file_path: &Path, pages: &[u32]) -> Result<Vec<Vec<u8>
|
|||||||
let mut captures = Vec::with_capacity(pages.len());
|
let mut captures = Vec::with_capacity(pages.len());
|
||||||
|
|
||||||
for (idx, page) in pages.iter().enumerate() {
|
for (idx, page) in pages.iter().enumerate() {
|
||||||
let target = format!(
|
let target = format!("{file_url}#page={page}&toolbar=0&statusbar=0&zoom=page-fit");
|
||||||
"{file_url}#page={page}&toolbar=0&statusbar=0&zoom=page-fit"
|
|
||||||
);
|
|
||||||
tab.navigate_to(&target)
|
tab.navigate_to(&target)
|
||||||
.map_err(|err| AppError::Processing(format!("Failed to navigate to PDF page: {err}")))?
|
.map_err(|err| AppError::Processing(format!("Failed to navigate to PDF page: {err}")))?
|
||||||
.wait_until_navigated()
|
.wait_until_navigated()
|
||||||
@@ -480,11 +478,7 @@ fn is_structural_line(line: &str) -> bool {
|
|||||||
|| line.starts_with('~')
|
|| line.starts_with('~')
|
||||||
|| line.starts_with("| ")
|
|| line.starts_with("| ")
|
||||||
|| line.starts_with("+-")
|
|| line.starts_with("+-")
|
||||||
|| lowered
|
|| lowered.chars().next().is_some_and(|c| c.is_ascii_digit()) && lowered.contains('.')
|
||||||
.chars()
|
|
||||||
.next()
|
|
||||||
.is_some_and(|c| c.is_ascii_digit())
|
|
||||||
&& lowered.contains('.')
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn debug_dump_directory() -> Option<PathBuf> {
|
fn debug_dump_directory() -> Option<PathBuf> {
|
||||||
|
|||||||
Reference in New Issue
Block a user