chore: refactor retrieval pipeline to chunk-first RRF with derived entities and slimmer eval surface.

Collapse the multi-strategy entity engine into one benchmarked chunk retrieval path, derive entities from retrieved chunks, and update consumers, docs, and clippy fixes across the workspace.
This commit is contained in:
Per Stark
2026-05-30 22:19:08 +02:00
parent c70141de35
commit 5c2d2e24d3
38 changed files with 1049 additions and 2614 deletions
+4 -6
View File
@@ -10,16 +10,14 @@ workspace = true
[dependencies]
tokio = { workspace = true }
serde = { workspace = true }
axum = { workspace = true }
tracing = { workspace = true }
anyhow = { workspace = true }
thiserror = { workspace = true }
serde_json = { workspace = true }
surrealdb = { workspace = true }
futures = { workspace = true }
async-openai = { workspace = true }
async-trait = { workspace = true }
uuid = { workspace = true }
fastembed = { workspace = true }
common = { path = "../common", features = ["test-utils"] }
[dev-dependencies]
anyhow = { workspace = true }
uuid = { workspace = true }
+43 -53
View File
@@ -1,61 +1,66 @@
//! Chat answer assembly: retrieval context formatting and structured LLM request/response types.
use async_openai::{
error::OpenAIError,
types::{
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
ResponseFormat, ResponseFormatJsonSchema,
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, ResponseFormat,
ResponseFormatJsonSchema,
},
};
use common::{
error::AppError,
storage::types::{
message::{format_history, Message},
system_settings::SystemSettings,
},
use common::storage::types::{
message::{format_history, Message},
system_settings::SystemSettings,
};
use serde::Deserialize;
use serde_json::Value;
use serde_json::{json, Value};
use super::answer_retrieval_helper::get_query_response_schema;
/// JSON schema describing the structured chat answer (answer text + references).
fn get_query_response_schema() -> Value {
json!({
"type": "object",
"properties": {
"answer": { "type": "string" },
"references": {
"type": "array",
"items": {
"type": "object",
"properties": {
"reference": { "type": "string" },
},
"required": ["reference"],
"additionalProperties": false,
}
}
},
"required": ["answer", "references"],
"additionalProperties": false
})
}
#[derive(Debug, Deserialize)]
pub struct Reference {
#[allow(dead_code)]
pub reference: String,
}
#[derive(Debug, Deserialize)]
pub struct LLMResponseFormat {
pub answer: String,
#[allow(dead_code)]
pub references: Vec<Reference>,
}
#[derive(Debug)]
pub struct Answer {
pub content: String,
pub references: Vec<String>,
}
pub fn create_user_message(entities_json: &Value, query: &str) -> String {
format!(
r"
Context Information:
==================
{entities_json}
User Question:
==================
{query}
"
)
}
/// Convert chunk-based retrieval results to JSON format for LLM context
pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value {
fn round_score(value: f32) -> f64 {
(f64::from(value) * 1000.0).round() / 1000.0
impl LLMResponseFormat {
pub fn reference_ids(&self) -> Vec<String> {
self.references
.iter()
.map(|entry| entry.reference.clone())
.collect()
}
}
/// Convert chunk-based retrieval results to JSON format for LLM context.
pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value {
use crate::round_score;
serde_json::json!(chunks
.iter()
@@ -70,7 +75,7 @@ pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value {
}
pub fn create_user_message_with_history(
entities_json: &Value,
context_json: &Value,
history: &[Message],
query: &str,
) -> String {
@@ -89,7 +94,7 @@ pub fn create_user_message_with_history(
{}
",
format_history(history),
entities_json,
context_json,
query
)
}
@@ -116,18 +121,3 @@ pub fn create_chat_request(
.response_format(response_format)
.build()
}
pub fn process_llm_response(
response: &CreateChatCompletionResponse,
) -> Result<LLMResponseFormat, Box<AppError>> {
response
.choices
.first()
.and_then(|choice| choice.message.content.as_ref())
.ok_or_else(|| Box::new(AppError::LLMParsing("No content found in LLM response".into())))
.and_then(|content| {
serde_json::from_str::<LLMResponseFormat>(content).map_err(|e| {
Box::new(AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {e}")))
})
})
}
@@ -1,23 +0,0 @@
use serde_json::{json, Value};
pub fn get_query_response_schema() -> Value {
json!({
"type": "object",
"properties": {
"answer": { "type": "string" },
"references": {
"type": "array",
"items": {
"type": "object",
"properties": {
"reference": { "type": "string" },
},
"required": ["reference"],
"additionalProperties": false,
}
}
},
"required": ["answer", "references"],
"additionalProperties": false
})
}
-228
View File
@@ -1,228 +0,0 @@
use std::collections::{HashMap, HashSet};
use surrealdb::{sql::Thing, Error};
use common::storage::{
db::SurrealDbClient,
types::{
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
StoredObject,
},
};
/// Find entities related to the given entity via graph relationships.
///
/// Queries the `relates_to` edge table for all relationships involving the entity,
/// then fetches and returns the neighboring entities.
///
/// # Arguments
/// * `db` - Database client
/// * `entity_id` - ID of the entity to find neighbors for
/// * `user_id` - User ID for access control
/// * `limit` - Maximum number of neighbors to return
pub async fn find_entities_by_relationship_by_id(
db: &SurrealDbClient,
entity_id: &str,
user_id: &str,
limit: usize,
) -> Result<Vec<KnowledgeEntity>, Error> {
let mut relationships_response = db
.query(
"
SELECT * FROM relates_to
WHERE metadata.user_id = $user_id
AND (in = type::thing('knowledge_entity', $entity_id)
OR out = type::thing('knowledge_entity', $entity_id))
",
)
.bind(("entity_id", entity_id.to_owned()))
.bind(("user_id", user_id.to_owned()))
.await?;
let relationships: Vec<KnowledgeRelationship> = relationships_response.take(0)?;
if relationships.is_empty() {
return Ok(Vec::new());
}
let mut neighbor_ids: Vec<String> = Vec::with_capacity(relationships.len());
let mut seen: HashSet<String> = HashSet::with_capacity(relationships.len());
for rel in relationships {
if rel.in_ == entity_id {
if seen.insert(rel.out.clone()) {
neighbor_ids.push(rel.out);
}
} else if rel.out == entity_id {
if seen.insert(rel.in_.clone()) {
neighbor_ids.push(rel.in_);
}
} else {
if seen.insert(rel.in_.clone()) {
neighbor_ids.push(rel.in_.clone());
}
if seen.insert(rel.out.clone()) {
neighbor_ids.push(rel.out);
}
}
}
neighbor_ids.retain(|id| id != entity_id);
if neighbor_ids.is_empty() {
return Ok(Vec::new());
}
if limit > 0 && neighbor_ids.len() > limit {
neighbor_ids.truncate(limit);
}
let thing_ids: Vec<Thing> = neighbor_ids
.iter()
.map(|id| Thing::from((KnowledgeEntity::table_name(), id.as_str())))
.collect();
let mut neighbors_response = db
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
.bind(("table", KnowledgeEntity::table_name().to_owned()))
.bind(("things", thing_ids))
.bind(("user_id", user_id.to_owned()))
.await?;
let neighbors: Vec<KnowledgeEntity> = neighbors_response.take(0)?;
if neighbors.is_empty() {
return Ok(Vec::new());
}
let mut neighbor_map: HashMap<String, KnowledgeEntity> = neighbors
.into_iter()
.map(|entity| (entity.id.clone(), entity))
.collect();
let mut ordered = Vec::with_capacity(neighbor_ids.len());
for id in neighbor_ids {
if let Some(entity) = neighbor_map.remove(&id) {
ordered.push(entity);
}
if limit > 0 && ordered.len() >= limit {
break;
}
}
Ok(ordered)
}
#[cfg(test)]
mod tests {
use anyhow::{self, Context};
use super::*;
use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
use common::storage::types::knowledge_relationship::KnowledgeRelationship;
use uuid::Uuid;
#[tokio::test]
async fn test_find_entities_by_relationship_by_id() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
let entity_type = KnowledgeEntityType::Document;
let user_id = "user123".to_string();
let central_entity = KnowledgeEntity::new(
"central_source".to_string(),
"Central Entity".to_string(),
"Central Description".to_string(),
entity_type.clone(),
None,
user_id.clone(),
);
let related_entity1 = KnowledgeEntity::new(
"related_source1".to_string(),
"Related Entity 1".to_string(),
"Related Description 1".to_string(),
entity_type.clone(),
None,
user_id.clone(),
);
let related_entity2 = KnowledgeEntity::new(
"related_source2".to_string(),
"Related Entity 2".to_string(),
"Related Description 2".to_string(),
entity_type.clone(),
None,
user_id.clone(),
);
let unrelated_entity = KnowledgeEntity::new(
"unrelated_source".to_string(),
"Unrelated Entity".to_string(),
"Unrelated Description".to_string(),
entity_type.clone(),
None,
user_id.clone(),
);
let central_entity = db
.store_item(central_entity.clone())
.await
.with_context(|| "Failed to store central entity".to_string())?
.ok_or_else(|| anyhow::anyhow!("Central entity not returned after store"))?;
let related_entity1 = db
.store_item(related_entity1.clone())
.await
.with_context(|| "Failed to store related entity 1".to_string())?
.ok_or_else(|| anyhow::anyhow!("Related entity 1 not returned after store"))?;
let related_entity2 = db
.store_item(related_entity2.clone())
.await
.with_context(|| "Failed to store related entity 2".to_string())?
.ok_or_else(|| anyhow::anyhow!("Related entity 2 not returned after store"))?;
let _unrelated_entity = db
.store_item(unrelated_entity.clone())
.await
.with_context(|| "Failed to store unrelated entity".to_string())?
.ok_or_else(|| anyhow::anyhow!("Unrelated entity not returned after store"))?;
let source_id = "relationship_source".to_string();
let relationship1 = KnowledgeRelationship::new(
central_entity.id.clone(),
related_entity1.id.clone(),
user_id.clone(),
source_id.clone(),
"references".to_string(),
);
let relationship2 = KnowledgeRelationship::new(
central_entity.id.clone(),
related_entity2.id.clone(),
user_id.clone(),
source_id.clone(),
"contains".to_string(),
);
relationship1
.store_relationship(&db)
.await
.with_context(|| "Failed to store relationship 1".to_string())?;
relationship2
.store_relationship(&db)
.await
.with_context(|| "Failed to store relationship 2".to_string())?;
let related_entities =
find_entities_by_relationship_by_id(&db, &central_entity.id, &user_id, usize::MAX)
.await
.with_context(|| "Failed to find entities by relationship".to_string())?;
assert!(
related_entities.len() >= 2,
"Should find related entities in both directions"
);
Ok(())
}
}
+63 -115
View File
@@ -1,10 +1,9 @@
pub mod answer_retrieval;
pub mod answer_retrieval_helper;
pub mod graph;
pub mod pipeline;
pub mod reranking;
pub mod scoring;
pub(crate) mod scoring;
use common::{
error::AppError,
@@ -16,39 +15,28 @@ use common::{
use reranking::RerankerLease;
use tracing::instrument;
// Strategy output variants - defined before pipeline module
/// Result of a retrieval run.
///
/// Chunk retrieval is always performed; entities are only present when the caller
/// requested entity resolution via [`RetrievalConfig::with_entities`].
#[derive(Debug)]
pub enum StrategyOutput {
Entities(Vec<RetrievedEntity>),
pub enum RetrievalOutput {
Chunks(Vec<RetrievedChunk>),
Search(SearchResult),
}
/// Unified search result containing both chunks and entities
#[derive(Debug, Clone)]
pub struct SearchResult {
pub chunks: Vec<RetrievedChunk>,
pub entities: Vec<RetrievedEntity>,
}
impl SearchResult {
pub fn new(chunks: Vec<RetrievedChunk>, entities: Vec<RetrievedEntity>) -> Self {
Self { chunks, entities }
}
pub fn is_empty(&self) -> bool {
self.chunks.is_empty() && self.entities.is_empty()
}
WithEntities {
chunks: Vec<RetrievedChunk>,
entities: Vec<RetrievedEntity>,
},
}
pub use pipeline::{
retrieved_entities_to_json, Diagnostics, StageTimings, RetrievalConfig,
RetrievalStrategy, RetrievalTuning, RetrievalTuningFlags, SearchTarget,
retrieved_entities_to_json, Diagnostics, RetrievalConfig, RetrievalParams, StageKind,
StageTimings,
};
// Backward-compatible type aliases for external consumers
pub type PipelineDiagnostics = Diagnostics;
pub type PipelineStageTimings = StageTimings;
/// Round a score to three decimal places for JSON output.
pub(crate) fn round_score(value: f32) -> f64 {
(f64::from(value) * 1000.0).round() / 1000.0
}
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
#[derive(Debug, Clone)]
@@ -57,7 +45,7 @@ pub struct RetrievedChunk {
pub score: f32,
}
// Final entity representation returned to callers, enriched with ranked chunks.
// Knowledge entity resolved from retrieved chunks, enriched with its contributing chunks.
#[derive(Debug, Clone)]
pub struct RetrievedEntity {
pub entity: KnowledgeEntity,
@@ -65,9 +53,9 @@ pub struct RetrievedEntity {
pub chunks: Vec<RetrievedChunk>,
}
/// Primary orchestrator for the process of retrieving `KnowledgeEntity` values related to an `input_text`
/// Run chunk-first hybrid retrieval for `input_text`, optionally resolving owning entities.
#[instrument(skip_all, fields(user_id))]
pub async fn retrieve_entities(
pub async fn retrieve(
db_client: &SurrealDbClient,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
@@ -75,8 +63,8 @@ pub async fn retrieve_entities(
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<StrategyOutput, AppError> {
let params = pipeline::StrategyParams {
) -> Result<RetrievalOutput, AppError> {
let params = pipeline::RetrievalParams {
db_client,
openai_client,
embedding_provider,
@@ -94,6 +82,7 @@ mod tests {
use anyhow::{self};
use async_openai::Client;
use common::storage::indexes::ensure_runtime;
use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
use common::storage::types::system_settings::SystemSettings;
use uuid::Uuid;
@@ -133,7 +122,7 @@ mod tests {
}
#[tokio::test]
async fn test_default_strategy_retrieves_chunks() -> anyhow::Result<()> {
async fn test_chunk_retrieval_returns_chunks() -> anyhow::Result<()> {
let db = setup_test_db().await?;
let user_id = "test_user";
let chunk = TextChunk::new(
@@ -145,7 +134,7 @@ mod tests {
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?;
let openai_client = Client::new();
let params = pipeline::StrategyParams {
let params = pipeline::RetrievalParams {
db_client: &db,
openai_client: &openai_client,
embedding_provider: None,
@@ -154,12 +143,13 @@ mod tests {
config: RetrievalConfig::default(),
reranker: None,
};
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
.await?;
let results = pipeline::run_with_embedding(params, test_embedding()).await?;
let chunks = match results {
StrategyOutput::Chunks(items) => items,
other => anyhow::bail!("expected chunk results, got {other:?}"),
RetrievalOutput::Chunks(items) => items,
RetrievalOutput::WithEntities { .. } => {
anyhow::bail!("expected chunk results, got entities")
}
};
assert!(!chunks.is_empty(), "Expected at least one retrieval result");
@@ -171,8 +161,7 @@ mod tests {
}
#[tokio::test]
async fn test_default_strategy_returns_chunks_from_multiple_sources(
) -> anyhow::Result<()> {
async fn test_chunk_retrieval_returns_chunks_from_multiple_sources() -> anyhow::Result<()> {
let db = setup_test_db().await?;
let user_id = "multi_source_user";
@@ -191,7 +180,7 @@ mod tests {
TextChunk::store_with_embedding(secondary_chunk, chunk_embedding_secondary(), &db).await?;
let openai_client = Client::new();
let params = pipeline::StrategyParams {
let params = pipeline::RetrievalParams {
db_client: &db,
openai_client: &openai_client,
embedding_provider: None,
@@ -200,12 +189,13 @@ mod tests {
config: RetrievalConfig::default(),
reranker: None,
};
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
.await?;
let results = pipeline::run_with_embedding(params, test_embedding()).await?;
let chunks = match results {
StrategyOutput::Chunks(items) => items,
other => anyhow::bail!("expected chunk results, got {other:?}"),
RetrievalOutput::Chunks(items) => items,
RetrievalOutput::WithEntities { .. } => {
anyhow::bail!("expected chunk results, got entities")
}
};
assert!(chunks.len() >= 2, "Expected chunks from multiple sources");
@@ -223,96 +213,54 @@ mod tests {
}
#[tokio::test]
async fn test_revised_strategy_returns_chunks() -> anyhow::Result<()> {
async fn test_with_entities_resolves_owning_entities() -> anyhow::Result<()> {
let db = setup_test_db().await?;
let user_id = "chunk_user";
let chunk_one = TextChunk::new(
"src_alpha".into(),
"Tokio tasks execute on worker threads managed by the runtime.".into(),
user_id.into(),
);
let chunk_two = TextChunk::new(
"src_beta".into(),
"Hyper utilizes Tokio to drive HTTP state machines efficiently.".into(),
user_id.into(),
);
let user_id = "entity_user";
TextChunk::store_with_embedding(chunk_one.clone(), chunk_embedding_primary(), &db).await?;
TextChunk::store_with_embedding(chunk_two.clone(), chunk_embedding_secondary(), &db).await?;
let config = RetrievalConfig::with_strategy(RetrievalStrategy::Default);
let openai_client = Client::new();
let params = pipeline::StrategyParams {
db_client: &db,
openai_client: &openai_client,
embedding_provider: None,
input_text: "tokio runtime worker behavior",
user_id,
config,
reranker: None,
};
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
.await?;
let chunks = match results {
StrategyOutput::Chunks(items) => items,
other => anyhow::bail!("expected chunk results, got {other:?}"),
};
assert!(
!chunks.is_empty(),
"Revised strategy should return chunk-only responses"
);
assert!(
chunks
.iter()
.any(|entry| entry.chunk.chunk.contains("Tokio")),
"Chunk results should contain relevant snippets"
);
Ok(())
}
#[tokio::test]
async fn test_search_strategy_returns_search_result() -> anyhow::Result<()> {
let db = setup_test_db().await?;
let user_id = "search_user";
let chunk = TextChunk::new(
"search_src".into(),
"Async Rust programming uses Tokio runtime for concurrent tasks.".into(),
"entity_source".into(),
"Async Rust programming uses the Tokio runtime for concurrent tasks.".into(),
user_id.into(),
);
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?;
let config = RetrievalConfig::for_search(pipeline::SearchTarget::Both);
let entity = KnowledgeEntity::new(
"entity_source".into(),
"Tokio Runtime".into(),
"Async runtime for Rust".into(),
KnowledgeEntityType::Document,
None,
user_id.into(),
);
db.store_item(entity).await?;
let openai_client = Client::new();
let params = pipeline::StrategyParams {
let params = pipeline::RetrievalParams {
db_client: &db,
openai_client: &openai_client,
embedding_provider: None,
input_text: "async rust programming",
user_id,
config,
config: RetrievalConfig::with_entities(),
reranker: None,
};
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
.await?;
let results = pipeline::run_with_embedding(params, test_embedding()).await?;
let StrategyOutput::Search(search_result) = results else {
anyhow::bail!("expected Search output");
let RetrievalOutput::WithEntities { chunks, entities } = results else {
anyhow::bail!("expected WithEntities output");
};
// Should return chunks (entities may be empty if none stored)
assert!(!chunks.is_empty(), "Should return chunks");
assert!(
!search_result.chunks.is_empty(),
"Search strategy should return chunks"
entities.iter().any(|e| e.entity.name == "Tokio Runtime"),
"Should resolve the entity owning the retrieved chunk"
);
assert!(
search_result
.chunks
entities
.iter()
.any(|c| c.chunk.chunk.contains("Tokio")),
"Search results should contain relevant chunks"
.find(|e| e.entity.name == "Tokio Runtime")
.is_some_and(|e| !e.chunks.is_empty()),
"Resolved entity should carry its contributing chunks"
);
Ok(())
}
+25 -128
View File
@@ -1,22 +1,5 @@
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::scoring::FusionWeights;
pub use common::utils::config::RetrievalStrategy;
/// Configures which result types to include in Search strategy
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum SearchTarget {
/// Return only text chunks
ChunksOnly,
/// Return only knowledge entities
EntitiesOnly,
/// Return both chunks and entities (default)
#[default]
Both,
}
/// Two-variant flag that serializes as a bool for backward compatibility.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum BoolFlag {
@@ -62,30 +45,20 @@ impl<'de> Deserialize<'de> for BoolFlag {
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct RetrievalTuningFlags {
pub rerank_scores_only: BoolFlag,
pub normalize_vector_scores: BoolFlag,
pub normalize_fts_scores: BoolFlag,
pub chunk_rrf_use_vector: BoolFlag,
pub chunk_rrf_use_fts: BoolFlag,
}
impl RetrievalTuningFlags {
pub const fn rerank_scores_only(&self) -> bool {
pub const fn rerank_scores_only(self) -> bool {
self.rerank_scores_only.as_bool()
}
pub const fn normalize_vector_scores(&self) -> bool {
self.normalize_vector_scores.as_bool()
}
pub const fn normalize_fts_scores(&self) -> bool {
self.normalize_fts_scores.as_bool()
}
pub const fn chunk_rrf_use_vector(&self) -> bool {
pub const fn chunk_rrf_use_vector(self) -> bool {
self.chunk_rrf_use_vector.as_bool()
}
pub const fn chunk_rrf_use_fts(&self) -> bool {
pub const fn chunk_rrf_use_fts(self) -> bool {
self.chunk_rrf_use_fts.as_bool()
}
}
@@ -94,146 +67,70 @@ impl Default for RetrievalTuningFlags {
fn default() -> Self {
Self {
rerank_scores_only: BoolFlag::Disabled,
normalize_vector_scores: BoolFlag::Disabled,
normalize_fts_scores: BoolFlag::Enabled,
chunk_rrf_use_vector: BoolFlag::Enabled,
chunk_rrf_use_fts: BoolFlag::Enabled,
}
}
}
/// Tunable parameters that govern each retrieval stage.
/// Tunable parameters governing the chunk-first hybrid (vector + FTS, RRF-fused) retrieval.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetrievalTuning {
pub entity_vector_take: usize,
/// Number of vector candidates to pull from the chunk embedding index.
pub chunk_vector_take: usize,
pub entity_fts_take: usize,
/// Number of full-text candidates to pull from the chunk index.
pub chunk_fts_take: usize,
pub score_threshold: f32,
pub fallback_min_results: usize,
pub token_budget_estimate: usize,
pub avg_chars_per_token: usize,
/// Maximum chunks attached to each resolved entity.
pub max_chunks_per_entity: usize,
pub lexical_match_weight: f32,
pub graph_traversal_seed_limit: usize,
pub graph_neighbor_limit: usize,
pub graph_score_decay: f32,
pub graph_seed_min_score: f32,
pub graph_vector_inheritance: f32,
/// Blend weight applied when mixing reranker scores with fused scores.
pub rerank_blend_weight: f32,
pub flags: RetrievalTuningFlags,
/// Keep top-N candidates after reranking.
pub rerank_keep_top: usize,
/// Maximum number of chunks returned to callers.
pub chunk_result_cap: usize,
/// Optional fusion weights for hybrid search. If None, uses default weights.
pub fusion_weights: Option<FusionWeights>,
/// Reciprocal rank fusion k value for chunk merging in Revised strategy.
#[serde(default = "default_chunk_rrf_k")]
/// Reciprocal rank fusion k value for chunk merging.
pub chunk_rrf_k: f32,
/// Weight applied to vector ranks in RRF.
#[serde(default = "default_chunk_rrf_vector_weight")]
pub chunk_rrf_vector_weight: f32,
/// Weight applied to chunk FTS ranks in RRF.
#[serde(default = "default_chunk_rrf_fts_weight")]
pub chunk_rrf_fts_weight: f32,
pub flags: RetrievalTuningFlags,
}
impl Default for RetrievalTuning {
fn default() -> Self {
Self {
entity_vector_take: 15,
chunk_vector_take: 20,
entity_fts_take: 10,
chunk_fts_take: 20,
score_threshold: 0.35,
fallback_min_results: 10,
token_budget_estimate: 10000,
avg_chars_per_token: 4,
max_chunks_per_entity: 4,
lexical_match_weight: 0.15,
graph_traversal_seed_limit: 5,
graph_neighbor_limit: 6,
graph_score_decay: 0.75,
graph_seed_min_score: 0.4,
graph_vector_inheritance: 0.6,
rerank_blend_weight: 0.65,
flags: RetrievalTuningFlags::default(),
rerank_keep_top: 8,
chunk_result_cap: 5,
fusion_weights: None,
chunk_rrf_k: default_chunk_rrf_k(),
chunk_rrf_vector_weight: default_chunk_rrf_vector_weight(),
chunk_rrf_fts_weight: default_chunk_rrf_fts_weight(),
chunk_rrf_k: 60.0,
chunk_rrf_vector_weight: 1.0,
chunk_rrf_fts_weight: 1.0,
flags: RetrievalTuningFlags::default(),
}
}
}
/// Wrapper containing tuning plus future flags for per-request overrides.
/// Per-request retrieval configuration.
///
/// The pipeline always performs chunk-first hybrid retrieval. Set `resolve_entities`
/// when a caller additionally needs the `KnowledgeEntity` rows that own the retrieved
/// chunks (search, ingestion linking, relationship suggestion).
#[derive(Debug, Clone, Default)]
pub struct RetrievalConfig {
pub strategy: RetrievalStrategy,
pub tuning: RetrievalTuning,
/// Target for Search strategy (chunks, entities, or both)
pub search_target: SearchTarget,
pub resolve_entities: bool,
}
impl RetrievalConfig {
pub fn new(tuning: RetrievalTuning) -> Self {
/// Chunk retrieval that also resolves the owning knowledge entities.
pub fn with_entities() -> Self {
Self {
strategy: RetrievalStrategy::Default,
tuning,
search_target: SearchTarget::default(),
}
}
pub fn with_strategy(strategy: RetrievalStrategy) -> Self {
Self {
strategy,
tuning: RetrievalTuning::default(),
search_target: SearchTarget::default(),
}
}
pub fn with_tuning(strategy: RetrievalStrategy, tuning: RetrievalTuning) -> Self {
Self {
strategy,
tuning,
search_target: SearchTarget::default(),
}
}
/// Create config for chat retrieval with strategy selection support
pub fn for_chat(strategy: RetrievalStrategy) -> Self {
Self::with_strategy(strategy)
}
/// Create config for relationship suggestion (entity-only retrieval)
pub fn for_relationship_suggestion() -> Self {
Self::with_strategy(RetrievalStrategy::RelationshipSuggestion)
}
/// Create config for ingestion pipeline (entity-only retrieval)
pub fn for_ingestion() -> Self {
Self::with_strategy(RetrievalStrategy::Ingestion)
}
/// Create config for unified search (chunks and/or entities)
pub fn for_search(target: SearchTarget) -> Self {
Self {
strategy: RetrievalStrategy::Search,
tuning: RetrievalTuning::default(),
search_target: target,
resolve_entities: true,
}
}
}
const fn default_chunk_rrf_k() -> f32 {
60.0
}
const fn default_chunk_rrf_vector_weight() -> f32 {
1.0
}
const fn default_chunk_rrf_fts_weight() -> f32 {
1.0
}
+107
View File
@@ -0,0 +1,107 @@
use async_openai::Client;
use common::{
error::AppError,
storage::{db::SurrealDbClient, types::text_chunk::TextChunk},
utils::embedding::EmbeddingProvider,
};
use crate::{reranking::RerankerLease, scoring::Scored, RetrievedChunk, RetrievedEntity};
use super::{
config::RetrievalConfig,
diagnostics::{AssembleStats, Diagnostics, SearchStats},
StageKind, StageTimings, RetrievalParams,
};
/// Mutable working state threaded through every retrieval stage.
pub(crate) struct PipelineContext<'a> {
pub db_client: &'a SurrealDbClient,
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
pub embedding_provider: Option<&'a EmbeddingProvider>,
pub input_text: String,
pub user_id: String,
pub config: RetrievalConfig,
pub query_embedding: Option<Vec<f32>>,
pub chunk_values: Vec<Scored<TextChunk>>,
pub reranker: Option<RerankerLease>,
pub diagnostics: Option<Diagnostics>,
pub entity_results: Vec<RetrievedEntity>,
pub chunk_results: Vec<RetrievedChunk>,
stage_timings: StageTimings,
}
impl<'a> PipelineContext<'a> {
pub fn new(params: RetrievalParams<'a>) -> Self {
Self {
db_client: params.db_client,
openai_client: params.openai_client,
embedding_provider: params.embedding_provider,
input_text: params.input_text.to_owned(),
user_id: params.user_id.to_owned(),
config: params.config,
query_embedding: None,
chunk_values: Vec::new(),
reranker: params.reranker,
diagnostics: None,
entity_results: Vec::new(),
chunk_results: Vec::new(),
stage_timings: StageTimings::default(),
}
}
pub fn with_embedding(params: RetrievalParams<'a>, query_embedding: Vec<f32>) -> Self {
let mut ctx = Self::new(params);
ctx.query_embedding = Some(query_embedding);
ctx
}
pub(crate) fn ensure_embedding(&self) -> Result<&Vec<f32>, Box<AppError>> {
self.query_embedding.as_ref().ok_or_else(|| {
Box::new(AppError::InternalError(
"query embedding missing before candidate search".to_string(),
))
})
}
pub fn enable_diagnostics(&mut self) {
if self.diagnostics.is_none() {
self.diagnostics = Some(Diagnostics::default());
}
}
pub fn diagnostics_enabled(&self) -> bool {
self.diagnostics.is_some()
}
pub(crate) fn record_search(&mut self, stats: SearchStats) {
if let Some(diag) = self.diagnostics.as_mut() {
diag.search = Some(stats);
}
}
pub(crate) fn record_assemble(&mut self, stats: AssembleStats) {
if let Some(diag) = self.diagnostics.as_mut() {
diag.assemble = Some(stats);
}
}
pub fn take_diagnostics(&mut self) -> Option<Diagnostics> {
self.diagnostics.take()
}
pub fn take_stage_timings(&mut self) -> StageTimings {
std::mem::take(&mut self.stage_timings)
}
pub fn record_stage_duration(&mut self, kind: StageKind, duration: std::time::Duration) {
self.stage_timings.record(kind, duration);
}
pub fn take_entity_results(&mut self) -> Vec<RetrievedEntity> {
std::mem::take(&mut self.entity_results)
}
pub fn take_chunk_results(&mut self) -> Vec<RetrievedChunk> {
std::mem::take(&mut self.chunk_results)
}
}
+3 -33
View File
@@ -1,51 +1,21 @@
use serde::Serialize;
/// Captures instrumentation for each hybrid retrieval stage when diagnostics are enabled.
/// Captures instrumentation for the retrieval stages when diagnostics are enabled.
#[derive(Debug, Clone, Default, Serialize)]
pub struct Diagnostics {
pub collect_candidates: Option<CollectCandidatesStats>,
pub enrich_chunks_from_entities: Option<ChunkEnrichmentStats>,
pub search: Option<SearchStats>,
pub assemble: Option<AssembleStats>,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct CollectCandidatesStats {
pub vector_entity_candidates: usize,
pub struct SearchStats {
pub vector_chunk_candidates: usize,
pub fts_entity_candidates: usize,
pub fts_chunk_candidates: usize,
pub vector_chunk_scores: Vec<f32>,
pub fts_chunk_scores: Vec<f32>,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct ChunkEnrichmentStats {
pub filtered_entity_count: usize,
pub fallback_min_results: usize,
pub chunk_sources_considered: usize,
pub chunk_candidates_before_enrichment: usize,
pub chunk_candidates_after_enrichment: usize,
pub top_chunk_scores: Vec<f32>,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct AssembleStats {
pub token_budget_start: usize,
pub token_budget_spent: usize,
pub token_budget_remaining: usize,
pub budget_exhausted: bool,
pub chunks_selected: usize,
pub chunks_skipped_due_budget: usize,
pub entity_count: usize,
pub entity_traces: Vec<EntityAssemblyTrace>,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct EntityAssemblyTrace {
pub entity_id: String,
pub source_id: String,
pub inspected_candidates: usize,
pub selected_chunk_ids: Vec<String>,
pub selected_chunk_scores: Vec<f32>,
pub skipped_due_budget: usize,
}
+119 -209
View File
@@ -1,61 +1,68 @@
mod config;
mod context;
mod diagnostics;
mod stages;
mod strategies;
pub use config::{
RetrievalConfig, RetrievalStrategy, RetrievalTuning, RetrievalTuningFlags, SearchTarget,
};
pub use diagnostics::{
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace, Diagnostics,
};
pub use config::RetrievalConfig;
pub use diagnostics::Diagnostics;
use crate::{reranking::RerankerLease, RetrievedEntity, StrategyOutput};
use crate::{round_score, RetrievalOutput, RetrievedEntity};
use async_openai::Client;
use async_trait::async_trait;
use common::{error::AppError, storage::db::SurrealDbClient};
use std::time::{Duration, Instant};
use tracing::info;
use stages::PipelineContext;
use strategies::{
DefaultStrategyDriver, IngestionDriver, RelationshipSuggestionDriver, SearchStrategyDriver,
use stages::{
ChunkAssembleStage, ChunkRerankStage, ChunkSearchStage, EmbedStage, ResolveEntitiesStage,
};
// Export StrategyOutput publicly from this module
// (it's defined in lib.rs but we re-export it here)
// Stage type enum
/// Identifies a retrieval stage for timing and instrumentation.
///
/// [`StageKind::ALL`] lists every kind in pipeline order; consumers (e.g. the evaluation
/// harness) iterate it generically so that adding a stage requires no changes outside this
/// crate — add the variant, extend `ALL`, and give it a [`StageKind::label`].
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum StageKind {
Embed,
CollectCandidates,
GraphExpansion,
ChunkAttach,
Search,
Rerank,
ResolveEntities,
Assemble,
}
// Pipeline stage trait
impl StageKind {
/// Every stage kind in canonical pipeline order.
pub const ALL: [StageKind; 5] = [
StageKind::Embed,
StageKind::Search,
StageKind::Rerank,
StageKind::ResolveEntities,
StageKind::Assemble,
];
/// Stable, machine-friendly identifier for the stage (used as a metrics key).
pub const fn label(self) -> &'static str {
match self {
StageKind::Embed => "embed",
StageKind::Search => "search",
StageKind::Rerank => "rerank",
StageKind::ResolveEntities => "resolve_entities",
StageKind::Assemble => "assemble",
}
}
}
/// A single composable step in the retrieval pipeline.
#[async_trait]
pub trait Stage: Send + Sync {
pub(crate) trait Stage: Send + Sync {
fn kind(&self) -> StageKind;
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError>;
async fn execute(&self, ctx: &mut context::PipelineContext<'_>) -> Result<(), AppError>;
}
// Type alias for boxed stages
pub type BoxedStage = Box<dyn Stage>;
pub(crate) type BoxedStage = Box<dyn Stage>;
// Strategy driver trait
#[async_trait]
pub trait StrategyDriver: Send + Sync {
type Output;
fn stages(&self) -> Vec<BoxedStage>;
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>>;
}
// Pipeline stage timings tracker
/// Per-stage execution timings recorded during a run.
#[derive(Debug, Default, Clone)]
pub struct StageTimings {
timings: Vec<(StageKind, Duration)>,
@@ -66,41 +73,13 @@ impl StageTimings {
self.timings.push((kind, duration));
}
pub fn into_vec(self) -> Vec<(StageKind, Duration)> {
self.timings
}
// Helper methods to get duration for each stage type (for backward compatibility)
fn get_stage_ms(&self, kind: StageKind) -> u128 {
/// Milliseconds recorded for `kind`, or `0` if the stage did not run.
pub fn stage_ms(&self, kind: StageKind) -> u128 {
self.timings
.iter()
.find(|(k, _)| *k == kind)
.map_or(0, |(_, d)| d.as_millis())
}
pub fn embed_ms(&self) -> u128 {
self.get_stage_ms(StageKind::Embed)
}
pub fn collect_candidates_ms(&self) -> u128 {
self.get_stage_ms(StageKind::CollectCandidates)
}
pub fn graph_expansion_ms(&self) -> u128 {
self.get_stage_ms(StageKind::GraphExpansion)
}
pub fn chunk_attach_ms(&self) -> u128 {
self.get_stage_ms(StageKind::ChunkAttach)
}
pub fn rerank_ms(&self) -> u128 {
self.get_stage_ms(StageKind::Rerank)
}
pub fn assemble_ms(&self) -> u128 {
self.get_stage_ms(StageKind::Assemble)
}
}
pub struct RunOutput<T> {
@@ -109,7 +88,35 @@ pub struct RunOutput<T> {
pub stage_timings: StageTimings,
}
pub async fn execute(params: StrategyParams<'_>) -> Result<StrategyOutput, AppError> {
/// Inputs required to run a retrieval.
pub struct RetrievalParams<'a> {
pub db_client: &'a SurrealDbClient,
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
pub embedding_provider: Option<&'a common::utils::embedding::EmbeddingProvider>,
pub input_text: &'a str,
pub user_id: &'a str,
pub config: RetrievalConfig,
pub reranker: Option<crate::reranking::RerankerLease>,
}
fn build_stages(config: &RetrievalConfig) -> Vec<BoxedStage> {
let mut stages: Vec<BoxedStage> = vec![
Box::new(EmbedStage),
Box::new(ChunkSearchStage),
Box::new(ChunkRerankStage),
];
if config.resolve_entities {
stages.push(Box::new(ResolveEntitiesStage));
}
stages.push(Box::new(ChunkAssembleStage));
stages
}
async fn run(
params: RetrievalParams<'_>,
query_embedding: Option<Vec<f32>>,
capture_diagnostics: bool,
) -> Result<RunOutput<RetrievalOutput>, AppError> {
let input_chars = params.input_text.chars().count();
let input_preview: String = params.input_text.chars().take(120).collect();
let input_preview_clean = input_preview.replace('\n', " ");
@@ -119,110 +126,67 @@ pub async fn execute(params: StrategyParams<'_>) -> Result<StrategyOutput, AppEr
input_chars,
preview_truncated = input_chars > preview_len,
preview = %input_preview_clean,
strategy = %params.config.strategy,
resolve_entities = params.config.resolve_entities,
"Starting retrieval pipeline"
);
let strategy = params.config.strategy;
let search_target = params.config.search_target;
let resolve_entities = params.config.resolve_entities;
let mut ctx = match query_embedding {
Some(embedding) => context::PipelineContext::with_embedding(params, embedding),
None => context::PipelineContext::new(params),
};
match strategy {
RetrievalStrategy::Default => {
let driver = DefaultStrategyDriver::new();
let run = execute_strategy(driver, params, None, false).await?;
Ok(StrategyOutput::Chunks(run.results))
}
RetrievalStrategy::RelationshipSuggestion => {
let driver = RelationshipSuggestionDriver::new();
let run = execute_strategy(driver, params, None, false).await?;
Ok(StrategyOutput::Entities(run.results))
}
RetrievalStrategy::Ingestion => {
let driver = IngestionDriver::new();
let run = execute_strategy(driver, params, None, false).await?;
Ok(StrategyOutput::Entities(run.results))
}
RetrievalStrategy::Search => {
let driver = SearchStrategyDriver::new(search_target);
let run = execute_strategy(driver, params, None, false).await?;
Ok(StrategyOutput::Search(run.results))
}
if capture_diagnostics {
ctx.enable_diagnostics();
}
for stage in build_stages(&ctx.config) {
let start = Instant::now();
stage.execute(&mut ctx).await?;
ctx.record_stage_duration(stage.kind(), start.elapsed());
}
let diagnostics = ctx.take_diagnostics();
let stage_timings = ctx.take_stage_timings();
let chunks = ctx.take_chunk_results();
let results = if resolve_entities {
RetrievalOutput::WithEntities {
chunks,
entities: ctx.take_entity_results(),
}
} else {
RetrievalOutput::Chunks(chunks)
};
Ok(RunOutput {
results,
diagnostics,
stage_timings,
})
}
pub async fn run_pipeline_with_embedding(
params: StrategyParams<'_>,
query_embedding: Vec<f32>,
) -> Result<StrategyOutput, AppError> {
let strategy = params.config.strategy;
let search_target = params.config.search_target;
match strategy {
RetrievalStrategy::Default => {
let driver = DefaultStrategyDriver::new();
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
Ok(StrategyOutput::Chunks(run.results))
}
RetrievalStrategy::RelationshipSuggestion => {
let driver = RelationshipSuggestionDriver::new();
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
Ok(StrategyOutput::Entities(run.results))
}
RetrievalStrategy::Ingestion => {
let driver = IngestionDriver::new();
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
Ok(StrategyOutput::Entities(run.results))
}
RetrievalStrategy::Search => {
let driver = SearchStrategyDriver::new(search_target);
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
Ok(StrategyOutput::Search(run.results))
}
}
/// Run the retrieval pipeline, generating the query embedding internally if needed.
pub async fn execute(params: RetrievalParams<'_>) -> Result<RetrievalOutput, AppError> {
Ok(run(params, None, false).await?.results)
}
pub async fn run_pipeline_with_embedding_with_metrics(
params: StrategyParams<'_>,
/// Run the retrieval pipeline with a pre-computed query embedding.
pub async fn run_with_embedding(
params: RetrievalParams<'_>,
query_embedding: Vec<f32>,
) -> Result<RunOutput<StrategyOutput>, AppError> {
let strategy = params.config.strategy;
match strategy {
RetrievalStrategy::Default => {
let driver = DefaultStrategyDriver::new();
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
Ok(RunOutput {
results: StrategyOutput::Chunks(run.results),
diagnostics: run.diagnostics,
stage_timings: run.stage_timings,
})
}
_ => Err(AppError::InternalError(
"Metrics not supported for this strategy".into(),
)),
}
) -> Result<RetrievalOutput, AppError> {
Ok(run(params, Some(query_embedding), false).await?.results)
}
pub async fn run_pipeline_with_embedding_with_diagnostics(
params: StrategyParams<'_>,
/// Run with a pre-computed embedding, returning results and per-stage timings.
///
/// When `capture_diagnostics` is true, pipeline search/assemble stats are included.
pub async fn run_with_embedding_instrumented(
params: RetrievalParams<'_>,
query_embedding: Vec<f32>,
) -> Result<RunOutput<StrategyOutput>, AppError> {
let strategy = params.config.strategy;
match strategy {
RetrievalStrategy::Default => {
let driver = DefaultStrategyDriver::new();
let run = execute_strategy(driver, params, Some(query_embedding), true).await?;
Ok(RunOutput {
results: StrategyOutput::Chunks(run.results),
diagnostics: run.diagnostics,
stage_timings: run.stage_timings,
})
}
_ => Err(AppError::InternalError(
"Diagnostics not supported for this strategy".into(),
)),
}
capture_diagnostics: bool,
) -> Result<RunOutput<RetrievalOutput>, AppError> {
run(params, Some(query_embedding), capture_diagnostics).await
}
pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value {
@@ -246,57 +210,3 @@ pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::V
})
.collect::<Vec<_>>())
}
pub struct StrategyParams<'a> {
pub db_client: &'a SurrealDbClient,
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
pub embedding_provider: Option<&'a common::utils::embedding::EmbeddingProvider>,
pub input_text: &'a str,
pub user_id: &'a str,
pub config: RetrievalConfig,
pub reranker: Option<RerankerLease>,
}
async fn execute_strategy<D: StrategyDriver>(
driver: D,
params: StrategyParams<'_>,
query_embedding: Option<Vec<f32>>,
capture_diagnostics: bool,
) -> Result<RunOutput<D::Output>, AppError> {
let ctx = match query_embedding {
Some(embedding) => PipelineContext::with_embedding(params, embedding),
None => PipelineContext::new(params),
};
run_with_driver(driver, ctx, capture_diagnostics).await
}
async fn run_with_driver<D: StrategyDriver>(
driver: D,
mut ctx: PipelineContext<'_>,
capture_diagnostics: bool,
) -> Result<RunOutput<D::Output>, AppError> {
if capture_diagnostics {
ctx.enable_diagnostics();
}
for stage in driver.stages() {
let start = Instant::now();
stage.execute(&mut ctx).await?;
ctx.record_stage_duration(stage.kind(), start.elapsed());
}
let diagnostics = ctx.take_diagnostics();
let stage_timings = ctx.take_stage_timings();
let results = driver.finalize(&mut ctx).map_err(|e| *e)?;
Ok(RunOutput {
results,
diagnostics,
stage_timings,
})
}
fn round_score(value: f32) -> f64 {
(f64::from(value) * 1000.0).round() / 1000.0
}
+424
View File
@@ -0,0 +1,424 @@
use async_trait::async_trait;
use common::{
error::AppError,
storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
utils::embedding::generate_embedding,
};
use fastembed::RerankResult;
use std::collections::HashMap;
use tracing::{debug, instrument, warn};
use crate::{
scoring::{
clamp_unit, min_max_normalize, reciprocal_rank_fusion, RrfConfig, Scored,
},
RetrievedChunk, RetrievedEntity,
};
use super::{
config::RetrievalTuning,
context::PipelineContext,
diagnostics::{AssembleStats, SearchStats},
Stage, StageKind,
};
#[derive(Debug, Clone, Copy)]
pub struct EmbedStage;
#[async_trait]
impl Stage for EmbedStage {
fn kind(&self) -> StageKind {
StageKind::Embed
}
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
embed(ctx).await
}
}
#[derive(Debug, Clone, Copy)]
pub struct ChunkSearchStage;
#[async_trait]
impl Stage for ChunkSearchStage {
fn kind(&self) -> StageKind {
StageKind::Search
}
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
search_chunks(ctx).await
}
}
#[derive(Debug, Clone, Copy)]
pub struct ChunkRerankStage;
#[async_trait]
impl Stage for ChunkRerankStage {
fn kind(&self) -> StageKind {
StageKind::Rerank
}
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
rerank_chunks(ctx).await
}
}
#[derive(Debug, Clone, Copy)]
pub struct ResolveEntitiesStage;
#[async_trait]
impl Stage for ResolveEntitiesStage {
fn kind(&self) -> StageKind {
StageKind::ResolveEntities
}
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
resolve_entities(ctx).await
}
}
#[derive(Debug, Clone, Copy)]
pub struct ChunkAssembleStage;
#[async_trait]
impl Stage for ChunkAssembleStage {
fn kind(&self) -> StageKind {
StageKind::Assemble
}
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
assemble_chunks(ctx)
}
}
#[instrument(level = "trace", skip_all)]
pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
if ctx.query_embedding.is_some() {
debug!("Reusing cached query embedding for hybrid retrieval");
} else {
debug!("Generating query embedding for hybrid retrieval");
let embedding = if let Some(provider) = ctx.embedding_provider {
provider.embed(&ctx.input_text).await?
} else {
generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await?
};
ctx.query_embedding = Some(embedding);
}
Ok(())
}
#[instrument(level = "trace", skip_all)]
pub async fn search_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
debug!("Collecting chunk candidates via vector and FTS search");
let embedding = ctx.ensure_embedding().map_err(|e| *e)?.clone();
let tuning = &ctx.config.tuning;
let fts_take = tuning.chunk_fts_take;
let (fts_query, fts_token_count) = normalize_fts_query(&ctx.input_text);
let fts_enabled = tuning.flags.chunk_rrf_use_fts() && fts_take > 0 && !fts_query.is_empty();
let (vector_rows, fts_rows) = tokio::try_join!(
TextChunk::vector_search(
tuning.chunk_vector_take,
embedding,
ctx.db_client,
&ctx.user_id,
),
async {
if fts_enabled {
TextChunk::fts_search(fts_take, &fts_query, ctx.db_client, &ctx.user_id).await
} else {
Ok(Vec::new())
}
}
)?;
let vector_candidates = vector_rows.len();
let fts_candidates = fts_rows.len();
let vector_scored: Vec<Scored<TextChunk>> = vector_rows
.into_iter()
.map(|row| Scored::new(row.chunk).with_vector_score(row.score))
.collect();
let fts_scored: Vec<Scored<TextChunk>> = fts_rows
.into_iter()
.map(|row| Scored::new(row.chunk).with_fts_score(row.score))
.collect();
let mut fts_weight = tuning.chunk_rrf_fts_weight;
if fts_enabled && fts_token_count > 0 && fts_token_count <= 3 {
// For very short keyword queries, lean more on lexical ranking.
fts_weight *= 1.5;
}
let rrf_config = RrfConfig {
k: tuning.chunk_rrf_k,
vector_weight: tuning.chunk_rrf_vector_weight,
fts_weight,
use_vector: tuning.flags.chunk_rrf_use_vector(),
use_fts: tuning.flags.chunk_rrf_use_fts() && fts_candidates > 0,
};
let chunks = reciprocal_rank_fusion(vector_scored, fts_scored, rrf_config);
debug!(
total_merged = chunks.len(),
vector_only = chunks.iter().filter(|c| c.scores.fts.is_none()).count(),
fts_only = chunks.iter().filter(|c| c.scores.vector.is_none()).count(),
both_signals = chunks
.iter()
.filter(|c| c.scores.vector.is_some() && c.scores.fts.is_some())
.count(),
rrf_k = rrf_config.k,
"Merged chunk candidates with RRF"
);
if ctx.diagnostics_enabled() {
ctx.record_search(SearchStats {
vector_chunk_candidates: vector_candidates,
fts_chunk_candidates: fts_candidates,
vector_chunk_scores: sample_scores(&chunks, |chunk| chunk.scores.vector.unwrap_or(0.0)),
fts_chunk_scores: sample_scores(&chunks, |chunk| chunk.scores.fts.unwrap_or(0.0)),
});
}
ctx.chunk_values = chunks;
Ok(())
}
#[instrument(level = "trace", skip_all)]
pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
if ctx.chunk_values.len() <= 1 {
return Ok(());
}
let Some(reranker) = ctx.reranker.as_ref() else {
debug!("No reranker lease provided; skipping chunk rerank stage");
return Ok(());
};
let documents =
build_chunk_rerank_documents(&ctx.chunk_values, ctx.config.tuning.rerank_keep_top.max(1));
if documents.len() <= 1 {
debug!("Skipping chunk reranking stage; insufficient chunk documents");
return Ok(());
}
match reranker.rerank(&ctx.input_text, documents).await {
Ok(results) if !results.is_empty() => {
apply_chunk_rerank_results(&mut ctx.chunk_values, &ctx.config.tuning, results);
}
Ok(_) => debug!("Chunk reranker returned no results; retaining original order"),
Err(err) => warn!(
error = %err,
"Chunk reranking failed; continuing with original ordering"
),
}
Ok(())
}
/// Resolve the `KnowledgeEntity` rows that own the retrieved chunks.
///
/// Entities are derived directly from the (benchmarked) chunk retrieval: chunks are grouped
/// by `source_id`, the owning entities are loaded, scored by their best contributing chunk,
/// and the contributing chunks are attached.
#[instrument(level = "trace", skip_all)]
pub async fn resolve_entities(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
if ctx.chunk_values.is_empty() {
return Ok(());
}
let max_chunks = ctx.config.tuning.max_chunks_per_entity.max(1);
let mut source_order: Vec<String> = Vec::new();
let mut chunks_by_source: HashMap<String, Vec<RetrievedChunk>> = HashMap::new();
let mut best_score: HashMap<String, f32> = HashMap::new();
for scored in &ctx.chunk_values {
let source = scored.item.source_id.clone();
let attached = chunks_by_source.entry(source.clone()).or_default();
if attached.is_empty() {
source_order.push(source.clone());
best_score.insert(source.clone(), scored.fused);
}
if attached.len() < max_chunks {
attached.push(RetrievedChunk {
chunk: scored.item.clone(),
score: scored.fused,
});
}
}
let entities =
KnowledgeEntity::find_by_source_ids(ctx.db_client, &source_order, &ctx.user_id).await?;
let mut entities_by_source: HashMap<String, Vec<KnowledgeEntity>> = HashMap::new();
for entity in entities {
entities_by_source
.entry(entity.source_id.clone())
.or_default()
.push(entity);
}
let mut results = Vec::new();
for source in &source_order {
let Some(entities) = entities_by_source.remove(source) else {
continue;
};
let score = best_score.get(source).copied().unwrap_or(0.0);
let chunks = chunks_by_source.get(source).cloned().unwrap_or_default();
for entity in entities {
results.push(RetrievedEntity {
entity,
score,
chunks: chunks.clone(),
});
}
}
debug!(
sources = source_order.len(),
entities = results.len(),
"Resolved entities from retrieved chunks"
);
ctx.entity_results = results;
Ok(())
}
#[instrument(level = "trace", skip_all)]
#[allow(clippy::result_large_err)]
pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
debug!("Assembling chunk retrieval results");
let mut chunk_values = std::mem::take(&mut ctx.chunk_values);
// Limit how many chunks we return to keep context size reasonable.
let limit = ctx
.config
.tuning
.chunk_result_cap
.max(1)
.min(ctx.config.tuning.chunk_vector_take.max(1));
if chunk_values.len() > limit {
chunk_values.truncate(limit);
}
ctx.chunk_results = chunk_values
.into_iter()
.map(|chunk| RetrievedChunk {
chunk: chunk.item,
score: chunk.fused,
})
.collect();
if ctx.diagnostics_enabled() {
ctx.record_assemble(AssembleStats {
chunks_selected: ctx.chunk_results.len(),
});
}
Ok(())
}
const SCORE_SAMPLE_LIMIT: usize = 8;
fn sample_scores<T, F>(items: &[Scored<T>], extractor: F) -> Vec<f32>
where
F: FnMut(&Scored<T>) -> f32,
{
items.iter().take(SCORE_SAMPLE_LIMIT).map(extractor).collect()
}
fn normalize_fts_query(input: &str) -> (String, usize) {
const STOPWORDS: &[&str] = &["the", "a", "an", "of", "in", "on", "and", "or", "to", "for"];
let mut cleaned = String::with_capacity(input.len());
for ch in input.chars() {
if ch.is_alphanumeric() {
cleaned.extend(ch.to_lowercase());
} else if ch.is_whitespace() {
cleaned.push(' ');
}
}
let mut tokens = Vec::with_capacity(cleaned.len().div_ceil(3));
for token in cleaned.split_whitespace() {
if !STOPWORDS.contains(&token) && !token.is_empty() {
tokens.push(token.to_string());
}
}
let normalized = tokens.join(" ");
(normalized, tokens.len())
}
fn build_chunk_rerank_documents(chunks: &[Scored<TextChunk>], max_chunks: usize) -> Vec<String> {
chunks
.iter()
.take(max_chunks)
.map(|chunk| {
format!(
"Source: {}\nChunk:\n{}",
chunk.item.source_id,
chunk.item.chunk.trim()
)
})
.collect()
}
fn apply_chunk_rerank_results(
chunks: &mut Vec<Scored<TextChunk>>,
tuning: &RetrievalTuning,
results: Vec<RerankResult>,
) {
if results.is_empty() || chunks.is_empty() {
return;
}
let mut remaining: Vec<Option<Scored<TextChunk>>> =
std::mem::take(chunks).into_iter().map(Some).collect();
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
let normalized_scores = min_max_normalize(&raw_scores);
let use_only = tuning.flags.rerank_scores_only();
let blend = if use_only {
1.0
} else {
clamp_unit(tuning.rerank_blend_weight)
};
let mut reranked: Vec<Scored<TextChunk>> = Vec::with_capacity(remaining.len());
for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) {
if let Some(slot) = remaining.get_mut(result.index) {
if let Some(mut candidate) = slot.take() {
let original = candidate.fused;
let blended = if use_only {
clamp_unit(normalized)
} else {
clamp_unit(original * (1.0 - blend) + normalized * blend)
};
candidate.update_fused(blended);
reranked.push(candidate);
}
} else {
warn!(
result_index = result.index,
"Chunk reranker returned out-of-range index; skipping"
);
}
if reranked.len() == remaining.len() {
break;
}
}
reranked.extend(remaining.into_iter().flatten());
let keep_top = tuning.rerank_keep_top;
if keep_top > 0 && reranked.len() > keep_top {
reranked.truncate(keep_top);
}
*chunks = reranked;
}
File diff suppressed because it is too large Load Diff
@@ -1,148 +0,0 @@
use super::{
stages::{
AssembleEntitiesStage, ChunkAssembleStage, ChunkRerankStage, ChunkVectorStage,
CollectCandidatesStage, EmbedStage, GraphExpansionStage, PipelineContext, RerankStage,
},
BoxedStage, StrategyDriver,
};
use crate::{RetrievedChunk, RetrievedEntity};
use common::error::AppError;
pub struct DefaultStrategyDriver;
impl DefaultStrategyDriver {
pub fn new() -> Self {
Self
}
}
impl StrategyDriver for DefaultStrategyDriver {
type Output = Vec<RetrievedChunk>;
fn stages(&self) -> Vec<BoxedStage> {
vec![
Box::new(EmbedStage),
Box::new(ChunkVectorStage),
Box::new(ChunkRerankStage),
Box::new(ChunkAssembleStage),
]
}
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
Ok(ctx.take_chunk_results())
}
}
pub struct RelationshipSuggestionDriver;
impl RelationshipSuggestionDriver {
pub fn new() -> Self {
Self
}
}
impl StrategyDriver for RelationshipSuggestionDriver {
type Output = Vec<RetrievedEntity>;
fn stages(&self) -> Vec<BoxedStage> {
vec![
Box::new(EmbedStage),
Box::new(CollectCandidatesStage),
Box::new(GraphExpansionStage),
// Skip ChunkAttachStage
Box::new(RerankStage),
Box::new(AssembleEntitiesStage),
]
}
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
Ok(ctx.take_entity_results())
}
}
pub struct IngestionDriver;
impl IngestionDriver {
pub fn new() -> Self {
Self
}
}
impl StrategyDriver for IngestionDriver {
type Output = Vec<RetrievedEntity>;
fn stages(&self) -> Vec<BoxedStage> {
vec![
Box::new(EmbedStage),
Box::new(CollectCandidatesStage),
Box::new(GraphExpansionStage),
// Skip ChunkAttachStage
Box::new(RerankStage),
Box::new(AssembleEntitiesStage),
]
}
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
Ok(ctx.take_entity_results())
}
}
use super::config::SearchTarget;
use crate::SearchResult;
/// Search strategy driver that retrieves both chunks and entities
pub struct SearchStrategyDriver {
target: SearchTarget,
}
impl SearchStrategyDriver {
pub fn new(target: SearchTarget) -> Self {
Self { target }
}
}
impl StrategyDriver for SearchStrategyDriver {
type Output = SearchResult;
fn stages(&self) -> Vec<BoxedStage> {
match self.target {
SearchTarget::ChunksOnly => vec![
Box::new(EmbedStage),
Box::new(ChunkVectorStage),
Box::new(ChunkRerankStage),
Box::new(ChunkAssembleStage),
],
SearchTarget::EntitiesOnly => vec![
Box::new(EmbedStage),
Box::new(CollectCandidatesStage),
Box::new(GraphExpansionStage),
Box::new(RerankStage),
Box::new(AssembleEntitiesStage),
],
SearchTarget::Both => vec![
Box::new(EmbedStage),
// Chunk retrieval path
Box::new(ChunkVectorStage),
Box::new(ChunkRerankStage),
Box::new(ChunkAssembleStage),
// Entity retrieval path (runs after chunk stages)
Box::new(CollectCandidatesStage),
Box::new(GraphExpansionStage),
Box::new(RerankStage),
Box::new(AssembleEntitiesStage),
],
}
}
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
let chunks = match self.target {
SearchTarget::EntitiesOnly => Vec::new(),
_ => ctx.take_chunk_results(),
};
let entities = match self.target {
SearchTarget::ChunksOnly => Vec::new(),
_ => ctx.take_entity_results(),
};
Ok(SearchResult::new(chunks, entities))
}
}
@@ -97,8 +97,7 @@ impl RerankerPool {
fn default_pool_size() -> usize {
available_parallelism()
.map(|value| value.get().min(2))
.unwrap_or(2)
.map_or(2, |value| value.get().min(2))
.max(1)
}
@@ -156,6 +155,7 @@ pub struct RerankerLease {
}
impl RerankerLease {
#[allow(clippy::result_large_err)]
pub async fn rerank(
&self,
query: &str,
@@ -165,7 +165,9 @@ impl RerankerLease {
let engine = Arc::clone(&self.engine);
tokio::task::spawn_blocking(move || {
let mut guard = engine.lock().expect("reranker engine mutex poisoned");
let mut guard = engine.lock().map_err(|_| {
AppError::InternalError("reranker engine mutex poisoned".into())
})?;
guard
.rerank(query, documents, false, None)
.map_err(|e| AppError::InternalError(e.to_string()))
+7 -120
View File
@@ -1,14 +1,12 @@
use std::{cmp::Ordering, collections::HashMap};
use common::storage::types::StoredObject;
use serde::{Deserialize, Serialize};
/// Holds optional subscores gathered from different retrieval signals.
/// Holds optional subscores gathered from the vector and full-text retrieval signals.
#[derive(Debug, Clone, Copy, Default)]
pub struct Scores {
pub fts: Option<f32>,
pub vector: Option<f32>,
pub graph: Option<f32>,
}
/// Generic wrapper combining an item with its accumulated retrieval scores.
@@ -40,40 +38,11 @@ impl<T> Scored<T> {
self
}
#[must_use]
pub const fn with_graph_score(mut self, score: f32) -> Self {
self.scores.graph = Some(score);
self
}
pub const fn update_fused(&mut self, fused: f32) {
self.fused = fused;
}
}
/// Weights used for linear score fusion.
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct FusionWeights {
pub vector: f32,
pub fts: f32,
pub graph: f32,
pub multi_bonus: f32,
}
impl Default for FusionWeights {
fn default() -> Self {
// Default weights favor vector search, which typically performs better
// FTS is used as a complement when there's good overlap
// Higher multi_bonus to heavily favor chunks with both signals (the "golden chunk")
Self {
vector: 0.8,
fts: 0.2,
graph: 0.2,
multi_bonus: 0.3, // Increased to boost chunks with both signals
}
}
}
/// Configuration for reciprocal rank fusion.
#[derive(Debug, Clone, Copy)]
pub struct RrfConfig {
@@ -84,29 +53,10 @@ pub struct RrfConfig {
pub use_fts: bool,
}
impl Default for RrfConfig {
fn default() -> Self {
Self {
k: 60.0,
vector_weight: 1.0,
fts_weight: 1.0,
use_vector: true,
use_fts: true,
}
}
}
pub const fn clamp_unit(value: f32) -> f32 {
value.clamp(0.0, 1.0)
}
pub fn distance_to_similarity(distance: f32) -> f32 {
if !distance.is_finite() {
return 0.0;
}
clamp_unit(1.0 / (1.0 + distance.max(0.0)))
}
pub fn min_max_normalize(scores: &[f32]) -> Vec<f32> {
if scores.is_empty() {
return Vec::new();
@@ -147,69 +97,6 @@ pub fn min_max_normalize(scores: &[f32]) -> Vec<f32> {
.collect()
}
pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 {
let vector = scores.vector.unwrap_or(0.0);
let fts = scores.fts.unwrap_or(0.0);
let graph = scores.graph.unwrap_or(0.0);
let mut fused = graph.mul_add(
weights.graph,
vector.mul_add(weights.vector, fts * weights.fts),
);
let signals_present = scores
.vector
.iter()
.chain(scores.fts.iter())
.chain(scores.graph.iter())
.count();
// Boost chunks with multiple signals (especially vector + FTS, the "golden chunk")
if signals_present >= 2 {
// For chunks with both vector and FTS, give a significant boost
// This helps identify the "golden chunk" that appears in both searches
if scores.vector.is_some() && scores.fts.is_some() {
// Multiplicative boost: multiply by (1 + bonus) to scale with the base score
// This ensures high-scoring golden chunks get boosted more than low-scoring ones
fused *= 1.0 + weights.multi_bonus;
} else {
// For other multi-signal combinations (e.g., vector + graph), use additive bonus
fused += weights.multi_bonus;
}
}
clamp_unit(fused)
}
pub fn merge_scored_by_id<T, S: std::hash::BuildHasher>(
target: &mut std::collections::HashMap<String, Scored<T>, S>,
incoming: Vec<Scored<T>>,
) where
T: StoredObject + Clone,
{
for scored in incoming {
let id = scored.item.id().to_owned();
target
.entry(id)
.and_modify(|existing| {
if let Some(score) = scored.scores.vector {
existing.scores.vector = Some(score);
}
if let Some(score) = scored.scores.fts {
existing.scores.fts = Some(score);
}
if let Some(score) = scored.scores.graph {
existing.scores.graph = Some(score);
}
})
.or_insert_with(|| Scored {
item: scored.item.clone(),
scores: scored.scores,
fused: scored.fused,
});
}
}
pub fn sort_by_fused_desc<T>(items: &mut [Scored<T>])
where
T: StoredObject,
@@ -222,6 +109,10 @@ where
});
}
/// Fuse two ranked candidate lists into a single ranking using reciprocal rank fusion.
///
/// This is the sole fusion mechanism for the retrieval pipeline: vector and full-text
/// candidates each contribute `weight / (k + rank + 1)` to a shared fused score.
pub fn reciprocal_rank_fusion<T>(
mut vector_ranked: Vec<Scored<T>>,
mut fts_ranked: Vec<Scored<T>>,
@@ -266,9 +157,7 @@ where
}
}
entry.item = candidate.item;
let rank_f32: f32 = u16::try_from(rank)
.map(f32::from)
.unwrap_or(f32::MAX);
let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from);
entry.fused += vector_weight / (k + rank_f32 + 1.0);
}
}
@@ -296,9 +185,7 @@ where
}
}
entry.item = candidate.item;
let rank_f32: f32 = u16::try_from(rank)
.map(f32::from)
.unwrap_or(f32::MAX);
let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from);
entry.fused += fts_weight / (k + rank_f32 + 1.0);
}
}