mirror of
https://github.com/perstarkse/minne.git
synced 2026-04-24 09:48:32 +02:00
retrieval-pipeline: v0
This commit is contained in:
48
Cargo.lock
generated
48
Cargo.lock
generated
@@ -1457,26 +1457,6 @@ dependencies = [
|
|||||||
"static_assertions",
|
"static_assertions",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "composite-retrieval"
|
|
||||||
version = "0.1.0"
|
|
||||||
dependencies = [
|
|
||||||
"anyhow",
|
|
||||||
"async-openai",
|
|
||||||
"axum",
|
|
||||||
"common",
|
|
||||||
"fastembed",
|
|
||||||
"futures",
|
|
||||||
"serde",
|
|
||||||
"serde_json",
|
|
||||||
"state-machines",
|
|
||||||
"surrealdb",
|
|
||||||
"thiserror 1.0.69",
|
|
||||||
"tokio",
|
|
||||||
"tracing",
|
|
||||||
"uuid",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "compression-codecs"
|
name = "compression-codecs"
|
||||||
version = "0.4.30"
|
version = "0.4.30"
|
||||||
@@ -2197,7 +2177,6 @@ dependencies = [
|
|||||||
"async-trait",
|
"async-trait",
|
||||||
"chrono",
|
"chrono",
|
||||||
"common",
|
"common",
|
||||||
"composite-retrieval",
|
|
||||||
"criterion",
|
"criterion",
|
||||||
"fastembed",
|
"fastembed",
|
||||||
"futures",
|
"futures",
|
||||||
@@ -2205,6 +2184,7 @@ dependencies = [
|
|||||||
"object_store 0.11.2",
|
"object_store 0.11.2",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"rand 0.8.5",
|
"rand 0.8.5",
|
||||||
|
"retrieval-pipeline",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"serde_yaml",
|
"serde_yaml",
|
||||||
@@ -2880,7 +2860,6 @@ dependencies = [
|
|||||||
"chrono",
|
"chrono",
|
||||||
"chrono-tz",
|
"chrono-tz",
|
||||||
"common",
|
"common",
|
||||||
"composite-retrieval",
|
|
||||||
"futures",
|
"futures",
|
||||||
"include_dir",
|
"include_dir",
|
||||||
"json-stream-parser",
|
"json-stream-parser",
|
||||||
@@ -2888,6 +2867,7 @@ dependencies = [
|
|||||||
"minijinja-autoreload",
|
"minijinja-autoreload",
|
||||||
"minijinja-contrib",
|
"minijinja-contrib",
|
||||||
"minijinja-embed",
|
"minijinja-embed",
|
||||||
|
"retrieval-pipeline",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"surrealdb",
|
"surrealdb",
|
||||||
@@ -3342,13 +3322,13 @@ dependencies = [
|
|||||||
"bytes",
|
"bytes",
|
||||||
"chrono",
|
"chrono",
|
||||||
"common",
|
"common",
|
||||||
"composite-retrieval",
|
|
||||||
"dom_smoothie",
|
"dom_smoothie",
|
||||||
"futures",
|
"futures",
|
||||||
"headless_chrome",
|
"headless_chrome",
|
||||||
"lopdf 0.32.0",
|
"lopdf 0.32.0",
|
||||||
"pdf-extract",
|
"pdf-extract",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
|
"retrieval-pipeline",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"state-machines",
|
"state-machines",
|
||||||
@@ -3802,10 +3782,10 @@ dependencies = [
|
|||||||
"async-openai",
|
"async-openai",
|
||||||
"axum",
|
"axum",
|
||||||
"common",
|
"common",
|
||||||
"composite-retrieval",
|
|
||||||
"futures",
|
"futures",
|
||||||
"html-router",
|
"html-router",
|
||||||
"ingestion-pipeline",
|
"ingestion-pipeline",
|
||||||
|
"retrieval-pipeline",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"surrealdb",
|
"surrealdb",
|
||||||
@@ -5475,6 +5455,26 @@ dependencies = [
|
|||||||
"thiserror 1.0.69",
|
"thiserror 1.0.69",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "retrieval-pipeline"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"async-openai",
|
||||||
|
"async-trait",
|
||||||
|
"axum",
|
||||||
|
"common",
|
||||||
|
"fastembed",
|
||||||
|
"futures",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"surrealdb",
|
||||||
|
"thiserror 1.0.69",
|
||||||
|
"tokio",
|
||||||
|
"tracing",
|
||||||
|
"uuid",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "revision"
|
name = "revision"
|
||||||
version = "0.10.0"
|
version = "0.10.0"
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ members = [
|
|||||||
"api-router",
|
"api-router",
|
||||||
"html-router",
|
"html-router",
|
||||||
"ingestion-pipeline",
|
"ingestion-pipeline",
|
||||||
"composite-retrieval",
|
"retrieval-pipeline",
|
||||||
"json-stream-parser",
|
"json-stream-parser",
|
||||||
"eval"
|
"eval"
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -54,6 +54,8 @@ pub struct AppConfig {
|
|||||||
pub fastembed_show_download_progress: Option<bool>,
|
pub fastembed_show_download_progress: Option<bool>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub fastembed_max_length: Option<usize>,
|
pub fastembed_max_length: Option<usize>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub retrieval_strategy: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_data_dir() -> String {
|
fn default_data_dir() -> String {
|
||||||
@@ -117,6 +119,7 @@ impl Default for AppConfig {
|
|||||||
fastembed_cache_dir: None,
|
fastembed_cache_dir: None,
|
||||||
fastembed_show_download_progress: None,
|
fastembed_show_download_progress: None,
|
||||||
fastembed_max_length: None,
|
fastembed_max_length: None,
|
||||||
|
retrieval_strategy: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,212 +0,0 @@
|
|||||||
mod config;
|
|
||||||
mod diagnostics;
|
|
||||||
mod stages;
|
|
||||||
mod state;
|
|
||||||
|
|
||||||
pub use config::{RetrievalConfig, RetrievalTuning};
|
|
||||||
pub use diagnostics::{
|
|
||||||
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
|
|
||||||
PipelineDiagnostics,
|
|
||||||
};
|
|
||||||
|
|
||||||
use crate::{reranking::RerankerLease, RetrievedEntity};
|
|
||||||
use async_openai::Client;
|
|
||||||
use common::{error::AppError, storage::db::SurrealDbClient};
|
|
||||||
use tracing::info;
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct PipelineRunOutput {
|
|
||||||
pub results: Vec<RetrievedEntity>,
|
|
||||||
pub diagnostics: Option<PipelineDiagnostics>,
|
|
||||||
pub stage_timings: PipelineStageTimings,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default, serde::Serialize)]
|
|
||||||
pub struct PipelineStageTimings {
|
|
||||||
pub collect_candidates_ms: u128,
|
|
||||||
pub graph_expansion_ms: u128,
|
|
||||||
pub chunk_attach_ms: u128,
|
|
||||||
pub rerank_ms: u128,
|
|
||||||
pub assemble_ms: u128,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PipelineStageTimings {
|
|
||||||
fn record_collect_candidates(&mut self, duration: std::time::Duration) {
|
|
||||||
self.collect_candidates_ms += duration.as_millis() as u128;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn record_graph_expansion(&mut self, duration: std::time::Duration) {
|
|
||||||
self.graph_expansion_ms += duration.as_millis() as u128;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn record_chunk_attach(&mut self, duration: std::time::Duration) {
|
|
||||||
self.chunk_attach_ms += duration.as_millis() as u128;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn record_rerank(&mut self, duration: std::time::Duration) {
|
|
||||||
self.rerank_ms += duration.as_millis() as u128;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn record_assemble(&mut self, duration: std::time::Duration) {
|
|
||||||
self.assemble_ms += duration.as_millis() as u128;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Drives the retrieval pipeline from embedding through final assembly.
|
|
||||||
pub async fn run_pipeline(
|
|
||||||
db_client: &SurrealDbClient,
|
|
||||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
|
||||||
input_text: &str,
|
|
||||||
user_id: &str,
|
|
||||||
config: RetrievalConfig,
|
|
||||||
reranker: Option<RerankerLease>,
|
|
||||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
|
||||||
let input_chars = input_text.chars().count();
|
|
||||||
let input_preview: String = input_text.chars().take(120).collect();
|
|
||||||
let input_preview_clean = input_preview.replace('\n', " ");
|
|
||||||
let preview_len = input_preview_clean.chars().count();
|
|
||||||
info!(
|
|
||||||
%user_id,
|
|
||||||
input_chars,
|
|
||||||
preview_truncated = input_chars > preview_len,
|
|
||||||
preview = %input_preview_clean,
|
|
||||||
"Starting ingestion retrieval pipeline"
|
|
||||||
);
|
|
||||||
let ctx = stages::PipelineContext::new(
|
|
||||||
db_client,
|
|
||||||
openai_client,
|
|
||||||
input_text.to_owned(),
|
|
||||||
user_id.to_owned(),
|
|
||||||
config,
|
|
||||||
reranker,
|
|
||||||
);
|
|
||||||
let outcome = run_pipeline_internal(ctx, false).await?;
|
|
||||||
|
|
||||||
Ok(outcome.results)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn run_pipeline_with_embedding(
|
|
||||||
db_client: &SurrealDbClient,
|
|
||||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
|
||||||
query_embedding: Vec<f32>,
|
|
||||||
input_text: &str,
|
|
||||||
user_id: &str,
|
|
||||||
config: RetrievalConfig,
|
|
||||||
reranker: Option<RerankerLease>,
|
|
||||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
|
||||||
let ctx = stages::PipelineContext::with_embedding(
|
|
||||||
db_client,
|
|
||||||
openai_client,
|
|
||||||
query_embedding,
|
|
||||||
input_text.to_owned(),
|
|
||||||
user_id.to_owned(),
|
|
||||||
config,
|
|
||||||
reranker,
|
|
||||||
);
|
|
||||||
let outcome = run_pipeline_internal(ctx, false).await?;
|
|
||||||
|
|
||||||
Ok(outcome.results)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Runs the pipeline with a precomputed embedding and returns stage metrics.
|
|
||||||
pub async fn run_pipeline_with_embedding_with_metrics(
|
|
||||||
db_client: &SurrealDbClient,
|
|
||||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
|
||||||
query_embedding: Vec<f32>,
|
|
||||||
input_text: &str,
|
|
||||||
user_id: &str,
|
|
||||||
config: RetrievalConfig,
|
|
||||||
reranker: Option<RerankerLease>,
|
|
||||||
) -> Result<PipelineRunOutput, AppError> {
|
|
||||||
let ctx = stages::PipelineContext::with_embedding(
|
|
||||||
db_client,
|
|
||||||
openai_client,
|
|
||||||
query_embedding,
|
|
||||||
input_text.to_owned(),
|
|
||||||
user_id.to_owned(),
|
|
||||||
config,
|
|
||||||
reranker,
|
|
||||||
);
|
|
||||||
|
|
||||||
run_pipeline_internal(ctx, false).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn run_pipeline_with_embedding_with_diagnostics(
|
|
||||||
db_client: &SurrealDbClient,
|
|
||||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
|
||||||
query_embedding: Vec<f32>,
|
|
||||||
input_text: &str,
|
|
||||||
user_id: &str,
|
|
||||||
config: RetrievalConfig,
|
|
||||||
reranker: Option<RerankerLease>,
|
|
||||||
) -> Result<PipelineRunOutput, AppError> {
|
|
||||||
let ctx = stages::PipelineContext::with_embedding(
|
|
||||||
db_client,
|
|
||||||
openai_client,
|
|
||||||
query_embedding,
|
|
||||||
input_text.to_owned(),
|
|
||||||
user_id.to_owned(),
|
|
||||||
config,
|
|
||||||
reranker,
|
|
||||||
);
|
|
||||||
|
|
||||||
run_pipeline_internal(ctx, true).await
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Helper exposed for tests to convert retrieved entities into downstream prompt JSON.
|
|
||||||
pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value {
|
|
||||||
serde_json::json!(entities
|
|
||||||
.iter()
|
|
||||||
.map(|entry| {
|
|
||||||
serde_json::json!({
|
|
||||||
"KnowledgeEntity": {
|
|
||||||
"id": entry.entity.id,
|
|
||||||
"name": entry.entity.name,
|
|
||||||
"description": entry.entity.description,
|
|
||||||
"score": round_score(entry.score),
|
|
||||||
"chunks": entry.chunks.iter().map(|chunk| {
|
|
||||||
serde_json::json!({
|
|
||||||
"score": round_score(chunk.score),
|
|
||||||
"content": chunk.chunk.chunk
|
|
||||||
})
|
|
||||||
}).collect::<Vec<_>>()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn run_pipeline_internal(
|
|
||||||
mut ctx: stages::PipelineContext<'_>,
|
|
||||||
capture_diagnostics: bool,
|
|
||||||
) -> Result<PipelineRunOutput, AppError> {
|
|
||||||
if capture_diagnostics {
|
|
||||||
ctx.enable_diagnostics();
|
|
||||||
}
|
|
||||||
|
|
||||||
let results = drive_pipeline(&mut ctx).await?;
|
|
||||||
let diagnostics = ctx.take_diagnostics();
|
|
||||||
|
|
||||||
Ok(PipelineRunOutput {
|
|
||||||
results,
|
|
||||||
diagnostics,
|
|
||||||
stage_timings: ctx.take_stage_timings(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn drive_pipeline(
|
|
||||||
ctx: &mut stages::PipelineContext<'_>,
|
|
||||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
|
||||||
let machine = state::ready();
|
|
||||||
let machine = stages::embed(machine, ctx).await?;
|
|
||||||
let machine = stages::collect_candidates(machine, ctx).await?;
|
|
||||||
let machine = stages::expand_graph(machine, ctx).await?;
|
|
||||||
let machine = stages::attach_chunks(machine, ctx).await?;
|
|
||||||
let machine = stages::rerank(machine, ctx).await?;
|
|
||||||
let results = stages::assemble(machine, ctx)?;
|
|
||||||
Ok(results)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn round_score(value: f32) -> f64 {
|
|
||||||
(f64::from(value) * 1000.0).round() / 1000.0
|
|
||||||
}
|
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
use state_machines::state_machine;
|
|
||||||
|
|
||||||
state_machine! {
|
|
||||||
name: HybridRetrievalMachine,
|
|
||||||
state: HybridRetrievalState,
|
|
||||||
initial: Ready,
|
|
||||||
states: [Ready, Embedded, CandidatesLoaded, GraphExpanded, ChunksAttached, Reranked, Completed, Failed],
|
|
||||||
events {
|
|
||||||
embed { transition: { from: Ready, to: Embedded } }
|
|
||||||
collect_candidates { transition: { from: Embedded, to: CandidatesLoaded } }
|
|
||||||
expand_graph { transition: { from: CandidatesLoaded, to: GraphExpanded } }
|
|
||||||
attach_chunks { transition: { from: GraphExpanded, to: ChunksAttached } }
|
|
||||||
rerank { transition: { from: ChunksAttached, to: Reranked } }
|
|
||||||
assemble { transition: { from: Reranked, to: Completed } }
|
|
||||||
abort {
|
|
||||||
transition: { from: Ready, to: Failed }
|
|
||||||
transition: { from: CandidatesLoaded, to: Failed }
|
|
||||||
transition: { from: GraphExpanded, to: Failed }
|
|
||||||
transition: { from: ChunksAttached, to: Failed }
|
|
||||||
transition: { from: Reranked, to: Failed }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn ready() -> HybridRetrievalMachine<(), Ready> {
|
|
||||||
HybridRetrievalMachine::new(())
|
|
||||||
}
|
|
||||||
@@ -38,7 +38,7 @@ url = { workspace = true }
|
|||||||
uuid = { workspace = true }
|
uuid = { workspace = true }
|
||||||
|
|
||||||
common = { path = "../common" }
|
common = { path = "../common" }
|
||||||
composite-retrieval = { path = "../composite-retrieval" }
|
retrieval-pipeline = { path = "../retrieval-pipeline" }
|
||||||
json-stream-parser = { path = "../json-stream-parser" }
|
json-stream-parser = { path = "../json-stream-parser" }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use common::storage::{db::SurrealDbClient, store::StorageManager};
|
use common::storage::{db::SurrealDbClient, store::StorageManager};
|
||||||
use common::utils::template_engine::{ProvidesTemplateEngine, TemplateEngine};
|
use common::utils::template_engine::{ProvidesTemplateEngine, TemplateEngine};
|
||||||
use common::{create_template_engine, storage::db::ProvidesDb, utils::config::AppConfig};
|
use common::{create_template_engine, storage::db::ProvidesDb, utils::config::AppConfig};
|
||||||
use composite_retrieval::reranking::RerankerPool;
|
use retrieval_pipeline::{reranking::RerankerPool, RetrievalStrategy};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
|
|
||||||
@@ -40,6 +40,14 @@ impl HtmlState {
|
|||||||
reranker_pool,
|
reranker_pool,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn retrieval_strategy(&self) -> RetrievalStrategy {
|
||||||
|
self.config
|
||||||
|
.retrieval_strategy
|
||||||
|
.as_deref()
|
||||||
|
.and_then(|value| value.parse().ok())
|
||||||
|
.unwrap_or(RetrievalStrategy::Initial)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
impl ProvidesDb for HtmlState {
|
impl ProvidesDb for HtmlState {
|
||||||
fn db(&self) -> &Arc<SurrealDbClient> {
|
fn db(&self) -> &Arc<SurrealDbClient> {
|
||||||
|
|||||||
@@ -8,16 +8,16 @@ use axum::{
|
|||||||
Sse,
|
Sse,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use composite_retrieval::{
|
|
||||||
answer_retrieval::{create_chat_request, create_user_message_with_history, LLMResponseFormat},
|
|
||||||
retrieve_entities, retrieved_entities_to_json,
|
|
||||||
};
|
|
||||||
use futures::{
|
use futures::{
|
||||||
stream::{self, once},
|
stream::{self, once},
|
||||||
Stream, StreamExt, TryStreamExt,
|
Stream, StreamExt, TryStreamExt,
|
||||||
};
|
};
|
||||||
use json_stream_parser::JsonStreamParser;
|
use json_stream_parser::JsonStreamParser;
|
||||||
use minijinja::Value;
|
use minijinja::Value;
|
||||||
|
use retrieval_pipeline::{
|
||||||
|
answer_retrieval::{create_chat_request, create_user_message_with_history, LLMResponseFormat},
|
||||||
|
retrieve_entities, retrieved_entities_to_json, RetrievalConfig, StrategyOutput,
|
||||||
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::from_str;
|
use serde_json::from_str;
|
||||||
use tokio::sync::{mpsc::channel, Mutex};
|
use tokio::sync::{mpsc::channel, Mutex};
|
||||||
@@ -123,16 +123,24 @@ pub async fn get_response_stream(
|
|||||||
None => None,
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let mut retrieval_config = RetrievalConfig::default();
|
||||||
|
retrieval_config.strategy = state.retrieval_strategy();
|
||||||
let entities = match retrieve_entities(
|
let entities = match retrieve_entities(
|
||||||
&state.db,
|
&state.db,
|
||||||
&state.openai_client,
|
&state.openai_client,
|
||||||
&user_message.content,
|
&user_message.content,
|
||||||
&user.id,
|
&user.id,
|
||||||
|
retrieval_config,
|
||||||
rerank_lease,
|
rerank_lease,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(entities) => entities,
|
Ok(StrategyOutput::Entities(entities)) => entities,
|
||||||
|
Ok(StrategyOutput::Chunks(_)) => {
|
||||||
|
return Sse::new(create_error_stream(
|
||||||
|
"Chunk-only retrieval results are not supported in this route",
|
||||||
|
))
|
||||||
|
}
|
||||||
Err(_e) => {
|
Err(_e) => {
|
||||||
return Sse::new(create_error_stream("Failed to retrieve knowledge entities"));
|
return Sse::new(create_error_stream("Failed to retrieve knowledge entities"));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ use common::{
|
|||||||
},
|
},
|
||||||
utils::embedding::generate_embedding,
|
utils::embedding::generate_embedding,
|
||||||
};
|
};
|
||||||
use composite_retrieval::{retrieve_entities, RetrievedEntity};
|
use retrieval_pipeline::{retrieve_entities, RetrievalConfig, RetrievedEntity, StrategyOutput};
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
@@ -284,11 +284,15 @@ pub async fn suggest_knowledge_relationships(
|
|||||||
None => None,
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Ok(results) = retrieve_entities(
|
let mut retrieval_config = RetrievalConfig::default();
|
||||||
|
retrieval_config.strategy = state.retrieval_strategy();
|
||||||
|
|
||||||
|
if let Ok(StrategyOutput::Entities(results)) = retrieve_entities(
|
||||||
&state.db,
|
&state.db,
|
||||||
&state.openai_client,
|
&state.openai_client,
|
||||||
&query,
|
&query,
|
||||||
&user.id,
|
&user.id,
|
||||||
|
retrieval_config,
|
||||||
rerank_lease,
|
rerank_lease,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ lopdf = "0.32"
|
|||||||
bytes = { workspace = true }
|
bytes = { workspace = true }
|
||||||
|
|
||||||
common = { path = "../common" }
|
common = { path = "../common" }
|
||||||
composite-retrieval = { path = "../composite-retrieval" }
|
retrieval-pipeline = { path = "../retrieval-pipeline" }
|
||||||
async-trait = { workspace = true }
|
async-trait = { workspace = true }
|
||||||
state-machines = { workspace = true }
|
state-machines = { workspace = true }
|
||||||
[features]
|
[features]
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ use common::{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use composite_retrieval::RetrievedEntity;
|
use retrieval_pipeline::RetrievedEntity;
|
||||||
use tracing::error;
|
use tracing::error;
|
||||||
|
|
||||||
use super::enrichment_result::LLMEnrichmentResult;
|
use super::enrichment_result::LLMEnrichmentResult;
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ use common::{
|
|||||||
},
|
},
|
||||||
utils::config::AppConfig,
|
utils::config::AppConfig,
|
||||||
};
|
};
|
||||||
use composite_retrieval::reranking::RerankerPool;
|
use retrieval_pipeline::reranking::RerankerPool;
|
||||||
use tracing::{debug, info, warn};
|
use tracing::{debug, info, warn};
|
||||||
|
|
||||||
use self::{
|
use self::{
|
||||||
|
|||||||
@@ -19,8 +19,9 @@ use common::{
|
|||||||
},
|
},
|
||||||
utils::{config::AppConfig, embedding::generate_embedding},
|
utils::{config::AppConfig, embedding::generate_embedding},
|
||||||
};
|
};
|
||||||
use composite_retrieval::{
|
use retrieval_pipeline::{
|
||||||
reranking::RerankerPool, retrieve_entities, retrieved_entities_to_json, RetrievedEntity,
|
reranking::RerankerPool, retrieve_entities, retrieved_entities_to_json, RetrievalConfig,
|
||||||
|
RetrievalStrategy, RetrievedEntity, StrategyOutput,
|
||||||
};
|
};
|
||||||
use text_splitter::TextSplitter;
|
use text_splitter::TextSplitter;
|
||||||
|
|
||||||
@@ -124,6 +125,14 @@ impl DefaultPipelineServices {
|
|||||||
Ok(request)
|
Ok(request)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn configured_strategy(&self) -> RetrievalStrategy {
|
||||||
|
self.config
|
||||||
|
.retrieval_strategy
|
||||||
|
.as_deref()
|
||||||
|
.and_then(|value| value.parse().ok())
|
||||||
|
.unwrap_or(RetrievalStrategy::Initial)
|
||||||
|
}
|
||||||
|
|
||||||
async fn perform_analysis(
|
async fn perform_analysis(
|
||||||
&self,
|
&self,
|
||||||
request: CreateChatCompletionRequest,
|
request: CreateChatCompletionRequest,
|
||||||
@@ -178,14 +187,24 @@ impl PipelineServices for DefaultPipelineServices {
|
|||||||
None => None,
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
retrieve_entities(
|
let mut config = RetrievalConfig::default();
|
||||||
|
config.strategy = self.configured_strategy();
|
||||||
|
match retrieve_entities(
|
||||||
&self.db,
|
&self.db,
|
||||||
&self.openai_client,
|
&self.openai_client,
|
||||||
&input_text,
|
&input_text,
|
||||||
&content.user_id,
|
&content.user_id,
|
||||||
|
config,
|
||||||
rerank_lease,
|
rerank_lease,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
|
{
|
||||||
|
Ok(StrategyOutput::Entities(entities)) => Ok(entities),
|
||||||
|
Ok(StrategyOutput::Chunks(_)) => Err(AppError::InternalError(
|
||||||
|
"Chunk-only retrieval is not supported in ingestion".into(),
|
||||||
|
)),
|
||||||
|
Err(err) => Err(err),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn run_enrichment(
|
async fn run_enrichment(
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ use common::{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use composite_retrieval::{RetrievedChunk, RetrievedEntity};
|
use retrieval_pipeline::{RetrievedChunk, RetrievedEntity};
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ ingestion-pipeline = { path = "../ingestion-pipeline" }
|
|||||||
api-router = { path = "../api-router" }
|
api-router = { path = "../api-router" }
|
||||||
html-router = { path = "../html-router" }
|
html-router = { path = "../html-router" }
|
||||||
common = { path = "../common" }
|
common = { path = "../common" }
|
||||||
composite-retrieval = { path = "../composite-retrieval" }
|
retrieval-pipeline = { path = "../retrieval-pipeline" }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tower = "0.5"
|
tower = "0.5"
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ use axum::{extract::FromRef, Router};
|
|||||||
use common::{
|
use common::{
|
||||||
storage::db::SurrealDbClient, storage::store::StorageManager, utils::config::get_config,
|
storage::db::SurrealDbClient, storage::store::StorageManager, utils::config::get_config,
|
||||||
};
|
};
|
||||||
use composite_retrieval::reranking::RerankerPool;
|
|
||||||
use html_router::{html_routes, html_state::HtmlState};
|
use html_router::{html_routes, html_state::HtmlState};
|
||||||
use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop};
|
use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop};
|
||||||
|
use retrieval_pipeline::reranking::RerankerPool;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tracing::{error, info};
|
use tracing::{error, info};
|
||||||
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
|
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ use axum::{extract::FromRef, Router};
|
|||||||
use common::{
|
use common::{
|
||||||
storage::db::SurrealDbClient, storage::store::StorageManager, utils::config::get_config,
|
storage::db::SurrealDbClient, storage::store::StorageManager, utils::config::get_config,
|
||||||
};
|
};
|
||||||
use composite_retrieval::reranking::RerankerPool;
|
|
||||||
use html_router::{html_routes, html_state::HtmlState};
|
use html_router::{html_routes, html_state::HtmlState};
|
||||||
|
use retrieval_pipeline::reranking::RerankerPool;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
|
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ use std::sync::Arc;
|
|||||||
use common::{
|
use common::{
|
||||||
storage::db::SurrealDbClient, storage::store::StorageManager, utils::config::get_config,
|
storage::db::SurrealDbClient, storage::store::StorageManager, utils::config::get_config,
|
||||||
};
|
};
|
||||||
use composite_retrieval::reranking::RerankerPool;
|
|
||||||
use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop};
|
use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop};
|
||||||
|
use retrieval_pipeline::reranking::RerankerPool;
|
||||||
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
|
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "composite-retrieval"
|
name = "retrieval-pipeline"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
license = "AGPL-3.0-or-later"
|
license = "AGPL-3.0-or-later"
|
||||||
@@ -18,8 +18,8 @@ serde_json = { workspace = true }
|
|||||||
surrealdb = { workspace = true }
|
surrealdb = { workspace = true }
|
||||||
futures = { workspace = true }
|
futures = { workspace = true }
|
||||||
async-openai = { workspace = true }
|
async-openai = { workspace = true }
|
||||||
|
async-trait = { workspace = true }
|
||||||
uuid = { workspace = true }
|
uuid = { workspace = true }
|
||||||
fastembed = { workspace = true }
|
fastembed = { workspace = true }
|
||||||
|
|
||||||
common = { path = "../common", features = ["test-utils"] }
|
common = { path = "../common", features = ["test-utils"] }
|
||||||
state-machines = { workspace = true }
|
|
||||||
@@ -17,7 +17,10 @@ use common::{
|
|||||||
use reranking::RerankerLease;
|
use reranking::RerankerLease;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
pub use pipeline::{retrieved_entities_to_json, RetrievalConfig, RetrievalTuning};
|
pub use pipeline::{
|
||||||
|
retrieved_entities_to_json, PipelineDiagnostics, PipelineStageTimings, RetrievalConfig,
|
||||||
|
RetrievalStrategy, RetrievalTuning, StrategyOutput,
|
||||||
|
};
|
||||||
|
|
||||||
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
|
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@@ -41,14 +44,15 @@ pub async fn retrieve_entities(
|
|||||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||||
input_text: &str,
|
input_text: &str,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
|
config: RetrievalConfig,
|
||||||
reranker: Option<RerankerLease>,
|
reranker: Option<RerankerLease>,
|
||||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
) -> Result<StrategyOutput, AppError> {
|
||||||
pipeline::run_pipeline(
|
pipeline::run_pipeline(
|
||||||
db_client,
|
db_client,
|
||||||
openai_client,
|
openai_client,
|
||||||
input_text,
|
input_text,
|
||||||
user_id,
|
user_id,
|
||||||
RetrievalConfig::default(),
|
config,
|
||||||
reranker,
|
reranker,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
@@ -63,7 +67,7 @@ mod tests {
|
|||||||
knowledge_relationship::KnowledgeRelationship,
|
knowledge_relationship::KnowledgeRelationship,
|
||||||
text_chunk::TextChunk,
|
text_chunk::TextChunk,
|
||||||
};
|
};
|
||||||
use pipeline::RetrievalConfig;
|
use pipeline::{RetrievalConfig, RetrievalStrategy};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
fn test_embedding() -> Vec<f32> {
|
fn test_embedding() -> Vec<f32> {
|
||||||
@@ -151,11 +155,16 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.expect("Hybrid retrieval failed");
|
.expect("Hybrid retrieval failed");
|
||||||
|
|
||||||
|
let entities = match results {
|
||||||
|
StrategyOutput::Entities(items) => items,
|
||||||
|
other => panic!("expected entity results, got {:?}", other),
|
||||||
|
};
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
!results.is_empty(),
|
!entities.is_empty(),
|
||||||
"Expected at least one retrieval result"
|
"Expected at least one retrieval result"
|
||||||
);
|
);
|
||||||
let top = &results[0];
|
let top = &entities[0];
|
||||||
assert!(
|
assert!(
|
||||||
top.entity.name.contains("Rust"),
|
top.entity.name.contains("Rust"),
|
||||||
"Expected Rust entity to be ranked first"
|
"Expected Rust entity to be ranked first"
|
||||||
@@ -242,8 +251,13 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.expect("Hybrid retrieval failed");
|
.expect("Hybrid retrieval failed");
|
||||||
|
|
||||||
|
let entities = match results {
|
||||||
|
StrategyOutput::Entities(items) => items,
|
||||||
|
other => panic!("expected entity results, got {:?}", other),
|
||||||
|
};
|
||||||
|
|
||||||
let mut neighbor_entry = None;
|
let mut neighbor_entry = None;
|
||||||
for entity in &results {
|
for entity in &entities {
|
||||||
if entity.entity.id == neighbor.id {
|
if entity.entity.id == neighbor.id {
|
||||||
neighbor_entry = Some(entity.clone());
|
neighbor_entry = Some(entity.clone());
|
||||||
}
|
}
|
||||||
@@ -264,4 +278,59 @@ mod tests {
|
|||||||
"Neighbor entity should surface its own supporting chunks"
|
"Neighbor entity should surface its own supporting chunks"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_revised_strategy_returns_chunks() {
|
||||||
|
let db = setup_test_db().await;
|
||||||
|
let user_id = "chunk_user";
|
||||||
|
let chunk_one = TextChunk::new(
|
||||||
|
"src_alpha".into(),
|
||||||
|
"Tokio tasks execute on worker threads managed by the runtime.".into(),
|
||||||
|
chunk_embedding_primary(),
|
||||||
|
user_id.into(),
|
||||||
|
);
|
||||||
|
let chunk_two = TextChunk::new(
|
||||||
|
"src_beta".into(),
|
||||||
|
"Hyper utilizes Tokio to drive HTTP state machines efficiently.".into(),
|
||||||
|
chunk_embedding_secondary(),
|
||||||
|
user_id.into(),
|
||||||
|
);
|
||||||
|
|
||||||
|
db.store_item(chunk_one.clone())
|
||||||
|
.await
|
||||||
|
.expect("Failed to store chunk one");
|
||||||
|
db.store_item(chunk_two.clone())
|
||||||
|
.await
|
||||||
|
.expect("Failed to store chunk two");
|
||||||
|
|
||||||
|
let config = RetrievalConfig::with_strategy(RetrievalStrategy::Revised);
|
||||||
|
let openai_client = Client::new();
|
||||||
|
let results = pipeline::run_pipeline_with_embedding(
|
||||||
|
&db,
|
||||||
|
&openai_client,
|
||||||
|
test_embedding(),
|
||||||
|
"tokio runtime worker behavior",
|
||||||
|
user_id,
|
||||||
|
config,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("Revised retrieval failed");
|
||||||
|
|
||||||
|
let chunks = match results {
|
||||||
|
StrategyOutput::Chunks(items) => items,
|
||||||
|
other => panic!("expected chunk output, got {:?}", other),
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
!chunks.is_empty(),
|
||||||
|
"Revised strategy should return chunk-only responses"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
chunks
|
||||||
|
.iter()
|
||||||
|
.any(|entry| entry.chunk.chunk.contains("Tokio")),
|
||||||
|
"Chunk results should contain relevant snippets"
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,40 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::fmt;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum RetrievalStrategy {
|
||||||
|
Initial,
|
||||||
|
Revised,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for RetrievalStrategy {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::Initial
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::str::FromStr for RetrievalStrategy {
|
||||||
|
type Err = String;
|
||||||
|
|
||||||
|
fn from_str(value: &str) -> Result<Self, Self::Err> {
|
||||||
|
match value.to_ascii_lowercase().as_str() {
|
||||||
|
"initial" => Ok(Self::Initial),
|
||||||
|
"revised" => Ok(Self::Revised),
|
||||||
|
other => Err(format!("unknown retrieval strategy '{other}'")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for RetrievalStrategy {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
let label = match self {
|
||||||
|
RetrievalStrategy::Initial => "initial",
|
||||||
|
RetrievalStrategy::Revised => "revised",
|
||||||
|
};
|
||||||
|
f.write_str(label)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Tunable parameters that govern each retrieval stage.
|
/// Tunable parameters that govern each retrieval stage.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -51,18 +87,34 @@ impl Default for RetrievalTuning {
|
|||||||
/// Wrapper containing tuning plus future flags for per-request overrides.
|
/// Wrapper containing tuning plus future flags for per-request overrides.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct RetrievalConfig {
|
pub struct RetrievalConfig {
|
||||||
|
pub strategy: RetrievalStrategy,
|
||||||
pub tuning: RetrievalTuning,
|
pub tuning: RetrievalTuning,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RetrievalConfig {
|
impl RetrievalConfig {
|
||||||
pub fn new(tuning: RetrievalTuning) -> Self {
|
pub fn new(tuning: RetrievalTuning) -> Self {
|
||||||
Self { tuning }
|
Self {
|
||||||
|
strategy: RetrievalStrategy::Initial,
|
||||||
|
tuning,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_strategy(strategy: RetrievalStrategy) -> Self {
|
||||||
|
Self {
|
||||||
|
strategy,
|
||||||
|
tuning: RetrievalTuning::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_tuning(strategy: RetrievalStrategy, tuning: RetrievalTuning) -> Self {
|
||||||
|
Self { strategy, tuning }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for RetrievalConfig {
|
impl Default for RetrievalConfig {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
strategy: RetrievalStrategy::default(),
|
||||||
tuning: RetrievalTuning::default(),
|
tuning: RetrievalTuning::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
397
retrieval-pipeline/src/pipeline/mod.rs
Normal file
397
retrieval-pipeline/src/pipeline/mod.rs
Normal file
@@ -0,0 +1,397 @@
|
|||||||
|
mod config;
|
||||||
|
mod diagnostics;
|
||||||
|
mod stages;
|
||||||
|
mod strategies;
|
||||||
|
|
||||||
|
pub use config::{RetrievalConfig, RetrievalStrategy, RetrievalTuning};
|
||||||
|
pub use diagnostics::{
|
||||||
|
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
|
||||||
|
PipelineDiagnostics,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::{reranking::RerankerLease, RetrievedChunk, RetrievedEntity};
|
||||||
|
use async_openai::Client;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use common::{error::AppError, storage::db::SurrealDbClient};
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
|
use stages::PipelineContext;
|
||||||
|
use strategies::{InitialStrategyDriver, RevisedStrategyDriver};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum StrategyOutput {
|
||||||
|
Entities(Vec<RetrievedEntity>),
|
||||||
|
Chunks(Vec<RetrievedChunk>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StrategyOutput {
|
||||||
|
pub fn as_entities(&self) -> Option<&[RetrievedEntity]> {
|
||||||
|
match self {
|
||||||
|
StrategyOutput::Entities(items) => Some(items),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_entities(self) -> Option<Vec<RetrievedEntity>> {
|
||||||
|
match self {
|
||||||
|
StrategyOutput::Entities(items) => Some(items),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn as_chunks(&self) -> Option<&[RetrievedChunk]> {
|
||||||
|
match self {
|
||||||
|
StrategyOutput::Chunks(items) => Some(items),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_chunks(self) -> Option<Vec<RetrievedChunk>> {
|
||||||
|
match self {
|
||||||
|
StrategyOutput::Chunks(items) => Some(items),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct PipelineRunOutput<T> {
|
||||||
|
pub results: T,
|
||||||
|
pub diagnostics: Option<PipelineDiagnostics>,
|
||||||
|
pub stage_timings: PipelineStageTimings,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum StageKind {
|
||||||
|
Embed,
|
||||||
|
CollectCandidates,
|
||||||
|
GraphExpansion,
|
||||||
|
ChunkAttach,
|
||||||
|
Rerank,
|
||||||
|
Assemble,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default, serde::Serialize)]
|
||||||
|
pub struct PipelineStageTimings {
|
||||||
|
pub embed_ms: u128,
|
||||||
|
pub collect_candidates_ms: u128,
|
||||||
|
pub graph_expansion_ms: u128,
|
||||||
|
pub chunk_attach_ms: u128,
|
||||||
|
pub rerank_ms: u128,
|
||||||
|
pub assemble_ms: u128,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PipelineStageTimings {
|
||||||
|
pub fn record(&mut self, kind: StageKind, duration: Duration) {
|
||||||
|
let elapsed = duration.as_millis() as u128;
|
||||||
|
match kind {
|
||||||
|
StageKind::Embed => self.embed_ms += elapsed,
|
||||||
|
StageKind::CollectCandidates => self.collect_candidates_ms += elapsed,
|
||||||
|
StageKind::GraphExpansion => self.graph_expansion_ms += elapsed,
|
||||||
|
StageKind::ChunkAttach => self.chunk_attach_ms += elapsed,
|
||||||
|
StageKind::Rerank => self.rerank_ms += elapsed,
|
||||||
|
StageKind::Assemble => self.assemble_ms += elapsed,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait PipelineStage: Send + Sync {
|
||||||
|
fn kind(&self) -> StageKind;
|
||||||
|
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type BoxedStage = Box<dyn PipelineStage + Send + Sync>;
|
||||||
|
|
||||||
|
pub trait StrategyDriver {
|
||||||
|
type Output;
|
||||||
|
|
||||||
|
fn strategy(&self) -> RetrievalStrategy;
|
||||||
|
fn stages(&self) -> Vec<BoxedStage>;
|
||||||
|
fn override_tuning(&self, _config: &mut RetrievalConfig) {}
|
||||||
|
|
||||||
|
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_pipeline(
|
||||||
|
db_client: &SurrealDbClient,
|
||||||
|
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||||
|
input_text: &str,
|
||||||
|
user_id: &str,
|
||||||
|
config: RetrievalConfig,
|
||||||
|
reranker: Option<RerankerLease>,
|
||||||
|
) -> Result<StrategyOutput, AppError> {
|
||||||
|
let input_chars = input_text.chars().count();
|
||||||
|
let input_preview: String = input_text.chars().take(120).collect();
|
||||||
|
let input_preview_clean = input_preview.replace('\n', " ");
|
||||||
|
let preview_len = input_preview_clean.chars().count();
|
||||||
|
info!(
|
||||||
|
%user_id,
|
||||||
|
input_chars,
|
||||||
|
preview_truncated = input_chars > preview_len,
|
||||||
|
preview = %input_preview_clean,
|
||||||
|
"Starting ingestion retrieval pipeline"
|
||||||
|
);
|
||||||
|
|
||||||
|
if config.strategy == RetrievalStrategy::Initial {
|
||||||
|
let driver = InitialStrategyDriver::new();
|
||||||
|
let run = execute_strategy(
|
||||||
|
driver,
|
||||||
|
db_client,
|
||||||
|
openai_client,
|
||||||
|
None,
|
||||||
|
input_text,
|
||||||
|
user_id,
|
||||||
|
config,
|
||||||
|
reranker,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
return Ok(StrategyOutput::Entities(run.results));
|
||||||
|
}
|
||||||
|
|
||||||
|
let driver = RevisedStrategyDriver::new();
|
||||||
|
let run = execute_strategy(
|
||||||
|
driver,
|
||||||
|
db_client,
|
||||||
|
openai_client,
|
||||||
|
None,
|
||||||
|
input_text,
|
||||||
|
user_id,
|
||||||
|
config,
|
||||||
|
reranker,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(StrategyOutput::Chunks(run.results))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_pipeline_with_embedding(
|
||||||
|
db_client: &SurrealDbClient,
|
||||||
|
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||||
|
query_embedding: Vec<f32>,
|
||||||
|
input_text: &str,
|
||||||
|
user_id: &str,
|
||||||
|
config: RetrievalConfig,
|
||||||
|
reranker: Option<RerankerLease>,
|
||||||
|
) -> Result<StrategyOutput, AppError> {
|
||||||
|
if config.strategy == RetrievalStrategy::Initial {
|
||||||
|
let driver = InitialStrategyDriver::new();
|
||||||
|
let run = execute_strategy(
|
||||||
|
driver,
|
||||||
|
db_client,
|
||||||
|
openai_client,
|
||||||
|
Some(query_embedding),
|
||||||
|
input_text,
|
||||||
|
user_id,
|
||||||
|
config,
|
||||||
|
reranker,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
return Ok(StrategyOutput::Entities(run.results));
|
||||||
|
}
|
||||||
|
|
||||||
|
let driver = RevisedStrategyDriver::new();
|
||||||
|
let run = execute_strategy(
|
||||||
|
driver,
|
||||||
|
db_client,
|
||||||
|
openai_client,
|
||||||
|
Some(query_embedding),
|
||||||
|
input_text,
|
||||||
|
user_id,
|
||||||
|
config,
|
||||||
|
reranker,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(StrategyOutput::Chunks(run.results))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_pipeline_with_embedding_with_metrics(
|
||||||
|
db_client: &SurrealDbClient,
|
||||||
|
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||||
|
query_embedding: Vec<f32>,
|
||||||
|
input_text: &str,
|
||||||
|
user_id: &str,
|
||||||
|
config: RetrievalConfig,
|
||||||
|
reranker: Option<RerankerLease>,
|
||||||
|
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
|
||||||
|
if config.strategy == RetrievalStrategy::Initial {
|
||||||
|
let driver = InitialStrategyDriver::new();
|
||||||
|
let run = execute_strategy(
|
||||||
|
driver,
|
||||||
|
db_client,
|
||||||
|
openai_client,
|
||||||
|
Some(query_embedding),
|
||||||
|
input_text,
|
||||||
|
user_id,
|
||||||
|
config,
|
||||||
|
reranker,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
return Ok(PipelineRunOutput {
|
||||||
|
results: StrategyOutput::Entities(run.results),
|
||||||
|
diagnostics: run.diagnostics,
|
||||||
|
stage_timings: run.stage_timings,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let driver = RevisedStrategyDriver::new();
|
||||||
|
let run = execute_strategy(
|
||||||
|
driver,
|
||||||
|
db_client,
|
||||||
|
openai_client,
|
||||||
|
Some(query_embedding),
|
||||||
|
input_text,
|
||||||
|
user_id,
|
||||||
|
config,
|
||||||
|
reranker,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(PipelineRunOutput {
|
||||||
|
results: StrategyOutput::Chunks(run.results),
|
||||||
|
diagnostics: run.diagnostics,
|
||||||
|
stage_timings: run.stage_timings,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_pipeline_with_embedding_with_diagnostics(
|
||||||
|
db_client: &SurrealDbClient,
|
||||||
|
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||||
|
query_embedding: Vec<f32>,
|
||||||
|
input_text: &str,
|
||||||
|
user_id: &str,
|
||||||
|
config: RetrievalConfig,
|
||||||
|
reranker: Option<RerankerLease>,
|
||||||
|
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
|
||||||
|
if config.strategy == RetrievalStrategy::Initial {
|
||||||
|
let driver = InitialStrategyDriver::new();
|
||||||
|
let run = execute_strategy(
|
||||||
|
driver,
|
||||||
|
db_client,
|
||||||
|
openai_client,
|
||||||
|
Some(query_embedding),
|
||||||
|
input_text,
|
||||||
|
user_id,
|
||||||
|
config,
|
||||||
|
reranker,
|
||||||
|
true,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
return Ok(PipelineRunOutput {
|
||||||
|
results: StrategyOutput::Entities(run.results),
|
||||||
|
diagnostics: run.diagnostics,
|
||||||
|
stage_timings: run.stage_timings,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let driver = RevisedStrategyDriver::new();
|
||||||
|
let run = execute_strategy(
|
||||||
|
driver,
|
||||||
|
db_client,
|
||||||
|
openai_client,
|
||||||
|
Some(query_embedding),
|
||||||
|
input_text,
|
||||||
|
user_id,
|
||||||
|
config,
|
||||||
|
reranker,
|
||||||
|
true,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(PipelineRunOutput {
|
||||||
|
results: StrategyOutput::Chunks(run.results),
|
||||||
|
diagnostics: run.diagnostics,
|
||||||
|
stage_timings: run.stage_timings,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value {
|
||||||
|
serde_json::json!(entities
|
||||||
|
.iter()
|
||||||
|
.map(|entry| {
|
||||||
|
serde_json::json!({
|
||||||
|
"KnowledgeEntity": {
|
||||||
|
"id": entry.entity.id,
|
||||||
|
"name": entry.entity.name,
|
||||||
|
"description": entry.entity.description,
|
||||||
|
"score": round_score(entry.score),
|
||||||
|
"chunks": entry.chunks.iter().map(|chunk| {
|
||||||
|
serde_json::json!({
|
||||||
|
"score": round_score(chunk.score),
|
||||||
|
"content": chunk.chunk.chunk
|
||||||
|
})
|
||||||
|
}).collect::<Vec<_>>()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute_strategy<D: StrategyDriver>(
|
||||||
|
driver: D,
|
||||||
|
db_client: &SurrealDbClient,
|
||||||
|
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||||
|
query_embedding: Option<Vec<f32>>,
|
||||||
|
input_text: &str,
|
||||||
|
user_id: &str,
|
||||||
|
mut config: RetrievalConfig,
|
||||||
|
reranker: Option<RerankerLease>,
|
||||||
|
capture_diagnostics: bool,
|
||||||
|
) -> Result<PipelineRunOutput<D::Output>, AppError> {
|
||||||
|
driver.override_tuning(&mut config);
|
||||||
|
let ctx = match query_embedding {
|
||||||
|
Some(embedding) => PipelineContext::with_embedding(
|
||||||
|
db_client,
|
||||||
|
openai_client,
|
||||||
|
embedding,
|
||||||
|
input_text.to_owned(),
|
||||||
|
user_id.to_owned(),
|
||||||
|
config,
|
||||||
|
reranker,
|
||||||
|
),
|
||||||
|
None => PipelineContext::new(
|
||||||
|
db_client,
|
||||||
|
openai_client,
|
||||||
|
input_text.to_owned(),
|
||||||
|
user_id.to_owned(),
|
||||||
|
config,
|
||||||
|
reranker,
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
run_with_driver(driver, ctx, capture_diagnostics).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run_with_driver<D: StrategyDriver>(
|
||||||
|
driver: D,
|
||||||
|
mut ctx: PipelineContext<'_>,
|
||||||
|
capture_diagnostics: bool,
|
||||||
|
) -> Result<PipelineRunOutput<D::Output>, AppError> {
|
||||||
|
if capture_diagnostics {
|
||||||
|
ctx.enable_diagnostics();
|
||||||
|
}
|
||||||
|
|
||||||
|
for stage in driver.stages() {
|
||||||
|
let start = Instant::now();
|
||||||
|
stage.execute(&mut ctx).await?;
|
||||||
|
ctx.record_stage_duration(stage.kind(), start.elapsed());
|
||||||
|
}
|
||||||
|
|
||||||
|
let diagnostics = ctx.take_diagnostics();
|
||||||
|
let stage_timings = ctx.take_stage_timings();
|
||||||
|
let results = driver.finalize(&mut ctx)?;
|
||||||
|
|
||||||
|
Ok(PipelineRunOutput {
|
||||||
|
results,
|
||||||
|
diagnostics,
|
||||||
|
stage_timings,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn round_score(value: f32) -> f64 {
|
||||||
|
(f64::from(value) * 1000.0).round() / 1000.0
|
||||||
|
}
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
use async_openai::Client;
|
use async_openai::Client;
|
||||||
|
use async_trait::async_trait;
|
||||||
use common::{
|
use common::{
|
||||||
error::AppError,
|
error::AppError,
|
||||||
storage::{
|
storage::{
|
||||||
@@ -9,11 +10,9 @@ use common::{
|
|||||||
};
|
};
|
||||||
use fastembed::RerankResult;
|
use fastembed::RerankResult;
|
||||||
use futures::{stream::FuturesUnordered, StreamExt};
|
use futures::{stream::FuturesUnordered, StreamExt};
|
||||||
use state_machines::core::GuardError;
|
|
||||||
use std::{
|
use std::{
|
||||||
cmp::Ordering,
|
cmp::Ordering,
|
||||||
collections::{HashMap, HashSet},
|
collections::{HashMap, HashSet},
|
||||||
time::Instant,
|
|
||||||
};
|
};
|
||||||
use tracing::{debug, instrument, warn};
|
use tracing::{debug, instrument, warn};
|
||||||
|
|
||||||
@@ -25,21 +24,20 @@ use crate::{
|
|||||||
clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc,
|
clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc,
|
||||||
FusionWeights, Scored,
|
FusionWeights, Scored,
|
||||||
},
|
},
|
||||||
vector::find_items_by_vector_similarity_with_embedding,
|
vector::{
|
||||||
|
find_chunk_snippets_by_vector_similarity_with_embedding,
|
||||||
|
find_items_by_vector_similarity_with_embedding, ChunkSnippet,
|
||||||
|
},
|
||||||
RetrievedChunk, RetrievedEntity,
|
RetrievedChunk, RetrievedEntity,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
config::RetrievalConfig,
|
config::{RetrievalConfig, RetrievalTuning},
|
||||||
diagnostics::{
|
diagnostics::{
|
||||||
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
|
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
|
||||||
PipelineDiagnostics,
|
PipelineDiagnostics,
|
||||||
},
|
},
|
||||||
state::{
|
PipelineStage, PipelineStageTimings, StageKind,
|
||||||
CandidatesLoaded, ChunksAttached, Embedded, GraphExpanded, HybridRetrievalMachine, Ready,
|
|
||||||
Reranked,
|
|
||||||
},
|
|
||||||
PipelineStageTimings,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub struct PipelineContext<'a> {
|
pub struct PipelineContext<'a> {
|
||||||
@@ -53,8 +51,11 @@ pub struct PipelineContext<'a> {
|
|||||||
pub chunk_candidates: HashMap<String, Scored<TextChunk>>,
|
pub chunk_candidates: HashMap<String, Scored<TextChunk>>,
|
||||||
pub filtered_entities: Vec<Scored<KnowledgeEntity>>,
|
pub filtered_entities: Vec<Scored<KnowledgeEntity>>,
|
||||||
pub chunk_values: Vec<Scored<TextChunk>>,
|
pub chunk_values: Vec<Scored<TextChunk>>,
|
||||||
|
pub revised_chunk_values: Vec<Scored<ChunkSnippet>>,
|
||||||
pub reranker: Option<RerankerLease>,
|
pub reranker: Option<RerankerLease>,
|
||||||
pub diagnostics: Option<PipelineDiagnostics>,
|
pub diagnostics: Option<PipelineDiagnostics>,
|
||||||
|
pub entity_results: Vec<RetrievedEntity>,
|
||||||
|
pub chunk_results: Vec<RetrievedChunk>,
|
||||||
stage_timings: PipelineStageTimings,
|
stage_timings: PipelineStageTimings,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -78,8 +79,11 @@ impl<'a> PipelineContext<'a> {
|
|||||||
chunk_candidates: HashMap::new(),
|
chunk_candidates: HashMap::new(),
|
||||||
filtered_entities: Vec::new(),
|
filtered_entities: Vec::new(),
|
||||||
chunk_values: Vec::new(),
|
chunk_values: Vec::new(),
|
||||||
|
revised_chunk_values: Vec::new(),
|
||||||
reranker,
|
reranker,
|
||||||
diagnostics: None,
|
diagnostics: None,
|
||||||
|
entity_results: Vec::new(),
|
||||||
|
chunk_results: Vec::new(),
|
||||||
stage_timings: PipelineStageTimings::default(),
|
stage_timings: PipelineStageTimings::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -145,36 +149,151 @@ impl<'a> PipelineContext<'a> {
|
|||||||
self.diagnostics.take()
|
self.diagnostics.take()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn record_collect_candidates_timing(&mut self, duration: std::time::Duration) {
|
|
||||||
self.stage_timings.record_collect_candidates(duration);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn record_graph_expansion_timing(&mut self, duration: std::time::Duration) {
|
|
||||||
self.stage_timings.record_graph_expansion(duration);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn record_chunk_attach_timing(&mut self, duration: std::time::Duration) {
|
|
||||||
self.stage_timings.record_chunk_attach(duration);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn record_rerank_timing(&mut self, duration: std::time::Duration) {
|
|
||||||
self.stage_timings.record_rerank(duration);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn record_assemble_timing(&mut self, duration: std::time::Duration) {
|
|
||||||
self.stage_timings.record_assemble(duration);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn take_stage_timings(&mut self) -> PipelineStageTimings {
|
pub fn take_stage_timings(&mut self) -> PipelineStageTimings {
|
||||||
std::mem::take(&mut self.stage_timings)
|
std::mem::take(&mut self.stage_timings)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn record_stage_duration(&mut self, kind: StageKind, duration: std::time::Duration) {
|
||||||
|
self.stage_timings.record(kind, duration);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn take_entity_results(&mut self) -> Vec<RetrievedEntity> {
|
||||||
|
std::mem::take(&mut self.entity_results)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn take_chunk_results(&mut self) -> Vec<RetrievedChunk> {
|
||||||
|
std::mem::take(&mut self.chunk_results)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct EmbedStage;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl PipelineStage for EmbedStage {
|
||||||
|
fn kind(&self) -> StageKind {
|
||||||
|
StageKind::Embed
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
|
embed(ctx).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct CollectCandidatesStage;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl PipelineStage for CollectCandidatesStage {
|
||||||
|
fn kind(&self) -> StageKind {
|
||||||
|
StageKind::CollectCandidates
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
|
collect_candidates(ctx).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct GraphExpansionStage;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl PipelineStage for GraphExpansionStage {
|
||||||
|
fn kind(&self) -> StageKind {
|
||||||
|
StageKind::GraphExpansion
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
|
expand_graph(ctx).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct ChunkAttachStage;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl PipelineStage for ChunkAttachStage {
|
||||||
|
fn kind(&self) -> StageKind {
|
||||||
|
StageKind::ChunkAttach
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
|
attach_chunks(ctx).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct RerankStage;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl PipelineStage for RerankStage {
|
||||||
|
fn kind(&self) -> StageKind {
|
||||||
|
StageKind::Rerank
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
|
rerank(ctx).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct AssembleEntitiesStage;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl PipelineStage for AssembleEntitiesStage {
|
||||||
|
fn kind(&self) -> StageKind {
|
||||||
|
StageKind::Assemble
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
|
assemble(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct ChunkVectorStage;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl PipelineStage for ChunkVectorStage {
|
||||||
|
fn kind(&self) -> StageKind {
|
||||||
|
StageKind::CollectCandidates
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
|
collect_vector_chunks(ctx).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct ChunkRerankStage;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl PipelineStage for ChunkRerankStage {
|
||||||
|
fn kind(&self) -> StageKind {
|
||||||
|
StageKind::Rerank
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
|
rerank_chunks(ctx).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct ChunkAssembleStage;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl PipelineStage for ChunkAssembleStage {
|
||||||
|
fn kind(&self) -> StageKind {
|
||||||
|
StageKind::Assemble
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
|
assemble_chunks(ctx)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(level = "trace", skip_all)]
|
#[instrument(level = "trace", skip_all)]
|
||||||
pub async fn embed(
|
pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
machine: HybridRetrievalMachine<(), Ready>,
|
|
||||||
ctx: &mut PipelineContext<'_>,
|
|
||||||
) -> Result<HybridRetrievalMachine<(), Embedded>, AppError> {
|
|
||||||
let embedding_cached = ctx.query_embedding.is_some();
|
let embedding_cached = ctx.query_embedding.is_some();
|
||||||
if embedding_cached {
|
if embedding_cached {
|
||||||
debug!("Reusing cached query embedding for hybrid retrieval");
|
debug!("Reusing cached query embedding for hybrid retrieval");
|
||||||
@@ -185,17 +304,11 @@ pub async fn embed(
|
|||||||
ctx.query_embedding = Some(embedding);
|
ctx.query_embedding = Some(embedding);
|
||||||
}
|
}
|
||||||
|
|
||||||
machine
|
Ok(())
|
||||||
.embed()
|
|
||||||
.map_err(|(_, guard)| map_guard_error("embed", guard))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(level = "trace", skip_all)]
|
#[instrument(level = "trace", skip_all)]
|
||||||
pub async fn collect_candidates(
|
pub async fn collect_candidates(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
machine: HybridRetrievalMachine<(), Embedded>,
|
|
||||||
ctx: &mut PipelineContext<'_>,
|
|
||||||
) -> Result<HybridRetrievalMachine<(), CandidatesLoaded>, AppError> {
|
|
||||||
let stage_start = Instant::now();
|
|
||||||
debug!("Collecting initial candidates via vector and FTS search");
|
debug!("Collecting initial candidates via vector and FTS search");
|
||||||
let embedding = ctx.ensure_embedding()?.clone();
|
let embedding = ctx.ensure_embedding()?.clone();
|
||||||
let tuning = &ctx.config.tuning;
|
let tuning = &ctx.config.tuning;
|
||||||
@@ -265,104 +378,80 @@ pub async fn collect_candidates(
|
|||||||
apply_fusion(&mut ctx.entity_candidates, weights);
|
apply_fusion(&mut ctx.entity_candidates, weights);
|
||||||
apply_fusion(&mut ctx.chunk_candidates, weights);
|
apply_fusion(&mut ctx.chunk_candidates, weights);
|
||||||
|
|
||||||
let next = machine
|
Ok(())
|
||||||
.collect_candidates()
|
|
||||||
.map_err(|(_, guard)| map_guard_error("collect_candidates", guard))?;
|
|
||||||
ctx.record_collect_candidates_timing(stage_start.elapsed());
|
|
||||||
Ok(next)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(level = "trace", skip_all)]
|
#[instrument(level = "trace", skip_all)]
|
||||||
pub async fn expand_graph(
|
pub async fn expand_graph(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
machine: HybridRetrievalMachine<(), CandidatesLoaded>,
|
|
||||||
ctx: &mut PipelineContext<'_>,
|
|
||||||
) -> Result<HybridRetrievalMachine<(), GraphExpanded>, AppError> {
|
|
||||||
let stage_start = Instant::now();
|
|
||||||
debug!("Expanding candidates using graph relationships");
|
debug!("Expanding candidates using graph relationships");
|
||||||
let next = {
|
let tuning = &ctx.config.tuning;
|
||||||
let tuning = &ctx.config.tuning;
|
let weights = FusionWeights::default();
|
||||||
let weights = FusionWeights::default();
|
|
||||||
|
|
||||||
if ctx.entity_candidates.is_empty() {
|
if ctx.entity_candidates.is_empty() {
|
||||||
machine
|
return Ok(());
|
||||||
.expand_graph()
|
}
|
||||||
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
|
|
||||||
} else {
|
|
||||||
let graph_seeds = seeds_from_candidates(
|
|
||||||
&ctx.entity_candidates,
|
|
||||||
tuning.graph_seed_min_score,
|
|
||||||
tuning.graph_traversal_seed_limit,
|
|
||||||
);
|
|
||||||
|
|
||||||
if graph_seeds.is_empty() {
|
let graph_seeds = seeds_from_candidates(
|
||||||
machine
|
&ctx.entity_candidates,
|
||||||
.expand_graph()
|
tuning.graph_seed_min_score,
|
||||||
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
|
tuning.graph_traversal_seed_limit,
|
||||||
} else {
|
);
|
||||||
let mut futures = FuturesUnordered::new();
|
|
||||||
for seed in graph_seeds {
|
|
||||||
let db = ctx.db_client;
|
|
||||||
let user = ctx.user_id.clone();
|
|
||||||
let limit = tuning.graph_neighbor_limit;
|
|
||||||
futures.push(async move {
|
|
||||||
let neighbors =
|
|
||||||
find_entities_by_relationship_by_id(db, &seed.id, &user, limit).await;
|
|
||||||
(seed, neighbors)
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
while let Some((seed, neighbors_result)) = futures.next().await {
|
if graph_seeds.is_empty() {
|
||||||
let neighbors = neighbors_result.map_err(AppError::from)?;
|
return Ok(());
|
||||||
if neighbors.is_empty() {
|
}
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
for neighbor in neighbors {
|
let mut futures = FuturesUnordered::new();
|
||||||
if neighbor.id == seed.id {
|
for seed in graph_seeds {
|
||||||
continue;
|
let db = ctx.db_client;
|
||||||
}
|
let user = ctx.user_id.clone();
|
||||||
|
let limit = tuning.graph_neighbor_limit;
|
||||||
|
futures.push(async move {
|
||||||
|
let neighbors = find_entities_by_relationship_by_id(db, &seed.id, &user, limit).await;
|
||||||
|
(seed, neighbors)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay);
|
while let Some((seed, neighbors_result)) = futures.next().await {
|
||||||
let entry = ctx
|
let neighbors = neighbors_result.map_err(AppError::from)?;
|
||||||
.entity_candidates
|
if neighbors.is_empty() {
|
||||||
.entry(neighbor.id.clone())
|
continue;
|
||||||
.or_insert_with(|| Scored::new(neighbor.clone()));
|
|
||||||
|
|
||||||
entry.item = neighbor;
|
|
||||||
|
|
||||||
let inherited_vector =
|
|
||||||
clamp_unit(graph_score * tuning.graph_vector_inheritance);
|
|
||||||
let vector_existing = entry.scores.vector.unwrap_or(0.0);
|
|
||||||
if inherited_vector > vector_existing {
|
|
||||||
entry.scores.vector = Some(inherited_vector);
|
|
||||||
}
|
|
||||||
|
|
||||||
let existing_graph = entry.scores.graph.unwrap_or(f32::MIN);
|
|
||||||
if graph_score > existing_graph || entry.scores.graph.is_none() {
|
|
||||||
entry.scores.graph = Some(graph_score);
|
|
||||||
}
|
|
||||||
|
|
||||||
let fused = fuse_scores(&entry.scores, weights);
|
|
||||||
entry.update_fused(fused);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
machine
|
|
||||||
.expand_graph()
|
|
||||||
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}?;
|
|
||||||
ctx.record_graph_expansion_timing(stage_start.elapsed());
|
for neighbor in neighbors {
|
||||||
Ok(next)
|
if neighbor.id == seed.id {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay);
|
||||||
|
let entry = ctx
|
||||||
|
.entity_candidates
|
||||||
|
.entry(neighbor.id.clone())
|
||||||
|
.or_insert_with(|| Scored::new(neighbor.clone()));
|
||||||
|
|
||||||
|
entry.item = neighbor;
|
||||||
|
|
||||||
|
let inherited_vector = clamp_unit(graph_score * tuning.graph_vector_inheritance);
|
||||||
|
let vector_existing = entry.scores.vector.unwrap_or(0.0);
|
||||||
|
if inherited_vector > vector_existing {
|
||||||
|
entry.scores.vector = Some(inherited_vector);
|
||||||
|
}
|
||||||
|
|
||||||
|
let existing_graph = entry.scores.graph.unwrap_or(f32::MIN);
|
||||||
|
if graph_score > existing_graph || entry.scores.graph.is_none() {
|
||||||
|
entry.scores.graph = Some(graph_score);
|
||||||
|
}
|
||||||
|
|
||||||
|
let fused = fuse_scores(&entry.scores, weights);
|
||||||
|
entry.update_fused(fused);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(level = "trace", skip_all)]
|
#[instrument(level = "trace", skip_all)]
|
||||||
pub async fn attach_chunks(
|
pub async fn attach_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
machine: HybridRetrievalMachine<(), GraphExpanded>,
|
|
||||||
ctx: &mut PipelineContext<'_>,
|
|
||||||
) -> Result<HybridRetrievalMachine<(), ChunksAttached>, AppError> {
|
|
||||||
let stage_start = Instant::now();
|
|
||||||
debug!("Attaching chunks to surviving entities");
|
debug!("Attaching chunks to surviving entities");
|
||||||
let tuning = &ctx.config.tuning;
|
let tuning = &ctx.config.tuning;
|
||||||
let weights = FusionWeights::default();
|
let weights = FusionWeights::default();
|
||||||
@@ -438,19 +527,11 @@ pub async fn attach_chunks(
|
|||||||
|
|
||||||
ctx.chunk_values = chunk_values;
|
ctx.chunk_values = chunk_values;
|
||||||
|
|
||||||
let next = machine
|
Ok(())
|
||||||
.attach_chunks()
|
|
||||||
.map_err(|(_, guard)| map_guard_error("attach_chunks", guard))?;
|
|
||||||
ctx.record_chunk_attach_timing(stage_start.elapsed());
|
|
||||||
Ok(next)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(level = "trace", skip_all)]
|
#[instrument(level = "trace", skip_all)]
|
||||||
pub async fn rerank(
|
pub async fn rerank(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
machine: HybridRetrievalMachine<(), ChunksAttached>,
|
|
||||||
ctx: &mut PipelineContext<'_>,
|
|
||||||
) -> Result<HybridRetrievalMachine<(), Reranked>, AppError> {
|
|
||||||
let stage_start = Instant::now();
|
|
||||||
let mut applied = false;
|
let mut applied = false;
|
||||||
|
|
||||||
if let Some(reranker) = ctx.reranker.as_ref() {
|
if let Some(reranker) = ctx.reranker.as_ref() {
|
||||||
@@ -490,19 +571,124 @@ pub async fn rerank(
|
|||||||
debug!("Applied reranking adjustments to candidate ordering");
|
debug!("Applied reranking adjustments to candidate ordering");
|
||||||
}
|
}
|
||||||
|
|
||||||
let next = machine
|
Ok(())
|
||||||
.rerank()
|
|
||||||
.map_err(|(_, guard)| map_guard_error("rerank", guard))?;
|
|
||||||
ctx.record_rerank_timing(stage_start.elapsed());
|
|
||||||
Ok(next)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(level = "trace", skip_all)]
|
#[instrument(level = "trace", skip_all)]
|
||||||
pub fn assemble(
|
pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
machine: HybridRetrievalMachine<(), Reranked>,
|
debug!("Collecting vector chunk candidates for revised strategy");
|
||||||
ctx: &mut PipelineContext<'_>,
|
let embedding = ctx.ensure_embedding()?.clone();
|
||||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
let tuning = &ctx.config.tuning;
|
||||||
let stage_start = Instant::now();
|
let mut vector_chunks = find_chunk_snippets_by_vector_similarity_with_embedding(
|
||||||
|
tuning.chunk_vector_take,
|
||||||
|
embedding,
|
||||||
|
ctx.db_client,
|
||||||
|
&ctx.user_id,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if ctx.diagnostics_enabled() {
|
||||||
|
ctx.record_collect_candidates(CollectCandidatesStats {
|
||||||
|
vector_entity_candidates: 0,
|
||||||
|
vector_chunk_candidates: vector_chunks.len(),
|
||||||
|
fts_entity_candidates: 0,
|
||||||
|
fts_chunk_candidates: 0,
|
||||||
|
vector_chunk_scores: sample_scores(&vector_chunks, |chunk| {
|
||||||
|
chunk.scores.vector.unwrap_or(0.0)
|
||||||
|
}),
|
||||||
|
fts_chunk_scores: Vec::new(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
vector_chunks.sort_by(|a, b| b.fused.partial_cmp(&a.fused).unwrap_or(Ordering::Equal));
|
||||||
|
ctx.revised_chunk_values = vector_chunks;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(level = "trace", skip_all)]
|
||||||
|
pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
|
if ctx.revised_chunk_values.len() <= 1 {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(reranker) = ctx.reranker.as_ref() else {
|
||||||
|
debug!("No reranker lease provided; skipping chunk rerank stage");
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
|
||||||
|
let documents = build_snippet_rerank_documents(
|
||||||
|
&ctx.revised_chunk_values,
|
||||||
|
ctx.config.tuning.rerank_keep_top.max(1),
|
||||||
|
);
|
||||||
|
if documents.len() <= 1 {
|
||||||
|
debug!("Skipping chunk reranking stage; insufficient chunk documents");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
match reranker.rerank(&ctx.input_text, documents).await {
|
||||||
|
Ok(results) if !results.is_empty() => {
|
||||||
|
apply_snippet_rerank_results(
|
||||||
|
&mut ctx.revised_chunk_values,
|
||||||
|
&ctx.config.tuning,
|
||||||
|
results,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(_) => debug!("Chunk reranker returned no results; retaining original order"),
|
||||||
|
Err(err) => warn!(
|
||||||
|
error = %err,
|
||||||
|
"Chunk reranking failed; continuing with original ordering"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(level = "trace", skip_all)]
|
||||||
|
pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
|
debug!("Assembling chunk-only retrieval results");
|
||||||
|
let mut chunk_values = std::mem::take(&mut ctx.revised_chunk_values);
|
||||||
|
let question_terms = extract_keywords(&ctx.input_text);
|
||||||
|
rank_snippet_chunks_by_combined_score(
|
||||||
|
&mut chunk_values,
|
||||||
|
&question_terms,
|
||||||
|
ctx.config.tuning.lexical_match_weight,
|
||||||
|
);
|
||||||
|
|
||||||
|
let limit = ctx.config.tuning.chunk_vector_take.max(1);
|
||||||
|
if chunk_values.len() > limit {
|
||||||
|
chunk_values.truncate(limit);
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.chunk_results = chunk_values
|
||||||
|
.into_iter()
|
||||||
|
.map(|chunk| {
|
||||||
|
let text_chunk = snippet_into_text_chunk(chunk.item, &ctx.user_id);
|
||||||
|
RetrievedChunk {
|
||||||
|
chunk: text_chunk,
|
||||||
|
score: chunk.fused,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if ctx.diagnostics_enabled() {
|
||||||
|
ctx.record_assemble(AssembleStats {
|
||||||
|
token_budget_start: ctx.config.tuning.token_budget_estimate,
|
||||||
|
token_budget_spent: 0,
|
||||||
|
token_budget_remaining: ctx.config.tuning.token_budget_estimate,
|
||||||
|
budget_exhausted: false,
|
||||||
|
chunks_selected: ctx.chunk_results.len(),
|
||||||
|
chunks_skipped_due_budget: 0,
|
||||||
|
entity_count: 0,
|
||||||
|
entity_traces: Vec::new(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(level = "trace", skip_all)]
|
||||||
|
pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
debug!("Assembling final retrieved entities");
|
debug!("Assembling final retrieved entities");
|
||||||
let tuning = &ctx.config.tuning;
|
let tuning = &ctx.config.tuning;
|
||||||
let query_embedding = ctx.ensure_embedding()?.clone();
|
let query_embedding = ctx.ensure_embedding()?.clone();
|
||||||
@@ -610,11 +796,8 @@ pub fn assemble(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
machine
|
ctx.entity_results = results;
|
||||||
.assemble()
|
Ok(())
|
||||||
.map_err(|(_, guard)| map_guard_error("assemble", guard))?;
|
|
||||||
ctx.record_assemble_timing(stage_start.elapsed());
|
|
||||||
Ok(results)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const SCORE_SAMPLE_LIMIT: usize = 8;
|
const SCORE_SAMPLE_LIMIT: usize = 8;
|
||||||
@@ -630,12 +813,6 @@ where
|
|||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn map_guard_error(stage: &'static str, err: GuardError) -> AppError {
|
|
||||||
AppError::InternalError(format!(
|
|
||||||
"state machine guard '{stage}' failed: guard={}, event={}, kind={:?}",
|
|
||||||
err.guard, err.event, err.kind
|
|
||||||
))
|
|
||||||
}
|
|
||||||
fn normalize_fts_scores<T>(results: &mut [Scored<T>]) {
|
fn normalize_fts_scores<T>(results: &mut [Scored<T>]) {
|
||||||
let raw_scores: Vec<f32> = results
|
let raw_scores: Vec<f32> = results
|
||||||
.iter()
|
.iter()
|
||||||
@@ -873,6 +1050,23 @@ fn build_rerank_documents(ctx: &PipelineContext<'_>, max_chunks_per_entity: usiz
|
|||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn build_snippet_rerank_documents(
|
||||||
|
chunks: &[Scored<ChunkSnippet>],
|
||||||
|
max_chunks: usize,
|
||||||
|
) -> Vec<String> {
|
||||||
|
chunks
|
||||||
|
.iter()
|
||||||
|
.take(max_chunks)
|
||||||
|
.map(|chunk| {
|
||||||
|
format!(
|
||||||
|
"Source: {}\nChunk:\n{}",
|
||||||
|
chunk.item.source_id,
|
||||||
|
chunk.item.chunk.trim()
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec<RerankResult>) {
|
fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec<RerankResult>) {
|
||||||
if results.is_empty() || ctx.filtered_entities.is_empty() {
|
if results.is_empty() || ctx.filtered_entities.is_empty() {
|
||||||
return;
|
return;
|
||||||
@@ -930,6 +1124,66 @@ fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec<RerankResult
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn apply_snippet_rerank_results(
|
||||||
|
chunks: &mut Vec<Scored<ChunkSnippet>>,
|
||||||
|
tuning: &RetrievalTuning,
|
||||||
|
results: Vec<RerankResult>,
|
||||||
|
) {
|
||||||
|
if results.is_empty() || chunks.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut remaining: Vec<Option<Scored<ChunkSnippet>>> =
|
||||||
|
std::mem::take(chunks).into_iter().map(Some).collect();
|
||||||
|
|
||||||
|
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
|
||||||
|
let normalized_scores = min_max_normalize(&raw_scores);
|
||||||
|
|
||||||
|
let use_only = tuning.rerank_scores_only;
|
||||||
|
let blend = if use_only {
|
||||||
|
1.0
|
||||||
|
} else {
|
||||||
|
clamp_unit(tuning.rerank_blend_weight)
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut reranked: Vec<Scored<ChunkSnippet>> = Vec::with_capacity(remaining.len());
|
||||||
|
for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) {
|
||||||
|
if let Some(slot) = remaining.get_mut(result.index) {
|
||||||
|
if let Some(mut candidate) = slot.take() {
|
||||||
|
let original = candidate.fused;
|
||||||
|
let blended = if use_only {
|
||||||
|
clamp_unit(normalized)
|
||||||
|
} else {
|
||||||
|
clamp_unit(original * (1.0 - blend) + normalized * blend)
|
||||||
|
};
|
||||||
|
candidate.update_fused(blended);
|
||||||
|
reranked.push(candidate);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
warn!(
|
||||||
|
result_index = result.index,
|
||||||
|
"Chunk reranker returned out-of-range index; skipping"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if reranked.len() == remaining.len() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for slot in remaining.into_iter() {
|
||||||
|
if let Some(candidate) = slot {
|
||||||
|
reranked.push(candidate);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let keep_top = tuning.rerank_keep_top;
|
||||||
|
if keep_top > 0 && reranked.len() > keep_top {
|
||||||
|
reranked.truncate(keep_top);
|
||||||
|
}
|
||||||
|
|
||||||
|
*chunks = reranked;
|
||||||
|
}
|
||||||
|
|
||||||
fn estimate_tokens(text: &str, avg_chars_per_token: usize) -> usize {
|
fn estimate_tokens(text: &str, avg_chars_per_token: usize) -> usize {
|
||||||
let chars = text.chars().count().max(1);
|
let chars = text.chars().count().max(1);
|
||||||
(chars / avg_chars_per_token).max(1)
|
(chars / avg_chars_per_token).max(1)
|
||||||
@@ -963,6 +1217,32 @@ fn extract_keywords(text: &str) -> Vec<String> {
|
|||||||
terms
|
terms
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn rank_snippet_chunks_by_combined_score(
|
||||||
|
candidates: &mut [Scored<ChunkSnippet>],
|
||||||
|
question_terms: &[String],
|
||||||
|
lexical_weight: f32,
|
||||||
|
) {
|
||||||
|
if lexical_weight > 0.0 && !question_terms.is_empty() {
|
||||||
|
for candidate in candidates.iter_mut() {
|
||||||
|
let lexical = lexical_overlap_score(question_terms, &candidate.item.chunk);
|
||||||
|
let combined = clamp_unit(candidate.fused + lexical_weight * lexical);
|
||||||
|
candidate.update_fused(combined);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
candidates.sort_by(|a, b| b.fused.partial_cmp(&a.fused).unwrap_or(Ordering::Equal));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn snippet_into_text_chunk(snippet: ChunkSnippet, user_id: &str) -> TextChunk {
|
||||||
|
let mut chunk = TextChunk::new(
|
||||||
|
snippet.source_id.clone(),
|
||||||
|
snippet.chunk,
|
||||||
|
Vec::new(),
|
||||||
|
user_id.to_owned(),
|
||||||
|
);
|
||||||
|
chunk.id = snippet.id;
|
||||||
|
chunk
|
||||||
|
}
|
||||||
|
|
||||||
fn lexical_overlap_score(terms: &[String], haystack: &str) -> f32 {
|
fn lexical_overlap_score(terms: &[String], haystack: &str) -> f32 {
|
||||||
if terms.is_empty() {
|
if terms.is_empty() {
|
||||||
return 0.0;
|
return 0.0;
|
||||||
75
retrieval-pipeline/src/pipeline/strategies.rs
Normal file
75
retrieval-pipeline/src/pipeline/strategies.rs
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
use super::{
|
||||||
|
stages::{
|
||||||
|
AssembleEntitiesStage, ChunkAssembleStage, ChunkAttachStage, ChunkRerankStage,
|
||||||
|
ChunkVectorStage, CollectCandidatesStage, EmbedStage, GraphExpansionStage, PipelineContext,
|
||||||
|
RerankStage,
|
||||||
|
},
|
||||||
|
BoxedStage, RetrievalConfig, RetrievalStrategy, StrategyDriver,
|
||||||
|
};
|
||||||
|
use crate::{RetrievedChunk, RetrievedEntity};
|
||||||
|
use common::error::AppError;
|
||||||
|
|
||||||
|
pub struct InitialStrategyDriver;
|
||||||
|
|
||||||
|
impl InitialStrategyDriver {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StrategyDriver for InitialStrategyDriver {
|
||||||
|
type Output = Vec<RetrievedEntity>;
|
||||||
|
|
||||||
|
fn strategy(&self) -> RetrievalStrategy {
|
||||||
|
RetrievalStrategy::Initial
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stages(&self) -> Vec<BoxedStage> {
|
||||||
|
vec![
|
||||||
|
Box::new(EmbedStage),
|
||||||
|
Box::new(CollectCandidatesStage),
|
||||||
|
Box::new(GraphExpansionStage),
|
||||||
|
Box::new(ChunkAttachStage),
|
||||||
|
Box::new(RerankStage),
|
||||||
|
Box::new(AssembleEntitiesStage),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
|
||||||
|
Ok(ctx.take_entity_results())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct RevisedStrategyDriver;
|
||||||
|
|
||||||
|
impl RevisedStrategyDriver {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StrategyDriver for RevisedStrategyDriver {
|
||||||
|
type Output = Vec<RetrievedChunk>;
|
||||||
|
|
||||||
|
fn strategy(&self) -> RetrievalStrategy {
|
||||||
|
RetrievalStrategy::Revised
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stages(&self) -> Vec<BoxedStage> {
|
||||||
|
vec![
|
||||||
|
Box::new(EmbedStage),
|
||||||
|
Box::new(ChunkVectorStage),
|
||||||
|
Box::new(ChunkRerankStage),
|
||||||
|
Box::new(ChunkAssembleStage),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
|
||||||
|
Ok(ctx.take_chunk_results())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn override_tuning(&self, config: &mut RetrievalConfig) {
|
||||||
|
config.tuning.entity_vector_take = 0;
|
||||||
|
config.tuning.entity_fts_take = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,9 +1,11 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use common::storage::types::file_info::deserialize_flexible_id;
|
|
||||||
use common::{
|
use common::{
|
||||||
error::AppError,
|
error::AppError,
|
||||||
storage::{db::SurrealDbClient, types::StoredObject},
|
storage::{
|
||||||
|
db::SurrealDbClient,
|
||||||
|
types::{file_info::deserialize_flexible_id, StoredObject},
|
||||||
|
},
|
||||||
utils::embedding::generate_embedding,
|
utils::embedding::generate_embedding,
|
||||||
};
|
};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
@@ -156,3 +158,61 @@ where
|
|||||||
|
|
||||||
Ok(scored)
|
Ok(scored)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct ChunkSnippet {
|
||||||
|
pub id: String,
|
||||||
|
pub source_id: String,
|
||||||
|
pub chunk: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct ChunkDistanceRow {
|
||||||
|
distance: Option<f32>,
|
||||||
|
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||||
|
pub id: String,
|
||||||
|
pub source_id: String,
|
||||||
|
pub chunk: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn find_chunk_snippets_by_vector_similarity_with_embedding(
|
||||||
|
take: usize,
|
||||||
|
query_embedding: Vec<f32>,
|
||||||
|
db_client: &SurrealDbClient,
|
||||||
|
user_id: &str,
|
||||||
|
) -> Result<Vec<Scored<ChunkSnippet>>, AppError> {
|
||||||
|
let embedding_literal = serde_json::to_string(&query_embedding)
|
||||||
|
.map_err(|err| AppError::InternalError(format!("Failed to serialize embedding: {err}")))?;
|
||||||
|
|
||||||
|
let closest_query = format!(
|
||||||
|
"SELECT id, source_id, chunk, vector::distance::knn() AS distance \
|
||||||
|
FROM text_chunk \
|
||||||
|
WHERE user_id = $user_id AND embedding <|{take},40|> {embedding} \
|
||||||
|
LIMIT $limit",
|
||||||
|
take = take,
|
||||||
|
embedding = embedding_literal
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut response = db_client
|
||||||
|
.query(closest_query)
|
||||||
|
.bind(("user_id", user_id.to_owned()))
|
||||||
|
.bind(("limit", take as i64))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let rows: Vec<ChunkDistanceRow> = response.take(0)?;
|
||||||
|
|
||||||
|
let mut scored = Vec::with_capacity(rows.len());
|
||||||
|
for row in rows {
|
||||||
|
let similarity = row.distance.map(distance_to_similarity).unwrap_or_default();
|
||||||
|
scored.push(
|
||||||
|
Scored::new(ChunkSnippet {
|
||||||
|
id: row.id,
|
||||||
|
source_id: row.source_id,
|
||||||
|
chunk: row.chunk,
|
||||||
|
})
|
||||||
|
.with_vector_score(similarity),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(scored)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user