diff --git a/Cargo.lock b/Cargo.lock index a298b77..870d08e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1457,26 +1457,6 @@ dependencies = [ "static_assertions", ] -[[package]] -name = "composite-retrieval" -version = "0.1.0" -dependencies = [ - "anyhow", - "async-openai", - "axum", - "common", - "fastembed", - "futures", - "serde", - "serde_json", - "state-machines", - "surrealdb", - "thiserror 1.0.69", - "tokio", - "tracing", - "uuid", -] - [[package]] name = "compression-codecs" version = "0.4.30" @@ -2197,7 +2177,6 @@ dependencies = [ "async-trait", "chrono", "common", - "composite-retrieval", "criterion", "fastembed", "futures", @@ -2205,6 +2184,7 @@ dependencies = [ "object_store 0.11.2", "once_cell", "rand 0.8.5", + "retrieval-pipeline", "serde", "serde_json", "serde_yaml", @@ -2880,7 +2860,6 @@ dependencies = [ "chrono", "chrono-tz", "common", - "composite-retrieval", "futures", "include_dir", "json-stream-parser", @@ -2888,6 +2867,7 @@ dependencies = [ "minijinja-autoreload", "minijinja-contrib", "minijinja-embed", + "retrieval-pipeline", "serde", "serde_json", "surrealdb", @@ -3342,13 +3322,13 @@ dependencies = [ "bytes", "chrono", "common", - "composite-retrieval", "dom_smoothie", "futures", "headless_chrome", "lopdf 0.32.0", "pdf-extract", "reqwest", + "retrieval-pipeline", "serde", "serde_json", "state-machines", @@ -3802,10 +3782,10 @@ dependencies = [ "async-openai", "axum", "common", - "composite-retrieval", "futures", "html-router", "ingestion-pipeline", + "retrieval-pipeline", "serde", "serde_json", "surrealdb", @@ -5475,6 +5455,26 @@ dependencies = [ "thiserror 1.0.69", ] +[[package]] +name = "retrieval-pipeline" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-openai", + "async-trait", + "axum", + "common", + "fastembed", + "futures", + "serde", + "serde_json", + "surrealdb", + "thiserror 1.0.69", + "tokio", + "tracing", + "uuid", +] + [[package]] name = "revision" version = "0.10.0" diff --git a/Cargo.toml b/Cargo.toml index febcc3a..a4eaec9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ members = [ "api-router", "html-router", "ingestion-pipeline", - "composite-retrieval", + "retrieval-pipeline", "json-stream-parser", "eval" ] diff --git a/common/src/utils/config.rs b/common/src/utils/config.rs index 5129a75..01b08ce 100644 --- a/common/src/utils/config.rs +++ b/common/src/utils/config.rs @@ -54,6 +54,8 @@ pub struct AppConfig { pub fastembed_show_download_progress: Option, #[serde(default)] pub fastembed_max_length: Option, + #[serde(default)] + pub retrieval_strategy: Option, } fn default_data_dir() -> String { @@ -117,6 +119,7 @@ impl Default for AppConfig { fastembed_cache_dir: None, fastembed_show_download_progress: None, fastembed_max_length: None, + retrieval_strategy: None, } } } diff --git a/composite-retrieval/src/pipeline/mod.rs b/composite-retrieval/src/pipeline/mod.rs deleted file mode 100644 index fc6993f..0000000 --- a/composite-retrieval/src/pipeline/mod.rs +++ /dev/null @@ -1,212 +0,0 @@ -mod config; -mod diagnostics; -mod stages; -mod state; - -pub use config::{RetrievalConfig, RetrievalTuning}; -pub use diagnostics::{ - AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace, - PipelineDiagnostics, -}; - -use crate::{reranking::RerankerLease, RetrievedEntity}; -use async_openai::Client; -use common::{error::AppError, storage::db::SurrealDbClient}; -use tracing::info; - -#[derive(Debug)] -pub struct PipelineRunOutput { - pub results: Vec, - pub diagnostics: Option, - pub stage_timings: PipelineStageTimings, -} - -#[derive(Debug, Clone, Default, serde::Serialize)] -pub struct PipelineStageTimings { - pub collect_candidates_ms: u128, - pub graph_expansion_ms: u128, - pub chunk_attach_ms: u128, - pub rerank_ms: u128, - pub assemble_ms: u128, -} - -impl PipelineStageTimings { - fn record_collect_candidates(&mut self, duration: std::time::Duration) { - self.collect_candidates_ms += duration.as_millis() as u128; - } - - fn record_graph_expansion(&mut self, duration: std::time::Duration) { - self.graph_expansion_ms += duration.as_millis() as u128; - } - - fn record_chunk_attach(&mut self, duration: std::time::Duration) { - self.chunk_attach_ms += duration.as_millis() as u128; - } - - fn record_rerank(&mut self, duration: std::time::Duration) { - self.rerank_ms += duration.as_millis() as u128; - } - - fn record_assemble(&mut self, duration: std::time::Duration) { - self.assemble_ms += duration.as_millis() as u128; - } -} - -/// Drives the retrieval pipeline from embedding through final assembly. -pub async fn run_pipeline( - db_client: &SurrealDbClient, - openai_client: &Client, - input_text: &str, - user_id: &str, - config: RetrievalConfig, - reranker: Option, -) -> Result, AppError> { - let input_chars = input_text.chars().count(); - let input_preview: String = input_text.chars().take(120).collect(); - let input_preview_clean = input_preview.replace('\n', " "); - let preview_len = input_preview_clean.chars().count(); - info!( - %user_id, - input_chars, - preview_truncated = input_chars > preview_len, - preview = %input_preview_clean, - "Starting ingestion retrieval pipeline" - ); - let ctx = stages::PipelineContext::new( - db_client, - openai_client, - input_text.to_owned(), - user_id.to_owned(), - config, - reranker, - ); - let outcome = run_pipeline_internal(ctx, false).await?; - - Ok(outcome.results) -} - -pub async fn run_pipeline_with_embedding( - db_client: &SurrealDbClient, - openai_client: &Client, - query_embedding: Vec, - input_text: &str, - user_id: &str, - config: RetrievalConfig, - reranker: Option, -) -> Result, AppError> { - let ctx = stages::PipelineContext::with_embedding( - db_client, - openai_client, - query_embedding, - input_text.to_owned(), - user_id.to_owned(), - config, - reranker, - ); - let outcome = run_pipeline_internal(ctx, false).await?; - - Ok(outcome.results) -} - -/// Runs the pipeline with a precomputed embedding and returns stage metrics. -pub async fn run_pipeline_with_embedding_with_metrics( - db_client: &SurrealDbClient, - openai_client: &Client, - query_embedding: Vec, - input_text: &str, - user_id: &str, - config: RetrievalConfig, - reranker: Option, -) -> Result { - let ctx = stages::PipelineContext::with_embedding( - db_client, - openai_client, - query_embedding, - input_text.to_owned(), - user_id.to_owned(), - config, - reranker, - ); - - run_pipeline_internal(ctx, false).await -} - -pub async fn run_pipeline_with_embedding_with_diagnostics( - db_client: &SurrealDbClient, - openai_client: &Client, - query_embedding: Vec, - input_text: &str, - user_id: &str, - config: RetrievalConfig, - reranker: Option, -) -> Result { - let ctx = stages::PipelineContext::with_embedding( - db_client, - openai_client, - query_embedding, - input_text.to_owned(), - user_id.to_owned(), - config, - reranker, - ); - - run_pipeline_internal(ctx, true).await -} - -/// Helper exposed for tests to convert retrieved entities into downstream prompt JSON. -pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value { - serde_json::json!(entities - .iter() - .map(|entry| { - serde_json::json!({ - "KnowledgeEntity": { - "id": entry.entity.id, - "name": entry.entity.name, - "description": entry.entity.description, - "score": round_score(entry.score), - "chunks": entry.chunks.iter().map(|chunk| { - serde_json::json!({ - "score": round_score(chunk.score), - "content": chunk.chunk.chunk - }) - }).collect::>() - } - }) - }) - .collect::>()) -} - -async fn run_pipeline_internal( - mut ctx: stages::PipelineContext<'_>, - capture_diagnostics: bool, -) -> Result { - if capture_diagnostics { - ctx.enable_diagnostics(); - } - - let results = drive_pipeline(&mut ctx).await?; - let diagnostics = ctx.take_diagnostics(); - - Ok(PipelineRunOutput { - results, - diagnostics, - stage_timings: ctx.take_stage_timings(), - }) -} - -async fn drive_pipeline( - ctx: &mut stages::PipelineContext<'_>, -) -> Result, AppError> { - let machine = state::ready(); - let machine = stages::embed(machine, ctx).await?; - let machine = stages::collect_candidates(machine, ctx).await?; - let machine = stages::expand_graph(machine, ctx).await?; - let machine = stages::attach_chunks(machine, ctx).await?; - let machine = stages::rerank(machine, ctx).await?; - let results = stages::assemble(machine, ctx)?; - Ok(results) -} - -fn round_score(value: f32) -> f64 { - (f64::from(value) * 1000.0).round() / 1000.0 -} diff --git a/composite-retrieval/src/pipeline/state.rs b/composite-retrieval/src/pipeline/state.rs deleted file mode 100644 index 91d3803..0000000 --- a/composite-retrieval/src/pipeline/state.rs +++ /dev/null @@ -1,27 +0,0 @@ -use state_machines::state_machine; - -state_machine! { - name: HybridRetrievalMachine, - state: HybridRetrievalState, - initial: Ready, - states: [Ready, Embedded, CandidatesLoaded, GraphExpanded, ChunksAttached, Reranked, Completed, Failed], - events { - embed { transition: { from: Ready, to: Embedded } } - collect_candidates { transition: { from: Embedded, to: CandidatesLoaded } } - expand_graph { transition: { from: CandidatesLoaded, to: GraphExpanded } } - attach_chunks { transition: { from: GraphExpanded, to: ChunksAttached } } - rerank { transition: { from: ChunksAttached, to: Reranked } } - assemble { transition: { from: Reranked, to: Completed } } - abort { - transition: { from: Ready, to: Failed } - transition: { from: CandidatesLoaded, to: Failed } - transition: { from: GraphExpanded, to: Failed } - transition: { from: ChunksAttached, to: Failed } - transition: { from: Reranked, to: Failed } - } - } -} - -pub fn ready() -> HybridRetrievalMachine<(), Ready> { - HybridRetrievalMachine::new(()) -} diff --git a/html-router/Cargo.toml b/html-router/Cargo.toml index 2503826..696147e 100644 --- a/html-router/Cargo.toml +++ b/html-router/Cargo.toml @@ -38,7 +38,7 @@ url = { workspace = true } uuid = { workspace = true } common = { path = "../common" } -composite-retrieval = { path = "../composite-retrieval" } +retrieval-pipeline = { path = "../retrieval-pipeline" } json-stream-parser = { path = "../json-stream-parser" } [build-dependencies] diff --git a/html-router/src/html_state.rs b/html-router/src/html_state.rs index a372098..c488738 100644 --- a/html-router/src/html_state.rs +++ b/html-router/src/html_state.rs @@ -1,7 +1,7 @@ use common::storage::{db::SurrealDbClient, store::StorageManager}; use common::utils::template_engine::{ProvidesTemplateEngine, TemplateEngine}; use common::{create_template_engine, storage::db::ProvidesDb, utils::config::AppConfig}; -use composite_retrieval::reranking::RerankerPool; +use retrieval_pipeline::{reranking::RerankerPool, RetrievalStrategy}; use std::sync::Arc; use tracing::debug; @@ -40,6 +40,14 @@ impl HtmlState { reranker_pool, }) } + + pub fn retrieval_strategy(&self) -> RetrievalStrategy { + self.config + .retrieval_strategy + .as_deref() + .and_then(|value| value.parse().ok()) + .unwrap_or(RetrievalStrategy::Initial) + } } impl ProvidesDb for HtmlState { fn db(&self) -> &Arc { diff --git a/html-router/src/routes/chat/message_response_stream.rs b/html-router/src/routes/chat/message_response_stream.rs index 5e6f08b..5279ebd 100644 --- a/html-router/src/routes/chat/message_response_stream.rs +++ b/html-router/src/routes/chat/message_response_stream.rs @@ -8,16 +8,16 @@ use axum::{ Sse, }, }; -use composite_retrieval::{ - answer_retrieval::{create_chat_request, create_user_message_with_history, LLMResponseFormat}, - retrieve_entities, retrieved_entities_to_json, -}; use futures::{ stream::{self, once}, Stream, StreamExt, TryStreamExt, }; use json_stream_parser::JsonStreamParser; use minijinja::Value; +use retrieval_pipeline::{ + answer_retrieval::{create_chat_request, create_user_message_with_history, LLMResponseFormat}, + retrieve_entities, retrieved_entities_to_json, RetrievalConfig, StrategyOutput, +}; use serde::{Deserialize, Serialize}; use serde_json::from_str; use tokio::sync::{mpsc::channel, Mutex}; @@ -123,16 +123,24 @@ pub async fn get_response_stream( None => None, }; + let mut retrieval_config = RetrievalConfig::default(); + retrieval_config.strategy = state.retrieval_strategy(); let entities = match retrieve_entities( &state.db, &state.openai_client, &user_message.content, &user.id, + retrieval_config, rerank_lease, ) .await { - Ok(entities) => entities, + Ok(StrategyOutput::Entities(entities)) => entities, + Ok(StrategyOutput::Chunks(_)) => { + return Sse::new(create_error_stream( + "Chunk-only retrieval results are not supported in this route", + )) + } Err(_e) => { return Sse::new(create_error_stream("Failed to retrieve knowledge entities")); } diff --git a/html-router/src/routes/knowledge/handlers.rs b/html-router/src/routes/knowledge/handlers.rs index b52cefa..01e3cb6 100644 --- a/html-router/src/routes/knowledge/handlers.rs +++ b/html-router/src/routes/knowledge/handlers.rs @@ -24,7 +24,7 @@ use common::{ }, utils::embedding::generate_embedding, }; -use composite_retrieval::{retrieve_entities, RetrievedEntity}; +use retrieval_pipeline::{retrieve_entities, RetrievalConfig, RetrievedEntity, StrategyOutput}; use tracing::debug; use uuid::Uuid; @@ -284,11 +284,15 @@ pub async fn suggest_knowledge_relationships( None => None, }; - if let Ok(results) = retrieve_entities( + let mut retrieval_config = RetrievalConfig::default(); + retrieval_config.strategy = state.retrieval_strategy(); + + if let Ok(StrategyOutput::Entities(results)) = retrieve_entities( &state.db, &state.openai_client, &query, &user.id, + retrieval_config, rerank_lease, ) .await diff --git a/ingestion-pipeline/Cargo.toml b/ingestion-pipeline/Cargo.toml index 124288a..1aea309 100644 --- a/ingestion-pipeline/Cargo.toml +++ b/ingestion-pipeline/Cargo.toml @@ -32,7 +32,7 @@ lopdf = "0.32" bytes = { workspace = true } common = { path = "../common" } -composite-retrieval = { path = "../composite-retrieval" } +retrieval-pipeline = { path = "../retrieval-pipeline" } async-trait = { workspace = true } state-machines = { workspace = true } [features] diff --git a/ingestion-pipeline/src/pipeline/context.rs b/ingestion-pipeline/src/pipeline/context.rs index 26ee9a1..8ccc3b5 100644 --- a/ingestion-pipeline/src/pipeline/context.rs +++ b/ingestion-pipeline/src/pipeline/context.rs @@ -11,7 +11,7 @@ use common::{ }, }, }; -use composite_retrieval::RetrievedEntity; +use retrieval_pipeline::RetrievedEntity; use tracing::error; use super::enrichment_result::LLMEnrichmentResult; diff --git a/ingestion-pipeline/src/pipeline/mod.rs b/ingestion-pipeline/src/pipeline/mod.rs index 6355446..7b8becd 100644 --- a/ingestion-pipeline/src/pipeline/mod.rs +++ b/ingestion-pipeline/src/pipeline/mod.rs @@ -28,7 +28,7 @@ use common::{ }, utils::config::AppConfig, }; -use composite_retrieval::reranking::RerankerPool; +use retrieval_pipeline::reranking::RerankerPool; use tracing::{debug, info, warn}; use self::{ diff --git a/ingestion-pipeline/src/pipeline/services.rs b/ingestion-pipeline/src/pipeline/services.rs index 719e463..3844f6b 100644 --- a/ingestion-pipeline/src/pipeline/services.rs +++ b/ingestion-pipeline/src/pipeline/services.rs @@ -19,8 +19,9 @@ use common::{ }, utils::{config::AppConfig, embedding::generate_embedding}, }; -use composite_retrieval::{ - reranking::RerankerPool, retrieve_entities, retrieved_entities_to_json, RetrievedEntity, +use retrieval_pipeline::{ + reranking::RerankerPool, retrieve_entities, retrieved_entities_to_json, RetrievalConfig, + RetrievalStrategy, RetrievedEntity, StrategyOutput, }; use text_splitter::TextSplitter; @@ -124,6 +125,14 @@ impl DefaultPipelineServices { Ok(request) } + fn configured_strategy(&self) -> RetrievalStrategy { + self.config + .retrieval_strategy + .as_deref() + .and_then(|value| value.parse().ok()) + .unwrap_or(RetrievalStrategy::Initial) + } + async fn perform_analysis( &self, request: CreateChatCompletionRequest, @@ -178,14 +187,24 @@ impl PipelineServices for DefaultPipelineServices { None => None, }; - retrieve_entities( + let mut config = RetrievalConfig::default(); + config.strategy = self.configured_strategy(); + match retrieve_entities( &self.db, &self.openai_client, &input_text, &content.user_id, + config, rerank_lease, ) .await + { + Ok(StrategyOutput::Entities(entities)) => Ok(entities), + Ok(StrategyOutput::Chunks(_)) => Err(AppError::InternalError( + "Chunk-only retrieval is not supported in ingestion".into(), + )), + Err(err) => Err(err), + } } async fn run_enrichment( diff --git a/ingestion-pipeline/src/pipeline/tests.rs b/ingestion-pipeline/src/pipeline/tests.rs index 41b2c67..cb6c206 100644 --- a/ingestion-pipeline/src/pipeline/tests.rs +++ b/ingestion-pipeline/src/pipeline/tests.rs @@ -16,7 +16,7 @@ use common::{ }, }, }; -use composite_retrieval::{RetrievedChunk, RetrievedEntity}; +use retrieval_pipeline::{RetrievedChunk, RetrievedEntity}; use tokio::sync::Mutex; use uuid::Uuid; diff --git a/main/Cargo.toml b/main/Cargo.toml index 87a0fb6..3673897 100644 --- a/main/Cargo.toml +++ b/main/Cargo.toml @@ -25,7 +25,7 @@ ingestion-pipeline = { path = "../ingestion-pipeline" } api-router = { path = "../api-router" } html-router = { path = "../html-router" } common = { path = "../common" } -composite-retrieval = { path = "../composite-retrieval" } +retrieval-pipeline = { path = "../retrieval-pipeline" } [dev-dependencies] tower = "0.5" diff --git a/main/src/main.rs b/main/src/main.rs index b2bd31c..21078d2 100644 --- a/main/src/main.rs +++ b/main/src/main.rs @@ -3,9 +3,9 @@ use axum::{extract::FromRef, Router}; use common::{ storage::db::SurrealDbClient, storage::store::StorageManager, utils::config::get_config, }; -use composite_retrieval::reranking::RerankerPool; use html_router::{html_routes, html_state::HtmlState}; use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop}; +use retrieval_pipeline::reranking::RerankerPool; use std::sync::Arc; use tracing::{error, info}; use tracing_subscriber::{fmt, prelude::*, EnvFilter}; diff --git a/main/src/server.rs b/main/src/server.rs index c5fe83f..c3fdd78 100644 --- a/main/src/server.rs +++ b/main/src/server.rs @@ -5,8 +5,8 @@ use axum::{extract::FromRef, Router}; use common::{ storage::db::SurrealDbClient, storage::store::StorageManager, utils::config::get_config, }; -use composite_retrieval::reranking::RerankerPool; use html_router::{html_routes, html_state::HtmlState}; +use retrieval_pipeline::reranking::RerankerPool; use tracing::info; use tracing_subscriber::{fmt, prelude::*, EnvFilter}; diff --git a/main/src/worker.rs b/main/src/worker.rs index d56e9f6..dde9621 100644 --- a/main/src/worker.rs +++ b/main/src/worker.rs @@ -3,8 +3,8 @@ use std::sync::Arc; use common::{ storage::db::SurrealDbClient, storage::store::StorageManager, utils::config::get_config, }; -use composite_retrieval::reranking::RerankerPool; use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop}; +use retrieval_pipeline::reranking::RerankerPool; use tracing_subscriber::{fmt, prelude::*, EnvFilter}; #[tokio::main] diff --git a/composite-retrieval/Cargo.toml b/retrieval-pipeline/Cargo.toml similarity index 89% rename from composite-retrieval/Cargo.toml rename to retrieval-pipeline/Cargo.toml index 1ef08a1..ad792ce 100644 --- a/composite-retrieval/Cargo.toml +++ b/retrieval-pipeline/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "composite-retrieval" +name = "retrieval-pipeline" version = "0.1.0" edition = "2021" license = "AGPL-3.0-or-later" @@ -18,8 +18,8 @@ serde_json = { workspace = true } surrealdb = { workspace = true } futures = { workspace = true } async-openai = { workspace = true } +async-trait = { workspace = true } uuid = { workspace = true } fastembed = { workspace = true } common = { path = "../common", features = ["test-utils"] } -state-machines = { workspace = true } diff --git a/composite-retrieval/src/answer_retrieval.rs b/retrieval-pipeline/src/answer_retrieval.rs similarity index 100% rename from composite-retrieval/src/answer_retrieval.rs rename to retrieval-pipeline/src/answer_retrieval.rs diff --git a/composite-retrieval/src/answer_retrieval_helper.rs b/retrieval-pipeline/src/answer_retrieval_helper.rs similarity index 100% rename from composite-retrieval/src/answer_retrieval_helper.rs rename to retrieval-pipeline/src/answer_retrieval_helper.rs diff --git a/composite-retrieval/src/fts.rs b/retrieval-pipeline/src/fts.rs similarity index 100% rename from composite-retrieval/src/fts.rs rename to retrieval-pipeline/src/fts.rs diff --git a/composite-retrieval/src/graph.rs b/retrieval-pipeline/src/graph.rs similarity index 100% rename from composite-retrieval/src/graph.rs rename to retrieval-pipeline/src/graph.rs diff --git a/composite-retrieval/src/lib.rs b/retrieval-pipeline/src/lib.rs similarity index 75% rename from composite-retrieval/src/lib.rs rename to retrieval-pipeline/src/lib.rs index cc2e8ab..0f77d37 100644 --- a/composite-retrieval/src/lib.rs +++ b/retrieval-pipeline/src/lib.rs @@ -17,7 +17,10 @@ use common::{ use reranking::RerankerLease; use tracing::instrument; -pub use pipeline::{retrieved_entities_to_json, RetrievalConfig, RetrievalTuning}; +pub use pipeline::{ + retrieved_entities_to_json, PipelineDiagnostics, PipelineStageTimings, RetrievalConfig, + RetrievalStrategy, RetrievalTuning, StrategyOutput, +}; // Captures a supporting chunk plus its fused retrieval score for downstream prompts. #[derive(Debug, Clone)] @@ -41,14 +44,15 @@ pub async fn retrieve_entities( openai_client: &async_openai::Client, input_text: &str, user_id: &str, + config: RetrievalConfig, reranker: Option, -) -> Result, AppError> { +) -> Result { pipeline::run_pipeline( db_client, openai_client, input_text, user_id, - RetrievalConfig::default(), + config, reranker, ) .await @@ -63,7 +67,7 @@ mod tests { knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk, }; - use pipeline::RetrievalConfig; + use pipeline::{RetrievalConfig, RetrievalStrategy}; use uuid::Uuid; fn test_embedding() -> Vec { @@ -151,11 +155,16 @@ mod tests { .await .expect("Hybrid retrieval failed"); + let entities = match results { + StrategyOutput::Entities(items) => items, + other => panic!("expected entity results, got {:?}", other), + }; + assert!( - !results.is_empty(), + !entities.is_empty(), "Expected at least one retrieval result" ); - let top = &results[0]; + let top = &entities[0]; assert!( top.entity.name.contains("Rust"), "Expected Rust entity to be ranked first" @@ -242,8 +251,13 @@ mod tests { .await .expect("Hybrid retrieval failed"); + let entities = match results { + StrategyOutput::Entities(items) => items, + other => panic!("expected entity results, got {:?}", other), + }; + let mut neighbor_entry = None; - for entity in &results { + for entity in &entities { if entity.entity.id == neighbor.id { neighbor_entry = Some(entity.clone()); } @@ -264,4 +278,59 @@ mod tests { "Neighbor entity should surface its own supporting chunks" ); } + + #[tokio::test] + async fn test_revised_strategy_returns_chunks() { + let db = setup_test_db().await; + let user_id = "chunk_user"; + let chunk_one = TextChunk::new( + "src_alpha".into(), + "Tokio tasks execute on worker threads managed by the runtime.".into(), + chunk_embedding_primary(), + user_id.into(), + ); + let chunk_two = TextChunk::new( + "src_beta".into(), + "Hyper utilizes Tokio to drive HTTP state machines efficiently.".into(), + chunk_embedding_secondary(), + user_id.into(), + ); + + db.store_item(chunk_one.clone()) + .await + .expect("Failed to store chunk one"); + db.store_item(chunk_two.clone()) + .await + .expect("Failed to store chunk two"); + + let config = RetrievalConfig::with_strategy(RetrievalStrategy::Revised); + let openai_client = Client::new(); + let results = pipeline::run_pipeline_with_embedding( + &db, + &openai_client, + test_embedding(), + "tokio runtime worker behavior", + user_id, + config, + None, + ) + .await + .expect("Revised retrieval failed"); + + let chunks = match results { + StrategyOutput::Chunks(items) => items, + other => panic!("expected chunk output, got {:?}", other), + }; + + assert!( + !chunks.is_empty(), + "Revised strategy should return chunk-only responses" + ); + assert!( + chunks + .iter() + .any(|entry| entry.chunk.chunk.contains("Tokio")), + "Chunk results should contain relevant snippets" + ); + } } diff --git a/composite-retrieval/src/pipeline/config.rs b/retrieval-pipeline/src/pipeline/config.rs similarity index 58% rename from composite-retrieval/src/pipeline/config.rs rename to retrieval-pipeline/src/pipeline/config.rs index 446cf0f..0937d26 100644 --- a/composite-retrieval/src/pipeline/config.rs +++ b/retrieval-pipeline/src/pipeline/config.rs @@ -1,4 +1,40 @@ use serde::{Deserialize, Serialize}; +use std::fmt; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum RetrievalStrategy { + Initial, + Revised, +} + +impl Default for RetrievalStrategy { + fn default() -> Self { + Self::Initial + } +} + +impl std::str::FromStr for RetrievalStrategy { + type Err = String; + + fn from_str(value: &str) -> Result { + match value.to_ascii_lowercase().as_str() { + "initial" => Ok(Self::Initial), + "revised" => Ok(Self::Revised), + other => Err(format!("unknown retrieval strategy '{other}'")), + } + } +} + +impl fmt::Display for RetrievalStrategy { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let label = match self { + RetrievalStrategy::Initial => "initial", + RetrievalStrategy::Revised => "revised", + }; + f.write_str(label) + } +} /// Tunable parameters that govern each retrieval stage. #[derive(Debug, Clone, Serialize, Deserialize)] @@ -51,18 +87,34 @@ impl Default for RetrievalTuning { /// Wrapper containing tuning plus future flags for per-request overrides. #[derive(Debug, Clone)] pub struct RetrievalConfig { + pub strategy: RetrievalStrategy, pub tuning: RetrievalTuning, } impl RetrievalConfig { pub fn new(tuning: RetrievalTuning) -> Self { - Self { tuning } + Self { + strategy: RetrievalStrategy::Initial, + tuning, + } + } + + pub fn with_strategy(strategy: RetrievalStrategy) -> Self { + Self { + strategy, + tuning: RetrievalTuning::default(), + } + } + + pub fn with_tuning(strategy: RetrievalStrategy, tuning: RetrievalTuning) -> Self { + Self { strategy, tuning } } } impl Default for RetrievalConfig { fn default() -> Self { Self { + strategy: RetrievalStrategy::default(), tuning: RetrievalTuning::default(), } } diff --git a/composite-retrieval/src/pipeline/diagnostics.rs b/retrieval-pipeline/src/pipeline/diagnostics.rs similarity index 100% rename from composite-retrieval/src/pipeline/diagnostics.rs rename to retrieval-pipeline/src/pipeline/diagnostics.rs diff --git a/retrieval-pipeline/src/pipeline/mod.rs b/retrieval-pipeline/src/pipeline/mod.rs new file mode 100644 index 0000000..0cbd494 --- /dev/null +++ b/retrieval-pipeline/src/pipeline/mod.rs @@ -0,0 +1,397 @@ +mod config; +mod diagnostics; +mod stages; +mod strategies; + +pub use config::{RetrievalConfig, RetrievalStrategy, RetrievalTuning}; +pub use diagnostics::{ + AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace, + PipelineDiagnostics, +}; + +use crate::{reranking::RerankerLease, RetrievedChunk, RetrievedEntity}; +use async_openai::Client; +use async_trait::async_trait; +use common::{error::AppError, storage::db::SurrealDbClient}; +use std::time::{Duration, Instant}; +use tracing::info; + +use stages::PipelineContext; +use strategies::{InitialStrategyDriver, RevisedStrategyDriver}; + +#[derive(Debug, Clone)] +pub enum StrategyOutput { + Entities(Vec), + Chunks(Vec), +} + +impl StrategyOutput { + pub fn as_entities(&self) -> Option<&[RetrievedEntity]> { + match self { + StrategyOutput::Entities(items) => Some(items), + _ => None, + } + } + + pub fn into_entities(self) -> Option> { + match self { + StrategyOutput::Entities(items) => Some(items), + _ => None, + } + } + + pub fn as_chunks(&self) -> Option<&[RetrievedChunk]> { + match self { + StrategyOutput::Chunks(items) => Some(items), + _ => None, + } + } + + pub fn into_chunks(self) -> Option> { + match self { + StrategyOutput::Chunks(items) => Some(items), + _ => None, + } + } +} + +#[derive(Debug)] +pub struct PipelineRunOutput { + pub results: T, + pub diagnostics: Option, + pub stage_timings: PipelineStageTimings, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StageKind { + Embed, + CollectCandidates, + GraphExpansion, + ChunkAttach, + Rerank, + Assemble, +} + +#[derive(Debug, Clone, Default, serde::Serialize)] +pub struct PipelineStageTimings { + pub embed_ms: u128, + pub collect_candidates_ms: u128, + pub graph_expansion_ms: u128, + pub chunk_attach_ms: u128, + pub rerank_ms: u128, + pub assemble_ms: u128, +} + +impl PipelineStageTimings { + pub fn record(&mut self, kind: StageKind, duration: Duration) { + let elapsed = duration.as_millis() as u128; + match kind { + StageKind::Embed => self.embed_ms += elapsed, + StageKind::CollectCandidates => self.collect_candidates_ms += elapsed, + StageKind::GraphExpansion => self.graph_expansion_ms += elapsed, + StageKind::ChunkAttach => self.chunk_attach_ms += elapsed, + StageKind::Rerank => self.rerank_ms += elapsed, + StageKind::Assemble => self.assemble_ms += elapsed, + } + } +} + +#[async_trait] +pub trait PipelineStage: Send + Sync { + fn kind(&self) -> StageKind; + async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError>; +} + +pub type BoxedStage = Box; + +pub trait StrategyDriver { + type Output; + + fn strategy(&self) -> RetrievalStrategy; + fn stages(&self) -> Vec; + fn override_tuning(&self, _config: &mut RetrievalConfig) {} + + fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result; +} + +pub async fn run_pipeline( + db_client: &SurrealDbClient, + openai_client: &Client, + input_text: &str, + user_id: &str, + config: RetrievalConfig, + reranker: Option, +) -> Result { + let input_chars = input_text.chars().count(); + let input_preview: String = input_text.chars().take(120).collect(); + let input_preview_clean = input_preview.replace('\n', " "); + let preview_len = input_preview_clean.chars().count(); + info!( + %user_id, + input_chars, + preview_truncated = input_chars > preview_len, + preview = %input_preview_clean, + "Starting ingestion retrieval pipeline" + ); + + if config.strategy == RetrievalStrategy::Initial { + let driver = InitialStrategyDriver::new(); + let run = execute_strategy( + driver, + db_client, + openai_client, + None, + input_text, + user_id, + config, + reranker, + false, + ) + .await?; + return Ok(StrategyOutput::Entities(run.results)); + } + + let driver = RevisedStrategyDriver::new(); + let run = execute_strategy( + driver, + db_client, + openai_client, + None, + input_text, + user_id, + config, + reranker, + false, + ) + .await?; + Ok(StrategyOutput::Chunks(run.results)) +} + +pub async fn run_pipeline_with_embedding( + db_client: &SurrealDbClient, + openai_client: &Client, + query_embedding: Vec, + input_text: &str, + user_id: &str, + config: RetrievalConfig, + reranker: Option, +) -> Result { + if config.strategy == RetrievalStrategy::Initial { + let driver = InitialStrategyDriver::new(); + let run = execute_strategy( + driver, + db_client, + openai_client, + Some(query_embedding), + input_text, + user_id, + config, + reranker, + false, + ) + .await?; + return Ok(StrategyOutput::Entities(run.results)); + } + + let driver = RevisedStrategyDriver::new(); + let run = execute_strategy( + driver, + db_client, + openai_client, + Some(query_embedding), + input_text, + user_id, + config, + reranker, + false, + ) + .await?; + Ok(StrategyOutput::Chunks(run.results)) +} + +pub async fn run_pipeline_with_embedding_with_metrics( + db_client: &SurrealDbClient, + openai_client: &Client, + query_embedding: Vec, + input_text: &str, + user_id: &str, + config: RetrievalConfig, + reranker: Option, +) -> Result, AppError> { + if config.strategy == RetrievalStrategy::Initial { + let driver = InitialStrategyDriver::new(); + let run = execute_strategy( + driver, + db_client, + openai_client, + Some(query_embedding), + input_text, + user_id, + config, + reranker, + false, + ) + .await?; + return Ok(PipelineRunOutput { + results: StrategyOutput::Entities(run.results), + diagnostics: run.diagnostics, + stage_timings: run.stage_timings, + }); + } + + let driver = RevisedStrategyDriver::new(); + let run = execute_strategy( + driver, + db_client, + openai_client, + Some(query_embedding), + input_text, + user_id, + config, + reranker, + false, + ) + .await?; + Ok(PipelineRunOutput { + results: StrategyOutput::Chunks(run.results), + diagnostics: run.diagnostics, + stage_timings: run.stage_timings, + }) +} + +pub async fn run_pipeline_with_embedding_with_diagnostics( + db_client: &SurrealDbClient, + openai_client: &Client, + query_embedding: Vec, + input_text: &str, + user_id: &str, + config: RetrievalConfig, + reranker: Option, +) -> Result, AppError> { + if config.strategy == RetrievalStrategy::Initial { + let driver = InitialStrategyDriver::new(); + let run = execute_strategy( + driver, + db_client, + openai_client, + Some(query_embedding), + input_text, + user_id, + config, + reranker, + true, + ) + .await?; + return Ok(PipelineRunOutput { + results: StrategyOutput::Entities(run.results), + diagnostics: run.diagnostics, + stage_timings: run.stage_timings, + }); + } + + let driver = RevisedStrategyDriver::new(); + let run = execute_strategy( + driver, + db_client, + openai_client, + Some(query_embedding), + input_text, + user_id, + config, + reranker, + true, + ) + .await?; + Ok(PipelineRunOutput { + results: StrategyOutput::Chunks(run.results), + diagnostics: run.diagnostics, + stage_timings: run.stage_timings, + }) +} + +pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value { + serde_json::json!(entities + .iter() + .map(|entry| { + serde_json::json!({ + "KnowledgeEntity": { + "id": entry.entity.id, + "name": entry.entity.name, + "description": entry.entity.description, + "score": round_score(entry.score), + "chunks": entry.chunks.iter().map(|chunk| { + serde_json::json!({ + "score": round_score(chunk.score), + "content": chunk.chunk.chunk + }) + }).collect::>() + } + }) + }) + .collect::>()) +} + +async fn execute_strategy( + driver: D, + db_client: &SurrealDbClient, + openai_client: &Client, + query_embedding: Option>, + input_text: &str, + user_id: &str, + mut config: RetrievalConfig, + reranker: Option, + capture_diagnostics: bool, +) -> Result, AppError> { + driver.override_tuning(&mut config); + let ctx = match query_embedding { + Some(embedding) => PipelineContext::with_embedding( + db_client, + openai_client, + embedding, + input_text.to_owned(), + user_id.to_owned(), + config, + reranker, + ), + None => PipelineContext::new( + db_client, + openai_client, + input_text.to_owned(), + user_id.to_owned(), + config, + reranker, + ), + }; + + run_with_driver(driver, ctx, capture_diagnostics).await +} + +async fn run_with_driver( + driver: D, + mut ctx: PipelineContext<'_>, + capture_diagnostics: bool, +) -> Result, AppError> { + if capture_diagnostics { + ctx.enable_diagnostics(); + } + + for stage in driver.stages() { + let start = Instant::now(); + stage.execute(&mut ctx).await?; + ctx.record_stage_duration(stage.kind(), start.elapsed()); + } + + let diagnostics = ctx.take_diagnostics(); + let stage_timings = ctx.take_stage_timings(); + let results = driver.finalize(&mut ctx)?; + + Ok(PipelineRunOutput { + results, + diagnostics, + stage_timings, + }) +} + +fn round_score(value: f32) -> f64 { + (f64::from(value) * 1000.0).round() / 1000.0 +} diff --git a/composite-retrieval/src/pipeline/stages/mod.rs b/retrieval-pipeline/src/pipeline/stages/mod.rs similarity index 67% rename from composite-retrieval/src/pipeline/stages/mod.rs rename to retrieval-pipeline/src/pipeline/stages/mod.rs index d7944e3..6f94d39 100644 --- a/composite-retrieval/src/pipeline/stages/mod.rs +++ b/retrieval-pipeline/src/pipeline/stages/mod.rs @@ -1,4 +1,5 @@ use async_openai::Client; +use async_trait::async_trait; use common::{ error::AppError, storage::{ @@ -9,11 +10,9 @@ use common::{ }; use fastembed::RerankResult; use futures::{stream::FuturesUnordered, StreamExt}; -use state_machines::core::GuardError; use std::{ cmp::Ordering, collections::{HashMap, HashSet}, - time::Instant, }; use tracing::{debug, instrument, warn}; @@ -25,21 +24,20 @@ use crate::{ clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc, FusionWeights, Scored, }, - vector::find_items_by_vector_similarity_with_embedding, + vector::{ + find_chunk_snippets_by_vector_similarity_with_embedding, + find_items_by_vector_similarity_with_embedding, ChunkSnippet, + }, RetrievedChunk, RetrievedEntity, }; use super::{ - config::RetrievalConfig, + config::{RetrievalConfig, RetrievalTuning}, diagnostics::{ AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace, PipelineDiagnostics, }, - state::{ - CandidatesLoaded, ChunksAttached, Embedded, GraphExpanded, HybridRetrievalMachine, Ready, - Reranked, - }, - PipelineStageTimings, + PipelineStage, PipelineStageTimings, StageKind, }; pub struct PipelineContext<'a> { @@ -53,8 +51,11 @@ pub struct PipelineContext<'a> { pub chunk_candidates: HashMap>, pub filtered_entities: Vec>, pub chunk_values: Vec>, + pub revised_chunk_values: Vec>, pub reranker: Option, pub diagnostics: Option, + pub entity_results: Vec, + pub chunk_results: Vec, stage_timings: PipelineStageTimings, } @@ -78,8 +79,11 @@ impl<'a> PipelineContext<'a> { chunk_candidates: HashMap::new(), filtered_entities: Vec::new(), chunk_values: Vec::new(), + revised_chunk_values: Vec::new(), reranker, diagnostics: None, + entity_results: Vec::new(), + chunk_results: Vec::new(), stage_timings: PipelineStageTimings::default(), } } @@ -145,36 +149,151 @@ impl<'a> PipelineContext<'a> { self.diagnostics.take() } - pub fn record_collect_candidates_timing(&mut self, duration: std::time::Duration) { - self.stage_timings.record_collect_candidates(duration); - } - - pub fn record_graph_expansion_timing(&mut self, duration: std::time::Duration) { - self.stage_timings.record_graph_expansion(duration); - } - - pub fn record_chunk_attach_timing(&mut self, duration: std::time::Duration) { - self.stage_timings.record_chunk_attach(duration); - } - - pub fn record_rerank_timing(&mut self, duration: std::time::Duration) { - self.stage_timings.record_rerank(duration); - } - - pub fn record_assemble_timing(&mut self, duration: std::time::Duration) { - self.stage_timings.record_assemble(duration); - } - pub fn take_stage_timings(&mut self) -> PipelineStageTimings { std::mem::take(&mut self.stage_timings) } + + pub fn record_stage_duration(&mut self, kind: StageKind, duration: std::time::Duration) { + self.stage_timings.record(kind, duration); + } + + pub fn take_entity_results(&mut self) -> Vec { + std::mem::take(&mut self.entity_results) + } + + pub fn take_chunk_results(&mut self) -> Vec { + std::mem::take(&mut self.chunk_results) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct EmbedStage; + +#[async_trait] +impl PipelineStage for EmbedStage { + fn kind(&self) -> StageKind { + StageKind::Embed + } + + async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { + embed(ctx).await + } +} + +#[derive(Debug, Clone, Copy)] +pub struct CollectCandidatesStage; + +#[async_trait] +impl PipelineStage for CollectCandidatesStage { + fn kind(&self) -> StageKind { + StageKind::CollectCandidates + } + + async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { + collect_candidates(ctx).await + } +} + +#[derive(Debug, Clone, Copy)] +pub struct GraphExpansionStage; + +#[async_trait] +impl PipelineStage for GraphExpansionStage { + fn kind(&self) -> StageKind { + StageKind::GraphExpansion + } + + async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { + expand_graph(ctx).await + } +} + +#[derive(Debug, Clone, Copy)] +pub struct ChunkAttachStage; + +#[async_trait] +impl PipelineStage for ChunkAttachStage { + fn kind(&self) -> StageKind { + StageKind::ChunkAttach + } + + async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { + attach_chunks(ctx).await + } +} + +#[derive(Debug, Clone, Copy)] +pub struct RerankStage; + +#[async_trait] +impl PipelineStage for RerankStage { + fn kind(&self) -> StageKind { + StageKind::Rerank + } + + async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { + rerank(ctx).await + } +} + +#[derive(Debug, Clone, Copy)] +pub struct AssembleEntitiesStage; + +#[async_trait] +impl PipelineStage for AssembleEntitiesStage { + fn kind(&self) -> StageKind { + StageKind::Assemble + } + + async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { + assemble(ctx) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct ChunkVectorStage; + +#[async_trait] +impl PipelineStage for ChunkVectorStage { + fn kind(&self) -> StageKind { + StageKind::CollectCandidates + } + + async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { + collect_vector_chunks(ctx).await + } +} + +#[derive(Debug, Clone, Copy)] +pub struct ChunkRerankStage; + +#[async_trait] +impl PipelineStage for ChunkRerankStage { + fn kind(&self) -> StageKind { + StageKind::Rerank + } + + async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { + rerank_chunks(ctx).await + } +} + +#[derive(Debug, Clone, Copy)] +pub struct ChunkAssembleStage; + +#[async_trait] +impl PipelineStage for ChunkAssembleStage { + fn kind(&self) -> StageKind { + StageKind::Assemble + } + + async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { + assemble_chunks(ctx) + } } #[instrument(level = "trace", skip_all)] -pub async fn embed( - machine: HybridRetrievalMachine<(), Ready>, - ctx: &mut PipelineContext<'_>, -) -> Result, AppError> { +pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { let embedding_cached = ctx.query_embedding.is_some(); if embedding_cached { debug!("Reusing cached query embedding for hybrid retrieval"); @@ -185,17 +304,11 @@ pub async fn embed( ctx.query_embedding = Some(embedding); } - machine - .embed() - .map_err(|(_, guard)| map_guard_error("embed", guard)) + Ok(()) } #[instrument(level = "trace", skip_all)] -pub async fn collect_candidates( - machine: HybridRetrievalMachine<(), Embedded>, - ctx: &mut PipelineContext<'_>, -) -> Result, AppError> { - let stage_start = Instant::now(); +pub async fn collect_candidates(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { debug!("Collecting initial candidates via vector and FTS search"); let embedding = ctx.ensure_embedding()?.clone(); let tuning = &ctx.config.tuning; @@ -265,104 +378,80 @@ pub async fn collect_candidates( apply_fusion(&mut ctx.entity_candidates, weights); apply_fusion(&mut ctx.chunk_candidates, weights); - let next = machine - .collect_candidates() - .map_err(|(_, guard)| map_guard_error("collect_candidates", guard))?; - ctx.record_collect_candidates_timing(stage_start.elapsed()); - Ok(next) + Ok(()) } #[instrument(level = "trace", skip_all)] -pub async fn expand_graph( - machine: HybridRetrievalMachine<(), CandidatesLoaded>, - ctx: &mut PipelineContext<'_>, -) -> Result, AppError> { - let stage_start = Instant::now(); +pub async fn expand_graph(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { debug!("Expanding candidates using graph relationships"); - let next = { - let tuning = &ctx.config.tuning; - let weights = FusionWeights::default(); + let tuning = &ctx.config.tuning; + let weights = FusionWeights::default(); - if ctx.entity_candidates.is_empty() { - machine - .expand_graph() - .map_err(|(_, guard)| map_guard_error("expand_graph", guard)) - } else { - let graph_seeds = seeds_from_candidates( - &ctx.entity_candidates, - tuning.graph_seed_min_score, - tuning.graph_traversal_seed_limit, - ); + if ctx.entity_candidates.is_empty() { + return Ok(()); + } - if graph_seeds.is_empty() { - machine - .expand_graph() - .map_err(|(_, guard)| map_guard_error("expand_graph", guard)) - } else { - let mut futures = FuturesUnordered::new(); - for seed in graph_seeds { - let db = ctx.db_client; - let user = ctx.user_id.clone(); - let limit = tuning.graph_neighbor_limit; - futures.push(async move { - let neighbors = - find_entities_by_relationship_by_id(db, &seed.id, &user, limit).await; - (seed, neighbors) - }); - } + let graph_seeds = seeds_from_candidates( + &ctx.entity_candidates, + tuning.graph_seed_min_score, + tuning.graph_traversal_seed_limit, + ); - while let Some((seed, neighbors_result)) = futures.next().await { - let neighbors = neighbors_result.map_err(AppError::from)?; - if neighbors.is_empty() { - continue; - } + if graph_seeds.is_empty() { + return Ok(()); + } - for neighbor in neighbors { - if neighbor.id == seed.id { - continue; - } + let mut futures = FuturesUnordered::new(); + for seed in graph_seeds { + let db = ctx.db_client; + let user = ctx.user_id.clone(); + let limit = tuning.graph_neighbor_limit; + futures.push(async move { + let neighbors = find_entities_by_relationship_by_id(db, &seed.id, &user, limit).await; + (seed, neighbors) + }); + } - let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay); - let entry = ctx - .entity_candidates - .entry(neighbor.id.clone()) - .or_insert_with(|| Scored::new(neighbor.clone())); - - entry.item = neighbor; - - let inherited_vector = - clamp_unit(graph_score * tuning.graph_vector_inheritance); - let vector_existing = entry.scores.vector.unwrap_or(0.0); - if inherited_vector > vector_existing { - entry.scores.vector = Some(inherited_vector); - } - - let existing_graph = entry.scores.graph.unwrap_or(f32::MIN); - if graph_score > existing_graph || entry.scores.graph.is_none() { - entry.scores.graph = Some(graph_score); - } - - let fused = fuse_scores(&entry.scores, weights); - entry.update_fused(fused); - } - } - - machine - .expand_graph() - .map_err(|(_, guard)| map_guard_error("expand_graph", guard)) - } + while let Some((seed, neighbors_result)) = futures.next().await { + let neighbors = neighbors_result.map_err(AppError::from)?; + if neighbors.is_empty() { + continue; } - }?; - ctx.record_graph_expansion_timing(stage_start.elapsed()); - Ok(next) + + for neighbor in neighbors { + if neighbor.id == seed.id { + continue; + } + + let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay); + let entry = ctx + .entity_candidates + .entry(neighbor.id.clone()) + .or_insert_with(|| Scored::new(neighbor.clone())); + + entry.item = neighbor; + + let inherited_vector = clamp_unit(graph_score * tuning.graph_vector_inheritance); + let vector_existing = entry.scores.vector.unwrap_or(0.0); + if inherited_vector > vector_existing { + entry.scores.vector = Some(inherited_vector); + } + + let existing_graph = entry.scores.graph.unwrap_or(f32::MIN); + if graph_score > existing_graph || entry.scores.graph.is_none() { + entry.scores.graph = Some(graph_score); + } + + let fused = fuse_scores(&entry.scores, weights); + entry.update_fused(fused); + } + } + + Ok(()) } #[instrument(level = "trace", skip_all)] -pub async fn attach_chunks( - machine: HybridRetrievalMachine<(), GraphExpanded>, - ctx: &mut PipelineContext<'_>, -) -> Result, AppError> { - let stage_start = Instant::now(); +pub async fn attach_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { debug!("Attaching chunks to surviving entities"); let tuning = &ctx.config.tuning; let weights = FusionWeights::default(); @@ -438,19 +527,11 @@ pub async fn attach_chunks( ctx.chunk_values = chunk_values; - let next = machine - .attach_chunks() - .map_err(|(_, guard)| map_guard_error("attach_chunks", guard))?; - ctx.record_chunk_attach_timing(stage_start.elapsed()); - Ok(next) + Ok(()) } #[instrument(level = "trace", skip_all)] -pub async fn rerank( - machine: HybridRetrievalMachine<(), ChunksAttached>, - ctx: &mut PipelineContext<'_>, -) -> Result, AppError> { - let stage_start = Instant::now(); +pub async fn rerank(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { let mut applied = false; if let Some(reranker) = ctx.reranker.as_ref() { @@ -490,19 +571,124 @@ pub async fn rerank( debug!("Applied reranking adjustments to candidate ordering"); } - let next = machine - .rerank() - .map_err(|(_, guard)| map_guard_error("rerank", guard))?; - ctx.record_rerank_timing(stage_start.elapsed()); - Ok(next) + Ok(()) } #[instrument(level = "trace", skip_all)] -pub fn assemble( - machine: HybridRetrievalMachine<(), Reranked>, - ctx: &mut PipelineContext<'_>, -) -> Result, AppError> { - let stage_start = Instant::now(); +pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { + debug!("Collecting vector chunk candidates for revised strategy"); + let embedding = ctx.ensure_embedding()?.clone(); + let tuning = &ctx.config.tuning; + let mut vector_chunks = find_chunk_snippets_by_vector_similarity_with_embedding( + tuning.chunk_vector_take, + embedding, + ctx.db_client, + &ctx.user_id, + ) + .await?; + + if ctx.diagnostics_enabled() { + ctx.record_collect_candidates(CollectCandidatesStats { + vector_entity_candidates: 0, + vector_chunk_candidates: vector_chunks.len(), + fts_entity_candidates: 0, + fts_chunk_candidates: 0, + vector_chunk_scores: sample_scores(&vector_chunks, |chunk| { + chunk.scores.vector.unwrap_or(0.0) + }), + fts_chunk_scores: Vec::new(), + }); + } + + vector_chunks.sort_by(|a, b| b.fused.partial_cmp(&a.fused).unwrap_or(Ordering::Equal)); + ctx.revised_chunk_values = vector_chunks; + + Ok(()) +} + +#[instrument(level = "trace", skip_all)] +pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { + if ctx.revised_chunk_values.len() <= 1 { + return Ok(()); + } + + let Some(reranker) = ctx.reranker.as_ref() else { + debug!("No reranker lease provided; skipping chunk rerank stage"); + return Ok(()); + }; + + let documents = build_snippet_rerank_documents( + &ctx.revised_chunk_values, + ctx.config.tuning.rerank_keep_top.max(1), + ); + if documents.len() <= 1 { + debug!("Skipping chunk reranking stage; insufficient chunk documents"); + return Ok(()); + } + + match reranker.rerank(&ctx.input_text, documents).await { + Ok(results) if !results.is_empty() => { + apply_snippet_rerank_results( + &mut ctx.revised_chunk_values, + &ctx.config.tuning, + results, + ); + } + Ok(_) => debug!("Chunk reranker returned no results; retaining original order"), + Err(err) => warn!( + error = %err, + "Chunk reranking failed; continuing with original ordering" + ), + } + + Ok(()) +} + +#[instrument(level = "trace", skip_all)] +pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { + debug!("Assembling chunk-only retrieval results"); + let mut chunk_values = std::mem::take(&mut ctx.revised_chunk_values); + let question_terms = extract_keywords(&ctx.input_text); + rank_snippet_chunks_by_combined_score( + &mut chunk_values, + &question_terms, + ctx.config.tuning.lexical_match_weight, + ); + + let limit = ctx.config.tuning.chunk_vector_take.max(1); + if chunk_values.len() > limit { + chunk_values.truncate(limit); + } + + ctx.chunk_results = chunk_values + .into_iter() + .map(|chunk| { + let text_chunk = snippet_into_text_chunk(chunk.item, &ctx.user_id); + RetrievedChunk { + chunk: text_chunk, + score: chunk.fused, + } + }) + .collect(); + + if ctx.diagnostics_enabled() { + ctx.record_assemble(AssembleStats { + token_budget_start: ctx.config.tuning.token_budget_estimate, + token_budget_spent: 0, + token_budget_remaining: ctx.config.tuning.token_budget_estimate, + budget_exhausted: false, + chunks_selected: ctx.chunk_results.len(), + chunks_skipped_due_budget: 0, + entity_count: 0, + entity_traces: Vec::new(), + }); + } + + Ok(()) +} + +#[instrument(level = "trace", skip_all)] +pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { debug!("Assembling final retrieved entities"); let tuning = &ctx.config.tuning; let query_embedding = ctx.ensure_embedding()?.clone(); @@ -610,11 +796,8 @@ pub fn assemble( }); } - machine - .assemble() - .map_err(|(_, guard)| map_guard_error("assemble", guard))?; - ctx.record_assemble_timing(stage_start.elapsed()); - Ok(results) + ctx.entity_results = results; + Ok(()) } const SCORE_SAMPLE_LIMIT: usize = 8; @@ -630,12 +813,6 @@ where .collect() } -fn map_guard_error(stage: &'static str, err: GuardError) -> AppError { - AppError::InternalError(format!( - "state machine guard '{stage}' failed: guard={}, event={}, kind={:?}", - err.guard, err.event, err.kind - )) -} fn normalize_fts_scores(results: &mut [Scored]) { let raw_scores: Vec = results .iter() @@ -873,6 +1050,23 @@ fn build_rerank_documents(ctx: &PipelineContext<'_>, max_chunks_per_entity: usiz .collect() } +fn build_snippet_rerank_documents( + chunks: &[Scored], + max_chunks: usize, +) -> Vec { + chunks + .iter() + .take(max_chunks) + .map(|chunk| { + format!( + "Source: {}\nChunk:\n{}", + chunk.item.source_id, + chunk.item.chunk.trim() + ) + }) + .collect() +} + fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec) { if results.is_empty() || ctx.filtered_entities.is_empty() { return; @@ -930,6 +1124,66 @@ fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec>, + tuning: &RetrievalTuning, + results: Vec, +) { + if results.is_empty() || chunks.is_empty() { + return; + } + + let mut remaining: Vec>> = + std::mem::take(chunks).into_iter().map(Some).collect(); + + let raw_scores: Vec = results.iter().map(|r| r.score).collect(); + let normalized_scores = min_max_normalize(&raw_scores); + + let use_only = tuning.rerank_scores_only; + let blend = if use_only { + 1.0 + } else { + clamp_unit(tuning.rerank_blend_weight) + }; + + let mut reranked: Vec> = Vec::with_capacity(remaining.len()); + for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) { + if let Some(slot) = remaining.get_mut(result.index) { + if let Some(mut candidate) = slot.take() { + let original = candidate.fused; + let blended = if use_only { + clamp_unit(normalized) + } else { + clamp_unit(original * (1.0 - blend) + normalized * blend) + }; + candidate.update_fused(blended); + reranked.push(candidate); + } + } else { + warn!( + result_index = result.index, + "Chunk reranker returned out-of-range index; skipping" + ); + } + if reranked.len() == remaining.len() { + break; + } + } + + for slot in remaining.into_iter() { + if let Some(candidate) = slot { + reranked.push(candidate); + } + } + + let keep_top = tuning.rerank_keep_top; + if keep_top > 0 && reranked.len() > keep_top { + reranked.truncate(keep_top); + } + + *chunks = reranked; +} + fn estimate_tokens(text: &str, avg_chars_per_token: usize) -> usize { let chars = text.chars().count().max(1); (chars / avg_chars_per_token).max(1) @@ -963,6 +1217,32 @@ fn extract_keywords(text: &str) -> Vec { terms } +fn rank_snippet_chunks_by_combined_score( + candidates: &mut [Scored], + question_terms: &[String], + lexical_weight: f32, +) { + if lexical_weight > 0.0 && !question_terms.is_empty() { + for candidate in candidates.iter_mut() { + let lexical = lexical_overlap_score(question_terms, &candidate.item.chunk); + let combined = clamp_unit(candidate.fused + lexical_weight * lexical); + candidate.update_fused(combined); + } + } + candidates.sort_by(|a, b| b.fused.partial_cmp(&a.fused).unwrap_or(Ordering::Equal)); +} + +fn snippet_into_text_chunk(snippet: ChunkSnippet, user_id: &str) -> TextChunk { + let mut chunk = TextChunk::new( + snippet.source_id.clone(), + snippet.chunk, + Vec::new(), + user_id.to_owned(), + ); + chunk.id = snippet.id; + chunk +} + fn lexical_overlap_score(terms: &[String], haystack: &str) -> f32 { if terms.is_empty() { return 0.0; diff --git a/retrieval-pipeline/src/pipeline/strategies.rs b/retrieval-pipeline/src/pipeline/strategies.rs new file mode 100644 index 0000000..a0675e4 --- /dev/null +++ b/retrieval-pipeline/src/pipeline/strategies.rs @@ -0,0 +1,75 @@ +use super::{ + stages::{ + AssembleEntitiesStage, ChunkAssembleStage, ChunkAttachStage, ChunkRerankStage, + ChunkVectorStage, CollectCandidatesStage, EmbedStage, GraphExpansionStage, PipelineContext, + RerankStage, + }, + BoxedStage, RetrievalConfig, RetrievalStrategy, StrategyDriver, +}; +use crate::{RetrievedChunk, RetrievedEntity}; +use common::error::AppError; + +pub struct InitialStrategyDriver; + +impl InitialStrategyDriver { + pub fn new() -> Self { + Self + } +} + +impl StrategyDriver for InitialStrategyDriver { + type Output = Vec; + + fn strategy(&self) -> RetrievalStrategy { + RetrievalStrategy::Initial + } + + fn stages(&self) -> Vec { + vec![ + Box::new(EmbedStage), + Box::new(CollectCandidatesStage), + Box::new(GraphExpansionStage), + Box::new(ChunkAttachStage), + Box::new(RerankStage), + Box::new(AssembleEntitiesStage), + ] + } + + fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result { + Ok(ctx.take_entity_results()) + } +} + +pub struct RevisedStrategyDriver; + +impl RevisedStrategyDriver { + pub fn new() -> Self { + Self + } +} + +impl StrategyDriver for RevisedStrategyDriver { + type Output = Vec; + + fn strategy(&self) -> RetrievalStrategy { + RetrievalStrategy::Revised + } + + fn stages(&self) -> Vec { + vec![ + Box::new(EmbedStage), + Box::new(ChunkVectorStage), + Box::new(ChunkRerankStage), + Box::new(ChunkAssembleStage), + ] + } + + fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result { + Ok(ctx.take_chunk_results()) + } + + fn override_tuning(&self, config: &mut RetrievalConfig) { + config.tuning.entity_vector_take = 0; + config.tuning.entity_fts_take = 0; + } +} diff --git a/composite-retrieval/src/reranking/mod.rs b/retrieval-pipeline/src/reranking/mod.rs similarity index 100% rename from composite-retrieval/src/reranking/mod.rs rename to retrieval-pipeline/src/reranking/mod.rs diff --git a/composite-retrieval/src/scoring.rs b/retrieval-pipeline/src/scoring.rs similarity index 100% rename from composite-retrieval/src/scoring.rs rename to retrieval-pipeline/src/scoring.rs diff --git a/composite-retrieval/src/vector.rs b/retrieval-pipeline/src/vector.rs similarity index 73% rename from composite-retrieval/src/vector.rs rename to retrieval-pipeline/src/vector.rs index 229ec12..94a3514 100644 --- a/composite-retrieval/src/vector.rs +++ b/retrieval-pipeline/src/vector.rs @@ -1,9 +1,11 @@ use std::collections::HashMap; -use common::storage::types::file_info::deserialize_flexible_id; use common::{ error::AppError, - storage::{db::SurrealDbClient, types::StoredObject}, + storage::{ + db::SurrealDbClient, + types::{file_info::deserialize_flexible_id, StoredObject}, + }, utils::embedding::generate_embedding, }; use serde::Deserialize; @@ -156,3 +158,61 @@ where Ok(scored) } + +#[derive(Debug, Clone, Deserialize)] +pub struct ChunkSnippet { + pub id: String, + pub source_id: String, + pub chunk: String, +} + +#[derive(Debug, Deserialize)] +struct ChunkDistanceRow { + distance: Option, + #[serde(deserialize_with = "deserialize_flexible_id")] + pub id: String, + pub source_id: String, + pub chunk: String, +} + +pub async fn find_chunk_snippets_by_vector_similarity_with_embedding( + take: usize, + query_embedding: Vec, + db_client: &SurrealDbClient, + user_id: &str, +) -> Result>, AppError> { + let embedding_literal = serde_json::to_string(&query_embedding) + .map_err(|err| AppError::InternalError(format!("Failed to serialize embedding: {err}")))?; + + let closest_query = format!( + "SELECT id, source_id, chunk, vector::distance::knn() AS distance \ + FROM text_chunk \ + WHERE user_id = $user_id AND embedding <|{take},40|> {embedding} \ + LIMIT $limit", + take = take, + embedding = embedding_literal + ); + + let mut response = db_client + .query(closest_query) + .bind(("user_id", user_id.to_owned())) + .bind(("limit", take as i64)) + .await?; + + let rows: Vec = response.take(0)?; + + let mut scored = Vec::with_capacity(rows.len()); + for row in rows { + let similarity = row.distance.map(distance_to_similarity).unwrap_or_default(); + scored.push( + Scored::new(ChunkSnippet { + id: row.id, + source_id: row.source_id, + chunk: row.chunk, + }) + .with_vector_score(similarity), + ); + } + + Ok(scored) +}