retrieval-pipeline: v0

This commit is contained in:
Per Stark
2025-11-18 21:20:27 +01:00
parent 6b7befbd04
commit f535df7e61
32 changed files with 1189 additions and 453 deletions

48
Cargo.lock generated
View File

@@ -1457,26 +1457,6 @@ dependencies = [
"static_assertions",
]
[[package]]
name = "composite-retrieval"
version = "0.1.0"
dependencies = [
"anyhow",
"async-openai",
"axum",
"common",
"fastembed",
"futures",
"serde",
"serde_json",
"state-machines",
"surrealdb",
"thiserror 1.0.69",
"tokio",
"tracing",
"uuid",
]
[[package]]
name = "compression-codecs"
version = "0.4.30"
@@ -2197,7 +2177,6 @@ dependencies = [
"async-trait",
"chrono",
"common",
"composite-retrieval",
"criterion",
"fastembed",
"futures",
@@ -2205,6 +2184,7 @@ dependencies = [
"object_store 0.11.2",
"once_cell",
"rand 0.8.5",
"retrieval-pipeline",
"serde",
"serde_json",
"serde_yaml",
@@ -2880,7 +2860,6 @@ dependencies = [
"chrono",
"chrono-tz",
"common",
"composite-retrieval",
"futures",
"include_dir",
"json-stream-parser",
@@ -2888,6 +2867,7 @@ dependencies = [
"minijinja-autoreload",
"minijinja-contrib",
"minijinja-embed",
"retrieval-pipeline",
"serde",
"serde_json",
"surrealdb",
@@ -3342,13 +3322,13 @@ dependencies = [
"bytes",
"chrono",
"common",
"composite-retrieval",
"dom_smoothie",
"futures",
"headless_chrome",
"lopdf 0.32.0",
"pdf-extract",
"reqwest",
"retrieval-pipeline",
"serde",
"serde_json",
"state-machines",
@@ -3802,10 +3782,10 @@ dependencies = [
"async-openai",
"axum",
"common",
"composite-retrieval",
"futures",
"html-router",
"ingestion-pipeline",
"retrieval-pipeline",
"serde",
"serde_json",
"surrealdb",
@@ -5475,6 +5455,26 @@ dependencies = [
"thiserror 1.0.69",
]
[[package]]
name = "retrieval-pipeline"
version = "0.1.0"
dependencies = [
"anyhow",
"async-openai",
"async-trait",
"axum",
"common",
"fastembed",
"futures",
"serde",
"serde_json",
"surrealdb",
"thiserror 1.0.69",
"tokio",
"tracing",
"uuid",
]
[[package]]
name = "revision"
version = "0.10.0"

View File

@@ -5,7 +5,7 @@ members = [
"api-router",
"html-router",
"ingestion-pipeline",
"composite-retrieval",
"retrieval-pipeline",
"json-stream-parser",
"eval"
]

View File

@@ -54,6 +54,8 @@ pub struct AppConfig {
pub fastembed_show_download_progress: Option<bool>,
#[serde(default)]
pub fastembed_max_length: Option<usize>,
#[serde(default)]
pub retrieval_strategy: Option<String>,
}
fn default_data_dir() -> String {
@@ -117,6 +119,7 @@ impl Default for AppConfig {
fastembed_cache_dir: None,
fastembed_show_download_progress: None,
fastembed_max_length: None,
retrieval_strategy: None,
}
}
}

View File

@@ -1,212 +0,0 @@
mod config;
mod diagnostics;
mod stages;
mod state;
pub use config::{RetrievalConfig, RetrievalTuning};
pub use diagnostics::{
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
PipelineDiagnostics,
};
use crate::{reranking::RerankerLease, RetrievedEntity};
use async_openai::Client;
use common::{error::AppError, storage::db::SurrealDbClient};
use tracing::info;
#[derive(Debug)]
pub struct PipelineRunOutput {
pub results: Vec<RetrievedEntity>,
pub diagnostics: Option<PipelineDiagnostics>,
pub stage_timings: PipelineStageTimings,
}
#[derive(Debug, Clone, Default, serde::Serialize)]
pub struct PipelineStageTimings {
pub collect_candidates_ms: u128,
pub graph_expansion_ms: u128,
pub chunk_attach_ms: u128,
pub rerank_ms: u128,
pub assemble_ms: u128,
}
impl PipelineStageTimings {
fn record_collect_candidates(&mut self, duration: std::time::Duration) {
self.collect_candidates_ms += duration.as_millis() as u128;
}
fn record_graph_expansion(&mut self, duration: std::time::Duration) {
self.graph_expansion_ms += duration.as_millis() as u128;
}
fn record_chunk_attach(&mut self, duration: std::time::Duration) {
self.chunk_attach_ms += duration.as_millis() as u128;
}
fn record_rerank(&mut self, duration: std::time::Duration) {
self.rerank_ms += duration.as_millis() as u128;
}
fn record_assemble(&mut self, duration: std::time::Duration) {
self.assemble_ms += duration.as_millis() as u128;
}
}
/// Drives the retrieval pipeline from embedding through final assembly.
pub async fn run_pipeline(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<Vec<RetrievedEntity>, AppError> {
let input_chars = input_text.chars().count();
let input_preview: String = input_text.chars().take(120).collect();
let input_preview_clean = input_preview.replace('\n', " ");
let preview_len = input_preview_clean.chars().count();
info!(
%user_id,
input_chars,
preview_truncated = input_chars > preview_len,
preview = %input_preview_clean,
"Starting ingestion retrieval pipeline"
);
let ctx = stages::PipelineContext::new(
db_client,
openai_client,
input_text.to_owned(),
user_id.to_owned(),
config,
reranker,
);
let outcome = run_pipeline_internal(ctx, false).await?;
Ok(outcome.results)
}
pub async fn run_pipeline_with_embedding(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<Vec<RetrievedEntity>, AppError> {
let ctx = stages::PipelineContext::with_embedding(
db_client,
openai_client,
query_embedding,
input_text.to_owned(),
user_id.to_owned(),
config,
reranker,
);
let outcome = run_pipeline_internal(ctx, false).await?;
Ok(outcome.results)
}
/// Runs the pipeline with a precomputed embedding and returns stage metrics.
pub async fn run_pipeline_with_embedding_with_metrics(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<PipelineRunOutput, AppError> {
let ctx = stages::PipelineContext::with_embedding(
db_client,
openai_client,
query_embedding,
input_text.to_owned(),
user_id.to_owned(),
config,
reranker,
);
run_pipeline_internal(ctx, false).await
}
pub async fn run_pipeline_with_embedding_with_diagnostics(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<PipelineRunOutput, AppError> {
let ctx = stages::PipelineContext::with_embedding(
db_client,
openai_client,
query_embedding,
input_text.to_owned(),
user_id.to_owned(),
config,
reranker,
);
run_pipeline_internal(ctx, true).await
}
/// Helper exposed for tests to convert retrieved entities into downstream prompt JSON.
pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value {
serde_json::json!(entities
.iter()
.map(|entry| {
serde_json::json!({
"KnowledgeEntity": {
"id": entry.entity.id,
"name": entry.entity.name,
"description": entry.entity.description,
"score": round_score(entry.score),
"chunks": entry.chunks.iter().map(|chunk| {
serde_json::json!({
"score": round_score(chunk.score),
"content": chunk.chunk.chunk
})
}).collect::<Vec<_>>()
}
})
})
.collect::<Vec<_>>())
}
async fn run_pipeline_internal(
mut ctx: stages::PipelineContext<'_>,
capture_diagnostics: bool,
) -> Result<PipelineRunOutput, AppError> {
if capture_diagnostics {
ctx.enable_diagnostics();
}
let results = drive_pipeline(&mut ctx).await?;
let diagnostics = ctx.take_diagnostics();
Ok(PipelineRunOutput {
results,
diagnostics,
stage_timings: ctx.take_stage_timings(),
})
}
async fn drive_pipeline(
ctx: &mut stages::PipelineContext<'_>,
) -> Result<Vec<RetrievedEntity>, AppError> {
let machine = state::ready();
let machine = stages::embed(machine, ctx).await?;
let machine = stages::collect_candidates(machine, ctx).await?;
let machine = stages::expand_graph(machine, ctx).await?;
let machine = stages::attach_chunks(machine, ctx).await?;
let machine = stages::rerank(machine, ctx).await?;
let results = stages::assemble(machine, ctx)?;
Ok(results)
}
fn round_score(value: f32) -> f64 {
(f64::from(value) * 1000.0).round() / 1000.0
}

View File

@@ -1,27 +0,0 @@
use state_machines::state_machine;
state_machine! {
name: HybridRetrievalMachine,
state: HybridRetrievalState,
initial: Ready,
states: [Ready, Embedded, CandidatesLoaded, GraphExpanded, ChunksAttached, Reranked, Completed, Failed],
events {
embed { transition: { from: Ready, to: Embedded } }
collect_candidates { transition: { from: Embedded, to: CandidatesLoaded } }
expand_graph { transition: { from: CandidatesLoaded, to: GraphExpanded } }
attach_chunks { transition: { from: GraphExpanded, to: ChunksAttached } }
rerank { transition: { from: ChunksAttached, to: Reranked } }
assemble { transition: { from: Reranked, to: Completed } }
abort {
transition: { from: Ready, to: Failed }
transition: { from: CandidatesLoaded, to: Failed }
transition: { from: GraphExpanded, to: Failed }
transition: { from: ChunksAttached, to: Failed }
transition: { from: Reranked, to: Failed }
}
}
}
pub fn ready() -> HybridRetrievalMachine<(), Ready> {
HybridRetrievalMachine::new(())
}

View File

@@ -38,7 +38,7 @@ url = { workspace = true }
uuid = { workspace = true }
common = { path = "../common" }
composite-retrieval = { path = "../composite-retrieval" }
retrieval-pipeline = { path = "../retrieval-pipeline" }
json-stream-parser = { path = "../json-stream-parser" }
[build-dependencies]

View File

@@ -1,7 +1,7 @@
use common::storage::{db::SurrealDbClient, store::StorageManager};
use common::utils::template_engine::{ProvidesTemplateEngine, TemplateEngine};
use common::{create_template_engine, storage::db::ProvidesDb, utils::config::AppConfig};
use composite_retrieval::reranking::RerankerPool;
use retrieval_pipeline::{reranking::RerankerPool, RetrievalStrategy};
use std::sync::Arc;
use tracing::debug;
@@ -40,6 +40,14 @@ impl HtmlState {
reranker_pool,
})
}
pub fn retrieval_strategy(&self) -> RetrievalStrategy {
self.config
.retrieval_strategy
.as_deref()
.and_then(|value| value.parse().ok())
.unwrap_or(RetrievalStrategy::Initial)
}
}
impl ProvidesDb for HtmlState {
fn db(&self) -> &Arc<SurrealDbClient> {

View File

@@ -8,16 +8,16 @@ use axum::{
Sse,
},
};
use composite_retrieval::{
answer_retrieval::{create_chat_request, create_user_message_with_history, LLMResponseFormat},
retrieve_entities, retrieved_entities_to_json,
};
use futures::{
stream::{self, once},
Stream, StreamExt, TryStreamExt,
};
use json_stream_parser::JsonStreamParser;
use minijinja::Value;
use retrieval_pipeline::{
answer_retrieval::{create_chat_request, create_user_message_with_history, LLMResponseFormat},
retrieve_entities, retrieved_entities_to_json, RetrievalConfig, StrategyOutput,
};
use serde::{Deserialize, Serialize};
use serde_json::from_str;
use tokio::sync::{mpsc::channel, Mutex};
@@ -123,16 +123,24 @@ pub async fn get_response_stream(
None => None,
};
let mut retrieval_config = RetrievalConfig::default();
retrieval_config.strategy = state.retrieval_strategy();
let entities = match retrieve_entities(
&state.db,
&state.openai_client,
&user_message.content,
&user.id,
retrieval_config,
rerank_lease,
)
.await
{
Ok(entities) => entities,
Ok(StrategyOutput::Entities(entities)) => entities,
Ok(StrategyOutput::Chunks(_)) => {
return Sse::new(create_error_stream(
"Chunk-only retrieval results are not supported in this route",
))
}
Err(_e) => {
return Sse::new(create_error_stream("Failed to retrieve knowledge entities"));
}

View File

@@ -24,7 +24,7 @@ use common::{
},
utils::embedding::generate_embedding,
};
use composite_retrieval::{retrieve_entities, RetrievedEntity};
use retrieval_pipeline::{retrieve_entities, RetrievalConfig, RetrievedEntity, StrategyOutput};
use tracing::debug;
use uuid::Uuid;
@@ -284,11 +284,15 @@ pub async fn suggest_knowledge_relationships(
None => None,
};
if let Ok(results) = retrieve_entities(
let mut retrieval_config = RetrievalConfig::default();
retrieval_config.strategy = state.retrieval_strategy();
if let Ok(StrategyOutput::Entities(results)) = retrieve_entities(
&state.db,
&state.openai_client,
&query,
&user.id,
retrieval_config,
rerank_lease,
)
.await

View File

@@ -32,7 +32,7 @@ lopdf = "0.32"
bytes = { workspace = true }
common = { path = "../common" }
composite-retrieval = { path = "../composite-retrieval" }
retrieval-pipeline = { path = "../retrieval-pipeline" }
async-trait = { workspace = true }
state-machines = { workspace = true }
[features]

View File

@@ -11,7 +11,7 @@ use common::{
},
},
};
use composite_retrieval::RetrievedEntity;
use retrieval_pipeline::RetrievedEntity;
use tracing::error;
use super::enrichment_result::LLMEnrichmentResult;

View File

@@ -28,7 +28,7 @@ use common::{
},
utils::config::AppConfig,
};
use composite_retrieval::reranking::RerankerPool;
use retrieval_pipeline::reranking::RerankerPool;
use tracing::{debug, info, warn};
use self::{

View File

@@ -19,8 +19,9 @@ use common::{
},
utils::{config::AppConfig, embedding::generate_embedding},
};
use composite_retrieval::{
reranking::RerankerPool, retrieve_entities, retrieved_entities_to_json, RetrievedEntity,
use retrieval_pipeline::{
reranking::RerankerPool, retrieve_entities, retrieved_entities_to_json, RetrievalConfig,
RetrievalStrategy, RetrievedEntity, StrategyOutput,
};
use text_splitter::TextSplitter;
@@ -124,6 +125,14 @@ impl DefaultPipelineServices {
Ok(request)
}
fn configured_strategy(&self) -> RetrievalStrategy {
self.config
.retrieval_strategy
.as_deref()
.and_then(|value| value.parse().ok())
.unwrap_or(RetrievalStrategy::Initial)
}
async fn perform_analysis(
&self,
request: CreateChatCompletionRequest,
@@ -178,14 +187,24 @@ impl PipelineServices for DefaultPipelineServices {
None => None,
};
retrieve_entities(
let mut config = RetrievalConfig::default();
config.strategy = self.configured_strategy();
match retrieve_entities(
&self.db,
&self.openai_client,
&input_text,
&content.user_id,
config,
rerank_lease,
)
.await
{
Ok(StrategyOutput::Entities(entities)) => Ok(entities),
Ok(StrategyOutput::Chunks(_)) => Err(AppError::InternalError(
"Chunk-only retrieval is not supported in ingestion".into(),
)),
Err(err) => Err(err),
}
}
async fn run_enrichment(

View File

@@ -16,7 +16,7 @@ use common::{
},
},
};
use composite_retrieval::{RetrievedChunk, RetrievedEntity};
use retrieval_pipeline::{RetrievedChunk, RetrievedEntity};
use tokio::sync::Mutex;
use uuid::Uuid;

View File

@@ -25,7 +25,7 @@ ingestion-pipeline = { path = "../ingestion-pipeline" }
api-router = { path = "../api-router" }
html-router = { path = "../html-router" }
common = { path = "../common" }
composite-retrieval = { path = "../composite-retrieval" }
retrieval-pipeline = { path = "../retrieval-pipeline" }
[dev-dependencies]
tower = "0.5"

View File

@@ -3,9 +3,9 @@ use axum::{extract::FromRef, Router};
use common::{
storage::db::SurrealDbClient, storage::store::StorageManager, utils::config::get_config,
};
use composite_retrieval::reranking::RerankerPool;
use html_router::{html_routes, html_state::HtmlState};
use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop};
use retrieval_pipeline::reranking::RerankerPool;
use std::sync::Arc;
use tracing::{error, info};
use tracing_subscriber::{fmt, prelude::*, EnvFilter};

View File

@@ -5,8 +5,8 @@ use axum::{extract::FromRef, Router};
use common::{
storage::db::SurrealDbClient, storage::store::StorageManager, utils::config::get_config,
};
use composite_retrieval::reranking::RerankerPool;
use html_router::{html_routes, html_state::HtmlState};
use retrieval_pipeline::reranking::RerankerPool;
use tracing::info;
use tracing_subscriber::{fmt, prelude::*, EnvFilter};

View File

@@ -3,8 +3,8 @@ use std::sync::Arc;
use common::{
storage::db::SurrealDbClient, storage::store::StorageManager, utils::config::get_config,
};
use composite_retrieval::reranking::RerankerPool;
use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop};
use retrieval_pipeline::reranking::RerankerPool;
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
#[tokio::main]

View File

@@ -1,5 +1,5 @@
[package]
name = "composite-retrieval"
name = "retrieval-pipeline"
version = "0.1.0"
edition = "2021"
license = "AGPL-3.0-or-later"
@@ -18,8 +18,8 @@ serde_json = { workspace = true }
surrealdb = { workspace = true }
futures = { workspace = true }
async-openai = { workspace = true }
async-trait = { workspace = true }
uuid = { workspace = true }
fastembed = { workspace = true }
common = { path = "../common", features = ["test-utils"] }
state-machines = { workspace = true }

View File

@@ -17,7 +17,10 @@ use common::{
use reranking::RerankerLease;
use tracing::instrument;
pub use pipeline::{retrieved_entities_to_json, RetrievalConfig, RetrievalTuning};
pub use pipeline::{
retrieved_entities_to_json, PipelineDiagnostics, PipelineStageTimings, RetrievalConfig,
RetrievalStrategy, RetrievalTuning, StrategyOutput,
};
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
#[derive(Debug, Clone)]
@@ -41,14 +44,15 @@ pub async fn retrieve_entities(
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<Vec<RetrievedEntity>, AppError> {
) -> Result<StrategyOutput, AppError> {
pipeline::run_pipeline(
db_client,
openai_client,
input_text,
user_id,
RetrievalConfig::default(),
config,
reranker,
)
.await
@@ -63,7 +67,7 @@ mod tests {
knowledge_relationship::KnowledgeRelationship,
text_chunk::TextChunk,
};
use pipeline::RetrievalConfig;
use pipeline::{RetrievalConfig, RetrievalStrategy};
use uuid::Uuid;
fn test_embedding() -> Vec<f32> {
@@ -151,11 +155,16 @@ mod tests {
.await
.expect("Hybrid retrieval failed");
let entities = match results {
StrategyOutput::Entities(items) => items,
other => panic!("expected entity results, got {:?}", other),
};
assert!(
!results.is_empty(),
!entities.is_empty(),
"Expected at least one retrieval result"
);
let top = &results[0];
let top = &entities[0];
assert!(
top.entity.name.contains("Rust"),
"Expected Rust entity to be ranked first"
@@ -242,8 +251,13 @@ mod tests {
.await
.expect("Hybrid retrieval failed");
let entities = match results {
StrategyOutput::Entities(items) => items,
other => panic!("expected entity results, got {:?}", other),
};
let mut neighbor_entry = None;
for entity in &results {
for entity in &entities {
if entity.entity.id == neighbor.id {
neighbor_entry = Some(entity.clone());
}
@@ -264,4 +278,59 @@ mod tests {
"Neighbor entity should surface its own supporting chunks"
);
}
#[tokio::test]
async fn test_revised_strategy_returns_chunks() {
let db = setup_test_db().await;
let user_id = "chunk_user";
let chunk_one = TextChunk::new(
"src_alpha".into(),
"Tokio tasks execute on worker threads managed by the runtime.".into(),
chunk_embedding_primary(),
user_id.into(),
);
let chunk_two = TextChunk::new(
"src_beta".into(),
"Hyper utilizes Tokio to drive HTTP state machines efficiently.".into(),
chunk_embedding_secondary(),
user_id.into(),
);
db.store_item(chunk_one.clone())
.await
.expect("Failed to store chunk one");
db.store_item(chunk_two.clone())
.await
.expect("Failed to store chunk two");
let config = RetrievalConfig::with_strategy(RetrievalStrategy::Revised);
let openai_client = Client::new();
let results = pipeline::run_pipeline_with_embedding(
&db,
&openai_client,
test_embedding(),
"tokio runtime worker behavior",
user_id,
config,
None,
)
.await
.expect("Revised retrieval failed");
let chunks = match results {
StrategyOutput::Chunks(items) => items,
other => panic!("expected chunk output, got {:?}", other),
};
assert!(
!chunks.is_empty(),
"Revised strategy should return chunk-only responses"
);
assert!(
chunks
.iter()
.any(|entry| entry.chunk.chunk.contains("Tokio")),
"Chunk results should contain relevant snippets"
);
}
}

View File

@@ -1,4 +1,40 @@
use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RetrievalStrategy {
Initial,
Revised,
}
impl Default for RetrievalStrategy {
fn default() -> Self {
Self::Initial
}
}
impl std::str::FromStr for RetrievalStrategy {
type Err = String;
fn from_str(value: &str) -> Result<Self, Self::Err> {
match value.to_ascii_lowercase().as_str() {
"initial" => Ok(Self::Initial),
"revised" => Ok(Self::Revised),
other => Err(format!("unknown retrieval strategy '{other}'")),
}
}
}
impl fmt::Display for RetrievalStrategy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let label = match self {
RetrievalStrategy::Initial => "initial",
RetrievalStrategy::Revised => "revised",
};
f.write_str(label)
}
}
/// Tunable parameters that govern each retrieval stage.
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -51,18 +87,34 @@ impl Default for RetrievalTuning {
/// Wrapper containing tuning plus future flags for per-request overrides.
#[derive(Debug, Clone)]
pub struct RetrievalConfig {
pub strategy: RetrievalStrategy,
pub tuning: RetrievalTuning,
}
impl RetrievalConfig {
pub fn new(tuning: RetrievalTuning) -> Self {
Self { tuning }
Self {
strategy: RetrievalStrategy::Initial,
tuning,
}
}
pub fn with_strategy(strategy: RetrievalStrategy) -> Self {
Self {
strategy,
tuning: RetrievalTuning::default(),
}
}
pub fn with_tuning(strategy: RetrievalStrategy, tuning: RetrievalTuning) -> Self {
Self { strategy, tuning }
}
}
impl Default for RetrievalConfig {
fn default() -> Self {
Self {
strategy: RetrievalStrategy::default(),
tuning: RetrievalTuning::default(),
}
}

View File

@@ -0,0 +1,397 @@
mod config;
mod diagnostics;
mod stages;
mod strategies;
pub use config::{RetrievalConfig, RetrievalStrategy, RetrievalTuning};
pub use diagnostics::{
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
PipelineDiagnostics,
};
use crate::{reranking::RerankerLease, RetrievedChunk, RetrievedEntity};
use async_openai::Client;
use async_trait::async_trait;
use common::{error::AppError, storage::db::SurrealDbClient};
use std::time::{Duration, Instant};
use tracing::info;
use stages::PipelineContext;
use strategies::{InitialStrategyDriver, RevisedStrategyDriver};
#[derive(Debug, Clone)]
pub enum StrategyOutput {
Entities(Vec<RetrievedEntity>),
Chunks(Vec<RetrievedChunk>),
}
impl StrategyOutput {
pub fn as_entities(&self) -> Option<&[RetrievedEntity]> {
match self {
StrategyOutput::Entities(items) => Some(items),
_ => None,
}
}
pub fn into_entities(self) -> Option<Vec<RetrievedEntity>> {
match self {
StrategyOutput::Entities(items) => Some(items),
_ => None,
}
}
pub fn as_chunks(&self) -> Option<&[RetrievedChunk]> {
match self {
StrategyOutput::Chunks(items) => Some(items),
_ => None,
}
}
pub fn into_chunks(self) -> Option<Vec<RetrievedChunk>> {
match self {
StrategyOutput::Chunks(items) => Some(items),
_ => None,
}
}
}
#[derive(Debug)]
pub struct PipelineRunOutput<T> {
pub results: T,
pub diagnostics: Option<PipelineDiagnostics>,
pub stage_timings: PipelineStageTimings,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StageKind {
Embed,
CollectCandidates,
GraphExpansion,
ChunkAttach,
Rerank,
Assemble,
}
#[derive(Debug, Clone, Default, serde::Serialize)]
pub struct PipelineStageTimings {
pub embed_ms: u128,
pub collect_candidates_ms: u128,
pub graph_expansion_ms: u128,
pub chunk_attach_ms: u128,
pub rerank_ms: u128,
pub assemble_ms: u128,
}
impl PipelineStageTimings {
pub fn record(&mut self, kind: StageKind, duration: Duration) {
let elapsed = duration.as_millis() as u128;
match kind {
StageKind::Embed => self.embed_ms += elapsed,
StageKind::CollectCandidates => self.collect_candidates_ms += elapsed,
StageKind::GraphExpansion => self.graph_expansion_ms += elapsed,
StageKind::ChunkAttach => self.chunk_attach_ms += elapsed,
StageKind::Rerank => self.rerank_ms += elapsed,
StageKind::Assemble => self.assemble_ms += elapsed,
}
}
}
#[async_trait]
pub trait PipelineStage: Send + Sync {
fn kind(&self) -> StageKind;
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError>;
}
pub type BoxedStage = Box<dyn PipelineStage + Send + Sync>;
pub trait StrategyDriver {
type Output;
fn strategy(&self) -> RetrievalStrategy;
fn stages(&self) -> Vec<BoxedStage>;
fn override_tuning(&self, _config: &mut RetrievalConfig) {}
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError>;
}
pub async fn run_pipeline(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<StrategyOutput, AppError> {
let input_chars = input_text.chars().count();
let input_preview: String = input_text.chars().take(120).collect();
let input_preview_clean = input_preview.replace('\n', " ");
let preview_len = input_preview_clean.chars().count();
info!(
%user_id,
input_chars,
preview_truncated = input_chars > preview_len,
preview = %input_preview_clean,
"Starting ingestion retrieval pipeline"
);
if config.strategy == RetrievalStrategy::Initial {
let driver = InitialStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
None,
input_text,
user_id,
config,
reranker,
false,
)
.await?;
return Ok(StrategyOutput::Entities(run.results));
}
let driver = RevisedStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
None,
input_text,
user_id,
config,
reranker,
false,
)
.await?;
Ok(StrategyOutput::Chunks(run.results))
}
pub async fn run_pipeline_with_embedding(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<StrategyOutput, AppError> {
if config.strategy == RetrievalStrategy::Initial {
let driver = InitialStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
false,
)
.await?;
return Ok(StrategyOutput::Entities(run.results));
}
let driver = RevisedStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
false,
)
.await?;
Ok(StrategyOutput::Chunks(run.results))
}
pub async fn run_pipeline_with_embedding_with_metrics(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
if config.strategy == RetrievalStrategy::Initial {
let driver = InitialStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
false,
)
.await?;
return Ok(PipelineRunOutput {
results: StrategyOutput::Entities(run.results),
diagnostics: run.diagnostics,
stage_timings: run.stage_timings,
});
}
let driver = RevisedStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
false,
)
.await?;
Ok(PipelineRunOutput {
results: StrategyOutput::Chunks(run.results),
diagnostics: run.diagnostics,
stage_timings: run.stage_timings,
})
}
pub async fn run_pipeline_with_embedding_with_diagnostics(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
if config.strategy == RetrievalStrategy::Initial {
let driver = InitialStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
true,
)
.await?;
return Ok(PipelineRunOutput {
results: StrategyOutput::Entities(run.results),
diagnostics: run.diagnostics,
stage_timings: run.stage_timings,
});
}
let driver = RevisedStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
true,
)
.await?;
Ok(PipelineRunOutput {
results: StrategyOutput::Chunks(run.results),
diagnostics: run.diagnostics,
stage_timings: run.stage_timings,
})
}
pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value {
serde_json::json!(entities
.iter()
.map(|entry| {
serde_json::json!({
"KnowledgeEntity": {
"id": entry.entity.id,
"name": entry.entity.name,
"description": entry.entity.description,
"score": round_score(entry.score),
"chunks": entry.chunks.iter().map(|chunk| {
serde_json::json!({
"score": round_score(chunk.score),
"content": chunk.chunk.chunk
})
}).collect::<Vec<_>>()
}
})
})
.collect::<Vec<_>>())
}
async fn execute_strategy<D: StrategyDriver>(
driver: D,
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
query_embedding: Option<Vec<f32>>,
input_text: &str,
user_id: &str,
mut config: RetrievalConfig,
reranker: Option<RerankerLease>,
capture_diagnostics: bool,
) -> Result<PipelineRunOutput<D::Output>, AppError> {
driver.override_tuning(&mut config);
let ctx = match query_embedding {
Some(embedding) => PipelineContext::with_embedding(
db_client,
openai_client,
embedding,
input_text.to_owned(),
user_id.to_owned(),
config,
reranker,
),
None => PipelineContext::new(
db_client,
openai_client,
input_text.to_owned(),
user_id.to_owned(),
config,
reranker,
),
};
run_with_driver(driver, ctx, capture_diagnostics).await
}
async fn run_with_driver<D: StrategyDriver>(
driver: D,
mut ctx: PipelineContext<'_>,
capture_diagnostics: bool,
) -> Result<PipelineRunOutput<D::Output>, AppError> {
if capture_diagnostics {
ctx.enable_diagnostics();
}
for stage in driver.stages() {
let start = Instant::now();
stage.execute(&mut ctx).await?;
ctx.record_stage_duration(stage.kind(), start.elapsed());
}
let diagnostics = ctx.take_diagnostics();
let stage_timings = ctx.take_stage_timings();
let results = driver.finalize(&mut ctx)?;
Ok(PipelineRunOutput {
results,
diagnostics,
stage_timings,
})
}
fn round_score(value: f32) -> f64 {
(f64::from(value) * 1000.0).round() / 1000.0
}

View File

@@ -1,4 +1,5 @@
use async_openai::Client;
use async_trait::async_trait;
use common::{
error::AppError,
storage::{
@@ -9,11 +10,9 @@ use common::{
};
use fastembed::RerankResult;
use futures::{stream::FuturesUnordered, StreamExt};
use state_machines::core::GuardError;
use std::{
cmp::Ordering,
collections::{HashMap, HashSet},
time::Instant,
};
use tracing::{debug, instrument, warn};
@@ -25,21 +24,20 @@ use crate::{
clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc,
FusionWeights, Scored,
},
vector::find_items_by_vector_similarity_with_embedding,
vector::{
find_chunk_snippets_by_vector_similarity_with_embedding,
find_items_by_vector_similarity_with_embedding, ChunkSnippet,
},
RetrievedChunk, RetrievedEntity,
};
use super::{
config::RetrievalConfig,
config::{RetrievalConfig, RetrievalTuning},
diagnostics::{
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
PipelineDiagnostics,
},
state::{
CandidatesLoaded, ChunksAttached, Embedded, GraphExpanded, HybridRetrievalMachine, Ready,
Reranked,
},
PipelineStageTimings,
PipelineStage, PipelineStageTimings, StageKind,
};
pub struct PipelineContext<'a> {
@@ -53,8 +51,11 @@ pub struct PipelineContext<'a> {
pub chunk_candidates: HashMap<String, Scored<TextChunk>>,
pub filtered_entities: Vec<Scored<KnowledgeEntity>>,
pub chunk_values: Vec<Scored<TextChunk>>,
pub revised_chunk_values: Vec<Scored<ChunkSnippet>>,
pub reranker: Option<RerankerLease>,
pub diagnostics: Option<PipelineDiagnostics>,
pub entity_results: Vec<RetrievedEntity>,
pub chunk_results: Vec<RetrievedChunk>,
stage_timings: PipelineStageTimings,
}
@@ -78,8 +79,11 @@ impl<'a> PipelineContext<'a> {
chunk_candidates: HashMap::new(),
filtered_entities: Vec::new(),
chunk_values: Vec::new(),
revised_chunk_values: Vec::new(),
reranker,
diagnostics: None,
entity_results: Vec::new(),
chunk_results: Vec::new(),
stage_timings: PipelineStageTimings::default(),
}
}
@@ -145,36 +149,151 @@ impl<'a> PipelineContext<'a> {
self.diagnostics.take()
}
pub fn record_collect_candidates_timing(&mut self, duration: std::time::Duration) {
self.stage_timings.record_collect_candidates(duration);
}
pub fn record_graph_expansion_timing(&mut self, duration: std::time::Duration) {
self.stage_timings.record_graph_expansion(duration);
}
pub fn record_chunk_attach_timing(&mut self, duration: std::time::Duration) {
self.stage_timings.record_chunk_attach(duration);
}
pub fn record_rerank_timing(&mut self, duration: std::time::Duration) {
self.stage_timings.record_rerank(duration);
}
pub fn record_assemble_timing(&mut self, duration: std::time::Duration) {
self.stage_timings.record_assemble(duration);
}
pub fn take_stage_timings(&mut self) -> PipelineStageTimings {
std::mem::take(&mut self.stage_timings)
}
pub fn record_stage_duration(&mut self, kind: StageKind, duration: std::time::Duration) {
self.stage_timings.record(kind, duration);
}
pub fn take_entity_results(&mut self) -> Vec<RetrievedEntity> {
std::mem::take(&mut self.entity_results)
}
pub fn take_chunk_results(&mut self) -> Vec<RetrievedChunk> {
std::mem::take(&mut self.chunk_results)
}
}
#[derive(Debug, Clone, Copy)]
pub struct EmbedStage;
#[async_trait]
impl PipelineStage for EmbedStage {
fn kind(&self) -> StageKind {
StageKind::Embed
}
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
embed(ctx).await
}
}
#[derive(Debug, Clone, Copy)]
pub struct CollectCandidatesStage;
#[async_trait]
impl PipelineStage for CollectCandidatesStage {
fn kind(&self) -> StageKind {
StageKind::CollectCandidates
}
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
collect_candidates(ctx).await
}
}
#[derive(Debug, Clone, Copy)]
pub struct GraphExpansionStage;
#[async_trait]
impl PipelineStage for GraphExpansionStage {
fn kind(&self) -> StageKind {
StageKind::GraphExpansion
}
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
expand_graph(ctx).await
}
}
#[derive(Debug, Clone, Copy)]
pub struct ChunkAttachStage;
#[async_trait]
impl PipelineStage for ChunkAttachStage {
fn kind(&self) -> StageKind {
StageKind::ChunkAttach
}
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
attach_chunks(ctx).await
}
}
#[derive(Debug, Clone, Copy)]
pub struct RerankStage;
#[async_trait]
impl PipelineStage for RerankStage {
fn kind(&self) -> StageKind {
StageKind::Rerank
}
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
rerank(ctx).await
}
}
#[derive(Debug, Clone, Copy)]
pub struct AssembleEntitiesStage;
#[async_trait]
impl PipelineStage for AssembleEntitiesStage {
fn kind(&self) -> StageKind {
StageKind::Assemble
}
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
assemble(ctx)
}
}
#[derive(Debug, Clone, Copy)]
pub struct ChunkVectorStage;
#[async_trait]
impl PipelineStage for ChunkVectorStage {
fn kind(&self) -> StageKind {
StageKind::CollectCandidates
}
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
collect_vector_chunks(ctx).await
}
}
#[derive(Debug, Clone, Copy)]
pub struct ChunkRerankStage;
#[async_trait]
impl PipelineStage for ChunkRerankStage {
fn kind(&self) -> StageKind {
StageKind::Rerank
}
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
rerank_chunks(ctx).await
}
}
#[derive(Debug, Clone, Copy)]
pub struct ChunkAssembleStage;
#[async_trait]
impl PipelineStage for ChunkAssembleStage {
fn kind(&self) -> StageKind {
StageKind::Assemble
}
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
assemble_chunks(ctx)
}
}
#[instrument(level = "trace", skip_all)]
pub async fn embed(
machine: HybridRetrievalMachine<(), Ready>,
ctx: &mut PipelineContext<'_>,
) -> Result<HybridRetrievalMachine<(), Embedded>, AppError> {
pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
let embedding_cached = ctx.query_embedding.is_some();
if embedding_cached {
debug!("Reusing cached query embedding for hybrid retrieval");
@@ -185,17 +304,11 @@ pub async fn embed(
ctx.query_embedding = Some(embedding);
}
machine
.embed()
.map_err(|(_, guard)| map_guard_error("embed", guard))
Ok(())
}
#[instrument(level = "trace", skip_all)]
pub async fn collect_candidates(
machine: HybridRetrievalMachine<(), Embedded>,
ctx: &mut PipelineContext<'_>,
) -> Result<HybridRetrievalMachine<(), CandidatesLoaded>, AppError> {
let stage_start = Instant::now();
pub async fn collect_candidates(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
debug!("Collecting initial candidates via vector and FTS search");
let embedding = ctx.ensure_embedding()?.clone();
let tuning = &ctx.config.tuning;
@@ -265,104 +378,80 @@ pub async fn collect_candidates(
apply_fusion(&mut ctx.entity_candidates, weights);
apply_fusion(&mut ctx.chunk_candidates, weights);
let next = machine
.collect_candidates()
.map_err(|(_, guard)| map_guard_error("collect_candidates", guard))?;
ctx.record_collect_candidates_timing(stage_start.elapsed());
Ok(next)
Ok(())
}
#[instrument(level = "trace", skip_all)]
pub async fn expand_graph(
machine: HybridRetrievalMachine<(), CandidatesLoaded>,
ctx: &mut PipelineContext<'_>,
) -> Result<HybridRetrievalMachine<(), GraphExpanded>, AppError> {
let stage_start = Instant::now();
pub async fn expand_graph(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
debug!("Expanding candidates using graph relationships");
let next = {
let tuning = &ctx.config.tuning;
let weights = FusionWeights::default();
let tuning = &ctx.config.tuning;
let weights = FusionWeights::default();
if ctx.entity_candidates.is_empty() {
machine
.expand_graph()
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
} else {
let graph_seeds = seeds_from_candidates(
&ctx.entity_candidates,
tuning.graph_seed_min_score,
tuning.graph_traversal_seed_limit,
);
if ctx.entity_candidates.is_empty() {
return Ok(());
}
if graph_seeds.is_empty() {
machine
.expand_graph()
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
} else {
let mut futures = FuturesUnordered::new();
for seed in graph_seeds {
let db = ctx.db_client;
let user = ctx.user_id.clone();
let limit = tuning.graph_neighbor_limit;
futures.push(async move {
let neighbors =
find_entities_by_relationship_by_id(db, &seed.id, &user, limit).await;
(seed, neighbors)
});
}
let graph_seeds = seeds_from_candidates(
&ctx.entity_candidates,
tuning.graph_seed_min_score,
tuning.graph_traversal_seed_limit,
);
while let Some((seed, neighbors_result)) = futures.next().await {
let neighbors = neighbors_result.map_err(AppError::from)?;
if neighbors.is_empty() {
continue;
}
if graph_seeds.is_empty() {
return Ok(());
}
for neighbor in neighbors {
if neighbor.id == seed.id {
continue;
}
let mut futures = FuturesUnordered::new();
for seed in graph_seeds {
let db = ctx.db_client;
let user = ctx.user_id.clone();
let limit = tuning.graph_neighbor_limit;
futures.push(async move {
let neighbors = find_entities_by_relationship_by_id(db, &seed.id, &user, limit).await;
(seed, neighbors)
});
}
let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay);
let entry = ctx
.entity_candidates
.entry(neighbor.id.clone())
.or_insert_with(|| Scored::new(neighbor.clone()));
entry.item = neighbor;
let inherited_vector =
clamp_unit(graph_score * tuning.graph_vector_inheritance);
let vector_existing = entry.scores.vector.unwrap_or(0.0);
if inherited_vector > vector_existing {
entry.scores.vector = Some(inherited_vector);
}
let existing_graph = entry.scores.graph.unwrap_or(f32::MIN);
if graph_score > existing_graph || entry.scores.graph.is_none() {
entry.scores.graph = Some(graph_score);
}
let fused = fuse_scores(&entry.scores, weights);
entry.update_fused(fused);
}
}
machine
.expand_graph()
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
}
while let Some((seed, neighbors_result)) = futures.next().await {
let neighbors = neighbors_result.map_err(AppError::from)?;
if neighbors.is_empty() {
continue;
}
}?;
ctx.record_graph_expansion_timing(stage_start.elapsed());
Ok(next)
for neighbor in neighbors {
if neighbor.id == seed.id {
continue;
}
let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay);
let entry = ctx
.entity_candidates
.entry(neighbor.id.clone())
.or_insert_with(|| Scored::new(neighbor.clone()));
entry.item = neighbor;
let inherited_vector = clamp_unit(graph_score * tuning.graph_vector_inheritance);
let vector_existing = entry.scores.vector.unwrap_or(0.0);
if inherited_vector > vector_existing {
entry.scores.vector = Some(inherited_vector);
}
let existing_graph = entry.scores.graph.unwrap_or(f32::MIN);
if graph_score > existing_graph || entry.scores.graph.is_none() {
entry.scores.graph = Some(graph_score);
}
let fused = fuse_scores(&entry.scores, weights);
entry.update_fused(fused);
}
}
Ok(())
}
#[instrument(level = "trace", skip_all)]
pub async fn attach_chunks(
machine: HybridRetrievalMachine<(), GraphExpanded>,
ctx: &mut PipelineContext<'_>,
) -> Result<HybridRetrievalMachine<(), ChunksAttached>, AppError> {
let stage_start = Instant::now();
pub async fn attach_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
debug!("Attaching chunks to surviving entities");
let tuning = &ctx.config.tuning;
let weights = FusionWeights::default();
@@ -438,19 +527,11 @@ pub async fn attach_chunks(
ctx.chunk_values = chunk_values;
let next = machine
.attach_chunks()
.map_err(|(_, guard)| map_guard_error("attach_chunks", guard))?;
ctx.record_chunk_attach_timing(stage_start.elapsed());
Ok(next)
Ok(())
}
#[instrument(level = "trace", skip_all)]
pub async fn rerank(
machine: HybridRetrievalMachine<(), ChunksAttached>,
ctx: &mut PipelineContext<'_>,
) -> Result<HybridRetrievalMachine<(), Reranked>, AppError> {
let stage_start = Instant::now();
pub async fn rerank(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
let mut applied = false;
if let Some(reranker) = ctx.reranker.as_ref() {
@@ -490,19 +571,124 @@ pub async fn rerank(
debug!("Applied reranking adjustments to candidate ordering");
}
let next = machine
.rerank()
.map_err(|(_, guard)| map_guard_error("rerank", guard))?;
ctx.record_rerank_timing(stage_start.elapsed());
Ok(next)
Ok(())
}
#[instrument(level = "trace", skip_all)]
pub fn assemble(
machine: HybridRetrievalMachine<(), Reranked>,
ctx: &mut PipelineContext<'_>,
) -> Result<Vec<RetrievedEntity>, AppError> {
let stage_start = Instant::now();
pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
debug!("Collecting vector chunk candidates for revised strategy");
let embedding = ctx.ensure_embedding()?.clone();
let tuning = &ctx.config.tuning;
let mut vector_chunks = find_chunk_snippets_by_vector_similarity_with_embedding(
tuning.chunk_vector_take,
embedding,
ctx.db_client,
&ctx.user_id,
)
.await?;
if ctx.diagnostics_enabled() {
ctx.record_collect_candidates(CollectCandidatesStats {
vector_entity_candidates: 0,
vector_chunk_candidates: vector_chunks.len(),
fts_entity_candidates: 0,
fts_chunk_candidates: 0,
vector_chunk_scores: sample_scores(&vector_chunks, |chunk| {
chunk.scores.vector.unwrap_or(0.0)
}),
fts_chunk_scores: Vec::new(),
});
}
vector_chunks.sort_by(|a, b| b.fused.partial_cmp(&a.fused).unwrap_or(Ordering::Equal));
ctx.revised_chunk_values = vector_chunks;
Ok(())
}
#[instrument(level = "trace", skip_all)]
pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
if ctx.revised_chunk_values.len() <= 1 {
return Ok(());
}
let Some(reranker) = ctx.reranker.as_ref() else {
debug!("No reranker lease provided; skipping chunk rerank stage");
return Ok(());
};
let documents = build_snippet_rerank_documents(
&ctx.revised_chunk_values,
ctx.config.tuning.rerank_keep_top.max(1),
);
if documents.len() <= 1 {
debug!("Skipping chunk reranking stage; insufficient chunk documents");
return Ok(());
}
match reranker.rerank(&ctx.input_text, documents).await {
Ok(results) if !results.is_empty() => {
apply_snippet_rerank_results(
&mut ctx.revised_chunk_values,
&ctx.config.tuning,
results,
);
}
Ok(_) => debug!("Chunk reranker returned no results; retaining original order"),
Err(err) => warn!(
error = %err,
"Chunk reranking failed; continuing with original ordering"
),
}
Ok(())
}
#[instrument(level = "trace", skip_all)]
pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
debug!("Assembling chunk-only retrieval results");
let mut chunk_values = std::mem::take(&mut ctx.revised_chunk_values);
let question_terms = extract_keywords(&ctx.input_text);
rank_snippet_chunks_by_combined_score(
&mut chunk_values,
&question_terms,
ctx.config.tuning.lexical_match_weight,
);
let limit = ctx.config.tuning.chunk_vector_take.max(1);
if chunk_values.len() > limit {
chunk_values.truncate(limit);
}
ctx.chunk_results = chunk_values
.into_iter()
.map(|chunk| {
let text_chunk = snippet_into_text_chunk(chunk.item, &ctx.user_id);
RetrievedChunk {
chunk: text_chunk,
score: chunk.fused,
}
})
.collect();
if ctx.diagnostics_enabled() {
ctx.record_assemble(AssembleStats {
token_budget_start: ctx.config.tuning.token_budget_estimate,
token_budget_spent: 0,
token_budget_remaining: ctx.config.tuning.token_budget_estimate,
budget_exhausted: false,
chunks_selected: ctx.chunk_results.len(),
chunks_skipped_due_budget: 0,
entity_count: 0,
entity_traces: Vec::new(),
});
}
Ok(())
}
#[instrument(level = "trace", skip_all)]
pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
debug!("Assembling final retrieved entities");
let tuning = &ctx.config.tuning;
let query_embedding = ctx.ensure_embedding()?.clone();
@@ -610,11 +796,8 @@ pub fn assemble(
});
}
machine
.assemble()
.map_err(|(_, guard)| map_guard_error("assemble", guard))?;
ctx.record_assemble_timing(stage_start.elapsed());
Ok(results)
ctx.entity_results = results;
Ok(())
}
const SCORE_SAMPLE_LIMIT: usize = 8;
@@ -630,12 +813,6 @@ where
.collect()
}
fn map_guard_error(stage: &'static str, err: GuardError) -> AppError {
AppError::InternalError(format!(
"state machine guard '{stage}' failed: guard={}, event={}, kind={:?}",
err.guard, err.event, err.kind
))
}
fn normalize_fts_scores<T>(results: &mut [Scored<T>]) {
let raw_scores: Vec<f32> = results
.iter()
@@ -873,6 +1050,23 @@ fn build_rerank_documents(ctx: &PipelineContext<'_>, max_chunks_per_entity: usiz
.collect()
}
fn build_snippet_rerank_documents(
chunks: &[Scored<ChunkSnippet>],
max_chunks: usize,
) -> Vec<String> {
chunks
.iter()
.take(max_chunks)
.map(|chunk| {
format!(
"Source: {}\nChunk:\n{}",
chunk.item.source_id,
chunk.item.chunk.trim()
)
})
.collect()
}
fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec<RerankResult>) {
if results.is_empty() || ctx.filtered_entities.is_empty() {
return;
@@ -930,6 +1124,66 @@ fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec<RerankResult
}
}
fn apply_snippet_rerank_results(
chunks: &mut Vec<Scored<ChunkSnippet>>,
tuning: &RetrievalTuning,
results: Vec<RerankResult>,
) {
if results.is_empty() || chunks.is_empty() {
return;
}
let mut remaining: Vec<Option<Scored<ChunkSnippet>>> =
std::mem::take(chunks).into_iter().map(Some).collect();
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
let normalized_scores = min_max_normalize(&raw_scores);
let use_only = tuning.rerank_scores_only;
let blend = if use_only {
1.0
} else {
clamp_unit(tuning.rerank_blend_weight)
};
let mut reranked: Vec<Scored<ChunkSnippet>> = Vec::with_capacity(remaining.len());
for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) {
if let Some(slot) = remaining.get_mut(result.index) {
if let Some(mut candidate) = slot.take() {
let original = candidate.fused;
let blended = if use_only {
clamp_unit(normalized)
} else {
clamp_unit(original * (1.0 - blend) + normalized * blend)
};
candidate.update_fused(blended);
reranked.push(candidate);
}
} else {
warn!(
result_index = result.index,
"Chunk reranker returned out-of-range index; skipping"
);
}
if reranked.len() == remaining.len() {
break;
}
}
for slot in remaining.into_iter() {
if let Some(candidate) = slot {
reranked.push(candidate);
}
}
let keep_top = tuning.rerank_keep_top;
if keep_top > 0 && reranked.len() > keep_top {
reranked.truncate(keep_top);
}
*chunks = reranked;
}
fn estimate_tokens(text: &str, avg_chars_per_token: usize) -> usize {
let chars = text.chars().count().max(1);
(chars / avg_chars_per_token).max(1)
@@ -963,6 +1217,32 @@ fn extract_keywords(text: &str) -> Vec<String> {
terms
}
fn rank_snippet_chunks_by_combined_score(
candidates: &mut [Scored<ChunkSnippet>],
question_terms: &[String],
lexical_weight: f32,
) {
if lexical_weight > 0.0 && !question_terms.is_empty() {
for candidate in candidates.iter_mut() {
let lexical = lexical_overlap_score(question_terms, &candidate.item.chunk);
let combined = clamp_unit(candidate.fused + lexical_weight * lexical);
candidate.update_fused(combined);
}
}
candidates.sort_by(|a, b| b.fused.partial_cmp(&a.fused).unwrap_or(Ordering::Equal));
}
fn snippet_into_text_chunk(snippet: ChunkSnippet, user_id: &str) -> TextChunk {
let mut chunk = TextChunk::new(
snippet.source_id.clone(),
snippet.chunk,
Vec::new(),
user_id.to_owned(),
);
chunk.id = snippet.id;
chunk
}
fn lexical_overlap_score(terms: &[String], haystack: &str) -> f32 {
if terms.is_empty() {
return 0.0;

View File

@@ -0,0 +1,75 @@
use super::{
stages::{
AssembleEntitiesStage, ChunkAssembleStage, ChunkAttachStage, ChunkRerankStage,
ChunkVectorStage, CollectCandidatesStage, EmbedStage, GraphExpansionStage, PipelineContext,
RerankStage,
},
BoxedStage, RetrievalConfig, RetrievalStrategy, StrategyDriver,
};
use crate::{RetrievedChunk, RetrievedEntity};
use common::error::AppError;
pub struct InitialStrategyDriver;
impl InitialStrategyDriver {
pub fn new() -> Self {
Self
}
}
impl StrategyDriver for InitialStrategyDriver {
type Output = Vec<RetrievedEntity>;
fn strategy(&self) -> RetrievalStrategy {
RetrievalStrategy::Initial
}
fn stages(&self) -> Vec<BoxedStage> {
vec![
Box::new(EmbedStage),
Box::new(CollectCandidatesStage),
Box::new(GraphExpansionStage),
Box::new(ChunkAttachStage),
Box::new(RerankStage),
Box::new(AssembleEntitiesStage),
]
}
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
Ok(ctx.take_entity_results())
}
}
pub struct RevisedStrategyDriver;
impl RevisedStrategyDriver {
pub fn new() -> Self {
Self
}
}
impl StrategyDriver for RevisedStrategyDriver {
type Output = Vec<RetrievedChunk>;
fn strategy(&self) -> RetrievalStrategy {
RetrievalStrategy::Revised
}
fn stages(&self) -> Vec<BoxedStage> {
vec![
Box::new(EmbedStage),
Box::new(ChunkVectorStage),
Box::new(ChunkRerankStage),
Box::new(ChunkAssembleStage),
]
}
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
Ok(ctx.take_chunk_results())
}
fn override_tuning(&self, config: &mut RetrievalConfig) {
config.tuning.entity_vector_take = 0;
config.tuning.entity_fts_take = 0;
}
}

View File

@@ -1,9 +1,11 @@
use std::collections::HashMap;
use common::storage::types::file_info::deserialize_flexible_id;
use common::{
error::AppError,
storage::{db::SurrealDbClient, types::StoredObject},
storage::{
db::SurrealDbClient,
types::{file_info::deserialize_flexible_id, StoredObject},
},
utils::embedding::generate_embedding,
};
use serde::Deserialize;
@@ -156,3 +158,61 @@ where
Ok(scored)
}
#[derive(Debug, Clone, Deserialize)]
pub struct ChunkSnippet {
pub id: String,
pub source_id: String,
pub chunk: String,
}
#[derive(Debug, Deserialize)]
struct ChunkDistanceRow {
distance: Option<f32>,
#[serde(deserialize_with = "deserialize_flexible_id")]
pub id: String,
pub source_id: String,
pub chunk: String,
}
pub async fn find_chunk_snippets_by_vector_similarity_with_embedding(
take: usize,
query_embedding: Vec<f32>,
db_client: &SurrealDbClient,
user_id: &str,
) -> Result<Vec<Scored<ChunkSnippet>>, AppError> {
let embedding_literal = serde_json::to_string(&query_embedding)
.map_err(|err| AppError::InternalError(format!("Failed to serialize embedding: {err}")))?;
let closest_query = format!(
"SELECT id, source_id, chunk, vector::distance::knn() AS distance \
FROM text_chunk \
WHERE user_id = $user_id AND embedding <|{take},40|> {embedding} \
LIMIT $limit",
take = take,
embedding = embedding_literal
);
let mut response = db_client
.query(closest_query)
.bind(("user_id", user_id.to_owned()))
.bind(("limit", take as i64))
.await?;
let rows: Vec<ChunkDistanceRow> = response.take(0)?;
let mut scored = Vec::with_capacity(rows.len());
for row in rows {
let similarity = row.distance.map(distance_to_similarity).unwrap_or_default();
scored.push(
Scored::new(ChunkSnippet {
id: row.id,
source_id: row.source_id,
chunk: row.chunk,
})
.with_vector_score(similarity),
);
}
Ok(scored)
}