mirror of
https://github.com/perstarkse/minne.git
synced 2026-05-31 03:40:38 +02:00
chore: refactor retrieval pipeline to chunk-first RRF with derived entities and slimmer eval surface.
Collapse the multi-strategy entity engine into one benchmarked chunk retrieval path, derive entities from retrieved chunks, and update consumers, docs, and clippy fixes across the workspace.
This commit is contained in:
Generated
-4
@@ -5426,14 +5426,10 @@ dependencies = [
|
|||||||
"anyhow",
|
"anyhow",
|
||||||
"async-openai",
|
"async-openai",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum",
|
|
||||||
"common",
|
"common",
|
||||||
"fastembed",
|
"fastembed",
|
||||||
"futures",
|
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"surrealdb",
|
|
||||||
"thiserror 1.0.69",
|
|
||||||
"tokio",
|
"tokio",
|
||||||
"tracing",
|
"tracing",
|
||||||
"uuid",
|
"uuid",
|
||||||
|
|||||||
@@ -164,6 +164,35 @@ impl KnowledgeEntity {
|
|||||||
.take(0)?)
|
.take(0)?)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Fetch all knowledge entities owned by any of the provided source ids for a user.
|
||||||
|
///
|
||||||
|
/// Used by retrieval to resolve the entities that own a set of retrieved chunks.
|
||||||
|
pub async fn find_by_source_ids(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
source_ids: &[String],
|
||||||
|
user_id: &str,
|
||||||
|
) -> Result<Vec<KnowledgeEntity>, AppError> {
|
||||||
|
if source_ids.is_empty() {
|
||||||
|
return Ok(Vec::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
let entities: Vec<KnowledgeEntity> = db
|
||||||
|
.client
|
||||||
|
.query(
|
||||||
|
"SELECT * FROM type::table($table) \
|
||||||
|
WHERE source_id IN $sources AND user_id = $user_id",
|
||||||
|
)
|
||||||
|
.bind(("table", Self::table_name()))
|
||||||
|
.bind(("sources", source_ids.to_vec()))
|
||||||
|
.bind(("user_id", user_id.to_owned()))
|
||||||
|
.await
|
||||||
|
.map_err(AppError::Database)?
|
||||||
|
.take(0)
|
||||||
|
.map_err(AppError::Database)?;
|
||||||
|
|
||||||
|
Ok(entities)
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn delete_by_source_id(
|
pub async fn delete_by_source_id(
|
||||||
source_id: &str,
|
source_id: &str,
|
||||||
db_client: &SurrealDbClient,
|
db_client: &SurrealDbClient,
|
||||||
|
|||||||
+11
-119
@@ -1,8 +1,7 @@
|
|||||||
use config::{Config, ConfigError, Environment, File};
|
use config::{Config, ConfigError, Environment, File};
|
||||||
use serde::{Deserialize, Deserializer, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::{env, fmt, str::FromStr, sync::Once};
|
use std::{env, str::FromStr, sync::Once};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tracing::warn;
|
|
||||||
|
|
||||||
/// Error returned when parsing an embedding backend name.
|
/// Error returned when parsing an embedding backend name.
|
||||||
#[derive(Debug, Error, PartialEq, Eq)]
|
#[derive(Debug, Error, PartialEq, Eq)]
|
||||||
@@ -36,83 +35,6 @@ impl EmbeddingBackend {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Error returned when parsing a retrieval strategy name.
|
|
||||||
#[derive(Debug, Error, PartialEq, Eq)]
|
|
||||||
#[error("unknown retrieval strategy '{input}'")]
|
|
||||||
pub struct ParseRetrievalStrategyError {
|
|
||||||
/// The unrecognized input string.
|
|
||||||
pub input: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Selects which retrieval pipeline strategy to run for chat and search.
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum RetrievalStrategy {
|
|
||||||
/// Primary hybrid chunk retrieval for search/chat.
|
|
||||||
#[default]
|
|
||||||
Default,
|
|
||||||
/// Entity retrieval for suggesting relationships when creating manual entities.
|
|
||||||
RelationshipSuggestion,
|
|
||||||
/// Entity retrieval for context during content ingestion.
|
|
||||||
Ingestion,
|
|
||||||
/// Unified search returning both chunks and entities.
|
|
||||||
Search,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RetrievalStrategy {
|
|
||||||
#[must_use]
|
|
||||||
pub fn as_str(self) -> &'static str {
|
|
||||||
match self {
|
|
||||||
Self::Default => "default",
|
|
||||||
Self::RelationshipSuggestion => "relationship_suggestion",
|
|
||||||
Self::Ingestion => "ingestion",
|
|
||||||
Self::Search => "search",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl FromStr for RetrievalStrategy {
|
|
||||||
type Err = ParseRetrievalStrategyError;
|
|
||||||
|
|
||||||
fn from_str(value: &str) -> Result<Self, Self::Err> {
|
|
||||||
match value.to_ascii_lowercase().as_str() {
|
|
||||||
"default" => Ok(Self::Default),
|
|
||||||
"initial" | "revised" => {
|
|
||||||
warn!(
|
|
||||||
"retrieval strategy '{value}' is deprecated; use 'default' instead"
|
|
||||||
);
|
|
||||||
Ok(Self::Default)
|
|
||||||
}
|
|
||||||
"relationship_suggestion" => Ok(Self::RelationshipSuggestion),
|
|
||||||
"ingestion" => Ok(Self::Ingestion),
|
|
||||||
"search" => Ok(Self::Search),
|
|
||||||
other => Err(ParseRetrievalStrategyError {
|
|
||||||
input: other.to_string(),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl fmt::Display for RetrievalStrategy {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
f.write_str(self.as_str())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn deserialize_optional_retrieval_strategy<'de, D>(
|
|
||||||
deserializer: D,
|
|
||||||
) -> Result<Option<RetrievalStrategy>, D::Error>
|
|
||||||
where
|
|
||||||
D: Deserializer<'de>,
|
|
||||||
{
|
|
||||||
let value = Option::<String>::deserialize(deserializer)?;
|
|
||||||
match value {
|
|
||||||
None => Ok(None),
|
|
||||||
Some(raw) if raw.trim().is_empty() => Ok(None),
|
|
||||||
Some(raw) => RetrievalStrategy::from_str(&raw).map(Some).map_err(serde::de::Error::custom),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl FromStr for EmbeddingBackend {
|
impl FromStr for EmbeddingBackend {
|
||||||
type Err = ParseEmbeddingBackendError;
|
type Err = ParseEmbeddingBackendError;
|
||||||
|
|
||||||
@@ -195,8 +117,6 @@ pub struct AppConfig {
|
|||||||
pub fastembed_show_download_progress: Option<bool>,
|
pub fastembed_show_download_progress: Option<bool>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub fastembed_max_length: Option<usize>,
|
pub fastembed_max_length: Option<usize>,
|
||||||
#[serde(default, deserialize_with = "deserialize_optional_retrieval_strategy")]
|
|
||||||
pub retrieval_strategy: Option<RetrievalStrategy>,
|
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub embedding_backend: EmbeddingBackend,
|
pub embedding_backend: EmbeddingBackend,
|
||||||
#[serde(default = "default_ingest_max_body_bytes")]
|
#[serde(default = "default_ingest_max_body_bytes")]
|
||||||
@@ -282,14 +202,6 @@ pub fn ensure_ort_path() {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AppConfig {
|
|
||||||
/// Returns the configured retrieval strategy, or [`RetrievalStrategy::Default`] when unset.
|
|
||||||
#[must_use]
|
|
||||||
pub fn resolved_retrieval_strategy(&self) -> RetrievalStrategy {
|
|
||||||
self.retrieval_strategy.unwrap_or_default()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for AppConfig {
|
impl Default for AppConfig {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -312,7 +224,6 @@ impl Default for AppConfig {
|
|||||||
fastembed_cache_dir: None,
|
fastembed_cache_dir: None,
|
||||||
fastembed_show_download_progress: None,
|
fastembed_show_download_progress: None,
|
||||||
fastembed_max_length: None,
|
fastembed_max_length: None,
|
||||||
retrieval_strategy: None,
|
|
||||||
embedding_backend: EmbeddingBackend::default(),
|
embedding_backend: EmbeddingBackend::default(),
|
||||||
ingest_max_body_bytes: default_ingest_max_body_bytes(),
|
ingest_max_body_bytes: default_ingest_max_body_bytes(),
|
||||||
ingest_max_files: default_ingest_max_files(),
|
ingest_max_files: default_ingest_max_files(),
|
||||||
@@ -340,41 +251,22 @@ pub fn get_config() -> Result<AppConfig, ConfigError> {
|
|||||||
mod tests {
|
mod tests {
|
||||||
#![allow(clippy::expect_used)]
|
#![allow(clippy::expect_used)]
|
||||||
|
|
||||||
use super::{ParseRetrievalStrategyError, RetrievalStrategy};
|
use super::EmbeddingBackend;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn retrieval_strategy_defaults_to_default() {
|
fn embedding_backend_defaults_to_fastembed() {
|
||||||
assert_eq!(
|
assert_eq!(EmbeddingBackend::default(), EmbeddingBackend::FastEmbed);
|
||||||
RetrievalStrategy::default(),
|
|
||||||
RetrievalStrategy::Default
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn retrieval_strategy_serializes_snake_case() {
|
fn embedding_backend_parses_aliases() {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
serde_json::to_string(&RetrievalStrategy::Search).expect("serialize"),
|
"openai".parse::<EmbeddingBackend>().expect("openai"),
|
||||||
"\"search\""
|
EmbeddingBackend::OpenAI
|
||||||
);
|
);
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn retrieval_strategy_from_str_accepts_deprecated_aliases() {
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
"initial".parse::<RetrievalStrategy>().expect("initial"),
|
"fast".parse::<EmbeddingBackend>().expect("fast"),
|
||||||
RetrievalStrategy::Default
|
EmbeddingBackend::FastEmbed
|
||||||
);
|
|
||||||
assert!(matches!(
|
|
||||||
"unknown".parse::<RetrievalStrategy>(),
|
|
||||||
Err(ParseRetrievalStrategyError { .. })
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn app_config_resolved_retrieval_strategy_uses_default_when_unset() {
|
|
||||||
let config = super::AppConfig::default();
|
|
||||||
assert_eq!(
|
|
||||||
config.resolved_retrieval_strategy(),
|
|
||||||
RetrievalStrategy::Default
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ Minne can be configured via environment variables or a `config.yaml` file. Envir
|
|||||||
| `RUST_LOG` | Logging level | `info` |
|
| `RUST_LOG` | Logging level | `info` |
|
||||||
| `STORAGE` | Storage backend (`local`, `memory`, `s3`) | `local` |
|
| `STORAGE` | Storage backend (`local`, `memory`, `s3`) | `local` |
|
||||||
| `PDF_INGEST_MODE` | PDF ingestion strategy (`classic`, `llm-first`) | `llm-first` |
|
| `PDF_INGEST_MODE` | PDF ingestion strategy (`classic`, `llm-first`) | `llm-first` |
|
||||||
| `RETRIEVAL_STRATEGY` | Default retrieval strategy | - |
|
|
||||||
| `EMBEDDING_BACKEND` | Embedding provider (`openai`, `fastembed`) | `fastembed` |
|
| `EMBEDDING_BACKEND` | Embedding provider (`openai`, `fastembed`) | `fastembed` |
|
||||||
| `FASTEMBED_CACHE_DIR` | Model cache directory | `<data_dir>/fastembed` |
|
| `FASTEMBED_CACHE_DIR` | Model cache directory | `<data_dir>/fastembed` |
|
||||||
| `FASTEMBED_SHOW_DOWNLOAD_PROGRESS` | Show progress bar for model downloads | `false` |
|
| `FASTEMBED_SHOW_DOWNLOAD_PROGRESS` | Show progress bar for model downloads | `false` |
|
||||||
|
|||||||
+6
-5
@@ -27,13 +27,14 @@ The D3-based graph visualization shows entities as nodes and relationships as ed
|
|||||||
|
|
||||||
## Hybrid Retrieval
|
## Hybrid Retrieval
|
||||||
|
|
||||||
Minne combines multiple retrieval strategies:
|
Minne uses chunk-first hybrid retrieval over the knowledge base:
|
||||||
|
|
||||||
- **Vector similarity** — Semantic matching via embeddings
|
- **Vector similarity** — Semantic matching via embeddings over text chunks
|
||||||
- **Full-text search** — Keyword matching with BM25
|
- **Full-text search** — Keyword matching with BM25 over the same chunk index
|
||||||
- **Graph traversal** — Following relationships between entities
|
|
||||||
|
|
||||||
Results are merged using Reciprocal Rank Fusion (RRF) for optimal relevance.
|
The two ranked candidate lists are merged with Reciprocal Rank Fusion (RRF). When a caller needs knowledge entities (search, ingestion linking, relationship suggestion), entities are derived from the top retrieved chunks grouped by `source_id`.
|
||||||
|
|
||||||
|
Optional **reranking** can rescore the fused chunk list with a cross-encoder model; see below.
|
||||||
|
|
||||||
## Reranking (Optional)
|
## Reranking (Optional)
|
||||||
|
|
||||||
|
|||||||
+7
-18
@@ -5,7 +5,6 @@ use std::{
|
|||||||
|
|
||||||
use anyhow::{anyhow, Context, Result};
|
use anyhow::{anyhow, Context, Result};
|
||||||
use clap::{Args, Parser, ValueEnum};
|
use clap::{Args, Parser, ValueEnum};
|
||||||
use retrieval_pipeline::RetrievalStrategy;
|
|
||||||
|
|
||||||
use crate::datasets::DatasetKind;
|
use crate::datasets::DatasetKind;
|
||||||
|
|
||||||
@@ -55,10 +54,6 @@ pub struct RetrievalSettings {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub chunk_fts_take: Option<usize>,
|
pub chunk_fts_take: Option<usize>,
|
||||||
|
|
||||||
/// Override average characters per token used for budgeting
|
|
||||||
#[arg(long)]
|
|
||||||
pub chunk_avg_chars_per_token: Option<usize>,
|
|
||||||
|
|
||||||
/// Override maximum chunks attached per entity
|
/// Override maximum chunks attached per entity
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub max_chunks_per_entity: Option<usize>,
|
pub max_chunks_per_entity: Option<usize>,
|
||||||
@@ -71,41 +66,37 @@ pub struct RetrievalSettings {
|
|||||||
#[arg(long, default_value_t = 4)]
|
#[arg(long, default_value_t = 4)]
|
||||||
pub rerank_pool_size: usize,
|
pub rerank_pool_size: usize,
|
||||||
|
|
||||||
/// Keep top-N entities after reranking
|
/// Keep top-N chunks after reranking
|
||||||
#[arg(long, default_value_t = 10)]
|
#[arg(long, default_value_t = 10)]
|
||||||
pub rerank_keep_top: usize,
|
pub rerank_keep_top: usize,
|
||||||
|
|
||||||
/// Cap the number of chunks returned by retrieval (revised strategy)
|
/// Cap the number of chunks returned by retrieval
|
||||||
#[arg(long, default_value_t = 5)]
|
#[arg(long, default_value_t = 5)]
|
||||||
pub chunk_result_cap: usize,
|
pub chunk_result_cap: usize,
|
||||||
|
|
||||||
/// Reciprocal rank fusion k value for revised chunk merging
|
/// Reciprocal rank fusion k value for chunk merging
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub chunk_rrf_k: Option<f32>,
|
pub chunk_rrf_k: Option<f32>,
|
||||||
|
|
||||||
/// Weight for vector ranks in revised RRF
|
/// Weight for vector ranks in RRF
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub chunk_rrf_vector_weight: Option<f32>,
|
pub chunk_rrf_vector_weight: Option<f32>,
|
||||||
|
|
||||||
/// Weight for chunk FTS ranks in revised RRF
|
/// Weight for chunk FTS ranks in RRF
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub chunk_rrf_fts_weight: Option<f32>,
|
pub chunk_rrf_fts_weight: Option<f32>,
|
||||||
|
|
||||||
/// Include vector ranks in revised RRF (default: true)
|
/// Include vector ranks in RRF (default: true)
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub chunk_rrf_use_vector: Option<bool>,
|
pub chunk_rrf_use_vector: Option<bool>,
|
||||||
|
|
||||||
/// Include chunk FTS ranks in revised RRF (default: true)
|
/// Include chunk FTS ranks in RRF (default: true)
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub chunk_rrf_use_fts: Option<bool>,
|
pub chunk_rrf_use_fts: Option<bool>,
|
||||||
|
|
||||||
/// Require verified chunks (disable with --llm-mode)
|
/// Require verified chunks (disable with --llm-mode)
|
||||||
#[arg(skip = true)]
|
#[arg(skip = true)]
|
||||||
pub require_verified_chunks: bool,
|
pub require_verified_chunks: bool,
|
||||||
|
|
||||||
/// Select the retrieval pipeline strategy
|
|
||||||
#[arg(long, default_value_t = RetrievalStrategy::Default)]
|
|
||||||
pub strategy: RetrievalStrategy,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for RetrievalSettings {
|
impl Default for RetrievalSettings {
|
||||||
@@ -113,7 +104,6 @@ impl Default for RetrievalSettings {
|
|||||||
Self {
|
Self {
|
||||||
chunk_vector_take: None,
|
chunk_vector_take: None,
|
||||||
chunk_fts_take: None,
|
chunk_fts_take: None,
|
||||||
chunk_avg_chars_per_token: None,
|
|
||||||
max_chunks_per_entity: None,
|
max_chunks_per_entity: None,
|
||||||
rerank: false,
|
rerank: false,
|
||||||
rerank_pool_size: 4,
|
rerank_pool_size: 4,
|
||||||
@@ -125,7 +115,6 @@ impl Default for RetrievalSettings {
|
|||||||
chunk_rrf_use_vector: None,
|
chunk_rrf_use_vector: None,
|
||||||
chunk_rrf_use_fts: None,
|
chunk_rrf_use_fts: None,
|
||||||
require_verified_chunks: true,
|
require_verified_chunks: true,
|
||||||
strategy: RetrievalStrategy::Default,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+18
-20
@@ -51,8 +51,8 @@ pub fn mirror_perf_outputs(
|
|||||||
pub fn print_console_summary(record: &EvaluationReport) {
|
pub fn print_console_summary(record: &EvaluationReport) {
|
||||||
let perf = &record.performance;
|
let perf = &record.performance;
|
||||||
println!(
|
println!(
|
||||||
"[perf] retrieval strategy={} | concurrency={} | rerank={} (pool {:?}, keep {})",
|
"[perf] resolve_entities={} | concurrency={} | rerank={} (pool {:?}, keep {})",
|
||||||
record.retrieval.strategy,
|
record.retrieval.resolve_entities,
|
||||||
record.retrieval.concurrency,
|
record.retrieval.concurrency,
|
||||||
record.retrieval.rerank_enabled,
|
record.retrieval.rerank_enabled,
|
||||||
record.retrieval.rerank_pool_size,
|
record.retrieval.rerank_pool_size,
|
||||||
@@ -63,16 +63,14 @@ pub fn print_console_summary(record: &EvaluationReport) {
|
|||||||
perf.ingestion_ms,
|
perf.ingestion_ms,
|
||||||
format_duration(perf.namespace_seed_ms),
|
format_duration(perf.namespace_seed_ms),
|
||||||
);
|
);
|
||||||
let stage = &perf.stage_latency;
|
let stage_summary = perf
|
||||||
println!(
|
.stage_latency
|
||||||
"[perf] stage avg ms → embed {:.1} | collect {:.1} | graph {:.1} | chunk {:.1} | rerank {:.1} | assemble {:.1}",
|
.stages
|
||||||
stage.embed.avg,
|
.iter()
|
||||||
stage.collect_candidates.avg,
|
.map(|s| format!("{} {:.1}", s.stage, s.stats.avg))
|
||||||
stage.graph_expansion.avg,
|
.collect::<Vec<_>>()
|
||||||
stage.chunk_attach.avg,
|
.join(" | ");
|
||||||
stage.rerank.avg,
|
println!("[perf] stage avg ms → {stage_summary}");
|
||||||
stage.assemble.avg,
|
|
||||||
);
|
|
||||||
let eval = &perf.evaluation_stages_ms;
|
let eval = &perf.evaluation_stages_ms;
|
||||||
println!(
|
println!(
|
||||||
"[perf] eval stage ms → slice {} | db {} | corpus {} | namespace {} | queries {} | summarize {} | finalize {}",
|
"[perf] eval stage ms → slice {} | db {} | corpus {} | namespace {} | queries {} | summarize {} | finalize {}",
|
||||||
@@ -107,12 +105,13 @@ mod tests {
|
|||||||
|
|
||||||
fn sample_stage_latency() -> crate::eval::StageLatencyBreakdown {
|
fn sample_stage_latency() -> crate::eval::StageLatencyBreakdown {
|
||||||
crate::eval::StageLatencyBreakdown {
|
crate::eval::StageLatencyBreakdown {
|
||||||
embed: sample_latency(),
|
stages: ["embed", "search", "rerank", "resolve_entities", "assemble"]
|
||||||
collect_candidates: sample_latency(),
|
.into_iter()
|
||||||
graph_expansion: sample_latency(),
|
.map(|stage| crate::eval::StageLatency {
|
||||||
chunk_attach: sample_latency(),
|
stage: stage.to_string(),
|
||||||
rerank: sample_latency(),
|
stats: sample_latency(),
|
||||||
assemble: sample_latency(),
|
})
|
||||||
|
.collect(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -193,7 +192,7 @@ mod tests {
|
|||||||
rerank_keep_top: 10,
|
rerank_keep_top: 10,
|
||||||
concurrency: 2,
|
concurrency: 2,
|
||||||
detailed_report: false,
|
detailed_report: false,
|
||||||
retrieval_strategy: "initial".into(),
|
resolve_entities: false,
|
||||||
chunk_result_cap: 5,
|
chunk_result_cap: 5,
|
||||||
chunk_rrf_k: 60.0,
|
chunk_rrf_k: 60.0,
|
||||||
chunk_rrf_vector_weight: 1.0,
|
chunk_rrf_vector_weight: 1.0,
|
||||||
@@ -206,7 +205,6 @@ mod tests {
|
|||||||
ingest_chunk_overlap_tokens: 50,
|
ingest_chunk_overlap_tokens: 50,
|
||||||
chunk_vector_take: 20,
|
chunk_vector_take: 20,
|
||||||
chunk_fts_take: 20,
|
chunk_fts_take: 20,
|
||||||
chunk_avg_chars_per_token: 4,
|
|
||||||
max_chunks_per_entity: 4,
|
max_chunks_per_entity: 4,
|
||||||
cases: Vec::new(),
|
cases: Vec::new(),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ use futures::stream::{self, StreamExt};
|
|||||||
use tracing::{debug, info};
|
use tracing::{debug, info};
|
||||||
|
|
||||||
use crate::eval::{
|
use crate::eval::{
|
||||||
adapt_strategy_output, build_case_diagnostics, text_contains_answer, CaseDiagnostics,
|
adapt_retrieval_output, build_case_diagnostics, text_contains_answer, CaseDiagnostics,
|
||||||
CaseSummary, RetrievedSummary,
|
CaseSummary, RetrievedSummary,
|
||||||
};
|
};
|
||||||
use retrieval_pipeline::{
|
use retrieval_pipeline::{
|
||||||
@@ -51,14 +51,8 @@ pub(crate) async fn run_queries(
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut retrieval_config = RetrievalConfig {
|
let mut retrieval_config = RetrievalConfig::default();
|
||||||
strategy: config.retrieval.strategy,
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
retrieval_config.tuning.rerank_keep_top = config.retrieval.rerank_keep_top;
|
retrieval_config.tuning.rerank_keep_top = config.retrieval.rerank_keep_top;
|
||||||
if retrieval_config.tuning.fallback_min_results < config.retrieval.rerank_keep_top {
|
|
||||||
retrieval_config.tuning.fallback_min_results = config.retrieval.rerank_keep_top;
|
|
||||||
}
|
|
||||||
retrieval_config.tuning.chunk_result_cap = config.retrieval.chunk_result_cap.max(1);
|
retrieval_config.tuning.chunk_result_cap = config.retrieval.chunk_result_cap.max(1);
|
||||||
if let Some(value) = config.retrieval.chunk_vector_take {
|
if let Some(value) = config.retrieval.chunk_vector_take {
|
||||||
retrieval_config.tuning.chunk_vector_take = value;
|
retrieval_config.tuning.chunk_vector_take = value;
|
||||||
@@ -81,9 +75,6 @@ pub(crate) async fn run_queries(
|
|||||||
if let Some(value) = config.retrieval.chunk_rrf_use_fts {
|
if let Some(value) = config.retrieval.chunk_rrf_use_fts {
|
||||||
retrieval_config.tuning.flags.chunk_rrf_use_fts = value.into();
|
retrieval_config.tuning.flags.chunk_rrf_use_fts = value.into();
|
||||||
}
|
}
|
||||||
if let Some(value) = config.retrieval.chunk_avg_chars_per_token {
|
|
||||||
retrieval_config.tuning.avg_chars_per_token = value;
|
|
||||||
}
|
|
||||||
if let Some(value) = config.retrieval.max_chunks_per_entity {
|
if let Some(value) = config.retrieval.max_chunks_per_entity {
|
||||||
retrieval_config.tuning.max_chunks_per_entity = value;
|
retrieval_config.tuning.max_chunks_per_entity = value;
|
||||||
}
|
}
|
||||||
@@ -187,7 +178,7 @@ pub(crate) async fn run_queries(
|
|||||||
None => None,
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let params = pipeline::StrategyParams {
|
let params = pipeline::RetrievalParams {
|
||||||
db_client: &db,
|
db_client: &db,
|
||||||
openai_client: &openai_client,
|
openai_client: &openai_client,
|
||||||
embedding_provider: Some(&embedding_provider),
|
embedding_provider: Some(&embedding_provider),
|
||||||
@@ -196,26 +187,19 @@ pub(crate) async fn run_queries(
|
|||||||
config: (*retrieval_config).clone(),
|
config: (*retrieval_config).clone(),
|
||||||
reranker,
|
reranker,
|
||||||
};
|
};
|
||||||
let (result_output, pipeline_diagnostics, stage_timings) = if diagnostics_enabled {
|
let (result_output, pipeline_diagnostics, stage_timings) = {
|
||||||
let outcome = pipeline::run_pipeline_with_embedding_with_diagnostics(
|
let outcome = pipeline::run_with_embedding_instrumented(
|
||||||
params,
|
params,
|
||||||
query_embedding,
|
query_embedding,
|
||||||
|
diagnostics_enabled,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.with_context(|| format!("running pipeline for question {question_id}"))?;
|
.with_context(|| format!("running pipeline for question {question_id}"))?;
|
||||||
(outcome.results, outcome.diagnostics, outcome.stage_timings)
|
(outcome.results, outcome.diagnostics, outcome.stage_timings)
|
||||||
} else {
|
|
||||||
let outcome = pipeline::run_pipeline_with_embedding_with_metrics(
|
|
||||||
params,
|
|
||||||
query_embedding,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.with_context(|| format!("running pipeline for question {question_id}"))?;
|
|
||||||
(outcome.results, None, outcome.stage_timings)
|
|
||||||
};
|
};
|
||||||
let query_latency = query_start.elapsed().as_millis();
|
let query_latency = query_start.elapsed().as_millis();
|
||||||
|
|
||||||
let candidates = adapt_strategy_output(result_output);
|
let candidates = adapt_retrieval_output(result_output);
|
||||||
let mut retrieved = Vec::new();
|
let mut retrieved = Vec::new();
|
||||||
let mut match_rank = None;
|
let mut match_rank = None;
|
||||||
let answers_lower: Vec<String> =
|
let answers_lower: Vec<String> =
|
||||||
|
|||||||
@@ -201,7 +201,10 @@ pub(crate) async fn summarize(
|
|||||||
rerank_keep_top: config.retrieval.rerank_keep_top,
|
rerank_keep_top: config.retrieval.rerank_keep_top,
|
||||||
concurrency: config.concurrency.max(1),
|
concurrency: config.concurrency.max(1),
|
||||||
detailed_report: config.detailed_report,
|
detailed_report: config.detailed_report,
|
||||||
retrieval_strategy: config.retrieval.strategy.to_string(),
|
resolve_entities: ctx
|
||||||
|
.retrieval_config
|
||||||
|
.as_ref()
|
||||||
|
.is_some_and(|config| config.resolve_entities),
|
||||||
chunk_result_cap: config.retrieval.chunk_result_cap,
|
chunk_result_cap: config.retrieval.chunk_result_cap,
|
||||||
chunk_rrf_k: active_tuning.chunk_rrf_k,
|
chunk_rrf_k: active_tuning.chunk_rrf_k,
|
||||||
chunk_rrf_vector_weight: active_tuning.chunk_rrf_vector_weight,
|
chunk_rrf_vector_weight: active_tuning.chunk_rrf_vector_weight,
|
||||||
@@ -214,7 +217,6 @@ pub(crate) async fn summarize(
|
|||||||
ingest_chunk_overlap_tokens: config.ingest.ingest_chunk_overlap_tokens,
|
ingest_chunk_overlap_tokens: config.ingest.ingest_chunk_overlap_tokens,
|
||||||
chunk_vector_take: active_tuning.chunk_vector_take,
|
chunk_vector_take: active_tuning.chunk_vector_take,
|
||||||
chunk_fts_take: active_tuning.chunk_fts_take,
|
chunk_fts_take: active_tuning.chunk_fts_take,
|
||||||
chunk_avg_chars_per_token: active_tuning.avg_chars_per_token,
|
|
||||||
max_chunks_per_entity: active_tuning.max_chunks_per_entity,
|
max_chunks_per_entity: active_tuning.max_chunks_per_entity,
|
||||||
cases: summaries,
|
cases: summaries,
|
||||||
});
|
});
|
||||||
|
|||||||
+54
-49
@@ -85,7 +85,7 @@ pub struct RetrievalSection {
|
|||||||
pub average_ndcg: f64,
|
pub average_ndcg: f64,
|
||||||
pub latency: LatencyStats,
|
pub latency: LatencyStats,
|
||||||
pub concurrency: usize,
|
pub concurrency: usize,
|
||||||
pub strategy: String,
|
pub resolve_entities: bool,
|
||||||
pub rerank_enabled: bool,
|
pub rerank_enabled: bool,
|
||||||
pub rerank_pool_size: Option<usize>,
|
pub rerank_pool_size: Option<usize>,
|
||||||
pub rerank_keep_top: usize,
|
pub rerank_keep_top: usize,
|
||||||
@@ -226,7 +226,7 @@ impl EvaluationReport {
|
|||||||
average_ndcg: summary.average_ndcg,
|
average_ndcg: summary.average_ndcg,
|
||||||
latency: summary.latency_ms.clone(),
|
latency: summary.latency_ms.clone(),
|
||||||
concurrency: summary.concurrency,
|
concurrency: summary.concurrency,
|
||||||
strategy: summary.retrieval_strategy.clone(),
|
resolve_entities: summary.resolve_entities,
|
||||||
rerank_enabled: summary.rerank_enabled,
|
rerank_enabled: summary.rerank_enabled,
|
||||||
rerank_pool_size: summary.rerank_pool_size,
|
rerank_pool_size: summary.rerank_pool_size,
|
||||||
rerank_keep_top: summary.rerank_keep_top,
|
rerank_keep_top: summary.rerank_keep_top,
|
||||||
@@ -463,7 +463,12 @@ fn render_markdown(report: &EvaluationReport) -> String {
|
|||||||
write!(md, "| MRR | {:.3} |\\n", report.retrieval.mrr).unwrap();
|
write!(md, "| MRR | {:.3} |\\n", report.retrieval.mrr).unwrap();
|
||||||
write!(md, "| NDCG | {:.3} |\\n", report.retrieval.average_ndcg).unwrap();
|
write!(md, "| NDCG | {:.3} |\\n", report.retrieval.average_ndcg).unwrap();
|
||||||
write!(md, "| Latency Avg / P50 / P95 (ms) | {:.1} / {} / {} |\\n", report.retrieval.latency.avg, report.retrieval.latency.p50, report.retrieval.latency.p95).unwrap();
|
write!(md, "| Latency Avg / P50 / P95 (ms) | {:.1} / {} / {} |\\n", report.retrieval.latency.avg, report.retrieval.latency.p50, report.retrieval.latency.p95).unwrap();
|
||||||
write!(md, "| Strategy | `{}` |\\n", report.retrieval.strategy).unwrap();
|
write!(
|
||||||
|
md,
|
||||||
|
"| Resolve entities | {} |\\n",
|
||||||
|
bool_badge(report.retrieval.resolve_entities)
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
write!(md, "| Concurrency | {} |\\n", report.retrieval.concurrency).unwrap();
|
write!(md, "| Concurrency | {} |\\n", report.retrieval.concurrency).unwrap();
|
||||||
if report.retrieval.rerank_enabled {
|
if report.retrieval.rerank_enabled {
|
||||||
let pool = report
|
let pool = report
|
||||||
@@ -510,28 +515,9 @@ fn render_markdown(report: &EvaluationReport) -> String {
|
|||||||
|
|
||||||
md.push_str("\\n## Retrieval Stage Timings\\n\\n");
|
md.push_str("\\n## Retrieval Stage Timings\\n\\n");
|
||||||
md.push_str("| Stage | Avg (ms) | P50 (ms) | P95 (ms) |\\n| --- | --- | --- | --- |\\n");
|
md.push_str("| Stage | Avg (ms) | P50 (ms) | P95 (ms) |\\n| --- | --- | --- | --- |\\n");
|
||||||
write_stage_row(&mut md, "Embed", &report.performance.stage_latency.embed);
|
for stage in &report.performance.stage_latency.stages {
|
||||||
write_stage_row(
|
write_stage_row(&mut md, &prettify_stage(&stage.stage), &stage.stats);
|
||||||
&mut md,
|
}
|
||||||
"Collect Candidates",
|
|
||||||
&report.performance.stage_latency.collect_candidates,
|
|
||||||
);
|
|
||||||
write_stage_row(
|
|
||||||
&mut md,
|
|
||||||
"Graph Expansion",
|
|
||||||
&report.performance.stage_latency.graph_expansion,
|
|
||||||
);
|
|
||||||
write_stage_row(
|
|
||||||
&mut md,
|
|
||||||
"Chunk Attach",
|
|
||||||
&report.performance.stage_latency.chunk_attach,
|
|
||||||
);
|
|
||||||
write_stage_row(&mut md, "Rerank", &report.performance.stage_latency.rerank);
|
|
||||||
write_stage_row(
|
|
||||||
&mut md,
|
|
||||||
"Assemble",
|
|
||||||
&report.performance.stage_latency.assemble,
|
|
||||||
);
|
|
||||||
|
|
||||||
if report.misses.is_empty() {
|
if report.misses.is_empty() {
|
||||||
if report.detailed_report {
|
if report.detailed_report {
|
||||||
@@ -623,6 +609,20 @@ fn write_stage_row(buf: &mut String, label: &str, stats: &LatencyStats) {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Turn a stable stage label (e.g. `resolve_entities`) into a display title (`Resolve Entities`).
|
||||||
|
fn prettify_stage(label: &str) -> String {
|
||||||
|
label
|
||||||
|
.split('_')
|
||||||
|
.map(|word| {
|
||||||
|
let mut chars = word.chars();
|
||||||
|
chars.next().map_or_else(String::new, |first| {
|
||||||
|
first.to_uppercase().collect::<String>() + chars.as_str()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(" ")
|
||||||
|
}
|
||||||
|
|
||||||
fn bool_badge(value: bool) -> &'static str {
|
fn bool_badge(value: bool) -> &'static str {
|
||||||
if value {
|
if value {
|
||||||
"✅"
|
"✅"
|
||||||
@@ -740,17 +740,6 @@ struct LegacyHistoryDelta {
|
|||||||
latency_avg_ms: f64,
|
latency_avg_ms: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_stage_latency() -> StageLatencyBreakdown {
|
|
||||||
StageLatencyBreakdown {
|
|
||||||
embed: LatencyStats::default(),
|
|
||||||
collect_candidates: LatencyStats::default(),
|
|
||||||
graph_expansion: LatencyStats::default(),
|
|
||||||
chunk_attach: LatencyStats::default(),
|
|
||||||
rerank: LatencyStats::default(),
|
|
||||||
assemble: LatencyStats::default(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::too_many_lines)]
|
#[allow(clippy::too_many_lines)]
|
||||||
fn convert_legacy_entry(entry: LegacyHistoryEntry) -> EvaluationReport {
|
fn convert_legacy_entry(entry: LegacyHistoryEntry) -> EvaluationReport {
|
||||||
let overview = OverviewSection {
|
let overview = OverviewSection {
|
||||||
@@ -807,7 +796,7 @@ fn convert_legacy_entry(entry: LegacyHistoryEntry) -> EvaluationReport {
|
|||||||
average_ndcg: entry.average_ndcg,
|
average_ndcg: entry.average_ndcg,
|
||||||
latency: entry.latency_ms,
|
latency: entry.latency_ms,
|
||||||
concurrency: 0,
|
concurrency: 0,
|
||||||
strategy: "unknown".into(),
|
resolve_entities: false,
|
||||||
rerank_enabled: entry.rerank_enabled,
|
rerank_enabled: entry.rerank_enabled,
|
||||||
rerank_pool_size: entry.rerank_pool_size,
|
rerank_pool_size: entry.rerank_pool_size,
|
||||||
rerank_keep_top: entry.rerank_keep_top,
|
rerank_keep_top: entry.rerank_keep_top,
|
||||||
@@ -840,7 +829,7 @@ fn convert_legacy_entry(entry: LegacyHistoryEntry) -> EvaluationReport {
|
|||||||
ingestion_ms: entry.ingestion_ms,
|
ingestion_ms: entry.ingestion_ms,
|
||||||
namespace_seed_ms: entry.namespace_seed_ms,
|
namespace_seed_ms: entry.namespace_seed_ms,
|
||||||
evaluation_stages_ms: EvaluationStageTimings::default(),
|
evaluation_stages_ms: EvaluationStageTimings::default(),
|
||||||
stage_latency: default_stage_latency(),
|
stage_latency: StageLatencyBreakdown::default(),
|
||||||
namespace_reused: false,
|
namespace_reused: false,
|
||||||
ingestion_reused: entry.ingestion_reused,
|
ingestion_reused: entry.ingestion_reused,
|
||||||
embeddings_reused: entry.ingestion_embeddings_reused,
|
embeddings_reused: entry.ingestion_embeddings_reused,
|
||||||
@@ -915,7 +904,8 @@ fn record_history(report: &EvaluationReport, report_dir: &Path) -> Result<PathBu
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::eval::{
|
use crate::eval::{
|
||||||
EvaluationStageTimings, PerformanceTimings, RetrievedSummary, StageLatencyBreakdown,
|
EvaluationStageTimings, PerformanceTimings, RetrievedSummary, StageLatency,
|
||||||
|
StageLatencyBreakdown,
|
||||||
};
|
};
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use tempfile::tempdir;
|
use tempfile::tempdir;
|
||||||
@@ -931,12 +921,28 @@ mod tests {
|
|||||||
|
|
||||||
fn sample_stage_latency() -> StageLatencyBreakdown {
|
fn sample_stage_latency() -> StageLatencyBreakdown {
|
||||||
StageLatencyBreakdown {
|
StageLatencyBreakdown {
|
||||||
embed: latency(9.0),
|
stages: vec![
|
||||||
collect_candidates: latency(10.0),
|
StageLatency {
|
||||||
graph_expansion: latency(11.0),
|
stage: "embed".to_string(),
|
||||||
chunk_attach: latency(12.0),
|
stats: latency(9.0),
|
||||||
rerank: latency(13.0),
|
},
|
||||||
assemble: latency(14.0),
|
StageLatency {
|
||||||
|
stage: "search".to_string(),
|
||||||
|
stats: latency(10.0),
|
||||||
|
},
|
||||||
|
StageLatency {
|
||||||
|
stage: "rerank".to_string(),
|
||||||
|
stats: latency(13.0),
|
||||||
|
},
|
||||||
|
StageLatency {
|
||||||
|
stage: "resolve_entities".to_string(),
|
||||||
|
stats: latency(11.0),
|
||||||
|
},
|
||||||
|
StageLatency {
|
||||||
|
stage: "assemble".to_string(),
|
||||||
|
stats: latency(14.0),
|
||||||
|
},
|
||||||
|
],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1058,7 +1064,7 @@ mod tests {
|
|||||||
rerank_keep_top: 5,
|
rerank_keep_top: 5,
|
||||||
concurrency: 2,
|
concurrency: 2,
|
||||||
detailed_report: true,
|
detailed_report: true,
|
||||||
retrieval_strategy: "initial".into(),
|
resolve_entities: false,
|
||||||
chunk_result_cap: 5,
|
chunk_result_cap: 5,
|
||||||
chunk_rrf_k: 60.0,
|
chunk_rrf_k: 60.0,
|
||||||
chunk_rrf_vector_weight: 1.0,
|
chunk_rrf_vector_weight: 1.0,
|
||||||
@@ -1071,7 +1077,6 @@ mod tests {
|
|||||||
ingest_chunks_only: false,
|
ingest_chunks_only: false,
|
||||||
chunk_vector_take: 50,
|
chunk_vector_take: 50,
|
||||||
chunk_fts_take: 50,
|
chunk_fts_take: 50,
|
||||||
chunk_avg_chars_per_token: 4,
|
|
||||||
max_chunks_per_entity: 4,
|
max_chunks_per_entity: 4,
|
||||||
cases,
|
cases,
|
||||||
}
|
}
|
||||||
@@ -1097,7 +1102,7 @@ mod tests {
|
|||||||
|
|
||||||
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::indexing_slicing)]
|
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::indexing_slicing)]
|
||||||
#[test]
|
#[test]
|
||||||
fn evaluations_history_captures_strategy_and_concurrency() {
|
fn evaluations_history_captures_resolve_entities_and_concurrency() {
|
||||||
let tmp = tempdir().unwrap();
|
let tmp = tempdir().unwrap();
|
||||||
let summary = sample_summary(false);
|
let summary = sample_summary(false);
|
||||||
|
|
||||||
@@ -1109,7 +1114,7 @@ mod tests {
|
|||||||
assert_eq!(entries.len(), 1);
|
assert_eq!(entries.len(), 1);
|
||||||
let stored = &entries[0];
|
let stored = &entries[0];
|
||||||
assert_eq!(stored.retrieval.concurrency, summary.concurrency);
|
assert_eq!(stored.retrieval.concurrency, summary.concurrency);
|
||||||
assert_eq!(stored.retrieval.strategy, summary.retrieval_strategy);
|
assert_eq!(stored.retrieval.resolve_entities, summary.resolve_entities);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
stored.performance.evaluation_stages_ms.run_queries_ms,
|
stored.performance.evaluation_stages_ms.run_queries_ms,
|
||||||
summary.perf.evaluation_stage_ms.run_queries_ms
|
summary.perf.evaluation_stage_ms.run_queries_ms
|
||||||
|
|||||||
+33
-38
@@ -3,7 +3,7 @@ use std::collections::HashSet;
|
|||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use common::storage::types::StoredObject;
|
use common::storage::types::StoredObject;
|
||||||
use retrieval_pipeline::{
|
use retrieval_pipeline::{
|
||||||
PipelineDiagnostics, PipelineStageTimings, RetrievedChunk, RetrievedEntity, StrategyOutput,
|
Diagnostics, RetrievalOutput, RetrievedChunk, RetrievedEntity, StageKind, StageTimings,
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use unicode_normalization::UnicodeNormalization;
|
use unicode_normalization::UnicodeNormalization;
|
||||||
@@ -69,7 +69,7 @@ pub struct EvaluationSummary {
|
|||||||
pub rerank_keep_top: usize,
|
pub rerank_keep_top: usize,
|
||||||
pub concurrency: usize,
|
pub concurrency: usize,
|
||||||
pub detailed_report: bool,
|
pub detailed_report: bool,
|
||||||
pub retrieval_strategy: String,
|
pub resolve_entities: bool,
|
||||||
pub chunk_result_cap: usize,
|
pub chunk_result_cap: usize,
|
||||||
pub chunk_rrf_k: f32,
|
pub chunk_rrf_k: f32,
|
||||||
pub chunk_rrf_vector_weight: f32,
|
pub chunk_rrf_vector_weight: f32,
|
||||||
@@ -82,7 +82,6 @@ pub struct EvaluationSummary {
|
|||||||
pub ingest_chunk_overlap_tokens: usize,
|
pub ingest_chunk_overlap_tokens: usize,
|
||||||
pub chunk_vector_take: usize,
|
pub chunk_vector_take: usize,
|
||||||
pub chunk_fts_take: usize,
|
pub chunk_fts_take: usize,
|
||||||
pub chunk_avg_chars_per_token: usize,
|
|
||||||
pub max_chunks_per_entity: usize,
|
pub max_chunks_per_entity: usize,
|
||||||
pub cases: Vec<CaseSummary>,
|
pub cases: Vec<CaseSummary>,
|
||||||
}
|
}
|
||||||
@@ -129,14 +128,20 @@ impl Default for LatencyStats {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Latency statistics for a single retrieval stage, keyed by the stage's stable label.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct StageLatency {
|
||||||
|
pub stage: String,
|
||||||
|
pub stats: LatencyStats,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Per-stage retrieval latency, in canonical pipeline order.
|
||||||
|
///
|
||||||
|
/// The set of stages is driven entirely by [`StageKind::ALL`], so adding a retrieval stage
|
||||||
|
/// surfaces here automatically without changes to the evaluation harness.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||||
pub struct StageLatencyBreakdown {
|
pub struct StageLatencyBreakdown {
|
||||||
pub embed: LatencyStats,
|
pub stages: Vec<StageLatency>,
|
||||||
pub collect_candidates: LatencyStats,
|
|
||||||
pub graph_expansion: LatencyStats,
|
|
||||||
pub chunk_attach: LatencyStats,
|
|
||||||
pub rerank: LatencyStats,
|
|
||||||
pub assemble: LatencyStats,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::struct_field_names)]
|
#[allow(clippy::struct_field_names)]
|
||||||
@@ -232,13 +237,12 @@ fn candidates_from_chunks(chunks: Vec<RetrievedChunk>) -> Vec<EvaluationCandidat
|
|||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn adapt_strategy_output(output: StrategyOutput) -> Vec<EvaluationCandidate> {
|
pub fn adapt_retrieval_output(output: RetrievalOutput) -> Vec<EvaluationCandidate> {
|
||||||
match output {
|
match output {
|
||||||
StrategyOutput::Entities(entities) => candidates_from_entities(entities),
|
RetrievalOutput::Chunks(chunks) => candidates_from_chunks(chunks),
|
||||||
StrategyOutput::Chunks(chunks) => candidates_from_chunks(chunks),
|
RetrievalOutput::WithEntities { chunks, entities } => {
|
||||||
StrategyOutput::Search(search_result) => {
|
let mut candidates = candidates_from_entities(entities);
|
||||||
let mut candidates = candidates_from_entities(search_result.entities);
|
candidates.extend(candidates_from_chunks(chunks));
|
||||||
candidates.extend(candidates_from_chunks(search_result.chunks));
|
|
||||||
candidates.sort_by(|a, b| b.score.total_cmp(&a.score));
|
candidates.sort_by(|a, b| b.score.total_cmp(&a.score));
|
||||||
candidates
|
candidates
|
||||||
}
|
}
|
||||||
@@ -262,7 +266,7 @@ pub struct CaseDiagnostics {
|
|||||||
pub attached_chunk_ids: Vec<String>,
|
pub attached_chunk_ids: Vec<String>,
|
||||||
pub retrieved: Vec<EntityDiagnostics>,
|
pub retrieved: Vec<EntityDiagnostics>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub pipeline: Option<PipelineDiagnostics>,
|
pub pipeline: Option<Diagnostics>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
@@ -366,28 +370,19 @@ pub fn compute_latency_stats(latencies: &[u128]) -> LatencyStats {
|
|||||||
LatencyStats { avg, p50, p95 }
|
LatencyStats { avg, p50, p95 }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn build_stage_latency_breakdown(samples: &[PipelineStageTimings]) -> StageLatencyBreakdown {
|
pub fn build_stage_latency_breakdown(samples: &[StageTimings]) -> StageLatencyBreakdown {
|
||||||
fn collect_stage<F>(samples: &[PipelineStageTimings], selector: F) -> Vec<u128>
|
let stages = StageKind::ALL
|
||||||
where
|
.iter()
|
||||||
F: Fn(&PipelineStageTimings) -> u128,
|
.map(|kind| {
|
||||||
{
|
let latencies: Vec<u128> = samples.iter().map(|s| s.stage_ms(*kind)).collect();
|
||||||
samples.iter().map(selector).collect()
|
StageLatency {
|
||||||
}
|
stage: kind.label().to_string(),
|
||||||
|
stats: compute_latency_stats(&latencies),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
StageLatencyBreakdown {
|
StageLatencyBreakdown { stages }
|
||||||
embed: compute_latency_stats(&collect_stage(samples, retrieval_pipeline::StageTimings::embed_ms)),
|
|
||||||
collect_candidates: compute_latency_stats(&collect_stage(samples, |entry| {
|
|
||||||
entry.collect_candidates_ms()
|
|
||||||
})),
|
|
||||||
graph_expansion: compute_latency_stats(&collect_stage(samples, |entry| {
|
|
||||||
entry.graph_expansion_ms()
|
|
||||||
})),
|
|
||||||
chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| {
|
|
||||||
entry.chunk_attach_ms()
|
|
||||||
})),
|
|
||||||
rerank: compute_latency_stats(&collect_stage(samples, retrieval_pipeline::StageTimings::rerank_ms)),
|
|
||||||
assemble: compute_latency_stats(&collect_stage(samples, retrieval_pipeline::StageTimings::assemble_ms)),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(
|
#[allow(
|
||||||
@@ -412,7 +407,7 @@ pub fn build_case_diagnostics(
|
|||||||
expected_chunk_ids: &[String],
|
expected_chunk_ids: &[String],
|
||||||
answers_lower: &[String],
|
answers_lower: &[String],
|
||||||
candidates: &[EvaluationCandidate],
|
candidates: &[EvaluationCandidate],
|
||||||
pipeline_stats: Option<PipelineDiagnostics>,
|
pipeline_stats: Option<Diagnostics>,
|
||||||
) -> CaseDiagnostics {
|
) -> CaseDiagnostics {
|
||||||
let expected_set: HashSet<&str> = expected_chunk_ids.iter().map(std::string::String::as_str).collect();
|
let expected_set: HashSet<&str> = expected_chunk_ids.iter().map(std::string::String::as_str).collect();
|
||||||
let mut seen_chunks: HashSet<String> = HashSet::new();
|
let mut seen_chunks: HashSet<String> = HashSet::new();
|
||||||
|
|||||||
@@ -44,7 +44,6 @@
|
|||||||
--leading-snug: 1.375;
|
--leading-snug: 1.375;
|
||||||
--leading-relaxed: 1.625;
|
--leading-relaxed: 1.625;
|
||||||
--ease-out: cubic-bezier(0, 0, 0.2, 1);
|
--ease-out: cubic-bezier(0, 0, 0.2, 1);
|
||||||
--ease-in-out: cubic-bezier(0.4, 0, 0.2, 1);
|
|
||||||
--animate-pulse: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite;
|
--animate-pulse: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite;
|
||||||
--default-transition-duration: 150ms;
|
--default-transition-duration: 150ms;
|
||||||
--default-transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
|
--default-transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
|
||||||
@@ -285,37 +284,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
.drawer-open {
|
|
||||||
> .drawer-side {
|
|
||||||
overflow-y: auto;
|
|
||||||
}
|
|
||||||
> .drawer-toggle {
|
|
||||||
display: none;
|
|
||||||
& ~ .drawer-side {
|
|
||||||
pointer-events: auto;
|
|
||||||
visibility: visible;
|
|
||||||
position: sticky;
|
|
||||||
display: block;
|
|
||||||
width: auto;
|
|
||||||
overscroll-behavior: auto;
|
|
||||||
opacity: 100%;
|
|
||||||
& > .drawer-overlay {
|
|
||||||
cursor: default;
|
|
||||||
background-color: transparent;
|
|
||||||
}
|
|
||||||
& > *:not(.drawer-overlay) {
|
|
||||||
translate: 0%;
|
|
||||||
[dir="rtl"] & {
|
|
||||||
translate: 0%;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
&:checked ~ .drawer-side {
|
|
||||||
pointer-events: auto;
|
|
||||||
visibility: visible;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
.drawer-toggle {
|
.drawer-toggle {
|
||||||
position: fixed;
|
position: fixed;
|
||||||
height: calc(0.25rem * 0);
|
height: calc(0.25rem * 0);
|
||||||
@@ -1074,22 +1042,6 @@
|
|||||||
grid-row-start: 1;
|
grid-row-start: 1;
|
||||||
min-width: calc(0.25rem * 0);
|
min-width: calc(0.25rem * 0);
|
||||||
}
|
}
|
||||||
.chat-image {
|
|
||||||
grid-row: span 2 / span 2;
|
|
||||||
align-self: flex-end;
|
|
||||||
}
|
|
||||||
.chat-footer {
|
|
||||||
grid-row-start: 3;
|
|
||||||
display: flex;
|
|
||||||
gap: calc(0.25rem * 1);
|
|
||||||
font-size: 0.6875rem;
|
|
||||||
}
|
|
||||||
.chat-header {
|
|
||||||
grid-row-start: 1;
|
|
||||||
display: flex;
|
|
||||||
gap: calc(0.25rem * 1);
|
|
||||||
font-size: 0.6875rem;
|
|
||||||
}
|
|
||||||
.container {
|
.container {
|
||||||
width: 100%;
|
width: 100%;
|
||||||
@media (width >= 40rem) {
|
@media (width >= 40rem) {
|
||||||
@@ -1796,9 +1748,6 @@
|
|||||||
.w-10 {
|
.w-10 {
|
||||||
width: calc(var(--spacing) * 10);
|
width: calc(var(--spacing) * 10);
|
||||||
}
|
}
|
||||||
.w-11 {
|
|
||||||
width: calc(var(--spacing) * 11);
|
|
||||||
}
|
|
||||||
.w-11\/12 {
|
.w-11\/12 {
|
||||||
width: calc(11/12 * 100%);
|
width: calc(11/12 * 100%);
|
||||||
}
|
}
|
||||||
@@ -1862,9 +1811,6 @@
|
|||||||
.flex-none {
|
.flex-none {
|
||||||
flex: none;
|
flex: none;
|
||||||
}
|
}
|
||||||
.flex-shrink {
|
|
||||||
flex-shrink: 1;
|
|
||||||
}
|
|
||||||
.flex-shrink-0 {
|
.flex-shrink-0 {
|
||||||
flex-shrink: 0;
|
flex-shrink: 0;
|
||||||
}
|
}
|
||||||
@@ -1877,13 +1823,6 @@
|
|||||||
.grow {
|
.grow {
|
||||||
flex-grow: 1;
|
flex-grow: 1;
|
||||||
}
|
}
|
||||||
.border-collapse {
|
|
||||||
border-collapse: collapse;
|
|
||||||
}
|
|
||||||
.-translate-y-1 {
|
|
||||||
--tw-translate-y: calc(var(--spacing) * -1);
|
|
||||||
translate: var(--tw-translate-x) var(--tw-translate-y);
|
|
||||||
}
|
|
||||||
.-translate-y-1\/2 {
|
.-translate-y-1\/2 {
|
||||||
--tw-translate-y: calc(calc(1/2 * 100%) * -1);
|
--tw-translate-y: calc(calc(1/2 * 100%) * -1);
|
||||||
translate: var(--tw-translate-x) var(--tw-translate-y);
|
translate: var(--tw-translate-x) var(--tw-translate-y);
|
||||||
@@ -1956,9 +1895,6 @@
|
|||||||
.justify-start {
|
.justify-start {
|
||||||
justify-content: flex-start;
|
justify-content: flex-start;
|
||||||
}
|
}
|
||||||
.gap-0 {
|
|
||||||
gap: calc(var(--spacing) * 0);
|
|
||||||
}
|
|
||||||
.gap-0\.5 {
|
.gap-0\.5 {
|
||||||
gap: calc(var(--spacing) * 0.5);
|
gap: calc(var(--spacing) * 0.5);
|
||||||
}
|
}
|
||||||
@@ -2115,9 +2051,6 @@
|
|||||||
.bg-transparent {
|
.bg-transparent {
|
||||||
background-color: transparent;
|
background-color: transparent;
|
||||||
}
|
}
|
||||||
.bg-warning {
|
|
||||||
background-color: var(--color-warning);
|
|
||||||
}
|
|
||||||
.bg-warning\/10 {
|
.bg-warning\/10 {
|
||||||
background-color: var(--color-warning);
|
background-color: var(--color-warning);
|
||||||
@supports (color: color-mix(in lab, red, red)) {
|
@supports (color: color-mix(in lab, red, red)) {
|
||||||
@@ -2136,9 +2069,6 @@
|
|||||||
.loading-spinner {
|
.loading-spinner {
|
||||||
mask-image: url("data:image/svg+xml,%3Csvg width='24' height='24' stroke='black' viewBox='0 0 24 24' xmlns='http://www.w3.org/2000/svg'%3E%3Cg transform-origin='center'%3E%3Ccircle cx='12' cy='12' r='9.5' fill='none' stroke-width='3' stroke-linecap='round'%3E%3CanimateTransform attributeName='transform' type='rotate' from='0 12 12' to='360 12 12' dur='2s' repeatCount='indefinite'/%3E%3Canimate attributeName='stroke-dasharray' values='0,150;42,150;42,150' keyTimes='0;0.475;1' dur='1.5s' repeatCount='indefinite'/%3E%3Canimate attributeName='stroke-dashoffset' values='0;-16;-59' keyTimes='0;0.475;1' dur='1.5s' repeatCount='indefinite'/%3E%3C/circle%3E%3C/g%3E%3C/svg%3E");
|
mask-image: url("data:image/svg+xml,%3Csvg width='24' height='24' stroke='black' viewBox='0 0 24 24' xmlns='http://www.w3.org/2000/svg'%3E%3Cg transform-origin='center'%3E%3Ccircle cx='12' cy='12' r='9.5' fill='none' stroke-width='3' stroke-linecap='round'%3E%3CanimateTransform attributeName='transform' type='rotate' from='0 12 12' to='360 12 12' dur='2s' repeatCount='indefinite'/%3E%3Canimate attributeName='stroke-dasharray' values='0,150;42,150;42,150' keyTimes='0;0.475;1' dur='1.5s' repeatCount='indefinite'/%3E%3Canimate attributeName='stroke-dashoffset' values='0;-16;-59' keyTimes='0;0.475;1' dur='1.5s' repeatCount='indefinite'/%3E%3C/circle%3E%3C/g%3E%3C/svg%3E");
|
||||||
}
|
}
|
||||||
.mask-repeat {
|
|
||||||
mask-repeat: repeat;
|
|
||||||
}
|
|
||||||
.fill-current {
|
.fill-current {
|
||||||
fill: currentcolor;
|
fill: currentcolor;
|
||||||
}
|
}
|
||||||
@@ -2169,9 +2099,6 @@
|
|||||||
.p-8 {
|
.p-8 {
|
||||||
padding: calc(var(--spacing) * 8);
|
padding: calc(var(--spacing) * 8);
|
||||||
}
|
}
|
||||||
.px-1 {
|
|
||||||
padding-inline: calc(var(--spacing) * 1);
|
|
||||||
}
|
|
||||||
.px-1\.5 {
|
.px-1\.5 {
|
||||||
padding-inline: calc(var(--spacing) * 1.5);
|
padding-inline: calc(var(--spacing) * 1.5);
|
||||||
}
|
}
|
||||||
@@ -2326,9 +2253,6 @@
|
|||||||
--tw-tracking: var(--tracking-widest);
|
--tw-tracking: var(--tracking-widest);
|
||||||
letter-spacing: var(--tracking-widest);
|
letter-spacing: var(--tracking-widest);
|
||||||
}
|
}
|
||||||
.text-wrap {
|
|
||||||
text-wrap: wrap;
|
|
||||||
}
|
|
||||||
.break-words {
|
.break-words {
|
||||||
overflow-wrap: break-word;
|
overflow-wrap: break-word;
|
||||||
}
|
}
|
||||||
@@ -2395,17 +2319,6 @@
|
|||||||
.italic {
|
.italic {
|
||||||
font-style: italic;
|
font-style: italic;
|
||||||
}
|
}
|
||||||
.underline {
|
|
||||||
text-decoration-line: underline;
|
|
||||||
}
|
|
||||||
.swap-active {
|
|
||||||
.swap-off {
|
|
||||||
opacity: 0%;
|
|
||||||
}
|
|
||||||
.swap-on {
|
|
||||||
opacity: 100%;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
.opacity-0 {
|
.opacity-0 {
|
||||||
opacity: 0%;
|
opacity: 0%;
|
||||||
}
|
}
|
||||||
@@ -2496,10 +2409,6 @@
|
|||||||
--tw-duration: 300ms;
|
--tw-duration: 300ms;
|
||||||
transition-duration: 300ms;
|
transition-duration: 300ms;
|
||||||
}
|
}
|
||||||
.ease-in-out {
|
|
||||||
--tw-ease: var(--ease-in-out);
|
|
||||||
transition-timing-function: var(--ease-in-out);
|
|
||||||
}
|
|
||||||
.ease-out {
|
.ease-out {
|
||||||
--tw-ease: var(--ease-out);
|
--tw-ease: var(--ease-out);
|
||||||
transition-timing-function: var(--ease-out);
|
transition-timing-function: var(--ease-out);
|
||||||
|
|||||||
@@ -2,10 +2,7 @@ use common::storage::types::conversation::SidebarConversation;
|
|||||||
use common::storage::{db::SurrealDbClient, store::StorageManager};
|
use common::storage::{db::SurrealDbClient, store::StorageManager};
|
||||||
use common::utils::embedding::EmbeddingProvider;
|
use common::utils::embedding::EmbeddingProvider;
|
||||||
use common::utils::template_engine::{ProvidesTemplateEngine, TemplateEngine};
|
use common::utils::template_engine::{ProvidesTemplateEngine, TemplateEngine};
|
||||||
use common::{
|
use common::{create_template_engine, storage::db::ProvidesDb, utils::config::AppConfig};
|
||||||
create_template_engine, storage::db::ProvidesDb,
|
|
||||||
utils::config::{AppConfig, RetrievalStrategy},
|
|
||||||
};
|
|
||||||
use retrieval_pipeline::reranking::RerankerPool;
|
use retrieval_pipeline::reranking::RerankerPool;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{
|
use std::sync::{
|
||||||
@@ -75,10 +72,6 @@ impl HtmlState {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn retrieval_strategy(&self) -> RetrievalStrategy {
|
|
||||||
self.config.resolved_retrieval_strategy()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn get_cached_conversation_archive(
|
pub async fn get_cached_conversation_archive(
|
||||||
&self,
|
&self,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
|
|||||||
@@ -16,12 +16,9 @@ use futures::{
|
|||||||
};
|
};
|
||||||
use json_stream_parser::JsonStreamParser;
|
use json_stream_parser::JsonStreamParser;
|
||||||
use minijinja::Value;
|
use minijinja::Value;
|
||||||
use retrieval_pipeline::{
|
use retrieval_pipeline::answer_retrieval::{
|
||||||
answer_retrieval::{
|
chunks_to_chat_context, create_chat_request, create_user_message_with_history,
|
||||||
chunks_to_chat_context, create_chat_request, create_user_message_with_history,
|
LLMResponseFormat,
|
||||||
LLMResponseFormat,
|
|
||||||
},
|
|
||||||
retrieved_entities_to_json,
|
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::from_str;
|
use serde_json::from_str;
|
||||||
@@ -189,11 +186,7 @@ struct ReferenceData {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn extract_reference_strings(response: &LLMResponseFormat) -> Vec<String> {
|
fn extract_reference_strings(response: &LLMResponseFormat) -> Vec<String> {
|
||||||
response
|
response.reference_ids()
|
||||||
.references
|
|
||||||
.iter()
|
|
||||||
.map(|reference| reference.reference.clone())
|
|
||||||
.collect()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_lines)]
|
#[allow(clippy::too_many_lines)]
|
||||||
@@ -362,10 +355,9 @@ async fn prepare_chat_request(
|
|||||||
None => None,
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let strategy = state.retrieval_strategy();
|
let config = retrieval_pipeline::RetrievalConfig::default();
|
||||||
let config = retrieval_pipeline::RetrievalConfig::for_chat(strategy);
|
|
||||||
|
|
||||||
let retrieval_result = match retrieval_pipeline::retrieve_entities(
|
let retrieval_result = match retrieval_pipeline::retrieve(
|
||||||
&state.db,
|
&state.db,
|
||||||
&state.openai_client,
|
&state.openai_client,
|
||||||
Some(&*state.embedding_provider),
|
Some(&*state.embedding_provider),
|
||||||
@@ -387,12 +379,9 @@ async fn prepare_chat_request(
|
|||||||
let allowed_reference_ids = collect_reference_ids_from_retrieval(&retrieval_result);
|
let allowed_reference_ids = collect_reference_ids_from_retrieval(&retrieval_result);
|
||||||
|
|
||||||
let context_json = match retrieval_result {
|
let context_json = match retrieval_result {
|
||||||
retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(&chunks),
|
retrieval_pipeline::RetrievalOutput::Chunks(chunks) => chunks_to_chat_context(&chunks),
|
||||||
retrieval_pipeline::StrategyOutput::Entities(entities) => {
|
retrieval_pipeline::RetrievalOutput::WithEntities { chunks, .. } => {
|
||||||
retrieved_entities_to_json(&entities)
|
chunks_to_chat_context(&chunks)
|
||||||
}
|
|
||||||
retrieval_pipeline::StrategyOutput::Search(search_result) => {
|
|
||||||
chunks_to_chat_context(&search_result.chunks)
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let formatted_user_message =
|
let formatted_user_message =
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ use common::{
|
|||||||
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject},
|
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use retrieval_pipeline::StrategyOutput;
|
use retrieval_pipeline::RetrievalOutput;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
pub(crate) const MAX_REFERENCE_COUNT: usize = 10;
|
pub(crate) const MAX_REFERENCE_COUNT: usize = 10;
|
||||||
@@ -86,40 +86,29 @@ pub(crate) enum ReferenceLookupTarget {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn collect_reference_ids_from_retrieval(
|
pub(crate) fn collect_reference_ids_from_retrieval(
|
||||||
retrieval_result: &StrategyOutput,
|
retrieval_result: &RetrievalOutput,
|
||||||
) -> Vec<String> {
|
) -> Vec<String> {
|
||||||
let mut ids = Vec::new();
|
let mut ids = Vec::new();
|
||||||
let mut seen = HashSet::new();
|
let mut seen = HashSet::new();
|
||||||
|
|
||||||
|
let mut push_id = |id: String| {
|
||||||
|
if seen.insert(id.clone()) {
|
||||||
|
ids.push(id);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
match retrieval_result {
|
match retrieval_result {
|
||||||
StrategyOutput::Chunks(chunks) => {
|
RetrievalOutput::Chunks(chunks) => {
|
||||||
for chunk in chunks {
|
for chunk in chunks {
|
||||||
let id = chunk.chunk.id.clone();
|
push_id(chunk.chunk.id.clone());
|
||||||
if seen.insert(id.clone()) {
|
|
||||||
ids.push(id);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
StrategyOutput::Entities(entities) => {
|
RetrievalOutput::WithEntities { chunks, entities } => {
|
||||||
|
for chunk in chunks {
|
||||||
|
push_id(chunk.chunk.id.clone());
|
||||||
|
}
|
||||||
for entity in entities {
|
for entity in entities {
|
||||||
let id = entity.entity.id.clone();
|
push_id(entity.entity.id.clone());
|
||||||
if seen.insert(id.clone()) {
|
|
||||||
ids.push(id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
StrategyOutput::Search(search) => {
|
|
||||||
for chunk in &search.chunks {
|
|
||||||
let id = chunk.chunk.id.clone();
|
|
||||||
if seen.insert(id.clone()) {
|
|
||||||
ids.push(id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for entity in &search.entities {
|
|
||||||
let id = entity.entity.id.clone();
|
|
||||||
if seen.insert(id.clone()) {
|
|
||||||
ids.push(id);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ use crate::{
|
|||||||
middlewares::{
|
middlewares::{
|
||||||
auth_middleware::RequireUser,
|
auth_middleware::RequireUser,
|
||||||
response_middleware::{
|
response_middleware::{
|
||||||
template_as_response, HtmlError, TemplateResponse, TemplateResult, ResponseResult,
|
template_as_response, TemplateResponse, TemplateResult, ResponseResult,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
utils::text_content_preview::truncate_text_contents,
|
utils::text_content_preview::truncate_text_contents,
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ use crate::{
|
|||||||
middlewares::{
|
middlewares::{
|
||||||
auth_middleware::RequireUser,
|
auth_middleware::RequireUser,
|
||||||
response_middleware::{
|
response_middleware::{
|
||||||
template_with_headers, HtmlError, TemplateResponse, TemplateResult, ResponseResult,
|
template_with_headers, TemplateResponse, TemplateResult, ResponseResult,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
utils::pagination::{paginate_items, Pagination},
|
utils::pagination::{paginate_items, Pagination},
|
||||||
@@ -284,9 +284,9 @@ pub async fn suggest_knowledge_relationships(
|
|||||||
None => None,
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let config = retrieval_pipeline::RetrievalConfig::for_relationship_suggestion();
|
let config = retrieval_pipeline::RetrievalConfig::with_entities();
|
||||||
if let Ok(retrieval_pipeline::StrategyOutput::Entities(results)) =
|
if let Ok(retrieval_pipeline::RetrievalOutput::WithEntities { entities, .. }) =
|
||||||
retrieval_pipeline::retrieve_entities(
|
retrieval_pipeline::retrieve(
|
||||||
&state.db,
|
&state.db,
|
||||||
&state.openai_client,
|
&state.openai_client,
|
||||||
Some(&*state.embedding_provider),
|
Some(&*state.embedding_provider),
|
||||||
@@ -297,7 +297,7 @@ pub async fn suggest_knowledge_relationships(
|
|||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
for retrieval_pipeline::RetrievedEntity { entity, score, .. } in results {
|
for retrieval_pipeline::RetrievedEntity { entity, score, .. } in entities {
|
||||||
if suggestion_scores.len() >= MAX_RELATIONSHIP_SUGGESTIONS {
|
if suggestion_scores.len() >= MAX_RELATIONSHIP_SUGGESTIONS {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ use crate::html_state::HtmlState;
|
|||||||
use crate::middlewares::{
|
use crate::middlewares::{
|
||||||
auth_middleware::RequireUser,
|
auth_middleware::RequireUser,
|
||||||
response_middleware::{
|
response_middleware::{
|
||||||
template_with_headers, HtmlError, TemplateResponse, TemplateResult, ResponseResult,
|
template_with_headers, TemplateResponse, TemplateResult, ResponseResult,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use common::storage::types::{
|
use common::storage::types::{
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use axum::{
|
|||||||
extract::{Query, State},
|
extract::{Query, State},
|
||||||
};
|
};
|
||||||
use common::storage::types::{text_content::TextContent, user::User};
|
use common::storage::types::{text_content::TextContent, user::User};
|
||||||
use retrieval_pipeline::{RetrievalConfig, SearchResult, SearchTarget, StrategyOutput};
|
use retrieval_pipeline::{retrieve, RetrievalConfig, RetrievalOutput, RetrievedChunk, RetrievedEntity};
|
||||||
use serde::{de, Deserialize, Deserializer, Serialize};
|
use serde::{de, Deserialize, Deserializer, Serialize};
|
||||||
use std::{fmt, str::FromStr};
|
use std::{fmt, str::FromStr};
|
||||||
|
|
||||||
@@ -108,35 +108,35 @@ async fn perform_search(
|
|||||||
return Ok((Vec::new(), String::new()));
|
return Ok((Vec::new(), String::new()));
|
||||||
}
|
}
|
||||||
|
|
||||||
let config = RetrievalConfig::for_search(SearchTarget::Both);
|
let config = RetrievalConfig::with_entities();
|
||||||
|
|
||||||
let reranker_lease = match &state.reranker_pool {
|
let reranker_lease = match &state.reranker_pool {
|
||||||
Some(pool) => pool.checkout().await,
|
Some(pool) => pool.checkout().await,
|
||||||
None => None,
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let params = retrieval_pipeline::pipeline::StrategyParams {
|
let result = retrieve(
|
||||||
db_client: &state.db,
|
&state.db,
|
||||||
openai_client: &state.openai_client,
|
&state.openai_client,
|
||||||
embedding_provider: Some(&state.embedding_provider),
|
Some(&state.embedding_provider),
|
||||||
input_text: trimmed_query,
|
trimmed_query,
|
||||||
user_id: &user.id,
|
&user.id,
|
||||||
config,
|
config,
|
||||||
reranker: reranker_lease,
|
reranker_lease,
|
||||||
};
|
)
|
||||||
let result = retrieval_pipeline::pipeline::execute(params).await?;
|
.await?;
|
||||||
|
|
||||||
let search_result = match result {
|
let (chunks, entities) = match result {
|
||||||
StrategyOutput::Search(sr) => sr,
|
RetrievalOutput::WithEntities { chunks, entities } => (chunks, entities),
|
||||||
_ => SearchResult::new(vec![], vec![]),
|
RetrievalOutput::Chunks(chunks) => (chunks, Vec::new()),
|
||||||
};
|
};
|
||||||
|
|
||||||
let source_label_map = collect_source_label_map(state, user, &search_result).await?;
|
let source_label_map = collect_source_label_map(state, user, &chunks, &entities).await?;
|
||||||
|
|
||||||
let mut combined_results: Vec<SearchResultForTemplate> =
|
let mut combined_results: Vec<SearchResultForTemplate> =
|
||||||
Vec::with_capacity(search_result.chunks.len().saturating_add(search_result.entities.len()));
|
Vec::with_capacity(chunks.len().saturating_add(entities.len()));
|
||||||
|
|
||||||
for chunk_result in search_result.chunks {
|
for chunk_result in chunks {
|
||||||
let source_label = source_label_map
|
let source_label = source_label_map
|
||||||
.get(&chunk_result.chunk.source_id)
|
.get(&chunk_result.chunk.source_id)
|
||||||
.cloned()
|
.cloned()
|
||||||
@@ -155,7 +155,7 @@ async fn perform_search(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
for entity_result in search_result.entities {
|
for entity_result in entities {
|
||||||
let source_label = source_label_map
|
let source_label = source_label_map
|
||||||
.get(&entity_result.entity.source_id)
|
.get(&entity_result.entity.source_id)
|
||||||
.cloned()
|
.cloned()
|
||||||
@@ -187,13 +187,14 @@ async fn perform_search(
|
|||||||
async fn collect_source_label_map(
|
async fn collect_source_label_map(
|
||||||
state: &HtmlState,
|
state: &HtmlState,
|
||||||
user: &User,
|
user: &User,
|
||||||
search_result: &SearchResult,
|
chunks: &[RetrievedChunk],
|
||||||
|
entities: &[RetrievedEntity],
|
||||||
) -> Result<std::collections::HashMap<String, String>, HtmlError> {
|
) -> Result<std::collections::HashMap<String, String>, HtmlError> {
|
||||||
let mut source_ids = HashSet::new();
|
let mut source_ids = HashSet::new();
|
||||||
for chunk_result in &search_result.chunks {
|
for chunk_result in chunks {
|
||||||
source_ids.insert(chunk_result.chunk.source_id.clone());
|
source_ids.insert(chunk_result.chunk.source_id.clone());
|
||||||
}
|
}
|
||||||
for entity_result in &search_result.entities {
|
for entity_result in entities {
|
||||||
source_ids.insert(entity_result.entity.source_id.clone());
|
source_ids.insert(entity_result.entity.source_id.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -183,10 +183,8 @@ impl PipelineServices for DefaultPipelineServices {
|
|||||||
None => None,
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let config = retrieval_pipeline::RetrievalConfig::for_search(
|
let config = retrieval_pipeline::RetrievalConfig::with_entities();
|
||||||
retrieval_pipeline::SearchTarget::EntitiesOnly,
|
match retrieval_pipeline::retrieve(
|
||||||
);
|
|
||||||
match retrieval_pipeline::retrieve_entities(
|
|
||||||
&self.db,
|
&self.db,
|
||||||
&self.openai_client,
|
&self.openai_client,
|
||||||
Some(&*self.embedding_provider),
|
Some(&*self.embedding_provider),
|
||||||
@@ -197,19 +195,16 @@ impl PipelineServices for DefaultPipelineServices {
|
|||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(retrieval_pipeline::StrategyOutput::Entities(entities)) => Ok(entities),
|
Ok(retrieval_pipeline::RetrievalOutput::WithEntities { chunks, entities }) => {
|
||||||
Ok(retrieval_pipeline::StrategyOutput::Search(search)) => {
|
|
||||||
let chunk_count = search.chunks.len();
|
|
||||||
let entities = search.entities;
|
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
chunk_count,
|
chunk_count = chunks.len(),
|
||||||
entity_count = entities.len(),
|
entity_count = entities.len(),
|
||||||
"ingestion search results returned entities"
|
"ingestion retrieval resolved entities from chunks"
|
||||||
);
|
);
|
||||||
Ok(entities)
|
Ok(entities)
|
||||||
}
|
}
|
||||||
Ok(retrieval_pipeline::StrategyOutput::Chunks(_)) => Err(AppError::InternalError(
|
Ok(retrieval_pipeline::RetrievalOutput::Chunks(_)) => Err(AppError::InternalError(
|
||||||
"Ingestion retrieval should return entities".into(),
|
"Ingestion retrieval should resolve entities".into(),
|
||||||
)),
|
)),
|
||||||
Err(e) => Err(e),
|
Err(e) => Err(e),
|
||||||
}
|
}
|
||||||
@@ -372,16 +367,16 @@ mod tests {
|
|||||||
|
|
||||||
fn system_prompt_from_request(
|
fn system_prompt_from_request(
|
||||||
request: &async_openai::types::CreateChatCompletionRequest,
|
request: &async_openai::types::CreateChatCompletionRequest,
|
||||||
) -> String {
|
) -> anyhow::Result<String> {
|
||||||
let ChatCompletionRequestMessage::System(system) = &request.messages[0] else {
|
let Some(ChatCompletionRequestMessage::System(system)) = request.messages.first() else {
|
||||||
panic!("expected first message to be system");
|
anyhow::bail!("expected first message to be system");
|
||||||
};
|
};
|
||||||
match &system.content {
|
let async_openai::types::ChatCompletionRequestSystemMessageContent::Text(text) =
|
||||||
async_openai::types::ChatCompletionRequestSystemMessageContent::Text(text) => {
|
&system.content
|
||||||
text.clone()
|
else {
|
||||||
}
|
anyhow::bail!("unexpected system message content: {:?}", system.content);
|
||||||
other => panic!("unexpected system message content: {other:?}"),
|
};
|
||||||
}
|
Ok(text.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -425,7 +420,7 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.context("prepare llm request")?;
|
.context("prepare llm request")?;
|
||||||
|
|
||||||
assert_eq!(system_prompt_from_request(&request), SENTINEL);
|
assert_eq!(system_prompt_from_request(&request)?, SENTINEL);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -125,10 +125,10 @@ async fn render_pdf_pages(file_path: &Path, pages: &[u32]) -> Result<Vec<Vec<u8>
|
|||||||
})
|
})
|
||||||
.await??;
|
.await??;
|
||||||
|
|
||||||
for (idx, png) in captures.iter().enumerate() {
|
for (page_number, png) in page_numbers.iter().zip(captures.iter()) {
|
||||||
if let Err(err) = maybe_dump_debug_image(page_numbers[idx], png).await {
|
if let Err(err) = maybe_dump_debug_image(*page_number, png).await {
|
||||||
warn!(
|
warn!(
|
||||||
page = page_numbers[idx],
|
page = page_number,
|
||||||
error = %err,
|
error = %err,
|
||||||
"Failed to write debug screenshot to disk"
|
"Failed to write debug screenshot to disk"
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -95,6 +95,7 @@ pub(crate) async fn init_with_config(config: AppConfig) -> anyhow::Result<Shared
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
#[allow(dead_code)] // helpers are shared across binary test targets
|
||||||
pub(crate) mod tests {
|
pub(crate) mod tests {
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
|
|||||||
@@ -11,8 +11,9 @@ use html_router::{
|
|||||||
use super::SharedServices;
|
use super::SharedServices;
|
||||||
|
|
||||||
/// Builds the Minne API and HTML route subtrees without fixing the outer Axum state
|
/// Builds the Minne API and HTML route subtrees without fixing the outer Axum state
|
||||||
/// type. SaaS consumers can merge additional routers and attach their own `AppState`
|
/// type. `SaaS` consumers can merge additional routers and attach their own `AppState`
|
||||||
/// as long as it implements `FromRef` for `ApiState` and `HtmlState`.
|
/// as long as it implements `FromRef` for `ApiState` and `HtmlState`.
|
||||||
|
#[allow(dead_code)] // used by server/main binaries, not worker
|
||||||
pub fn minne_routes<S>(api_state: &ApiState, html_state: &HtmlState) -> Router<S>
|
pub fn minne_routes<S>(api_state: &ApiState, html_state: &HtmlState) -> Router<S>
|
||||||
where
|
where
|
||||||
S: Clone + Send + Sync + 'static,
|
S: Clone + Send + Sync + 'static,
|
||||||
@@ -24,6 +25,7 @@ where
|
|||||||
.merge(html_routes(html_state))
|
.merge(html_routes(html_state))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)] // used by server/main binaries, not worker
|
||||||
pub fn build_api_state(services: &SharedServices) -> ApiState {
|
pub fn build_api_state(services: &SharedServices) -> ApiState {
|
||||||
ApiState {
|
ApiState {
|
||||||
db: Arc::clone(&services.db),
|
db: Arc::clone(&services.db),
|
||||||
@@ -32,6 +34,7 @@ pub fn build_api_state(services: &SharedServices) -> ApiState {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)] // used by server/main binaries, not worker
|
||||||
pub async fn build_html_state(services: &SharedServices) -> anyhow::Result<HtmlState> {
|
pub async fn build_html_state(services: &SharedServices) -> anyhow::Result<HtmlState> {
|
||||||
let session_store = Arc::new(
|
let session_store = Arc::new(
|
||||||
services
|
services
|
||||||
|
|||||||
@@ -75,6 +75,7 @@ struct AppState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
#[allow(clippy::expect_used)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use axum::{
|
use axum::{
|
||||||
|
|||||||
@@ -10,16 +10,14 @@ workspace = true
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
tokio = { workspace = true }
|
tokio = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
axum = { workspace = true }
|
|
||||||
tracing = { workspace = true }
|
tracing = { workspace = true }
|
||||||
anyhow = { workspace = true }
|
|
||||||
thiserror = { workspace = true }
|
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
surrealdb = { workspace = true }
|
|
||||||
futures = { workspace = true }
|
|
||||||
async-openai = { workspace = true }
|
async-openai = { workspace = true }
|
||||||
async-trait = { workspace = true }
|
async-trait = { workspace = true }
|
||||||
uuid = { workspace = true }
|
|
||||||
fastembed = { workspace = true }
|
fastembed = { workspace = true }
|
||||||
|
|
||||||
common = { path = "../common", features = ["test-utils"] }
|
common = { path = "../common", features = ["test-utils"] }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
anyhow = { workspace = true }
|
||||||
|
uuid = { workspace = true }
|
||||||
|
|||||||
@@ -1,61 +1,66 @@
|
|||||||
|
//! Chat answer assembly: retrieval context formatting and structured LLM request/response types.
|
||||||
|
|
||||||
use async_openai::{
|
use async_openai::{
|
||||||
error::OpenAIError,
|
error::OpenAIError,
|
||||||
types::{
|
types::{
|
||||||
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
|
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
|
||||||
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
|
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, ResponseFormat,
|
||||||
ResponseFormat, ResponseFormatJsonSchema,
|
ResponseFormatJsonSchema,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use common::{
|
use common::storage::types::{
|
||||||
error::AppError,
|
message::{format_history, Message},
|
||||||
storage::types::{
|
system_settings::SystemSettings,
|
||||||
message::{format_history, Message},
|
|
||||||
system_settings::SystemSettings,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_json::Value;
|
use serde_json::{json, Value};
|
||||||
|
|
||||||
use super::answer_retrieval_helper::get_query_response_schema;
|
/// JSON schema describing the structured chat answer (answer text + references).
|
||||||
|
fn get_query_response_schema() -> Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"answer": { "type": "string" },
|
||||||
|
"references": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"reference": { "type": "string" },
|
||||||
|
},
|
||||||
|
"required": ["reference"],
|
||||||
|
"additionalProperties": false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["answer", "references"],
|
||||||
|
"additionalProperties": false
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct Reference {
|
pub struct Reference {
|
||||||
#[allow(dead_code)]
|
|
||||||
pub reference: String,
|
pub reference: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct LLMResponseFormat {
|
pub struct LLMResponseFormat {
|
||||||
pub answer: String,
|
pub answer: String,
|
||||||
#[allow(dead_code)]
|
|
||||||
pub references: Vec<Reference>,
|
pub references: Vec<Reference>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
impl LLMResponseFormat {
|
||||||
pub struct Answer {
|
pub fn reference_ids(&self) -> Vec<String> {
|
||||||
pub content: String,
|
self.references
|
||||||
pub references: Vec<String>,
|
.iter()
|
||||||
}
|
.map(|entry| entry.reference.clone())
|
||||||
|
.collect()
|
||||||
pub fn create_user_message(entities_json: &Value, query: &str) -> String {
|
|
||||||
format!(
|
|
||||||
r"
|
|
||||||
Context Information:
|
|
||||||
==================
|
|
||||||
{entities_json}
|
|
||||||
|
|
||||||
User Question:
|
|
||||||
==================
|
|
||||||
{query}
|
|
||||||
"
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convert chunk-based retrieval results to JSON format for LLM context
|
|
||||||
pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value {
|
|
||||||
fn round_score(value: f32) -> f64 {
|
|
||||||
(f64::from(value) * 1000.0).round() / 1000.0
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert chunk-based retrieval results to JSON format for LLM context.
|
||||||
|
pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value {
|
||||||
|
use crate::round_score;
|
||||||
|
|
||||||
serde_json::json!(chunks
|
serde_json::json!(chunks
|
||||||
.iter()
|
.iter()
|
||||||
@@ -70,7 +75,7 @@ pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn create_user_message_with_history(
|
pub fn create_user_message_with_history(
|
||||||
entities_json: &Value,
|
context_json: &Value,
|
||||||
history: &[Message],
|
history: &[Message],
|
||||||
query: &str,
|
query: &str,
|
||||||
) -> String {
|
) -> String {
|
||||||
@@ -89,7 +94,7 @@ pub fn create_user_message_with_history(
|
|||||||
{}
|
{}
|
||||||
",
|
",
|
||||||
format_history(history),
|
format_history(history),
|
||||||
entities_json,
|
context_json,
|
||||||
query
|
query
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -116,18 +121,3 @@ pub fn create_chat_request(
|
|||||||
.response_format(response_format)
|
.response_format(response_format)
|
||||||
.build()
|
.build()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn process_llm_response(
|
|
||||||
response: &CreateChatCompletionResponse,
|
|
||||||
) -> Result<LLMResponseFormat, Box<AppError>> {
|
|
||||||
response
|
|
||||||
.choices
|
|
||||||
.first()
|
|
||||||
.and_then(|choice| choice.message.content.as_ref())
|
|
||||||
.ok_or_else(|| Box::new(AppError::LLMParsing("No content found in LLM response".into())))
|
|
||||||
.and_then(|content| {
|
|
||||||
serde_json::from_str::<LLMResponseFormat>(content).map_err(|e| {
|
|
||||||
Box::new(AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {e}")))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,23 +0,0 @@
|
|||||||
use serde_json::{json, Value};
|
|
||||||
|
|
||||||
pub fn get_query_response_schema() -> Value {
|
|
||||||
json!({
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"answer": { "type": "string" },
|
|
||||||
"references": {
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"reference": { "type": "string" },
|
|
||||||
},
|
|
||||||
"required": ["reference"],
|
|
||||||
"additionalProperties": false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["answer", "references"],
|
|
||||||
"additionalProperties": false
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,228 +0,0 @@
|
|||||||
use std::collections::{HashMap, HashSet};
|
|
||||||
|
|
||||||
use surrealdb::{sql::Thing, Error};
|
|
||||||
|
|
||||||
use common::storage::{
|
|
||||||
db::SurrealDbClient,
|
|
||||||
types::{
|
|
||||||
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
|
|
||||||
StoredObject,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Find entities related to the given entity via graph relationships.
|
|
||||||
///
|
|
||||||
/// Queries the `relates_to` edge table for all relationships involving the entity,
|
|
||||||
/// then fetches and returns the neighboring entities.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * `db` - Database client
|
|
||||||
/// * `entity_id` - ID of the entity to find neighbors for
|
|
||||||
/// * `user_id` - User ID for access control
|
|
||||||
/// * `limit` - Maximum number of neighbors to return
|
|
||||||
pub async fn find_entities_by_relationship_by_id(
|
|
||||||
db: &SurrealDbClient,
|
|
||||||
entity_id: &str,
|
|
||||||
user_id: &str,
|
|
||||||
limit: usize,
|
|
||||||
) -> Result<Vec<KnowledgeEntity>, Error> {
|
|
||||||
let mut relationships_response = db
|
|
||||||
.query(
|
|
||||||
"
|
|
||||||
SELECT * FROM relates_to
|
|
||||||
WHERE metadata.user_id = $user_id
|
|
||||||
AND (in = type::thing('knowledge_entity', $entity_id)
|
|
||||||
OR out = type::thing('knowledge_entity', $entity_id))
|
|
||||||
",
|
|
||||||
)
|
|
||||||
.bind(("entity_id", entity_id.to_owned()))
|
|
||||||
.bind(("user_id", user_id.to_owned()))
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let relationships: Vec<KnowledgeRelationship> = relationships_response.take(0)?;
|
|
||||||
if relationships.is_empty() {
|
|
||||||
return Ok(Vec::new());
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut neighbor_ids: Vec<String> = Vec::with_capacity(relationships.len());
|
|
||||||
let mut seen: HashSet<String> = HashSet::with_capacity(relationships.len());
|
|
||||||
for rel in relationships {
|
|
||||||
if rel.in_ == entity_id {
|
|
||||||
if seen.insert(rel.out.clone()) {
|
|
||||||
neighbor_ids.push(rel.out);
|
|
||||||
}
|
|
||||||
} else if rel.out == entity_id {
|
|
||||||
if seen.insert(rel.in_.clone()) {
|
|
||||||
neighbor_ids.push(rel.in_);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if seen.insert(rel.in_.clone()) {
|
|
||||||
neighbor_ids.push(rel.in_.clone());
|
|
||||||
}
|
|
||||||
if seen.insert(rel.out.clone()) {
|
|
||||||
neighbor_ids.push(rel.out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
neighbor_ids.retain(|id| id != entity_id);
|
|
||||||
|
|
||||||
if neighbor_ids.is_empty() {
|
|
||||||
return Ok(Vec::new());
|
|
||||||
}
|
|
||||||
|
|
||||||
if limit > 0 && neighbor_ids.len() > limit {
|
|
||||||
neighbor_ids.truncate(limit);
|
|
||||||
}
|
|
||||||
|
|
||||||
let thing_ids: Vec<Thing> = neighbor_ids
|
|
||||||
.iter()
|
|
||||||
.map(|id| Thing::from((KnowledgeEntity::table_name(), id.as_str())))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let mut neighbors_response = db
|
|
||||||
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
|
|
||||||
.bind(("table", KnowledgeEntity::table_name().to_owned()))
|
|
||||||
.bind(("things", thing_ids))
|
|
||||||
.bind(("user_id", user_id.to_owned()))
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let neighbors: Vec<KnowledgeEntity> = neighbors_response.take(0)?;
|
|
||||||
if neighbors.is_empty() {
|
|
||||||
return Ok(Vec::new());
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut neighbor_map: HashMap<String, KnowledgeEntity> = neighbors
|
|
||||||
.into_iter()
|
|
||||||
.map(|entity| (entity.id.clone(), entity))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let mut ordered = Vec::with_capacity(neighbor_ids.len());
|
|
||||||
for id in neighbor_ids {
|
|
||||||
if let Some(entity) = neighbor_map.remove(&id) {
|
|
||||||
ordered.push(entity);
|
|
||||||
}
|
|
||||||
if limit > 0 && ordered.len() >= limit {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(ordered)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use anyhow::{self, Context};
|
|
||||||
use super::*;
|
|
||||||
use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
|
||||||
use common::storage::types::knowledge_relationship::KnowledgeRelationship;
|
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_find_entities_by_relationship_by_id() -> anyhow::Result<()> {
|
|
||||||
let namespace = "test_ns";
|
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
let entity_type = KnowledgeEntityType::Document;
|
|
||||||
let user_id = "user123".to_string();
|
|
||||||
|
|
||||||
let central_entity = KnowledgeEntity::new(
|
|
||||||
"central_source".to_string(),
|
|
||||||
"Central Entity".to_string(),
|
|
||||||
"Central Description".to_string(),
|
|
||||||
entity_type.clone(),
|
|
||||||
None,
|
|
||||||
user_id.clone(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let related_entity1 = KnowledgeEntity::new(
|
|
||||||
"related_source1".to_string(),
|
|
||||||
"Related Entity 1".to_string(),
|
|
||||||
"Related Description 1".to_string(),
|
|
||||||
entity_type.clone(),
|
|
||||||
None,
|
|
||||||
user_id.clone(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let related_entity2 = KnowledgeEntity::new(
|
|
||||||
"related_source2".to_string(),
|
|
||||||
"Related Entity 2".to_string(),
|
|
||||||
"Related Description 2".to_string(),
|
|
||||||
entity_type.clone(),
|
|
||||||
None,
|
|
||||||
user_id.clone(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let unrelated_entity = KnowledgeEntity::new(
|
|
||||||
"unrelated_source".to_string(),
|
|
||||||
"Unrelated Entity".to_string(),
|
|
||||||
"Unrelated Description".to_string(),
|
|
||||||
entity_type.clone(),
|
|
||||||
None,
|
|
||||||
user_id.clone(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let central_entity = db
|
|
||||||
.store_item(central_entity.clone())
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to store central entity".to_string())?
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("Central entity not returned after store"))?;
|
|
||||||
let related_entity1 = db
|
|
||||||
.store_item(related_entity1.clone())
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to store related entity 1".to_string())?
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("Related entity 1 not returned after store"))?;
|
|
||||||
let related_entity2 = db
|
|
||||||
.store_item(related_entity2.clone())
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to store related entity 2".to_string())?
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("Related entity 2 not returned after store"))?;
|
|
||||||
let _unrelated_entity = db
|
|
||||||
.store_item(unrelated_entity.clone())
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to store unrelated entity".to_string())?
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("Unrelated entity not returned after store"))?;
|
|
||||||
|
|
||||||
let source_id = "relationship_source".to_string();
|
|
||||||
|
|
||||||
let relationship1 = KnowledgeRelationship::new(
|
|
||||||
central_entity.id.clone(),
|
|
||||||
related_entity1.id.clone(),
|
|
||||||
user_id.clone(),
|
|
||||||
source_id.clone(),
|
|
||||||
"references".to_string(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let relationship2 = KnowledgeRelationship::new(
|
|
||||||
central_entity.id.clone(),
|
|
||||||
related_entity2.id.clone(),
|
|
||||||
user_id.clone(),
|
|
||||||
source_id.clone(),
|
|
||||||
"contains".to_string(),
|
|
||||||
);
|
|
||||||
|
|
||||||
relationship1
|
|
||||||
.store_relationship(&db)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to store relationship 1".to_string())?;
|
|
||||||
relationship2
|
|
||||||
.store_relationship(&db)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to store relationship 2".to_string())?;
|
|
||||||
|
|
||||||
let related_entities =
|
|
||||||
find_entities_by_relationship_by_id(&db, ¢ral_entity.id, &user_id, usize::MAX)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to find entities by relationship".to_string())?;
|
|
||||||
|
|
||||||
assert!(
|
|
||||||
related_entities.len() >= 2,
|
|
||||||
"Should find related entities in both directions"
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
+63
-115
@@ -1,10 +1,9 @@
|
|||||||
pub mod answer_retrieval;
|
pub mod answer_retrieval;
|
||||||
pub mod answer_retrieval_helper;
|
|
||||||
|
|
||||||
pub mod graph;
|
|
||||||
pub mod pipeline;
|
pub mod pipeline;
|
||||||
pub mod reranking;
|
pub mod reranking;
|
||||||
pub mod scoring;
|
|
||||||
|
pub(crate) mod scoring;
|
||||||
|
|
||||||
use common::{
|
use common::{
|
||||||
error::AppError,
|
error::AppError,
|
||||||
@@ -16,39 +15,28 @@ use common::{
|
|||||||
use reranking::RerankerLease;
|
use reranking::RerankerLease;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
// Strategy output variants - defined before pipeline module
|
/// Result of a retrieval run.
|
||||||
|
///
|
||||||
|
/// Chunk retrieval is always performed; entities are only present when the caller
|
||||||
|
/// requested entity resolution via [`RetrievalConfig::with_entities`].
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum StrategyOutput {
|
pub enum RetrievalOutput {
|
||||||
Entities(Vec<RetrievedEntity>),
|
|
||||||
Chunks(Vec<RetrievedChunk>),
|
Chunks(Vec<RetrievedChunk>),
|
||||||
Search(SearchResult),
|
WithEntities {
|
||||||
}
|
chunks: Vec<RetrievedChunk>,
|
||||||
|
entities: Vec<RetrievedEntity>,
|
||||||
/// Unified search result containing both chunks and entities
|
},
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct SearchResult {
|
|
||||||
pub chunks: Vec<RetrievedChunk>,
|
|
||||||
pub entities: Vec<RetrievedEntity>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SearchResult {
|
|
||||||
pub fn new(chunks: Vec<RetrievedChunk>, entities: Vec<RetrievedEntity>) -> Self {
|
|
||||||
Self { chunks, entities }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn is_empty(&self) -> bool {
|
|
||||||
self.chunks.is_empty() && self.entities.is_empty()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub use pipeline::{
|
pub use pipeline::{
|
||||||
retrieved_entities_to_json, Diagnostics, StageTimings, RetrievalConfig,
|
retrieved_entities_to_json, Diagnostics, RetrievalConfig, RetrievalParams, StageKind,
|
||||||
RetrievalStrategy, RetrievalTuning, RetrievalTuningFlags, SearchTarget,
|
StageTimings,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Backward-compatible type aliases for external consumers
|
/// Round a score to three decimal places for JSON output.
|
||||||
pub type PipelineDiagnostics = Diagnostics;
|
pub(crate) fn round_score(value: f32) -> f64 {
|
||||||
pub type PipelineStageTimings = StageTimings;
|
(f64::from(value) * 1000.0).round() / 1000.0
|
||||||
|
}
|
||||||
|
|
||||||
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
|
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@@ -57,7 +45,7 @@ pub struct RetrievedChunk {
|
|||||||
pub score: f32,
|
pub score: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Final entity representation returned to callers, enriched with ranked chunks.
|
// Knowledge entity resolved from retrieved chunks, enriched with its contributing chunks.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct RetrievedEntity {
|
pub struct RetrievedEntity {
|
||||||
pub entity: KnowledgeEntity,
|
pub entity: KnowledgeEntity,
|
||||||
@@ -65,9 +53,9 @@ pub struct RetrievedEntity {
|
|||||||
pub chunks: Vec<RetrievedChunk>,
|
pub chunks: Vec<RetrievedChunk>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Primary orchestrator for the process of retrieving `KnowledgeEntity` values related to an `input_text`
|
/// Run chunk-first hybrid retrieval for `input_text`, optionally resolving owning entities.
|
||||||
#[instrument(skip_all, fields(user_id))]
|
#[instrument(skip_all, fields(user_id))]
|
||||||
pub async fn retrieve_entities(
|
pub async fn retrieve(
|
||||||
db_client: &SurrealDbClient,
|
db_client: &SurrealDbClient,
|
||||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||||
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
|
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
|
||||||
@@ -75,8 +63,8 @@ pub async fn retrieve_entities(
|
|||||||
user_id: &str,
|
user_id: &str,
|
||||||
config: RetrievalConfig,
|
config: RetrievalConfig,
|
||||||
reranker: Option<RerankerLease>,
|
reranker: Option<RerankerLease>,
|
||||||
) -> Result<StrategyOutput, AppError> {
|
) -> Result<RetrievalOutput, AppError> {
|
||||||
let params = pipeline::StrategyParams {
|
let params = pipeline::RetrievalParams {
|
||||||
db_client,
|
db_client,
|
||||||
openai_client,
|
openai_client,
|
||||||
embedding_provider,
|
embedding_provider,
|
||||||
@@ -94,6 +82,7 @@ mod tests {
|
|||||||
use anyhow::{self};
|
use anyhow::{self};
|
||||||
use async_openai::Client;
|
use async_openai::Client;
|
||||||
use common::storage::indexes::ensure_runtime;
|
use common::storage::indexes::ensure_runtime;
|
||||||
|
use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||||
use common::storage::types::system_settings::SystemSettings;
|
use common::storage::types::system_settings::SystemSettings;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
@@ -133,7 +122,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_default_strategy_retrieves_chunks() -> anyhow::Result<()> {
|
async fn test_chunk_retrieval_returns_chunks() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await?;
|
let db = setup_test_db().await?;
|
||||||
let user_id = "test_user";
|
let user_id = "test_user";
|
||||||
let chunk = TextChunk::new(
|
let chunk = TextChunk::new(
|
||||||
@@ -145,7 +134,7 @@ mod tests {
|
|||||||
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?;
|
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?;
|
||||||
|
|
||||||
let openai_client = Client::new();
|
let openai_client = Client::new();
|
||||||
let params = pipeline::StrategyParams {
|
let params = pipeline::RetrievalParams {
|
||||||
db_client: &db,
|
db_client: &db,
|
||||||
openai_client: &openai_client,
|
openai_client: &openai_client,
|
||||||
embedding_provider: None,
|
embedding_provider: None,
|
||||||
@@ -154,12 +143,13 @@ mod tests {
|
|||||||
config: RetrievalConfig::default(),
|
config: RetrievalConfig::default(),
|
||||||
reranker: None,
|
reranker: None,
|
||||||
};
|
};
|
||||||
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
|
let results = pipeline::run_with_embedding(params, test_embedding()).await?;
|
||||||
.await?;
|
|
||||||
|
|
||||||
let chunks = match results {
|
let chunks = match results {
|
||||||
StrategyOutput::Chunks(items) => items,
|
RetrievalOutput::Chunks(items) => items,
|
||||||
other => anyhow::bail!("expected chunk results, got {other:?}"),
|
RetrievalOutput::WithEntities { .. } => {
|
||||||
|
anyhow::bail!("expected chunk results, got entities")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
assert!(!chunks.is_empty(), "Expected at least one retrieval result");
|
assert!(!chunks.is_empty(), "Expected at least one retrieval result");
|
||||||
@@ -171,8 +161,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_default_strategy_returns_chunks_from_multiple_sources(
|
async fn test_chunk_retrieval_returns_chunks_from_multiple_sources() -> anyhow::Result<()> {
|
||||||
) -> anyhow::Result<()> {
|
|
||||||
let db = setup_test_db().await?;
|
let db = setup_test_db().await?;
|
||||||
let user_id = "multi_source_user";
|
let user_id = "multi_source_user";
|
||||||
|
|
||||||
@@ -191,7 +180,7 @@ mod tests {
|
|||||||
TextChunk::store_with_embedding(secondary_chunk, chunk_embedding_secondary(), &db).await?;
|
TextChunk::store_with_embedding(secondary_chunk, chunk_embedding_secondary(), &db).await?;
|
||||||
|
|
||||||
let openai_client = Client::new();
|
let openai_client = Client::new();
|
||||||
let params = pipeline::StrategyParams {
|
let params = pipeline::RetrievalParams {
|
||||||
db_client: &db,
|
db_client: &db,
|
||||||
openai_client: &openai_client,
|
openai_client: &openai_client,
|
||||||
embedding_provider: None,
|
embedding_provider: None,
|
||||||
@@ -200,12 +189,13 @@ mod tests {
|
|||||||
config: RetrievalConfig::default(),
|
config: RetrievalConfig::default(),
|
||||||
reranker: None,
|
reranker: None,
|
||||||
};
|
};
|
||||||
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
|
let results = pipeline::run_with_embedding(params, test_embedding()).await?;
|
||||||
.await?;
|
|
||||||
|
|
||||||
let chunks = match results {
|
let chunks = match results {
|
||||||
StrategyOutput::Chunks(items) => items,
|
RetrievalOutput::Chunks(items) => items,
|
||||||
other => anyhow::bail!("expected chunk results, got {other:?}"),
|
RetrievalOutput::WithEntities { .. } => {
|
||||||
|
anyhow::bail!("expected chunk results, got entities")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
assert!(chunks.len() >= 2, "Expected chunks from multiple sources");
|
assert!(chunks.len() >= 2, "Expected chunks from multiple sources");
|
||||||
@@ -223,96 +213,54 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_revised_strategy_returns_chunks() -> anyhow::Result<()> {
|
async fn test_with_entities_resolves_owning_entities() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await?;
|
let db = setup_test_db().await?;
|
||||||
let user_id = "chunk_user";
|
let user_id = "entity_user";
|
||||||
let chunk_one = TextChunk::new(
|
|
||||||
"src_alpha".into(),
|
|
||||||
"Tokio tasks execute on worker threads managed by the runtime.".into(),
|
|
||||||
user_id.into(),
|
|
||||||
);
|
|
||||||
let chunk_two = TextChunk::new(
|
|
||||||
"src_beta".into(),
|
|
||||||
"Hyper utilizes Tokio to drive HTTP state machines efficiently.".into(),
|
|
||||||
user_id.into(),
|
|
||||||
);
|
|
||||||
|
|
||||||
TextChunk::store_with_embedding(chunk_one.clone(), chunk_embedding_primary(), &db).await?;
|
|
||||||
TextChunk::store_with_embedding(chunk_two.clone(), chunk_embedding_secondary(), &db).await?;
|
|
||||||
|
|
||||||
let config = RetrievalConfig::with_strategy(RetrievalStrategy::Default);
|
|
||||||
let openai_client = Client::new();
|
|
||||||
let params = pipeline::StrategyParams {
|
|
||||||
db_client: &db,
|
|
||||||
openai_client: &openai_client,
|
|
||||||
embedding_provider: None,
|
|
||||||
input_text: "tokio runtime worker behavior",
|
|
||||||
user_id,
|
|
||||||
config,
|
|
||||||
reranker: None,
|
|
||||||
};
|
|
||||||
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let chunks = match results {
|
|
||||||
StrategyOutput::Chunks(items) => items,
|
|
||||||
other => anyhow::bail!("expected chunk results, got {other:?}"),
|
|
||||||
};
|
|
||||||
|
|
||||||
assert!(
|
|
||||||
!chunks.is_empty(),
|
|
||||||
"Revised strategy should return chunk-only responses"
|
|
||||||
);
|
|
||||||
assert!(
|
|
||||||
chunks
|
|
||||||
.iter()
|
|
||||||
.any(|entry| entry.chunk.chunk.contains("Tokio")),
|
|
||||||
"Chunk results should contain relevant snippets"
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_search_strategy_returns_search_result() -> anyhow::Result<()> {
|
|
||||||
let db = setup_test_db().await?;
|
|
||||||
let user_id = "search_user";
|
|
||||||
let chunk = TextChunk::new(
|
let chunk = TextChunk::new(
|
||||||
"search_src".into(),
|
"entity_source".into(),
|
||||||
"Async Rust programming uses Tokio runtime for concurrent tasks.".into(),
|
"Async Rust programming uses the Tokio runtime for concurrent tasks.".into(),
|
||||||
user_id.into(),
|
user_id.into(),
|
||||||
);
|
);
|
||||||
|
|
||||||
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?;
|
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?;
|
||||||
|
|
||||||
let config = RetrievalConfig::for_search(pipeline::SearchTarget::Both);
|
let entity = KnowledgeEntity::new(
|
||||||
|
"entity_source".into(),
|
||||||
|
"Tokio Runtime".into(),
|
||||||
|
"Async runtime for Rust".into(),
|
||||||
|
KnowledgeEntityType::Document,
|
||||||
|
None,
|
||||||
|
user_id.into(),
|
||||||
|
);
|
||||||
|
db.store_item(entity).await?;
|
||||||
|
|
||||||
let openai_client = Client::new();
|
let openai_client = Client::new();
|
||||||
let params = pipeline::StrategyParams {
|
let params = pipeline::RetrievalParams {
|
||||||
db_client: &db,
|
db_client: &db,
|
||||||
openai_client: &openai_client,
|
openai_client: &openai_client,
|
||||||
embedding_provider: None,
|
embedding_provider: None,
|
||||||
input_text: "async rust programming",
|
input_text: "async rust programming",
|
||||||
user_id,
|
user_id,
|
||||||
config,
|
config: RetrievalConfig::with_entities(),
|
||||||
reranker: None,
|
reranker: None,
|
||||||
};
|
};
|
||||||
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
|
let results = pipeline::run_with_embedding(params, test_embedding()).await?;
|
||||||
.await?;
|
|
||||||
|
|
||||||
let StrategyOutput::Search(search_result) = results else {
|
let RetrievalOutput::WithEntities { chunks, entities } = results else {
|
||||||
anyhow::bail!("expected Search output");
|
anyhow::bail!("expected WithEntities output");
|
||||||
};
|
};
|
||||||
|
|
||||||
// Should return chunks (entities may be empty if none stored)
|
assert!(!chunks.is_empty(), "Should return chunks");
|
||||||
assert!(
|
assert!(
|
||||||
!search_result.chunks.is_empty(),
|
entities.iter().any(|e| e.entity.name == "Tokio Runtime"),
|
||||||
"Search strategy should return chunks"
|
"Should resolve the entity owning the retrieved chunk"
|
||||||
);
|
);
|
||||||
assert!(
|
assert!(
|
||||||
search_result
|
entities
|
||||||
.chunks
|
|
||||||
.iter()
|
.iter()
|
||||||
.any(|c| c.chunk.chunk.contains("Tokio")),
|
.find(|e| e.entity.name == "Tokio Runtime")
|
||||||
"Search results should contain relevant chunks"
|
.is_some_and(|e| !e.chunks.is_empty()),
|
||||||
|
"Resolved entity should carry its contributing chunks"
|
||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,22 +1,5 @@
|
|||||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||||
|
|
||||||
use crate::scoring::FusionWeights;
|
|
||||||
|
|
||||||
pub use common::utils::config::RetrievalStrategy;
|
|
||||||
|
|
||||||
/// Configures which result types to include in Search strategy
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum SearchTarget {
|
|
||||||
/// Return only text chunks
|
|
||||||
ChunksOnly,
|
|
||||||
/// Return only knowledge entities
|
|
||||||
EntitiesOnly,
|
|
||||||
/// Return both chunks and entities (default)
|
|
||||||
#[default]
|
|
||||||
Both,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Two-variant flag that serializes as a bool for backward compatibility.
|
/// Two-variant flag that serializes as a bool for backward compatibility.
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||||
pub enum BoolFlag {
|
pub enum BoolFlag {
|
||||||
@@ -62,30 +45,20 @@ impl<'de> Deserialize<'de> for BoolFlag {
|
|||||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||||
pub struct RetrievalTuningFlags {
|
pub struct RetrievalTuningFlags {
|
||||||
pub rerank_scores_only: BoolFlag,
|
pub rerank_scores_only: BoolFlag,
|
||||||
pub normalize_vector_scores: BoolFlag,
|
|
||||||
pub normalize_fts_scores: BoolFlag,
|
|
||||||
pub chunk_rrf_use_vector: BoolFlag,
|
pub chunk_rrf_use_vector: BoolFlag,
|
||||||
pub chunk_rrf_use_fts: BoolFlag,
|
pub chunk_rrf_use_fts: BoolFlag,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RetrievalTuningFlags {
|
impl RetrievalTuningFlags {
|
||||||
pub const fn rerank_scores_only(&self) -> bool {
|
pub const fn rerank_scores_only(self) -> bool {
|
||||||
self.rerank_scores_only.as_bool()
|
self.rerank_scores_only.as_bool()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const fn normalize_vector_scores(&self) -> bool {
|
pub const fn chunk_rrf_use_vector(self) -> bool {
|
||||||
self.normalize_vector_scores.as_bool()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub const fn normalize_fts_scores(&self) -> bool {
|
|
||||||
self.normalize_fts_scores.as_bool()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub const fn chunk_rrf_use_vector(&self) -> bool {
|
|
||||||
self.chunk_rrf_use_vector.as_bool()
|
self.chunk_rrf_use_vector.as_bool()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const fn chunk_rrf_use_fts(&self) -> bool {
|
pub const fn chunk_rrf_use_fts(self) -> bool {
|
||||||
self.chunk_rrf_use_fts.as_bool()
|
self.chunk_rrf_use_fts.as_bool()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -94,146 +67,70 @@ impl Default for RetrievalTuningFlags {
|
|||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
rerank_scores_only: BoolFlag::Disabled,
|
rerank_scores_only: BoolFlag::Disabled,
|
||||||
normalize_vector_scores: BoolFlag::Disabled,
|
|
||||||
normalize_fts_scores: BoolFlag::Enabled,
|
|
||||||
chunk_rrf_use_vector: BoolFlag::Enabled,
|
chunk_rrf_use_vector: BoolFlag::Enabled,
|
||||||
chunk_rrf_use_fts: BoolFlag::Enabled,
|
chunk_rrf_use_fts: BoolFlag::Enabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Tunable parameters that govern each retrieval stage.
|
/// Tunable parameters governing the chunk-first hybrid (vector + FTS, RRF-fused) retrieval.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct RetrievalTuning {
|
pub struct RetrievalTuning {
|
||||||
pub entity_vector_take: usize,
|
/// Number of vector candidates to pull from the chunk embedding index.
|
||||||
pub chunk_vector_take: usize,
|
pub chunk_vector_take: usize,
|
||||||
pub entity_fts_take: usize,
|
/// Number of full-text candidates to pull from the chunk index.
|
||||||
pub chunk_fts_take: usize,
|
pub chunk_fts_take: usize,
|
||||||
pub score_threshold: f32,
|
/// Maximum chunks attached to each resolved entity.
|
||||||
pub fallback_min_results: usize,
|
|
||||||
pub token_budget_estimate: usize,
|
|
||||||
pub avg_chars_per_token: usize,
|
|
||||||
pub max_chunks_per_entity: usize,
|
pub max_chunks_per_entity: usize,
|
||||||
pub lexical_match_weight: f32,
|
/// Blend weight applied when mixing reranker scores with fused scores.
|
||||||
pub graph_traversal_seed_limit: usize,
|
|
||||||
pub graph_neighbor_limit: usize,
|
|
||||||
pub graph_score_decay: f32,
|
|
||||||
pub graph_seed_min_score: f32,
|
|
||||||
pub graph_vector_inheritance: f32,
|
|
||||||
pub rerank_blend_weight: f32,
|
pub rerank_blend_weight: f32,
|
||||||
pub flags: RetrievalTuningFlags,
|
/// Keep top-N candidates after reranking.
|
||||||
pub rerank_keep_top: usize,
|
pub rerank_keep_top: usize,
|
||||||
|
/// Maximum number of chunks returned to callers.
|
||||||
pub chunk_result_cap: usize,
|
pub chunk_result_cap: usize,
|
||||||
/// Optional fusion weights for hybrid search. If None, uses default weights.
|
/// Reciprocal rank fusion k value for chunk merging.
|
||||||
pub fusion_weights: Option<FusionWeights>,
|
|
||||||
/// Reciprocal rank fusion k value for chunk merging in Revised strategy.
|
|
||||||
#[serde(default = "default_chunk_rrf_k")]
|
|
||||||
pub chunk_rrf_k: f32,
|
pub chunk_rrf_k: f32,
|
||||||
/// Weight applied to vector ranks in RRF.
|
/// Weight applied to vector ranks in RRF.
|
||||||
#[serde(default = "default_chunk_rrf_vector_weight")]
|
|
||||||
pub chunk_rrf_vector_weight: f32,
|
pub chunk_rrf_vector_weight: f32,
|
||||||
/// Weight applied to chunk FTS ranks in RRF.
|
/// Weight applied to chunk FTS ranks in RRF.
|
||||||
#[serde(default = "default_chunk_rrf_fts_weight")]
|
|
||||||
pub chunk_rrf_fts_weight: f32,
|
pub chunk_rrf_fts_weight: f32,
|
||||||
|
pub flags: RetrievalTuningFlags,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for RetrievalTuning {
|
impl Default for RetrievalTuning {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
entity_vector_take: 15,
|
|
||||||
chunk_vector_take: 20,
|
chunk_vector_take: 20,
|
||||||
entity_fts_take: 10,
|
|
||||||
chunk_fts_take: 20,
|
chunk_fts_take: 20,
|
||||||
score_threshold: 0.35,
|
|
||||||
fallback_min_results: 10,
|
|
||||||
token_budget_estimate: 10000,
|
|
||||||
avg_chars_per_token: 4,
|
|
||||||
max_chunks_per_entity: 4,
|
max_chunks_per_entity: 4,
|
||||||
lexical_match_weight: 0.15,
|
|
||||||
graph_traversal_seed_limit: 5,
|
|
||||||
graph_neighbor_limit: 6,
|
|
||||||
graph_score_decay: 0.75,
|
|
||||||
graph_seed_min_score: 0.4,
|
|
||||||
graph_vector_inheritance: 0.6,
|
|
||||||
rerank_blend_weight: 0.65,
|
rerank_blend_weight: 0.65,
|
||||||
flags: RetrievalTuningFlags::default(),
|
|
||||||
rerank_keep_top: 8,
|
rerank_keep_top: 8,
|
||||||
chunk_result_cap: 5,
|
chunk_result_cap: 5,
|
||||||
fusion_weights: None,
|
chunk_rrf_k: 60.0,
|
||||||
chunk_rrf_k: default_chunk_rrf_k(),
|
chunk_rrf_vector_weight: 1.0,
|
||||||
chunk_rrf_vector_weight: default_chunk_rrf_vector_weight(),
|
chunk_rrf_fts_weight: 1.0,
|
||||||
chunk_rrf_fts_weight: default_chunk_rrf_fts_weight(),
|
flags: RetrievalTuningFlags::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Wrapper containing tuning plus future flags for per-request overrides.
|
/// Per-request retrieval configuration.
|
||||||
|
///
|
||||||
|
/// The pipeline always performs chunk-first hybrid retrieval. Set `resolve_entities`
|
||||||
|
/// when a caller additionally needs the `KnowledgeEntity` rows that own the retrieved
|
||||||
|
/// chunks (search, ingestion linking, relationship suggestion).
|
||||||
#[derive(Debug, Clone, Default)]
|
#[derive(Debug, Clone, Default)]
|
||||||
pub struct RetrievalConfig {
|
pub struct RetrievalConfig {
|
||||||
pub strategy: RetrievalStrategy,
|
|
||||||
pub tuning: RetrievalTuning,
|
pub tuning: RetrievalTuning,
|
||||||
/// Target for Search strategy (chunks, entities, or both)
|
pub resolve_entities: bool,
|
||||||
pub search_target: SearchTarget,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RetrievalConfig {
|
impl RetrievalConfig {
|
||||||
pub fn new(tuning: RetrievalTuning) -> Self {
|
/// Chunk retrieval that also resolves the owning knowledge entities.
|
||||||
|
pub fn with_entities() -> Self {
|
||||||
Self {
|
Self {
|
||||||
strategy: RetrievalStrategy::Default,
|
|
||||||
tuning,
|
|
||||||
search_target: SearchTarget::default(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn with_strategy(strategy: RetrievalStrategy) -> Self {
|
|
||||||
Self {
|
|
||||||
strategy,
|
|
||||||
tuning: RetrievalTuning::default(),
|
tuning: RetrievalTuning::default(),
|
||||||
search_target: SearchTarget::default(),
|
resolve_entities: true,
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn with_tuning(strategy: RetrievalStrategy, tuning: RetrievalTuning) -> Self {
|
|
||||||
Self {
|
|
||||||
strategy,
|
|
||||||
tuning,
|
|
||||||
search_target: SearchTarget::default(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create config for chat retrieval with strategy selection support
|
|
||||||
pub fn for_chat(strategy: RetrievalStrategy) -> Self {
|
|
||||||
Self::with_strategy(strategy)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create config for relationship suggestion (entity-only retrieval)
|
|
||||||
pub fn for_relationship_suggestion() -> Self {
|
|
||||||
Self::with_strategy(RetrievalStrategy::RelationshipSuggestion)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create config for ingestion pipeline (entity-only retrieval)
|
|
||||||
pub fn for_ingestion() -> Self {
|
|
||||||
Self::with_strategy(RetrievalStrategy::Ingestion)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create config for unified search (chunks and/or entities)
|
|
||||||
pub fn for_search(target: SearchTarget) -> Self {
|
|
||||||
Self {
|
|
||||||
strategy: RetrievalStrategy::Search,
|
|
||||||
tuning: RetrievalTuning::default(),
|
|
||||||
search_target: target,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const fn default_chunk_rrf_k() -> f32 {
|
|
||||||
60.0
|
|
||||||
}
|
|
||||||
|
|
||||||
const fn default_chunk_rrf_vector_weight() -> f32 {
|
|
||||||
1.0
|
|
||||||
}
|
|
||||||
|
|
||||||
const fn default_chunk_rrf_fts_weight() -> f32 {
|
|
||||||
1.0
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -0,0 +1,107 @@
|
|||||||
|
use async_openai::Client;
|
||||||
|
use common::{
|
||||||
|
error::AppError,
|
||||||
|
storage::{db::SurrealDbClient, types::text_chunk::TextChunk},
|
||||||
|
utils::embedding::EmbeddingProvider,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::{reranking::RerankerLease, scoring::Scored, RetrievedChunk, RetrievedEntity};
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
config::RetrievalConfig,
|
||||||
|
diagnostics::{AssembleStats, Diagnostics, SearchStats},
|
||||||
|
StageKind, StageTimings, RetrievalParams,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Mutable working state threaded through every retrieval stage.
|
||||||
|
pub(crate) struct PipelineContext<'a> {
|
||||||
|
pub db_client: &'a SurrealDbClient,
|
||||||
|
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||||
|
pub embedding_provider: Option<&'a EmbeddingProvider>,
|
||||||
|
pub input_text: String,
|
||||||
|
pub user_id: String,
|
||||||
|
pub config: RetrievalConfig,
|
||||||
|
pub query_embedding: Option<Vec<f32>>,
|
||||||
|
pub chunk_values: Vec<Scored<TextChunk>>,
|
||||||
|
pub reranker: Option<RerankerLease>,
|
||||||
|
pub diagnostics: Option<Diagnostics>,
|
||||||
|
pub entity_results: Vec<RetrievedEntity>,
|
||||||
|
pub chunk_results: Vec<RetrievedChunk>,
|
||||||
|
stage_timings: StageTimings,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> PipelineContext<'a> {
|
||||||
|
pub fn new(params: RetrievalParams<'a>) -> Self {
|
||||||
|
Self {
|
||||||
|
db_client: params.db_client,
|
||||||
|
openai_client: params.openai_client,
|
||||||
|
embedding_provider: params.embedding_provider,
|
||||||
|
input_text: params.input_text.to_owned(),
|
||||||
|
user_id: params.user_id.to_owned(),
|
||||||
|
config: params.config,
|
||||||
|
query_embedding: None,
|
||||||
|
chunk_values: Vec::new(),
|
||||||
|
reranker: params.reranker,
|
||||||
|
diagnostics: None,
|
||||||
|
entity_results: Vec::new(),
|
||||||
|
chunk_results: Vec::new(),
|
||||||
|
stage_timings: StageTimings::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_embedding(params: RetrievalParams<'a>, query_embedding: Vec<f32>) -> Self {
|
||||||
|
let mut ctx = Self::new(params);
|
||||||
|
ctx.query_embedding = Some(query_embedding);
|
||||||
|
ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn ensure_embedding(&self) -> Result<&Vec<f32>, Box<AppError>> {
|
||||||
|
self.query_embedding.as_ref().ok_or_else(|| {
|
||||||
|
Box::new(AppError::InternalError(
|
||||||
|
"query embedding missing before candidate search".to_string(),
|
||||||
|
))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn enable_diagnostics(&mut self) {
|
||||||
|
if self.diagnostics.is_none() {
|
||||||
|
self.diagnostics = Some(Diagnostics::default());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn diagnostics_enabled(&self) -> bool {
|
||||||
|
self.diagnostics.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn record_search(&mut self, stats: SearchStats) {
|
||||||
|
if let Some(diag) = self.diagnostics.as_mut() {
|
||||||
|
diag.search = Some(stats);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn record_assemble(&mut self, stats: AssembleStats) {
|
||||||
|
if let Some(diag) = self.diagnostics.as_mut() {
|
||||||
|
diag.assemble = Some(stats);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn take_diagnostics(&mut self) -> Option<Diagnostics> {
|
||||||
|
self.diagnostics.take()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn take_stage_timings(&mut self) -> StageTimings {
|
||||||
|
std::mem::take(&mut self.stage_timings)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn record_stage_duration(&mut self, kind: StageKind, duration: std::time::Duration) {
|
||||||
|
self.stage_timings.record(kind, duration);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn take_entity_results(&mut self) -> Vec<RetrievedEntity> {
|
||||||
|
std::mem::take(&mut self.entity_results)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn take_chunk_results(&mut self) -> Vec<RetrievedChunk> {
|
||||||
|
std::mem::take(&mut self.chunk_results)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,51 +1,21 @@
|
|||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
|
||||||
/// Captures instrumentation for each hybrid retrieval stage when diagnostics are enabled.
|
/// Captures instrumentation for the retrieval stages when diagnostics are enabled.
|
||||||
#[derive(Debug, Clone, Default, Serialize)]
|
#[derive(Debug, Clone, Default, Serialize)]
|
||||||
pub struct Diagnostics {
|
pub struct Diagnostics {
|
||||||
pub collect_candidates: Option<CollectCandidatesStats>,
|
pub search: Option<SearchStats>,
|
||||||
pub enrich_chunks_from_entities: Option<ChunkEnrichmentStats>,
|
|
||||||
pub assemble: Option<AssembleStats>,
|
pub assemble: Option<AssembleStats>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default, Serialize)]
|
#[derive(Debug, Clone, Default, Serialize)]
|
||||||
pub struct CollectCandidatesStats {
|
pub struct SearchStats {
|
||||||
pub vector_entity_candidates: usize,
|
|
||||||
pub vector_chunk_candidates: usize,
|
pub vector_chunk_candidates: usize,
|
||||||
pub fts_entity_candidates: usize,
|
|
||||||
pub fts_chunk_candidates: usize,
|
pub fts_chunk_candidates: usize,
|
||||||
pub vector_chunk_scores: Vec<f32>,
|
pub vector_chunk_scores: Vec<f32>,
|
||||||
pub fts_chunk_scores: Vec<f32>,
|
pub fts_chunk_scores: Vec<f32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default, Serialize)]
|
|
||||||
pub struct ChunkEnrichmentStats {
|
|
||||||
pub filtered_entity_count: usize,
|
|
||||||
pub fallback_min_results: usize,
|
|
||||||
pub chunk_sources_considered: usize,
|
|
||||||
pub chunk_candidates_before_enrichment: usize,
|
|
||||||
pub chunk_candidates_after_enrichment: usize,
|
|
||||||
pub top_chunk_scores: Vec<f32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default, Serialize)]
|
#[derive(Debug, Clone, Default, Serialize)]
|
||||||
pub struct AssembleStats {
|
pub struct AssembleStats {
|
||||||
pub token_budget_start: usize,
|
|
||||||
pub token_budget_spent: usize,
|
|
||||||
pub token_budget_remaining: usize,
|
|
||||||
pub budget_exhausted: bool,
|
|
||||||
pub chunks_selected: usize,
|
pub chunks_selected: usize,
|
||||||
pub chunks_skipped_due_budget: usize,
|
|
||||||
pub entity_count: usize,
|
|
||||||
pub entity_traces: Vec<EntityAssemblyTrace>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default, Serialize)]
|
|
||||||
pub struct EntityAssemblyTrace {
|
|
||||||
pub entity_id: String,
|
|
||||||
pub source_id: String,
|
|
||||||
pub inspected_candidates: usize,
|
|
||||||
pub selected_chunk_ids: Vec<String>,
|
|
||||||
pub selected_chunk_scores: Vec<f32>,
|
|
||||||
pub skipped_due_budget: usize,
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,61 +1,68 @@
|
|||||||
mod config;
|
mod config;
|
||||||
|
mod context;
|
||||||
mod diagnostics;
|
mod diagnostics;
|
||||||
mod stages;
|
mod stages;
|
||||||
mod strategies;
|
|
||||||
|
|
||||||
pub use config::{
|
pub use config::RetrievalConfig;
|
||||||
RetrievalConfig, RetrievalStrategy, RetrievalTuning, RetrievalTuningFlags, SearchTarget,
|
pub use diagnostics::Diagnostics;
|
||||||
};
|
|
||||||
pub use diagnostics::{
|
|
||||||
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace, Diagnostics,
|
|
||||||
};
|
|
||||||
|
|
||||||
use crate::{reranking::RerankerLease, RetrievedEntity, StrategyOutput};
|
use crate::{round_score, RetrievalOutput, RetrievedEntity};
|
||||||
use async_openai::Client;
|
use async_openai::Client;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use common::{error::AppError, storage::db::SurrealDbClient};
|
use common::{error::AppError, storage::db::SurrealDbClient};
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
use stages::PipelineContext;
|
use stages::{
|
||||||
use strategies::{
|
ChunkAssembleStage, ChunkRerankStage, ChunkSearchStage, EmbedStage, ResolveEntitiesStage,
|
||||||
DefaultStrategyDriver, IngestionDriver, RelationshipSuggestionDriver, SearchStrategyDriver,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Export StrategyOutput publicly from this module
|
/// Identifies a retrieval stage for timing and instrumentation.
|
||||||
// (it's defined in lib.rs but we re-export it here)
|
///
|
||||||
|
/// [`StageKind::ALL`] lists every kind in pipeline order; consumers (e.g. the evaluation
|
||||||
// Stage type enum
|
/// harness) iterate it generically so that adding a stage requires no changes outside this
|
||||||
|
/// crate — add the variant, extend `ALL`, and give it a [`StageKind::label`].
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
pub enum StageKind {
|
pub enum StageKind {
|
||||||
Embed,
|
Embed,
|
||||||
CollectCandidates,
|
Search,
|
||||||
GraphExpansion,
|
|
||||||
ChunkAttach,
|
|
||||||
Rerank,
|
Rerank,
|
||||||
|
ResolveEntities,
|
||||||
Assemble,
|
Assemble,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pipeline stage trait
|
impl StageKind {
|
||||||
|
/// Every stage kind in canonical pipeline order.
|
||||||
|
pub const ALL: [StageKind; 5] = [
|
||||||
|
StageKind::Embed,
|
||||||
|
StageKind::Search,
|
||||||
|
StageKind::Rerank,
|
||||||
|
StageKind::ResolveEntities,
|
||||||
|
StageKind::Assemble,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Stable, machine-friendly identifier for the stage (used as a metrics key).
|
||||||
|
pub const fn label(self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
StageKind::Embed => "embed",
|
||||||
|
StageKind::Search => "search",
|
||||||
|
StageKind::Rerank => "rerank",
|
||||||
|
StageKind::ResolveEntities => "resolve_entities",
|
||||||
|
StageKind::Assemble => "assemble",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A single composable step in the retrieval pipeline.
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Stage: Send + Sync {
|
pub(crate) trait Stage: Send + Sync {
|
||||||
fn kind(&self) -> StageKind;
|
fn kind(&self) -> StageKind;
|
||||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError>;
|
async fn execute(&self, ctx: &mut context::PipelineContext<'_>) -> Result<(), AppError>;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Type alias for boxed stages
|
pub(crate) type BoxedStage = Box<dyn Stage>;
|
||||||
pub type BoxedStage = Box<dyn Stage>;
|
|
||||||
|
|
||||||
// Strategy driver trait
|
/// Per-stage execution timings recorded during a run.
|
||||||
#[async_trait]
|
|
||||||
pub trait StrategyDriver: Send + Sync {
|
|
||||||
type Output;
|
|
||||||
|
|
||||||
fn stages(&self) -> Vec<BoxedStage>;
|
|
||||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>>;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pipeline stage timings tracker
|
|
||||||
#[derive(Debug, Default, Clone)]
|
#[derive(Debug, Default, Clone)]
|
||||||
pub struct StageTimings {
|
pub struct StageTimings {
|
||||||
timings: Vec<(StageKind, Duration)>,
|
timings: Vec<(StageKind, Duration)>,
|
||||||
@@ -66,41 +73,13 @@ impl StageTimings {
|
|||||||
self.timings.push((kind, duration));
|
self.timings.push((kind, duration));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn into_vec(self) -> Vec<(StageKind, Duration)> {
|
/// Milliseconds recorded for `kind`, or `0` if the stage did not run.
|
||||||
self.timings
|
pub fn stage_ms(&self, kind: StageKind) -> u128 {
|
||||||
}
|
|
||||||
|
|
||||||
// Helper methods to get duration for each stage type (for backward compatibility)
|
|
||||||
fn get_stage_ms(&self, kind: StageKind) -> u128 {
|
|
||||||
self.timings
|
self.timings
|
||||||
.iter()
|
.iter()
|
||||||
.find(|(k, _)| *k == kind)
|
.find(|(k, _)| *k == kind)
|
||||||
.map_or(0, |(_, d)| d.as_millis())
|
.map_or(0, |(_, d)| d.as_millis())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embed_ms(&self) -> u128 {
|
|
||||||
self.get_stage_ms(StageKind::Embed)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn collect_candidates_ms(&self) -> u128 {
|
|
||||||
self.get_stage_ms(StageKind::CollectCandidates)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn graph_expansion_ms(&self) -> u128 {
|
|
||||||
self.get_stage_ms(StageKind::GraphExpansion)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn chunk_attach_ms(&self) -> u128 {
|
|
||||||
self.get_stage_ms(StageKind::ChunkAttach)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn rerank_ms(&self) -> u128 {
|
|
||||||
self.get_stage_ms(StageKind::Rerank)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn assemble_ms(&self) -> u128 {
|
|
||||||
self.get_stage_ms(StageKind::Assemble)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct RunOutput<T> {
|
pub struct RunOutput<T> {
|
||||||
@@ -109,7 +88,35 @@ pub struct RunOutput<T> {
|
|||||||
pub stage_timings: StageTimings,
|
pub stage_timings: StageTimings,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn execute(params: StrategyParams<'_>) -> Result<StrategyOutput, AppError> {
|
/// Inputs required to run a retrieval.
|
||||||
|
pub struct RetrievalParams<'a> {
|
||||||
|
pub db_client: &'a SurrealDbClient,
|
||||||
|
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||||
|
pub embedding_provider: Option<&'a common::utils::embedding::EmbeddingProvider>,
|
||||||
|
pub input_text: &'a str,
|
||||||
|
pub user_id: &'a str,
|
||||||
|
pub config: RetrievalConfig,
|
||||||
|
pub reranker: Option<crate::reranking::RerankerLease>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_stages(config: &RetrievalConfig) -> Vec<BoxedStage> {
|
||||||
|
let mut stages: Vec<BoxedStage> = vec![
|
||||||
|
Box::new(EmbedStage),
|
||||||
|
Box::new(ChunkSearchStage),
|
||||||
|
Box::new(ChunkRerankStage),
|
||||||
|
];
|
||||||
|
if config.resolve_entities {
|
||||||
|
stages.push(Box::new(ResolveEntitiesStage));
|
||||||
|
}
|
||||||
|
stages.push(Box::new(ChunkAssembleStage));
|
||||||
|
stages
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run(
|
||||||
|
params: RetrievalParams<'_>,
|
||||||
|
query_embedding: Option<Vec<f32>>,
|
||||||
|
capture_diagnostics: bool,
|
||||||
|
) -> Result<RunOutput<RetrievalOutput>, AppError> {
|
||||||
let input_chars = params.input_text.chars().count();
|
let input_chars = params.input_text.chars().count();
|
||||||
let input_preview: String = params.input_text.chars().take(120).collect();
|
let input_preview: String = params.input_text.chars().take(120).collect();
|
||||||
let input_preview_clean = input_preview.replace('\n', " ");
|
let input_preview_clean = input_preview.replace('\n', " ");
|
||||||
@@ -119,110 +126,67 @@ pub async fn execute(params: StrategyParams<'_>) -> Result<StrategyOutput, AppEr
|
|||||||
input_chars,
|
input_chars,
|
||||||
preview_truncated = input_chars > preview_len,
|
preview_truncated = input_chars > preview_len,
|
||||||
preview = %input_preview_clean,
|
preview = %input_preview_clean,
|
||||||
strategy = %params.config.strategy,
|
resolve_entities = params.config.resolve_entities,
|
||||||
"Starting retrieval pipeline"
|
"Starting retrieval pipeline"
|
||||||
);
|
);
|
||||||
|
|
||||||
let strategy = params.config.strategy;
|
let resolve_entities = params.config.resolve_entities;
|
||||||
let search_target = params.config.search_target;
|
let mut ctx = match query_embedding {
|
||||||
|
Some(embedding) => context::PipelineContext::with_embedding(params, embedding),
|
||||||
|
None => context::PipelineContext::new(params),
|
||||||
|
};
|
||||||
|
|
||||||
match strategy {
|
if capture_diagnostics {
|
||||||
RetrievalStrategy::Default => {
|
ctx.enable_diagnostics();
|
||||||
let driver = DefaultStrategyDriver::new();
|
|
||||||
let run = execute_strategy(driver, params, None, false).await?;
|
|
||||||
Ok(StrategyOutput::Chunks(run.results))
|
|
||||||
}
|
|
||||||
RetrievalStrategy::RelationshipSuggestion => {
|
|
||||||
let driver = RelationshipSuggestionDriver::new();
|
|
||||||
let run = execute_strategy(driver, params, None, false).await?;
|
|
||||||
Ok(StrategyOutput::Entities(run.results))
|
|
||||||
}
|
|
||||||
RetrievalStrategy::Ingestion => {
|
|
||||||
let driver = IngestionDriver::new();
|
|
||||||
let run = execute_strategy(driver, params, None, false).await?;
|
|
||||||
Ok(StrategyOutput::Entities(run.results))
|
|
||||||
}
|
|
||||||
RetrievalStrategy::Search => {
|
|
||||||
let driver = SearchStrategyDriver::new(search_target);
|
|
||||||
let run = execute_strategy(driver, params, None, false).await?;
|
|
||||||
Ok(StrategyOutput::Search(run.results))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for stage in build_stages(&ctx.config) {
|
||||||
|
let start = Instant::now();
|
||||||
|
stage.execute(&mut ctx).await?;
|
||||||
|
ctx.record_stage_duration(stage.kind(), start.elapsed());
|
||||||
|
}
|
||||||
|
|
||||||
|
let diagnostics = ctx.take_diagnostics();
|
||||||
|
let stage_timings = ctx.take_stage_timings();
|
||||||
|
let chunks = ctx.take_chunk_results();
|
||||||
|
let results = if resolve_entities {
|
||||||
|
RetrievalOutput::WithEntities {
|
||||||
|
chunks,
|
||||||
|
entities: ctx.take_entity_results(),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
RetrievalOutput::Chunks(chunks)
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(RunOutput {
|
||||||
|
results,
|
||||||
|
diagnostics,
|
||||||
|
stage_timings,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run_pipeline_with_embedding(
|
/// Run the retrieval pipeline, generating the query embedding internally if needed.
|
||||||
params: StrategyParams<'_>,
|
pub async fn execute(params: RetrievalParams<'_>) -> Result<RetrievalOutput, AppError> {
|
||||||
query_embedding: Vec<f32>,
|
Ok(run(params, None, false).await?.results)
|
||||||
) -> Result<StrategyOutput, AppError> {
|
|
||||||
let strategy = params.config.strategy;
|
|
||||||
let search_target = params.config.search_target;
|
|
||||||
|
|
||||||
match strategy {
|
|
||||||
RetrievalStrategy::Default => {
|
|
||||||
let driver = DefaultStrategyDriver::new();
|
|
||||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
|
||||||
Ok(StrategyOutput::Chunks(run.results))
|
|
||||||
}
|
|
||||||
RetrievalStrategy::RelationshipSuggestion => {
|
|
||||||
let driver = RelationshipSuggestionDriver::new();
|
|
||||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
|
||||||
Ok(StrategyOutput::Entities(run.results))
|
|
||||||
}
|
|
||||||
RetrievalStrategy::Ingestion => {
|
|
||||||
let driver = IngestionDriver::new();
|
|
||||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
|
||||||
Ok(StrategyOutput::Entities(run.results))
|
|
||||||
}
|
|
||||||
RetrievalStrategy::Search => {
|
|
||||||
let driver = SearchStrategyDriver::new(search_target);
|
|
||||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
|
||||||
Ok(StrategyOutput::Search(run.results))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run_pipeline_with_embedding_with_metrics(
|
/// Run the retrieval pipeline with a pre-computed query embedding.
|
||||||
params: StrategyParams<'_>,
|
pub async fn run_with_embedding(
|
||||||
|
params: RetrievalParams<'_>,
|
||||||
query_embedding: Vec<f32>,
|
query_embedding: Vec<f32>,
|
||||||
) -> Result<RunOutput<StrategyOutput>, AppError> {
|
) -> Result<RetrievalOutput, AppError> {
|
||||||
let strategy = params.config.strategy;
|
Ok(run(params, Some(query_embedding), false).await?.results)
|
||||||
|
|
||||||
match strategy {
|
|
||||||
RetrievalStrategy::Default => {
|
|
||||||
let driver = DefaultStrategyDriver::new();
|
|
||||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
|
||||||
Ok(RunOutput {
|
|
||||||
results: StrategyOutput::Chunks(run.results),
|
|
||||||
diagnostics: run.diagnostics,
|
|
||||||
stage_timings: run.stage_timings,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
_ => Err(AppError::InternalError(
|
|
||||||
"Metrics not supported for this strategy".into(),
|
|
||||||
)),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run_pipeline_with_embedding_with_diagnostics(
|
/// Run with a pre-computed embedding, returning results and per-stage timings.
|
||||||
params: StrategyParams<'_>,
|
///
|
||||||
|
/// When `capture_diagnostics` is true, pipeline search/assemble stats are included.
|
||||||
|
pub async fn run_with_embedding_instrumented(
|
||||||
|
params: RetrievalParams<'_>,
|
||||||
query_embedding: Vec<f32>,
|
query_embedding: Vec<f32>,
|
||||||
) -> Result<RunOutput<StrategyOutput>, AppError> {
|
capture_diagnostics: bool,
|
||||||
let strategy = params.config.strategy;
|
) -> Result<RunOutput<RetrievalOutput>, AppError> {
|
||||||
|
run(params, Some(query_embedding), capture_diagnostics).await
|
||||||
match strategy {
|
|
||||||
RetrievalStrategy::Default => {
|
|
||||||
let driver = DefaultStrategyDriver::new();
|
|
||||||
let run = execute_strategy(driver, params, Some(query_embedding), true).await?;
|
|
||||||
Ok(RunOutput {
|
|
||||||
results: StrategyOutput::Chunks(run.results),
|
|
||||||
diagnostics: run.diagnostics,
|
|
||||||
stage_timings: run.stage_timings,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
_ => Err(AppError::InternalError(
|
|
||||||
"Diagnostics not supported for this strategy".into(),
|
|
||||||
)),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value {
|
pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value {
|
||||||
@@ -246,57 +210,3 @@ pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::V
|
|||||||
})
|
})
|
||||||
.collect::<Vec<_>>())
|
.collect::<Vec<_>>())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct StrategyParams<'a> {
|
|
||||||
pub db_client: &'a SurrealDbClient,
|
|
||||||
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
|
||||||
pub embedding_provider: Option<&'a common::utils::embedding::EmbeddingProvider>,
|
|
||||||
pub input_text: &'a str,
|
|
||||||
pub user_id: &'a str,
|
|
||||||
pub config: RetrievalConfig,
|
|
||||||
pub reranker: Option<RerankerLease>,
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn execute_strategy<D: StrategyDriver>(
|
|
||||||
driver: D,
|
|
||||||
params: StrategyParams<'_>,
|
|
||||||
query_embedding: Option<Vec<f32>>,
|
|
||||||
capture_diagnostics: bool,
|
|
||||||
) -> Result<RunOutput<D::Output>, AppError> {
|
|
||||||
let ctx = match query_embedding {
|
|
||||||
Some(embedding) => PipelineContext::with_embedding(params, embedding),
|
|
||||||
None => PipelineContext::new(params),
|
|
||||||
};
|
|
||||||
|
|
||||||
run_with_driver(driver, ctx, capture_diagnostics).await
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn run_with_driver<D: StrategyDriver>(
|
|
||||||
driver: D,
|
|
||||||
mut ctx: PipelineContext<'_>,
|
|
||||||
capture_diagnostics: bool,
|
|
||||||
) -> Result<RunOutput<D::Output>, AppError> {
|
|
||||||
if capture_diagnostics {
|
|
||||||
ctx.enable_diagnostics();
|
|
||||||
}
|
|
||||||
|
|
||||||
for stage in driver.stages() {
|
|
||||||
let start = Instant::now();
|
|
||||||
stage.execute(&mut ctx).await?;
|
|
||||||
ctx.record_stage_duration(stage.kind(), start.elapsed());
|
|
||||||
}
|
|
||||||
|
|
||||||
let diagnostics = ctx.take_diagnostics();
|
|
||||||
let stage_timings = ctx.take_stage_timings();
|
|
||||||
let results = driver.finalize(&mut ctx).map_err(|e| *e)?;
|
|
||||||
|
|
||||||
Ok(RunOutput {
|
|
||||||
results,
|
|
||||||
diagnostics,
|
|
||||||
stage_timings,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn round_score(value: f32) -> f64 {
|
|
||||||
(f64::from(value) * 1000.0).round() / 1000.0
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -0,0 +1,424 @@
|
|||||||
|
use async_trait::async_trait;
|
||||||
|
use common::{
|
||||||
|
error::AppError,
|
||||||
|
storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
|
||||||
|
utils::embedding::generate_embedding,
|
||||||
|
};
|
||||||
|
use fastembed::RerankResult;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use tracing::{debug, instrument, warn};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
scoring::{
|
||||||
|
clamp_unit, min_max_normalize, reciprocal_rank_fusion, RrfConfig, Scored,
|
||||||
|
},
|
||||||
|
RetrievedChunk, RetrievedEntity,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
config::RetrievalTuning,
|
||||||
|
context::PipelineContext,
|
||||||
|
diagnostics::{AssembleStats, SearchStats},
|
||||||
|
Stage, StageKind,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct EmbedStage;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Stage for EmbedStage {
|
||||||
|
fn kind(&self) -> StageKind {
|
||||||
|
StageKind::Embed
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
|
embed(ctx).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct ChunkSearchStage;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Stage for ChunkSearchStage {
|
||||||
|
fn kind(&self) -> StageKind {
|
||||||
|
StageKind::Search
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
|
search_chunks(ctx).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct ChunkRerankStage;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Stage for ChunkRerankStage {
|
||||||
|
fn kind(&self) -> StageKind {
|
||||||
|
StageKind::Rerank
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
|
rerank_chunks(ctx).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct ResolveEntitiesStage;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Stage for ResolveEntitiesStage {
|
||||||
|
fn kind(&self) -> StageKind {
|
||||||
|
StageKind::ResolveEntities
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
|
resolve_entities(ctx).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct ChunkAssembleStage;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Stage for ChunkAssembleStage {
|
||||||
|
fn kind(&self) -> StageKind {
|
||||||
|
StageKind::Assemble
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
|
assemble_chunks(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(level = "trace", skip_all)]
|
||||||
|
pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
|
if ctx.query_embedding.is_some() {
|
||||||
|
debug!("Reusing cached query embedding for hybrid retrieval");
|
||||||
|
} else {
|
||||||
|
debug!("Generating query embedding for hybrid retrieval");
|
||||||
|
let embedding = if let Some(provider) = ctx.embedding_provider {
|
||||||
|
provider.embed(&ctx.input_text).await?
|
||||||
|
} else {
|
||||||
|
generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await?
|
||||||
|
};
|
||||||
|
ctx.query_embedding = Some(embedding);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(level = "trace", skip_all)]
|
||||||
|
pub async fn search_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
|
debug!("Collecting chunk candidates via vector and FTS search");
|
||||||
|
let embedding = ctx.ensure_embedding().map_err(|e| *e)?.clone();
|
||||||
|
let tuning = &ctx.config.tuning;
|
||||||
|
let fts_take = tuning.chunk_fts_take;
|
||||||
|
let (fts_query, fts_token_count) = normalize_fts_query(&ctx.input_text);
|
||||||
|
let fts_enabled = tuning.flags.chunk_rrf_use_fts() && fts_take > 0 && !fts_query.is_empty();
|
||||||
|
|
||||||
|
let (vector_rows, fts_rows) = tokio::try_join!(
|
||||||
|
TextChunk::vector_search(
|
||||||
|
tuning.chunk_vector_take,
|
||||||
|
embedding,
|
||||||
|
ctx.db_client,
|
||||||
|
&ctx.user_id,
|
||||||
|
),
|
||||||
|
async {
|
||||||
|
if fts_enabled {
|
||||||
|
TextChunk::fts_search(fts_take, &fts_query, ctx.db_client, &ctx.user_id).await
|
||||||
|
} else {
|
||||||
|
Ok(Vec::new())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let vector_candidates = vector_rows.len();
|
||||||
|
let fts_candidates = fts_rows.len();
|
||||||
|
|
||||||
|
let vector_scored: Vec<Scored<TextChunk>> = vector_rows
|
||||||
|
.into_iter()
|
||||||
|
.map(|row| Scored::new(row.chunk).with_vector_score(row.score))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let fts_scored: Vec<Scored<TextChunk>> = fts_rows
|
||||||
|
.into_iter()
|
||||||
|
.map(|row| Scored::new(row.chunk).with_fts_score(row.score))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut fts_weight = tuning.chunk_rrf_fts_weight;
|
||||||
|
if fts_enabled && fts_token_count > 0 && fts_token_count <= 3 {
|
||||||
|
// For very short keyword queries, lean more on lexical ranking.
|
||||||
|
fts_weight *= 1.5;
|
||||||
|
}
|
||||||
|
|
||||||
|
let rrf_config = RrfConfig {
|
||||||
|
k: tuning.chunk_rrf_k,
|
||||||
|
vector_weight: tuning.chunk_rrf_vector_weight,
|
||||||
|
fts_weight,
|
||||||
|
use_vector: tuning.flags.chunk_rrf_use_vector(),
|
||||||
|
use_fts: tuning.flags.chunk_rrf_use_fts() && fts_candidates > 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let chunks = reciprocal_rank_fusion(vector_scored, fts_scored, rrf_config);
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
total_merged = chunks.len(),
|
||||||
|
vector_only = chunks.iter().filter(|c| c.scores.fts.is_none()).count(),
|
||||||
|
fts_only = chunks.iter().filter(|c| c.scores.vector.is_none()).count(),
|
||||||
|
both_signals = chunks
|
||||||
|
.iter()
|
||||||
|
.filter(|c| c.scores.vector.is_some() && c.scores.fts.is_some())
|
||||||
|
.count(),
|
||||||
|
rrf_k = rrf_config.k,
|
||||||
|
"Merged chunk candidates with RRF"
|
||||||
|
);
|
||||||
|
|
||||||
|
if ctx.diagnostics_enabled() {
|
||||||
|
ctx.record_search(SearchStats {
|
||||||
|
vector_chunk_candidates: vector_candidates,
|
||||||
|
fts_chunk_candidates: fts_candidates,
|
||||||
|
vector_chunk_scores: sample_scores(&chunks, |chunk| chunk.scores.vector.unwrap_or(0.0)),
|
||||||
|
fts_chunk_scores: sample_scores(&chunks, |chunk| chunk.scores.fts.unwrap_or(0.0)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.chunk_values = chunks;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(level = "trace", skip_all)]
|
||||||
|
pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
|
if ctx.chunk_values.len() <= 1 {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(reranker) = ctx.reranker.as_ref() else {
|
||||||
|
debug!("No reranker lease provided; skipping chunk rerank stage");
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
|
||||||
|
let documents =
|
||||||
|
build_chunk_rerank_documents(&ctx.chunk_values, ctx.config.tuning.rerank_keep_top.max(1));
|
||||||
|
if documents.len() <= 1 {
|
||||||
|
debug!("Skipping chunk reranking stage; insufficient chunk documents");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
match reranker.rerank(&ctx.input_text, documents).await {
|
||||||
|
Ok(results) if !results.is_empty() => {
|
||||||
|
apply_chunk_rerank_results(&mut ctx.chunk_values, &ctx.config.tuning, results);
|
||||||
|
}
|
||||||
|
Ok(_) => debug!("Chunk reranker returned no results; retaining original order"),
|
||||||
|
Err(err) => warn!(
|
||||||
|
error = %err,
|
||||||
|
"Chunk reranking failed; continuing with original ordering"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resolve the `KnowledgeEntity` rows that own the retrieved chunks.
|
||||||
|
///
|
||||||
|
/// Entities are derived directly from the (benchmarked) chunk retrieval: chunks are grouped
|
||||||
|
/// by `source_id`, the owning entities are loaded, scored by their best contributing chunk,
|
||||||
|
/// and the contributing chunks are attached.
|
||||||
|
#[instrument(level = "trace", skip_all)]
|
||||||
|
pub async fn resolve_entities(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
|
if ctx.chunk_values.is_empty() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let max_chunks = ctx.config.tuning.max_chunks_per_entity.max(1);
|
||||||
|
|
||||||
|
let mut source_order: Vec<String> = Vec::new();
|
||||||
|
let mut chunks_by_source: HashMap<String, Vec<RetrievedChunk>> = HashMap::new();
|
||||||
|
let mut best_score: HashMap<String, f32> = HashMap::new();
|
||||||
|
|
||||||
|
for scored in &ctx.chunk_values {
|
||||||
|
let source = scored.item.source_id.clone();
|
||||||
|
let attached = chunks_by_source.entry(source.clone()).or_default();
|
||||||
|
if attached.is_empty() {
|
||||||
|
source_order.push(source.clone());
|
||||||
|
best_score.insert(source.clone(), scored.fused);
|
||||||
|
}
|
||||||
|
if attached.len() < max_chunks {
|
||||||
|
attached.push(RetrievedChunk {
|
||||||
|
chunk: scored.item.clone(),
|
||||||
|
score: scored.fused,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let entities =
|
||||||
|
KnowledgeEntity::find_by_source_ids(ctx.db_client, &source_order, &ctx.user_id).await?;
|
||||||
|
|
||||||
|
let mut entities_by_source: HashMap<String, Vec<KnowledgeEntity>> = HashMap::new();
|
||||||
|
for entity in entities {
|
||||||
|
entities_by_source
|
||||||
|
.entry(entity.source_id.clone())
|
||||||
|
.or_default()
|
||||||
|
.push(entity);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut results = Vec::new();
|
||||||
|
for source in &source_order {
|
||||||
|
let Some(entities) = entities_by_source.remove(source) else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
let score = best_score.get(source).copied().unwrap_or(0.0);
|
||||||
|
let chunks = chunks_by_source.get(source).cloned().unwrap_or_default();
|
||||||
|
for entity in entities {
|
||||||
|
results.push(RetrievedEntity {
|
||||||
|
entity,
|
||||||
|
score,
|
||||||
|
chunks: chunks.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
sources = source_order.len(),
|
||||||
|
entities = results.len(),
|
||||||
|
"Resolved entities from retrieved chunks"
|
||||||
|
);
|
||||||
|
|
||||||
|
ctx.entity_results = results;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(level = "trace", skip_all)]
|
||||||
|
#[allow(clippy::result_large_err)]
|
||||||
|
pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
|
debug!("Assembling chunk retrieval results");
|
||||||
|
let mut chunk_values = std::mem::take(&mut ctx.chunk_values);
|
||||||
|
// Limit how many chunks we return to keep context size reasonable.
|
||||||
|
let limit = ctx
|
||||||
|
.config
|
||||||
|
.tuning
|
||||||
|
.chunk_result_cap
|
||||||
|
.max(1)
|
||||||
|
.min(ctx.config.tuning.chunk_vector_take.max(1));
|
||||||
|
|
||||||
|
if chunk_values.len() > limit {
|
||||||
|
chunk_values.truncate(limit);
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.chunk_results = chunk_values
|
||||||
|
.into_iter()
|
||||||
|
.map(|chunk| RetrievedChunk {
|
||||||
|
chunk: chunk.item,
|
||||||
|
score: chunk.fused,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if ctx.diagnostics_enabled() {
|
||||||
|
ctx.record_assemble(AssembleStats {
|
||||||
|
chunks_selected: ctx.chunk_results.len(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
const SCORE_SAMPLE_LIMIT: usize = 8;
|
||||||
|
|
||||||
|
fn sample_scores<T, F>(items: &[Scored<T>], extractor: F) -> Vec<f32>
|
||||||
|
where
|
||||||
|
F: FnMut(&Scored<T>) -> f32,
|
||||||
|
{
|
||||||
|
items.iter().take(SCORE_SAMPLE_LIMIT).map(extractor).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize_fts_query(input: &str) -> (String, usize) {
|
||||||
|
const STOPWORDS: &[&str] = &["the", "a", "an", "of", "in", "on", "and", "or", "to", "for"];
|
||||||
|
let mut cleaned = String::with_capacity(input.len());
|
||||||
|
for ch in input.chars() {
|
||||||
|
if ch.is_alphanumeric() {
|
||||||
|
cleaned.extend(ch.to_lowercase());
|
||||||
|
} else if ch.is_whitespace() {
|
||||||
|
cleaned.push(' ');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mut tokens = Vec::with_capacity(cleaned.len().div_ceil(3));
|
||||||
|
for token in cleaned.split_whitespace() {
|
||||||
|
if !STOPWORDS.contains(&token) && !token.is_empty() {
|
||||||
|
tokens.push(token.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let normalized = tokens.join(" ");
|
||||||
|
(normalized, tokens.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_chunk_rerank_documents(chunks: &[Scored<TextChunk>], max_chunks: usize) -> Vec<String> {
|
||||||
|
chunks
|
||||||
|
.iter()
|
||||||
|
.take(max_chunks)
|
||||||
|
.map(|chunk| {
|
||||||
|
format!(
|
||||||
|
"Source: {}\nChunk:\n{}",
|
||||||
|
chunk.item.source_id,
|
||||||
|
chunk.item.chunk.trim()
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_chunk_rerank_results(
|
||||||
|
chunks: &mut Vec<Scored<TextChunk>>,
|
||||||
|
tuning: &RetrievalTuning,
|
||||||
|
results: Vec<RerankResult>,
|
||||||
|
) {
|
||||||
|
if results.is_empty() || chunks.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut remaining: Vec<Option<Scored<TextChunk>>> =
|
||||||
|
std::mem::take(chunks).into_iter().map(Some).collect();
|
||||||
|
|
||||||
|
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
|
||||||
|
let normalized_scores = min_max_normalize(&raw_scores);
|
||||||
|
|
||||||
|
let use_only = tuning.flags.rerank_scores_only();
|
||||||
|
let blend = if use_only {
|
||||||
|
1.0
|
||||||
|
} else {
|
||||||
|
clamp_unit(tuning.rerank_blend_weight)
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut reranked: Vec<Scored<TextChunk>> = Vec::with_capacity(remaining.len());
|
||||||
|
for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) {
|
||||||
|
if let Some(slot) = remaining.get_mut(result.index) {
|
||||||
|
if let Some(mut candidate) = slot.take() {
|
||||||
|
let original = candidate.fused;
|
||||||
|
let blended = if use_only {
|
||||||
|
clamp_unit(normalized)
|
||||||
|
} else {
|
||||||
|
clamp_unit(original * (1.0 - blend) + normalized * blend)
|
||||||
|
};
|
||||||
|
candidate.update_fused(blended);
|
||||||
|
reranked.push(candidate);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
warn!(
|
||||||
|
result_index = result.index,
|
||||||
|
"Chunk reranker returned out-of-range index; skipping"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if reranked.len() == remaining.len() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
reranked.extend(remaining.into_iter().flatten());
|
||||||
|
|
||||||
|
let keep_top = tuning.rerank_keep_top;
|
||||||
|
if keep_top > 0 && reranked.len() > keep_top {
|
||||||
|
reranked.truncate(keep_top);
|
||||||
|
}
|
||||||
|
|
||||||
|
*chunks = reranked;
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,148 +0,0 @@
|
|||||||
use super::{
|
|
||||||
stages::{
|
|
||||||
AssembleEntitiesStage, ChunkAssembleStage, ChunkRerankStage, ChunkVectorStage,
|
|
||||||
CollectCandidatesStage, EmbedStage, GraphExpansionStage, PipelineContext, RerankStage,
|
|
||||||
},
|
|
||||||
BoxedStage, StrategyDriver,
|
|
||||||
};
|
|
||||||
use crate::{RetrievedChunk, RetrievedEntity};
|
|
||||||
use common::error::AppError;
|
|
||||||
|
|
||||||
pub struct DefaultStrategyDriver;
|
|
||||||
|
|
||||||
impl DefaultStrategyDriver {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StrategyDriver for DefaultStrategyDriver {
|
|
||||||
type Output = Vec<RetrievedChunk>;
|
|
||||||
|
|
||||||
fn stages(&self) -> Vec<BoxedStage> {
|
|
||||||
vec![
|
|
||||||
Box::new(EmbedStage),
|
|
||||||
Box::new(ChunkVectorStage),
|
|
||||||
Box::new(ChunkRerankStage),
|
|
||||||
Box::new(ChunkAssembleStage),
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
|
|
||||||
Ok(ctx.take_chunk_results())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct RelationshipSuggestionDriver;
|
|
||||||
|
|
||||||
impl RelationshipSuggestionDriver {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StrategyDriver for RelationshipSuggestionDriver {
|
|
||||||
type Output = Vec<RetrievedEntity>;
|
|
||||||
|
|
||||||
fn stages(&self) -> Vec<BoxedStage> {
|
|
||||||
vec![
|
|
||||||
Box::new(EmbedStage),
|
|
||||||
Box::new(CollectCandidatesStage),
|
|
||||||
Box::new(GraphExpansionStage),
|
|
||||||
// Skip ChunkAttachStage
|
|
||||||
Box::new(RerankStage),
|
|
||||||
Box::new(AssembleEntitiesStage),
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
|
|
||||||
Ok(ctx.take_entity_results())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct IngestionDriver;
|
|
||||||
|
|
||||||
impl IngestionDriver {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StrategyDriver for IngestionDriver {
|
|
||||||
type Output = Vec<RetrievedEntity>;
|
|
||||||
|
|
||||||
fn stages(&self) -> Vec<BoxedStage> {
|
|
||||||
vec![
|
|
||||||
Box::new(EmbedStage),
|
|
||||||
Box::new(CollectCandidatesStage),
|
|
||||||
Box::new(GraphExpansionStage),
|
|
||||||
// Skip ChunkAttachStage
|
|
||||||
Box::new(RerankStage),
|
|
||||||
Box::new(AssembleEntitiesStage),
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
|
|
||||||
Ok(ctx.take_entity_results())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
use super::config::SearchTarget;
|
|
||||||
use crate::SearchResult;
|
|
||||||
|
|
||||||
/// Search strategy driver that retrieves both chunks and entities
|
|
||||||
pub struct SearchStrategyDriver {
|
|
||||||
target: SearchTarget,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SearchStrategyDriver {
|
|
||||||
pub fn new(target: SearchTarget) -> Self {
|
|
||||||
Self { target }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StrategyDriver for SearchStrategyDriver {
|
|
||||||
type Output = SearchResult;
|
|
||||||
|
|
||||||
fn stages(&self) -> Vec<BoxedStage> {
|
|
||||||
match self.target {
|
|
||||||
SearchTarget::ChunksOnly => vec![
|
|
||||||
Box::new(EmbedStage),
|
|
||||||
Box::new(ChunkVectorStage),
|
|
||||||
Box::new(ChunkRerankStage),
|
|
||||||
Box::new(ChunkAssembleStage),
|
|
||||||
],
|
|
||||||
SearchTarget::EntitiesOnly => vec![
|
|
||||||
Box::new(EmbedStage),
|
|
||||||
Box::new(CollectCandidatesStage),
|
|
||||||
Box::new(GraphExpansionStage),
|
|
||||||
Box::new(RerankStage),
|
|
||||||
Box::new(AssembleEntitiesStage),
|
|
||||||
],
|
|
||||||
SearchTarget::Both => vec![
|
|
||||||
Box::new(EmbedStage),
|
|
||||||
// Chunk retrieval path
|
|
||||||
Box::new(ChunkVectorStage),
|
|
||||||
Box::new(ChunkRerankStage),
|
|
||||||
Box::new(ChunkAssembleStage),
|
|
||||||
// Entity retrieval path (runs after chunk stages)
|
|
||||||
Box::new(CollectCandidatesStage),
|
|
||||||
Box::new(GraphExpansionStage),
|
|
||||||
Box::new(RerankStage),
|
|
||||||
Box::new(AssembleEntitiesStage),
|
|
||||||
],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
|
|
||||||
let chunks = match self.target {
|
|
||||||
SearchTarget::EntitiesOnly => Vec::new(),
|
|
||||||
_ => ctx.take_chunk_results(),
|
|
||||||
};
|
|
||||||
let entities = match self.target {
|
|
||||||
SearchTarget::ChunksOnly => Vec::new(),
|
|
||||||
_ => ctx.take_entity_results(),
|
|
||||||
};
|
|
||||||
Ok(SearchResult::new(chunks, entities))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -97,8 +97,7 @@ impl RerankerPool {
|
|||||||
|
|
||||||
fn default_pool_size() -> usize {
|
fn default_pool_size() -> usize {
|
||||||
available_parallelism()
|
available_parallelism()
|
||||||
.map(|value| value.get().min(2))
|
.map_or(2, |value| value.get().min(2))
|
||||||
.unwrap_or(2)
|
|
||||||
.max(1)
|
.max(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -156,6 +155,7 @@ pub struct RerankerLease {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl RerankerLease {
|
impl RerankerLease {
|
||||||
|
#[allow(clippy::result_large_err)]
|
||||||
pub async fn rerank(
|
pub async fn rerank(
|
||||||
&self,
|
&self,
|
||||||
query: &str,
|
query: &str,
|
||||||
@@ -165,7 +165,9 @@ impl RerankerLease {
|
|||||||
let engine = Arc::clone(&self.engine);
|
let engine = Arc::clone(&self.engine);
|
||||||
|
|
||||||
tokio::task::spawn_blocking(move || {
|
tokio::task::spawn_blocking(move || {
|
||||||
let mut guard = engine.lock().expect("reranker engine mutex poisoned");
|
let mut guard = engine.lock().map_err(|_| {
|
||||||
|
AppError::InternalError("reranker engine mutex poisoned".into())
|
||||||
|
})?;
|
||||||
guard
|
guard
|
||||||
.rerank(query, documents, false, None)
|
.rerank(query, documents, false, None)
|
||||||
.map_err(|e| AppError::InternalError(e.to_string()))
|
.map_err(|e| AppError::InternalError(e.to_string()))
|
||||||
@@ -1,14 +1,12 @@
|
|||||||
use std::{cmp::Ordering, collections::HashMap};
|
use std::{cmp::Ordering, collections::HashMap};
|
||||||
|
|
||||||
use common::storage::types::StoredObject;
|
use common::storage::types::StoredObject;
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
/// Holds optional subscores gathered from different retrieval signals.
|
/// Holds optional subscores gathered from the vector and full-text retrieval signals.
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
pub struct Scores {
|
pub struct Scores {
|
||||||
pub fts: Option<f32>,
|
pub fts: Option<f32>,
|
||||||
pub vector: Option<f32>,
|
pub vector: Option<f32>,
|
||||||
pub graph: Option<f32>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generic wrapper combining an item with its accumulated retrieval scores.
|
/// Generic wrapper combining an item with its accumulated retrieval scores.
|
||||||
@@ -40,40 +38,11 @@ impl<T> Scored<T> {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub const fn with_graph_score(mut self, score: f32) -> Self {
|
|
||||||
self.scores.graph = Some(score);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub const fn update_fused(&mut self, fused: f32) {
|
pub const fn update_fused(&mut self, fused: f32) {
|
||||||
self.fused = fused;
|
self.fused = fused;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Weights used for linear score fusion.
|
|
||||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
|
||||||
pub struct FusionWeights {
|
|
||||||
pub vector: f32,
|
|
||||||
pub fts: f32,
|
|
||||||
pub graph: f32,
|
|
||||||
pub multi_bonus: f32,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for FusionWeights {
|
|
||||||
fn default() -> Self {
|
|
||||||
// Default weights favor vector search, which typically performs better
|
|
||||||
// FTS is used as a complement when there's good overlap
|
|
||||||
// Higher multi_bonus to heavily favor chunks with both signals (the "golden chunk")
|
|
||||||
Self {
|
|
||||||
vector: 0.8,
|
|
||||||
fts: 0.2,
|
|
||||||
graph: 0.2,
|
|
||||||
multi_bonus: 0.3, // Increased to boost chunks with both signals
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Configuration for reciprocal rank fusion.
|
/// Configuration for reciprocal rank fusion.
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub struct RrfConfig {
|
pub struct RrfConfig {
|
||||||
@@ -84,29 +53,10 @@ pub struct RrfConfig {
|
|||||||
pub use_fts: bool,
|
pub use_fts: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for RrfConfig {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
k: 60.0,
|
|
||||||
vector_weight: 1.0,
|
|
||||||
fts_weight: 1.0,
|
|
||||||
use_vector: true,
|
|
||||||
use_fts: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub const fn clamp_unit(value: f32) -> f32 {
|
pub const fn clamp_unit(value: f32) -> f32 {
|
||||||
value.clamp(0.0, 1.0)
|
value.clamp(0.0, 1.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn distance_to_similarity(distance: f32) -> f32 {
|
|
||||||
if !distance.is_finite() {
|
|
||||||
return 0.0;
|
|
||||||
}
|
|
||||||
clamp_unit(1.0 / (1.0 + distance.max(0.0)))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn min_max_normalize(scores: &[f32]) -> Vec<f32> {
|
pub fn min_max_normalize(scores: &[f32]) -> Vec<f32> {
|
||||||
if scores.is_empty() {
|
if scores.is_empty() {
|
||||||
return Vec::new();
|
return Vec::new();
|
||||||
@@ -147,69 +97,6 @@ pub fn min_max_normalize(scores: &[f32]) -> Vec<f32> {
|
|||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 {
|
|
||||||
let vector = scores.vector.unwrap_or(0.0);
|
|
||||||
let fts = scores.fts.unwrap_or(0.0);
|
|
||||||
let graph = scores.graph.unwrap_or(0.0);
|
|
||||||
|
|
||||||
let mut fused = graph.mul_add(
|
|
||||||
weights.graph,
|
|
||||||
vector.mul_add(weights.vector, fts * weights.fts),
|
|
||||||
);
|
|
||||||
|
|
||||||
let signals_present = scores
|
|
||||||
.vector
|
|
||||||
.iter()
|
|
||||||
.chain(scores.fts.iter())
|
|
||||||
.chain(scores.graph.iter())
|
|
||||||
.count();
|
|
||||||
|
|
||||||
// Boost chunks with multiple signals (especially vector + FTS, the "golden chunk")
|
|
||||||
if signals_present >= 2 {
|
|
||||||
// For chunks with both vector and FTS, give a significant boost
|
|
||||||
// This helps identify the "golden chunk" that appears in both searches
|
|
||||||
if scores.vector.is_some() && scores.fts.is_some() {
|
|
||||||
// Multiplicative boost: multiply by (1 + bonus) to scale with the base score
|
|
||||||
// This ensures high-scoring golden chunks get boosted more than low-scoring ones
|
|
||||||
fused *= 1.0 + weights.multi_bonus;
|
|
||||||
} else {
|
|
||||||
// For other multi-signal combinations (e.g., vector + graph), use additive bonus
|
|
||||||
fused += weights.multi_bonus;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
clamp_unit(fused)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn merge_scored_by_id<T, S: std::hash::BuildHasher>(
|
|
||||||
target: &mut std::collections::HashMap<String, Scored<T>, S>,
|
|
||||||
incoming: Vec<Scored<T>>,
|
|
||||||
) where
|
|
||||||
T: StoredObject + Clone,
|
|
||||||
{
|
|
||||||
for scored in incoming {
|
|
||||||
let id = scored.item.id().to_owned();
|
|
||||||
target
|
|
||||||
.entry(id)
|
|
||||||
.and_modify(|existing| {
|
|
||||||
if let Some(score) = scored.scores.vector {
|
|
||||||
existing.scores.vector = Some(score);
|
|
||||||
}
|
|
||||||
if let Some(score) = scored.scores.fts {
|
|
||||||
existing.scores.fts = Some(score);
|
|
||||||
}
|
|
||||||
if let Some(score) = scored.scores.graph {
|
|
||||||
existing.scores.graph = Some(score);
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.or_insert_with(|| Scored {
|
|
||||||
item: scored.item.clone(),
|
|
||||||
scores: scored.scores,
|
|
||||||
fused: scored.fused,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn sort_by_fused_desc<T>(items: &mut [Scored<T>])
|
pub fn sort_by_fused_desc<T>(items: &mut [Scored<T>])
|
||||||
where
|
where
|
||||||
T: StoredObject,
|
T: StoredObject,
|
||||||
@@ -222,6 +109,10 @@ where
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Fuse two ranked candidate lists into a single ranking using reciprocal rank fusion.
|
||||||
|
///
|
||||||
|
/// This is the sole fusion mechanism for the retrieval pipeline: vector and full-text
|
||||||
|
/// candidates each contribute `weight / (k + rank + 1)` to a shared fused score.
|
||||||
pub fn reciprocal_rank_fusion<T>(
|
pub fn reciprocal_rank_fusion<T>(
|
||||||
mut vector_ranked: Vec<Scored<T>>,
|
mut vector_ranked: Vec<Scored<T>>,
|
||||||
mut fts_ranked: Vec<Scored<T>>,
|
mut fts_ranked: Vec<Scored<T>>,
|
||||||
@@ -266,9 +157,7 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
entry.item = candidate.item;
|
entry.item = candidate.item;
|
||||||
let rank_f32: f32 = u16::try_from(rank)
|
let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from);
|
||||||
.map(f32::from)
|
|
||||||
.unwrap_or(f32::MAX);
|
|
||||||
entry.fused += vector_weight / (k + rank_f32 + 1.0);
|
entry.fused += vector_weight / (k + rank_f32 + 1.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -296,9 +185,7 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
entry.item = candidate.item;
|
entry.item = candidate.item;
|
||||||
let rank_f32: f32 = u16::try_from(rank)
|
let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from);
|
||||||
.map(f32::from)
|
|
||||||
.unwrap_or(f32::MAX);
|
|
||||||
entry.fused += fts_weight / (k + rank_f32 + 1.0);
|
entry.fused += fts_weight / (k + rank_f32 + 1.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user