chore: refactor retrieval pipeline to chunk-first RRF with derived entities and slimmer eval surface.

Collapse the multi-strategy entity engine into one benchmarked chunk retrieval path, derive entities from retrieved chunks, and update consumers, docs, and clippy fixes across the workspace.
This commit is contained in:
Per Stark
2026-05-30 22:19:08 +02:00
parent c70141de35
commit 5c2d2e24d3
38 changed files with 1049 additions and 2614 deletions
Generated
-4
View File
@@ -5426,14 +5426,10 @@ dependencies = [
"anyhow", "anyhow",
"async-openai", "async-openai",
"async-trait", "async-trait",
"axum",
"common", "common",
"fastembed", "fastembed",
"futures",
"serde", "serde",
"serde_json", "serde_json",
"surrealdb",
"thiserror 1.0.69",
"tokio", "tokio",
"tracing", "tracing",
"uuid", "uuid",
@@ -164,6 +164,35 @@ impl KnowledgeEntity {
.take(0)?) .take(0)?)
} }
/// Fetch all knowledge entities owned by any of the provided source ids for a user.
///
/// Used by retrieval to resolve the entities that own a set of retrieved chunks.
pub async fn find_by_source_ids(
db: &SurrealDbClient,
source_ids: &[String],
user_id: &str,
) -> Result<Vec<KnowledgeEntity>, AppError> {
if source_ids.is_empty() {
return Ok(Vec::new());
}
let entities: Vec<KnowledgeEntity> = db
.client
.query(
"SELECT * FROM type::table($table) \
WHERE source_id IN $sources AND user_id = $user_id",
)
.bind(("table", Self::table_name()))
.bind(("sources", source_ids.to_vec()))
.bind(("user_id", user_id.to_owned()))
.await
.map_err(AppError::Database)?
.take(0)
.map_err(AppError::Database)?;
Ok(entities)
}
pub async fn delete_by_source_id( pub async fn delete_by_source_id(
source_id: &str, source_id: &str,
db_client: &SurrealDbClient, db_client: &SurrealDbClient,
+11 -119
View File
@@ -1,8 +1,7 @@
use config::{Config, ConfigError, Environment, File}; use config::{Config, ConfigError, Environment, File};
use serde::{Deserialize, Deserializer, Serialize}; use serde::{Deserialize, Serialize};
use std::{env, fmt, str::FromStr, sync::Once}; use std::{env, str::FromStr, sync::Once};
use thiserror::Error; use thiserror::Error;
use tracing::warn;
/// Error returned when parsing an embedding backend name. /// Error returned when parsing an embedding backend name.
#[derive(Debug, Error, PartialEq, Eq)] #[derive(Debug, Error, PartialEq, Eq)]
@@ -36,83 +35,6 @@ impl EmbeddingBackend {
} }
} }
/// Error returned when parsing a retrieval strategy name.
#[derive(Debug, Error, PartialEq, Eq)]
#[error("unknown retrieval strategy '{input}'")]
pub struct ParseRetrievalStrategyError {
/// The unrecognized input string.
pub input: String,
}
/// Selects which retrieval pipeline strategy to run for chat and search.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum RetrievalStrategy {
/// Primary hybrid chunk retrieval for search/chat.
#[default]
Default,
/// Entity retrieval for suggesting relationships when creating manual entities.
RelationshipSuggestion,
/// Entity retrieval for context during content ingestion.
Ingestion,
/// Unified search returning both chunks and entities.
Search,
}
impl RetrievalStrategy {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::Default => "default",
Self::RelationshipSuggestion => "relationship_suggestion",
Self::Ingestion => "ingestion",
Self::Search => "search",
}
}
}
impl FromStr for RetrievalStrategy {
type Err = ParseRetrievalStrategyError;
fn from_str(value: &str) -> Result<Self, Self::Err> {
match value.to_ascii_lowercase().as_str() {
"default" => Ok(Self::Default),
"initial" | "revised" => {
warn!(
"retrieval strategy '{value}' is deprecated; use 'default' instead"
);
Ok(Self::Default)
}
"relationship_suggestion" => Ok(Self::RelationshipSuggestion),
"ingestion" => Ok(Self::Ingestion),
"search" => Ok(Self::Search),
other => Err(ParseRetrievalStrategyError {
input: other.to_string(),
}),
}
}
}
impl fmt::Display for RetrievalStrategy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
fn deserialize_optional_retrieval_strategy<'de, D>(
deserializer: D,
) -> Result<Option<RetrievalStrategy>, D::Error>
where
D: Deserializer<'de>,
{
let value = Option::<String>::deserialize(deserializer)?;
match value {
None => Ok(None),
Some(raw) if raw.trim().is_empty() => Ok(None),
Some(raw) => RetrievalStrategy::from_str(&raw).map(Some).map_err(serde::de::Error::custom),
}
}
impl FromStr for EmbeddingBackend { impl FromStr for EmbeddingBackend {
type Err = ParseEmbeddingBackendError; type Err = ParseEmbeddingBackendError;
@@ -195,8 +117,6 @@ pub struct AppConfig {
pub fastembed_show_download_progress: Option<bool>, pub fastembed_show_download_progress: Option<bool>,
#[serde(default)] #[serde(default)]
pub fastembed_max_length: Option<usize>, pub fastembed_max_length: Option<usize>,
#[serde(default, deserialize_with = "deserialize_optional_retrieval_strategy")]
pub retrieval_strategy: Option<RetrievalStrategy>,
#[serde(default)] #[serde(default)]
pub embedding_backend: EmbeddingBackend, pub embedding_backend: EmbeddingBackend,
#[serde(default = "default_ingest_max_body_bytes")] #[serde(default = "default_ingest_max_body_bytes")]
@@ -282,14 +202,6 @@ pub fn ensure_ort_path() {
}); });
} }
impl AppConfig {
/// Returns the configured retrieval strategy, or [`RetrievalStrategy::Default`] when unset.
#[must_use]
pub fn resolved_retrieval_strategy(&self) -> RetrievalStrategy {
self.retrieval_strategy.unwrap_or_default()
}
}
impl Default for AppConfig { impl Default for AppConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
@@ -312,7 +224,6 @@ impl Default for AppConfig {
fastembed_cache_dir: None, fastembed_cache_dir: None,
fastembed_show_download_progress: None, fastembed_show_download_progress: None,
fastembed_max_length: None, fastembed_max_length: None,
retrieval_strategy: None,
embedding_backend: EmbeddingBackend::default(), embedding_backend: EmbeddingBackend::default(),
ingest_max_body_bytes: default_ingest_max_body_bytes(), ingest_max_body_bytes: default_ingest_max_body_bytes(),
ingest_max_files: default_ingest_max_files(), ingest_max_files: default_ingest_max_files(),
@@ -340,41 +251,22 @@ pub fn get_config() -> Result<AppConfig, ConfigError> {
mod tests { mod tests {
#![allow(clippy::expect_used)] #![allow(clippy::expect_used)]
use super::{ParseRetrievalStrategyError, RetrievalStrategy}; use super::EmbeddingBackend;
#[test] #[test]
fn retrieval_strategy_defaults_to_default() { fn embedding_backend_defaults_to_fastembed() {
assert_eq!( assert_eq!(EmbeddingBackend::default(), EmbeddingBackend::FastEmbed);
RetrievalStrategy::default(),
RetrievalStrategy::Default
);
} }
#[test] #[test]
fn retrieval_strategy_serializes_snake_case() { fn embedding_backend_parses_aliases() {
assert_eq!( assert_eq!(
serde_json::to_string(&RetrievalStrategy::Search).expect("serialize"), "openai".parse::<EmbeddingBackend>().expect("openai"),
"\"search\"" EmbeddingBackend::OpenAI
); );
}
#[test]
fn retrieval_strategy_from_str_accepts_deprecated_aliases() {
assert_eq!( assert_eq!(
"initial".parse::<RetrievalStrategy>().expect("initial"), "fast".parse::<EmbeddingBackend>().expect("fast"),
RetrievalStrategy::Default EmbeddingBackend::FastEmbed
);
assert!(matches!(
"unknown".parse::<RetrievalStrategy>(),
Err(ParseRetrievalStrategyError { .. })
));
}
#[test]
fn app_config_resolved_retrieval_strategy_uses_default_when_unset() {
let config = super::AppConfig::default();
assert_eq!(
config.resolved_retrieval_strategy(),
RetrievalStrategy::Default
); );
} }
} }
-1
View File
@@ -24,7 +24,6 @@ Minne can be configured via environment variables or a `config.yaml` file. Envir
| `RUST_LOG` | Logging level | `info` | | `RUST_LOG` | Logging level | `info` |
| `STORAGE` | Storage backend (`local`, `memory`, `s3`) | `local` | | `STORAGE` | Storage backend (`local`, `memory`, `s3`) | `local` |
| `PDF_INGEST_MODE` | PDF ingestion strategy (`classic`, `llm-first`) | `llm-first` | | `PDF_INGEST_MODE` | PDF ingestion strategy (`classic`, `llm-first`) | `llm-first` |
| `RETRIEVAL_STRATEGY` | Default retrieval strategy | - |
| `EMBEDDING_BACKEND` | Embedding provider (`openai`, `fastembed`) | `fastembed` | | `EMBEDDING_BACKEND` | Embedding provider (`openai`, `fastembed`) | `fastembed` |
| `FASTEMBED_CACHE_DIR` | Model cache directory | `<data_dir>/fastembed` | | `FASTEMBED_CACHE_DIR` | Model cache directory | `<data_dir>/fastembed` |
| `FASTEMBED_SHOW_DOWNLOAD_PROGRESS` | Show progress bar for model downloads | `false` | | `FASTEMBED_SHOW_DOWNLOAD_PROGRESS` | Show progress bar for model downloads | `false` |
+6 -5
View File
@@ -27,13 +27,14 @@ The D3-based graph visualization shows entities as nodes and relationships as ed
## Hybrid Retrieval ## Hybrid Retrieval
Minne combines multiple retrieval strategies: Minne uses chunk-first hybrid retrieval over the knowledge base:
- **Vector similarity** — Semantic matching via embeddings - **Vector similarity** — Semantic matching via embeddings over text chunks
- **Full-text search** — Keyword matching with BM25 - **Full-text search** — Keyword matching with BM25 over the same chunk index
- **Graph traversal** — Following relationships between entities
Results are merged using Reciprocal Rank Fusion (RRF) for optimal relevance. The two ranked candidate lists are merged with Reciprocal Rank Fusion (RRF). When a caller needs knowledge entities (search, ingestion linking, relationship suggestion), entities are derived from the top retrieved chunks grouped by `source_id`.
Optional **reranking** can rescore the fused chunk list with a cross-encoder model; see below.
## Reranking (Optional) ## Reranking (Optional)
+7 -18
View File
@@ -5,7 +5,6 @@ use std::{
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use clap::{Args, Parser, ValueEnum}; use clap::{Args, Parser, ValueEnum};
use retrieval_pipeline::RetrievalStrategy;
use crate::datasets::DatasetKind; use crate::datasets::DatasetKind;
@@ -55,10 +54,6 @@ pub struct RetrievalSettings {
#[arg(long)] #[arg(long)]
pub chunk_fts_take: Option<usize>, pub chunk_fts_take: Option<usize>,
/// Override average characters per token used for budgeting
#[arg(long)]
pub chunk_avg_chars_per_token: Option<usize>,
/// Override maximum chunks attached per entity /// Override maximum chunks attached per entity
#[arg(long)] #[arg(long)]
pub max_chunks_per_entity: Option<usize>, pub max_chunks_per_entity: Option<usize>,
@@ -71,41 +66,37 @@ pub struct RetrievalSettings {
#[arg(long, default_value_t = 4)] #[arg(long, default_value_t = 4)]
pub rerank_pool_size: usize, pub rerank_pool_size: usize,
/// Keep top-N entities after reranking /// Keep top-N chunks after reranking
#[arg(long, default_value_t = 10)] #[arg(long, default_value_t = 10)]
pub rerank_keep_top: usize, pub rerank_keep_top: usize,
/// Cap the number of chunks returned by retrieval (revised strategy) /// Cap the number of chunks returned by retrieval
#[arg(long, default_value_t = 5)] #[arg(long, default_value_t = 5)]
pub chunk_result_cap: usize, pub chunk_result_cap: usize,
/// Reciprocal rank fusion k value for revised chunk merging /// Reciprocal rank fusion k value for chunk merging
#[arg(long)] #[arg(long)]
pub chunk_rrf_k: Option<f32>, pub chunk_rrf_k: Option<f32>,
/// Weight for vector ranks in revised RRF /// Weight for vector ranks in RRF
#[arg(long)] #[arg(long)]
pub chunk_rrf_vector_weight: Option<f32>, pub chunk_rrf_vector_weight: Option<f32>,
/// Weight for chunk FTS ranks in revised RRF /// Weight for chunk FTS ranks in RRF
#[arg(long)] #[arg(long)]
pub chunk_rrf_fts_weight: Option<f32>, pub chunk_rrf_fts_weight: Option<f32>,
/// Include vector ranks in revised RRF (default: true) /// Include vector ranks in RRF (default: true)
#[arg(long)] #[arg(long)]
pub chunk_rrf_use_vector: Option<bool>, pub chunk_rrf_use_vector: Option<bool>,
/// Include chunk FTS ranks in revised RRF (default: true) /// Include chunk FTS ranks in RRF (default: true)
#[arg(long)] #[arg(long)]
pub chunk_rrf_use_fts: Option<bool>, pub chunk_rrf_use_fts: Option<bool>,
/// Require verified chunks (disable with --llm-mode) /// Require verified chunks (disable with --llm-mode)
#[arg(skip = true)] #[arg(skip = true)]
pub require_verified_chunks: bool, pub require_verified_chunks: bool,
/// Select the retrieval pipeline strategy
#[arg(long, default_value_t = RetrievalStrategy::Default)]
pub strategy: RetrievalStrategy,
} }
impl Default for RetrievalSettings { impl Default for RetrievalSettings {
@@ -113,7 +104,6 @@ impl Default for RetrievalSettings {
Self { Self {
chunk_vector_take: None, chunk_vector_take: None,
chunk_fts_take: None, chunk_fts_take: None,
chunk_avg_chars_per_token: None,
max_chunks_per_entity: None, max_chunks_per_entity: None,
rerank: false, rerank: false,
rerank_pool_size: 4, rerank_pool_size: 4,
@@ -125,7 +115,6 @@ impl Default for RetrievalSettings {
chunk_rrf_use_vector: None, chunk_rrf_use_vector: None,
chunk_rrf_use_fts: None, chunk_rrf_use_fts: None,
require_verified_chunks: true, require_verified_chunks: true,
strategy: RetrievalStrategy::Default,
} }
} }
} }
+18 -20
View File
@@ -51,8 +51,8 @@ pub fn mirror_perf_outputs(
pub fn print_console_summary(record: &EvaluationReport) { pub fn print_console_summary(record: &EvaluationReport) {
let perf = &record.performance; let perf = &record.performance;
println!( println!(
"[perf] retrieval strategy={} | concurrency={} | rerank={} (pool {:?}, keep {})", "[perf] resolve_entities={} | concurrency={} | rerank={} (pool {:?}, keep {})",
record.retrieval.strategy, record.retrieval.resolve_entities,
record.retrieval.concurrency, record.retrieval.concurrency,
record.retrieval.rerank_enabled, record.retrieval.rerank_enabled,
record.retrieval.rerank_pool_size, record.retrieval.rerank_pool_size,
@@ -63,16 +63,14 @@ pub fn print_console_summary(record: &EvaluationReport) {
perf.ingestion_ms, perf.ingestion_ms,
format_duration(perf.namespace_seed_ms), format_duration(perf.namespace_seed_ms),
); );
let stage = &perf.stage_latency; let stage_summary = perf
println!( .stage_latency
"[perf] stage avg ms → embed {:.1} | collect {:.1} | graph {:.1} | chunk {:.1} | rerank {:.1} | assemble {:.1}", .stages
stage.embed.avg, .iter()
stage.collect_candidates.avg, .map(|s| format!("{} {:.1}", s.stage, s.stats.avg))
stage.graph_expansion.avg, .collect::<Vec<_>>()
stage.chunk_attach.avg, .join(" | ");
stage.rerank.avg, println!("[perf] stage avg ms → {stage_summary}");
stage.assemble.avg,
);
let eval = &perf.evaluation_stages_ms; let eval = &perf.evaluation_stages_ms;
println!( println!(
"[perf] eval stage ms → slice {} | db {} | corpus {} | namespace {} | queries {} | summarize {} | finalize {}", "[perf] eval stage ms → slice {} | db {} | corpus {} | namespace {} | queries {} | summarize {} | finalize {}",
@@ -107,12 +105,13 @@ mod tests {
fn sample_stage_latency() -> crate::eval::StageLatencyBreakdown { fn sample_stage_latency() -> crate::eval::StageLatencyBreakdown {
crate::eval::StageLatencyBreakdown { crate::eval::StageLatencyBreakdown {
embed: sample_latency(), stages: ["embed", "search", "rerank", "resolve_entities", "assemble"]
collect_candidates: sample_latency(), .into_iter()
graph_expansion: sample_latency(), .map(|stage| crate::eval::StageLatency {
chunk_attach: sample_latency(), stage: stage.to_string(),
rerank: sample_latency(), stats: sample_latency(),
assemble: sample_latency(), })
.collect(),
} }
} }
@@ -193,7 +192,7 @@ mod tests {
rerank_keep_top: 10, rerank_keep_top: 10,
concurrency: 2, concurrency: 2,
detailed_report: false, detailed_report: false,
retrieval_strategy: "initial".into(), resolve_entities: false,
chunk_result_cap: 5, chunk_result_cap: 5,
chunk_rrf_k: 60.0, chunk_rrf_k: 60.0,
chunk_rrf_vector_weight: 1.0, chunk_rrf_vector_weight: 1.0,
@@ -206,7 +205,6 @@ mod tests {
ingest_chunk_overlap_tokens: 50, ingest_chunk_overlap_tokens: 50,
chunk_vector_take: 20, chunk_vector_take: 20,
chunk_fts_take: 20, chunk_fts_take: 20,
chunk_avg_chars_per_token: 4,
max_chunks_per_entity: 4, max_chunks_per_entity: 4,
cases: Vec::new(), cases: Vec::new(),
} }
+7 -23
View File
@@ -6,7 +6,7 @@ use futures::stream::{self, StreamExt};
use tracing::{debug, info}; use tracing::{debug, info};
use crate::eval::{ use crate::eval::{
adapt_strategy_output, build_case_diagnostics, text_contains_answer, CaseDiagnostics, adapt_retrieval_output, build_case_diagnostics, text_contains_answer, CaseDiagnostics,
CaseSummary, RetrievedSummary, CaseSummary, RetrievedSummary,
}; };
use retrieval_pipeline::{ use retrieval_pipeline::{
@@ -51,14 +51,8 @@ pub(crate) async fn run_queries(
None None
}; };
let mut retrieval_config = RetrievalConfig { let mut retrieval_config = RetrievalConfig::default();
strategy: config.retrieval.strategy,
..Default::default()
};
retrieval_config.tuning.rerank_keep_top = config.retrieval.rerank_keep_top; retrieval_config.tuning.rerank_keep_top = config.retrieval.rerank_keep_top;
if retrieval_config.tuning.fallback_min_results < config.retrieval.rerank_keep_top {
retrieval_config.tuning.fallback_min_results = config.retrieval.rerank_keep_top;
}
retrieval_config.tuning.chunk_result_cap = config.retrieval.chunk_result_cap.max(1); retrieval_config.tuning.chunk_result_cap = config.retrieval.chunk_result_cap.max(1);
if let Some(value) = config.retrieval.chunk_vector_take { if let Some(value) = config.retrieval.chunk_vector_take {
retrieval_config.tuning.chunk_vector_take = value; retrieval_config.tuning.chunk_vector_take = value;
@@ -81,9 +75,6 @@ pub(crate) async fn run_queries(
if let Some(value) = config.retrieval.chunk_rrf_use_fts { if let Some(value) = config.retrieval.chunk_rrf_use_fts {
retrieval_config.tuning.flags.chunk_rrf_use_fts = value.into(); retrieval_config.tuning.flags.chunk_rrf_use_fts = value.into();
} }
if let Some(value) = config.retrieval.chunk_avg_chars_per_token {
retrieval_config.tuning.avg_chars_per_token = value;
}
if let Some(value) = config.retrieval.max_chunks_per_entity { if let Some(value) = config.retrieval.max_chunks_per_entity {
retrieval_config.tuning.max_chunks_per_entity = value; retrieval_config.tuning.max_chunks_per_entity = value;
} }
@@ -187,7 +178,7 @@ pub(crate) async fn run_queries(
None => None, None => None,
}; };
let params = pipeline::StrategyParams { let params = pipeline::RetrievalParams {
db_client: &db, db_client: &db,
openai_client: &openai_client, openai_client: &openai_client,
embedding_provider: Some(&embedding_provider), embedding_provider: Some(&embedding_provider),
@@ -196,26 +187,19 @@ pub(crate) async fn run_queries(
config: (*retrieval_config).clone(), config: (*retrieval_config).clone(),
reranker, reranker,
}; };
let (result_output, pipeline_diagnostics, stage_timings) = if diagnostics_enabled { let (result_output, pipeline_diagnostics, stage_timings) = {
let outcome = pipeline::run_pipeline_with_embedding_with_diagnostics( let outcome = pipeline::run_with_embedding_instrumented(
params, params,
query_embedding, query_embedding,
diagnostics_enabled,
) )
.await .await
.with_context(|| format!("running pipeline for question {question_id}"))?; .with_context(|| format!("running pipeline for question {question_id}"))?;
(outcome.results, outcome.diagnostics, outcome.stage_timings) (outcome.results, outcome.diagnostics, outcome.stage_timings)
} else {
let outcome = pipeline::run_pipeline_with_embedding_with_metrics(
params,
query_embedding,
)
.await
.with_context(|| format!("running pipeline for question {question_id}"))?;
(outcome.results, None, outcome.stage_timings)
}; };
let query_latency = query_start.elapsed().as_millis(); let query_latency = query_start.elapsed().as_millis();
let candidates = adapt_strategy_output(result_output); let candidates = adapt_retrieval_output(result_output);
let mut retrieved = Vec::new(); let mut retrieved = Vec::new();
let mut match_rank = None; let mut match_rank = None;
let answers_lower: Vec<String> = let answers_lower: Vec<String> =
+4 -2
View File
@@ -201,7 +201,10 @@ pub(crate) async fn summarize(
rerank_keep_top: config.retrieval.rerank_keep_top, rerank_keep_top: config.retrieval.rerank_keep_top,
concurrency: config.concurrency.max(1), concurrency: config.concurrency.max(1),
detailed_report: config.detailed_report, detailed_report: config.detailed_report,
retrieval_strategy: config.retrieval.strategy.to_string(), resolve_entities: ctx
.retrieval_config
.as_ref()
.is_some_and(|config| config.resolve_entities),
chunk_result_cap: config.retrieval.chunk_result_cap, chunk_result_cap: config.retrieval.chunk_result_cap,
chunk_rrf_k: active_tuning.chunk_rrf_k, chunk_rrf_k: active_tuning.chunk_rrf_k,
chunk_rrf_vector_weight: active_tuning.chunk_rrf_vector_weight, chunk_rrf_vector_weight: active_tuning.chunk_rrf_vector_weight,
@@ -214,7 +217,6 @@ pub(crate) async fn summarize(
ingest_chunk_overlap_tokens: config.ingest.ingest_chunk_overlap_tokens, ingest_chunk_overlap_tokens: config.ingest.ingest_chunk_overlap_tokens,
chunk_vector_take: active_tuning.chunk_vector_take, chunk_vector_take: active_tuning.chunk_vector_take,
chunk_fts_take: active_tuning.chunk_fts_take, chunk_fts_take: active_tuning.chunk_fts_take,
chunk_avg_chars_per_token: active_tuning.avg_chars_per_token,
max_chunks_per_entity: active_tuning.max_chunks_per_entity, max_chunks_per_entity: active_tuning.max_chunks_per_entity,
cases: summaries, cases: summaries,
}); });
+54 -49
View File
@@ -85,7 +85,7 @@ pub struct RetrievalSection {
pub average_ndcg: f64, pub average_ndcg: f64,
pub latency: LatencyStats, pub latency: LatencyStats,
pub concurrency: usize, pub concurrency: usize,
pub strategy: String, pub resolve_entities: bool,
pub rerank_enabled: bool, pub rerank_enabled: bool,
pub rerank_pool_size: Option<usize>, pub rerank_pool_size: Option<usize>,
pub rerank_keep_top: usize, pub rerank_keep_top: usize,
@@ -226,7 +226,7 @@ impl EvaluationReport {
average_ndcg: summary.average_ndcg, average_ndcg: summary.average_ndcg,
latency: summary.latency_ms.clone(), latency: summary.latency_ms.clone(),
concurrency: summary.concurrency, concurrency: summary.concurrency,
strategy: summary.retrieval_strategy.clone(), resolve_entities: summary.resolve_entities,
rerank_enabled: summary.rerank_enabled, rerank_enabled: summary.rerank_enabled,
rerank_pool_size: summary.rerank_pool_size, rerank_pool_size: summary.rerank_pool_size,
rerank_keep_top: summary.rerank_keep_top, rerank_keep_top: summary.rerank_keep_top,
@@ -463,7 +463,12 @@ fn render_markdown(report: &EvaluationReport) -> String {
write!(md, "| MRR | {:.3} |\\n", report.retrieval.mrr).unwrap(); write!(md, "| MRR | {:.3} |\\n", report.retrieval.mrr).unwrap();
write!(md, "| NDCG | {:.3} |\\n", report.retrieval.average_ndcg).unwrap(); write!(md, "| NDCG | {:.3} |\\n", report.retrieval.average_ndcg).unwrap();
write!(md, "| Latency Avg / P50 / P95 (ms) | {:.1} / {} / {} |\\n", report.retrieval.latency.avg, report.retrieval.latency.p50, report.retrieval.latency.p95).unwrap(); write!(md, "| Latency Avg / P50 / P95 (ms) | {:.1} / {} / {} |\\n", report.retrieval.latency.avg, report.retrieval.latency.p50, report.retrieval.latency.p95).unwrap();
write!(md, "| Strategy | `{}` |\\n", report.retrieval.strategy).unwrap(); write!(
md,
"| Resolve entities | {} |\\n",
bool_badge(report.retrieval.resolve_entities)
)
.unwrap();
write!(md, "| Concurrency | {} |\\n", report.retrieval.concurrency).unwrap(); write!(md, "| Concurrency | {} |\\n", report.retrieval.concurrency).unwrap();
if report.retrieval.rerank_enabled { if report.retrieval.rerank_enabled {
let pool = report let pool = report
@@ -510,28 +515,9 @@ fn render_markdown(report: &EvaluationReport) -> String {
md.push_str("\\n## Retrieval Stage Timings\\n\\n"); md.push_str("\\n## Retrieval Stage Timings\\n\\n");
md.push_str("| Stage | Avg (ms) | P50 (ms) | P95 (ms) |\\n| --- | --- | --- | --- |\\n"); md.push_str("| Stage | Avg (ms) | P50 (ms) | P95 (ms) |\\n| --- | --- | --- | --- |\\n");
write_stage_row(&mut md, "Embed", &report.performance.stage_latency.embed); for stage in &report.performance.stage_latency.stages {
write_stage_row( write_stage_row(&mut md, &prettify_stage(&stage.stage), &stage.stats);
&mut md, }
"Collect Candidates",
&report.performance.stage_latency.collect_candidates,
);
write_stage_row(
&mut md,
"Graph Expansion",
&report.performance.stage_latency.graph_expansion,
);
write_stage_row(
&mut md,
"Chunk Attach",
&report.performance.stage_latency.chunk_attach,
);
write_stage_row(&mut md, "Rerank", &report.performance.stage_latency.rerank);
write_stage_row(
&mut md,
"Assemble",
&report.performance.stage_latency.assemble,
);
if report.misses.is_empty() { if report.misses.is_empty() {
if report.detailed_report { if report.detailed_report {
@@ -623,6 +609,20 @@ fn write_stage_row(buf: &mut String, label: &str, stats: &LatencyStats) {
.unwrap(); .unwrap();
} }
/// Turn a stable stage label (e.g. `resolve_entities`) into a display title (`Resolve Entities`).
fn prettify_stage(label: &str) -> String {
label
.split('_')
.map(|word| {
let mut chars = word.chars();
chars.next().map_or_else(String::new, |first| {
first.to_uppercase().collect::<String>() + chars.as_str()
})
})
.collect::<Vec<_>>()
.join(" ")
}
fn bool_badge(value: bool) -> &'static str { fn bool_badge(value: bool) -> &'static str {
if value { if value {
"" ""
@@ -740,17 +740,6 @@ struct LegacyHistoryDelta {
latency_avg_ms: f64, latency_avg_ms: f64,
} }
fn default_stage_latency() -> StageLatencyBreakdown {
StageLatencyBreakdown {
embed: LatencyStats::default(),
collect_candidates: LatencyStats::default(),
graph_expansion: LatencyStats::default(),
chunk_attach: LatencyStats::default(),
rerank: LatencyStats::default(),
assemble: LatencyStats::default(),
}
}
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
fn convert_legacy_entry(entry: LegacyHistoryEntry) -> EvaluationReport { fn convert_legacy_entry(entry: LegacyHistoryEntry) -> EvaluationReport {
let overview = OverviewSection { let overview = OverviewSection {
@@ -807,7 +796,7 @@ fn convert_legacy_entry(entry: LegacyHistoryEntry) -> EvaluationReport {
average_ndcg: entry.average_ndcg, average_ndcg: entry.average_ndcg,
latency: entry.latency_ms, latency: entry.latency_ms,
concurrency: 0, concurrency: 0,
strategy: "unknown".into(), resolve_entities: false,
rerank_enabled: entry.rerank_enabled, rerank_enabled: entry.rerank_enabled,
rerank_pool_size: entry.rerank_pool_size, rerank_pool_size: entry.rerank_pool_size,
rerank_keep_top: entry.rerank_keep_top, rerank_keep_top: entry.rerank_keep_top,
@@ -840,7 +829,7 @@ fn convert_legacy_entry(entry: LegacyHistoryEntry) -> EvaluationReport {
ingestion_ms: entry.ingestion_ms, ingestion_ms: entry.ingestion_ms,
namespace_seed_ms: entry.namespace_seed_ms, namespace_seed_ms: entry.namespace_seed_ms,
evaluation_stages_ms: EvaluationStageTimings::default(), evaluation_stages_ms: EvaluationStageTimings::default(),
stage_latency: default_stage_latency(), stage_latency: StageLatencyBreakdown::default(),
namespace_reused: false, namespace_reused: false,
ingestion_reused: entry.ingestion_reused, ingestion_reused: entry.ingestion_reused,
embeddings_reused: entry.ingestion_embeddings_reused, embeddings_reused: entry.ingestion_embeddings_reused,
@@ -915,7 +904,8 @@ fn record_history(report: &EvaluationReport, report_dir: &Path) -> Result<PathBu
mod tests { mod tests {
use super::*; use super::*;
use crate::eval::{ use crate::eval::{
EvaluationStageTimings, PerformanceTimings, RetrievedSummary, StageLatencyBreakdown, EvaluationStageTimings, PerformanceTimings, RetrievedSummary, StageLatency,
StageLatencyBreakdown,
}; };
use chrono::Utc; use chrono::Utc;
use tempfile::tempdir; use tempfile::tempdir;
@@ -931,12 +921,28 @@ mod tests {
fn sample_stage_latency() -> StageLatencyBreakdown { fn sample_stage_latency() -> StageLatencyBreakdown {
StageLatencyBreakdown { StageLatencyBreakdown {
embed: latency(9.0), stages: vec![
collect_candidates: latency(10.0), StageLatency {
graph_expansion: latency(11.0), stage: "embed".to_string(),
chunk_attach: latency(12.0), stats: latency(9.0),
rerank: latency(13.0), },
assemble: latency(14.0), StageLatency {
stage: "search".to_string(),
stats: latency(10.0),
},
StageLatency {
stage: "rerank".to_string(),
stats: latency(13.0),
},
StageLatency {
stage: "resolve_entities".to_string(),
stats: latency(11.0),
},
StageLatency {
stage: "assemble".to_string(),
stats: latency(14.0),
},
],
} }
} }
@@ -1058,7 +1064,7 @@ mod tests {
rerank_keep_top: 5, rerank_keep_top: 5,
concurrency: 2, concurrency: 2,
detailed_report: true, detailed_report: true,
retrieval_strategy: "initial".into(), resolve_entities: false,
chunk_result_cap: 5, chunk_result_cap: 5,
chunk_rrf_k: 60.0, chunk_rrf_k: 60.0,
chunk_rrf_vector_weight: 1.0, chunk_rrf_vector_weight: 1.0,
@@ -1071,7 +1077,6 @@ mod tests {
ingest_chunks_only: false, ingest_chunks_only: false,
chunk_vector_take: 50, chunk_vector_take: 50,
chunk_fts_take: 50, chunk_fts_take: 50,
chunk_avg_chars_per_token: 4,
max_chunks_per_entity: 4, max_chunks_per_entity: 4,
cases, cases,
} }
@@ -1097,7 +1102,7 @@ mod tests {
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::indexing_slicing)] #[allow(clippy::unwrap_used, clippy::expect_used, clippy::indexing_slicing)]
#[test] #[test]
fn evaluations_history_captures_strategy_and_concurrency() { fn evaluations_history_captures_resolve_entities_and_concurrency() {
let tmp = tempdir().unwrap(); let tmp = tempdir().unwrap();
let summary = sample_summary(false); let summary = sample_summary(false);
@@ -1109,7 +1114,7 @@ mod tests {
assert_eq!(entries.len(), 1); assert_eq!(entries.len(), 1);
let stored = &entries[0]; let stored = &entries[0];
assert_eq!(stored.retrieval.concurrency, summary.concurrency); assert_eq!(stored.retrieval.concurrency, summary.concurrency);
assert_eq!(stored.retrieval.strategy, summary.retrieval_strategy); assert_eq!(stored.retrieval.resolve_entities, summary.resolve_entities);
assert_eq!( assert_eq!(
stored.performance.evaluation_stages_ms.run_queries_ms, stored.performance.evaluation_stages_ms.run_queries_ms,
summary.perf.evaluation_stage_ms.run_queries_ms summary.perf.evaluation_stage_ms.run_queries_ms
+33 -38
View File
@@ -3,7 +3,7 @@ use std::collections::HashSet;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use common::storage::types::StoredObject; use common::storage::types::StoredObject;
use retrieval_pipeline::{ use retrieval_pipeline::{
PipelineDiagnostics, PipelineStageTimings, RetrievedChunk, RetrievedEntity, StrategyOutput, Diagnostics, RetrievalOutput, RetrievedChunk, RetrievedEntity, StageKind, StageTimings,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use unicode_normalization::UnicodeNormalization; use unicode_normalization::UnicodeNormalization;
@@ -69,7 +69,7 @@ pub struct EvaluationSummary {
pub rerank_keep_top: usize, pub rerank_keep_top: usize,
pub concurrency: usize, pub concurrency: usize,
pub detailed_report: bool, pub detailed_report: bool,
pub retrieval_strategy: String, pub resolve_entities: bool,
pub chunk_result_cap: usize, pub chunk_result_cap: usize,
pub chunk_rrf_k: f32, pub chunk_rrf_k: f32,
pub chunk_rrf_vector_weight: f32, pub chunk_rrf_vector_weight: f32,
@@ -82,7 +82,6 @@ pub struct EvaluationSummary {
pub ingest_chunk_overlap_tokens: usize, pub ingest_chunk_overlap_tokens: usize,
pub chunk_vector_take: usize, pub chunk_vector_take: usize,
pub chunk_fts_take: usize, pub chunk_fts_take: usize,
pub chunk_avg_chars_per_token: usize,
pub max_chunks_per_entity: usize, pub max_chunks_per_entity: usize,
pub cases: Vec<CaseSummary>, pub cases: Vec<CaseSummary>,
} }
@@ -129,14 +128,20 @@ impl Default for LatencyStats {
} }
} }
/// Latency statistics for a single retrieval stage, keyed by the stage's stable label.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StageLatency {
pub stage: String,
pub stats: LatencyStats,
}
/// Per-stage retrieval latency, in canonical pipeline order.
///
/// The set of stages is driven entirely by [`StageKind::ALL`], so adding a retrieval stage
/// surfaces here automatically without changes to the evaluation harness.
#[derive(Debug, Clone, Serialize, Deserialize, Default)] #[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct StageLatencyBreakdown { pub struct StageLatencyBreakdown {
pub embed: LatencyStats, pub stages: Vec<StageLatency>,
pub collect_candidates: LatencyStats,
pub graph_expansion: LatencyStats,
pub chunk_attach: LatencyStats,
pub rerank: LatencyStats,
pub assemble: LatencyStats,
} }
#[allow(clippy::struct_field_names)] #[allow(clippy::struct_field_names)]
@@ -232,13 +237,12 @@ fn candidates_from_chunks(chunks: Vec<RetrievedChunk>) -> Vec<EvaluationCandidat
.collect() .collect()
} }
pub fn adapt_strategy_output(output: StrategyOutput) -> Vec<EvaluationCandidate> { pub fn adapt_retrieval_output(output: RetrievalOutput) -> Vec<EvaluationCandidate> {
match output { match output {
StrategyOutput::Entities(entities) => candidates_from_entities(entities), RetrievalOutput::Chunks(chunks) => candidates_from_chunks(chunks),
StrategyOutput::Chunks(chunks) => candidates_from_chunks(chunks), RetrievalOutput::WithEntities { chunks, entities } => {
StrategyOutput::Search(search_result) => { let mut candidates = candidates_from_entities(entities);
let mut candidates = candidates_from_entities(search_result.entities); candidates.extend(candidates_from_chunks(chunks));
candidates.extend(candidates_from_chunks(search_result.chunks));
candidates.sort_by(|a, b| b.score.total_cmp(&a.score)); candidates.sort_by(|a, b| b.score.total_cmp(&a.score));
candidates candidates
} }
@@ -262,7 +266,7 @@ pub struct CaseDiagnostics {
pub attached_chunk_ids: Vec<String>, pub attached_chunk_ids: Vec<String>,
pub retrieved: Vec<EntityDiagnostics>, pub retrieved: Vec<EntityDiagnostics>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub pipeline: Option<PipelineDiagnostics>, pub pipeline: Option<Diagnostics>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
@@ -366,28 +370,19 @@ pub fn compute_latency_stats(latencies: &[u128]) -> LatencyStats {
LatencyStats { avg, p50, p95 } LatencyStats { avg, p50, p95 }
} }
pub fn build_stage_latency_breakdown(samples: &[PipelineStageTimings]) -> StageLatencyBreakdown { pub fn build_stage_latency_breakdown(samples: &[StageTimings]) -> StageLatencyBreakdown {
fn collect_stage<F>(samples: &[PipelineStageTimings], selector: F) -> Vec<u128> let stages = StageKind::ALL
where .iter()
F: Fn(&PipelineStageTimings) -> u128, .map(|kind| {
{ let latencies: Vec<u128> = samples.iter().map(|s| s.stage_ms(*kind)).collect();
samples.iter().map(selector).collect() StageLatency {
} stage: kind.label().to_string(),
stats: compute_latency_stats(&latencies),
}
})
.collect();
StageLatencyBreakdown { StageLatencyBreakdown { stages }
embed: compute_latency_stats(&collect_stage(samples, retrieval_pipeline::StageTimings::embed_ms)),
collect_candidates: compute_latency_stats(&collect_stage(samples, |entry| {
entry.collect_candidates_ms()
})),
graph_expansion: compute_latency_stats(&collect_stage(samples, |entry| {
entry.graph_expansion_ms()
})),
chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| {
entry.chunk_attach_ms()
})),
rerank: compute_latency_stats(&collect_stage(samples, retrieval_pipeline::StageTimings::rerank_ms)),
assemble: compute_latency_stats(&collect_stage(samples, retrieval_pipeline::StageTimings::assemble_ms)),
}
} }
#[allow( #[allow(
@@ -412,7 +407,7 @@ pub fn build_case_diagnostics(
expected_chunk_ids: &[String], expected_chunk_ids: &[String],
answers_lower: &[String], answers_lower: &[String],
candidates: &[EvaluationCandidate], candidates: &[EvaluationCandidate],
pipeline_stats: Option<PipelineDiagnostics>, pipeline_stats: Option<Diagnostics>,
) -> CaseDiagnostics { ) -> CaseDiagnostics {
let expected_set: HashSet<&str> = expected_chunk_ids.iter().map(std::string::String::as_str).collect(); let expected_set: HashSet<&str> = expected_chunk_ids.iter().map(std::string::String::as_str).collect();
let mut seen_chunks: HashSet<String> = HashSet::new(); let mut seen_chunks: HashSet<String> = HashSet::new();
-91
View File
@@ -44,7 +44,6 @@
--leading-snug: 1.375; --leading-snug: 1.375;
--leading-relaxed: 1.625; --leading-relaxed: 1.625;
--ease-out: cubic-bezier(0, 0, 0.2, 1); --ease-out: cubic-bezier(0, 0, 0.2, 1);
--ease-in-out: cubic-bezier(0.4, 0, 0.2, 1);
--animate-pulse: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite; --animate-pulse: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite;
--default-transition-duration: 150ms; --default-transition-duration: 150ms;
--default-transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1); --default-transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
@@ -285,37 +284,6 @@
} }
} }
} }
.drawer-open {
> .drawer-side {
overflow-y: auto;
}
> .drawer-toggle {
display: none;
& ~ .drawer-side {
pointer-events: auto;
visibility: visible;
position: sticky;
display: block;
width: auto;
overscroll-behavior: auto;
opacity: 100%;
& > .drawer-overlay {
cursor: default;
background-color: transparent;
}
& > *:not(.drawer-overlay) {
translate: 0%;
[dir="rtl"] & {
translate: 0%;
}
}
}
&:checked ~ .drawer-side {
pointer-events: auto;
visibility: visible;
}
}
}
.drawer-toggle { .drawer-toggle {
position: fixed; position: fixed;
height: calc(0.25rem * 0); height: calc(0.25rem * 0);
@@ -1074,22 +1042,6 @@
grid-row-start: 1; grid-row-start: 1;
min-width: calc(0.25rem * 0); min-width: calc(0.25rem * 0);
} }
.chat-image {
grid-row: span 2 / span 2;
align-self: flex-end;
}
.chat-footer {
grid-row-start: 3;
display: flex;
gap: calc(0.25rem * 1);
font-size: 0.6875rem;
}
.chat-header {
grid-row-start: 1;
display: flex;
gap: calc(0.25rem * 1);
font-size: 0.6875rem;
}
.container { .container {
width: 100%; width: 100%;
@media (width >= 40rem) { @media (width >= 40rem) {
@@ -1796,9 +1748,6 @@
.w-10 { .w-10 {
width: calc(var(--spacing) * 10); width: calc(var(--spacing) * 10);
} }
.w-11 {
width: calc(var(--spacing) * 11);
}
.w-11\/12 { .w-11\/12 {
width: calc(11/12 * 100%); width: calc(11/12 * 100%);
} }
@@ -1862,9 +1811,6 @@
.flex-none { .flex-none {
flex: none; flex: none;
} }
.flex-shrink {
flex-shrink: 1;
}
.flex-shrink-0 { .flex-shrink-0 {
flex-shrink: 0; flex-shrink: 0;
} }
@@ -1877,13 +1823,6 @@
.grow { .grow {
flex-grow: 1; flex-grow: 1;
} }
.border-collapse {
border-collapse: collapse;
}
.-translate-y-1 {
--tw-translate-y: calc(var(--spacing) * -1);
translate: var(--tw-translate-x) var(--tw-translate-y);
}
.-translate-y-1\/2 { .-translate-y-1\/2 {
--tw-translate-y: calc(calc(1/2 * 100%) * -1); --tw-translate-y: calc(calc(1/2 * 100%) * -1);
translate: var(--tw-translate-x) var(--tw-translate-y); translate: var(--tw-translate-x) var(--tw-translate-y);
@@ -1956,9 +1895,6 @@
.justify-start { .justify-start {
justify-content: flex-start; justify-content: flex-start;
} }
.gap-0 {
gap: calc(var(--spacing) * 0);
}
.gap-0\.5 { .gap-0\.5 {
gap: calc(var(--spacing) * 0.5); gap: calc(var(--spacing) * 0.5);
} }
@@ -2115,9 +2051,6 @@
.bg-transparent { .bg-transparent {
background-color: transparent; background-color: transparent;
} }
.bg-warning {
background-color: var(--color-warning);
}
.bg-warning\/10 { .bg-warning\/10 {
background-color: var(--color-warning); background-color: var(--color-warning);
@supports (color: color-mix(in lab, red, red)) { @supports (color: color-mix(in lab, red, red)) {
@@ -2136,9 +2069,6 @@
.loading-spinner { .loading-spinner {
mask-image: url("data:image/svg+xml,%3Csvg width='24' height='24' stroke='black' viewBox='0 0 24 24' xmlns='http://www.w3.org/2000/svg'%3E%3Cg transform-origin='center'%3E%3Ccircle cx='12' cy='12' r='9.5' fill='none' stroke-width='3' stroke-linecap='round'%3E%3CanimateTransform attributeName='transform' type='rotate' from='0 12 12' to='360 12 12' dur='2s' repeatCount='indefinite'/%3E%3Canimate attributeName='stroke-dasharray' values='0,150;42,150;42,150' keyTimes='0;0.475;1' dur='1.5s' repeatCount='indefinite'/%3E%3Canimate attributeName='stroke-dashoffset' values='0;-16;-59' keyTimes='0;0.475;1' dur='1.5s' repeatCount='indefinite'/%3E%3C/circle%3E%3C/g%3E%3C/svg%3E"); mask-image: url("data:image/svg+xml,%3Csvg width='24' height='24' stroke='black' viewBox='0 0 24 24' xmlns='http://www.w3.org/2000/svg'%3E%3Cg transform-origin='center'%3E%3Ccircle cx='12' cy='12' r='9.5' fill='none' stroke-width='3' stroke-linecap='round'%3E%3CanimateTransform attributeName='transform' type='rotate' from='0 12 12' to='360 12 12' dur='2s' repeatCount='indefinite'/%3E%3Canimate attributeName='stroke-dasharray' values='0,150;42,150;42,150' keyTimes='0;0.475;1' dur='1.5s' repeatCount='indefinite'/%3E%3Canimate attributeName='stroke-dashoffset' values='0;-16;-59' keyTimes='0;0.475;1' dur='1.5s' repeatCount='indefinite'/%3E%3C/circle%3E%3C/g%3E%3C/svg%3E");
} }
.mask-repeat {
mask-repeat: repeat;
}
.fill-current { .fill-current {
fill: currentcolor; fill: currentcolor;
} }
@@ -2169,9 +2099,6 @@
.p-8 { .p-8 {
padding: calc(var(--spacing) * 8); padding: calc(var(--spacing) * 8);
} }
.px-1 {
padding-inline: calc(var(--spacing) * 1);
}
.px-1\.5 { .px-1\.5 {
padding-inline: calc(var(--spacing) * 1.5); padding-inline: calc(var(--spacing) * 1.5);
} }
@@ -2326,9 +2253,6 @@
--tw-tracking: var(--tracking-widest); --tw-tracking: var(--tracking-widest);
letter-spacing: var(--tracking-widest); letter-spacing: var(--tracking-widest);
} }
.text-wrap {
text-wrap: wrap;
}
.break-words { .break-words {
overflow-wrap: break-word; overflow-wrap: break-word;
} }
@@ -2395,17 +2319,6 @@
.italic { .italic {
font-style: italic; font-style: italic;
} }
.underline {
text-decoration-line: underline;
}
.swap-active {
.swap-off {
opacity: 0%;
}
.swap-on {
opacity: 100%;
}
}
.opacity-0 { .opacity-0 {
opacity: 0%; opacity: 0%;
} }
@@ -2496,10 +2409,6 @@
--tw-duration: 300ms; --tw-duration: 300ms;
transition-duration: 300ms; transition-duration: 300ms;
} }
.ease-in-out {
--tw-ease: var(--ease-in-out);
transition-timing-function: var(--ease-in-out);
}
.ease-out { .ease-out {
--tw-ease: var(--ease-out); --tw-ease: var(--ease-out);
transition-timing-function: var(--ease-out); transition-timing-function: var(--ease-out);
+1 -8
View File
@@ -2,10 +2,7 @@ use common::storage::types::conversation::SidebarConversation;
use common::storage::{db::SurrealDbClient, store::StorageManager}; use common::storage::{db::SurrealDbClient, store::StorageManager};
use common::utils::embedding::EmbeddingProvider; use common::utils::embedding::EmbeddingProvider;
use common::utils::template_engine::{ProvidesTemplateEngine, TemplateEngine}; use common::utils::template_engine::{ProvidesTemplateEngine, TemplateEngine};
use common::{ use common::{create_template_engine, storage::db::ProvidesDb, utils::config::AppConfig};
create_template_engine, storage::db::ProvidesDb,
utils::config::{AppConfig, RetrievalStrategy},
};
use retrieval_pipeline::reranking::RerankerPool; use retrieval_pipeline::reranking::RerankerPool;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{ use std::sync::{
@@ -75,10 +72,6 @@ impl HtmlState {
} }
} }
pub fn retrieval_strategy(&self) -> RetrievalStrategy {
self.config.resolved_retrieval_strategy()
}
pub async fn get_cached_conversation_archive( pub async fn get_cached_conversation_archive(
&self, &self,
user_id: &str, user_id: &str,
@@ -16,12 +16,9 @@ use futures::{
}; };
use json_stream_parser::JsonStreamParser; use json_stream_parser::JsonStreamParser;
use minijinja::Value; use minijinja::Value;
use retrieval_pipeline::{ use retrieval_pipeline::answer_retrieval::{
answer_retrieval::{ chunks_to_chat_context, create_chat_request, create_user_message_with_history,
chunks_to_chat_context, create_chat_request, create_user_message_with_history, LLMResponseFormat,
LLMResponseFormat,
},
retrieved_entities_to_json,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::from_str; use serde_json::from_str;
@@ -189,11 +186,7 @@ struct ReferenceData {
} }
fn extract_reference_strings(response: &LLMResponseFormat) -> Vec<String> { fn extract_reference_strings(response: &LLMResponseFormat) -> Vec<String> {
response response.reference_ids()
.references
.iter()
.map(|reference| reference.reference.clone())
.collect()
} }
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
@@ -362,10 +355,9 @@ async fn prepare_chat_request(
None => None, None => None,
}; };
let strategy = state.retrieval_strategy(); let config = retrieval_pipeline::RetrievalConfig::default();
let config = retrieval_pipeline::RetrievalConfig::for_chat(strategy);
let retrieval_result = match retrieval_pipeline::retrieve_entities( let retrieval_result = match retrieval_pipeline::retrieve(
&state.db, &state.db,
&state.openai_client, &state.openai_client,
Some(&*state.embedding_provider), Some(&*state.embedding_provider),
@@ -387,12 +379,9 @@ async fn prepare_chat_request(
let allowed_reference_ids = collect_reference_ids_from_retrieval(&retrieval_result); let allowed_reference_ids = collect_reference_ids_from_retrieval(&retrieval_result);
let context_json = match retrieval_result { let context_json = match retrieval_result {
retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(&chunks), retrieval_pipeline::RetrievalOutput::Chunks(chunks) => chunks_to_chat_context(&chunks),
retrieval_pipeline::StrategyOutput::Entities(entities) => { retrieval_pipeline::RetrievalOutput::WithEntities { chunks, .. } => {
retrieved_entities_to_json(&entities) chunks_to_chat_context(&chunks)
}
retrieval_pipeline::StrategyOutput::Search(search_result) => {
chunks_to_chat_context(&search_result.chunks)
} }
}; };
let formatted_user_message = let formatted_user_message =
@@ -9,7 +9,7 @@ use common::{
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject}, types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject},
}, },
}; };
use retrieval_pipeline::StrategyOutput; use retrieval_pipeline::RetrievalOutput;
use uuid::Uuid; use uuid::Uuid;
pub(crate) const MAX_REFERENCE_COUNT: usize = 10; pub(crate) const MAX_REFERENCE_COUNT: usize = 10;
@@ -86,40 +86,29 @@ pub(crate) enum ReferenceLookupTarget {
} }
pub(crate) fn collect_reference_ids_from_retrieval( pub(crate) fn collect_reference_ids_from_retrieval(
retrieval_result: &StrategyOutput, retrieval_result: &RetrievalOutput,
) -> Vec<String> { ) -> Vec<String> {
let mut ids = Vec::new(); let mut ids = Vec::new();
let mut seen = HashSet::new(); let mut seen = HashSet::new();
let mut push_id = |id: String| {
if seen.insert(id.clone()) {
ids.push(id);
}
};
match retrieval_result { match retrieval_result {
StrategyOutput::Chunks(chunks) => { RetrievalOutput::Chunks(chunks) => {
for chunk in chunks { for chunk in chunks {
let id = chunk.chunk.id.clone(); push_id(chunk.chunk.id.clone());
if seen.insert(id.clone()) {
ids.push(id);
}
} }
} }
StrategyOutput::Entities(entities) => { RetrievalOutput::WithEntities { chunks, entities } => {
for chunk in chunks {
push_id(chunk.chunk.id.clone());
}
for entity in entities { for entity in entities {
let id = entity.entity.id.clone(); push_id(entity.entity.id.clone());
if seen.insert(id.clone()) {
ids.push(id);
}
}
}
StrategyOutput::Search(search) => {
for chunk in &search.chunks {
let id = chunk.chunk.id.clone();
if seen.insert(id.clone()) {
ids.push(id);
}
}
for entity in &search.entities {
let id = entity.entity.id.clone();
if seen.insert(id.clone()) {
ids.push(id);
}
} }
} }
} }
+1 -1
View File
@@ -13,7 +13,7 @@ use crate::{
middlewares::{ middlewares::{
auth_middleware::RequireUser, auth_middleware::RequireUser,
response_middleware::{ response_middleware::{
template_as_response, HtmlError, TemplateResponse, TemplateResult, ResponseResult, template_as_response, TemplateResponse, TemplateResult, ResponseResult,
}, },
}, },
utils::text_content_preview::truncate_text_contents, utils::text_content_preview::truncate_text_contents,
+5 -5
View File
@@ -32,7 +32,7 @@ use crate::{
middlewares::{ middlewares::{
auth_middleware::RequireUser, auth_middleware::RequireUser,
response_middleware::{ response_middleware::{
template_with_headers, HtmlError, TemplateResponse, TemplateResult, ResponseResult, template_with_headers, TemplateResponse, TemplateResult, ResponseResult,
}, },
}, },
utils::pagination::{paginate_items, Pagination}, utils::pagination::{paginate_items, Pagination},
@@ -284,9 +284,9 @@ pub async fn suggest_knowledge_relationships(
None => None, None => None,
}; };
let config = retrieval_pipeline::RetrievalConfig::for_relationship_suggestion(); let config = retrieval_pipeline::RetrievalConfig::with_entities();
if let Ok(retrieval_pipeline::StrategyOutput::Entities(results)) = if let Ok(retrieval_pipeline::RetrievalOutput::WithEntities { entities, .. }) =
retrieval_pipeline::retrieve_entities( retrieval_pipeline::retrieve(
&state.db, &state.db,
&state.openai_client, &state.openai_client,
Some(&*state.embedding_provider), Some(&*state.embedding_provider),
@@ -297,7 +297,7 @@ pub async fn suggest_knowledge_relationships(
) )
.await .await
{ {
for retrieval_pipeline::RetrievedEntity { entity, score, .. } in results { for retrieval_pipeline::RetrievedEntity { entity, score, .. } in entities {
if suggestion_scores.len() >= MAX_RELATIONSHIP_SUGGESTIONS { if suggestion_scores.len() >= MAX_RELATIONSHIP_SUGGESTIONS {
break; break;
} }
@@ -12,7 +12,7 @@ use crate::html_state::HtmlState;
use crate::middlewares::{ use crate::middlewares::{
auth_middleware::RequireUser, auth_middleware::RequireUser,
response_middleware::{ response_middleware::{
template_with_headers, HtmlError, TemplateResponse, TemplateResult, ResponseResult, template_with_headers, TemplateResponse, TemplateResult, ResponseResult,
}, },
}; };
use common::storage::types::{ use common::storage::types::{
+22 -21
View File
@@ -4,7 +4,7 @@ use axum::{
extract::{Query, State}, extract::{Query, State},
}; };
use common::storage::types::{text_content::TextContent, user::User}; use common::storage::types::{text_content::TextContent, user::User};
use retrieval_pipeline::{RetrievalConfig, SearchResult, SearchTarget, StrategyOutput}; use retrieval_pipeline::{retrieve, RetrievalConfig, RetrievalOutput, RetrievedChunk, RetrievedEntity};
use serde::{de, Deserialize, Deserializer, Serialize}; use serde::{de, Deserialize, Deserializer, Serialize};
use std::{fmt, str::FromStr}; use std::{fmt, str::FromStr};
@@ -108,35 +108,35 @@ async fn perform_search(
return Ok((Vec::new(), String::new())); return Ok((Vec::new(), String::new()));
} }
let config = RetrievalConfig::for_search(SearchTarget::Both); let config = RetrievalConfig::with_entities();
let reranker_lease = match &state.reranker_pool { let reranker_lease = match &state.reranker_pool {
Some(pool) => pool.checkout().await, Some(pool) => pool.checkout().await,
None => None, None => None,
}; };
let params = retrieval_pipeline::pipeline::StrategyParams { let result = retrieve(
db_client: &state.db, &state.db,
openai_client: &state.openai_client, &state.openai_client,
embedding_provider: Some(&state.embedding_provider), Some(&state.embedding_provider),
input_text: trimmed_query, trimmed_query,
user_id: &user.id, &user.id,
config, config,
reranker: reranker_lease, reranker_lease,
}; )
let result = retrieval_pipeline::pipeline::execute(params).await?; .await?;
let search_result = match result { let (chunks, entities) = match result {
StrategyOutput::Search(sr) => sr, RetrievalOutput::WithEntities { chunks, entities } => (chunks, entities),
_ => SearchResult::new(vec![], vec![]), RetrievalOutput::Chunks(chunks) => (chunks, Vec::new()),
}; };
let source_label_map = collect_source_label_map(state, user, &search_result).await?; let source_label_map = collect_source_label_map(state, user, &chunks, &entities).await?;
let mut combined_results: Vec<SearchResultForTemplate> = let mut combined_results: Vec<SearchResultForTemplate> =
Vec::with_capacity(search_result.chunks.len().saturating_add(search_result.entities.len())); Vec::with_capacity(chunks.len().saturating_add(entities.len()));
for chunk_result in search_result.chunks { for chunk_result in chunks {
let source_label = source_label_map let source_label = source_label_map
.get(&chunk_result.chunk.source_id) .get(&chunk_result.chunk.source_id)
.cloned() .cloned()
@@ -155,7 +155,7 @@ async fn perform_search(
}); });
} }
for entity_result in search_result.entities { for entity_result in entities {
let source_label = source_label_map let source_label = source_label_map
.get(&entity_result.entity.source_id) .get(&entity_result.entity.source_id)
.cloned() .cloned()
@@ -187,13 +187,14 @@ async fn perform_search(
async fn collect_source_label_map( async fn collect_source_label_map(
state: &HtmlState, state: &HtmlState,
user: &User, user: &User,
search_result: &SearchResult, chunks: &[RetrievedChunk],
entities: &[RetrievedEntity],
) -> Result<std::collections::HashMap<String, String>, HtmlError> { ) -> Result<std::collections::HashMap<String, String>, HtmlError> {
let mut source_ids = HashSet::new(); let mut source_ids = HashSet::new();
for chunk_result in &search_result.chunks { for chunk_result in chunks {
source_ids.insert(chunk_result.chunk.source_id.clone()); source_ids.insert(chunk_result.chunk.source_id.clone());
} }
for entity_result in &search_result.entities { for entity_result in entities {
source_ids.insert(entity_result.entity.source_id.clone()); source_ids.insert(entity_result.entity.source_id.clone());
} }
+17 -22
View File
@@ -183,10 +183,8 @@ impl PipelineServices for DefaultPipelineServices {
None => None, None => None,
}; };
let config = retrieval_pipeline::RetrievalConfig::for_search( let config = retrieval_pipeline::RetrievalConfig::with_entities();
retrieval_pipeline::SearchTarget::EntitiesOnly, match retrieval_pipeline::retrieve(
);
match retrieval_pipeline::retrieve_entities(
&self.db, &self.db,
&self.openai_client, &self.openai_client,
Some(&*self.embedding_provider), Some(&*self.embedding_provider),
@@ -197,19 +195,16 @@ impl PipelineServices for DefaultPipelineServices {
) )
.await .await
{ {
Ok(retrieval_pipeline::StrategyOutput::Entities(entities)) => Ok(entities), Ok(retrieval_pipeline::RetrievalOutput::WithEntities { chunks, entities }) => {
Ok(retrieval_pipeline::StrategyOutput::Search(search)) => {
let chunk_count = search.chunks.len();
let entities = search.entities;
tracing::debug!( tracing::debug!(
chunk_count, chunk_count = chunks.len(),
entity_count = entities.len(), entity_count = entities.len(),
"ingestion search results returned entities" "ingestion retrieval resolved entities from chunks"
); );
Ok(entities) Ok(entities)
} }
Ok(retrieval_pipeline::StrategyOutput::Chunks(_)) => Err(AppError::InternalError( Ok(retrieval_pipeline::RetrievalOutput::Chunks(_)) => Err(AppError::InternalError(
"Ingestion retrieval should return entities".into(), "Ingestion retrieval should resolve entities".into(),
)), )),
Err(e) => Err(e), Err(e) => Err(e),
} }
@@ -372,16 +367,16 @@ mod tests {
fn system_prompt_from_request( fn system_prompt_from_request(
request: &async_openai::types::CreateChatCompletionRequest, request: &async_openai::types::CreateChatCompletionRequest,
) -> String { ) -> anyhow::Result<String> {
let ChatCompletionRequestMessage::System(system) = &request.messages[0] else { let Some(ChatCompletionRequestMessage::System(system)) = request.messages.first() else {
panic!("expected first message to be system"); anyhow::bail!("expected first message to be system");
}; };
match &system.content { let async_openai::types::ChatCompletionRequestSystemMessageContent::Text(text) =
async_openai::types::ChatCompletionRequestSystemMessageContent::Text(text) => { &system.content
text.clone() else {
} anyhow::bail!("unexpected system message content: {:?}", system.content);
other => panic!("unexpected system message content: {other:?}"), };
} Ok(text.clone())
} }
#[tokio::test] #[tokio::test]
@@ -425,7 +420,7 @@ mod tests {
.await .await
.context("prepare llm request")?; .context("prepare llm request")?;
assert_eq!(system_prompt_from_request(&request), SENTINEL); assert_eq!(system_prompt_from_request(&request)?, SENTINEL);
Ok(()) Ok(())
} }
} }
@@ -125,10 +125,10 @@ async fn render_pdf_pages(file_path: &Path, pages: &[u32]) -> Result<Vec<Vec<u8>
}) })
.await??; .await??;
for (idx, png) in captures.iter().enumerate() { for (page_number, png) in page_numbers.iter().zip(captures.iter()) {
if let Err(err) = maybe_dump_debug_image(page_numbers[idx], png).await { if let Err(err) = maybe_dump_debug_image(*page_number, png).await {
warn!( warn!(
page = page_numbers[idx], page = page_number,
error = %err, error = %err,
"Failed to write debug screenshot to disk" "Failed to write debug screenshot to disk"
); );
+1
View File
@@ -95,6 +95,7 @@ pub(crate) async fn init_with_config(config: AppConfig) -> anyhow::Result<Shared
} }
#[cfg(test)] #[cfg(test)]
#[allow(dead_code)] // helpers are shared across binary test targets
pub(crate) mod tests { pub(crate) mod tests {
use std::path::Path; use std::path::Path;
+4 -1
View File
@@ -11,8 +11,9 @@ use html_router::{
use super::SharedServices; use super::SharedServices;
/// Builds the Minne API and HTML route subtrees without fixing the outer Axum state /// Builds the Minne API and HTML route subtrees without fixing the outer Axum state
/// type. SaaS consumers can merge additional routers and attach their own `AppState` /// type. `SaaS` consumers can merge additional routers and attach their own `AppState`
/// as long as it implements `FromRef` for `ApiState` and `HtmlState`. /// as long as it implements `FromRef` for `ApiState` and `HtmlState`.
#[allow(dead_code)] // used by server/main binaries, not worker
pub fn minne_routes<S>(api_state: &ApiState, html_state: &HtmlState) -> Router<S> pub fn minne_routes<S>(api_state: &ApiState, html_state: &HtmlState) -> Router<S>
where where
S: Clone + Send + Sync + 'static, S: Clone + Send + Sync + 'static,
@@ -24,6 +25,7 @@ where
.merge(html_routes(html_state)) .merge(html_routes(html_state))
} }
#[allow(dead_code)] // used by server/main binaries, not worker
pub fn build_api_state(services: &SharedServices) -> ApiState { pub fn build_api_state(services: &SharedServices) -> ApiState {
ApiState { ApiState {
db: Arc::clone(&services.db), db: Arc::clone(&services.db),
@@ -32,6 +34,7 @@ pub fn build_api_state(services: &SharedServices) -> ApiState {
} }
} }
#[allow(dead_code)] // used by server/main binaries, not worker
pub async fn build_html_state(services: &SharedServices) -> anyhow::Result<HtmlState> { pub async fn build_html_state(services: &SharedServices) -> anyhow::Result<HtmlState> {
let session_store = Arc::new( let session_store = Arc::new(
services services
+1
View File
@@ -75,6 +75,7 @@ struct AppState {
} }
#[cfg(test)] #[cfg(test)]
#[allow(clippy::expect_used)]
mod tests { mod tests {
use super::*; use super::*;
use axum::{ use axum::{
+4 -6
View File
@@ -10,16 +10,14 @@ workspace = true
[dependencies] [dependencies]
tokio = { workspace = true } tokio = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
axum = { workspace = true }
tracing = { workspace = true } tracing = { workspace = true }
anyhow = { workspace = true }
thiserror = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }
surrealdb = { workspace = true }
futures = { workspace = true }
async-openai = { workspace = true } async-openai = { workspace = true }
async-trait = { workspace = true } async-trait = { workspace = true }
uuid = { workspace = true }
fastembed = { workspace = true } fastembed = { workspace = true }
common = { path = "../common", features = ["test-utils"] } common = { path = "../common", features = ["test-utils"] }
[dev-dependencies]
anyhow = { workspace = true }
uuid = { workspace = true }
+43 -53
View File
@@ -1,61 +1,66 @@
//! Chat answer assembly: retrieval context formatting and structured LLM request/response types.
use async_openai::{ use async_openai::{
error::OpenAIError, error::OpenAIError,
types::{ types::{
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse, CreateChatCompletionRequest, CreateChatCompletionRequestArgs, ResponseFormat,
ResponseFormat, ResponseFormatJsonSchema, ResponseFormatJsonSchema,
}, },
}; };
use common::{ use common::storage::types::{
error::AppError, message::{format_history, Message},
storage::types::{ system_settings::SystemSettings,
message::{format_history, Message},
system_settings::SystemSettings,
},
}; };
use serde::Deserialize; use serde::Deserialize;
use serde_json::Value; use serde_json::{json, Value};
use super::answer_retrieval_helper::get_query_response_schema; /// JSON schema describing the structured chat answer (answer text + references).
fn get_query_response_schema() -> Value {
json!({
"type": "object",
"properties": {
"answer": { "type": "string" },
"references": {
"type": "array",
"items": {
"type": "object",
"properties": {
"reference": { "type": "string" },
},
"required": ["reference"],
"additionalProperties": false,
}
}
},
"required": ["answer", "references"],
"additionalProperties": false
})
}
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct Reference { pub struct Reference {
#[allow(dead_code)]
pub reference: String, pub reference: String,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct LLMResponseFormat { pub struct LLMResponseFormat {
pub answer: String, pub answer: String,
#[allow(dead_code)]
pub references: Vec<Reference>, pub references: Vec<Reference>,
} }
#[derive(Debug)] impl LLMResponseFormat {
pub struct Answer { pub fn reference_ids(&self) -> Vec<String> {
pub content: String, self.references
pub references: Vec<String>, .iter()
} .map(|entry| entry.reference.clone())
.collect()
pub fn create_user_message(entities_json: &Value, query: &str) -> String {
format!(
r"
Context Information:
==================
{entities_json}
User Question:
==================
{query}
"
)
}
/// Convert chunk-based retrieval results to JSON format for LLM context
pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value {
fn round_score(value: f32) -> f64 {
(f64::from(value) * 1000.0).round() / 1000.0
} }
}
/// Convert chunk-based retrieval results to JSON format for LLM context.
pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value {
use crate::round_score;
serde_json::json!(chunks serde_json::json!(chunks
.iter() .iter()
@@ -70,7 +75,7 @@ pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value {
} }
pub fn create_user_message_with_history( pub fn create_user_message_with_history(
entities_json: &Value, context_json: &Value,
history: &[Message], history: &[Message],
query: &str, query: &str,
) -> String { ) -> String {
@@ -89,7 +94,7 @@ pub fn create_user_message_with_history(
{} {}
", ",
format_history(history), format_history(history),
entities_json, context_json,
query query
) )
} }
@@ -116,18 +121,3 @@ pub fn create_chat_request(
.response_format(response_format) .response_format(response_format)
.build() .build()
} }
pub fn process_llm_response(
response: &CreateChatCompletionResponse,
) -> Result<LLMResponseFormat, Box<AppError>> {
response
.choices
.first()
.and_then(|choice| choice.message.content.as_ref())
.ok_or_else(|| Box::new(AppError::LLMParsing("No content found in LLM response".into())))
.and_then(|content| {
serde_json::from_str::<LLMResponseFormat>(content).map_err(|e| {
Box::new(AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {e}")))
})
})
}
@@ -1,23 +0,0 @@
use serde_json::{json, Value};
pub fn get_query_response_schema() -> Value {
json!({
"type": "object",
"properties": {
"answer": { "type": "string" },
"references": {
"type": "array",
"items": {
"type": "object",
"properties": {
"reference": { "type": "string" },
},
"required": ["reference"],
"additionalProperties": false,
}
}
},
"required": ["answer", "references"],
"additionalProperties": false
})
}
-228
View File
@@ -1,228 +0,0 @@
use std::collections::{HashMap, HashSet};
use surrealdb::{sql::Thing, Error};
use common::storage::{
db::SurrealDbClient,
types::{
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
StoredObject,
},
};
/// Find entities related to the given entity via graph relationships.
///
/// Queries the `relates_to` edge table for all relationships involving the entity,
/// then fetches and returns the neighboring entities.
///
/// # Arguments
/// * `db` - Database client
/// * `entity_id` - ID of the entity to find neighbors for
/// * `user_id` - User ID for access control
/// * `limit` - Maximum number of neighbors to return
pub async fn find_entities_by_relationship_by_id(
db: &SurrealDbClient,
entity_id: &str,
user_id: &str,
limit: usize,
) -> Result<Vec<KnowledgeEntity>, Error> {
let mut relationships_response = db
.query(
"
SELECT * FROM relates_to
WHERE metadata.user_id = $user_id
AND (in = type::thing('knowledge_entity', $entity_id)
OR out = type::thing('knowledge_entity', $entity_id))
",
)
.bind(("entity_id", entity_id.to_owned()))
.bind(("user_id", user_id.to_owned()))
.await?;
let relationships: Vec<KnowledgeRelationship> = relationships_response.take(0)?;
if relationships.is_empty() {
return Ok(Vec::new());
}
let mut neighbor_ids: Vec<String> = Vec::with_capacity(relationships.len());
let mut seen: HashSet<String> = HashSet::with_capacity(relationships.len());
for rel in relationships {
if rel.in_ == entity_id {
if seen.insert(rel.out.clone()) {
neighbor_ids.push(rel.out);
}
} else if rel.out == entity_id {
if seen.insert(rel.in_.clone()) {
neighbor_ids.push(rel.in_);
}
} else {
if seen.insert(rel.in_.clone()) {
neighbor_ids.push(rel.in_.clone());
}
if seen.insert(rel.out.clone()) {
neighbor_ids.push(rel.out);
}
}
}
neighbor_ids.retain(|id| id != entity_id);
if neighbor_ids.is_empty() {
return Ok(Vec::new());
}
if limit > 0 && neighbor_ids.len() > limit {
neighbor_ids.truncate(limit);
}
let thing_ids: Vec<Thing> = neighbor_ids
.iter()
.map(|id| Thing::from((KnowledgeEntity::table_name(), id.as_str())))
.collect();
let mut neighbors_response = db
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
.bind(("table", KnowledgeEntity::table_name().to_owned()))
.bind(("things", thing_ids))
.bind(("user_id", user_id.to_owned()))
.await?;
let neighbors: Vec<KnowledgeEntity> = neighbors_response.take(0)?;
if neighbors.is_empty() {
return Ok(Vec::new());
}
let mut neighbor_map: HashMap<String, KnowledgeEntity> = neighbors
.into_iter()
.map(|entity| (entity.id.clone(), entity))
.collect();
let mut ordered = Vec::with_capacity(neighbor_ids.len());
for id in neighbor_ids {
if let Some(entity) = neighbor_map.remove(&id) {
ordered.push(entity);
}
if limit > 0 && ordered.len() >= limit {
break;
}
}
Ok(ordered)
}
#[cfg(test)]
mod tests {
use anyhow::{self, Context};
use super::*;
use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
use common::storage::types::knowledge_relationship::KnowledgeRelationship;
use uuid::Uuid;
#[tokio::test]
async fn test_find_entities_by_relationship_by_id() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
let entity_type = KnowledgeEntityType::Document;
let user_id = "user123".to_string();
let central_entity = KnowledgeEntity::new(
"central_source".to_string(),
"Central Entity".to_string(),
"Central Description".to_string(),
entity_type.clone(),
None,
user_id.clone(),
);
let related_entity1 = KnowledgeEntity::new(
"related_source1".to_string(),
"Related Entity 1".to_string(),
"Related Description 1".to_string(),
entity_type.clone(),
None,
user_id.clone(),
);
let related_entity2 = KnowledgeEntity::new(
"related_source2".to_string(),
"Related Entity 2".to_string(),
"Related Description 2".to_string(),
entity_type.clone(),
None,
user_id.clone(),
);
let unrelated_entity = KnowledgeEntity::new(
"unrelated_source".to_string(),
"Unrelated Entity".to_string(),
"Unrelated Description".to_string(),
entity_type.clone(),
None,
user_id.clone(),
);
let central_entity = db
.store_item(central_entity.clone())
.await
.with_context(|| "Failed to store central entity".to_string())?
.ok_or_else(|| anyhow::anyhow!("Central entity not returned after store"))?;
let related_entity1 = db
.store_item(related_entity1.clone())
.await
.with_context(|| "Failed to store related entity 1".to_string())?
.ok_or_else(|| anyhow::anyhow!("Related entity 1 not returned after store"))?;
let related_entity2 = db
.store_item(related_entity2.clone())
.await
.with_context(|| "Failed to store related entity 2".to_string())?
.ok_or_else(|| anyhow::anyhow!("Related entity 2 not returned after store"))?;
let _unrelated_entity = db
.store_item(unrelated_entity.clone())
.await
.with_context(|| "Failed to store unrelated entity".to_string())?
.ok_or_else(|| anyhow::anyhow!("Unrelated entity not returned after store"))?;
let source_id = "relationship_source".to_string();
let relationship1 = KnowledgeRelationship::new(
central_entity.id.clone(),
related_entity1.id.clone(),
user_id.clone(),
source_id.clone(),
"references".to_string(),
);
let relationship2 = KnowledgeRelationship::new(
central_entity.id.clone(),
related_entity2.id.clone(),
user_id.clone(),
source_id.clone(),
"contains".to_string(),
);
relationship1
.store_relationship(&db)
.await
.with_context(|| "Failed to store relationship 1".to_string())?;
relationship2
.store_relationship(&db)
.await
.with_context(|| "Failed to store relationship 2".to_string())?;
let related_entities =
find_entities_by_relationship_by_id(&db, &central_entity.id, &user_id, usize::MAX)
.await
.with_context(|| "Failed to find entities by relationship".to_string())?;
assert!(
related_entities.len() >= 2,
"Should find related entities in both directions"
);
Ok(())
}
}
+63 -115
View File
@@ -1,10 +1,9 @@
pub mod answer_retrieval; pub mod answer_retrieval;
pub mod answer_retrieval_helper;
pub mod graph;
pub mod pipeline; pub mod pipeline;
pub mod reranking; pub mod reranking;
pub mod scoring;
pub(crate) mod scoring;
use common::{ use common::{
error::AppError, error::AppError,
@@ -16,39 +15,28 @@ use common::{
use reranking::RerankerLease; use reranking::RerankerLease;
use tracing::instrument; use tracing::instrument;
// Strategy output variants - defined before pipeline module /// Result of a retrieval run.
///
/// Chunk retrieval is always performed; entities are only present when the caller
/// requested entity resolution via [`RetrievalConfig::with_entities`].
#[derive(Debug)] #[derive(Debug)]
pub enum StrategyOutput { pub enum RetrievalOutput {
Entities(Vec<RetrievedEntity>),
Chunks(Vec<RetrievedChunk>), Chunks(Vec<RetrievedChunk>),
Search(SearchResult), WithEntities {
} chunks: Vec<RetrievedChunk>,
entities: Vec<RetrievedEntity>,
/// Unified search result containing both chunks and entities },
#[derive(Debug, Clone)]
pub struct SearchResult {
pub chunks: Vec<RetrievedChunk>,
pub entities: Vec<RetrievedEntity>,
}
impl SearchResult {
pub fn new(chunks: Vec<RetrievedChunk>, entities: Vec<RetrievedEntity>) -> Self {
Self { chunks, entities }
}
pub fn is_empty(&self) -> bool {
self.chunks.is_empty() && self.entities.is_empty()
}
} }
pub use pipeline::{ pub use pipeline::{
retrieved_entities_to_json, Diagnostics, StageTimings, RetrievalConfig, retrieved_entities_to_json, Diagnostics, RetrievalConfig, RetrievalParams, StageKind,
RetrievalStrategy, RetrievalTuning, RetrievalTuningFlags, SearchTarget, StageTimings,
}; };
// Backward-compatible type aliases for external consumers /// Round a score to three decimal places for JSON output.
pub type PipelineDiagnostics = Diagnostics; pub(crate) fn round_score(value: f32) -> f64 {
pub type PipelineStageTimings = StageTimings; (f64::from(value) * 1000.0).round() / 1000.0
}
// Captures a supporting chunk plus its fused retrieval score for downstream prompts. // Captures a supporting chunk plus its fused retrieval score for downstream prompts.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -57,7 +45,7 @@ pub struct RetrievedChunk {
pub score: f32, pub score: f32,
} }
// Final entity representation returned to callers, enriched with ranked chunks. // Knowledge entity resolved from retrieved chunks, enriched with its contributing chunks.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct RetrievedEntity { pub struct RetrievedEntity {
pub entity: KnowledgeEntity, pub entity: KnowledgeEntity,
@@ -65,9 +53,9 @@ pub struct RetrievedEntity {
pub chunks: Vec<RetrievedChunk>, pub chunks: Vec<RetrievedChunk>,
} }
/// Primary orchestrator for the process of retrieving `KnowledgeEntity` values related to an `input_text` /// Run chunk-first hybrid retrieval for `input_text`, optionally resolving owning entities.
#[instrument(skip_all, fields(user_id))] #[instrument(skip_all, fields(user_id))]
pub async fn retrieve_entities( pub async fn retrieve(
db_client: &SurrealDbClient, db_client: &SurrealDbClient,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>, openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>, embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
@@ -75,8 +63,8 @@ pub async fn retrieve_entities(
user_id: &str, user_id: &str,
config: RetrievalConfig, config: RetrievalConfig,
reranker: Option<RerankerLease>, reranker: Option<RerankerLease>,
) -> Result<StrategyOutput, AppError> { ) -> Result<RetrievalOutput, AppError> {
let params = pipeline::StrategyParams { let params = pipeline::RetrievalParams {
db_client, db_client,
openai_client, openai_client,
embedding_provider, embedding_provider,
@@ -94,6 +82,7 @@ mod tests {
use anyhow::{self}; use anyhow::{self};
use async_openai::Client; use async_openai::Client;
use common::storage::indexes::ensure_runtime; use common::storage::indexes::ensure_runtime;
use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
use common::storage::types::system_settings::SystemSettings; use common::storage::types::system_settings::SystemSettings;
use uuid::Uuid; use uuid::Uuid;
@@ -133,7 +122,7 @@ mod tests {
} }
#[tokio::test] #[tokio::test]
async fn test_default_strategy_retrieves_chunks() -> anyhow::Result<()> { async fn test_chunk_retrieval_returns_chunks() -> anyhow::Result<()> {
let db = setup_test_db().await?; let db = setup_test_db().await?;
let user_id = "test_user"; let user_id = "test_user";
let chunk = TextChunk::new( let chunk = TextChunk::new(
@@ -145,7 +134,7 @@ mod tests {
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?; TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?;
let openai_client = Client::new(); let openai_client = Client::new();
let params = pipeline::StrategyParams { let params = pipeline::RetrievalParams {
db_client: &db, db_client: &db,
openai_client: &openai_client, openai_client: &openai_client,
embedding_provider: None, embedding_provider: None,
@@ -154,12 +143,13 @@ mod tests {
config: RetrievalConfig::default(), config: RetrievalConfig::default(),
reranker: None, reranker: None,
}; };
let results = pipeline::run_pipeline_with_embedding(params, test_embedding()) let results = pipeline::run_with_embedding(params, test_embedding()).await?;
.await?;
let chunks = match results { let chunks = match results {
StrategyOutput::Chunks(items) => items, RetrievalOutput::Chunks(items) => items,
other => anyhow::bail!("expected chunk results, got {other:?}"), RetrievalOutput::WithEntities { .. } => {
anyhow::bail!("expected chunk results, got entities")
}
}; };
assert!(!chunks.is_empty(), "Expected at least one retrieval result"); assert!(!chunks.is_empty(), "Expected at least one retrieval result");
@@ -171,8 +161,7 @@ mod tests {
} }
#[tokio::test] #[tokio::test]
async fn test_default_strategy_returns_chunks_from_multiple_sources( async fn test_chunk_retrieval_returns_chunks_from_multiple_sources() -> anyhow::Result<()> {
) -> anyhow::Result<()> {
let db = setup_test_db().await?; let db = setup_test_db().await?;
let user_id = "multi_source_user"; let user_id = "multi_source_user";
@@ -191,7 +180,7 @@ mod tests {
TextChunk::store_with_embedding(secondary_chunk, chunk_embedding_secondary(), &db).await?; TextChunk::store_with_embedding(secondary_chunk, chunk_embedding_secondary(), &db).await?;
let openai_client = Client::new(); let openai_client = Client::new();
let params = pipeline::StrategyParams { let params = pipeline::RetrievalParams {
db_client: &db, db_client: &db,
openai_client: &openai_client, openai_client: &openai_client,
embedding_provider: None, embedding_provider: None,
@@ -200,12 +189,13 @@ mod tests {
config: RetrievalConfig::default(), config: RetrievalConfig::default(),
reranker: None, reranker: None,
}; };
let results = pipeline::run_pipeline_with_embedding(params, test_embedding()) let results = pipeline::run_with_embedding(params, test_embedding()).await?;
.await?;
let chunks = match results { let chunks = match results {
StrategyOutput::Chunks(items) => items, RetrievalOutput::Chunks(items) => items,
other => anyhow::bail!("expected chunk results, got {other:?}"), RetrievalOutput::WithEntities { .. } => {
anyhow::bail!("expected chunk results, got entities")
}
}; };
assert!(chunks.len() >= 2, "Expected chunks from multiple sources"); assert!(chunks.len() >= 2, "Expected chunks from multiple sources");
@@ -223,96 +213,54 @@ mod tests {
} }
#[tokio::test] #[tokio::test]
async fn test_revised_strategy_returns_chunks() -> anyhow::Result<()> { async fn test_with_entities_resolves_owning_entities() -> anyhow::Result<()> {
let db = setup_test_db().await?; let db = setup_test_db().await?;
let user_id = "chunk_user"; let user_id = "entity_user";
let chunk_one = TextChunk::new(
"src_alpha".into(),
"Tokio tasks execute on worker threads managed by the runtime.".into(),
user_id.into(),
);
let chunk_two = TextChunk::new(
"src_beta".into(),
"Hyper utilizes Tokio to drive HTTP state machines efficiently.".into(),
user_id.into(),
);
TextChunk::store_with_embedding(chunk_one.clone(), chunk_embedding_primary(), &db).await?;
TextChunk::store_with_embedding(chunk_two.clone(), chunk_embedding_secondary(), &db).await?;
let config = RetrievalConfig::with_strategy(RetrievalStrategy::Default);
let openai_client = Client::new();
let params = pipeline::StrategyParams {
db_client: &db,
openai_client: &openai_client,
embedding_provider: None,
input_text: "tokio runtime worker behavior",
user_id,
config,
reranker: None,
};
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
.await?;
let chunks = match results {
StrategyOutput::Chunks(items) => items,
other => anyhow::bail!("expected chunk results, got {other:?}"),
};
assert!(
!chunks.is_empty(),
"Revised strategy should return chunk-only responses"
);
assert!(
chunks
.iter()
.any(|entry| entry.chunk.chunk.contains("Tokio")),
"Chunk results should contain relevant snippets"
);
Ok(())
}
#[tokio::test]
async fn test_search_strategy_returns_search_result() -> anyhow::Result<()> {
let db = setup_test_db().await?;
let user_id = "search_user";
let chunk = TextChunk::new( let chunk = TextChunk::new(
"search_src".into(), "entity_source".into(),
"Async Rust programming uses Tokio runtime for concurrent tasks.".into(), "Async Rust programming uses the Tokio runtime for concurrent tasks.".into(),
user_id.into(), user_id.into(),
); );
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?; TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?;
let config = RetrievalConfig::for_search(pipeline::SearchTarget::Both); let entity = KnowledgeEntity::new(
"entity_source".into(),
"Tokio Runtime".into(),
"Async runtime for Rust".into(),
KnowledgeEntityType::Document,
None,
user_id.into(),
);
db.store_item(entity).await?;
let openai_client = Client::new(); let openai_client = Client::new();
let params = pipeline::StrategyParams { let params = pipeline::RetrievalParams {
db_client: &db, db_client: &db,
openai_client: &openai_client, openai_client: &openai_client,
embedding_provider: None, embedding_provider: None,
input_text: "async rust programming", input_text: "async rust programming",
user_id, user_id,
config, config: RetrievalConfig::with_entities(),
reranker: None, reranker: None,
}; };
let results = pipeline::run_pipeline_with_embedding(params, test_embedding()) let results = pipeline::run_with_embedding(params, test_embedding()).await?;
.await?;
let StrategyOutput::Search(search_result) = results else { let RetrievalOutput::WithEntities { chunks, entities } = results else {
anyhow::bail!("expected Search output"); anyhow::bail!("expected WithEntities output");
}; };
// Should return chunks (entities may be empty if none stored) assert!(!chunks.is_empty(), "Should return chunks");
assert!( assert!(
!search_result.chunks.is_empty(), entities.iter().any(|e| e.entity.name == "Tokio Runtime"),
"Search strategy should return chunks" "Should resolve the entity owning the retrieved chunk"
); );
assert!( assert!(
search_result entities
.chunks
.iter() .iter()
.any(|c| c.chunk.chunk.contains("Tokio")), .find(|e| e.entity.name == "Tokio Runtime")
"Search results should contain relevant chunks" .is_some_and(|e| !e.chunks.is_empty()),
"Resolved entity should carry its contributing chunks"
); );
Ok(()) Ok(())
} }
+25 -128
View File
@@ -1,22 +1,5 @@
use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::scoring::FusionWeights;
pub use common::utils::config::RetrievalStrategy;
/// Configures which result types to include in Search strategy
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum SearchTarget {
/// Return only text chunks
ChunksOnly,
/// Return only knowledge entities
EntitiesOnly,
/// Return both chunks and entities (default)
#[default]
Both,
}
/// Two-variant flag that serializes as a bool for backward compatibility. /// Two-variant flag that serializes as a bool for backward compatibility.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum BoolFlag { pub enum BoolFlag {
@@ -62,30 +45,20 @@ impl<'de> Deserialize<'de> for BoolFlag {
#[derive(Debug, Clone, Copy, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct RetrievalTuningFlags { pub struct RetrievalTuningFlags {
pub rerank_scores_only: BoolFlag, pub rerank_scores_only: BoolFlag,
pub normalize_vector_scores: BoolFlag,
pub normalize_fts_scores: BoolFlag,
pub chunk_rrf_use_vector: BoolFlag, pub chunk_rrf_use_vector: BoolFlag,
pub chunk_rrf_use_fts: BoolFlag, pub chunk_rrf_use_fts: BoolFlag,
} }
impl RetrievalTuningFlags { impl RetrievalTuningFlags {
pub const fn rerank_scores_only(&self) -> bool { pub const fn rerank_scores_only(self) -> bool {
self.rerank_scores_only.as_bool() self.rerank_scores_only.as_bool()
} }
pub const fn normalize_vector_scores(&self) -> bool { pub const fn chunk_rrf_use_vector(self) -> bool {
self.normalize_vector_scores.as_bool()
}
pub const fn normalize_fts_scores(&self) -> bool {
self.normalize_fts_scores.as_bool()
}
pub const fn chunk_rrf_use_vector(&self) -> bool {
self.chunk_rrf_use_vector.as_bool() self.chunk_rrf_use_vector.as_bool()
} }
pub const fn chunk_rrf_use_fts(&self) -> bool { pub const fn chunk_rrf_use_fts(self) -> bool {
self.chunk_rrf_use_fts.as_bool() self.chunk_rrf_use_fts.as_bool()
} }
} }
@@ -94,146 +67,70 @@ impl Default for RetrievalTuningFlags {
fn default() -> Self { fn default() -> Self {
Self { Self {
rerank_scores_only: BoolFlag::Disabled, rerank_scores_only: BoolFlag::Disabled,
normalize_vector_scores: BoolFlag::Disabled,
normalize_fts_scores: BoolFlag::Enabled,
chunk_rrf_use_vector: BoolFlag::Enabled, chunk_rrf_use_vector: BoolFlag::Enabled,
chunk_rrf_use_fts: BoolFlag::Enabled, chunk_rrf_use_fts: BoolFlag::Enabled,
} }
} }
} }
/// Tunable parameters that govern each retrieval stage. /// Tunable parameters governing the chunk-first hybrid (vector + FTS, RRF-fused) retrieval.
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetrievalTuning { pub struct RetrievalTuning {
pub entity_vector_take: usize, /// Number of vector candidates to pull from the chunk embedding index.
pub chunk_vector_take: usize, pub chunk_vector_take: usize,
pub entity_fts_take: usize, /// Number of full-text candidates to pull from the chunk index.
pub chunk_fts_take: usize, pub chunk_fts_take: usize,
pub score_threshold: f32, /// Maximum chunks attached to each resolved entity.
pub fallback_min_results: usize,
pub token_budget_estimate: usize,
pub avg_chars_per_token: usize,
pub max_chunks_per_entity: usize, pub max_chunks_per_entity: usize,
pub lexical_match_weight: f32, /// Blend weight applied when mixing reranker scores with fused scores.
pub graph_traversal_seed_limit: usize,
pub graph_neighbor_limit: usize,
pub graph_score_decay: f32,
pub graph_seed_min_score: f32,
pub graph_vector_inheritance: f32,
pub rerank_blend_weight: f32, pub rerank_blend_weight: f32,
pub flags: RetrievalTuningFlags, /// Keep top-N candidates after reranking.
pub rerank_keep_top: usize, pub rerank_keep_top: usize,
/// Maximum number of chunks returned to callers.
pub chunk_result_cap: usize, pub chunk_result_cap: usize,
/// Optional fusion weights for hybrid search. If None, uses default weights. /// Reciprocal rank fusion k value for chunk merging.
pub fusion_weights: Option<FusionWeights>,
/// Reciprocal rank fusion k value for chunk merging in Revised strategy.
#[serde(default = "default_chunk_rrf_k")]
pub chunk_rrf_k: f32, pub chunk_rrf_k: f32,
/// Weight applied to vector ranks in RRF. /// Weight applied to vector ranks in RRF.
#[serde(default = "default_chunk_rrf_vector_weight")]
pub chunk_rrf_vector_weight: f32, pub chunk_rrf_vector_weight: f32,
/// Weight applied to chunk FTS ranks in RRF. /// Weight applied to chunk FTS ranks in RRF.
#[serde(default = "default_chunk_rrf_fts_weight")]
pub chunk_rrf_fts_weight: f32, pub chunk_rrf_fts_weight: f32,
pub flags: RetrievalTuningFlags,
} }
impl Default for RetrievalTuning { impl Default for RetrievalTuning {
fn default() -> Self { fn default() -> Self {
Self { Self {
entity_vector_take: 15,
chunk_vector_take: 20, chunk_vector_take: 20,
entity_fts_take: 10,
chunk_fts_take: 20, chunk_fts_take: 20,
score_threshold: 0.35,
fallback_min_results: 10,
token_budget_estimate: 10000,
avg_chars_per_token: 4,
max_chunks_per_entity: 4, max_chunks_per_entity: 4,
lexical_match_weight: 0.15,
graph_traversal_seed_limit: 5,
graph_neighbor_limit: 6,
graph_score_decay: 0.75,
graph_seed_min_score: 0.4,
graph_vector_inheritance: 0.6,
rerank_blend_weight: 0.65, rerank_blend_weight: 0.65,
flags: RetrievalTuningFlags::default(),
rerank_keep_top: 8, rerank_keep_top: 8,
chunk_result_cap: 5, chunk_result_cap: 5,
fusion_weights: None, chunk_rrf_k: 60.0,
chunk_rrf_k: default_chunk_rrf_k(), chunk_rrf_vector_weight: 1.0,
chunk_rrf_vector_weight: default_chunk_rrf_vector_weight(), chunk_rrf_fts_weight: 1.0,
chunk_rrf_fts_weight: default_chunk_rrf_fts_weight(), flags: RetrievalTuningFlags::default(),
} }
} }
} }
/// Wrapper containing tuning plus future flags for per-request overrides. /// Per-request retrieval configuration.
///
/// The pipeline always performs chunk-first hybrid retrieval. Set `resolve_entities`
/// when a caller additionally needs the `KnowledgeEntity` rows that own the retrieved
/// chunks (search, ingestion linking, relationship suggestion).
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
pub struct RetrievalConfig { pub struct RetrievalConfig {
pub strategy: RetrievalStrategy,
pub tuning: RetrievalTuning, pub tuning: RetrievalTuning,
/// Target for Search strategy (chunks, entities, or both) pub resolve_entities: bool,
pub search_target: SearchTarget,
} }
impl RetrievalConfig { impl RetrievalConfig {
pub fn new(tuning: RetrievalTuning) -> Self { /// Chunk retrieval that also resolves the owning knowledge entities.
pub fn with_entities() -> Self {
Self { Self {
strategy: RetrievalStrategy::Default,
tuning,
search_target: SearchTarget::default(),
}
}
pub fn with_strategy(strategy: RetrievalStrategy) -> Self {
Self {
strategy,
tuning: RetrievalTuning::default(), tuning: RetrievalTuning::default(),
search_target: SearchTarget::default(), resolve_entities: true,
}
}
pub fn with_tuning(strategy: RetrievalStrategy, tuning: RetrievalTuning) -> Self {
Self {
strategy,
tuning,
search_target: SearchTarget::default(),
}
}
/// Create config for chat retrieval with strategy selection support
pub fn for_chat(strategy: RetrievalStrategy) -> Self {
Self::with_strategy(strategy)
}
/// Create config for relationship suggestion (entity-only retrieval)
pub fn for_relationship_suggestion() -> Self {
Self::with_strategy(RetrievalStrategy::RelationshipSuggestion)
}
/// Create config for ingestion pipeline (entity-only retrieval)
pub fn for_ingestion() -> Self {
Self::with_strategy(RetrievalStrategy::Ingestion)
}
/// Create config for unified search (chunks and/or entities)
pub fn for_search(target: SearchTarget) -> Self {
Self {
strategy: RetrievalStrategy::Search,
tuning: RetrievalTuning::default(),
search_target: target,
} }
} }
} }
const fn default_chunk_rrf_k() -> f32 {
60.0
}
const fn default_chunk_rrf_vector_weight() -> f32 {
1.0
}
const fn default_chunk_rrf_fts_weight() -> f32 {
1.0
}
+107
View File
@@ -0,0 +1,107 @@
use async_openai::Client;
use common::{
error::AppError,
storage::{db::SurrealDbClient, types::text_chunk::TextChunk},
utils::embedding::EmbeddingProvider,
};
use crate::{reranking::RerankerLease, scoring::Scored, RetrievedChunk, RetrievedEntity};
use super::{
config::RetrievalConfig,
diagnostics::{AssembleStats, Diagnostics, SearchStats},
StageKind, StageTimings, RetrievalParams,
};
/// Mutable working state threaded through every retrieval stage.
pub(crate) struct PipelineContext<'a> {
pub db_client: &'a SurrealDbClient,
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
pub embedding_provider: Option<&'a EmbeddingProvider>,
pub input_text: String,
pub user_id: String,
pub config: RetrievalConfig,
pub query_embedding: Option<Vec<f32>>,
pub chunk_values: Vec<Scored<TextChunk>>,
pub reranker: Option<RerankerLease>,
pub diagnostics: Option<Diagnostics>,
pub entity_results: Vec<RetrievedEntity>,
pub chunk_results: Vec<RetrievedChunk>,
stage_timings: StageTimings,
}
impl<'a> PipelineContext<'a> {
pub fn new(params: RetrievalParams<'a>) -> Self {
Self {
db_client: params.db_client,
openai_client: params.openai_client,
embedding_provider: params.embedding_provider,
input_text: params.input_text.to_owned(),
user_id: params.user_id.to_owned(),
config: params.config,
query_embedding: None,
chunk_values: Vec::new(),
reranker: params.reranker,
diagnostics: None,
entity_results: Vec::new(),
chunk_results: Vec::new(),
stage_timings: StageTimings::default(),
}
}
pub fn with_embedding(params: RetrievalParams<'a>, query_embedding: Vec<f32>) -> Self {
let mut ctx = Self::new(params);
ctx.query_embedding = Some(query_embedding);
ctx
}
pub(crate) fn ensure_embedding(&self) -> Result<&Vec<f32>, Box<AppError>> {
self.query_embedding.as_ref().ok_or_else(|| {
Box::new(AppError::InternalError(
"query embedding missing before candidate search".to_string(),
))
})
}
pub fn enable_diagnostics(&mut self) {
if self.diagnostics.is_none() {
self.diagnostics = Some(Diagnostics::default());
}
}
pub fn diagnostics_enabled(&self) -> bool {
self.diagnostics.is_some()
}
pub(crate) fn record_search(&mut self, stats: SearchStats) {
if let Some(diag) = self.diagnostics.as_mut() {
diag.search = Some(stats);
}
}
pub(crate) fn record_assemble(&mut self, stats: AssembleStats) {
if let Some(diag) = self.diagnostics.as_mut() {
diag.assemble = Some(stats);
}
}
pub fn take_diagnostics(&mut self) -> Option<Diagnostics> {
self.diagnostics.take()
}
pub fn take_stage_timings(&mut self) -> StageTimings {
std::mem::take(&mut self.stage_timings)
}
pub fn record_stage_duration(&mut self, kind: StageKind, duration: std::time::Duration) {
self.stage_timings.record(kind, duration);
}
pub fn take_entity_results(&mut self) -> Vec<RetrievedEntity> {
std::mem::take(&mut self.entity_results)
}
pub fn take_chunk_results(&mut self) -> Vec<RetrievedChunk> {
std::mem::take(&mut self.chunk_results)
}
}
+3 -33
View File
@@ -1,51 +1,21 @@
use serde::Serialize; use serde::Serialize;
/// Captures instrumentation for each hybrid retrieval stage when diagnostics are enabled. /// Captures instrumentation for the retrieval stages when diagnostics are enabled.
#[derive(Debug, Clone, Default, Serialize)] #[derive(Debug, Clone, Default, Serialize)]
pub struct Diagnostics { pub struct Diagnostics {
pub collect_candidates: Option<CollectCandidatesStats>, pub search: Option<SearchStats>,
pub enrich_chunks_from_entities: Option<ChunkEnrichmentStats>,
pub assemble: Option<AssembleStats>, pub assemble: Option<AssembleStats>,
} }
#[derive(Debug, Clone, Default, Serialize)] #[derive(Debug, Clone, Default, Serialize)]
pub struct CollectCandidatesStats { pub struct SearchStats {
pub vector_entity_candidates: usize,
pub vector_chunk_candidates: usize, pub vector_chunk_candidates: usize,
pub fts_entity_candidates: usize,
pub fts_chunk_candidates: usize, pub fts_chunk_candidates: usize,
pub vector_chunk_scores: Vec<f32>, pub vector_chunk_scores: Vec<f32>,
pub fts_chunk_scores: Vec<f32>, pub fts_chunk_scores: Vec<f32>,
} }
#[derive(Debug, Clone, Default, Serialize)]
pub struct ChunkEnrichmentStats {
pub filtered_entity_count: usize,
pub fallback_min_results: usize,
pub chunk_sources_considered: usize,
pub chunk_candidates_before_enrichment: usize,
pub chunk_candidates_after_enrichment: usize,
pub top_chunk_scores: Vec<f32>,
}
#[derive(Debug, Clone, Default, Serialize)] #[derive(Debug, Clone, Default, Serialize)]
pub struct AssembleStats { pub struct AssembleStats {
pub token_budget_start: usize,
pub token_budget_spent: usize,
pub token_budget_remaining: usize,
pub budget_exhausted: bool,
pub chunks_selected: usize, pub chunks_selected: usize,
pub chunks_skipped_due_budget: usize,
pub entity_count: usize,
pub entity_traces: Vec<EntityAssemblyTrace>,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct EntityAssemblyTrace {
pub entity_id: String,
pub source_id: String,
pub inspected_candidates: usize,
pub selected_chunk_ids: Vec<String>,
pub selected_chunk_scores: Vec<f32>,
pub skipped_due_budget: usize,
} }
+119 -209
View File
@@ -1,61 +1,68 @@
mod config; mod config;
mod context;
mod diagnostics; mod diagnostics;
mod stages; mod stages;
mod strategies;
pub use config::{ pub use config::RetrievalConfig;
RetrievalConfig, RetrievalStrategy, RetrievalTuning, RetrievalTuningFlags, SearchTarget, pub use diagnostics::Diagnostics;
};
pub use diagnostics::{
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace, Diagnostics,
};
use crate::{reranking::RerankerLease, RetrievedEntity, StrategyOutput}; use crate::{round_score, RetrievalOutput, RetrievedEntity};
use async_openai::Client; use async_openai::Client;
use async_trait::async_trait; use async_trait::async_trait;
use common::{error::AppError, storage::db::SurrealDbClient}; use common::{error::AppError, storage::db::SurrealDbClient};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tracing::info; use tracing::info;
use stages::PipelineContext; use stages::{
use strategies::{ ChunkAssembleStage, ChunkRerankStage, ChunkSearchStage, EmbedStage, ResolveEntitiesStage,
DefaultStrategyDriver, IngestionDriver, RelationshipSuggestionDriver, SearchStrategyDriver,
}; };
// Export StrategyOutput publicly from this module /// Identifies a retrieval stage for timing and instrumentation.
// (it's defined in lib.rs but we re-export it here) ///
/// [`StageKind::ALL`] lists every kind in pipeline order; consumers (e.g. the evaluation
// Stage type enum /// harness) iterate it generically so that adding a stage requires no changes outside this
/// crate — add the variant, extend `ALL`, and give it a [`StageKind::label`].
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum StageKind { pub enum StageKind {
Embed, Embed,
CollectCandidates, Search,
GraphExpansion,
ChunkAttach,
Rerank, Rerank,
ResolveEntities,
Assemble, Assemble,
} }
// Pipeline stage trait impl StageKind {
/// Every stage kind in canonical pipeline order.
pub const ALL: [StageKind; 5] = [
StageKind::Embed,
StageKind::Search,
StageKind::Rerank,
StageKind::ResolveEntities,
StageKind::Assemble,
];
/// Stable, machine-friendly identifier for the stage (used as a metrics key).
pub const fn label(self) -> &'static str {
match self {
StageKind::Embed => "embed",
StageKind::Search => "search",
StageKind::Rerank => "rerank",
StageKind::ResolveEntities => "resolve_entities",
StageKind::Assemble => "assemble",
}
}
}
/// A single composable step in the retrieval pipeline.
#[async_trait] #[async_trait]
pub trait Stage: Send + Sync { pub(crate) trait Stage: Send + Sync {
fn kind(&self) -> StageKind; fn kind(&self) -> StageKind;
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError>; async fn execute(&self, ctx: &mut context::PipelineContext<'_>) -> Result<(), AppError>;
} }
// Type alias for boxed stages pub(crate) type BoxedStage = Box<dyn Stage>;
pub type BoxedStage = Box<dyn Stage>;
// Strategy driver trait /// Per-stage execution timings recorded during a run.
#[async_trait]
pub trait StrategyDriver: Send + Sync {
type Output;
fn stages(&self) -> Vec<BoxedStage>;
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>>;
}
// Pipeline stage timings tracker
#[derive(Debug, Default, Clone)] #[derive(Debug, Default, Clone)]
pub struct StageTimings { pub struct StageTimings {
timings: Vec<(StageKind, Duration)>, timings: Vec<(StageKind, Duration)>,
@@ -66,41 +73,13 @@ impl StageTimings {
self.timings.push((kind, duration)); self.timings.push((kind, duration));
} }
pub fn into_vec(self) -> Vec<(StageKind, Duration)> { /// Milliseconds recorded for `kind`, or `0` if the stage did not run.
self.timings pub fn stage_ms(&self, kind: StageKind) -> u128 {
}
// Helper methods to get duration for each stage type (for backward compatibility)
fn get_stage_ms(&self, kind: StageKind) -> u128 {
self.timings self.timings
.iter() .iter()
.find(|(k, _)| *k == kind) .find(|(k, _)| *k == kind)
.map_or(0, |(_, d)| d.as_millis()) .map_or(0, |(_, d)| d.as_millis())
} }
pub fn embed_ms(&self) -> u128 {
self.get_stage_ms(StageKind::Embed)
}
pub fn collect_candidates_ms(&self) -> u128 {
self.get_stage_ms(StageKind::CollectCandidates)
}
pub fn graph_expansion_ms(&self) -> u128 {
self.get_stage_ms(StageKind::GraphExpansion)
}
pub fn chunk_attach_ms(&self) -> u128 {
self.get_stage_ms(StageKind::ChunkAttach)
}
pub fn rerank_ms(&self) -> u128 {
self.get_stage_ms(StageKind::Rerank)
}
pub fn assemble_ms(&self) -> u128 {
self.get_stage_ms(StageKind::Assemble)
}
} }
pub struct RunOutput<T> { pub struct RunOutput<T> {
@@ -109,7 +88,35 @@ pub struct RunOutput<T> {
pub stage_timings: StageTimings, pub stage_timings: StageTimings,
} }
pub async fn execute(params: StrategyParams<'_>) -> Result<StrategyOutput, AppError> { /// Inputs required to run a retrieval.
pub struct RetrievalParams<'a> {
pub db_client: &'a SurrealDbClient,
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
pub embedding_provider: Option<&'a common::utils::embedding::EmbeddingProvider>,
pub input_text: &'a str,
pub user_id: &'a str,
pub config: RetrievalConfig,
pub reranker: Option<crate::reranking::RerankerLease>,
}
fn build_stages(config: &RetrievalConfig) -> Vec<BoxedStage> {
let mut stages: Vec<BoxedStage> = vec![
Box::new(EmbedStage),
Box::new(ChunkSearchStage),
Box::new(ChunkRerankStage),
];
if config.resolve_entities {
stages.push(Box::new(ResolveEntitiesStage));
}
stages.push(Box::new(ChunkAssembleStage));
stages
}
async fn run(
params: RetrievalParams<'_>,
query_embedding: Option<Vec<f32>>,
capture_diagnostics: bool,
) -> Result<RunOutput<RetrievalOutput>, AppError> {
let input_chars = params.input_text.chars().count(); let input_chars = params.input_text.chars().count();
let input_preview: String = params.input_text.chars().take(120).collect(); let input_preview: String = params.input_text.chars().take(120).collect();
let input_preview_clean = input_preview.replace('\n', " "); let input_preview_clean = input_preview.replace('\n', " ");
@@ -119,110 +126,67 @@ pub async fn execute(params: StrategyParams<'_>) -> Result<StrategyOutput, AppEr
input_chars, input_chars,
preview_truncated = input_chars > preview_len, preview_truncated = input_chars > preview_len,
preview = %input_preview_clean, preview = %input_preview_clean,
strategy = %params.config.strategy, resolve_entities = params.config.resolve_entities,
"Starting retrieval pipeline" "Starting retrieval pipeline"
); );
let strategy = params.config.strategy; let resolve_entities = params.config.resolve_entities;
let search_target = params.config.search_target; let mut ctx = match query_embedding {
Some(embedding) => context::PipelineContext::with_embedding(params, embedding),
None => context::PipelineContext::new(params),
};
match strategy { if capture_diagnostics {
RetrievalStrategy::Default => { ctx.enable_diagnostics();
let driver = DefaultStrategyDriver::new();
let run = execute_strategy(driver, params, None, false).await?;
Ok(StrategyOutput::Chunks(run.results))
}
RetrievalStrategy::RelationshipSuggestion => {
let driver = RelationshipSuggestionDriver::new();
let run = execute_strategy(driver, params, None, false).await?;
Ok(StrategyOutput::Entities(run.results))
}
RetrievalStrategy::Ingestion => {
let driver = IngestionDriver::new();
let run = execute_strategy(driver, params, None, false).await?;
Ok(StrategyOutput::Entities(run.results))
}
RetrievalStrategy::Search => {
let driver = SearchStrategyDriver::new(search_target);
let run = execute_strategy(driver, params, None, false).await?;
Ok(StrategyOutput::Search(run.results))
}
} }
for stage in build_stages(&ctx.config) {
let start = Instant::now();
stage.execute(&mut ctx).await?;
ctx.record_stage_duration(stage.kind(), start.elapsed());
}
let diagnostics = ctx.take_diagnostics();
let stage_timings = ctx.take_stage_timings();
let chunks = ctx.take_chunk_results();
let results = if resolve_entities {
RetrievalOutput::WithEntities {
chunks,
entities: ctx.take_entity_results(),
}
} else {
RetrievalOutput::Chunks(chunks)
};
Ok(RunOutput {
results,
diagnostics,
stage_timings,
})
} }
pub async fn run_pipeline_with_embedding( /// Run the retrieval pipeline, generating the query embedding internally if needed.
params: StrategyParams<'_>, pub async fn execute(params: RetrievalParams<'_>) -> Result<RetrievalOutput, AppError> {
query_embedding: Vec<f32>, Ok(run(params, None, false).await?.results)
) -> Result<StrategyOutput, AppError> {
let strategy = params.config.strategy;
let search_target = params.config.search_target;
match strategy {
RetrievalStrategy::Default => {
let driver = DefaultStrategyDriver::new();
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
Ok(StrategyOutput::Chunks(run.results))
}
RetrievalStrategy::RelationshipSuggestion => {
let driver = RelationshipSuggestionDriver::new();
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
Ok(StrategyOutput::Entities(run.results))
}
RetrievalStrategy::Ingestion => {
let driver = IngestionDriver::new();
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
Ok(StrategyOutput::Entities(run.results))
}
RetrievalStrategy::Search => {
let driver = SearchStrategyDriver::new(search_target);
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
Ok(StrategyOutput::Search(run.results))
}
}
} }
pub async fn run_pipeline_with_embedding_with_metrics( /// Run the retrieval pipeline with a pre-computed query embedding.
params: StrategyParams<'_>, pub async fn run_with_embedding(
params: RetrievalParams<'_>,
query_embedding: Vec<f32>, query_embedding: Vec<f32>,
) -> Result<RunOutput<StrategyOutput>, AppError> { ) -> Result<RetrievalOutput, AppError> {
let strategy = params.config.strategy; Ok(run(params, Some(query_embedding), false).await?.results)
match strategy {
RetrievalStrategy::Default => {
let driver = DefaultStrategyDriver::new();
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
Ok(RunOutput {
results: StrategyOutput::Chunks(run.results),
diagnostics: run.diagnostics,
stage_timings: run.stage_timings,
})
}
_ => Err(AppError::InternalError(
"Metrics not supported for this strategy".into(),
)),
}
} }
pub async fn run_pipeline_with_embedding_with_diagnostics( /// Run with a pre-computed embedding, returning results and per-stage timings.
params: StrategyParams<'_>, ///
/// When `capture_diagnostics` is true, pipeline search/assemble stats are included.
pub async fn run_with_embedding_instrumented(
params: RetrievalParams<'_>,
query_embedding: Vec<f32>, query_embedding: Vec<f32>,
) -> Result<RunOutput<StrategyOutput>, AppError> { capture_diagnostics: bool,
let strategy = params.config.strategy; ) -> Result<RunOutput<RetrievalOutput>, AppError> {
run(params, Some(query_embedding), capture_diagnostics).await
match strategy {
RetrievalStrategy::Default => {
let driver = DefaultStrategyDriver::new();
let run = execute_strategy(driver, params, Some(query_embedding), true).await?;
Ok(RunOutput {
results: StrategyOutput::Chunks(run.results),
diagnostics: run.diagnostics,
stage_timings: run.stage_timings,
})
}
_ => Err(AppError::InternalError(
"Diagnostics not supported for this strategy".into(),
)),
}
} }
pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value { pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value {
@@ -246,57 +210,3 @@ pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::V
}) })
.collect::<Vec<_>>()) .collect::<Vec<_>>())
} }
pub struct StrategyParams<'a> {
pub db_client: &'a SurrealDbClient,
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
pub embedding_provider: Option<&'a common::utils::embedding::EmbeddingProvider>,
pub input_text: &'a str,
pub user_id: &'a str,
pub config: RetrievalConfig,
pub reranker: Option<RerankerLease>,
}
async fn execute_strategy<D: StrategyDriver>(
driver: D,
params: StrategyParams<'_>,
query_embedding: Option<Vec<f32>>,
capture_diagnostics: bool,
) -> Result<RunOutput<D::Output>, AppError> {
let ctx = match query_embedding {
Some(embedding) => PipelineContext::with_embedding(params, embedding),
None => PipelineContext::new(params),
};
run_with_driver(driver, ctx, capture_diagnostics).await
}
async fn run_with_driver<D: StrategyDriver>(
driver: D,
mut ctx: PipelineContext<'_>,
capture_diagnostics: bool,
) -> Result<RunOutput<D::Output>, AppError> {
if capture_diagnostics {
ctx.enable_diagnostics();
}
for stage in driver.stages() {
let start = Instant::now();
stage.execute(&mut ctx).await?;
ctx.record_stage_duration(stage.kind(), start.elapsed());
}
let diagnostics = ctx.take_diagnostics();
let stage_timings = ctx.take_stage_timings();
let results = driver.finalize(&mut ctx).map_err(|e| *e)?;
Ok(RunOutput {
results,
diagnostics,
stage_timings,
})
}
fn round_score(value: f32) -> f64 {
(f64::from(value) * 1000.0).round() / 1000.0
}
+424
View File
@@ -0,0 +1,424 @@
use async_trait::async_trait;
use common::{
error::AppError,
storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
utils::embedding::generate_embedding,
};
use fastembed::RerankResult;
use std::collections::HashMap;
use tracing::{debug, instrument, warn};
use crate::{
scoring::{
clamp_unit, min_max_normalize, reciprocal_rank_fusion, RrfConfig, Scored,
},
RetrievedChunk, RetrievedEntity,
};
use super::{
config::RetrievalTuning,
context::PipelineContext,
diagnostics::{AssembleStats, SearchStats},
Stage, StageKind,
};
#[derive(Debug, Clone, Copy)]
pub struct EmbedStage;
#[async_trait]
impl Stage for EmbedStage {
fn kind(&self) -> StageKind {
StageKind::Embed
}
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
embed(ctx).await
}
}
#[derive(Debug, Clone, Copy)]
pub struct ChunkSearchStage;
#[async_trait]
impl Stage for ChunkSearchStage {
fn kind(&self) -> StageKind {
StageKind::Search
}
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
search_chunks(ctx).await
}
}
#[derive(Debug, Clone, Copy)]
pub struct ChunkRerankStage;
#[async_trait]
impl Stage for ChunkRerankStage {
fn kind(&self) -> StageKind {
StageKind::Rerank
}
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
rerank_chunks(ctx).await
}
}
#[derive(Debug, Clone, Copy)]
pub struct ResolveEntitiesStage;
#[async_trait]
impl Stage for ResolveEntitiesStage {
fn kind(&self) -> StageKind {
StageKind::ResolveEntities
}
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
resolve_entities(ctx).await
}
}
#[derive(Debug, Clone, Copy)]
pub struct ChunkAssembleStage;
#[async_trait]
impl Stage for ChunkAssembleStage {
fn kind(&self) -> StageKind {
StageKind::Assemble
}
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
assemble_chunks(ctx)
}
}
#[instrument(level = "trace", skip_all)]
pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
if ctx.query_embedding.is_some() {
debug!("Reusing cached query embedding for hybrid retrieval");
} else {
debug!("Generating query embedding for hybrid retrieval");
let embedding = if let Some(provider) = ctx.embedding_provider {
provider.embed(&ctx.input_text).await?
} else {
generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await?
};
ctx.query_embedding = Some(embedding);
}
Ok(())
}
#[instrument(level = "trace", skip_all)]
pub async fn search_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
debug!("Collecting chunk candidates via vector and FTS search");
let embedding = ctx.ensure_embedding().map_err(|e| *e)?.clone();
let tuning = &ctx.config.tuning;
let fts_take = tuning.chunk_fts_take;
let (fts_query, fts_token_count) = normalize_fts_query(&ctx.input_text);
let fts_enabled = tuning.flags.chunk_rrf_use_fts() && fts_take > 0 && !fts_query.is_empty();
let (vector_rows, fts_rows) = tokio::try_join!(
TextChunk::vector_search(
tuning.chunk_vector_take,
embedding,
ctx.db_client,
&ctx.user_id,
),
async {
if fts_enabled {
TextChunk::fts_search(fts_take, &fts_query, ctx.db_client, &ctx.user_id).await
} else {
Ok(Vec::new())
}
}
)?;
let vector_candidates = vector_rows.len();
let fts_candidates = fts_rows.len();
let vector_scored: Vec<Scored<TextChunk>> = vector_rows
.into_iter()
.map(|row| Scored::new(row.chunk).with_vector_score(row.score))
.collect();
let fts_scored: Vec<Scored<TextChunk>> = fts_rows
.into_iter()
.map(|row| Scored::new(row.chunk).with_fts_score(row.score))
.collect();
let mut fts_weight = tuning.chunk_rrf_fts_weight;
if fts_enabled && fts_token_count > 0 && fts_token_count <= 3 {
// For very short keyword queries, lean more on lexical ranking.
fts_weight *= 1.5;
}
let rrf_config = RrfConfig {
k: tuning.chunk_rrf_k,
vector_weight: tuning.chunk_rrf_vector_weight,
fts_weight,
use_vector: tuning.flags.chunk_rrf_use_vector(),
use_fts: tuning.flags.chunk_rrf_use_fts() && fts_candidates > 0,
};
let chunks = reciprocal_rank_fusion(vector_scored, fts_scored, rrf_config);
debug!(
total_merged = chunks.len(),
vector_only = chunks.iter().filter(|c| c.scores.fts.is_none()).count(),
fts_only = chunks.iter().filter(|c| c.scores.vector.is_none()).count(),
both_signals = chunks
.iter()
.filter(|c| c.scores.vector.is_some() && c.scores.fts.is_some())
.count(),
rrf_k = rrf_config.k,
"Merged chunk candidates with RRF"
);
if ctx.diagnostics_enabled() {
ctx.record_search(SearchStats {
vector_chunk_candidates: vector_candidates,
fts_chunk_candidates: fts_candidates,
vector_chunk_scores: sample_scores(&chunks, |chunk| chunk.scores.vector.unwrap_or(0.0)),
fts_chunk_scores: sample_scores(&chunks, |chunk| chunk.scores.fts.unwrap_or(0.0)),
});
}
ctx.chunk_values = chunks;
Ok(())
}
#[instrument(level = "trace", skip_all)]
pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
if ctx.chunk_values.len() <= 1 {
return Ok(());
}
let Some(reranker) = ctx.reranker.as_ref() else {
debug!("No reranker lease provided; skipping chunk rerank stage");
return Ok(());
};
let documents =
build_chunk_rerank_documents(&ctx.chunk_values, ctx.config.tuning.rerank_keep_top.max(1));
if documents.len() <= 1 {
debug!("Skipping chunk reranking stage; insufficient chunk documents");
return Ok(());
}
match reranker.rerank(&ctx.input_text, documents).await {
Ok(results) if !results.is_empty() => {
apply_chunk_rerank_results(&mut ctx.chunk_values, &ctx.config.tuning, results);
}
Ok(_) => debug!("Chunk reranker returned no results; retaining original order"),
Err(err) => warn!(
error = %err,
"Chunk reranking failed; continuing with original ordering"
),
}
Ok(())
}
/// Resolve the `KnowledgeEntity` rows that own the retrieved chunks.
///
/// Entities are derived directly from the (benchmarked) chunk retrieval: chunks are grouped
/// by `source_id`, the owning entities are loaded, scored by their best contributing chunk,
/// and the contributing chunks are attached.
#[instrument(level = "trace", skip_all)]
pub async fn resolve_entities(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
if ctx.chunk_values.is_empty() {
return Ok(());
}
let max_chunks = ctx.config.tuning.max_chunks_per_entity.max(1);
let mut source_order: Vec<String> = Vec::new();
let mut chunks_by_source: HashMap<String, Vec<RetrievedChunk>> = HashMap::new();
let mut best_score: HashMap<String, f32> = HashMap::new();
for scored in &ctx.chunk_values {
let source = scored.item.source_id.clone();
let attached = chunks_by_source.entry(source.clone()).or_default();
if attached.is_empty() {
source_order.push(source.clone());
best_score.insert(source.clone(), scored.fused);
}
if attached.len() < max_chunks {
attached.push(RetrievedChunk {
chunk: scored.item.clone(),
score: scored.fused,
});
}
}
let entities =
KnowledgeEntity::find_by_source_ids(ctx.db_client, &source_order, &ctx.user_id).await?;
let mut entities_by_source: HashMap<String, Vec<KnowledgeEntity>> = HashMap::new();
for entity in entities {
entities_by_source
.entry(entity.source_id.clone())
.or_default()
.push(entity);
}
let mut results = Vec::new();
for source in &source_order {
let Some(entities) = entities_by_source.remove(source) else {
continue;
};
let score = best_score.get(source).copied().unwrap_or(0.0);
let chunks = chunks_by_source.get(source).cloned().unwrap_or_default();
for entity in entities {
results.push(RetrievedEntity {
entity,
score,
chunks: chunks.clone(),
});
}
}
debug!(
sources = source_order.len(),
entities = results.len(),
"Resolved entities from retrieved chunks"
);
ctx.entity_results = results;
Ok(())
}
#[instrument(level = "trace", skip_all)]
#[allow(clippy::result_large_err)]
pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
debug!("Assembling chunk retrieval results");
let mut chunk_values = std::mem::take(&mut ctx.chunk_values);
// Limit how many chunks we return to keep context size reasonable.
let limit = ctx
.config
.tuning
.chunk_result_cap
.max(1)
.min(ctx.config.tuning.chunk_vector_take.max(1));
if chunk_values.len() > limit {
chunk_values.truncate(limit);
}
ctx.chunk_results = chunk_values
.into_iter()
.map(|chunk| RetrievedChunk {
chunk: chunk.item,
score: chunk.fused,
})
.collect();
if ctx.diagnostics_enabled() {
ctx.record_assemble(AssembleStats {
chunks_selected: ctx.chunk_results.len(),
});
}
Ok(())
}
const SCORE_SAMPLE_LIMIT: usize = 8;
fn sample_scores<T, F>(items: &[Scored<T>], extractor: F) -> Vec<f32>
where
F: FnMut(&Scored<T>) -> f32,
{
items.iter().take(SCORE_SAMPLE_LIMIT).map(extractor).collect()
}
fn normalize_fts_query(input: &str) -> (String, usize) {
const STOPWORDS: &[&str] = &["the", "a", "an", "of", "in", "on", "and", "or", "to", "for"];
let mut cleaned = String::with_capacity(input.len());
for ch in input.chars() {
if ch.is_alphanumeric() {
cleaned.extend(ch.to_lowercase());
} else if ch.is_whitespace() {
cleaned.push(' ');
}
}
let mut tokens = Vec::with_capacity(cleaned.len().div_ceil(3));
for token in cleaned.split_whitespace() {
if !STOPWORDS.contains(&token) && !token.is_empty() {
tokens.push(token.to_string());
}
}
let normalized = tokens.join(" ");
(normalized, tokens.len())
}
fn build_chunk_rerank_documents(chunks: &[Scored<TextChunk>], max_chunks: usize) -> Vec<String> {
chunks
.iter()
.take(max_chunks)
.map(|chunk| {
format!(
"Source: {}\nChunk:\n{}",
chunk.item.source_id,
chunk.item.chunk.trim()
)
})
.collect()
}
fn apply_chunk_rerank_results(
chunks: &mut Vec<Scored<TextChunk>>,
tuning: &RetrievalTuning,
results: Vec<RerankResult>,
) {
if results.is_empty() || chunks.is_empty() {
return;
}
let mut remaining: Vec<Option<Scored<TextChunk>>> =
std::mem::take(chunks).into_iter().map(Some).collect();
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
let normalized_scores = min_max_normalize(&raw_scores);
let use_only = tuning.flags.rerank_scores_only();
let blend = if use_only {
1.0
} else {
clamp_unit(tuning.rerank_blend_weight)
};
let mut reranked: Vec<Scored<TextChunk>> = Vec::with_capacity(remaining.len());
for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) {
if let Some(slot) = remaining.get_mut(result.index) {
if let Some(mut candidate) = slot.take() {
let original = candidate.fused;
let blended = if use_only {
clamp_unit(normalized)
} else {
clamp_unit(original * (1.0 - blend) + normalized * blend)
};
candidate.update_fused(blended);
reranked.push(candidate);
}
} else {
warn!(
result_index = result.index,
"Chunk reranker returned out-of-range index; skipping"
);
}
if reranked.len() == remaining.len() {
break;
}
}
reranked.extend(remaining.into_iter().flatten());
let keep_top = tuning.rerank_keep_top;
if keep_top > 0 && reranked.len() > keep_top {
reranked.truncate(keep_top);
}
*chunks = reranked;
}
File diff suppressed because it is too large Load Diff
@@ -1,148 +0,0 @@
use super::{
stages::{
AssembleEntitiesStage, ChunkAssembleStage, ChunkRerankStage, ChunkVectorStage,
CollectCandidatesStage, EmbedStage, GraphExpansionStage, PipelineContext, RerankStage,
},
BoxedStage, StrategyDriver,
};
use crate::{RetrievedChunk, RetrievedEntity};
use common::error::AppError;
pub struct DefaultStrategyDriver;
impl DefaultStrategyDriver {
pub fn new() -> Self {
Self
}
}
impl StrategyDriver for DefaultStrategyDriver {
type Output = Vec<RetrievedChunk>;
fn stages(&self) -> Vec<BoxedStage> {
vec![
Box::new(EmbedStage),
Box::new(ChunkVectorStage),
Box::new(ChunkRerankStage),
Box::new(ChunkAssembleStage),
]
}
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
Ok(ctx.take_chunk_results())
}
}
pub struct RelationshipSuggestionDriver;
impl RelationshipSuggestionDriver {
pub fn new() -> Self {
Self
}
}
impl StrategyDriver for RelationshipSuggestionDriver {
type Output = Vec<RetrievedEntity>;
fn stages(&self) -> Vec<BoxedStage> {
vec![
Box::new(EmbedStage),
Box::new(CollectCandidatesStage),
Box::new(GraphExpansionStage),
// Skip ChunkAttachStage
Box::new(RerankStage),
Box::new(AssembleEntitiesStage),
]
}
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
Ok(ctx.take_entity_results())
}
}
pub struct IngestionDriver;
impl IngestionDriver {
pub fn new() -> Self {
Self
}
}
impl StrategyDriver for IngestionDriver {
type Output = Vec<RetrievedEntity>;
fn stages(&self) -> Vec<BoxedStage> {
vec![
Box::new(EmbedStage),
Box::new(CollectCandidatesStage),
Box::new(GraphExpansionStage),
// Skip ChunkAttachStage
Box::new(RerankStage),
Box::new(AssembleEntitiesStage),
]
}
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
Ok(ctx.take_entity_results())
}
}
use super::config::SearchTarget;
use crate::SearchResult;
/// Search strategy driver that retrieves both chunks and entities
pub struct SearchStrategyDriver {
target: SearchTarget,
}
impl SearchStrategyDriver {
pub fn new(target: SearchTarget) -> Self {
Self { target }
}
}
impl StrategyDriver for SearchStrategyDriver {
type Output = SearchResult;
fn stages(&self) -> Vec<BoxedStage> {
match self.target {
SearchTarget::ChunksOnly => vec![
Box::new(EmbedStage),
Box::new(ChunkVectorStage),
Box::new(ChunkRerankStage),
Box::new(ChunkAssembleStage),
],
SearchTarget::EntitiesOnly => vec![
Box::new(EmbedStage),
Box::new(CollectCandidatesStage),
Box::new(GraphExpansionStage),
Box::new(RerankStage),
Box::new(AssembleEntitiesStage),
],
SearchTarget::Both => vec![
Box::new(EmbedStage),
// Chunk retrieval path
Box::new(ChunkVectorStage),
Box::new(ChunkRerankStage),
Box::new(ChunkAssembleStage),
// Entity retrieval path (runs after chunk stages)
Box::new(CollectCandidatesStage),
Box::new(GraphExpansionStage),
Box::new(RerankStage),
Box::new(AssembleEntitiesStage),
],
}
}
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
let chunks = match self.target {
SearchTarget::EntitiesOnly => Vec::new(),
_ => ctx.take_chunk_results(),
};
let entities = match self.target {
SearchTarget::ChunksOnly => Vec::new(),
_ => ctx.take_entity_results(),
};
Ok(SearchResult::new(chunks, entities))
}
}
@@ -97,8 +97,7 @@ impl RerankerPool {
fn default_pool_size() -> usize { fn default_pool_size() -> usize {
available_parallelism() available_parallelism()
.map(|value| value.get().min(2)) .map_or(2, |value| value.get().min(2))
.unwrap_or(2)
.max(1) .max(1)
} }
@@ -156,6 +155,7 @@ pub struct RerankerLease {
} }
impl RerankerLease { impl RerankerLease {
#[allow(clippy::result_large_err)]
pub async fn rerank( pub async fn rerank(
&self, &self,
query: &str, query: &str,
@@ -165,7 +165,9 @@ impl RerankerLease {
let engine = Arc::clone(&self.engine); let engine = Arc::clone(&self.engine);
tokio::task::spawn_blocking(move || { tokio::task::spawn_blocking(move || {
let mut guard = engine.lock().expect("reranker engine mutex poisoned"); let mut guard = engine.lock().map_err(|_| {
AppError::InternalError("reranker engine mutex poisoned".into())
})?;
guard guard
.rerank(query, documents, false, None) .rerank(query, documents, false, None)
.map_err(|e| AppError::InternalError(e.to_string())) .map_err(|e| AppError::InternalError(e.to_string()))
+7 -120
View File
@@ -1,14 +1,12 @@
use std::{cmp::Ordering, collections::HashMap}; use std::{cmp::Ordering, collections::HashMap};
use common::storage::types::StoredObject; use common::storage::types::StoredObject;
use serde::{Deserialize, Serialize};
/// Holds optional subscores gathered from different retrieval signals. /// Holds optional subscores gathered from the vector and full-text retrieval signals.
#[derive(Debug, Clone, Copy, Default)] #[derive(Debug, Clone, Copy, Default)]
pub struct Scores { pub struct Scores {
pub fts: Option<f32>, pub fts: Option<f32>,
pub vector: Option<f32>, pub vector: Option<f32>,
pub graph: Option<f32>,
} }
/// Generic wrapper combining an item with its accumulated retrieval scores. /// Generic wrapper combining an item with its accumulated retrieval scores.
@@ -40,40 +38,11 @@ impl<T> Scored<T> {
self self
} }
#[must_use]
pub const fn with_graph_score(mut self, score: f32) -> Self {
self.scores.graph = Some(score);
self
}
pub const fn update_fused(&mut self, fused: f32) { pub const fn update_fused(&mut self, fused: f32) {
self.fused = fused; self.fused = fused;
} }
} }
/// Weights used for linear score fusion.
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct FusionWeights {
pub vector: f32,
pub fts: f32,
pub graph: f32,
pub multi_bonus: f32,
}
impl Default for FusionWeights {
fn default() -> Self {
// Default weights favor vector search, which typically performs better
// FTS is used as a complement when there's good overlap
// Higher multi_bonus to heavily favor chunks with both signals (the "golden chunk")
Self {
vector: 0.8,
fts: 0.2,
graph: 0.2,
multi_bonus: 0.3, // Increased to boost chunks with both signals
}
}
}
/// Configuration for reciprocal rank fusion. /// Configuration for reciprocal rank fusion.
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct RrfConfig { pub struct RrfConfig {
@@ -84,29 +53,10 @@ pub struct RrfConfig {
pub use_fts: bool, pub use_fts: bool,
} }
impl Default for RrfConfig {
fn default() -> Self {
Self {
k: 60.0,
vector_weight: 1.0,
fts_weight: 1.0,
use_vector: true,
use_fts: true,
}
}
}
pub const fn clamp_unit(value: f32) -> f32 { pub const fn clamp_unit(value: f32) -> f32 {
value.clamp(0.0, 1.0) value.clamp(0.0, 1.0)
} }
pub fn distance_to_similarity(distance: f32) -> f32 {
if !distance.is_finite() {
return 0.0;
}
clamp_unit(1.0 / (1.0 + distance.max(0.0)))
}
pub fn min_max_normalize(scores: &[f32]) -> Vec<f32> { pub fn min_max_normalize(scores: &[f32]) -> Vec<f32> {
if scores.is_empty() { if scores.is_empty() {
return Vec::new(); return Vec::new();
@@ -147,69 +97,6 @@ pub fn min_max_normalize(scores: &[f32]) -> Vec<f32> {
.collect() .collect()
} }
pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 {
let vector = scores.vector.unwrap_or(0.0);
let fts = scores.fts.unwrap_or(0.0);
let graph = scores.graph.unwrap_or(0.0);
let mut fused = graph.mul_add(
weights.graph,
vector.mul_add(weights.vector, fts * weights.fts),
);
let signals_present = scores
.vector
.iter()
.chain(scores.fts.iter())
.chain(scores.graph.iter())
.count();
// Boost chunks with multiple signals (especially vector + FTS, the "golden chunk")
if signals_present >= 2 {
// For chunks with both vector and FTS, give a significant boost
// This helps identify the "golden chunk" that appears in both searches
if scores.vector.is_some() && scores.fts.is_some() {
// Multiplicative boost: multiply by (1 + bonus) to scale with the base score
// This ensures high-scoring golden chunks get boosted more than low-scoring ones
fused *= 1.0 + weights.multi_bonus;
} else {
// For other multi-signal combinations (e.g., vector + graph), use additive bonus
fused += weights.multi_bonus;
}
}
clamp_unit(fused)
}
pub fn merge_scored_by_id<T, S: std::hash::BuildHasher>(
target: &mut std::collections::HashMap<String, Scored<T>, S>,
incoming: Vec<Scored<T>>,
) where
T: StoredObject + Clone,
{
for scored in incoming {
let id = scored.item.id().to_owned();
target
.entry(id)
.and_modify(|existing| {
if let Some(score) = scored.scores.vector {
existing.scores.vector = Some(score);
}
if let Some(score) = scored.scores.fts {
existing.scores.fts = Some(score);
}
if let Some(score) = scored.scores.graph {
existing.scores.graph = Some(score);
}
})
.or_insert_with(|| Scored {
item: scored.item.clone(),
scores: scored.scores,
fused: scored.fused,
});
}
}
pub fn sort_by_fused_desc<T>(items: &mut [Scored<T>]) pub fn sort_by_fused_desc<T>(items: &mut [Scored<T>])
where where
T: StoredObject, T: StoredObject,
@@ -222,6 +109,10 @@ where
}); });
} }
/// Fuse two ranked candidate lists into a single ranking using reciprocal rank fusion.
///
/// This is the sole fusion mechanism for the retrieval pipeline: vector and full-text
/// candidates each contribute `weight / (k + rank + 1)` to a shared fused score.
pub fn reciprocal_rank_fusion<T>( pub fn reciprocal_rank_fusion<T>(
mut vector_ranked: Vec<Scored<T>>, mut vector_ranked: Vec<Scored<T>>,
mut fts_ranked: Vec<Scored<T>>, mut fts_ranked: Vec<Scored<T>>,
@@ -266,9 +157,7 @@ where
} }
} }
entry.item = candidate.item; entry.item = candidate.item;
let rank_f32: f32 = u16::try_from(rank) let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from);
.map(f32::from)
.unwrap_or(f32::MAX);
entry.fused += vector_weight / (k + rank_f32 + 1.0); entry.fused += vector_weight / (k + rank_f32 + 1.0);
} }
} }
@@ -296,9 +185,7 @@ where
} }
} }
entry.item = candidate.item; entry.item = candidate.item;
let rank_f32: f32 = u16::try_from(rank) let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from);
.map(f32::from)
.unwrap_or(f32::MAX);
entry.fused += fts_weight / (k + rank_f32 + 1.0); entry.fused += fts_weight / (k + rank_f32 + 1.0);
} }
} }