mirror of
https://github.com/perstarkse/minne.git
synced 2026-05-31 03:40:38 +02:00
chore: refactor retrieval pipeline to chunk-first RRF with derived entities and slimmer eval surface.
Collapse the multi-strategy entity engine into one benchmarked chunk retrieval path, derive entities from retrieved chunks, and update consumers, docs, and clippy fixes across the workspace.
This commit is contained in:
Generated
-4
@@ -5426,14 +5426,10 @@ dependencies = [
|
||||
"anyhow",
|
||||
"async-openai",
|
||||
"async-trait",
|
||||
"axum",
|
||||
"common",
|
||||
"fastembed",
|
||||
"futures",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"surrealdb",
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"uuid",
|
||||
|
||||
@@ -164,6 +164,35 @@ impl KnowledgeEntity {
|
||||
.take(0)?)
|
||||
}
|
||||
|
||||
/// Fetch all knowledge entities owned by any of the provided source ids for a user.
|
||||
///
|
||||
/// Used by retrieval to resolve the entities that own a set of retrieved chunks.
|
||||
pub async fn find_by_source_ids(
|
||||
db: &SurrealDbClient,
|
||||
source_ids: &[String],
|
||||
user_id: &str,
|
||||
) -> Result<Vec<KnowledgeEntity>, AppError> {
|
||||
if source_ids.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let entities: Vec<KnowledgeEntity> = db
|
||||
.client
|
||||
.query(
|
||||
"SELECT * FROM type::table($table) \
|
||||
WHERE source_id IN $sources AND user_id = $user_id",
|
||||
)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("sources", source_ids.to_vec()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await
|
||||
.map_err(AppError::Database)?
|
||||
.take(0)
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
Ok(entities)
|
||||
}
|
||||
|
||||
pub async fn delete_by_source_id(
|
||||
source_id: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
|
||||
+11
-119
@@ -1,8 +1,7 @@
|
||||
use config::{Config, ConfigError, Environment, File};
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
use std::{env, fmt, str::FromStr, sync::Once};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{env, str::FromStr, sync::Once};
|
||||
use thiserror::Error;
|
||||
use tracing::warn;
|
||||
|
||||
/// Error returned when parsing an embedding backend name.
|
||||
#[derive(Debug, Error, PartialEq, Eq)]
|
||||
@@ -36,83 +35,6 @@ impl EmbeddingBackend {
|
||||
}
|
||||
}
|
||||
|
||||
/// Error returned when parsing a retrieval strategy name.
|
||||
#[derive(Debug, Error, PartialEq, Eq)]
|
||||
#[error("unknown retrieval strategy '{input}'")]
|
||||
pub struct ParseRetrievalStrategyError {
|
||||
/// The unrecognized input string.
|
||||
pub input: String,
|
||||
}
|
||||
|
||||
/// Selects which retrieval pipeline strategy to run for chat and search.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RetrievalStrategy {
|
||||
/// Primary hybrid chunk retrieval for search/chat.
|
||||
#[default]
|
||||
Default,
|
||||
/// Entity retrieval for suggesting relationships when creating manual entities.
|
||||
RelationshipSuggestion,
|
||||
/// Entity retrieval for context during content ingestion.
|
||||
Ingestion,
|
||||
/// Unified search returning both chunks and entities.
|
||||
Search,
|
||||
}
|
||||
|
||||
impl RetrievalStrategy {
|
||||
#[must_use]
|
||||
pub fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
Self::Default => "default",
|
||||
Self::RelationshipSuggestion => "relationship_suggestion",
|
||||
Self::Ingestion => "ingestion",
|
||||
Self::Search => "search",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for RetrievalStrategy {
|
||||
type Err = ParseRetrievalStrategyError;
|
||||
|
||||
fn from_str(value: &str) -> Result<Self, Self::Err> {
|
||||
match value.to_ascii_lowercase().as_str() {
|
||||
"default" => Ok(Self::Default),
|
||||
"initial" | "revised" => {
|
||||
warn!(
|
||||
"retrieval strategy '{value}' is deprecated; use 'default' instead"
|
||||
);
|
||||
Ok(Self::Default)
|
||||
}
|
||||
"relationship_suggestion" => Ok(Self::RelationshipSuggestion),
|
||||
"ingestion" => Ok(Self::Ingestion),
|
||||
"search" => Ok(Self::Search),
|
||||
other => Err(ParseRetrievalStrategyError {
|
||||
input: other.to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for RetrievalStrategy {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(self.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
fn deserialize_optional_retrieval_strategy<'de, D>(
|
||||
deserializer: D,
|
||||
) -> Result<Option<RetrievalStrategy>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value = Option::<String>::deserialize(deserializer)?;
|
||||
match value {
|
||||
None => Ok(None),
|
||||
Some(raw) if raw.trim().is_empty() => Ok(None),
|
||||
Some(raw) => RetrievalStrategy::from_str(&raw).map(Some).map_err(serde::de::Error::custom),
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for EmbeddingBackend {
|
||||
type Err = ParseEmbeddingBackendError;
|
||||
|
||||
@@ -195,8 +117,6 @@ pub struct AppConfig {
|
||||
pub fastembed_show_download_progress: Option<bool>,
|
||||
#[serde(default)]
|
||||
pub fastembed_max_length: Option<usize>,
|
||||
#[serde(default, deserialize_with = "deserialize_optional_retrieval_strategy")]
|
||||
pub retrieval_strategy: Option<RetrievalStrategy>,
|
||||
#[serde(default)]
|
||||
pub embedding_backend: EmbeddingBackend,
|
||||
#[serde(default = "default_ingest_max_body_bytes")]
|
||||
@@ -282,14 +202,6 @@ pub fn ensure_ort_path() {
|
||||
});
|
||||
}
|
||||
|
||||
impl AppConfig {
|
||||
/// Returns the configured retrieval strategy, or [`RetrievalStrategy::Default`] when unset.
|
||||
#[must_use]
|
||||
pub fn resolved_retrieval_strategy(&self) -> RetrievalStrategy {
|
||||
self.retrieval_strategy.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AppConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@@ -312,7 +224,6 @@ impl Default for AppConfig {
|
||||
fastembed_cache_dir: None,
|
||||
fastembed_show_download_progress: None,
|
||||
fastembed_max_length: None,
|
||||
retrieval_strategy: None,
|
||||
embedding_backend: EmbeddingBackend::default(),
|
||||
ingest_max_body_bytes: default_ingest_max_body_bytes(),
|
||||
ingest_max_files: default_ingest_max_files(),
|
||||
@@ -340,41 +251,22 @@ pub fn get_config() -> Result<AppConfig, ConfigError> {
|
||||
mod tests {
|
||||
#![allow(clippy::expect_used)]
|
||||
|
||||
use super::{ParseRetrievalStrategyError, RetrievalStrategy};
|
||||
use super::EmbeddingBackend;
|
||||
|
||||
#[test]
|
||||
fn retrieval_strategy_defaults_to_default() {
|
||||
assert_eq!(
|
||||
RetrievalStrategy::default(),
|
||||
RetrievalStrategy::Default
|
||||
);
|
||||
fn embedding_backend_defaults_to_fastembed() {
|
||||
assert_eq!(EmbeddingBackend::default(), EmbeddingBackend::FastEmbed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn retrieval_strategy_serializes_snake_case() {
|
||||
fn embedding_backend_parses_aliases() {
|
||||
assert_eq!(
|
||||
serde_json::to_string(&RetrievalStrategy::Search).expect("serialize"),
|
||||
"\"search\""
|
||||
"openai".parse::<EmbeddingBackend>().expect("openai"),
|
||||
EmbeddingBackend::OpenAI
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn retrieval_strategy_from_str_accepts_deprecated_aliases() {
|
||||
assert_eq!(
|
||||
"initial".parse::<RetrievalStrategy>().expect("initial"),
|
||||
RetrievalStrategy::Default
|
||||
);
|
||||
assert!(matches!(
|
||||
"unknown".parse::<RetrievalStrategy>(),
|
||||
Err(ParseRetrievalStrategyError { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn app_config_resolved_retrieval_strategy_uses_default_when_unset() {
|
||||
let config = super::AppConfig::default();
|
||||
assert_eq!(
|
||||
config.resolved_retrieval_strategy(),
|
||||
RetrievalStrategy::Default
|
||||
"fast".parse::<EmbeddingBackend>().expect("fast"),
|
||||
EmbeddingBackend::FastEmbed
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,7 +24,6 @@ Minne can be configured via environment variables or a `config.yaml` file. Envir
|
||||
| `RUST_LOG` | Logging level | `info` |
|
||||
| `STORAGE` | Storage backend (`local`, `memory`, `s3`) | `local` |
|
||||
| `PDF_INGEST_MODE` | PDF ingestion strategy (`classic`, `llm-first`) | `llm-first` |
|
||||
| `RETRIEVAL_STRATEGY` | Default retrieval strategy | - |
|
||||
| `EMBEDDING_BACKEND` | Embedding provider (`openai`, `fastembed`) | `fastembed` |
|
||||
| `FASTEMBED_CACHE_DIR` | Model cache directory | `<data_dir>/fastembed` |
|
||||
| `FASTEMBED_SHOW_DOWNLOAD_PROGRESS` | Show progress bar for model downloads | `false` |
|
||||
|
||||
+6
-5
@@ -27,13 +27,14 @@ The D3-based graph visualization shows entities as nodes and relationships as ed
|
||||
|
||||
## Hybrid Retrieval
|
||||
|
||||
Minne combines multiple retrieval strategies:
|
||||
Minne uses chunk-first hybrid retrieval over the knowledge base:
|
||||
|
||||
- **Vector similarity** — Semantic matching via embeddings
|
||||
- **Full-text search** — Keyword matching with BM25
|
||||
- **Graph traversal** — Following relationships between entities
|
||||
- **Vector similarity** — Semantic matching via embeddings over text chunks
|
||||
- **Full-text search** — Keyword matching with BM25 over the same chunk index
|
||||
|
||||
Results are merged using Reciprocal Rank Fusion (RRF) for optimal relevance.
|
||||
The two ranked candidate lists are merged with Reciprocal Rank Fusion (RRF). When a caller needs knowledge entities (search, ingestion linking, relationship suggestion), entities are derived from the top retrieved chunks grouped by `source_id`.
|
||||
|
||||
Optional **reranking** can rescore the fused chunk list with a cross-encoder model; see below.
|
||||
|
||||
## Reranking (Optional)
|
||||
|
||||
|
||||
+7
-18
@@ -5,7 +5,6 @@ use std::{
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use clap::{Args, Parser, ValueEnum};
|
||||
use retrieval_pipeline::RetrievalStrategy;
|
||||
|
||||
use crate::datasets::DatasetKind;
|
||||
|
||||
@@ -55,10 +54,6 @@ pub struct RetrievalSettings {
|
||||
#[arg(long)]
|
||||
pub chunk_fts_take: Option<usize>,
|
||||
|
||||
/// Override average characters per token used for budgeting
|
||||
#[arg(long)]
|
||||
pub chunk_avg_chars_per_token: Option<usize>,
|
||||
|
||||
/// Override maximum chunks attached per entity
|
||||
#[arg(long)]
|
||||
pub max_chunks_per_entity: Option<usize>,
|
||||
@@ -71,41 +66,37 @@ pub struct RetrievalSettings {
|
||||
#[arg(long, default_value_t = 4)]
|
||||
pub rerank_pool_size: usize,
|
||||
|
||||
/// Keep top-N entities after reranking
|
||||
/// Keep top-N chunks after reranking
|
||||
#[arg(long, default_value_t = 10)]
|
||||
pub rerank_keep_top: usize,
|
||||
|
||||
/// Cap the number of chunks returned by retrieval (revised strategy)
|
||||
/// Cap the number of chunks returned by retrieval
|
||||
#[arg(long, default_value_t = 5)]
|
||||
pub chunk_result_cap: usize,
|
||||
|
||||
/// Reciprocal rank fusion k value for revised chunk merging
|
||||
/// Reciprocal rank fusion k value for chunk merging
|
||||
#[arg(long)]
|
||||
pub chunk_rrf_k: Option<f32>,
|
||||
|
||||
/// Weight for vector ranks in revised RRF
|
||||
/// Weight for vector ranks in RRF
|
||||
#[arg(long)]
|
||||
pub chunk_rrf_vector_weight: Option<f32>,
|
||||
|
||||
/// Weight for chunk FTS ranks in revised RRF
|
||||
/// Weight for chunk FTS ranks in RRF
|
||||
#[arg(long)]
|
||||
pub chunk_rrf_fts_weight: Option<f32>,
|
||||
|
||||
/// Include vector ranks in revised RRF (default: true)
|
||||
/// Include vector ranks in RRF (default: true)
|
||||
#[arg(long)]
|
||||
pub chunk_rrf_use_vector: Option<bool>,
|
||||
|
||||
/// Include chunk FTS ranks in revised RRF (default: true)
|
||||
/// Include chunk FTS ranks in RRF (default: true)
|
||||
#[arg(long)]
|
||||
pub chunk_rrf_use_fts: Option<bool>,
|
||||
|
||||
/// Require verified chunks (disable with --llm-mode)
|
||||
#[arg(skip = true)]
|
||||
pub require_verified_chunks: bool,
|
||||
|
||||
/// Select the retrieval pipeline strategy
|
||||
#[arg(long, default_value_t = RetrievalStrategy::Default)]
|
||||
pub strategy: RetrievalStrategy,
|
||||
}
|
||||
|
||||
impl Default for RetrievalSettings {
|
||||
@@ -113,7 +104,6 @@ impl Default for RetrievalSettings {
|
||||
Self {
|
||||
chunk_vector_take: None,
|
||||
chunk_fts_take: None,
|
||||
chunk_avg_chars_per_token: None,
|
||||
max_chunks_per_entity: None,
|
||||
rerank: false,
|
||||
rerank_pool_size: 4,
|
||||
@@ -125,7 +115,6 @@ impl Default for RetrievalSettings {
|
||||
chunk_rrf_use_vector: None,
|
||||
chunk_rrf_use_fts: None,
|
||||
require_verified_chunks: true,
|
||||
strategy: RetrievalStrategy::Default,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+18
-20
@@ -51,8 +51,8 @@ pub fn mirror_perf_outputs(
|
||||
pub fn print_console_summary(record: &EvaluationReport) {
|
||||
let perf = &record.performance;
|
||||
println!(
|
||||
"[perf] retrieval strategy={} | concurrency={} | rerank={} (pool {:?}, keep {})",
|
||||
record.retrieval.strategy,
|
||||
"[perf] resolve_entities={} | concurrency={} | rerank={} (pool {:?}, keep {})",
|
||||
record.retrieval.resolve_entities,
|
||||
record.retrieval.concurrency,
|
||||
record.retrieval.rerank_enabled,
|
||||
record.retrieval.rerank_pool_size,
|
||||
@@ -63,16 +63,14 @@ pub fn print_console_summary(record: &EvaluationReport) {
|
||||
perf.ingestion_ms,
|
||||
format_duration(perf.namespace_seed_ms),
|
||||
);
|
||||
let stage = &perf.stage_latency;
|
||||
println!(
|
||||
"[perf] stage avg ms → embed {:.1} | collect {:.1} | graph {:.1} | chunk {:.1} | rerank {:.1} | assemble {:.1}",
|
||||
stage.embed.avg,
|
||||
stage.collect_candidates.avg,
|
||||
stage.graph_expansion.avg,
|
||||
stage.chunk_attach.avg,
|
||||
stage.rerank.avg,
|
||||
stage.assemble.avg,
|
||||
);
|
||||
let stage_summary = perf
|
||||
.stage_latency
|
||||
.stages
|
||||
.iter()
|
||||
.map(|s| format!("{} {:.1}", s.stage, s.stats.avg))
|
||||
.collect::<Vec<_>>()
|
||||
.join(" | ");
|
||||
println!("[perf] stage avg ms → {stage_summary}");
|
||||
let eval = &perf.evaluation_stages_ms;
|
||||
println!(
|
||||
"[perf] eval stage ms → slice {} | db {} | corpus {} | namespace {} | queries {} | summarize {} | finalize {}",
|
||||
@@ -107,12 +105,13 @@ mod tests {
|
||||
|
||||
fn sample_stage_latency() -> crate::eval::StageLatencyBreakdown {
|
||||
crate::eval::StageLatencyBreakdown {
|
||||
embed: sample_latency(),
|
||||
collect_candidates: sample_latency(),
|
||||
graph_expansion: sample_latency(),
|
||||
chunk_attach: sample_latency(),
|
||||
rerank: sample_latency(),
|
||||
assemble: sample_latency(),
|
||||
stages: ["embed", "search", "rerank", "resolve_entities", "assemble"]
|
||||
.into_iter()
|
||||
.map(|stage| crate::eval::StageLatency {
|
||||
stage: stage.to_string(),
|
||||
stats: sample_latency(),
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -193,7 +192,7 @@ mod tests {
|
||||
rerank_keep_top: 10,
|
||||
concurrency: 2,
|
||||
detailed_report: false,
|
||||
retrieval_strategy: "initial".into(),
|
||||
resolve_entities: false,
|
||||
chunk_result_cap: 5,
|
||||
chunk_rrf_k: 60.0,
|
||||
chunk_rrf_vector_weight: 1.0,
|
||||
@@ -206,7 +205,6 @@ mod tests {
|
||||
ingest_chunk_overlap_tokens: 50,
|
||||
chunk_vector_take: 20,
|
||||
chunk_fts_take: 20,
|
||||
chunk_avg_chars_per_token: 4,
|
||||
max_chunks_per_entity: 4,
|
||||
cases: Vec::new(),
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ use futures::stream::{self, StreamExt};
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::eval::{
|
||||
adapt_strategy_output, build_case_diagnostics, text_contains_answer, CaseDiagnostics,
|
||||
adapt_retrieval_output, build_case_diagnostics, text_contains_answer, CaseDiagnostics,
|
||||
CaseSummary, RetrievedSummary,
|
||||
};
|
||||
use retrieval_pipeline::{
|
||||
@@ -51,14 +51,8 @@ pub(crate) async fn run_queries(
|
||||
None
|
||||
};
|
||||
|
||||
let mut retrieval_config = RetrievalConfig {
|
||||
strategy: config.retrieval.strategy,
|
||||
..Default::default()
|
||||
};
|
||||
let mut retrieval_config = RetrievalConfig::default();
|
||||
retrieval_config.tuning.rerank_keep_top = config.retrieval.rerank_keep_top;
|
||||
if retrieval_config.tuning.fallback_min_results < config.retrieval.rerank_keep_top {
|
||||
retrieval_config.tuning.fallback_min_results = config.retrieval.rerank_keep_top;
|
||||
}
|
||||
retrieval_config.tuning.chunk_result_cap = config.retrieval.chunk_result_cap.max(1);
|
||||
if let Some(value) = config.retrieval.chunk_vector_take {
|
||||
retrieval_config.tuning.chunk_vector_take = value;
|
||||
@@ -81,9 +75,6 @@ pub(crate) async fn run_queries(
|
||||
if let Some(value) = config.retrieval.chunk_rrf_use_fts {
|
||||
retrieval_config.tuning.flags.chunk_rrf_use_fts = value.into();
|
||||
}
|
||||
if let Some(value) = config.retrieval.chunk_avg_chars_per_token {
|
||||
retrieval_config.tuning.avg_chars_per_token = value;
|
||||
}
|
||||
if let Some(value) = config.retrieval.max_chunks_per_entity {
|
||||
retrieval_config.tuning.max_chunks_per_entity = value;
|
||||
}
|
||||
@@ -187,7 +178,7 @@ pub(crate) async fn run_queries(
|
||||
None => None,
|
||||
};
|
||||
|
||||
let params = pipeline::StrategyParams {
|
||||
let params = pipeline::RetrievalParams {
|
||||
db_client: &db,
|
||||
openai_client: &openai_client,
|
||||
embedding_provider: Some(&embedding_provider),
|
||||
@@ -196,26 +187,19 @@ pub(crate) async fn run_queries(
|
||||
config: (*retrieval_config).clone(),
|
||||
reranker,
|
||||
};
|
||||
let (result_output, pipeline_diagnostics, stage_timings) = if diagnostics_enabled {
|
||||
let outcome = pipeline::run_pipeline_with_embedding_with_diagnostics(
|
||||
let (result_output, pipeline_diagnostics, stage_timings) = {
|
||||
let outcome = pipeline::run_with_embedding_instrumented(
|
||||
params,
|
||||
query_embedding,
|
||||
diagnostics_enabled,
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("running pipeline for question {question_id}"))?;
|
||||
(outcome.results, outcome.diagnostics, outcome.stage_timings)
|
||||
} else {
|
||||
let outcome = pipeline::run_pipeline_with_embedding_with_metrics(
|
||||
params,
|
||||
query_embedding,
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("running pipeline for question {question_id}"))?;
|
||||
(outcome.results, None, outcome.stage_timings)
|
||||
};
|
||||
let query_latency = query_start.elapsed().as_millis();
|
||||
|
||||
let candidates = adapt_strategy_output(result_output);
|
||||
let candidates = adapt_retrieval_output(result_output);
|
||||
let mut retrieved = Vec::new();
|
||||
let mut match_rank = None;
|
||||
let answers_lower: Vec<String> =
|
||||
|
||||
@@ -201,7 +201,10 @@ pub(crate) async fn summarize(
|
||||
rerank_keep_top: config.retrieval.rerank_keep_top,
|
||||
concurrency: config.concurrency.max(1),
|
||||
detailed_report: config.detailed_report,
|
||||
retrieval_strategy: config.retrieval.strategy.to_string(),
|
||||
resolve_entities: ctx
|
||||
.retrieval_config
|
||||
.as_ref()
|
||||
.is_some_and(|config| config.resolve_entities),
|
||||
chunk_result_cap: config.retrieval.chunk_result_cap,
|
||||
chunk_rrf_k: active_tuning.chunk_rrf_k,
|
||||
chunk_rrf_vector_weight: active_tuning.chunk_rrf_vector_weight,
|
||||
@@ -214,7 +217,6 @@ pub(crate) async fn summarize(
|
||||
ingest_chunk_overlap_tokens: config.ingest.ingest_chunk_overlap_tokens,
|
||||
chunk_vector_take: active_tuning.chunk_vector_take,
|
||||
chunk_fts_take: active_tuning.chunk_fts_take,
|
||||
chunk_avg_chars_per_token: active_tuning.avg_chars_per_token,
|
||||
max_chunks_per_entity: active_tuning.max_chunks_per_entity,
|
||||
cases: summaries,
|
||||
});
|
||||
|
||||
+54
-49
@@ -85,7 +85,7 @@ pub struct RetrievalSection {
|
||||
pub average_ndcg: f64,
|
||||
pub latency: LatencyStats,
|
||||
pub concurrency: usize,
|
||||
pub strategy: String,
|
||||
pub resolve_entities: bool,
|
||||
pub rerank_enabled: bool,
|
||||
pub rerank_pool_size: Option<usize>,
|
||||
pub rerank_keep_top: usize,
|
||||
@@ -226,7 +226,7 @@ impl EvaluationReport {
|
||||
average_ndcg: summary.average_ndcg,
|
||||
latency: summary.latency_ms.clone(),
|
||||
concurrency: summary.concurrency,
|
||||
strategy: summary.retrieval_strategy.clone(),
|
||||
resolve_entities: summary.resolve_entities,
|
||||
rerank_enabled: summary.rerank_enabled,
|
||||
rerank_pool_size: summary.rerank_pool_size,
|
||||
rerank_keep_top: summary.rerank_keep_top,
|
||||
@@ -463,7 +463,12 @@ fn render_markdown(report: &EvaluationReport) -> String {
|
||||
write!(md, "| MRR | {:.3} |\\n", report.retrieval.mrr).unwrap();
|
||||
write!(md, "| NDCG | {:.3} |\\n", report.retrieval.average_ndcg).unwrap();
|
||||
write!(md, "| Latency Avg / P50 / P95 (ms) | {:.1} / {} / {} |\\n", report.retrieval.latency.avg, report.retrieval.latency.p50, report.retrieval.latency.p95).unwrap();
|
||||
write!(md, "| Strategy | `{}` |\\n", report.retrieval.strategy).unwrap();
|
||||
write!(
|
||||
md,
|
||||
"| Resolve entities | {} |\\n",
|
||||
bool_badge(report.retrieval.resolve_entities)
|
||||
)
|
||||
.unwrap();
|
||||
write!(md, "| Concurrency | {} |\\n", report.retrieval.concurrency).unwrap();
|
||||
if report.retrieval.rerank_enabled {
|
||||
let pool = report
|
||||
@@ -510,28 +515,9 @@ fn render_markdown(report: &EvaluationReport) -> String {
|
||||
|
||||
md.push_str("\\n## Retrieval Stage Timings\\n\\n");
|
||||
md.push_str("| Stage | Avg (ms) | P50 (ms) | P95 (ms) |\\n| --- | --- | --- | --- |\\n");
|
||||
write_stage_row(&mut md, "Embed", &report.performance.stage_latency.embed);
|
||||
write_stage_row(
|
||||
&mut md,
|
||||
"Collect Candidates",
|
||||
&report.performance.stage_latency.collect_candidates,
|
||||
);
|
||||
write_stage_row(
|
||||
&mut md,
|
||||
"Graph Expansion",
|
||||
&report.performance.stage_latency.graph_expansion,
|
||||
);
|
||||
write_stage_row(
|
||||
&mut md,
|
||||
"Chunk Attach",
|
||||
&report.performance.stage_latency.chunk_attach,
|
||||
);
|
||||
write_stage_row(&mut md, "Rerank", &report.performance.stage_latency.rerank);
|
||||
write_stage_row(
|
||||
&mut md,
|
||||
"Assemble",
|
||||
&report.performance.stage_latency.assemble,
|
||||
);
|
||||
for stage in &report.performance.stage_latency.stages {
|
||||
write_stage_row(&mut md, &prettify_stage(&stage.stage), &stage.stats);
|
||||
}
|
||||
|
||||
if report.misses.is_empty() {
|
||||
if report.detailed_report {
|
||||
@@ -623,6 +609,20 @@ fn write_stage_row(buf: &mut String, label: &str, stats: &LatencyStats) {
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
/// Turn a stable stage label (e.g. `resolve_entities`) into a display title (`Resolve Entities`).
|
||||
fn prettify_stage(label: &str) -> String {
|
||||
label
|
||||
.split('_')
|
||||
.map(|word| {
|
||||
let mut chars = word.chars();
|
||||
chars.next().map_or_else(String::new, |first| {
|
||||
first.to_uppercase().collect::<String>() + chars.as_str()
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ")
|
||||
}
|
||||
|
||||
fn bool_badge(value: bool) -> &'static str {
|
||||
if value {
|
||||
"✅"
|
||||
@@ -740,17 +740,6 @@ struct LegacyHistoryDelta {
|
||||
latency_avg_ms: f64,
|
||||
}
|
||||
|
||||
fn default_stage_latency() -> StageLatencyBreakdown {
|
||||
StageLatencyBreakdown {
|
||||
embed: LatencyStats::default(),
|
||||
collect_candidates: LatencyStats::default(),
|
||||
graph_expansion: LatencyStats::default(),
|
||||
chunk_attach: LatencyStats::default(),
|
||||
rerank: LatencyStats::default(),
|
||||
assemble: LatencyStats::default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn convert_legacy_entry(entry: LegacyHistoryEntry) -> EvaluationReport {
|
||||
let overview = OverviewSection {
|
||||
@@ -807,7 +796,7 @@ fn convert_legacy_entry(entry: LegacyHistoryEntry) -> EvaluationReport {
|
||||
average_ndcg: entry.average_ndcg,
|
||||
latency: entry.latency_ms,
|
||||
concurrency: 0,
|
||||
strategy: "unknown".into(),
|
||||
resolve_entities: false,
|
||||
rerank_enabled: entry.rerank_enabled,
|
||||
rerank_pool_size: entry.rerank_pool_size,
|
||||
rerank_keep_top: entry.rerank_keep_top,
|
||||
@@ -840,7 +829,7 @@ fn convert_legacy_entry(entry: LegacyHistoryEntry) -> EvaluationReport {
|
||||
ingestion_ms: entry.ingestion_ms,
|
||||
namespace_seed_ms: entry.namespace_seed_ms,
|
||||
evaluation_stages_ms: EvaluationStageTimings::default(),
|
||||
stage_latency: default_stage_latency(),
|
||||
stage_latency: StageLatencyBreakdown::default(),
|
||||
namespace_reused: false,
|
||||
ingestion_reused: entry.ingestion_reused,
|
||||
embeddings_reused: entry.ingestion_embeddings_reused,
|
||||
@@ -915,7 +904,8 @@ fn record_history(report: &EvaluationReport, report_dir: &Path) -> Result<PathBu
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::eval::{
|
||||
EvaluationStageTimings, PerformanceTimings, RetrievedSummary, StageLatencyBreakdown,
|
||||
EvaluationStageTimings, PerformanceTimings, RetrievedSummary, StageLatency,
|
||||
StageLatencyBreakdown,
|
||||
};
|
||||
use chrono::Utc;
|
||||
use tempfile::tempdir;
|
||||
@@ -931,12 +921,28 @@ mod tests {
|
||||
|
||||
fn sample_stage_latency() -> StageLatencyBreakdown {
|
||||
StageLatencyBreakdown {
|
||||
embed: latency(9.0),
|
||||
collect_candidates: latency(10.0),
|
||||
graph_expansion: latency(11.0),
|
||||
chunk_attach: latency(12.0),
|
||||
rerank: latency(13.0),
|
||||
assemble: latency(14.0),
|
||||
stages: vec![
|
||||
StageLatency {
|
||||
stage: "embed".to_string(),
|
||||
stats: latency(9.0),
|
||||
},
|
||||
StageLatency {
|
||||
stage: "search".to_string(),
|
||||
stats: latency(10.0),
|
||||
},
|
||||
StageLatency {
|
||||
stage: "rerank".to_string(),
|
||||
stats: latency(13.0),
|
||||
},
|
||||
StageLatency {
|
||||
stage: "resolve_entities".to_string(),
|
||||
stats: latency(11.0),
|
||||
},
|
||||
StageLatency {
|
||||
stage: "assemble".to_string(),
|
||||
stats: latency(14.0),
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1058,7 +1064,7 @@ mod tests {
|
||||
rerank_keep_top: 5,
|
||||
concurrency: 2,
|
||||
detailed_report: true,
|
||||
retrieval_strategy: "initial".into(),
|
||||
resolve_entities: false,
|
||||
chunk_result_cap: 5,
|
||||
chunk_rrf_k: 60.0,
|
||||
chunk_rrf_vector_weight: 1.0,
|
||||
@@ -1071,7 +1077,6 @@ mod tests {
|
||||
ingest_chunks_only: false,
|
||||
chunk_vector_take: 50,
|
||||
chunk_fts_take: 50,
|
||||
chunk_avg_chars_per_token: 4,
|
||||
max_chunks_per_entity: 4,
|
||||
cases,
|
||||
}
|
||||
@@ -1097,7 +1102,7 @@ mod tests {
|
||||
|
||||
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::indexing_slicing)]
|
||||
#[test]
|
||||
fn evaluations_history_captures_strategy_and_concurrency() {
|
||||
fn evaluations_history_captures_resolve_entities_and_concurrency() {
|
||||
let tmp = tempdir().unwrap();
|
||||
let summary = sample_summary(false);
|
||||
|
||||
@@ -1109,7 +1114,7 @@ mod tests {
|
||||
assert_eq!(entries.len(), 1);
|
||||
let stored = &entries[0];
|
||||
assert_eq!(stored.retrieval.concurrency, summary.concurrency);
|
||||
assert_eq!(stored.retrieval.strategy, summary.retrieval_strategy);
|
||||
assert_eq!(stored.retrieval.resolve_entities, summary.resolve_entities);
|
||||
assert_eq!(
|
||||
stored.performance.evaluation_stages_ms.run_queries_ms,
|
||||
summary.perf.evaluation_stage_ms.run_queries_ms
|
||||
|
||||
+33
-38
@@ -3,7 +3,7 @@ use std::collections::HashSet;
|
||||
use chrono::{DateTime, Utc};
|
||||
use common::storage::types::StoredObject;
|
||||
use retrieval_pipeline::{
|
||||
PipelineDiagnostics, PipelineStageTimings, RetrievedChunk, RetrievedEntity, StrategyOutput,
|
||||
Diagnostics, RetrievalOutput, RetrievedChunk, RetrievedEntity, StageKind, StageTimings,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use unicode_normalization::UnicodeNormalization;
|
||||
@@ -69,7 +69,7 @@ pub struct EvaluationSummary {
|
||||
pub rerank_keep_top: usize,
|
||||
pub concurrency: usize,
|
||||
pub detailed_report: bool,
|
||||
pub retrieval_strategy: String,
|
||||
pub resolve_entities: bool,
|
||||
pub chunk_result_cap: usize,
|
||||
pub chunk_rrf_k: f32,
|
||||
pub chunk_rrf_vector_weight: f32,
|
||||
@@ -82,7 +82,6 @@ pub struct EvaluationSummary {
|
||||
pub ingest_chunk_overlap_tokens: usize,
|
||||
pub chunk_vector_take: usize,
|
||||
pub chunk_fts_take: usize,
|
||||
pub chunk_avg_chars_per_token: usize,
|
||||
pub max_chunks_per_entity: usize,
|
||||
pub cases: Vec<CaseSummary>,
|
||||
}
|
||||
@@ -129,14 +128,20 @@ impl Default for LatencyStats {
|
||||
}
|
||||
}
|
||||
|
||||
/// Latency statistics for a single retrieval stage, keyed by the stage's stable label.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StageLatency {
|
||||
pub stage: String,
|
||||
pub stats: LatencyStats,
|
||||
}
|
||||
|
||||
/// Per-stage retrieval latency, in canonical pipeline order.
|
||||
///
|
||||
/// The set of stages is driven entirely by [`StageKind::ALL`], so adding a retrieval stage
|
||||
/// surfaces here automatically without changes to the evaluation harness.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct StageLatencyBreakdown {
|
||||
pub embed: LatencyStats,
|
||||
pub collect_candidates: LatencyStats,
|
||||
pub graph_expansion: LatencyStats,
|
||||
pub chunk_attach: LatencyStats,
|
||||
pub rerank: LatencyStats,
|
||||
pub assemble: LatencyStats,
|
||||
pub stages: Vec<StageLatency>,
|
||||
}
|
||||
|
||||
#[allow(clippy::struct_field_names)]
|
||||
@@ -232,13 +237,12 @@ fn candidates_from_chunks(chunks: Vec<RetrievedChunk>) -> Vec<EvaluationCandidat
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn adapt_strategy_output(output: StrategyOutput) -> Vec<EvaluationCandidate> {
|
||||
pub fn adapt_retrieval_output(output: RetrievalOutput) -> Vec<EvaluationCandidate> {
|
||||
match output {
|
||||
StrategyOutput::Entities(entities) => candidates_from_entities(entities),
|
||||
StrategyOutput::Chunks(chunks) => candidates_from_chunks(chunks),
|
||||
StrategyOutput::Search(search_result) => {
|
||||
let mut candidates = candidates_from_entities(search_result.entities);
|
||||
candidates.extend(candidates_from_chunks(search_result.chunks));
|
||||
RetrievalOutput::Chunks(chunks) => candidates_from_chunks(chunks),
|
||||
RetrievalOutput::WithEntities { chunks, entities } => {
|
||||
let mut candidates = candidates_from_entities(entities);
|
||||
candidates.extend(candidates_from_chunks(chunks));
|
||||
candidates.sort_by(|a, b| b.score.total_cmp(&a.score));
|
||||
candidates
|
||||
}
|
||||
@@ -262,7 +266,7 @@ pub struct CaseDiagnostics {
|
||||
pub attached_chunk_ids: Vec<String>,
|
||||
pub retrieved: Vec<EntityDiagnostics>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub pipeline: Option<PipelineDiagnostics>,
|
||||
pub pipeline: Option<Diagnostics>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
@@ -366,28 +370,19 @@ pub fn compute_latency_stats(latencies: &[u128]) -> LatencyStats {
|
||||
LatencyStats { avg, p50, p95 }
|
||||
}
|
||||
|
||||
pub fn build_stage_latency_breakdown(samples: &[PipelineStageTimings]) -> StageLatencyBreakdown {
|
||||
fn collect_stage<F>(samples: &[PipelineStageTimings], selector: F) -> Vec<u128>
|
||||
where
|
||||
F: Fn(&PipelineStageTimings) -> u128,
|
||||
{
|
||||
samples.iter().map(selector).collect()
|
||||
}
|
||||
pub fn build_stage_latency_breakdown(samples: &[StageTimings]) -> StageLatencyBreakdown {
|
||||
let stages = StageKind::ALL
|
||||
.iter()
|
||||
.map(|kind| {
|
||||
let latencies: Vec<u128> = samples.iter().map(|s| s.stage_ms(*kind)).collect();
|
||||
StageLatency {
|
||||
stage: kind.label().to_string(),
|
||||
stats: compute_latency_stats(&latencies),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
StageLatencyBreakdown {
|
||||
embed: compute_latency_stats(&collect_stage(samples, retrieval_pipeline::StageTimings::embed_ms)),
|
||||
collect_candidates: compute_latency_stats(&collect_stage(samples, |entry| {
|
||||
entry.collect_candidates_ms()
|
||||
})),
|
||||
graph_expansion: compute_latency_stats(&collect_stage(samples, |entry| {
|
||||
entry.graph_expansion_ms()
|
||||
})),
|
||||
chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| {
|
||||
entry.chunk_attach_ms()
|
||||
})),
|
||||
rerank: compute_latency_stats(&collect_stage(samples, retrieval_pipeline::StageTimings::rerank_ms)),
|
||||
assemble: compute_latency_stats(&collect_stage(samples, retrieval_pipeline::StageTimings::assemble_ms)),
|
||||
}
|
||||
StageLatencyBreakdown { stages }
|
||||
}
|
||||
|
||||
#[allow(
|
||||
@@ -412,7 +407,7 @@ pub fn build_case_diagnostics(
|
||||
expected_chunk_ids: &[String],
|
||||
answers_lower: &[String],
|
||||
candidates: &[EvaluationCandidate],
|
||||
pipeline_stats: Option<PipelineDiagnostics>,
|
||||
pipeline_stats: Option<Diagnostics>,
|
||||
) -> CaseDiagnostics {
|
||||
let expected_set: HashSet<&str> = expected_chunk_ids.iter().map(std::string::String::as_str).collect();
|
||||
let mut seen_chunks: HashSet<String> = HashSet::new();
|
||||
|
||||
@@ -44,7 +44,6 @@
|
||||
--leading-snug: 1.375;
|
||||
--leading-relaxed: 1.625;
|
||||
--ease-out: cubic-bezier(0, 0, 0.2, 1);
|
||||
--ease-in-out: cubic-bezier(0.4, 0, 0.2, 1);
|
||||
--animate-pulse: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite;
|
||||
--default-transition-duration: 150ms;
|
||||
--default-transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
|
||||
@@ -285,37 +284,6 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
.drawer-open {
|
||||
> .drawer-side {
|
||||
overflow-y: auto;
|
||||
}
|
||||
> .drawer-toggle {
|
||||
display: none;
|
||||
& ~ .drawer-side {
|
||||
pointer-events: auto;
|
||||
visibility: visible;
|
||||
position: sticky;
|
||||
display: block;
|
||||
width: auto;
|
||||
overscroll-behavior: auto;
|
||||
opacity: 100%;
|
||||
& > .drawer-overlay {
|
||||
cursor: default;
|
||||
background-color: transparent;
|
||||
}
|
||||
& > *:not(.drawer-overlay) {
|
||||
translate: 0%;
|
||||
[dir="rtl"] & {
|
||||
translate: 0%;
|
||||
}
|
||||
}
|
||||
}
|
||||
&:checked ~ .drawer-side {
|
||||
pointer-events: auto;
|
||||
visibility: visible;
|
||||
}
|
||||
}
|
||||
}
|
||||
.drawer-toggle {
|
||||
position: fixed;
|
||||
height: calc(0.25rem * 0);
|
||||
@@ -1074,22 +1042,6 @@
|
||||
grid-row-start: 1;
|
||||
min-width: calc(0.25rem * 0);
|
||||
}
|
||||
.chat-image {
|
||||
grid-row: span 2 / span 2;
|
||||
align-self: flex-end;
|
||||
}
|
||||
.chat-footer {
|
||||
grid-row-start: 3;
|
||||
display: flex;
|
||||
gap: calc(0.25rem * 1);
|
||||
font-size: 0.6875rem;
|
||||
}
|
||||
.chat-header {
|
||||
grid-row-start: 1;
|
||||
display: flex;
|
||||
gap: calc(0.25rem * 1);
|
||||
font-size: 0.6875rem;
|
||||
}
|
||||
.container {
|
||||
width: 100%;
|
||||
@media (width >= 40rem) {
|
||||
@@ -1796,9 +1748,6 @@
|
||||
.w-10 {
|
||||
width: calc(var(--spacing) * 10);
|
||||
}
|
||||
.w-11 {
|
||||
width: calc(var(--spacing) * 11);
|
||||
}
|
||||
.w-11\/12 {
|
||||
width: calc(11/12 * 100%);
|
||||
}
|
||||
@@ -1862,9 +1811,6 @@
|
||||
.flex-none {
|
||||
flex: none;
|
||||
}
|
||||
.flex-shrink {
|
||||
flex-shrink: 1;
|
||||
}
|
||||
.flex-shrink-0 {
|
||||
flex-shrink: 0;
|
||||
}
|
||||
@@ -1877,13 +1823,6 @@
|
||||
.grow {
|
||||
flex-grow: 1;
|
||||
}
|
||||
.border-collapse {
|
||||
border-collapse: collapse;
|
||||
}
|
||||
.-translate-y-1 {
|
||||
--tw-translate-y: calc(var(--spacing) * -1);
|
||||
translate: var(--tw-translate-x) var(--tw-translate-y);
|
||||
}
|
||||
.-translate-y-1\/2 {
|
||||
--tw-translate-y: calc(calc(1/2 * 100%) * -1);
|
||||
translate: var(--tw-translate-x) var(--tw-translate-y);
|
||||
@@ -1956,9 +1895,6 @@
|
||||
.justify-start {
|
||||
justify-content: flex-start;
|
||||
}
|
||||
.gap-0 {
|
||||
gap: calc(var(--spacing) * 0);
|
||||
}
|
||||
.gap-0\.5 {
|
||||
gap: calc(var(--spacing) * 0.5);
|
||||
}
|
||||
@@ -2115,9 +2051,6 @@
|
||||
.bg-transparent {
|
||||
background-color: transparent;
|
||||
}
|
||||
.bg-warning {
|
||||
background-color: var(--color-warning);
|
||||
}
|
||||
.bg-warning\/10 {
|
||||
background-color: var(--color-warning);
|
||||
@supports (color: color-mix(in lab, red, red)) {
|
||||
@@ -2136,9 +2069,6 @@
|
||||
.loading-spinner {
|
||||
mask-image: url("data:image/svg+xml,%3Csvg width='24' height='24' stroke='black' viewBox='0 0 24 24' xmlns='http://www.w3.org/2000/svg'%3E%3Cg transform-origin='center'%3E%3Ccircle cx='12' cy='12' r='9.5' fill='none' stroke-width='3' stroke-linecap='round'%3E%3CanimateTransform attributeName='transform' type='rotate' from='0 12 12' to='360 12 12' dur='2s' repeatCount='indefinite'/%3E%3Canimate attributeName='stroke-dasharray' values='0,150;42,150;42,150' keyTimes='0;0.475;1' dur='1.5s' repeatCount='indefinite'/%3E%3Canimate attributeName='stroke-dashoffset' values='0;-16;-59' keyTimes='0;0.475;1' dur='1.5s' repeatCount='indefinite'/%3E%3C/circle%3E%3C/g%3E%3C/svg%3E");
|
||||
}
|
||||
.mask-repeat {
|
||||
mask-repeat: repeat;
|
||||
}
|
||||
.fill-current {
|
||||
fill: currentcolor;
|
||||
}
|
||||
@@ -2169,9 +2099,6 @@
|
||||
.p-8 {
|
||||
padding: calc(var(--spacing) * 8);
|
||||
}
|
||||
.px-1 {
|
||||
padding-inline: calc(var(--spacing) * 1);
|
||||
}
|
||||
.px-1\.5 {
|
||||
padding-inline: calc(var(--spacing) * 1.5);
|
||||
}
|
||||
@@ -2326,9 +2253,6 @@
|
||||
--tw-tracking: var(--tracking-widest);
|
||||
letter-spacing: var(--tracking-widest);
|
||||
}
|
||||
.text-wrap {
|
||||
text-wrap: wrap;
|
||||
}
|
||||
.break-words {
|
||||
overflow-wrap: break-word;
|
||||
}
|
||||
@@ -2395,17 +2319,6 @@
|
||||
.italic {
|
||||
font-style: italic;
|
||||
}
|
||||
.underline {
|
||||
text-decoration-line: underline;
|
||||
}
|
||||
.swap-active {
|
||||
.swap-off {
|
||||
opacity: 0%;
|
||||
}
|
||||
.swap-on {
|
||||
opacity: 100%;
|
||||
}
|
||||
}
|
||||
.opacity-0 {
|
||||
opacity: 0%;
|
||||
}
|
||||
@@ -2496,10 +2409,6 @@
|
||||
--tw-duration: 300ms;
|
||||
transition-duration: 300ms;
|
||||
}
|
||||
.ease-in-out {
|
||||
--tw-ease: var(--ease-in-out);
|
||||
transition-timing-function: var(--ease-in-out);
|
||||
}
|
||||
.ease-out {
|
||||
--tw-ease: var(--ease-out);
|
||||
transition-timing-function: var(--ease-out);
|
||||
|
||||
@@ -2,10 +2,7 @@ use common::storage::types::conversation::SidebarConversation;
|
||||
use common::storage::{db::SurrealDbClient, store::StorageManager};
|
||||
use common::utils::embedding::EmbeddingProvider;
|
||||
use common::utils::template_engine::{ProvidesTemplateEngine, TemplateEngine};
|
||||
use common::{
|
||||
create_template_engine, storage::db::ProvidesDb,
|
||||
utils::config::{AppConfig, RetrievalStrategy},
|
||||
};
|
||||
use common::{create_template_engine, storage::db::ProvidesDb, utils::config::AppConfig};
|
||||
use retrieval_pipeline::reranking::RerankerPool;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{
|
||||
@@ -75,10 +72,6 @@ impl HtmlState {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn retrieval_strategy(&self) -> RetrievalStrategy {
|
||||
self.config.resolved_retrieval_strategy()
|
||||
}
|
||||
|
||||
pub async fn get_cached_conversation_archive(
|
||||
&self,
|
||||
user_id: &str,
|
||||
|
||||
@@ -16,12 +16,9 @@ use futures::{
|
||||
};
|
||||
use json_stream_parser::JsonStreamParser;
|
||||
use minijinja::Value;
|
||||
use retrieval_pipeline::{
|
||||
answer_retrieval::{
|
||||
chunks_to_chat_context, create_chat_request, create_user_message_with_history,
|
||||
LLMResponseFormat,
|
||||
},
|
||||
retrieved_entities_to_json,
|
||||
use retrieval_pipeline::answer_retrieval::{
|
||||
chunks_to_chat_context, create_chat_request, create_user_message_with_history,
|
||||
LLMResponseFormat,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::from_str;
|
||||
@@ -189,11 +186,7 @@ struct ReferenceData {
|
||||
}
|
||||
|
||||
fn extract_reference_strings(response: &LLMResponseFormat) -> Vec<String> {
|
||||
response
|
||||
.references
|
||||
.iter()
|
||||
.map(|reference| reference.reference.clone())
|
||||
.collect()
|
||||
response.reference_ids()
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
@@ -362,10 +355,9 @@ async fn prepare_chat_request(
|
||||
None => None,
|
||||
};
|
||||
|
||||
let strategy = state.retrieval_strategy();
|
||||
let config = retrieval_pipeline::RetrievalConfig::for_chat(strategy);
|
||||
let config = retrieval_pipeline::RetrievalConfig::default();
|
||||
|
||||
let retrieval_result = match retrieval_pipeline::retrieve_entities(
|
||||
let retrieval_result = match retrieval_pipeline::retrieve(
|
||||
&state.db,
|
||||
&state.openai_client,
|
||||
Some(&*state.embedding_provider),
|
||||
@@ -387,12 +379,9 @@ async fn prepare_chat_request(
|
||||
let allowed_reference_ids = collect_reference_ids_from_retrieval(&retrieval_result);
|
||||
|
||||
let context_json = match retrieval_result {
|
||||
retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(&chunks),
|
||||
retrieval_pipeline::StrategyOutput::Entities(entities) => {
|
||||
retrieved_entities_to_json(&entities)
|
||||
}
|
||||
retrieval_pipeline::StrategyOutput::Search(search_result) => {
|
||||
chunks_to_chat_context(&search_result.chunks)
|
||||
retrieval_pipeline::RetrievalOutput::Chunks(chunks) => chunks_to_chat_context(&chunks),
|
||||
retrieval_pipeline::RetrievalOutput::WithEntities { chunks, .. } => {
|
||||
chunks_to_chat_context(&chunks)
|
||||
}
|
||||
};
|
||||
let formatted_user_message =
|
||||
|
||||
@@ -9,7 +9,7 @@ use common::{
|
||||
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject},
|
||||
},
|
||||
};
|
||||
use retrieval_pipeline::StrategyOutput;
|
||||
use retrieval_pipeline::RetrievalOutput;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub(crate) const MAX_REFERENCE_COUNT: usize = 10;
|
||||
@@ -86,40 +86,29 @@ pub(crate) enum ReferenceLookupTarget {
|
||||
}
|
||||
|
||||
pub(crate) fn collect_reference_ids_from_retrieval(
|
||||
retrieval_result: &StrategyOutput,
|
||||
retrieval_result: &RetrievalOutput,
|
||||
) -> Vec<String> {
|
||||
let mut ids = Vec::new();
|
||||
let mut seen = HashSet::new();
|
||||
|
||||
let mut push_id = |id: String| {
|
||||
if seen.insert(id.clone()) {
|
||||
ids.push(id);
|
||||
}
|
||||
};
|
||||
|
||||
match retrieval_result {
|
||||
StrategyOutput::Chunks(chunks) => {
|
||||
RetrievalOutput::Chunks(chunks) => {
|
||||
for chunk in chunks {
|
||||
let id = chunk.chunk.id.clone();
|
||||
if seen.insert(id.clone()) {
|
||||
ids.push(id);
|
||||
}
|
||||
push_id(chunk.chunk.id.clone());
|
||||
}
|
||||
}
|
||||
StrategyOutput::Entities(entities) => {
|
||||
RetrievalOutput::WithEntities { chunks, entities } => {
|
||||
for chunk in chunks {
|
||||
push_id(chunk.chunk.id.clone());
|
||||
}
|
||||
for entity in entities {
|
||||
let id = entity.entity.id.clone();
|
||||
if seen.insert(id.clone()) {
|
||||
ids.push(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
StrategyOutput::Search(search) => {
|
||||
for chunk in &search.chunks {
|
||||
let id = chunk.chunk.id.clone();
|
||||
if seen.insert(id.clone()) {
|
||||
ids.push(id);
|
||||
}
|
||||
}
|
||||
for entity in &search.entities {
|
||||
let id = entity.entity.id.clone();
|
||||
if seen.insert(id.clone()) {
|
||||
ids.push(id);
|
||||
}
|
||||
push_id(entity.entity.id.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ use crate::{
|
||||
middlewares::{
|
||||
auth_middleware::RequireUser,
|
||||
response_middleware::{
|
||||
template_as_response, HtmlError, TemplateResponse, TemplateResult, ResponseResult,
|
||||
template_as_response, TemplateResponse, TemplateResult, ResponseResult,
|
||||
},
|
||||
},
|
||||
utils::text_content_preview::truncate_text_contents,
|
||||
|
||||
@@ -32,7 +32,7 @@ use crate::{
|
||||
middlewares::{
|
||||
auth_middleware::RequireUser,
|
||||
response_middleware::{
|
||||
template_with_headers, HtmlError, TemplateResponse, TemplateResult, ResponseResult,
|
||||
template_with_headers, TemplateResponse, TemplateResult, ResponseResult,
|
||||
},
|
||||
},
|
||||
utils::pagination::{paginate_items, Pagination},
|
||||
@@ -284,9 +284,9 @@ pub async fn suggest_knowledge_relationships(
|
||||
None => None,
|
||||
};
|
||||
|
||||
let config = retrieval_pipeline::RetrievalConfig::for_relationship_suggestion();
|
||||
if let Ok(retrieval_pipeline::StrategyOutput::Entities(results)) =
|
||||
retrieval_pipeline::retrieve_entities(
|
||||
let config = retrieval_pipeline::RetrievalConfig::with_entities();
|
||||
if let Ok(retrieval_pipeline::RetrievalOutput::WithEntities { entities, .. }) =
|
||||
retrieval_pipeline::retrieve(
|
||||
&state.db,
|
||||
&state.openai_client,
|
||||
Some(&*state.embedding_provider),
|
||||
@@ -297,7 +297,7 @@ pub async fn suggest_knowledge_relationships(
|
||||
)
|
||||
.await
|
||||
{
|
||||
for retrieval_pipeline::RetrievedEntity { entity, score, .. } in results {
|
||||
for retrieval_pipeline::RetrievedEntity { entity, score, .. } in entities {
|
||||
if suggestion_scores.len() >= MAX_RELATIONSHIP_SUGGESTIONS {
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ use crate::html_state::HtmlState;
|
||||
use crate::middlewares::{
|
||||
auth_middleware::RequireUser,
|
||||
response_middleware::{
|
||||
template_with_headers, HtmlError, TemplateResponse, TemplateResult, ResponseResult,
|
||||
template_with_headers, TemplateResponse, TemplateResult, ResponseResult,
|
||||
},
|
||||
};
|
||||
use common::storage::types::{
|
||||
|
||||
@@ -4,7 +4,7 @@ use axum::{
|
||||
extract::{Query, State},
|
||||
};
|
||||
use common::storage::types::{text_content::TextContent, user::User};
|
||||
use retrieval_pipeline::{RetrievalConfig, SearchResult, SearchTarget, StrategyOutput};
|
||||
use retrieval_pipeline::{retrieve, RetrievalConfig, RetrievalOutput, RetrievedChunk, RetrievedEntity};
|
||||
use serde::{de, Deserialize, Deserializer, Serialize};
|
||||
use std::{fmt, str::FromStr};
|
||||
|
||||
@@ -108,35 +108,35 @@ async fn perform_search(
|
||||
return Ok((Vec::new(), String::new()));
|
||||
}
|
||||
|
||||
let config = RetrievalConfig::for_search(SearchTarget::Both);
|
||||
let config = RetrievalConfig::with_entities();
|
||||
|
||||
let reranker_lease = match &state.reranker_pool {
|
||||
Some(pool) => pool.checkout().await,
|
||||
None => None,
|
||||
};
|
||||
|
||||
let params = retrieval_pipeline::pipeline::StrategyParams {
|
||||
db_client: &state.db,
|
||||
openai_client: &state.openai_client,
|
||||
embedding_provider: Some(&state.embedding_provider),
|
||||
input_text: trimmed_query,
|
||||
user_id: &user.id,
|
||||
let result = retrieve(
|
||||
&state.db,
|
||||
&state.openai_client,
|
||||
Some(&state.embedding_provider),
|
||||
trimmed_query,
|
||||
&user.id,
|
||||
config,
|
||||
reranker: reranker_lease,
|
||||
};
|
||||
let result = retrieval_pipeline::pipeline::execute(params).await?;
|
||||
reranker_lease,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let search_result = match result {
|
||||
StrategyOutput::Search(sr) => sr,
|
||||
_ => SearchResult::new(vec![], vec![]),
|
||||
let (chunks, entities) = match result {
|
||||
RetrievalOutput::WithEntities { chunks, entities } => (chunks, entities),
|
||||
RetrievalOutput::Chunks(chunks) => (chunks, Vec::new()),
|
||||
};
|
||||
|
||||
let source_label_map = collect_source_label_map(state, user, &search_result).await?;
|
||||
let source_label_map = collect_source_label_map(state, user, &chunks, &entities).await?;
|
||||
|
||||
let mut combined_results: Vec<SearchResultForTemplate> =
|
||||
Vec::with_capacity(search_result.chunks.len().saturating_add(search_result.entities.len()));
|
||||
Vec::with_capacity(chunks.len().saturating_add(entities.len()));
|
||||
|
||||
for chunk_result in search_result.chunks {
|
||||
for chunk_result in chunks {
|
||||
let source_label = source_label_map
|
||||
.get(&chunk_result.chunk.source_id)
|
||||
.cloned()
|
||||
@@ -155,7 +155,7 @@ async fn perform_search(
|
||||
});
|
||||
}
|
||||
|
||||
for entity_result in search_result.entities {
|
||||
for entity_result in entities {
|
||||
let source_label = source_label_map
|
||||
.get(&entity_result.entity.source_id)
|
||||
.cloned()
|
||||
@@ -187,13 +187,14 @@ async fn perform_search(
|
||||
async fn collect_source_label_map(
|
||||
state: &HtmlState,
|
||||
user: &User,
|
||||
search_result: &SearchResult,
|
||||
chunks: &[RetrievedChunk],
|
||||
entities: &[RetrievedEntity],
|
||||
) -> Result<std::collections::HashMap<String, String>, HtmlError> {
|
||||
let mut source_ids = HashSet::new();
|
||||
for chunk_result in &search_result.chunks {
|
||||
for chunk_result in chunks {
|
||||
source_ids.insert(chunk_result.chunk.source_id.clone());
|
||||
}
|
||||
for entity_result in &search_result.entities {
|
||||
for entity_result in entities {
|
||||
source_ids.insert(entity_result.entity.source_id.clone());
|
||||
}
|
||||
|
||||
|
||||
@@ -183,10 +183,8 @@ impl PipelineServices for DefaultPipelineServices {
|
||||
None => None,
|
||||
};
|
||||
|
||||
let config = retrieval_pipeline::RetrievalConfig::for_search(
|
||||
retrieval_pipeline::SearchTarget::EntitiesOnly,
|
||||
);
|
||||
match retrieval_pipeline::retrieve_entities(
|
||||
let config = retrieval_pipeline::RetrievalConfig::with_entities();
|
||||
match retrieval_pipeline::retrieve(
|
||||
&self.db,
|
||||
&self.openai_client,
|
||||
Some(&*self.embedding_provider),
|
||||
@@ -197,19 +195,16 @@ impl PipelineServices for DefaultPipelineServices {
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(retrieval_pipeline::StrategyOutput::Entities(entities)) => Ok(entities),
|
||||
Ok(retrieval_pipeline::StrategyOutput::Search(search)) => {
|
||||
let chunk_count = search.chunks.len();
|
||||
let entities = search.entities;
|
||||
Ok(retrieval_pipeline::RetrievalOutput::WithEntities { chunks, entities }) => {
|
||||
tracing::debug!(
|
||||
chunk_count,
|
||||
chunk_count = chunks.len(),
|
||||
entity_count = entities.len(),
|
||||
"ingestion search results returned entities"
|
||||
"ingestion retrieval resolved entities from chunks"
|
||||
);
|
||||
Ok(entities)
|
||||
}
|
||||
Ok(retrieval_pipeline::StrategyOutput::Chunks(_)) => Err(AppError::InternalError(
|
||||
"Ingestion retrieval should return entities".into(),
|
||||
Ok(retrieval_pipeline::RetrievalOutput::Chunks(_)) => Err(AppError::InternalError(
|
||||
"Ingestion retrieval should resolve entities".into(),
|
||||
)),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
@@ -372,16 +367,16 @@ mod tests {
|
||||
|
||||
fn system_prompt_from_request(
|
||||
request: &async_openai::types::CreateChatCompletionRequest,
|
||||
) -> String {
|
||||
let ChatCompletionRequestMessage::System(system) = &request.messages[0] else {
|
||||
panic!("expected first message to be system");
|
||||
) -> anyhow::Result<String> {
|
||||
let Some(ChatCompletionRequestMessage::System(system)) = request.messages.first() else {
|
||||
anyhow::bail!("expected first message to be system");
|
||||
};
|
||||
match &system.content {
|
||||
async_openai::types::ChatCompletionRequestSystemMessageContent::Text(text) => {
|
||||
text.clone()
|
||||
}
|
||||
other => panic!("unexpected system message content: {other:?}"),
|
||||
}
|
||||
let async_openai::types::ChatCompletionRequestSystemMessageContent::Text(text) =
|
||||
&system.content
|
||||
else {
|
||||
anyhow::bail!("unexpected system message content: {:?}", system.content);
|
||||
};
|
||||
Ok(text.clone())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -425,7 +420,7 @@ mod tests {
|
||||
.await
|
||||
.context("prepare llm request")?;
|
||||
|
||||
assert_eq!(system_prompt_from_request(&request), SENTINEL);
|
||||
assert_eq!(system_prompt_from_request(&request)?, SENTINEL);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,10 +125,10 @@ async fn render_pdf_pages(file_path: &Path, pages: &[u32]) -> Result<Vec<Vec<u8>
|
||||
})
|
||||
.await??;
|
||||
|
||||
for (idx, png) in captures.iter().enumerate() {
|
||||
if let Err(err) = maybe_dump_debug_image(page_numbers[idx], png).await {
|
||||
for (page_number, png) in page_numbers.iter().zip(captures.iter()) {
|
||||
if let Err(err) = maybe_dump_debug_image(*page_number, png).await {
|
||||
warn!(
|
||||
page = page_numbers[idx],
|
||||
page = page_number,
|
||||
error = %err,
|
||||
"Failed to write debug screenshot to disk"
|
||||
);
|
||||
|
||||
@@ -95,6 +95,7 @@ pub(crate) async fn init_with_config(config: AppConfig) -> anyhow::Result<Shared
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(dead_code)] // helpers are shared across binary test targets
|
||||
pub(crate) mod tests {
|
||||
use std::path::Path;
|
||||
|
||||
|
||||
@@ -11,8 +11,9 @@ use html_router::{
|
||||
use super::SharedServices;
|
||||
|
||||
/// Builds the Minne API and HTML route subtrees without fixing the outer Axum state
|
||||
/// type. SaaS consumers can merge additional routers and attach their own `AppState`
|
||||
/// type. `SaaS` consumers can merge additional routers and attach their own `AppState`
|
||||
/// as long as it implements `FromRef` for `ApiState` and `HtmlState`.
|
||||
#[allow(dead_code)] // used by server/main binaries, not worker
|
||||
pub fn minne_routes<S>(api_state: &ApiState, html_state: &HtmlState) -> Router<S>
|
||||
where
|
||||
S: Clone + Send + Sync + 'static,
|
||||
@@ -24,6 +25,7 @@ where
|
||||
.merge(html_routes(html_state))
|
||||
}
|
||||
|
||||
#[allow(dead_code)] // used by server/main binaries, not worker
|
||||
pub fn build_api_state(services: &SharedServices) -> ApiState {
|
||||
ApiState {
|
||||
db: Arc::clone(&services.db),
|
||||
@@ -32,6 +34,7 @@ pub fn build_api_state(services: &SharedServices) -> ApiState {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)] // used by server/main binaries, not worker
|
||||
pub async fn build_html_state(services: &SharedServices) -> anyhow::Result<HtmlState> {
|
||||
let session_store = Arc::new(
|
||||
services
|
||||
|
||||
@@ -75,6 +75,7 @@ struct AppState {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::expect_used)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::{
|
||||
|
||||
@@ -10,16 +10,14 @@ workspace = true
|
||||
[dependencies]
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
surrealdb = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
async-openai = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
fastembed = { workspace = true }
|
||||
|
||||
common = { path = "../common", features = ["test-utils"] }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
|
||||
@@ -1,61 +1,66 @@
|
||||
//! Chat answer assembly: retrieval context formatting and structured LLM request/response types.
|
||||
|
||||
use async_openai::{
|
||||
error::OpenAIError,
|
||||
types::{
|
||||
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
|
||||
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
|
||||
ResponseFormat, ResponseFormatJsonSchema,
|
||||
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, ResponseFormat,
|
||||
ResponseFormatJsonSchema,
|
||||
},
|
||||
};
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::types::{
|
||||
message::{format_history, Message},
|
||||
system_settings::SystemSettings,
|
||||
},
|
||||
use common::storage::types::{
|
||||
message::{format_history, Message},
|
||||
system_settings::SystemSettings,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use super::answer_retrieval_helper::get_query_response_schema;
|
||||
/// JSON schema describing the structured chat answer (answer text + references).
|
||||
fn get_query_response_schema() -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"answer": { "type": "string" },
|
||||
"references": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reference": { "type": "string" },
|
||||
},
|
||||
"required": ["reference"],
|
||||
"additionalProperties": false,
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["answer", "references"],
|
||||
"additionalProperties": false
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Reference {
|
||||
#[allow(dead_code)]
|
||||
pub reference: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct LLMResponseFormat {
|
||||
pub answer: String,
|
||||
#[allow(dead_code)]
|
||||
pub references: Vec<Reference>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Answer {
|
||||
pub content: String,
|
||||
pub references: Vec<String>,
|
||||
}
|
||||
|
||||
pub fn create_user_message(entities_json: &Value, query: &str) -> String {
|
||||
format!(
|
||||
r"
|
||||
Context Information:
|
||||
==================
|
||||
{entities_json}
|
||||
|
||||
User Question:
|
||||
==================
|
||||
{query}
|
||||
"
|
||||
)
|
||||
}
|
||||
|
||||
/// Convert chunk-based retrieval results to JSON format for LLM context
|
||||
pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value {
|
||||
fn round_score(value: f32) -> f64 {
|
||||
(f64::from(value) * 1000.0).round() / 1000.0
|
||||
impl LLMResponseFormat {
|
||||
pub fn reference_ids(&self) -> Vec<String> {
|
||||
self.references
|
||||
.iter()
|
||||
.map(|entry| entry.reference.clone())
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert chunk-based retrieval results to JSON format for LLM context.
|
||||
pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value {
|
||||
use crate::round_score;
|
||||
|
||||
serde_json::json!(chunks
|
||||
.iter()
|
||||
@@ -70,7 +75,7 @@ pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value {
|
||||
}
|
||||
|
||||
pub fn create_user_message_with_history(
|
||||
entities_json: &Value,
|
||||
context_json: &Value,
|
||||
history: &[Message],
|
||||
query: &str,
|
||||
) -> String {
|
||||
@@ -89,7 +94,7 @@ pub fn create_user_message_with_history(
|
||||
{}
|
||||
",
|
||||
format_history(history),
|
||||
entities_json,
|
||||
context_json,
|
||||
query
|
||||
)
|
||||
}
|
||||
@@ -116,18 +121,3 @@ pub fn create_chat_request(
|
||||
.response_format(response_format)
|
||||
.build()
|
||||
}
|
||||
|
||||
pub fn process_llm_response(
|
||||
response: &CreateChatCompletionResponse,
|
||||
) -> Result<LLMResponseFormat, Box<AppError>> {
|
||||
response
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.message.content.as_ref())
|
||||
.ok_or_else(|| Box::new(AppError::LLMParsing("No content found in LLM response".into())))
|
||||
.and_then(|content| {
|
||||
serde_json::from_str::<LLMResponseFormat>(content).map_err(|e| {
|
||||
Box::new(AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {e}")))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
use serde_json::{json, Value};
|
||||
|
||||
pub fn get_query_response_schema() -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"answer": { "type": "string" },
|
||||
"references": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reference": { "type": "string" },
|
||||
},
|
||||
"required": ["reference"],
|
||||
"additionalProperties": false,
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["answer", "references"],
|
||||
"additionalProperties": false
|
||||
})
|
||||
}
|
||||
@@ -1,228 +0,0 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use surrealdb::{sql::Thing, Error};
|
||||
|
||||
use common::storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{
|
||||
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
|
||||
StoredObject,
|
||||
},
|
||||
};
|
||||
|
||||
/// Find entities related to the given entity via graph relationships.
|
||||
///
|
||||
/// Queries the `relates_to` edge table for all relationships involving the entity,
|
||||
/// then fetches and returns the neighboring entities.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `db` - Database client
|
||||
/// * `entity_id` - ID of the entity to find neighbors for
|
||||
/// * `user_id` - User ID for access control
|
||||
/// * `limit` - Maximum number of neighbors to return
|
||||
pub async fn find_entities_by_relationship_by_id(
|
||||
db: &SurrealDbClient,
|
||||
entity_id: &str,
|
||||
user_id: &str,
|
||||
limit: usize,
|
||||
) -> Result<Vec<KnowledgeEntity>, Error> {
|
||||
let mut relationships_response = db
|
||||
.query(
|
||||
"
|
||||
SELECT * FROM relates_to
|
||||
WHERE metadata.user_id = $user_id
|
||||
AND (in = type::thing('knowledge_entity', $entity_id)
|
||||
OR out = type::thing('knowledge_entity', $entity_id))
|
||||
",
|
||||
)
|
||||
.bind(("entity_id", entity_id.to_owned()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?;
|
||||
|
||||
let relationships: Vec<KnowledgeRelationship> = relationships_response.take(0)?;
|
||||
if relationships.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut neighbor_ids: Vec<String> = Vec::with_capacity(relationships.len());
|
||||
let mut seen: HashSet<String> = HashSet::with_capacity(relationships.len());
|
||||
for rel in relationships {
|
||||
if rel.in_ == entity_id {
|
||||
if seen.insert(rel.out.clone()) {
|
||||
neighbor_ids.push(rel.out);
|
||||
}
|
||||
} else if rel.out == entity_id {
|
||||
if seen.insert(rel.in_.clone()) {
|
||||
neighbor_ids.push(rel.in_);
|
||||
}
|
||||
} else {
|
||||
if seen.insert(rel.in_.clone()) {
|
||||
neighbor_ids.push(rel.in_.clone());
|
||||
}
|
||||
if seen.insert(rel.out.clone()) {
|
||||
neighbor_ids.push(rel.out);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
neighbor_ids.retain(|id| id != entity_id);
|
||||
|
||||
if neighbor_ids.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
if limit > 0 && neighbor_ids.len() > limit {
|
||||
neighbor_ids.truncate(limit);
|
||||
}
|
||||
|
||||
let thing_ids: Vec<Thing> = neighbor_ids
|
||||
.iter()
|
||||
.map(|id| Thing::from((KnowledgeEntity::table_name(), id.as_str())))
|
||||
.collect();
|
||||
|
||||
let mut neighbors_response = db
|
||||
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
|
||||
.bind(("table", KnowledgeEntity::table_name().to_owned()))
|
||||
.bind(("things", thing_ids))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?;
|
||||
|
||||
let neighbors: Vec<KnowledgeEntity> = neighbors_response.take(0)?;
|
||||
if neighbors.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut neighbor_map: HashMap<String, KnowledgeEntity> = neighbors
|
||||
.into_iter()
|
||||
.map(|entity| (entity.id.clone(), entity))
|
||||
.collect();
|
||||
|
||||
let mut ordered = Vec::with_capacity(neighbor_ids.len());
|
||||
for id in neighbor_ids {
|
||||
if let Some(entity) = neighbor_map.remove(&id) {
|
||||
ordered.push(entity);
|
||||
}
|
||||
if limit > 0 && ordered.len() >= limit {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ordered)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
use super::*;
|
||||
use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||
use common::storage::types::knowledge_relationship::KnowledgeRelationship;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_find_entities_by_relationship_by_id() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let central_entity = KnowledgeEntity::new(
|
||||
"central_source".to_string(),
|
||||
"Central Entity".to_string(),
|
||||
"Central Description".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
let related_entity1 = KnowledgeEntity::new(
|
||||
"related_source1".to_string(),
|
||||
"Related Entity 1".to_string(),
|
||||
"Related Description 1".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
let related_entity2 = KnowledgeEntity::new(
|
||||
"related_source2".to_string(),
|
||||
"Related Entity 2".to_string(),
|
||||
"Related Description 2".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
let unrelated_entity = KnowledgeEntity::new(
|
||||
"unrelated_source".to_string(),
|
||||
"Unrelated Entity".to_string(),
|
||||
"Unrelated Description".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
let central_entity = db
|
||||
.store_item(central_entity.clone())
|
||||
.await
|
||||
.with_context(|| "Failed to store central entity".to_string())?
|
||||
.ok_or_else(|| anyhow::anyhow!("Central entity not returned after store"))?;
|
||||
let related_entity1 = db
|
||||
.store_item(related_entity1.clone())
|
||||
.await
|
||||
.with_context(|| "Failed to store related entity 1".to_string())?
|
||||
.ok_or_else(|| anyhow::anyhow!("Related entity 1 not returned after store"))?;
|
||||
let related_entity2 = db
|
||||
.store_item(related_entity2.clone())
|
||||
.await
|
||||
.with_context(|| "Failed to store related entity 2".to_string())?
|
||||
.ok_or_else(|| anyhow::anyhow!("Related entity 2 not returned after store"))?;
|
||||
let _unrelated_entity = db
|
||||
.store_item(unrelated_entity.clone())
|
||||
.await
|
||||
.with_context(|| "Failed to store unrelated entity".to_string())?
|
||||
.ok_or_else(|| anyhow::anyhow!("Unrelated entity not returned after store"))?;
|
||||
|
||||
let source_id = "relationship_source".to_string();
|
||||
|
||||
let relationship1 = KnowledgeRelationship::new(
|
||||
central_entity.id.clone(),
|
||||
related_entity1.id.clone(),
|
||||
user_id.clone(),
|
||||
source_id.clone(),
|
||||
"references".to_string(),
|
||||
);
|
||||
|
||||
let relationship2 = KnowledgeRelationship::new(
|
||||
central_entity.id.clone(),
|
||||
related_entity2.id.clone(),
|
||||
user_id.clone(),
|
||||
source_id.clone(),
|
||||
"contains".to_string(),
|
||||
);
|
||||
|
||||
relationship1
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.with_context(|| "Failed to store relationship 1".to_string())?;
|
||||
relationship2
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.with_context(|| "Failed to store relationship 2".to_string())?;
|
||||
|
||||
let related_entities =
|
||||
find_entities_by_relationship_by_id(&db, ¢ral_entity.id, &user_id, usize::MAX)
|
||||
.await
|
||||
.with_context(|| "Failed to find entities by relationship".to_string())?;
|
||||
|
||||
assert!(
|
||||
related_entities.len() >= 2,
|
||||
"Should find related entities in both directions"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
+63
-115
@@ -1,10 +1,9 @@
|
||||
pub mod answer_retrieval;
|
||||
pub mod answer_retrieval_helper;
|
||||
|
||||
pub mod graph;
|
||||
pub mod pipeline;
|
||||
pub mod reranking;
|
||||
pub mod scoring;
|
||||
|
||||
pub(crate) mod scoring;
|
||||
|
||||
use common::{
|
||||
error::AppError,
|
||||
@@ -16,39 +15,28 @@ use common::{
|
||||
use reranking::RerankerLease;
|
||||
use tracing::instrument;
|
||||
|
||||
// Strategy output variants - defined before pipeline module
|
||||
/// Result of a retrieval run.
|
||||
///
|
||||
/// Chunk retrieval is always performed; entities are only present when the caller
|
||||
/// requested entity resolution via [`RetrievalConfig::with_entities`].
|
||||
#[derive(Debug)]
|
||||
pub enum StrategyOutput {
|
||||
Entities(Vec<RetrievedEntity>),
|
||||
pub enum RetrievalOutput {
|
||||
Chunks(Vec<RetrievedChunk>),
|
||||
Search(SearchResult),
|
||||
}
|
||||
|
||||
/// Unified search result containing both chunks and entities
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SearchResult {
|
||||
pub chunks: Vec<RetrievedChunk>,
|
||||
pub entities: Vec<RetrievedEntity>,
|
||||
}
|
||||
|
||||
impl SearchResult {
|
||||
pub fn new(chunks: Vec<RetrievedChunk>, entities: Vec<RetrievedEntity>) -> Self {
|
||||
Self { chunks, entities }
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.chunks.is_empty() && self.entities.is_empty()
|
||||
}
|
||||
WithEntities {
|
||||
chunks: Vec<RetrievedChunk>,
|
||||
entities: Vec<RetrievedEntity>,
|
||||
},
|
||||
}
|
||||
|
||||
pub use pipeline::{
|
||||
retrieved_entities_to_json, Diagnostics, StageTimings, RetrievalConfig,
|
||||
RetrievalStrategy, RetrievalTuning, RetrievalTuningFlags, SearchTarget,
|
||||
retrieved_entities_to_json, Diagnostics, RetrievalConfig, RetrievalParams, StageKind,
|
||||
StageTimings,
|
||||
};
|
||||
|
||||
// Backward-compatible type aliases for external consumers
|
||||
pub type PipelineDiagnostics = Diagnostics;
|
||||
pub type PipelineStageTimings = StageTimings;
|
||||
/// Round a score to three decimal places for JSON output.
|
||||
pub(crate) fn round_score(value: f32) -> f64 {
|
||||
(f64::from(value) * 1000.0).round() / 1000.0
|
||||
}
|
||||
|
||||
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -57,7 +45,7 @@ pub struct RetrievedChunk {
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
// Final entity representation returned to callers, enriched with ranked chunks.
|
||||
// Knowledge entity resolved from retrieved chunks, enriched with its contributing chunks.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetrievedEntity {
|
||||
pub entity: KnowledgeEntity,
|
||||
@@ -65,9 +53,9 @@ pub struct RetrievedEntity {
|
||||
pub chunks: Vec<RetrievedChunk>,
|
||||
}
|
||||
|
||||
/// Primary orchestrator for the process of retrieving `KnowledgeEntity` values related to an `input_text`
|
||||
/// Run chunk-first hybrid retrieval for `input_text`, optionally resolving owning entities.
|
||||
#[instrument(skip_all, fields(user_id))]
|
||||
pub async fn retrieve_entities(
|
||||
pub async fn retrieve(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
|
||||
@@ -75,8 +63,8 @@ pub async fn retrieve_entities(
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<StrategyOutput, AppError> {
|
||||
let params = pipeline::StrategyParams {
|
||||
) -> Result<RetrievalOutput, AppError> {
|
||||
let params = pipeline::RetrievalParams {
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
@@ -94,6 +82,7 @@ mod tests {
|
||||
use anyhow::{self};
|
||||
use async_openai::Client;
|
||||
use common::storage::indexes::ensure_runtime;
|
||||
use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||
use common::storage::types::system_settings::SystemSettings;
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -133,7 +122,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_default_strategy_retrieves_chunks() -> anyhow::Result<()> {
|
||||
async fn test_chunk_retrieval_returns_chunks() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "test_user";
|
||||
let chunk = TextChunk::new(
|
||||
@@ -145,7 +134,7 @@ mod tests {
|
||||
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?;
|
||||
|
||||
let openai_client = Client::new();
|
||||
let params = pipeline::StrategyParams {
|
||||
let params = pipeline::RetrievalParams {
|
||||
db_client: &db,
|
||||
openai_client: &openai_client,
|
||||
embedding_provider: None,
|
||||
@@ -154,12 +143,13 @@ mod tests {
|
||||
config: RetrievalConfig::default(),
|
||||
reranker: None,
|
||||
};
|
||||
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
|
||||
.await?;
|
||||
let results = pipeline::run_with_embedding(params, test_embedding()).await?;
|
||||
|
||||
let chunks = match results {
|
||||
StrategyOutput::Chunks(items) => items,
|
||||
other => anyhow::bail!("expected chunk results, got {other:?}"),
|
||||
RetrievalOutput::Chunks(items) => items,
|
||||
RetrievalOutput::WithEntities { .. } => {
|
||||
anyhow::bail!("expected chunk results, got entities")
|
||||
}
|
||||
};
|
||||
|
||||
assert!(!chunks.is_empty(), "Expected at least one retrieval result");
|
||||
@@ -171,8 +161,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_default_strategy_returns_chunks_from_multiple_sources(
|
||||
) -> anyhow::Result<()> {
|
||||
async fn test_chunk_retrieval_returns_chunks_from_multiple_sources() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "multi_source_user";
|
||||
|
||||
@@ -191,7 +180,7 @@ mod tests {
|
||||
TextChunk::store_with_embedding(secondary_chunk, chunk_embedding_secondary(), &db).await?;
|
||||
|
||||
let openai_client = Client::new();
|
||||
let params = pipeline::StrategyParams {
|
||||
let params = pipeline::RetrievalParams {
|
||||
db_client: &db,
|
||||
openai_client: &openai_client,
|
||||
embedding_provider: None,
|
||||
@@ -200,12 +189,13 @@ mod tests {
|
||||
config: RetrievalConfig::default(),
|
||||
reranker: None,
|
||||
};
|
||||
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
|
||||
.await?;
|
||||
let results = pipeline::run_with_embedding(params, test_embedding()).await?;
|
||||
|
||||
let chunks = match results {
|
||||
StrategyOutput::Chunks(items) => items,
|
||||
other => anyhow::bail!("expected chunk results, got {other:?}"),
|
||||
RetrievalOutput::Chunks(items) => items,
|
||||
RetrievalOutput::WithEntities { .. } => {
|
||||
anyhow::bail!("expected chunk results, got entities")
|
||||
}
|
||||
};
|
||||
|
||||
assert!(chunks.len() >= 2, "Expected chunks from multiple sources");
|
||||
@@ -223,96 +213,54 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_revised_strategy_returns_chunks() -> anyhow::Result<()> {
|
||||
async fn test_with_entities_resolves_owning_entities() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "chunk_user";
|
||||
let chunk_one = TextChunk::new(
|
||||
"src_alpha".into(),
|
||||
"Tokio tasks execute on worker threads managed by the runtime.".into(),
|
||||
user_id.into(),
|
||||
);
|
||||
let chunk_two = TextChunk::new(
|
||||
"src_beta".into(),
|
||||
"Hyper utilizes Tokio to drive HTTP state machines efficiently.".into(),
|
||||
user_id.into(),
|
||||
);
|
||||
let user_id = "entity_user";
|
||||
|
||||
TextChunk::store_with_embedding(chunk_one.clone(), chunk_embedding_primary(), &db).await?;
|
||||
TextChunk::store_with_embedding(chunk_two.clone(), chunk_embedding_secondary(), &db).await?;
|
||||
|
||||
let config = RetrievalConfig::with_strategy(RetrievalStrategy::Default);
|
||||
let openai_client = Client::new();
|
||||
let params = pipeline::StrategyParams {
|
||||
db_client: &db,
|
||||
openai_client: &openai_client,
|
||||
embedding_provider: None,
|
||||
input_text: "tokio runtime worker behavior",
|
||||
user_id,
|
||||
config,
|
||||
reranker: None,
|
||||
};
|
||||
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
|
||||
.await?;
|
||||
|
||||
let chunks = match results {
|
||||
StrategyOutput::Chunks(items) => items,
|
||||
other => anyhow::bail!("expected chunk results, got {other:?}"),
|
||||
};
|
||||
|
||||
assert!(
|
||||
!chunks.is_empty(),
|
||||
"Revised strategy should return chunk-only responses"
|
||||
);
|
||||
assert!(
|
||||
chunks
|
||||
.iter()
|
||||
.any(|entry| entry.chunk.chunk.contains("Tokio")),
|
||||
"Chunk results should contain relevant snippets"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_search_strategy_returns_search_result() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "search_user";
|
||||
let chunk = TextChunk::new(
|
||||
"search_src".into(),
|
||||
"Async Rust programming uses Tokio runtime for concurrent tasks.".into(),
|
||||
"entity_source".into(),
|
||||
"Async Rust programming uses the Tokio runtime for concurrent tasks.".into(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?;
|
||||
|
||||
let config = RetrievalConfig::for_search(pipeline::SearchTarget::Both);
|
||||
let entity = KnowledgeEntity::new(
|
||||
"entity_source".into(),
|
||||
"Tokio Runtime".into(),
|
||||
"Async runtime for Rust".into(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
user_id.into(),
|
||||
);
|
||||
db.store_item(entity).await?;
|
||||
|
||||
let openai_client = Client::new();
|
||||
let params = pipeline::StrategyParams {
|
||||
let params = pipeline::RetrievalParams {
|
||||
db_client: &db,
|
||||
openai_client: &openai_client,
|
||||
embedding_provider: None,
|
||||
input_text: "async rust programming",
|
||||
user_id,
|
||||
config,
|
||||
config: RetrievalConfig::with_entities(),
|
||||
reranker: None,
|
||||
};
|
||||
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
|
||||
.await?;
|
||||
let results = pipeline::run_with_embedding(params, test_embedding()).await?;
|
||||
|
||||
let StrategyOutput::Search(search_result) = results else {
|
||||
anyhow::bail!("expected Search output");
|
||||
let RetrievalOutput::WithEntities { chunks, entities } = results else {
|
||||
anyhow::bail!("expected WithEntities output");
|
||||
};
|
||||
|
||||
// Should return chunks (entities may be empty if none stored)
|
||||
assert!(!chunks.is_empty(), "Should return chunks");
|
||||
assert!(
|
||||
!search_result.chunks.is_empty(),
|
||||
"Search strategy should return chunks"
|
||||
entities.iter().any(|e| e.entity.name == "Tokio Runtime"),
|
||||
"Should resolve the entity owning the retrieved chunk"
|
||||
);
|
||||
assert!(
|
||||
search_result
|
||||
.chunks
|
||||
entities
|
||||
.iter()
|
||||
.any(|c| c.chunk.chunk.contains("Tokio")),
|
||||
"Search results should contain relevant chunks"
|
||||
.find(|e| e.entity.name == "Tokio Runtime")
|
||||
.is_some_and(|e| !e.chunks.is_empty()),
|
||||
"Resolved entity should carry its contributing chunks"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,22 +1,5 @@
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
|
||||
use crate::scoring::FusionWeights;
|
||||
|
||||
pub use common::utils::config::RetrievalStrategy;
|
||||
|
||||
/// Configures which result types to include in Search strategy
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum SearchTarget {
|
||||
/// Return only text chunks
|
||||
ChunksOnly,
|
||||
/// Return only knowledge entities
|
||||
EntitiesOnly,
|
||||
/// Return both chunks and entities (default)
|
||||
#[default]
|
||||
Both,
|
||||
}
|
||||
|
||||
/// Two-variant flag that serializes as a bool for backward compatibility.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub enum BoolFlag {
|
||||
@@ -62,30 +45,20 @@ impl<'de> Deserialize<'de> for BoolFlag {
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct RetrievalTuningFlags {
|
||||
pub rerank_scores_only: BoolFlag,
|
||||
pub normalize_vector_scores: BoolFlag,
|
||||
pub normalize_fts_scores: BoolFlag,
|
||||
pub chunk_rrf_use_vector: BoolFlag,
|
||||
pub chunk_rrf_use_fts: BoolFlag,
|
||||
}
|
||||
|
||||
impl RetrievalTuningFlags {
|
||||
pub const fn rerank_scores_only(&self) -> bool {
|
||||
pub const fn rerank_scores_only(self) -> bool {
|
||||
self.rerank_scores_only.as_bool()
|
||||
}
|
||||
|
||||
pub const fn normalize_vector_scores(&self) -> bool {
|
||||
self.normalize_vector_scores.as_bool()
|
||||
}
|
||||
|
||||
pub const fn normalize_fts_scores(&self) -> bool {
|
||||
self.normalize_fts_scores.as_bool()
|
||||
}
|
||||
|
||||
pub const fn chunk_rrf_use_vector(&self) -> bool {
|
||||
pub const fn chunk_rrf_use_vector(self) -> bool {
|
||||
self.chunk_rrf_use_vector.as_bool()
|
||||
}
|
||||
|
||||
pub const fn chunk_rrf_use_fts(&self) -> bool {
|
||||
pub const fn chunk_rrf_use_fts(self) -> bool {
|
||||
self.chunk_rrf_use_fts.as_bool()
|
||||
}
|
||||
}
|
||||
@@ -94,146 +67,70 @@ impl Default for RetrievalTuningFlags {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
rerank_scores_only: BoolFlag::Disabled,
|
||||
normalize_vector_scores: BoolFlag::Disabled,
|
||||
normalize_fts_scores: BoolFlag::Enabled,
|
||||
chunk_rrf_use_vector: BoolFlag::Enabled,
|
||||
chunk_rrf_use_fts: BoolFlag::Enabled,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tunable parameters that govern each retrieval stage.
|
||||
/// Tunable parameters governing the chunk-first hybrid (vector + FTS, RRF-fused) retrieval.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RetrievalTuning {
|
||||
pub entity_vector_take: usize,
|
||||
/// Number of vector candidates to pull from the chunk embedding index.
|
||||
pub chunk_vector_take: usize,
|
||||
pub entity_fts_take: usize,
|
||||
/// Number of full-text candidates to pull from the chunk index.
|
||||
pub chunk_fts_take: usize,
|
||||
pub score_threshold: f32,
|
||||
pub fallback_min_results: usize,
|
||||
pub token_budget_estimate: usize,
|
||||
pub avg_chars_per_token: usize,
|
||||
/// Maximum chunks attached to each resolved entity.
|
||||
pub max_chunks_per_entity: usize,
|
||||
pub lexical_match_weight: f32,
|
||||
pub graph_traversal_seed_limit: usize,
|
||||
pub graph_neighbor_limit: usize,
|
||||
pub graph_score_decay: f32,
|
||||
pub graph_seed_min_score: f32,
|
||||
pub graph_vector_inheritance: f32,
|
||||
/// Blend weight applied when mixing reranker scores with fused scores.
|
||||
pub rerank_blend_weight: f32,
|
||||
pub flags: RetrievalTuningFlags,
|
||||
/// Keep top-N candidates after reranking.
|
||||
pub rerank_keep_top: usize,
|
||||
/// Maximum number of chunks returned to callers.
|
||||
pub chunk_result_cap: usize,
|
||||
/// Optional fusion weights for hybrid search. If None, uses default weights.
|
||||
pub fusion_weights: Option<FusionWeights>,
|
||||
/// Reciprocal rank fusion k value for chunk merging in Revised strategy.
|
||||
#[serde(default = "default_chunk_rrf_k")]
|
||||
/// Reciprocal rank fusion k value for chunk merging.
|
||||
pub chunk_rrf_k: f32,
|
||||
/// Weight applied to vector ranks in RRF.
|
||||
#[serde(default = "default_chunk_rrf_vector_weight")]
|
||||
pub chunk_rrf_vector_weight: f32,
|
||||
/// Weight applied to chunk FTS ranks in RRF.
|
||||
#[serde(default = "default_chunk_rrf_fts_weight")]
|
||||
pub chunk_rrf_fts_weight: f32,
|
||||
pub flags: RetrievalTuningFlags,
|
||||
}
|
||||
|
||||
impl Default for RetrievalTuning {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
entity_vector_take: 15,
|
||||
chunk_vector_take: 20,
|
||||
entity_fts_take: 10,
|
||||
chunk_fts_take: 20,
|
||||
score_threshold: 0.35,
|
||||
fallback_min_results: 10,
|
||||
token_budget_estimate: 10000,
|
||||
avg_chars_per_token: 4,
|
||||
max_chunks_per_entity: 4,
|
||||
lexical_match_weight: 0.15,
|
||||
graph_traversal_seed_limit: 5,
|
||||
graph_neighbor_limit: 6,
|
||||
graph_score_decay: 0.75,
|
||||
graph_seed_min_score: 0.4,
|
||||
graph_vector_inheritance: 0.6,
|
||||
rerank_blend_weight: 0.65,
|
||||
flags: RetrievalTuningFlags::default(),
|
||||
rerank_keep_top: 8,
|
||||
chunk_result_cap: 5,
|
||||
fusion_weights: None,
|
||||
chunk_rrf_k: default_chunk_rrf_k(),
|
||||
chunk_rrf_vector_weight: default_chunk_rrf_vector_weight(),
|
||||
chunk_rrf_fts_weight: default_chunk_rrf_fts_weight(),
|
||||
chunk_rrf_k: 60.0,
|
||||
chunk_rrf_vector_weight: 1.0,
|
||||
chunk_rrf_fts_weight: 1.0,
|
||||
flags: RetrievalTuningFlags::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper containing tuning plus future flags for per-request overrides.
|
||||
/// Per-request retrieval configuration.
|
||||
///
|
||||
/// The pipeline always performs chunk-first hybrid retrieval. Set `resolve_entities`
|
||||
/// when a caller additionally needs the `KnowledgeEntity` rows that own the retrieved
|
||||
/// chunks (search, ingestion linking, relationship suggestion).
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct RetrievalConfig {
|
||||
pub strategy: RetrievalStrategy,
|
||||
pub tuning: RetrievalTuning,
|
||||
/// Target for Search strategy (chunks, entities, or both)
|
||||
pub search_target: SearchTarget,
|
||||
pub resolve_entities: bool,
|
||||
}
|
||||
|
||||
impl RetrievalConfig {
|
||||
pub fn new(tuning: RetrievalTuning) -> Self {
|
||||
/// Chunk retrieval that also resolves the owning knowledge entities.
|
||||
pub fn with_entities() -> Self {
|
||||
Self {
|
||||
strategy: RetrievalStrategy::Default,
|
||||
tuning,
|
||||
search_target: SearchTarget::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_strategy(strategy: RetrievalStrategy) -> Self {
|
||||
Self {
|
||||
strategy,
|
||||
tuning: RetrievalTuning::default(),
|
||||
search_target: SearchTarget::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_tuning(strategy: RetrievalStrategy, tuning: RetrievalTuning) -> Self {
|
||||
Self {
|
||||
strategy,
|
||||
tuning,
|
||||
search_target: SearchTarget::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create config for chat retrieval with strategy selection support
|
||||
pub fn for_chat(strategy: RetrievalStrategy) -> Self {
|
||||
Self::with_strategy(strategy)
|
||||
}
|
||||
|
||||
/// Create config for relationship suggestion (entity-only retrieval)
|
||||
pub fn for_relationship_suggestion() -> Self {
|
||||
Self::with_strategy(RetrievalStrategy::RelationshipSuggestion)
|
||||
}
|
||||
|
||||
/// Create config for ingestion pipeline (entity-only retrieval)
|
||||
pub fn for_ingestion() -> Self {
|
||||
Self::with_strategy(RetrievalStrategy::Ingestion)
|
||||
}
|
||||
|
||||
/// Create config for unified search (chunks and/or entities)
|
||||
pub fn for_search(target: SearchTarget) -> Self {
|
||||
Self {
|
||||
strategy: RetrievalStrategy::Search,
|
||||
tuning: RetrievalTuning::default(),
|
||||
search_target: target,
|
||||
resolve_entities: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const fn default_chunk_rrf_k() -> f32 {
|
||||
60.0
|
||||
}
|
||||
|
||||
const fn default_chunk_rrf_vector_weight() -> f32 {
|
||||
1.0
|
||||
}
|
||||
|
||||
const fn default_chunk_rrf_fts_weight() -> f32 {
|
||||
1.0
|
||||
}
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
use async_openai::Client;
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{db::SurrealDbClient, types::text_chunk::TextChunk},
|
||||
utils::embedding::EmbeddingProvider,
|
||||
};
|
||||
|
||||
use crate::{reranking::RerankerLease, scoring::Scored, RetrievedChunk, RetrievedEntity};
|
||||
|
||||
use super::{
|
||||
config::RetrievalConfig,
|
||||
diagnostics::{AssembleStats, Diagnostics, SearchStats},
|
||||
StageKind, StageTimings, RetrievalParams,
|
||||
};
|
||||
|
||||
/// Mutable working state threaded through every retrieval stage.
|
||||
pub(crate) struct PipelineContext<'a> {
|
||||
pub db_client: &'a SurrealDbClient,
|
||||
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
pub embedding_provider: Option<&'a EmbeddingProvider>,
|
||||
pub input_text: String,
|
||||
pub user_id: String,
|
||||
pub config: RetrievalConfig,
|
||||
pub query_embedding: Option<Vec<f32>>,
|
||||
pub chunk_values: Vec<Scored<TextChunk>>,
|
||||
pub reranker: Option<RerankerLease>,
|
||||
pub diagnostics: Option<Diagnostics>,
|
||||
pub entity_results: Vec<RetrievedEntity>,
|
||||
pub chunk_results: Vec<RetrievedChunk>,
|
||||
stage_timings: StageTimings,
|
||||
}
|
||||
|
||||
impl<'a> PipelineContext<'a> {
|
||||
pub fn new(params: RetrievalParams<'a>) -> Self {
|
||||
Self {
|
||||
db_client: params.db_client,
|
||||
openai_client: params.openai_client,
|
||||
embedding_provider: params.embedding_provider,
|
||||
input_text: params.input_text.to_owned(),
|
||||
user_id: params.user_id.to_owned(),
|
||||
config: params.config,
|
||||
query_embedding: None,
|
||||
chunk_values: Vec::new(),
|
||||
reranker: params.reranker,
|
||||
diagnostics: None,
|
||||
entity_results: Vec::new(),
|
||||
chunk_results: Vec::new(),
|
||||
stage_timings: StageTimings::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_embedding(params: RetrievalParams<'a>, query_embedding: Vec<f32>) -> Self {
|
||||
let mut ctx = Self::new(params);
|
||||
ctx.query_embedding = Some(query_embedding);
|
||||
ctx
|
||||
}
|
||||
|
||||
pub(crate) fn ensure_embedding(&self) -> Result<&Vec<f32>, Box<AppError>> {
|
||||
self.query_embedding.as_ref().ok_or_else(|| {
|
||||
Box::new(AppError::InternalError(
|
||||
"query embedding missing before candidate search".to_string(),
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn enable_diagnostics(&mut self) {
|
||||
if self.diagnostics.is_none() {
|
||||
self.diagnostics = Some(Diagnostics::default());
|
||||
}
|
||||
}
|
||||
|
||||
pub fn diagnostics_enabled(&self) -> bool {
|
||||
self.diagnostics.is_some()
|
||||
}
|
||||
|
||||
pub(crate) fn record_search(&mut self, stats: SearchStats) {
|
||||
if let Some(diag) = self.diagnostics.as_mut() {
|
||||
diag.search = Some(stats);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn record_assemble(&mut self, stats: AssembleStats) {
|
||||
if let Some(diag) = self.diagnostics.as_mut() {
|
||||
diag.assemble = Some(stats);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn take_diagnostics(&mut self) -> Option<Diagnostics> {
|
||||
self.diagnostics.take()
|
||||
}
|
||||
|
||||
pub fn take_stage_timings(&mut self) -> StageTimings {
|
||||
std::mem::take(&mut self.stage_timings)
|
||||
}
|
||||
|
||||
pub fn record_stage_duration(&mut self, kind: StageKind, duration: std::time::Duration) {
|
||||
self.stage_timings.record(kind, duration);
|
||||
}
|
||||
|
||||
pub fn take_entity_results(&mut self) -> Vec<RetrievedEntity> {
|
||||
std::mem::take(&mut self.entity_results)
|
||||
}
|
||||
|
||||
pub fn take_chunk_results(&mut self) -> Vec<RetrievedChunk> {
|
||||
std::mem::take(&mut self.chunk_results)
|
||||
}
|
||||
}
|
||||
@@ -1,51 +1,21 @@
|
||||
use serde::Serialize;
|
||||
|
||||
/// Captures instrumentation for each hybrid retrieval stage when diagnostics are enabled.
|
||||
/// Captures instrumentation for the retrieval stages when diagnostics are enabled.
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct Diagnostics {
|
||||
pub collect_candidates: Option<CollectCandidatesStats>,
|
||||
pub enrich_chunks_from_entities: Option<ChunkEnrichmentStats>,
|
||||
pub search: Option<SearchStats>,
|
||||
pub assemble: Option<AssembleStats>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct CollectCandidatesStats {
|
||||
pub vector_entity_candidates: usize,
|
||||
pub struct SearchStats {
|
||||
pub vector_chunk_candidates: usize,
|
||||
pub fts_entity_candidates: usize,
|
||||
pub fts_chunk_candidates: usize,
|
||||
pub vector_chunk_scores: Vec<f32>,
|
||||
pub fts_chunk_scores: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct ChunkEnrichmentStats {
|
||||
pub filtered_entity_count: usize,
|
||||
pub fallback_min_results: usize,
|
||||
pub chunk_sources_considered: usize,
|
||||
pub chunk_candidates_before_enrichment: usize,
|
||||
pub chunk_candidates_after_enrichment: usize,
|
||||
pub top_chunk_scores: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct AssembleStats {
|
||||
pub token_budget_start: usize,
|
||||
pub token_budget_spent: usize,
|
||||
pub token_budget_remaining: usize,
|
||||
pub budget_exhausted: bool,
|
||||
pub chunks_selected: usize,
|
||||
pub chunks_skipped_due_budget: usize,
|
||||
pub entity_count: usize,
|
||||
pub entity_traces: Vec<EntityAssemblyTrace>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct EntityAssemblyTrace {
|
||||
pub entity_id: String,
|
||||
pub source_id: String,
|
||||
pub inspected_candidates: usize,
|
||||
pub selected_chunk_ids: Vec<String>,
|
||||
pub selected_chunk_scores: Vec<f32>,
|
||||
pub skipped_due_budget: usize,
|
||||
}
|
||||
|
||||
@@ -1,61 +1,68 @@
|
||||
mod config;
|
||||
mod context;
|
||||
mod diagnostics;
|
||||
mod stages;
|
||||
mod strategies;
|
||||
|
||||
pub use config::{
|
||||
RetrievalConfig, RetrievalStrategy, RetrievalTuning, RetrievalTuningFlags, SearchTarget,
|
||||
};
|
||||
pub use diagnostics::{
|
||||
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace, Diagnostics,
|
||||
};
|
||||
pub use config::RetrievalConfig;
|
||||
pub use diagnostics::Diagnostics;
|
||||
|
||||
use crate::{reranking::RerankerLease, RetrievedEntity, StrategyOutput};
|
||||
use crate::{round_score, RetrievalOutput, RetrievedEntity};
|
||||
use async_openai::Client;
|
||||
use async_trait::async_trait;
|
||||
use common::{error::AppError, storage::db::SurrealDbClient};
|
||||
use std::time::{Duration, Instant};
|
||||
use tracing::info;
|
||||
|
||||
use stages::PipelineContext;
|
||||
use strategies::{
|
||||
DefaultStrategyDriver, IngestionDriver, RelationshipSuggestionDriver, SearchStrategyDriver,
|
||||
use stages::{
|
||||
ChunkAssembleStage, ChunkRerankStage, ChunkSearchStage, EmbedStage, ResolveEntitiesStage,
|
||||
};
|
||||
|
||||
// Export StrategyOutput publicly from this module
|
||||
// (it's defined in lib.rs but we re-export it here)
|
||||
|
||||
// Stage type enum
|
||||
/// Identifies a retrieval stage for timing and instrumentation.
|
||||
///
|
||||
/// [`StageKind::ALL`] lists every kind in pipeline order; consumers (e.g. the evaluation
|
||||
/// harness) iterate it generically so that adding a stage requires no changes outside this
|
||||
/// crate — add the variant, extend `ALL`, and give it a [`StageKind::label`].
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum StageKind {
|
||||
Embed,
|
||||
CollectCandidates,
|
||||
GraphExpansion,
|
||||
ChunkAttach,
|
||||
Search,
|
||||
Rerank,
|
||||
ResolveEntities,
|
||||
Assemble,
|
||||
}
|
||||
|
||||
// Pipeline stage trait
|
||||
impl StageKind {
|
||||
/// Every stage kind in canonical pipeline order.
|
||||
pub const ALL: [StageKind; 5] = [
|
||||
StageKind::Embed,
|
||||
StageKind::Search,
|
||||
StageKind::Rerank,
|
||||
StageKind::ResolveEntities,
|
||||
StageKind::Assemble,
|
||||
];
|
||||
|
||||
/// Stable, machine-friendly identifier for the stage (used as a metrics key).
|
||||
pub const fn label(self) -> &'static str {
|
||||
match self {
|
||||
StageKind::Embed => "embed",
|
||||
StageKind::Search => "search",
|
||||
StageKind::Rerank => "rerank",
|
||||
StageKind::ResolveEntities => "resolve_entities",
|
||||
StageKind::Assemble => "assemble",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A single composable step in the retrieval pipeline.
|
||||
#[async_trait]
|
||||
pub trait Stage: Send + Sync {
|
||||
pub(crate) trait Stage: Send + Sync {
|
||||
fn kind(&self) -> StageKind;
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError>;
|
||||
async fn execute(&self, ctx: &mut context::PipelineContext<'_>) -> Result<(), AppError>;
|
||||
}
|
||||
|
||||
// Type alias for boxed stages
|
||||
pub type BoxedStage = Box<dyn Stage>;
|
||||
pub(crate) type BoxedStage = Box<dyn Stage>;
|
||||
|
||||
// Strategy driver trait
|
||||
#[async_trait]
|
||||
pub trait StrategyDriver: Send + Sync {
|
||||
type Output;
|
||||
|
||||
fn stages(&self) -> Vec<BoxedStage>;
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>>;
|
||||
}
|
||||
|
||||
// Pipeline stage timings tracker
|
||||
/// Per-stage execution timings recorded during a run.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct StageTimings {
|
||||
timings: Vec<(StageKind, Duration)>,
|
||||
@@ -66,41 +73,13 @@ impl StageTimings {
|
||||
self.timings.push((kind, duration));
|
||||
}
|
||||
|
||||
pub fn into_vec(self) -> Vec<(StageKind, Duration)> {
|
||||
self.timings
|
||||
}
|
||||
|
||||
// Helper methods to get duration for each stage type (for backward compatibility)
|
||||
fn get_stage_ms(&self, kind: StageKind) -> u128 {
|
||||
/// Milliseconds recorded for `kind`, or `0` if the stage did not run.
|
||||
pub fn stage_ms(&self, kind: StageKind) -> u128 {
|
||||
self.timings
|
||||
.iter()
|
||||
.find(|(k, _)| *k == kind)
|
||||
.map_or(0, |(_, d)| d.as_millis())
|
||||
}
|
||||
|
||||
pub fn embed_ms(&self) -> u128 {
|
||||
self.get_stage_ms(StageKind::Embed)
|
||||
}
|
||||
|
||||
pub fn collect_candidates_ms(&self) -> u128 {
|
||||
self.get_stage_ms(StageKind::CollectCandidates)
|
||||
}
|
||||
|
||||
pub fn graph_expansion_ms(&self) -> u128 {
|
||||
self.get_stage_ms(StageKind::GraphExpansion)
|
||||
}
|
||||
|
||||
pub fn chunk_attach_ms(&self) -> u128 {
|
||||
self.get_stage_ms(StageKind::ChunkAttach)
|
||||
}
|
||||
|
||||
pub fn rerank_ms(&self) -> u128 {
|
||||
self.get_stage_ms(StageKind::Rerank)
|
||||
}
|
||||
|
||||
pub fn assemble_ms(&self) -> u128 {
|
||||
self.get_stage_ms(StageKind::Assemble)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RunOutput<T> {
|
||||
@@ -109,7 +88,35 @@ pub struct RunOutput<T> {
|
||||
pub stage_timings: StageTimings,
|
||||
}
|
||||
|
||||
pub async fn execute(params: StrategyParams<'_>) -> Result<StrategyOutput, AppError> {
|
||||
/// Inputs required to run a retrieval.
|
||||
pub struct RetrievalParams<'a> {
|
||||
pub db_client: &'a SurrealDbClient,
|
||||
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
pub embedding_provider: Option<&'a common::utils::embedding::EmbeddingProvider>,
|
||||
pub input_text: &'a str,
|
||||
pub user_id: &'a str,
|
||||
pub config: RetrievalConfig,
|
||||
pub reranker: Option<crate::reranking::RerankerLease>,
|
||||
}
|
||||
|
||||
fn build_stages(config: &RetrievalConfig) -> Vec<BoxedStage> {
|
||||
let mut stages: Vec<BoxedStage> = vec![
|
||||
Box::new(EmbedStage),
|
||||
Box::new(ChunkSearchStage),
|
||||
Box::new(ChunkRerankStage),
|
||||
];
|
||||
if config.resolve_entities {
|
||||
stages.push(Box::new(ResolveEntitiesStage));
|
||||
}
|
||||
stages.push(Box::new(ChunkAssembleStage));
|
||||
stages
|
||||
}
|
||||
|
||||
async fn run(
|
||||
params: RetrievalParams<'_>,
|
||||
query_embedding: Option<Vec<f32>>,
|
||||
capture_diagnostics: bool,
|
||||
) -> Result<RunOutput<RetrievalOutput>, AppError> {
|
||||
let input_chars = params.input_text.chars().count();
|
||||
let input_preview: String = params.input_text.chars().take(120).collect();
|
||||
let input_preview_clean = input_preview.replace('\n', " ");
|
||||
@@ -119,110 +126,67 @@ pub async fn execute(params: StrategyParams<'_>) -> Result<StrategyOutput, AppEr
|
||||
input_chars,
|
||||
preview_truncated = input_chars > preview_len,
|
||||
preview = %input_preview_clean,
|
||||
strategy = %params.config.strategy,
|
||||
resolve_entities = params.config.resolve_entities,
|
||||
"Starting retrieval pipeline"
|
||||
);
|
||||
|
||||
let strategy = params.config.strategy;
|
||||
let search_target = params.config.search_target;
|
||||
let resolve_entities = params.config.resolve_entities;
|
||||
let mut ctx = match query_embedding {
|
||||
Some(embedding) => context::PipelineContext::with_embedding(params, embedding),
|
||||
None => context::PipelineContext::new(params),
|
||||
};
|
||||
|
||||
match strategy {
|
||||
RetrievalStrategy::Default => {
|
||||
let driver = DefaultStrategyDriver::new();
|
||||
let run = execute_strategy(driver, params, None, false).await?;
|
||||
Ok(StrategyOutput::Chunks(run.results))
|
||||
}
|
||||
RetrievalStrategy::RelationshipSuggestion => {
|
||||
let driver = RelationshipSuggestionDriver::new();
|
||||
let run = execute_strategy(driver, params, None, false).await?;
|
||||
Ok(StrategyOutput::Entities(run.results))
|
||||
}
|
||||
RetrievalStrategy::Ingestion => {
|
||||
let driver = IngestionDriver::new();
|
||||
let run = execute_strategy(driver, params, None, false).await?;
|
||||
Ok(StrategyOutput::Entities(run.results))
|
||||
}
|
||||
RetrievalStrategy::Search => {
|
||||
let driver = SearchStrategyDriver::new(search_target);
|
||||
let run = execute_strategy(driver, params, None, false).await?;
|
||||
Ok(StrategyOutput::Search(run.results))
|
||||
}
|
||||
if capture_diagnostics {
|
||||
ctx.enable_diagnostics();
|
||||
}
|
||||
|
||||
for stage in build_stages(&ctx.config) {
|
||||
let start = Instant::now();
|
||||
stage.execute(&mut ctx).await?;
|
||||
ctx.record_stage_duration(stage.kind(), start.elapsed());
|
||||
}
|
||||
|
||||
let diagnostics = ctx.take_diagnostics();
|
||||
let stage_timings = ctx.take_stage_timings();
|
||||
let chunks = ctx.take_chunk_results();
|
||||
let results = if resolve_entities {
|
||||
RetrievalOutput::WithEntities {
|
||||
chunks,
|
||||
entities: ctx.take_entity_results(),
|
||||
}
|
||||
} else {
|
||||
RetrievalOutput::Chunks(chunks)
|
||||
};
|
||||
|
||||
Ok(RunOutput {
|
||||
results,
|
||||
diagnostics,
|
||||
stage_timings,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn run_pipeline_with_embedding(
|
||||
params: StrategyParams<'_>,
|
||||
query_embedding: Vec<f32>,
|
||||
) -> Result<StrategyOutput, AppError> {
|
||||
let strategy = params.config.strategy;
|
||||
let search_target = params.config.search_target;
|
||||
|
||||
match strategy {
|
||||
RetrievalStrategy::Default => {
|
||||
let driver = DefaultStrategyDriver::new();
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
||||
Ok(StrategyOutput::Chunks(run.results))
|
||||
}
|
||||
RetrievalStrategy::RelationshipSuggestion => {
|
||||
let driver = RelationshipSuggestionDriver::new();
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
||||
Ok(StrategyOutput::Entities(run.results))
|
||||
}
|
||||
RetrievalStrategy::Ingestion => {
|
||||
let driver = IngestionDriver::new();
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
||||
Ok(StrategyOutput::Entities(run.results))
|
||||
}
|
||||
RetrievalStrategy::Search => {
|
||||
let driver = SearchStrategyDriver::new(search_target);
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
||||
Ok(StrategyOutput::Search(run.results))
|
||||
}
|
||||
}
|
||||
/// Run the retrieval pipeline, generating the query embedding internally if needed.
|
||||
pub async fn execute(params: RetrievalParams<'_>) -> Result<RetrievalOutput, AppError> {
|
||||
Ok(run(params, None, false).await?.results)
|
||||
}
|
||||
|
||||
pub async fn run_pipeline_with_embedding_with_metrics(
|
||||
params: StrategyParams<'_>,
|
||||
/// Run the retrieval pipeline with a pre-computed query embedding.
|
||||
pub async fn run_with_embedding(
|
||||
params: RetrievalParams<'_>,
|
||||
query_embedding: Vec<f32>,
|
||||
) -> Result<RunOutput<StrategyOutput>, AppError> {
|
||||
let strategy = params.config.strategy;
|
||||
|
||||
match strategy {
|
||||
RetrievalStrategy::Default => {
|
||||
let driver = DefaultStrategyDriver::new();
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
||||
Ok(RunOutput {
|
||||
results: StrategyOutput::Chunks(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
})
|
||||
}
|
||||
_ => Err(AppError::InternalError(
|
||||
"Metrics not supported for this strategy".into(),
|
||||
)),
|
||||
}
|
||||
) -> Result<RetrievalOutput, AppError> {
|
||||
Ok(run(params, Some(query_embedding), false).await?.results)
|
||||
}
|
||||
|
||||
pub async fn run_pipeline_with_embedding_with_diagnostics(
|
||||
params: StrategyParams<'_>,
|
||||
/// Run with a pre-computed embedding, returning results and per-stage timings.
|
||||
///
|
||||
/// When `capture_diagnostics` is true, pipeline search/assemble stats are included.
|
||||
pub async fn run_with_embedding_instrumented(
|
||||
params: RetrievalParams<'_>,
|
||||
query_embedding: Vec<f32>,
|
||||
) -> Result<RunOutput<StrategyOutput>, AppError> {
|
||||
let strategy = params.config.strategy;
|
||||
|
||||
match strategy {
|
||||
RetrievalStrategy::Default => {
|
||||
let driver = DefaultStrategyDriver::new();
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), true).await?;
|
||||
Ok(RunOutput {
|
||||
results: StrategyOutput::Chunks(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
})
|
||||
}
|
||||
_ => Err(AppError::InternalError(
|
||||
"Diagnostics not supported for this strategy".into(),
|
||||
)),
|
||||
}
|
||||
capture_diagnostics: bool,
|
||||
) -> Result<RunOutput<RetrievalOutput>, AppError> {
|
||||
run(params, Some(query_embedding), capture_diagnostics).await
|
||||
}
|
||||
|
||||
pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value {
|
||||
@@ -246,57 +210,3 @@ pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::V
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
}
|
||||
|
||||
pub struct StrategyParams<'a> {
|
||||
pub db_client: &'a SurrealDbClient,
|
||||
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
pub embedding_provider: Option<&'a common::utils::embedding::EmbeddingProvider>,
|
||||
pub input_text: &'a str,
|
||||
pub user_id: &'a str,
|
||||
pub config: RetrievalConfig,
|
||||
pub reranker: Option<RerankerLease>,
|
||||
}
|
||||
|
||||
async fn execute_strategy<D: StrategyDriver>(
|
||||
driver: D,
|
||||
params: StrategyParams<'_>,
|
||||
query_embedding: Option<Vec<f32>>,
|
||||
capture_diagnostics: bool,
|
||||
) -> Result<RunOutput<D::Output>, AppError> {
|
||||
let ctx = match query_embedding {
|
||||
Some(embedding) => PipelineContext::with_embedding(params, embedding),
|
||||
None => PipelineContext::new(params),
|
||||
};
|
||||
|
||||
run_with_driver(driver, ctx, capture_diagnostics).await
|
||||
}
|
||||
|
||||
async fn run_with_driver<D: StrategyDriver>(
|
||||
driver: D,
|
||||
mut ctx: PipelineContext<'_>,
|
||||
capture_diagnostics: bool,
|
||||
) -> Result<RunOutput<D::Output>, AppError> {
|
||||
if capture_diagnostics {
|
||||
ctx.enable_diagnostics();
|
||||
}
|
||||
|
||||
for stage in driver.stages() {
|
||||
let start = Instant::now();
|
||||
stage.execute(&mut ctx).await?;
|
||||
ctx.record_stage_duration(stage.kind(), start.elapsed());
|
||||
}
|
||||
|
||||
let diagnostics = ctx.take_diagnostics();
|
||||
let stage_timings = ctx.take_stage_timings();
|
||||
let results = driver.finalize(&mut ctx).map_err(|e| *e)?;
|
||||
|
||||
Ok(RunOutput {
|
||||
results,
|
||||
diagnostics,
|
||||
stage_timings,
|
||||
})
|
||||
}
|
||||
|
||||
fn round_score(value: f32) -> f64 {
|
||||
(f64::from(value) * 1000.0).round() / 1000.0
|
||||
}
|
||||
|
||||
@@ -0,0 +1,424 @@
|
||||
use async_trait::async_trait;
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
|
||||
utils::embedding::generate_embedding,
|
||||
};
|
||||
use fastembed::RerankResult;
|
||||
use std::collections::HashMap;
|
||||
use tracing::{debug, instrument, warn};
|
||||
|
||||
use crate::{
|
||||
scoring::{
|
||||
clamp_unit, min_max_normalize, reciprocal_rank_fusion, RrfConfig, Scored,
|
||||
},
|
||||
RetrievedChunk, RetrievedEntity,
|
||||
};
|
||||
|
||||
use super::{
|
||||
config::RetrievalTuning,
|
||||
context::PipelineContext,
|
||||
diagnostics::{AssembleStats, SearchStats},
|
||||
Stage, StageKind,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct EmbedStage;
|
||||
|
||||
#[async_trait]
|
||||
impl Stage for EmbedStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Embed
|
||||
}
|
||||
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
embed(ctx).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ChunkSearchStage;
|
||||
|
||||
#[async_trait]
|
||||
impl Stage for ChunkSearchStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Search
|
||||
}
|
||||
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
search_chunks(ctx).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ChunkRerankStage;
|
||||
|
||||
#[async_trait]
|
||||
impl Stage for ChunkRerankStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Rerank
|
||||
}
|
||||
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
rerank_chunks(ctx).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ResolveEntitiesStage;
|
||||
|
||||
#[async_trait]
|
||||
impl Stage for ResolveEntitiesStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::ResolveEntities
|
||||
}
|
||||
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
resolve_entities(ctx).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ChunkAssembleStage;
|
||||
|
||||
#[async_trait]
|
||||
impl Stage for ChunkAssembleStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Assemble
|
||||
}
|
||||
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
assemble_chunks(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
if ctx.query_embedding.is_some() {
|
||||
debug!("Reusing cached query embedding for hybrid retrieval");
|
||||
} else {
|
||||
debug!("Generating query embedding for hybrid retrieval");
|
||||
let embedding = if let Some(provider) = ctx.embedding_provider {
|
||||
provider.embed(&ctx.input_text).await?
|
||||
} else {
|
||||
generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await?
|
||||
};
|
||||
ctx.query_embedding = Some(embedding);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn search_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
debug!("Collecting chunk candidates via vector and FTS search");
|
||||
let embedding = ctx.ensure_embedding().map_err(|e| *e)?.clone();
|
||||
let tuning = &ctx.config.tuning;
|
||||
let fts_take = tuning.chunk_fts_take;
|
||||
let (fts_query, fts_token_count) = normalize_fts_query(&ctx.input_text);
|
||||
let fts_enabled = tuning.flags.chunk_rrf_use_fts() && fts_take > 0 && !fts_query.is_empty();
|
||||
|
||||
let (vector_rows, fts_rows) = tokio::try_join!(
|
||||
TextChunk::vector_search(
|
||||
tuning.chunk_vector_take,
|
||||
embedding,
|
||||
ctx.db_client,
|
||||
&ctx.user_id,
|
||||
),
|
||||
async {
|
||||
if fts_enabled {
|
||||
TextChunk::fts_search(fts_take, &fts_query, ctx.db_client, &ctx.user_id).await
|
||||
} else {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
}
|
||||
)?;
|
||||
|
||||
let vector_candidates = vector_rows.len();
|
||||
let fts_candidates = fts_rows.len();
|
||||
|
||||
let vector_scored: Vec<Scored<TextChunk>> = vector_rows
|
||||
.into_iter()
|
||||
.map(|row| Scored::new(row.chunk).with_vector_score(row.score))
|
||||
.collect();
|
||||
|
||||
let fts_scored: Vec<Scored<TextChunk>> = fts_rows
|
||||
.into_iter()
|
||||
.map(|row| Scored::new(row.chunk).with_fts_score(row.score))
|
||||
.collect();
|
||||
|
||||
let mut fts_weight = tuning.chunk_rrf_fts_weight;
|
||||
if fts_enabled && fts_token_count > 0 && fts_token_count <= 3 {
|
||||
// For very short keyword queries, lean more on lexical ranking.
|
||||
fts_weight *= 1.5;
|
||||
}
|
||||
|
||||
let rrf_config = RrfConfig {
|
||||
k: tuning.chunk_rrf_k,
|
||||
vector_weight: tuning.chunk_rrf_vector_weight,
|
||||
fts_weight,
|
||||
use_vector: tuning.flags.chunk_rrf_use_vector(),
|
||||
use_fts: tuning.flags.chunk_rrf_use_fts() && fts_candidates > 0,
|
||||
};
|
||||
|
||||
let chunks = reciprocal_rank_fusion(vector_scored, fts_scored, rrf_config);
|
||||
|
||||
debug!(
|
||||
total_merged = chunks.len(),
|
||||
vector_only = chunks.iter().filter(|c| c.scores.fts.is_none()).count(),
|
||||
fts_only = chunks.iter().filter(|c| c.scores.vector.is_none()).count(),
|
||||
both_signals = chunks
|
||||
.iter()
|
||||
.filter(|c| c.scores.vector.is_some() && c.scores.fts.is_some())
|
||||
.count(),
|
||||
rrf_k = rrf_config.k,
|
||||
"Merged chunk candidates with RRF"
|
||||
);
|
||||
|
||||
if ctx.diagnostics_enabled() {
|
||||
ctx.record_search(SearchStats {
|
||||
vector_chunk_candidates: vector_candidates,
|
||||
fts_chunk_candidates: fts_candidates,
|
||||
vector_chunk_scores: sample_scores(&chunks, |chunk| chunk.scores.vector.unwrap_or(0.0)),
|
||||
fts_chunk_scores: sample_scores(&chunks, |chunk| chunk.scores.fts.unwrap_or(0.0)),
|
||||
});
|
||||
}
|
||||
|
||||
ctx.chunk_values = chunks;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
if ctx.chunk_values.len() <= 1 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let Some(reranker) = ctx.reranker.as_ref() else {
|
||||
debug!("No reranker lease provided; skipping chunk rerank stage");
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let documents =
|
||||
build_chunk_rerank_documents(&ctx.chunk_values, ctx.config.tuning.rerank_keep_top.max(1));
|
||||
if documents.len() <= 1 {
|
||||
debug!("Skipping chunk reranking stage; insufficient chunk documents");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
match reranker.rerank(&ctx.input_text, documents).await {
|
||||
Ok(results) if !results.is_empty() => {
|
||||
apply_chunk_rerank_results(&mut ctx.chunk_values, &ctx.config.tuning, results);
|
||||
}
|
||||
Ok(_) => debug!("Chunk reranker returned no results; retaining original order"),
|
||||
Err(err) => warn!(
|
||||
error = %err,
|
||||
"Chunk reranking failed; continuing with original ordering"
|
||||
),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Resolve the `KnowledgeEntity` rows that own the retrieved chunks.
|
||||
///
|
||||
/// Entities are derived directly from the (benchmarked) chunk retrieval: chunks are grouped
|
||||
/// by `source_id`, the owning entities are loaded, scored by their best contributing chunk,
|
||||
/// and the contributing chunks are attached.
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn resolve_entities(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
if ctx.chunk_values.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let max_chunks = ctx.config.tuning.max_chunks_per_entity.max(1);
|
||||
|
||||
let mut source_order: Vec<String> = Vec::new();
|
||||
let mut chunks_by_source: HashMap<String, Vec<RetrievedChunk>> = HashMap::new();
|
||||
let mut best_score: HashMap<String, f32> = HashMap::new();
|
||||
|
||||
for scored in &ctx.chunk_values {
|
||||
let source = scored.item.source_id.clone();
|
||||
let attached = chunks_by_source.entry(source.clone()).or_default();
|
||||
if attached.is_empty() {
|
||||
source_order.push(source.clone());
|
||||
best_score.insert(source.clone(), scored.fused);
|
||||
}
|
||||
if attached.len() < max_chunks {
|
||||
attached.push(RetrievedChunk {
|
||||
chunk: scored.item.clone(),
|
||||
score: scored.fused,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let entities =
|
||||
KnowledgeEntity::find_by_source_ids(ctx.db_client, &source_order, &ctx.user_id).await?;
|
||||
|
||||
let mut entities_by_source: HashMap<String, Vec<KnowledgeEntity>> = HashMap::new();
|
||||
for entity in entities {
|
||||
entities_by_source
|
||||
.entry(entity.source_id.clone())
|
||||
.or_default()
|
||||
.push(entity);
|
||||
}
|
||||
|
||||
let mut results = Vec::new();
|
||||
for source in &source_order {
|
||||
let Some(entities) = entities_by_source.remove(source) else {
|
||||
continue;
|
||||
};
|
||||
let score = best_score.get(source).copied().unwrap_or(0.0);
|
||||
let chunks = chunks_by_source.get(source).cloned().unwrap_or_default();
|
||||
for entity in entities {
|
||||
results.push(RetrievedEntity {
|
||||
entity,
|
||||
score,
|
||||
chunks: chunks.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
debug!(
|
||||
sources = source_order.len(),
|
||||
entities = results.len(),
|
||||
"Resolved entities from retrieved chunks"
|
||||
);
|
||||
|
||||
ctx.entity_results = results;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
#[allow(clippy::result_large_err)]
|
||||
pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
debug!("Assembling chunk retrieval results");
|
||||
let mut chunk_values = std::mem::take(&mut ctx.chunk_values);
|
||||
// Limit how many chunks we return to keep context size reasonable.
|
||||
let limit = ctx
|
||||
.config
|
||||
.tuning
|
||||
.chunk_result_cap
|
||||
.max(1)
|
||||
.min(ctx.config.tuning.chunk_vector_take.max(1));
|
||||
|
||||
if chunk_values.len() > limit {
|
||||
chunk_values.truncate(limit);
|
||||
}
|
||||
|
||||
ctx.chunk_results = chunk_values
|
||||
.into_iter()
|
||||
.map(|chunk| RetrievedChunk {
|
||||
chunk: chunk.item,
|
||||
score: chunk.fused,
|
||||
})
|
||||
.collect();
|
||||
|
||||
if ctx.diagnostics_enabled() {
|
||||
ctx.record_assemble(AssembleStats {
|
||||
chunks_selected: ctx.chunk_results.len(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
const SCORE_SAMPLE_LIMIT: usize = 8;
|
||||
|
||||
fn sample_scores<T, F>(items: &[Scored<T>], extractor: F) -> Vec<f32>
|
||||
where
|
||||
F: FnMut(&Scored<T>) -> f32,
|
||||
{
|
||||
items.iter().take(SCORE_SAMPLE_LIMIT).map(extractor).collect()
|
||||
}
|
||||
|
||||
fn normalize_fts_query(input: &str) -> (String, usize) {
|
||||
const STOPWORDS: &[&str] = &["the", "a", "an", "of", "in", "on", "and", "or", "to", "for"];
|
||||
let mut cleaned = String::with_capacity(input.len());
|
||||
for ch in input.chars() {
|
||||
if ch.is_alphanumeric() {
|
||||
cleaned.extend(ch.to_lowercase());
|
||||
} else if ch.is_whitespace() {
|
||||
cleaned.push(' ');
|
||||
}
|
||||
}
|
||||
let mut tokens = Vec::with_capacity(cleaned.len().div_ceil(3));
|
||||
for token in cleaned.split_whitespace() {
|
||||
if !STOPWORDS.contains(&token) && !token.is_empty() {
|
||||
tokens.push(token.to_string());
|
||||
}
|
||||
}
|
||||
let normalized = tokens.join(" ");
|
||||
(normalized, tokens.len())
|
||||
}
|
||||
|
||||
fn build_chunk_rerank_documents(chunks: &[Scored<TextChunk>], max_chunks: usize) -> Vec<String> {
|
||||
chunks
|
||||
.iter()
|
||||
.take(max_chunks)
|
||||
.map(|chunk| {
|
||||
format!(
|
||||
"Source: {}\nChunk:\n{}",
|
||||
chunk.item.source_id,
|
||||
chunk.item.chunk.trim()
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn apply_chunk_rerank_results(
|
||||
chunks: &mut Vec<Scored<TextChunk>>,
|
||||
tuning: &RetrievalTuning,
|
||||
results: Vec<RerankResult>,
|
||||
) {
|
||||
if results.is_empty() || chunks.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut remaining: Vec<Option<Scored<TextChunk>>> =
|
||||
std::mem::take(chunks).into_iter().map(Some).collect();
|
||||
|
||||
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
|
||||
let normalized_scores = min_max_normalize(&raw_scores);
|
||||
|
||||
let use_only = tuning.flags.rerank_scores_only();
|
||||
let blend = if use_only {
|
||||
1.0
|
||||
} else {
|
||||
clamp_unit(tuning.rerank_blend_weight)
|
||||
};
|
||||
|
||||
let mut reranked: Vec<Scored<TextChunk>> = Vec::with_capacity(remaining.len());
|
||||
for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) {
|
||||
if let Some(slot) = remaining.get_mut(result.index) {
|
||||
if let Some(mut candidate) = slot.take() {
|
||||
let original = candidate.fused;
|
||||
let blended = if use_only {
|
||||
clamp_unit(normalized)
|
||||
} else {
|
||||
clamp_unit(original * (1.0 - blend) + normalized * blend)
|
||||
};
|
||||
candidate.update_fused(blended);
|
||||
reranked.push(candidate);
|
||||
}
|
||||
} else {
|
||||
warn!(
|
||||
result_index = result.index,
|
||||
"Chunk reranker returned out-of-range index; skipping"
|
||||
);
|
||||
}
|
||||
if reranked.len() == remaining.len() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
reranked.extend(remaining.into_iter().flatten());
|
||||
|
||||
let keep_top = tuning.rerank_keep_top;
|
||||
if keep_top > 0 && reranked.len() > keep_top {
|
||||
reranked.truncate(keep_top);
|
||||
}
|
||||
|
||||
*chunks = reranked;
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,148 +0,0 @@
|
||||
use super::{
|
||||
stages::{
|
||||
AssembleEntitiesStage, ChunkAssembleStage, ChunkRerankStage, ChunkVectorStage,
|
||||
CollectCandidatesStage, EmbedStage, GraphExpansionStage, PipelineContext, RerankStage,
|
||||
},
|
||||
BoxedStage, StrategyDriver,
|
||||
};
|
||||
use crate::{RetrievedChunk, RetrievedEntity};
|
||||
use common::error::AppError;
|
||||
|
||||
pub struct DefaultStrategyDriver;
|
||||
|
||||
impl DefaultStrategyDriver {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl StrategyDriver for DefaultStrategyDriver {
|
||||
type Output = Vec<RetrievedChunk>;
|
||||
|
||||
fn stages(&self) -> Vec<BoxedStage> {
|
||||
vec![
|
||||
Box::new(EmbedStage),
|
||||
Box::new(ChunkVectorStage),
|
||||
Box::new(ChunkRerankStage),
|
||||
Box::new(ChunkAssembleStage),
|
||||
]
|
||||
}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
|
||||
Ok(ctx.take_chunk_results())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RelationshipSuggestionDriver;
|
||||
|
||||
impl RelationshipSuggestionDriver {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl StrategyDriver for RelationshipSuggestionDriver {
|
||||
type Output = Vec<RetrievedEntity>;
|
||||
|
||||
fn stages(&self) -> Vec<BoxedStage> {
|
||||
vec![
|
||||
Box::new(EmbedStage),
|
||||
Box::new(CollectCandidatesStage),
|
||||
Box::new(GraphExpansionStage),
|
||||
// Skip ChunkAttachStage
|
||||
Box::new(RerankStage),
|
||||
Box::new(AssembleEntitiesStage),
|
||||
]
|
||||
}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
|
||||
Ok(ctx.take_entity_results())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct IngestionDriver;
|
||||
|
||||
impl IngestionDriver {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl StrategyDriver for IngestionDriver {
|
||||
type Output = Vec<RetrievedEntity>;
|
||||
|
||||
fn stages(&self) -> Vec<BoxedStage> {
|
||||
vec![
|
||||
Box::new(EmbedStage),
|
||||
Box::new(CollectCandidatesStage),
|
||||
Box::new(GraphExpansionStage),
|
||||
// Skip ChunkAttachStage
|
||||
Box::new(RerankStage),
|
||||
Box::new(AssembleEntitiesStage),
|
||||
]
|
||||
}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
|
||||
Ok(ctx.take_entity_results())
|
||||
}
|
||||
}
|
||||
|
||||
use super::config::SearchTarget;
|
||||
use crate::SearchResult;
|
||||
|
||||
/// Search strategy driver that retrieves both chunks and entities
|
||||
pub struct SearchStrategyDriver {
|
||||
target: SearchTarget,
|
||||
}
|
||||
|
||||
impl SearchStrategyDriver {
|
||||
pub fn new(target: SearchTarget) -> Self {
|
||||
Self { target }
|
||||
}
|
||||
}
|
||||
|
||||
impl StrategyDriver for SearchStrategyDriver {
|
||||
type Output = SearchResult;
|
||||
|
||||
fn stages(&self) -> Vec<BoxedStage> {
|
||||
match self.target {
|
||||
SearchTarget::ChunksOnly => vec![
|
||||
Box::new(EmbedStage),
|
||||
Box::new(ChunkVectorStage),
|
||||
Box::new(ChunkRerankStage),
|
||||
Box::new(ChunkAssembleStage),
|
||||
],
|
||||
SearchTarget::EntitiesOnly => vec![
|
||||
Box::new(EmbedStage),
|
||||
Box::new(CollectCandidatesStage),
|
||||
Box::new(GraphExpansionStage),
|
||||
Box::new(RerankStage),
|
||||
Box::new(AssembleEntitiesStage),
|
||||
],
|
||||
SearchTarget::Both => vec![
|
||||
Box::new(EmbedStage),
|
||||
// Chunk retrieval path
|
||||
Box::new(ChunkVectorStage),
|
||||
Box::new(ChunkRerankStage),
|
||||
Box::new(ChunkAssembleStage),
|
||||
// Entity retrieval path (runs after chunk stages)
|
||||
Box::new(CollectCandidatesStage),
|
||||
Box::new(GraphExpansionStage),
|
||||
Box::new(RerankStage),
|
||||
Box::new(AssembleEntitiesStage),
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
|
||||
let chunks = match self.target {
|
||||
SearchTarget::EntitiesOnly => Vec::new(),
|
||||
_ => ctx.take_chunk_results(),
|
||||
};
|
||||
let entities = match self.target {
|
||||
SearchTarget::ChunksOnly => Vec::new(),
|
||||
_ => ctx.take_entity_results(),
|
||||
};
|
||||
Ok(SearchResult::new(chunks, entities))
|
||||
}
|
||||
}
|
||||
@@ -97,8 +97,7 @@ impl RerankerPool {
|
||||
|
||||
fn default_pool_size() -> usize {
|
||||
available_parallelism()
|
||||
.map(|value| value.get().min(2))
|
||||
.unwrap_or(2)
|
||||
.map_or(2, |value| value.get().min(2))
|
||||
.max(1)
|
||||
}
|
||||
|
||||
@@ -156,6 +155,7 @@ pub struct RerankerLease {
|
||||
}
|
||||
|
||||
impl RerankerLease {
|
||||
#[allow(clippy::result_large_err)]
|
||||
pub async fn rerank(
|
||||
&self,
|
||||
query: &str,
|
||||
@@ -165,7 +165,9 @@ impl RerankerLease {
|
||||
let engine = Arc::clone(&self.engine);
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let mut guard = engine.lock().expect("reranker engine mutex poisoned");
|
||||
let mut guard = engine.lock().map_err(|_| {
|
||||
AppError::InternalError("reranker engine mutex poisoned".into())
|
||||
})?;
|
||||
guard
|
||||
.rerank(query, documents, false, None)
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))
|
||||
@@ -1,14 +1,12 @@
|
||||
use std::{cmp::Ordering, collections::HashMap};
|
||||
|
||||
use common::storage::types::StoredObject;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Holds optional subscores gathered from different retrieval signals.
|
||||
/// Holds optional subscores gathered from the vector and full-text retrieval signals.
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct Scores {
|
||||
pub fts: Option<f32>,
|
||||
pub vector: Option<f32>,
|
||||
pub graph: Option<f32>,
|
||||
}
|
||||
|
||||
/// Generic wrapper combining an item with its accumulated retrieval scores.
|
||||
@@ -40,40 +38,11 @@ impl<T> Scored<T> {
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn with_graph_score(mut self, score: f32) -> Self {
|
||||
self.scores.graph = Some(score);
|
||||
self
|
||||
}
|
||||
|
||||
pub const fn update_fused(&mut self, fused: f32) {
|
||||
self.fused = fused;
|
||||
}
|
||||
}
|
||||
|
||||
/// Weights used for linear score fusion.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct FusionWeights {
|
||||
pub vector: f32,
|
||||
pub fts: f32,
|
||||
pub graph: f32,
|
||||
pub multi_bonus: f32,
|
||||
}
|
||||
|
||||
impl Default for FusionWeights {
|
||||
fn default() -> Self {
|
||||
// Default weights favor vector search, which typically performs better
|
||||
// FTS is used as a complement when there's good overlap
|
||||
// Higher multi_bonus to heavily favor chunks with both signals (the "golden chunk")
|
||||
Self {
|
||||
vector: 0.8,
|
||||
fts: 0.2,
|
||||
graph: 0.2,
|
||||
multi_bonus: 0.3, // Increased to boost chunks with both signals
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for reciprocal rank fusion.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct RrfConfig {
|
||||
@@ -84,29 +53,10 @@ pub struct RrfConfig {
|
||||
pub use_fts: bool,
|
||||
}
|
||||
|
||||
impl Default for RrfConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
k: 60.0,
|
||||
vector_weight: 1.0,
|
||||
fts_weight: 1.0,
|
||||
use_vector: true,
|
||||
use_fts: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn clamp_unit(value: f32) -> f32 {
|
||||
value.clamp(0.0, 1.0)
|
||||
}
|
||||
|
||||
pub fn distance_to_similarity(distance: f32) -> f32 {
|
||||
if !distance.is_finite() {
|
||||
return 0.0;
|
||||
}
|
||||
clamp_unit(1.0 / (1.0 + distance.max(0.0)))
|
||||
}
|
||||
|
||||
pub fn min_max_normalize(scores: &[f32]) -> Vec<f32> {
|
||||
if scores.is_empty() {
|
||||
return Vec::new();
|
||||
@@ -147,69 +97,6 @@ pub fn min_max_normalize(scores: &[f32]) -> Vec<f32> {
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 {
|
||||
let vector = scores.vector.unwrap_or(0.0);
|
||||
let fts = scores.fts.unwrap_or(0.0);
|
||||
let graph = scores.graph.unwrap_or(0.0);
|
||||
|
||||
let mut fused = graph.mul_add(
|
||||
weights.graph,
|
||||
vector.mul_add(weights.vector, fts * weights.fts),
|
||||
);
|
||||
|
||||
let signals_present = scores
|
||||
.vector
|
||||
.iter()
|
||||
.chain(scores.fts.iter())
|
||||
.chain(scores.graph.iter())
|
||||
.count();
|
||||
|
||||
// Boost chunks with multiple signals (especially vector + FTS, the "golden chunk")
|
||||
if signals_present >= 2 {
|
||||
// For chunks with both vector and FTS, give a significant boost
|
||||
// This helps identify the "golden chunk" that appears in both searches
|
||||
if scores.vector.is_some() && scores.fts.is_some() {
|
||||
// Multiplicative boost: multiply by (1 + bonus) to scale with the base score
|
||||
// This ensures high-scoring golden chunks get boosted more than low-scoring ones
|
||||
fused *= 1.0 + weights.multi_bonus;
|
||||
} else {
|
||||
// For other multi-signal combinations (e.g., vector + graph), use additive bonus
|
||||
fused += weights.multi_bonus;
|
||||
}
|
||||
}
|
||||
|
||||
clamp_unit(fused)
|
||||
}
|
||||
|
||||
pub fn merge_scored_by_id<T, S: std::hash::BuildHasher>(
|
||||
target: &mut std::collections::HashMap<String, Scored<T>, S>,
|
||||
incoming: Vec<Scored<T>>,
|
||||
) where
|
||||
T: StoredObject + Clone,
|
||||
{
|
||||
for scored in incoming {
|
||||
let id = scored.item.id().to_owned();
|
||||
target
|
||||
.entry(id)
|
||||
.and_modify(|existing| {
|
||||
if let Some(score) = scored.scores.vector {
|
||||
existing.scores.vector = Some(score);
|
||||
}
|
||||
if let Some(score) = scored.scores.fts {
|
||||
existing.scores.fts = Some(score);
|
||||
}
|
||||
if let Some(score) = scored.scores.graph {
|
||||
existing.scores.graph = Some(score);
|
||||
}
|
||||
})
|
||||
.or_insert_with(|| Scored {
|
||||
item: scored.item.clone(),
|
||||
scores: scored.scores,
|
||||
fused: scored.fused,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sort_by_fused_desc<T>(items: &mut [Scored<T>])
|
||||
where
|
||||
T: StoredObject,
|
||||
@@ -222,6 +109,10 @@ where
|
||||
});
|
||||
}
|
||||
|
||||
/// Fuse two ranked candidate lists into a single ranking using reciprocal rank fusion.
|
||||
///
|
||||
/// This is the sole fusion mechanism for the retrieval pipeline: vector and full-text
|
||||
/// candidates each contribute `weight / (k + rank + 1)` to a shared fused score.
|
||||
pub fn reciprocal_rank_fusion<T>(
|
||||
mut vector_ranked: Vec<Scored<T>>,
|
||||
mut fts_ranked: Vec<Scored<T>>,
|
||||
@@ -266,9 +157,7 @@ where
|
||||
}
|
||||
}
|
||||
entry.item = candidate.item;
|
||||
let rank_f32: f32 = u16::try_from(rank)
|
||||
.map(f32::from)
|
||||
.unwrap_or(f32::MAX);
|
||||
let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from);
|
||||
entry.fused += vector_weight / (k + rank_f32 + 1.0);
|
||||
}
|
||||
}
|
||||
@@ -296,9 +185,7 @@ where
|
||||
}
|
||||
}
|
||||
entry.item = candidate.item;
|
||||
let rank_f32: f32 = u16::try_from(rank)
|
||||
.map(f32::from)
|
||||
.unwrap_or(f32::MAX);
|
||||
let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from);
|
||||
entry.fused += fts_weight / (k + rank_f32 + 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user