chore: refactor retrieval pipeline to chunk-first RRF with derived entities and slimmer eval surface.

Collapse the multi-strategy entity engine into one benchmarked chunk retrieval path, derive entities from retrieved chunks, and update consumers, docs, and clippy fixes across the workspace.
This commit is contained in:
Per Stark
2026-05-30 22:19:08 +02:00
parent c70141de35
commit 5c2d2e24d3
38 changed files with 1049 additions and 2614 deletions
Generated
-4
View File
@@ -5426,14 +5426,10 @@ dependencies = [
"anyhow",
"async-openai",
"async-trait",
"axum",
"common",
"fastembed",
"futures",
"serde",
"serde_json",
"surrealdb",
"thiserror 1.0.69",
"tokio",
"tracing",
"uuid",
@@ -164,6 +164,35 @@ impl KnowledgeEntity {
.take(0)?)
}
/// Fetch all knowledge entities owned by any of the provided source ids for a user.
///
/// Used by retrieval to resolve the entities that own a set of retrieved chunks.
pub async fn find_by_source_ids(
db: &SurrealDbClient,
source_ids: &[String],
user_id: &str,
) -> Result<Vec<KnowledgeEntity>, AppError> {
if source_ids.is_empty() {
return Ok(Vec::new());
}
let entities: Vec<KnowledgeEntity> = db
.client
.query(
"SELECT * FROM type::table($table) \
WHERE source_id IN $sources AND user_id = $user_id",
)
.bind(("table", Self::table_name()))
.bind(("sources", source_ids.to_vec()))
.bind(("user_id", user_id.to_owned()))
.await
.map_err(AppError::Database)?
.take(0)
.map_err(AppError::Database)?;
Ok(entities)
}
pub async fn delete_by_source_id(
source_id: &str,
db_client: &SurrealDbClient,
+11 -119
View File
@@ -1,8 +1,7 @@
use config::{Config, ConfigError, Environment, File};
use serde::{Deserialize, Deserializer, Serialize};
use std::{env, fmt, str::FromStr, sync::Once};
use serde::{Deserialize, Serialize};
use std::{env, str::FromStr, sync::Once};
use thiserror::Error;
use tracing::warn;
/// Error returned when parsing an embedding backend name.
#[derive(Debug, Error, PartialEq, Eq)]
@@ -36,83 +35,6 @@ impl EmbeddingBackend {
}
}
/// Error returned when parsing a retrieval strategy name.
#[derive(Debug, Error, PartialEq, Eq)]
#[error("unknown retrieval strategy '{input}'")]
pub struct ParseRetrievalStrategyError {
/// The unrecognized input string.
pub input: String,
}
/// Selects which retrieval pipeline strategy to run for chat and search.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum RetrievalStrategy {
/// Primary hybrid chunk retrieval for search/chat.
#[default]
Default,
/// Entity retrieval for suggesting relationships when creating manual entities.
RelationshipSuggestion,
/// Entity retrieval for context during content ingestion.
Ingestion,
/// Unified search returning both chunks and entities.
Search,
}
impl RetrievalStrategy {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::Default => "default",
Self::RelationshipSuggestion => "relationship_suggestion",
Self::Ingestion => "ingestion",
Self::Search => "search",
}
}
}
impl FromStr for RetrievalStrategy {
type Err = ParseRetrievalStrategyError;
fn from_str(value: &str) -> Result<Self, Self::Err> {
match value.to_ascii_lowercase().as_str() {
"default" => Ok(Self::Default),
"initial" | "revised" => {
warn!(
"retrieval strategy '{value}' is deprecated; use 'default' instead"
);
Ok(Self::Default)
}
"relationship_suggestion" => Ok(Self::RelationshipSuggestion),
"ingestion" => Ok(Self::Ingestion),
"search" => Ok(Self::Search),
other => Err(ParseRetrievalStrategyError {
input: other.to_string(),
}),
}
}
}
impl fmt::Display for RetrievalStrategy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
fn deserialize_optional_retrieval_strategy<'de, D>(
deserializer: D,
) -> Result<Option<RetrievalStrategy>, D::Error>
where
D: Deserializer<'de>,
{
let value = Option::<String>::deserialize(deserializer)?;
match value {
None => Ok(None),
Some(raw) if raw.trim().is_empty() => Ok(None),
Some(raw) => RetrievalStrategy::from_str(&raw).map(Some).map_err(serde::de::Error::custom),
}
}
impl FromStr for EmbeddingBackend {
type Err = ParseEmbeddingBackendError;
@@ -195,8 +117,6 @@ pub struct AppConfig {
pub fastembed_show_download_progress: Option<bool>,
#[serde(default)]
pub fastembed_max_length: Option<usize>,
#[serde(default, deserialize_with = "deserialize_optional_retrieval_strategy")]
pub retrieval_strategy: Option<RetrievalStrategy>,
#[serde(default)]
pub embedding_backend: EmbeddingBackend,
#[serde(default = "default_ingest_max_body_bytes")]
@@ -282,14 +202,6 @@ pub fn ensure_ort_path() {
});
}
impl AppConfig {
/// Returns the configured retrieval strategy, or [`RetrievalStrategy::Default`] when unset.
#[must_use]
pub fn resolved_retrieval_strategy(&self) -> RetrievalStrategy {
self.retrieval_strategy.unwrap_or_default()
}
}
impl Default for AppConfig {
fn default() -> Self {
Self {
@@ -312,7 +224,6 @@ impl Default for AppConfig {
fastembed_cache_dir: None,
fastembed_show_download_progress: None,
fastembed_max_length: None,
retrieval_strategy: None,
embedding_backend: EmbeddingBackend::default(),
ingest_max_body_bytes: default_ingest_max_body_bytes(),
ingest_max_files: default_ingest_max_files(),
@@ -340,41 +251,22 @@ pub fn get_config() -> Result<AppConfig, ConfigError> {
mod tests {
#![allow(clippy::expect_used)]
use super::{ParseRetrievalStrategyError, RetrievalStrategy};
use super::EmbeddingBackend;
#[test]
fn retrieval_strategy_defaults_to_default() {
assert_eq!(
RetrievalStrategy::default(),
RetrievalStrategy::Default
);
fn embedding_backend_defaults_to_fastembed() {
assert_eq!(EmbeddingBackend::default(), EmbeddingBackend::FastEmbed);
}
#[test]
fn retrieval_strategy_serializes_snake_case() {
fn embedding_backend_parses_aliases() {
assert_eq!(
serde_json::to_string(&RetrievalStrategy::Search).expect("serialize"),
"\"search\""
"openai".parse::<EmbeddingBackend>().expect("openai"),
EmbeddingBackend::OpenAI
);
}
#[test]
fn retrieval_strategy_from_str_accepts_deprecated_aliases() {
assert_eq!(
"initial".parse::<RetrievalStrategy>().expect("initial"),
RetrievalStrategy::Default
);
assert!(matches!(
"unknown".parse::<RetrievalStrategy>(),
Err(ParseRetrievalStrategyError { .. })
));
}
#[test]
fn app_config_resolved_retrieval_strategy_uses_default_when_unset() {
let config = super::AppConfig::default();
assert_eq!(
config.resolved_retrieval_strategy(),
RetrievalStrategy::Default
"fast".parse::<EmbeddingBackend>().expect("fast"),
EmbeddingBackend::FastEmbed
);
}
}
-1
View File
@@ -24,7 +24,6 @@ Minne can be configured via environment variables or a `config.yaml` file. Envir
| `RUST_LOG` | Logging level | `info` |
| `STORAGE` | Storage backend (`local`, `memory`, `s3`) | `local` |
| `PDF_INGEST_MODE` | PDF ingestion strategy (`classic`, `llm-first`) | `llm-first` |
| `RETRIEVAL_STRATEGY` | Default retrieval strategy | - |
| `EMBEDDING_BACKEND` | Embedding provider (`openai`, `fastembed`) | `fastembed` |
| `FASTEMBED_CACHE_DIR` | Model cache directory | `<data_dir>/fastembed` |
| `FASTEMBED_SHOW_DOWNLOAD_PROGRESS` | Show progress bar for model downloads | `false` |
+6 -5
View File
@@ -27,13 +27,14 @@ The D3-based graph visualization shows entities as nodes and relationships as ed
## Hybrid Retrieval
Minne combines multiple retrieval strategies:
Minne uses chunk-first hybrid retrieval over the knowledge base:
- **Vector similarity** — Semantic matching via embeddings
- **Full-text search** — Keyword matching with BM25
- **Graph traversal** — Following relationships between entities
- **Vector similarity** — Semantic matching via embeddings over text chunks
- **Full-text search** — Keyword matching with BM25 over the same chunk index
Results are merged using Reciprocal Rank Fusion (RRF) for optimal relevance.
The two ranked candidate lists are merged with Reciprocal Rank Fusion (RRF). When a caller needs knowledge entities (search, ingestion linking, relationship suggestion), entities are derived from the top retrieved chunks grouped by `source_id`.
Optional **reranking** can rescore the fused chunk list with a cross-encoder model; see below.
## Reranking (Optional)
+7 -18
View File
@@ -5,7 +5,6 @@ use std::{
use anyhow::{anyhow, Context, Result};
use clap::{Args, Parser, ValueEnum};
use retrieval_pipeline::RetrievalStrategy;
use crate::datasets::DatasetKind;
@@ -55,10 +54,6 @@ pub struct RetrievalSettings {
#[arg(long)]
pub chunk_fts_take: Option<usize>,
/// Override average characters per token used for budgeting
#[arg(long)]
pub chunk_avg_chars_per_token: Option<usize>,
/// Override maximum chunks attached per entity
#[arg(long)]
pub max_chunks_per_entity: Option<usize>,
@@ -71,41 +66,37 @@ pub struct RetrievalSettings {
#[arg(long, default_value_t = 4)]
pub rerank_pool_size: usize,
/// Keep top-N entities after reranking
/// Keep top-N chunks after reranking
#[arg(long, default_value_t = 10)]
pub rerank_keep_top: usize,
/// Cap the number of chunks returned by retrieval (revised strategy)
/// Cap the number of chunks returned by retrieval
#[arg(long, default_value_t = 5)]
pub chunk_result_cap: usize,
/// Reciprocal rank fusion k value for revised chunk merging
/// Reciprocal rank fusion k value for chunk merging
#[arg(long)]
pub chunk_rrf_k: Option<f32>,
/// Weight for vector ranks in revised RRF
/// Weight for vector ranks in RRF
#[arg(long)]
pub chunk_rrf_vector_weight: Option<f32>,
/// Weight for chunk FTS ranks in revised RRF
/// Weight for chunk FTS ranks in RRF
#[arg(long)]
pub chunk_rrf_fts_weight: Option<f32>,
/// Include vector ranks in revised RRF (default: true)
/// Include vector ranks in RRF (default: true)
#[arg(long)]
pub chunk_rrf_use_vector: Option<bool>,
/// Include chunk FTS ranks in revised RRF (default: true)
/// Include chunk FTS ranks in RRF (default: true)
#[arg(long)]
pub chunk_rrf_use_fts: Option<bool>,
/// Require verified chunks (disable with --llm-mode)
#[arg(skip = true)]
pub require_verified_chunks: bool,
/// Select the retrieval pipeline strategy
#[arg(long, default_value_t = RetrievalStrategy::Default)]
pub strategy: RetrievalStrategy,
}
impl Default for RetrievalSettings {
@@ -113,7 +104,6 @@ impl Default for RetrievalSettings {
Self {
chunk_vector_take: None,
chunk_fts_take: None,
chunk_avg_chars_per_token: None,
max_chunks_per_entity: None,
rerank: false,
rerank_pool_size: 4,
@@ -125,7 +115,6 @@ impl Default for RetrievalSettings {
chunk_rrf_use_vector: None,
chunk_rrf_use_fts: None,
require_verified_chunks: true,
strategy: RetrievalStrategy::Default,
}
}
}
+18 -20
View File
@@ -51,8 +51,8 @@ pub fn mirror_perf_outputs(
pub fn print_console_summary(record: &EvaluationReport) {
let perf = &record.performance;
println!(
"[perf] retrieval strategy={} | concurrency={} | rerank={} (pool {:?}, keep {})",
record.retrieval.strategy,
"[perf] resolve_entities={} | concurrency={} | rerank={} (pool {:?}, keep {})",
record.retrieval.resolve_entities,
record.retrieval.concurrency,
record.retrieval.rerank_enabled,
record.retrieval.rerank_pool_size,
@@ -63,16 +63,14 @@ pub fn print_console_summary(record: &EvaluationReport) {
perf.ingestion_ms,
format_duration(perf.namespace_seed_ms),
);
let stage = &perf.stage_latency;
println!(
"[perf] stage avg ms → embed {:.1} | collect {:.1} | graph {:.1} | chunk {:.1} | rerank {:.1} | assemble {:.1}",
stage.embed.avg,
stage.collect_candidates.avg,
stage.graph_expansion.avg,
stage.chunk_attach.avg,
stage.rerank.avg,
stage.assemble.avg,
);
let stage_summary = perf
.stage_latency
.stages
.iter()
.map(|s| format!("{} {:.1}", s.stage, s.stats.avg))
.collect::<Vec<_>>()
.join(" | ");
println!("[perf] stage avg ms → {stage_summary}");
let eval = &perf.evaluation_stages_ms;
println!(
"[perf] eval stage ms → slice {} | db {} | corpus {} | namespace {} | queries {} | summarize {} | finalize {}",
@@ -107,12 +105,13 @@ mod tests {
fn sample_stage_latency() -> crate::eval::StageLatencyBreakdown {
crate::eval::StageLatencyBreakdown {
embed: sample_latency(),
collect_candidates: sample_latency(),
graph_expansion: sample_latency(),
chunk_attach: sample_latency(),
rerank: sample_latency(),
assemble: sample_latency(),
stages: ["embed", "search", "rerank", "resolve_entities", "assemble"]
.into_iter()
.map(|stage| crate::eval::StageLatency {
stage: stage.to_string(),
stats: sample_latency(),
})
.collect(),
}
}
@@ -193,7 +192,7 @@ mod tests {
rerank_keep_top: 10,
concurrency: 2,
detailed_report: false,
retrieval_strategy: "initial".into(),
resolve_entities: false,
chunk_result_cap: 5,
chunk_rrf_k: 60.0,
chunk_rrf_vector_weight: 1.0,
@@ -206,7 +205,6 @@ mod tests {
ingest_chunk_overlap_tokens: 50,
chunk_vector_take: 20,
chunk_fts_take: 20,
chunk_avg_chars_per_token: 4,
max_chunks_per_entity: 4,
cases: Vec::new(),
}
+7 -23
View File
@@ -6,7 +6,7 @@ use futures::stream::{self, StreamExt};
use tracing::{debug, info};
use crate::eval::{
adapt_strategy_output, build_case_diagnostics, text_contains_answer, CaseDiagnostics,
adapt_retrieval_output, build_case_diagnostics, text_contains_answer, CaseDiagnostics,
CaseSummary, RetrievedSummary,
};
use retrieval_pipeline::{
@@ -51,14 +51,8 @@ pub(crate) async fn run_queries(
None
};
let mut retrieval_config = RetrievalConfig {
strategy: config.retrieval.strategy,
..Default::default()
};
let mut retrieval_config = RetrievalConfig::default();
retrieval_config.tuning.rerank_keep_top = config.retrieval.rerank_keep_top;
if retrieval_config.tuning.fallback_min_results < config.retrieval.rerank_keep_top {
retrieval_config.tuning.fallback_min_results = config.retrieval.rerank_keep_top;
}
retrieval_config.tuning.chunk_result_cap = config.retrieval.chunk_result_cap.max(1);
if let Some(value) = config.retrieval.chunk_vector_take {
retrieval_config.tuning.chunk_vector_take = value;
@@ -81,9 +75,6 @@ pub(crate) async fn run_queries(
if let Some(value) = config.retrieval.chunk_rrf_use_fts {
retrieval_config.tuning.flags.chunk_rrf_use_fts = value.into();
}
if let Some(value) = config.retrieval.chunk_avg_chars_per_token {
retrieval_config.tuning.avg_chars_per_token = value;
}
if let Some(value) = config.retrieval.max_chunks_per_entity {
retrieval_config.tuning.max_chunks_per_entity = value;
}
@@ -187,7 +178,7 @@ pub(crate) async fn run_queries(
None => None,
};
let params = pipeline::StrategyParams {
let params = pipeline::RetrievalParams {
db_client: &db,
openai_client: &openai_client,
embedding_provider: Some(&embedding_provider),
@@ -196,26 +187,19 @@ pub(crate) async fn run_queries(
config: (*retrieval_config).clone(),
reranker,
};
let (result_output, pipeline_diagnostics, stage_timings) = if diagnostics_enabled {
let outcome = pipeline::run_pipeline_with_embedding_with_diagnostics(
let (result_output, pipeline_diagnostics, stage_timings) = {
let outcome = pipeline::run_with_embedding_instrumented(
params,
query_embedding,
diagnostics_enabled,
)
.await
.with_context(|| format!("running pipeline for question {question_id}"))?;
(outcome.results, outcome.diagnostics, outcome.stage_timings)
} else {
let outcome = pipeline::run_pipeline_with_embedding_with_metrics(
params,
query_embedding,
)
.await
.with_context(|| format!("running pipeline for question {question_id}"))?;
(outcome.results, None, outcome.stage_timings)
};
let query_latency = query_start.elapsed().as_millis();
let candidates = adapt_strategy_output(result_output);
let candidates = adapt_retrieval_output(result_output);
let mut retrieved = Vec::new();
let mut match_rank = None;
let answers_lower: Vec<String> =
+4 -2
View File
@@ -201,7 +201,10 @@ pub(crate) async fn summarize(
rerank_keep_top: config.retrieval.rerank_keep_top,
concurrency: config.concurrency.max(1),
detailed_report: config.detailed_report,
retrieval_strategy: config.retrieval.strategy.to_string(),
resolve_entities: ctx
.retrieval_config
.as_ref()
.is_some_and(|config| config.resolve_entities),
chunk_result_cap: config.retrieval.chunk_result_cap,
chunk_rrf_k: active_tuning.chunk_rrf_k,
chunk_rrf_vector_weight: active_tuning.chunk_rrf_vector_weight,
@@ -214,7 +217,6 @@ pub(crate) async fn summarize(
ingest_chunk_overlap_tokens: config.ingest.ingest_chunk_overlap_tokens,
chunk_vector_take: active_tuning.chunk_vector_take,
chunk_fts_take: active_tuning.chunk_fts_take,
chunk_avg_chars_per_token: active_tuning.avg_chars_per_token,
max_chunks_per_entity: active_tuning.max_chunks_per_entity,
cases: summaries,
});
+54 -49
View File
@@ -85,7 +85,7 @@ pub struct RetrievalSection {
pub average_ndcg: f64,
pub latency: LatencyStats,
pub concurrency: usize,
pub strategy: String,
pub resolve_entities: bool,
pub rerank_enabled: bool,
pub rerank_pool_size: Option<usize>,
pub rerank_keep_top: usize,
@@ -226,7 +226,7 @@ impl EvaluationReport {
average_ndcg: summary.average_ndcg,
latency: summary.latency_ms.clone(),
concurrency: summary.concurrency,
strategy: summary.retrieval_strategy.clone(),
resolve_entities: summary.resolve_entities,
rerank_enabled: summary.rerank_enabled,
rerank_pool_size: summary.rerank_pool_size,
rerank_keep_top: summary.rerank_keep_top,
@@ -463,7 +463,12 @@ fn render_markdown(report: &EvaluationReport) -> String {
write!(md, "| MRR | {:.3} |\\n", report.retrieval.mrr).unwrap();
write!(md, "| NDCG | {:.3} |\\n", report.retrieval.average_ndcg).unwrap();
write!(md, "| Latency Avg / P50 / P95 (ms) | {:.1} / {} / {} |\\n", report.retrieval.latency.avg, report.retrieval.latency.p50, report.retrieval.latency.p95).unwrap();
write!(md, "| Strategy | `{}` |\\n", report.retrieval.strategy).unwrap();
write!(
md,
"| Resolve entities | {} |\\n",
bool_badge(report.retrieval.resolve_entities)
)
.unwrap();
write!(md, "| Concurrency | {} |\\n", report.retrieval.concurrency).unwrap();
if report.retrieval.rerank_enabled {
let pool = report
@@ -510,28 +515,9 @@ fn render_markdown(report: &EvaluationReport) -> String {
md.push_str("\\n## Retrieval Stage Timings\\n\\n");
md.push_str("| Stage | Avg (ms) | P50 (ms) | P95 (ms) |\\n| --- | --- | --- | --- |\\n");
write_stage_row(&mut md, "Embed", &report.performance.stage_latency.embed);
write_stage_row(
&mut md,
"Collect Candidates",
&report.performance.stage_latency.collect_candidates,
);
write_stage_row(
&mut md,
"Graph Expansion",
&report.performance.stage_latency.graph_expansion,
);
write_stage_row(
&mut md,
"Chunk Attach",
&report.performance.stage_latency.chunk_attach,
);
write_stage_row(&mut md, "Rerank", &report.performance.stage_latency.rerank);
write_stage_row(
&mut md,
"Assemble",
&report.performance.stage_latency.assemble,
);
for stage in &report.performance.stage_latency.stages {
write_stage_row(&mut md, &prettify_stage(&stage.stage), &stage.stats);
}
if report.misses.is_empty() {
if report.detailed_report {
@@ -623,6 +609,20 @@ fn write_stage_row(buf: &mut String, label: &str, stats: &LatencyStats) {
.unwrap();
}
/// Turn a stable stage label (e.g. `resolve_entities`) into a display title (`Resolve Entities`).
fn prettify_stage(label: &str) -> String {
label
.split('_')
.map(|word| {
let mut chars = word.chars();
chars.next().map_or_else(String::new, |first| {
first.to_uppercase().collect::<String>() + chars.as_str()
})
})
.collect::<Vec<_>>()
.join(" ")
}
fn bool_badge(value: bool) -> &'static str {
if value {
""
@@ -740,17 +740,6 @@ struct LegacyHistoryDelta {
latency_avg_ms: f64,
}
fn default_stage_latency() -> StageLatencyBreakdown {
StageLatencyBreakdown {
embed: LatencyStats::default(),
collect_candidates: LatencyStats::default(),
graph_expansion: LatencyStats::default(),
chunk_attach: LatencyStats::default(),
rerank: LatencyStats::default(),
assemble: LatencyStats::default(),
}
}
#[allow(clippy::too_many_lines)]
fn convert_legacy_entry(entry: LegacyHistoryEntry) -> EvaluationReport {
let overview = OverviewSection {
@@ -807,7 +796,7 @@ fn convert_legacy_entry(entry: LegacyHistoryEntry) -> EvaluationReport {
average_ndcg: entry.average_ndcg,
latency: entry.latency_ms,
concurrency: 0,
strategy: "unknown".into(),
resolve_entities: false,
rerank_enabled: entry.rerank_enabled,
rerank_pool_size: entry.rerank_pool_size,
rerank_keep_top: entry.rerank_keep_top,
@@ -840,7 +829,7 @@ fn convert_legacy_entry(entry: LegacyHistoryEntry) -> EvaluationReport {
ingestion_ms: entry.ingestion_ms,
namespace_seed_ms: entry.namespace_seed_ms,
evaluation_stages_ms: EvaluationStageTimings::default(),
stage_latency: default_stage_latency(),
stage_latency: StageLatencyBreakdown::default(),
namespace_reused: false,
ingestion_reused: entry.ingestion_reused,
embeddings_reused: entry.ingestion_embeddings_reused,
@@ -915,7 +904,8 @@ fn record_history(report: &EvaluationReport, report_dir: &Path) -> Result<PathBu
mod tests {
use super::*;
use crate::eval::{
EvaluationStageTimings, PerformanceTimings, RetrievedSummary, StageLatencyBreakdown,
EvaluationStageTimings, PerformanceTimings, RetrievedSummary, StageLatency,
StageLatencyBreakdown,
};
use chrono::Utc;
use tempfile::tempdir;
@@ -931,12 +921,28 @@ mod tests {
fn sample_stage_latency() -> StageLatencyBreakdown {
StageLatencyBreakdown {
embed: latency(9.0),
collect_candidates: latency(10.0),
graph_expansion: latency(11.0),
chunk_attach: latency(12.0),
rerank: latency(13.0),
assemble: latency(14.0),
stages: vec![
StageLatency {
stage: "embed".to_string(),
stats: latency(9.0),
},
StageLatency {
stage: "search".to_string(),
stats: latency(10.0),
},
StageLatency {
stage: "rerank".to_string(),
stats: latency(13.0),
},
StageLatency {
stage: "resolve_entities".to_string(),
stats: latency(11.0),
},
StageLatency {
stage: "assemble".to_string(),
stats: latency(14.0),
},
],
}
}
@@ -1058,7 +1064,7 @@ mod tests {
rerank_keep_top: 5,
concurrency: 2,
detailed_report: true,
retrieval_strategy: "initial".into(),
resolve_entities: false,
chunk_result_cap: 5,
chunk_rrf_k: 60.0,
chunk_rrf_vector_weight: 1.0,
@@ -1071,7 +1077,6 @@ mod tests {
ingest_chunks_only: false,
chunk_vector_take: 50,
chunk_fts_take: 50,
chunk_avg_chars_per_token: 4,
max_chunks_per_entity: 4,
cases,
}
@@ -1097,7 +1102,7 @@ mod tests {
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::indexing_slicing)]
#[test]
fn evaluations_history_captures_strategy_and_concurrency() {
fn evaluations_history_captures_resolve_entities_and_concurrency() {
let tmp = tempdir().unwrap();
let summary = sample_summary(false);
@@ -1109,7 +1114,7 @@ mod tests {
assert_eq!(entries.len(), 1);
let stored = &entries[0];
assert_eq!(stored.retrieval.concurrency, summary.concurrency);
assert_eq!(stored.retrieval.strategy, summary.retrieval_strategy);
assert_eq!(stored.retrieval.resolve_entities, summary.resolve_entities);
assert_eq!(
stored.performance.evaluation_stages_ms.run_queries_ms,
summary.perf.evaluation_stage_ms.run_queries_ms
+33 -38
View File
@@ -3,7 +3,7 @@ use std::collections::HashSet;
use chrono::{DateTime, Utc};
use common::storage::types::StoredObject;
use retrieval_pipeline::{
PipelineDiagnostics, PipelineStageTimings, RetrievedChunk, RetrievedEntity, StrategyOutput,
Diagnostics, RetrievalOutput, RetrievedChunk, RetrievedEntity, StageKind, StageTimings,
};
use serde::{Deserialize, Serialize};
use unicode_normalization::UnicodeNormalization;
@@ -69,7 +69,7 @@ pub struct EvaluationSummary {
pub rerank_keep_top: usize,
pub concurrency: usize,
pub detailed_report: bool,
pub retrieval_strategy: String,
pub resolve_entities: bool,
pub chunk_result_cap: usize,
pub chunk_rrf_k: f32,
pub chunk_rrf_vector_weight: f32,
@@ -82,7 +82,6 @@ pub struct EvaluationSummary {
pub ingest_chunk_overlap_tokens: usize,
pub chunk_vector_take: usize,
pub chunk_fts_take: usize,
pub chunk_avg_chars_per_token: usize,
pub max_chunks_per_entity: usize,
pub cases: Vec<CaseSummary>,
}
@@ -129,14 +128,20 @@ impl Default for LatencyStats {
}
}
/// Latency statistics for a single retrieval stage, keyed by the stage's stable label.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StageLatency {
pub stage: String,
pub stats: LatencyStats,
}
/// Per-stage retrieval latency, in canonical pipeline order.
///
/// The set of stages is driven entirely by [`StageKind::ALL`], so adding a retrieval stage
/// surfaces here automatically without changes to the evaluation harness.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct StageLatencyBreakdown {
pub embed: LatencyStats,
pub collect_candidates: LatencyStats,
pub graph_expansion: LatencyStats,
pub chunk_attach: LatencyStats,
pub rerank: LatencyStats,
pub assemble: LatencyStats,
pub stages: Vec<StageLatency>,
}
#[allow(clippy::struct_field_names)]
@@ -232,13 +237,12 @@ fn candidates_from_chunks(chunks: Vec<RetrievedChunk>) -> Vec<EvaluationCandidat
.collect()
}
pub fn adapt_strategy_output(output: StrategyOutput) -> Vec<EvaluationCandidate> {
pub fn adapt_retrieval_output(output: RetrievalOutput) -> Vec<EvaluationCandidate> {
match output {
StrategyOutput::Entities(entities) => candidates_from_entities(entities),
StrategyOutput::Chunks(chunks) => candidates_from_chunks(chunks),
StrategyOutput::Search(search_result) => {
let mut candidates = candidates_from_entities(search_result.entities);
candidates.extend(candidates_from_chunks(search_result.chunks));
RetrievalOutput::Chunks(chunks) => candidates_from_chunks(chunks),
RetrievalOutput::WithEntities { chunks, entities } => {
let mut candidates = candidates_from_entities(entities);
candidates.extend(candidates_from_chunks(chunks));
candidates.sort_by(|a, b| b.score.total_cmp(&a.score));
candidates
}
@@ -262,7 +266,7 @@ pub struct CaseDiagnostics {
pub attached_chunk_ids: Vec<String>,
pub retrieved: Vec<EntityDiagnostics>,
#[serde(skip_serializing_if = "Option::is_none")]
pub pipeline: Option<PipelineDiagnostics>,
pub pipeline: Option<Diagnostics>,
}
#[derive(Debug, Serialize)]
@@ -366,28 +370,19 @@ pub fn compute_latency_stats(latencies: &[u128]) -> LatencyStats {
LatencyStats { avg, p50, p95 }
}
pub fn build_stage_latency_breakdown(samples: &[PipelineStageTimings]) -> StageLatencyBreakdown {
fn collect_stage<F>(samples: &[PipelineStageTimings], selector: F) -> Vec<u128>
where
F: Fn(&PipelineStageTimings) -> u128,
{
samples.iter().map(selector).collect()
}
pub fn build_stage_latency_breakdown(samples: &[StageTimings]) -> StageLatencyBreakdown {
let stages = StageKind::ALL
.iter()
.map(|kind| {
let latencies: Vec<u128> = samples.iter().map(|s| s.stage_ms(*kind)).collect();
StageLatency {
stage: kind.label().to_string(),
stats: compute_latency_stats(&latencies),
}
})
.collect();
StageLatencyBreakdown {
embed: compute_latency_stats(&collect_stage(samples, retrieval_pipeline::StageTimings::embed_ms)),
collect_candidates: compute_latency_stats(&collect_stage(samples, |entry| {
entry.collect_candidates_ms()
})),
graph_expansion: compute_latency_stats(&collect_stage(samples, |entry| {
entry.graph_expansion_ms()
})),
chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| {
entry.chunk_attach_ms()
})),
rerank: compute_latency_stats(&collect_stage(samples, retrieval_pipeline::StageTimings::rerank_ms)),
assemble: compute_latency_stats(&collect_stage(samples, retrieval_pipeline::StageTimings::assemble_ms)),
}
StageLatencyBreakdown { stages }
}
#[allow(
@@ -412,7 +407,7 @@ pub fn build_case_diagnostics(
expected_chunk_ids: &[String],
answers_lower: &[String],
candidates: &[EvaluationCandidate],
pipeline_stats: Option<PipelineDiagnostics>,
pipeline_stats: Option<Diagnostics>,
) -> CaseDiagnostics {
let expected_set: HashSet<&str> = expected_chunk_ids.iter().map(std::string::String::as_str).collect();
let mut seen_chunks: HashSet<String> = HashSet::new();
-91
View File
@@ -44,7 +44,6 @@
--leading-snug: 1.375;
--leading-relaxed: 1.625;
--ease-out: cubic-bezier(0, 0, 0.2, 1);
--ease-in-out: cubic-bezier(0.4, 0, 0.2, 1);
--animate-pulse: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite;
--default-transition-duration: 150ms;
--default-transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
@@ -285,37 +284,6 @@
}
}
}
.drawer-open {
> .drawer-side {
overflow-y: auto;
}
> .drawer-toggle {
display: none;
& ~ .drawer-side {
pointer-events: auto;
visibility: visible;
position: sticky;
display: block;
width: auto;
overscroll-behavior: auto;
opacity: 100%;
& > .drawer-overlay {
cursor: default;
background-color: transparent;
}
& > *:not(.drawer-overlay) {
translate: 0%;
[dir="rtl"] & {
translate: 0%;
}
}
}
&:checked ~ .drawer-side {
pointer-events: auto;
visibility: visible;
}
}
}
.drawer-toggle {
position: fixed;
height: calc(0.25rem * 0);
@@ -1074,22 +1042,6 @@
grid-row-start: 1;
min-width: calc(0.25rem * 0);
}
.chat-image {
grid-row: span 2 / span 2;
align-self: flex-end;
}
.chat-footer {
grid-row-start: 3;
display: flex;
gap: calc(0.25rem * 1);
font-size: 0.6875rem;
}
.chat-header {
grid-row-start: 1;
display: flex;
gap: calc(0.25rem * 1);
font-size: 0.6875rem;
}
.container {
width: 100%;
@media (width >= 40rem) {
@@ -1796,9 +1748,6 @@
.w-10 {
width: calc(var(--spacing) * 10);
}
.w-11 {
width: calc(var(--spacing) * 11);
}
.w-11\/12 {
width: calc(11/12 * 100%);
}
@@ -1862,9 +1811,6 @@
.flex-none {
flex: none;
}
.flex-shrink {
flex-shrink: 1;
}
.flex-shrink-0 {
flex-shrink: 0;
}
@@ -1877,13 +1823,6 @@
.grow {
flex-grow: 1;
}
.border-collapse {
border-collapse: collapse;
}
.-translate-y-1 {
--tw-translate-y: calc(var(--spacing) * -1);
translate: var(--tw-translate-x) var(--tw-translate-y);
}
.-translate-y-1\/2 {
--tw-translate-y: calc(calc(1/2 * 100%) * -1);
translate: var(--tw-translate-x) var(--tw-translate-y);
@@ -1956,9 +1895,6 @@
.justify-start {
justify-content: flex-start;
}
.gap-0 {
gap: calc(var(--spacing) * 0);
}
.gap-0\.5 {
gap: calc(var(--spacing) * 0.5);
}
@@ -2115,9 +2051,6 @@
.bg-transparent {
background-color: transparent;
}
.bg-warning {
background-color: var(--color-warning);
}
.bg-warning\/10 {
background-color: var(--color-warning);
@supports (color: color-mix(in lab, red, red)) {
@@ -2136,9 +2069,6 @@
.loading-spinner {
mask-image: url("data:image/svg+xml,%3Csvg width='24' height='24' stroke='black' viewBox='0 0 24 24' xmlns='http://www.w3.org/2000/svg'%3E%3Cg transform-origin='center'%3E%3Ccircle cx='12' cy='12' r='9.5' fill='none' stroke-width='3' stroke-linecap='round'%3E%3CanimateTransform attributeName='transform' type='rotate' from='0 12 12' to='360 12 12' dur='2s' repeatCount='indefinite'/%3E%3Canimate attributeName='stroke-dasharray' values='0,150;42,150;42,150' keyTimes='0;0.475;1' dur='1.5s' repeatCount='indefinite'/%3E%3Canimate attributeName='stroke-dashoffset' values='0;-16;-59' keyTimes='0;0.475;1' dur='1.5s' repeatCount='indefinite'/%3E%3C/circle%3E%3C/g%3E%3C/svg%3E");
}
.mask-repeat {
mask-repeat: repeat;
}
.fill-current {
fill: currentcolor;
}
@@ -2169,9 +2099,6 @@
.p-8 {
padding: calc(var(--spacing) * 8);
}
.px-1 {
padding-inline: calc(var(--spacing) * 1);
}
.px-1\.5 {
padding-inline: calc(var(--spacing) * 1.5);
}
@@ -2326,9 +2253,6 @@
--tw-tracking: var(--tracking-widest);
letter-spacing: var(--tracking-widest);
}
.text-wrap {
text-wrap: wrap;
}
.break-words {
overflow-wrap: break-word;
}
@@ -2395,17 +2319,6 @@
.italic {
font-style: italic;
}
.underline {
text-decoration-line: underline;
}
.swap-active {
.swap-off {
opacity: 0%;
}
.swap-on {
opacity: 100%;
}
}
.opacity-0 {
opacity: 0%;
}
@@ -2496,10 +2409,6 @@
--tw-duration: 300ms;
transition-duration: 300ms;
}
.ease-in-out {
--tw-ease: var(--ease-in-out);
transition-timing-function: var(--ease-in-out);
}
.ease-out {
--tw-ease: var(--ease-out);
transition-timing-function: var(--ease-out);
+1 -8
View File
@@ -2,10 +2,7 @@ use common::storage::types::conversation::SidebarConversation;
use common::storage::{db::SurrealDbClient, store::StorageManager};
use common::utils::embedding::EmbeddingProvider;
use common::utils::template_engine::{ProvidesTemplateEngine, TemplateEngine};
use common::{
create_template_engine, storage::db::ProvidesDb,
utils::config::{AppConfig, RetrievalStrategy},
};
use common::{create_template_engine, storage::db::ProvidesDb, utils::config::AppConfig};
use retrieval_pipeline::reranking::RerankerPool;
use std::collections::HashMap;
use std::sync::{
@@ -75,10 +72,6 @@ impl HtmlState {
}
}
pub fn retrieval_strategy(&self) -> RetrievalStrategy {
self.config.resolved_retrieval_strategy()
}
pub async fn get_cached_conversation_archive(
&self,
user_id: &str,
@@ -16,12 +16,9 @@ use futures::{
};
use json_stream_parser::JsonStreamParser;
use minijinja::Value;
use retrieval_pipeline::{
answer_retrieval::{
chunks_to_chat_context, create_chat_request, create_user_message_with_history,
LLMResponseFormat,
},
retrieved_entities_to_json,
use retrieval_pipeline::answer_retrieval::{
chunks_to_chat_context, create_chat_request, create_user_message_with_history,
LLMResponseFormat,
};
use serde::{Deserialize, Serialize};
use serde_json::from_str;
@@ -189,11 +186,7 @@ struct ReferenceData {
}
fn extract_reference_strings(response: &LLMResponseFormat) -> Vec<String> {
response
.references
.iter()
.map(|reference| reference.reference.clone())
.collect()
response.reference_ids()
}
#[allow(clippy::too_many_lines)]
@@ -362,10 +355,9 @@ async fn prepare_chat_request(
None => None,
};
let strategy = state.retrieval_strategy();
let config = retrieval_pipeline::RetrievalConfig::for_chat(strategy);
let config = retrieval_pipeline::RetrievalConfig::default();
let retrieval_result = match retrieval_pipeline::retrieve_entities(
let retrieval_result = match retrieval_pipeline::retrieve(
&state.db,
&state.openai_client,
Some(&*state.embedding_provider),
@@ -387,12 +379,9 @@ async fn prepare_chat_request(
let allowed_reference_ids = collect_reference_ids_from_retrieval(&retrieval_result);
let context_json = match retrieval_result {
retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(&chunks),
retrieval_pipeline::StrategyOutput::Entities(entities) => {
retrieved_entities_to_json(&entities)
}
retrieval_pipeline::StrategyOutput::Search(search_result) => {
chunks_to_chat_context(&search_result.chunks)
retrieval_pipeline::RetrievalOutput::Chunks(chunks) => chunks_to_chat_context(&chunks),
retrieval_pipeline::RetrievalOutput::WithEntities { chunks, .. } => {
chunks_to_chat_context(&chunks)
}
};
let formatted_user_message =
@@ -9,7 +9,7 @@ use common::{
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject},
},
};
use retrieval_pipeline::StrategyOutput;
use retrieval_pipeline::RetrievalOutput;
use uuid::Uuid;
pub(crate) const MAX_REFERENCE_COUNT: usize = 10;
@@ -86,40 +86,29 @@ pub(crate) enum ReferenceLookupTarget {
}
pub(crate) fn collect_reference_ids_from_retrieval(
retrieval_result: &StrategyOutput,
retrieval_result: &RetrievalOutput,
) -> Vec<String> {
let mut ids = Vec::new();
let mut seen = HashSet::new();
let mut push_id = |id: String| {
if seen.insert(id.clone()) {
ids.push(id);
}
};
match retrieval_result {
StrategyOutput::Chunks(chunks) => {
RetrievalOutput::Chunks(chunks) => {
for chunk in chunks {
let id = chunk.chunk.id.clone();
if seen.insert(id.clone()) {
ids.push(id);
}
push_id(chunk.chunk.id.clone());
}
}
StrategyOutput::Entities(entities) => {
RetrievalOutput::WithEntities { chunks, entities } => {
for chunk in chunks {
push_id(chunk.chunk.id.clone());
}
for entity in entities {
let id = entity.entity.id.clone();
if seen.insert(id.clone()) {
ids.push(id);
}
}
}
StrategyOutput::Search(search) => {
for chunk in &search.chunks {
let id = chunk.chunk.id.clone();
if seen.insert(id.clone()) {
ids.push(id);
}
}
for entity in &search.entities {
let id = entity.entity.id.clone();
if seen.insert(id.clone()) {
ids.push(id);
}
push_id(entity.entity.id.clone());
}
}
}
+1 -1
View File
@@ -13,7 +13,7 @@ use crate::{
middlewares::{
auth_middleware::RequireUser,
response_middleware::{
template_as_response, HtmlError, TemplateResponse, TemplateResult, ResponseResult,
template_as_response, TemplateResponse, TemplateResult, ResponseResult,
},
},
utils::text_content_preview::truncate_text_contents,
+5 -5
View File
@@ -32,7 +32,7 @@ use crate::{
middlewares::{
auth_middleware::RequireUser,
response_middleware::{
template_with_headers, HtmlError, TemplateResponse, TemplateResult, ResponseResult,
template_with_headers, TemplateResponse, TemplateResult, ResponseResult,
},
},
utils::pagination::{paginate_items, Pagination},
@@ -284,9 +284,9 @@ pub async fn suggest_knowledge_relationships(
None => None,
};
let config = retrieval_pipeline::RetrievalConfig::for_relationship_suggestion();
if let Ok(retrieval_pipeline::StrategyOutput::Entities(results)) =
retrieval_pipeline::retrieve_entities(
let config = retrieval_pipeline::RetrievalConfig::with_entities();
if let Ok(retrieval_pipeline::RetrievalOutput::WithEntities { entities, .. }) =
retrieval_pipeline::retrieve(
&state.db,
&state.openai_client,
Some(&*state.embedding_provider),
@@ -297,7 +297,7 @@ pub async fn suggest_knowledge_relationships(
)
.await
{
for retrieval_pipeline::RetrievedEntity { entity, score, .. } in results {
for retrieval_pipeline::RetrievedEntity { entity, score, .. } in entities {
if suggestion_scores.len() >= MAX_RELATIONSHIP_SUGGESTIONS {
break;
}
@@ -12,7 +12,7 @@ use crate::html_state::HtmlState;
use crate::middlewares::{
auth_middleware::RequireUser,
response_middleware::{
template_with_headers, HtmlError, TemplateResponse, TemplateResult, ResponseResult,
template_with_headers, TemplateResponse, TemplateResult, ResponseResult,
},
};
use common::storage::types::{
+22 -21
View File
@@ -4,7 +4,7 @@ use axum::{
extract::{Query, State},
};
use common::storage::types::{text_content::TextContent, user::User};
use retrieval_pipeline::{RetrievalConfig, SearchResult, SearchTarget, StrategyOutput};
use retrieval_pipeline::{retrieve, RetrievalConfig, RetrievalOutput, RetrievedChunk, RetrievedEntity};
use serde::{de, Deserialize, Deserializer, Serialize};
use std::{fmt, str::FromStr};
@@ -108,35 +108,35 @@ async fn perform_search(
return Ok((Vec::new(), String::new()));
}
let config = RetrievalConfig::for_search(SearchTarget::Both);
let config = RetrievalConfig::with_entities();
let reranker_lease = match &state.reranker_pool {
Some(pool) => pool.checkout().await,
None => None,
};
let params = retrieval_pipeline::pipeline::StrategyParams {
db_client: &state.db,
openai_client: &state.openai_client,
embedding_provider: Some(&state.embedding_provider),
input_text: trimmed_query,
user_id: &user.id,
let result = retrieve(
&state.db,
&state.openai_client,
Some(&state.embedding_provider),
trimmed_query,
&user.id,
config,
reranker: reranker_lease,
};
let result = retrieval_pipeline::pipeline::execute(params).await?;
reranker_lease,
)
.await?;
let search_result = match result {
StrategyOutput::Search(sr) => sr,
_ => SearchResult::new(vec![], vec![]),
let (chunks, entities) = match result {
RetrievalOutput::WithEntities { chunks, entities } => (chunks, entities),
RetrievalOutput::Chunks(chunks) => (chunks, Vec::new()),
};
let source_label_map = collect_source_label_map(state, user, &search_result).await?;
let source_label_map = collect_source_label_map(state, user, &chunks, &entities).await?;
let mut combined_results: Vec<SearchResultForTemplate> =
Vec::with_capacity(search_result.chunks.len().saturating_add(search_result.entities.len()));
Vec::with_capacity(chunks.len().saturating_add(entities.len()));
for chunk_result in search_result.chunks {
for chunk_result in chunks {
let source_label = source_label_map
.get(&chunk_result.chunk.source_id)
.cloned()
@@ -155,7 +155,7 @@ async fn perform_search(
});
}
for entity_result in search_result.entities {
for entity_result in entities {
let source_label = source_label_map
.get(&entity_result.entity.source_id)
.cloned()
@@ -187,13 +187,14 @@ async fn perform_search(
async fn collect_source_label_map(
state: &HtmlState,
user: &User,
search_result: &SearchResult,
chunks: &[RetrievedChunk],
entities: &[RetrievedEntity],
) -> Result<std::collections::HashMap<String, String>, HtmlError> {
let mut source_ids = HashSet::new();
for chunk_result in &search_result.chunks {
for chunk_result in chunks {
source_ids.insert(chunk_result.chunk.source_id.clone());
}
for entity_result in &search_result.entities {
for entity_result in entities {
source_ids.insert(entity_result.entity.source_id.clone());
}
+17 -22
View File
@@ -183,10 +183,8 @@ impl PipelineServices for DefaultPipelineServices {
None => None,
};
let config = retrieval_pipeline::RetrievalConfig::for_search(
retrieval_pipeline::SearchTarget::EntitiesOnly,
);
match retrieval_pipeline::retrieve_entities(
let config = retrieval_pipeline::RetrievalConfig::with_entities();
match retrieval_pipeline::retrieve(
&self.db,
&self.openai_client,
Some(&*self.embedding_provider),
@@ -197,19 +195,16 @@ impl PipelineServices for DefaultPipelineServices {
)
.await
{
Ok(retrieval_pipeline::StrategyOutput::Entities(entities)) => Ok(entities),
Ok(retrieval_pipeline::StrategyOutput::Search(search)) => {
let chunk_count = search.chunks.len();
let entities = search.entities;
Ok(retrieval_pipeline::RetrievalOutput::WithEntities { chunks, entities }) => {
tracing::debug!(
chunk_count,
chunk_count = chunks.len(),
entity_count = entities.len(),
"ingestion search results returned entities"
"ingestion retrieval resolved entities from chunks"
);
Ok(entities)
}
Ok(retrieval_pipeline::StrategyOutput::Chunks(_)) => Err(AppError::InternalError(
"Ingestion retrieval should return entities".into(),
Ok(retrieval_pipeline::RetrievalOutput::Chunks(_)) => Err(AppError::InternalError(
"Ingestion retrieval should resolve entities".into(),
)),
Err(e) => Err(e),
}
@@ -372,16 +367,16 @@ mod tests {
fn system_prompt_from_request(
request: &async_openai::types::CreateChatCompletionRequest,
) -> String {
let ChatCompletionRequestMessage::System(system) = &request.messages[0] else {
panic!("expected first message to be system");
) -> anyhow::Result<String> {
let Some(ChatCompletionRequestMessage::System(system)) = request.messages.first() else {
anyhow::bail!("expected first message to be system");
};
match &system.content {
async_openai::types::ChatCompletionRequestSystemMessageContent::Text(text) => {
text.clone()
}
other => panic!("unexpected system message content: {other:?}"),
}
let async_openai::types::ChatCompletionRequestSystemMessageContent::Text(text) =
&system.content
else {
anyhow::bail!("unexpected system message content: {:?}", system.content);
};
Ok(text.clone())
}
#[tokio::test]
@@ -425,7 +420,7 @@ mod tests {
.await
.context("prepare llm request")?;
assert_eq!(system_prompt_from_request(&request), SENTINEL);
assert_eq!(system_prompt_from_request(&request)?, SENTINEL);
Ok(())
}
}
@@ -125,10 +125,10 @@ async fn render_pdf_pages(file_path: &Path, pages: &[u32]) -> Result<Vec<Vec<u8>
})
.await??;
for (idx, png) in captures.iter().enumerate() {
if let Err(err) = maybe_dump_debug_image(page_numbers[idx], png).await {
for (page_number, png) in page_numbers.iter().zip(captures.iter()) {
if let Err(err) = maybe_dump_debug_image(*page_number, png).await {
warn!(
page = page_numbers[idx],
page = page_number,
error = %err,
"Failed to write debug screenshot to disk"
);
+1
View File
@@ -95,6 +95,7 @@ pub(crate) async fn init_with_config(config: AppConfig) -> anyhow::Result<Shared
}
#[cfg(test)]
#[allow(dead_code)] // helpers are shared across binary test targets
pub(crate) mod tests {
use std::path::Path;
+4 -1
View File
@@ -11,8 +11,9 @@ use html_router::{
use super::SharedServices;
/// Builds the Minne API and HTML route subtrees without fixing the outer Axum state
/// type. SaaS consumers can merge additional routers and attach their own `AppState`
/// type. `SaaS` consumers can merge additional routers and attach their own `AppState`
/// as long as it implements `FromRef` for `ApiState` and `HtmlState`.
#[allow(dead_code)] // used by server/main binaries, not worker
pub fn minne_routes<S>(api_state: &ApiState, html_state: &HtmlState) -> Router<S>
where
S: Clone + Send + Sync + 'static,
@@ -24,6 +25,7 @@ where
.merge(html_routes(html_state))
}
#[allow(dead_code)] // used by server/main binaries, not worker
pub fn build_api_state(services: &SharedServices) -> ApiState {
ApiState {
db: Arc::clone(&services.db),
@@ -32,6 +34,7 @@ pub fn build_api_state(services: &SharedServices) -> ApiState {
}
}
#[allow(dead_code)] // used by server/main binaries, not worker
pub async fn build_html_state(services: &SharedServices) -> anyhow::Result<HtmlState> {
let session_store = Arc::new(
services
+1
View File
@@ -75,6 +75,7 @@ struct AppState {
}
#[cfg(test)]
#[allow(clippy::expect_used)]
mod tests {
use super::*;
use axum::{
+4 -6
View File
@@ -10,16 +10,14 @@ workspace = true
[dependencies]
tokio = { workspace = true }
serde = { workspace = true }
axum = { workspace = true }
tracing = { workspace = true }
anyhow = { workspace = true }
thiserror = { workspace = true }
serde_json = { workspace = true }
surrealdb = { workspace = true }
futures = { workspace = true }
async-openai = { workspace = true }
async-trait = { workspace = true }
uuid = { workspace = true }
fastembed = { workspace = true }
common = { path = "../common", features = ["test-utils"] }
[dev-dependencies]
anyhow = { workspace = true }
uuid = { workspace = true }
+43 -53
View File
@@ -1,61 +1,66 @@
//! Chat answer assembly: retrieval context formatting and structured LLM request/response types.
use async_openai::{
error::OpenAIError,
types::{
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
ResponseFormat, ResponseFormatJsonSchema,
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, ResponseFormat,
ResponseFormatJsonSchema,
},
};
use common::{
error::AppError,
storage::types::{
message::{format_history, Message},
system_settings::SystemSettings,
},
use common::storage::types::{
message::{format_history, Message},
system_settings::SystemSettings,
};
use serde::Deserialize;
use serde_json::Value;
use serde_json::{json, Value};
use super::answer_retrieval_helper::get_query_response_schema;
/// JSON schema describing the structured chat answer (answer text + references).
fn get_query_response_schema() -> Value {
json!({
"type": "object",
"properties": {
"answer": { "type": "string" },
"references": {
"type": "array",
"items": {
"type": "object",
"properties": {
"reference": { "type": "string" },
},
"required": ["reference"],
"additionalProperties": false,
}
}
},
"required": ["answer", "references"],
"additionalProperties": false
})
}
#[derive(Debug, Deserialize)]
pub struct Reference {
#[allow(dead_code)]
pub reference: String,
}
#[derive(Debug, Deserialize)]
pub struct LLMResponseFormat {
pub answer: String,
#[allow(dead_code)]
pub references: Vec<Reference>,
}
#[derive(Debug)]
pub struct Answer {
pub content: String,
pub references: Vec<String>,
}
pub fn create_user_message(entities_json: &Value, query: &str) -> String {
format!(
r"
Context Information:
==================
{entities_json}
User Question:
==================
{query}
"
)
}
/// Convert chunk-based retrieval results to JSON format for LLM context
pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value {
fn round_score(value: f32) -> f64 {
(f64::from(value) * 1000.0).round() / 1000.0
impl LLMResponseFormat {
pub fn reference_ids(&self) -> Vec<String> {
self.references
.iter()
.map(|entry| entry.reference.clone())
.collect()
}
}
/// Convert chunk-based retrieval results to JSON format for LLM context.
pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value {
use crate::round_score;
serde_json::json!(chunks
.iter()
@@ -70,7 +75,7 @@ pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value {
}
pub fn create_user_message_with_history(
entities_json: &Value,
context_json: &Value,
history: &[Message],
query: &str,
) -> String {
@@ -89,7 +94,7 @@ pub fn create_user_message_with_history(
{}
",
format_history(history),
entities_json,
context_json,
query
)
}
@@ -116,18 +121,3 @@ pub fn create_chat_request(
.response_format(response_format)
.build()
}
pub fn process_llm_response(
response: &CreateChatCompletionResponse,
) -> Result<LLMResponseFormat, Box<AppError>> {
response
.choices
.first()
.and_then(|choice| choice.message.content.as_ref())
.ok_or_else(|| Box::new(AppError::LLMParsing("No content found in LLM response".into())))
.and_then(|content| {
serde_json::from_str::<LLMResponseFormat>(content).map_err(|e| {
Box::new(AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {e}")))
})
})
}
@@ -1,23 +0,0 @@
use serde_json::{json, Value};
pub fn get_query_response_schema() -> Value {
json!({
"type": "object",
"properties": {
"answer": { "type": "string" },
"references": {
"type": "array",
"items": {
"type": "object",
"properties": {
"reference": { "type": "string" },
},
"required": ["reference"],
"additionalProperties": false,
}
}
},
"required": ["answer", "references"],
"additionalProperties": false
})
}
-228
View File
@@ -1,228 +0,0 @@
use std::collections::{HashMap, HashSet};
use surrealdb::{sql::Thing, Error};
use common::storage::{
db::SurrealDbClient,
types::{
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
StoredObject,
},
};
/// Find entities related to the given entity via graph relationships.
///
/// Queries the `relates_to` edge table for all relationships involving the entity,
/// then fetches and returns the neighboring entities.
///
/// # Arguments
/// * `db` - Database client
/// * `entity_id` - ID of the entity to find neighbors for
/// * `user_id` - User ID for access control
/// * `limit` - Maximum number of neighbors to return
pub async fn find_entities_by_relationship_by_id(
db: &SurrealDbClient,
entity_id: &str,
user_id: &str,
limit: usize,
) -> Result<Vec<KnowledgeEntity>, Error> {
let mut relationships_response = db
.query(
"
SELECT * FROM relates_to
WHERE metadata.user_id = $user_id
AND (in = type::thing('knowledge_entity', $entity_id)
OR out = type::thing('knowledge_entity', $entity_id))
",
)
.bind(("entity_id", entity_id.to_owned()))
.bind(("user_id", user_id.to_owned()))
.await?;
let relationships: Vec<KnowledgeRelationship> = relationships_response.take(0)?;
if relationships.is_empty() {
return Ok(Vec::new());
}
let mut neighbor_ids: Vec<String> = Vec::with_capacity(relationships.len());
let mut seen: HashSet<String> = HashSet::with_capacity(relationships.len());
for rel in relationships {
if rel.in_ == entity_id {
if seen.insert(rel.out.clone()) {
neighbor_ids.push(rel.out);
}
} else if rel.out == entity_id {
if seen.insert(rel.in_.clone()) {
neighbor_ids.push(rel.in_);
}
} else {
if seen.insert(rel.in_.clone()) {
neighbor_ids.push(rel.in_.clone());
}
if seen.insert(rel.out.clone()) {
neighbor_ids.push(rel.out);
}
}
}
neighbor_ids.retain(|id| id != entity_id);
if neighbor_ids.is_empty() {
return Ok(Vec::new());
}
if limit > 0 && neighbor_ids.len() > limit {
neighbor_ids.truncate(limit);
}
let thing_ids: Vec<Thing> = neighbor_ids
.iter()
.map(|id| Thing::from((KnowledgeEntity::table_name(), id.as_str())))
.collect();
let mut neighbors_response = db
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
.bind(("table", KnowledgeEntity::table_name().to_owned()))
.bind(("things", thing_ids))
.bind(("user_id", user_id.to_owned()))
.await?;
let neighbors: Vec<KnowledgeEntity> = neighbors_response.take(0)?;
if neighbors.is_empty() {
return Ok(Vec::new());
}
let mut neighbor_map: HashMap<String, KnowledgeEntity> = neighbors
.into_iter()
.map(|entity| (entity.id.clone(), entity))
.collect();
let mut ordered = Vec::with_capacity(neighbor_ids.len());
for id in neighbor_ids {
if let Some(entity) = neighbor_map.remove(&id) {
ordered.push(entity);
}
if limit > 0 && ordered.len() >= limit {
break;
}
}
Ok(ordered)
}
#[cfg(test)]
mod tests {
use anyhow::{self, Context};
use super::*;
use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
use common::storage::types::knowledge_relationship::KnowledgeRelationship;
use uuid::Uuid;
#[tokio::test]
async fn test_find_entities_by_relationship_by_id() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
let entity_type = KnowledgeEntityType::Document;
let user_id = "user123".to_string();
let central_entity = KnowledgeEntity::new(
"central_source".to_string(),
"Central Entity".to_string(),
"Central Description".to_string(),
entity_type.clone(),
None,
user_id.clone(),
);
let related_entity1 = KnowledgeEntity::new(
"related_source1".to_string(),
"Related Entity 1".to_string(),
"Related Description 1".to_string(),
entity_type.clone(),
None,
user_id.clone(),
);
let related_entity2 = KnowledgeEntity::new(
"related_source2".to_string(),
"Related Entity 2".to_string(),
"Related Description 2".to_string(),
entity_type.clone(),
None,
user_id.clone(),
);
let unrelated_entity = KnowledgeEntity::new(
"unrelated_source".to_string(),
"Unrelated Entity".to_string(),
"Unrelated Description".to_string(),
entity_type.clone(),
None,
user_id.clone(),
);
let central_entity = db
.store_item(central_entity.clone())
.await
.with_context(|| "Failed to store central entity".to_string())?
.ok_or_else(|| anyhow::anyhow!("Central entity not returned after store"))?;
let related_entity1 = db
.store_item(related_entity1.clone())
.await
.with_context(|| "Failed to store related entity 1".to_string())?
.ok_or_else(|| anyhow::anyhow!("Related entity 1 not returned after store"))?;
let related_entity2 = db
.store_item(related_entity2.clone())
.await
.with_context(|| "Failed to store related entity 2".to_string())?
.ok_or_else(|| anyhow::anyhow!("Related entity 2 not returned after store"))?;
let _unrelated_entity = db
.store_item(unrelated_entity.clone())
.await
.with_context(|| "Failed to store unrelated entity".to_string())?
.ok_or_else(|| anyhow::anyhow!("Unrelated entity not returned after store"))?;
let source_id = "relationship_source".to_string();
let relationship1 = KnowledgeRelationship::new(
central_entity.id.clone(),
related_entity1.id.clone(),
user_id.clone(),
source_id.clone(),
"references".to_string(),
);
let relationship2 = KnowledgeRelationship::new(
central_entity.id.clone(),
related_entity2.id.clone(),
user_id.clone(),
source_id.clone(),
"contains".to_string(),
);
relationship1
.store_relationship(&db)
.await
.with_context(|| "Failed to store relationship 1".to_string())?;
relationship2
.store_relationship(&db)
.await
.with_context(|| "Failed to store relationship 2".to_string())?;
let related_entities =
find_entities_by_relationship_by_id(&db, &central_entity.id, &user_id, usize::MAX)
.await
.with_context(|| "Failed to find entities by relationship".to_string())?;
assert!(
related_entities.len() >= 2,
"Should find related entities in both directions"
);
Ok(())
}
}
+63 -115
View File
@@ -1,10 +1,9 @@
pub mod answer_retrieval;
pub mod answer_retrieval_helper;
pub mod graph;
pub mod pipeline;
pub mod reranking;
pub mod scoring;
pub(crate) mod scoring;
use common::{
error::AppError,
@@ -16,39 +15,28 @@ use common::{
use reranking::RerankerLease;
use tracing::instrument;
// Strategy output variants - defined before pipeline module
/// Result of a retrieval run.
///
/// Chunk retrieval is always performed; entities are only present when the caller
/// requested entity resolution via [`RetrievalConfig::with_entities`].
#[derive(Debug)]
pub enum StrategyOutput {
Entities(Vec<RetrievedEntity>),
pub enum RetrievalOutput {
Chunks(Vec<RetrievedChunk>),
Search(SearchResult),
}
/// Unified search result containing both chunks and entities
#[derive(Debug, Clone)]
pub struct SearchResult {
pub chunks: Vec<RetrievedChunk>,
pub entities: Vec<RetrievedEntity>,
}
impl SearchResult {
pub fn new(chunks: Vec<RetrievedChunk>, entities: Vec<RetrievedEntity>) -> Self {
Self { chunks, entities }
}
pub fn is_empty(&self) -> bool {
self.chunks.is_empty() && self.entities.is_empty()
}
WithEntities {
chunks: Vec<RetrievedChunk>,
entities: Vec<RetrievedEntity>,
},
}
pub use pipeline::{
retrieved_entities_to_json, Diagnostics, StageTimings, RetrievalConfig,
RetrievalStrategy, RetrievalTuning, RetrievalTuningFlags, SearchTarget,
retrieved_entities_to_json, Diagnostics, RetrievalConfig, RetrievalParams, StageKind,
StageTimings,
};
// Backward-compatible type aliases for external consumers
pub type PipelineDiagnostics = Diagnostics;
pub type PipelineStageTimings = StageTimings;
/// Round a score to three decimal places for JSON output.
pub(crate) fn round_score(value: f32) -> f64 {
(f64::from(value) * 1000.0).round() / 1000.0
}
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
#[derive(Debug, Clone)]
@@ -57,7 +45,7 @@ pub struct RetrievedChunk {
pub score: f32,
}
// Final entity representation returned to callers, enriched with ranked chunks.
// Knowledge entity resolved from retrieved chunks, enriched with its contributing chunks.
#[derive(Debug, Clone)]
pub struct RetrievedEntity {
pub entity: KnowledgeEntity,
@@ -65,9 +53,9 @@ pub struct RetrievedEntity {
pub chunks: Vec<RetrievedChunk>,
}
/// Primary orchestrator for the process of retrieving `KnowledgeEntity` values related to an `input_text`
/// Run chunk-first hybrid retrieval for `input_text`, optionally resolving owning entities.
#[instrument(skip_all, fields(user_id))]
pub async fn retrieve_entities(
pub async fn retrieve(
db_client: &SurrealDbClient,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
@@ -75,8 +63,8 @@ pub async fn retrieve_entities(
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<StrategyOutput, AppError> {
let params = pipeline::StrategyParams {
) -> Result<RetrievalOutput, AppError> {
let params = pipeline::RetrievalParams {
db_client,
openai_client,
embedding_provider,
@@ -94,6 +82,7 @@ mod tests {
use anyhow::{self};
use async_openai::Client;
use common::storage::indexes::ensure_runtime;
use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
use common::storage::types::system_settings::SystemSettings;
use uuid::Uuid;
@@ -133,7 +122,7 @@ mod tests {
}
#[tokio::test]
async fn test_default_strategy_retrieves_chunks() -> anyhow::Result<()> {
async fn test_chunk_retrieval_returns_chunks() -> anyhow::Result<()> {
let db = setup_test_db().await?;
let user_id = "test_user";
let chunk = TextChunk::new(
@@ -145,7 +134,7 @@ mod tests {
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?;
let openai_client = Client::new();
let params = pipeline::StrategyParams {
let params = pipeline::RetrievalParams {
db_client: &db,
openai_client: &openai_client,
embedding_provider: None,
@@ -154,12 +143,13 @@ mod tests {
config: RetrievalConfig::default(),
reranker: None,
};
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
.await?;
let results = pipeline::run_with_embedding(params, test_embedding()).await?;
let chunks = match results {
StrategyOutput::Chunks(items) => items,
other => anyhow::bail!("expected chunk results, got {other:?}"),
RetrievalOutput::Chunks(items) => items,
RetrievalOutput::WithEntities { .. } => {
anyhow::bail!("expected chunk results, got entities")
}
};
assert!(!chunks.is_empty(), "Expected at least one retrieval result");
@@ -171,8 +161,7 @@ mod tests {
}
#[tokio::test]
async fn test_default_strategy_returns_chunks_from_multiple_sources(
) -> anyhow::Result<()> {
async fn test_chunk_retrieval_returns_chunks_from_multiple_sources() -> anyhow::Result<()> {
let db = setup_test_db().await?;
let user_id = "multi_source_user";
@@ -191,7 +180,7 @@ mod tests {
TextChunk::store_with_embedding(secondary_chunk, chunk_embedding_secondary(), &db).await?;
let openai_client = Client::new();
let params = pipeline::StrategyParams {
let params = pipeline::RetrievalParams {
db_client: &db,
openai_client: &openai_client,
embedding_provider: None,
@@ -200,12 +189,13 @@ mod tests {
config: RetrievalConfig::default(),
reranker: None,
};
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
.await?;
let results = pipeline::run_with_embedding(params, test_embedding()).await?;
let chunks = match results {
StrategyOutput::Chunks(items) => items,
other => anyhow::bail!("expected chunk results, got {other:?}"),
RetrievalOutput::Chunks(items) => items,
RetrievalOutput::WithEntities { .. } => {
anyhow::bail!("expected chunk results, got entities")
}
};
assert!(chunks.len() >= 2, "Expected chunks from multiple sources");
@@ -223,96 +213,54 @@ mod tests {
}
#[tokio::test]
async fn test_revised_strategy_returns_chunks() -> anyhow::Result<()> {
async fn test_with_entities_resolves_owning_entities() -> anyhow::Result<()> {
let db = setup_test_db().await?;
let user_id = "chunk_user";
let chunk_one = TextChunk::new(
"src_alpha".into(),
"Tokio tasks execute on worker threads managed by the runtime.".into(),
user_id.into(),
);
let chunk_two = TextChunk::new(
"src_beta".into(),
"Hyper utilizes Tokio to drive HTTP state machines efficiently.".into(),
user_id.into(),
);
let user_id = "entity_user";
TextChunk::store_with_embedding(chunk_one.clone(), chunk_embedding_primary(), &db).await?;
TextChunk::store_with_embedding(chunk_two.clone(), chunk_embedding_secondary(), &db).await?;
let config = RetrievalConfig::with_strategy(RetrievalStrategy::Default);
let openai_client = Client::new();
let params = pipeline::StrategyParams {
db_client: &db,
openai_client: &openai_client,
embedding_provider: None,
input_text: "tokio runtime worker behavior",
user_id,
config,
reranker: None,
};
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
.await?;
let chunks = match results {
StrategyOutput::Chunks(items) => items,
other => anyhow::bail!("expected chunk results, got {other:?}"),
};
assert!(
!chunks.is_empty(),
"Revised strategy should return chunk-only responses"
);
assert!(
chunks
.iter()
.any(|entry| entry.chunk.chunk.contains("Tokio")),
"Chunk results should contain relevant snippets"
);
Ok(())
}
#[tokio::test]
async fn test_search_strategy_returns_search_result() -> anyhow::Result<()> {
let db = setup_test_db().await?;
let user_id = "search_user";
let chunk = TextChunk::new(
"search_src".into(),
"Async Rust programming uses Tokio runtime for concurrent tasks.".into(),
"entity_source".into(),
"Async Rust programming uses the Tokio runtime for concurrent tasks.".into(),
user_id.into(),
);
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?;
let config = RetrievalConfig::for_search(pipeline::SearchTarget::Both);
let entity = KnowledgeEntity::new(
"entity_source".into(),
"Tokio Runtime".into(),
"Async runtime for Rust".into(),
KnowledgeEntityType::Document,
None,
user_id.into(),
);
db.store_item(entity).await?;
let openai_client = Client::new();
let params = pipeline::StrategyParams {
let params = pipeline::RetrievalParams {
db_client: &db,
openai_client: &openai_client,
embedding_provider: None,
input_text: "async rust programming",
user_id,
config,
config: RetrievalConfig::with_entities(),
reranker: None,
};
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
.await?;
let results = pipeline::run_with_embedding(params, test_embedding()).await?;
let StrategyOutput::Search(search_result) = results else {
anyhow::bail!("expected Search output");
let RetrievalOutput::WithEntities { chunks, entities } = results else {
anyhow::bail!("expected WithEntities output");
};
// Should return chunks (entities may be empty if none stored)
assert!(!chunks.is_empty(), "Should return chunks");
assert!(
!search_result.chunks.is_empty(),
"Search strategy should return chunks"
entities.iter().any(|e| e.entity.name == "Tokio Runtime"),
"Should resolve the entity owning the retrieved chunk"
);
assert!(
search_result
.chunks
entities
.iter()
.any(|c| c.chunk.chunk.contains("Tokio")),
"Search results should contain relevant chunks"
.find(|e| e.entity.name == "Tokio Runtime")
.is_some_and(|e| !e.chunks.is_empty()),
"Resolved entity should carry its contributing chunks"
);
Ok(())
}
+25 -128
View File
@@ -1,22 +1,5 @@
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::scoring::FusionWeights;
pub use common::utils::config::RetrievalStrategy;
/// Configures which result types to include in Search strategy
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum SearchTarget {
/// Return only text chunks
ChunksOnly,
/// Return only knowledge entities
EntitiesOnly,
/// Return both chunks and entities (default)
#[default]
Both,
}
/// Two-variant flag that serializes as a bool for backward compatibility.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum BoolFlag {
@@ -62,30 +45,20 @@ impl<'de> Deserialize<'de> for BoolFlag {
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct RetrievalTuningFlags {
pub rerank_scores_only: BoolFlag,
pub normalize_vector_scores: BoolFlag,
pub normalize_fts_scores: BoolFlag,
pub chunk_rrf_use_vector: BoolFlag,
pub chunk_rrf_use_fts: BoolFlag,
}
impl RetrievalTuningFlags {
pub const fn rerank_scores_only(&self) -> bool {
pub const fn rerank_scores_only(self) -> bool {
self.rerank_scores_only.as_bool()
}
pub const fn normalize_vector_scores(&self) -> bool {
self.normalize_vector_scores.as_bool()
}
pub const fn normalize_fts_scores(&self) -> bool {
self.normalize_fts_scores.as_bool()
}
pub const fn chunk_rrf_use_vector(&self) -> bool {
pub const fn chunk_rrf_use_vector(self) -> bool {
self.chunk_rrf_use_vector.as_bool()
}
pub const fn chunk_rrf_use_fts(&self) -> bool {
pub const fn chunk_rrf_use_fts(self) -> bool {
self.chunk_rrf_use_fts.as_bool()
}
}
@@ -94,146 +67,70 @@ impl Default for RetrievalTuningFlags {
fn default() -> Self {
Self {
rerank_scores_only: BoolFlag::Disabled,
normalize_vector_scores: BoolFlag::Disabled,
normalize_fts_scores: BoolFlag::Enabled,
chunk_rrf_use_vector: BoolFlag::Enabled,
chunk_rrf_use_fts: BoolFlag::Enabled,
}
}
}
/// Tunable parameters that govern each retrieval stage.
/// Tunable parameters governing the chunk-first hybrid (vector + FTS, RRF-fused) retrieval.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetrievalTuning {
pub entity_vector_take: usize,
/// Number of vector candidates to pull from the chunk embedding index.
pub chunk_vector_take: usize,
pub entity_fts_take: usize,
/// Number of full-text candidates to pull from the chunk index.
pub chunk_fts_take: usize,
pub score_threshold: f32,
pub fallback_min_results: usize,
pub token_budget_estimate: usize,
pub avg_chars_per_token: usize,
/// Maximum chunks attached to each resolved entity.
pub max_chunks_per_entity: usize,
pub lexical_match_weight: f32,
pub graph_traversal_seed_limit: usize,
pub graph_neighbor_limit: usize,
pub graph_score_decay: f32,
pub graph_seed_min_score: f32,
pub graph_vector_inheritance: f32,
/// Blend weight applied when mixing reranker scores with fused scores.
pub rerank_blend_weight: f32,
pub flags: RetrievalTuningFlags,
/// Keep top-N candidates after reranking.
pub rerank_keep_top: usize,
/// Maximum number of chunks returned to callers.
pub chunk_result_cap: usize,
/// Optional fusion weights for hybrid search. If None, uses default weights.
pub fusion_weights: Option<FusionWeights>,
/// Reciprocal rank fusion k value for chunk merging in Revised strategy.
#[serde(default = "default_chunk_rrf_k")]
/// Reciprocal rank fusion k value for chunk merging.
pub chunk_rrf_k: f32,
/// Weight applied to vector ranks in RRF.
#[serde(default = "default_chunk_rrf_vector_weight")]
pub chunk_rrf_vector_weight: f32,
/// Weight applied to chunk FTS ranks in RRF.
#[serde(default = "default_chunk_rrf_fts_weight")]
pub chunk_rrf_fts_weight: f32,
pub flags: RetrievalTuningFlags,
}
impl Default for RetrievalTuning {
fn default() -> Self {
Self {
entity_vector_take: 15,
chunk_vector_take: 20,
entity_fts_take: 10,
chunk_fts_take: 20,
score_threshold: 0.35,
fallback_min_results: 10,
token_budget_estimate: 10000,
avg_chars_per_token: 4,
max_chunks_per_entity: 4,
lexical_match_weight: 0.15,
graph_traversal_seed_limit: 5,
graph_neighbor_limit: 6,
graph_score_decay: 0.75,
graph_seed_min_score: 0.4,
graph_vector_inheritance: 0.6,
rerank_blend_weight: 0.65,
flags: RetrievalTuningFlags::default(),
rerank_keep_top: 8,
chunk_result_cap: 5,
fusion_weights: None,
chunk_rrf_k: default_chunk_rrf_k(),
chunk_rrf_vector_weight: default_chunk_rrf_vector_weight(),
chunk_rrf_fts_weight: default_chunk_rrf_fts_weight(),
chunk_rrf_k: 60.0,
chunk_rrf_vector_weight: 1.0,
chunk_rrf_fts_weight: 1.0,
flags: RetrievalTuningFlags::default(),
}
}
}
/// Wrapper containing tuning plus future flags for per-request overrides.
/// Per-request retrieval configuration.
///
/// The pipeline always performs chunk-first hybrid retrieval. Set `resolve_entities`
/// when a caller additionally needs the `KnowledgeEntity` rows that own the retrieved
/// chunks (search, ingestion linking, relationship suggestion).
#[derive(Debug, Clone, Default)]
pub struct RetrievalConfig {
pub strategy: RetrievalStrategy,
pub tuning: RetrievalTuning,
/// Target for Search strategy (chunks, entities, or both)
pub search_target: SearchTarget,
pub resolve_entities: bool,
}
impl RetrievalConfig {
pub fn new(tuning: RetrievalTuning) -> Self {
/// Chunk retrieval that also resolves the owning knowledge entities.
pub fn with_entities() -> Self {
Self {
strategy: RetrievalStrategy::Default,
tuning,
search_target: SearchTarget::default(),
}
}
pub fn with_strategy(strategy: RetrievalStrategy) -> Self {
Self {
strategy,
tuning: RetrievalTuning::default(),
search_target: SearchTarget::default(),
}
}
pub fn with_tuning(strategy: RetrievalStrategy, tuning: RetrievalTuning) -> Self {
Self {
strategy,
tuning,
search_target: SearchTarget::default(),
}
}
/// Create config for chat retrieval with strategy selection support
pub fn for_chat(strategy: RetrievalStrategy) -> Self {
Self::with_strategy(strategy)
}
/// Create config for relationship suggestion (entity-only retrieval)
pub fn for_relationship_suggestion() -> Self {
Self::with_strategy(RetrievalStrategy::RelationshipSuggestion)
}
/// Create config for ingestion pipeline (entity-only retrieval)
pub fn for_ingestion() -> Self {
Self::with_strategy(RetrievalStrategy::Ingestion)
}
/// Create config for unified search (chunks and/or entities)
pub fn for_search(target: SearchTarget) -> Self {
Self {
strategy: RetrievalStrategy::Search,
tuning: RetrievalTuning::default(),
search_target: target,
resolve_entities: true,
}
}
}
const fn default_chunk_rrf_k() -> f32 {
60.0
}
const fn default_chunk_rrf_vector_weight() -> f32 {
1.0
}
const fn default_chunk_rrf_fts_weight() -> f32 {
1.0
}
+107
View File
@@ -0,0 +1,107 @@
use async_openai::Client;
use common::{
error::AppError,
storage::{db::SurrealDbClient, types::text_chunk::TextChunk},
utils::embedding::EmbeddingProvider,
};
use crate::{reranking::RerankerLease, scoring::Scored, RetrievedChunk, RetrievedEntity};
use super::{
config::RetrievalConfig,
diagnostics::{AssembleStats, Diagnostics, SearchStats},
StageKind, StageTimings, RetrievalParams,
};
/// Mutable working state threaded through every retrieval stage.
pub(crate) struct PipelineContext<'a> {
pub db_client: &'a SurrealDbClient,
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
pub embedding_provider: Option<&'a EmbeddingProvider>,
pub input_text: String,
pub user_id: String,
pub config: RetrievalConfig,
pub query_embedding: Option<Vec<f32>>,
pub chunk_values: Vec<Scored<TextChunk>>,
pub reranker: Option<RerankerLease>,
pub diagnostics: Option<Diagnostics>,
pub entity_results: Vec<RetrievedEntity>,
pub chunk_results: Vec<RetrievedChunk>,
stage_timings: StageTimings,
}
impl<'a> PipelineContext<'a> {
pub fn new(params: RetrievalParams<'a>) -> Self {
Self {
db_client: params.db_client,
openai_client: params.openai_client,
embedding_provider: params.embedding_provider,
input_text: params.input_text.to_owned(),
user_id: params.user_id.to_owned(),
config: params.config,
query_embedding: None,
chunk_values: Vec::new(),
reranker: params.reranker,
diagnostics: None,
entity_results: Vec::new(),
chunk_results: Vec::new(),
stage_timings: StageTimings::default(),
}
}
pub fn with_embedding(params: RetrievalParams<'a>, query_embedding: Vec<f32>) -> Self {
let mut ctx = Self::new(params);
ctx.query_embedding = Some(query_embedding);
ctx
}
pub(crate) fn ensure_embedding(&self) -> Result<&Vec<f32>, Box<AppError>> {
self.query_embedding.as_ref().ok_or_else(|| {
Box::new(AppError::InternalError(
"query embedding missing before candidate search".to_string(),
))
})
}
pub fn enable_diagnostics(&mut self) {
if self.diagnostics.is_none() {
self.diagnostics = Some(Diagnostics::default());
}
}
pub fn diagnostics_enabled(&self) -> bool {
self.diagnostics.is_some()
}
pub(crate) fn record_search(&mut self, stats: SearchStats) {
if let Some(diag) = self.diagnostics.as_mut() {
diag.search = Some(stats);
}
}
pub(crate) fn record_assemble(&mut self, stats: AssembleStats) {
if let Some(diag) = self.diagnostics.as_mut() {
diag.assemble = Some(stats);
}
}
pub fn take_diagnostics(&mut self) -> Option<Diagnostics> {
self.diagnostics.take()
}
pub fn take_stage_timings(&mut self) -> StageTimings {
std::mem::take(&mut self.stage_timings)
}
pub fn record_stage_duration(&mut self, kind: StageKind, duration: std::time::Duration) {
self.stage_timings.record(kind, duration);
}
pub fn take_entity_results(&mut self) -> Vec<RetrievedEntity> {
std::mem::take(&mut self.entity_results)
}
pub fn take_chunk_results(&mut self) -> Vec<RetrievedChunk> {
std::mem::take(&mut self.chunk_results)
}
}
+3 -33
View File
@@ -1,51 +1,21 @@
use serde::Serialize;
/// Captures instrumentation for each hybrid retrieval stage when diagnostics are enabled.
/// Captures instrumentation for the retrieval stages when diagnostics are enabled.
#[derive(Debug, Clone, Default, Serialize)]
pub struct Diagnostics {
pub collect_candidates: Option<CollectCandidatesStats>,
pub enrich_chunks_from_entities: Option<ChunkEnrichmentStats>,
pub search: Option<SearchStats>,
pub assemble: Option<AssembleStats>,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct CollectCandidatesStats {
pub vector_entity_candidates: usize,
pub struct SearchStats {
pub vector_chunk_candidates: usize,
pub fts_entity_candidates: usize,
pub fts_chunk_candidates: usize,
pub vector_chunk_scores: Vec<f32>,
pub fts_chunk_scores: Vec<f32>,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct ChunkEnrichmentStats {
pub filtered_entity_count: usize,
pub fallback_min_results: usize,
pub chunk_sources_considered: usize,
pub chunk_candidates_before_enrichment: usize,
pub chunk_candidates_after_enrichment: usize,
pub top_chunk_scores: Vec<f32>,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct AssembleStats {
pub token_budget_start: usize,
pub token_budget_spent: usize,
pub token_budget_remaining: usize,
pub budget_exhausted: bool,
pub chunks_selected: usize,
pub chunks_skipped_due_budget: usize,
pub entity_count: usize,
pub entity_traces: Vec<EntityAssemblyTrace>,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct EntityAssemblyTrace {
pub entity_id: String,
pub source_id: String,
pub inspected_candidates: usize,
pub selected_chunk_ids: Vec<String>,
pub selected_chunk_scores: Vec<f32>,
pub skipped_due_budget: usize,
}
+119 -209
View File
@@ -1,61 +1,68 @@
mod config;
mod context;
mod diagnostics;
mod stages;
mod strategies;
pub use config::{
RetrievalConfig, RetrievalStrategy, RetrievalTuning, RetrievalTuningFlags, SearchTarget,
};
pub use diagnostics::{
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace, Diagnostics,
};
pub use config::RetrievalConfig;
pub use diagnostics::Diagnostics;
use crate::{reranking::RerankerLease, RetrievedEntity, StrategyOutput};
use crate::{round_score, RetrievalOutput, RetrievedEntity};
use async_openai::Client;
use async_trait::async_trait;
use common::{error::AppError, storage::db::SurrealDbClient};
use std::time::{Duration, Instant};
use tracing::info;
use stages::PipelineContext;
use strategies::{
DefaultStrategyDriver, IngestionDriver, RelationshipSuggestionDriver, SearchStrategyDriver,
use stages::{
ChunkAssembleStage, ChunkRerankStage, ChunkSearchStage, EmbedStage, ResolveEntitiesStage,
};
// Export StrategyOutput publicly from this module
// (it's defined in lib.rs but we re-export it here)
// Stage type enum
/// Identifies a retrieval stage for timing and instrumentation.
///
/// [`StageKind::ALL`] lists every kind in pipeline order; consumers (e.g. the evaluation
/// harness) iterate it generically so that adding a stage requires no changes outside this
/// crate — add the variant, extend `ALL`, and give it a [`StageKind::label`].
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum StageKind {
Embed,
CollectCandidates,
GraphExpansion,
ChunkAttach,
Search,
Rerank,
ResolveEntities,
Assemble,
}
// Pipeline stage trait
impl StageKind {
/// Every stage kind in canonical pipeline order.
pub const ALL: [StageKind; 5] = [
StageKind::Embed,
StageKind::Search,
StageKind::Rerank,
StageKind::ResolveEntities,
StageKind::Assemble,
];
/// Stable, machine-friendly identifier for the stage (used as a metrics key).
pub const fn label(self) -> &'static str {
match self {
StageKind::Embed => "embed",
StageKind::Search => "search",
StageKind::Rerank => "rerank",
StageKind::ResolveEntities => "resolve_entities",
StageKind::Assemble => "assemble",
}
}
}
/// A single composable step in the retrieval pipeline.
#[async_trait]
pub trait Stage: Send + Sync {
pub(crate) trait Stage: Send + Sync {
fn kind(&self) -> StageKind;
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError>;
async fn execute(&self, ctx: &mut context::PipelineContext<'_>) -> Result<(), AppError>;
}
// Type alias for boxed stages
pub type BoxedStage = Box<dyn Stage>;
pub(crate) type BoxedStage = Box<dyn Stage>;
// Strategy driver trait
#[async_trait]
pub trait StrategyDriver: Send + Sync {
type Output;
fn stages(&self) -> Vec<BoxedStage>;
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>>;
}
// Pipeline stage timings tracker
/// Per-stage execution timings recorded during a run.
#[derive(Debug, Default, Clone)]
pub struct StageTimings {
timings: Vec<(StageKind, Duration)>,
@@ -66,41 +73,13 @@ impl StageTimings {
self.timings.push((kind, duration));
}
pub fn into_vec(self) -> Vec<(StageKind, Duration)> {
self.timings
}
// Helper methods to get duration for each stage type (for backward compatibility)
fn get_stage_ms(&self, kind: StageKind) -> u128 {
/// Milliseconds recorded for `kind`, or `0` if the stage did not run.
pub fn stage_ms(&self, kind: StageKind) -> u128 {
self.timings
.iter()
.find(|(k, _)| *k == kind)
.map_or(0, |(_, d)| d.as_millis())
}
pub fn embed_ms(&self) -> u128 {
self.get_stage_ms(StageKind::Embed)
}
pub fn collect_candidates_ms(&self) -> u128 {
self.get_stage_ms(StageKind::CollectCandidates)
}
pub fn graph_expansion_ms(&self) -> u128 {
self.get_stage_ms(StageKind::GraphExpansion)
}
pub fn chunk_attach_ms(&self) -> u128 {
self.get_stage_ms(StageKind::ChunkAttach)
}
pub fn rerank_ms(&self) -> u128 {
self.get_stage_ms(StageKind::Rerank)
}
pub fn assemble_ms(&self) -> u128 {
self.get_stage_ms(StageKind::Assemble)
}
}
pub struct RunOutput<T> {
@@ -109,7 +88,35 @@ pub struct RunOutput<T> {
pub stage_timings: StageTimings,
}
pub async fn execute(params: StrategyParams<'_>) -> Result<StrategyOutput, AppError> {
/// Inputs required to run a retrieval.
pub struct RetrievalParams<'a> {
pub db_client: &'a SurrealDbClient,
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
pub embedding_provider: Option<&'a common::utils::embedding::EmbeddingProvider>,
pub input_text: &'a str,
pub user_id: &'a str,
pub config: RetrievalConfig,
pub reranker: Option<crate::reranking::RerankerLease>,
}
fn build_stages(config: &RetrievalConfig) -> Vec<BoxedStage> {
let mut stages: Vec<BoxedStage> = vec![
Box::new(EmbedStage),
Box::new(ChunkSearchStage),
Box::new(ChunkRerankStage),
];
if config.resolve_entities {
stages.push(Box::new(ResolveEntitiesStage));
}
stages.push(Box::new(ChunkAssembleStage));
stages
}
async fn run(
params: RetrievalParams<'_>,
query_embedding: Option<Vec<f32>>,
capture_diagnostics: bool,
) -> Result<RunOutput<RetrievalOutput>, AppError> {
let input_chars = params.input_text.chars().count();
let input_preview: String = params.input_text.chars().take(120).collect();
let input_preview_clean = input_preview.replace('\n', " ");
@@ -119,110 +126,67 @@ pub async fn execute(params: StrategyParams<'_>) -> Result<StrategyOutput, AppEr
input_chars,
preview_truncated = input_chars > preview_len,
preview = %input_preview_clean,
strategy = %params.config.strategy,
resolve_entities = params.config.resolve_entities,
"Starting retrieval pipeline"
);
let strategy = params.config.strategy;
let search_target = params.config.search_target;
let resolve_entities = params.config.resolve_entities;
let mut ctx = match query_embedding {
Some(embedding) => context::PipelineContext::with_embedding(params, embedding),
None => context::PipelineContext::new(params),
};
match strategy {
RetrievalStrategy::Default => {
let driver = DefaultStrategyDriver::new();
let run = execute_strategy(driver, params, None, false).await?;
Ok(StrategyOutput::Chunks(run.results))
}
RetrievalStrategy::RelationshipSuggestion => {
let driver = RelationshipSuggestionDriver::new();
let run = execute_strategy(driver, params, None, false).await?;
Ok(StrategyOutput::Entities(run.results))
}
RetrievalStrategy::Ingestion => {
let driver = IngestionDriver::new();
let run = execute_strategy(driver, params, None, false).await?;
Ok(StrategyOutput::Entities(run.results))
}
RetrievalStrategy::Search => {
let driver = SearchStrategyDriver::new(search_target);
let run = execute_strategy(driver, params, None, false).await?;
Ok(StrategyOutput::Search(run.results))
}
if capture_diagnostics {
ctx.enable_diagnostics();
}
for stage in build_stages(&ctx.config) {
let start = Instant::now();
stage.execute(&mut ctx).await?;
ctx.record_stage_duration(stage.kind(), start.elapsed());
}
let diagnostics = ctx.take_diagnostics();
let stage_timings = ctx.take_stage_timings();
let chunks = ctx.take_chunk_results();
let results = if resolve_entities {
RetrievalOutput::WithEntities {
chunks,
entities: ctx.take_entity_results(),
}
} else {
RetrievalOutput::Chunks(chunks)
};
Ok(RunOutput {
results,
diagnostics,
stage_timings,
})
}
pub async fn run_pipeline_with_embedding(
params: StrategyParams<'_>,
query_embedding: Vec<f32>,
) -> Result<StrategyOutput, AppError> {
let strategy = params.config.strategy;
let search_target = params.config.search_target;
match strategy {
RetrievalStrategy::Default => {
let driver = DefaultStrategyDriver::new();
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
Ok(StrategyOutput::Chunks(run.results))
}
RetrievalStrategy::RelationshipSuggestion => {
let driver = RelationshipSuggestionDriver::new();
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
Ok(StrategyOutput::Entities(run.results))
}
RetrievalStrategy::Ingestion => {
let driver = IngestionDriver::new();
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
Ok(StrategyOutput::Entities(run.results))
}
RetrievalStrategy::Search => {
let driver = SearchStrategyDriver::new(search_target);
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
Ok(StrategyOutput::Search(run.results))
}
}
/// Run the retrieval pipeline, generating the query embedding internally if needed.
pub async fn execute(params: RetrievalParams<'_>) -> Result<RetrievalOutput, AppError> {
Ok(run(params, None, false).await?.results)
}
pub async fn run_pipeline_with_embedding_with_metrics(
params: StrategyParams<'_>,
/// Run the retrieval pipeline with a pre-computed query embedding.
pub async fn run_with_embedding(
params: RetrievalParams<'_>,
query_embedding: Vec<f32>,
) -> Result<RunOutput<StrategyOutput>, AppError> {
let strategy = params.config.strategy;
match strategy {
RetrievalStrategy::Default => {
let driver = DefaultStrategyDriver::new();
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
Ok(RunOutput {
results: StrategyOutput::Chunks(run.results),
diagnostics: run.diagnostics,
stage_timings: run.stage_timings,
})
}
_ => Err(AppError::InternalError(
"Metrics not supported for this strategy".into(),
)),
}
) -> Result<RetrievalOutput, AppError> {
Ok(run(params, Some(query_embedding), false).await?.results)
}
pub async fn run_pipeline_with_embedding_with_diagnostics(
params: StrategyParams<'_>,
/// Run with a pre-computed embedding, returning results and per-stage timings.
///
/// When `capture_diagnostics` is true, pipeline search/assemble stats are included.
pub async fn run_with_embedding_instrumented(
params: RetrievalParams<'_>,
query_embedding: Vec<f32>,
) -> Result<RunOutput<StrategyOutput>, AppError> {
let strategy = params.config.strategy;
match strategy {
RetrievalStrategy::Default => {
let driver = DefaultStrategyDriver::new();
let run = execute_strategy(driver, params, Some(query_embedding), true).await?;
Ok(RunOutput {
results: StrategyOutput::Chunks(run.results),
diagnostics: run.diagnostics,
stage_timings: run.stage_timings,
})
}
_ => Err(AppError::InternalError(
"Diagnostics not supported for this strategy".into(),
)),
}
capture_diagnostics: bool,
) -> Result<RunOutput<RetrievalOutput>, AppError> {
run(params, Some(query_embedding), capture_diagnostics).await
}
pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value {
@@ -246,57 +210,3 @@ pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::V
})
.collect::<Vec<_>>())
}
pub struct StrategyParams<'a> {
pub db_client: &'a SurrealDbClient,
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
pub embedding_provider: Option<&'a common::utils::embedding::EmbeddingProvider>,
pub input_text: &'a str,
pub user_id: &'a str,
pub config: RetrievalConfig,
pub reranker: Option<RerankerLease>,
}
async fn execute_strategy<D: StrategyDriver>(
driver: D,
params: StrategyParams<'_>,
query_embedding: Option<Vec<f32>>,
capture_diagnostics: bool,
) -> Result<RunOutput<D::Output>, AppError> {
let ctx = match query_embedding {
Some(embedding) => PipelineContext::with_embedding(params, embedding),
None => PipelineContext::new(params),
};
run_with_driver(driver, ctx, capture_diagnostics).await
}
async fn run_with_driver<D: StrategyDriver>(
driver: D,
mut ctx: PipelineContext<'_>,
capture_diagnostics: bool,
) -> Result<RunOutput<D::Output>, AppError> {
if capture_diagnostics {
ctx.enable_diagnostics();
}
for stage in driver.stages() {
let start = Instant::now();
stage.execute(&mut ctx).await?;
ctx.record_stage_duration(stage.kind(), start.elapsed());
}
let diagnostics = ctx.take_diagnostics();
let stage_timings = ctx.take_stage_timings();
let results = driver.finalize(&mut ctx).map_err(|e| *e)?;
Ok(RunOutput {
results,
diagnostics,
stage_timings,
})
}
fn round_score(value: f32) -> f64 {
(f64::from(value) * 1000.0).round() / 1000.0
}
+424
View File
@@ -0,0 +1,424 @@
use async_trait::async_trait;
use common::{
error::AppError,
storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
utils::embedding::generate_embedding,
};
use fastembed::RerankResult;
use std::collections::HashMap;
use tracing::{debug, instrument, warn};
use crate::{
scoring::{
clamp_unit, min_max_normalize, reciprocal_rank_fusion, RrfConfig, Scored,
},
RetrievedChunk, RetrievedEntity,
};
use super::{
config::RetrievalTuning,
context::PipelineContext,
diagnostics::{AssembleStats, SearchStats},
Stage, StageKind,
};
#[derive(Debug, Clone, Copy)]
pub struct EmbedStage;
#[async_trait]
impl Stage for EmbedStage {
fn kind(&self) -> StageKind {
StageKind::Embed
}
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
embed(ctx).await
}
}
#[derive(Debug, Clone, Copy)]
pub struct ChunkSearchStage;
#[async_trait]
impl Stage for ChunkSearchStage {
fn kind(&self) -> StageKind {
StageKind::Search
}
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
search_chunks(ctx).await
}
}
#[derive(Debug, Clone, Copy)]
pub struct ChunkRerankStage;
#[async_trait]
impl Stage for ChunkRerankStage {
fn kind(&self) -> StageKind {
StageKind::Rerank
}
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
rerank_chunks(ctx).await
}
}
#[derive(Debug, Clone, Copy)]
pub struct ResolveEntitiesStage;
#[async_trait]
impl Stage for ResolveEntitiesStage {
fn kind(&self) -> StageKind {
StageKind::ResolveEntities
}
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
resolve_entities(ctx).await
}
}
#[derive(Debug, Clone, Copy)]
pub struct ChunkAssembleStage;
#[async_trait]
impl Stage for ChunkAssembleStage {
fn kind(&self) -> StageKind {
StageKind::Assemble
}
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
assemble_chunks(ctx)
}
}
#[instrument(level = "trace", skip_all)]
pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
if ctx.query_embedding.is_some() {
debug!("Reusing cached query embedding for hybrid retrieval");
} else {
debug!("Generating query embedding for hybrid retrieval");
let embedding = if let Some(provider) = ctx.embedding_provider {
provider.embed(&ctx.input_text).await?
} else {
generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await?
};
ctx.query_embedding = Some(embedding);
}
Ok(())
}
#[instrument(level = "trace", skip_all)]
pub async fn search_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
debug!("Collecting chunk candidates via vector and FTS search");
let embedding = ctx.ensure_embedding().map_err(|e| *e)?.clone();
let tuning = &ctx.config.tuning;
let fts_take = tuning.chunk_fts_take;
let (fts_query, fts_token_count) = normalize_fts_query(&ctx.input_text);
let fts_enabled = tuning.flags.chunk_rrf_use_fts() && fts_take > 0 && !fts_query.is_empty();
let (vector_rows, fts_rows) = tokio::try_join!(
TextChunk::vector_search(
tuning.chunk_vector_take,
embedding,
ctx.db_client,
&ctx.user_id,
),
async {
if fts_enabled {
TextChunk::fts_search(fts_take, &fts_query, ctx.db_client, &ctx.user_id).await
} else {
Ok(Vec::new())
}
}
)?;
let vector_candidates = vector_rows.len();
let fts_candidates = fts_rows.len();
let vector_scored: Vec<Scored<TextChunk>> = vector_rows
.into_iter()
.map(|row| Scored::new(row.chunk).with_vector_score(row.score))
.collect();
let fts_scored: Vec<Scored<TextChunk>> = fts_rows
.into_iter()
.map(|row| Scored::new(row.chunk).with_fts_score(row.score))
.collect();
let mut fts_weight = tuning.chunk_rrf_fts_weight;
if fts_enabled && fts_token_count > 0 && fts_token_count <= 3 {
// For very short keyword queries, lean more on lexical ranking.
fts_weight *= 1.5;
}
let rrf_config = RrfConfig {
k: tuning.chunk_rrf_k,
vector_weight: tuning.chunk_rrf_vector_weight,
fts_weight,
use_vector: tuning.flags.chunk_rrf_use_vector(),
use_fts: tuning.flags.chunk_rrf_use_fts() && fts_candidates > 0,
};
let chunks = reciprocal_rank_fusion(vector_scored, fts_scored, rrf_config);
debug!(
total_merged = chunks.len(),
vector_only = chunks.iter().filter(|c| c.scores.fts.is_none()).count(),
fts_only = chunks.iter().filter(|c| c.scores.vector.is_none()).count(),
both_signals = chunks
.iter()
.filter(|c| c.scores.vector.is_some() && c.scores.fts.is_some())
.count(),
rrf_k = rrf_config.k,
"Merged chunk candidates with RRF"
);
if ctx.diagnostics_enabled() {
ctx.record_search(SearchStats {
vector_chunk_candidates: vector_candidates,
fts_chunk_candidates: fts_candidates,
vector_chunk_scores: sample_scores(&chunks, |chunk| chunk.scores.vector.unwrap_or(0.0)),
fts_chunk_scores: sample_scores(&chunks, |chunk| chunk.scores.fts.unwrap_or(0.0)),
});
}
ctx.chunk_values = chunks;
Ok(())
}
#[instrument(level = "trace", skip_all)]
pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
if ctx.chunk_values.len() <= 1 {
return Ok(());
}
let Some(reranker) = ctx.reranker.as_ref() else {
debug!("No reranker lease provided; skipping chunk rerank stage");
return Ok(());
};
let documents =
build_chunk_rerank_documents(&ctx.chunk_values, ctx.config.tuning.rerank_keep_top.max(1));
if documents.len() <= 1 {
debug!("Skipping chunk reranking stage; insufficient chunk documents");
return Ok(());
}
match reranker.rerank(&ctx.input_text, documents).await {
Ok(results) if !results.is_empty() => {
apply_chunk_rerank_results(&mut ctx.chunk_values, &ctx.config.tuning, results);
}
Ok(_) => debug!("Chunk reranker returned no results; retaining original order"),
Err(err) => warn!(
error = %err,
"Chunk reranking failed; continuing with original ordering"
),
}
Ok(())
}
/// Resolve the `KnowledgeEntity` rows that own the retrieved chunks.
///
/// Entities are derived directly from the (benchmarked) chunk retrieval: chunks are grouped
/// by `source_id`, the owning entities are loaded, scored by their best contributing chunk,
/// and the contributing chunks are attached.
#[instrument(level = "trace", skip_all)]
pub async fn resolve_entities(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
if ctx.chunk_values.is_empty() {
return Ok(());
}
let max_chunks = ctx.config.tuning.max_chunks_per_entity.max(1);
let mut source_order: Vec<String> = Vec::new();
let mut chunks_by_source: HashMap<String, Vec<RetrievedChunk>> = HashMap::new();
let mut best_score: HashMap<String, f32> = HashMap::new();
for scored in &ctx.chunk_values {
let source = scored.item.source_id.clone();
let attached = chunks_by_source.entry(source.clone()).or_default();
if attached.is_empty() {
source_order.push(source.clone());
best_score.insert(source.clone(), scored.fused);
}
if attached.len() < max_chunks {
attached.push(RetrievedChunk {
chunk: scored.item.clone(),
score: scored.fused,
});
}
}
let entities =
KnowledgeEntity::find_by_source_ids(ctx.db_client, &source_order, &ctx.user_id).await?;
let mut entities_by_source: HashMap<String, Vec<KnowledgeEntity>> = HashMap::new();
for entity in entities {
entities_by_source
.entry(entity.source_id.clone())
.or_default()
.push(entity);
}
let mut results = Vec::new();
for source in &source_order {
let Some(entities) = entities_by_source.remove(source) else {
continue;
};
let score = best_score.get(source).copied().unwrap_or(0.0);
let chunks = chunks_by_source.get(source).cloned().unwrap_or_default();
for entity in entities {
results.push(RetrievedEntity {
entity,
score,
chunks: chunks.clone(),
});
}
}
debug!(
sources = source_order.len(),
entities = results.len(),
"Resolved entities from retrieved chunks"
);
ctx.entity_results = results;
Ok(())
}
#[instrument(level = "trace", skip_all)]
#[allow(clippy::result_large_err)]
pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
debug!("Assembling chunk retrieval results");
let mut chunk_values = std::mem::take(&mut ctx.chunk_values);
// Limit how many chunks we return to keep context size reasonable.
let limit = ctx
.config
.tuning
.chunk_result_cap
.max(1)
.min(ctx.config.tuning.chunk_vector_take.max(1));
if chunk_values.len() > limit {
chunk_values.truncate(limit);
}
ctx.chunk_results = chunk_values
.into_iter()
.map(|chunk| RetrievedChunk {
chunk: chunk.item,
score: chunk.fused,
})
.collect();
if ctx.diagnostics_enabled() {
ctx.record_assemble(AssembleStats {
chunks_selected: ctx.chunk_results.len(),
});
}
Ok(())
}
const SCORE_SAMPLE_LIMIT: usize = 8;
fn sample_scores<T, F>(items: &[Scored<T>], extractor: F) -> Vec<f32>
where
F: FnMut(&Scored<T>) -> f32,
{
items.iter().take(SCORE_SAMPLE_LIMIT).map(extractor).collect()
}
fn normalize_fts_query(input: &str) -> (String, usize) {
const STOPWORDS: &[&str] = &["the", "a", "an", "of", "in", "on", "and", "or", "to", "for"];
let mut cleaned = String::with_capacity(input.len());
for ch in input.chars() {
if ch.is_alphanumeric() {
cleaned.extend(ch.to_lowercase());
} else if ch.is_whitespace() {
cleaned.push(' ');
}
}
let mut tokens = Vec::with_capacity(cleaned.len().div_ceil(3));
for token in cleaned.split_whitespace() {
if !STOPWORDS.contains(&token) && !token.is_empty() {
tokens.push(token.to_string());
}
}
let normalized = tokens.join(" ");
(normalized, tokens.len())
}
fn build_chunk_rerank_documents(chunks: &[Scored<TextChunk>], max_chunks: usize) -> Vec<String> {
chunks
.iter()
.take(max_chunks)
.map(|chunk| {
format!(
"Source: {}\nChunk:\n{}",
chunk.item.source_id,
chunk.item.chunk.trim()
)
})
.collect()
}
fn apply_chunk_rerank_results(
chunks: &mut Vec<Scored<TextChunk>>,
tuning: &RetrievalTuning,
results: Vec<RerankResult>,
) {
if results.is_empty() || chunks.is_empty() {
return;
}
let mut remaining: Vec<Option<Scored<TextChunk>>> =
std::mem::take(chunks).into_iter().map(Some).collect();
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
let normalized_scores = min_max_normalize(&raw_scores);
let use_only = tuning.flags.rerank_scores_only();
let blend = if use_only {
1.0
} else {
clamp_unit(tuning.rerank_blend_weight)
};
let mut reranked: Vec<Scored<TextChunk>> = Vec::with_capacity(remaining.len());
for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) {
if let Some(slot) = remaining.get_mut(result.index) {
if let Some(mut candidate) = slot.take() {
let original = candidate.fused;
let blended = if use_only {
clamp_unit(normalized)
} else {
clamp_unit(original * (1.0 - blend) + normalized * blend)
};
candidate.update_fused(blended);
reranked.push(candidate);
}
} else {
warn!(
result_index = result.index,
"Chunk reranker returned out-of-range index; skipping"
);
}
if reranked.len() == remaining.len() {
break;
}
}
reranked.extend(remaining.into_iter().flatten());
let keep_top = tuning.rerank_keep_top;
if keep_top > 0 && reranked.len() > keep_top {
reranked.truncate(keep_top);
}
*chunks = reranked;
}
File diff suppressed because it is too large Load Diff
@@ -1,148 +0,0 @@
use super::{
stages::{
AssembleEntitiesStage, ChunkAssembleStage, ChunkRerankStage, ChunkVectorStage,
CollectCandidatesStage, EmbedStage, GraphExpansionStage, PipelineContext, RerankStage,
},
BoxedStage, StrategyDriver,
};
use crate::{RetrievedChunk, RetrievedEntity};
use common::error::AppError;
pub struct DefaultStrategyDriver;
impl DefaultStrategyDriver {
pub fn new() -> Self {
Self
}
}
impl StrategyDriver for DefaultStrategyDriver {
type Output = Vec<RetrievedChunk>;
fn stages(&self) -> Vec<BoxedStage> {
vec![
Box::new(EmbedStage),
Box::new(ChunkVectorStage),
Box::new(ChunkRerankStage),
Box::new(ChunkAssembleStage),
]
}
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
Ok(ctx.take_chunk_results())
}
}
pub struct RelationshipSuggestionDriver;
impl RelationshipSuggestionDriver {
pub fn new() -> Self {
Self
}
}
impl StrategyDriver for RelationshipSuggestionDriver {
type Output = Vec<RetrievedEntity>;
fn stages(&self) -> Vec<BoxedStage> {
vec![
Box::new(EmbedStage),
Box::new(CollectCandidatesStage),
Box::new(GraphExpansionStage),
// Skip ChunkAttachStage
Box::new(RerankStage),
Box::new(AssembleEntitiesStage),
]
}
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
Ok(ctx.take_entity_results())
}
}
pub struct IngestionDriver;
impl IngestionDriver {
pub fn new() -> Self {
Self
}
}
impl StrategyDriver for IngestionDriver {
type Output = Vec<RetrievedEntity>;
fn stages(&self) -> Vec<BoxedStage> {
vec![
Box::new(EmbedStage),
Box::new(CollectCandidatesStage),
Box::new(GraphExpansionStage),
// Skip ChunkAttachStage
Box::new(RerankStage),
Box::new(AssembleEntitiesStage),
]
}
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
Ok(ctx.take_entity_results())
}
}
use super::config::SearchTarget;
use crate::SearchResult;
/// Search strategy driver that retrieves both chunks and entities
pub struct SearchStrategyDriver {
target: SearchTarget,
}
impl SearchStrategyDriver {
pub fn new(target: SearchTarget) -> Self {
Self { target }
}
}
impl StrategyDriver for SearchStrategyDriver {
type Output = SearchResult;
fn stages(&self) -> Vec<BoxedStage> {
match self.target {
SearchTarget::ChunksOnly => vec![
Box::new(EmbedStage),
Box::new(ChunkVectorStage),
Box::new(ChunkRerankStage),
Box::new(ChunkAssembleStage),
],
SearchTarget::EntitiesOnly => vec![
Box::new(EmbedStage),
Box::new(CollectCandidatesStage),
Box::new(GraphExpansionStage),
Box::new(RerankStage),
Box::new(AssembleEntitiesStage),
],
SearchTarget::Both => vec![
Box::new(EmbedStage),
// Chunk retrieval path
Box::new(ChunkVectorStage),
Box::new(ChunkRerankStage),
Box::new(ChunkAssembleStage),
// Entity retrieval path (runs after chunk stages)
Box::new(CollectCandidatesStage),
Box::new(GraphExpansionStage),
Box::new(RerankStage),
Box::new(AssembleEntitiesStage),
],
}
}
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
let chunks = match self.target {
SearchTarget::EntitiesOnly => Vec::new(),
_ => ctx.take_chunk_results(),
};
let entities = match self.target {
SearchTarget::ChunksOnly => Vec::new(),
_ => ctx.take_entity_results(),
};
Ok(SearchResult::new(chunks, entities))
}
}
@@ -97,8 +97,7 @@ impl RerankerPool {
fn default_pool_size() -> usize {
available_parallelism()
.map(|value| value.get().min(2))
.unwrap_or(2)
.map_or(2, |value| value.get().min(2))
.max(1)
}
@@ -156,6 +155,7 @@ pub struct RerankerLease {
}
impl RerankerLease {
#[allow(clippy::result_large_err)]
pub async fn rerank(
&self,
query: &str,
@@ -165,7 +165,9 @@ impl RerankerLease {
let engine = Arc::clone(&self.engine);
tokio::task::spawn_blocking(move || {
let mut guard = engine.lock().expect("reranker engine mutex poisoned");
let mut guard = engine.lock().map_err(|_| {
AppError::InternalError("reranker engine mutex poisoned".into())
})?;
guard
.rerank(query, documents, false, None)
.map_err(|e| AppError::InternalError(e.to_string()))
+7 -120
View File
@@ -1,14 +1,12 @@
use std::{cmp::Ordering, collections::HashMap};
use common::storage::types::StoredObject;
use serde::{Deserialize, Serialize};
/// Holds optional subscores gathered from different retrieval signals.
/// Holds optional subscores gathered from the vector and full-text retrieval signals.
#[derive(Debug, Clone, Copy, Default)]
pub struct Scores {
pub fts: Option<f32>,
pub vector: Option<f32>,
pub graph: Option<f32>,
}
/// Generic wrapper combining an item with its accumulated retrieval scores.
@@ -40,40 +38,11 @@ impl<T> Scored<T> {
self
}
#[must_use]
pub const fn with_graph_score(mut self, score: f32) -> Self {
self.scores.graph = Some(score);
self
}
pub const fn update_fused(&mut self, fused: f32) {
self.fused = fused;
}
}
/// Weights used for linear score fusion.
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct FusionWeights {
pub vector: f32,
pub fts: f32,
pub graph: f32,
pub multi_bonus: f32,
}
impl Default for FusionWeights {
fn default() -> Self {
// Default weights favor vector search, which typically performs better
// FTS is used as a complement when there's good overlap
// Higher multi_bonus to heavily favor chunks with both signals (the "golden chunk")
Self {
vector: 0.8,
fts: 0.2,
graph: 0.2,
multi_bonus: 0.3, // Increased to boost chunks with both signals
}
}
}
/// Configuration for reciprocal rank fusion.
#[derive(Debug, Clone, Copy)]
pub struct RrfConfig {
@@ -84,29 +53,10 @@ pub struct RrfConfig {
pub use_fts: bool,
}
impl Default for RrfConfig {
fn default() -> Self {
Self {
k: 60.0,
vector_weight: 1.0,
fts_weight: 1.0,
use_vector: true,
use_fts: true,
}
}
}
pub const fn clamp_unit(value: f32) -> f32 {
value.clamp(0.0, 1.0)
}
pub fn distance_to_similarity(distance: f32) -> f32 {
if !distance.is_finite() {
return 0.0;
}
clamp_unit(1.0 / (1.0 + distance.max(0.0)))
}
pub fn min_max_normalize(scores: &[f32]) -> Vec<f32> {
if scores.is_empty() {
return Vec::new();
@@ -147,69 +97,6 @@ pub fn min_max_normalize(scores: &[f32]) -> Vec<f32> {
.collect()
}
pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 {
let vector = scores.vector.unwrap_or(0.0);
let fts = scores.fts.unwrap_or(0.0);
let graph = scores.graph.unwrap_or(0.0);
let mut fused = graph.mul_add(
weights.graph,
vector.mul_add(weights.vector, fts * weights.fts),
);
let signals_present = scores
.vector
.iter()
.chain(scores.fts.iter())
.chain(scores.graph.iter())
.count();
// Boost chunks with multiple signals (especially vector + FTS, the "golden chunk")
if signals_present >= 2 {
// For chunks with both vector and FTS, give a significant boost
// This helps identify the "golden chunk" that appears in both searches
if scores.vector.is_some() && scores.fts.is_some() {
// Multiplicative boost: multiply by (1 + bonus) to scale with the base score
// This ensures high-scoring golden chunks get boosted more than low-scoring ones
fused *= 1.0 + weights.multi_bonus;
} else {
// For other multi-signal combinations (e.g., vector + graph), use additive bonus
fused += weights.multi_bonus;
}
}
clamp_unit(fused)
}
pub fn merge_scored_by_id<T, S: std::hash::BuildHasher>(
target: &mut std::collections::HashMap<String, Scored<T>, S>,
incoming: Vec<Scored<T>>,
) where
T: StoredObject + Clone,
{
for scored in incoming {
let id = scored.item.id().to_owned();
target
.entry(id)
.and_modify(|existing| {
if let Some(score) = scored.scores.vector {
existing.scores.vector = Some(score);
}
if let Some(score) = scored.scores.fts {
existing.scores.fts = Some(score);
}
if let Some(score) = scored.scores.graph {
existing.scores.graph = Some(score);
}
})
.or_insert_with(|| Scored {
item: scored.item.clone(),
scores: scored.scores,
fused: scored.fused,
});
}
}
pub fn sort_by_fused_desc<T>(items: &mut [Scored<T>])
where
T: StoredObject,
@@ -222,6 +109,10 @@ where
});
}
/// Fuse two ranked candidate lists into a single ranking using reciprocal rank fusion.
///
/// This is the sole fusion mechanism for the retrieval pipeline: vector and full-text
/// candidates each contribute `weight / (k + rank + 1)` to a shared fused score.
pub fn reciprocal_rank_fusion<T>(
mut vector_ranked: Vec<Scored<T>>,
mut fts_ranked: Vec<Scored<T>>,
@@ -266,9 +157,7 @@ where
}
}
entry.item = candidate.item;
let rank_f32: f32 = u16::try_from(rank)
.map(f32::from)
.unwrap_or(f32::MAX);
let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from);
entry.fused += vector_weight / (k + rank_f32 + 1.0);
}
}
@@ -296,9 +185,7 @@ where
}
}
entry.item = candidate.item;
let rank_f32: f32 = u16::try_from(rank)
.map(f32::from)
.unwrap_or(f32::MAX);
let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from);
entry.fused += fts_weight / (k + rank_f32 + 1.0);
}
}