retrieval-pipeline: v0

This commit is contained in:
Per Stark
2025-11-18 21:20:27 +01:00
parent 73e709153a
commit 97d35a8982
32 changed files with 1189 additions and 453 deletions
+397
View File
@@ -0,0 +1,397 @@
mod config;
mod diagnostics;
mod stages;
mod strategies;
pub use config::{RetrievalConfig, RetrievalStrategy, RetrievalTuning};
pub use diagnostics::{
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
PipelineDiagnostics,
};
use crate::{reranking::RerankerLease, RetrievedChunk, RetrievedEntity};
use async_openai::Client;
use async_trait::async_trait;
use common::{error::AppError, storage::db::SurrealDbClient};
use std::time::{Duration, Instant};
use tracing::info;
use stages::PipelineContext;
use strategies::{InitialStrategyDriver, RevisedStrategyDriver};
#[derive(Debug, Clone)]
pub enum StrategyOutput {
Entities(Vec<RetrievedEntity>),
Chunks(Vec<RetrievedChunk>),
}
impl StrategyOutput {
pub fn as_entities(&self) -> Option<&[RetrievedEntity]> {
match self {
StrategyOutput::Entities(items) => Some(items),
_ => None,
}
}
pub fn into_entities(self) -> Option<Vec<RetrievedEntity>> {
match self {
StrategyOutput::Entities(items) => Some(items),
_ => None,
}
}
pub fn as_chunks(&self) -> Option<&[RetrievedChunk]> {
match self {
StrategyOutput::Chunks(items) => Some(items),
_ => None,
}
}
pub fn into_chunks(self) -> Option<Vec<RetrievedChunk>> {
match self {
StrategyOutput::Chunks(items) => Some(items),
_ => None,
}
}
}
#[derive(Debug)]
pub struct PipelineRunOutput<T> {
pub results: T,
pub diagnostics: Option<PipelineDiagnostics>,
pub stage_timings: PipelineStageTimings,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StageKind {
Embed,
CollectCandidates,
GraphExpansion,
ChunkAttach,
Rerank,
Assemble,
}
#[derive(Debug, Clone, Default, serde::Serialize)]
pub struct PipelineStageTimings {
pub embed_ms: u128,
pub collect_candidates_ms: u128,
pub graph_expansion_ms: u128,
pub chunk_attach_ms: u128,
pub rerank_ms: u128,
pub assemble_ms: u128,
}
impl PipelineStageTimings {
pub fn record(&mut self, kind: StageKind, duration: Duration) {
let elapsed = duration.as_millis() as u128;
match kind {
StageKind::Embed => self.embed_ms += elapsed,
StageKind::CollectCandidates => self.collect_candidates_ms += elapsed,
StageKind::GraphExpansion => self.graph_expansion_ms += elapsed,
StageKind::ChunkAttach => self.chunk_attach_ms += elapsed,
StageKind::Rerank => self.rerank_ms += elapsed,
StageKind::Assemble => self.assemble_ms += elapsed,
}
}
}
#[async_trait]
pub trait PipelineStage: Send + Sync {
fn kind(&self) -> StageKind;
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError>;
}
pub type BoxedStage = Box<dyn PipelineStage + Send + Sync>;
pub trait StrategyDriver {
type Output;
fn strategy(&self) -> RetrievalStrategy;
fn stages(&self) -> Vec<BoxedStage>;
fn override_tuning(&self, _config: &mut RetrievalConfig) {}
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError>;
}
pub async fn run_pipeline(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<StrategyOutput, AppError> {
let input_chars = input_text.chars().count();
let input_preview: String = input_text.chars().take(120).collect();
let input_preview_clean = input_preview.replace('\n', " ");
let preview_len = input_preview_clean.chars().count();
info!(
%user_id,
input_chars,
preview_truncated = input_chars > preview_len,
preview = %input_preview_clean,
"Starting ingestion retrieval pipeline"
);
if config.strategy == RetrievalStrategy::Initial {
let driver = InitialStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
None,
input_text,
user_id,
config,
reranker,
false,
)
.await?;
return Ok(StrategyOutput::Entities(run.results));
}
let driver = RevisedStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
None,
input_text,
user_id,
config,
reranker,
false,
)
.await?;
Ok(StrategyOutput::Chunks(run.results))
}
pub async fn run_pipeline_with_embedding(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<StrategyOutput, AppError> {
if config.strategy == RetrievalStrategy::Initial {
let driver = InitialStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
false,
)
.await?;
return Ok(StrategyOutput::Entities(run.results));
}
let driver = RevisedStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
false,
)
.await?;
Ok(StrategyOutput::Chunks(run.results))
}
pub async fn run_pipeline_with_embedding_with_metrics(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
if config.strategy == RetrievalStrategy::Initial {
let driver = InitialStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
false,
)
.await?;
return Ok(PipelineRunOutput {
results: StrategyOutput::Entities(run.results),
diagnostics: run.diagnostics,
stage_timings: run.stage_timings,
});
}
let driver = RevisedStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
false,
)
.await?;
Ok(PipelineRunOutput {
results: StrategyOutput::Chunks(run.results),
diagnostics: run.diagnostics,
stage_timings: run.stage_timings,
})
}
pub async fn run_pipeline_with_embedding_with_diagnostics(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
if config.strategy == RetrievalStrategy::Initial {
let driver = InitialStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
true,
)
.await?;
return Ok(PipelineRunOutput {
results: StrategyOutput::Entities(run.results),
diagnostics: run.diagnostics,
stage_timings: run.stage_timings,
});
}
let driver = RevisedStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
true,
)
.await?;
Ok(PipelineRunOutput {
results: StrategyOutput::Chunks(run.results),
diagnostics: run.diagnostics,
stage_timings: run.stage_timings,
})
}
pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value {
serde_json::json!(entities
.iter()
.map(|entry| {
serde_json::json!({
"KnowledgeEntity": {
"id": entry.entity.id,
"name": entry.entity.name,
"description": entry.entity.description,
"score": round_score(entry.score),
"chunks": entry.chunks.iter().map(|chunk| {
serde_json::json!({
"score": round_score(chunk.score),
"content": chunk.chunk.chunk
})
}).collect::<Vec<_>>()
}
})
})
.collect::<Vec<_>>())
}
async fn execute_strategy<D: StrategyDriver>(
driver: D,
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
query_embedding: Option<Vec<f32>>,
input_text: &str,
user_id: &str,
mut config: RetrievalConfig,
reranker: Option<RerankerLease>,
capture_diagnostics: bool,
) -> Result<PipelineRunOutput<D::Output>, AppError> {
driver.override_tuning(&mut config);
let ctx = match query_embedding {
Some(embedding) => PipelineContext::with_embedding(
db_client,
openai_client,
embedding,
input_text.to_owned(),
user_id.to_owned(),
config,
reranker,
),
None => PipelineContext::new(
db_client,
openai_client,
input_text.to_owned(),
user_id.to_owned(),
config,
reranker,
),
};
run_with_driver(driver, ctx, capture_diagnostics).await
}
async fn run_with_driver<D: StrategyDriver>(
driver: D,
mut ctx: PipelineContext<'_>,
capture_diagnostics: bool,
) -> Result<PipelineRunOutput<D::Output>, AppError> {
if capture_diagnostics {
ctx.enable_diagnostics();
}
for stage in driver.stages() {
let start = Instant::now();
stage.execute(&mut ctx).await?;
ctx.record_stage_duration(stage.kind(), start.elapsed());
}
let diagnostics = ctx.take_diagnostics();
let stage_timings = ctx.take_stage_timings();
let results = driver.finalize(&mut ctx)?;
Ok(PipelineRunOutput {
results,
diagnostics,
stage_timings,
})
}
fn round_score(value: f32) -> f64 {
(f64::from(value) * 1000.0).round() / 1000.0
}