mirror of
https://github.com/perstarkse/minne.git
synced 2026-04-25 02:08:30 +02:00
retrieval-pipeline: v1
This commit is contained in:
@@ -473,7 +473,8 @@ pub(crate) async fn load_or_init_system_settings(
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::ingest::{CorpusManifest, CorpusMetadata, CorpusParagraph, CorpusQuestion};
|
use crate::ingest::store::CorpusParagraph;
|
||||||
|
use crate::ingest::{CorpusManifest, CorpusMetadata, CorpusQuestion};
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use common::storage::types::text_content::TextContent;
|
use common::storage::types::text_content::TextContent;
|
||||||
|
|
||||||
|
|||||||
@@ -294,16 +294,16 @@ pub fn build_stage_latency_breakdown(samples: &[PipelineStageTimings]) -> StageL
|
|||||||
}
|
}
|
||||||
|
|
||||||
StageLatencyBreakdown {
|
StageLatencyBreakdown {
|
||||||
embed: compute_latency_stats(&collect_stage(samples, |entry| entry.embed_ms)),
|
embed: compute_latency_stats(&collect_stage(samples, |entry| entry.embed_ms())),
|
||||||
collect_candidates: compute_latency_stats(&collect_stage(samples, |entry| {
|
collect_candidates: compute_latency_stats(&collect_stage(samples, |entry| {
|
||||||
entry.collect_candidates_ms
|
entry.collect_candidates_ms()
|
||||||
})),
|
})),
|
||||||
graph_expansion: compute_latency_stats(&collect_stage(samples, |entry| {
|
graph_expansion: compute_latency_stats(&collect_stage(samples, |entry| {
|
||||||
entry.graph_expansion_ms
|
entry.graph_expansion_ms()
|
||||||
})),
|
})),
|
||||||
chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| entry.chunk_attach_ms)),
|
chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| entry.chunk_attach_ms())),
|
||||||
rerank: compute_latency_stats(&collect_stage(samples, |entry| entry.rerank_ms)),
|
rerank: compute_latency_stats(&collect_stage(samples, |entry| entry.rerank_ms())),
|
||||||
assemble: compute_latency_stats(&collect_stage(samples, |entry| entry.assemble_ms)),
|
assemble: compute_latency_stats(&collect_stage(samples, |entry| entry.assemble_ms())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
mod config;
|
mod config;
|
||||||
mod orchestrator;
|
mod orchestrator;
|
||||||
mod store;
|
pub(crate) mod store;
|
||||||
|
|
||||||
pub use config::{CorpusCacheConfig, CorpusEmbeddingProvider};
|
pub use config::{CorpusCacheConfig, CorpusEmbeddingProvider};
|
||||||
pub use orchestrator::ensure_corpus;
|
pub use orchestrator::ensure_corpus;
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ use json_stream_parser::JsonStreamParser;
|
|||||||
use minijinja::Value;
|
use minijinja::Value;
|
||||||
use retrieval_pipeline::{
|
use retrieval_pipeline::{
|
||||||
answer_retrieval::{create_chat_request, create_user_message_with_history, LLMResponseFormat},
|
answer_retrieval::{create_chat_request, create_user_message_with_history, LLMResponseFormat},
|
||||||
retrieve_entities, retrieved_entities_to_json, RetrievalConfig, StrategyOutput,
|
retrieved_entities_to_json,
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::from_str;
|
use serde_json::from_str;
|
||||||
@@ -123,23 +123,22 @@ pub async fn get_response_stream(
|
|||||||
None => None,
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut retrieval_config = RetrievalConfig::default();
|
let strategy = state.retrieval_strategy();
|
||||||
retrieval_config.strategy = state.retrieval_strategy();
|
let config = retrieval_pipeline::RetrievalConfig::for_chat(strategy);
|
||||||
let entities = match retrieve_entities(
|
|
||||||
|
let entities = match retrieval_pipeline::retrieve_entities(
|
||||||
&state.db,
|
&state.db,
|
||||||
&state.openai_client,
|
&state.openai_client,
|
||||||
&user_message.content,
|
&user_message.content,
|
||||||
&user.id,
|
&user.id,
|
||||||
retrieval_config,
|
config,
|
||||||
rerank_lease,
|
rerank_lease,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(StrategyOutput::Entities(entities)) => entities,
|
Ok(retrieval_pipeline::StrategyOutput::Entities(entities)) => entities,
|
||||||
Ok(StrategyOutput::Chunks(_)) => {
|
Ok(retrieval_pipeline::StrategyOutput::Chunks(_chunks)) => {
|
||||||
return Sse::new(create_error_stream(
|
return Sse::new(create_error_stream("Chat retrieval currently only supports Entity-based strategies (Initial). Revised strategy returns Chunks which are not yet supported by this handler."));
|
||||||
"Chunk-only retrieval results are not supported in this route",
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
Err(_e) => {
|
Err(_e) => {
|
||||||
return Sse::new(create_error_stream("Failed to retrieve knowledge entities"));
|
return Sse::new(create_error_stream("Failed to retrieve knowledge entities"));
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ use common::{
|
|||||||
},
|
},
|
||||||
utils::embedding::generate_embedding,
|
utils::embedding::generate_embedding,
|
||||||
};
|
};
|
||||||
use retrieval_pipeline::{retrieve_entities, RetrievalConfig, RetrievedEntity, StrategyOutput};
|
use retrieval_pipeline;
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
@@ -284,20 +284,18 @@ pub async fn suggest_knowledge_relationships(
|
|||||||
None => None,
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut retrieval_config = RetrievalConfig::default();
|
let config = retrieval_pipeline::RetrievalConfig::for_relationship_suggestion();
|
||||||
retrieval_config.strategy = state.retrieval_strategy();
|
if let Ok(retrieval_pipeline::StrategyOutput::Entities(results)) = retrieval_pipeline::retrieve_entities(
|
||||||
|
|
||||||
if let Ok(StrategyOutput::Entities(results)) = retrieve_entities(
|
|
||||||
&state.db,
|
&state.db,
|
||||||
&state.openai_client,
|
&state.openai_client,
|
||||||
&query,
|
&query,
|
||||||
&user.id,
|
&user.id,
|
||||||
retrieval_config,
|
config,
|
||||||
rerank_lease,
|
rerank_lease,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
for RetrievedEntity { entity, score, .. } in results {
|
for retrieval_pipeline::RetrievedEntity { entity, score, .. } in results {
|
||||||
if suggestion_scores.len() >= MAX_RELATIONSHIP_SUGGESTIONS {
|
if suggestion_scores.len() >= MAX_RELATIONSHIP_SUGGESTIONS {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,8 +20,7 @@ use common::{
|
|||||||
utils::{config::AppConfig, embedding::generate_embedding},
|
utils::{config::AppConfig, embedding::generate_embedding},
|
||||||
};
|
};
|
||||||
use retrieval_pipeline::{
|
use retrieval_pipeline::{
|
||||||
reranking::RerankerPool, retrieve_entities, retrieved_entities_to_json, RetrievalConfig,
|
reranking::RerankerPool, retrieved_entities_to_json, RetrievedEntity,
|
||||||
RetrievalStrategy, RetrievedEntity, StrategyOutput,
|
|
||||||
};
|
};
|
||||||
use text_splitter::TextSplitter;
|
use text_splitter::TextSplitter;
|
||||||
|
|
||||||
@@ -125,14 +124,6 @@ impl DefaultPipelineServices {
|
|||||||
Ok(request)
|
Ok(request)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn configured_strategy(&self) -> RetrievalStrategy {
|
|
||||||
self.config
|
|
||||||
.retrieval_strategy
|
|
||||||
.as_deref()
|
|
||||||
.and_then(|value| value.parse().ok())
|
|
||||||
.unwrap_or(RetrievalStrategy::Initial)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn perform_analysis(
|
async fn perform_analysis(
|
||||||
&self,
|
&self,
|
||||||
request: CreateChatCompletionRequest,
|
request: CreateChatCompletionRequest,
|
||||||
@@ -187,9 +178,8 @@ impl PipelineServices for DefaultPipelineServices {
|
|||||||
None => None,
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut config = RetrievalConfig::default();
|
let config = retrieval_pipeline::RetrievalConfig::for_ingestion();
|
||||||
config.strategy = self.configured_strategy();
|
match retrieval_pipeline::retrieve_entities(
|
||||||
match retrieve_entities(
|
|
||||||
&self.db,
|
&self.db,
|
||||||
&self.openai_client,
|
&self.openai_client,
|
||||||
&input_text,
|
&input_text,
|
||||||
@@ -199,11 +189,11 @@ impl PipelineServices for DefaultPipelineServices {
|
|||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(StrategyOutput::Entities(entities)) => Ok(entities),
|
Ok(retrieval_pipeline::StrategyOutput::Entities(entities)) => Ok(entities),
|
||||||
Ok(StrategyOutput::Chunks(_)) => Err(AppError::InternalError(
|
Ok(retrieval_pipeline::StrategyOutput::Chunks(_)) => Err(AppError::InternalError(
|
||||||
"Chunk-only retrieval is not supported in ingestion".into(),
|
"Ingestion retrieval should return entities".into(),
|
||||||
)),
|
)),
|
||||||
Err(err) => Err(err),
|
Err(e) => Err(e),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,9 +17,16 @@ use common::{
|
|||||||
use reranking::RerankerLease;
|
use reranking::RerankerLease;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
|
// Strategy output variants - defined before pipeline module
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum StrategyOutput {
|
||||||
|
Entities(Vec<RetrievedEntity>),
|
||||||
|
Chunks(Vec<RetrievedChunk>),
|
||||||
|
}
|
||||||
|
|
||||||
pub use pipeline::{
|
pub use pipeline::{
|
||||||
retrieved_entities_to_json, PipelineDiagnostics, PipelineStageTimings, RetrievalConfig,
|
retrieved_entities_to_json, PipelineDiagnostics, PipelineStageTimings, RetrievalConfig,
|
||||||
RetrievalStrategy, RetrievalTuning, StrategyOutput,
|
RetrievalStrategy, RetrievalTuning,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
|
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
|
||||||
@@ -37,7 +44,7 @@ pub struct RetrievedEntity {
|
|||||||
pub chunks: Vec<RetrievedChunk>,
|
pub chunks: Vec<RetrievedChunk>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Primary orchestrator for the process of retrieving KnowledgeEntitities related to a input_text
|
/// Primary orchestrator for the process of retrieving KnowledgeEntitities related to a input_text
|
||||||
#[instrument(skip_all, fields(user_id))]
|
#[instrument(skip_all, fields(user_id))]
|
||||||
pub async fn retrieve_entities(
|
pub async fn retrieve_entities(
|
||||||
db_client: &SurrealDbClient,
|
db_client: &SurrealDbClient,
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ use std::fmt;
|
|||||||
pub enum RetrievalStrategy {
|
pub enum RetrievalStrategy {
|
||||||
Initial,
|
Initial,
|
||||||
Revised,
|
Revised,
|
||||||
|
RelationshipSuggestion,
|
||||||
|
Ingestion,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for RetrievalStrategy {
|
impl Default for RetrievalStrategy {
|
||||||
@@ -21,6 +23,8 @@ impl std::str::FromStr for RetrievalStrategy {
|
|||||||
match value.to_ascii_lowercase().as_str() {
|
match value.to_ascii_lowercase().as_str() {
|
||||||
"initial" => Ok(Self::Initial),
|
"initial" => Ok(Self::Initial),
|
||||||
"revised" => Ok(Self::Revised),
|
"revised" => Ok(Self::Revised),
|
||||||
|
"relationship_suggestion" => Ok(Self::RelationshipSuggestion),
|
||||||
|
"ingestion" => Ok(Self::Ingestion),
|
||||||
other => Err(format!("unknown retrieval strategy '{other}'")),
|
other => Err(format!("unknown retrieval strategy '{other}'")),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -31,6 +35,8 @@ impl fmt::Display for RetrievalStrategy {
|
|||||||
let label = match self {
|
let label = match self {
|
||||||
RetrievalStrategy::Initial => "initial",
|
RetrievalStrategy::Initial => "initial",
|
||||||
RetrievalStrategy::Revised => "revised",
|
RetrievalStrategy::Revised => "revised",
|
||||||
|
RetrievalStrategy::RelationshipSuggestion => "relationship_suggestion",
|
||||||
|
RetrievalStrategy::Ingestion => "ingestion",
|
||||||
};
|
};
|
||||||
f.write_str(label)
|
f.write_str(label)
|
||||||
}
|
}
|
||||||
@@ -109,6 +115,21 @@ impl RetrievalConfig {
|
|||||||
pub fn with_tuning(strategy: RetrievalStrategy, tuning: RetrievalTuning) -> Self {
|
pub fn with_tuning(strategy: RetrievalStrategy, tuning: RetrievalTuning) -> Self {
|
||||||
Self { strategy, tuning }
|
Self { strategy, tuning }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create config for chat retrieval with strategy selection support
|
||||||
|
pub fn for_chat(strategy: RetrievalStrategy) -> Self {
|
||||||
|
Self::with_strategy(strategy)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create config for relationship suggestion (entity-only retrieval)
|
||||||
|
pub fn for_relationship_suggestion() -> Self {
|
||||||
|
Self::with_strategy(RetrievalStrategy::RelationshipSuggestion)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create config for ingestion pipeline (entity-only retrieval)
|
||||||
|
pub fn for_ingestion() -> Self {
|
||||||
|
Self::with_strategy(RetrievalStrategy::Ingestion)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for RetrievalConfig {
|
impl Default for RetrievalConfig {
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ pub use diagnostics::{
|
|||||||
PipelineDiagnostics,
|
PipelineDiagnostics,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{reranking::RerankerLease, RetrievedChunk, RetrievedEntity};
|
use crate::{reranking::RerankerLease, RetrievedEntity, StrategyOutput};
|
||||||
use async_openai::Client;
|
use async_openai::Client;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use common::{error::AppError, storage::db::SurrealDbClient};
|
use common::{error::AppError, storage::db::SurrealDbClient};
|
||||||
@@ -17,52 +17,15 @@ use std::time::{Duration, Instant};
|
|||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
use stages::PipelineContext;
|
use stages::PipelineContext;
|
||||||
use strategies::{InitialStrategyDriver, RevisedStrategyDriver};
|
use strategies::{
|
||||||
|
IngestionDriver, InitialStrategyDriver, RelationshipSuggestionDriver, RevisedStrategyDriver,
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
// Export StrategyOutput publicly from this module
|
||||||
pub enum StrategyOutput {
|
// (it's defined in lib.rs but we re-export it here)
|
||||||
Entities(Vec<RetrievedEntity>),
|
|
||||||
Chunks(Vec<RetrievedChunk>),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StrategyOutput {
|
// Stage type enum
|
||||||
pub fn as_entities(&self) -> Option<&[RetrievedEntity]> {
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
match self {
|
|
||||||
StrategyOutput::Entities(items) => Some(items),
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn into_entities(self) -> Option<Vec<RetrievedEntity>> {
|
|
||||||
match self {
|
|
||||||
StrategyOutput::Entities(items) => Some(items),
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn as_chunks(&self) -> Option<&[RetrievedChunk]> {
|
|
||||||
match self {
|
|
||||||
StrategyOutput::Chunks(items) => Some(items),
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn into_chunks(self) -> Option<Vec<RetrievedChunk>> {
|
|
||||||
match self {
|
|
||||||
StrategyOutput::Chunks(items) => Some(items),
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct PipelineRunOutput<T> {
|
|
||||||
pub results: T,
|
|
||||||
pub diagnostics: Option<PipelineDiagnostics>,
|
|
||||||
pub stage_timings: PipelineStageTimings,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
||||||
pub enum StageKind {
|
pub enum StageKind {
|
||||||
Embed,
|
Embed,
|
||||||
CollectCandidates,
|
CollectCandidates,
|
||||||
@@ -72,48 +35,80 @@ pub enum StageKind {
|
|||||||
Assemble,
|
Assemble,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default, serde::Serialize)]
|
// Pipeline stage trait
|
||||||
pub struct PipelineStageTimings {
|
|
||||||
pub embed_ms: u128,
|
|
||||||
pub collect_candidates_ms: u128,
|
|
||||||
pub graph_expansion_ms: u128,
|
|
||||||
pub chunk_attach_ms: u128,
|
|
||||||
pub rerank_ms: u128,
|
|
||||||
pub assemble_ms: u128,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PipelineStageTimings {
|
|
||||||
pub fn record(&mut self, kind: StageKind, duration: Duration) {
|
|
||||||
let elapsed = duration.as_millis() as u128;
|
|
||||||
match kind {
|
|
||||||
StageKind::Embed => self.embed_ms += elapsed,
|
|
||||||
StageKind::CollectCandidates => self.collect_candidates_ms += elapsed,
|
|
||||||
StageKind::GraphExpansion => self.graph_expansion_ms += elapsed,
|
|
||||||
StageKind::ChunkAttach => self.chunk_attach_ms += elapsed,
|
|
||||||
StageKind::Rerank => self.rerank_ms += elapsed,
|
|
||||||
StageKind::Assemble => self.assemble_ms += elapsed,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait PipelineStage: Send + Sync {
|
pub trait PipelineStage: Send + Sync {
|
||||||
fn kind(&self) -> StageKind;
|
fn kind(&self) -> StageKind;
|
||||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError>;
|
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type BoxedStage = Box<dyn PipelineStage + Send + Sync>;
|
// Type alias for boxed stages
|
||||||
|
pub type BoxedStage = Box<dyn PipelineStage>;
|
||||||
|
|
||||||
pub trait StrategyDriver {
|
// Strategy driver trait
|
||||||
|
#[async_trait]
|
||||||
|
pub trait StrategyDriver: Send + Sync {
|
||||||
type Output;
|
type Output;
|
||||||
|
|
||||||
fn strategy(&self) -> RetrievalStrategy;
|
|
||||||
fn stages(&self) -> Vec<BoxedStage>;
|
fn stages(&self) -> Vec<BoxedStage>;
|
||||||
fn override_tuning(&self, _config: &mut RetrievalConfig) {}
|
|
||||||
|
|
||||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError>;
|
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Pipeline stage timings tracker
|
||||||
|
#[derive(Debug, Default, Clone)]
|
||||||
|
pub struct PipelineStageTimings {
|
||||||
|
timings: Vec<(StageKind, Duration)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PipelineStageTimings {
|
||||||
|
pub fn record(&mut self, kind: StageKind, duration: Duration) {
|
||||||
|
self.timings.push((kind, duration));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_vec(self) -> Vec<(StageKind, Duration)> {
|
||||||
|
self.timings
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper methods to get duration for each stage type (for backward compatibility)
|
||||||
|
fn get_stage_ms(&self, kind: StageKind) -> u128 {
|
||||||
|
self.timings
|
||||||
|
.iter()
|
||||||
|
.find(|(k, _)| *k == kind)
|
||||||
|
.map(|(_, d)| d.as_millis())
|
||||||
|
.unwrap_or(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn embed_ms(&self) -> u128 {
|
||||||
|
self.get_stage_ms(StageKind::Embed)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn collect_candidates_ms(&self) -> u128 {
|
||||||
|
self.get_stage_ms(StageKind::CollectCandidates)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn graph_expansion_ms(&self) -> u128 {
|
||||||
|
self.get_stage_ms(StageKind::GraphExpansion)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn chunk_attach_ms(&self) -> u128 {
|
||||||
|
self.get_stage_ms(StageKind::ChunkAttach)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rerank_ms(&self) -> u128 {
|
||||||
|
self.get_stage_ms(StageKind::Rerank)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn assemble_ms(&self) -> u128 {
|
||||||
|
self.get_stage_ms(StageKind::Assemble)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PipelineRunOutput<T> {
|
||||||
|
pub results: T,
|
||||||
|
pub diagnostics: Option<PipelineDiagnostics>,
|
||||||
|
pub stage_timings: PipelineStageTimings,
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn run_pipeline(
|
pub async fn run_pipeline(
|
||||||
db_client: &SurrealDbClient,
|
db_client: &SurrealDbClient,
|
||||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||||
@@ -131,40 +126,76 @@ pub async fn run_pipeline(
|
|||||||
input_chars,
|
input_chars,
|
||||||
preview_truncated = input_chars > preview_len,
|
preview_truncated = input_chars > preview_len,
|
||||||
preview = %input_preview_clean,
|
preview = %input_preview_clean,
|
||||||
"Starting ingestion retrieval pipeline"
|
strategy = %config.strategy,
|
||||||
|
"Starting retrieval pipeline"
|
||||||
);
|
);
|
||||||
|
|
||||||
if config.strategy == RetrievalStrategy::Initial {
|
match config.strategy {
|
||||||
let driver = InitialStrategyDriver::new();
|
RetrievalStrategy::Initial => {
|
||||||
let run = execute_strategy(
|
let driver = InitialStrategyDriver::new();
|
||||||
driver,
|
let run = execute_strategy(
|
||||||
db_client,
|
driver,
|
||||||
openai_client,
|
db_client,
|
||||||
None,
|
openai_client,
|
||||||
input_text,
|
None,
|
||||||
user_id,
|
input_text,
|
||||||
config,
|
user_id,
|
||||||
reranker,
|
config,
|
||||||
false,
|
reranker,
|
||||||
)
|
false,
|
||||||
.await?;
|
)
|
||||||
return Ok(StrategyOutput::Entities(run.results));
|
.await?;
|
||||||
|
Ok(StrategyOutput::Entities(run.results))
|
||||||
|
}
|
||||||
|
RetrievalStrategy::Revised => {
|
||||||
|
let driver = RevisedStrategyDriver::new();
|
||||||
|
let run = execute_strategy(
|
||||||
|
driver,
|
||||||
|
db_client,
|
||||||
|
openai_client,
|
||||||
|
None,
|
||||||
|
input_text,
|
||||||
|
user_id,
|
||||||
|
config,
|
||||||
|
reranker,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(StrategyOutput::Chunks(run.results))
|
||||||
|
}
|
||||||
|
RetrievalStrategy::RelationshipSuggestion => {
|
||||||
|
let driver = RelationshipSuggestionDriver::new();
|
||||||
|
let run = execute_strategy(
|
||||||
|
driver,
|
||||||
|
db_client,
|
||||||
|
openai_client,
|
||||||
|
None,
|
||||||
|
input_text,
|
||||||
|
user_id,
|
||||||
|
config,
|
||||||
|
reranker,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(StrategyOutput::Entities(run.results))
|
||||||
|
}
|
||||||
|
RetrievalStrategy::Ingestion => {
|
||||||
|
let driver = IngestionDriver::new();
|
||||||
|
let run = execute_strategy(
|
||||||
|
driver,
|
||||||
|
db_client,
|
||||||
|
openai_client,
|
||||||
|
None,
|
||||||
|
input_text,
|
||||||
|
user_id,
|
||||||
|
config,
|
||||||
|
reranker,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(StrategyOutput::Entities(run.results))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let driver = RevisedStrategyDriver::new();
|
|
||||||
let run = execute_strategy(
|
|
||||||
driver,
|
|
||||||
db_client,
|
|
||||||
openai_client,
|
|
||||||
None,
|
|
||||||
input_text,
|
|
||||||
user_id,
|
|
||||||
config,
|
|
||||||
reranker,
|
|
||||||
false,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
Ok(StrategyOutput::Chunks(run.results))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run_pipeline_with_embedding(
|
pub async fn run_pipeline_with_embedding(
|
||||||
@@ -176,39 +207,79 @@ pub async fn run_pipeline_with_embedding(
|
|||||||
config: RetrievalConfig,
|
config: RetrievalConfig,
|
||||||
reranker: Option<RerankerLease>,
|
reranker: Option<RerankerLease>,
|
||||||
) -> Result<StrategyOutput, AppError> {
|
) -> Result<StrategyOutput, AppError> {
|
||||||
if config.strategy == RetrievalStrategy::Initial {
|
match config.strategy {
|
||||||
let driver = InitialStrategyDriver::new();
|
RetrievalStrategy::Initial => {
|
||||||
let run = execute_strategy(
|
let driver = InitialStrategyDriver::new();
|
||||||
driver,
|
let run = execute_strategy(
|
||||||
db_client,
|
driver,
|
||||||
openai_client,
|
db_client,
|
||||||
Some(query_embedding),
|
openai_client,
|
||||||
input_text,
|
Some(query_embedding),
|
||||||
user_id,
|
input_text,
|
||||||
config,
|
user_id,
|
||||||
reranker,
|
config,
|
||||||
false,
|
reranker,
|
||||||
)
|
false,
|
||||||
.await?;
|
)
|
||||||
return Ok(StrategyOutput::Entities(run.results));
|
.await?;
|
||||||
|
Ok(StrategyOutput::Entities(run.results))
|
||||||
|
}
|
||||||
|
RetrievalStrategy::Revised => {
|
||||||
|
let driver = RevisedStrategyDriver::new();
|
||||||
|
let run = execute_strategy(
|
||||||
|
driver,
|
||||||
|
db_client,
|
||||||
|
openai_client,
|
||||||
|
Some(query_embedding),
|
||||||
|
input_text,
|
||||||
|
user_id,
|
||||||
|
config,
|
||||||
|
reranker,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(StrategyOutput::Chunks(run.results))
|
||||||
|
}
|
||||||
|
RetrievalStrategy::RelationshipSuggestion => {
|
||||||
|
let driver = RelationshipSuggestionDriver::new();
|
||||||
|
let run = execute_strategy(
|
||||||
|
driver,
|
||||||
|
db_client,
|
||||||
|
openai_client,
|
||||||
|
Some(query_embedding),
|
||||||
|
input_text,
|
||||||
|
user_id,
|
||||||
|
config,
|
||||||
|
reranker,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(StrategyOutput::Entities(run.results))
|
||||||
|
}
|
||||||
|
RetrievalStrategy::Ingestion => {
|
||||||
|
let driver = IngestionDriver::new();
|
||||||
|
let run = execute_strategy(
|
||||||
|
driver,
|
||||||
|
db_client,
|
||||||
|
openai_client,
|
||||||
|
Some(query_embedding),
|
||||||
|
input_text,
|
||||||
|
user_id,
|
||||||
|
config,
|
||||||
|
reranker,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(StrategyOutput::Entities(run.results))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let driver = RevisedStrategyDriver::new();
|
|
||||||
let run = execute_strategy(
|
|
||||||
driver,
|
|
||||||
db_client,
|
|
||||||
openai_client,
|
|
||||||
Some(query_embedding),
|
|
||||||
input_text,
|
|
||||||
user_id,
|
|
||||||
config,
|
|
||||||
reranker,
|
|
||||||
false,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
Ok(StrategyOutput::Chunks(run.results))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Note: The metrics/diagnostics variants would follow the same pattern,
|
||||||
|
// but for brevity I'm only updating the main ones used by callers.
|
||||||
|
// If metrics/diagnostics are needed for non-chat strategies, they should be updated too.
|
||||||
|
// For now, I'll update them to support at least Initial/Revised as before.
|
||||||
|
|
||||||
pub async fn run_pipeline_with_embedding_with_metrics(
|
pub async fn run_pipeline_with_embedding_with_metrics(
|
||||||
db_client: &SurrealDbClient,
|
db_client: &SurrealDbClient,
|
||||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||||
@@ -218,45 +289,52 @@ pub async fn run_pipeline_with_embedding_with_metrics(
|
|||||||
config: RetrievalConfig,
|
config: RetrievalConfig,
|
||||||
reranker: Option<RerankerLease>,
|
reranker: Option<RerankerLease>,
|
||||||
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
|
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
|
||||||
if config.strategy == RetrievalStrategy::Initial {
|
match config.strategy {
|
||||||
let driver = InitialStrategyDriver::new();
|
RetrievalStrategy::Initial => {
|
||||||
let run = execute_strategy(
|
let driver = InitialStrategyDriver::new();
|
||||||
driver,
|
let run = execute_strategy(
|
||||||
db_client,
|
driver,
|
||||||
openai_client,
|
db_client,
|
||||||
Some(query_embedding),
|
openai_client,
|
||||||
input_text,
|
Some(query_embedding),
|
||||||
user_id,
|
input_text,
|
||||||
config,
|
user_id,
|
||||||
reranker,
|
config,
|
||||||
false,
|
reranker,
|
||||||
)
|
false,
|
||||||
.await?;
|
)
|
||||||
return Ok(PipelineRunOutput {
|
.await?;
|
||||||
results: StrategyOutput::Entities(run.results),
|
Ok(PipelineRunOutput {
|
||||||
diagnostics: run.diagnostics,
|
results: StrategyOutput::Entities(run.results),
|
||||||
stage_timings: run.stage_timings,
|
diagnostics: run.diagnostics,
|
||||||
});
|
stage_timings: run.stage_timings,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
RetrievalStrategy::Revised => {
|
||||||
|
let driver = RevisedStrategyDriver::new();
|
||||||
|
let run = execute_strategy(
|
||||||
|
driver,
|
||||||
|
db_client,
|
||||||
|
openai_client,
|
||||||
|
Some(query_embedding),
|
||||||
|
input_text,
|
||||||
|
user_id,
|
||||||
|
config,
|
||||||
|
reranker,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(PipelineRunOutput {
|
||||||
|
results: StrategyOutput::Chunks(run.results),
|
||||||
|
diagnostics: run.diagnostics,
|
||||||
|
stage_timings: run.stage_timings,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
// Fallback for others if needed, or error. For now assuming metrics mainly for chat.
|
||||||
|
_ => Err(AppError::InternalError(
|
||||||
|
"Metrics not supported for this strategy".into(),
|
||||||
|
)),
|
||||||
}
|
}
|
||||||
|
|
||||||
let driver = RevisedStrategyDriver::new();
|
|
||||||
let run = execute_strategy(
|
|
||||||
driver,
|
|
||||||
db_client,
|
|
||||||
openai_client,
|
|
||||||
Some(query_embedding),
|
|
||||||
input_text,
|
|
||||||
user_id,
|
|
||||||
config,
|
|
||||||
reranker,
|
|
||||||
false,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
Ok(PipelineRunOutput {
|
|
||||||
results: StrategyOutput::Chunks(run.results),
|
|
||||||
diagnostics: run.diagnostics,
|
|
||||||
stage_timings: run.stage_timings,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run_pipeline_with_embedding_with_diagnostics(
|
pub async fn run_pipeline_with_embedding_with_diagnostics(
|
||||||
@@ -268,45 +346,51 @@ pub async fn run_pipeline_with_embedding_with_diagnostics(
|
|||||||
config: RetrievalConfig,
|
config: RetrievalConfig,
|
||||||
reranker: Option<RerankerLease>,
|
reranker: Option<RerankerLease>,
|
||||||
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
|
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
|
||||||
if config.strategy == RetrievalStrategy::Initial {
|
match config.strategy {
|
||||||
let driver = InitialStrategyDriver::new();
|
RetrievalStrategy::Initial => {
|
||||||
let run = execute_strategy(
|
let driver = InitialStrategyDriver::new();
|
||||||
driver,
|
let run = execute_strategy(
|
||||||
db_client,
|
driver,
|
||||||
openai_client,
|
db_client,
|
||||||
Some(query_embedding),
|
openai_client,
|
||||||
input_text,
|
Some(query_embedding),
|
||||||
user_id,
|
input_text,
|
||||||
config,
|
user_id,
|
||||||
reranker,
|
config,
|
||||||
true,
|
reranker,
|
||||||
)
|
true,
|
||||||
.await?;
|
)
|
||||||
return Ok(PipelineRunOutput {
|
.await?;
|
||||||
results: StrategyOutput::Entities(run.results),
|
Ok(PipelineRunOutput {
|
||||||
diagnostics: run.diagnostics,
|
results: StrategyOutput::Entities(run.results),
|
||||||
stage_timings: run.stage_timings,
|
diagnostics: run.diagnostics,
|
||||||
});
|
stage_timings: run.stage_timings,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
RetrievalStrategy::Revised => {
|
||||||
|
let driver = RevisedStrategyDriver::new();
|
||||||
|
let run = execute_strategy(
|
||||||
|
driver,
|
||||||
|
db_client,
|
||||||
|
openai_client,
|
||||||
|
Some(query_embedding),
|
||||||
|
input_text,
|
||||||
|
user_id,
|
||||||
|
config,
|
||||||
|
reranker,
|
||||||
|
true,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(PipelineRunOutput {
|
||||||
|
results: StrategyOutput::Chunks(run.results),
|
||||||
|
diagnostics: run.diagnostics,
|
||||||
|
stage_timings: run.stage_timings,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
_ => Err(AppError::InternalError(
|
||||||
|
"Diagnostics not supported for this strategy".into(),
|
||||||
|
)),
|
||||||
}
|
}
|
||||||
|
|
||||||
let driver = RevisedStrategyDriver::new();
|
|
||||||
let run = execute_strategy(
|
|
||||||
driver,
|
|
||||||
db_client,
|
|
||||||
openai_client,
|
|
||||||
Some(query_embedding),
|
|
||||||
input_text,
|
|
||||||
user_id,
|
|
||||||
config,
|
|
||||||
reranker,
|
|
||||||
true,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
Ok(PipelineRunOutput {
|
|
||||||
results: StrategyOutput::Chunks(run.results),
|
|
||||||
diagnostics: run.diagnostics,
|
|
||||||
stage_timings: run.stage_timings,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value {
|
pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value {
|
||||||
@@ -338,11 +422,10 @@ async fn execute_strategy<D: StrategyDriver>(
|
|||||||
query_embedding: Option<Vec<f32>>,
|
query_embedding: Option<Vec<f32>>,
|
||||||
input_text: &str,
|
input_text: &str,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
mut config: RetrievalConfig,
|
config: RetrievalConfig,
|
||||||
reranker: Option<RerankerLease>,
|
reranker: Option<RerankerLease>,
|
||||||
capture_diagnostics: bool,
|
capture_diagnostics: bool,
|
||||||
) -> Result<PipelineRunOutput<D::Output>, AppError> {
|
) -> Result<PipelineRunOutput<D::Output>, AppError> {
|
||||||
driver.override_tuning(&mut config);
|
|
||||||
let ctx = match query_embedding {
|
let ctx = match query_embedding {
|
||||||
Some(embedding) => PipelineContext::with_embedding(
|
Some(embedding) => PipelineContext::with_embedding(
|
||||||
db_client,
|
db_client,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use super::{
|
|||||||
ChunkVectorStage, CollectCandidatesStage, EmbedStage, GraphExpansionStage, PipelineContext,
|
ChunkVectorStage, CollectCandidatesStage, EmbedStage, GraphExpansionStage, PipelineContext,
|
||||||
RerankStage,
|
RerankStage,
|
||||||
},
|
},
|
||||||
BoxedStage, RetrievalConfig, RetrievalStrategy, StrategyDriver,
|
BoxedStage, StrategyDriver,
|
||||||
};
|
};
|
||||||
use crate::{RetrievedChunk, RetrievedEntity};
|
use crate::{RetrievedChunk, RetrievedEntity};
|
||||||
use common::error::AppError;
|
use common::error::AppError;
|
||||||
@@ -20,10 +20,6 @@ impl InitialStrategyDriver {
|
|||||||
impl StrategyDriver for InitialStrategyDriver {
|
impl StrategyDriver for InitialStrategyDriver {
|
||||||
type Output = Vec<RetrievedEntity>;
|
type Output = Vec<RetrievedEntity>;
|
||||||
|
|
||||||
fn strategy(&self) -> RetrievalStrategy {
|
|
||||||
RetrievalStrategy::Initial
|
|
||||||
}
|
|
||||||
|
|
||||||
fn stages(&self) -> Vec<BoxedStage> {
|
fn stages(&self) -> Vec<BoxedStage> {
|
||||||
vec![
|
vec![
|
||||||
Box::new(EmbedStage),
|
Box::new(EmbedStage),
|
||||||
@@ -51,10 +47,6 @@ impl RevisedStrategyDriver {
|
|||||||
impl StrategyDriver for RevisedStrategyDriver {
|
impl StrategyDriver for RevisedStrategyDriver {
|
||||||
type Output = Vec<RetrievedChunk>;
|
type Output = Vec<RetrievedChunk>;
|
||||||
|
|
||||||
fn strategy(&self) -> RetrievalStrategy {
|
|
||||||
RetrievalStrategy::Revised
|
|
||||||
}
|
|
||||||
|
|
||||||
fn stages(&self) -> Vec<BoxedStage> {
|
fn stages(&self) -> Vec<BoxedStage> {
|
||||||
vec![
|
vec![
|
||||||
Box::new(EmbedStage),
|
Box::new(EmbedStage),
|
||||||
@@ -67,9 +59,58 @@ impl StrategyDriver for RevisedStrategyDriver {
|
|||||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
|
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
|
||||||
Ok(ctx.take_chunk_results())
|
Ok(ctx.take_chunk_results())
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn override_tuning(&self, config: &mut RetrievalConfig) {
|
pub struct RelationshipSuggestionDriver;
|
||||||
config.tuning.entity_vector_take = 0;
|
|
||||||
config.tuning.entity_fts_take = 0;
|
impl RelationshipSuggestionDriver {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StrategyDriver for RelationshipSuggestionDriver {
|
||||||
|
type Output = Vec<RetrievedEntity>;
|
||||||
|
|
||||||
|
fn stages(&self) -> Vec<BoxedStage> {
|
||||||
|
vec![
|
||||||
|
Box::new(EmbedStage),
|
||||||
|
Box::new(CollectCandidatesStage),
|
||||||
|
Box::new(GraphExpansionStage),
|
||||||
|
// Skip ChunkAttachStage
|
||||||
|
Box::new(RerankStage),
|
||||||
|
Box::new(AssembleEntitiesStage),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
|
||||||
|
Ok(ctx.take_entity_results())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct IngestionDriver;
|
||||||
|
|
||||||
|
impl IngestionDriver {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StrategyDriver for IngestionDriver {
|
||||||
|
type Output = Vec<RetrievedEntity>;
|
||||||
|
|
||||||
|
fn stages(&self) -> Vec<BoxedStage> {
|
||||||
|
vec![
|
||||||
|
Box::new(EmbedStage),
|
||||||
|
Box::new(CollectCandidatesStage),
|
||||||
|
Box::new(GraphExpansionStage),
|
||||||
|
// Skip ChunkAttachStage
|
||||||
|
Box::new(RerankStage),
|
||||||
|
Box::new(AssembleEntitiesStage),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
|
||||||
|
Ok(ctx.take_entity_results())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user