mirror of
https://github.com/perstarkse/minne.git
synced 2026-03-21 17:09:51 +01:00
retrieval-pipeline: v0
This commit is contained in:
48
Cargo.lock
generated
48
Cargo.lock
generated
@@ -1457,26 +1457,6 @@ dependencies = [
|
||||
"static_assertions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "composite-retrieval"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-openai",
|
||||
"axum",
|
||||
"common",
|
||||
"fastembed",
|
||||
"futures",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"state-machines",
|
||||
"surrealdb",
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "compression-codecs"
|
||||
version = "0.4.30"
|
||||
@@ -2197,7 +2177,6 @@ dependencies = [
|
||||
"async-trait",
|
||||
"chrono",
|
||||
"common",
|
||||
"composite-retrieval",
|
||||
"criterion",
|
||||
"fastembed",
|
||||
"futures",
|
||||
@@ -2205,6 +2184,7 @@ dependencies = [
|
||||
"object_store 0.11.2",
|
||||
"once_cell",
|
||||
"rand 0.8.5",
|
||||
"retrieval-pipeline",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_yaml",
|
||||
@@ -2880,7 +2860,6 @@ dependencies = [
|
||||
"chrono",
|
||||
"chrono-tz",
|
||||
"common",
|
||||
"composite-retrieval",
|
||||
"futures",
|
||||
"include_dir",
|
||||
"json-stream-parser",
|
||||
@@ -2888,6 +2867,7 @@ dependencies = [
|
||||
"minijinja-autoreload",
|
||||
"minijinja-contrib",
|
||||
"minijinja-embed",
|
||||
"retrieval-pipeline",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"surrealdb",
|
||||
@@ -3342,13 +3322,13 @@ dependencies = [
|
||||
"bytes",
|
||||
"chrono",
|
||||
"common",
|
||||
"composite-retrieval",
|
||||
"dom_smoothie",
|
||||
"futures",
|
||||
"headless_chrome",
|
||||
"lopdf 0.32.0",
|
||||
"pdf-extract",
|
||||
"reqwest",
|
||||
"retrieval-pipeline",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"state-machines",
|
||||
@@ -3802,10 +3782,10 @@ dependencies = [
|
||||
"async-openai",
|
||||
"axum",
|
||||
"common",
|
||||
"composite-retrieval",
|
||||
"futures",
|
||||
"html-router",
|
||||
"ingestion-pipeline",
|
||||
"retrieval-pipeline",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"surrealdb",
|
||||
@@ -5475,6 +5455,26 @@ dependencies = [
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "retrieval-pipeline"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-openai",
|
||||
"async-trait",
|
||||
"axum",
|
||||
"common",
|
||||
"fastembed",
|
||||
"futures",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"surrealdb",
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "revision"
|
||||
version = "0.10.0"
|
||||
|
||||
@@ -5,7 +5,7 @@ members = [
|
||||
"api-router",
|
||||
"html-router",
|
||||
"ingestion-pipeline",
|
||||
"composite-retrieval",
|
||||
"retrieval-pipeline",
|
||||
"json-stream-parser",
|
||||
"eval"
|
||||
]
|
||||
|
||||
@@ -54,6 +54,8 @@ pub struct AppConfig {
|
||||
pub fastembed_show_download_progress: Option<bool>,
|
||||
#[serde(default)]
|
||||
pub fastembed_max_length: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub retrieval_strategy: Option<String>,
|
||||
}
|
||||
|
||||
fn default_data_dir() -> String {
|
||||
@@ -117,6 +119,7 @@ impl Default for AppConfig {
|
||||
fastembed_cache_dir: None,
|
||||
fastembed_show_download_progress: None,
|
||||
fastembed_max_length: None,
|
||||
retrieval_strategy: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,212 +0,0 @@
|
||||
mod config;
|
||||
mod diagnostics;
|
||||
mod stages;
|
||||
mod state;
|
||||
|
||||
pub use config::{RetrievalConfig, RetrievalTuning};
|
||||
pub use diagnostics::{
|
||||
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
|
||||
PipelineDiagnostics,
|
||||
};
|
||||
|
||||
use crate::{reranking::RerankerLease, RetrievedEntity};
|
||||
use async_openai::Client;
|
||||
use common::{error::AppError, storage::db::SurrealDbClient};
|
||||
use tracing::info;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PipelineRunOutput {
|
||||
pub results: Vec<RetrievedEntity>,
|
||||
pub diagnostics: Option<PipelineDiagnostics>,
|
||||
pub stage_timings: PipelineStageTimings,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, serde::Serialize)]
|
||||
pub struct PipelineStageTimings {
|
||||
pub collect_candidates_ms: u128,
|
||||
pub graph_expansion_ms: u128,
|
||||
pub chunk_attach_ms: u128,
|
||||
pub rerank_ms: u128,
|
||||
pub assemble_ms: u128,
|
||||
}
|
||||
|
||||
impl PipelineStageTimings {
|
||||
fn record_collect_candidates(&mut self, duration: std::time::Duration) {
|
||||
self.collect_candidates_ms += duration.as_millis() as u128;
|
||||
}
|
||||
|
||||
fn record_graph_expansion(&mut self, duration: std::time::Duration) {
|
||||
self.graph_expansion_ms += duration.as_millis() as u128;
|
||||
}
|
||||
|
||||
fn record_chunk_attach(&mut self, duration: std::time::Duration) {
|
||||
self.chunk_attach_ms += duration.as_millis() as u128;
|
||||
}
|
||||
|
||||
fn record_rerank(&mut self, duration: std::time::Duration) {
|
||||
self.rerank_ms += duration.as_millis() as u128;
|
||||
}
|
||||
|
||||
fn record_assemble(&mut self, duration: std::time::Duration) {
|
||||
self.assemble_ms += duration.as_millis() as u128;
|
||||
}
|
||||
}
|
||||
|
||||
/// Drives the retrieval pipeline from embedding through final assembly.
|
||||
pub async fn run_pipeline(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
let input_chars = input_text.chars().count();
|
||||
let input_preview: String = input_text.chars().take(120).collect();
|
||||
let input_preview_clean = input_preview.replace('\n', " ");
|
||||
let preview_len = input_preview_clean.chars().count();
|
||||
info!(
|
||||
%user_id,
|
||||
input_chars,
|
||||
preview_truncated = input_chars > preview_len,
|
||||
preview = %input_preview_clean,
|
||||
"Starting ingestion retrieval pipeline"
|
||||
);
|
||||
let ctx = stages::PipelineContext::new(
|
||||
db_client,
|
||||
openai_client,
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
config,
|
||||
reranker,
|
||||
);
|
||||
let outcome = run_pipeline_internal(ctx, false).await?;
|
||||
|
||||
Ok(outcome.results)
|
||||
}
|
||||
|
||||
pub async fn run_pipeline_with_embedding(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
let ctx = stages::PipelineContext::with_embedding(
|
||||
db_client,
|
||||
openai_client,
|
||||
query_embedding,
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
config,
|
||||
reranker,
|
||||
);
|
||||
let outcome = run_pipeline_internal(ctx, false).await?;
|
||||
|
||||
Ok(outcome.results)
|
||||
}
|
||||
|
||||
/// Runs the pipeline with a precomputed embedding and returns stage metrics.
|
||||
pub async fn run_pipeline_with_embedding_with_metrics(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<PipelineRunOutput, AppError> {
|
||||
let ctx = stages::PipelineContext::with_embedding(
|
||||
db_client,
|
||||
openai_client,
|
||||
query_embedding,
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
config,
|
||||
reranker,
|
||||
);
|
||||
|
||||
run_pipeline_internal(ctx, false).await
|
||||
}
|
||||
|
||||
pub async fn run_pipeline_with_embedding_with_diagnostics(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<PipelineRunOutput, AppError> {
|
||||
let ctx = stages::PipelineContext::with_embedding(
|
||||
db_client,
|
||||
openai_client,
|
||||
query_embedding,
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
config,
|
||||
reranker,
|
||||
);
|
||||
|
||||
run_pipeline_internal(ctx, true).await
|
||||
}
|
||||
|
||||
/// Helper exposed for tests to convert retrieved entities into downstream prompt JSON.
|
||||
pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value {
|
||||
serde_json::json!(entities
|
||||
.iter()
|
||||
.map(|entry| {
|
||||
serde_json::json!({
|
||||
"KnowledgeEntity": {
|
||||
"id": entry.entity.id,
|
||||
"name": entry.entity.name,
|
||||
"description": entry.entity.description,
|
||||
"score": round_score(entry.score),
|
||||
"chunks": entry.chunks.iter().map(|chunk| {
|
||||
serde_json::json!({
|
||||
"score": round_score(chunk.score),
|
||||
"content": chunk.chunk.chunk
|
||||
})
|
||||
}).collect::<Vec<_>>()
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
}
|
||||
|
||||
async fn run_pipeline_internal(
|
||||
mut ctx: stages::PipelineContext<'_>,
|
||||
capture_diagnostics: bool,
|
||||
) -> Result<PipelineRunOutput, AppError> {
|
||||
if capture_diagnostics {
|
||||
ctx.enable_diagnostics();
|
||||
}
|
||||
|
||||
let results = drive_pipeline(&mut ctx).await?;
|
||||
let diagnostics = ctx.take_diagnostics();
|
||||
|
||||
Ok(PipelineRunOutput {
|
||||
results,
|
||||
diagnostics,
|
||||
stage_timings: ctx.take_stage_timings(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn drive_pipeline(
|
||||
ctx: &mut stages::PipelineContext<'_>,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
let machine = state::ready();
|
||||
let machine = stages::embed(machine, ctx).await?;
|
||||
let machine = stages::collect_candidates(machine, ctx).await?;
|
||||
let machine = stages::expand_graph(machine, ctx).await?;
|
||||
let machine = stages::attach_chunks(machine, ctx).await?;
|
||||
let machine = stages::rerank(machine, ctx).await?;
|
||||
let results = stages::assemble(machine, ctx)?;
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
fn round_score(value: f32) -> f64 {
|
||||
(f64::from(value) * 1000.0).round() / 1000.0
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
use state_machines::state_machine;
|
||||
|
||||
state_machine! {
|
||||
name: HybridRetrievalMachine,
|
||||
state: HybridRetrievalState,
|
||||
initial: Ready,
|
||||
states: [Ready, Embedded, CandidatesLoaded, GraphExpanded, ChunksAttached, Reranked, Completed, Failed],
|
||||
events {
|
||||
embed { transition: { from: Ready, to: Embedded } }
|
||||
collect_candidates { transition: { from: Embedded, to: CandidatesLoaded } }
|
||||
expand_graph { transition: { from: CandidatesLoaded, to: GraphExpanded } }
|
||||
attach_chunks { transition: { from: GraphExpanded, to: ChunksAttached } }
|
||||
rerank { transition: { from: ChunksAttached, to: Reranked } }
|
||||
assemble { transition: { from: Reranked, to: Completed } }
|
||||
abort {
|
||||
transition: { from: Ready, to: Failed }
|
||||
transition: { from: CandidatesLoaded, to: Failed }
|
||||
transition: { from: GraphExpanded, to: Failed }
|
||||
transition: { from: ChunksAttached, to: Failed }
|
||||
transition: { from: Reranked, to: Failed }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ready() -> HybridRetrievalMachine<(), Ready> {
|
||||
HybridRetrievalMachine::new(())
|
||||
}
|
||||
@@ -38,7 +38,7 @@ url = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
|
||||
common = { path = "../common" }
|
||||
composite-retrieval = { path = "../composite-retrieval" }
|
||||
retrieval-pipeline = { path = "../retrieval-pipeline" }
|
||||
json-stream-parser = { path = "../json-stream-parser" }
|
||||
|
||||
[build-dependencies]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use common::storage::{db::SurrealDbClient, store::StorageManager};
|
||||
use common::utils::template_engine::{ProvidesTemplateEngine, TemplateEngine};
|
||||
use common::{create_template_engine, storage::db::ProvidesDb, utils::config::AppConfig};
|
||||
use composite_retrieval::reranking::RerankerPool;
|
||||
use retrieval_pipeline::{reranking::RerankerPool, RetrievalStrategy};
|
||||
use std::sync::Arc;
|
||||
use tracing::debug;
|
||||
|
||||
@@ -40,6 +40,14 @@ impl HtmlState {
|
||||
reranker_pool,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn retrieval_strategy(&self) -> RetrievalStrategy {
|
||||
self.config
|
||||
.retrieval_strategy
|
||||
.as_deref()
|
||||
.and_then(|value| value.parse().ok())
|
||||
.unwrap_or(RetrievalStrategy::Initial)
|
||||
}
|
||||
}
|
||||
impl ProvidesDb for HtmlState {
|
||||
fn db(&self) -> &Arc<SurrealDbClient> {
|
||||
|
||||
@@ -8,16 +8,16 @@ use axum::{
|
||||
Sse,
|
||||
},
|
||||
};
|
||||
use composite_retrieval::{
|
||||
answer_retrieval::{create_chat_request, create_user_message_with_history, LLMResponseFormat},
|
||||
retrieve_entities, retrieved_entities_to_json,
|
||||
};
|
||||
use futures::{
|
||||
stream::{self, once},
|
||||
Stream, StreamExt, TryStreamExt,
|
||||
};
|
||||
use json_stream_parser::JsonStreamParser;
|
||||
use minijinja::Value;
|
||||
use retrieval_pipeline::{
|
||||
answer_retrieval::{create_chat_request, create_user_message_with_history, LLMResponseFormat},
|
||||
retrieve_entities, retrieved_entities_to_json, RetrievalConfig, StrategyOutput,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::from_str;
|
||||
use tokio::sync::{mpsc::channel, Mutex};
|
||||
@@ -123,16 +123,24 @@ pub async fn get_response_stream(
|
||||
None => None,
|
||||
};
|
||||
|
||||
let mut retrieval_config = RetrievalConfig::default();
|
||||
retrieval_config.strategy = state.retrieval_strategy();
|
||||
let entities = match retrieve_entities(
|
||||
&state.db,
|
||||
&state.openai_client,
|
||||
&user_message.content,
|
||||
&user.id,
|
||||
retrieval_config,
|
||||
rerank_lease,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(entities) => entities,
|
||||
Ok(StrategyOutput::Entities(entities)) => entities,
|
||||
Ok(StrategyOutput::Chunks(_)) => {
|
||||
return Sse::new(create_error_stream(
|
||||
"Chunk-only retrieval results are not supported in this route",
|
||||
))
|
||||
}
|
||||
Err(_e) => {
|
||||
return Sse::new(create_error_stream("Failed to retrieve knowledge entities"));
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ use common::{
|
||||
},
|
||||
utils::embedding::generate_embedding,
|
||||
};
|
||||
use composite_retrieval::{retrieve_entities, RetrievedEntity};
|
||||
use retrieval_pipeline::{retrieve_entities, RetrievalConfig, RetrievedEntity, StrategyOutput};
|
||||
use tracing::debug;
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -284,11 +284,15 @@ pub async fn suggest_knowledge_relationships(
|
||||
None => None,
|
||||
};
|
||||
|
||||
if let Ok(results) = retrieve_entities(
|
||||
let mut retrieval_config = RetrievalConfig::default();
|
||||
retrieval_config.strategy = state.retrieval_strategy();
|
||||
|
||||
if let Ok(StrategyOutput::Entities(results)) = retrieve_entities(
|
||||
&state.db,
|
||||
&state.openai_client,
|
||||
&query,
|
||||
&user.id,
|
||||
retrieval_config,
|
||||
rerank_lease,
|
||||
)
|
||||
.await
|
||||
|
||||
@@ -32,7 +32,7 @@ lopdf = "0.32"
|
||||
bytes = { workspace = true }
|
||||
|
||||
common = { path = "../common" }
|
||||
composite-retrieval = { path = "../composite-retrieval" }
|
||||
retrieval-pipeline = { path = "../retrieval-pipeline" }
|
||||
async-trait = { workspace = true }
|
||||
state-machines = { workspace = true }
|
||||
[features]
|
||||
|
||||
@@ -11,7 +11,7 @@ use common::{
|
||||
},
|
||||
},
|
||||
};
|
||||
use composite_retrieval::RetrievedEntity;
|
||||
use retrieval_pipeline::RetrievedEntity;
|
||||
use tracing::error;
|
||||
|
||||
use super::enrichment_result::LLMEnrichmentResult;
|
||||
|
||||
@@ -28,7 +28,7 @@ use common::{
|
||||
},
|
||||
utils::config::AppConfig,
|
||||
};
|
||||
use composite_retrieval::reranking::RerankerPool;
|
||||
use retrieval_pipeline::reranking::RerankerPool;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use self::{
|
||||
|
||||
@@ -19,8 +19,9 @@ use common::{
|
||||
},
|
||||
utils::{config::AppConfig, embedding::generate_embedding},
|
||||
};
|
||||
use composite_retrieval::{
|
||||
reranking::RerankerPool, retrieve_entities, retrieved_entities_to_json, RetrievedEntity,
|
||||
use retrieval_pipeline::{
|
||||
reranking::RerankerPool, retrieve_entities, retrieved_entities_to_json, RetrievalConfig,
|
||||
RetrievalStrategy, RetrievedEntity, StrategyOutput,
|
||||
};
|
||||
use text_splitter::TextSplitter;
|
||||
|
||||
@@ -124,6 +125,14 @@ impl DefaultPipelineServices {
|
||||
Ok(request)
|
||||
}
|
||||
|
||||
fn configured_strategy(&self) -> RetrievalStrategy {
|
||||
self.config
|
||||
.retrieval_strategy
|
||||
.as_deref()
|
||||
.and_then(|value| value.parse().ok())
|
||||
.unwrap_or(RetrievalStrategy::Initial)
|
||||
}
|
||||
|
||||
async fn perform_analysis(
|
||||
&self,
|
||||
request: CreateChatCompletionRequest,
|
||||
@@ -178,14 +187,24 @@ impl PipelineServices for DefaultPipelineServices {
|
||||
None => None,
|
||||
};
|
||||
|
||||
retrieve_entities(
|
||||
let mut config = RetrievalConfig::default();
|
||||
config.strategy = self.configured_strategy();
|
||||
match retrieve_entities(
|
||||
&self.db,
|
||||
&self.openai_client,
|
||||
&input_text,
|
||||
&content.user_id,
|
||||
config,
|
||||
rerank_lease,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(StrategyOutput::Entities(entities)) => Ok(entities),
|
||||
Ok(StrategyOutput::Chunks(_)) => Err(AppError::InternalError(
|
||||
"Chunk-only retrieval is not supported in ingestion".into(),
|
||||
)),
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_enrichment(
|
||||
|
||||
@@ -16,7 +16,7 @@ use common::{
|
||||
},
|
||||
},
|
||||
};
|
||||
use composite_retrieval::{RetrievedChunk, RetrievedEntity};
|
||||
use retrieval_pipeline::{RetrievedChunk, RetrievedEntity};
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ ingestion-pipeline = { path = "../ingestion-pipeline" }
|
||||
api-router = { path = "../api-router" }
|
||||
html-router = { path = "../html-router" }
|
||||
common = { path = "../common" }
|
||||
composite-retrieval = { path = "../composite-retrieval" }
|
||||
retrieval-pipeline = { path = "../retrieval-pipeline" }
|
||||
|
||||
[dev-dependencies]
|
||||
tower = "0.5"
|
||||
|
||||
@@ -3,9 +3,9 @@ use axum::{extract::FromRef, Router};
|
||||
use common::{
|
||||
storage::db::SurrealDbClient, storage::store::StorageManager, utils::config::get_config,
|
||||
};
|
||||
use composite_retrieval::reranking::RerankerPool;
|
||||
use html_router::{html_routes, html_state::HtmlState};
|
||||
use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop};
|
||||
use retrieval_pipeline::reranking::RerankerPool;
|
||||
use std::sync::Arc;
|
||||
use tracing::{error, info};
|
||||
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
|
||||
|
||||
@@ -5,8 +5,8 @@ use axum::{extract::FromRef, Router};
|
||||
use common::{
|
||||
storage::db::SurrealDbClient, storage::store::StorageManager, utils::config::get_config,
|
||||
};
|
||||
use composite_retrieval::reranking::RerankerPool;
|
||||
use html_router::{html_routes, html_state::HtmlState};
|
||||
use retrieval_pipeline::reranking::RerankerPool;
|
||||
use tracing::info;
|
||||
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ use std::sync::Arc;
|
||||
use common::{
|
||||
storage::db::SurrealDbClient, storage::store::StorageManager, utils::config::get_config,
|
||||
};
|
||||
use composite_retrieval::reranking::RerankerPool;
|
||||
use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop};
|
||||
use retrieval_pipeline::reranking::RerankerPool;
|
||||
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
|
||||
|
||||
#[tokio::main]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[package]
|
||||
name = "composite-retrieval"
|
||||
name = "retrieval-pipeline"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
license = "AGPL-3.0-or-later"
|
||||
@@ -18,8 +18,8 @@ serde_json = { workspace = true }
|
||||
surrealdb = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
async-openai = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
fastembed = { workspace = true }
|
||||
|
||||
common = { path = "../common", features = ["test-utils"] }
|
||||
state-machines = { workspace = true }
|
||||
@@ -17,7 +17,10 @@ use common::{
|
||||
use reranking::RerankerLease;
|
||||
use tracing::instrument;
|
||||
|
||||
pub use pipeline::{retrieved_entities_to_json, RetrievalConfig, RetrievalTuning};
|
||||
pub use pipeline::{
|
||||
retrieved_entities_to_json, PipelineDiagnostics, PipelineStageTimings, RetrievalConfig,
|
||||
RetrievalStrategy, RetrievalTuning, StrategyOutput,
|
||||
};
|
||||
|
||||
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -41,14 +44,15 @@ pub async fn retrieve_entities(
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
) -> Result<StrategyOutput, AppError> {
|
||||
pipeline::run_pipeline(
|
||||
db_client,
|
||||
openai_client,
|
||||
input_text,
|
||||
user_id,
|
||||
RetrievalConfig::default(),
|
||||
config,
|
||||
reranker,
|
||||
)
|
||||
.await
|
||||
@@ -63,7 +67,7 @@ mod tests {
|
||||
knowledge_relationship::KnowledgeRelationship,
|
||||
text_chunk::TextChunk,
|
||||
};
|
||||
use pipeline::RetrievalConfig;
|
||||
use pipeline::{RetrievalConfig, RetrievalStrategy};
|
||||
use uuid::Uuid;
|
||||
|
||||
fn test_embedding() -> Vec<f32> {
|
||||
@@ -151,11 +155,16 @@ mod tests {
|
||||
.await
|
||||
.expect("Hybrid retrieval failed");
|
||||
|
||||
let entities = match results {
|
||||
StrategyOutput::Entities(items) => items,
|
||||
other => panic!("expected entity results, got {:?}", other),
|
||||
};
|
||||
|
||||
assert!(
|
||||
!results.is_empty(),
|
||||
!entities.is_empty(),
|
||||
"Expected at least one retrieval result"
|
||||
);
|
||||
let top = &results[0];
|
||||
let top = &entities[0];
|
||||
assert!(
|
||||
top.entity.name.contains("Rust"),
|
||||
"Expected Rust entity to be ranked first"
|
||||
@@ -242,8 +251,13 @@ mod tests {
|
||||
.await
|
||||
.expect("Hybrid retrieval failed");
|
||||
|
||||
let entities = match results {
|
||||
StrategyOutput::Entities(items) => items,
|
||||
other => panic!("expected entity results, got {:?}", other),
|
||||
};
|
||||
|
||||
let mut neighbor_entry = None;
|
||||
for entity in &results {
|
||||
for entity in &entities {
|
||||
if entity.entity.id == neighbor.id {
|
||||
neighbor_entry = Some(entity.clone());
|
||||
}
|
||||
@@ -264,4 +278,59 @@ mod tests {
|
||||
"Neighbor entity should surface its own supporting chunks"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_revised_strategy_returns_chunks() {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "chunk_user";
|
||||
let chunk_one = TextChunk::new(
|
||||
"src_alpha".into(),
|
||||
"Tokio tasks execute on worker threads managed by the runtime.".into(),
|
||||
chunk_embedding_primary(),
|
||||
user_id.into(),
|
||||
);
|
||||
let chunk_two = TextChunk::new(
|
||||
"src_beta".into(),
|
||||
"Hyper utilizes Tokio to drive HTTP state machines efficiently.".into(),
|
||||
chunk_embedding_secondary(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(chunk_one.clone())
|
||||
.await
|
||||
.expect("Failed to store chunk one");
|
||||
db.store_item(chunk_two.clone())
|
||||
.await
|
||||
.expect("Failed to store chunk two");
|
||||
|
||||
let config = RetrievalConfig::with_strategy(RetrievalStrategy::Revised);
|
||||
let openai_client = Client::new();
|
||||
let results = pipeline::run_pipeline_with_embedding(
|
||||
&db,
|
||||
&openai_client,
|
||||
test_embedding(),
|
||||
"tokio runtime worker behavior",
|
||||
user_id,
|
||||
config,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Revised retrieval failed");
|
||||
|
||||
let chunks = match results {
|
||||
StrategyOutput::Chunks(items) => items,
|
||||
other => panic!("expected chunk output, got {:?}", other),
|
||||
};
|
||||
|
||||
assert!(
|
||||
!chunks.is_empty(),
|
||||
"Revised strategy should return chunk-only responses"
|
||||
);
|
||||
assert!(
|
||||
chunks
|
||||
.iter()
|
||||
.any(|entry| entry.chunk.chunk.contains("Tokio")),
|
||||
"Chunk results should contain relevant snippets"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,40 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RetrievalStrategy {
|
||||
Initial,
|
||||
Revised,
|
||||
}
|
||||
|
||||
impl Default for RetrievalStrategy {
|
||||
fn default() -> Self {
|
||||
Self::Initial
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for RetrievalStrategy {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(value: &str) -> Result<Self, Self::Err> {
|
||||
match value.to_ascii_lowercase().as_str() {
|
||||
"initial" => Ok(Self::Initial),
|
||||
"revised" => Ok(Self::Revised),
|
||||
other => Err(format!("unknown retrieval strategy '{other}'")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for RetrievalStrategy {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let label = match self {
|
||||
RetrievalStrategy::Initial => "initial",
|
||||
RetrievalStrategy::Revised => "revised",
|
||||
};
|
||||
f.write_str(label)
|
||||
}
|
||||
}
|
||||
|
||||
/// Tunable parameters that govern each retrieval stage.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -51,18 +87,34 @@ impl Default for RetrievalTuning {
|
||||
/// Wrapper containing tuning plus future flags for per-request overrides.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetrievalConfig {
|
||||
pub strategy: RetrievalStrategy,
|
||||
pub tuning: RetrievalTuning,
|
||||
}
|
||||
|
||||
impl RetrievalConfig {
|
||||
pub fn new(tuning: RetrievalTuning) -> Self {
|
||||
Self { tuning }
|
||||
Self {
|
||||
strategy: RetrievalStrategy::Initial,
|
||||
tuning,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_strategy(strategy: RetrievalStrategy) -> Self {
|
||||
Self {
|
||||
strategy,
|
||||
tuning: RetrievalTuning::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_tuning(strategy: RetrievalStrategy, tuning: RetrievalTuning) -> Self {
|
||||
Self { strategy, tuning }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RetrievalConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
strategy: RetrievalStrategy::default(),
|
||||
tuning: RetrievalTuning::default(),
|
||||
}
|
||||
}
|
||||
397
retrieval-pipeline/src/pipeline/mod.rs
Normal file
397
retrieval-pipeline/src/pipeline/mod.rs
Normal file
@@ -0,0 +1,397 @@
|
||||
mod config;
|
||||
mod diagnostics;
|
||||
mod stages;
|
||||
mod strategies;
|
||||
|
||||
pub use config::{RetrievalConfig, RetrievalStrategy, RetrievalTuning};
|
||||
pub use diagnostics::{
|
||||
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
|
||||
PipelineDiagnostics,
|
||||
};
|
||||
|
||||
use crate::{reranking::RerankerLease, RetrievedChunk, RetrievedEntity};
|
||||
use async_openai::Client;
|
||||
use async_trait::async_trait;
|
||||
use common::{error::AppError, storage::db::SurrealDbClient};
|
||||
use std::time::{Duration, Instant};
|
||||
use tracing::info;
|
||||
|
||||
use stages::PipelineContext;
|
||||
use strategies::{InitialStrategyDriver, RevisedStrategyDriver};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum StrategyOutput {
|
||||
Entities(Vec<RetrievedEntity>),
|
||||
Chunks(Vec<RetrievedChunk>),
|
||||
}
|
||||
|
||||
impl StrategyOutput {
|
||||
pub fn as_entities(&self) -> Option<&[RetrievedEntity]> {
|
||||
match self {
|
||||
StrategyOutput::Entities(items) => Some(items),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_entities(self) -> Option<Vec<RetrievedEntity>> {
|
||||
match self {
|
||||
StrategyOutput::Entities(items) => Some(items),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_chunks(&self) -> Option<&[RetrievedChunk]> {
|
||||
match self {
|
||||
StrategyOutput::Chunks(items) => Some(items),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_chunks(self) -> Option<Vec<RetrievedChunk>> {
|
||||
match self {
|
||||
StrategyOutput::Chunks(items) => Some(items),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PipelineRunOutput<T> {
|
||||
pub results: T,
|
||||
pub diagnostics: Option<PipelineDiagnostics>,
|
||||
pub stage_timings: PipelineStageTimings,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum StageKind {
|
||||
Embed,
|
||||
CollectCandidates,
|
||||
GraphExpansion,
|
||||
ChunkAttach,
|
||||
Rerank,
|
||||
Assemble,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, serde::Serialize)]
|
||||
pub struct PipelineStageTimings {
|
||||
pub embed_ms: u128,
|
||||
pub collect_candidates_ms: u128,
|
||||
pub graph_expansion_ms: u128,
|
||||
pub chunk_attach_ms: u128,
|
||||
pub rerank_ms: u128,
|
||||
pub assemble_ms: u128,
|
||||
}
|
||||
|
||||
impl PipelineStageTimings {
|
||||
pub fn record(&mut self, kind: StageKind, duration: Duration) {
|
||||
let elapsed = duration.as_millis() as u128;
|
||||
match kind {
|
||||
StageKind::Embed => self.embed_ms += elapsed,
|
||||
StageKind::CollectCandidates => self.collect_candidates_ms += elapsed,
|
||||
StageKind::GraphExpansion => self.graph_expansion_ms += elapsed,
|
||||
StageKind::ChunkAttach => self.chunk_attach_ms += elapsed,
|
||||
StageKind::Rerank => self.rerank_ms += elapsed,
|
||||
StageKind::Assemble => self.assemble_ms += elapsed,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait PipelineStage: Send + Sync {
|
||||
fn kind(&self) -> StageKind;
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError>;
|
||||
}
|
||||
|
||||
pub type BoxedStage = Box<dyn PipelineStage + Send + Sync>;
|
||||
|
||||
pub trait StrategyDriver {
|
||||
type Output;
|
||||
|
||||
fn strategy(&self) -> RetrievalStrategy;
|
||||
fn stages(&self) -> Vec<BoxedStage>;
|
||||
fn override_tuning(&self, _config: &mut RetrievalConfig) {}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError>;
|
||||
}
|
||||
|
||||
pub async fn run_pipeline(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<StrategyOutput, AppError> {
|
||||
let input_chars = input_text.chars().count();
|
||||
let input_preview: String = input_text.chars().take(120).collect();
|
||||
let input_preview_clean = input_preview.replace('\n', " ");
|
||||
let preview_len = input_preview_clean.chars().count();
|
||||
info!(
|
||||
%user_id,
|
||||
input_chars,
|
||||
preview_truncated = input_chars > preview_len,
|
||||
preview = %input_preview_clean,
|
||||
"Starting ingestion retrieval pipeline"
|
||||
);
|
||||
|
||||
if config.strategy == RetrievalStrategy::Initial {
|
||||
let driver = InitialStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
return Ok(StrategyOutput::Entities(run.results));
|
||||
}
|
||||
|
||||
let driver = RevisedStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(StrategyOutput::Chunks(run.results))
|
||||
}
|
||||
|
||||
pub async fn run_pipeline_with_embedding(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<StrategyOutput, AppError> {
|
||||
if config.strategy == RetrievalStrategy::Initial {
|
||||
let driver = InitialStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
return Ok(StrategyOutput::Entities(run.results));
|
||||
}
|
||||
|
||||
let driver = RevisedStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(StrategyOutput::Chunks(run.results))
|
||||
}
|
||||
|
||||
pub async fn run_pipeline_with_embedding_with_metrics(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
|
||||
if config.strategy == RetrievalStrategy::Initial {
|
||||
let driver = InitialStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
return Ok(PipelineRunOutput {
|
||||
results: StrategyOutput::Entities(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
});
|
||||
}
|
||||
|
||||
let driver = RevisedStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(PipelineRunOutput {
|
||||
results: StrategyOutput::Chunks(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn run_pipeline_with_embedding_with_diagnostics(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
|
||||
if config.strategy == RetrievalStrategy::Initial {
|
||||
let driver = InitialStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
true,
|
||||
)
|
||||
.await?;
|
||||
return Ok(PipelineRunOutput {
|
||||
results: StrategyOutput::Entities(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
});
|
||||
}
|
||||
|
||||
let driver = RevisedStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
true,
|
||||
)
|
||||
.await?;
|
||||
Ok(PipelineRunOutput {
|
||||
results: StrategyOutput::Chunks(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value {
|
||||
serde_json::json!(entities
|
||||
.iter()
|
||||
.map(|entry| {
|
||||
serde_json::json!({
|
||||
"KnowledgeEntity": {
|
||||
"id": entry.entity.id,
|
||||
"name": entry.entity.name,
|
||||
"description": entry.entity.description,
|
||||
"score": round_score(entry.score),
|
||||
"chunks": entry.chunks.iter().map(|chunk| {
|
||||
serde_json::json!({
|
||||
"score": round_score(chunk.score),
|
||||
"content": chunk.chunk.chunk
|
||||
})
|
||||
}).collect::<Vec<_>>()
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
}
|
||||
|
||||
async fn execute_strategy<D: StrategyDriver>(
|
||||
driver: D,
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
query_embedding: Option<Vec<f32>>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
mut config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
capture_diagnostics: bool,
|
||||
) -> Result<PipelineRunOutput<D::Output>, AppError> {
|
||||
driver.override_tuning(&mut config);
|
||||
let ctx = match query_embedding {
|
||||
Some(embedding) => PipelineContext::with_embedding(
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding,
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
config,
|
||||
reranker,
|
||||
),
|
||||
None => PipelineContext::new(
|
||||
db_client,
|
||||
openai_client,
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
config,
|
||||
reranker,
|
||||
),
|
||||
};
|
||||
|
||||
run_with_driver(driver, ctx, capture_diagnostics).await
|
||||
}
|
||||
|
||||
async fn run_with_driver<D: StrategyDriver>(
|
||||
driver: D,
|
||||
mut ctx: PipelineContext<'_>,
|
||||
capture_diagnostics: bool,
|
||||
) -> Result<PipelineRunOutput<D::Output>, AppError> {
|
||||
if capture_diagnostics {
|
||||
ctx.enable_diagnostics();
|
||||
}
|
||||
|
||||
for stage in driver.stages() {
|
||||
let start = Instant::now();
|
||||
stage.execute(&mut ctx).await?;
|
||||
ctx.record_stage_duration(stage.kind(), start.elapsed());
|
||||
}
|
||||
|
||||
let diagnostics = ctx.take_diagnostics();
|
||||
let stage_timings = ctx.take_stage_timings();
|
||||
let results = driver.finalize(&mut ctx)?;
|
||||
|
||||
Ok(PipelineRunOutput {
|
||||
results,
|
||||
diagnostics,
|
||||
stage_timings,
|
||||
})
|
||||
}
|
||||
|
||||
fn round_score(value: f32) -> f64 {
|
||||
(f64::from(value) * 1000.0).round() / 1000.0
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
use async_openai::Client;
|
||||
use async_trait::async_trait;
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
@@ -9,11 +10,9 @@ use common::{
|
||||
};
|
||||
use fastembed::RerankResult;
|
||||
use futures::{stream::FuturesUnordered, StreamExt};
|
||||
use state_machines::core::GuardError;
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
collections::{HashMap, HashSet},
|
||||
time::Instant,
|
||||
};
|
||||
use tracing::{debug, instrument, warn};
|
||||
|
||||
@@ -25,21 +24,20 @@ use crate::{
|
||||
clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc,
|
||||
FusionWeights, Scored,
|
||||
},
|
||||
vector::find_items_by_vector_similarity_with_embedding,
|
||||
vector::{
|
||||
find_chunk_snippets_by_vector_similarity_with_embedding,
|
||||
find_items_by_vector_similarity_with_embedding, ChunkSnippet,
|
||||
},
|
||||
RetrievedChunk, RetrievedEntity,
|
||||
};
|
||||
|
||||
use super::{
|
||||
config::RetrievalConfig,
|
||||
config::{RetrievalConfig, RetrievalTuning},
|
||||
diagnostics::{
|
||||
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
|
||||
PipelineDiagnostics,
|
||||
},
|
||||
state::{
|
||||
CandidatesLoaded, ChunksAttached, Embedded, GraphExpanded, HybridRetrievalMachine, Ready,
|
||||
Reranked,
|
||||
},
|
||||
PipelineStageTimings,
|
||||
PipelineStage, PipelineStageTimings, StageKind,
|
||||
};
|
||||
|
||||
pub struct PipelineContext<'a> {
|
||||
@@ -53,8 +51,11 @@ pub struct PipelineContext<'a> {
|
||||
pub chunk_candidates: HashMap<String, Scored<TextChunk>>,
|
||||
pub filtered_entities: Vec<Scored<KnowledgeEntity>>,
|
||||
pub chunk_values: Vec<Scored<TextChunk>>,
|
||||
pub revised_chunk_values: Vec<Scored<ChunkSnippet>>,
|
||||
pub reranker: Option<RerankerLease>,
|
||||
pub diagnostics: Option<PipelineDiagnostics>,
|
||||
pub entity_results: Vec<RetrievedEntity>,
|
||||
pub chunk_results: Vec<RetrievedChunk>,
|
||||
stage_timings: PipelineStageTimings,
|
||||
}
|
||||
|
||||
@@ -78,8 +79,11 @@ impl<'a> PipelineContext<'a> {
|
||||
chunk_candidates: HashMap::new(),
|
||||
filtered_entities: Vec::new(),
|
||||
chunk_values: Vec::new(),
|
||||
revised_chunk_values: Vec::new(),
|
||||
reranker,
|
||||
diagnostics: None,
|
||||
entity_results: Vec::new(),
|
||||
chunk_results: Vec::new(),
|
||||
stage_timings: PipelineStageTimings::default(),
|
||||
}
|
||||
}
|
||||
@@ -145,36 +149,151 @@ impl<'a> PipelineContext<'a> {
|
||||
self.diagnostics.take()
|
||||
}
|
||||
|
||||
pub fn record_collect_candidates_timing(&mut self, duration: std::time::Duration) {
|
||||
self.stage_timings.record_collect_candidates(duration);
|
||||
}
|
||||
|
||||
pub fn record_graph_expansion_timing(&mut self, duration: std::time::Duration) {
|
||||
self.stage_timings.record_graph_expansion(duration);
|
||||
}
|
||||
|
||||
pub fn record_chunk_attach_timing(&mut self, duration: std::time::Duration) {
|
||||
self.stage_timings.record_chunk_attach(duration);
|
||||
}
|
||||
|
||||
pub fn record_rerank_timing(&mut self, duration: std::time::Duration) {
|
||||
self.stage_timings.record_rerank(duration);
|
||||
}
|
||||
|
||||
pub fn record_assemble_timing(&mut self, duration: std::time::Duration) {
|
||||
self.stage_timings.record_assemble(duration);
|
||||
}
|
||||
|
||||
pub fn take_stage_timings(&mut self) -> PipelineStageTimings {
|
||||
std::mem::take(&mut self.stage_timings)
|
||||
}
|
||||
|
||||
pub fn record_stage_duration(&mut self, kind: StageKind, duration: std::time::Duration) {
|
||||
self.stage_timings.record(kind, duration);
|
||||
}
|
||||
|
||||
pub fn take_entity_results(&mut self) -> Vec<RetrievedEntity> {
|
||||
std::mem::take(&mut self.entity_results)
|
||||
}
|
||||
|
||||
pub fn take_chunk_results(&mut self) -> Vec<RetrievedChunk> {
|
||||
std::mem::take(&mut self.chunk_results)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct EmbedStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for EmbedStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Embed
|
||||
}
|
||||
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
embed(ctx).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct CollectCandidatesStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for CollectCandidatesStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::CollectCandidates
|
||||
}
|
||||
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
collect_candidates(ctx).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct GraphExpansionStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for GraphExpansionStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::GraphExpansion
|
||||
}
|
||||
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
expand_graph(ctx).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ChunkAttachStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for ChunkAttachStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::ChunkAttach
|
||||
}
|
||||
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
attach_chunks(ctx).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct RerankStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for RerankStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Rerank
|
||||
}
|
||||
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
rerank(ctx).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct AssembleEntitiesStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for AssembleEntitiesStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Assemble
|
||||
}
|
||||
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
assemble(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ChunkVectorStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for ChunkVectorStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::CollectCandidates
|
||||
}
|
||||
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
collect_vector_chunks(ctx).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ChunkRerankStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for ChunkRerankStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Rerank
|
||||
}
|
||||
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
rerank_chunks(ctx).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ChunkAssembleStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for ChunkAssembleStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Assemble
|
||||
}
|
||||
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
assemble_chunks(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn embed(
|
||||
machine: HybridRetrievalMachine<(), Ready>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<HybridRetrievalMachine<(), Embedded>, AppError> {
|
||||
pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
let embedding_cached = ctx.query_embedding.is_some();
|
||||
if embedding_cached {
|
||||
debug!("Reusing cached query embedding for hybrid retrieval");
|
||||
@@ -185,17 +304,11 @@ pub async fn embed(
|
||||
ctx.query_embedding = Some(embedding);
|
||||
}
|
||||
|
||||
machine
|
||||
.embed()
|
||||
.map_err(|(_, guard)| map_guard_error("embed", guard))
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn collect_candidates(
|
||||
machine: HybridRetrievalMachine<(), Embedded>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<HybridRetrievalMachine<(), CandidatesLoaded>, AppError> {
|
||||
let stage_start = Instant::now();
|
||||
pub async fn collect_candidates(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
debug!("Collecting initial candidates via vector and FTS search");
|
||||
let embedding = ctx.ensure_embedding()?.clone();
|
||||
let tuning = &ctx.config.tuning;
|
||||
@@ -265,104 +378,80 @@ pub async fn collect_candidates(
|
||||
apply_fusion(&mut ctx.entity_candidates, weights);
|
||||
apply_fusion(&mut ctx.chunk_candidates, weights);
|
||||
|
||||
let next = machine
|
||||
.collect_candidates()
|
||||
.map_err(|(_, guard)| map_guard_error("collect_candidates", guard))?;
|
||||
ctx.record_collect_candidates_timing(stage_start.elapsed());
|
||||
Ok(next)
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn expand_graph(
|
||||
machine: HybridRetrievalMachine<(), CandidatesLoaded>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<HybridRetrievalMachine<(), GraphExpanded>, AppError> {
|
||||
let stage_start = Instant::now();
|
||||
pub async fn expand_graph(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
debug!("Expanding candidates using graph relationships");
|
||||
let next = {
|
||||
let tuning = &ctx.config.tuning;
|
||||
let weights = FusionWeights::default();
|
||||
let tuning = &ctx.config.tuning;
|
||||
let weights = FusionWeights::default();
|
||||
|
||||
if ctx.entity_candidates.is_empty() {
|
||||
machine
|
||||
.expand_graph()
|
||||
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
|
||||
} else {
|
||||
let graph_seeds = seeds_from_candidates(
|
||||
&ctx.entity_candidates,
|
||||
tuning.graph_seed_min_score,
|
||||
tuning.graph_traversal_seed_limit,
|
||||
);
|
||||
if ctx.entity_candidates.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if graph_seeds.is_empty() {
|
||||
machine
|
||||
.expand_graph()
|
||||
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
|
||||
} else {
|
||||
let mut futures = FuturesUnordered::new();
|
||||
for seed in graph_seeds {
|
||||
let db = ctx.db_client;
|
||||
let user = ctx.user_id.clone();
|
||||
let limit = tuning.graph_neighbor_limit;
|
||||
futures.push(async move {
|
||||
let neighbors =
|
||||
find_entities_by_relationship_by_id(db, &seed.id, &user, limit).await;
|
||||
(seed, neighbors)
|
||||
});
|
||||
}
|
||||
let graph_seeds = seeds_from_candidates(
|
||||
&ctx.entity_candidates,
|
||||
tuning.graph_seed_min_score,
|
||||
tuning.graph_traversal_seed_limit,
|
||||
);
|
||||
|
||||
while let Some((seed, neighbors_result)) = futures.next().await {
|
||||
let neighbors = neighbors_result.map_err(AppError::from)?;
|
||||
if neighbors.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if graph_seeds.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for neighbor in neighbors {
|
||||
if neighbor.id == seed.id {
|
||||
continue;
|
||||
}
|
||||
let mut futures = FuturesUnordered::new();
|
||||
for seed in graph_seeds {
|
||||
let db = ctx.db_client;
|
||||
let user = ctx.user_id.clone();
|
||||
let limit = tuning.graph_neighbor_limit;
|
||||
futures.push(async move {
|
||||
let neighbors = find_entities_by_relationship_by_id(db, &seed.id, &user, limit).await;
|
||||
(seed, neighbors)
|
||||
});
|
||||
}
|
||||
|
||||
let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay);
|
||||
let entry = ctx
|
||||
.entity_candidates
|
||||
.entry(neighbor.id.clone())
|
||||
.or_insert_with(|| Scored::new(neighbor.clone()));
|
||||
|
||||
entry.item = neighbor;
|
||||
|
||||
let inherited_vector =
|
||||
clamp_unit(graph_score * tuning.graph_vector_inheritance);
|
||||
let vector_existing = entry.scores.vector.unwrap_or(0.0);
|
||||
if inherited_vector > vector_existing {
|
||||
entry.scores.vector = Some(inherited_vector);
|
||||
}
|
||||
|
||||
let existing_graph = entry.scores.graph.unwrap_or(f32::MIN);
|
||||
if graph_score > existing_graph || entry.scores.graph.is_none() {
|
||||
entry.scores.graph = Some(graph_score);
|
||||
}
|
||||
|
||||
let fused = fuse_scores(&entry.scores, weights);
|
||||
entry.update_fused(fused);
|
||||
}
|
||||
}
|
||||
|
||||
machine
|
||||
.expand_graph()
|
||||
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
|
||||
}
|
||||
while let Some((seed, neighbors_result)) = futures.next().await {
|
||||
let neighbors = neighbors_result.map_err(AppError::from)?;
|
||||
if neighbors.is_empty() {
|
||||
continue;
|
||||
}
|
||||
}?;
|
||||
ctx.record_graph_expansion_timing(stage_start.elapsed());
|
||||
Ok(next)
|
||||
|
||||
for neighbor in neighbors {
|
||||
if neighbor.id == seed.id {
|
||||
continue;
|
||||
}
|
||||
|
||||
let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay);
|
||||
let entry = ctx
|
||||
.entity_candidates
|
||||
.entry(neighbor.id.clone())
|
||||
.or_insert_with(|| Scored::new(neighbor.clone()));
|
||||
|
||||
entry.item = neighbor;
|
||||
|
||||
let inherited_vector = clamp_unit(graph_score * tuning.graph_vector_inheritance);
|
||||
let vector_existing = entry.scores.vector.unwrap_or(0.0);
|
||||
if inherited_vector > vector_existing {
|
||||
entry.scores.vector = Some(inherited_vector);
|
||||
}
|
||||
|
||||
let existing_graph = entry.scores.graph.unwrap_or(f32::MIN);
|
||||
if graph_score > existing_graph || entry.scores.graph.is_none() {
|
||||
entry.scores.graph = Some(graph_score);
|
||||
}
|
||||
|
||||
let fused = fuse_scores(&entry.scores, weights);
|
||||
entry.update_fused(fused);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn attach_chunks(
|
||||
machine: HybridRetrievalMachine<(), GraphExpanded>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<HybridRetrievalMachine<(), ChunksAttached>, AppError> {
|
||||
let stage_start = Instant::now();
|
||||
pub async fn attach_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
debug!("Attaching chunks to surviving entities");
|
||||
let tuning = &ctx.config.tuning;
|
||||
let weights = FusionWeights::default();
|
||||
@@ -438,19 +527,11 @@ pub async fn attach_chunks(
|
||||
|
||||
ctx.chunk_values = chunk_values;
|
||||
|
||||
let next = machine
|
||||
.attach_chunks()
|
||||
.map_err(|(_, guard)| map_guard_error("attach_chunks", guard))?;
|
||||
ctx.record_chunk_attach_timing(stage_start.elapsed());
|
||||
Ok(next)
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn rerank(
|
||||
machine: HybridRetrievalMachine<(), ChunksAttached>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<HybridRetrievalMachine<(), Reranked>, AppError> {
|
||||
let stage_start = Instant::now();
|
||||
pub async fn rerank(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
let mut applied = false;
|
||||
|
||||
if let Some(reranker) = ctx.reranker.as_ref() {
|
||||
@@ -490,19 +571,124 @@ pub async fn rerank(
|
||||
debug!("Applied reranking adjustments to candidate ordering");
|
||||
}
|
||||
|
||||
let next = machine
|
||||
.rerank()
|
||||
.map_err(|(_, guard)| map_guard_error("rerank", guard))?;
|
||||
ctx.record_rerank_timing(stage_start.elapsed());
|
||||
Ok(next)
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub fn assemble(
|
||||
machine: HybridRetrievalMachine<(), Reranked>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
let stage_start = Instant::now();
|
||||
pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
debug!("Collecting vector chunk candidates for revised strategy");
|
||||
let embedding = ctx.ensure_embedding()?.clone();
|
||||
let tuning = &ctx.config.tuning;
|
||||
let mut vector_chunks = find_chunk_snippets_by_vector_similarity_with_embedding(
|
||||
tuning.chunk_vector_take,
|
||||
embedding,
|
||||
ctx.db_client,
|
||||
&ctx.user_id,
|
||||
)
|
||||
.await?;
|
||||
|
||||
if ctx.diagnostics_enabled() {
|
||||
ctx.record_collect_candidates(CollectCandidatesStats {
|
||||
vector_entity_candidates: 0,
|
||||
vector_chunk_candidates: vector_chunks.len(),
|
||||
fts_entity_candidates: 0,
|
||||
fts_chunk_candidates: 0,
|
||||
vector_chunk_scores: sample_scores(&vector_chunks, |chunk| {
|
||||
chunk.scores.vector.unwrap_or(0.0)
|
||||
}),
|
||||
fts_chunk_scores: Vec::new(),
|
||||
});
|
||||
}
|
||||
|
||||
vector_chunks.sort_by(|a, b| b.fused.partial_cmp(&a.fused).unwrap_or(Ordering::Equal));
|
||||
ctx.revised_chunk_values = vector_chunks;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
if ctx.revised_chunk_values.len() <= 1 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let Some(reranker) = ctx.reranker.as_ref() else {
|
||||
debug!("No reranker lease provided; skipping chunk rerank stage");
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let documents = build_snippet_rerank_documents(
|
||||
&ctx.revised_chunk_values,
|
||||
ctx.config.tuning.rerank_keep_top.max(1),
|
||||
);
|
||||
if documents.len() <= 1 {
|
||||
debug!("Skipping chunk reranking stage; insufficient chunk documents");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
match reranker.rerank(&ctx.input_text, documents).await {
|
||||
Ok(results) if !results.is_empty() => {
|
||||
apply_snippet_rerank_results(
|
||||
&mut ctx.revised_chunk_values,
|
||||
&ctx.config.tuning,
|
||||
results,
|
||||
);
|
||||
}
|
||||
Ok(_) => debug!("Chunk reranker returned no results; retaining original order"),
|
||||
Err(err) => warn!(
|
||||
error = %err,
|
||||
"Chunk reranking failed; continuing with original ordering"
|
||||
),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
debug!("Assembling chunk-only retrieval results");
|
||||
let mut chunk_values = std::mem::take(&mut ctx.revised_chunk_values);
|
||||
let question_terms = extract_keywords(&ctx.input_text);
|
||||
rank_snippet_chunks_by_combined_score(
|
||||
&mut chunk_values,
|
||||
&question_terms,
|
||||
ctx.config.tuning.lexical_match_weight,
|
||||
);
|
||||
|
||||
let limit = ctx.config.tuning.chunk_vector_take.max(1);
|
||||
if chunk_values.len() > limit {
|
||||
chunk_values.truncate(limit);
|
||||
}
|
||||
|
||||
ctx.chunk_results = chunk_values
|
||||
.into_iter()
|
||||
.map(|chunk| {
|
||||
let text_chunk = snippet_into_text_chunk(chunk.item, &ctx.user_id);
|
||||
RetrievedChunk {
|
||||
chunk: text_chunk,
|
||||
score: chunk.fused,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if ctx.diagnostics_enabled() {
|
||||
ctx.record_assemble(AssembleStats {
|
||||
token_budget_start: ctx.config.tuning.token_budget_estimate,
|
||||
token_budget_spent: 0,
|
||||
token_budget_remaining: ctx.config.tuning.token_budget_estimate,
|
||||
budget_exhausted: false,
|
||||
chunks_selected: ctx.chunk_results.len(),
|
||||
chunks_skipped_due_budget: 0,
|
||||
entity_count: 0,
|
||||
entity_traces: Vec::new(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
debug!("Assembling final retrieved entities");
|
||||
let tuning = &ctx.config.tuning;
|
||||
let query_embedding = ctx.ensure_embedding()?.clone();
|
||||
@@ -610,11 +796,8 @@ pub fn assemble(
|
||||
});
|
||||
}
|
||||
|
||||
machine
|
||||
.assemble()
|
||||
.map_err(|(_, guard)| map_guard_error("assemble", guard))?;
|
||||
ctx.record_assemble_timing(stage_start.elapsed());
|
||||
Ok(results)
|
||||
ctx.entity_results = results;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
const SCORE_SAMPLE_LIMIT: usize = 8;
|
||||
@@ -630,12 +813,6 @@ where
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn map_guard_error(stage: &'static str, err: GuardError) -> AppError {
|
||||
AppError::InternalError(format!(
|
||||
"state machine guard '{stage}' failed: guard={}, event={}, kind={:?}",
|
||||
err.guard, err.event, err.kind
|
||||
))
|
||||
}
|
||||
fn normalize_fts_scores<T>(results: &mut [Scored<T>]) {
|
||||
let raw_scores: Vec<f32> = results
|
||||
.iter()
|
||||
@@ -873,6 +1050,23 @@ fn build_rerank_documents(ctx: &PipelineContext<'_>, max_chunks_per_entity: usiz
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn build_snippet_rerank_documents(
|
||||
chunks: &[Scored<ChunkSnippet>],
|
||||
max_chunks: usize,
|
||||
) -> Vec<String> {
|
||||
chunks
|
||||
.iter()
|
||||
.take(max_chunks)
|
||||
.map(|chunk| {
|
||||
format!(
|
||||
"Source: {}\nChunk:\n{}",
|
||||
chunk.item.source_id,
|
||||
chunk.item.chunk.trim()
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec<RerankResult>) {
|
||||
if results.is_empty() || ctx.filtered_entities.is_empty() {
|
||||
return;
|
||||
@@ -930,6 +1124,66 @@ fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec<RerankResult
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_snippet_rerank_results(
|
||||
chunks: &mut Vec<Scored<ChunkSnippet>>,
|
||||
tuning: &RetrievalTuning,
|
||||
results: Vec<RerankResult>,
|
||||
) {
|
||||
if results.is_empty() || chunks.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut remaining: Vec<Option<Scored<ChunkSnippet>>> =
|
||||
std::mem::take(chunks).into_iter().map(Some).collect();
|
||||
|
||||
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
|
||||
let normalized_scores = min_max_normalize(&raw_scores);
|
||||
|
||||
let use_only = tuning.rerank_scores_only;
|
||||
let blend = if use_only {
|
||||
1.0
|
||||
} else {
|
||||
clamp_unit(tuning.rerank_blend_weight)
|
||||
};
|
||||
|
||||
let mut reranked: Vec<Scored<ChunkSnippet>> = Vec::with_capacity(remaining.len());
|
||||
for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) {
|
||||
if let Some(slot) = remaining.get_mut(result.index) {
|
||||
if let Some(mut candidate) = slot.take() {
|
||||
let original = candidate.fused;
|
||||
let blended = if use_only {
|
||||
clamp_unit(normalized)
|
||||
} else {
|
||||
clamp_unit(original * (1.0 - blend) + normalized * blend)
|
||||
};
|
||||
candidate.update_fused(blended);
|
||||
reranked.push(candidate);
|
||||
}
|
||||
} else {
|
||||
warn!(
|
||||
result_index = result.index,
|
||||
"Chunk reranker returned out-of-range index; skipping"
|
||||
);
|
||||
}
|
||||
if reranked.len() == remaining.len() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for slot in remaining.into_iter() {
|
||||
if let Some(candidate) = slot {
|
||||
reranked.push(candidate);
|
||||
}
|
||||
}
|
||||
|
||||
let keep_top = tuning.rerank_keep_top;
|
||||
if keep_top > 0 && reranked.len() > keep_top {
|
||||
reranked.truncate(keep_top);
|
||||
}
|
||||
|
||||
*chunks = reranked;
|
||||
}
|
||||
|
||||
fn estimate_tokens(text: &str, avg_chars_per_token: usize) -> usize {
|
||||
let chars = text.chars().count().max(1);
|
||||
(chars / avg_chars_per_token).max(1)
|
||||
@@ -963,6 +1217,32 @@ fn extract_keywords(text: &str) -> Vec<String> {
|
||||
terms
|
||||
}
|
||||
|
||||
fn rank_snippet_chunks_by_combined_score(
|
||||
candidates: &mut [Scored<ChunkSnippet>],
|
||||
question_terms: &[String],
|
||||
lexical_weight: f32,
|
||||
) {
|
||||
if lexical_weight > 0.0 && !question_terms.is_empty() {
|
||||
for candidate in candidates.iter_mut() {
|
||||
let lexical = lexical_overlap_score(question_terms, &candidate.item.chunk);
|
||||
let combined = clamp_unit(candidate.fused + lexical_weight * lexical);
|
||||
candidate.update_fused(combined);
|
||||
}
|
||||
}
|
||||
candidates.sort_by(|a, b| b.fused.partial_cmp(&a.fused).unwrap_or(Ordering::Equal));
|
||||
}
|
||||
|
||||
fn snippet_into_text_chunk(snippet: ChunkSnippet, user_id: &str) -> TextChunk {
|
||||
let mut chunk = TextChunk::new(
|
||||
snippet.source_id.clone(),
|
||||
snippet.chunk,
|
||||
Vec::new(),
|
||||
user_id.to_owned(),
|
||||
);
|
||||
chunk.id = snippet.id;
|
||||
chunk
|
||||
}
|
||||
|
||||
fn lexical_overlap_score(terms: &[String], haystack: &str) -> f32 {
|
||||
if terms.is_empty() {
|
||||
return 0.0;
|
||||
75
retrieval-pipeline/src/pipeline/strategies.rs
Normal file
75
retrieval-pipeline/src/pipeline/strategies.rs
Normal file
@@ -0,0 +1,75 @@
|
||||
use super::{
|
||||
stages::{
|
||||
AssembleEntitiesStage, ChunkAssembleStage, ChunkAttachStage, ChunkRerankStage,
|
||||
ChunkVectorStage, CollectCandidatesStage, EmbedStage, GraphExpansionStage, PipelineContext,
|
||||
RerankStage,
|
||||
},
|
||||
BoxedStage, RetrievalConfig, RetrievalStrategy, StrategyDriver,
|
||||
};
|
||||
use crate::{RetrievedChunk, RetrievedEntity};
|
||||
use common::error::AppError;
|
||||
|
||||
pub struct InitialStrategyDriver;
|
||||
|
||||
impl InitialStrategyDriver {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl StrategyDriver for InitialStrategyDriver {
|
||||
type Output = Vec<RetrievedEntity>;
|
||||
|
||||
fn strategy(&self) -> RetrievalStrategy {
|
||||
RetrievalStrategy::Initial
|
||||
}
|
||||
|
||||
fn stages(&self) -> Vec<BoxedStage> {
|
||||
vec![
|
||||
Box::new(EmbedStage),
|
||||
Box::new(CollectCandidatesStage),
|
||||
Box::new(GraphExpansionStage),
|
||||
Box::new(ChunkAttachStage),
|
||||
Box::new(RerankStage),
|
||||
Box::new(AssembleEntitiesStage),
|
||||
]
|
||||
}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
|
||||
Ok(ctx.take_entity_results())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RevisedStrategyDriver;
|
||||
|
||||
impl RevisedStrategyDriver {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl StrategyDriver for RevisedStrategyDriver {
|
||||
type Output = Vec<RetrievedChunk>;
|
||||
|
||||
fn strategy(&self) -> RetrievalStrategy {
|
||||
RetrievalStrategy::Revised
|
||||
}
|
||||
|
||||
fn stages(&self) -> Vec<BoxedStage> {
|
||||
vec![
|
||||
Box::new(EmbedStage),
|
||||
Box::new(ChunkVectorStage),
|
||||
Box::new(ChunkRerankStage),
|
||||
Box::new(ChunkAssembleStage),
|
||||
]
|
||||
}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
|
||||
Ok(ctx.take_chunk_results())
|
||||
}
|
||||
|
||||
fn override_tuning(&self, config: &mut RetrievalConfig) {
|
||||
config.tuning.entity_vector_take = 0;
|
||||
config.tuning.entity_fts_take = 0;
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,11 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use common::storage::types::file_info::deserialize_flexible_id;
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{db::SurrealDbClient, types::StoredObject},
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{file_info::deserialize_flexible_id, StoredObject},
|
||||
},
|
||||
utils::embedding::generate_embedding,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
@@ -156,3 +158,61 @@ where
|
||||
|
||||
Ok(scored)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ChunkSnippet {
|
||||
pub id: String,
|
||||
pub source_id: String,
|
||||
pub chunk: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChunkDistanceRow {
|
||||
distance: Option<f32>,
|
||||
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||
pub id: String,
|
||||
pub source_id: String,
|
||||
pub chunk: String,
|
||||
}
|
||||
|
||||
pub async fn find_chunk_snippets_by_vector_similarity_with_embedding(
|
||||
take: usize,
|
||||
query_embedding: Vec<f32>,
|
||||
db_client: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<Scored<ChunkSnippet>>, AppError> {
|
||||
let embedding_literal = serde_json::to_string(&query_embedding)
|
||||
.map_err(|err| AppError::InternalError(format!("Failed to serialize embedding: {err}")))?;
|
||||
|
||||
let closest_query = format!(
|
||||
"SELECT id, source_id, chunk, vector::distance::knn() AS distance \
|
||||
FROM text_chunk \
|
||||
WHERE user_id = $user_id AND embedding <|{take},40|> {embedding} \
|
||||
LIMIT $limit",
|
||||
take = take,
|
||||
embedding = embedding_literal
|
||||
);
|
||||
|
||||
let mut response = db_client
|
||||
.query(closest_query)
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.bind(("limit", take as i64))
|
||||
.await?;
|
||||
|
||||
let rows: Vec<ChunkDistanceRow> = response.take(0)?;
|
||||
|
||||
let mut scored = Vec::with_capacity(rows.len());
|
||||
for row in rows {
|
||||
let similarity = row.distance.map(distance_to_similarity).unwrap_or_default();
|
||||
scored.push(
|
||||
Scored::new(ChunkSnippet {
|
||||
id: row.id,
|
||||
source_id: row.source_id,
|
||||
chunk: row.chunk,
|
||||
})
|
||||
.with_vector_score(similarity),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(scored)
|
||||
}
|
||||
Reference in New Issue
Block a user