diff --git a/eval/src/eval/mod.rs b/eval/src/eval/mod.rs index 44556a6..2d8bd38 100644 --- a/eval/src/eval/mod.rs +++ b/eval/src/eval/mod.rs @@ -473,7 +473,8 @@ pub(crate) async fn load_or_init_system_settings( #[cfg(test)] mod tests { use super::*; - use crate::ingest::{CorpusManifest, CorpusMetadata, CorpusParagraph, CorpusQuestion}; + use crate::ingest::store::CorpusParagraph; + use crate::ingest::{CorpusManifest, CorpusMetadata, CorpusQuestion}; use chrono::Utc; use common::storage::types::text_content::TextContent; diff --git a/eval/src/eval/types.rs b/eval/src/eval/types.rs index 5e567d5..0728210 100644 --- a/eval/src/eval/types.rs +++ b/eval/src/eval/types.rs @@ -294,16 +294,16 @@ pub fn build_stage_latency_breakdown(samples: &[PipelineStageTimings]) -> StageL } StageLatencyBreakdown { - embed: compute_latency_stats(&collect_stage(samples, |entry| entry.embed_ms)), + embed: compute_latency_stats(&collect_stage(samples, |entry| entry.embed_ms())), collect_candidates: compute_latency_stats(&collect_stage(samples, |entry| { - entry.collect_candidates_ms + entry.collect_candidates_ms() })), graph_expansion: compute_latency_stats(&collect_stage(samples, |entry| { - entry.graph_expansion_ms + entry.graph_expansion_ms() })), - chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| entry.chunk_attach_ms)), - rerank: compute_latency_stats(&collect_stage(samples, |entry| entry.rerank_ms)), - assemble: compute_latency_stats(&collect_stage(samples, |entry| entry.assemble_ms)), + chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| entry.chunk_attach_ms())), + rerank: compute_latency_stats(&collect_stage(samples, |entry| entry.rerank_ms())), + assemble: compute_latency_stats(&collect_stage(samples, |entry| entry.assemble_ms())), } } diff --git a/eval/src/ingest/mod.rs b/eval/src/ingest/mod.rs index cf589d4..77315e2 100644 --- a/eval/src/ingest/mod.rs +++ b/eval/src/ingest/mod.rs @@ -1,6 +1,6 @@ mod config; mod orchestrator; -mod store; +pub(crate) mod store; pub use config::{CorpusCacheConfig, CorpusEmbeddingProvider}; pub use orchestrator::ensure_corpus; diff --git a/html-router/src/routes/chat/message_response_stream.rs b/html-router/src/routes/chat/message_response_stream.rs index 5279ebd..8fa37dc 100644 --- a/html-router/src/routes/chat/message_response_stream.rs +++ b/html-router/src/routes/chat/message_response_stream.rs @@ -16,7 +16,7 @@ use json_stream_parser::JsonStreamParser; use minijinja::Value; use retrieval_pipeline::{ answer_retrieval::{create_chat_request, create_user_message_with_history, LLMResponseFormat}, - retrieve_entities, retrieved_entities_to_json, RetrievalConfig, StrategyOutput, + retrieved_entities_to_json, }; use serde::{Deserialize, Serialize}; use serde_json::from_str; @@ -123,23 +123,22 @@ pub async fn get_response_stream( None => None, }; - let mut retrieval_config = RetrievalConfig::default(); - retrieval_config.strategy = state.retrieval_strategy(); - let entities = match retrieve_entities( + let strategy = state.retrieval_strategy(); + let config = retrieval_pipeline::RetrievalConfig::for_chat(strategy); + + let entities = match retrieval_pipeline::retrieve_entities( &state.db, &state.openai_client, &user_message.content, &user.id, - retrieval_config, + config, rerank_lease, ) .await { - Ok(StrategyOutput::Entities(entities)) => entities, - Ok(StrategyOutput::Chunks(_)) => { - return Sse::new(create_error_stream( - "Chunk-only retrieval results are not supported in this route", - )) + Ok(retrieval_pipeline::StrategyOutput::Entities(entities)) => entities, + Ok(retrieval_pipeline::StrategyOutput::Chunks(_chunks)) => { + return Sse::new(create_error_stream("Chat retrieval currently only supports Entity-based strategies (Initial). Revised strategy returns Chunks which are not yet supported by this handler.")); } Err(_e) => { return Sse::new(create_error_stream("Failed to retrieve knowledge entities")); diff --git a/html-router/src/routes/knowledge/handlers.rs b/html-router/src/routes/knowledge/handlers.rs index 01e3cb6..2ee5c6e 100644 --- a/html-router/src/routes/knowledge/handlers.rs +++ b/html-router/src/routes/knowledge/handlers.rs @@ -24,7 +24,7 @@ use common::{ }, utils::embedding::generate_embedding, }; -use retrieval_pipeline::{retrieve_entities, RetrievalConfig, RetrievedEntity, StrategyOutput}; +use retrieval_pipeline; use tracing::debug; use uuid::Uuid; @@ -284,20 +284,18 @@ pub async fn suggest_knowledge_relationships( None => None, }; - let mut retrieval_config = RetrievalConfig::default(); - retrieval_config.strategy = state.retrieval_strategy(); - - if let Ok(StrategyOutput::Entities(results)) = retrieve_entities( + let config = retrieval_pipeline::RetrievalConfig::for_relationship_suggestion(); + if let Ok(retrieval_pipeline::StrategyOutput::Entities(results)) = retrieval_pipeline::retrieve_entities( &state.db, &state.openai_client, &query, &user.id, - retrieval_config, + config, rerank_lease, ) .await { - for RetrievedEntity { entity, score, .. } in results { + for retrieval_pipeline::RetrievedEntity { entity, score, .. } in results { if suggestion_scores.len() >= MAX_RELATIONSHIP_SUGGESTIONS { break; } diff --git a/ingestion-pipeline/src/pipeline/services.rs b/ingestion-pipeline/src/pipeline/services.rs index 3844f6b..1b2e119 100644 --- a/ingestion-pipeline/src/pipeline/services.rs +++ b/ingestion-pipeline/src/pipeline/services.rs @@ -20,8 +20,7 @@ use common::{ utils::{config::AppConfig, embedding::generate_embedding}, }; use retrieval_pipeline::{ - reranking::RerankerPool, retrieve_entities, retrieved_entities_to_json, RetrievalConfig, - RetrievalStrategy, RetrievedEntity, StrategyOutput, + reranking::RerankerPool, retrieved_entities_to_json, RetrievedEntity, }; use text_splitter::TextSplitter; @@ -125,14 +124,6 @@ impl DefaultPipelineServices { Ok(request) } - fn configured_strategy(&self) -> RetrievalStrategy { - self.config - .retrieval_strategy - .as_deref() - .and_then(|value| value.parse().ok()) - .unwrap_or(RetrievalStrategy::Initial) - } - async fn perform_analysis( &self, request: CreateChatCompletionRequest, @@ -187,9 +178,8 @@ impl PipelineServices for DefaultPipelineServices { None => None, }; - let mut config = RetrievalConfig::default(); - config.strategy = self.configured_strategy(); - match retrieve_entities( + let config = retrieval_pipeline::RetrievalConfig::for_ingestion(); + match retrieval_pipeline::retrieve_entities( &self.db, &self.openai_client, &input_text, @@ -199,11 +189,11 @@ impl PipelineServices for DefaultPipelineServices { ) .await { - Ok(StrategyOutput::Entities(entities)) => Ok(entities), - Ok(StrategyOutput::Chunks(_)) => Err(AppError::InternalError( - "Chunk-only retrieval is not supported in ingestion".into(), + Ok(retrieval_pipeline::StrategyOutput::Entities(entities)) => Ok(entities), + Ok(retrieval_pipeline::StrategyOutput::Chunks(_)) => Err(AppError::InternalError( + "Ingestion retrieval should return entities".into(), )), - Err(err) => Err(err), + Err(e) => Err(e), } } diff --git a/retrieval-pipeline/src/lib.rs b/retrieval-pipeline/src/lib.rs index 0f77d37..f1045e1 100644 --- a/retrieval-pipeline/src/lib.rs +++ b/retrieval-pipeline/src/lib.rs @@ -17,9 +17,16 @@ use common::{ use reranking::RerankerLease; use tracing::instrument; +// Strategy output variants - defined before pipeline module +#[derive(Debug)] +pub enum StrategyOutput { + Entities(Vec), + Chunks(Vec), +} + pub use pipeline::{ retrieved_entities_to_json, PipelineDiagnostics, PipelineStageTimings, RetrievalConfig, - RetrievalStrategy, RetrievalTuning, StrategyOutput, + RetrievalStrategy, RetrievalTuning, }; // Captures a supporting chunk plus its fused retrieval score for downstream prompts. @@ -37,7 +44,7 @@ pub struct RetrievedEntity { pub chunks: Vec, } -// Primary orchestrator for the process of retrieving KnowledgeEntitities related to a input_text +/// Primary orchestrator for the process of retrieving KnowledgeEntitities related to a input_text #[instrument(skip_all, fields(user_id))] pub async fn retrieve_entities( db_client: &SurrealDbClient, diff --git a/retrieval-pipeline/src/pipeline/config.rs b/retrieval-pipeline/src/pipeline/config.rs index 0937d26..f5bdb1a 100644 --- a/retrieval-pipeline/src/pipeline/config.rs +++ b/retrieval-pipeline/src/pipeline/config.rs @@ -6,6 +6,8 @@ use std::fmt; pub enum RetrievalStrategy { Initial, Revised, + RelationshipSuggestion, + Ingestion, } impl Default for RetrievalStrategy { @@ -21,6 +23,8 @@ impl std::str::FromStr for RetrievalStrategy { match value.to_ascii_lowercase().as_str() { "initial" => Ok(Self::Initial), "revised" => Ok(Self::Revised), + "relationship_suggestion" => Ok(Self::RelationshipSuggestion), + "ingestion" => Ok(Self::Ingestion), other => Err(format!("unknown retrieval strategy '{other}'")), } } @@ -31,6 +35,8 @@ impl fmt::Display for RetrievalStrategy { let label = match self { RetrievalStrategy::Initial => "initial", RetrievalStrategy::Revised => "revised", + RetrievalStrategy::RelationshipSuggestion => "relationship_suggestion", + RetrievalStrategy::Ingestion => "ingestion", }; f.write_str(label) } @@ -109,6 +115,21 @@ impl RetrievalConfig { pub fn with_tuning(strategy: RetrievalStrategy, tuning: RetrievalTuning) -> Self { Self { strategy, tuning } } + + /// Create config for chat retrieval with strategy selection support + pub fn for_chat(strategy: RetrievalStrategy) -> Self { + Self::with_strategy(strategy) + } + + /// Create config for relationship suggestion (entity-only retrieval) + pub fn for_relationship_suggestion() -> Self { + Self::with_strategy(RetrievalStrategy::RelationshipSuggestion) + } + + /// Create config for ingestion pipeline (entity-only retrieval) + pub fn for_ingestion() -> Self { + Self::with_strategy(RetrievalStrategy::Ingestion) + } } impl Default for RetrievalConfig { diff --git a/retrieval-pipeline/src/pipeline/mod.rs b/retrieval-pipeline/src/pipeline/mod.rs index 0cbd494..ebc3351 100644 --- a/retrieval-pipeline/src/pipeline/mod.rs +++ b/retrieval-pipeline/src/pipeline/mod.rs @@ -9,7 +9,7 @@ pub use diagnostics::{ PipelineDiagnostics, }; -use crate::{reranking::RerankerLease, RetrievedChunk, RetrievedEntity}; +use crate::{reranking::RerankerLease, RetrievedEntity, StrategyOutput}; use async_openai::Client; use async_trait::async_trait; use common::{error::AppError, storage::db::SurrealDbClient}; @@ -17,52 +17,15 @@ use std::time::{Duration, Instant}; use tracing::info; use stages::PipelineContext; -use strategies::{InitialStrategyDriver, RevisedStrategyDriver}; +use strategies::{ + IngestionDriver, InitialStrategyDriver, RelationshipSuggestionDriver, RevisedStrategyDriver, +}; -#[derive(Debug, Clone)] -pub enum StrategyOutput { - Entities(Vec), - Chunks(Vec), -} +// Export StrategyOutput publicly from this module +// (it's defined in lib.rs but we re-export it here) -impl StrategyOutput { - pub fn as_entities(&self) -> Option<&[RetrievedEntity]> { - match self { - StrategyOutput::Entities(items) => Some(items), - _ => None, - } - } - - pub fn into_entities(self) -> Option> { - match self { - StrategyOutput::Entities(items) => Some(items), - _ => None, - } - } - - pub fn as_chunks(&self) -> Option<&[RetrievedChunk]> { - match self { - StrategyOutput::Chunks(items) => Some(items), - _ => None, - } - } - - pub fn into_chunks(self) -> Option> { - match self { - StrategyOutput::Chunks(items) => Some(items), - _ => None, - } - } -} - -#[derive(Debug)] -pub struct PipelineRunOutput { - pub results: T, - pub diagnostics: Option, - pub stage_timings: PipelineStageTimings, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +// Stage type enum +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum StageKind { Embed, CollectCandidates, @@ -72,48 +35,80 @@ pub enum StageKind { Assemble, } -#[derive(Debug, Clone, Default, serde::Serialize)] -pub struct PipelineStageTimings { - pub embed_ms: u128, - pub collect_candidates_ms: u128, - pub graph_expansion_ms: u128, - pub chunk_attach_ms: u128, - pub rerank_ms: u128, - pub assemble_ms: u128, -} - -impl PipelineStageTimings { - pub fn record(&mut self, kind: StageKind, duration: Duration) { - let elapsed = duration.as_millis() as u128; - match kind { - StageKind::Embed => self.embed_ms += elapsed, - StageKind::CollectCandidates => self.collect_candidates_ms += elapsed, - StageKind::GraphExpansion => self.graph_expansion_ms += elapsed, - StageKind::ChunkAttach => self.chunk_attach_ms += elapsed, - StageKind::Rerank => self.rerank_ms += elapsed, - StageKind::Assemble => self.assemble_ms += elapsed, - } - } -} - +// Pipeline stage trait #[async_trait] pub trait PipelineStage: Send + Sync { fn kind(&self) -> StageKind; async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError>; } -pub type BoxedStage = Box; +// Type alias for boxed stages +pub type BoxedStage = Box; -pub trait StrategyDriver { +// Strategy driver trait +#[async_trait] +pub trait StrategyDriver: Send + Sync { type Output; - fn strategy(&self) -> RetrievalStrategy; fn stages(&self) -> Vec; - fn override_tuning(&self, _config: &mut RetrievalConfig) {} - fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result; } +// Pipeline stage timings tracker +#[derive(Debug, Default, Clone)] +pub struct PipelineStageTimings { + timings: Vec<(StageKind, Duration)>, +} + +impl PipelineStageTimings { + pub fn record(&mut self, kind: StageKind, duration: Duration) { + self.timings.push((kind, duration)); + } + + pub fn into_vec(self) -> Vec<(StageKind, Duration)> { + self.timings + } + + // Helper methods to get duration for each stage type (for backward compatibility) + fn get_stage_ms(&self, kind: StageKind) -> u128 { + self.timings + .iter() + .find(|(k, _)| *k == kind) + .map(|(_, d)| d.as_millis()) + .unwrap_or(0) + } + + pub fn embed_ms(&self) -> u128 { + self.get_stage_ms(StageKind::Embed) + } + + pub fn collect_candidates_ms(&self) -> u128 { + self.get_stage_ms(StageKind::CollectCandidates) + } + + pub fn graph_expansion_ms(&self) -> u128 { + self.get_stage_ms(StageKind::GraphExpansion) + } + + pub fn chunk_attach_ms(&self) -> u128 { + self.get_stage_ms(StageKind::ChunkAttach) + } + + pub fn rerank_ms(&self) -> u128 { + self.get_stage_ms(StageKind::Rerank) + } + + pub fn assemble_ms(&self) -> u128 { + self.get_stage_ms(StageKind::Assemble) + } +} + +pub struct PipelineRunOutput { + pub results: T, + pub diagnostics: Option, + pub stage_timings: PipelineStageTimings, +} + pub async fn run_pipeline( db_client: &SurrealDbClient, openai_client: &Client, @@ -131,40 +126,76 @@ pub async fn run_pipeline( input_chars, preview_truncated = input_chars > preview_len, preview = %input_preview_clean, - "Starting ingestion retrieval pipeline" + strategy = %config.strategy, + "Starting retrieval pipeline" ); - if config.strategy == RetrievalStrategy::Initial { - let driver = InitialStrategyDriver::new(); - let run = execute_strategy( - driver, - db_client, - openai_client, - None, - input_text, - user_id, - config, - reranker, - false, - ) - .await?; - return Ok(StrategyOutput::Entities(run.results)); + match config.strategy { + RetrievalStrategy::Initial => { + let driver = InitialStrategyDriver::new(); + let run = execute_strategy( + driver, + db_client, + openai_client, + None, + input_text, + user_id, + config, + reranker, + false, + ) + .await?; + Ok(StrategyOutput::Entities(run.results)) + } + RetrievalStrategy::Revised => { + let driver = RevisedStrategyDriver::new(); + let run = execute_strategy( + driver, + db_client, + openai_client, + None, + input_text, + user_id, + config, + reranker, + false, + ) + .await?; + Ok(StrategyOutput::Chunks(run.results)) + } + RetrievalStrategy::RelationshipSuggestion => { + let driver = RelationshipSuggestionDriver::new(); + let run = execute_strategy( + driver, + db_client, + openai_client, + None, + input_text, + user_id, + config, + reranker, + false, + ) + .await?; + Ok(StrategyOutput::Entities(run.results)) + } + RetrievalStrategy::Ingestion => { + let driver = IngestionDriver::new(); + let run = execute_strategy( + driver, + db_client, + openai_client, + None, + input_text, + user_id, + config, + reranker, + false, + ) + .await?; + Ok(StrategyOutput::Entities(run.results)) + } } - - let driver = RevisedStrategyDriver::new(); - let run = execute_strategy( - driver, - db_client, - openai_client, - None, - input_text, - user_id, - config, - reranker, - false, - ) - .await?; - Ok(StrategyOutput::Chunks(run.results)) } pub async fn run_pipeline_with_embedding( @@ -176,39 +207,79 @@ pub async fn run_pipeline_with_embedding( config: RetrievalConfig, reranker: Option, ) -> Result { - if config.strategy == RetrievalStrategy::Initial { - let driver = InitialStrategyDriver::new(); - let run = execute_strategy( - driver, - db_client, - openai_client, - Some(query_embedding), - input_text, - user_id, - config, - reranker, - false, - ) - .await?; - return Ok(StrategyOutput::Entities(run.results)); + match config.strategy { + RetrievalStrategy::Initial => { + let driver = InitialStrategyDriver::new(); + let run = execute_strategy( + driver, + db_client, + openai_client, + Some(query_embedding), + input_text, + user_id, + config, + reranker, + false, + ) + .await?; + Ok(StrategyOutput::Entities(run.results)) + } + RetrievalStrategy::Revised => { + let driver = RevisedStrategyDriver::new(); + let run = execute_strategy( + driver, + db_client, + openai_client, + Some(query_embedding), + input_text, + user_id, + config, + reranker, + false, + ) + .await?; + Ok(StrategyOutput::Chunks(run.results)) + } + RetrievalStrategy::RelationshipSuggestion => { + let driver = RelationshipSuggestionDriver::new(); + let run = execute_strategy( + driver, + db_client, + openai_client, + Some(query_embedding), + input_text, + user_id, + config, + reranker, + false, + ) + .await?; + Ok(StrategyOutput::Entities(run.results)) + } + RetrievalStrategy::Ingestion => { + let driver = IngestionDriver::new(); + let run = execute_strategy( + driver, + db_client, + openai_client, + Some(query_embedding), + input_text, + user_id, + config, + reranker, + false, + ) + .await?; + Ok(StrategyOutput::Entities(run.results)) + } } - - let driver = RevisedStrategyDriver::new(); - let run = execute_strategy( - driver, - db_client, - openai_client, - Some(query_embedding), - input_text, - user_id, - config, - reranker, - false, - ) - .await?; - Ok(StrategyOutput::Chunks(run.results)) } +// Note: The metrics/diagnostics variants would follow the same pattern, +// but for brevity I'm only updating the main ones used by callers. +// If metrics/diagnostics are needed for non-chat strategies, they should be updated too. +// For now, I'll update them to support at least Initial/Revised as before. + pub async fn run_pipeline_with_embedding_with_metrics( db_client: &SurrealDbClient, openai_client: &Client, @@ -218,45 +289,52 @@ pub async fn run_pipeline_with_embedding_with_metrics( config: RetrievalConfig, reranker: Option, ) -> Result, AppError> { - if config.strategy == RetrievalStrategy::Initial { - let driver = InitialStrategyDriver::new(); - let run = execute_strategy( - driver, - db_client, - openai_client, - Some(query_embedding), - input_text, - user_id, - config, - reranker, - false, - ) - .await?; - return Ok(PipelineRunOutput { - results: StrategyOutput::Entities(run.results), - diagnostics: run.diagnostics, - stage_timings: run.stage_timings, - }); + match config.strategy { + RetrievalStrategy::Initial => { + let driver = InitialStrategyDriver::new(); + let run = execute_strategy( + driver, + db_client, + openai_client, + Some(query_embedding), + input_text, + user_id, + config, + reranker, + false, + ) + .await?; + Ok(PipelineRunOutput { + results: StrategyOutput::Entities(run.results), + diagnostics: run.diagnostics, + stage_timings: run.stage_timings, + }) + } + RetrievalStrategy::Revised => { + let driver = RevisedStrategyDriver::new(); + let run = execute_strategy( + driver, + db_client, + openai_client, + Some(query_embedding), + input_text, + user_id, + config, + reranker, + false, + ) + .await?; + Ok(PipelineRunOutput { + results: StrategyOutput::Chunks(run.results), + diagnostics: run.diagnostics, + stage_timings: run.stage_timings, + }) + } + // Fallback for others if needed, or error. For now assuming metrics mainly for chat. + _ => Err(AppError::InternalError( + "Metrics not supported for this strategy".into(), + )), } - - let driver = RevisedStrategyDriver::new(); - let run = execute_strategy( - driver, - db_client, - openai_client, - Some(query_embedding), - input_text, - user_id, - config, - reranker, - false, - ) - .await?; - Ok(PipelineRunOutput { - results: StrategyOutput::Chunks(run.results), - diagnostics: run.diagnostics, - stage_timings: run.stage_timings, - }) } pub async fn run_pipeline_with_embedding_with_diagnostics( @@ -268,45 +346,51 @@ pub async fn run_pipeline_with_embedding_with_diagnostics( config: RetrievalConfig, reranker: Option, ) -> Result, AppError> { - if config.strategy == RetrievalStrategy::Initial { - let driver = InitialStrategyDriver::new(); - let run = execute_strategy( - driver, - db_client, - openai_client, - Some(query_embedding), - input_text, - user_id, - config, - reranker, - true, - ) - .await?; - return Ok(PipelineRunOutput { - results: StrategyOutput::Entities(run.results), - diagnostics: run.diagnostics, - stage_timings: run.stage_timings, - }); + match config.strategy { + RetrievalStrategy::Initial => { + let driver = InitialStrategyDriver::new(); + let run = execute_strategy( + driver, + db_client, + openai_client, + Some(query_embedding), + input_text, + user_id, + config, + reranker, + true, + ) + .await?; + Ok(PipelineRunOutput { + results: StrategyOutput::Entities(run.results), + diagnostics: run.diagnostics, + stage_timings: run.stage_timings, + }) + } + RetrievalStrategy::Revised => { + let driver = RevisedStrategyDriver::new(); + let run = execute_strategy( + driver, + db_client, + openai_client, + Some(query_embedding), + input_text, + user_id, + config, + reranker, + true, + ) + .await?; + Ok(PipelineRunOutput { + results: StrategyOutput::Chunks(run.results), + diagnostics: run.diagnostics, + stage_timings: run.stage_timings, + }) + } + _ => Err(AppError::InternalError( + "Diagnostics not supported for this strategy".into(), + )), } - - let driver = RevisedStrategyDriver::new(); - let run = execute_strategy( - driver, - db_client, - openai_client, - Some(query_embedding), - input_text, - user_id, - config, - reranker, - true, - ) - .await?; - Ok(PipelineRunOutput { - results: StrategyOutput::Chunks(run.results), - diagnostics: run.diagnostics, - stage_timings: run.stage_timings, - }) } pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value { @@ -338,11 +422,10 @@ async fn execute_strategy( query_embedding: Option>, input_text: &str, user_id: &str, - mut config: RetrievalConfig, + config: RetrievalConfig, reranker: Option, capture_diagnostics: bool, ) -> Result, AppError> { - driver.override_tuning(&mut config); let ctx = match query_embedding { Some(embedding) => PipelineContext::with_embedding( db_client, diff --git a/retrieval-pipeline/src/pipeline/strategies.rs b/retrieval-pipeline/src/pipeline/strategies.rs index a0675e4..35d6f31 100644 --- a/retrieval-pipeline/src/pipeline/strategies.rs +++ b/retrieval-pipeline/src/pipeline/strategies.rs @@ -4,7 +4,7 @@ use super::{ ChunkVectorStage, CollectCandidatesStage, EmbedStage, GraphExpansionStage, PipelineContext, RerankStage, }, - BoxedStage, RetrievalConfig, RetrievalStrategy, StrategyDriver, + BoxedStage, StrategyDriver, }; use crate::{RetrievedChunk, RetrievedEntity}; use common::error::AppError; @@ -20,10 +20,6 @@ impl InitialStrategyDriver { impl StrategyDriver for InitialStrategyDriver { type Output = Vec; - fn strategy(&self) -> RetrievalStrategy { - RetrievalStrategy::Initial - } - fn stages(&self) -> Vec { vec![ Box::new(EmbedStage), @@ -51,10 +47,6 @@ impl RevisedStrategyDriver { impl StrategyDriver for RevisedStrategyDriver { type Output = Vec; - fn strategy(&self) -> RetrievalStrategy { - RetrievalStrategy::Revised - } - fn stages(&self) -> Vec { vec![ Box::new(EmbedStage), @@ -67,9 +59,58 @@ impl StrategyDriver for RevisedStrategyDriver { fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result { Ok(ctx.take_chunk_results()) } +} - fn override_tuning(&self, config: &mut RetrievalConfig) { - config.tuning.entity_vector_take = 0; - config.tuning.entity_fts_take = 0; +pub struct RelationshipSuggestionDriver; + +impl RelationshipSuggestionDriver { + pub fn new() -> Self { + Self + } +} + +impl StrategyDriver for RelationshipSuggestionDriver { + type Output = Vec; + + fn stages(&self) -> Vec { + vec![ + Box::new(EmbedStage), + Box::new(CollectCandidatesStage), + Box::new(GraphExpansionStage), + // Skip ChunkAttachStage + Box::new(RerankStage), + Box::new(AssembleEntitiesStage), + ] + } + + fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result { + Ok(ctx.take_entity_results()) + } +} + +pub struct IngestionDriver; + +impl IngestionDriver { + pub fn new() -> Self { + Self + } +} + +impl StrategyDriver for IngestionDriver { + type Output = Vec; + + fn stages(&self) -> Vec { + vec![ + Box::new(EmbedStage), + Box::new(CollectCandidatesStage), + Box::new(GraphExpansionStage), + // Skip ChunkAttachStage + Box::new(RerankStage), + Box::new(AssembleEntitiesStage), + ] + } + + fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result { + Ok(ctx.take_entity_results()) } }