diff --git a/Cargo.lock b/Cargo.lock index 1a0e354..7a08320 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1345,6 +1345,7 @@ dependencies = [ "futures", "serde", "serde_json", + "state-machines", "surrealdb", "thiserror 1.0.69", "tokio", diff --git a/composite-retrieval/Cargo.toml b/composite-retrieval/Cargo.toml index 54eb07d..8ec509c 100644 --- a/composite-retrieval/Cargo.toml +++ b/composite-retrieval/Cargo.toml @@ -21,3 +21,4 @@ async-openai = { workspace = true } uuid = { workspace = true } common = { path = "../common", features = ["test-utils"] } +state-machines = { workspace = true } diff --git a/composite-retrieval/src/answer_retrieval.rs b/composite-retrieval/src/answer_retrieval.rs index 9c61149..c8c9993 100644 --- a/composite-retrieval/src/answer_retrieval.rs +++ b/composite-retrieval/src/answer_retrieval.rs @@ -17,9 +17,9 @@ use common::{ }, }; use serde::Deserialize; -use serde_json::{json, Value}; +use serde_json::Value; -use crate::{retrieve_entities, RetrievedEntity}; +use crate::{retrieve_entities, retrieved_entities_to_json}; use super::answer_retrieval_helper::get_query_response_schema; @@ -65,7 +65,7 @@ pub async fn get_answer_with_references( let entities = retrieve_entities(surreal_db_client, openai_client, query, user_id).await?; let settings = SystemSettings::get_current(surreal_db_client).await?; - let entities_json = format_entities_json(&entities); + let entities_json = retrieved_entities_to_json(&entities); let user_message = create_user_message(&entities_json, query); let request = create_chat_request(user_message, &settings)?; @@ -83,31 +83,6 @@ pub async fn get_answer_with_references( }) } -pub fn format_entities_json(entities: &[RetrievedEntity]) -> Value { - json!(entities - .iter() - .map(|entry| { - json!({ - "KnowledgeEntity": { - "id": entry.entity.id, - "name": entry.entity.name, - "description": entry.entity.description, - "score": round_score(entry.score), - "chunks": entry.chunks.iter().map(|chunk| { - json!({ - "score": round_score(chunk.score), - "content": chunk.chunk.chunk - }) - }).collect::>() - } - }) - }) - .collect::>()) -} - -fn round_score(value: f32) -> f64 { - (f64::from(value) * 1000.0).round() / 1000.0 -} pub fn create_user_message(entities_json: &Value, query: &str) -> String { format!( r" diff --git a/composite-retrieval/src/lib.rs b/composite-retrieval/src/lib.rs index d5fcc2b..207d48a 100644 --- a/composite-retrieval/src/lib.rs +++ b/composite-retrieval/src/lib.rs @@ -2,44 +2,20 @@ pub mod answer_retrieval; pub mod answer_retrieval_helper; pub mod fts; pub mod graph; +pub mod pipeline; pub mod scoring; pub mod vector; -use std::collections::{HashMap, HashSet}; - use common::{ error::AppError, storage::{ db::SurrealDbClient, - types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject}, + types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk}, }, - utils::embedding::generate_embedding, }; -use futures::{stream::FuturesUnordered, StreamExt}; -use graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids}; -use scoring::{ - clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc, - FusionWeights, Scored, -}; -use tracing::{debug, instrument, trace}; +use tracing::instrument; -use crate::{fts::find_items_by_fts, vector::find_items_by_vector_similarity_with_embedding}; - -// Tunable knobs controlling first-pass recall, graph expansion, and answer shaping. -const ENTITY_VECTOR_TAKE: usize = 15; -const CHUNK_VECTOR_TAKE: usize = 20; -const ENTITY_FTS_TAKE: usize = 10; -const CHUNK_FTS_TAKE: usize = 20; -const SCORE_THRESHOLD: f32 = 0.35; -const FALLBACK_MIN_RESULTS: usize = 10; -const TOKEN_BUDGET_ESTIMATE: usize = 2800; -const AVG_CHARS_PER_TOKEN: usize = 4; -const MAX_CHUNKS_PER_ENTITY: usize = 4; -const GRAPH_TRAVERSAL_SEED_LIMIT: usize = 5; -const GRAPH_NEIGHBOR_LIMIT: usize = 6; -const GRAPH_SCORE_DECAY: f32 = 0.75; -const GRAPH_SEED_MIN_SCORE: f32 = 0.4; -const GRAPH_VECTOR_INHERITANCE: f32 = 0.6; +pub use pipeline::{retrieved_entities_to_json, RetrievalConfig, RetrievalTuning}; // Captures a supporting chunk plus its fused retrieval score for downstream prompts. #[derive(Debug, Clone)] @@ -56,435 +32,34 @@ pub struct RetrievedEntity { pub chunks: Vec, } +// Primary orchestrator for the process of retrieving KnowledgeEntitities related to a input_text #[instrument(skip_all, fields(user_id))] pub async fn retrieve_entities( db_client: &SurrealDbClient, openai_client: &async_openai::Client, - query: &str, + input_text: &str, user_id: &str, ) -> Result, AppError> { - trace!("Generating query embedding for hybrid retrieval"); - let query_embedding = generate_embedding(openai_client, query, db_client).await?; - retrieve_entities_with_embedding(db_client, query_embedding, query, user_id).await -} - -pub(crate) async fn retrieve_entities_with_embedding( - db_client: &SurrealDbClient, - query_embedding: Vec, - query: &str, - user_id: &str, -) -> Result, AppError> { - // 1) Gather first-pass candidates from vector search and BM25. - let weights = FusionWeights::default(); - - let (vector_entities, vector_chunks, mut fts_entities, mut fts_chunks) = tokio::try_join!( - find_items_by_vector_similarity_with_embedding( - ENTITY_VECTOR_TAKE, - query_embedding.clone(), - db_client, - "knowledge_entity", - user_id, - ), - find_items_by_vector_similarity_with_embedding( - CHUNK_VECTOR_TAKE, - query_embedding, - db_client, - "text_chunk", - user_id, - ), - find_items_by_fts( - ENTITY_FTS_TAKE, - query, - db_client, - "knowledge_entity", - user_id - ), - find_items_by_fts(CHUNK_FTS_TAKE, query, db_client, "text_chunk", user_id), - )?; - - debug!( - vector_entities = vector_entities.len(), - vector_chunks = vector_chunks.len(), - fts_entities = fts_entities.len(), - fts_chunks = fts_chunks.len(), - "Hybrid retrieval initial candidate counts" - ); - - normalize_fts_scores(&mut fts_entities); - normalize_fts_scores(&mut fts_chunks); - - let mut entity_candidates: HashMap> = HashMap::new(); - let mut chunk_candidates: HashMap> = HashMap::new(); - - // Collate raw retrieval results so each ID accumulates all available signals. - merge_scored_by_id(&mut entity_candidates, vector_entities); - merge_scored_by_id(&mut entity_candidates, fts_entities); - merge_scored_by_id(&mut chunk_candidates, vector_chunks); - merge_scored_by_id(&mut chunk_candidates, fts_chunks); - - // 2) Normalize scores, fuse them, and allow high-confidence entities to pull neighbors from the graph. - apply_fusion(&mut entity_candidates, weights); - apply_fusion(&mut chunk_candidates, weights); - enrich_entities_from_graph(&mut entity_candidates, db_client, user_id, weights).await?; - - // 3) Track high-signal chunk sources so we can backfill missing entities. - let chunk_by_source = group_chunks_by_source(&chunk_candidates); - let mut missing_sources = Vec::new(); - - for source_id in chunk_by_source.keys() { - if !entity_candidates - .values() - .any(|entity| entity.item.source_id == *source_id) - { - missing_sources.push(source_id.clone()); - } - } - - if !missing_sources.is_empty() { - let related_entities: Vec = find_entities_by_source_ids( - missing_sources.clone(), - "knowledge_entity", - user_id, - db_client, - ) - .await - .unwrap_or_default(); - - for entity in related_entities { - if let Some(chunks) = chunk_by_source.get(&entity.source_id) { - let best_chunk_score = chunks - .iter() - .map(|chunk| chunk.fused) - .fold(0.0f32, f32::max); - - let mut scored = Scored::new(entity.clone()).with_vector_score(best_chunk_score); - let fused = fuse_scores(&scored.scores, weights); - scored.update_fused(fused); - entity_candidates.insert(entity.id.clone(), scored); - } - } - } - - // Boost entities with evidence from high scoring chunks. - for entity in entity_candidates.values_mut() { - if let Some(chunks) = chunk_by_source.get(&entity.item.source_id) { - let best_chunk_score = chunks - .iter() - .map(|chunk| chunk.fused) - .fold(0.0f32, f32::max); - - if best_chunk_score > 0.0 { - let boosted = entity.scores.vector.unwrap_or(0.0).max(best_chunk_score); - entity.scores.vector = Some(boosted); - let fused = fuse_scores(&entity.scores, weights); - entity.update_fused(fused); - } - } - } - - let mut entity_results: Vec> = - entity_candidates.into_values().collect(); - sort_by_fused_desc(&mut entity_results); - - let mut filtered_entities: Vec> = entity_results - .iter() - .filter(|candidate| candidate.fused >= SCORE_THRESHOLD) - .cloned() - .collect(); - - if filtered_entities.len() < FALLBACK_MIN_RESULTS { - // Low recall scenarios still benefit from some context; take the top N regardless of score. - filtered_entities = entity_results - .into_iter() - .take(FALLBACK_MIN_RESULTS) - .collect(); - } - - // 4) Re-rank chunks and prepare for attachment to surviving entities. - let mut chunk_results: Vec> = chunk_candidates.into_values().collect(); - sort_by_fused_desc(&mut chunk_results); - - let mut chunk_by_id: HashMap> = HashMap::new(); - for chunk in chunk_results { - chunk_by_id.insert(chunk.item.id.clone(), chunk); - } - - enrich_chunks_from_entities( - &mut chunk_by_id, - &filtered_entities, + pipeline::run_pipeline( db_client, + openai_client, + input_text, user_id, - weights, + RetrievalConfig::default(), ) - .await?; - - let mut chunk_values: Vec> = chunk_by_id.into_values().collect(); - sort_by_fused_desc(&mut chunk_values); - - Ok(assemble_results(filtered_entities, chunk_values)) -} - -// Minimal record used while seeding graph expansion so we can retain the original fused score. -#[derive(Clone)] -struct GraphSeed { - id: String, - fused: f32, -} - -async fn enrich_entities_from_graph( - entity_candidates: &mut HashMap>, - db_client: &SurrealDbClient, - user_id: &str, - weights: FusionWeights, -) -> Result<(), AppError> { - if entity_candidates.is_empty() { - return Ok(()); - } - - // Select a small frontier of high-confidence entities to seed the relationship walk. - let mut seeds: Vec = entity_candidates - .values() - .filter(|entity| entity.fused >= GRAPH_SEED_MIN_SCORE) - .map(|entity| GraphSeed { - id: entity.item.id.clone(), - fused: entity.fused, - }) - .collect(); - - if seeds.is_empty() { - return Ok(()); - } - - // Prioritise the strongest seeds so we explore the most grounded context first. - seeds.sort_by(|a, b| { - b.fused - .partial_cmp(&a.fused) - .unwrap_or(std::cmp::Ordering::Equal) - }); - seeds.truncate(GRAPH_TRAVERSAL_SEED_LIMIT); - - let mut futures = FuturesUnordered::new(); - for seed in seeds.clone() { - let user_id = user_id.to_owned(); - futures.push(async move { - // Fetch neighbors concurrently to avoid serial graph round trips. - let neighbors = find_entities_by_relationship_by_id( - db_client, - &seed.id, - &user_id, - GRAPH_NEIGHBOR_LIMIT, - ) - .await; - (seed, neighbors) - }); - } - - while let Some((seed, neighbors_result)) = futures.next().await { - let neighbors = neighbors_result.map_err(AppError::from)?; - if neighbors.is_empty() { - continue; - } - - // Fold neighbors back into the candidate map and let them inherit attenuated signal. - for neighbor in neighbors { - if neighbor.id == seed.id { - continue; - } - - let graph_score = clamp_unit(seed.fused * GRAPH_SCORE_DECAY); - let entry = entity_candidates - .entry(neighbor.id.clone()) - .or_insert_with(|| Scored::new(neighbor.clone())); - - entry.item = neighbor; - - let inherited_vector = clamp_unit(graph_score * GRAPH_VECTOR_INHERITANCE); - let vector_existing = entry.scores.vector.unwrap_or(0.0); - if inherited_vector > vector_existing { - entry.scores.vector = Some(inherited_vector); - } - - let existing_graph = entry.scores.graph.unwrap_or(f32::MIN); - if graph_score > existing_graph || entry.scores.graph.is_none() { - entry.scores.graph = Some(graph_score); - } - - let fused = fuse_scores(&entry.scores, weights); - entry.update_fused(fused); - } - } - - Ok(()) -} - -fn normalize_fts_scores(results: &mut [Scored]) { - // Scale BM25 outputs into [0,1] to keep fusion weights predictable. - let raw_scores: Vec = results - .iter() - .map(|candidate| candidate.scores.fts.unwrap_or(0.0)) - .collect(); - - let normalized = min_max_normalize(&raw_scores); - for (candidate, normalized_score) in results.iter_mut().zip(normalized.into_iter()) { - candidate.scores.fts = Some(normalized_score); - candidate.update_fused(0.0); - } -} - -fn apply_fusion(candidates: &mut HashMap>, weights: FusionWeights) -where - T: StoredObject, -{ - // Collapse individual signals into a single fused score used for ranking. - for candidate in candidates.values_mut() { - let fused = fuse_scores(&candidate.scores, weights); - candidate.update_fused(fused); - } -} - -fn group_chunks_by_source( - chunks: &HashMap>, -) -> HashMap>> { - // Preserve chunk candidates keyed by their originating source entity. - let mut by_source: HashMap>> = HashMap::new(); - - for chunk in chunks.values() { - by_source - .entry(chunk.item.source_id.clone()) - .or_default() - .push(chunk.clone()); - } - by_source -} - -async fn enrich_chunks_from_entities( - chunk_candidates: &mut HashMap>, - entities: &[Scored], - db_client: &SurrealDbClient, - user_id: &str, - weights: FusionWeights, -) -> Result<(), AppError> { - // Fetch additional chunks referenced by entities that survived the fusion stage. - let mut source_ids: HashSet = HashSet::new(); - for entity in entities { - source_ids.insert(entity.item.source_id.clone()); - } - - if source_ids.is_empty() { - return Ok(()); - } - - let chunks = find_entities_by_source_ids::( - source_ids.into_iter().collect(), - "text_chunk", - user_id, - db_client, - ) - .await?; - - let mut entity_score_lookup: HashMap = HashMap::new(); - // Cache fused scores per source so chunks inherit the strength of their parent entity. - for entity in entities { - entity_score_lookup.insert(entity.item.source_id.clone(), entity.fused); - } - - for chunk in chunks { - // Ensure each chunk is represented so downstream selection sees the latest content. - let entry = chunk_candidates - .entry(chunk.id.clone()) - .or_insert_with(|| Scored::new(chunk.clone()).with_vector_score(0.0)); - - let entity_score = entity_score_lookup - .get(&chunk.source_id) - .copied() - .unwrap_or(0.0); - - // Lift chunk score toward the entity score so supporting evidence is prioritised. - entry.scores.vector = Some(entry.scores.vector.unwrap_or(0.0).max(entity_score * 0.8)); - let fused = fuse_scores(&entry.scores, weights); - entry.update_fused(fused); - entry.item = chunk; - } - - Ok(()) -} - -fn assemble_results( - entities: Vec>, - mut chunks: Vec>, -) -> Vec { - // Re-associate chunk candidates with their parent entity for ranked selection. - let mut chunk_by_source: HashMap>> = HashMap::new(); - for chunk in chunks.drain(..) { - chunk_by_source - .entry(chunk.item.source_id.clone()) - .or_default() - .push(chunk); - } - - for chunk_list in chunk_by_source.values_mut() { - sort_by_fused_desc(chunk_list); - } - - let mut token_budget_remaining = TOKEN_BUDGET_ESTIMATE; - let mut results = Vec::new(); - - for entity in entities { - // Attach best chunks first while respecting per-entity and global token caps. - let mut selected_chunks = Vec::new(); - if let Some(candidates) = chunk_by_source.get_mut(&entity.item.source_id) { - let mut per_entity_count = 0; - candidates.sort_by(|a, b| { - b.fused - .partial_cmp(&a.fused) - .unwrap_or(std::cmp::Ordering::Equal) - }); - - for candidate in candidates.iter() { - if per_entity_count >= MAX_CHUNKS_PER_ENTITY { - break; - } - let estimated_tokens = estimate_tokens(&candidate.item.chunk); - if estimated_tokens > token_budget_remaining { - continue; - } - - token_budget_remaining = token_budget_remaining.saturating_sub(estimated_tokens); - per_entity_count += 1; - - selected_chunks.push(RetrievedChunk { - chunk: candidate.item.clone(), - score: candidate.fused, - }); - } - } - - results.push(RetrievedEntity { - entity: entity.item.clone(), - score: entity.fused, - chunks: selected_chunks, - }); - - if token_budget_remaining == 0 { - break; - } - } - - results -} - -fn estimate_tokens(text: &str) -> usize { - // Simple heuristic to avoid calling a tokenizer in hot code paths. - let chars = text.chars().count().max(1); - (chars / AVG_CHARS_PER_TOKEN).max(1) + .await } #[cfg(test)] mod tests { use super::*; + use async_openai::Client; use common::storage::types::{ - knowledge_entity::KnowledgeEntityType, knowledge_relationship::KnowledgeRelationship, + knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}, + knowledge_relationship::KnowledgeRelationship, + text_chunk::TextChunk, }; + use pipeline::RetrievalConfig; use uuid::Uuid; fn test_embedding() -> Vec { @@ -527,70 +102,46 @@ mod tests { COMMIT TRANSACTION;", ) .await - .expect("Failed to redefine vector indexes for tests"); + .expect("Failed to configure indices"); db } - async fn seed_test_data(db: &SurrealDbClient, user_id: &str) { - let entity_relevant = KnowledgeEntity::new( - "source_a".into(), - "Rust Concurrency Patterns".into(), - "Discussion about async concurrency in Rust.".into(), + #[tokio::test] + async fn test_retrieve_entities_with_embedding_basic_flow() { + let db = setup_test_db().await; + let user_id = "test_user"; + let entity = KnowledgeEntity::new( + "source_1".into(), + "Rust async guide".into(), + "Detailed notes about async runtimes".into(), KnowledgeEntityType::Document, None, entity_embedding_high(), user_id.into(), ); - let entity_irrelevant = KnowledgeEntity::new( - "source_b".into(), - "Python Tips".into(), - "General Python programming tips.".into(), - KnowledgeEntityType::Document, - None, - entity_embedding_low(), - user_id.into(), - ); - - db.store_item(entity_relevant.clone()) - .await - .expect("Failed to store relevant entity"); - db.store_item(entity_irrelevant.clone()) - .await - .expect("Failed to store irrelevant entity"); - - let chunk_primary = TextChunk::new( - entity_relevant.source_id.clone(), - "Tokio enables async concurrency with lightweight tasks.".into(), + let chunk = TextChunk::new( + entity.source_id.clone(), + "Tokio uses cooperative scheduling for fairness.".into(), chunk_embedding_primary(), user_id.into(), ); - let chunk_secondary = TextChunk::new( - entity_irrelevant.source_id.clone(), - "Python focuses on readability and dynamic typing.".into(), - chunk_embedding_secondary(), - user_id.into(), - ); - db.store_item(chunk_primary) + db.store_item(entity.clone()) .await - .expect("Failed to store primary chunk"); - db.store_item(chunk_secondary) + .expect("Failed to store entity"); + db.store_item(chunk.clone()) .await - .expect("Failed to store secondary chunk"); - } + .expect("Failed to store chunk"); - #[tokio::test] - async fn test_hybrid_retrieval_prioritises_relevant_entity() { - let db = setup_test_db().await; - let user_id = "user123"; - seed_test_data(&db, user_id).await; - - let results = retrieve_entities_with_embedding( + let openai_client = Client::new(); + let results = pipeline::run_pipeline_with_embedding( &db, + &openai_client, test_embedding(), "Rust concurrency async tasks", user_id, + RetrievalConfig::default(), ) .await .expect("Hybrid retrieval failed"); @@ -599,9 +150,7 @@ mod tests { !results.is_empty(), "Expected at least one retrieval result" ); - let top = &results[0]; - assert!( top.entity.name.contains("Rust"), "Expected Rust entity to be ranked first" @@ -610,16 +159,6 @@ mod tests { !top.chunks.is_empty(), "Expected Rust entity to include supporting chunks" ); - - let chunk_texts: Vec<&str> = top - .chunks - .iter() - .map(|chunk| chunk.chunk.chunk.as_str()) - .collect(); - assert!( - chunk_texts.iter().any(|text| text.contains("Tokio")), - "Expected chunk discussing Tokio to be included" - ); } #[tokio::test] @@ -673,6 +212,7 @@ mod tests { .await .expect("Failed to store neighbor chunk"); + let openai_client = Client::new(); let relationship = KnowledgeRelationship::new( primary.id.clone(), neighbor.id.clone(), @@ -685,11 +225,13 @@ mod tests { .await .expect("Failed to store relationship"); - let results = retrieve_entities_with_embedding( + let results = pipeline::run_pipeline_with_embedding( &db, + &openai_client, test_embedding(), "Rust concurrency async tasks", user_id, + RetrievalConfig::default(), ) .await .expect("Hybrid retrieval failed"); diff --git a/composite-retrieval/src/pipeline/config.rs b/composite-retrieval/src/pipeline/config.rs new file mode 100644 index 0000000..d461440 --- /dev/null +++ b/composite-retrieval/src/pipeline/config.rs @@ -0,0 +1,61 @@ +use serde::{Deserialize, Serialize}; + +/// Tunable parameters that govern each retrieval stage. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RetrievalTuning { + pub entity_vector_take: usize, + pub chunk_vector_take: usize, + pub entity_fts_take: usize, + pub chunk_fts_take: usize, + pub score_threshold: f32, + pub fallback_min_results: usize, + pub token_budget_estimate: usize, + pub avg_chars_per_token: usize, + pub max_chunks_per_entity: usize, + pub graph_traversal_seed_limit: usize, + pub graph_neighbor_limit: usize, + pub graph_score_decay: f32, + pub graph_seed_min_score: f32, + pub graph_vector_inheritance: f32, +} + +impl Default for RetrievalTuning { + fn default() -> Self { + Self { + entity_vector_take: 15, + chunk_vector_take: 20, + entity_fts_take: 10, + chunk_fts_take: 20, + score_threshold: 0.35, + fallback_min_results: 10, + token_budget_estimate: 2800, + avg_chars_per_token: 4, + max_chunks_per_entity: 4, + graph_traversal_seed_limit: 5, + graph_neighbor_limit: 6, + graph_score_decay: 0.75, + graph_seed_min_score: 0.4, + graph_vector_inheritance: 0.6, + } + } +} + +/// Wrapper containing tuning plus future flags for per-request overrides. +#[derive(Debug, Clone)] +pub struct RetrievalConfig { + pub tuning: RetrievalTuning, +} + +impl RetrievalConfig { + pub fn new(tuning: RetrievalTuning) -> Self { + Self { tuning } + } +} + +impl Default for RetrievalConfig { + fn default() -> Self { + Self { + tuning: RetrievalTuning::default(), + } + } +} diff --git a/composite-retrieval/src/pipeline/mod.rs b/composite-retrieval/src/pipeline/mod.rs new file mode 100644 index 0000000..15fc3b3 --- /dev/null +++ b/composite-retrieval/src/pipeline/mod.rs @@ -0,0 +1,100 @@ +mod config; +mod stages; +mod state; + +pub use config::{RetrievalConfig, RetrievalTuning}; + +use crate::RetrievedEntity; +use async_openai::Client; +use common::{error::AppError, storage::db::SurrealDbClient}; +use tracing::info; + +/// Drives the retrieval pipeline from embedding through final assembly. +pub async fn run_pipeline( + db_client: &SurrealDbClient, + openai_client: &Client, + input_text: &str, + user_id: &str, + config: RetrievalConfig, +) -> Result, AppError> { + let machine = state::ready(); + let input_chars = input_text.chars().count(); + let input_preview: String = input_text.chars().take(120).collect(); + let input_preview_clean = input_preview.replace('\n', " "); + let preview_len = input_preview_clean.chars().count(); + info!( + %user_id, + input_chars, + preview_truncated = input_chars > preview_len, + preview = %input_preview_clean, + "Starting ingestion retrieval pipeline" + ); + let mut ctx = stages::PipelineContext::new( + db_client, + openai_client, + input_text.to_owned(), + user_id.to_owned(), + config, + ); + let machine = stages::embed(machine, &mut ctx).await?; + let machine = stages::collect_candidates(machine, &mut ctx).await?; + let machine = stages::expand_graph(machine, &mut ctx).await?; + let machine = stages::attach_chunks(machine, &mut ctx).await?; + let results = stages::assemble(machine, &mut ctx)?; + + Ok(results) +} + +#[cfg(test)] +pub async fn run_pipeline_with_embedding( + db_client: &SurrealDbClient, + openai_client: &Client, + query_embedding: Vec, + input_text: &str, + user_id: &str, + config: RetrievalConfig, +) -> Result, AppError> { + let machine = state::ready(); + let mut ctx = stages::PipelineContext::with_embedding( + db_client, + openai_client, + query_embedding, + input_text.to_owned(), + user_id.to_owned(), + config, + ); + let machine = stages::embed(machine, &mut ctx).await?; + let machine = stages::collect_candidates(machine, &mut ctx).await?; + let machine = stages::expand_graph(machine, &mut ctx).await?; + let machine = stages::attach_chunks(machine, &mut ctx).await?; + let results = stages::assemble(machine, &mut ctx)?; + + Ok(results) +} + +/// Helper exposed for tests to convert retrieved entities into downstream prompt JSON. +pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value { + serde_json::json!(entities + .iter() + .map(|entry| { + serde_json::json!({ + "KnowledgeEntity": { + "id": entry.entity.id, + "name": entry.entity.name, + "description": entry.entity.description, + "score": round_score(entry.score), + "chunks": entry.chunks.iter().map(|chunk| { + serde_json::json!({ + "score": round_score(chunk.score), + "content": chunk.chunk.chunk + }) + }).collect::>() + } + }) + }) + .collect::>()) +} + +fn round_score(value: f32) -> f64 { + (f64::from(value) * 1000.0).round() / 1000.0 +} diff --git a/composite-retrieval/src/pipeline/stages/mod.rs b/composite-retrieval/src/pipeline/stages/mod.rs new file mode 100644 index 0000000..3d33f26 --- /dev/null +++ b/composite-retrieval/src/pipeline/stages/mod.rs @@ -0,0 +1,599 @@ +use async_openai::Client; +use common::{ + error::AppError, + storage::{ + db::SurrealDbClient, + types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject}, + }, + utils::embedding::generate_embedding, +}; +use futures::{stream::FuturesUnordered, StreamExt}; +use state_machines::core::GuardError; +use std::collections::{HashMap, HashSet}; +use tracing::{debug, instrument, warn}; + +use crate::{ + fts::find_items_by_fts, + graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids}, + scoring::{ + clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc, + FusionWeights, Scored, + }, + vector::find_items_by_vector_similarity_with_embedding, + RetrievedChunk, RetrievedEntity, +}; + +use super::{ + config::RetrievalConfig, + state::{ + CandidatesLoaded, ChunksAttached, Embedded, GraphExpanded, HybridRetrievalMachine, Ready, + }, +}; + +pub struct PipelineContext<'a> { + pub db_client: &'a SurrealDbClient, + pub openai_client: &'a Client, + pub input_text: String, + pub user_id: String, + pub config: RetrievalConfig, + pub query_embedding: Option>, + pub entity_candidates: HashMap>, + pub chunk_candidates: HashMap>, + pub filtered_entities: Vec>, + pub chunk_values: Vec>, +} + +impl<'a> PipelineContext<'a> { + pub fn new( + db_client: &'a SurrealDbClient, + openai_client: &'a Client, + input_text: String, + user_id: String, + config: RetrievalConfig, + ) -> Self { + Self { + db_client, + openai_client, + input_text, + user_id, + config, + query_embedding: None, + entity_candidates: HashMap::new(), + chunk_candidates: HashMap::new(), + filtered_entities: Vec::new(), + chunk_values: Vec::new(), + } + } + + #[cfg(test)] + pub fn with_embedding( + db_client: &'a SurrealDbClient, + openai_client: &'a Client, + query_embedding: Vec, + input_text: String, + user_id: String, + config: RetrievalConfig, + ) -> Self { + let mut ctx = Self::new(db_client, openai_client, input_text, user_id, config); + ctx.query_embedding = Some(query_embedding); + ctx + } + + fn ensure_embedding(&self) -> Result<&Vec, AppError> { + self.query_embedding.as_ref().ok_or_else(|| { + AppError::InternalError( + "query embedding missing before candidate collection".to_string(), + ) + }) + } +} + +#[instrument(level = "trace", skip_all)] +pub async fn embed( + machine: HybridRetrievalMachine<(), Ready>, + ctx: &mut PipelineContext<'_>, +) -> Result, AppError> { + let embedding_cached = ctx.query_embedding.is_some(); + if embedding_cached { + debug!("Reusing cached query embedding for hybrid retrieval"); + } else { + debug!("Generating query embedding for hybrid retrieval"); + let embedding = + generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await?; + ctx.query_embedding = Some(embedding); + } + + machine + .embed() + .map_err(|(_, guard)| map_guard_error("embed", guard)) +} + +#[instrument(level = "trace", skip_all)] +pub async fn collect_candidates( + machine: HybridRetrievalMachine<(), Embedded>, + ctx: &mut PipelineContext<'_>, +) -> Result, AppError> { + debug!("Collecting initial candidates via vector and FTS search"); + let embedding = ctx.ensure_embedding()?.clone(); + let tuning = &ctx.config.tuning; + + let weights = FusionWeights::default(); + + let (vector_entities, vector_chunks, mut fts_entities, mut fts_chunks) = tokio::try_join!( + find_items_by_vector_similarity_with_embedding( + tuning.entity_vector_take, + embedding.clone(), + ctx.db_client, + "knowledge_entity", + &ctx.user_id, + ), + find_items_by_vector_similarity_with_embedding( + tuning.chunk_vector_take, + embedding, + ctx.db_client, + "text_chunk", + &ctx.user_id, + ), + find_items_by_fts( + tuning.entity_fts_take, + &ctx.input_text, + ctx.db_client, + "knowledge_entity", + &ctx.user_id, + ), + find_items_by_fts( + tuning.chunk_fts_take, + &ctx.input_text, + ctx.db_client, + "text_chunk", + &ctx.user_id + ), + )?; + + debug!( + vector_entities = vector_entities.len(), + vector_chunks = vector_chunks.len(), + fts_entities = fts_entities.len(), + fts_chunks = fts_chunks.len(), + "Hybrid retrieval initial candidate counts" + ); + + normalize_fts_scores(&mut fts_entities); + normalize_fts_scores(&mut fts_chunks); + + merge_scored_by_id(&mut ctx.entity_candidates, vector_entities); + merge_scored_by_id(&mut ctx.entity_candidates, fts_entities); + merge_scored_by_id(&mut ctx.chunk_candidates, vector_chunks); + merge_scored_by_id(&mut ctx.chunk_candidates, fts_chunks); + + apply_fusion(&mut ctx.entity_candidates, weights); + apply_fusion(&mut ctx.chunk_candidates, weights); + + machine + .collect_candidates() + .map_err(|(_, guard)| map_guard_error("collect_candidates", guard)) +} + +#[instrument(level = "trace", skip_all)] +pub async fn expand_graph( + machine: HybridRetrievalMachine<(), CandidatesLoaded>, + ctx: &mut PipelineContext<'_>, +) -> Result, AppError> { + debug!("Expanding candidates using graph relationships"); + let tuning = &ctx.config.tuning; + let weights = FusionWeights::default(); + + if ctx.entity_candidates.is_empty() { + return machine + .expand_graph() + .map_err(|(_, guard)| map_guard_error("expand_graph", guard)); + } + + let graph_seeds = seeds_from_candidates( + &ctx.entity_candidates, + tuning.graph_seed_min_score, + tuning.graph_traversal_seed_limit, + ); + + if graph_seeds.is_empty() { + return machine + .expand_graph() + .map_err(|(_, guard)| map_guard_error("expand_graph", guard)); + } + + let mut futures = FuturesUnordered::new(); + for seed in graph_seeds { + let db = ctx.db_client; + let user = ctx.user_id.clone(); + futures.push(async move { + let neighbors = find_entities_by_relationship_by_id( + db, + &seed.id, + &user, + tuning.graph_neighbor_limit, + ) + .await; + (seed, neighbors) + }); + } + + while let Some((seed, neighbors_result)) = futures.next().await { + let neighbors = neighbors_result.map_err(AppError::from)?; + if neighbors.is_empty() { + continue; + } + + for neighbor in neighbors { + if neighbor.id == seed.id { + continue; + } + + let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay); + let entry = ctx + .entity_candidates + .entry(neighbor.id.clone()) + .or_insert_with(|| Scored::new(neighbor.clone())); + + entry.item = neighbor; + + let inherited_vector = clamp_unit(graph_score * tuning.graph_vector_inheritance); + let vector_existing = entry.scores.vector.unwrap_or(0.0); + if inherited_vector > vector_existing { + entry.scores.vector = Some(inherited_vector); + } + + let existing_graph = entry.scores.graph.unwrap_or(f32::MIN); + if graph_score > existing_graph || entry.scores.graph.is_none() { + entry.scores.graph = Some(graph_score); + } + + let fused = fuse_scores(&entry.scores, weights); + entry.update_fused(fused); + } + } + + machine + .expand_graph() + .map_err(|(_, guard)| map_guard_error("expand_graph", guard)) +} + +#[instrument(level = "trace", skip_all)] +pub async fn attach_chunks( + machine: HybridRetrievalMachine<(), GraphExpanded>, + ctx: &mut PipelineContext<'_>, +) -> Result, AppError> { + debug!("Attaching chunks to surviving entities"); + let tuning = &ctx.config.tuning; + let weights = FusionWeights::default(); + + let chunk_by_source = group_chunks_by_source(&ctx.chunk_candidates); + + backfill_entities_from_chunks( + &mut ctx.entity_candidates, + &chunk_by_source, + ctx.db_client, + &ctx.user_id, + weights, + ) + .await?; + + boost_entities_with_chunks(&mut ctx.entity_candidates, &chunk_by_source, weights); + + let mut entity_results: Vec> = + ctx.entity_candidates.values().cloned().collect(); + sort_by_fused_desc(&mut entity_results); + + let mut filtered_entities: Vec> = entity_results + .iter() + .filter(|candidate| candidate.fused >= tuning.score_threshold) + .cloned() + .collect(); + + if filtered_entities.len() < tuning.fallback_min_results { + filtered_entities = entity_results + .into_iter() + .take(tuning.fallback_min_results) + .collect(); + } + + ctx.filtered_entities = filtered_entities; + + let mut chunk_results: Vec> = + ctx.chunk_candidates.values().cloned().collect(); + sort_by_fused_desc(&mut chunk_results); + + let mut chunk_by_id: HashMap> = HashMap::new(); + for chunk in chunk_results { + chunk_by_id.insert(chunk.item.id.clone(), chunk); + } + + enrich_chunks_from_entities( + &mut chunk_by_id, + &ctx.filtered_entities, + ctx.db_client, + &ctx.user_id, + weights, + ) + .await?; + + let mut chunk_values: Vec> = chunk_by_id.into_values().collect(); + sort_by_fused_desc(&mut chunk_values); + + ctx.chunk_values = chunk_values; + + machine + .attach_chunks() + .map_err(|(_, guard)| map_guard_error("attach_chunks", guard)) +} + +#[instrument(level = "trace", skip_all)] +pub fn assemble( + machine: HybridRetrievalMachine<(), ChunksAttached>, + ctx: &mut PipelineContext<'_>, +) -> Result, AppError> { + debug!("Assembling final retrieved entities"); + let tuning = &ctx.config.tuning; + + let mut chunk_by_source: HashMap>> = HashMap::new(); + for chunk in ctx.chunk_values.drain(..) { + chunk_by_source + .entry(chunk.item.source_id.clone()) + .or_default() + .push(chunk); + } + + for chunk_list in chunk_by_source.values_mut() { + sort_by_fused_desc(chunk_list); + } + + let mut token_budget_remaining = tuning.token_budget_estimate; + let mut results = Vec::new(); + + for entity in &ctx.filtered_entities { + let mut selected_chunks = Vec::new(); + if let Some(candidates) = chunk_by_source.get_mut(&entity.item.source_id) { + let mut per_entity_count = 0; + candidates.sort_by(|a, b| { + b.fused + .partial_cmp(&a.fused) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + for candidate in candidates.iter() { + if per_entity_count >= tuning.max_chunks_per_entity { + break; + } + let estimated_tokens = + estimate_tokens(&candidate.item.chunk, tuning.avg_chars_per_token); + if estimated_tokens > token_budget_remaining { + continue; + } + + token_budget_remaining = token_budget_remaining.saturating_sub(estimated_tokens); + per_entity_count += 1; + + selected_chunks.push(RetrievedChunk { + chunk: candidate.item.clone(), + score: candidate.fused, + }); + } + } + + results.push(RetrievedEntity { + entity: entity.item.clone(), + score: entity.fused, + chunks: selected_chunks, + }); + + if token_budget_remaining == 0 { + break; + } + } + + machine + .assemble() + .map_err(|(_, guard)| map_guard_error("assemble", guard))?; + Ok(results) +} + +fn map_guard_error(stage: &'static str, err: GuardError) -> AppError { + AppError::InternalError(format!( + "state machine guard '{stage}' failed: guard={}, event={}, kind={:?}", + err.guard, err.event, err.kind + )) +} +fn normalize_fts_scores(results: &mut [Scored]) { + let raw_scores: Vec = results + .iter() + .map(|candidate| candidate.scores.fts.unwrap_or(0.0)) + .collect(); + + let normalized = min_max_normalize(&raw_scores); + for (candidate, normalized_score) in results.iter_mut().zip(normalized.into_iter()) { + candidate.scores.fts = Some(normalized_score); + candidate.update_fused(0.0); + } +} + +fn apply_fusion(candidates: &mut HashMap>, weights: FusionWeights) +where + T: StoredObject, +{ + for candidate in candidates.values_mut() { + let fused = fuse_scores(&candidate.scores, weights); + candidate.update_fused(fused); + } +} + +fn group_chunks_by_source( + chunks: &HashMap>, +) -> HashMap>> { + let mut by_source: HashMap>> = HashMap::new(); + + for chunk in chunks.values() { + by_source + .entry(chunk.item.source_id.clone()) + .or_default() + .push(chunk.clone()); + } + by_source +} + +async fn backfill_entities_from_chunks( + entity_candidates: &mut HashMap>, + chunk_by_source: &HashMap>>, + db_client: &SurrealDbClient, + user_id: &str, + weights: FusionWeights, +) -> Result<(), AppError> { + let mut missing_sources = Vec::new(); + + for source_id in chunk_by_source.keys() { + if !entity_candidates + .values() + .any(|entity| entity.item.source_id == *source_id) + { + missing_sources.push(source_id.clone()); + } + } + + if missing_sources.is_empty() { + return Ok(()); + } + + let related_entities: Vec = find_entities_by_source_ids( + missing_sources.clone(), + "knowledge_entity", + user_id, + db_client, + ) + .await + .unwrap_or_default(); + + if related_entities.is_empty() { + warn!("expected related entities for missing chunk sources, but none were found"); + } + + for entity in related_entities { + if let Some(chunks) = chunk_by_source.get(&entity.source_id) { + let best_chunk_score = chunks + .iter() + .map(|chunk| chunk.fused) + .fold(0.0f32, f32::max); + + let mut scored = Scored::new(entity.clone()).with_vector_score(best_chunk_score); + let fused = fuse_scores(&scored.scores, weights); + scored.update_fused(fused); + entity_candidates.insert(entity.id.clone(), scored); + } + } + + Ok(()) +} + +fn boost_entities_with_chunks( + entity_candidates: &mut HashMap>, + chunk_by_source: &HashMap>>, + weights: FusionWeights, +) { + for entity in entity_candidates.values_mut() { + if let Some(chunks) = chunk_by_source.get(&entity.item.source_id) { + let best_chunk_score = chunks + .iter() + .map(|chunk| chunk.fused) + .fold(0.0f32, f32::max); + + if best_chunk_score > 0.0 { + let boosted = entity.scores.vector.unwrap_or(0.0).max(best_chunk_score); + entity.scores.vector = Some(boosted); + let fused = fuse_scores(&entity.scores, weights); + entity.update_fused(fused); + } + } + } +} + +async fn enrich_chunks_from_entities( + chunk_candidates: &mut HashMap>, + entities: &[Scored], + db_client: &SurrealDbClient, + user_id: &str, + weights: FusionWeights, +) -> Result<(), AppError> { + let mut source_ids: HashSet = HashSet::new(); + for entity in entities { + source_ids.insert(entity.item.source_id.clone()); + } + + if source_ids.is_empty() { + return Ok(()); + } + + let chunks = find_entities_by_source_ids::( + source_ids.into_iter().collect(), + "text_chunk", + user_id, + db_client, + ) + .await?; + + let mut entity_score_lookup: HashMap = HashMap::new(); + for entity in entities { + entity_score_lookup.insert(entity.item.source_id.clone(), entity.fused); + } + + for chunk in chunks { + let entry = chunk_candidates + .entry(chunk.id.clone()) + .or_insert_with(|| Scored::new(chunk.clone()).with_vector_score(0.0)); + + let entity_score = entity_score_lookup + .get(&chunk.source_id) + .copied() + .unwrap_or(0.0); + + entry.scores.vector = Some(entry.scores.vector.unwrap_or(0.0).max(entity_score * 0.8)); + let fused = fuse_scores(&entry.scores, weights); + entry.update_fused(fused); + entry.item = chunk; + } + + Ok(()) +} + +fn estimate_tokens(text: &str, avg_chars_per_token: usize) -> usize { + let chars = text.chars().count().max(1); + (chars / avg_chars_per_token).max(1) +} + +#[derive(Clone)] +struct GraphSeed { + id: String, + fused: f32, +} + +fn seeds_from_candidates( + entity_candidates: &HashMap>, + min_score: f32, + limit: usize, +) -> Vec { + let mut seeds: Vec = entity_candidates + .values() + .filter(|entity| entity.fused >= min_score) + .map(|entity| GraphSeed { + id: entity.item.id.clone(), + fused: entity.fused, + }) + .collect(); + + seeds.sort_by(|a, b| { + b.fused + .partial_cmp(&a.fused) + .unwrap_or(std::cmp::Ordering::Equal) + }); + if seeds.len() > limit { + seeds.truncate(limit); + } + + seeds +} diff --git a/composite-retrieval/src/pipeline/state.rs b/composite-retrieval/src/pipeline/state.rs new file mode 100644 index 0000000..fb605c8 --- /dev/null +++ b/composite-retrieval/src/pipeline/state.rs @@ -0,0 +1,25 @@ +use state_machines::state_machine; + +state_machine! { + name: HybridRetrievalMachine, + state: HybridRetrievalState, + initial: Ready, + states: [Ready, Embedded, CandidatesLoaded, GraphExpanded, ChunksAttached, Completed, Failed], + events { + embed { transition: { from: Ready, to: Embedded } } + collect_candidates { transition: { from: Embedded, to: CandidatesLoaded } } + expand_graph { transition: { from: CandidatesLoaded, to: GraphExpanded } } + attach_chunks { transition: { from: GraphExpanded, to: ChunksAttached } } + assemble { transition: { from: ChunksAttached, to: Completed } } + abort { + transition: { from: Ready, to: Failed } + transition: { from: CandidatesLoaded, to: Failed } + transition: { from: GraphExpanded, to: Failed } + transition: { from: ChunksAttached, to: Failed } + } + } +} + +pub fn ready() -> HybridRetrievalMachine<(), Ready> { + HybridRetrievalMachine::new(()) +} diff --git a/html-router/src/routes/account/handlers.rs b/html-router/src/routes/account/handlers.rs index 34f6586..7a246c1 100644 --- a/html-router/src/routes/account/handlers.rs +++ b/html-router/src/routes/account/handlers.rs @@ -24,7 +24,10 @@ pub async fn show_account_page( RequireUser(user): RequireUser, State(state): State, ) -> Result { - let timezones = TZ_VARIANTS.iter().map(std::string::ToString::to_string).collect(); + let timezones = TZ_VARIANTS + .iter() + .map(std::string::ToString::to_string) + .collect(); let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?; Ok(TemplateResponse::new_template( @@ -102,7 +105,10 @@ pub async fn update_timezone( ..user.clone() }; - let timezones = TZ_VARIANTS.iter().map(std::string::ToString::to_string).collect(); + let timezones = TZ_VARIANTS + .iter() + .map(std::string::ToString::to_string) + .collect(); // Render the API key section block Ok(TemplateResponse::new_partial( diff --git a/html-router/src/routes/auth/signin.rs b/html-router/src/routes/auth/signin.rs index 0b3df71..edb7237 100644 --- a/html-router/src/routes/auth/signin.rs +++ b/html-router/src/routes/auth/signin.rs @@ -27,11 +27,15 @@ pub async fn show_signin_form( if auth.is_authenticated() { return Ok(TemplateResponse::redirect("/")); } - if boosted { Ok(TemplateResponse::new_partial( - "auth/signin_base.html", - "body", - (), - )) } else { Ok(TemplateResponse::new_template("auth/signin_base.html", ())) } + if boosted { + Ok(TemplateResponse::new_partial( + "auth/signin_base.html", + "body", + (), + )) + } else { + Ok(TemplateResponse::new_template("auth/signin_base.html", ())) + } } pub async fn authenticate_user( diff --git a/html-router/src/routes/auth/signup.rs b/html-router/src/routes/auth/signup.rs index 88e305f..5b5f5db 100644 --- a/html-router/src/routes/auth/signup.rs +++ b/html-router/src/routes/auth/signup.rs @@ -29,11 +29,15 @@ pub async fn show_signup_form( return Ok(TemplateResponse::redirect("/")); } - if boosted { Ok(TemplateResponse::new_partial( - "auth/signup_form.html", - "body", - (), - )) } else { Ok(TemplateResponse::new_template("auth/signup_form.html", ())) } + if boosted { + Ok(TemplateResponse::new_partial( + "auth/signup_form.html", + "body", + (), + )) + } else { + Ok(TemplateResponse::new_template("auth/signup_form.html", ())) + } } pub async fn process_signup_and_show_verification( diff --git a/html-router/src/routes/chat/message_response_stream.rs b/html-router/src/routes/chat/message_response_stream.rs index a9848f3..b6b05d1 100644 --- a/html-router/src/routes/chat/message_response_stream.rs +++ b/html-router/src/routes/chat/message_response_stream.rs @@ -9,11 +9,8 @@ use axum::{ }, }; use composite_retrieval::{ - answer_retrieval::{ - create_chat_request, create_user_message_with_history, format_entities_json, - LLMResponseFormat, - }, - retrieve_entities, + answer_retrieval::{create_chat_request, create_user_message_with_history, LLMResponseFormat}, + retrieve_entities, retrieved_entities_to_json, }; use futures::{ stream::{self, once}, @@ -136,7 +133,7 @@ pub async fn get_response_stream( }; // 3. Create the OpenAI request - let entities_json = format_entities_json(&entities); + let entities_json = retrieved_entities_to_json(&entities); let formatted_user_message = create_user_message_with_history(&entities_json, &history, &user_message.content); let settings = match SystemSettings::get_current(&state.db).await { @@ -260,7 +257,11 @@ pub async fn get_response_stream( .chain(stream::once(async move { if let Some(message) = rx_final.recv().await { // Don't send any event if references is empty - if message.references.as_ref().is_some_and(std::vec::Vec::is_empty) { + if message + .references + .as_ref() + .is_some_and(std::vec::Vec::is_empty) + { return Ok(Event::default().event("empty")); // This event won't be sent } diff --git a/ingestion-pipeline/src/enricher.rs b/ingestion-pipeline/src/enricher.rs index b2e2652..6a71cd4 100644 --- a/ingestion-pipeline/src/enricher.rs +++ b/ingestion-pipeline/src/enricher.rs @@ -9,10 +9,7 @@ use common::{ error::AppError, storage::{db::SurrealDbClient, types::system_settings::SystemSettings}, }; -use composite_retrieval::{ - answer_retrieval::format_entities_json, retrieve_entities, RetrievedEntity, -}; -use tracing::{debug, info}; +use composite_retrieval::{retrieve_entities, retrieved_entities_to_json, RetrievedEntity}; use crate::{ types::llm_enrichment_result::LLMEnrichmentResult, @@ -42,11 +39,9 @@ impl IngestionEnricher { text: &str, user_id: &str, ) -> Result { - info!("getting similar entitities"); let similar_entities = self .find_similar_entities(category, context, text, user_id) .await?; - info!("got similar entitities"); let llm_request = self .prepare_llm_request(category, context, text, &similar_entities) .await?; @@ -60,9 +55,8 @@ impl IngestionEnricher { text: &str, user_id: &str, ) -> Result, AppError> { - let input_text = format!( - "content: {text}, category: {category}, user_context: {context:?}" - ); + let input_text = + format!("content: {text}, category: {category}, user_context: {context:?}"); retrieve_entities(&self.db_client, &self.openai_client, &input_text, user_id).await } @@ -76,14 +70,12 @@ impl IngestionEnricher { ) -> Result { let settings = SystemSettings::get_current(&self.db_client).await?; - let entities_json = format_entities_json(similar_entities); + let entities_json = retrieved_entities_to_json(similar_entities); let user_message = format!( "Category:\n{category}\ncontext:\n{context:?}\nContent:\n{text}\nExisting KnowledgeEntities in database:\n{entities_json}" ); - debug!("Prepared LLM request message: {}", user_message); - let response_format = ResponseFormat::JsonSchema { json_schema: ResponseFormatJsonSchema { description: Some("Structured analysis of the submitted content".into()), diff --git a/ingestion-pipeline/src/pipeline.rs b/ingestion-pipeline/src/pipeline.rs index 8c5e9ad..497a040 100644 --- a/ingestion-pipeline/src/pipeline.rs +++ b/ingestion-pipeline/src/pipeline.rs @@ -2,7 +2,7 @@ use std::{sync::Arc, time::Instant}; use text_splitter::TextSplitter; use tokio::time::{sleep, Duration}; -use tracing::{info, info_span, warn}; +use tracing::{debug, info, info_span, warn}; use common::{ error::AppError, @@ -67,6 +67,34 @@ impl IngestionPipeline { ) .await?; + let text_len = text_content.text.chars().count(); + let preview: String = text_content.text.chars().take(120).collect(); + let preview_clean = preview.replace("\n", " "); + let preview_len = preview_clean.chars().count(); + let truncated = text_len > preview_len; + let context_len = text_content + .context + .as_ref() + .map(|c| c.chars().count()) + .unwrap_or(0); + info!( + %task_id, + attempt, + user_id = %text_content.user_id, + category = %text_content.category, + text_chars = text_len, + context_chars = context_len, + attachments = text_content.file_info.is_some(), + "ingestion task input ready" + ); + debug!( + %task_id, + attempt, + preview = %preview_clean, + preview_truncated = truncated, + "ingestion task input preview" + ); + match self.process(&text_content).await { Ok(()) => { processing_task.mark_succeeded(&self.db).await?; diff --git a/ingestion-pipeline/src/utils/pdf_ingestion.rs b/ingestion-pipeline/src/utils/pdf_ingestion.rs index efe09e6..c4b7a27 100644 --- a/ingestion-pipeline/src/utils/pdf_ingestion.rs +++ b/ingestion-pipeline/src/utils/pdf_ingestion.rs @@ -132,9 +132,7 @@ async fn render_pdf_pages(file_path: &Path, pages: &[u32]) -> Result let mut captures = Vec::with_capacity(pages.len()); for (idx, page) in pages.iter().enumerate() { - let target = format!( - "{file_url}#page={page}&toolbar=0&statusbar=0&zoom=page-fit" - ); + let target = format!("{file_url}#page={page}&toolbar=0&statusbar=0&zoom=page-fit"); tab.navigate_to(&target) .map_err(|err| AppError::Processing(format!("Failed to navigate to PDF page: {err}")))? .wait_until_navigated() @@ -480,11 +478,7 @@ fn is_structural_line(line: &str) -> bool { || line.starts_with('~') || line.starts_with("| ") || line.starts_with("+-") - || lowered - .chars() - .next() - .is_some_and(|c| c.is_ascii_digit()) - && lowered.contains('.') + || lowered.chars().next().is_some_and(|c| c.is_ascii_digit()) && lowered.contains('.') } fn debug_dump_directory() -> Option {