mirror of
https://github.com/perstarkse/minne.git
synced 2026-07-05 12:31:41 +02:00
retrieval-pipeline: v0
This commit is contained in:
@@ -0,0 +1,397 @@
|
||||
mod config;
|
||||
mod diagnostics;
|
||||
mod stages;
|
||||
mod strategies;
|
||||
|
||||
pub use config::{RetrievalConfig, RetrievalStrategy, RetrievalTuning};
|
||||
pub use diagnostics::{
|
||||
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
|
||||
PipelineDiagnostics,
|
||||
};
|
||||
|
||||
use crate::{reranking::RerankerLease, RetrievedChunk, RetrievedEntity};
|
||||
use async_openai::Client;
|
||||
use async_trait::async_trait;
|
||||
use common::{error::AppError, storage::db::SurrealDbClient};
|
||||
use std::time::{Duration, Instant};
|
||||
use tracing::info;
|
||||
|
||||
use stages::PipelineContext;
|
||||
use strategies::{InitialStrategyDriver, RevisedStrategyDriver};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum StrategyOutput {
|
||||
Entities(Vec<RetrievedEntity>),
|
||||
Chunks(Vec<RetrievedChunk>),
|
||||
}
|
||||
|
||||
impl StrategyOutput {
|
||||
pub fn as_entities(&self) -> Option<&[RetrievedEntity]> {
|
||||
match self {
|
||||
StrategyOutput::Entities(items) => Some(items),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_entities(self) -> Option<Vec<RetrievedEntity>> {
|
||||
match self {
|
||||
StrategyOutput::Entities(items) => Some(items),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_chunks(&self) -> Option<&[RetrievedChunk]> {
|
||||
match self {
|
||||
StrategyOutput::Chunks(items) => Some(items),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_chunks(self) -> Option<Vec<RetrievedChunk>> {
|
||||
match self {
|
||||
StrategyOutput::Chunks(items) => Some(items),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PipelineRunOutput<T> {
|
||||
pub results: T,
|
||||
pub diagnostics: Option<PipelineDiagnostics>,
|
||||
pub stage_timings: PipelineStageTimings,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum StageKind {
|
||||
Embed,
|
||||
CollectCandidates,
|
||||
GraphExpansion,
|
||||
ChunkAttach,
|
||||
Rerank,
|
||||
Assemble,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, serde::Serialize)]
|
||||
pub struct PipelineStageTimings {
|
||||
pub embed_ms: u128,
|
||||
pub collect_candidates_ms: u128,
|
||||
pub graph_expansion_ms: u128,
|
||||
pub chunk_attach_ms: u128,
|
||||
pub rerank_ms: u128,
|
||||
pub assemble_ms: u128,
|
||||
}
|
||||
|
||||
impl PipelineStageTimings {
|
||||
pub fn record(&mut self, kind: StageKind, duration: Duration) {
|
||||
let elapsed = duration.as_millis() as u128;
|
||||
match kind {
|
||||
StageKind::Embed => self.embed_ms += elapsed,
|
||||
StageKind::CollectCandidates => self.collect_candidates_ms += elapsed,
|
||||
StageKind::GraphExpansion => self.graph_expansion_ms += elapsed,
|
||||
StageKind::ChunkAttach => self.chunk_attach_ms += elapsed,
|
||||
StageKind::Rerank => self.rerank_ms += elapsed,
|
||||
StageKind::Assemble => self.assemble_ms += elapsed,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait PipelineStage: Send + Sync {
|
||||
fn kind(&self) -> StageKind;
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError>;
|
||||
}
|
||||
|
||||
pub type BoxedStage = Box<dyn PipelineStage + Send + Sync>;
|
||||
|
||||
pub trait StrategyDriver {
|
||||
type Output;
|
||||
|
||||
fn strategy(&self) -> RetrievalStrategy;
|
||||
fn stages(&self) -> Vec<BoxedStage>;
|
||||
fn override_tuning(&self, _config: &mut RetrievalConfig) {}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError>;
|
||||
}
|
||||
|
||||
pub async fn run_pipeline(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<StrategyOutput, AppError> {
|
||||
let input_chars = input_text.chars().count();
|
||||
let input_preview: String = input_text.chars().take(120).collect();
|
||||
let input_preview_clean = input_preview.replace('\n', " ");
|
||||
let preview_len = input_preview_clean.chars().count();
|
||||
info!(
|
||||
%user_id,
|
||||
input_chars,
|
||||
preview_truncated = input_chars > preview_len,
|
||||
preview = %input_preview_clean,
|
||||
"Starting ingestion retrieval pipeline"
|
||||
);
|
||||
|
||||
if config.strategy == RetrievalStrategy::Initial {
|
||||
let driver = InitialStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
return Ok(StrategyOutput::Entities(run.results));
|
||||
}
|
||||
|
||||
let driver = RevisedStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(StrategyOutput::Chunks(run.results))
|
||||
}
|
||||
|
||||
pub async fn run_pipeline_with_embedding(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<StrategyOutput, AppError> {
|
||||
if config.strategy == RetrievalStrategy::Initial {
|
||||
let driver = InitialStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
return Ok(StrategyOutput::Entities(run.results));
|
||||
}
|
||||
|
||||
let driver = RevisedStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(StrategyOutput::Chunks(run.results))
|
||||
}
|
||||
|
||||
pub async fn run_pipeline_with_embedding_with_metrics(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
|
||||
if config.strategy == RetrievalStrategy::Initial {
|
||||
let driver = InitialStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
return Ok(PipelineRunOutput {
|
||||
results: StrategyOutput::Entities(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
});
|
||||
}
|
||||
|
||||
let driver = RevisedStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(PipelineRunOutput {
|
||||
results: StrategyOutput::Chunks(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn run_pipeline_with_embedding_with_diagnostics(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
|
||||
if config.strategy == RetrievalStrategy::Initial {
|
||||
let driver = InitialStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
true,
|
||||
)
|
||||
.await?;
|
||||
return Ok(PipelineRunOutput {
|
||||
results: StrategyOutput::Entities(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
});
|
||||
}
|
||||
|
||||
let driver = RevisedStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
true,
|
||||
)
|
||||
.await?;
|
||||
Ok(PipelineRunOutput {
|
||||
results: StrategyOutput::Chunks(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value {
|
||||
serde_json::json!(entities
|
||||
.iter()
|
||||
.map(|entry| {
|
||||
serde_json::json!({
|
||||
"KnowledgeEntity": {
|
||||
"id": entry.entity.id,
|
||||
"name": entry.entity.name,
|
||||
"description": entry.entity.description,
|
||||
"score": round_score(entry.score),
|
||||
"chunks": entry.chunks.iter().map(|chunk| {
|
||||
serde_json::json!({
|
||||
"score": round_score(chunk.score),
|
||||
"content": chunk.chunk.chunk
|
||||
})
|
||||
}).collect::<Vec<_>>()
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
}
|
||||
|
||||
async fn execute_strategy<D: StrategyDriver>(
|
||||
driver: D,
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
query_embedding: Option<Vec<f32>>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
mut config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
capture_diagnostics: bool,
|
||||
) -> Result<PipelineRunOutput<D::Output>, AppError> {
|
||||
driver.override_tuning(&mut config);
|
||||
let ctx = match query_embedding {
|
||||
Some(embedding) => PipelineContext::with_embedding(
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding,
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
config,
|
||||
reranker,
|
||||
),
|
||||
None => PipelineContext::new(
|
||||
db_client,
|
||||
openai_client,
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
config,
|
||||
reranker,
|
||||
),
|
||||
};
|
||||
|
||||
run_with_driver(driver, ctx, capture_diagnostics).await
|
||||
}
|
||||
|
||||
async fn run_with_driver<D: StrategyDriver>(
|
||||
driver: D,
|
||||
mut ctx: PipelineContext<'_>,
|
||||
capture_diagnostics: bool,
|
||||
) -> Result<PipelineRunOutput<D::Output>, AppError> {
|
||||
if capture_diagnostics {
|
||||
ctx.enable_diagnostics();
|
||||
}
|
||||
|
||||
for stage in driver.stages() {
|
||||
let start = Instant::now();
|
||||
stage.execute(&mut ctx).await?;
|
||||
ctx.record_stage_duration(stage.kind(), start.elapsed());
|
||||
}
|
||||
|
||||
let diagnostics = ctx.take_diagnostics();
|
||||
let stage_timings = ctx.take_stage_timings();
|
||||
let results = driver.finalize(&mut ctx)?;
|
||||
|
||||
Ok(PipelineRunOutput {
|
||||
results,
|
||||
diagnostics,
|
||||
stage_timings,
|
||||
})
|
||||
}
|
||||
|
||||
fn round_score(value: f32) -> f64 {
|
||||
(f64::from(value) * 1000.0).round() / 1000.0
|
||||
}
|
||||
Reference in New Issue
Block a user