mirror of
https://github.com/perstarkse/minne.git
synced 2026-05-28 10:29:30 +02:00
clippy: adhere to pedantic clippy, uniform test error handling
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::fmt;
|
||||
|
||||
use crate::scoring::FusionWeights;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RetrievalStrategy {
|
||||
/// Primary hybrid chunk retrieval for search/chat (formerly Revised)
|
||||
#[default]
|
||||
Default,
|
||||
/// Entity retrieval for suggesting relationships when creating manual entities
|
||||
RelationshipSuggestion,
|
||||
@@ -29,12 +30,6 @@ pub enum SearchTarget {
|
||||
Both,
|
||||
}
|
||||
|
||||
impl Default for RetrievalStrategy {
|
||||
fn default() -> Self {
|
||||
Self::Default
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for RetrievalStrategy {
|
||||
type Err = String;
|
||||
|
||||
@@ -70,6 +65,91 @@ impl fmt::Display for RetrievalStrategy {
|
||||
}
|
||||
}
|
||||
|
||||
/// Two-variant flag that serializes as a bool for backward compatibility.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub enum BoolFlag {
|
||||
#[default]
|
||||
Disabled,
|
||||
Enabled,
|
||||
}
|
||||
|
||||
impl BoolFlag {
|
||||
pub const fn as_bool(self) -> bool {
|
||||
matches!(self, BoolFlag::Enabled)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<bool> for BoolFlag {
|
||||
fn from(value: bool) -> Self {
|
||||
if value {
|
||||
BoolFlag::Enabled
|
||||
} else {
|
||||
BoolFlag::Disabled
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for BoolFlag {
|
||||
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
serializer.serialize_bool(self.as_bool())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for BoolFlag {
|
||||
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
|
||||
bool::deserialize(deserializer).map(|b| {
|
||||
if b {
|
||||
BoolFlag::Enabled
|
||||
} else {
|
||||
BoolFlag::Disabled
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct RetrievalTuningFlags {
|
||||
pub rerank_scores_only: BoolFlag,
|
||||
pub normalize_vector_scores: BoolFlag,
|
||||
pub normalize_fts_scores: BoolFlag,
|
||||
pub chunk_rrf_use_vector: BoolFlag,
|
||||
pub chunk_rrf_use_fts: BoolFlag,
|
||||
}
|
||||
|
||||
impl RetrievalTuningFlags {
|
||||
pub const fn rerank_scores_only(&self) -> bool {
|
||||
self.rerank_scores_only.as_bool()
|
||||
}
|
||||
|
||||
pub const fn normalize_vector_scores(&self) -> bool {
|
||||
self.normalize_vector_scores.as_bool()
|
||||
}
|
||||
|
||||
pub const fn normalize_fts_scores(&self) -> bool {
|
||||
self.normalize_fts_scores.as_bool()
|
||||
}
|
||||
|
||||
pub const fn chunk_rrf_use_vector(&self) -> bool {
|
||||
self.chunk_rrf_use_vector.as_bool()
|
||||
}
|
||||
|
||||
pub const fn chunk_rrf_use_fts(&self) -> bool {
|
||||
self.chunk_rrf_use_fts.as_bool()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RetrievalTuningFlags {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
rerank_scores_only: BoolFlag::Disabled,
|
||||
normalize_vector_scores: BoolFlag::Disabled,
|
||||
normalize_fts_scores: BoolFlag::Enabled,
|
||||
chunk_rrf_use_vector: BoolFlag::Enabled,
|
||||
chunk_rrf_use_fts: BoolFlag::Enabled,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tunable parameters that govern each retrieval stage.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RetrievalTuning {
|
||||
@@ -89,15 +169,11 @@ pub struct RetrievalTuning {
|
||||
pub graph_seed_min_score: f32,
|
||||
pub graph_vector_inheritance: f32,
|
||||
pub rerank_blend_weight: f32,
|
||||
pub rerank_scores_only: bool,
|
||||
pub flags: RetrievalTuningFlags,
|
||||
pub rerank_keep_top: usize,
|
||||
pub chunk_result_cap: usize,
|
||||
/// Optional fusion weights for hybrid search. If None, uses default weights.
|
||||
pub fusion_weights: Option<FusionWeights>,
|
||||
/// Normalize vector similarity scores before fusion (default: true)
|
||||
pub normalize_vector_scores: bool,
|
||||
/// Normalize FTS (BM25) scores before fusion (default: true)
|
||||
pub normalize_fts_scores: bool,
|
||||
/// Reciprocal rank fusion k value for chunk merging in Revised strategy.
|
||||
#[serde(default = "default_chunk_rrf_k")]
|
||||
pub chunk_rrf_k: f32,
|
||||
@@ -107,12 +183,6 @@ pub struct RetrievalTuning {
|
||||
/// Weight applied to chunk FTS ranks in RRF.
|
||||
#[serde(default = "default_chunk_rrf_fts_weight")]
|
||||
pub chunk_rrf_fts_weight: f32,
|
||||
/// Whether to include vector rankings in RRF.
|
||||
#[serde(default = "default_chunk_rrf_use_vector")]
|
||||
pub chunk_rrf_use_vector: bool,
|
||||
/// Whether to include chunk FTS rankings in RRF.
|
||||
#[serde(default = "default_chunk_rrf_use_fts")]
|
||||
pub chunk_rrf_use_fts: bool,
|
||||
}
|
||||
|
||||
impl Default for RetrievalTuning {
|
||||
@@ -134,26 +204,19 @@ impl Default for RetrievalTuning {
|
||||
graph_seed_min_score: 0.4,
|
||||
graph_vector_inheritance: 0.6,
|
||||
rerank_blend_weight: 0.65,
|
||||
rerank_scores_only: false,
|
||||
flags: RetrievalTuningFlags::default(),
|
||||
rerank_keep_top: 8,
|
||||
chunk_result_cap: 5,
|
||||
fusion_weights: None,
|
||||
// Vector scores (cosine similarity) are already in [0,1] range
|
||||
// Normalization only helps when there's significant variation
|
||||
normalize_vector_scores: false,
|
||||
// FTS scores (BM25) are unbounded, normalization helps more
|
||||
normalize_fts_scores: true,
|
||||
chunk_rrf_k: default_chunk_rrf_k(),
|
||||
chunk_rrf_vector_weight: default_chunk_rrf_vector_weight(),
|
||||
chunk_rrf_fts_weight: default_chunk_rrf_fts_weight(),
|
||||
chunk_rrf_use_vector: default_chunk_rrf_use_vector(),
|
||||
chunk_rrf_use_fts: default_chunk_rrf_use_fts(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper containing tuning plus future flags for per-request overrides.
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct RetrievalConfig {
|
||||
pub strategy: RetrievalStrategy,
|
||||
pub tuning: RetrievalTuning,
|
||||
@@ -211,16 +274,6 @@ impl RetrievalConfig {
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RetrievalConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
strategy: RetrievalStrategy::default(),
|
||||
tuning: RetrievalTuning::default(),
|
||||
search_target: SearchTarget::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const fn default_chunk_rrf_k() -> f32 {
|
||||
60.0
|
||||
}
|
||||
@@ -233,10 +286,4 @@ const fn default_chunk_rrf_fts_weight() -> f32 {
|
||||
1.0
|
||||
}
|
||||
|
||||
const fn default_chunk_rrf_use_vector() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
const fn default_chunk_rrf_use_fts() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ use serde::Serialize;
|
||||
|
||||
/// Captures instrumentation for each hybrid retrieval stage when diagnostics are enabled.
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct PipelineDiagnostics {
|
||||
pub struct Diagnostics {
|
||||
pub collect_candidates: Option<CollectCandidatesStats>,
|
||||
pub enrich_chunks_from_entities: Option<ChunkEnrichmentStats>,
|
||||
pub assemble: Option<AssembleStats>,
|
||||
|
||||
@@ -3,10 +3,11 @@ mod diagnostics;
|
||||
mod stages;
|
||||
mod strategies;
|
||||
|
||||
pub use config::{RetrievalConfig, RetrievalStrategy, RetrievalTuning, SearchTarget};
|
||||
pub use config::{
|
||||
RetrievalConfig, RetrievalStrategy, RetrievalTuning, RetrievalTuningFlags, SearchTarget,
|
||||
};
|
||||
pub use diagnostics::{
|
||||
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
|
||||
PipelineDiagnostics,
|
||||
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace, Diagnostics,
|
||||
};
|
||||
|
||||
use crate::{reranking::RerankerLease, RetrievedEntity, StrategyOutput};
|
||||
@@ -37,13 +38,13 @@ pub enum StageKind {
|
||||
|
||||
// Pipeline stage trait
|
||||
#[async_trait]
|
||||
pub trait PipelineStage: Send + Sync {
|
||||
pub trait Stage: Send + Sync {
|
||||
fn kind(&self) -> StageKind;
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError>;
|
||||
}
|
||||
|
||||
// Type alias for boxed stages
|
||||
pub type BoxedStage = Box<dyn PipelineStage>;
|
||||
pub type BoxedStage = Box<dyn Stage>;
|
||||
|
||||
// Strategy driver trait
|
||||
#[async_trait]
|
||||
@@ -51,16 +52,16 @@ pub trait StrategyDriver: Send + Sync {
|
||||
type Output;
|
||||
|
||||
fn stages(&self) -> Vec<BoxedStage>;
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError>;
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>>;
|
||||
}
|
||||
|
||||
// Pipeline stage timings tracker
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct PipelineStageTimings {
|
||||
pub struct StageTimings {
|
||||
timings: Vec<(StageKind, Duration)>,
|
||||
}
|
||||
|
||||
impl PipelineStageTimings {
|
||||
impl StageTimings {
|
||||
pub fn record(&mut self, kind: StageKind, duration: Duration) {
|
||||
self.timings.push((kind, duration));
|
||||
}
|
||||
@@ -74,8 +75,7 @@ impl PipelineStageTimings {
|
||||
self.timings
|
||||
.iter()
|
||||
.find(|(k, _)| *k == kind)
|
||||
.map(|(_, d)| d.as_millis())
|
||||
.unwrap_or(0)
|
||||
.map_or(0, |(_, d)| d.as_millis())
|
||||
}
|
||||
|
||||
pub fn embed_ms(&self) -> u128 {
|
||||
@@ -103,228 +103,100 @@ impl PipelineStageTimings {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PipelineRunOutput<T> {
|
||||
pub struct RunOutput<T> {
|
||||
pub results: T,
|
||||
pub diagnostics: Option<PipelineDiagnostics>,
|
||||
pub stage_timings: PipelineStageTimings,
|
||||
pub diagnostics: Option<Diagnostics>,
|
||||
pub stage_timings: StageTimings,
|
||||
}
|
||||
|
||||
pub async fn run_pipeline(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<StrategyOutput, AppError> {
|
||||
let input_chars = input_text.chars().count();
|
||||
let input_preview: String = input_text.chars().take(120).collect();
|
||||
pub async fn execute(params: StrategyParams<'_>) -> Result<StrategyOutput, AppError> {
|
||||
let input_chars = params.input_text.chars().count();
|
||||
let input_preview: String = params.input_text.chars().take(120).collect();
|
||||
let input_preview_clean = input_preview.replace('\n', " ");
|
||||
let preview_len = input_preview_clean.chars().count();
|
||||
info!(
|
||||
%user_id,
|
||||
user_id = %params.user_id,
|
||||
input_chars,
|
||||
preview_truncated = input_chars > preview_len,
|
||||
preview = %input_preview_clean,
|
||||
strategy = %config.strategy,
|
||||
strategy = %params.config.strategy,
|
||||
"Starting retrieval pipeline"
|
||||
);
|
||||
|
||||
match config.strategy {
|
||||
let strategy = params.config.strategy;
|
||||
let search_target = params.config.search_target;
|
||||
|
||||
match strategy {
|
||||
RetrievalStrategy::Default => {
|
||||
let driver = DefaultStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
let run = execute_strategy(driver, params, None, false).await?;
|
||||
Ok(StrategyOutput::Chunks(run.results))
|
||||
}
|
||||
RetrievalStrategy::RelationshipSuggestion => {
|
||||
let driver = RelationshipSuggestionDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
let run = execute_strategy(driver, params, None, false).await?;
|
||||
Ok(StrategyOutput::Entities(run.results))
|
||||
}
|
||||
RetrievalStrategy::Ingestion => {
|
||||
let driver = IngestionDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
let run = execute_strategy(driver, params, None, false).await?;
|
||||
Ok(StrategyOutput::Entities(run.results))
|
||||
}
|
||||
RetrievalStrategy::Search => {
|
||||
let search_target = config.search_target;
|
||||
let driver = SearchStrategyDriver::new(search_target);
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
let run = execute_strategy(driver, params, None, false).await?;
|
||||
Ok(StrategyOutput::Search(run.results))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_pipeline_with_embedding(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
|
||||
params: StrategyParams<'_>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<StrategyOutput, AppError> {
|
||||
match config.strategy {
|
||||
let strategy = params.config.strategy;
|
||||
let search_target = params.config.search_target;
|
||||
|
||||
match strategy {
|
||||
RetrievalStrategy::Default => {
|
||||
let driver = DefaultStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
||||
Ok(StrategyOutput::Chunks(run.results))
|
||||
}
|
||||
RetrievalStrategy::RelationshipSuggestion => {
|
||||
let driver = RelationshipSuggestionDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
||||
Ok(StrategyOutput::Entities(run.results))
|
||||
}
|
||||
RetrievalStrategy::Ingestion => {
|
||||
let driver = IngestionDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
||||
Ok(StrategyOutput::Entities(run.results))
|
||||
}
|
||||
RetrievalStrategy::Search => {
|
||||
let search_target = config.search_target;
|
||||
let driver = SearchStrategyDriver::new(search_target);
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
||||
Ok(StrategyOutput::Search(run.results))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Note: The metrics/diagnostics variants would follow the same pattern,
|
||||
// but for brevity I'm only updating the main ones used by callers.
|
||||
// If metrics/diagnostics are needed for non-chat strategies, they should be updated too.
|
||||
// For now, I'll update them to support at least Initial/Revised as before.
|
||||
|
||||
pub async fn run_pipeline_with_embedding_with_metrics(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
|
||||
params: StrategyParams<'_>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
|
||||
match config.strategy {
|
||||
) -> Result<RunOutput<StrategyOutput>, AppError> {
|
||||
let strategy = params.config.strategy;
|
||||
|
||||
match strategy {
|
||||
RetrievalStrategy::Default => {
|
||||
let driver = DefaultStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(PipelineRunOutput {
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
||||
Ok(RunOutput {
|
||||
results: StrategyOutput::Chunks(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
})
|
||||
}
|
||||
// Fallback for others if needed, or error. For now assuming metrics mainly for chat.
|
||||
_ => Err(AppError::InternalError(
|
||||
"Metrics not supported for this strategy".into(),
|
||||
)),
|
||||
@@ -332,32 +204,16 @@ pub async fn run_pipeline_with_embedding_with_metrics(
|
||||
}
|
||||
|
||||
pub async fn run_pipeline_with_embedding_with_diagnostics(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
|
||||
params: StrategyParams<'_>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
|
||||
match config.strategy {
|
||||
) -> Result<RunOutput<StrategyOutput>, AppError> {
|
||||
let strategy = params.config.strategy;
|
||||
|
||||
match strategy {
|
||||
RetrievalStrategy::Default => {
|
||||
let driver = DefaultStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
true,
|
||||
)
|
||||
.await?;
|
||||
Ok(PipelineRunOutput {
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), true).await?;
|
||||
Ok(RunOutput {
|
||||
results: StrategyOutput::Chunks(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
@@ -391,38 +247,25 @@ pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::V
|
||||
.collect::<Vec<_>>())
|
||||
}
|
||||
|
||||
pub struct StrategyParams<'a> {
|
||||
pub db_client: &'a SurrealDbClient,
|
||||
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
pub embedding_provider: Option<&'a common::utils::embedding::EmbeddingProvider>,
|
||||
pub input_text: &'a str,
|
||||
pub user_id: &'a str,
|
||||
pub config: RetrievalConfig,
|
||||
pub reranker: Option<RerankerLease>,
|
||||
}
|
||||
|
||||
async fn execute_strategy<D: StrategyDriver>(
|
||||
driver: D,
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
|
||||
params: StrategyParams<'_>,
|
||||
query_embedding: Option<Vec<f32>>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
capture_diagnostics: bool,
|
||||
) -> Result<PipelineRunOutput<D::Output>, AppError> {
|
||||
) -> Result<RunOutput<D::Output>, AppError> {
|
||||
let ctx = match query_embedding {
|
||||
Some(embedding) => PipelineContext::with_embedding(
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
embedding,
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
config,
|
||||
reranker,
|
||||
),
|
||||
None => PipelineContext::new(
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
config,
|
||||
reranker,
|
||||
),
|
||||
Some(embedding) => PipelineContext::with_embedding(params, embedding),
|
||||
None => PipelineContext::new(params),
|
||||
};
|
||||
|
||||
run_with_driver(driver, ctx, capture_diagnostics).await
|
||||
@@ -432,7 +275,7 @@ async fn run_with_driver<D: StrategyDriver>(
|
||||
driver: D,
|
||||
mut ctx: PipelineContext<'_>,
|
||||
capture_diagnostics: bool,
|
||||
) -> Result<PipelineRunOutput<D::Output>, AppError> {
|
||||
) -> Result<RunOutput<D::Output>, AppError> {
|
||||
if capture_diagnostics {
|
||||
ctx.enable_diagnostics();
|
||||
}
|
||||
@@ -445,9 +288,9 @@ async fn run_with_driver<D: StrategyDriver>(
|
||||
|
||||
let diagnostics = ctx.take_diagnostics();
|
||||
let stage_timings = ctx.take_stage_timings();
|
||||
let results = driver.finalize(&mut ctx)?;
|
||||
let results = driver.finalize(&mut ctx).map_err(|e| *e)?;
|
||||
|
||||
Ok(PipelineRunOutput {
|
||||
Ok(RunOutput {
|
||||
results,
|
||||
diagnostics,
|
||||
stage_timings,
|
||||
|
||||
@@ -27,9 +27,9 @@ use super::{
|
||||
config::{RetrievalConfig, RetrievalTuning},
|
||||
diagnostics::{
|
||||
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
|
||||
PipelineDiagnostics,
|
||||
Diagnostics,
|
||||
},
|
||||
PipelineStage, PipelineStageTimings, StageKind,
|
||||
StageTimings, Stage, StageKind, StrategyParams,
|
||||
};
|
||||
|
||||
pub struct PipelineContext<'a> {
|
||||
@@ -45,76 +45,51 @@ pub struct PipelineContext<'a> {
|
||||
pub chunk_values: Vec<Scored<TextChunk>>,
|
||||
pub revised_chunk_values: Vec<Scored<TextChunk>>,
|
||||
pub reranker: Option<RerankerLease>,
|
||||
pub diagnostics: Option<PipelineDiagnostics>,
|
||||
pub diagnostics: Option<Diagnostics>,
|
||||
pub entity_results: Vec<RetrievedEntity>,
|
||||
pub chunk_results: Vec<RetrievedChunk>,
|
||||
stage_timings: PipelineStageTimings,
|
||||
stage_timings: StageTimings,
|
||||
}
|
||||
|
||||
impl<'a> PipelineContext<'a> {
|
||||
pub fn new(
|
||||
db_client: &'a SurrealDbClient,
|
||||
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&'a EmbeddingProvider>,
|
||||
input_text: String,
|
||||
user_id: String,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Self {
|
||||
pub fn new(params: StrategyParams<'a>) -> Self {
|
||||
Self {
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
db_client: params.db_client,
|
||||
openai_client: params.openai_client,
|
||||
embedding_provider: params.embedding_provider,
|
||||
input_text: params.input_text.to_owned(),
|
||||
user_id: params.user_id.to_owned(),
|
||||
config: params.config,
|
||||
query_embedding: None,
|
||||
entity_candidates: HashMap::new(),
|
||||
filtered_entities: Vec::new(),
|
||||
chunk_values: Vec::new(),
|
||||
revised_chunk_values: Vec::new(),
|
||||
reranker,
|
||||
reranker: params.reranker,
|
||||
diagnostics: None,
|
||||
entity_results: Vec::new(),
|
||||
chunk_results: Vec::new(),
|
||||
stage_timings: PipelineStageTimings::default(),
|
||||
stage_timings: StageTimings::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_embedding(
|
||||
db_client: &'a SurrealDbClient,
|
||||
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&'a EmbeddingProvider>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: String,
|
||||
user_id: String,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Self {
|
||||
let mut ctx = Self::new(
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
);
|
||||
pub fn with_embedding(params: StrategyParams<'a>, query_embedding: Vec<f32>) -> Self {
|
||||
let mut ctx = Self::new(params);
|
||||
ctx.query_embedding = Some(query_embedding);
|
||||
ctx
|
||||
}
|
||||
|
||||
fn ensure_embedding(&self) -> Result<&Vec<f32>, AppError> {
|
||||
fn ensure_embedding(&self) -> Result<&Vec<f32>, Box<AppError>> {
|
||||
self.query_embedding.as_ref().ok_or_else(|| {
|
||||
AppError::InternalError(
|
||||
Box::new(AppError::InternalError(
|
||||
"query embedding missing before candidate collection".to_string(),
|
||||
)
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn enable_diagnostics(&mut self) {
|
||||
if self.diagnostics.is_none() {
|
||||
self.diagnostics = Some(PipelineDiagnostics::default());
|
||||
self.diagnostics = Some(Diagnostics::default());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,11 +115,11 @@ impl<'a> PipelineContext<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn take_diagnostics(&mut self) -> Option<PipelineDiagnostics> {
|
||||
pub fn take_diagnostics(&mut self) -> Option<Diagnostics> {
|
||||
self.diagnostics.take()
|
||||
}
|
||||
|
||||
pub fn take_stage_timings(&mut self) -> PipelineStageTimings {
|
||||
pub fn take_stage_timings(&mut self) -> StageTimings {
|
||||
std::mem::take(&mut self.stage_timings)
|
||||
}
|
||||
|
||||
@@ -165,7 +140,7 @@ impl<'a> PipelineContext<'a> {
|
||||
pub struct EmbedStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for EmbedStage {
|
||||
impl Stage for EmbedStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Embed
|
||||
}
|
||||
@@ -179,7 +154,7 @@ impl PipelineStage for EmbedStage {
|
||||
pub struct CollectCandidatesStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for CollectCandidatesStage {
|
||||
impl Stage for CollectCandidatesStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::CollectCandidates
|
||||
}
|
||||
@@ -193,7 +168,7 @@ impl PipelineStage for CollectCandidatesStage {
|
||||
pub struct GraphExpansionStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for GraphExpansionStage {
|
||||
impl Stage for GraphExpansionStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::GraphExpansion
|
||||
}
|
||||
@@ -207,7 +182,7 @@ impl PipelineStage for GraphExpansionStage {
|
||||
pub struct RerankStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for RerankStage {
|
||||
impl Stage for RerankStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Rerank
|
||||
}
|
||||
@@ -221,7 +196,7 @@ impl PipelineStage for RerankStage {
|
||||
pub struct AssembleEntitiesStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for AssembleEntitiesStage {
|
||||
impl Stage for AssembleEntitiesStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Assemble
|
||||
}
|
||||
@@ -235,7 +210,7 @@ impl PipelineStage for AssembleEntitiesStage {
|
||||
pub struct ChunkVectorStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for ChunkVectorStage {
|
||||
impl Stage for ChunkVectorStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::CollectCandidates
|
||||
}
|
||||
@@ -249,7 +224,7 @@ impl PipelineStage for ChunkVectorStage {
|
||||
pub struct ChunkRerankStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for ChunkRerankStage {
|
||||
impl Stage for ChunkRerankStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Rerank
|
||||
}
|
||||
@@ -263,7 +238,7 @@ impl PipelineStage for ChunkRerankStage {
|
||||
pub struct ChunkAssembleStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for ChunkAssembleStage {
|
||||
impl Stage for ChunkAssembleStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Assemble
|
||||
}
|
||||
@@ -283,8 +258,7 @@ pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
let embedding = if let Some(provider) = ctx.embedding_provider {
|
||||
provider.embed(&ctx.input_text).await.map_err(|e| {
|
||||
AppError::InternalError(format!(
|
||||
"Failed to generate embedding with provider: {}",
|
||||
e
|
||||
"Failed to generate embedding with provider: {e}",
|
||||
))
|
||||
})?
|
||||
} else {
|
||||
@@ -299,7 +273,7 @@ pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn collect_candidates(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
debug!("Collecting initial candidates via vector and FTS search");
|
||||
let embedding = ctx.ensure_embedding()?.clone();
|
||||
let embedding = ctx.ensure_embedding().map_err(|e| *e)?.clone();
|
||||
let tuning = &ctx.config.tuning;
|
||||
|
||||
let weights = FusionWeights::default();
|
||||
@@ -487,11 +461,11 @@ pub async fn rerank(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
debug!("Collecting vector chunk candidates for revised strategy");
|
||||
let embedding = ctx.ensure_embedding()?.clone();
|
||||
let embedding = ctx.ensure_embedding().map_err(|e| *e)?.clone();
|
||||
let tuning = &ctx.config.tuning;
|
||||
let fts_take = tuning.chunk_fts_take;
|
||||
let (fts_query, fts_token_count) = normalize_fts_query(&ctx.input_text);
|
||||
let fts_enabled = tuning.chunk_rrf_use_fts && fts_take > 0 && !fts_query.is_empty();
|
||||
let fts_enabled = tuning.flags.chunk_rrf_use_fts() && fts_take > 0 && !fts_query.is_empty();
|
||||
|
||||
let (vector_rows, fts_rows) = tokio::try_join!(
|
||||
TextChunk::vector_search(
|
||||
@@ -532,8 +506,8 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(),
|
||||
k: tuning.chunk_rrf_k,
|
||||
vector_weight: tuning.chunk_rrf_vector_weight,
|
||||
fts_weight,
|
||||
use_vector: tuning.chunk_rrf_use_vector,
|
||||
use_fts: tuning.chunk_rrf_use_fts && fts_candidates > 0,
|
||||
use_vector: tuning.flags.chunk_rrf_use_vector(),
|
||||
use_fts: tuning.flags.chunk_rrf_use_fts() && fts_candidates > 0,
|
||||
};
|
||||
|
||||
let mut vector_chunks = reciprocal_rank_fusion(vector_scored, fts_scored, rrf_config);
|
||||
@@ -715,7 +689,7 @@ pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
let mut per_entity_count = 0;
|
||||
for candidate in candidates.iter() {
|
||||
if let Some(trace) = entity_trace.as_mut() {
|
||||
trace.inspected_candidates += 1;
|
||||
trace.inspected_candidates = trace.inspected_candidates.saturating_add(1);
|
||||
}
|
||||
if per_entity_count >= tuning.max_chunks_per_entity {
|
||||
break;
|
||||
@@ -723,17 +697,17 @@ pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
let estimated_tokens =
|
||||
estimate_tokens(&candidate.item.chunk, tuning.avg_chars_per_token);
|
||||
if estimated_tokens > token_budget_remaining {
|
||||
chunks_skipped_due_budget += 1;
|
||||
chunks_skipped_due_budget = chunks_skipped_due_budget.saturating_add(1);
|
||||
if let Some(trace) = entity_trace.as_mut() {
|
||||
trace.skipped_due_budget += 1;
|
||||
trace.skipped_due_budget = trace.skipped_due_budget.saturating_add(1);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
token_budget_remaining = token_budget_remaining.saturating_sub(estimated_tokens);
|
||||
tokens_spent += estimated_tokens;
|
||||
per_entity_count += 1;
|
||||
chunks_selected += 1;
|
||||
tokens_spent = tokens_spent.saturating_add(estimated_tokens);
|
||||
per_entity_count = per_entity_count.saturating_add(1);
|
||||
chunks_selected = chunks_selected.saturating_add(1);
|
||||
|
||||
selected_chunks.push(RetrievedChunk {
|
||||
chunk: candidate.item.clone(),
|
||||
@@ -780,14 +754,14 @@ pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
|
||||
const SCORE_SAMPLE_LIMIT: usize = 8;
|
||||
|
||||
fn sample_scores<T, F>(items: &[Scored<T>], mut extractor: F) -> Vec<f32>
|
||||
fn sample_scores<T, F>(items: &[Scored<T>], extractor: F) -> Vec<f32>
|
||||
where
|
||||
F: FnMut(&Scored<T>) -> f32,
|
||||
{
|
||||
items
|
||||
.iter()
|
||||
.take(SCORE_SAMPLE_LIMIT)
|
||||
.map(|item| extractor(item))
|
||||
.map(extractor)
|
||||
.collect()
|
||||
}
|
||||
|
||||
@@ -912,7 +886,7 @@ fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec<RerankResult
|
||||
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
|
||||
let normalized_scores = min_max_normalize(&raw_scores);
|
||||
|
||||
let use_only = ctx.config.tuning.rerank_scores_only;
|
||||
let use_only = ctx.config.tuning.flags.rerank_scores_only();
|
||||
let blend = if use_only {
|
||||
1.0
|
||||
} else {
|
||||
@@ -942,11 +916,7 @@ fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec<RerankResult
|
||||
}
|
||||
}
|
||||
|
||||
for slot in remaining.into_iter() {
|
||||
if let Some(candidate) = slot {
|
||||
reranked.push(candidate);
|
||||
}
|
||||
}
|
||||
reranked.extend(remaining.into_iter().flatten());
|
||||
|
||||
ctx.filtered_entities = reranked;
|
||||
let keep_top = ctx.config.tuning.rerank_keep_top;
|
||||
@@ -970,7 +940,7 @@ fn apply_chunk_rerank_results(
|
||||
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
|
||||
let normalized_scores = min_max_normalize(&raw_scores);
|
||||
|
||||
let use_only = tuning.rerank_scores_only;
|
||||
let use_only = tuning.flags.rerank_scores_only();
|
||||
let blend = if use_only {
|
||||
1.0
|
||||
} else {
|
||||
@@ -1001,11 +971,7 @@ fn apply_chunk_rerank_results(
|
||||
}
|
||||
}
|
||||
|
||||
for slot in remaining.into_iter() {
|
||||
if let Some(candidate) = slot {
|
||||
reranked.push(candidate);
|
||||
}
|
||||
}
|
||||
reranked.extend(remaining.into_iter().flatten());
|
||||
|
||||
let keep_top = tuning.rerank_keep_top;
|
||||
if keep_top > 0 && reranked.len() > keep_top {
|
||||
@@ -1017,7 +983,7 @@ fn apply_chunk_rerank_results(
|
||||
|
||||
fn estimate_tokens(text: &str, avg_chars_per_token: usize) -> usize {
|
||||
let chars = text.chars().count().max(1);
|
||||
(chars / avg_chars_per_token).max(1)
|
||||
chars.checked_div(avg_chars_per_token).map_or(1, |v| v.max(1))
|
||||
}
|
||||
|
||||
fn rank_chunks_by_combined_score(
|
||||
@@ -1053,13 +1019,20 @@ fn lexical_overlap_score(terms: &[String], haystack: &str) -> f32 {
|
||||
return 0.0;
|
||||
}
|
||||
let lower = haystack.to_ascii_lowercase();
|
||||
let mut matches = 0usize;
|
||||
let mut matches: u32 = 0;
|
||||
for term in terms {
|
||||
if lower.contains(term) {
|
||||
matches += 1;
|
||||
matches = matches.saturating_add(1);
|
||||
}
|
||||
}
|
||||
(matches as f32) / (terms.len() as f32)
|
||||
let total = u32::try_from(terms.len()).unwrap_or(u32::MAX);
|
||||
if total == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
let num = matches.min(total);
|
||||
let num_f32 = u16::try_from(num).map(f32::from).unwrap_or(f32::MAX);
|
||||
let den_f32 = u16::try_from(total).map(f32::from).unwrap_or(f32::MAX);
|
||||
num_f32 / den_f32
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
||||
@@ -28,7 +28,7 @@ impl StrategyDriver for DefaultStrategyDriver {
|
||||
]
|
||||
}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
|
||||
Ok(ctx.take_chunk_results())
|
||||
}
|
||||
}
|
||||
@@ -55,7 +55,7 @@ impl StrategyDriver for RelationshipSuggestionDriver {
|
||||
]
|
||||
}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
|
||||
Ok(ctx.take_entity_results())
|
||||
}
|
||||
}
|
||||
@@ -82,7 +82,7 @@ impl StrategyDriver for IngestionDriver {
|
||||
]
|
||||
}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
|
||||
Ok(ctx.take_entity_results())
|
||||
}
|
||||
}
|
||||
@@ -134,7 +134,7 @@ impl StrategyDriver for SearchStrategyDriver {
|
||||
}
|
||||
}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
|
||||
let chunks = match self.target {
|
||||
SearchTarget::EntitiesOnly => Vec::new(),
|
||||
_ => ctx.take_chunk_results(),
|
||||
|
||||
Reference in New Issue
Block a user