clippy: adhere to pedantic clippy, uniform test error handling

This commit is contained in:
Per Stark
2026-05-26 11:43:45 +02:00
parent 6a5d631287
commit 000852c94c
68 changed files with 2468 additions and 2547 deletions
+4 -6
View File
@@ -118,18 +118,16 @@ pub fn create_chat_request(
}
pub fn process_llm_response(
response: CreateChatCompletionResponse,
) -> Result<LLMResponseFormat, AppError> {
response: &CreateChatCompletionResponse,
) -> Result<LLMResponseFormat, Box<AppError>> {
response
.choices
.first()
.and_then(|choice| choice.message.content.as_ref())
.ok_or(AppError::LLMParsing(
"No content found in LLM response".into(),
))
.ok_or_else(|| Box::new(AppError::LLMParsing("No content found in LLM response".into())))
.and_then(|content| {
serde_json::from_str::<LLMResponseFormat>(content).map_err(|e| {
AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {e}"))
Box::new(AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {e}")))
})
})
}
+16 -26
View File
@@ -20,7 +20,6 @@ use common::storage::{
/// * `entity_id` - ID of the entity to find neighbors for
/// * `user_id` - User ID for access control
/// * `limit` - Maximum number of neighbors to return
pub async fn find_entities_by_relationship_by_id(
db: &SurrealDbClient,
entity_id: &str,
@@ -113,25 +112,23 @@ pub async fn find_entities_by_relationship_by_id(
#[cfg(test)]
mod tests {
use anyhow::{self, Context};
use super::*;
use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
use common::storage::types::knowledge_relationship::KnowledgeRelationship;
use uuid::Uuid;
#[tokio::test]
async fn test_find_entities_by_relationship_by_id() {
// Setup in-memory database for testing
async fn test_find_entities_by_relationship_by_id() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Create some test entities
let entity_type = KnowledgeEntityType::Document;
let user_id = "user123".to_string();
// Create the central entity we'll query relationships for
let central_entity = KnowledgeEntity::new(
"central_source".to_string(),
"Central Entity".to_string(),
@@ -141,7 +138,6 @@ mod tests {
user_id.clone(),
);
// Create related entities
let related_entity1 = KnowledgeEntity::new(
"related_source1".to_string(),
"Related Entity 1".to_string(),
@@ -160,7 +156,6 @@ mod tests {
user_id.clone(),
);
// Create an unrelated entity
let unrelated_entity = KnowledgeEntity::new(
"unrelated_source".to_string(),
"Unrelated Entity".to_string(),
@@ -170,32 +165,29 @@ mod tests {
user_id.clone(),
);
// Store all entities
let central_entity = db
.store_item(central_entity.clone())
.await
.expect("Failed to store central entity")
.unwrap();
.with_context(|| "Failed to store central entity".to_string())?
.ok_or_else(|| anyhow::anyhow!("Central entity not returned after store"))?;
let related_entity1 = db
.store_item(related_entity1.clone())
.await
.expect("Failed to store related entity 1")
.unwrap();
.with_context(|| "Failed to store related entity 1".to_string())?
.ok_or_else(|| anyhow::anyhow!("Related entity 1 not returned after store"))?;
let related_entity2 = db
.store_item(related_entity2.clone())
.await
.expect("Failed to store related entity 2")
.unwrap();
.with_context(|| "Failed to store related entity 2".to_string())?
.ok_or_else(|| anyhow::anyhow!("Related entity 2 not returned after store"))?;
let _unrelated_entity = db
.store_item(unrelated_entity.clone())
.await
.expect("Failed to store unrelated entity")
.unwrap();
.with_context(|| "Failed to store unrelated entity".to_string())?
.ok_or_else(|| anyhow::anyhow!("Unrelated entity not returned after store"))?;
// Create relationships
let source_id = "relationship_source".to_string();
// Create relationship 1: central -> related1
let relationship1 = KnowledgeRelationship::new(
central_entity.id.clone(),
related_entity1.id.clone(),
@@ -204,7 +196,6 @@ mod tests {
"references".to_string(),
);
// Create relationship 2: central -> related2
let relationship2 = KnowledgeRelationship::new(
central_entity.id.clone(),
related_entity2.id.clone(),
@@ -213,26 +204,25 @@ mod tests {
"contains".to_string(),
);
// Store relationships
relationship1
.store_relationship(&db)
.await
.expect("Failed to store relationship 1");
.with_context(|| "Failed to store relationship 1".to_string())?;
relationship2
.store_relationship(&db)
.await
.expect("Failed to store relationship 2");
.with_context(|| "Failed to store relationship 2".to_string())?;
// Test finding entities related to the central entity
let related_entities =
find_entities_by_relationship_by_id(&db, &central_entity.id, &user_id, usize::MAX)
.await
.expect("Failed to find entities by relationship");
.with_context(|| "Failed to find entities by relationship".to_string())?;
// Check that we found relationships
assert!(
related_entities.len() >= 2,
"Should find related entities in both directions"
);
Ok(())
}
}
+80 -100
View File
@@ -42,10 +42,14 @@ impl SearchResult {
}
pub use pipeline::{
retrieved_entities_to_json, PipelineDiagnostics, PipelineStageTimings, RetrievalConfig,
RetrievalStrategy, RetrievalTuning, SearchTarget,
retrieved_entities_to_json, Diagnostics, StageTimings, RetrievalConfig,
RetrievalStrategy, RetrievalTuning, RetrievalTuningFlags, SearchTarget,
};
// Backward-compatible type aliases for external consumers
pub type PipelineDiagnostics = Diagnostics;
pub type PipelineStageTimings = StageTimings;
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
#[derive(Debug, Clone)]
pub struct RetrievedChunk {
@@ -61,7 +65,7 @@ pub struct RetrievedEntity {
pub chunks: Vec<RetrievedChunk>,
}
/// Primary orchestrator for the process of retrieving KnowledgeEntitities related to a input_text
/// Primary orchestrator for the process of retrieving `KnowledgeEntity` values related to an `input_text`
#[instrument(skip_all, fields(user_id))]
pub async fn retrieve_entities(
db_client: &SurrealDbClient,
@@ -72,7 +76,7 @@ pub async fn retrieve_entities(
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<StrategyOutput, AppError> {
pipeline::run_pipeline(
let params = pipeline::StrategyParams {
db_client,
openai_client,
embedding_provider,
@@ -80,17 +84,16 @@ pub async fn retrieve_entities(
user_id,
config,
reranker,
)
.await
};
pipeline::execute(params).await
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::{self};
use async_openai::Client;
use common::storage::indexes::ensure_runtime_indexes;
use common::storage::types::text_chunk::TextChunk;
use pipeline::{RetrievalConfig, RetrievalStrategy};
use common::storage::indexes::ensure_runtime;
use uuid::Uuid;
fn test_embedding() -> Vec<f32> {
@@ -105,27 +108,21 @@ mod tests {
vec![0.2, 0.8, 0.0]
}
async fn setup_test_db() -> SurrealDbClient {
async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
let db = SurrealDbClient::memory(namespace, database).await?;
db.apply_migrations()
.await
.expect("Failed to apply migrations");
db.apply_migrations().await?;
ensure_runtime_indexes(&db, 3)
.await
.expect("failed to build runtime indexes");
ensure_runtime(&db, 3).await?;
db
Ok(db)
}
#[tokio::test]
async fn test_default_strategy_retrieves_chunks() {
let db = setup_test_db().await;
async fn test_default_strategy_retrieves_chunks() -> anyhow::Result<()> {
let db = setup_test_db().await?;
let user_id = "test_user";
let chunk = TextChunk::new(
"source_1".into(),
@@ -133,39 +130,38 @@ mod tests {
user_id.into(),
);
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db)
.await
.expect("Failed to store chunk");
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?;
let openai_client = Client::new();
let results = pipeline::run_pipeline_with_embedding(
&db,
&openai_client,
None,
test_embedding(),
"Rust concurrency async tasks",
let params = pipeline::StrategyParams {
db_client: &db,
openai_client: &openai_client,
embedding_provider: None,
input_text: "Rust concurrency async tasks",
user_id,
RetrievalConfig::default(),
None,
)
.await
.expect("Default strategy retrieval failed");
config: RetrievalConfig::default(),
reranker: None,
};
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
.await?;
let chunks = match results {
StrategyOutput::Chunks(items) => items,
other => panic!("expected chunk results, got {:?}", other),
other => anyhow::bail!("expected chunk results, got {other:?}"),
};
assert!(!chunks.is_empty(), "Expected at least one retrieval result");
assert!(
chunks[0].chunk.chunk.contains("Tokio"),
chunks.first().is_some_and(|c| c.chunk.chunk.contains("Tokio")),
"Expected chunk about Tokio"
);
Ok(())
}
#[tokio::test]
async fn test_default_strategy_returns_chunks_from_multiple_sources() {
let db = setup_test_db().await;
async fn test_default_strategy_returns_chunks_from_multiple_sources(
) -> anyhow::Result<()> {
let db = setup_test_db().await?;
let user_id = "multi_source_user";
let primary_chunk = TextChunk::new(
@@ -179,30 +175,25 @@ mod tests {
user_id.into(),
);
TextChunk::store_with_embedding(primary_chunk, chunk_embedding_primary(), &db)
.await
.expect("Failed to store primary chunk");
TextChunk::store_with_embedding(secondary_chunk, chunk_embedding_secondary(), &db)
.await
.expect("Failed to store secondary chunk");
TextChunk::store_with_embedding(primary_chunk, chunk_embedding_primary(), &db).await?;
TextChunk::store_with_embedding(secondary_chunk, chunk_embedding_secondary(), &db).await?;
let openai_client = Client::new();
let results = pipeline::run_pipeline_with_embedding(
&db,
&openai_client,
None,
test_embedding(),
"Rust concurrency async tasks",
let params = pipeline::StrategyParams {
db_client: &db,
openai_client: &openai_client,
embedding_provider: None,
input_text: "Rust concurrency async tasks",
user_id,
RetrievalConfig::default(),
None,
)
.await
.expect("Default strategy retrieval failed");
config: RetrievalConfig::default(),
reranker: None,
};
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
.await?;
let chunks = match results {
StrategyOutput::Chunks(items) => items,
other => panic!("expected chunk results, got {:?}", other),
other => anyhow::bail!("expected chunk results, got {other:?}"),
};
assert!(chunks.len() >= 2, "Expected chunks from multiple sources");
@@ -216,11 +207,12 @@ mod tests {
.any(|c| c.chunk.source_id == "secondary_source"),
"Should include secondary source chunk"
);
Ok(())
}
#[tokio::test]
async fn test_revised_strategy_returns_chunks() {
let db = setup_test_db().await;
async fn test_revised_strategy_returns_chunks() -> anyhow::Result<()> {
let db = setup_test_db().await?;
let user_id = "chunk_user";
let chunk_one = TextChunk::new(
"src_alpha".into(),
@@ -233,31 +225,26 @@ mod tests {
user_id.into(),
);
TextChunk::store_with_embedding(chunk_one.clone(), chunk_embedding_primary(), &db)
.await
.expect("Failed to store chunk one");
TextChunk::store_with_embedding(chunk_two.clone(), chunk_embedding_secondary(), &db)
.await
.expect("Failed to store chunk two");
TextChunk::store_with_embedding(chunk_one.clone(), chunk_embedding_primary(), &db).await?;
TextChunk::store_with_embedding(chunk_two.clone(), chunk_embedding_secondary(), &db).await?;
let config = RetrievalConfig::with_strategy(RetrievalStrategy::Default);
let openai_client = Client::new();
let results = pipeline::run_pipeline_with_embedding(
&db,
&openai_client,
None,
test_embedding(),
"tokio runtime worker behavior",
let params = pipeline::StrategyParams {
db_client: &db,
openai_client: &openai_client,
embedding_provider: None,
input_text: "tokio runtime worker behavior",
user_id,
config,
None,
)
.await
.expect("Revised retrieval failed");
reranker: None,
};
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
.await?;
let chunks = match results {
StrategyOutput::Chunks(items) => items,
other => panic!("expected chunk output, got {:?}", other),
other => anyhow::bail!("expected chunk results, got {other:?}"),
};
assert!(
@@ -270,11 +257,12 @@ mod tests {
.any(|entry| entry.chunk.chunk.contains("Tokio")),
"Chunk results should contain relevant snippets"
);
Ok(())
}
#[tokio::test]
async fn test_search_strategy_returns_search_result() {
let db = setup_test_db().await;
async fn test_search_strategy_returns_search_result() -> anyhow::Result<()> {
let db = setup_test_db().await?;
let user_id = "search_user";
let chunk = TextChunk::new(
"search_src".into(),
@@ -282,33 +270,24 @@ mod tests {
user_id.into(),
);
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db)
.await
.expect("Failed to store chunk");
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?;
let config = RetrievalConfig::for_search(pipeline::SearchTarget::Both);
let openai_client = Client::new();
let results = pipeline::run_pipeline_with_embedding(
&db,
&openai_client,
None,
test_embedding(),
"async rust programming",
let params = pipeline::StrategyParams {
db_client: &db,
openai_client: &openai_client,
embedding_provider: None,
input_text: "async rust programming",
user_id,
config,
None,
)
.await
.expect("Search strategy retrieval failed");
reranker: None,
};
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
.await?;
assert!(
matches!(results, StrategyOutput::Search(_)),
"expected Search output, got {:?}",
results
);
let search_result = match results {
StrategyOutput::Search(sr) => sr,
_ => unreachable!(),
let StrategyOutput::Search(search_result) = results else {
anyhow::bail!("expected Search output");
};
// Should return chunks (entities may be empty if none stored)
@@ -323,5 +302,6 @@ mod tests {
.any(|c| c.chunk.chunk.contains("Tokio")),
"Search results should contain relevant chunks"
);
Ok(())
}
}
+91 -44
View File
@@ -1,12 +1,13 @@
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::fmt;
use crate::scoring::FusionWeights;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum RetrievalStrategy {
/// Primary hybrid chunk retrieval for search/chat (formerly Revised)
#[default]
Default,
/// Entity retrieval for suggesting relationships when creating manual entities
RelationshipSuggestion,
@@ -29,12 +30,6 @@ pub enum SearchTarget {
Both,
}
impl Default for RetrievalStrategy {
fn default() -> Self {
Self::Default
}
}
impl std::str::FromStr for RetrievalStrategy {
type Err = String;
@@ -70,6 +65,91 @@ impl fmt::Display for RetrievalStrategy {
}
}
/// Two-variant flag that serializes as a bool for backward compatibility.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum BoolFlag {
#[default]
Disabled,
Enabled,
}
impl BoolFlag {
pub const fn as_bool(self) -> bool {
matches!(self, BoolFlag::Enabled)
}
}
impl From<bool> for BoolFlag {
fn from(value: bool) -> Self {
if value {
BoolFlag::Enabled
} else {
BoolFlag::Disabled
}
}
}
impl Serialize for BoolFlag {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_bool(self.as_bool())
}
}
impl<'de> Deserialize<'de> for BoolFlag {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
bool::deserialize(deserializer).map(|b| {
if b {
BoolFlag::Enabled
} else {
BoolFlag::Disabled
}
})
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct RetrievalTuningFlags {
pub rerank_scores_only: BoolFlag,
pub normalize_vector_scores: BoolFlag,
pub normalize_fts_scores: BoolFlag,
pub chunk_rrf_use_vector: BoolFlag,
pub chunk_rrf_use_fts: BoolFlag,
}
impl RetrievalTuningFlags {
pub const fn rerank_scores_only(&self) -> bool {
self.rerank_scores_only.as_bool()
}
pub const fn normalize_vector_scores(&self) -> bool {
self.normalize_vector_scores.as_bool()
}
pub const fn normalize_fts_scores(&self) -> bool {
self.normalize_fts_scores.as_bool()
}
pub const fn chunk_rrf_use_vector(&self) -> bool {
self.chunk_rrf_use_vector.as_bool()
}
pub const fn chunk_rrf_use_fts(&self) -> bool {
self.chunk_rrf_use_fts.as_bool()
}
}
impl Default for RetrievalTuningFlags {
fn default() -> Self {
Self {
rerank_scores_only: BoolFlag::Disabled,
normalize_vector_scores: BoolFlag::Disabled,
normalize_fts_scores: BoolFlag::Enabled,
chunk_rrf_use_vector: BoolFlag::Enabled,
chunk_rrf_use_fts: BoolFlag::Enabled,
}
}
}
/// Tunable parameters that govern each retrieval stage.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetrievalTuning {
@@ -89,15 +169,11 @@ pub struct RetrievalTuning {
pub graph_seed_min_score: f32,
pub graph_vector_inheritance: f32,
pub rerank_blend_weight: f32,
pub rerank_scores_only: bool,
pub flags: RetrievalTuningFlags,
pub rerank_keep_top: usize,
pub chunk_result_cap: usize,
/// Optional fusion weights for hybrid search. If None, uses default weights.
pub fusion_weights: Option<FusionWeights>,
/// Normalize vector similarity scores before fusion (default: true)
pub normalize_vector_scores: bool,
/// Normalize FTS (BM25) scores before fusion (default: true)
pub normalize_fts_scores: bool,
/// Reciprocal rank fusion k value for chunk merging in Revised strategy.
#[serde(default = "default_chunk_rrf_k")]
pub chunk_rrf_k: f32,
@@ -107,12 +183,6 @@ pub struct RetrievalTuning {
/// Weight applied to chunk FTS ranks in RRF.
#[serde(default = "default_chunk_rrf_fts_weight")]
pub chunk_rrf_fts_weight: f32,
/// Whether to include vector rankings in RRF.
#[serde(default = "default_chunk_rrf_use_vector")]
pub chunk_rrf_use_vector: bool,
/// Whether to include chunk FTS rankings in RRF.
#[serde(default = "default_chunk_rrf_use_fts")]
pub chunk_rrf_use_fts: bool,
}
impl Default for RetrievalTuning {
@@ -134,26 +204,19 @@ impl Default for RetrievalTuning {
graph_seed_min_score: 0.4,
graph_vector_inheritance: 0.6,
rerank_blend_weight: 0.65,
rerank_scores_only: false,
flags: RetrievalTuningFlags::default(),
rerank_keep_top: 8,
chunk_result_cap: 5,
fusion_weights: None,
// Vector scores (cosine similarity) are already in [0,1] range
// Normalization only helps when there's significant variation
normalize_vector_scores: false,
// FTS scores (BM25) are unbounded, normalization helps more
normalize_fts_scores: true,
chunk_rrf_k: default_chunk_rrf_k(),
chunk_rrf_vector_weight: default_chunk_rrf_vector_weight(),
chunk_rrf_fts_weight: default_chunk_rrf_fts_weight(),
chunk_rrf_use_vector: default_chunk_rrf_use_vector(),
chunk_rrf_use_fts: default_chunk_rrf_use_fts(),
}
}
}
/// Wrapper containing tuning plus future flags for per-request overrides.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Default)]
pub struct RetrievalConfig {
pub strategy: RetrievalStrategy,
pub tuning: RetrievalTuning,
@@ -211,16 +274,6 @@ impl RetrievalConfig {
}
}
impl Default for RetrievalConfig {
fn default() -> Self {
Self {
strategy: RetrievalStrategy::default(),
tuning: RetrievalTuning::default(),
search_target: SearchTarget::default(),
}
}
}
const fn default_chunk_rrf_k() -> f32 {
60.0
}
@@ -233,10 +286,4 @@ const fn default_chunk_rrf_fts_weight() -> f32 {
1.0
}
const fn default_chunk_rrf_use_vector() -> bool {
true
}
const fn default_chunk_rrf_use_fts() -> bool {
true
}
@@ -2,7 +2,7 @@ use serde::Serialize;
/// Captures instrumentation for each hybrid retrieval stage when diagnostics are enabled.
#[derive(Debug, Clone, Default, Serialize)]
pub struct PipelineDiagnostics {
pub struct Diagnostics {
pub collect_candidates: Option<CollectCandidatesStats>,
pub enrich_chunks_from_entities: Option<ChunkEnrichmentStats>,
pub assemble: Option<AssembleStats>,
+66 -223
View File
@@ -3,10 +3,11 @@ mod diagnostics;
mod stages;
mod strategies;
pub use config::{RetrievalConfig, RetrievalStrategy, RetrievalTuning, SearchTarget};
pub use config::{
RetrievalConfig, RetrievalStrategy, RetrievalTuning, RetrievalTuningFlags, SearchTarget,
};
pub use diagnostics::{
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
PipelineDiagnostics,
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace, Diagnostics,
};
use crate::{reranking::RerankerLease, RetrievedEntity, StrategyOutput};
@@ -37,13 +38,13 @@ pub enum StageKind {
// Pipeline stage trait
#[async_trait]
pub trait PipelineStage: Send + Sync {
pub trait Stage: Send + Sync {
fn kind(&self) -> StageKind;
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError>;
}
// Type alias for boxed stages
pub type BoxedStage = Box<dyn PipelineStage>;
pub type BoxedStage = Box<dyn Stage>;
// Strategy driver trait
#[async_trait]
@@ -51,16 +52,16 @@ pub trait StrategyDriver: Send + Sync {
type Output;
fn stages(&self) -> Vec<BoxedStage>;
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError>;
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>>;
}
// Pipeline stage timings tracker
#[derive(Debug, Default, Clone)]
pub struct PipelineStageTimings {
pub struct StageTimings {
timings: Vec<(StageKind, Duration)>,
}
impl PipelineStageTimings {
impl StageTimings {
pub fn record(&mut self, kind: StageKind, duration: Duration) {
self.timings.push((kind, duration));
}
@@ -74,8 +75,7 @@ impl PipelineStageTimings {
self.timings
.iter()
.find(|(k, _)| *k == kind)
.map(|(_, d)| d.as_millis())
.unwrap_or(0)
.map_or(0, |(_, d)| d.as_millis())
}
pub fn embed_ms(&self) -> u128 {
@@ -103,228 +103,100 @@ impl PipelineStageTimings {
}
}
pub struct PipelineRunOutput<T> {
pub struct RunOutput<T> {
pub results: T,
pub diagnostics: Option<PipelineDiagnostics>,
pub stage_timings: PipelineStageTimings,
pub diagnostics: Option<Diagnostics>,
pub stage_timings: StageTimings,
}
pub async fn run_pipeline(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<StrategyOutput, AppError> {
let input_chars = input_text.chars().count();
let input_preview: String = input_text.chars().take(120).collect();
pub async fn execute(params: StrategyParams<'_>) -> Result<StrategyOutput, AppError> {
let input_chars = params.input_text.chars().count();
let input_preview: String = params.input_text.chars().take(120).collect();
let input_preview_clean = input_preview.replace('\n', " ");
let preview_len = input_preview_clean.chars().count();
info!(
%user_id,
user_id = %params.user_id,
input_chars,
preview_truncated = input_chars > preview_len,
preview = %input_preview_clean,
strategy = %config.strategy,
strategy = %params.config.strategy,
"Starting retrieval pipeline"
);
match config.strategy {
let strategy = params.config.strategy;
let search_target = params.config.search_target;
match strategy {
RetrievalStrategy::Default => {
let driver = DefaultStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
embedding_provider,
None,
input_text,
user_id,
config,
reranker,
false,
)
.await?;
let run = execute_strategy(driver, params, None, false).await?;
Ok(StrategyOutput::Chunks(run.results))
}
RetrievalStrategy::RelationshipSuggestion => {
let driver = RelationshipSuggestionDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
embedding_provider,
None,
input_text,
user_id,
config,
reranker,
false,
)
.await?;
let run = execute_strategy(driver, params, None, false).await?;
Ok(StrategyOutput::Entities(run.results))
}
RetrievalStrategy::Ingestion => {
let driver = IngestionDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
embedding_provider,
None,
input_text,
user_id,
config,
reranker,
false,
)
.await?;
let run = execute_strategy(driver, params, None, false).await?;
Ok(StrategyOutput::Entities(run.results))
}
RetrievalStrategy::Search => {
let search_target = config.search_target;
let driver = SearchStrategyDriver::new(search_target);
let run = execute_strategy(
driver,
db_client,
openai_client,
embedding_provider,
None,
input_text,
user_id,
config,
reranker,
false,
)
.await?;
let run = execute_strategy(driver, params, None, false).await?;
Ok(StrategyOutput::Search(run.results))
}
}
}
pub async fn run_pipeline_with_embedding(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
params: StrategyParams<'_>,
query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<StrategyOutput, AppError> {
match config.strategy {
let strategy = params.config.strategy;
let search_target = params.config.search_target;
match strategy {
RetrievalStrategy::Default => {
let driver = DefaultStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
false,
)
.await?;
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
Ok(StrategyOutput::Chunks(run.results))
}
RetrievalStrategy::RelationshipSuggestion => {
let driver = RelationshipSuggestionDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
false,
)
.await?;
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
Ok(StrategyOutput::Entities(run.results))
}
RetrievalStrategy::Ingestion => {
let driver = IngestionDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
false,
)
.await?;
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
Ok(StrategyOutput::Entities(run.results))
}
RetrievalStrategy::Search => {
let search_target = config.search_target;
let driver = SearchStrategyDriver::new(search_target);
let run = execute_strategy(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
false,
)
.await?;
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
Ok(StrategyOutput::Search(run.results))
}
}
}
// Note: The metrics/diagnostics variants would follow the same pattern,
// but for brevity I'm only updating the main ones used by callers.
// If metrics/diagnostics are needed for non-chat strategies, they should be updated too.
// For now, I'll update them to support at least Initial/Revised as before.
pub async fn run_pipeline_with_embedding_with_metrics(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
params: StrategyParams<'_>,
query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
match config.strategy {
) -> Result<RunOutput<StrategyOutput>, AppError> {
let strategy = params.config.strategy;
match strategy {
RetrievalStrategy::Default => {
let driver = DefaultStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
false,
)
.await?;
Ok(PipelineRunOutput {
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
Ok(RunOutput {
results: StrategyOutput::Chunks(run.results),
diagnostics: run.diagnostics,
stage_timings: run.stage_timings,
})
}
// Fallback for others if needed, or error. For now assuming metrics mainly for chat.
_ => Err(AppError::InternalError(
"Metrics not supported for this strategy".into(),
)),
@@ -332,32 +204,16 @@ pub async fn run_pipeline_with_embedding_with_metrics(
}
pub async fn run_pipeline_with_embedding_with_diagnostics(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
params: StrategyParams<'_>,
query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
match config.strategy {
) -> Result<RunOutput<StrategyOutput>, AppError> {
let strategy = params.config.strategy;
match strategy {
RetrievalStrategy::Default => {
let driver = DefaultStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
true,
)
.await?;
Ok(PipelineRunOutput {
let run = execute_strategy(driver, params, Some(query_embedding), true).await?;
Ok(RunOutput {
results: StrategyOutput::Chunks(run.results),
diagnostics: run.diagnostics,
stage_timings: run.stage_timings,
@@ -391,38 +247,25 @@ pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::V
.collect::<Vec<_>>())
}
pub struct StrategyParams<'a> {
pub db_client: &'a SurrealDbClient,
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
pub embedding_provider: Option<&'a common::utils::embedding::EmbeddingProvider>,
pub input_text: &'a str,
pub user_id: &'a str,
pub config: RetrievalConfig,
pub reranker: Option<RerankerLease>,
}
async fn execute_strategy<D: StrategyDriver>(
driver: D,
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
params: StrategyParams<'_>,
query_embedding: Option<Vec<f32>>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
capture_diagnostics: bool,
) -> Result<PipelineRunOutput<D::Output>, AppError> {
) -> Result<RunOutput<D::Output>, AppError> {
let ctx = match query_embedding {
Some(embedding) => PipelineContext::with_embedding(
db_client,
openai_client,
embedding_provider,
embedding,
input_text.to_owned(),
user_id.to_owned(),
config,
reranker,
),
None => PipelineContext::new(
db_client,
openai_client,
embedding_provider,
input_text.to_owned(),
user_id.to_owned(),
config,
reranker,
),
Some(embedding) => PipelineContext::with_embedding(params, embedding),
None => PipelineContext::new(params),
};
run_with_driver(driver, ctx, capture_diagnostics).await
@@ -432,7 +275,7 @@ async fn run_with_driver<D: StrategyDriver>(
driver: D,
mut ctx: PipelineContext<'_>,
capture_diagnostics: bool,
) -> Result<PipelineRunOutput<D::Output>, AppError> {
) -> Result<RunOutput<D::Output>, AppError> {
if capture_diagnostics {
ctx.enable_diagnostics();
}
@@ -445,9 +288,9 @@ async fn run_with_driver<D: StrategyDriver>(
let diagnostics = ctx.take_diagnostics();
let stage_timings = ctx.take_stage_timings();
let results = driver.finalize(&mut ctx)?;
let results = driver.finalize(&mut ctx).map_err(|e| *e)?;
Ok(PipelineRunOutput {
Ok(RunOutput {
results,
diagnostics,
stage_timings,
+58 -85
View File
@@ -27,9 +27,9 @@ use super::{
config::{RetrievalConfig, RetrievalTuning},
diagnostics::{
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
PipelineDiagnostics,
Diagnostics,
},
PipelineStage, PipelineStageTimings, StageKind,
StageTimings, Stage, StageKind, StrategyParams,
};
pub struct PipelineContext<'a> {
@@ -45,76 +45,51 @@ pub struct PipelineContext<'a> {
pub chunk_values: Vec<Scored<TextChunk>>,
pub revised_chunk_values: Vec<Scored<TextChunk>>,
pub reranker: Option<RerankerLease>,
pub diagnostics: Option<PipelineDiagnostics>,
pub diagnostics: Option<Diagnostics>,
pub entity_results: Vec<RetrievedEntity>,
pub chunk_results: Vec<RetrievedChunk>,
stage_timings: PipelineStageTimings,
stage_timings: StageTimings,
}
impl<'a> PipelineContext<'a> {
pub fn new(
db_client: &'a SurrealDbClient,
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&'a EmbeddingProvider>,
input_text: String,
user_id: String,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Self {
pub fn new(params: StrategyParams<'a>) -> Self {
Self {
db_client,
openai_client,
embedding_provider,
input_text,
user_id,
config,
db_client: params.db_client,
openai_client: params.openai_client,
embedding_provider: params.embedding_provider,
input_text: params.input_text.to_owned(),
user_id: params.user_id.to_owned(),
config: params.config,
query_embedding: None,
entity_candidates: HashMap::new(),
filtered_entities: Vec::new(),
chunk_values: Vec::new(),
revised_chunk_values: Vec::new(),
reranker,
reranker: params.reranker,
diagnostics: None,
entity_results: Vec::new(),
chunk_results: Vec::new(),
stage_timings: PipelineStageTimings::default(),
stage_timings: StageTimings::default(),
}
}
pub fn with_embedding(
db_client: &'a SurrealDbClient,
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&'a EmbeddingProvider>,
query_embedding: Vec<f32>,
input_text: String,
user_id: String,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Self {
let mut ctx = Self::new(
db_client,
openai_client,
embedding_provider,
input_text,
user_id,
config,
reranker,
);
pub fn with_embedding(params: StrategyParams<'a>, query_embedding: Vec<f32>) -> Self {
let mut ctx = Self::new(params);
ctx.query_embedding = Some(query_embedding);
ctx
}
fn ensure_embedding(&self) -> Result<&Vec<f32>, AppError> {
fn ensure_embedding(&self) -> Result<&Vec<f32>, Box<AppError>> {
self.query_embedding.as_ref().ok_or_else(|| {
AppError::InternalError(
Box::new(AppError::InternalError(
"query embedding missing before candidate collection".to_string(),
)
))
})
}
pub fn enable_diagnostics(&mut self) {
if self.diagnostics.is_none() {
self.diagnostics = Some(PipelineDiagnostics::default());
self.diagnostics = Some(Diagnostics::default());
}
}
@@ -140,11 +115,11 @@ impl<'a> PipelineContext<'a> {
}
}
pub fn take_diagnostics(&mut self) -> Option<PipelineDiagnostics> {
pub fn take_diagnostics(&mut self) -> Option<Diagnostics> {
self.diagnostics.take()
}
pub fn take_stage_timings(&mut self) -> PipelineStageTimings {
pub fn take_stage_timings(&mut self) -> StageTimings {
std::mem::take(&mut self.stage_timings)
}
@@ -165,7 +140,7 @@ impl<'a> PipelineContext<'a> {
pub struct EmbedStage;
#[async_trait]
impl PipelineStage for EmbedStage {
impl Stage for EmbedStage {
fn kind(&self) -> StageKind {
StageKind::Embed
}
@@ -179,7 +154,7 @@ impl PipelineStage for EmbedStage {
pub struct CollectCandidatesStage;
#[async_trait]
impl PipelineStage for CollectCandidatesStage {
impl Stage for CollectCandidatesStage {
fn kind(&self) -> StageKind {
StageKind::CollectCandidates
}
@@ -193,7 +168,7 @@ impl PipelineStage for CollectCandidatesStage {
pub struct GraphExpansionStage;
#[async_trait]
impl PipelineStage for GraphExpansionStage {
impl Stage for GraphExpansionStage {
fn kind(&self) -> StageKind {
StageKind::GraphExpansion
}
@@ -207,7 +182,7 @@ impl PipelineStage for GraphExpansionStage {
pub struct RerankStage;
#[async_trait]
impl PipelineStage for RerankStage {
impl Stage for RerankStage {
fn kind(&self) -> StageKind {
StageKind::Rerank
}
@@ -221,7 +196,7 @@ impl PipelineStage for RerankStage {
pub struct AssembleEntitiesStage;
#[async_trait]
impl PipelineStage for AssembleEntitiesStage {
impl Stage for AssembleEntitiesStage {
fn kind(&self) -> StageKind {
StageKind::Assemble
}
@@ -235,7 +210,7 @@ impl PipelineStage for AssembleEntitiesStage {
pub struct ChunkVectorStage;
#[async_trait]
impl PipelineStage for ChunkVectorStage {
impl Stage for ChunkVectorStage {
fn kind(&self) -> StageKind {
StageKind::CollectCandidates
}
@@ -249,7 +224,7 @@ impl PipelineStage for ChunkVectorStage {
pub struct ChunkRerankStage;
#[async_trait]
impl PipelineStage for ChunkRerankStage {
impl Stage for ChunkRerankStage {
fn kind(&self) -> StageKind {
StageKind::Rerank
}
@@ -263,7 +238,7 @@ impl PipelineStage for ChunkRerankStage {
pub struct ChunkAssembleStage;
#[async_trait]
impl PipelineStage for ChunkAssembleStage {
impl Stage for ChunkAssembleStage {
fn kind(&self) -> StageKind {
StageKind::Assemble
}
@@ -283,8 +258,7 @@ pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
let embedding = if let Some(provider) = ctx.embedding_provider {
provider.embed(&ctx.input_text).await.map_err(|e| {
AppError::InternalError(format!(
"Failed to generate embedding with provider: {}",
e
"Failed to generate embedding with provider: {e}",
))
})?
} else {
@@ -299,7 +273,7 @@ pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
#[instrument(level = "trace", skip_all)]
pub async fn collect_candidates(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
debug!("Collecting initial candidates via vector and FTS search");
let embedding = ctx.ensure_embedding()?.clone();
let embedding = ctx.ensure_embedding().map_err(|e| *e)?.clone();
let tuning = &ctx.config.tuning;
let weights = FusionWeights::default();
@@ -487,11 +461,11 @@ pub async fn rerank(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
#[instrument(level = "trace", skip_all)]
pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
debug!("Collecting vector chunk candidates for revised strategy");
let embedding = ctx.ensure_embedding()?.clone();
let embedding = ctx.ensure_embedding().map_err(|e| *e)?.clone();
let tuning = &ctx.config.tuning;
let fts_take = tuning.chunk_fts_take;
let (fts_query, fts_token_count) = normalize_fts_query(&ctx.input_text);
let fts_enabled = tuning.chunk_rrf_use_fts && fts_take > 0 && !fts_query.is_empty();
let fts_enabled = tuning.flags.chunk_rrf_use_fts() && fts_take > 0 && !fts_query.is_empty();
let (vector_rows, fts_rows) = tokio::try_join!(
TextChunk::vector_search(
@@ -532,8 +506,8 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(),
k: tuning.chunk_rrf_k,
vector_weight: tuning.chunk_rrf_vector_weight,
fts_weight,
use_vector: tuning.chunk_rrf_use_vector,
use_fts: tuning.chunk_rrf_use_fts && fts_candidates > 0,
use_vector: tuning.flags.chunk_rrf_use_vector(),
use_fts: tuning.flags.chunk_rrf_use_fts() && fts_candidates > 0,
};
let mut vector_chunks = reciprocal_rank_fusion(vector_scored, fts_scored, rrf_config);
@@ -715,7 +689,7 @@ pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
let mut per_entity_count = 0;
for candidate in candidates.iter() {
if let Some(trace) = entity_trace.as_mut() {
trace.inspected_candidates += 1;
trace.inspected_candidates = trace.inspected_candidates.saturating_add(1);
}
if per_entity_count >= tuning.max_chunks_per_entity {
break;
@@ -723,17 +697,17 @@ pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
let estimated_tokens =
estimate_tokens(&candidate.item.chunk, tuning.avg_chars_per_token);
if estimated_tokens > token_budget_remaining {
chunks_skipped_due_budget += 1;
chunks_skipped_due_budget = chunks_skipped_due_budget.saturating_add(1);
if let Some(trace) = entity_trace.as_mut() {
trace.skipped_due_budget += 1;
trace.skipped_due_budget = trace.skipped_due_budget.saturating_add(1);
}
continue;
}
token_budget_remaining = token_budget_remaining.saturating_sub(estimated_tokens);
tokens_spent += estimated_tokens;
per_entity_count += 1;
chunks_selected += 1;
tokens_spent = tokens_spent.saturating_add(estimated_tokens);
per_entity_count = per_entity_count.saturating_add(1);
chunks_selected = chunks_selected.saturating_add(1);
selected_chunks.push(RetrievedChunk {
chunk: candidate.item.clone(),
@@ -780,14 +754,14 @@ pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
const SCORE_SAMPLE_LIMIT: usize = 8;
fn sample_scores<T, F>(items: &[Scored<T>], mut extractor: F) -> Vec<f32>
fn sample_scores<T, F>(items: &[Scored<T>], extractor: F) -> Vec<f32>
where
F: FnMut(&Scored<T>) -> f32,
{
items
.iter()
.take(SCORE_SAMPLE_LIMIT)
.map(|item| extractor(item))
.map(extractor)
.collect()
}
@@ -912,7 +886,7 @@ fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec<RerankResult
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
let normalized_scores = min_max_normalize(&raw_scores);
let use_only = ctx.config.tuning.rerank_scores_only;
let use_only = ctx.config.tuning.flags.rerank_scores_only();
let blend = if use_only {
1.0
} else {
@@ -942,11 +916,7 @@ fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec<RerankResult
}
}
for slot in remaining.into_iter() {
if let Some(candidate) = slot {
reranked.push(candidate);
}
}
reranked.extend(remaining.into_iter().flatten());
ctx.filtered_entities = reranked;
let keep_top = ctx.config.tuning.rerank_keep_top;
@@ -970,7 +940,7 @@ fn apply_chunk_rerank_results(
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
let normalized_scores = min_max_normalize(&raw_scores);
let use_only = tuning.rerank_scores_only;
let use_only = tuning.flags.rerank_scores_only();
let blend = if use_only {
1.0
} else {
@@ -1001,11 +971,7 @@ fn apply_chunk_rerank_results(
}
}
for slot in remaining.into_iter() {
if let Some(candidate) = slot {
reranked.push(candidate);
}
}
reranked.extend(remaining.into_iter().flatten());
let keep_top = tuning.rerank_keep_top;
if keep_top > 0 && reranked.len() > keep_top {
@@ -1017,7 +983,7 @@ fn apply_chunk_rerank_results(
fn estimate_tokens(text: &str, avg_chars_per_token: usize) -> usize {
let chars = text.chars().count().max(1);
(chars / avg_chars_per_token).max(1)
chars.checked_div(avg_chars_per_token).map_or(1, |v| v.max(1))
}
fn rank_chunks_by_combined_score(
@@ -1053,13 +1019,20 @@ fn lexical_overlap_score(terms: &[String], haystack: &str) -> f32 {
return 0.0;
}
let lower = haystack.to_ascii_lowercase();
let mut matches = 0usize;
let mut matches: u32 = 0;
for term in terms {
if lower.contains(term) {
matches += 1;
matches = matches.saturating_add(1);
}
}
(matches as f32) / (terms.len() as f32)
let total = u32::try_from(terms.len()).unwrap_or(u32::MAX);
if total == 0 {
return 0.0;
}
let num = matches.min(total);
let num_f32 = u16::try_from(num).map(f32::from).unwrap_or(f32::MAX);
let den_f32 = u16::try_from(total).map(f32::from).unwrap_or(f32::MAX);
num_f32 / den_f32
}
#[derive(Clone)]
@@ -28,7 +28,7 @@ impl StrategyDriver for DefaultStrategyDriver {
]
}
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
Ok(ctx.take_chunk_results())
}
}
@@ -55,7 +55,7 @@ impl StrategyDriver for RelationshipSuggestionDriver {
]
}
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
Ok(ctx.take_entity_results())
}
}
@@ -82,7 +82,7 @@ impl StrategyDriver for IngestionDriver {
]
}
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
Ok(ctx.take_entity_results())
}
}
@@ -134,7 +134,7 @@ impl StrategyDriver for SearchStrategyDriver {
}
}
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
let chunks = match self.target {
SearchTarget::EntitiesOnly => Vec::new(),
_ => ctx.take_chunk_results(),
+24 -26
View File
@@ -17,7 +17,7 @@ static NEXT_ENGINE: AtomicUsize = AtomicUsize::new(0);
fn pick_engine_index(pool_len: usize) -> usize {
let n = NEXT_ENGINE.fetch_add(1, Ordering::Relaxed);
n % pool_len
n.checked_rem(pool_len).unwrap_or(0)
}
pub struct RerankerPool {
@@ -28,30 +28,30 @@ pub struct RerankerPool {
impl RerankerPool {
/// Build the pool at startup.
/// `pool_size` controls max parallel reranks.
pub fn new(pool_size: usize) -> Result<Arc<Self>, AppError> {
Self::new_with_options(
pool_size,
RerankInitOptions::new(fastembed::RerankerModel::JINARerankerV1TurboEn),
)
pub fn new(pool_size: usize) -> Result<Arc<Self>, Box<AppError>> {
let init_options =
RerankInitOptions::new(fastembed::RerankerModel::JINARerankerV1TurboEn);
Self::new_with_options(pool_size, &init_options)
}
fn new_with_options(
pool_size: usize,
init_options: RerankInitOptions,
) -> Result<Arc<Self>, AppError> {
init_options: &RerankInitOptions,
) -> Result<Arc<Self>, Box<AppError>> {
if pool_size == 0 {
return Err(AppError::Validation(
return Err(Box::new(AppError::Validation(
"RERANKING_POOL_SIZE must be greater than zero".to_string(),
));
)));
}
fs::create_dir_all(&init_options.cache_dir)?;
fs::create_dir_all(&init_options.cache_dir)
.map_err(|e| Box::new(AppError::from(e)))?;
let mut engines = Vec::with_capacity(pool_size);
for x in 0..pool_size {
debug!("Creating reranking engine: {x}");
let model = TextRerank::try_new(init_options.clone())
.map_err(|e| AppError::InternalError(e.to_string()))?;
.map_err(|e| Box::new(AppError::InternalError(e.to_string())))?;
engines.push(Arc::new(Mutex::new(model)));
}
@@ -62,7 +62,7 @@ impl RerankerPool {
}
/// Initialize a pool using application configuration.
pub fn maybe_from_config(config: &AppConfig) -> Result<Option<Arc<Self>>, AppError> {
pub fn maybe_from_config(config: &AppConfig) -> Result<Option<Arc<Self>>, Box<AppError>> {
if !config.reranking_enabled {
return Ok(None);
}
@@ -70,30 +70,28 @@ impl RerankerPool {
let pool_size = config.reranking_pool_size.unwrap_or_else(default_pool_size);
let init_options = build_rerank_init_options(config)?;
Self::new_with_options(pool_size, init_options).map(Some)
Self::new_with_options(pool_size, &init_options).map(Some)
}
/// Check out capacity + pick an engine.
/// This returns a lease that can perform rerank().
pub async fn checkout(self: &Arc<Self>) -> RerankerLease {
/// This returns a lease that can perform `rerank()`.
pub async fn checkout(self: &Arc<Self>) -> Option<RerankerLease> {
// Acquire a permit. This enforces backpressure.
let permit = self
.semaphore
.clone()
let permit = Arc::clone(&self.semaphore)
.acquire_owned()
.await
.expect("semaphore closed");
.ok()?;
// Pick an engine.
// This is naive: just pick based on a simple modulo counter.
// We use an atomic counter to avoid always choosing index 0.
let idx = pick_engine_index(self.engines.len());
let engine = self.engines[idx].clone();
let engine = self.engines.get(idx).map(Arc::clone)?;
RerankerLease {
Some(RerankerLease {
_permit: permit,
engine,
}
})
}
}
@@ -111,7 +109,7 @@ fn is_truthy(value: &str) -> bool {
)
}
fn build_rerank_init_options(config: &AppConfig) -> Result<RerankInitOptions, AppError> {
fn build_rerank_init_options(config: &AppConfig) -> Result<RerankInitOptions, Box<AppError>> {
let mut options = RerankInitOptions::default();
let cache_dir = config
@@ -125,7 +123,7 @@ fn build_rerank_init_options(config: &AppConfig) -> Result<RerankInitOptions, Ap
.join("fastembed")
.join("reranker")
});
fs::create_dir_all(&cache_dir)?;
fs::create_dir_all(&cache_dir).map_err(|e| Box::new(AppError::from(e)))?;
options.cache_dir = cache_dir;
let show_progress = config
@@ -150,7 +148,7 @@ fn env_bool(key: &str) -> Option<bool> {
env::var(key).ok().map(|value| is_truthy(&value))
}
/// Active lease on a single TextRerank instance.
/// Active lease on a single `TextRerank` instance.
pub struct RerankerLease {
// When this drops the semaphore permit is released.
_permit: OwnedSemaphorePermit,
+14 -5
View File
@@ -28,16 +28,19 @@ impl<T> Scored<T> {
}
}
#[must_use]
pub const fn with_vector_score(mut self, score: f32) -> Self {
self.scores.vector = Some(score);
self
}
#[must_use]
pub const fn with_fts_score(mut self, score: f32) -> Self {
self.scores.fts = Some(score);
self
}
#[must_use]
pub const fn with_graph_score(mut self, score: f32) -> Self {
self.scores.graph = Some(score);
self
@@ -168,7 +171,7 @@ pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 {
if scores.vector.is_some() && scores.fts.is_some() {
// Multiplicative boost: multiply by (1 + bonus) to scale with the base score
// This ensures high-scoring golden chunks get boosted more than low-scoring ones
fused = fused * (1.0 + weights.multi_bonus);
fused *= 1.0 + weights.multi_bonus;
} else {
// For other multi-signal combinations (e.g., vector + graph), use additive bonus
fused += weights.multi_bonus;
@@ -178,8 +181,8 @@ pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 {
clamp_unit(fused)
}
pub fn merge_scored_by_id<T>(
target: &mut std::collections::HashMap<String, Scored<T>>,
pub fn merge_scored_by_id<T, S: std::hash::BuildHasher>(
target: &mut std::collections::HashMap<String, Scored<T>, S>,
incoming: Vec<Scored<T>>,
) where
T: StoredObject + Clone,
@@ -263,7 +266,10 @@ where
}
}
entry.item = candidate.item;
entry.fused += vector_weight / (k + rank as f32 + 1.0);
let rank_f32: f32 = u16::try_from(rank)
.map(f32::from)
.unwrap_or(f32::MAX);
entry.fused += vector_weight / (k + rank_f32 + 1.0);
}
}
@@ -290,7 +296,10 @@ where
}
}
entry.item = candidate.item;
entry.fused += fts_weight / (k + rank as f32 + 1.0);
let rank_f32: f32 = u16::try_from(rank)
.map(f32::from)
.unwrap_or(f32::MAX);
entry.fused += fts_weight / (k + rank_f32 + 1.0);
}
}