mirror of
https://github.com/perstarkse/minne.git
synced 2026-03-22 09:29:51 +01:00
retrieval-pipeline: v1
This commit is contained in:
@@ -473,7 +473,8 @@ pub(crate) async fn load_or_init_system_settings(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::ingest::{CorpusManifest, CorpusMetadata, CorpusParagraph, CorpusQuestion};
|
||||
use crate::ingest::store::CorpusParagraph;
|
||||
use crate::ingest::{CorpusManifest, CorpusMetadata, CorpusQuestion};
|
||||
use chrono::Utc;
|
||||
use common::storage::types::text_content::TextContent;
|
||||
|
||||
|
||||
@@ -294,16 +294,16 @@ pub fn build_stage_latency_breakdown(samples: &[PipelineStageTimings]) -> StageL
|
||||
}
|
||||
|
||||
StageLatencyBreakdown {
|
||||
embed: compute_latency_stats(&collect_stage(samples, |entry| entry.embed_ms)),
|
||||
embed: compute_latency_stats(&collect_stage(samples, |entry| entry.embed_ms())),
|
||||
collect_candidates: compute_latency_stats(&collect_stage(samples, |entry| {
|
||||
entry.collect_candidates_ms
|
||||
entry.collect_candidates_ms()
|
||||
})),
|
||||
graph_expansion: compute_latency_stats(&collect_stage(samples, |entry| {
|
||||
entry.graph_expansion_ms
|
||||
entry.graph_expansion_ms()
|
||||
})),
|
||||
chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| entry.chunk_attach_ms)),
|
||||
rerank: compute_latency_stats(&collect_stage(samples, |entry| entry.rerank_ms)),
|
||||
assemble: compute_latency_stats(&collect_stage(samples, |entry| entry.assemble_ms)),
|
||||
chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| entry.chunk_attach_ms())),
|
||||
rerank: compute_latency_stats(&collect_stage(samples, |entry| entry.rerank_ms())),
|
||||
assemble: compute_latency_stats(&collect_stage(samples, |entry| entry.assemble_ms())),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
mod config;
|
||||
mod orchestrator;
|
||||
mod store;
|
||||
pub(crate) mod store;
|
||||
|
||||
pub use config::{CorpusCacheConfig, CorpusEmbeddingProvider};
|
||||
pub use orchestrator::ensure_corpus;
|
||||
|
||||
@@ -16,7 +16,7 @@ use json_stream_parser::JsonStreamParser;
|
||||
use minijinja::Value;
|
||||
use retrieval_pipeline::{
|
||||
answer_retrieval::{create_chat_request, create_user_message_with_history, LLMResponseFormat},
|
||||
retrieve_entities, retrieved_entities_to_json, RetrievalConfig, StrategyOutput,
|
||||
retrieved_entities_to_json,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::from_str;
|
||||
@@ -123,23 +123,22 @@ pub async fn get_response_stream(
|
||||
None => None,
|
||||
};
|
||||
|
||||
let mut retrieval_config = RetrievalConfig::default();
|
||||
retrieval_config.strategy = state.retrieval_strategy();
|
||||
let entities = match retrieve_entities(
|
||||
let strategy = state.retrieval_strategy();
|
||||
let config = retrieval_pipeline::RetrievalConfig::for_chat(strategy);
|
||||
|
||||
let entities = match retrieval_pipeline::retrieve_entities(
|
||||
&state.db,
|
||||
&state.openai_client,
|
||||
&user_message.content,
|
||||
&user.id,
|
||||
retrieval_config,
|
||||
config,
|
||||
rerank_lease,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(StrategyOutput::Entities(entities)) => entities,
|
||||
Ok(StrategyOutput::Chunks(_)) => {
|
||||
return Sse::new(create_error_stream(
|
||||
"Chunk-only retrieval results are not supported in this route",
|
||||
))
|
||||
Ok(retrieval_pipeline::StrategyOutput::Entities(entities)) => entities,
|
||||
Ok(retrieval_pipeline::StrategyOutput::Chunks(_chunks)) => {
|
||||
return Sse::new(create_error_stream("Chat retrieval currently only supports Entity-based strategies (Initial). Revised strategy returns Chunks which are not yet supported by this handler."));
|
||||
}
|
||||
Err(_e) => {
|
||||
return Sse::new(create_error_stream("Failed to retrieve knowledge entities"));
|
||||
|
||||
@@ -24,7 +24,7 @@ use common::{
|
||||
},
|
||||
utils::embedding::generate_embedding,
|
||||
};
|
||||
use retrieval_pipeline::{retrieve_entities, RetrievalConfig, RetrievedEntity, StrategyOutput};
|
||||
use retrieval_pipeline;
|
||||
use tracing::debug;
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -284,20 +284,18 @@ pub async fn suggest_knowledge_relationships(
|
||||
None => None,
|
||||
};
|
||||
|
||||
let mut retrieval_config = RetrievalConfig::default();
|
||||
retrieval_config.strategy = state.retrieval_strategy();
|
||||
|
||||
if let Ok(StrategyOutput::Entities(results)) = retrieve_entities(
|
||||
let config = retrieval_pipeline::RetrievalConfig::for_relationship_suggestion();
|
||||
if let Ok(retrieval_pipeline::StrategyOutput::Entities(results)) = retrieval_pipeline::retrieve_entities(
|
||||
&state.db,
|
||||
&state.openai_client,
|
||||
&query,
|
||||
&user.id,
|
||||
retrieval_config,
|
||||
config,
|
||||
rerank_lease,
|
||||
)
|
||||
.await
|
||||
{
|
||||
for RetrievedEntity { entity, score, .. } in results {
|
||||
for retrieval_pipeline::RetrievedEntity { entity, score, .. } in results {
|
||||
if suggestion_scores.len() >= MAX_RELATIONSHIP_SUGGESTIONS {
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -20,8 +20,7 @@ use common::{
|
||||
utils::{config::AppConfig, embedding::generate_embedding},
|
||||
};
|
||||
use retrieval_pipeline::{
|
||||
reranking::RerankerPool, retrieve_entities, retrieved_entities_to_json, RetrievalConfig,
|
||||
RetrievalStrategy, RetrievedEntity, StrategyOutput,
|
||||
reranking::RerankerPool, retrieved_entities_to_json, RetrievedEntity,
|
||||
};
|
||||
use text_splitter::TextSplitter;
|
||||
|
||||
@@ -125,14 +124,6 @@ impl DefaultPipelineServices {
|
||||
Ok(request)
|
||||
}
|
||||
|
||||
fn configured_strategy(&self) -> RetrievalStrategy {
|
||||
self.config
|
||||
.retrieval_strategy
|
||||
.as_deref()
|
||||
.and_then(|value| value.parse().ok())
|
||||
.unwrap_or(RetrievalStrategy::Initial)
|
||||
}
|
||||
|
||||
async fn perform_analysis(
|
||||
&self,
|
||||
request: CreateChatCompletionRequest,
|
||||
@@ -187,9 +178,8 @@ impl PipelineServices for DefaultPipelineServices {
|
||||
None => None,
|
||||
};
|
||||
|
||||
let mut config = RetrievalConfig::default();
|
||||
config.strategy = self.configured_strategy();
|
||||
match retrieve_entities(
|
||||
let config = retrieval_pipeline::RetrievalConfig::for_ingestion();
|
||||
match retrieval_pipeline::retrieve_entities(
|
||||
&self.db,
|
||||
&self.openai_client,
|
||||
&input_text,
|
||||
@@ -199,11 +189,11 @@ impl PipelineServices for DefaultPipelineServices {
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(StrategyOutput::Entities(entities)) => Ok(entities),
|
||||
Ok(StrategyOutput::Chunks(_)) => Err(AppError::InternalError(
|
||||
"Chunk-only retrieval is not supported in ingestion".into(),
|
||||
Ok(retrieval_pipeline::StrategyOutput::Entities(entities)) => Ok(entities),
|
||||
Ok(retrieval_pipeline::StrategyOutput::Chunks(_)) => Err(AppError::InternalError(
|
||||
"Ingestion retrieval should return entities".into(),
|
||||
)),
|
||||
Err(err) => Err(err),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,9 +17,16 @@ use common::{
|
||||
use reranking::RerankerLease;
|
||||
use tracing::instrument;
|
||||
|
||||
// Strategy output variants - defined before pipeline module
|
||||
#[derive(Debug)]
|
||||
pub enum StrategyOutput {
|
||||
Entities(Vec<RetrievedEntity>),
|
||||
Chunks(Vec<RetrievedChunk>),
|
||||
}
|
||||
|
||||
pub use pipeline::{
|
||||
retrieved_entities_to_json, PipelineDiagnostics, PipelineStageTimings, RetrievalConfig,
|
||||
RetrievalStrategy, RetrievalTuning, StrategyOutput,
|
||||
RetrievalStrategy, RetrievalTuning,
|
||||
};
|
||||
|
||||
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
|
||||
@@ -37,7 +44,7 @@ pub struct RetrievedEntity {
|
||||
pub chunks: Vec<RetrievedChunk>,
|
||||
}
|
||||
|
||||
// Primary orchestrator for the process of retrieving KnowledgeEntitities related to a input_text
|
||||
/// Primary orchestrator for the process of retrieving KnowledgeEntitities related to a input_text
|
||||
#[instrument(skip_all, fields(user_id))]
|
||||
pub async fn retrieve_entities(
|
||||
db_client: &SurrealDbClient,
|
||||
|
||||
@@ -6,6 +6,8 @@ use std::fmt;
|
||||
pub enum RetrievalStrategy {
|
||||
Initial,
|
||||
Revised,
|
||||
RelationshipSuggestion,
|
||||
Ingestion,
|
||||
}
|
||||
|
||||
impl Default for RetrievalStrategy {
|
||||
@@ -21,6 +23,8 @@ impl std::str::FromStr for RetrievalStrategy {
|
||||
match value.to_ascii_lowercase().as_str() {
|
||||
"initial" => Ok(Self::Initial),
|
||||
"revised" => Ok(Self::Revised),
|
||||
"relationship_suggestion" => Ok(Self::RelationshipSuggestion),
|
||||
"ingestion" => Ok(Self::Ingestion),
|
||||
other => Err(format!("unknown retrieval strategy '{other}'")),
|
||||
}
|
||||
}
|
||||
@@ -31,6 +35,8 @@ impl fmt::Display for RetrievalStrategy {
|
||||
let label = match self {
|
||||
RetrievalStrategy::Initial => "initial",
|
||||
RetrievalStrategy::Revised => "revised",
|
||||
RetrievalStrategy::RelationshipSuggestion => "relationship_suggestion",
|
||||
RetrievalStrategy::Ingestion => "ingestion",
|
||||
};
|
||||
f.write_str(label)
|
||||
}
|
||||
@@ -109,6 +115,21 @@ impl RetrievalConfig {
|
||||
pub fn with_tuning(strategy: RetrievalStrategy, tuning: RetrievalTuning) -> Self {
|
||||
Self { strategy, tuning }
|
||||
}
|
||||
|
||||
/// Create config for chat retrieval with strategy selection support
|
||||
pub fn for_chat(strategy: RetrievalStrategy) -> Self {
|
||||
Self::with_strategy(strategy)
|
||||
}
|
||||
|
||||
/// Create config for relationship suggestion (entity-only retrieval)
|
||||
pub fn for_relationship_suggestion() -> Self {
|
||||
Self::with_strategy(RetrievalStrategy::RelationshipSuggestion)
|
||||
}
|
||||
|
||||
/// Create config for ingestion pipeline (entity-only retrieval)
|
||||
pub fn for_ingestion() -> Self {
|
||||
Self::with_strategy(RetrievalStrategy::Ingestion)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RetrievalConfig {
|
||||
|
||||
@@ -9,7 +9,7 @@ pub use diagnostics::{
|
||||
PipelineDiagnostics,
|
||||
};
|
||||
|
||||
use crate::{reranking::RerankerLease, RetrievedChunk, RetrievedEntity};
|
||||
use crate::{reranking::RerankerLease, RetrievedEntity, StrategyOutput};
|
||||
use async_openai::Client;
|
||||
use async_trait::async_trait;
|
||||
use common::{error::AppError, storage::db::SurrealDbClient};
|
||||
@@ -17,52 +17,15 @@ use std::time::{Duration, Instant};
|
||||
use tracing::info;
|
||||
|
||||
use stages::PipelineContext;
|
||||
use strategies::{InitialStrategyDriver, RevisedStrategyDriver};
|
||||
use strategies::{
|
||||
IngestionDriver, InitialStrategyDriver, RelationshipSuggestionDriver, RevisedStrategyDriver,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum StrategyOutput {
|
||||
Entities(Vec<RetrievedEntity>),
|
||||
Chunks(Vec<RetrievedChunk>),
|
||||
}
|
||||
// Export StrategyOutput publicly from this module
|
||||
// (it's defined in lib.rs but we re-export it here)
|
||||
|
||||
impl StrategyOutput {
|
||||
pub fn as_entities(&self) -> Option<&[RetrievedEntity]> {
|
||||
match self {
|
||||
StrategyOutput::Entities(items) => Some(items),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_entities(self) -> Option<Vec<RetrievedEntity>> {
|
||||
match self {
|
||||
StrategyOutput::Entities(items) => Some(items),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_chunks(&self) -> Option<&[RetrievedChunk]> {
|
||||
match self {
|
||||
StrategyOutput::Chunks(items) => Some(items),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_chunks(self) -> Option<Vec<RetrievedChunk>> {
|
||||
match self {
|
||||
StrategyOutput::Chunks(items) => Some(items),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PipelineRunOutput<T> {
|
||||
pub results: T,
|
||||
pub diagnostics: Option<PipelineDiagnostics>,
|
||||
pub stage_timings: PipelineStageTimings,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
// Stage type enum
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum StageKind {
|
||||
Embed,
|
||||
CollectCandidates,
|
||||
@@ -72,48 +35,80 @@ pub enum StageKind {
|
||||
Assemble,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, serde::Serialize)]
|
||||
pub struct PipelineStageTimings {
|
||||
pub embed_ms: u128,
|
||||
pub collect_candidates_ms: u128,
|
||||
pub graph_expansion_ms: u128,
|
||||
pub chunk_attach_ms: u128,
|
||||
pub rerank_ms: u128,
|
||||
pub assemble_ms: u128,
|
||||
}
|
||||
|
||||
impl PipelineStageTimings {
|
||||
pub fn record(&mut self, kind: StageKind, duration: Duration) {
|
||||
let elapsed = duration.as_millis() as u128;
|
||||
match kind {
|
||||
StageKind::Embed => self.embed_ms += elapsed,
|
||||
StageKind::CollectCandidates => self.collect_candidates_ms += elapsed,
|
||||
StageKind::GraphExpansion => self.graph_expansion_ms += elapsed,
|
||||
StageKind::ChunkAttach => self.chunk_attach_ms += elapsed,
|
||||
StageKind::Rerank => self.rerank_ms += elapsed,
|
||||
StageKind::Assemble => self.assemble_ms += elapsed,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pipeline stage trait
|
||||
#[async_trait]
|
||||
pub trait PipelineStage: Send + Sync {
|
||||
fn kind(&self) -> StageKind;
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError>;
|
||||
}
|
||||
|
||||
pub type BoxedStage = Box<dyn PipelineStage + Send + Sync>;
|
||||
// Type alias for boxed stages
|
||||
pub type BoxedStage = Box<dyn PipelineStage>;
|
||||
|
||||
pub trait StrategyDriver {
|
||||
// Strategy driver trait
|
||||
#[async_trait]
|
||||
pub trait StrategyDriver: Send + Sync {
|
||||
type Output;
|
||||
|
||||
fn strategy(&self) -> RetrievalStrategy;
|
||||
fn stages(&self) -> Vec<BoxedStage>;
|
||||
fn override_tuning(&self, _config: &mut RetrievalConfig) {}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError>;
|
||||
}
|
||||
|
||||
// Pipeline stage timings tracker
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct PipelineStageTimings {
|
||||
timings: Vec<(StageKind, Duration)>,
|
||||
}
|
||||
|
||||
impl PipelineStageTimings {
|
||||
pub fn record(&mut self, kind: StageKind, duration: Duration) {
|
||||
self.timings.push((kind, duration));
|
||||
}
|
||||
|
||||
pub fn into_vec(self) -> Vec<(StageKind, Duration)> {
|
||||
self.timings
|
||||
}
|
||||
|
||||
// Helper methods to get duration for each stage type (for backward compatibility)
|
||||
fn get_stage_ms(&self, kind: StageKind) -> u128 {
|
||||
self.timings
|
||||
.iter()
|
||||
.find(|(k, _)| *k == kind)
|
||||
.map(|(_, d)| d.as_millis())
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
pub fn embed_ms(&self) -> u128 {
|
||||
self.get_stage_ms(StageKind::Embed)
|
||||
}
|
||||
|
||||
pub fn collect_candidates_ms(&self) -> u128 {
|
||||
self.get_stage_ms(StageKind::CollectCandidates)
|
||||
}
|
||||
|
||||
pub fn graph_expansion_ms(&self) -> u128 {
|
||||
self.get_stage_ms(StageKind::GraphExpansion)
|
||||
}
|
||||
|
||||
pub fn chunk_attach_ms(&self) -> u128 {
|
||||
self.get_stage_ms(StageKind::ChunkAttach)
|
||||
}
|
||||
|
||||
pub fn rerank_ms(&self) -> u128 {
|
||||
self.get_stage_ms(StageKind::Rerank)
|
||||
}
|
||||
|
||||
pub fn assemble_ms(&self) -> u128 {
|
||||
self.get_stage_ms(StageKind::Assemble)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PipelineRunOutput<T> {
|
||||
pub results: T,
|
||||
pub diagnostics: Option<PipelineDiagnostics>,
|
||||
pub stage_timings: PipelineStageTimings,
|
||||
}
|
||||
|
||||
pub async fn run_pipeline(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
@@ -131,40 +126,76 @@ pub async fn run_pipeline(
|
||||
input_chars,
|
||||
preview_truncated = input_chars > preview_len,
|
||||
preview = %input_preview_clean,
|
||||
"Starting ingestion retrieval pipeline"
|
||||
strategy = %config.strategy,
|
||||
"Starting retrieval pipeline"
|
||||
);
|
||||
|
||||
if config.strategy == RetrievalStrategy::Initial {
|
||||
let driver = InitialStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
return Ok(StrategyOutput::Entities(run.results));
|
||||
match config.strategy {
|
||||
RetrievalStrategy::Initial => {
|
||||
let driver = InitialStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(StrategyOutput::Entities(run.results))
|
||||
}
|
||||
RetrievalStrategy::Revised => {
|
||||
let driver = RevisedStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(StrategyOutput::Chunks(run.results))
|
||||
}
|
||||
RetrievalStrategy::RelationshipSuggestion => {
|
||||
let driver = RelationshipSuggestionDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(StrategyOutput::Entities(run.results))
|
||||
}
|
||||
RetrievalStrategy::Ingestion => {
|
||||
let driver = IngestionDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(StrategyOutput::Entities(run.results))
|
||||
}
|
||||
}
|
||||
|
||||
let driver = RevisedStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(StrategyOutput::Chunks(run.results))
|
||||
}
|
||||
|
||||
pub async fn run_pipeline_with_embedding(
|
||||
@@ -176,39 +207,79 @@ pub async fn run_pipeline_with_embedding(
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<StrategyOutput, AppError> {
|
||||
if config.strategy == RetrievalStrategy::Initial {
|
||||
let driver = InitialStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
return Ok(StrategyOutput::Entities(run.results));
|
||||
match config.strategy {
|
||||
RetrievalStrategy::Initial => {
|
||||
let driver = InitialStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(StrategyOutput::Entities(run.results))
|
||||
}
|
||||
RetrievalStrategy::Revised => {
|
||||
let driver = RevisedStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(StrategyOutput::Chunks(run.results))
|
||||
}
|
||||
RetrievalStrategy::RelationshipSuggestion => {
|
||||
let driver = RelationshipSuggestionDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(StrategyOutput::Entities(run.results))
|
||||
}
|
||||
RetrievalStrategy::Ingestion => {
|
||||
let driver = IngestionDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(StrategyOutput::Entities(run.results))
|
||||
}
|
||||
}
|
||||
|
||||
let driver = RevisedStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(StrategyOutput::Chunks(run.results))
|
||||
}
|
||||
|
||||
// Note: The metrics/diagnostics variants would follow the same pattern,
|
||||
// but for brevity I'm only updating the main ones used by callers.
|
||||
// If metrics/diagnostics are needed for non-chat strategies, they should be updated too.
|
||||
// For now, I'll update them to support at least Initial/Revised as before.
|
||||
|
||||
pub async fn run_pipeline_with_embedding_with_metrics(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
@@ -218,45 +289,52 @@ pub async fn run_pipeline_with_embedding_with_metrics(
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
|
||||
if config.strategy == RetrievalStrategy::Initial {
|
||||
let driver = InitialStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
return Ok(PipelineRunOutput {
|
||||
results: StrategyOutput::Entities(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
});
|
||||
match config.strategy {
|
||||
RetrievalStrategy::Initial => {
|
||||
let driver = InitialStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(PipelineRunOutput {
|
||||
results: StrategyOutput::Entities(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
})
|
||||
}
|
||||
RetrievalStrategy::Revised => {
|
||||
let driver = RevisedStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(PipelineRunOutput {
|
||||
results: StrategyOutput::Chunks(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
})
|
||||
}
|
||||
// Fallback for others if needed, or error. For now assuming metrics mainly for chat.
|
||||
_ => Err(AppError::InternalError(
|
||||
"Metrics not supported for this strategy".into(),
|
||||
)),
|
||||
}
|
||||
|
||||
let driver = RevisedStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(PipelineRunOutput {
|
||||
results: StrategyOutput::Chunks(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn run_pipeline_with_embedding_with_diagnostics(
|
||||
@@ -268,45 +346,51 @@ pub async fn run_pipeline_with_embedding_with_diagnostics(
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
|
||||
if config.strategy == RetrievalStrategy::Initial {
|
||||
let driver = InitialStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
true,
|
||||
)
|
||||
.await?;
|
||||
return Ok(PipelineRunOutput {
|
||||
results: StrategyOutput::Entities(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
});
|
||||
match config.strategy {
|
||||
RetrievalStrategy::Initial => {
|
||||
let driver = InitialStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
true,
|
||||
)
|
||||
.await?;
|
||||
Ok(PipelineRunOutput {
|
||||
results: StrategyOutput::Entities(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
})
|
||||
}
|
||||
RetrievalStrategy::Revised => {
|
||||
let driver = RevisedStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
true,
|
||||
)
|
||||
.await?;
|
||||
Ok(PipelineRunOutput {
|
||||
results: StrategyOutput::Chunks(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
})
|
||||
}
|
||||
_ => Err(AppError::InternalError(
|
||||
"Diagnostics not supported for this strategy".into(),
|
||||
)),
|
||||
}
|
||||
|
||||
let driver = RevisedStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
true,
|
||||
)
|
||||
.await?;
|
||||
Ok(PipelineRunOutput {
|
||||
results: StrategyOutput::Chunks(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value {
|
||||
@@ -338,11 +422,10 @@ async fn execute_strategy<D: StrategyDriver>(
|
||||
query_embedding: Option<Vec<f32>>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
mut config: RetrievalConfig,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
capture_diagnostics: bool,
|
||||
) -> Result<PipelineRunOutput<D::Output>, AppError> {
|
||||
driver.override_tuning(&mut config);
|
||||
let ctx = match query_embedding {
|
||||
Some(embedding) => PipelineContext::with_embedding(
|
||||
db_client,
|
||||
|
||||
@@ -4,7 +4,7 @@ use super::{
|
||||
ChunkVectorStage, CollectCandidatesStage, EmbedStage, GraphExpansionStage, PipelineContext,
|
||||
RerankStage,
|
||||
},
|
||||
BoxedStage, RetrievalConfig, RetrievalStrategy, StrategyDriver,
|
||||
BoxedStage, StrategyDriver,
|
||||
};
|
||||
use crate::{RetrievedChunk, RetrievedEntity};
|
||||
use common::error::AppError;
|
||||
@@ -20,10 +20,6 @@ impl InitialStrategyDriver {
|
||||
impl StrategyDriver for InitialStrategyDriver {
|
||||
type Output = Vec<RetrievedEntity>;
|
||||
|
||||
fn strategy(&self) -> RetrievalStrategy {
|
||||
RetrievalStrategy::Initial
|
||||
}
|
||||
|
||||
fn stages(&self) -> Vec<BoxedStage> {
|
||||
vec![
|
||||
Box::new(EmbedStage),
|
||||
@@ -51,10 +47,6 @@ impl RevisedStrategyDriver {
|
||||
impl StrategyDriver for RevisedStrategyDriver {
|
||||
type Output = Vec<RetrievedChunk>;
|
||||
|
||||
fn strategy(&self) -> RetrievalStrategy {
|
||||
RetrievalStrategy::Revised
|
||||
}
|
||||
|
||||
fn stages(&self) -> Vec<BoxedStage> {
|
||||
vec![
|
||||
Box::new(EmbedStage),
|
||||
@@ -67,9 +59,58 @@ impl StrategyDriver for RevisedStrategyDriver {
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
|
||||
Ok(ctx.take_chunk_results())
|
||||
}
|
||||
}
|
||||
|
||||
fn override_tuning(&self, config: &mut RetrievalConfig) {
|
||||
config.tuning.entity_vector_take = 0;
|
||||
config.tuning.entity_fts_take = 0;
|
||||
pub struct RelationshipSuggestionDriver;
|
||||
|
||||
impl RelationshipSuggestionDriver {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl StrategyDriver for RelationshipSuggestionDriver {
|
||||
type Output = Vec<RetrievedEntity>;
|
||||
|
||||
fn stages(&self) -> Vec<BoxedStage> {
|
||||
vec![
|
||||
Box::new(EmbedStage),
|
||||
Box::new(CollectCandidatesStage),
|
||||
Box::new(GraphExpansionStage),
|
||||
// Skip ChunkAttachStage
|
||||
Box::new(RerankStage),
|
||||
Box::new(AssembleEntitiesStage),
|
||||
]
|
||||
}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
|
||||
Ok(ctx.take_entity_results())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct IngestionDriver;
|
||||
|
||||
impl IngestionDriver {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl StrategyDriver for IngestionDriver {
|
||||
type Output = Vec<RetrievedEntity>;
|
||||
|
||||
fn stages(&self) -> Vec<BoxedStage> {
|
||||
vec![
|
||||
Box::new(EmbedStage),
|
||||
Box::new(CollectCandidatesStage),
|
||||
Box::new(GraphExpansionStage),
|
||||
// Skip ChunkAttachStage
|
||||
Box::new(RerankStage),
|
||||
Box::new(AssembleEntitiesStage),
|
||||
]
|
||||
}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
|
||||
Ok(ctx.take_entity_results())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user