diff --git a/Cargo.lock b/Cargo.lock index 2fdb359..bb68cf7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5426,14 +5426,10 @@ dependencies = [ "anyhow", "async-openai", "async-trait", - "axum", "common", "fastembed", - "futures", "serde", "serde_json", - "surrealdb", - "thiserror 1.0.69", "tokio", "tracing", "uuid", diff --git a/common/src/storage/types/knowledge_entity.rs b/common/src/storage/types/knowledge_entity.rs index 1d4c830..3a59e95 100644 --- a/common/src/storage/types/knowledge_entity.rs +++ b/common/src/storage/types/knowledge_entity.rs @@ -164,6 +164,35 @@ impl KnowledgeEntity { .take(0)?) } + /// Fetch all knowledge entities owned by any of the provided source ids for a user. + /// + /// Used by retrieval to resolve the entities that own a set of retrieved chunks. + pub async fn find_by_source_ids( + db: &SurrealDbClient, + source_ids: &[String], + user_id: &str, + ) -> Result, AppError> { + if source_ids.is_empty() { + return Ok(Vec::new()); + } + + let entities: Vec = db + .client + .query( + "SELECT * FROM type::table($table) \ + WHERE source_id IN $sources AND user_id = $user_id", + ) + .bind(("table", Self::table_name())) + .bind(("sources", source_ids.to_vec())) + .bind(("user_id", user_id.to_owned())) + .await + .map_err(AppError::Database)? + .take(0) + .map_err(AppError::Database)?; + + Ok(entities) + } + pub async fn delete_by_source_id( source_id: &str, db_client: &SurrealDbClient, diff --git a/common/src/utils/config.rs b/common/src/utils/config.rs index 920eb32..ed306f4 100644 --- a/common/src/utils/config.rs +++ b/common/src/utils/config.rs @@ -1,8 +1,7 @@ use config::{Config, ConfigError, Environment, File}; -use serde::{Deserialize, Deserializer, Serialize}; -use std::{env, fmt, str::FromStr, sync::Once}; +use serde::{Deserialize, Serialize}; +use std::{env, str::FromStr, sync::Once}; use thiserror::Error; -use tracing::warn; /// Error returned when parsing an embedding backend name. #[derive(Debug, Error, PartialEq, Eq)] @@ -36,83 +35,6 @@ impl EmbeddingBackend { } } -/// Error returned when parsing a retrieval strategy name. -#[derive(Debug, Error, PartialEq, Eq)] -#[error("unknown retrieval strategy '{input}'")] -pub struct ParseRetrievalStrategyError { - /// The unrecognized input string. - pub input: String, -} - -/// Selects which retrieval pipeline strategy to run for chat and search. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] -#[serde(rename_all = "snake_case")] -pub enum RetrievalStrategy { - /// Primary hybrid chunk retrieval for search/chat. - #[default] - Default, - /// Entity retrieval for suggesting relationships when creating manual entities. - RelationshipSuggestion, - /// Entity retrieval for context during content ingestion. - Ingestion, - /// Unified search returning both chunks and entities. - Search, -} - -impl RetrievalStrategy { - #[must_use] - pub fn as_str(self) -> &'static str { - match self { - Self::Default => "default", - Self::RelationshipSuggestion => "relationship_suggestion", - Self::Ingestion => "ingestion", - Self::Search => "search", - } - } -} - -impl FromStr for RetrievalStrategy { - type Err = ParseRetrievalStrategyError; - - fn from_str(value: &str) -> Result { - match value.to_ascii_lowercase().as_str() { - "default" => Ok(Self::Default), - "initial" | "revised" => { - warn!( - "retrieval strategy '{value}' is deprecated; use 'default' instead" - ); - Ok(Self::Default) - } - "relationship_suggestion" => Ok(Self::RelationshipSuggestion), - "ingestion" => Ok(Self::Ingestion), - "search" => Ok(Self::Search), - other => Err(ParseRetrievalStrategyError { - input: other.to_string(), - }), - } - } -} - -impl fmt::Display for RetrievalStrategy { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(self.as_str()) - } -} - -fn deserialize_optional_retrieval_strategy<'de, D>( - deserializer: D, -) -> Result, D::Error> -where - D: Deserializer<'de>, -{ - let value = Option::::deserialize(deserializer)?; - match value { - None => Ok(None), - Some(raw) if raw.trim().is_empty() => Ok(None), - Some(raw) => RetrievalStrategy::from_str(&raw).map(Some).map_err(serde::de::Error::custom), - } -} - impl FromStr for EmbeddingBackend { type Err = ParseEmbeddingBackendError; @@ -195,8 +117,6 @@ pub struct AppConfig { pub fastembed_show_download_progress: Option, #[serde(default)] pub fastembed_max_length: Option, - #[serde(default, deserialize_with = "deserialize_optional_retrieval_strategy")] - pub retrieval_strategy: Option, #[serde(default)] pub embedding_backend: EmbeddingBackend, #[serde(default = "default_ingest_max_body_bytes")] @@ -282,14 +202,6 @@ pub fn ensure_ort_path() { }); } -impl AppConfig { - /// Returns the configured retrieval strategy, or [`RetrievalStrategy::Default`] when unset. - #[must_use] - pub fn resolved_retrieval_strategy(&self) -> RetrievalStrategy { - self.retrieval_strategy.unwrap_or_default() - } -} - impl Default for AppConfig { fn default() -> Self { Self { @@ -312,7 +224,6 @@ impl Default for AppConfig { fastembed_cache_dir: None, fastembed_show_download_progress: None, fastembed_max_length: None, - retrieval_strategy: None, embedding_backend: EmbeddingBackend::default(), ingest_max_body_bytes: default_ingest_max_body_bytes(), ingest_max_files: default_ingest_max_files(), @@ -340,41 +251,22 @@ pub fn get_config() -> Result { mod tests { #![allow(clippy::expect_used)] - use super::{ParseRetrievalStrategyError, RetrievalStrategy}; + use super::EmbeddingBackend; + #[test] - fn retrieval_strategy_defaults_to_default() { - assert_eq!( - RetrievalStrategy::default(), - RetrievalStrategy::Default - ); + fn embedding_backend_defaults_to_fastembed() { + assert_eq!(EmbeddingBackend::default(), EmbeddingBackend::FastEmbed); } #[test] - fn retrieval_strategy_serializes_snake_case() { + fn embedding_backend_parses_aliases() { assert_eq!( - serde_json::to_string(&RetrievalStrategy::Search).expect("serialize"), - "\"search\"" + "openai".parse::().expect("openai"), + EmbeddingBackend::OpenAI ); - } - - #[test] - fn retrieval_strategy_from_str_accepts_deprecated_aliases() { assert_eq!( - "initial".parse::().expect("initial"), - RetrievalStrategy::Default - ); - assert!(matches!( - "unknown".parse::(), - Err(ParseRetrievalStrategyError { .. }) - )); - } - - #[test] - fn app_config_resolved_retrieval_strategy_uses_default_when_unset() { - let config = super::AppConfig::default(); - assert_eq!( - config.resolved_retrieval_strategy(), - RetrievalStrategy::Default + "fast".parse::().expect("fast"), + EmbeddingBackend::FastEmbed ); } } diff --git a/docs/configuration.md b/docs/configuration.md index 2018469..559059b 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -24,7 +24,6 @@ Minne can be configured via environment variables or a `config.yaml` file. Envir | `RUST_LOG` | Logging level | `info` | | `STORAGE` | Storage backend (`local`, `memory`, `s3`) | `local` | | `PDF_INGEST_MODE` | PDF ingestion strategy (`classic`, `llm-first`) | `llm-first` | -| `RETRIEVAL_STRATEGY` | Default retrieval strategy | - | | `EMBEDDING_BACKEND` | Embedding provider (`openai`, `fastembed`) | `fastembed` | | `FASTEMBED_CACHE_DIR` | Model cache directory | `/fastembed` | | `FASTEMBED_SHOW_DOWNLOAD_PROGRESS` | Show progress bar for model downloads | `false` | diff --git a/docs/features.md b/docs/features.md index e113e18..9af4163 100644 --- a/docs/features.md +++ b/docs/features.md @@ -27,13 +27,14 @@ The D3-based graph visualization shows entities as nodes and relationships as ed ## Hybrid Retrieval -Minne combines multiple retrieval strategies: +Minne uses chunk-first hybrid retrieval over the knowledge base: -- **Vector similarity** — Semantic matching via embeddings -- **Full-text search** — Keyword matching with BM25 -- **Graph traversal** — Following relationships between entities +- **Vector similarity** — Semantic matching via embeddings over text chunks +- **Full-text search** — Keyword matching with BM25 over the same chunk index -Results are merged using Reciprocal Rank Fusion (RRF) for optimal relevance. +The two ranked candidate lists are merged with Reciprocal Rank Fusion (RRF). When a caller needs knowledge entities (search, ingestion linking, relationship suggestion), entities are derived from the top retrieved chunks grouped by `source_id`. + +Optional **reranking** can rescore the fused chunk list with a cross-encoder model; see below. ## Reranking (Optional) diff --git a/evaluations/src/args.rs b/evaluations/src/args.rs index ead1cfa..1600305 100644 --- a/evaluations/src/args.rs +++ b/evaluations/src/args.rs @@ -5,7 +5,6 @@ use std::{ use anyhow::{anyhow, Context, Result}; use clap::{Args, Parser, ValueEnum}; -use retrieval_pipeline::RetrievalStrategy; use crate::datasets::DatasetKind; @@ -55,10 +54,6 @@ pub struct RetrievalSettings { #[arg(long)] pub chunk_fts_take: Option, - /// Override average characters per token used for budgeting - #[arg(long)] - pub chunk_avg_chars_per_token: Option, - /// Override maximum chunks attached per entity #[arg(long)] pub max_chunks_per_entity: Option, @@ -71,41 +66,37 @@ pub struct RetrievalSettings { #[arg(long, default_value_t = 4)] pub rerank_pool_size: usize, - /// Keep top-N entities after reranking + /// Keep top-N chunks after reranking #[arg(long, default_value_t = 10)] pub rerank_keep_top: usize, - /// Cap the number of chunks returned by retrieval (revised strategy) + /// Cap the number of chunks returned by retrieval #[arg(long, default_value_t = 5)] pub chunk_result_cap: usize, - /// Reciprocal rank fusion k value for revised chunk merging + /// Reciprocal rank fusion k value for chunk merging #[arg(long)] pub chunk_rrf_k: Option, - /// Weight for vector ranks in revised RRF + /// Weight for vector ranks in RRF #[arg(long)] pub chunk_rrf_vector_weight: Option, - /// Weight for chunk FTS ranks in revised RRF + /// Weight for chunk FTS ranks in RRF #[arg(long)] pub chunk_rrf_fts_weight: Option, - /// Include vector ranks in revised RRF (default: true) + /// Include vector ranks in RRF (default: true) #[arg(long)] pub chunk_rrf_use_vector: Option, - /// Include chunk FTS ranks in revised RRF (default: true) + /// Include chunk FTS ranks in RRF (default: true) #[arg(long)] pub chunk_rrf_use_fts: Option, /// Require verified chunks (disable with --llm-mode) #[arg(skip = true)] pub require_verified_chunks: bool, - - /// Select the retrieval pipeline strategy - #[arg(long, default_value_t = RetrievalStrategy::Default)] - pub strategy: RetrievalStrategy, } impl Default for RetrievalSettings { @@ -113,7 +104,6 @@ impl Default for RetrievalSettings { Self { chunk_vector_take: None, chunk_fts_take: None, - chunk_avg_chars_per_token: None, max_chunks_per_entity: None, rerank: false, rerank_pool_size: 4, @@ -125,7 +115,6 @@ impl Default for RetrievalSettings { chunk_rrf_use_vector: None, chunk_rrf_use_fts: None, require_verified_chunks: true, - strategy: RetrievalStrategy::Default, } } } diff --git a/evaluations/src/perf.rs b/evaluations/src/perf.rs index cc38935..dc29036 100644 --- a/evaluations/src/perf.rs +++ b/evaluations/src/perf.rs @@ -51,8 +51,8 @@ pub fn mirror_perf_outputs( pub fn print_console_summary(record: &EvaluationReport) { let perf = &record.performance; println!( - "[perf] retrieval strategy={} | concurrency={} | rerank={} (pool {:?}, keep {})", - record.retrieval.strategy, + "[perf] resolve_entities={} | concurrency={} | rerank={} (pool {:?}, keep {})", + record.retrieval.resolve_entities, record.retrieval.concurrency, record.retrieval.rerank_enabled, record.retrieval.rerank_pool_size, @@ -63,16 +63,14 @@ pub fn print_console_summary(record: &EvaluationReport) { perf.ingestion_ms, format_duration(perf.namespace_seed_ms), ); - let stage = &perf.stage_latency; - println!( - "[perf] stage avg ms → embed {:.1} | collect {:.1} | graph {:.1} | chunk {:.1} | rerank {:.1} | assemble {:.1}", - stage.embed.avg, - stage.collect_candidates.avg, - stage.graph_expansion.avg, - stage.chunk_attach.avg, - stage.rerank.avg, - stage.assemble.avg, - ); + let stage_summary = perf + .stage_latency + .stages + .iter() + .map(|s| format!("{} {:.1}", s.stage, s.stats.avg)) + .collect::>() + .join(" | "); + println!("[perf] stage avg ms → {stage_summary}"); let eval = &perf.evaluation_stages_ms; println!( "[perf] eval stage ms → slice {} | db {} | corpus {} | namespace {} | queries {} | summarize {} | finalize {}", @@ -107,12 +105,13 @@ mod tests { fn sample_stage_latency() -> crate::eval::StageLatencyBreakdown { crate::eval::StageLatencyBreakdown { - embed: sample_latency(), - collect_candidates: sample_latency(), - graph_expansion: sample_latency(), - chunk_attach: sample_latency(), - rerank: sample_latency(), - assemble: sample_latency(), + stages: ["embed", "search", "rerank", "resolve_entities", "assemble"] + .into_iter() + .map(|stage| crate::eval::StageLatency { + stage: stage.to_string(), + stats: sample_latency(), + }) + .collect(), } } @@ -193,7 +192,7 @@ mod tests { rerank_keep_top: 10, concurrency: 2, detailed_report: false, - retrieval_strategy: "initial".into(), + resolve_entities: false, chunk_result_cap: 5, chunk_rrf_k: 60.0, chunk_rrf_vector_weight: 1.0, @@ -206,7 +205,6 @@ mod tests { ingest_chunk_overlap_tokens: 50, chunk_vector_take: 20, chunk_fts_take: 20, - chunk_avg_chars_per_token: 4, max_chunks_per_entity: 4, cases: Vec::new(), } diff --git a/evaluations/src/pipeline/stages/run_queries.rs b/evaluations/src/pipeline/stages/run_queries.rs index 079b1b9..e97587b 100644 --- a/evaluations/src/pipeline/stages/run_queries.rs +++ b/evaluations/src/pipeline/stages/run_queries.rs @@ -6,7 +6,7 @@ use futures::stream::{self, StreamExt}; use tracing::{debug, info}; use crate::eval::{ - adapt_strategy_output, build_case_diagnostics, text_contains_answer, CaseDiagnostics, + adapt_retrieval_output, build_case_diagnostics, text_contains_answer, CaseDiagnostics, CaseSummary, RetrievedSummary, }; use retrieval_pipeline::{ @@ -51,14 +51,8 @@ pub(crate) async fn run_queries( None }; - let mut retrieval_config = RetrievalConfig { - strategy: config.retrieval.strategy, - ..Default::default() - }; + let mut retrieval_config = RetrievalConfig::default(); retrieval_config.tuning.rerank_keep_top = config.retrieval.rerank_keep_top; - if retrieval_config.tuning.fallback_min_results < config.retrieval.rerank_keep_top { - retrieval_config.tuning.fallback_min_results = config.retrieval.rerank_keep_top; - } retrieval_config.tuning.chunk_result_cap = config.retrieval.chunk_result_cap.max(1); if let Some(value) = config.retrieval.chunk_vector_take { retrieval_config.tuning.chunk_vector_take = value; @@ -81,9 +75,6 @@ pub(crate) async fn run_queries( if let Some(value) = config.retrieval.chunk_rrf_use_fts { retrieval_config.tuning.flags.chunk_rrf_use_fts = value.into(); } - if let Some(value) = config.retrieval.chunk_avg_chars_per_token { - retrieval_config.tuning.avg_chars_per_token = value; - } if let Some(value) = config.retrieval.max_chunks_per_entity { retrieval_config.tuning.max_chunks_per_entity = value; } @@ -187,7 +178,7 @@ pub(crate) async fn run_queries( None => None, }; - let params = pipeline::StrategyParams { + let params = pipeline::RetrievalParams { db_client: &db, openai_client: &openai_client, embedding_provider: Some(&embedding_provider), @@ -196,26 +187,19 @@ pub(crate) async fn run_queries( config: (*retrieval_config).clone(), reranker, }; - let (result_output, pipeline_diagnostics, stage_timings) = if diagnostics_enabled { - let outcome = pipeline::run_pipeline_with_embedding_with_diagnostics( + let (result_output, pipeline_diagnostics, stage_timings) = { + let outcome = pipeline::run_with_embedding_instrumented( params, query_embedding, + diagnostics_enabled, ) .await .with_context(|| format!("running pipeline for question {question_id}"))?; (outcome.results, outcome.diagnostics, outcome.stage_timings) - } else { - let outcome = pipeline::run_pipeline_with_embedding_with_metrics( - params, - query_embedding, - ) - .await - .with_context(|| format!("running pipeline for question {question_id}"))?; - (outcome.results, None, outcome.stage_timings) }; let query_latency = query_start.elapsed().as_millis(); - let candidates = adapt_strategy_output(result_output); + let candidates = adapt_retrieval_output(result_output); let mut retrieved = Vec::new(); let mut match_rank = None; let answers_lower: Vec = diff --git a/evaluations/src/pipeline/stages/summarize.rs b/evaluations/src/pipeline/stages/summarize.rs index c7391d4..ac187ec 100644 --- a/evaluations/src/pipeline/stages/summarize.rs +++ b/evaluations/src/pipeline/stages/summarize.rs @@ -201,7 +201,10 @@ pub(crate) async fn summarize( rerank_keep_top: config.retrieval.rerank_keep_top, concurrency: config.concurrency.max(1), detailed_report: config.detailed_report, - retrieval_strategy: config.retrieval.strategy.to_string(), + resolve_entities: ctx + .retrieval_config + .as_ref() + .is_some_and(|config| config.resolve_entities), chunk_result_cap: config.retrieval.chunk_result_cap, chunk_rrf_k: active_tuning.chunk_rrf_k, chunk_rrf_vector_weight: active_tuning.chunk_rrf_vector_weight, @@ -214,7 +217,6 @@ pub(crate) async fn summarize( ingest_chunk_overlap_tokens: config.ingest.ingest_chunk_overlap_tokens, chunk_vector_take: active_tuning.chunk_vector_take, chunk_fts_take: active_tuning.chunk_fts_take, - chunk_avg_chars_per_token: active_tuning.avg_chars_per_token, max_chunks_per_entity: active_tuning.max_chunks_per_entity, cases: summaries, }); diff --git a/evaluations/src/report.rs b/evaluations/src/report.rs index d4608f3..dd74b78 100644 --- a/evaluations/src/report.rs +++ b/evaluations/src/report.rs @@ -85,7 +85,7 @@ pub struct RetrievalSection { pub average_ndcg: f64, pub latency: LatencyStats, pub concurrency: usize, - pub strategy: String, + pub resolve_entities: bool, pub rerank_enabled: bool, pub rerank_pool_size: Option, pub rerank_keep_top: usize, @@ -226,7 +226,7 @@ impl EvaluationReport { average_ndcg: summary.average_ndcg, latency: summary.latency_ms.clone(), concurrency: summary.concurrency, - strategy: summary.retrieval_strategy.clone(), + resolve_entities: summary.resolve_entities, rerank_enabled: summary.rerank_enabled, rerank_pool_size: summary.rerank_pool_size, rerank_keep_top: summary.rerank_keep_top, @@ -463,7 +463,12 @@ fn render_markdown(report: &EvaluationReport) -> String { write!(md, "| MRR | {:.3} |\\n", report.retrieval.mrr).unwrap(); write!(md, "| NDCG | {:.3} |\\n", report.retrieval.average_ndcg).unwrap(); write!(md, "| Latency Avg / P50 / P95 (ms) | {:.1} / {} / {} |\\n", report.retrieval.latency.avg, report.retrieval.latency.p50, report.retrieval.latency.p95).unwrap(); - write!(md, "| Strategy | `{}` |\\n", report.retrieval.strategy).unwrap(); + write!( + md, + "| Resolve entities | {} |\\n", + bool_badge(report.retrieval.resolve_entities) + ) + .unwrap(); write!(md, "| Concurrency | {} |\\n", report.retrieval.concurrency).unwrap(); if report.retrieval.rerank_enabled { let pool = report @@ -510,28 +515,9 @@ fn render_markdown(report: &EvaluationReport) -> String { md.push_str("\\n## Retrieval Stage Timings\\n\\n"); md.push_str("| Stage | Avg (ms) | P50 (ms) | P95 (ms) |\\n| --- | --- | --- | --- |\\n"); - write_stage_row(&mut md, "Embed", &report.performance.stage_latency.embed); - write_stage_row( - &mut md, - "Collect Candidates", - &report.performance.stage_latency.collect_candidates, - ); - write_stage_row( - &mut md, - "Graph Expansion", - &report.performance.stage_latency.graph_expansion, - ); - write_stage_row( - &mut md, - "Chunk Attach", - &report.performance.stage_latency.chunk_attach, - ); - write_stage_row(&mut md, "Rerank", &report.performance.stage_latency.rerank); - write_stage_row( - &mut md, - "Assemble", - &report.performance.stage_latency.assemble, - ); + for stage in &report.performance.stage_latency.stages { + write_stage_row(&mut md, &prettify_stage(&stage.stage), &stage.stats); + } if report.misses.is_empty() { if report.detailed_report { @@ -623,6 +609,20 @@ fn write_stage_row(buf: &mut String, label: &str, stats: &LatencyStats) { .unwrap(); } +/// Turn a stable stage label (e.g. `resolve_entities`) into a display title (`Resolve Entities`). +fn prettify_stage(label: &str) -> String { + label + .split('_') + .map(|word| { + let mut chars = word.chars(); + chars.next().map_or_else(String::new, |first| { + first.to_uppercase().collect::() + chars.as_str() + }) + }) + .collect::>() + .join(" ") +} + fn bool_badge(value: bool) -> &'static str { if value { "✅" @@ -740,17 +740,6 @@ struct LegacyHistoryDelta { latency_avg_ms: f64, } -fn default_stage_latency() -> StageLatencyBreakdown { - StageLatencyBreakdown { - embed: LatencyStats::default(), - collect_candidates: LatencyStats::default(), - graph_expansion: LatencyStats::default(), - chunk_attach: LatencyStats::default(), - rerank: LatencyStats::default(), - assemble: LatencyStats::default(), - } -} - #[allow(clippy::too_many_lines)] fn convert_legacy_entry(entry: LegacyHistoryEntry) -> EvaluationReport { let overview = OverviewSection { @@ -807,7 +796,7 @@ fn convert_legacy_entry(entry: LegacyHistoryEntry) -> EvaluationReport { average_ndcg: entry.average_ndcg, latency: entry.latency_ms, concurrency: 0, - strategy: "unknown".into(), + resolve_entities: false, rerank_enabled: entry.rerank_enabled, rerank_pool_size: entry.rerank_pool_size, rerank_keep_top: entry.rerank_keep_top, @@ -840,7 +829,7 @@ fn convert_legacy_entry(entry: LegacyHistoryEntry) -> EvaluationReport { ingestion_ms: entry.ingestion_ms, namespace_seed_ms: entry.namespace_seed_ms, evaluation_stages_ms: EvaluationStageTimings::default(), - stage_latency: default_stage_latency(), + stage_latency: StageLatencyBreakdown::default(), namespace_reused: false, ingestion_reused: entry.ingestion_reused, embeddings_reused: entry.ingestion_embeddings_reused, @@ -915,7 +904,8 @@ fn record_history(report: &EvaluationReport, report_dir: &Path) -> Result StageLatencyBreakdown { StageLatencyBreakdown { - embed: latency(9.0), - collect_candidates: latency(10.0), - graph_expansion: latency(11.0), - chunk_attach: latency(12.0), - rerank: latency(13.0), - assemble: latency(14.0), + stages: vec![ + StageLatency { + stage: "embed".to_string(), + stats: latency(9.0), + }, + StageLatency { + stage: "search".to_string(), + stats: latency(10.0), + }, + StageLatency { + stage: "rerank".to_string(), + stats: latency(13.0), + }, + StageLatency { + stage: "resolve_entities".to_string(), + stats: latency(11.0), + }, + StageLatency { + stage: "assemble".to_string(), + stats: latency(14.0), + }, + ], } } @@ -1058,7 +1064,7 @@ mod tests { rerank_keep_top: 5, concurrency: 2, detailed_report: true, - retrieval_strategy: "initial".into(), + resolve_entities: false, chunk_result_cap: 5, chunk_rrf_k: 60.0, chunk_rrf_vector_weight: 1.0, @@ -1071,7 +1077,6 @@ mod tests { ingest_chunks_only: false, chunk_vector_take: 50, chunk_fts_take: 50, - chunk_avg_chars_per_token: 4, max_chunks_per_entity: 4, cases, } @@ -1097,7 +1102,7 @@ mod tests { #[allow(clippy::unwrap_used, clippy::expect_used, clippy::indexing_slicing)] #[test] - fn evaluations_history_captures_strategy_and_concurrency() { + fn evaluations_history_captures_resolve_entities_and_concurrency() { let tmp = tempdir().unwrap(); let summary = sample_summary(false); @@ -1109,7 +1114,7 @@ mod tests { assert_eq!(entries.len(), 1); let stored = &entries[0]; assert_eq!(stored.retrieval.concurrency, summary.concurrency); - assert_eq!(stored.retrieval.strategy, summary.retrieval_strategy); + assert_eq!(stored.retrieval.resolve_entities, summary.resolve_entities); assert_eq!( stored.performance.evaluation_stages_ms.run_queries_ms, summary.perf.evaluation_stage_ms.run_queries_ms diff --git a/evaluations/src/types.rs b/evaluations/src/types.rs index 971a990..5102ff0 100644 --- a/evaluations/src/types.rs +++ b/evaluations/src/types.rs @@ -3,7 +3,7 @@ use std::collections::HashSet; use chrono::{DateTime, Utc}; use common::storage::types::StoredObject; use retrieval_pipeline::{ - PipelineDiagnostics, PipelineStageTimings, RetrievedChunk, RetrievedEntity, StrategyOutput, + Diagnostics, RetrievalOutput, RetrievedChunk, RetrievedEntity, StageKind, StageTimings, }; use serde::{Deserialize, Serialize}; use unicode_normalization::UnicodeNormalization; @@ -69,7 +69,7 @@ pub struct EvaluationSummary { pub rerank_keep_top: usize, pub concurrency: usize, pub detailed_report: bool, - pub retrieval_strategy: String, + pub resolve_entities: bool, pub chunk_result_cap: usize, pub chunk_rrf_k: f32, pub chunk_rrf_vector_weight: f32, @@ -82,7 +82,6 @@ pub struct EvaluationSummary { pub ingest_chunk_overlap_tokens: usize, pub chunk_vector_take: usize, pub chunk_fts_take: usize, - pub chunk_avg_chars_per_token: usize, pub max_chunks_per_entity: usize, pub cases: Vec, } @@ -129,14 +128,20 @@ impl Default for LatencyStats { } } +/// Latency statistics for a single retrieval stage, keyed by the stage's stable label. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StageLatency { + pub stage: String, + pub stats: LatencyStats, +} + +/// Per-stage retrieval latency, in canonical pipeline order. +/// +/// The set of stages is driven entirely by [`StageKind::ALL`], so adding a retrieval stage +/// surfaces here automatically without changes to the evaluation harness. #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct StageLatencyBreakdown { - pub embed: LatencyStats, - pub collect_candidates: LatencyStats, - pub graph_expansion: LatencyStats, - pub chunk_attach: LatencyStats, - pub rerank: LatencyStats, - pub assemble: LatencyStats, + pub stages: Vec, } #[allow(clippy::struct_field_names)] @@ -232,13 +237,12 @@ fn candidates_from_chunks(chunks: Vec) -> Vec Vec { +pub fn adapt_retrieval_output(output: RetrievalOutput) -> Vec { match output { - StrategyOutput::Entities(entities) => candidates_from_entities(entities), - StrategyOutput::Chunks(chunks) => candidates_from_chunks(chunks), - StrategyOutput::Search(search_result) => { - let mut candidates = candidates_from_entities(search_result.entities); - candidates.extend(candidates_from_chunks(search_result.chunks)); + RetrievalOutput::Chunks(chunks) => candidates_from_chunks(chunks), + RetrievalOutput::WithEntities { chunks, entities } => { + let mut candidates = candidates_from_entities(entities); + candidates.extend(candidates_from_chunks(chunks)); candidates.sort_by(|a, b| b.score.total_cmp(&a.score)); candidates } @@ -262,7 +266,7 @@ pub struct CaseDiagnostics { pub attached_chunk_ids: Vec, pub retrieved: Vec, #[serde(skip_serializing_if = "Option::is_none")] - pub pipeline: Option, + pub pipeline: Option, } #[derive(Debug, Serialize)] @@ -366,28 +370,19 @@ pub fn compute_latency_stats(latencies: &[u128]) -> LatencyStats { LatencyStats { avg, p50, p95 } } -pub fn build_stage_latency_breakdown(samples: &[PipelineStageTimings]) -> StageLatencyBreakdown { - fn collect_stage(samples: &[PipelineStageTimings], selector: F) -> Vec - where - F: Fn(&PipelineStageTimings) -> u128, - { - samples.iter().map(selector).collect() - } +pub fn build_stage_latency_breakdown(samples: &[StageTimings]) -> StageLatencyBreakdown { + let stages = StageKind::ALL + .iter() + .map(|kind| { + let latencies: Vec = samples.iter().map(|s| s.stage_ms(*kind)).collect(); + StageLatency { + stage: kind.label().to_string(), + stats: compute_latency_stats(&latencies), + } + }) + .collect(); - StageLatencyBreakdown { - embed: compute_latency_stats(&collect_stage(samples, retrieval_pipeline::StageTimings::embed_ms)), - collect_candidates: compute_latency_stats(&collect_stage(samples, |entry| { - entry.collect_candidates_ms() - })), - graph_expansion: compute_latency_stats(&collect_stage(samples, |entry| { - entry.graph_expansion_ms() - })), - chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| { - entry.chunk_attach_ms() - })), - rerank: compute_latency_stats(&collect_stage(samples, retrieval_pipeline::StageTimings::rerank_ms)), - assemble: compute_latency_stats(&collect_stage(samples, retrieval_pipeline::StageTimings::assemble_ms)), - } + StageLatencyBreakdown { stages } } #[allow( @@ -412,7 +407,7 @@ pub fn build_case_diagnostics( expected_chunk_ids: &[String], answers_lower: &[String], candidates: &[EvaluationCandidate], - pipeline_stats: Option, + pipeline_stats: Option, ) -> CaseDiagnostics { let expected_set: HashSet<&str> = expected_chunk_ids.iter().map(std::string::String::as_str).collect(); let mut seen_chunks: HashSet = HashSet::new(); diff --git a/html-router/assets/style.css b/html-router/assets/style.css index 3b94d07..3431eb7 100644 --- a/html-router/assets/style.css +++ b/html-router/assets/style.css @@ -44,7 +44,6 @@ --leading-snug: 1.375; --leading-relaxed: 1.625; --ease-out: cubic-bezier(0, 0, 0.2, 1); - --ease-in-out: cubic-bezier(0.4, 0, 0.2, 1); --animate-pulse: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite; --default-transition-duration: 150ms; --default-transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1); @@ -285,37 +284,6 @@ } } } - .drawer-open { - > .drawer-side { - overflow-y: auto; - } - > .drawer-toggle { - display: none; - & ~ .drawer-side { - pointer-events: auto; - visibility: visible; - position: sticky; - display: block; - width: auto; - overscroll-behavior: auto; - opacity: 100%; - & > .drawer-overlay { - cursor: default; - background-color: transparent; - } - & > *:not(.drawer-overlay) { - translate: 0%; - [dir="rtl"] & { - translate: 0%; - } - } - } - &:checked ~ .drawer-side { - pointer-events: auto; - visibility: visible; - } - } - } .drawer-toggle { position: fixed; height: calc(0.25rem * 0); @@ -1074,22 +1042,6 @@ grid-row-start: 1; min-width: calc(0.25rem * 0); } - .chat-image { - grid-row: span 2 / span 2; - align-self: flex-end; - } - .chat-footer { - grid-row-start: 3; - display: flex; - gap: calc(0.25rem * 1); - font-size: 0.6875rem; - } - .chat-header { - grid-row-start: 1; - display: flex; - gap: calc(0.25rem * 1); - font-size: 0.6875rem; - } .container { width: 100%; @media (width >= 40rem) { @@ -1796,9 +1748,6 @@ .w-10 { width: calc(var(--spacing) * 10); } - .w-11 { - width: calc(var(--spacing) * 11); - } .w-11\/12 { width: calc(11/12 * 100%); } @@ -1862,9 +1811,6 @@ .flex-none { flex: none; } - .flex-shrink { - flex-shrink: 1; - } .flex-shrink-0 { flex-shrink: 0; } @@ -1877,13 +1823,6 @@ .grow { flex-grow: 1; } - .border-collapse { - border-collapse: collapse; - } - .-translate-y-1 { - --tw-translate-y: calc(var(--spacing) * -1); - translate: var(--tw-translate-x) var(--tw-translate-y); - } .-translate-y-1\/2 { --tw-translate-y: calc(calc(1/2 * 100%) * -1); translate: var(--tw-translate-x) var(--tw-translate-y); @@ -1956,9 +1895,6 @@ .justify-start { justify-content: flex-start; } - .gap-0 { - gap: calc(var(--spacing) * 0); - } .gap-0\.5 { gap: calc(var(--spacing) * 0.5); } @@ -2115,9 +2051,6 @@ .bg-transparent { background-color: transparent; } - .bg-warning { - background-color: var(--color-warning); - } .bg-warning\/10 { background-color: var(--color-warning); @supports (color: color-mix(in lab, red, red)) { @@ -2136,9 +2069,6 @@ .loading-spinner { mask-image: url("data:image/svg+xml,%3Csvg width='24' height='24' stroke='black' viewBox='0 0 24 24' xmlns='http://www.w3.org/2000/svg'%3E%3Cg transform-origin='center'%3E%3Ccircle cx='12' cy='12' r='9.5' fill='none' stroke-width='3' stroke-linecap='round'%3E%3CanimateTransform attributeName='transform' type='rotate' from='0 12 12' to='360 12 12' dur='2s' repeatCount='indefinite'/%3E%3Canimate attributeName='stroke-dasharray' values='0,150;42,150;42,150' keyTimes='0;0.475;1' dur='1.5s' repeatCount='indefinite'/%3E%3Canimate attributeName='stroke-dashoffset' values='0;-16;-59' keyTimes='0;0.475;1' dur='1.5s' repeatCount='indefinite'/%3E%3C/circle%3E%3C/g%3E%3C/svg%3E"); } - .mask-repeat { - mask-repeat: repeat; - } .fill-current { fill: currentcolor; } @@ -2169,9 +2099,6 @@ .p-8 { padding: calc(var(--spacing) * 8); } - .px-1 { - padding-inline: calc(var(--spacing) * 1); - } .px-1\.5 { padding-inline: calc(var(--spacing) * 1.5); } @@ -2326,9 +2253,6 @@ --tw-tracking: var(--tracking-widest); letter-spacing: var(--tracking-widest); } - .text-wrap { - text-wrap: wrap; - } .break-words { overflow-wrap: break-word; } @@ -2395,17 +2319,6 @@ .italic { font-style: italic; } - .underline { - text-decoration-line: underline; - } - .swap-active { - .swap-off { - opacity: 0%; - } - .swap-on { - opacity: 100%; - } - } .opacity-0 { opacity: 0%; } @@ -2496,10 +2409,6 @@ --tw-duration: 300ms; transition-duration: 300ms; } - .ease-in-out { - --tw-ease: var(--ease-in-out); - transition-timing-function: var(--ease-in-out); - } .ease-out { --tw-ease: var(--ease-out); transition-timing-function: var(--ease-out); diff --git a/html-router/src/html_state.rs b/html-router/src/html_state.rs index fdc8e13..ce8ffb2 100644 --- a/html-router/src/html_state.rs +++ b/html-router/src/html_state.rs @@ -2,10 +2,7 @@ use common::storage::types::conversation::SidebarConversation; use common::storage::{db::SurrealDbClient, store::StorageManager}; use common::utils::embedding::EmbeddingProvider; use common::utils::template_engine::{ProvidesTemplateEngine, TemplateEngine}; -use common::{ - create_template_engine, storage::db::ProvidesDb, - utils::config::{AppConfig, RetrievalStrategy}, -}; +use common::{create_template_engine, storage::db::ProvidesDb, utils::config::AppConfig}; use retrieval_pipeline::reranking::RerankerPool; use std::collections::HashMap; use std::sync::{ @@ -75,10 +72,6 @@ impl HtmlState { } } - pub fn retrieval_strategy(&self) -> RetrievalStrategy { - self.config.resolved_retrieval_strategy() - } - pub async fn get_cached_conversation_archive( &self, user_id: &str, diff --git a/html-router/src/routes/chat/message_response_stream.rs b/html-router/src/routes/chat/message_response_stream.rs index 99f04d9..44cd413 100644 --- a/html-router/src/routes/chat/message_response_stream.rs +++ b/html-router/src/routes/chat/message_response_stream.rs @@ -16,12 +16,9 @@ use futures::{ }; use json_stream_parser::JsonStreamParser; use minijinja::Value; -use retrieval_pipeline::{ - answer_retrieval::{ - chunks_to_chat_context, create_chat_request, create_user_message_with_history, - LLMResponseFormat, - }, - retrieved_entities_to_json, +use retrieval_pipeline::answer_retrieval::{ + chunks_to_chat_context, create_chat_request, create_user_message_with_history, + LLMResponseFormat, }; use serde::{Deserialize, Serialize}; use serde_json::from_str; @@ -189,11 +186,7 @@ struct ReferenceData { } fn extract_reference_strings(response: &LLMResponseFormat) -> Vec { - response - .references - .iter() - .map(|reference| reference.reference.clone()) - .collect() + response.reference_ids() } #[allow(clippy::too_many_lines)] @@ -362,10 +355,9 @@ async fn prepare_chat_request( None => None, }; - let strategy = state.retrieval_strategy(); - let config = retrieval_pipeline::RetrievalConfig::for_chat(strategy); + let config = retrieval_pipeline::RetrievalConfig::default(); - let retrieval_result = match retrieval_pipeline::retrieve_entities( + let retrieval_result = match retrieval_pipeline::retrieve( &state.db, &state.openai_client, Some(&*state.embedding_provider), @@ -387,12 +379,9 @@ async fn prepare_chat_request( let allowed_reference_ids = collect_reference_ids_from_retrieval(&retrieval_result); let context_json = match retrieval_result { - retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(&chunks), - retrieval_pipeline::StrategyOutput::Entities(entities) => { - retrieved_entities_to_json(&entities) - } - retrieval_pipeline::StrategyOutput::Search(search_result) => { - chunks_to_chat_context(&search_result.chunks) + retrieval_pipeline::RetrievalOutput::Chunks(chunks) => chunks_to_chat_context(&chunks), + retrieval_pipeline::RetrievalOutput::WithEntities { chunks, .. } => { + chunks_to_chat_context(&chunks) } }; let formatted_user_message = diff --git a/html-router/src/routes/chat/reference_validation.rs b/html-router/src/routes/chat/reference_validation.rs index c88fabd..f69456e 100644 --- a/html-router/src/routes/chat/reference_validation.rs +++ b/html-router/src/routes/chat/reference_validation.rs @@ -9,7 +9,7 @@ use common::{ types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject}, }, }; -use retrieval_pipeline::StrategyOutput; +use retrieval_pipeline::RetrievalOutput; use uuid::Uuid; pub(crate) const MAX_REFERENCE_COUNT: usize = 10; @@ -86,40 +86,29 @@ pub(crate) enum ReferenceLookupTarget { } pub(crate) fn collect_reference_ids_from_retrieval( - retrieval_result: &StrategyOutput, + retrieval_result: &RetrievalOutput, ) -> Vec { let mut ids = Vec::new(); let mut seen = HashSet::new(); + let mut push_id = |id: String| { + if seen.insert(id.clone()) { + ids.push(id); + } + }; + match retrieval_result { - StrategyOutput::Chunks(chunks) => { + RetrievalOutput::Chunks(chunks) => { for chunk in chunks { - let id = chunk.chunk.id.clone(); - if seen.insert(id.clone()) { - ids.push(id); - } + push_id(chunk.chunk.id.clone()); } } - StrategyOutput::Entities(entities) => { + RetrievalOutput::WithEntities { chunks, entities } => { + for chunk in chunks { + push_id(chunk.chunk.id.clone()); + } for entity in entities { - let id = entity.entity.id.clone(); - if seen.insert(id.clone()) { - ids.push(id); - } - } - } - StrategyOutput::Search(search) => { - for chunk in &search.chunks { - let id = chunk.chunk.id.clone(); - if seen.insert(id.clone()) { - ids.push(id); - } - } - for entity in &search.entities { - let id = entity.entity.id.clone(); - if seen.insert(id.clone()) { - ids.push(id); - } + push_id(entity.entity.id.clone()); } } } diff --git a/html-router/src/routes/index/handlers.rs b/html-router/src/routes/index/handlers.rs index 2ec9b2b..0bca199 100644 --- a/html-router/src/routes/index/handlers.rs +++ b/html-router/src/routes/index/handlers.rs @@ -13,7 +13,7 @@ use crate::{ middlewares::{ auth_middleware::RequireUser, response_middleware::{ - template_as_response, HtmlError, TemplateResponse, TemplateResult, ResponseResult, + template_as_response, TemplateResponse, TemplateResult, ResponseResult, }, }, utils::text_content_preview::truncate_text_contents, diff --git a/html-router/src/routes/knowledge/handlers.rs b/html-router/src/routes/knowledge/handlers.rs index d288249..30707c7 100644 --- a/html-router/src/routes/knowledge/handlers.rs +++ b/html-router/src/routes/knowledge/handlers.rs @@ -32,7 +32,7 @@ use crate::{ middlewares::{ auth_middleware::RequireUser, response_middleware::{ - template_with_headers, HtmlError, TemplateResponse, TemplateResult, ResponseResult, + template_with_headers, TemplateResponse, TemplateResult, ResponseResult, }, }, utils::pagination::{paginate_items, Pagination}, @@ -284,9 +284,9 @@ pub async fn suggest_knowledge_relationships( None => None, }; - let config = retrieval_pipeline::RetrievalConfig::for_relationship_suggestion(); - if let Ok(retrieval_pipeline::StrategyOutput::Entities(results)) = - retrieval_pipeline::retrieve_entities( + let config = retrieval_pipeline::RetrievalConfig::with_entities(); + if let Ok(retrieval_pipeline::RetrievalOutput::WithEntities { entities, .. }) = + retrieval_pipeline::retrieve( &state.db, &state.openai_client, Some(&*state.embedding_provider), @@ -297,7 +297,7 @@ pub async fn suggest_knowledge_relationships( ) .await { - for retrieval_pipeline::RetrievedEntity { entity, score, .. } in results { + for retrieval_pipeline::RetrievedEntity { entity, score, .. } in entities { if suggestion_scores.len() >= MAX_RELATIONSHIP_SUGGESTIONS { break; } diff --git a/html-router/src/routes/scratchpad/handlers.rs b/html-router/src/routes/scratchpad/handlers.rs index 991f248..d556570 100644 --- a/html-router/src/routes/scratchpad/handlers.rs +++ b/html-router/src/routes/scratchpad/handlers.rs @@ -12,7 +12,7 @@ use crate::html_state::HtmlState; use crate::middlewares::{ auth_middleware::RequireUser, response_middleware::{ - template_with_headers, HtmlError, TemplateResponse, TemplateResult, ResponseResult, + template_with_headers, TemplateResponse, TemplateResult, ResponseResult, }, }; use common::storage::types::{ diff --git a/html-router/src/routes/search/handlers.rs b/html-router/src/routes/search/handlers.rs index 8716936..826060f 100644 --- a/html-router/src/routes/search/handlers.rs +++ b/html-router/src/routes/search/handlers.rs @@ -4,7 +4,7 @@ use axum::{ extract::{Query, State}, }; use common::storage::types::{text_content::TextContent, user::User}; -use retrieval_pipeline::{RetrievalConfig, SearchResult, SearchTarget, StrategyOutput}; +use retrieval_pipeline::{retrieve, RetrievalConfig, RetrievalOutput, RetrievedChunk, RetrievedEntity}; use serde::{de, Deserialize, Deserializer, Serialize}; use std::{fmt, str::FromStr}; @@ -108,35 +108,35 @@ async fn perform_search( return Ok((Vec::new(), String::new())); } - let config = RetrievalConfig::for_search(SearchTarget::Both); + let config = RetrievalConfig::with_entities(); let reranker_lease = match &state.reranker_pool { Some(pool) => pool.checkout().await, None => None, }; - let params = retrieval_pipeline::pipeline::StrategyParams { - db_client: &state.db, - openai_client: &state.openai_client, - embedding_provider: Some(&state.embedding_provider), - input_text: trimmed_query, - user_id: &user.id, + let result = retrieve( + &state.db, + &state.openai_client, + Some(&state.embedding_provider), + trimmed_query, + &user.id, config, - reranker: reranker_lease, - }; - let result = retrieval_pipeline::pipeline::execute(params).await?; + reranker_lease, + ) + .await?; - let search_result = match result { - StrategyOutput::Search(sr) => sr, - _ => SearchResult::new(vec![], vec![]), + let (chunks, entities) = match result { + RetrievalOutput::WithEntities { chunks, entities } => (chunks, entities), + RetrievalOutput::Chunks(chunks) => (chunks, Vec::new()), }; - let source_label_map = collect_source_label_map(state, user, &search_result).await?; + let source_label_map = collect_source_label_map(state, user, &chunks, &entities).await?; let mut combined_results: Vec = - Vec::with_capacity(search_result.chunks.len().saturating_add(search_result.entities.len())); + Vec::with_capacity(chunks.len().saturating_add(entities.len())); - for chunk_result in search_result.chunks { + for chunk_result in chunks { let source_label = source_label_map .get(&chunk_result.chunk.source_id) .cloned() @@ -155,7 +155,7 @@ async fn perform_search( }); } - for entity_result in search_result.entities { + for entity_result in entities { let source_label = source_label_map .get(&entity_result.entity.source_id) .cloned() @@ -187,13 +187,14 @@ async fn perform_search( async fn collect_source_label_map( state: &HtmlState, user: &User, - search_result: &SearchResult, + chunks: &[RetrievedChunk], + entities: &[RetrievedEntity], ) -> Result, HtmlError> { let mut source_ids = HashSet::new(); - for chunk_result in &search_result.chunks { + for chunk_result in chunks { source_ids.insert(chunk_result.chunk.source_id.clone()); } - for entity_result in &search_result.entities { + for entity_result in entities { source_ids.insert(entity_result.entity.source_id.clone()); } diff --git a/ingestion-pipeline/src/pipeline/services.rs b/ingestion-pipeline/src/pipeline/services.rs index 72f2698..13fb02e 100644 --- a/ingestion-pipeline/src/pipeline/services.rs +++ b/ingestion-pipeline/src/pipeline/services.rs @@ -183,10 +183,8 @@ impl PipelineServices for DefaultPipelineServices { None => None, }; - let config = retrieval_pipeline::RetrievalConfig::for_search( - retrieval_pipeline::SearchTarget::EntitiesOnly, - ); - match retrieval_pipeline::retrieve_entities( + let config = retrieval_pipeline::RetrievalConfig::with_entities(); + match retrieval_pipeline::retrieve( &self.db, &self.openai_client, Some(&*self.embedding_provider), @@ -197,19 +195,16 @@ impl PipelineServices for DefaultPipelineServices { ) .await { - Ok(retrieval_pipeline::StrategyOutput::Entities(entities)) => Ok(entities), - Ok(retrieval_pipeline::StrategyOutput::Search(search)) => { - let chunk_count = search.chunks.len(); - let entities = search.entities; + Ok(retrieval_pipeline::RetrievalOutput::WithEntities { chunks, entities }) => { tracing::debug!( - chunk_count, + chunk_count = chunks.len(), entity_count = entities.len(), - "ingestion search results returned entities" + "ingestion retrieval resolved entities from chunks" ); Ok(entities) } - Ok(retrieval_pipeline::StrategyOutput::Chunks(_)) => Err(AppError::InternalError( - "Ingestion retrieval should return entities".into(), + Ok(retrieval_pipeline::RetrievalOutput::Chunks(_)) => Err(AppError::InternalError( + "Ingestion retrieval should resolve entities".into(), )), Err(e) => Err(e), } @@ -372,16 +367,16 @@ mod tests { fn system_prompt_from_request( request: &async_openai::types::CreateChatCompletionRequest, - ) -> String { - let ChatCompletionRequestMessage::System(system) = &request.messages[0] else { - panic!("expected first message to be system"); + ) -> anyhow::Result { + let Some(ChatCompletionRequestMessage::System(system)) = request.messages.first() else { + anyhow::bail!("expected first message to be system"); }; - match &system.content { - async_openai::types::ChatCompletionRequestSystemMessageContent::Text(text) => { - text.clone() - } - other => panic!("unexpected system message content: {other:?}"), - } + let async_openai::types::ChatCompletionRequestSystemMessageContent::Text(text) = + &system.content + else { + anyhow::bail!("unexpected system message content: {:?}", system.content); + }; + Ok(text.clone()) } #[tokio::test] @@ -425,7 +420,7 @@ mod tests { .await .context("prepare llm request")?; - assert_eq!(system_prompt_from_request(&request), SENTINEL); + assert_eq!(system_prompt_from_request(&request)?, SENTINEL); Ok(()) } } diff --git a/ingestion-pipeline/src/utils/pdf_ingestion.rs b/ingestion-pipeline/src/utils/pdf_ingestion.rs index 4f84130..4b63c83 100644 --- a/ingestion-pipeline/src/utils/pdf_ingestion.rs +++ b/ingestion-pipeline/src/utils/pdf_ingestion.rs @@ -125,10 +125,10 @@ async fn render_pdf_pages(file_path: &Path, pages: &[u32]) -> Result }) .await??; - for (idx, png) in captures.iter().enumerate() { - if let Err(err) = maybe_dump_debug_image(page_numbers[idx], png).await { + for (page_number, png) in page_numbers.iter().zip(captures.iter()) { + if let Err(err) = maybe_dump_debug_image(*page_number, png).await { warn!( - page = page_numbers[idx], + page = page_number, error = %err, "Failed to write debug screenshot to disk" ); diff --git a/main/src/bootstrap/mod.rs b/main/src/bootstrap/mod.rs index 13ee0b5..1925091 100644 --- a/main/src/bootstrap/mod.rs +++ b/main/src/bootstrap/mod.rs @@ -95,6 +95,7 @@ pub(crate) async fn init_with_config(config: AppConfig) -> anyhow::Result(api_state: &ApiState, html_state: &HtmlState) -> Router where S: Clone + Send + Sync + 'static, @@ -24,6 +25,7 @@ where .merge(html_routes(html_state)) } +#[allow(dead_code)] // used by server/main binaries, not worker pub fn build_api_state(services: &SharedServices) -> ApiState { ApiState { db: Arc::clone(&services.db), @@ -32,6 +34,7 @@ pub fn build_api_state(services: &SharedServices) -> ApiState { } } +#[allow(dead_code)] // used by server/main binaries, not worker pub async fn build_html_state(services: &SharedServices) -> anyhow::Result { let session_store = Arc::new( services diff --git a/main/src/main.rs b/main/src/main.rs index 6ec9582..34f21db 100644 --- a/main/src/main.rs +++ b/main/src/main.rs @@ -75,6 +75,7 @@ struct AppState { } #[cfg(test)] +#[allow(clippy::expect_used)] mod tests { use super::*; use axum::{ diff --git a/retrieval-pipeline/Cargo.toml b/retrieval-pipeline/Cargo.toml index c2fd692..05bcdf8 100644 --- a/retrieval-pipeline/Cargo.toml +++ b/retrieval-pipeline/Cargo.toml @@ -10,16 +10,14 @@ workspace = true [dependencies] tokio = { workspace = true } serde = { workspace = true } -axum = { workspace = true } tracing = { workspace = true } -anyhow = { workspace = true } -thiserror = { workspace = true } serde_json = { workspace = true } -surrealdb = { workspace = true } -futures = { workspace = true } async-openai = { workspace = true } async-trait = { workspace = true } -uuid = { workspace = true } fastembed = { workspace = true } common = { path = "../common", features = ["test-utils"] } + +[dev-dependencies] +anyhow = { workspace = true } +uuid = { workspace = true } diff --git a/retrieval-pipeline/src/answer_retrieval.rs b/retrieval-pipeline/src/answer_retrieval.rs index 7eb7fc3..2fba43c 100644 --- a/retrieval-pipeline/src/answer_retrieval.rs +++ b/retrieval-pipeline/src/answer_retrieval.rs @@ -1,61 +1,66 @@ +//! Chat answer assembly: retrieval context formatting and structured LLM request/response types. + use async_openai::{ error::OpenAIError, types::{ ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, - CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse, - ResponseFormat, ResponseFormatJsonSchema, + CreateChatCompletionRequest, CreateChatCompletionRequestArgs, ResponseFormat, + ResponseFormatJsonSchema, }, }; -use common::{ - error::AppError, - storage::types::{ - message::{format_history, Message}, - system_settings::SystemSettings, - }, +use common::storage::types::{ + message::{format_history, Message}, + system_settings::SystemSettings, }; use serde::Deserialize; -use serde_json::Value; +use serde_json::{json, Value}; -use super::answer_retrieval_helper::get_query_response_schema; +/// JSON schema describing the structured chat answer (answer text + references). +fn get_query_response_schema() -> Value { + json!({ + "type": "object", + "properties": { + "answer": { "type": "string" }, + "references": { + "type": "array", + "items": { + "type": "object", + "properties": { + "reference": { "type": "string" }, + }, + "required": ["reference"], + "additionalProperties": false, + } + } + }, + "required": ["answer", "references"], + "additionalProperties": false + }) +} #[derive(Debug, Deserialize)] pub struct Reference { - #[allow(dead_code)] pub reference: String, } #[derive(Debug, Deserialize)] pub struct LLMResponseFormat { pub answer: String, - #[allow(dead_code)] pub references: Vec, } -#[derive(Debug)] -pub struct Answer { - pub content: String, - pub references: Vec, -} - -pub fn create_user_message(entities_json: &Value, query: &str) -> String { - format!( - r" - Context Information: - ================== - {entities_json} - - User Question: - ================== - {query} - " - ) -} - -/// Convert chunk-based retrieval results to JSON format for LLM context -pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value { - fn round_score(value: f32) -> f64 { - (f64::from(value) * 1000.0).round() / 1000.0 +impl LLMResponseFormat { + pub fn reference_ids(&self) -> Vec { + self.references + .iter() + .map(|entry| entry.reference.clone()) + .collect() } +} + +/// Convert chunk-based retrieval results to JSON format for LLM context. +pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value { + use crate::round_score; serde_json::json!(chunks .iter() @@ -70,7 +75,7 @@ pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value { } pub fn create_user_message_with_history( - entities_json: &Value, + context_json: &Value, history: &[Message], query: &str, ) -> String { @@ -89,7 +94,7 @@ pub fn create_user_message_with_history( {} ", format_history(history), - entities_json, + context_json, query ) } @@ -116,18 +121,3 @@ pub fn create_chat_request( .response_format(response_format) .build() } - -pub fn process_llm_response( - response: &CreateChatCompletionResponse, -) -> Result> { - response - .choices - .first() - .and_then(|choice| choice.message.content.as_ref()) - .ok_or_else(|| Box::new(AppError::LLMParsing("No content found in LLM response".into()))) - .and_then(|content| { - serde_json::from_str::(content).map_err(|e| { - Box::new(AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {e}"))) - }) - }) -} diff --git a/retrieval-pipeline/src/answer_retrieval_helper.rs b/retrieval-pipeline/src/answer_retrieval_helper.rs deleted file mode 100644 index 66b75a6..0000000 --- a/retrieval-pipeline/src/answer_retrieval_helper.rs +++ /dev/null @@ -1,23 +0,0 @@ -use serde_json::{json, Value}; - -pub fn get_query_response_schema() -> Value { - json!({ - "type": "object", - "properties": { - "answer": { "type": "string" }, - "references": { - "type": "array", - "items": { - "type": "object", - "properties": { - "reference": { "type": "string" }, - }, - "required": ["reference"], - "additionalProperties": false, - } - } - }, - "required": ["answer", "references"], - "additionalProperties": false - }) -} diff --git a/retrieval-pipeline/src/graph.rs b/retrieval-pipeline/src/graph.rs deleted file mode 100644 index a326c38..0000000 --- a/retrieval-pipeline/src/graph.rs +++ /dev/null @@ -1,228 +0,0 @@ -use std::collections::{HashMap, HashSet}; - -use surrealdb::{sql::Thing, Error}; - -use common::storage::{ - db::SurrealDbClient, - types::{ - knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship, - StoredObject, - }, -}; - -/// Find entities related to the given entity via graph relationships. -/// -/// Queries the `relates_to` edge table for all relationships involving the entity, -/// then fetches and returns the neighboring entities. -/// -/// # Arguments -/// * `db` - Database client -/// * `entity_id` - ID of the entity to find neighbors for -/// * `user_id` - User ID for access control -/// * `limit` - Maximum number of neighbors to return -pub async fn find_entities_by_relationship_by_id( - db: &SurrealDbClient, - entity_id: &str, - user_id: &str, - limit: usize, -) -> Result, Error> { - let mut relationships_response = db - .query( - " - SELECT * FROM relates_to - WHERE metadata.user_id = $user_id - AND (in = type::thing('knowledge_entity', $entity_id) - OR out = type::thing('knowledge_entity', $entity_id)) - ", - ) - .bind(("entity_id", entity_id.to_owned())) - .bind(("user_id", user_id.to_owned())) - .await?; - - let relationships: Vec = relationships_response.take(0)?; - if relationships.is_empty() { - return Ok(Vec::new()); - } - - let mut neighbor_ids: Vec = Vec::with_capacity(relationships.len()); - let mut seen: HashSet = HashSet::with_capacity(relationships.len()); - for rel in relationships { - if rel.in_ == entity_id { - if seen.insert(rel.out.clone()) { - neighbor_ids.push(rel.out); - } - } else if rel.out == entity_id { - if seen.insert(rel.in_.clone()) { - neighbor_ids.push(rel.in_); - } - } else { - if seen.insert(rel.in_.clone()) { - neighbor_ids.push(rel.in_.clone()); - } - if seen.insert(rel.out.clone()) { - neighbor_ids.push(rel.out); - } - } - } - - neighbor_ids.retain(|id| id != entity_id); - - if neighbor_ids.is_empty() { - return Ok(Vec::new()); - } - - if limit > 0 && neighbor_ids.len() > limit { - neighbor_ids.truncate(limit); - } - - let thing_ids: Vec = neighbor_ids - .iter() - .map(|id| Thing::from((KnowledgeEntity::table_name(), id.as_str()))) - .collect(); - - let mut neighbors_response = db - .query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id") - .bind(("table", KnowledgeEntity::table_name().to_owned())) - .bind(("things", thing_ids)) - .bind(("user_id", user_id.to_owned())) - .await?; - - let neighbors: Vec = neighbors_response.take(0)?; - if neighbors.is_empty() { - return Ok(Vec::new()); - } - - let mut neighbor_map: HashMap = neighbors - .into_iter() - .map(|entity| (entity.id.clone(), entity)) - .collect(); - - let mut ordered = Vec::with_capacity(neighbor_ids.len()); - for id in neighbor_ids { - if let Some(entity) = neighbor_map.remove(&id) { - ordered.push(entity); - } - if limit > 0 && ordered.len() >= limit { - break; - } - } - - Ok(ordered) -} - -#[cfg(test)] -mod tests { - use anyhow::{self, Context}; - use super::*; - use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}; - use common::storage::types::knowledge_relationship::KnowledgeRelationship; - use uuid::Uuid; - - #[tokio::test] - async fn test_find_entities_by_relationship_by_id() -> anyhow::Result<()> { - let namespace = "test_ns"; - let database = &Uuid::new_v4().to_string(); - let db = SurrealDbClient::memory(namespace, database) - .await - .with_context(|| "Failed to start in-memory surrealdb".to_string())?; - - let entity_type = KnowledgeEntityType::Document; - let user_id = "user123".to_string(); - - let central_entity = KnowledgeEntity::new( - "central_source".to_string(), - "Central Entity".to_string(), - "Central Description".to_string(), - entity_type.clone(), - None, - user_id.clone(), - ); - - let related_entity1 = KnowledgeEntity::new( - "related_source1".to_string(), - "Related Entity 1".to_string(), - "Related Description 1".to_string(), - entity_type.clone(), - None, - user_id.clone(), - ); - - let related_entity2 = KnowledgeEntity::new( - "related_source2".to_string(), - "Related Entity 2".to_string(), - "Related Description 2".to_string(), - entity_type.clone(), - None, - user_id.clone(), - ); - - let unrelated_entity = KnowledgeEntity::new( - "unrelated_source".to_string(), - "Unrelated Entity".to_string(), - "Unrelated Description".to_string(), - entity_type.clone(), - None, - user_id.clone(), - ); - - let central_entity = db - .store_item(central_entity.clone()) - .await - .with_context(|| "Failed to store central entity".to_string())? - .ok_or_else(|| anyhow::anyhow!("Central entity not returned after store"))?; - let related_entity1 = db - .store_item(related_entity1.clone()) - .await - .with_context(|| "Failed to store related entity 1".to_string())? - .ok_or_else(|| anyhow::anyhow!("Related entity 1 not returned after store"))?; - let related_entity2 = db - .store_item(related_entity2.clone()) - .await - .with_context(|| "Failed to store related entity 2".to_string())? - .ok_or_else(|| anyhow::anyhow!("Related entity 2 not returned after store"))?; - let _unrelated_entity = db - .store_item(unrelated_entity.clone()) - .await - .with_context(|| "Failed to store unrelated entity".to_string())? - .ok_or_else(|| anyhow::anyhow!("Unrelated entity not returned after store"))?; - - let source_id = "relationship_source".to_string(); - - let relationship1 = KnowledgeRelationship::new( - central_entity.id.clone(), - related_entity1.id.clone(), - user_id.clone(), - source_id.clone(), - "references".to_string(), - ); - - let relationship2 = KnowledgeRelationship::new( - central_entity.id.clone(), - related_entity2.id.clone(), - user_id.clone(), - source_id.clone(), - "contains".to_string(), - ); - - relationship1 - .store_relationship(&db) - .await - .with_context(|| "Failed to store relationship 1".to_string())?; - relationship2 - .store_relationship(&db) - .await - .with_context(|| "Failed to store relationship 2".to_string())?; - - let related_entities = - find_entities_by_relationship_by_id(&db, ¢ral_entity.id, &user_id, usize::MAX) - .await - .with_context(|| "Failed to find entities by relationship".to_string())?; - - assert!( - related_entities.len() >= 2, - "Should find related entities in both directions" - ); - - Ok(()) - } -} diff --git a/retrieval-pipeline/src/lib.rs b/retrieval-pipeline/src/lib.rs index 15bc7af..6332d18 100644 --- a/retrieval-pipeline/src/lib.rs +++ b/retrieval-pipeline/src/lib.rs @@ -1,10 +1,9 @@ pub mod answer_retrieval; -pub mod answer_retrieval_helper; -pub mod graph; pub mod pipeline; pub mod reranking; -pub mod scoring; + +pub(crate) mod scoring; use common::{ error::AppError, @@ -16,39 +15,28 @@ use common::{ use reranking::RerankerLease; use tracing::instrument; -// Strategy output variants - defined before pipeline module +/// Result of a retrieval run. +/// +/// Chunk retrieval is always performed; entities are only present when the caller +/// requested entity resolution via [`RetrievalConfig::with_entities`]. #[derive(Debug)] -pub enum StrategyOutput { - Entities(Vec), +pub enum RetrievalOutput { Chunks(Vec), - Search(SearchResult), -} - -/// Unified search result containing both chunks and entities -#[derive(Debug, Clone)] -pub struct SearchResult { - pub chunks: Vec, - pub entities: Vec, -} - -impl SearchResult { - pub fn new(chunks: Vec, entities: Vec) -> Self { - Self { chunks, entities } - } - - pub fn is_empty(&self) -> bool { - self.chunks.is_empty() && self.entities.is_empty() - } + WithEntities { + chunks: Vec, + entities: Vec, + }, } pub use pipeline::{ - retrieved_entities_to_json, Diagnostics, StageTimings, RetrievalConfig, - RetrievalStrategy, RetrievalTuning, RetrievalTuningFlags, SearchTarget, + retrieved_entities_to_json, Diagnostics, RetrievalConfig, RetrievalParams, StageKind, + StageTimings, }; -// Backward-compatible type aliases for external consumers -pub type PipelineDiagnostics = Diagnostics; -pub type PipelineStageTimings = StageTimings; +/// Round a score to three decimal places for JSON output. +pub(crate) fn round_score(value: f32) -> f64 { + (f64::from(value) * 1000.0).round() / 1000.0 +} // Captures a supporting chunk plus its fused retrieval score for downstream prompts. #[derive(Debug, Clone)] @@ -57,7 +45,7 @@ pub struct RetrievedChunk { pub score: f32, } -// Final entity representation returned to callers, enriched with ranked chunks. +// Knowledge entity resolved from retrieved chunks, enriched with its contributing chunks. #[derive(Debug, Clone)] pub struct RetrievedEntity { pub entity: KnowledgeEntity, @@ -65,9 +53,9 @@ pub struct RetrievedEntity { pub chunks: Vec, } -/// Primary orchestrator for the process of retrieving `KnowledgeEntity` values related to an `input_text` +/// Run chunk-first hybrid retrieval for `input_text`, optionally resolving owning entities. #[instrument(skip_all, fields(user_id))] -pub async fn retrieve_entities( +pub async fn retrieve( db_client: &SurrealDbClient, openai_client: &async_openai::Client, embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>, @@ -75,8 +63,8 @@ pub async fn retrieve_entities( user_id: &str, config: RetrievalConfig, reranker: Option, -) -> Result { - let params = pipeline::StrategyParams { +) -> Result { + let params = pipeline::RetrievalParams { db_client, openai_client, embedding_provider, @@ -94,6 +82,7 @@ mod tests { use anyhow::{self}; use async_openai::Client; use common::storage::indexes::ensure_runtime; + use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}; use common::storage::types::system_settings::SystemSettings; use uuid::Uuid; @@ -133,7 +122,7 @@ mod tests { } #[tokio::test] - async fn test_default_strategy_retrieves_chunks() -> anyhow::Result<()> { + async fn test_chunk_retrieval_returns_chunks() -> anyhow::Result<()> { let db = setup_test_db().await?; let user_id = "test_user"; let chunk = TextChunk::new( @@ -145,7 +134,7 @@ mod tests { TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?; let openai_client = Client::new(); - let params = pipeline::StrategyParams { + let params = pipeline::RetrievalParams { db_client: &db, openai_client: &openai_client, embedding_provider: None, @@ -154,12 +143,13 @@ mod tests { config: RetrievalConfig::default(), reranker: None, }; - let results = pipeline::run_pipeline_with_embedding(params, test_embedding()) - .await?; + let results = pipeline::run_with_embedding(params, test_embedding()).await?; let chunks = match results { - StrategyOutput::Chunks(items) => items, - other => anyhow::bail!("expected chunk results, got {other:?}"), + RetrievalOutput::Chunks(items) => items, + RetrievalOutput::WithEntities { .. } => { + anyhow::bail!("expected chunk results, got entities") + } }; assert!(!chunks.is_empty(), "Expected at least one retrieval result"); @@ -171,8 +161,7 @@ mod tests { } #[tokio::test] - async fn test_default_strategy_returns_chunks_from_multiple_sources( - ) -> anyhow::Result<()> { + async fn test_chunk_retrieval_returns_chunks_from_multiple_sources() -> anyhow::Result<()> { let db = setup_test_db().await?; let user_id = "multi_source_user"; @@ -191,7 +180,7 @@ mod tests { TextChunk::store_with_embedding(secondary_chunk, chunk_embedding_secondary(), &db).await?; let openai_client = Client::new(); - let params = pipeline::StrategyParams { + let params = pipeline::RetrievalParams { db_client: &db, openai_client: &openai_client, embedding_provider: None, @@ -200,12 +189,13 @@ mod tests { config: RetrievalConfig::default(), reranker: None, }; - let results = pipeline::run_pipeline_with_embedding(params, test_embedding()) - .await?; + let results = pipeline::run_with_embedding(params, test_embedding()).await?; let chunks = match results { - StrategyOutput::Chunks(items) => items, - other => anyhow::bail!("expected chunk results, got {other:?}"), + RetrievalOutput::Chunks(items) => items, + RetrievalOutput::WithEntities { .. } => { + anyhow::bail!("expected chunk results, got entities") + } }; assert!(chunks.len() >= 2, "Expected chunks from multiple sources"); @@ -223,96 +213,54 @@ mod tests { } #[tokio::test] - async fn test_revised_strategy_returns_chunks() -> anyhow::Result<()> { + async fn test_with_entities_resolves_owning_entities() -> anyhow::Result<()> { let db = setup_test_db().await?; - let user_id = "chunk_user"; - let chunk_one = TextChunk::new( - "src_alpha".into(), - "Tokio tasks execute on worker threads managed by the runtime.".into(), - user_id.into(), - ); - let chunk_two = TextChunk::new( - "src_beta".into(), - "Hyper utilizes Tokio to drive HTTP state machines efficiently.".into(), - user_id.into(), - ); + let user_id = "entity_user"; - TextChunk::store_with_embedding(chunk_one.clone(), chunk_embedding_primary(), &db).await?; - TextChunk::store_with_embedding(chunk_two.clone(), chunk_embedding_secondary(), &db).await?; - - let config = RetrievalConfig::with_strategy(RetrievalStrategy::Default); - let openai_client = Client::new(); - let params = pipeline::StrategyParams { - db_client: &db, - openai_client: &openai_client, - embedding_provider: None, - input_text: "tokio runtime worker behavior", - user_id, - config, - reranker: None, - }; - let results = pipeline::run_pipeline_with_embedding(params, test_embedding()) - .await?; - - let chunks = match results { - StrategyOutput::Chunks(items) => items, - other => anyhow::bail!("expected chunk results, got {other:?}"), - }; - - assert!( - !chunks.is_empty(), - "Revised strategy should return chunk-only responses" - ); - assert!( - chunks - .iter() - .any(|entry| entry.chunk.chunk.contains("Tokio")), - "Chunk results should contain relevant snippets" - ); - Ok(()) - } - - #[tokio::test] - async fn test_search_strategy_returns_search_result() -> anyhow::Result<()> { - let db = setup_test_db().await?; - let user_id = "search_user"; let chunk = TextChunk::new( - "search_src".into(), - "Async Rust programming uses Tokio runtime for concurrent tasks.".into(), + "entity_source".into(), + "Async Rust programming uses the Tokio runtime for concurrent tasks.".into(), user_id.into(), ); - TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?; - let config = RetrievalConfig::for_search(pipeline::SearchTarget::Both); + let entity = KnowledgeEntity::new( + "entity_source".into(), + "Tokio Runtime".into(), + "Async runtime for Rust".into(), + KnowledgeEntityType::Document, + None, + user_id.into(), + ); + db.store_item(entity).await?; + let openai_client = Client::new(); - let params = pipeline::StrategyParams { + let params = pipeline::RetrievalParams { db_client: &db, openai_client: &openai_client, embedding_provider: None, input_text: "async rust programming", user_id, - config, + config: RetrievalConfig::with_entities(), reranker: None, }; - let results = pipeline::run_pipeline_with_embedding(params, test_embedding()) - .await?; + let results = pipeline::run_with_embedding(params, test_embedding()).await?; - let StrategyOutput::Search(search_result) = results else { - anyhow::bail!("expected Search output"); + let RetrievalOutput::WithEntities { chunks, entities } = results else { + anyhow::bail!("expected WithEntities output"); }; - // Should return chunks (entities may be empty if none stored) + assert!(!chunks.is_empty(), "Should return chunks"); assert!( - !search_result.chunks.is_empty(), - "Search strategy should return chunks" + entities.iter().any(|e| e.entity.name == "Tokio Runtime"), + "Should resolve the entity owning the retrieved chunk" ); assert!( - search_result - .chunks + entities .iter() - .any(|c| c.chunk.chunk.contains("Tokio")), - "Search results should contain relevant chunks" + .find(|e| e.entity.name == "Tokio Runtime") + .is_some_and(|e| !e.chunks.is_empty()), + "Resolved entity should carry its contributing chunks" ); Ok(()) } diff --git a/retrieval-pipeline/src/pipeline/config.rs b/retrieval-pipeline/src/pipeline/config.rs index 1c77d39..b047b1c 100644 --- a/retrieval-pipeline/src/pipeline/config.rs +++ b/retrieval-pipeline/src/pipeline/config.rs @@ -1,22 +1,5 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use crate::scoring::FusionWeights; - -pub use common::utils::config::RetrievalStrategy; - -/// Configures which result types to include in Search strategy -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] -#[serde(rename_all = "snake_case")] -pub enum SearchTarget { - /// Return only text chunks - ChunksOnly, - /// Return only knowledge entities - EntitiesOnly, - /// Return both chunks and entities (default) - #[default] - Both, -} - /// Two-variant flag that serializes as a bool for backward compatibility. #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum BoolFlag { @@ -62,30 +45,20 @@ impl<'de> Deserialize<'de> for BoolFlag { #[derive(Debug, Clone, Copy, Serialize, Deserialize)] pub struct RetrievalTuningFlags { pub rerank_scores_only: BoolFlag, - pub normalize_vector_scores: BoolFlag, - pub normalize_fts_scores: BoolFlag, pub chunk_rrf_use_vector: BoolFlag, pub chunk_rrf_use_fts: BoolFlag, } impl RetrievalTuningFlags { - pub const fn rerank_scores_only(&self) -> bool { + pub const fn rerank_scores_only(self) -> bool { self.rerank_scores_only.as_bool() } - pub const fn normalize_vector_scores(&self) -> bool { - self.normalize_vector_scores.as_bool() - } - - pub const fn normalize_fts_scores(&self) -> bool { - self.normalize_fts_scores.as_bool() - } - - pub const fn chunk_rrf_use_vector(&self) -> bool { + pub const fn chunk_rrf_use_vector(self) -> bool { self.chunk_rrf_use_vector.as_bool() } - pub const fn chunk_rrf_use_fts(&self) -> bool { + pub const fn chunk_rrf_use_fts(self) -> bool { self.chunk_rrf_use_fts.as_bool() } } @@ -94,146 +67,70 @@ impl Default for RetrievalTuningFlags { fn default() -> Self { Self { rerank_scores_only: BoolFlag::Disabled, - normalize_vector_scores: BoolFlag::Disabled, - normalize_fts_scores: BoolFlag::Enabled, chunk_rrf_use_vector: BoolFlag::Enabled, chunk_rrf_use_fts: BoolFlag::Enabled, } } } -/// Tunable parameters that govern each retrieval stage. +/// Tunable parameters governing the chunk-first hybrid (vector + FTS, RRF-fused) retrieval. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RetrievalTuning { - pub entity_vector_take: usize, + /// Number of vector candidates to pull from the chunk embedding index. pub chunk_vector_take: usize, - pub entity_fts_take: usize, + /// Number of full-text candidates to pull from the chunk index. pub chunk_fts_take: usize, - pub score_threshold: f32, - pub fallback_min_results: usize, - pub token_budget_estimate: usize, - pub avg_chars_per_token: usize, + /// Maximum chunks attached to each resolved entity. pub max_chunks_per_entity: usize, - pub lexical_match_weight: f32, - pub graph_traversal_seed_limit: usize, - pub graph_neighbor_limit: usize, - pub graph_score_decay: f32, - pub graph_seed_min_score: f32, - pub graph_vector_inheritance: f32, + /// Blend weight applied when mixing reranker scores with fused scores. pub rerank_blend_weight: f32, - pub flags: RetrievalTuningFlags, + /// Keep top-N candidates after reranking. pub rerank_keep_top: usize, + /// Maximum number of chunks returned to callers. pub chunk_result_cap: usize, - /// Optional fusion weights for hybrid search. If None, uses default weights. - pub fusion_weights: Option, - /// Reciprocal rank fusion k value for chunk merging in Revised strategy. - #[serde(default = "default_chunk_rrf_k")] + /// Reciprocal rank fusion k value for chunk merging. pub chunk_rrf_k: f32, /// Weight applied to vector ranks in RRF. - #[serde(default = "default_chunk_rrf_vector_weight")] pub chunk_rrf_vector_weight: f32, /// Weight applied to chunk FTS ranks in RRF. - #[serde(default = "default_chunk_rrf_fts_weight")] pub chunk_rrf_fts_weight: f32, + pub flags: RetrievalTuningFlags, } impl Default for RetrievalTuning { fn default() -> Self { Self { - entity_vector_take: 15, chunk_vector_take: 20, - entity_fts_take: 10, chunk_fts_take: 20, - score_threshold: 0.35, - fallback_min_results: 10, - token_budget_estimate: 10000, - avg_chars_per_token: 4, max_chunks_per_entity: 4, - lexical_match_weight: 0.15, - graph_traversal_seed_limit: 5, - graph_neighbor_limit: 6, - graph_score_decay: 0.75, - graph_seed_min_score: 0.4, - graph_vector_inheritance: 0.6, rerank_blend_weight: 0.65, - flags: RetrievalTuningFlags::default(), rerank_keep_top: 8, chunk_result_cap: 5, - fusion_weights: None, - chunk_rrf_k: default_chunk_rrf_k(), - chunk_rrf_vector_weight: default_chunk_rrf_vector_weight(), - chunk_rrf_fts_weight: default_chunk_rrf_fts_weight(), + chunk_rrf_k: 60.0, + chunk_rrf_vector_weight: 1.0, + chunk_rrf_fts_weight: 1.0, + flags: RetrievalTuningFlags::default(), } } } -/// Wrapper containing tuning plus future flags for per-request overrides. +/// Per-request retrieval configuration. +/// +/// The pipeline always performs chunk-first hybrid retrieval. Set `resolve_entities` +/// when a caller additionally needs the `KnowledgeEntity` rows that own the retrieved +/// chunks (search, ingestion linking, relationship suggestion). #[derive(Debug, Clone, Default)] pub struct RetrievalConfig { - pub strategy: RetrievalStrategy, pub tuning: RetrievalTuning, - /// Target for Search strategy (chunks, entities, or both) - pub search_target: SearchTarget, + pub resolve_entities: bool, } impl RetrievalConfig { - pub fn new(tuning: RetrievalTuning) -> Self { + /// Chunk retrieval that also resolves the owning knowledge entities. + pub fn with_entities() -> Self { Self { - strategy: RetrievalStrategy::Default, - tuning, - search_target: SearchTarget::default(), - } - } - - pub fn with_strategy(strategy: RetrievalStrategy) -> Self { - Self { - strategy, tuning: RetrievalTuning::default(), - search_target: SearchTarget::default(), - } - } - - pub fn with_tuning(strategy: RetrievalStrategy, tuning: RetrievalTuning) -> Self { - Self { - strategy, - tuning, - search_target: SearchTarget::default(), - } - } - - /// Create config for chat retrieval with strategy selection support - pub fn for_chat(strategy: RetrievalStrategy) -> Self { - Self::with_strategy(strategy) - } - - /// Create config for relationship suggestion (entity-only retrieval) - pub fn for_relationship_suggestion() -> Self { - Self::with_strategy(RetrievalStrategy::RelationshipSuggestion) - } - - /// Create config for ingestion pipeline (entity-only retrieval) - pub fn for_ingestion() -> Self { - Self::with_strategy(RetrievalStrategy::Ingestion) - } - - /// Create config for unified search (chunks and/or entities) - pub fn for_search(target: SearchTarget) -> Self { - Self { - strategy: RetrievalStrategy::Search, - tuning: RetrievalTuning::default(), - search_target: target, + resolve_entities: true, } } } - -const fn default_chunk_rrf_k() -> f32 { - 60.0 -} - -const fn default_chunk_rrf_vector_weight() -> f32 { - 1.0 -} - -const fn default_chunk_rrf_fts_weight() -> f32 { - 1.0 -} diff --git a/retrieval-pipeline/src/pipeline/context.rs b/retrieval-pipeline/src/pipeline/context.rs new file mode 100644 index 0000000..a234af8 --- /dev/null +++ b/retrieval-pipeline/src/pipeline/context.rs @@ -0,0 +1,107 @@ +use async_openai::Client; +use common::{ + error::AppError, + storage::{db::SurrealDbClient, types::text_chunk::TextChunk}, + utils::embedding::EmbeddingProvider, +}; + +use crate::{reranking::RerankerLease, scoring::Scored, RetrievedChunk, RetrievedEntity}; + +use super::{ + config::RetrievalConfig, + diagnostics::{AssembleStats, Diagnostics, SearchStats}, + StageKind, StageTimings, RetrievalParams, +}; + +/// Mutable working state threaded through every retrieval stage. +pub(crate) struct PipelineContext<'a> { + pub db_client: &'a SurrealDbClient, + pub openai_client: &'a Client, + pub embedding_provider: Option<&'a EmbeddingProvider>, + pub input_text: String, + pub user_id: String, + pub config: RetrievalConfig, + pub query_embedding: Option>, + pub chunk_values: Vec>, + pub reranker: Option, + pub diagnostics: Option, + pub entity_results: Vec, + pub chunk_results: Vec, + stage_timings: StageTimings, +} + +impl<'a> PipelineContext<'a> { + pub fn new(params: RetrievalParams<'a>) -> Self { + Self { + db_client: params.db_client, + openai_client: params.openai_client, + embedding_provider: params.embedding_provider, + input_text: params.input_text.to_owned(), + user_id: params.user_id.to_owned(), + config: params.config, + query_embedding: None, + chunk_values: Vec::new(), + reranker: params.reranker, + diagnostics: None, + entity_results: Vec::new(), + chunk_results: Vec::new(), + stage_timings: StageTimings::default(), + } + } + + pub fn with_embedding(params: RetrievalParams<'a>, query_embedding: Vec) -> Self { + let mut ctx = Self::new(params); + ctx.query_embedding = Some(query_embedding); + ctx + } + + pub(crate) fn ensure_embedding(&self) -> Result<&Vec, Box> { + self.query_embedding.as_ref().ok_or_else(|| { + Box::new(AppError::InternalError( + "query embedding missing before candidate search".to_string(), + )) + }) + } + + pub fn enable_diagnostics(&mut self) { + if self.diagnostics.is_none() { + self.diagnostics = Some(Diagnostics::default()); + } + } + + pub fn diagnostics_enabled(&self) -> bool { + self.diagnostics.is_some() + } + + pub(crate) fn record_search(&mut self, stats: SearchStats) { + if let Some(diag) = self.diagnostics.as_mut() { + diag.search = Some(stats); + } + } + + pub(crate) fn record_assemble(&mut self, stats: AssembleStats) { + if let Some(diag) = self.diagnostics.as_mut() { + diag.assemble = Some(stats); + } + } + + pub fn take_diagnostics(&mut self) -> Option { + self.diagnostics.take() + } + + pub fn take_stage_timings(&mut self) -> StageTimings { + std::mem::take(&mut self.stage_timings) + } + + pub fn record_stage_duration(&mut self, kind: StageKind, duration: std::time::Duration) { + self.stage_timings.record(kind, duration); + } + + pub fn take_entity_results(&mut self) -> Vec { + std::mem::take(&mut self.entity_results) + } + + pub fn take_chunk_results(&mut self) -> Vec { + std::mem::take(&mut self.chunk_results) + } +} diff --git a/retrieval-pipeline/src/pipeline/diagnostics.rs b/retrieval-pipeline/src/pipeline/diagnostics.rs index 8fd3495..374e130 100644 --- a/retrieval-pipeline/src/pipeline/diagnostics.rs +++ b/retrieval-pipeline/src/pipeline/diagnostics.rs @@ -1,51 +1,21 @@ use serde::Serialize; -/// Captures instrumentation for each hybrid retrieval stage when diagnostics are enabled. +/// Captures instrumentation for the retrieval stages when diagnostics are enabled. #[derive(Debug, Clone, Default, Serialize)] pub struct Diagnostics { - pub collect_candidates: Option, - pub enrich_chunks_from_entities: Option, + pub search: Option, pub assemble: Option, } #[derive(Debug, Clone, Default, Serialize)] -pub struct CollectCandidatesStats { - pub vector_entity_candidates: usize, +pub struct SearchStats { pub vector_chunk_candidates: usize, - pub fts_entity_candidates: usize, pub fts_chunk_candidates: usize, pub vector_chunk_scores: Vec, pub fts_chunk_scores: Vec, } -#[derive(Debug, Clone, Default, Serialize)] -pub struct ChunkEnrichmentStats { - pub filtered_entity_count: usize, - pub fallback_min_results: usize, - pub chunk_sources_considered: usize, - pub chunk_candidates_before_enrichment: usize, - pub chunk_candidates_after_enrichment: usize, - pub top_chunk_scores: Vec, -} - #[derive(Debug, Clone, Default, Serialize)] pub struct AssembleStats { - pub token_budget_start: usize, - pub token_budget_spent: usize, - pub token_budget_remaining: usize, - pub budget_exhausted: bool, pub chunks_selected: usize, - pub chunks_skipped_due_budget: usize, - pub entity_count: usize, - pub entity_traces: Vec, -} - -#[derive(Debug, Clone, Default, Serialize)] -pub struct EntityAssemblyTrace { - pub entity_id: String, - pub source_id: String, - pub inspected_candidates: usize, - pub selected_chunk_ids: Vec, - pub selected_chunk_scores: Vec, - pub skipped_due_budget: usize, } diff --git a/retrieval-pipeline/src/pipeline/mod.rs b/retrieval-pipeline/src/pipeline/mod.rs index 738dd81..238d0d5 100644 --- a/retrieval-pipeline/src/pipeline/mod.rs +++ b/retrieval-pipeline/src/pipeline/mod.rs @@ -1,61 +1,68 @@ mod config; +mod context; mod diagnostics; mod stages; -mod strategies; -pub use config::{ - RetrievalConfig, RetrievalStrategy, RetrievalTuning, RetrievalTuningFlags, SearchTarget, -}; -pub use diagnostics::{ - AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace, Diagnostics, -}; +pub use config::RetrievalConfig; +pub use diagnostics::Diagnostics; -use crate::{reranking::RerankerLease, RetrievedEntity, StrategyOutput}; +use crate::{round_score, RetrievalOutput, RetrievedEntity}; use async_openai::Client; use async_trait::async_trait; use common::{error::AppError, storage::db::SurrealDbClient}; use std::time::{Duration, Instant}; use tracing::info; -use stages::PipelineContext; -use strategies::{ - DefaultStrategyDriver, IngestionDriver, RelationshipSuggestionDriver, SearchStrategyDriver, +use stages::{ + ChunkAssembleStage, ChunkRerankStage, ChunkSearchStage, EmbedStage, ResolveEntitiesStage, }; -// Export StrategyOutput publicly from this module -// (it's defined in lib.rs but we re-export it here) - -// Stage type enum +/// Identifies a retrieval stage for timing and instrumentation. +/// +/// [`StageKind::ALL`] lists every kind in pipeline order; consumers (e.g. the evaluation +/// harness) iterate it generically so that adding a stage requires no changes outside this +/// crate — add the variant, extend `ALL`, and give it a [`StageKind::label`]. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum StageKind { Embed, - CollectCandidates, - GraphExpansion, - ChunkAttach, + Search, Rerank, + ResolveEntities, Assemble, } -// Pipeline stage trait +impl StageKind { + /// Every stage kind in canonical pipeline order. + pub const ALL: [StageKind; 5] = [ + StageKind::Embed, + StageKind::Search, + StageKind::Rerank, + StageKind::ResolveEntities, + StageKind::Assemble, + ]; + + /// Stable, machine-friendly identifier for the stage (used as a metrics key). + pub const fn label(self) -> &'static str { + match self { + StageKind::Embed => "embed", + StageKind::Search => "search", + StageKind::Rerank => "rerank", + StageKind::ResolveEntities => "resolve_entities", + StageKind::Assemble => "assemble", + } + } +} + +/// A single composable step in the retrieval pipeline. #[async_trait] -pub trait Stage: Send + Sync { +pub(crate) trait Stage: Send + Sync { fn kind(&self) -> StageKind; - async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError>; + async fn execute(&self, ctx: &mut context::PipelineContext<'_>) -> Result<(), AppError>; } -// Type alias for boxed stages -pub type BoxedStage = Box; +pub(crate) type BoxedStage = Box; -// Strategy driver trait -#[async_trait] -pub trait StrategyDriver: Send + Sync { - type Output; - - fn stages(&self) -> Vec; - fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result>; -} - -// Pipeline stage timings tracker +/// Per-stage execution timings recorded during a run. #[derive(Debug, Default, Clone)] pub struct StageTimings { timings: Vec<(StageKind, Duration)>, @@ -66,41 +73,13 @@ impl StageTimings { self.timings.push((kind, duration)); } - pub fn into_vec(self) -> Vec<(StageKind, Duration)> { - self.timings - } - - // Helper methods to get duration for each stage type (for backward compatibility) - fn get_stage_ms(&self, kind: StageKind) -> u128 { + /// Milliseconds recorded for `kind`, or `0` if the stage did not run. + pub fn stage_ms(&self, kind: StageKind) -> u128 { self.timings .iter() .find(|(k, _)| *k == kind) .map_or(0, |(_, d)| d.as_millis()) } - - pub fn embed_ms(&self) -> u128 { - self.get_stage_ms(StageKind::Embed) - } - - pub fn collect_candidates_ms(&self) -> u128 { - self.get_stage_ms(StageKind::CollectCandidates) - } - - pub fn graph_expansion_ms(&self) -> u128 { - self.get_stage_ms(StageKind::GraphExpansion) - } - - pub fn chunk_attach_ms(&self) -> u128 { - self.get_stage_ms(StageKind::ChunkAttach) - } - - pub fn rerank_ms(&self) -> u128 { - self.get_stage_ms(StageKind::Rerank) - } - - pub fn assemble_ms(&self) -> u128 { - self.get_stage_ms(StageKind::Assemble) - } } pub struct RunOutput { @@ -109,7 +88,35 @@ pub struct RunOutput { pub stage_timings: StageTimings, } -pub async fn execute(params: StrategyParams<'_>) -> Result { +/// Inputs required to run a retrieval. +pub struct RetrievalParams<'a> { + pub db_client: &'a SurrealDbClient, + pub openai_client: &'a Client, + pub embedding_provider: Option<&'a common::utils::embedding::EmbeddingProvider>, + pub input_text: &'a str, + pub user_id: &'a str, + pub config: RetrievalConfig, + pub reranker: Option, +} + +fn build_stages(config: &RetrievalConfig) -> Vec { + let mut stages: Vec = vec![ + Box::new(EmbedStage), + Box::new(ChunkSearchStage), + Box::new(ChunkRerankStage), + ]; + if config.resolve_entities { + stages.push(Box::new(ResolveEntitiesStage)); + } + stages.push(Box::new(ChunkAssembleStage)); + stages +} + +async fn run( + params: RetrievalParams<'_>, + query_embedding: Option>, + capture_diagnostics: bool, +) -> Result, AppError> { let input_chars = params.input_text.chars().count(); let input_preview: String = params.input_text.chars().take(120).collect(); let input_preview_clean = input_preview.replace('\n', " "); @@ -119,110 +126,67 @@ pub async fn execute(params: StrategyParams<'_>) -> Result preview_len, preview = %input_preview_clean, - strategy = %params.config.strategy, + resolve_entities = params.config.resolve_entities, "Starting retrieval pipeline" ); - let strategy = params.config.strategy; - let search_target = params.config.search_target; + let resolve_entities = params.config.resolve_entities; + let mut ctx = match query_embedding { + Some(embedding) => context::PipelineContext::with_embedding(params, embedding), + None => context::PipelineContext::new(params), + }; - match strategy { - RetrievalStrategy::Default => { - let driver = DefaultStrategyDriver::new(); - let run = execute_strategy(driver, params, None, false).await?; - Ok(StrategyOutput::Chunks(run.results)) - } - RetrievalStrategy::RelationshipSuggestion => { - let driver = RelationshipSuggestionDriver::new(); - let run = execute_strategy(driver, params, None, false).await?; - Ok(StrategyOutput::Entities(run.results)) - } - RetrievalStrategy::Ingestion => { - let driver = IngestionDriver::new(); - let run = execute_strategy(driver, params, None, false).await?; - Ok(StrategyOutput::Entities(run.results)) - } - RetrievalStrategy::Search => { - let driver = SearchStrategyDriver::new(search_target); - let run = execute_strategy(driver, params, None, false).await?; - Ok(StrategyOutput::Search(run.results)) - } + if capture_diagnostics { + ctx.enable_diagnostics(); } + + for stage in build_stages(&ctx.config) { + let start = Instant::now(); + stage.execute(&mut ctx).await?; + ctx.record_stage_duration(stage.kind(), start.elapsed()); + } + + let diagnostics = ctx.take_diagnostics(); + let stage_timings = ctx.take_stage_timings(); + let chunks = ctx.take_chunk_results(); + let results = if resolve_entities { + RetrievalOutput::WithEntities { + chunks, + entities: ctx.take_entity_results(), + } + } else { + RetrievalOutput::Chunks(chunks) + }; + + Ok(RunOutput { + results, + diagnostics, + stage_timings, + }) } -pub async fn run_pipeline_with_embedding( - params: StrategyParams<'_>, - query_embedding: Vec, -) -> Result { - let strategy = params.config.strategy; - let search_target = params.config.search_target; - - match strategy { - RetrievalStrategy::Default => { - let driver = DefaultStrategyDriver::new(); - let run = execute_strategy(driver, params, Some(query_embedding), false).await?; - Ok(StrategyOutput::Chunks(run.results)) - } - RetrievalStrategy::RelationshipSuggestion => { - let driver = RelationshipSuggestionDriver::new(); - let run = execute_strategy(driver, params, Some(query_embedding), false).await?; - Ok(StrategyOutput::Entities(run.results)) - } - RetrievalStrategy::Ingestion => { - let driver = IngestionDriver::new(); - let run = execute_strategy(driver, params, Some(query_embedding), false).await?; - Ok(StrategyOutput::Entities(run.results)) - } - RetrievalStrategy::Search => { - let driver = SearchStrategyDriver::new(search_target); - let run = execute_strategy(driver, params, Some(query_embedding), false).await?; - Ok(StrategyOutput::Search(run.results)) - } - } +/// Run the retrieval pipeline, generating the query embedding internally if needed. +pub async fn execute(params: RetrievalParams<'_>) -> Result { + Ok(run(params, None, false).await?.results) } -pub async fn run_pipeline_with_embedding_with_metrics( - params: StrategyParams<'_>, +/// Run the retrieval pipeline with a pre-computed query embedding. +pub async fn run_with_embedding( + params: RetrievalParams<'_>, query_embedding: Vec, -) -> Result, AppError> { - let strategy = params.config.strategy; - - match strategy { - RetrievalStrategy::Default => { - let driver = DefaultStrategyDriver::new(); - let run = execute_strategy(driver, params, Some(query_embedding), false).await?; - Ok(RunOutput { - results: StrategyOutput::Chunks(run.results), - diagnostics: run.diagnostics, - stage_timings: run.stage_timings, - }) - } - _ => Err(AppError::InternalError( - "Metrics not supported for this strategy".into(), - )), - } +) -> Result { + Ok(run(params, Some(query_embedding), false).await?.results) } -pub async fn run_pipeline_with_embedding_with_diagnostics( - params: StrategyParams<'_>, +/// Run with a pre-computed embedding, returning results and per-stage timings. +/// +/// When `capture_diagnostics` is true, pipeline search/assemble stats are included. +pub async fn run_with_embedding_instrumented( + params: RetrievalParams<'_>, query_embedding: Vec, -) -> Result, AppError> { - let strategy = params.config.strategy; - - match strategy { - RetrievalStrategy::Default => { - let driver = DefaultStrategyDriver::new(); - let run = execute_strategy(driver, params, Some(query_embedding), true).await?; - Ok(RunOutput { - results: StrategyOutput::Chunks(run.results), - diagnostics: run.diagnostics, - stage_timings: run.stage_timings, - }) - } - _ => Err(AppError::InternalError( - "Diagnostics not supported for this strategy".into(), - )), - } + capture_diagnostics: bool, +) -> Result, AppError> { + run(params, Some(query_embedding), capture_diagnostics).await } pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value { @@ -246,57 +210,3 @@ pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::V }) .collect::>()) } - -pub struct StrategyParams<'a> { - pub db_client: &'a SurrealDbClient, - pub openai_client: &'a Client, - pub embedding_provider: Option<&'a common::utils::embedding::EmbeddingProvider>, - pub input_text: &'a str, - pub user_id: &'a str, - pub config: RetrievalConfig, - pub reranker: Option, -} - -async fn execute_strategy( - driver: D, - params: StrategyParams<'_>, - query_embedding: Option>, - capture_diagnostics: bool, -) -> Result, AppError> { - let ctx = match query_embedding { - Some(embedding) => PipelineContext::with_embedding(params, embedding), - None => PipelineContext::new(params), - }; - - run_with_driver(driver, ctx, capture_diagnostics).await -} - -async fn run_with_driver( - driver: D, - mut ctx: PipelineContext<'_>, - capture_diagnostics: bool, -) -> Result, AppError> { - if capture_diagnostics { - ctx.enable_diagnostics(); - } - - for stage in driver.stages() { - let start = Instant::now(); - stage.execute(&mut ctx).await?; - ctx.record_stage_duration(stage.kind(), start.elapsed()); - } - - let diagnostics = ctx.take_diagnostics(); - let stage_timings = ctx.take_stage_timings(); - let results = driver.finalize(&mut ctx).map_err(|e| *e)?; - - Ok(RunOutput { - results, - diagnostics, - stage_timings, - }) -} - -fn round_score(value: f32) -> f64 { - (f64::from(value) * 1000.0).round() / 1000.0 -} diff --git a/retrieval-pipeline/src/pipeline/stages.rs b/retrieval-pipeline/src/pipeline/stages.rs new file mode 100644 index 0000000..7de0bad --- /dev/null +++ b/retrieval-pipeline/src/pipeline/stages.rs @@ -0,0 +1,424 @@ +use async_trait::async_trait; +use common::{ + error::AppError, + storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk}, + utils::embedding::generate_embedding, +}; +use fastembed::RerankResult; +use std::collections::HashMap; +use tracing::{debug, instrument, warn}; + +use crate::{ + scoring::{ + clamp_unit, min_max_normalize, reciprocal_rank_fusion, RrfConfig, Scored, + }, + RetrievedChunk, RetrievedEntity, +}; + +use super::{ + config::RetrievalTuning, + context::PipelineContext, + diagnostics::{AssembleStats, SearchStats}, + Stage, StageKind, +}; + +#[derive(Debug, Clone, Copy)] +pub struct EmbedStage; + +#[async_trait] +impl Stage for EmbedStage { + fn kind(&self) -> StageKind { + StageKind::Embed + } + + async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { + embed(ctx).await + } +} + +#[derive(Debug, Clone, Copy)] +pub struct ChunkSearchStage; + +#[async_trait] +impl Stage for ChunkSearchStage { + fn kind(&self) -> StageKind { + StageKind::Search + } + + async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { + search_chunks(ctx).await + } +} + +#[derive(Debug, Clone, Copy)] +pub struct ChunkRerankStage; + +#[async_trait] +impl Stage for ChunkRerankStage { + fn kind(&self) -> StageKind { + StageKind::Rerank + } + + async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { + rerank_chunks(ctx).await + } +} + +#[derive(Debug, Clone, Copy)] +pub struct ResolveEntitiesStage; + +#[async_trait] +impl Stage for ResolveEntitiesStage { + fn kind(&self) -> StageKind { + StageKind::ResolveEntities + } + + async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { + resolve_entities(ctx).await + } +} + +#[derive(Debug, Clone, Copy)] +pub struct ChunkAssembleStage; + +#[async_trait] +impl Stage for ChunkAssembleStage { + fn kind(&self) -> StageKind { + StageKind::Assemble + } + + async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { + assemble_chunks(ctx) + } +} + +#[instrument(level = "trace", skip_all)] +pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { + if ctx.query_embedding.is_some() { + debug!("Reusing cached query embedding for hybrid retrieval"); + } else { + debug!("Generating query embedding for hybrid retrieval"); + let embedding = if let Some(provider) = ctx.embedding_provider { + provider.embed(&ctx.input_text).await? + } else { + generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await? + }; + ctx.query_embedding = Some(embedding); + } + + Ok(()) +} + +#[instrument(level = "trace", skip_all)] +pub async fn search_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { + debug!("Collecting chunk candidates via vector and FTS search"); + let embedding = ctx.ensure_embedding().map_err(|e| *e)?.clone(); + let tuning = &ctx.config.tuning; + let fts_take = tuning.chunk_fts_take; + let (fts_query, fts_token_count) = normalize_fts_query(&ctx.input_text); + let fts_enabled = tuning.flags.chunk_rrf_use_fts() && fts_take > 0 && !fts_query.is_empty(); + + let (vector_rows, fts_rows) = tokio::try_join!( + TextChunk::vector_search( + tuning.chunk_vector_take, + embedding, + ctx.db_client, + &ctx.user_id, + ), + async { + if fts_enabled { + TextChunk::fts_search(fts_take, &fts_query, ctx.db_client, &ctx.user_id).await + } else { + Ok(Vec::new()) + } + } + )?; + + let vector_candidates = vector_rows.len(); + let fts_candidates = fts_rows.len(); + + let vector_scored: Vec> = vector_rows + .into_iter() + .map(|row| Scored::new(row.chunk).with_vector_score(row.score)) + .collect(); + + let fts_scored: Vec> = fts_rows + .into_iter() + .map(|row| Scored::new(row.chunk).with_fts_score(row.score)) + .collect(); + + let mut fts_weight = tuning.chunk_rrf_fts_weight; + if fts_enabled && fts_token_count > 0 && fts_token_count <= 3 { + // For very short keyword queries, lean more on lexical ranking. + fts_weight *= 1.5; + } + + let rrf_config = RrfConfig { + k: tuning.chunk_rrf_k, + vector_weight: tuning.chunk_rrf_vector_weight, + fts_weight, + use_vector: tuning.flags.chunk_rrf_use_vector(), + use_fts: tuning.flags.chunk_rrf_use_fts() && fts_candidates > 0, + }; + + let chunks = reciprocal_rank_fusion(vector_scored, fts_scored, rrf_config); + + debug!( + total_merged = chunks.len(), + vector_only = chunks.iter().filter(|c| c.scores.fts.is_none()).count(), + fts_only = chunks.iter().filter(|c| c.scores.vector.is_none()).count(), + both_signals = chunks + .iter() + .filter(|c| c.scores.vector.is_some() && c.scores.fts.is_some()) + .count(), + rrf_k = rrf_config.k, + "Merged chunk candidates with RRF" + ); + + if ctx.diagnostics_enabled() { + ctx.record_search(SearchStats { + vector_chunk_candidates: vector_candidates, + fts_chunk_candidates: fts_candidates, + vector_chunk_scores: sample_scores(&chunks, |chunk| chunk.scores.vector.unwrap_or(0.0)), + fts_chunk_scores: sample_scores(&chunks, |chunk| chunk.scores.fts.unwrap_or(0.0)), + }); + } + + ctx.chunk_values = chunks; + + Ok(()) +} + +#[instrument(level = "trace", skip_all)] +pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { + if ctx.chunk_values.len() <= 1 { + return Ok(()); + } + + let Some(reranker) = ctx.reranker.as_ref() else { + debug!("No reranker lease provided; skipping chunk rerank stage"); + return Ok(()); + }; + + let documents = + build_chunk_rerank_documents(&ctx.chunk_values, ctx.config.tuning.rerank_keep_top.max(1)); + if documents.len() <= 1 { + debug!("Skipping chunk reranking stage; insufficient chunk documents"); + return Ok(()); + } + + match reranker.rerank(&ctx.input_text, documents).await { + Ok(results) if !results.is_empty() => { + apply_chunk_rerank_results(&mut ctx.chunk_values, &ctx.config.tuning, results); + } + Ok(_) => debug!("Chunk reranker returned no results; retaining original order"), + Err(err) => warn!( + error = %err, + "Chunk reranking failed; continuing with original ordering" + ), + } + + Ok(()) +} + +/// Resolve the `KnowledgeEntity` rows that own the retrieved chunks. +/// +/// Entities are derived directly from the (benchmarked) chunk retrieval: chunks are grouped +/// by `source_id`, the owning entities are loaded, scored by their best contributing chunk, +/// and the contributing chunks are attached. +#[instrument(level = "trace", skip_all)] +pub async fn resolve_entities(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { + if ctx.chunk_values.is_empty() { + return Ok(()); + } + + let max_chunks = ctx.config.tuning.max_chunks_per_entity.max(1); + + let mut source_order: Vec = Vec::new(); + let mut chunks_by_source: HashMap> = HashMap::new(); + let mut best_score: HashMap = HashMap::new(); + + for scored in &ctx.chunk_values { + let source = scored.item.source_id.clone(); + let attached = chunks_by_source.entry(source.clone()).or_default(); + if attached.is_empty() { + source_order.push(source.clone()); + best_score.insert(source.clone(), scored.fused); + } + if attached.len() < max_chunks { + attached.push(RetrievedChunk { + chunk: scored.item.clone(), + score: scored.fused, + }); + } + } + + let entities = + KnowledgeEntity::find_by_source_ids(ctx.db_client, &source_order, &ctx.user_id).await?; + + let mut entities_by_source: HashMap> = HashMap::new(); + for entity in entities { + entities_by_source + .entry(entity.source_id.clone()) + .or_default() + .push(entity); + } + + let mut results = Vec::new(); + for source in &source_order { + let Some(entities) = entities_by_source.remove(source) else { + continue; + }; + let score = best_score.get(source).copied().unwrap_or(0.0); + let chunks = chunks_by_source.get(source).cloned().unwrap_or_default(); + for entity in entities { + results.push(RetrievedEntity { + entity, + score, + chunks: chunks.clone(), + }); + } + } + + debug!( + sources = source_order.len(), + entities = results.len(), + "Resolved entities from retrieved chunks" + ); + + ctx.entity_results = results; + Ok(()) +} + +#[instrument(level = "trace", skip_all)] +#[allow(clippy::result_large_err)] +pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { + debug!("Assembling chunk retrieval results"); + let mut chunk_values = std::mem::take(&mut ctx.chunk_values); + // Limit how many chunks we return to keep context size reasonable. + let limit = ctx + .config + .tuning + .chunk_result_cap + .max(1) + .min(ctx.config.tuning.chunk_vector_take.max(1)); + + if chunk_values.len() > limit { + chunk_values.truncate(limit); + } + + ctx.chunk_results = chunk_values + .into_iter() + .map(|chunk| RetrievedChunk { + chunk: chunk.item, + score: chunk.fused, + }) + .collect(); + + if ctx.diagnostics_enabled() { + ctx.record_assemble(AssembleStats { + chunks_selected: ctx.chunk_results.len(), + }); + } + + Ok(()) +} + +const SCORE_SAMPLE_LIMIT: usize = 8; + +fn sample_scores(items: &[Scored], extractor: F) -> Vec +where + F: FnMut(&Scored) -> f32, +{ + items.iter().take(SCORE_SAMPLE_LIMIT).map(extractor).collect() +} + +fn normalize_fts_query(input: &str) -> (String, usize) { + const STOPWORDS: &[&str] = &["the", "a", "an", "of", "in", "on", "and", "or", "to", "for"]; + let mut cleaned = String::with_capacity(input.len()); + for ch in input.chars() { + if ch.is_alphanumeric() { + cleaned.extend(ch.to_lowercase()); + } else if ch.is_whitespace() { + cleaned.push(' '); + } + } + let mut tokens = Vec::with_capacity(cleaned.len().div_ceil(3)); + for token in cleaned.split_whitespace() { + if !STOPWORDS.contains(&token) && !token.is_empty() { + tokens.push(token.to_string()); + } + } + let normalized = tokens.join(" "); + (normalized, tokens.len()) +} + +fn build_chunk_rerank_documents(chunks: &[Scored], max_chunks: usize) -> Vec { + chunks + .iter() + .take(max_chunks) + .map(|chunk| { + format!( + "Source: {}\nChunk:\n{}", + chunk.item.source_id, + chunk.item.chunk.trim() + ) + }) + .collect() +} + +fn apply_chunk_rerank_results( + chunks: &mut Vec>, + tuning: &RetrievalTuning, + results: Vec, +) { + if results.is_empty() || chunks.is_empty() { + return; + } + + let mut remaining: Vec>> = + std::mem::take(chunks).into_iter().map(Some).collect(); + + let raw_scores: Vec = results.iter().map(|r| r.score).collect(); + let normalized_scores = min_max_normalize(&raw_scores); + + let use_only = tuning.flags.rerank_scores_only(); + let blend = if use_only { + 1.0 + } else { + clamp_unit(tuning.rerank_blend_weight) + }; + + let mut reranked: Vec> = Vec::with_capacity(remaining.len()); + for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) { + if let Some(slot) = remaining.get_mut(result.index) { + if let Some(mut candidate) = slot.take() { + let original = candidate.fused; + let blended = if use_only { + clamp_unit(normalized) + } else { + clamp_unit(original * (1.0 - blend) + normalized * blend) + }; + candidate.update_fused(blended); + reranked.push(candidate); + } + } else { + warn!( + result_index = result.index, + "Chunk reranker returned out-of-range index; skipping" + ); + } + if reranked.len() == remaining.len() { + break; + } + } + + reranked.extend(remaining.into_iter().flatten()); + + let keep_top = tuning.rerank_keep_top; + if keep_top > 0 && reranked.len() > keep_top { + reranked.truncate(keep_top); + } + + *chunks = reranked; +} diff --git a/retrieval-pipeline/src/pipeline/stages/mod.rs b/retrieval-pipeline/src/pipeline/stages/mod.rs deleted file mode 100644 index 791d3ea..0000000 --- a/retrieval-pipeline/src/pipeline/stages/mod.rs +++ /dev/null @@ -1,1070 +0,0 @@ -use async_openai::Client; -use async_trait::async_trait; -use common::{ - error::AppError, - storage::{ - db::SurrealDbClient, - types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject}, - }, - utils::{embedding::generate_embedding, embedding::EmbeddingProvider}, -}; -use fastembed::RerankResult; -use futures::{stream::FuturesUnordered, StreamExt}; -use std::{cmp::Ordering, collections::HashMap}; -use tracing::{debug, instrument, warn}; - -use crate::{ - graph::find_entities_by_relationship_by_id, - reranking::RerankerLease, - scoring::{ - clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, reciprocal_rank_fusion, - sort_by_fused_desc, FusionWeights, RrfConfig, Scored, - }, - RetrievedChunk, RetrievedEntity, -}; - -use super::{ - config::{RetrievalConfig, RetrievalTuning}, - diagnostics::{ - AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace, - Diagnostics, - }, - StageTimings, Stage, StageKind, StrategyParams, -}; - -pub struct PipelineContext<'a> { - pub db_client: &'a SurrealDbClient, - pub openai_client: &'a Client, - pub embedding_provider: Option<&'a EmbeddingProvider>, - pub input_text: String, - pub user_id: String, - pub config: RetrievalConfig, - pub query_embedding: Option>, - pub entity_candidates: HashMap>, - pub filtered_entities: Vec>, - pub chunk_values: Vec>, - pub revised_chunk_values: Vec>, - pub reranker: Option, - pub diagnostics: Option, - pub entity_results: Vec, - pub chunk_results: Vec, - stage_timings: StageTimings, -} - -impl<'a> PipelineContext<'a> { - pub fn new(params: StrategyParams<'a>) -> Self { - Self { - db_client: params.db_client, - openai_client: params.openai_client, - embedding_provider: params.embedding_provider, - input_text: params.input_text.to_owned(), - user_id: params.user_id.to_owned(), - config: params.config, - query_embedding: None, - entity_candidates: HashMap::new(), - filtered_entities: Vec::new(), - chunk_values: Vec::new(), - revised_chunk_values: Vec::new(), - reranker: params.reranker, - diagnostics: None, - entity_results: Vec::new(), - chunk_results: Vec::new(), - stage_timings: StageTimings::default(), - } - } - - pub fn with_embedding(params: StrategyParams<'a>, query_embedding: Vec) -> Self { - let mut ctx = Self::new(params); - ctx.query_embedding = Some(query_embedding); - ctx - } - - fn ensure_embedding(&self) -> Result<&Vec, Box> { - self.query_embedding.as_ref().ok_or_else(|| { - Box::new(AppError::InternalError( - "query embedding missing before candidate collection".to_string(), - )) - }) - } - - pub fn enable_diagnostics(&mut self) { - if self.diagnostics.is_none() { - self.diagnostics = Some(Diagnostics::default()); - } - } - - pub fn diagnostics_enabled(&self) -> bool { - self.diagnostics.is_some() - } - - pub fn record_collect_candidates(&mut self, stats: CollectCandidatesStats) { - if let Some(diag) = self.diagnostics.as_mut() { - diag.collect_candidates = Some(stats); - } - } - - pub fn record_chunk_enrichment(&mut self, stats: ChunkEnrichmentStats) { - if let Some(diag) = self.diagnostics.as_mut() { - diag.enrich_chunks_from_entities = Some(stats); - } - } - - pub fn record_assemble(&mut self, stats: AssembleStats) { - if let Some(diag) = self.diagnostics.as_mut() { - diag.assemble = Some(stats); - } - } - - pub fn take_diagnostics(&mut self) -> Option { - self.diagnostics.take() - } - - pub fn take_stage_timings(&mut self) -> StageTimings { - std::mem::take(&mut self.stage_timings) - } - - pub fn record_stage_duration(&mut self, kind: StageKind, duration: std::time::Duration) { - self.stage_timings.record(kind, duration); - } - - pub fn take_entity_results(&mut self) -> Vec { - std::mem::take(&mut self.entity_results) - } - - pub fn take_chunk_results(&mut self) -> Vec { - std::mem::take(&mut self.chunk_results) - } -} - -#[derive(Debug, Clone, Copy)] -pub struct EmbedStage; - -#[async_trait] -impl Stage for EmbedStage { - fn kind(&self) -> StageKind { - StageKind::Embed - } - - async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { - embed(ctx).await - } -} - -#[derive(Debug, Clone, Copy)] -pub struct CollectCandidatesStage; - -#[async_trait] -impl Stage for CollectCandidatesStage { - fn kind(&self) -> StageKind { - StageKind::CollectCandidates - } - - async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { - collect_candidates(ctx).await - } -} - -#[derive(Debug, Clone, Copy)] -pub struct GraphExpansionStage; - -#[async_trait] -impl Stage for GraphExpansionStage { - fn kind(&self) -> StageKind { - StageKind::GraphExpansion - } - - async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { - expand_graph(ctx).await - } -} - -#[derive(Debug, Clone, Copy)] -pub struct RerankStage; - -#[async_trait] -impl Stage for RerankStage { - fn kind(&self) -> StageKind { - StageKind::Rerank - } - - async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { - rerank(ctx).await - } -} - -#[derive(Debug, Clone, Copy)] -pub struct AssembleEntitiesStage; - -#[async_trait] -impl Stage for AssembleEntitiesStage { - fn kind(&self) -> StageKind { - StageKind::Assemble - } - - async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { - assemble(ctx) - } -} - -#[derive(Debug, Clone, Copy)] -pub struct ChunkVectorStage; - -#[async_trait] -impl Stage for ChunkVectorStage { - fn kind(&self) -> StageKind { - StageKind::CollectCandidates - } - - async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { - collect_vector_chunks(ctx).await - } -} - -#[derive(Debug, Clone, Copy)] -pub struct ChunkRerankStage; - -#[async_trait] -impl Stage for ChunkRerankStage { - fn kind(&self) -> StageKind { - StageKind::Rerank - } - - async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { - rerank_chunks(ctx).await - } -} - -#[derive(Debug, Clone, Copy)] -pub struct ChunkAssembleStage; - -#[async_trait] -impl Stage for ChunkAssembleStage { - fn kind(&self) -> StageKind { - StageKind::Assemble - } - - async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { - assemble_chunks(ctx) - } -} - -#[instrument(level = "trace", skip_all)] -pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { - let embedding_cached = ctx.query_embedding.is_some(); - if embedding_cached { - debug!("Reusing cached query embedding for hybrid retrieval"); - } else { - debug!("Generating query embedding for hybrid retrieval"); - let embedding = if let Some(provider) = ctx.embedding_provider { - provider.embed(&ctx.input_text).await? - } else { - generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await? - }; - ctx.query_embedding = Some(embedding); - } - - Ok(()) -} - -#[instrument(level = "trace", skip_all)] -pub async fn collect_candidates(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { - debug!("Collecting initial candidates via vector and FTS search"); - let embedding = ctx.ensure_embedding().map_err(|e| *e)?.clone(); - let tuning = &ctx.config.tuning; - - let weights = FusionWeights::default(); - - let (vector_entity_results, fts_entity_results) = tokio::try_join!( - KnowledgeEntity::vector_search( - tuning.entity_vector_take, - embedding, - ctx.db_client, - &ctx.user_id, - ), - KnowledgeEntity::search( - ctx.db_client, - &ctx.input_text, - &ctx.user_id, - tuning.entity_fts_take, - ) - )?; - - #[allow(clippy::useless_conversion)] - let vector_entities: Vec> = vector_entity_results - .into_iter() - .map(|row| Scored::new(row.entity).with_vector_score(row.score)) - .collect(); - - let mut fts_entities: Vec> = fts_entity_results - .into_iter() - .map(|res| { - let entity = KnowledgeEntity { - id: res.id, - created_at: res.created_at, - updated_at: res.updated_at, - source_id: res.source_id, - name: res.name, - description: res.description, - entity_type: res.entity_type, - metadata: res.metadata, - user_id: res.user_id, - }; - Scored::new(entity).with_fts_score(res.score) - }) - .collect(); - - debug!( - vector_entities = vector_entities.len(), - fts_entities = fts_entities.len(), - "Hybrid retrieval initial candidate counts" - ); - - if ctx.diagnostics_enabled() { - ctx.record_collect_candidates(CollectCandidatesStats { - vector_entity_candidates: vector_entities.len(), - vector_chunk_candidates: 0, - fts_entity_candidates: fts_entities.len(), - fts_chunk_candidates: 0, - vector_chunk_scores: Vec::new(), - fts_chunk_scores: Vec::new(), - }); - } - - normalize_fts_scores(&mut fts_entities); - - merge_scored_by_id(&mut ctx.entity_candidates, vector_entities); - merge_scored_by_id(&mut ctx.entity_candidates, fts_entities); - - apply_fusion(&mut ctx.entity_candidates, weights); - - Ok(()) -} - -#[instrument(level = "trace", skip_all)] -pub async fn expand_graph(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { - debug!("Expanding candidates using graph relationships"); - let tuning = &ctx.config.tuning; - let weights = FusionWeights::default(); - - if ctx.entity_candidates.is_empty() { - return Ok(()); - } - - let graph_seeds = seeds_from_candidates( - &ctx.entity_candidates, - tuning.graph_seed_min_score, - tuning.graph_traversal_seed_limit, - ); - - if graph_seeds.is_empty() { - return Ok(()); - } - - let mut futures = FuturesUnordered::new(); - for seed in graph_seeds { - let db = ctx.db_client; - let user = ctx.user_id.clone(); - let limit = tuning.graph_neighbor_limit; - futures.push(async move { - let neighbors = find_entities_by_relationship_by_id(db, &seed.id, &user, limit).await; - (seed, neighbors) - }); - } - - while let Some((seed, neighbors_result)) = futures.next().await { - let neighbors = neighbors_result.map_err(AppError::from)?; - if neighbors.is_empty() { - continue; - } - - for neighbor in neighbors { - let neighbor_id = neighbor.id.clone(); - if neighbor_id == seed.id { - continue; - } - - let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay); - let entry = ctx - .entity_candidates - .entry(neighbor_id.clone()) - .or_insert_with(|| Scored::new(neighbor.clone())); - - entry.item = neighbor; - - let inherited_vector = clamp_unit(graph_score * tuning.graph_vector_inheritance); - let vector_existing = entry.scores.vector.unwrap_or(0.0); - if inherited_vector > vector_existing { - entry.scores.vector = Some(inherited_vector); - } - - let existing_graph = entry.scores.graph.unwrap_or(f32::MIN); - if graph_score > existing_graph || entry.scores.graph.is_none() { - entry.scores.graph = Some(graph_score); - } - - let fused = fuse_scores(&entry.scores, weights); - entry.update_fused(fused); - } - } - - Ok(()) -} - -#[instrument(level = "trace", skip_all)] -pub async fn rerank(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { - let mut applied = false; - - if let Some(reranker) = ctx.reranker.as_ref() { - if ctx.filtered_entities.len() > 1 { - let documents = build_rerank_documents(ctx, ctx.config.tuning.max_chunks_per_entity); - - if documents.len() > 1 { - match reranker.rerank(&ctx.input_text, documents).await { - Ok(results) if !results.is_empty() => { - apply_rerank_results(ctx, results); - applied = true; - } - Ok(_) => { - debug!("Reranker returned no results; retaining original ordering"); - } - Err(err) => { - warn!( - error = %err, - "Reranking failed; continuing with original ordering" - ); - } - } - } else { - debug!( - document_count = documents.len(), - "Skipping reranking stage; insufficient document context" - ); - } - } else { - debug!("Skipping reranking stage; less than two entities available"); - } - } else { - debug!("No reranker lease provided; skipping reranking stage"); - } - - if applied { - debug!("Applied reranking adjustments to candidate ordering"); - } - - Ok(()) -} - -#[instrument(level = "trace", skip_all)] -pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { - debug!("Collecting vector chunk candidates for revised strategy"); - let embedding = ctx.ensure_embedding().map_err(|e| *e)?.clone(); - let tuning = &ctx.config.tuning; - let fts_take = tuning.chunk_fts_take; - let (fts_query, fts_token_count) = normalize_fts_query(&ctx.input_text); - let fts_enabled = tuning.flags.chunk_rrf_use_fts() && fts_take > 0 && !fts_query.is_empty(); - - let (vector_rows, fts_rows) = tokio::try_join!( - TextChunk::vector_search( - tuning.chunk_vector_take, - embedding, - ctx.db_client, - &ctx.user_id, - ), - async { - if fts_enabled { - TextChunk::fts_search(fts_take, &fts_query, ctx.db_client, &ctx.user_id).await - } else { - Ok(Vec::new()) - } - } - )?; - - let vector_candidates = vector_rows.len(); - let fts_candidates = fts_rows.len(); - - let vector_scored: Vec> = vector_rows - .into_iter() - .map(|row| Scored::new(row.chunk).with_vector_score(row.score)) - .collect(); - - let fts_scored: Vec> = fts_rows - .into_iter() - .map(|row| Scored::new(row.chunk).with_fts_score(row.score)) - .collect(); - - let mut fts_weight = tuning.chunk_rrf_fts_weight; - if fts_enabled && fts_token_count > 0 && fts_token_count <= 3 { - // For very short keyword queries, lean more on lexical ranking. - fts_weight *= 1.5; - } - - let rrf_config = RrfConfig { - k: tuning.chunk_rrf_k, - vector_weight: tuning.chunk_rrf_vector_weight, - fts_weight, - use_vector: tuning.flags.chunk_rrf_use_vector(), - use_fts: tuning.flags.chunk_rrf_use_fts() && fts_candidates > 0, - }; - - let mut vector_chunks = reciprocal_rank_fusion(vector_scored, fts_scored, rrf_config); - - debug!( - total_merged = vector_chunks.len(), - vector_only = vector_chunks - .iter() - .filter(|c| c.scores.fts.is_none()) - .count(), - fts_only = vector_chunks - .iter() - .filter(|c| c.scores.vector.is_none()) - .count(), - both_signals = vector_chunks - .iter() - .filter(|c| c.scores.vector.is_some() && c.scores.fts.is_some()) - .count(), - rrf_k = rrf_config.k, - rrf_vector_weight = rrf_config.vector_weight, - rrf_fts_weight = rrf_config.fts_weight, - "Merged chunk candidates with RRF" - ); - - // let fts_only_count = vector_chunks - // .iter() - // .filter(|c| c.scores.vector.is_none()) - // .count(); - // let both_count = vector_chunks - // .iter() - // .filter(|c| c.scores.vector.is_some() && c.scores.fts.is_some()) - // .count(); - - debug!( - top_fused_scores = ?vector_chunks.iter().take(5).map(|c| c.fused).collect::>(), - "Fused scores after RRF ordering" - ); - - if ctx.diagnostics_enabled() { - ctx.record_collect_candidates(CollectCandidatesStats { - vector_entity_candidates: 0, - vector_chunk_candidates: vector_candidates, - fts_entity_candidates: 0, - fts_chunk_candidates: fts_candidates, - vector_chunk_scores: sample_scores(&vector_chunks, |chunk| { - chunk.scores.vector.unwrap_or(0.0) - }), - fts_chunk_scores: sample_scores(&vector_chunks, |chunk| { - chunk.scores.fts.unwrap_or(0.0) - }), - }); - } - - sort_by_fused_desc(&mut vector_chunks); - ctx.revised_chunk_values = vector_chunks; - - Ok(()) -} - -#[instrument(level = "trace", skip_all)] -pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { - if ctx.revised_chunk_values.len() <= 1 { - return Ok(()); - } - - let Some(reranker) = ctx.reranker.as_ref() else { - debug!("No reranker lease provided; skipping chunk rerank stage"); - return Ok(()); - }; - - let documents = build_chunk_rerank_documents( - &ctx.revised_chunk_values, - ctx.config.tuning.rerank_keep_top.max(1), - ); - if documents.len() <= 1 { - debug!("Skipping chunk reranking stage; insufficient chunk documents"); - return Ok(()); - } - - match reranker.rerank(&ctx.input_text, documents).await { - Ok(results) if !results.is_empty() => { - apply_chunk_rerank_results(&mut ctx.revised_chunk_values, &ctx.config.tuning, results); - } - Ok(_) => debug!("Chunk reranker returned no results; retaining original order"), - Err(err) => warn!( - error = %err, - "Chunk reranking failed; continuing with original ordering" - ), - } - - Ok(()) -} - -#[instrument(level = "trace", skip_all)] -pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { - debug!("Assembling chunk-only retrieval results"); - let mut chunk_values = std::mem::take(&mut ctx.revised_chunk_values); - // Limit how many chunks we return to keep context size reasonable. - let limit = ctx - .config - .tuning - .chunk_result_cap - .max(1) - .min(ctx.config.tuning.chunk_vector_take.max(1)); - - if chunk_values.len() > limit { - chunk_values.truncate(limit); - } - - ctx.chunk_results = chunk_values - .into_iter() - .map(|chunk| RetrievedChunk { - chunk: chunk.item, - score: chunk.fused, - }) - .collect(); - - if ctx.diagnostics_enabled() { - ctx.record_assemble(AssembleStats { - token_budget_start: ctx.config.tuning.token_budget_estimate, - token_budget_spent: 0, - token_budget_remaining: ctx.config.tuning.token_budget_estimate, - budget_exhausted: false, - chunks_selected: ctx.chunk_results.len(), - chunks_skipped_due_budget: 0, - entity_count: 0, - entity_traces: Vec::new(), - }); - } - - Ok(()) -} - -#[instrument(level = "trace", skip_all)] -pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { - debug!("Assembling final retrieved entities"); - let tuning = &ctx.config.tuning; - let question_terms = extract_keywords(&ctx.input_text); - - let mut chunk_by_source: HashMap>> = - HashMap::with_capacity(ctx.chunk_values.len()); - for chunk in ctx.chunk_values.drain(..) { - chunk_by_source - .entry(chunk.item.source_id.clone()) - .or_default() - .push(chunk); - } - - for chunk_list in chunk_by_source.values_mut() { - chunk_list.sort_by(|a, b| { - // No base-table embeddings; order by fused score only. - b.fused.partial_cmp(&a.fused).unwrap_or(Ordering::Equal) - }); - } - - let mut token_budget_remaining = tuning.token_budget_estimate; - let mut results = Vec::with_capacity(ctx.filtered_entities.len()); - let diagnostics_enabled = ctx.diagnostics_enabled(); - let mut per_entity_traces = if diagnostics_enabled { - Vec::with_capacity(ctx.filtered_entities.len()) - } else { - Vec::new() - }; - let mut chunks_skipped_due_budget = 0usize; - let mut chunks_selected = 0usize; - let mut tokens_spent = 0usize; - - for entity in &ctx.filtered_entities { - let mut selected_chunks = Vec::with_capacity(tuning.max_chunks_per_entity); - let mut entity_trace = if diagnostics_enabled { - Some(EntityAssemblyTrace { - entity_id: entity.item.id.clone(), - source_id: entity.item.source_id.clone(), - inspected_candidates: 0, - selected_chunk_ids: Vec::new(), - selected_chunk_scores: Vec::new(), - skipped_due_budget: 0, - }) - } else { - None - }; - if let Some(candidates) = chunk_by_source.get_mut(&entity.item.source_id) { - rank_chunks_by_combined_score(candidates, &question_terms, tuning.lexical_match_weight); - let mut per_entity_count = 0; - for candidate in candidates.iter() { - if let Some(trace) = entity_trace.as_mut() { - trace.inspected_candidates = trace.inspected_candidates.saturating_add(1); - } - if per_entity_count >= tuning.max_chunks_per_entity { - break; - } - let estimated_tokens = - estimate_tokens(&candidate.item.chunk, tuning.avg_chars_per_token); - if estimated_tokens > token_budget_remaining { - chunks_skipped_due_budget = chunks_skipped_due_budget.saturating_add(1); - if let Some(trace) = entity_trace.as_mut() { - trace.skipped_due_budget = trace.skipped_due_budget.saturating_add(1); - } - continue; - } - - token_budget_remaining = token_budget_remaining.saturating_sub(estimated_tokens); - tokens_spent = tokens_spent.saturating_add(estimated_tokens); - per_entity_count = per_entity_count.saturating_add(1); - chunks_selected = chunks_selected.saturating_add(1); - - selected_chunks.push(RetrievedChunk { - chunk: candidate.item.clone(), - score: candidate.fused, - }); - if let Some(trace) = entity_trace.as_mut() { - trace.selected_chunk_ids.push(candidate.item.id.clone()); - trace.selected_chunk_scores.push(candidate.fused); - } - } - } - - results.push(RetrievedEntity { - entity: entity.item.clone(), - score: entity.fused, - chunks: selected_chunks, - }); - - if let Some(trace) = entity_trace { - per_entity_traces.push(trace); - } - - if token_budget_remaining == 0 { - break; - } - } - - if diagnostics_enabled { - ctx.record_assemble(AssembleStats { - token_budget_start: tuning.token_budget_estimate, - token_budget_spent: tokens_spent, - token_budget_remaining, - budget_exhausted: token_budget_remaining == 0, - chunks_selected, - chunks_skipped_due_budget, - entity_count: ctx.filtered_entities.len(), - entity_traces: per_entity_traces, - }); - } - - ctx.entity_results = results; - Ok(()) -} - -const SCORE_SAMPLE_LIMIT: usize = 8; - -fn sample_scores(items: &[Scored], extractor: F) -> Vec -where - F: FnMut(&Scored) -> f32, -{ - items - .iter() - .take(SCORE_SAMPLE_LIMIT) - .map(extractor) - .collect() -} - -fn normalize_fts_scores(results: &mut [Scored]) { - let raw_scores: Vec = results - .iter() - .map(|candidate| candidate.scores.fts.unwrap_or(0.0)) - .collect(); - - let normalized = min_max_normalize(&raw_scores); - for (candidate, normalized_score) in results.iter_mut().zip(normalized.into_iter()) { - candidate.scores.fts = Some(normalized_score); - candidate.update_fused(0.0); - } -} - -fn normalize_fts_query(input: &str) -> (String, usize) { - const STOPWORDS: &[&str] = &["the", "a", "an", "of", "in", "on", "and", "or", "to", "for"]; - let mut cleaned = String::with_capacity(input.len()); - for ch in input.chars() { - if ch.is_alphanumeric() { - cleaned.extend(ch.to_lowercase()); - } else if ch.is_whitespace() { - cleaned.push(' '); - } - } - let mut tokens = Vec::with_capacity(cleaned.len() / 3 + 1); - for token in cleaned.split_whitespace() { - if !STOPWORDS.contains(&token) && !token.is_empty() { - tokens.push(token.to_string()); - } - } - let normalized = tokens.join(" "); - (normalized, tokens.len()) -} - -fn apply_fusion(candidates: &mut HashMap>, weights: FusionWeights) -where - T: StoredObject, -{ - for candidate in candidates.values_mut() { - let fused = fuse_scores(&candidate.scores, weights); - candidate.update_fused(fused); - } -} - -fn build_rerank_documents(ctx: &PipelineContext<'_>, max_chunks_per_entity: usize) -> Vec { - if ctx.filtered_entities.is_empty() { - return Vec::new(); - } - - let mut chunk_by_source: HashMap<&str, Vec<&Scored>> = - HashMap::with_capacity(ctx.chunk_values.len()); - for chunk in &ctx.chunk_values { - chunk_by_source - .entry(chunk.item.source_id.as_str()) - .or_default() - .push(chunk); - } - - ctx.filtered_entities - .iter() - .map(|entity| { - let mut doc = format!( - "Name: {}\nType: {:?}\nDescription: {}\n", - entity.item.name, entity.item.entity_type, entity.item.description - ); - - if let Some(chunks) = chunk_by_source.get(entity.item.source_id.as_str()) { - let mut chunk_refs = chunks.clone(); - chunk_refs.sort_by(|a, b| { - b.fused - .partial_cmp(&a.fused) - .unwrap_or(std::cmp::Ordering::Equal) - }); - - let mut header_added = false; - for chunk in chunk_refs.into_iter().take(max_chunks_per_entity.max(1)) { - let snippet = chunk.item.chunk.trim(); - if snippet.is_empty() { - continue; - } - if !header_added { - doc.push_str("Chunks:\n"); - header_added = true; - } - doc.push_str("- "); - doc.push_str(snippet); - doc.push('\n'); - } - } - - doc - }) - .collect() -} - -fn build_chunk_rerank_documents(chunks: &[Scored], max_chunks: usize) -> Vec { - chunks - .iter() - .take(max_chunks) - .map(|chunk| { - format!( - "Source: {}\nChunk:\n{}", - chunk.item.source_id, - chunk.item.chunk.trim() - ) - }) - .collect() -} - -fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec) { - if results.is_empty() || ctx.filtered_entities.is_empty() { - return; - } - - let mut remaining: Vec>> = - std::mem::take(&mut ctx.filtered_entities) - .into_iter() - .map(Some) - .collect(); - - let raw_scores: Vec = results.iter().map(|r| r.score).collect(); - let normalized_scores = min_max_normalize(&raw_scores); - - let use_only = ctx.config.tuning.flags.rerank_scores_only(); - let blend = if use_only { - 1.0 - } else { - clamp_unit(ctx.config.tuning.rerank_blend_weight) - }; - let mut reranked: Vec> = Vec::with_capacity(remaining.len()); - for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) { - if let Some(slot) = remaining.get_mut(result.index) { - if let Some(mut candidate) = slot.take() { - let original = candidate.fused; - let blended = if use_only { - clamp_unit(normalized) - } else { - clamp_unit(original * (1.0 - blend) + normalized * blend) - }; - candidate.update_fused(blended); - reranked.push(candidate); - } - } else { - warn!( - result_index = result.index, - "Reranker returned out-of-range index; skipping" - ); - } - if reranked.len() == remaining.len() { - break; - } - } - - reranked.extend(remaining.into_iter().flatten()); - - ctx.filtered_entities = reranked; - let keep_top = ctx.config.tuning.rerank_keep_top; - if keep_top > 0 && ctx.filtered_entities.len() > keep_top { - ctx.filtered_entities.truncate(keep_top); - } -} - -fn apply_chunk_rerank_results( - chunks: &mut Vec>, - tuning: &RetrievalTuning, - results: Vec, -) { - if results.is_empty() || chunks.is_empty() { - return; - } - - let mut remaining: Vec>> = - std::mem::take(chunks).into_iter().map(Some).collect(); - - let raw_scores: Vec = results.iter().map(|r| r.score).collect(); - let normalized_scores = min_max_normalize(&raw_scores); - - let use_only = tuning.flags.rerank_scores_only(); - let blend = if use_only { - 1.0 - } else { - clamp_unit(tuning.rerank_blend_weight) - }; - - let mut reranked: Vec> = Vec::with_capacity(remaining.len()); - for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) { - if let Some(slot) = remaining.get_mut(result.index) { - if let Some(mut candidate) = slot.take() { - let original = candidate.fused; - let blended = if use_only { - clamp_unit(normalized) - } else { - clamp_unit(original * (1.0 - blend) + normalized * blend) - }; - candidate.update_fused(blended); - reranked.push(candidate); - } - } else { - warn!( - result_index = result.index, - "Chunk reranker returned out-of-range index; skipping" - ); - } - if reranked.len() == remaining.len() { - break; - } - } - - reranked.extend(remaining.into_iter().flatten()); - - let keep_top = tuning.rerank_keep_top; - if keep_top > 0 && reranked.len() > keep_top { - reranked.truncate(keep_top); - } - - *chunks = reranked; -} - -fn estimate_tokens(text: &str, avg_chars_per_token: usize) -> usize { - let chars = text.chars().count().max(1); - chars.checked_div(avg_chars_per_token).map_or(1, |v| v.max(1)) -} - -fn rank_chunks_by_combined_score( - candidates: &mut [Scored], - question_terms: &[String], - lexical_weight: f32, -) { - if lexical_weight > 0.0 && !question_terms.is_empty() { - for candidate in candidates.iter_mut() { - let lexical = lexical_overlap_score(question_terms, &candidate.item.chunk); - let combined = clamp_unit(candidate.fused + lexical_weight * lexical); - candidate.update_fused(combined); - } - } - candidates.sort_by(|a, b| b.fused.partial_cmp(&a.fused).unwrap_or(Ordering::Equal)); -} - -fn extract_keywords(text: &str) -> Vec { - let mut terms = Vec::with_capacity((text.len() / 3).max(4)); - for raw in text.split(|c: char| !c.is_alphanumeric()) { - let term = raw.trim().to_ascii_lowercase(); - if term.len() >= 3 { - terms.push(term); - } - } - terms.sort(); - terms.dedup(); - terms -} - -fn lexical_overlap_score(terms: &[String], haystack: &str) -> f32 { - if terms.is_empty() { - return 0.0; - } - let lower = haystack.to_ascii_lowercase(); - let mut matches: u32 = 0; - for term in terms { - if lower.contains(term) { - matches = matches.saturating_add(1); - } - } - let total = u32::try_from(terms.len()).unwrap_or(u32::MAX); - if total == 0 { - return 0.0; - } - let num = matches.min(total); - let num_f32 = u16::try_from(num).map(f32::from).unwrap_or(f32::MAX); - let den_f32 = u16::try_from(total).map(f32::from).unwrap_or(f32::MAX); - num_f32 / den_f32 -} - -#[derive(Clone)] -struct GraphSeed { - id: String, - fused: f32, -} - -fn seeds_from_candidates( - entity_candidates: &HashMap>, - min_score: f32, - limit: usize, -) -> Vec { - let mut seeds: Vec = entity_candidates - .values() - .filter(|entity| entity.fused >= min_score) - .map(|entity| GraphSeed { - id: entity.item.id.clone(), - fused: entity.fused, - }) - .collect(); - - seeds.sort_by(|a, b| { - b.fused - .partial_cmp(&a.fused) - .unwrap_or(std::cmp::Ordering::Equal) - }); - if seeds.len() > limit { - seeds.truncate(limit); - } - - seeds -} diff --git a/retrieval-pipeline/src/pipeline/strategies.rs b/retrieval-pipeline/src/pipeline/strategies.rs deleted file mode 100644 index 0b1e6bb..0000000 --- a/retrieval-pipeline/src/pipeline/strategies.rs +++ /dev/null @@ -1,148 +0,0 @@ -use super::{ - stages::{ - AssembleEntitiesStage, ChunkAssembleStage, ChunkRerankStage, ChunkVectorStage, - CollectCandidatesStage, EmbedStage, GraphExpansionStage, PipelineContext, RerankStage, - }, - BoxedStage, StrategyDriver, -}; -use crate::{RetrievedChunk, RetrievedEntity}; -use common::error::AppError; - -pub struct DefaultStrategyDriver; - -impl DefaultStrategyDriver { - pub fn new() -> Self { - Self - } -} - -impl StrategyDriver for DefaultStrategyDriver { - type Output = Vec; - - fn stages(&self) -> Vec { - vec![ - Box::new(EmbedStage), - Box::new(ChunkVectorStage), - Box::new(ChunkRerankStage), - Box::new(ChunkAssembleStage), - ] - } - - fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result> { - Ok(ctx.take_chunk_results()) - } -} - -pub struct RelationshipSuggestionDriver; - -impl RelationshipSuggestionDriver { - pub fn new() -> Self { - Self - } -} - -impl StrategyDriver for RelationshipSuggestionDriver { - type Output = Vec; - - fn stages(&self) -> Vec { - vec![ - Box::new(EmbedStage), - Box::new(CollectCandidatesStage), - Box::new(GraphExpansionStage), - // Skip ChunkAttachStage - Box::new(RerankStage), - Box::new(AssembleEntitiesStage), - ] - } - - fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result> { - Ok(ctx.take_entity_results()) - } -} - -pub struct IngestionDriver; - -impl IngestionDriver { - pub fn new() -> Self { - Self - } -} - -impl StrategyDriver for IngestionDriver { - type Output = Vec; - - fn stages(&self) -> Vec { - vec![ - Box::new(EmbedStage), - Box::new(CollectCandidatesStage), - Box::new(GraphExpansionStage), - // Skip ChunkAttachStage - Box::new(RerankStage), - Box::new(AssembleEntitiesStage), - ] - } - - fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result> { - Ok(ctx.take_entity_results()) - } -} - -use super::config::SearchTarget; -use crate::SearchResult; - -/// Search strategy driver that retrieves both chunks and entities -pub struct SearchStrategyDriver { - target: SearchTarget, -} - -impl SearchStrategyDriver { - pub fn new(target: SearchTarget) -> Self { - Self { target } - } -} - -impl StrategyDriver for SearchStrategyDriver { - type Output = SearchResult; - - fn stages(&self) -> Vec { - match self.target { - SearchTarget::ChunksOnly => vec![ - Box::new(EmbedStage), - Box::new(ChunkVectorStage), - Box::new(ChunkRerankStage), - Box::new(ChunkAssembleStage), - ], - SearchTarget::EntitiesOnly => vec![ - Box::new(EmbedStage), - Box::new(CollectCandidatesStage), - Box::new(GraphExpansionStage), - Box::new(RerankStage), - Box::new(AssembleEntitiesStage), - ], - SearchTarget::Both => vec![ - Box::new(EmbedStage), - // Chunk retrieval path - Box::new(ChunkVectorStage), - Box::new(ChunkRerankStage), - Box::new(ChunkAssembleStage), - // Entity retrieval path (runs after chunk stages) - Box::new(CollectCandidatesStage), - Box::new(GraphExpansionStage), - Box::new(RerankStage), - Box::new(AssembleEntitiesStage), - ], - } - } - - fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result> { - let chunks = match self.target { - SearchTarget::EntitiesOnly => Vec::new(), - _ => ctx.take_chunk_results(), - }; - let entities = match self.target { - SearchTarget::ChunksOnly => Vec::new(), - _ => ctx.take_entity_results(), - }; - Ok(SearchResult::new(chunks, entities)) - } -} diff --git a/retrieval-pipeline/src/reranking/mod.rs b/retrieval-pipeline/src/reranking.rs similarity index 95% rename from retrieval-pipeline/src/reranking/mod.rs rename to retrieval-pipeline/src/reranking.rs index b706774..823e03e 100644 --- a/retrieval-pipeline/src/reranking/mod.rs +++ b/retrieval-pipeline/src/reranking.rs @@ -97,8 +97,7 @@ impl RerankerPool { fn default_pool_size() -> usize { available_parallelism() - .map(|value| value.get().min(2)) - .unwrap_or(2) + .map_or(2, |value| value.get().min(2)) .max(1) } @@ -156,6 +155,7 @@ pub struct RerankerLease { } impl RerankerLease { + #[allow(clippy::result_large_err)] pub async fn rerank( &self, query: &str, @@ -165,7 +165,9 @@ impl RerankerLease { let engine = Arc::clone(&self.engine); tokio::task::spawn_blocking(move || { - let mut guard = engine.lock().expect("reranker engine mutex poisoned"); + let mut guard = engine.lock().map_err(|_| { + AppError::InternalError("reranker engine mutex poisoned".into()) + })?; guard .rerank(query, documents, false, None) .map_err(|e| AppError::InternalError(e.to_string())) diff --git a/retrieval-pipeline/src/scoring.rs b/retrieval-pipeline/src/scoring.rs index b29da7d..1c51e2b 100644 --- a/retrieval-pipeline/src/scoring.rs +++ b/retrieval-pipeline/src/scoring.rs @@ -1,14 +1,12 @@ use std::{cmp::Ordering, collections::HashMap}; use common::storage::types::StoredObject; -use serde::{Deserialize, Serialize}; -/// Holds optional subscores gathered from different retrieval signals. +/// Holds optional subscores gathered from the vector and full-text retrieval signals. #[derive(Debug, Clone, Copy, Default)] pub struct Scores { pub fts: Option, pub vector: Option, - pub graph: Option, } /// Generic wrapper combining an item with its accumulated retrieval scores. @@ -40,40 +38,11 @@ impl Scored { self } - #[must_use] - pub const fn with_graph_score(mut self, score: f32) -> Self { - self.scores.graph = Some(score); - self - } - pub const fn update_fused(&mut self, fused: f32) { self.fused = fused; } } -/// Weights used for linear score fusion. -#[derive(Debug, Clone, Copy, Serialize, Deserialize)] -pub struct FusionWeights { - pub vector: f32, - pub fts: f32, - pub graph: f32, - pub multi_bonus: f32, -} - -impl Default for FusionWeights { - fn default() -> Self { - // Default weights favor vector search, which typically performs better - // FTS is used as a complement when there's good overlap - // Higher multi_bonus to heavily favor chunks with both signals (the "golden chunk") - Self { - vector: 0.8, - fts: 0.2, - graph: 0.2, - multi_bonus: 0.3, // Increased to boost chunks with both signals - } - } -} - /// Configuration for reciprocal rank fusion. #[derive(Debug, Clone, Copy)] pub struct RrfConfig { @@ -84,29 +53,10 @@ pub struct RrfConfig { pub use_fts: bool, } -impl Default for RrfConfig { - fn default() -> Self { - Self { - k: 60.0, - vector_weight: 1.0, - fts_weight: 1.0, - use_vector: true, - use_fts: true, - } - } -} - pub const fn clamp_unit(value: f32) -> f32 { value.clamp(0.0, 1.0) } -pub fn distance_to_similarity(distance: f32) -> f32 { - if !distance.is_finite() { - return 0.0; - } - clamp_unit(1.0 / (1.0 + distance.max(0.0))) -} - pub fn min_max_normalize(scores: &[f32]) -> Vec { if scores.is_empty() { return Vec::new(); @@ -147,69 +97,6 @@ pub fn min_max_normalize(scores: &[f32]) -> Vec { .collect() } -pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 { - let vector = scores.vector.unwrap_or(0.0); - let fts = scores.fts.unwrap_or(0.0); - let graph = scores.graph.unwrap_or(0.0); - - let mut fused = graph.mul_add( - weights.graph, - vector.mul_add(weights.vector, fts * weights.fts), - ); - - let signals_present = scores - .vector - .iter() - .chain(scores.fts.iter()) - .chain(scores.graph.iter()) - .count(); - - // Boost chunks with multiple signals (especially vector + FTS, the "golden chunk") - if signals_present >= 2 { - // For chunks with both vector and FTS, give a significant boost - // This helps identify the "golden chunk" that appears in both searches - if scores.vector.is_some() && scores.fts.is_some() { - // Multiplicative boost: multiply by (1 + bonus) to scale with the base score - // This ensures high-scoring golden chunks get boosted more than low-scoring ones - fused *= 1.0 + weights.multi_bonus; - } else { - // For other multi-signal combinations (e.g., vector + graph), use additive bonus - fused += weights.multi_bonus; - } - } - - clamp_unit(fused) -} - -pub fn merge_scored_by_id( - target: &mut std::collections::HashMap, S>, - incoming: Vec>, -) where - T: StoredObject + Clone, -{ - for scored in incoming { - let id = scored.item.id().to_owned(); - target - .entry(id) - .and_modify(|existing| { - if let Some(score) = scored.scores.vector { - existing.scores.vector = Some(score); - } - if let Some(score) = scored.scores.fts { - existing.scores.fts = Some(score); - } - if let Some(score) = scored.scores.graph { - existing.scores.graph = Some(score); - } - }) - .or_insert_with(|| Scored { - item: scored.item.clone(), - scores: scored.scores, - fused: scored.fused, - }); - } -} - pub fn sort_by_fused_desc(items: &mut [Scored]) where T: StoredObject, @@ -222,6 +109,10 @@ where }); } +/// Fuse two ranked candidate lists into a single ranking using reciprocal rank fusion. +/// +/// This is the sole fusion mechanism for the retrieval pipeline: vector and full-text +/// candidates each contribute `weight / (k + rank + 1)` to a shared fused score. pub fn reciprocal_rank_fusion( mut vector_ranked: Vec>, mut fts_ranked: Vec>, @@ -266,9 +157,7 @@ where } } entry.item = candidate.item; - let rank_f32: f32 = u16::try_from(rank) - .map(f32::from) - .unwrap_or(f32::MAX); + let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from); entry.fused += vector_weight / (k + rank_f32 + 1.0); } } @@ -296,9 +185,7 @@ where } } entry.item = candidate.item; - let rank_f32: f32 = u16::try_from(rank) - .map(f32::from) - .unwrap_or(f32::MAX); + let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from); entry.fused += fts_weight / (k + rank_f32 + 1.0); } }