mirror of
https://github.com/perstarkse/minne.git
synced 2026-05-28 10:29:30 +02:00
clippy: adhere to pedantic clippy, uniform test error handling
This commit is contained in:
@@ -118,18 +118,16 @@ pub fn create_chat_request(
|
||||
}
|
||||
|
||||
pub fn process_llm_response(
|
||||
response: CreateChatCompletionResponse,
|
||||
) -> Result<LLMResponseFormat, AppError> {
|
||||
response: &CreateChatCompletionResponse,
|
||||
) -> Result<LLMResponseFormat, Box<AppError>> {
|
||||
response
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.message.content.as_ref())
|
||||
.ok_or(AppError::LLMParsing(
|
||||
"No content found in LLM response".into(),
|
||||
))
|
||||
.ok_or_else(|| Box::new(AppError::LLMParsing("No content found in LLM response".into())))
|
||||
.and_then(|content| {
|
||||
serde_json::from_str::<LLMResponseFormat>(content).map_err(|e| {
|
||||
AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {e}"))
|
||||
Box::new(AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {e}")))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ use common::storage::{
|
||||
/// * `entity_id` - ID of the entity to find neighbors for
|
||||
/// * `user_id` - User ID for access control
|
||||
/// * `limit` - Maximum number of neighbors to return
|
||||
|
||||
pub async fn find_entities_by_relationship_by_id(
|
||||
db: &SurrealDbClient,
|
||||
entity_id: &str,
|
||||
@@ -113,25 +112,23 @@ pub async fn find_entities_by_relationship_by_id(
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
use super::*;
|
||||
use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||
use common::storage::types::knowledge_relationship::KnowledgeRelationship;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_find_entities_by_relationship_by_id() {
|
||||
// Setup in-memory database for testing
|
||||
async fn test_find_entities_by_relationship_by_id() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Create some test entities
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
// Create the central entity we'll query relationships for
|
||||
let central_entity = KnowledgeEntity::new(
|
||||
"central_source".to_string(),
|
||||
"Central Entity".to_string(),
|
||||
@@ -141,7 +138,6 @@ mod tests {
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Create related entities
|
||||
let related_entity1 = KnowledgeEntity::new(
|
||||
"related_source1".to_string(),
|
||||
"Related Entity 1".to_string(),
|
||||
@@ -160,7 +156,6 @@ mod tests {
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Create an unrelated entity
|
||||
let unrelated_entity = KnowledgeEntity::new(
|
||||
"unrelated_source".to_string(),
|
||||
"Unrelated Entity".to_string(),
|
||||
@@ -170,32 +165,29 @@ mod tests {
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Store all entities
|
||||
let central_entity = db
|
||||
.store_item(central_entity.clone())
|
||||
.await
|
||||
.expect("Failed to store central entity")
|
||||
.unwrap();
|
||||
.with_context(|| "Failed to store central entity".to_string())?
|
||||
.ok_or_else(|| anyhow::anyhow!("Central entity not returned after store"))?;
|
||||
let related_entity1 = db
|
||||
.store_item(related_entity1.clone())
|
||||
.await
|
||||
.expect("Failed to store related entity 1")
|
||||
.unwrap();
|
||||
.with_context(|| "Failed to store related entity 1".to_string())?
|
||||
.ok_or_else(|| anyhow::anyhow!("Related entity 1 not returned after store"))?;
|
||||
let related_entity2 = db
|
||||
.store_item(related_entity2.clone())
|
||||
.await
|
||||
.expect("Failed to store related entity 2")
|
||||
.unwrap();
|
||||
.with_context(|| "Failed to store related entity 2".to_string())?
|
||||
.ok_or_else(|| anyhow::anyhow!("Related entity 2 not returned after store"))?;
|
||||
let _unrelated_entity = db
|
||||
.store_item(unrelated_entity.clone())
|
||||
.await
|
||||
.expect("Failed to store unrelated entity")
|
||||
.unwrap();
|
||||
.with_context(|| "Failed to store unrelated entity".to_string())?
|
||||
.ok_or_else(|| anyhow::anyhow!("Unrelated entity not returned after store"))?;
|
||||
|
||||
// Create relationships
|
||||
let source_id = "relationship_source".to_string();
|
||||
|
||||
// Create relationship 1: central -> related1
|
||||
let relationship1 = KnowledgeRelationship::new(
|
||||
central_entity.id.clone(),
|
||||
related_entity1.id.clone(),
|
||||
@@ -204,7 +196,6 @@ mod tests {
|
||||
"references".to_string(),
|
||||
);
|
||||
|
||||
// Create relationship 2: central -> related2
|
||||
let relationship2 = KnowledgeRelationship::new(
|
||||
central_entity.id.clone(),
|
||||
related_entity2.id.clone(),
|
||||
@@ -213,26 +204,25 @@ mod tests {
|
||||
"contains".to_string(),
|
||||
);
|
||||
|
||||
// Store relationships
|
||||
relationship1
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship 1");
|
||||
.with_context(|| "Failed to store relationship 1".to_string())?;
|
||||
relationship2
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship 2");
|
||||
.with_context(|| "Failed to store relationship 2".to_string())?;
|
||||
|
||||
// Test finding entities related to the central entity
|
||||
let related_entities =
|
||||
find_entities_by_relationship_by_id(&db, ¢ral_entity.id, &user_id, usize::MAX)
|
||||
.await
|
||||
.expect("Failed to find entities by relationship");
|
||||
.with_context(|| "Failed to find entities by relationship".to_string())?;
|
||||
|
||||
// Check that we found relationships
|
||||
assert!(
|
||||
related_entities.len() >= 2,
|
||||
"Should find related entities in both directions"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
+80
-100
@@ -42,10 +42,14 @@ impl SearchResult {
|
||||
}
|
||||
|
||||
pub use pipeline::{
|
||||
retrieved_entities_to_json, PipelineDiagnostics, PipelineStageTimings, RetrievalConfig,
|
||||
RetrievalStrategy, RetrievalTuning, SearchTarget,
|
||||
retrieved_entities_to_json, Diagnostics, StageTimings, RetrievalConfig,
|
||||
RetrievalStrategy, RetrievalTuning, RetrievalTuningFlags, SearchTarget,
|
||||
};
|
||||
|
||||
// Backward-compatible type aliases for external consumers
|
||||
pub type PipelineDiagnostics = Diagnostics;
|
||||
pub type PipelineStageTimings = StageTimings;
|
||||
|
||||
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetrievedChunk {
|
||||
@@ -61,7 +65,7 @@ pub struct RetrievedEntity {
|
||||
pub chunks: Vec<RetrievedChunk>,
|
||||
}
|
||||
|
||||
/// Primary orchestrator for the process of retrieving KnowledgeEntitities related to a input_text
|
||||
/// Primary orchestrator for the process of retrieving `KnowledgeEntity` values related to an `input_text`
|
||||
#[instrument(skip_all, fields(user_id))]
|
||||
pub async fn retrieve_entities(
|
||||
db_client: &SurrealDbClient,
|
||||
@@ -72,7 +76,7 @@ pub async fn retrieve_entities(
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<StrategyOutput, AppError> {
|
||||
pipeline::run_pipeline(
|
||||
let params = pipeline::StrategyParams {
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
@@ -80,17 +84,16 @@ pub async fn retrieve_entities(
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
)
|
||||
.await
|
||||
};
|
||||
pipeline::execute(params).await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use anyhow::{self};
|
||||
use async_openai::Client;
|
||||
use common::storage::indexes::ensure_runtime_indexes;
|
||||
use common::storage::types::text_chunk::TextChunk;
|
||||
use pipeline::{RetrievalConfig, RetrievalStrategy};
|
||||
use common::storage::indexes::ensure_runtime;
|
||||
use uuid::Uuid;
|
||||
|
||||
fn test_embedding() -> Vec<f32> {
|
||||
@@ -105,27 +108,21 @@ mod tests {
|
||||
vec![0.2, 0.8, 0.0]
|
||||
}
|
||||
|
||||
async fn setup_test_db() -> SurrealDbClient {
|
||||
async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
db.apply_migrations().await?;
|
||||
|
||||
ensure_runtime_indexes(&db, 3)
|
||||
.await
|
||||
.expect("failed to build runtime indexes");
|
||||
ensure_runtime(&db, 3).await?;
|
||||
|
||||
db
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_default_strategy_retrieves_chunks() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_default_strategy_retrieves_chunks() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "test_user";
|
||||
let chunk = TextChunk::new(
|
||||
"source_1".into(),
|
||||
@@ -133,39 +130,38 @@ mod tests {
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db)
|
||||
.await
|
||||
.expect("Failed to store chunk");
|
||||
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?;
|
||||
|
||||
let openai_client = Client::new();
|
||||
let results = pipeline::run_pipeline_with_embedding(
|
||||
&db,
|
||||
&openai_client,
|
||||
None,
|
||||
test_embedding(),
|
||||
"Rust concurrency async tasks",
|
||||
let params = pipeline::StrategyParams {
|
||||
db_client: &db,
|
||||
openai_client: &openai_client,
|
||||
embedding_provider: None,
|
||||
input_text: "Rust concurrency async tasks",
|
||||
user_id,
|
||||
RetrievalConfig::default(),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Default strategy retrieval failed");
|
||||
config: RetrievalConfig::default(),
|
||||
reranker: None,
|
||||
};
|
||||
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
|
||||
.await?;
|
||||
|
||||
let chunks = match results {
|
||||
StrategyOutput::Chunks(items) => items,
|
||||
other => panic!("expected chunk results, got {:?}", other),
|
||||
other => anyhow::bail!("expected chunk results, got {other:?}"),
|
||||
};
|
||||
|
||||
assert!(!chunks.is_empty(), "Expected at least one retrieval result");
|
||||
assert!(
|
||||
chunks[0].chunk.chunk.contains("Tokio"),
|
||||
chunks.first().is_some_and(|c| c.chunk.chunk.contains("Tokio")),
|
||||
"Expected chunk about Tokio"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_default_strategy_returns_chunks_from_multiple_sources() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_default_strategy_returns_chunks_from_multiple_sources(
|
||||
) -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "multi_source_user";
|
||||
|
||||
let primary_chunk = TextChunk::new(
|
||||
@@ -179,30 +175,25 @@ mod tests {
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(primary_chunk, chunk_embedding_primary(), &db)
|
||||
.await
|
||||
.expect("Failed to store primary chunk");
|
||||
TextChunk::store_with_embedding(secondary_chunk, chunk_embedding_secondary(), &db)
|
||||
.await
|
||||
.expect("Failed to store secondary chunk");
|
||||
TextChunk::store_with_embedding(primary_chunk, chunk_embedding_primary(), &db).await?;
|
||||
TextChunk::store_with_embedding(secondary_chunk, chunk_embedding_secondary(), &db).await?;
|
||||
|
||||
let openai_client = Client::new();
|
||||
let results = pipeline::run_pipeline_with_embedding(
|
||||
&db,
|
||||
&openai_client,
|
||||
None,
|
||||
test_embedding(),
|
||||
"Rust concurrency async tasks",
|
||||
let params = pipeline::StrategyParams {
|
||||
db_client: &db,
|
||||
openai_client: &openai_client,
|
||||
embedding_provider: None,
|
||||
input_text: "Rust concurrency async tasks",
|
||||
user_id,
|
||||
RetrievalConfig::default(),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Default strategy retrieval failed");
|
||||
config: RetrievalConfig::default(),
|
||||
reranker: None,
|
||||
};
|
||||
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
|
||||
.await?;
|
||||
|
||||
let chunks = match results {
|
||||
StrategyOutput::Chunks(items) => items,
|
||||
other => panic!("expected chunk results, got {:?}", other),
|
||||
other => anyhow::bail!("expected chunk results, got {other:?}"),
|
||||
};
|
||||
|
||||
assert!(chunks.len() >= 2, "Expected chunks from multiple sources");
|
||||
@@ -216,11 +207,12 @@ mod tests {
|
||||
.any(|c| c.chunk.source_id == "secondary_source"),
|
||||
"Should include secondary source chunk"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_revised_strategy_returns_chunks() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_revised_strategy_returns_chunks() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "chunk_user";
|
||||
let chunk_one = TextChunk::new(
|
||||
"src_alpha".into(),
|
||||
@@ -233,31 +225,26 @@ mod tests {
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk_one.clone(), chunk_embedding_primary(), &db)
|
||||
.await
|
||||
.expect("Failed to store chunk one");
|
||||
TextChunk::store_with_embedding(chunk_two.clone(), chunk_embedding_secondary(), &db)
|
||||
.await
|
||||
.expect("Failed to store chunk two");
|
||||
TextChunk::store_with_embedding(chunk_one.clone(), chunk_embedding_primary(), &db).await?;
|
||||
TextChunk::store_with_embedding(chunk_two.clone(), chunk_embedding_secondary(), &db).await?;
|
||||
|
||||
let config = RetrievalConfig::with_strategy(RetrievalStrategy::Default);
|
||||
let openai_client = Client::new();
|
||||
let results = pipeline::run_pipeline_with_embedding(
|
||||
&db,
|
||||
&openai_client,
|
||||
None,
|
||||
test_embedding(),
|
||||
"tokio runtime worker behavior",
|
||||
let params = pipeline::StrategyParams {
|
||||
db_client: &db,
|
||||
openai_client: &openai_client,
|
||||
embedding_provider: None,
|
||||
input_text: "tokio runtime worker behavior",
|
||||
user_id,
|
||||
config,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Revised retrieval failed");
|
||||
reranker: None,
|
||||
};
|
||||
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
|
||||
.await?;
|
||||
|
||||
let chunks = match results {
|
||||
StrategyOutput::Chunks(items) => items,
|
||||
other => panic!("expected chunk output, got {:?}", other),
|
||||
other => anyhow::bail!("expected chunk results, got {other:?}"),
|
||||
};
|
||||
|
||||
assert!(
|
||||
@@ -270,11 +257,12 @@ mod tests {
|
||||
.any(|entry| entry.chunk.chunk.contains("Tokio")),
|
||||
"Chunk results should contain relevant snippets"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_search_strategy_returns_search_result() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_search_strategy_returns_search_result() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "search_user";
|
||||
let chunk = TextChunk::new(
|
||||
"search_src".into(),
|
||||
@@ -282,33 +270,24 @@ mod tests {
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db)
|
||||
.await
|
||||
.expect("Failed to store chunk");
|
||||
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?;
|
||||
|
||||
let config = RetrievalConfig::for_search(pipeline::SearchTarget::Both);
|
||||
let openai_client = Client::new();
|
||||
let results = pipeline::run_pipeline_with_embedding(
|
||||
&db,
|
||||
&openai_client,
|
||||
None,
|
||||
test_embedding(),
|
||||
"async rust programming",
|
||||
let params = pipeline::StrategyParams {
|
||||
db_client: &db,
|
||||
openai_client: &openai_client,
|
||||
embedding_provider: None,
|
||||
input_text: "async rust programming",
|
||||
user_id,
|
||||
config,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Search strategy retrieval failed");
|
||||
reranker: None,
|
||||
};
|
||||
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
|
||||
.await?;
|
||||
|
||||
assert!(
|
||||
matches!(results, StrategyOutput::Search(_)),
|
||||
"expected Search output, got {:?}",
|
||||
results
|
||||
);
|
||||
let search_result = match results {
|
||||
StrategyOutput::Search(sr) => sr,
|
||||
_ => unreachable!(),
|
||||
let StrategyOutput::Search(search_result) = results else {
|
||||
anyhow::bail!("expected Search output");
|
||||
};
|
||||
|
||||
// Should return chunks (entities may be empty if none stored)
|
||||
@@ -323,5 +302,6 @@ mod tests {
|
||||
.any(|c| c.chunk.chunk.contains("Tokio")),
|
||||
"Search results should contain relevant chunks"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::fmt;
|
||||
|
||||
use crate::scoring::FusionWeights;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RetrievalStrategy {
|
||||
/// Primary hybrid chunk retrieval for search/chat (formerly Revised)
|
||||
#[default]
|
||||
Default,
|
||||
/// Entity retrieval for suggesting relationships when creating manual entities
|
||||
RelationshipSuggestion,
|
||||
@@ -29,12 +30,6 @@ pub enum SearchTarget {
|
||||
Both,
|
||||
}
|
||||
|
||||
impl Default for RetrievalStrategy {
|
||||
fn default() -> Self {
|
||||
Self::Default
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for RetrievalStrategy {
|
||||
type Err = String;
|
||||
|
||||
@@ -70,6 +65,91 @@ impl fmt::Display for RetrievalStrategy {
|
||||
}
|
||||
}
|
||||
|
||||
/// Two-variant flag that serializes as a bool for backward compatibility.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub enum BoolFlag {
|
||||
#[default]
|
||||
Disabled,
|
||||
Enabled,
|
||||
}
|
||||
|
||||
impl BoolFlag {
|
||||
pub const fn as_bool(self) -> bool {
|
||||
matches!(self, BoolFlag::Enabled)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<bool> for BoolFlag {
|
||||
fn from(value: bool) -> Self {
|
||||
if value {
|
||||
BoolFlag::Enabled
|
||||
} else {
|
||||
BoolFlag::Disabled
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for BoolFlag {
|
||||
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
serializer.serialize_bool(self.as_bool())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for BoolFlag {
|
||||
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
|
||||
bool::deserialize(deserializer).map(|b| {
|
||||
if b {
|
||||
BoolFlag::Enabled
|
||||
} else {
|
||||
BoolFlag::Disabled
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct RetrievalTuningFlags {
|
||||
pub rerank_scores_only: BoolFlag,
|
||||
pub normalize_vector_scores: BoolFlag,
|
||||
pub normalize_fts_scores: BoolFlag,
|
||||
pub chunk_rrf_use_vector: BoolFlag,
|
||||
pub chunk_rrf_use_fts: BoolFlag,
|
||||
}
|
||||
|
||||
impl RetrievalTuningFlags {
|
||||
pub const fn rerank_scores_only(&self) -> bool {
|
||||
self.rerank_scores_only.as_bool()
|
||||
}
|
||||
|
||||
pub const fn normalize_vector_scores(&self) -> bool {
|
||||
self.normalize_vector_scores.as_bool()
|
||||
}
|
||||
|
||||
pub const fn normalize_fts_scores(&self) -> bool {
|
||||
self.normalize_fts_scores.as_bool()
|
||||
}
|
||||
|
||||
pub const fn chunk_rrf_use_vector(&self) -> bool {
|
||||
self.chunk_rrf_use_vector.as_bool()
|
||||
}
|
||||
|
||||
pub const fn chunk_rrf_use_fts(&self) -> bool {
|
||||
self.chunk_rrf_use_fts.as_bool()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RetrievalTuningFlags {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
rerank_scores_only: BoolFlag::Disabled,
|
||||
normalize_vector_scores: BoolFlag::Disabled,
|
||||
normalize_fts_scores: BoolFlag::Enabled,
|
||||
chunk_rrf_use_vector: BoolFlag::Enabled,
|
||||
chunk_rrf_use_fts: BoolFlag::Enabled,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tunable parameters that govern each retrieval stage.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RetrievalTuning {
|
||||
@@ -89,15 +169,11 @@ pub struct RetrievalTuning {
|
||||
pub graph_seed_min_score: f32,
|
||||
pub graph_vector_inheritance: f32,
|
||||
pub rerank_blend_weight: f32,
|
||||
pub rerank_scores_only: bool,
|
||||
pub flags: RetrievalTuningFlags,
|
||||
pub rerank_keep_top: usize,
|
||||
pub chunk_result_cap: usize,
|
||||
/// Optional fusion weights for hybrid search. If None, uses default weights.
|
||||
pub fusion_weights: Option<FusionWeights>,
|
||||
/// Normalize vector similarity scores before fusion (default: true)
|
||||
pub normalize_vector_scores: bool,
|
||||
/// Normalize FTS (BM25) scores before fusion (default: true)
|
||||
pub normalize_fts_scores: bool,
|
||||
/// Reciprocal rank fusion k value for chunk merging in Revised strategy.
|
||||
#[serde(default = "default_chunk_rrf_k")]
|
||||
pub chunk_rrf_k: f32,
|
||||
@@ -107,12 +183,6 @@ pub struct RetrievalTuning {
|
||||
/// Weight applied to chunk FTS ranks in RRF.
|
||||
#[serde(default = "default_chunk_rrf_fts_weight")]
|
||||
pub chunk_rrf_fts_weight: f32,
|
||||
/// Whether to include vector rankings in RRF.
|
||||
#[serde(default = "default_chunk_rrf_use_vector")]
|
||||
pub chunk_rrf_use_vector: bool,
|
||||
/// Whether to include chunk FTS rankings in RRF.
|
||||
#[serde(default = "default_chunk_rrf_use_fts")]
|
||||
pub chunk_rrf_use_fts: bool,
|
||||
}
|
||||
|
||||
impl Default for RetrievalTuning {
|
||||
@@ -134,26 +204,19 @@ impl Default for RetrievalTuning {
|
||||
graph_seed_min_score: 0.4,
|
||||
graph_vector_inheritance: 0.6,
|
||||
rerank_blend_weight: 0.65,
|
||||
rerank_scores_only: false,
|
||||
flags: RetrievalTuningFlags::default(),
|
||||
rerank_keep_top: 8,
|
||||
chunk_result_cap: 5,
|
||||
fusion_weights: None,
|
||||
// Vector scores (cosine similarity) are already in [0,1] range
|
||||
// Normalization only helps when there's significant variation
|
||||
normalize_vector_scores: false,
|
||||
// FTS scores (BM25) are unbounded, normalization helps more
|
||||
normalize_fts_scores: true,
|
||||
chunk_rrf_k: default_chunk_rrf_k(),
|
||||
chunk_rrf_vector_weight: default_chunk_rrf_vector_weight(),
|
||||
chunk_rrf_fts_weight: default_chunk_rrf_fts_weight(),
|
||||
chunk_rrf_use_vector: default_chunk_rrf_use_vector(),
|
||||
chunk_rrf_use_fts: default_chunk_rrf_use_fts(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper containing tuning plus future flags for per-request overrides.
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct RetrievalConfig {
|
||||
pub strategy: RetrievalStrategy,
|
||||
pub tuning: RetrievalTuning,
|
||||
@@ -211,16 +274,6 @@ impl RetrievalConfig {
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RetrievalConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
strategy: RetrievalStrategy::default(),
|
||||
tuning: RetrievalTuning::default(),
|
||||
search_target: SearchTarget::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const fn default_chunk_rrf_k() -> f32 {
|
||||
60.0
|
||||
}
|
||||
@@ -233,10 +286,4 @@ const fn default_chunk_rrf_fts_weight() -> f32 {
|
||||
1.0
|
||||
}
|
||||
|
||||
const fn default_chunk_rrf_use_vector() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
const fn default_chunk_rrf_use_fts() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ use serde::Serialize;
|
||||
|
||||
/// Captures instrumentation for each hybrid retrieval stage when diagnostics are enabled.
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct PipelineDiagnostics {
|
||||
pub struct Diagnostics {
|
||||
pub collect_candidates: Option<CollectCandidatesStats>,
|
||||
pub enrich_chunks_from_entities: Option<ChunkEnrichmentStats>,
|
||||
pub assemble: Option<AssembleStats>,
|
||||
|
||||
@@ -3,10 +3,11 @@ mod diagnostics;
|
||||
mod stages;
|
||||
mod strategies;
|
||||
|
||||
pub use config::{RetrievalConfig, RetrievalStrategy, RetrievalTuning, SearchTarget};
|
||||
pub use config::{
|
||||
RetrievalConfig, RetrievalStrategy, RetrievalTuning, RetrievalTuningFlags, SearchTarget,
|
||||
};
|
||||
pub use diagnostics::{
|
||||
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
|
||||
PipelineDiagnostics,
|
||||
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace, Diagnostics,
|
||||
};
|
||||
|
||||
use crate::{reranking::RerankerLease, RetrievedEntity, StrategyOutput};
|
||||
@@ -37,13 +38,13 @@ pub enum StageKind {
|
||||
|
||||
// Pipeline stage trait
|
||||
#[async_trait]
|
||||
pub trait PipelineStage: Send + Sync {
|
||||
pub trait Stage: Send + Sync {
|
||||
fn kind(&self) -> StageKind;
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError>;
|
||||
}
|
||||
|
||||
// Type alias for boxed stages
|
||||
pub type BoxedStage = Box<dyn PipelineStage>;
|
||||
pub type BoxedStage = Box<dyn Stage>;
|
||||
|
||||
// Strategy driver trait
|
||||
#[async_trait]
|
||||
@@ -51,16 +52,16 @@ pub trait StrategyDriver: Send + Sync {
|
||||
type Output;
|
||||
|
||||
fn stages(&self) -> Vec<BoxedStage>;
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError>;
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>>;
|
||||
}
|
||||
|
||||
// Pipeline stage timings tracker
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct PipelineStageTimings {
|
||||
pub struct StageTimings {
|
||||
timings: Vec<(StageKind, Duration)>,
|
||||
}
|
||||
|
||||
impl PipelineStageTimings {
|
||||
impl StageTimings {
|
||||
pub fn record(&mut self, kind: StageKind, duration: Duration) {
|
||||
self.timings.push((kind, duration));
|
||||
}
|
||||
@@ -74,8 +75,7 @@ impl PipelineStageTimings {
|
||||
self.timings
|
||||
.iter()
|
||||
.find(|(k, _)| *k == kind)
|
||||
.map(|(_, d)| d.as_millis())
|
||||
.unwrap_or(0)
|
||||
.map_or(0, |(_, d)| d.as_millis())
|
||||
}
|
||||
|
||||
pub fn embed_ms(&self) -> u128 {
|
||||
@@ -103,228 +103,100 @@ impl PipelineStageTimings {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PipelineRunOutput<T> {
|
||||
pub struct RunOutput<T> {
|
||||
pub results: T,
|
||||
pub diagnostics: Option<PipelineDiagnostics>,
|
||||
pub stage_timings: PipelineStageTimings,
|
||||
pub diagnostics: Option<Diagnostics>,
|
||||
pub stage_timings: StageTimings,
|
||||
}
|
||||
|
||||
pub async fn run_pipeline(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<StrategyOutput, AppError> {
|
||||
let input_chars = input_text.chars().count();
|
||||
let input_preview: String = input_text.chars().take(120).collect();
|
||||
pub async fn execute(params: StrategyParams<'_>) -> Result<StrategyOutput, AppError> {
|
||||
let input_chars = params.input_text.chars().count();
|
||||
let input_preview: String = params.input_text.chars().take(120).collect();
|
||||
let input_preview_clean = input_preview.replace('\n', " ");
|
||||
let preview_len = input_preview_clean.chars().count();
|
||||
info!(
|
||||
%user_id,
|
||||
user_id = %params.user_id,
|
||||
input_chars,
|
||||
preview_truncated = input_chars > preview_len,
|
||||
preview = %input_preview_clean,
|
||||
strategy = %config.strategy,
|
||||
strategy = %params.config.strategy,
|
||||
"Starting retrieval pipeline"
|
||||
);
|
||||
|
||||
match config.strategy {
|
||||
let strategy = params.config.strategy;
|
||||
let search_target = params.config.search_target;
|
||||
|
||||
match strategy {
|
||||
RetrievalStrategy::Default => {
|
||||
let driver = DefaultStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
let run = execute_strategy(driver, params, None, false).await?;
|
||||
Ok(StrategyOutput::Chunks(run.results))
|
||||
}
|
||||
RetrievalStrategy::RelationshipSuggestion => {
|
||||
let driver = RelationshipSuggestionDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
let run = execute_strategy(driver, params, None, false).await?;
|
||||
Ok(StrategyOutput::Entities(run.results))
|
||||
}
|
||||
RetrievalStrategy::Ingestion => {
|
||||
let driver = IngestionDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
let run = execute_strategy(driver, params, None, false).await?;
|
||||
Ok(StrategyOutput::Entities(run.results))
|
||||
}
|
||||
RetrievalStrategy::Search => {
|
||||
let search_target = config.search_target;
|
||||
let driver = SearchStrategyDriver::new(search_target);
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
let run = execute_strategy(driver, params, None, false).await?;
|
||||
Ok(StrategyOutput::Search(run.results))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_pipeline_with_embedding(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
|
||||
params: StrategyParams<'_>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<StrategyOutput, AppError> {
|
||||
match config.strategy {
|
||||
let strategy = params.config.strategy;
|
||||
let search_target = params.config.search_target;
|
||||
|
||||
match strategy {
|
||||
RetrievalStrategy::Default => {
|
||||
let driver = DefaultStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
||||
Ok(StrategyOutput::Chunks(run.results))
|
||||
}
|
||||
RetrievalStrategy::RelationshipSuggestion => {
|
||||
let driver = RelationshipSuggestionDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
||||
Ok(StrategyOutput::Entities(run.results))
|
||||
}
|
||||
RetrievalStrategy::Ingestion => {
|
||||
let driver = IngestionDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
||||
Ok(StrategyOutput::Entities(run.results))
|
||||
}
|
||||
RetrievalStrategy::Search => {
|
||||
let search_target = config.search_target;
|
||||
let driver = SearchStrategyDriver::new(search_target);
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
||||
Ok(StrategyOutput::Search(run.results))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Note: The metrics/diagnostics variants would follow the same pattern,
|
||||
// but for brevity I'm only updating the main ones used by callers.
|
||||
// If metrics/diagnostics are needed for non-chat strategies, they should be updated too.
|
||||
// For now, I'll update them to support at least Initial/Revised as before.
|
||||
|
||||
pub async fn run_pipeline_with_embedding_with_metrics(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
|
||||
params: StrategyParams<'_>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
|
||||
match config.strategy {
|
||||
) -> Result<RunOutput<StrategyOutput>, AppError> {
|
||||
let strategy = params.config.strategy;
|
||||
|
||||
match strategy {
|
||||
RetrievalStrategy::Default => {
|
||||
let driver = DefaultStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(PipelineRunOutput {
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
||||
Ok(RunOutput {
|
||||
results: StrategyOutput::Chunks(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
})
|
||||
}
|
||||
// Fallback for others if needed, or error. For now assuming metrics mainly for chat.
|
||||
_ => Err(AppError::InternalError(
|
||||
"Metrics not supported for this strategy".into(),
|
||||
)),
|
||||
@@ -332,32 +204,16 @@ pub async fn run_pipeline_with_embedding_with_metrics(
|
||||
}
|
||||
|
||||
pub async fn run_pipeline_with_embedding_with_diagnostics(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
|
||||
params: StrategyParams<'_>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
|
||||
match config.strategy {
|
||||
) -> Result<RunOutput<StrategyOutput>, AppError> {
|
||||
let strategy = params.config.strategy;
|
||||
|
||||
match strategy {
|
||||
RetrievalStrategy::Default => {
|
||||
let driver = DefaultStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
true,
|
||||
)
|
||||
.await?;
|
||||
Ok(PipelineRunOutput {
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), true).await?;
|
||||
Ok(RunOutput {
|
||||
results: StrategyOutput::Chunks(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
@@ -391,38 +247,25 @@ pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::V
|
||||
.collect::<Vec<_>>())
|
||||
}
|
||||
|
||||
pub struct StrategyParams<'a> {
|
||||
pub db_client: &'a SurrealDbClient,
|
||||
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
pub embedding_provider: Option<&'a common::utils::embedding::EmbeddingProvider>,
|
||||
pub input_text: &'a str,
|
||||
pub user_id: &'a str,
|
||||
pub config: RetrievalConfig,
|
||||
pub reranker: Option<RerankerLease>,
|
||||
}
|
||||
|
||||
async fn execute_strategy<D: StrategyDriver>(
|
||||
driver: D,
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
|
||||
params: StrategyParams<'_>,
|
||||
query_embedding: Option<Vec<f32>>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
capture_diagnostics: bool,
|
||||
) -> Result<PipelineRunOutput<D::Output>, AppError> {
|
||||
) -> Result<RunOutput<D::Output>, AppError> {
|
||||
let ctx = match query_embedding {
|
||||
Some(embedding) => PipelineContext::with_embedding(
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
embedding,
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
config,
|
||||
reranker,
|
||||
),
|
||||
None => PipelineContext::new(
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
config,
|
||||
reranker,
|
||||
),
|
||||
Some(embedding) => PipelineContext::with_embedding(params, embedding),
|
||||
None => PipelineContext::new(params),
|
||||
};
|
||||
|
||||
run_with_driver(driver, ctx, capture_diagnostics).await
|
||||
@@ -432,7 +275,7 @@ async fn run_with_driver<D: StrategyDriver>(
|
||||
driver: D,
|
||||
mut ctx: PipelineContext<'_>,
|
||||
capture_diagnostics: bool,
|
||||
) -> Result<PipelineRunOutput<D::Output>, AppError> {
|
||||
) -> Result<RunOutput<D::Output>, AppError> {
|
||||
if capture_diagnostics {
|
||||
ctx.enable_diagnostics();
|
||||
}
|
||||
@@ -445,9 +288,9 @@ async fn run_with_driver<D: StrategyDriver>(
|
||||
|
||||
let diagnostics = ctx.take_diagnostics();
|
||||
let stage_timings = ctx.take_stage_timings();
|
||||
let results = driver.finalize(&mut ctx)?;
|
||||
let results = driver.finalize(&mut ctx).map_err(|e| *e)?;
|
||||
|
||||
Ok(PipelineRunOutput {
|
||||
Ok(RunOutput {
|
||||
results,
|
||||
diagnostics,
|
||||
stage_timings,
|
||||
|
||||
@@ -27,9 +27,9 @@ use super::{
|
||||
config::{RetrievalConfig, RetrievalTuning},
|
||||
diagnostics::{
|
||||
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
|
||||
PipelineDiagnostics,
|
||||
Diagnostics,
|
||||
},
|
||||
PipelineStage, PipelineStageTimings, StageKind,
|
||||
StageTimings, Stage, StageKind, StrategyParams,
|
||||
};
|
||||
|
||||
pub struct PipelineContext<'a> {
|
||||
@@ -45,76 +45,51 @@ pub struct PipelineContext<'a> {
|
||||
pub chunk_values: Vec<Scored<TextChunk>>,
|
||||
pub revised_chunk_values: Vec<Scored<TextChunk>>,
|
||||
pub reranker: Option<RerankerLease>,
|
||||
pub diagnostics: Option<PipelineDiagnostics>,
|
||||
pub diagnostics: Option<Diagnostics>,
|
||||
pub entity_results: Vec<RetrievedEntity>,
|
||||
pub chunk_results: Vec<RetrievedChunk>,
|
||||
stage_timings: PipelineStageTimings,
|
||||
stage_timings: StageTimings,
|
||||
}
|
||||
|
||||
impl<'a> PipelineContext<'a> {
|
||||
pub fn new(
|
||||
db_client: &'a SurrealDbClient,
|
||||
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&'a EmbeddingProvider>,
|
||||
input_text: String,
|
||||
user_id: String,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Self {
|
||||
pub fn new(params: StrategyParams<'a>) -> Self {
|
||||
Self {
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
db_client: params.db_client,
|
||||
openai_client: params.openai_client,
|
||||
embedding_provider: params.embedding_provider,
|
||||
input_text: params.input_text.to_owned(),
|
||||
user_id: params.user_id.to_owned(),
|
||||
config: params.config,
|
||||
query_embedding: None,
|
||||
entity_candidates: HashMap::new(),
|
||||
filtered_entities: Vec::new(),
|
||||
chunk_values: Vec::new(),
|
||||
revised_chunk_values: Vec::new(),
|
||||
reranker,
|
||||
reranker: params.reranker,
|
||||
diagnostics: None,
|
||||
entity_results: Vec::new(),
|
||||
chunk_results: Vec::new(),
|
||||
stage_timings: PipelineStageTimings::default(),
|
||||
stage_timings: StageTimings::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_embedding(
|
||||
db_client: &'a SurrealDbClient,
|
||||
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&'a EmbeddingProvider>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: String,
|
||||
user_id: String,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Self {
|
||||
let mut ctx = Self::new(
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
);
|
||||
pub fn with_embedding(params: StrategyParams<'a>, query_embedding: Vec<f32>) -> Self {
|
||||
let mut ctx = Self::new(params);
|
||||
ctx.query_embedding = Some(query_embedding);
|
||||
ctx
|
||||
}
|
||||
|
||||
fn ensure_embedding(&self) -> Result<&Vec<f32>, AppError> {
|
||||
fn ensure_embedding(&self) -> Result<&Vec<f32>, Box<AppError>> {
|
||||
self.query_embedding.as_ref().ok_or_else(|| {
|
||||
AppError::InternalError(
|
||||
Box::new(AppError::InternalError(
|
||||
"query embedding missing before candidate collection".to_string(),
|
||||
)
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn enable_diagnostics(&mut self) {
|
||||
if self.diagnostics.is_none() {
|
||||
self.diagnostics = Some(PipelineDiagnostics::default());
|
||||
self.diagnostics = Some(Diagnostics::default());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,11 +115,11 @@ impl<'a> PipelineContext<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn take_diagnostics(&mut self) -> Option<PipelineDiagnostics> {
|
||||
pub fn take_diagnostics(&mut self) -> Option<Diagnostics> {
|
||||
self.diagnostics.take()
|
||||
}
|
||||
|
||||
pub fn take_stage_timings(&mut self) -> PipelineStageTimings {
|
||||
pub fn take_stage_timings(&mut self) -> StageTimings {
|
||||
std::mem::take(&mut self.stage_timings)
|
||||
}
|
||||
|
||||
@@ -165,7 +140,7 @@ impl<'a> PipelineContext<'a> {
|
||||
pub struct EmbedStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for EmbedStage {
|
||||
impl Stage for EmbedStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Embed
|
||||
}
|
||||
@@ -179,7 +154,7 @@ impl PipelineStage for EmbedStage {
|
||||
pub struct CollectCandidatesStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for CollectCandidatesStage {
|
||||
impl Stage for CollectCandidatesStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::CollectCandidates
|
||||
}
|
||||
@@ -193,7 +168,7 @@ impl PipelineStage for CollectCandidatesStage {
|
||||
pub struct GraphExpansionStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for GraphExpansionStage {
|
||||
impl Stage for GraphExpansionStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::GraphExpansion
|
||||
}
|
||||
@@ -207,7 +182,7 @@ impl PipelineStage for GraphExpansionStage {
|
||||
pub struct RerankStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for RerankStage {
|
||||
impl Stage for RerankStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Rerank
|
||||
}
|
||||
@@ -221,7 +196,7 @@ impl PipelineStage for RerankStage {
|
||||
pub struct AssembleEntitiesStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for AssembleEntitiesStage {
|
||||
impl Stage for AssembleEntitiesStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Assemble
|
||||
}
|
||||
@@ -235,7 +210,7 @@ impl PipelineStage for AssembleEntitiesStage {
|
||||
pub struct ChunkVectorStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for ChunkVectorStage {
|
||||
impl Stage for ChunkVectorStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::CollectCandidates
|
||||
}
|
||||
@@ -249,7 +224,7 @@ impl PipelineStage for ChunkVectorStage {
|
||||
pub struct ChunkRerankStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for ChunkRerankStage {
|
||||
impl Stage for ChunkRerankStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Rerank
|
||||
}
|
||||
@@ -263,7 +238,7 @@ impl PipelineStage for ChunkRerankStage {
|
||||
pub struct ChunkAssembleStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for ChunkAssembleStage {
|
||||
impl Stage for ChunkAssembleStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Assemble
|
||||
}
|
||||
@@ -283,8 +258,7 @@ pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
let embedding = if let Some(provider) = ctx.embedding_provider {
|
||||
provider.embed(&ctx.input_text).await.map_err(|e| {
|
||||
AppError::InternalError(format!(
|
||||
"Failed to generate embedding with provider: {}",
|
||||
e
|
||||
"Failed to generate embedding with provider: {e}",
|
||||
))
|
||||
})?
|
||||
} else {
|
||||
@@ -299,7 +273,7 @@ pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn collect_candidates(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
debug!("Collecting initial candidates via vector and FTS search");
|
||||
let embedding = ctx.ensure_embedding()?.clone();
|
||||
let embedding = ctx.ensure_embedding().map_err(|e| *e)?.clone();
|
||||
let tuning = &ctx.config.tuning;
|
||||
|
||||
let weights = FusionWeights::default();
|
||||
@@ -487,11 +461,11 @@ pub async fn rerank(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
debug!("Collecting vector chunk candidates for revised strategy");
|
||||
let embedding = ctx.ensure_embedding()?.clone();
|
||||
let embedding = ctx.ensure_embedding().map_err(|e| *e)?.clone();
|
||||
let tuning = &ctx.config.tuning;
|
||||
let fts_take = tuning.chunk_fts_take;
|
||||
let (fts_query, fts_token_count) = normalize_fts_query(&ctx.input_text);
|
||||
let fts_enabled = tuning.chunk_rrf_use_fts && fts_take > 0 && !fts_query.is_empty();
|
||||
let fts_enabled = tuning.flags.chunk_rrf_use_fts() && fts_take > 0 && !fts_query.is_empty();
|
||||
|
||||
let (vector_rows, fts_rows) = tokio::try_join!(
|
||||
TextChunk::vector_search(
|
||||
@@ -532,8 +506,8 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(),
|
||||
k: tuning.chunk_rrf_k,
|
||||
vector_weight: tuning.chunk_rrf_vector_weight,
|
||||
fts_weight,
|
||||
use_vector: tuning.chunk_rrf_use_vector,
|
||||
use_fts: tuning.chunk_rrf_use_fts && fts_candidates > 0,
|
||||
use_vector: tuning.flags.chunk_rrf_use_vector(),
|
||||
use_fts: tuning.flags.chunk_rrf_use_fts() && fts_candidates > 0,
|
||||
};
|
||||
|
||||
let mut vector_chunks = reciprocal_rank_fusion(vector_scored, fts_scored, rrf_config);
|
||||
@@ -715,7 +689,7 @@ pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
let mut per_entity_count = 0;
|
||||
for candidate in candidates.iter() {
|
||||
if let Some(trace) = entity_trace.as_mut() {
|
||||
trace.inspected_candidates += 1;
|
||||
trace.inspected_candidates = trace.inspected_candidates.saturating_add(1);
|
||||
}
|
||||
if per_entity_count >= tuning.max_chunks_per_entity {
|
||||
break;
|
||||
@@ -723,17 +697,17 @@ pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
let estimated_tokens =
|
||||
estimate_tokens(&candidate.item.chunk, tuning.avg_chars_per_token);
|
||||
if estimated_tokens > token_budget_remaining {
|
||||
chunks_skipped_due_budget += 1;
|
||||
chunks_skipped_due_budget = chunks_skipped_due_budget.saturating_add(1);
|
||||
if let Some(trace) = entity_trace.as_mut() {
|
||||
trace.skipped_due_budget += 1;
|
||||
trace.skipped_due_budget = trace.skipped_due_budget.saturating_add(1);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
token_budget_remaining = token_budget_remaining.saturating_sub(estimated_tokens);
|
||||
tokens_spent += estimated_tokens;
|
||||
per_entity_count += 1;
|
||||
chunks_selected += 1;
|
||||
tokens_spent = tokens_spent.saturating_add(estimated_tokens);
|
||||
per_entity_count = per_entity_count.saturating_add(1);
|
||||
chunks_selected = chunks_selected.saturating_add(1);
|
||||
|
||||
selected_chunks.push(RetrievedChunk {
|
||||
chunk: candidate.item.clone(),
|
||||
@@ -780,14 +754,14 @@ pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
|
||||
const SCORE_SAMPLE_LIMIT: usize = 8;
|
||||
|
||||
fn sample_scores<T, F>(items: &[Scored<T>], mut extractor: F) -> Vec<f32>
|
||||
fn sample_scores<T, F>(items: &[Scored<T>], extractor: F) -> Vec<f32>
|
||||
where
|
||||
F: FnMut(&Scored<T>) -> f32,
|
||||
{
|
||||
items
|
||||
.iter()
|
||||
.take(SCORE_SAMPLE_LIMIT)
|
||||
.map(|item| extractor(item))
|
||||
.map(extractor)
|
||||
.collect()
|
||||
}
|
||||
|
||||
@@ -912,7 +886,7 @@ fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec<RerankResult
|
||||
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
|
||||
let normalized_scores = min_max_normalize(&raw_scores);
|
||||
|
||||
let use_only = ctx.config.tuning.rerank_scores_only;
|
||||
let use_only = ctx.config.tuning.flags.rerank_scores_only();
|
||||
let blend = if use_only {
|
||||
1.0
|
||||
} else {
|
||||
@@ -942,11 +916,7 @@ fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec<RerankResult
|
||||
}
|
||||
}
|
||||
|
||||
for slot in remaining.into_iter() {
|
||||
if let Some(candidate) = slot {
|
||||
reranked.push(candidate);
|
||||
}
|
||||
}
|
||||
reranked.extend(remaining.into_iter().flatten());
|
||||
|
||||
ctx.filtered_entities = reranked;
|
||||
let keep_top = ctx.config.tuning.rerank_keep_top;
|
||||
@@ -970,7 +940,7 @@ fn apply_chunk_rerank_results(
|
||||
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
|
||||
let normalized_scores = min_max_normalize(&raw_scores);
|
||||
|
||||
let use_only = tuning.rerank_scores_only;
|
||||
let use_only = tuning.flags.rerank_scores_only();
|
||||
let blend = if use_only {
|
||||
1.0
|
||||
} else {
|
||||
@@ -1001,11 +971,7 @@ fn apply_chunk_rerank_results(
|
||||
}
|
||||
}
|
||||
|
||||
for slot in remaining.into_iter() {
|
||||
if let Some(candidate) = slot {
|
||||
reranked.push(candidate);
|
||||
}
|
||||
}
|
||||
reranked.extend(remaining.into_iter().flatten());
|
||||
|
||||
let keep_top = tuning.rerank_keep_top;
|
||||
if keep_top > 0 && reranked.len() > keep_top {
|
||||
@@ -1017,7 +983,7 @@ fn apply_chunk_rerank_results(
|
||||
|
||||
fn estimate_tokens(text: &str, avg_chars_per_token: usize) -> usize {
|
||||
let chars = text.chars().count().max(1);
|
||||
(chars / avg_chars_per_token).max(1)
|
||||
chars.checked_div(avg_chars_per_token).map_or(1, |v| v.max(1))
|
||||
}
|
||||
|
||||
fn rank_chunks_by_combined_score(
|
||||
@@ -1053,13 +1019,20 @@ fn lexical_overlap_score(terms: &[String], haystack: &str) -> f32 {
|
||||
return 0.0;
|
||||
}
|
||||
let lower = haystack.to_ascii_lowercase();
|
||||
let mut matches = 0usize;
|
||||
let mut matches: u32 = 0;
|
||||
for term in terms {
|
||||
if lower.contains(term) {
|
||||
matches += 1;
|
||||
matches = matches.saturating_add(1);
|
||||
}
|
||||
}
|
||||
(matches as f32) / (terms.len() as f32)
|
||||
let total = u32::try_from(terms.len()).unwrap_or(u32::MAX);
|
||||
if total == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
let num = matches.min(total);
|
||||
let num_f32 = u16::try_from(num).map(f32::from).unwrap_or(f32::MAX);
|
||||
let den_f32 = u16::try_from(total).map(f32::from).unwrap_or(f32::MAX);
|
||||
num_f32 / den_f32
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
||||
@@ -28,7 +28,7 @@ impl StrategyDriver for DefaultStrategyDriver {
|
||||
]
|
||||
}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
|
||||
Ok(ctx.take_chunk_results())
|
||||
}
|
||||
}
|
||||
@@ -55,7 +55,7 @@ impl StrategyDriver for RelationshipSuggestionDriver {
|
||||
]
|
||||
}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
|
||||
Ok(ctx.take_entity_results())
|
||||
}
|
||||
}
|
||||
@@ -82,7 +82,7 @@ impl StrategyDriver for IngestionDriver {
|
||||
]
|
||||
}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
|
||||
Ok(ctx.take_entity_results())
|
||||
}
|
||||
}
|
||||
@@ -134,7 +134,7 @@ impl StrategyDriver for SearchStrategyDriver {
|
||||
}
|
||||
}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
|
||||
let chunks = match self.target {
|
||||
SearchTarget::EntitiesOnly => Vec::new(),
|
||||
_ => ctx.take_chunk_results(),
|
||||
|
||||
@@ -17,7 +17,7 @@ static NEXT_ENGINE: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
fn pick_engine_index(pool_len: usize) -> usize {
|
||||
let n = NEXT_ENGINE.fetch_add(1, Ordering::Relaxed);
|
||||
n % pool_len
|
||||
n.checked_rem(pool_len).unwrap_or(0)
|
||||
}
|
||||
|
||||
pub struct RerankerPool {
|
||||
@@ -28,30 +28,30 @@ pub struct RerankerPool {
|
||||
impl RerankerPool {
|
||||
/// Build the pool at startup.
|
||||
/// `pool_size` controls max parallel reranks.
|
||||
pub fn new(pool_size: usize) -> Result<Arc<Self>, AppError> {
|
||||
Self::new_with_options(
|
||||
pool_size,
|
||||
RerankInitOptions::new(fastembed::RerankerModel::JINARerankerV1TurboEn),
|
||||
)
|
||||
pub fn new(pool_size: usize) -> Result<Arc<Self>, Box<AppError>> {
|
||||
let init_options =
|
||||
RerankInitOptions::new(fastembed::RerankerModel::JINARerankerV1TurboEn);
|
||||
Self::new_with_options(pool_size, &init_options)
|
||||
}
|
||||
|
||||
fn new_with_options(
|
||||
pool_size: usize,
|
||||
init_options: RerankInitOptions,
|
||||
) -> Result<Arc<Self>, AppError> {
|
||||
init_options: &RerankInitOptions,
|
||||
) -> Result<Arc<Self>, Box<AppError>> {
|
||||
if pool_size == 0 {
|
||||
return Err(AppError::Validation(
|
||||
return Err(Box::new(AppError::Validation(
|
||||
"RERANKING_POOL_SIZE must be greater than zero".to_string(),
|
||||
));
|
||||
)));
|
||||
}
|
||||
|
||||
fs::create_dir_all(&init_options.cache_dir)?;
|
||||
fs::create_dir_all(&init_options.cache_dir)
|
||||
.map_err(|e| Box::new(AppError::from(e)))?;
|
||||
|
||||
let mut engines = Vec::with_capacity(pool_size);
|
||||
for x in 0..pool_size {
|
||||
debug!("Creating reranking engine: {x}");
|
||||
let model = TextRerank::try_new(init_options.clone())
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
.map_err(|e| Box::new(AppError::InternalError(e.to_string())))?;
|
||||
engines.push(Arc::new(Mutex::new(model)));
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ impl RerankerPool {
|
||||
}
|
||||
|
||||
/// Initialize a pool using application configuration.
|
||||
pub fn maybe_from_config(config: &AppConfig) -> Result<Option<Arc<Self>>, AppError> {
|
||||
pub fn maybe_from_config(config: &AppConfig) -> Result<Option<Arc<Self>>, Box<AppError>> {
|
||||
if !config.reranking_enabled {
|
||||
return Ok(None);
|
||||
}
|
||||
@@ -70,30 +70,28 @@ impl RerankerPool {
|
||||
let pool_size = config.reranking_pool_size.unwrap_or_else(default_pool_size);
|
||||
|
||||
let init_options = build_rerank_init_options(config)?;
|
||||
Self::new_with_options(pool_size, init_options).map(Some)
|
||||
Self::new_with_options(pool_size, &init_options).map(Some)
|
||||
}
|
||||
|
||||
/// Check out capacity + pick an engine.
|
||||
/// This returns a lease that can perform rerank().
|
||||
pub async fn checkout(self: &Arc<Self>) -> RerankerLease {
|
||||
/// This returns a lease that can perform `rerank()`.
|
||||
pub async fn checkout(self: &Arc<Self>) -> Option<RerankerLease> {
|
||||
// Acquire a permit. This enforces backpressure.
|
||||
let permit = self
|
||||
.semaphore
|
||||
.clone()
|
||||
let permit = Arc::clone(&self.semaphore)
|
||||
.acquire_owned()
|
||||
.await
|
||||
.expect("semaphore closed");
|
||||
.ok()?;
|
||||
|
||||
// Pick an engine.
|
||||
// This is naive: just pick based on a simple modulo counter.
|
||||
// We use an atomic counter to avoid always choosing index 0.
|
||||
let idx = pick_engine_index(self.engines.len());
|
||||
let engine = self.engines[idx].clone();
|
||||
let engine = self.engines.get(idx).map(Arc::clone)?;
|
||||
|
||||
RerankerLease {
|
||||
Some(RerankerLease {
|
||||
_permit: permit,
|
||||
engine,
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,7 +109,7 @@ fn is_truthy(value: &str) -> bool {
|
||||
)
|
||||
}
|
||||
|
||||
fn build_rerank_init_options(config: &AppConfig) -> Result<RerankInitOptions, AppError> {
|
||||
fn build_rerank_init_options(config: &AppConfig) -> Result<RerankInitOptions, Box<AppError>> {
|
||||
let mut options = RerankInitOptions::default();
|
||||
|
||||
let cache_dir = config
|
||||
@@ -125,7 +123,7 @@ fn build_rerank_init_options(config: &AppConfig) -> Result<RerankInitOptions, Ap
|
||||
.join("fastembed")
|
||||
.join("reranker")
|
||||
});
|
||||
fs::create_dir_all(&cache_dir)?;
|
||||
fs::create_dir_all(&cache_dir).map_err(|e| Box::new(AppError::from(e)))?;
|
||||
options.cache_dir = cache_dir;
|
||||
|
||||
let show_progress = config
|
||||
@@ -150,7 +148,7 @@ fn env_bool(key: &str) -> Option<bool> {
|
||||
env::var(key).ok().map(|value| is_truthy(&value))
|
||||
}
|
||||
|
||||
/// Active lease on a single TextRerank instance.
|
||||
/// Active lease on a single `TextRerank` instance.
|
||||
pub struct RerankerLease {
|
||||
// When this drops the semaphore permit is released.
|
||||
_permit: OwnedSemaphorePermit,
|
||||
|
||||
@@ -28,16 +28,19 @@ impl<T> Scored<T> {
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn with_vector_score(mut self, score: f32) -> Self {
|
||||
self.scores.vector = Some(score);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn with_fts_score(mut self, score: f32) -> Self {
|
||||
self.scores.fts = Some(score);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn with_graph_score(mut self, score: f32) -> Self {
|
||||
self.scores.graph = Some(score);
|
||||
self
|
||||
@@ -168,7 +171,7 @@ pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 {
|
||||
if scores.vector.is_some() && scores.fts.is_some() {
|
||||
// Multiplicative boost: multiply by (1 + bonus) to scale with the base score
|
||||
// This ensures high-scoring golden chunks get boosted more than low-scoring ones
|
||||
fused = fused * (1.0 + weights.multi_bonus);
|
||||
fused *= 1.0 + weights.multi_bonus;
|
||||
} else {
|
||||
// For other multi-signal combinations (e.g., vector + graph), use additive bonus
|
||||
fused += weights.multi_bonus;
|
||||
@@ -178,8 +181,8 @@ pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 {
|
||||
clamp_unit(fused)
|
||||
}
|
||||
|
||||
pub fn merge_scored_by_id<T>(
|
||||
target: &mut std::collections::HashMap<String, Scored<T>>,
|
||||
pub fn merge_scored_by_id<T, S: std::hash::BuildHasher>(
|
||||
target: &mut std::collections::HashMap<String, Scored<T>, S>,
|
||||
incoming: Vec<Scored<T>>,
|
||||
) where
|
||||
T: StoredObject + Clone,
|
||||
@@ -263,7 +266,10 @@ where
|
||||
}
|
||||
}
|
||||
entry.item = candidate.item;
|
||||
entry.fused += vector_weight / (k + rank as f32 + 1.0);
|
||||
let rank_f32: f32 = u16::try_from(rank)
|
||||
.map(f32::from)
|
||||
.unwrap_or(f32::MAX);
|
||||
entry.fused += vector_weight / (k + rank_f32 + 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -290,7 +296,10 @@ where
|
||||
}
|
||||
}
|
||||
entry.item = candidate.item;
|
||||
entry.fused += fts_weight / (k + rank as f32 + 1.0);
|
||||
let rank_f32: f32 = u16::try_from(rank)
|
||||
.map(f32::from)
|
||||
.unwrap_or(f32::MAX);
|
||||
entry.fused += fts_weight / (k + rank_f32 + 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user