mirror of
https://github.com/perstarkse/minne.git
synced 2026-05-31 03:40:38 +02:00
chore: refactor retrieval pipeline to chunk-first RRF with derived entities and slimmer eval surface.
Collapse the multi-strategy entity engine into one benchmarked chunk retrieval path, derive entities from retrieved chunks, and update consumers, docs, and clippy fixes across the workspace.
This commit is contained in:
@@ -10,16 +10,14 @@ workspace = true
|
||||
[dependencies]
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
surrealdb = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
async-openai = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
fastembed = { workspace = true }
|
||||
|
||||
common = { path = "../common", features = ["test-utils"] }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
|
||||
@@ -1,61 +1,66 @@
|
||||
//! Chat answer assembly: retrieval context formatting and structured LLM request/response types.
|
||||
|
||||
use async_openai::{
|
||||
error::OpenAIError,
|
||||
types::{
|
||||
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
|
||||
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
|
||||
ResponseFormat, ResponseFormatJsonSchema,
|
||||
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, ResponseFormat,
|
||||
ResponseFormatJsonSchema,
|
||||
},
|
||||
};
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::types::{
|
||||
message::{format_history, Message},
|
||||
system_settings::SystemSettings,
|
||||
},
|
||||
use common::storage::types::{
|
||||
message::{format_history, Message},
|
||||
system_settings::SystemSettings,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use super::answer_retrieval_helper::get_query_response_schema;
|
||||
/// JSON schema describing the structured chat answer (answer text + references).
|
||||
fn get_query_response_schema() -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"answer": { "type": "string" },
|
||||
"references": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reference": { "type": "string" },
|
||||
},
|
||||
"required": ["reference"],
|
||||
"additionalProperties": false,
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["answer", "references"],
|
||||
"additionalProperties": false
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Reference {
|
||||
#[allow(dead_code)]
|
||||
pub reference: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct LLMResponseFormat {
|
||||
pub answer: String,
|
||||
#[allow(dead_code)]
|
||||
pub references: Vec<Reference>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Answer {
|
||||
pub content: String,
|
||||
pub references: Vec<String>,
|
||||
}
|
||||
|
||||
pub fn create_user_message(entities_json: &Value, query: &str) -> String {
|
||||
format!(
|
||||
r"
|
||||
Context Information:
|
||||
==================
|
||||
{entities_json}
|
||||
|
||||
User Question:
|
||||
==================
|
||||
{query}
|
||||
"
|
||||
)
|
||||
}
|
||||
|
||||
/// Convert chunk-based retrieval results to JSON format for LLM context
|
||||
pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value {
|
||||
fn round_score(value: f32) -> f64 {
|
||||
(f64::from(value) * 1000.0).round() / 1000.0
|
||||
impl LLMResponseFormat {
|
||||
pub fn reference_ids(&self) -> Vec<String> {
|
||||
self.references
|
||||
.iter()
|
||||
.map(|entry| entry.reference.clone())
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert chunk-based retrieval results to JSON format for LLM context.
|
||||
pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value {
|
||||
use crate::round_score;
|
||||
|
||||
serde_json::json!(chunks
|
||||
.iter()
|
||||
@@ -70,7 +75,7 @@ pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value {
|
||||
}
|
||||
|
||||
pub fn create_user_message_with_history(
|
||||
entities_json: &Value,
|
||||
context_json: &Value,
|
||||
history: &[Message],
|
||||
query: &str,
|
||||
) -> String {
|
||||
@@ -89,7 +94,7 @@ pub fn create_user_message_with_history(
|
||||
{}
|
||||
",
|
||||
format_history(history),
|
||||
entities_json,
|
||||
context_json,
|
||||
query
|
||||
)
|
||||
}
|
||||
@@ -116,18 +121,3 @@ pub fn create_chat_request(
|
||||
.response_format(response_format)
|
||||
.build()
|
||||
}
|
||||
|
||||
pub fn process_llm_response(
|
||||
response: &CreateChatCompletionResponse,
|
||||
) -> Result<LLMResponseFormat, Box<AppError>> {
|
||||
response
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.message.content.as_ref())
|
||||
.ok_or_else(|| Box::new(AppError::LLMParsing("No content found in LLM response".into())))
|
||||
.and_then(|content| {
|
||||
serde_json::from_str::<LLMResponseFormat>(content).map_err(|e| {
|
||||
Box::new(AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {e}")))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
use serde_json::{json, Value};
|
||||
|
||||
pub fn get_query_response_schema() -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"answer": { "type": "string" },
|
||||
"references": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reference": { "type": "string" },
|
||||
},
|
||||
"required": ["reference"],
|
||||
"additionalProperties": false,
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["answer", "references"],
|
||||
"additionalProperties": false
|
||||
})
|
||||
}
|
||||
@@ -1,228 +0,0 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use surrealdb::{sql::Thing, Error};
|
||||
|
||||
use common::storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{
|
||||
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
|
||||
StoredObject,
|
||||
},
|
||||
};
|
||||
|
||||
/// Find entities related to the given entity via graph relationships.
|
||||
///
|
||||
/// Queries the `relates_to` edge table for all relationships involving the entity,
|
||||
/// then fetches and returns the neighboring entities.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `db` - Database client
|
||||
/// * `entity_id` - ID of the entity to find neighbors for
|
||||
/// * `user_id` - User ID for access control
|
||||
/// * `limit` - Maximum number of neighbors to return
|
||||
pub async fn find_entities_by_relationship_by_id(
|
||||
db: &SurrealDbClient,
|
||||
entity_id: &str,
|
||||
user_id: &str,
|
||||
limit: usize,
|
||||
) -> Result<Vec<KnowledgeEntity>, Error> {
|
||||
let mut relationships_response = db
|
||||
.query(
|
||||
"
|
||||
SELECT * FROM relates_to
|
||||
WHERE metadata.user_id = $user_id
|
||||
AND (in = type::thing('knowledge_entity', $entity_id)
|
||||
OR out = type::thing('knowledge_entity', $entity_id))
|
||||
",
|
||||
)
|
||||
.bind(("entity_id", entity_id.to_owned()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?;
|
||||
|
||||
let relationships: Vec<KnowledgeRelationship> = relationships_response.take(0)?;
|
||||
if relationships.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut neighbor_ids: Vec<String> = Vec::with_capacity(relationships.len());
|
||||
let mut seen: HashSet<String> = HashSet::with_capacity(relationships.len());
|
||||
for rel in relationships {
|
||||
if rel.in_ == entity_id {
|
||||
if seen.insert(rel.out.clone()) {
|
||||
neighbor_ids.push(rel.out);
|
||||
}
|
||||
} else if rel.out == entity_id {
|
||||
if seen.insert(rel.in_.clone()) {
|
||||
neighbor_ids.push(rel.in_);
|
||||
}
|
||||
} else {
|
||||
if seen.insert(rel.in_.clone()) {
|
||||
neighbor_ids.push(rel.in_.clone());
|
||||
}
|
||||
if seen.insert(rel.out.clone()) {
|
||||
neighbor_ids.push(rel.out);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
neighbor_ids.retain(|id| id != entity_id);
|
||||
|
||||
if neighbor_ids.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
if limit > 0 && neighbor_ids.len() > limit {
|
||||
neighbor_ids.truncate(limit);
|
||||
}
|
||||
|
||||
let thing_ids: Vec<Thing> = neighbor_ids
|
||||
.iter()
|
||||
.map(|id| Thing::from((KnowledgeEntity::table_name(), id.as_str())))
|
||||
.collect();
|
||||
|
||||
let mut neighbors_response = db
|
||||
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
|
||||
.bind(("table", KnowledgeEntity::table_name().to_owned()))
|
||||
.bind(("things", thing_ids))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?;
|
||||
|
||||
let neighbors: Vec<KnowledgeEntity> = neighbors_response.take(0)?;
|
||||
if neighbors.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut neighbor_map: HashMap<String, KnowledgeEntity> = neighbors
|
||||
.into_iter()
|
||||
.map(|entity| (entity.id.clone(), entity))
|
||||
.collect();
|
||||
|
||||
let mut ordered = Vec::with_capacity(neighbor_ids.len());
|
||||
for id in neighbor_ids {
|
||||
if let Some(entity) = neighbor_map.remove(&id) {
|
||||
ordered.push(entity);
|
||||
}
|
||||
if limit > 0 && ordered.len() >= limit {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ordered)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
use super::*;
|
||||
use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||
use common::storage::types::knowledge_relationship::KnowledgeRelationship;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_find_entities_by_relationship_by_id() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let central_entity = KnowledgeEntity::new(
|
||||
"central_source".to_string(),
|
||||
"Central Entity".to_string(),
|
||||
"Central Description".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
let related_entity1 = KnowledgeEntity::new(
|
||||
"related_source1".to_string(),
|
||||
"Related Entity 1".to_string(),
|
||||
"Related Description 1".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
let related_entity2 = KnowledgeEntity::new(
|
||||
"related_source2".to_string(),
|
||||
"Related Entity 2".to_string(),
|
||||
"Related Description 2".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
let unrelated_entity = KnowledgeEntity::new(
|
||||
"unrelated_source".to_string(),
|
||||
"Unrelated Entity".to_string(),
|
||||
"Unrelated Description".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
let central_entity = db
|
||||
.store_item(central_entity.clone())
|
||||
.await
|
||||
.with_context(|| "Failed to store central entity".to_string())?
|
||||
.ok_or_else(|| anyhow::anyhow!("Central entity not returned after store"))?;
|
||||
let related_entity1 = db
|
||||
.store_item(related_entity1.clone())
|
||||
.await
|
||||
.with_context(|| "Failed to store related entity 1".to_string())?
|
||||
.ok_or_else(|| anyhow::anyhow!("Related entity 1 not returned after store"))?;
|
||||
let related_entity2 = db
|
||||
.store_item(related_entity2.clone())
|
||||
.await
|
||||
.with_context(|| "Failed to store related entity 2".to_string())?
|
||||
.ok_or_else(|| anyhow::anyhow!("Related entity 2 not returned after store"))?;
|
||||
let _unrelated_entity = db
|
||||
.store_item(unrelated_entity.clone())
|
||||
.await
|
||||
.with_context(|| "Failed to store unrelated entity".to_string())?
|
||||
.ok_or_else(|| anyhow::anyhow!("Unrelated entity not returned after store"))?;
|
||||
|
||||
let source_id = "relationship_source".to_string();
|
||||
|
||||
let relationship1 = KnowledgeRelationship::new(
|
||||
central_entity.id.clone(),
|
||||
related_entity1.id.clone(),
|
||||
user_id.clone(),
|
||||
source_id.clone(),
|
||||
"references".to_string(),
|
||||
);
|
||||
|
||||
let relationship2 = KnowledgeRelationship::new(
|
||||
central_entity.id.clone(),
|
||||
related_entity2.id.clone(),
|
||||
user_id.clone(),
|
||||
source_id.clone(),
|
||||
"contains".to_string(),
|
||||
);
|
||||
|
||||
relationship1
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.with_context(|| "Failed to store relationship 1".to_string())?;
|
||||
relationship2
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.with_context(|| "Failed to store relationship 2".to_string())?;
|
||||
|
||||
let related_entities =
|
||||
find_entities_by_relationship_by_id(&db, ¢ral_entity.id, &user_id, usize::MAX)
|
||||
.await
|
||||
.with_context(|| "Failed to find entities by relationship".to_string())?;
|
||||
|
||||
assert!(
|
||||
related_entities.len() >= 2,
|
||||
"Should find related entities in both directions"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
+63
-115
@@ -1,10 +1,9 @@
|
||||
pub mod answer_retrieval;
|
||||
pub mod answer_retrieval_helper;
|
||||
|
||||
pub mod graph;
|
||||
pub mod pipeline;
|
||||
pub mod reranking;
|
||||
pub mod scoring;
|
||||
|
||||
pub(crate) mod scoring;
|
||||
|
||||
use common::{
|
||||
error::AppError,
|
||||
@@ -16,39 +15,28 @@ use common::{
|
||||
use reranking::RerankerLease;
|
||||
use tracing::instrument;
|
||||
|
||||
// Strategy output variants - defined before pipeline module
|
||||
/// Result of a retrieval run.
|
||||
///
|
||||
/// Chunk retrieval is always performed; entities are only present when the caller
|
||||
/// requested entity resolution via [`RetrievalConfig::with_entities`].
|
||||
#[derive(Debug)]
|
||||
pub enum StrategyOutput {
|
||||
Entities(Vec<RetrievedEntity>),
|
||||
pub enum RetrievalOutput {
|
||||
Chunks(Vec<RetrievedChunk>),
|
||||
Search(SearchResult),
|
||||
}
|
||||
|
||||
/// Unified search result containing both chunks and entities
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SearchResult {
|
||||
pub chunks: Vec<RetrievedChunk>,
|
||||
pub entities: Vec<RetrievedEntity>,
|
||||
}
|
||||
|
||||
impl SearchResult {
|
||||
pub fn new(chunks: Vec<RetrievedChunk>, entities: Vec<RetrievedEntity>) -> Self {
|
||||
Self { chunks, entities }
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.chunks.is_empty() && self.entities.is_empty()
|
||||
}
|
||||
WithEntities {
|
||||
chunks: Vec<RetrievedChunk>,
|
||||
entities: Vec<RetrievedEntity>,
|
||||
},
|
||||
}
|
||||
|
||||
pub use pipeline::{
|
||||
retrieved_entities_to_json, Diagnostics, StageTimings, RetrievalConfig,
|
||||
RetrievalStrategy, RetrievalTuning, RetrievalTuningFlags, SearchTarget,
|
||||
retrieved_entities_to_json, Diagnostics, RetrievalConfig, RetrievalParams, StageKind,
|
||||
StageTimings,
|
||||
};
|
||||
|
||||
// Backward-compatible type aliases for external consumers
|
||||
pub type PipelineDiagnostics = Diagnostics;
|
||||
pub type PipelineStageTimings = StageTimings;
|
||||
/// Round a score to three decimal places for JSON output.
|
||||
pub(crate) fn round_score(value: f32) -> f64 {
|
||||
(f64::from(value) * 1000.0).round() / 1000.0
|
||||
}
|
||||
|
||||
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -57,7 +45,7 @@ pub struct RetrievedChunk {
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
// Final entity representation returned to callers, enriched with ranked chunks.
|
||||
// Knowledge entity resolved from retrieved chunks, enriched with its contributing chunks.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetrievedEntity {
|
||||
pub entity: KnowledgeEntity,
|
||||
@@ -65,9 +53,9 @@ pub struct RetrievedEntity {
|
||||
pub chunks: Vec<RetrievedChunk>,
|
||||
}
|
||||
|
||||
/// Primary orchestrator for the process of retrieving `KnowledgeEntity` values related to an `input_text`
|
||||
/// Run chunk-first hybrid retrieval for `input_text`, optionally resolving owning entities.
|
||||
#[instrument(skip_all, fields(user_id))]
|
||||
pub async fn retrieve_entities(
|
||||
pub async fn retrieve(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
|
||||
@@ -75,8 +63,8 @@ pub async fn retrieve_entities(
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<StrategyOutput, AppError> {
|
||||
let params = pipeline::StrategyParams {
|
||||
) -> Result<RetrievalOutput, AppError> {
|
||||
let params = pipeline::RetrievalParams {
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
@@ -94,6 +82,7 @@ mod tests {
|
||||
use anyhow::{self};
|
||||
use async_openai::Client;
|
||||
use common::storage::indexes::ensure_runtime;
|
||||
use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||
use common::storage::types::system_settings::SystemSettings;
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -133,7 +122,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_default_strategy_retrieves_chunks() -> anyhow::Result<()> {
|
||||
async fn test_chunk_retrieval_returns_chunks() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "test_user";
|
||||
let chunk = TextChunk::new(
|
||||
@@ -145,7 +134,7 @@ mod tests {
|
||||
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?;
|
||||
|
||||
let openai_client = Client::new();
|
||||
let params = pipeline::StrategyParams {
|
||||
let params = pipeline::RetrievalParams {
|
||||
db_client: &db,
|
||||
openai_client: &openai_client,
|
||||
embedding_provider: None,
|
||||
@@ -154,12 +143,13 @@ mod tests {
|
||||
config: RetrievalConfig::default(),
|
||||
reranker: None,
|
||||
};
|
||||
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
|
||||
.await?;
|
||||
let results = pipeline::run_with_embedding(params, test_embedding()).await?;
|
||||
|
||||
let chunks = match results {
|
||||
StrategyOutput::Chunks(items) => items,
|
||||
other => anyhow::bail!("expected chunk results, got {other:?}"),
|
||||
RetrievalOutput::Chunks(items) => items,
|
||||
RetrievalOutput::WithEntities { .. } => {
|
||||
anyhow::bail!("expected chunk results, got entities")
|
||||
}
|
||||
};
|
||||
|
||||
assert!(!chunks.is_empty(), "Expected at least one retrieval result");
|
||||
@@ -171,8 +161,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_default_strategy_returns_chunks_from_multiple_sources(
|
||||
) -> anyhow::Result<()> {
|
||||
async fn test_chunk_retrieval_returns_chunks_from_multiple_sources() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "multi_source_user";
|
||||
|
||||
@@ -191,7 +180,7 @@ mod tests {
|
||||
TextChunk::store_with_embedding(secondary_chunk, chunk_embedding_secondary(), &db).await?;
|
||||
|
||||
let openai_client = Client::new();
|
||||
let params = pipeline::StrategyParams {
|
||||
let params = pipeline::RetrievalParams {
|
||||
db_client: &db,
|
||||
openai_client: &openai_client,
|
||||
embedding_provider: None,
|
||||
@@ -200,12 +189,13 @@ mod tests {
|
||||
config: RetrievalConfig::default(),
|
||||
reranker: None,
|
||||
};
|
||||
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
|
||||
.await?;
|
||||
let results = pipeline::run_with_embedding(params, test_embedding()).await?;
|
||||
|
||||
let chunks = match results {
|
||||
StrategyOutput::Chunks(items) => items,
|
||||
other => anyhow::bail!("expected chunk results, got {other:?}"),
|
||||
RetrievalOutput::Chunks(items) => items,
|
||||
RetrievalOutput::WithEntities { .. } => {
|
||||
anyhow::bail!("expected chunk results, got entities")
|
||||
}
|
||||
};
|
||||
|
||||
assert!(chunks.len() >= 2, "Expected chunks from multiple sources");
|
||||
@@ -223,96 +213,54 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_revised_strategy_returns_chunks() -> anyhow::Result<()> {
|
||||
async fn test_with_entities_resolves_owning_entities() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "chunk_user";
|
||||
let chunk_one = TextChunk::new(
|
||||
"src_alpha".into(),
|
||||
"Tokio tasks execute on worker threads managed by the runtime.".into(),
|
||||
user_id.into(),
|
||||
);
|
||||
let chunk_two = TextChunk::new(
|
||||
"src_beta".into(),
|
||||
"Hyper utilizes Tokio to drive HTTP state machines efficiently.".into(),
|
||||
user_id.into(),
|
||||
);
|
||||
let user_id = "entity_user";
|
||||
|
||||
TextChunk::store_with_embedding(chunk_one.clone(), chunk_embedding_primary(), &db).await?;
|
||||
TextChunk::store_with_embedding(chunk_two.clone(), chunk_embedding_secondary(), &db).await?;
|
||||
|
||||
let config = RetrievalConfig::with_strategy(RetrievalStrategy::Default);
|
||||
let openai_client = Client::new();
|
||||
let params = pipeline::StrategyParams {
|
||||
db_client: &db,
|
||||
openai_client: &openai_client,
|
||||
embedding_provider: None,
|
||||
input_text: "tokio runtime worker behavior",
|
||||
user_id,
|
||||
config,
|
||||
reranker: None,
|
||||
};
|
||||
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
|
||||
.await?;
|
||||
|
||||
let chunks = match results {
|
||||
StrategyOutput::Chunks(items) => items,
|
||||
other => anyhow::bail!("expected chunk results, got {other:?}"),
|
||||
};
|
||||
|
||||
assert!(
|
||||
!chunks.is_empty(),
|
||||
"Revised strategy should return chunk-only responses"
|
||||
);
|
||||
assert!(
|
||||
chunks
|
||||
.iter()
|
||||
.any(|entry| entry.chunk.chunk.contains("Tokio")),
|
||||
"Chunk results should contain relevant snippets"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_search_strategy_returns_search_result() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "search_user";
|
||||
let chunk = TextChunk::new(
|
||||
"search_src".into(),
|
||||
"Async Rust programming uses Tokio runtime for concurrent tasks.".into(),
|
||||
"entity_source".into(),
|
||||
"Async Rust programming uses the Tokio runtime for concurrent tasks.".into(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?;
|
||||
|
||||
let config = RetrievalConfig::for_search(pipeline::SearchTarget::Both);
|
||||
let entity = KnowledgeEntity::new(
|
||||
"entity_source".into(),
|
||||
"Tokio Runtime".into(),
|
||||
"Async runtime for Rust".into(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
user_id.into(),
|
||||
);
|
||||
db.store_item(entity).await?;
|
||||
|
||||
let openai_client = Client::new();
|
||||
let params = pipeline::StrategyParams {
|
||||
let params = pipeline::RetrievalParams {
|
||||
db_client: &db,
|
||||
openai_client: &openai_client,
|
||||
embedding_provider: None,
|
||||
input_text: "async rust programming",
|
||||
user_id,
|
||||
config,
|
||||
config: RetrievalConfig::with_entities(),
|
||||
reranker: None,
|
||||
};
|
||||
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
|
||||
.await?;
|
||||
let results = pipeline::run_with_embedding(params, test_embedding()).await?;
|
||||
|
||||
let StrategyOutput::Search(search_result) = results else {
|
||||
anyhow::bail!("expected Search output");
|
||||
let RetrievalOutput::WithEntities { chunks, entities } = results else {
|
||||
anyhow::bail!("expected WithEntities output");
|
||||
};
|
||||
|
||||
// Should return chunks (entities may be empty if none stored)
|
||||
assert!(!chunks.is_empty(), "Should return chunks");
|
||||
assert!(
|
||||
!search_result.chunks.is_empty(),
|
||||
"Search strategy should return chunks"
|
||||
entities.iter().any(|e| e.entity.name == "Tokio Runtime"),
|
||||
"Should resolve the entity owning the retrieved chunk"
|
||||
);
|
||||
assert!(
|
||||
search_result
|
||||
.chunks
|
||||
entities
|
||||
.iter()
|
||||
.any(|c| c.chunk.chunk.contains("Tokio")),
|
||||
"Search results should contain relevant chunks"
|
||||
.find(|e| e.entity.name == "Tokio Runtime")
|
||||
.is_some_and(|e| !e.chunks.is_empty()),
|
||||
"Resolved entity should carry its contributing chunks"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,22 +1,5 @@
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
|
||||
use crate::scoring::FusionWeights;
|
||||
|
||||
pub use common::utils::config::RetrievalStrategy;
|
||||
|
||||
/// Configures which result types to include in Search strategy
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum SearchTarget {
|
||||
/// Return only text chunks
|
||||
ChunksOnly,
|
||||
/// Return only knowledge entities
|
||||
EntitiesOnly,
|
||||
/// Return both chunks and entities (default)
|
||||
#[default]
|
||||
Both,
|
||||
}
|
||||
|
||||
/// Two-variant flag that serializes as a bool for backward compatibility.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub enum BoolFlag {
|
||||
@@ -62,30 +45,20 @@ impl<'de> Deserialize<'de> for BoolFlag {
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct RetrievalTuningFlags {
|
||||
pub rerank_scores_only: BoolFlag,
|
||||
pub normalize_vector_scores: BoolFlag,
|
||||
pub normalize_fts_scores: BoolFlag,
|
||||
pub chunk_rrf_use_vector: BoolFlag,
|
||||
pub chunk_rrf_use_fts: BoolFlag,
|
||||
}
|
||||
|
||||
impl RetrievalTuningFlags {
|
||||
pub const fn rerank_scores_only(&self) -> bool {
|
||||
pub const fn rerank_scores_only(self) -> bool {
|
||||
self.rerank_scores_only.as_bool()
|
||||
}
|
||||
|
||||
pub const fn normalize_vector_scores(&self) -> bool {
|
||||
self.normalize_vector_scores.as_bool()
|
||||
}
|
||||
|
||||
pub const fn normalize_fts_scores(&self) -> bool {
|
||||
self.normalize_fts_scores.as_bool()
|
||||
}
|
||||
|
||||
pub const fn chunk_rrf_use_vector(&self) -> bool {
|
||||
pub const fn chunk_rrf_use_vector(self) -> bool {
|
||||
self.chunk_rrf_use_vector.as_bool()
|
||||
}
|
||||
|
||||
pub const fn chunk_rrf_use_fts(&self) -> bool {
|
||||
pub const fn chunk_rrf_use_fts(self) -> bool {
|
||||
self.chunk_rrf_use_fts.as_bool()
|
||||
}
|
||||
}
|
||||
@@ -94,146 +67,70 @@ impl Default for RetrievalTuningFlags {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
rerank_scores_only: BoolFlag::Disabled,
|
||||
normalize_vector_scores: BoolFlag::Disabled,
|
||||
normalize_fts_scores: BoolFlag::Enabled,
|
||||
chunk_rrf_use_vector: BoolFlag::Enabled,
|
||||
chunk_rrf_use_fts: BoolFlag::Enabled,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tunable parameters that govern each retrieval stage.
|
||||
/// Tunable parameters governing the chunk-first hybrid (vector + FTS, RRF-fused) retrieval.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RetrievalTuning {
|
||||
pub entity_vector_take: usize,
|
||||
/// Number of vector candidates to pull from the chunk embedding index.
|
||||
pub chunk_vector_take: usize,
|
||||
pub entity_fts_take: usize,
|
||||
/// Number of full-text candidates to pull from the chunk index.
|
||||
pub chunk_fts_take: usize,
|
||||
pub score_threshold: f32,
|
||||
pub fallback_min_results: usize,
|
||||
pub token_budget_estimate: usize,
|
||||
pub avg_chars_per_token: usize,
|
||||
/// Maximum chunks attached to each resolved entity.
|
||||
pub max_chunks_per_entity: usize,
|
||||
pub lexical_match_weight: f32,
|
||||
pub graph_traversal_seed_limit: usize,
|
||||
pub graph_neighbor_limit: usize,
|
||||
pub graph_score_decay: f32,
|
||||
pub graph_seed_min_score: f32,
|
||||
pub graph_vector_inheritance: f32,
|
||||
/// Blend weight applied when mixing reranker scores with fused scores.
|
||||
pub rerank_blend_weight: f32,
|
||||
pub flags: RetrievalTuningFlags,
|
||||
/// Keep top-N candidates after reranking.
|
||||
pub rerank_keep_top: usize,
|
||||
/// Maximum number of chunks returned to callers.
|
||||
pub chunk_result_cap: usize,
|
||||
/// Optional fusion weights for hybrid search. If None, uses default weights.
|
||||
pub fusion_weights: Option<FusionWeights>,
|
||||
/// Reciprocal rank fusion k value for chunk merging in Revised strategy.
|
||||
#[serde(default = "default_chunk_rrf_k")]
|
||||
/// Reciprocal rank fusion k value for chunk merging.
|
||||
pub chunk_rrf_k: f32,
|
||||
/// Weight applied to vector ranks in RRF.
|
||||
#[serde(default = "default_chunk_rrf_vector_weight")]
|
||||
pub chunk_rrf_vector_weight: f32,
|
||||
/// Weight applied to chunk FTS ranks in RRF.
|
||||
#[serde(default = "default_chunk_rrf_fts_weight")]
|
||||
pub chunk_rrf_fts_weight: f32,
|
||||
pub flags: RetrievalTuningFlags,
|
||||
}
|
||||
|
||||
impl Default for RetrievalTuning {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
entity_vector_take: 15,
|
||||
chunk_vector_take: 20,
|
||||
entity_fts_take: 10,
|
||||
chunk_fts_take: 20,
|
||||
score_threshold: 0.35,
|
||||
fallback_min_results: 10,
|
||||
token_budget_estimate: 10000,
|
||||
avg_chars_per_token: 4,
|
||||
max_chunks_per_entity: 4,
|
||||
lexical_match_weight: 0.15,
|
||||
graph_traversal_seed_limit: 5,
|
||||
graph_neighbor_limit: 6,
|
||||
graph_score_decay: 0.75,
|
||||
graph_seed_min_score: 0.4,
|
||||
graph_vector_inheritance: 0.6,
|
||||
rerank_blend_weight: 0.65,
|
||||
flags: RetrievalTuningFlags::default(),
|
||||
rerank_keep_top: 8,
|
||||
chunk_result_cap: 5,
|
||||
fusion_weights: None,
|
||||
chunk_rrf_k: default_chunk_rrf_k(),
|
||||
chunk_rrf_vector_weight: default_chunk_rrf_vector_weight(),
|
||||
chunk_rrf_fts_weight: default_chunk_rrf_fts_weight(),
|
||||
chunk_rrf_k: 60.0,
|
||||
chunk_rrf_vector_weight: 1.0,
|
||||
chunk_rrf_fts_weight: 1.0,
|
||||
flags: RetrievalTuningFlags::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper containing tuning plus future flags for per-request overrides.
|
||||
/// Per-request retrieval configuration.
|
||||
///
|
||||
/// The pipeline always performs chunk-first hybrid retrieval. Set `resolve_entities`
|
||||
/// when a caller additionally needs the `KnowledgeEntity` rows that own the retrieved
|
||||
/// chunks (search, ingestion linking, relationship suggestion).
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct RetrievalConfig {
|
||||
pub strategy: RetrievalStrategy,
|
||||
pub tuning: RetrievalTuning,
|
||||
/// Target for Search strategy (chunks, entities, or both)
|
||||
pub search_target: SearchTarget,
|
||||
pub resolve_entities: bool,
|
||||
}
|
||||
|
||||
impl RetrievalConfig {
|
||||
pub fn new(tuning: RetrievalTuning) -> Self {
|
||||
/// Chunk retrieval that also resolves the owning knowledge entities.
|
||||
pub fn with_entities() -> Self {
|
||||
Self {
|
||||
strategy: RetrievalStrategy::Default,
|
||||
tuning,
|
||||
search_target: SearchTarget::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_strategy(strategy: RetrievalStrategy) -> Self {
|
||||
Self {
|
||||
strategy,
|
||||
tuning: RetrievalTuning::default(),
|
||||
search_target: SearchTarget::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_tuning(strategy: RetrievalStrategy, tuning: RetrievalTuning) -> Self {
|
||||
Self {
|
||||
strategy,
|
||||
tuning,
|
||||
search_target: SearchTarget::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create config for chat retrieval with strategy selection support
|
||||
pub fn for_chat(strategy: RetrievalStrategy) -> Self {
|
||||
Self::with_strategy(strategy)
|
||||
}
|
||||
|
||||
/// Create config for relationship suggestion (entity-only retrieval)
|
||||
pub fn for_relationship_suggestion() -> Self {
|
||||
Self::with_strategy(RetrievalStrategy::RelationshipSuggestion)
|
||||
}
|
||||
|
||||
/// Create config for ingestion pipeline (entity-only retrieval)
|
||||
pub fn for_ingestion() -> Self {
|
||||
Self::with_strategy(RetrievalStrategy::Ingestion)
|
||||
}
|
||||
|
||||
/// Create config for unified search (chunks and/or entities)
|
||||
pub fn for_search(target: SearchTarget) -> Self {
|
||||
Self {
|
||||
strategy: RetrievalStrategy::Search,
|
||||
tuning: RetrievalTuning::default(),
|
||||
search_target: target,
|
||||
resolve_entities: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const fn default_chunk_rrf_k() -> f32 {
|
||||
60.0
|
||||
}
|
||||
|
||||
const fn default_chunk_rrf_vector_weight() -> f32 {
|
||||
1.0
|
||||
}
|
||||
|
||||
const fn default_chunk_rrf_fts_weight() -> f32 {
|
||||
1.0
|
||||
}
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
use async_openai::Client;
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{db::SurrealDbClient, types::text_chunk::TextChunk},
|
||||
utils::embedding::EmbeddingProvider,
|
||||
};
|
||||
|
||||
use crate::{reranking::RerankerLease, scoring::Scored, RetrievedChunk, RetrievedEntity};
|
||||
|
||||
use super::{
|
||||
config::RetrievalConfig,
|
||||
diagnostics::{AssembleStats, Diagnostics, SearchStats},
|
||||
StageKind, StageTimings, RetrievalParams,
|
||||
};
|
||||
|
||||
/// Mutable working state threaded through every retrieval stage.
|
||||
pub(crate) struct PipelineContext<'a> {
|
||||
pub db_client: &'a SurrealDbClient,
|
||||
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
pub embedding_provider: Option<&'a EmbeddingProvider>,
|
||||
pub input_text: String,
|
||||
pub user_id: String,
|
||||
pub config: RetrievalConfig,
|
||||
pub query_embedding: Option<Vec<f32>>,
|
||||
pub chunk_values: Vec<Scored<TextChunk>>,
|
||||
pub reranker: Option<RerankerLease>,
|
||||
pub diagnostics: Option<Diagnostics>,
|
||||
pub entity_results: Vec<RetrievedEntity>,
|
||||
pub chunk_results: Vec<RetrievedChunk>,
|
||||
stage_timings: StageTimings,
|
||||
}
|
||||
|
||||
impl<'a> PipelineContext<'a> {
|
||||
pub fn new(params: RetrievalParams<'a>) -> Self {
|
||||
Self {
|
||||
db_client: params.db_client,
|
||||
openai_client: params.openai_client,
|
||||
embedding_provider: params.embedding_provider,
|
||||
input_text: params.input_text.to_owned(),
|
||||
user_id: params.user_id.to_owned(),
|
||||
config: params.config,
|
||||
query_embedding: None,
|
||||
chunk_values: Vec::new(),
|
||||
reranker: params.reranker,
|
||||
diagnostics: None,
|
||||
entity_results: Vec::new(),
|
||||
chunk_results: Vec::new(),
|
||||
stage_timings: StageTimings::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_embedding(params: RetrievalParams<'a>, query_embedding: Vec<f32>) -> Self {
|
||||
let mut ctx = Self::new(params);
|
||||
ctx.query_embedding = Some(query_embedding);
|
||||
ctx
|
||||
}
|
||||
|
||||
pub(crate) fn ensure_embedding(&self) -> Result<&Vec<f32>, Box<AppError>> {
|
||||
self.query_embedding.as_ref().ok_or_else(|| {
|
||||
Box::new(AppError::InternalError(
|
||||
"query embedding missing before candidate search".to_string(),
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn enable_diagnostics(&mut self) {
|
||||
if self.diagnostics.is_none() {
|
||||
self.diagnostics = Some(Diagnostics::default());
|
||||
}
|
||||
}
|
||||
|
||||
pub fn diagnostics_enabled(&self) -> bool {
|
||||
self.diagnostics.is_some()
|
||||
}
|
||||
|
||||
pub(crate) fn record_search(&mut self, stats: SearchStats) {
|
||||
if let Some(diag) = self.diagnostics.as_mut() {
|
||||
diag.search = Some(stats);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn record_assemble(&mut self, stats: AssembleStats) {
|
||||
if let Some(diag) = self.diagnostics.as_mut() {
|
||||
diag.assemble = Some(stats);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn take_diagnostics(&mut self) -> Option<Diagnostics> {
|
||||
self.diagnostics.take()
|
||||
}
|
||||
|
||||
pub fn take_stage_timings(&mut self) -> StageTimings {
|
||||
std::mem::take(&mut self.stage_timings)
|
||||
}
|
||||
|
||||
pub fn record_stage_duration(&mut self, kind: StageKind, duration: std::time::Duration) {
|
||||
self.stage_timings.record(kind, duration);
|
||||
}
|
||||
|
||||
pub fn take_entity_results(&mut self) -> Vec<RetrievedEntity> {
|
||||
std::mem::take(&mut self.entity_results)
|
||||
}
|
||||
|
||||
pub fn take_chunk_results(&mut self) -> Vec<RetrievedChunk> {
|
||||
std::mem::take(&mut self.chunk_results)
|
||||
}
|
||||
}
|
||||
@@ -1,51 +1,21 @@
|
||||
use serde::Serialize;
|
||||
|
||||
/// Captures instrumentation for each hybrid retrieval stage when diagnostics are enabled.
|
||||
/// Captures instrumentation for the retrieval stages when diagnostics are enabled.
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct Diagnostics {
|
||||
pub collect_candidates: Option<CollectCandidatesStats>,
|
||||
pub enrich_chunks_from_entities: Option<ChunkEnrichmentStats>,
|
||||
pub search: Option<SearchStats>,
|
||||
pub assemble: Option<AssembleStats>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct CollectCandidatesStats {
|
||||
pub vector_entity_candidates: usize,
|
||||
pub struct SearchStats {
|
||||
pub vector_chunk_candidates: usize,
|
||||
pub fts_entity_candidates: usize,
|
||||
pub fts_chunk_candidates: usize,
|
||||
pub vector_chunk_scores: Vec<f32>,
|
||||
pub fts_chunk_scores: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct ChunkEnrichmentStats {
|
||||
pub filtered_entity_count: usize,
|
||||
pub fallback_min_results: usize,
|
||||
pub chunk_sources_considered: usize,
|
||||
pub chunk_candidates_before_enrichment: usize,
|
||||
pub chunk_candidates_after_enrichment: usize,
|
||||
pub top_chunk_scores: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct AssembleStats {
|
||||
pub token_budget_start: usize,
|
||||
pub token_budget_spent: usize,
|
||||
pub token_budget_remaining: usize,
|
||||
pub budget_exhausted: bool,
|
||||
pub chunks_selected: usize,
|
||||
pub chunks_skipped_due_budget: usize,
|
||||
pub entity_count: usize,
|
||||
pub entity_traces: Vec<EntityAssemblyTrace>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct EntityAssemblyTrace {
|
||||
pub entity_id: String,
|
||||
pub source_id: String,
|
||||
pub inspected_candidates: usize,
|
||||
pub selected_chunk_ids: Vec<String>,
|
||||
pub selected_chunk_scores: Vec<f32>,
|
||||
pub skipped_due_budget: usize,
|
||||
}
|
||||
|
||||
@@ -1,61 +1,68 @@
|
||||
mod config;
|
||||
mod context;
|
||||
mod diagnostics;
|
||||
mod stages;
|
||||
mod strategies;
|
||||
|
||||
pub use config::{
|
||||
RetrievalConfig, RetrievalStrategy, RetrievalTuning, RetrievalTuningFlags, SearchTarget,
|
||||
};
|
||||
pub use diagnostics::{
|
||||
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace, Diagnostics,
|
||||
};
|
||||
pub use config::RetrievalConfig;
|
||||
pub use diagnostics::Diagnostics;
|
||||
|
||||
use crate::{reranking::RerankerLease, RetrievedEntity, StrategyOutput};
|
||||
use crate::{round_score, RetrievalOutput, RetrievedEntity};
|
||||
use async_openai::Client;
|
||||
use async_trait::async_trait;
|
||||
use common::{error::AppError, storage::db::SurrealDbClient};
|
||||
use std::time::{Duration, Instant};
|
||||
use tracing::info;
|
||||
|
||||
use stages::PipelineContext;
|
||||
use strategies::{
|
||||
DefaultStrategyDriver, IngestionDriver, RelationshipSuggestionDriver, SearchStrategyDriver,
|
||||
use stages::{
|
||||
ChunkAssembleStage, ChunkRerankStage, ChunkSearchStage, EmbedStage, ResolveEntitiesStage,
|
||||
};
|
||||
|
||||
// Export StrategyOutput publicly from this module
|
||||
// (it's defined in lib.rs but we re-export it here)
|
||||
|
||||
// Stage type enum
|
||||
/// Identifies a retrieval stage for timing and instrumentation.
|
||||
///
|
||||
/// [`StageKind::ALL`] lists every kind in pipeline order; consumers (e.g. the evaluation
|
||||
/// harness) iterate it generically so that adding a stage requires no changes outside this
|
||||
/// crate — add the variant, extend `ALL`, and give it a [`StageKind::label`].
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum StageKind {
|
||||
Embed,
|
||||
CollectCandidates,
|
||||
GraphExpansion,
|
||||
ChunkAttach,
|
||||
Search,
|
||||
Rerank,
|
||||
ResolveEntities,
|
||||
Assemble,
|
||||
}
|
||||
|
||||
// Pipeline stage trait
|
||||
impl StageKind {
|
||||
/// Every stage kind in canonical pipeline order.
|
||||
pub const ALL: [StageKind; 5] = [
|
||||
StageKind::Embed,
|
||||
StageKind::Search,
|
||||
StageKind::Rerank,
|
||||
StageKind::ResolveEntities,
|
||||
StageKind::Assemble,
|
||||
];
|
||||
|
||||
/// Stable, machine-friendly identifier for the stage (used as a metrics key).
|
||||
pub const fn label(self) -> &'static str {
|
||||
match self {
|
||||
StageKind::Embed => "embed",
|
||||
StageKind::Search => "search",
|
||||
StageKind::Rerank => "rerank",
|
||||
StageKind::ResolveEntities => "resolve_entities",
|
||||
StageKind::Assemble => "assemble",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A single composable step in the retrieval pipeline.
|
||||
#[async_trait]
|
||||
pub trait Stage: Send + Sync {
|
||||
pub(crate) trait Stage: Send + Sync {
|
||||
fn kind(&self) -> StageKind;
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError>;
|
||||
async fn execute(&self, ctx: &mut context::PipelineContext<'_>) -> Result<(), AppError>;
|
||||
}
|
||||
|
||||
// Type alias for boxed stages
|
||||
pub type BoxedStage = Box<dyn Stage>;
|
||||
pub(crate) type BoxedStage = Box<dyn Stage>;
|
||||
|
||||
// Strategy driver trait
|
||||
#[async_trait]
|
||||
pub trait StrategyDriver: Send + Sync {
|
||||
type Output;
|
||||
|
||||
fn stages(&self) -> Vec<BoxedStage>;
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>>;
|
||||
}
|
||||
|
||||
// Pipeline stage timings tracker
|
||||
/// Per-stage execution timings recorded during a run.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct StageTimings {
|
||||
timings: Vec<(StageKind, Duration)>,
|
||||
@@ -66,41 +73,13 @@ impl StageTimings {
|
||||
self.timings.push((kind, duration));
|
||||
}
|
||||
|
||||
pub fn into_vec(self) -> Vec<(StageKind, Duration)> {
|
||||
self.timings
|
||||
}
|
||||
|
||||
// Helper methods to get duration for each stage type (for backward compatibility)
|
||||
fn get_stage_ms(&self, kind: StageKind) -> u128 {
|
||||
/// Milliseconds recorded for `kind`, or `0` if the stage did not run.
|
||||
pub fn stage_ms(&self, kind: StageKind) -> u128 {
|
||||
self.timings
|
||||
.iter()
|
||||
.find(|(k, _)| *k == kind)
|
||||
.map_or(0, |(_, d)| d.as_millis())
|
||||
}
|
||||
|
||||
pub fn embed_ms(&self) -> u128 {
|
||||
self.get_stage_ms(StageKind::Embed)
|
||||
}
|
||||
|
||||
pub fn collect_candidates_ms(&self) -> u128 {
|
||||
self.get_stage_ms(StageKind::CollectCandidates)
|
||||
}
|
||||
|
||||
pub fn graph_expansion_ms(&self) -> u128 {
|
||||
self.get_stage_ms(StageKind::GraphExpansion)
|
||||
}
|
||||
|
||||
pub fn chunk_attach_ms(&self) -> u128 {
|
||||
self.get_stage_ms(StageKind::ChunkAttach)
|
||||
}
|
||||
|
||||
pub fn rerank_ms(&self) -> u128 {
|
||||
self.get_stage_ms(StageKind::Rerank)
|
||||
}
|
||||
|
||||
pub fn assemble_ms(&self) -> u128 {
|
||||
self.get_stage_ms(StageKind::Assemble)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RunOutput<T> {
|
||||
@@ -109,7 +88,35 @@ pub struct RunOutput<T> {
|
||||
pub stage_timings: StageTimings,
|
||||
}
|
||||
|
||||
pub async fn execute(params: StrategyParams<'_>) -> Result<StrategyOutput, AppError> {
|
||||
/// Inputs required to run a retrieval.
|
||||
pub struct RetrievalParams<'a> {
|
||||
pub db_client: &'a SurrealDbClient,
|
||||
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
pub embedding_provider: Option<&'a common::utils::embedding::EmbeddingProvider>,
|
||||
pub input_text: &'a str,
|
||||
pub user_id: &'a str,
|
||||
pub config: RetrievalConfig,
|
||||
pub reranker: Option<crate::reranking::RerankerLease>,
|
||||
}
|
||||
|
||||
fn build_stages(config: &RetrievalConfig) -> Vec<BoxedStage> {
|
||||
let mut stages: Vec<BoxedStage> = vec![
|
||||
Box::new(EmbedStage),
|
||||
Box::new(ChunkSearchStage),
|
||||
Box::new(ChunkRerankStage),
|
||||
];
|
||||
if config.resolve_entities {
|
||||
stages.push(Box::new(ResolveEntitiesStage));
|
||||
}
|
||||
stages.push(Box::new(ChunkAssembleStage));
|
||||
stages
|
||||
}
|
||||
|
||||
async fn run(
|
||||
params: RetrievalParams<'_>,
|
||||
query_embedding: Option<Vec<f32>>,
|
||||
capture_diagnostics: bool,
|
||||
) -> Result<RunOutput<RetrievalOutput>, AppError> {
|
||||
let input_chars = params.input_text.chars().count();
|
||||
let input_preview: String = params.input_text.chars().take(120).collect();
|
||||
let input_preview_clean = input_preview.replace('\n', " ");
|
||||
@@ -119,110 +126,67 @@ pub async fn execute(params: StrategyParams<'_>) -> Result<StrategyOutput, AppEr
|
||||
input_chars,
|
||||
preview_truncated = input_chars > preview_len,
|
||||
preview = %input_preview_clean,
|
||||
strategy = %params.config.strategy,
|
||||
resolve_entities = params.config.resolve_entities,
|
||||
"Starting retrieval pipeline"
|
||||
);
|
||||
|
||||
let strategy = params.config.strategy;
|
||||
let search_target = params.config.search_target;
|
||||
let resolve_entities = params.config.resolve_entities;
|
||||
let mut ctx = match query_embedding {
|
||||
Some(embedding) => context::PipelineContext::with_embedding(params, embedding),
|
||||
None => context::PipelineContext::new(params),
|
||||
};
|
||||
|
||||
match strategy {
|
||||
RetrievalStrategy::Default => {
|
||||
let driver = DefaultStrategyDriver::new();
|
||||
let run = execute_strategy(driver, params, None, false).await?;
|
||||
Ok(StrategyOutput::Chunks(run.results))
|
||||
}
|
||||
RetrievalStrategy::RelationshipSuggestion => {
|
||||
let driver = RelationshipSuggestionDriver::new();
|
||||
let run = execute_strategy(driver, params, None, false).await?;
|
||||
Ok(StrategyOutput::Entities(run.results))
|
||||
}
|
||||
RetrievalStrategy::Ingestion => {
|
||||
let driver = IngestionDriver::new();
|
||||
let run = execute_strategy(driver, params, None, false).await?;
|
||||
Ok(StrategyOutput::Entities(run.results))
|
||||
}
|
||||
RetrievalStrategy::Search => {
|
||||
let driver = SearchStrategyDriver::new(search_target);
|
||||
let run = execute_strategy(driver, params, None, false).await?;
|
||||
Ok(StrategyOutput::Search(run.results))
|
||||
}
|
||||
if capture_diagnostics {
|
||||
ctx.enable_diagnostics();
|
||||
}
|
||||
|
||||
for stage in build_stages(&ctx.config) {
|
||||
let start = Instant::now();
|
||||
stage.execute(&mut ctx).await?;
|
||||
ctx.record_stage_duration(stage.kind(), start.elapsed());
|
||||
}
|
||||
|
||||
let diagnostics = ctx.take_diagnostics();
|
||||
let stage_timings = ctx.take_stage_timings();
|
||||
let chunks = ctx.take_chunk_results();
|
||||
let results = if resolve_entities {
|
||||
RetrievalOutput::WithEntities {
|
||||
chunks,
|
||||
entities: ctx.take_entity_results(),
|
||||
}
|
||||
} else {
|
||||
RetrievalOutput::Chunks(chunks)
|
||||
};
|
||||
|
||||
Ok(RunOutput {
|
||||
results,
|
||||
diagnostics,
|
||||
stage_timings,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn run_pipeline_with_embedding(
|
||||
params: StrategyParams<'_>,
|
||||
query_embedding: Vec<f32>,
|
||||
) -> Result<StrategyOutput, AppError> {
|
||||
let strategy = params.config.strategy;
|
||||
let search_target = params.config.search_target;
|
||||
|
||||
match strategy {
|
||||
RetrievalStrategy::Default => {
|
||||
let driver = DefaultStrategyDriver::new();
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
||||
Ok(StrategyOutput::Chunks(run.results))
|
||||
}
|
||||
RetrievalStrategy::RelationshipSuggestion => {
|
||||
let driver = RelationshipSuggestionDriver::new();
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
||||
Ok(StrategyOutput::Entities(run.results))
|
||||
}
|
||||
RetrievalStrategy::Ingestion => {
|
||||
let driver = IngestionDriver::new();
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
||||
Ok(StrategyOutput::Entities(run.results))
|
||||
}
|
||||
RetrievalStrategy::Search => {
|
||||
let driver = SearchStrategyDriver::new(search_target);
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
||||
Ok(StrategyOutput::Search(run.results))
|
||||
}
|
||||
}
|
||||
/// Run the retrieval pipeline, generating the query embedding internally if needed.
|
||||
pub async fn execute(params: RetrievalParams<'_>) -> Result<RetrievalOutput, AppError> {
|
||||
Ok(run(params, None, false).await?.results)
|
||||
}
|
||||
|
||||
pub async fn run_pipeline_with_embedding_with_metrics(
|
||||
params: StrategyParams<'_>,
|
||||
/// Run the retrieval pipeline with a pre-computed query embedding.
|
||||
pub async fn run_with_embedding(
|
||||
params: RetrievalParams<'_>,
|
||||
query_embedding: Vec<f32>,
|
||||
) -> Result<RunOutput<StrategyOutput>, AppError> {
|
||||
let strategy = params.config.strategy;
|
||||
|
||||
match strategy {
|
||||
RetrievalStrategy::Default => {
|
||||
let driver = DefaultStrategyDriver::new();
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
||||
Ok(RunOutput {
|
||||
results: StrategyOutput::Chunks(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
})
|
||||
}
|
||||
_ => Err(AppError::InternalError(
|
||||
"Metrics not supported for this strategy".into(),
|
||||
)),
|
||||
}
|
||||
) -> Result<RetrievalOutput, AppError> {
|
||||
Ok(run(params, Some(query_embedding), false).await?.results)
|
||||
}
|
||||
|
||||
pub async fn run_pipeline_with_embedding_with_diagnostics(
|
||||
params: StrategyParams<'_>,
|
||||
/// Run with a pre-computed embedding, returning results and per-stage timings.
|
||||
///
|
||||
/// When `capture_diagnostics` is true, pipeline search/assemble stats are included.
|
||||
pub async fn run_with_embedding_instrumented(
|
||||
params: RetrievalParams<'_>,
|
||||
query_embedding: Vec<f32>,
|
||||
) -> Result<RunOutput<StrategyOutput>, AppError> {
|
||||
let strategy = params.config.strategy;
|
||||
|
||||
match strategy {
|
||||
RetrievalStrategy::Default => {
|
||||
let driver = DefaultStrategyDriver::new();
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), true).await?;
|
||||
Ok(RunOutput {
|
||||
results: StrategyOutput::Chunks(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
})
|
||||
}
|
||||
_ => Err(AppError::InternalError(
|
||||
"Diagnostics not supported for this strategy".into(),
|
||||
)),
|
||||
}
|
||||
capture_diagnostics: bool,
|
||||
) -> Result<RunOutput<RetrievalOutput>, AppError> {
|
||||
run(params, Some(query_embedding), capture_diagnostics).await
|
||||
}
|
||||
|
||||
pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value {
|
||||
@@ -246,57 +210,3 @@ pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::V
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
}
|
||||
|
||||
pub struct StrategyParams<'a> {
|
||||
pub db_client: &'a SurrealDbClient,
|
||||
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
pub embedding_provider: Option<&'a common::utils::embedding::EmbeddingProvider>,
|
||||
pub input_text: &'a str,
|
||||
pub user_id: &'a str,
|
||||
pub config: RetrievalConfig,
|
||||
pub reranker: Option<RerankerLease>,
|
||||
}
|
||||
|
||||
async fn execute_strategy<D: StrategyDriver>(
|
||||
driver: D,
|
||||
params: StrategyParams<'_>,
|
||||
query_embedding: Option<Vec<f32>>,
|
||||
capture_diagnostics: bool,
|
||||
) -> Result<RunOutput<D::Output>, AppError> {
|
||||
let ctx = match query_embedding {
|
||||
Some(embedding) => PipelineContext::with_embedding(params, embedding),
|
||||
None => PipelineContext::new(params),
|
||||
};
|
||||
|
||||
run_with_driver(driver, ctx, capture_diagnostics).await
|
||||
}
|
||||
|
||||
async fn run_with_driver<D: StrategyDriver>(
|
||||
driver: D,
|
||||
mut ctx: PipelineContext<'_>,
|
||||
capture_diagnostics: bool,
|
||||
) -> Result<RunOutput<D::Output>, AppError> {
|
||||
if capture_diagnostics {
|
||||
ctx.enable_diagnostics();
|
||||
}
|
||||
|
||||
for stage in driver.stages() {
|
||||
let start = Instant::now();
|
||||
stage.execute(&mut ctx).await?;
|
||||
ctx.record_stage_duration(stage.kind(), start.elapsed());
|
||||
}
|
||||
|
||||
let diagnostics = ctx.take_diagnostics();
|
||||
let stage_timings = ctx.take_stage_timings();
|
||||
let results = driver.finalize(&mut ctx).map_err(|e| *e)?;
|
||||
|
||||
Ok(RunOutput {
|
||||
results,
|
||||
diagnostics,
|
||||
stage_timings,
|
||||
})
|
||||
}
|
||||
|
||||
fn round_score(value: f32) -> f64 {
|
||||
(f64::from(value) * 1000.0).round() / 1000.0
|
||||
}
|
||||
|
||||
@@ -0,0 +1,424 @@
|
||||
use async_trait::async_trait;
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
|
||||
utils::embedding::generate_embedding,
|
||||
};
|
||||
use fastembed::RerankResult;
|
||||
use std::collections::HashMap;
|
||||
use tracing::{debug, instrument, warn};
|
||||
|
||||
use crate::{
|
||||
scoring::{
|
||||
clamp_unit, min_max_normalize, reciprocal_rank_fusion, RrfConfig, Scored,
|
||||
},
|
||||
RetrievedChunk, RetrievedEntity,
|
||||
};
|
||||
|
||||
use super::{
|
||||
config::RetrievalTuning,
|
||||
context::PipelineContext,
|
||||
diagnostics::{AssembleStats, SearchStats},
|
||||
Stage, StageKind,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct EmbedStage;
|
||||
|
||||
#[async_trait]
|
||||
impl Stage for EmbedStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Embed
|
||||
}
|
||||
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
embed(ctx).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ChunkSearchStage;
|
||||
|
||||
#[async_trait]
|
||||
impl Stage for ChunkSearchStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Search
|
||||
}
|
||||
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
search_chunks(ctx).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ChunkRerankStage;
|
||||
|
||||
#[async_trait]
|
||||
impl Stage for ChunkRerankStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Rerank
|
||||
}
|
||||
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
rerank_chunks(ctx).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ResolveEntitiesStage;
|
||||
|
||||
#[async_trait]
|
||||
impl Stage for ResolveEntitiesStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::ResolveEntities
|
||||
}
|
||||
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
resolve_entities(ctx).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ChunkAssembleStage;
|
||||
|
||||
#[async_trait]
|
||||
impl Stage for ChunkAssembleStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Assemble
|
||||
}
|
||||
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
assemble_chunks(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
if ctx.query_embedding.is_some() {
|
||||
debug!("Reusing cached query embedding for hybrid retrieval");
|
||||
} else {
|
||||
debug!("Generating query embedding for hybrid retrieval");
|
||||
let embedding = if let Some(provider) = ctx.embedding_provider {
|
||||
provider.embed(&ctx.input_text).await?
|
||||
} else {
|
||||
generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await?
|
||||
};
|
||||
ctx.query_embedding = Some(embedding);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn search_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
debug!("Collecting chunk candidates via vector and FTS search");
|
||||
let embedding = ctx.ensure_embedding().map_err(|e| *e)?.clone();
|
||||
let tuning = &ctx.config.tuning;
|
||||
let fts_take = tuning.chunk_fts_take;
|
||||
let (fts_query, fts_token_count) = normalize_fts_query(&ctx.input_text);
|
||||
let fts_enabled = tuning.flags.chunk_rrf_use_fts() && fts_take > 0 && !fts_query.is_empty();
|
||||
|
||||
let (vector_rows, fts_rows) = tokio::try_join!(
|
||||
TextChunk::vector_search(
|
||||
tuning.chunk_vector_take,
|
||||
embedding,
|
||||
ctx.db_client,
|
||||
&ctx.user_id,
|
||||
),
|
||||
async {
|
||||
if fts_enabled {
|
||||
TextChunk::fts_search(fts_take, &fts_query, ctx.db_client, &ctx.user_id).await
|
||||
} else {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
}
|
||||
)?;
|
||||
|
||||
let vector_candidates = vector_rows.len();
|
||||
let fts_candidates = fts_rows.len();
|
||||
|
||||
let vector_scored: Vec<Scored<TextChunk>> = vector_rows
|
||||
.into_iter()
|
||||
.map(|row| Scored::new(row.chunk).with_vector_score(row.score))
|
||||
.collect();
|
||||
|
||||
let fts_scored: Vec<Scored<TextChunk>> = fts_rows
|
||||
.into_iter()
|
||||
.map(|row| Scored::new(row.chunk).with_fts_score(row.score))
|
||||
.collect();
|
||||
|
||||
let mut fts_weight = tuning.chunk_rrf_fts_weight;
|
||||
if fts_enabled && fts_token_count > 0 && fts_token_count <= 3 {
|
||||
// For very short keyword queries, lean more on lexical ranking.
|
||||
fts_weight *= 1.5;
|
||||
}
|
||||
|
||||
let rrf_config = RrfConfig {
|
||||
k: tuning.chunk_rrf_k,
|
||||
vector_weight: tuning.chunk_rrf_vector_weight,
|
||||
fts_weight,
|
||||
use_vector: tuning.flags.chunk_rrf_use_vector(),
|
||||
use_fts: tuning.flags.chunk_rrf_use_fts() && fts_candidates > 0,
|
||||
};
|
||||
|
||||
let chunks = reciprocal_rank_fusion(vector_scored, fts_scored, rrf_config);
|
||||
|
||||
debug!(
|
||||
total_merged = chunks.len(),
|
||||
vector_only = chunks.iter().filter(|c| c.scores.fts.is_none()).count(),
|
||||
fts_only = chunks.iter().filter(|c| c.scores.vector.is_none()).count(),
|
||||
both_signals = chunks
|
||||
.iter()
|
||||
.filter(|c| c.scores.vector.is_some() && c.scores.fts.is_some())
|
||||
.count(),
|
||||
rrf_k = rrf_config.k,
|
||||
"Merged chunk candidates with RRF"
|
||||
);
|
||||
|
||||
if ctx.diagnostics_enabled() {
|
||||
ctx.record_search(SearchStats {
|
||||
vector_chunk_candidates: vector_candidates,
|
||||
fts_chunk_candidates: fts_candidates,
|
||||
vector_chunk_scores: sample_scores(&chunks, |chunk| chunk.scores.vector.unwrap_or(0.0)),
|
||||
fts_chunk_scores: sample_scores(&chunks, |chunk| chunk.scores.fts.unwrap_or(0.0)),
|
||||
});
|
||||
}
|
||||
|
||||
ctx.chunk_values = chunks;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
if ctx.chunk_values.len() <= 1 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let Some(reranker) = ctx.reranker.as_ref() else {
|
||||
debug!("No reranker lease provided; skipping chunk rerank stage");
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let documents =
|
||||
build_chunk_rerank_documents(&ctx.chunk_values, ctx.config.tuning.rerank_keep_top.max(1));
|
||||
if documents.len() <= 1 {
|
||||
debug!("Skipping chunk reranking stage; insufficient chunk documents");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
match reranker.rerank(&ctx.input_text, documents).await {
|
||||
Ok(results) if !results.is_empty() => {
|
||||
apply_chunk_rerank_results(&mut ctx.chunk_values, &ctx.config.tuning, results);
|
||||
}
|
||||
Ok(_) => debug!("Chunk reranker returned no results; retaining original order"),
|
||||
Err(err) => warn!(
|
||||
error = %err,
|
||||
"Chunk reranking failed; continuing with original ordering"
|
||||
),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Resolve the `KnowledgeEntity` rows that own the retrieved chunks.
|
||||
///
|
||||
/// Entities are derived directly from the (benchmarked) chunk retrieval: chunks are grouped
|
||||
/// by `source_id`, the owning entities are loaded, scored by their best contributing chunk,
|
||||
/// and the contributing chunks are attached.
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn resolve_entities(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
if ctx.chunk_values.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let max_chunks = ctx.config.tuning.max_chunks_per_entity.max(1);
|
||||
|
||||
let mut source_order: Vec<String> = Vec::new();
|
||||
let mut chunks_by_source: HashMap<String, Vec<RetrievedChunk>> = HashMap::new();
|
||||
let mut best_score: HashMap<String, f32> = HashMap::new();
|
||||
|
||||
for scored in &ctx.chunk_values {
|
||||
let source = scored.item.source_id.clone();
|
||||
let attached = chunks_by_source.entry(source.clone()).or_default();
|
||||
if attached.is_empty() {
|
||||
source_order.push(source.clone());
|
||||
best_score.insert(source.clone(), scored.fused);
|
||||
}
|
||||
if attached.len() < max_chunks {
|
||||
attached.push(RetrievedChunk {
|
||||
chunk: scored.item.clone(),
|
||||
score: scored.fused,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let entities =
|
||||
KnowledgeEntity::find_by_source_ids(ctx.db_client, &source_order, &ctx.user_id).await?;
|
||||
|
||||
let mut entities_by_source: HashMap<String, Vec<KnowledgeEntity>> = HashMap::new();
|
||||
for entity in entities {
|
||||
entities_by_source
|
||||
.entry(entity.source_id.clone())
|
||||
.or_default()
|
||||
.push(entity);
|
||||
}
|
||||
|
||||
let mut results = Vec::new();
|
||||
for source in &source_order {
|
||||
let Some(entities) = entities_by_source.remove(source) else {
|
||||
continue;
|
||||
};
|
||||
let score = best_score.get(source).copied().unwrap_or(0.0);
|
||||
let chunks = chunks_by_source.get(source).cloned().unwrap_or_default();
|
||||
for entity in entities {
|
||||
results.push(RetrievedEntity {
|
||||
entity,
|
||||
score,
|
||||
chunks: chunks.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
debug!(
|
||||
sources = source_order.len(),
|
||||
entities = results.len(),
|
||||
"Resolved entities from retrieved chunks"
|
||||
);
|
||||
|
||||
ctx.entity_results = results;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
#[allow(clippy::result_large_err)]
|
||||
pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
debug!("Assembling chunk retrieval results");
|
||||
let mut chunk_values = std::mem::take(&mut ctx.chunk_values);
|
||||
// Limit how many chunks we return to keep context size reasonable.
|
||||
let limit = ctx
|
||||
.config
|
||||
.tuning
|
||||
.chunk_result_cap
|
||||
.max(1)
|
||||
.min(ctx.config.tuning.chunk_vector_take.max(1));
|
||||
|
||||
if chunk_values.len() > limit {
|
||||
chunk_values.truncate(limit);
|
||||
}
|
||||
|
||||
ctx.chunk_results = chunk_values
|
||||
.into_iter()
|
||||
.map(|chunk| RetrievedChunk {
|
||||
chunk: chunk.item,
|
||||
score: chunk.fused,
|
||||
})
|
||||
.collect();
|
||||
|
||||
if ctx.diagnostics_enabled() {
|
||||
ctx.record_assemble(AssembleStats {
|
||||
chunks_selected: ctx.chunk_results.len(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
const SCORE_SAMPLE_LIMIT: usize = 8;
|
||||
|
||||
fn sample_scores<T, F>(items: &[Scored<T>], extractor: F) -> Vec<f32>
|
||||
where
|
||||
F: FnMut(&Scored<T>) -> f32,
|
||||
{
|
||||
items.iter().take(SCORE_SAMPLE_LIMIT).map(extractor).collect()
|
||||
}
|
||||
|
||||
fn normalize_fts_query(input: &str) -> (String, usize) {
|
||||
const STOPWORDS: &[&str] = &["the", "a", "an", "of", "in", "on", "and", "or", "to", "for"];
|
||||
let mut cleaned = String::with_capacity(input.len());
|
||||
for ch in input.chars() {
|
||||
if ch.is_alphanumeric() {
|
||||
cleaned.extend(ch.to_lowercase());
|
||||
} else if ch.is_whitespace() {
|
||||
cleaned.push(' ');
|
||||
}
|
||||
}
|
||||
let mut tokens = Vec::with_capacity(cleaned.len().div_ceil(3));
|
||||
for token in cleaned.split_whitespace() {
|
||||
if !STOPWORDS.contains(&token) && !token.is_empty() {
|
||||
tokens.push(token.to_string());
|
||||
}
|
||||
}
|
||||
let normalized = tokens.join(" ");
|
||||
(normalized, tokens.len())
|
||||
}
|
||||
|
||||
fn build_chunk_rerank_documents(chunks: &[Scored<TextChunk>], max_chunks: usize) -> Vec<String> {
|
||||
chunks
|
||||
.iter()
|
||||
.take(max_chunks)
|
||||
.map(|chunk| {
|
||||
format!(
|
||||
"Source: {}\nChunk:\n{}",
|
||||
chunk.item.source_id,
|
||||
chunk.item.chunk.trim()
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn apply_chunk_rerank_results(
|
||||
chunks: &mut Vec<Scored<TextChunk>>,
|
||||
tuning: &RetrievalTuning,
|
||||
results: Vec<RerankResult>,
|
||||
) {
|
||||
if results.is_empty() || chunks.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut remaining: Vec<Option<Scored<TextChunk>>> =
|
||||
std::mem::take(chunks).into_iter().map(Some).collect();
|
||||
|
||||
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
|
||||
let normalized_scores = min_max_normalize(&raw_scores);
|
||||
|
||||
let use_only = tuning.flags.rerank_scores_only();
|
||||
let blend = if use_only {
|
||||
1.0
|
||||
} else {
|
||||
clamp_unit(tuning.rerank_blend_weight)
|
||||
};
|
||||
|
||||
let mut reranked: Vec<Scored<TextChunk>> = Vec::with_capacity(remaining.len());
|
||||
for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) {
|
||||
if let Some(slot) = remaining.get_mut(result.index) {
|
||||
if let Some(mut candidate) = slot.take() {
|
||||
let original = candidate.fused;
|
||||
let blended = if use_only {
|
||||
clamp_unit(normalized)
|
||||
} else {
|
||||
clamp_unit(original * (1.0 - blend) + normalized * blend)
|
||||
};
|
||||
candidate.update_fused(blended);
|
||||
reranked.push(candidate);
|
||||
}
|
||||
} else {
|
||||
warn!(
|
||||
result_index = result.index,
|
||||
"Chunk reranker returned out-of-range index; skipping"
|
||||
);
|
||||
}
|
||||
if reranked.len() == remaining.len() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
reranked.extend(remaining.into_iter().flatten());
|
||||
|
||||
let keep_top = tuning.rerank_keep_top;
|
||||
if keep_top > 0 && reranked.len() > keep_top {
|
||||
reranked.truncate(keep_top);
|
||||
}
|
||||
|
||||
*chunks = reranked;
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,148 +0,0 @@
|
||||
use super::{
|
||||
stages::{
|
||||
AssembleEntitiesStage, ChunkAssembleStage, ChunkRerankStage, ChunkVectorStage,
|
||||
CollectCandidatesStage, EmbedStage, GraphExpansionStage, PipelineContext, RerankStage,
|
||||
},
|
||||
BoxedStage, StrategyDriver,
|
||||
};
|
||||
use crate::{RetrievedChunk, RetrievedEntity};
|
||||
use common::error::AppError;
|
||||
|
||||
pub struct DefaultStrategyDriver;
|
||||
|
||||
impl DefaultStrategyDriver {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl StrategyDriver for DefaultStrategyDriver {
|
||||
type Output = Vec<RetrievedChunk>;
|
||||
|
||||
fn stages(&self) -> Vec<BoxedStage> {
|
||||
vec![
|
||||
Box::new(EmbedStage),
|
||||
Box::new(ChunkVectorStage),
|
||||
Box::new(ChunkRerankStage),
|
||||
Box::new(ChunkAssembleStage),
|
||||
]
|
||||
}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
|
||||
Ok(ctx.take_chunk_results())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RelationshipSuggestionDriver;
|
||||
|
||||
impl RelationshipSuggestionDriver {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl StrategyDriver for RelationshipSuggestionDriver {
|
||||
type Output = Vec<RetrievedEntity>;
|
||||
|
||||
fn stages(&self) -> Vec<BoxedStage> {
|
||||
vec![
|
||||
Box::new(EmbedStage),
|
||||
Box::new(CollectCandidatesStage),
|
||||
Box::new(GraphExpansionStage),
|
||||
// Skip ChunkAttachStage
|
||||
Box::new(RerankStage),
|
||||
Box::new(AssembleEntitiesStage),
|
||||
]
|
||||
}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
|
||||
Ok(ctx.take_entity_results())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct IngestionDriver;
|
||||
|
||||
impl IngestionDriver {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl StrategyDriver for IngestionDriver {
|
||||
type Output = Vec<RetrievedEntity>;
|
||||
|
||||
fn stages(&self) -> Vec<BoxedStage> {
|
||||
vec![
|
||||
Box::new(EmbedStage),
|
||||
Box::new(CollectCandidatesStage),
|
||||
Box::new(GraphExpansionStage),
|
||||
// Skip ChunkAttachStage
|
||||
Box::new(RerankStage),
|
||||
Box::new(AssembleEntitiesStage),
|
||||
]
|
||||
}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
|
||||
Ok(ctx.take_entity_results())
|
||||
}
|
||||
}
|
||||
|
||||
use super::config::SearchTarget;
|
||||
use crate::SearchResult;
|
||||
|
||||
/// Search strategy driver that retrieves both chunks and entities
|
||||
pub struct SearchStrategyDriver {
|
||||
target: SearchTarget,
|
||||
}
|
||||
|
||||
impl SearchStrategyDriver {
|
||||
pub fn new(target: SearchTarget) -> Self {
|
||||
Self { target }
|
||||
}
|
||||
}
|
||||
|
||||
impl StrategyDriver for SearchStrategyDriver {
|
||||
type Output = SearchResult;
|
||||
|
||||
fn stages(&self) -> Vec<BoxedStage> {
|
||||
match self.target {
|
||||
SearchTarget::ChunksOnly => vec![
|
||||
Box::new(EmbedStage),
|
||||
Box::new(ChunkVectorStage),
|
||||
Box::new(ChunkRerankStage),
|
||||
Box::new(ChunkAssembleStage),
|
||||
],
|
||||
SearchTarget::EntitiesOnly => vec![
|
||||
Box::new(EmbedStage),
|
||||
Box::new(CollectCandidatesStage),
|
||||
Box::new(GraphExpansionStage),
|
||||
Box::new(RerankStage),
|
||||
Box::new(AssembleEntitiesStage),
|
||||
],
|
||||
SearchTarget::Both => vec![
|
||||
Box::new(EmbedStage),
|
||||
// Chunk retrieval path
|
||||
Box::new(ChunkVectorStage),
|
||||
Box::new(ChunkRerankStage),
|
||||
Box::new(ChunkAssembleStage),
|
||||
// Entity retrieval path (runs after chunk stages)
|
||||
Box::new(CollectCandidatesStage),
|
||||
Box::new(GraphExpansionStage),
|
||||
Box::new(RerankStage),
|
||||
Box::new(AssembleEntitiesStage),
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
|
||||
let chunks = match self.target {
|
||||
SearchTarget::EntitiesOnly => Vec::new(),
|
||||
_ => ctx.take_chunk_results(),
|
||||
};
|
||||
let entities = match self.target {
|
||||
SearchTarget::ChunksOnly => Vec::new(),
|
||||
_ => ctx.take_entity_results(),
|
||||
};
|
||||
Ok(SearchResult::new(chunks, entities))
|
||||
}
|
||||
}
|
||||
@@ -97,8 +97,7 @@ impl RerankerPool {
|
||||
|
||||
fn default_pool_size() -> usize {
|
||||
available_parallelism()
|
||||
.map(|value| value.get().min(2))
|
||||
.unwrap_or(2)
|
||||
.map_or(2, |value| value.get().min(2))
|
||||
.max(1)
|
||||
}
|
||||
|
||||
@@ -156,6 +155,7 @@ pub struct RerankerLease {
|
||||
}
|
||||
|
||||
impl RerankerLease {
|
||||
#[allow(clippy::result_large_err)]
|
||||
pub async fn rerank(
|
||||
&self,
|
||||
query: &str,
|
||||
@@ -165,7 +165,9 @@ impl RerankerLease {
|
||||
let engine = Arc::clone(&self.engine);
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let mut guard = engine.lock().expect("reranker engine mutex poisoned");
|
||||
let mut guard = engine.lock().map_err(|_| {
|
||||
AppError::InternalError("reranker engine mutex poisoned".into())
|
||||
})?;
|
||||
guard
|
||||
.rerank(query, documents, false, None)
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))
|
||||
@@ -1,14 +1,12 @@
|
||||
use std::{cmp::Ordering, collections::HashMap};
|
||||
|
||||
use common::storage::types::StoredObject;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Holds optional subscores gathered from different retrieval signals.
|
||||
/// Holds optional subscores gathered from the vector and full-text retrieval signals.
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct Scores {
|
||||
pub fts: Option<f32>,
|
||||
pub vector: Option<f32>,
|
||||
pub graph: Option<f32>,
|
||||
}
|
||||
|
||||
/// Generic wrapper combining an item with its accumulated retrieval scores.
|
||||
@@ -40,40 +38,11 @@ impl<T> Scored<T> {
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn with_graph_score(mut self, score: f32) -> Self {
|
||||
self.scores.graph = Some(score);
|
||||
self
|
||||
}
|
||||
|
||||
pub const fn update_fused(&mut self, fused: f32) {
|
||||
self.fused = fused;
|
||||
}
|
||||
}
|
||||
|
||||
/// Weights used for linear score fusion.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct FusionWeights {
|
||||
pub vector: f32,
|
||||
pub fts: f32,
|
||||
pub graph: f32,
|
||||
pub multi_bonus: f32,
|
||||
}
|
||||
|
||||
impl Default for FusionWeights {
|
||||
fn default() -> Self {
|
||||
// Default weights favor vector search, which typically performs better
|
||||
// FTS is used as a complement when there's good overlap
|
||||
// Higher multi_bonus to heavily favor chunks with both signals (the "golden chunk")
|
||||
Self {
|
||||
vector: 0.8,
|
||||
fts: 0.2,
|
||||
graph: 0.2,
|
||||
multi_bonus: 0.3, // Increased to boost chunks with both signals
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for reciprocal rank fusion.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct RrfConfig {
|
||||
@@ -84,29 +53,10 @@ pub struct RrfConfig {
|
||||
pub use_fts: bool,
|
||||
}
|
||||
|
||||
impl Default for RrfConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
k: 60.0,
|
||||
vector_weight: 1.0,
|
||||
fts_weight: 1.0,
|
||||
use_vector: true,
|
||||
use_fts: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn clamp_unit(value: f32) -> f32 {
|
||||
value.clamp(0.0, 1.0)
|
||||
}
|
||||
|
||||
pub fn distance_to_similarity(distance: f32) -> f32 {
|
||||
if !distance.is_finite() {
|
||||
return 0.0;
|
||||
}
|
||||
clamp_unit(1.0 / (1.0 + distance.max(0.0)))
|
||||
}
|
||||
|
||||
pub fn min_max_normalize(scores: &[f32]) -> Vec<f32> {
|
||||
if scores.is_empty() {
|
||||
return Vec::new();
|
||||
@@ -147,69 +97,6 @@ pub fn min_max_normalize(scores: &[f32]) -> Vec<f32> {
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 {
|
||||
let vector = scores.vector.unwrap_or(0.0);
|
||||
let fts = scores.fts.unwrap_or(0.0);
|
||||
let graph = scores.graph.unwrap_or(0.0);
|
||||
|
||||
let mut fused = graph.mul_add(
|
||||
weights.graph,
|
||||
vector.mul_add(weights.vector, fts * weights.fts),
|
||||
);
|
||||
|
||||
let signals_present = scores
|
||||
.vector
|
||||
.iter()
|
||||
.chain(scores.fts.iter())
|
||||
.chain(scores.graph.iter())
|
||||
.count();
|
||||
|
||||
// Boost chunks with multiple signals (especially vector + FTS, the "golden chunk")
|
||||
if signals_present >= 2 {
|
||||
// For chunks with both vector and FTS, give a significant boost
|
||||
// This helps identify the "golden chunk" that appears in both searches
|
||||
if scores.vector.is_some() && scores.fts.is_some() {
|
||||
// Multiplicative boost: multiply by (1 + bonus) to scale with the base score
|
||||
// This ensures high-scoring golden chunks get boosted more than low-scoring ones
|
||||
fused *= 1.0 + weights.multi_bonus;
|
||||
} else {
|
||||
// For other multi-signal combinations (e.g., vector + graph), use additive bonus
|
||||
fused += weights.multi_bonus;
|
||||
}
|
||||
}
|
||||
|
||||
clamp_unit(fused)
|
||||
}
|
||||
|
||||
pub fn merge_scored_by_id<T, S: std::hash::BuildHasher>(
|
||||
target: &mut std::collections::HashMap<String, Scored<T>, S>,
|
||||
incoming: Vec<Scored<T>>,
|
||||
) where
|
||||
T: StoredObject + Clone,
|
||||
{
|
||||
for scored in incoming {
|
||||
let id = scored.item.id().to_owned();
|
||||
target
|
||||
.entry(id)
|
||||
.and_modify(|existing| {
|
||||
if let Some(score) = scored.scores.vector {
|
||||
existing.scores.vector = Some(score);
|
||||
}
|
||||
if let Some(score) = scored.scores.fts {
|
||||
existing.scores.fts = Some(score);
|
||||
}
|
||||
if let Some(score) = scored.scores.graph {
|
||||
existing.scores.graph = Some(score);
|
||||
}
|
||||
})
|
||||
.or_insert_with(|| Scored {
|
||||
item: scored.item.clone(),
|
||||
scores: scored.scores,
|
||||
fused: scored.fused,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sort_by_fused_desc<T>(items: &mut [Scored<T>])
|
||||
where
|
||||
T: StoredObject,
|
||||
@@ -222,6 +109,10 @@ where
|
||||
});
|
||||
}
|
||||
|
||||
/// Fuse two ranked candidate lists into a single ranking using reciprocal rank fusion.
|
||||
///
|
||||
/// This is the sole fusion mechanism for the retrieval pipeline: vector and full-text
|
||||
/// candidates each contribute `weight / (k + rank + 1)` to a shared fused score.
|
||||
pub fn reciprocal_rank_fusion<T>(
|
||||
mut vector_ranked: Vec<Scored<T>>,
|
||||
mut fts_ranked: Vec<Scored<T>>,
|
||||
@@ -266,9 +157,7 @@ where
|
||||
}
|
||||
}
|
||||
entry.item = candidate.item;
|
||||
let rank_f32: f32 = u16::try_from(rank)
|
||||
.map(f32::from)
|
||||
.unwrap_or(f32::MAX);
|
||||
let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from);
|
||||
entry.fused += vector_weight / (k + rank_f32 + 1.0);
|
||||
}
|
||||
}
|
||||
@@ -296,9 +185,7 @@ where
|
||||
}
|
||||
}
|
||||
entry.item = candidate.item;
|
||||
let rank_f32: f32 = u16::try_from(rank)
|
||||
.map(f32::from)
|
||||
.unwrap_or(f32::MAX);
|
||||
let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from);
|
||||
entry.fused += fts_weight / (k + rank_f32 + 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user