retrieval-pipeline: v0

This commit is contained in:
Per Stark
2025-11-18 21:20:27 +01:00
parent 6b7befbd04
commit f535df7e61
32 changed files with 1189 additions and 453 deletions

View File

@@ -0,0 +1,117 @@
use async_openai::{
error::OpenAIError,
types::{
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
ResponseFormat, ResponseFormatJsonSchema,
},
};
use common::{
error::AppError,
storage::types::{
message::{format_history, Message},
system_settings::SystemSettings,
},
};
use serde::Deserialize;
use serde_json::Value;
use super::answer_retrieval_helper::get_query_response_schema;
#[derive(Debug, Deserialize)]
pub struct Reference {
#[allow(dead_code)]
pub reference: String,
}
#[derive(Debug, Deserialize)]
pub struct LLMResponseFormat {
pub answer: String,
#[allow(dead_code)]
pub references: Vec<Reference>,
}
#[derive(Debug)]
pub struct Answer {
pub content: String,
pub references: Vec<String>,
}
pub fn create_user_message(entities_json: &Value, query: &str) -> String {
format!(
r"
Context Information:
==================
{entities_json}
User Question:
==================
{query}
"
)
}
pub fn create_user_message_with_history(
entities_json: &Value,
history: &[Message],
query: &str,
) -> String {
format!(
r"
Chat history:
==================
{}
Context Information:
==================
{}
User Question:
==================
{}
",
format_history(history),
entities_json,
query
)
}
pub fn create_chat_request(
user_message: String,
settings: &SystemSettings,
) -> Result<CreateChatCompletionRequest, OpenAIError> {
let response_format = ResponseFormat::JsonSchema {
json_schema: ResponseFormatJsonSchema {
description: Some("Query answering AI".into()),
name: "query_answering_with_uuids".into(),
schema: Some(get_query_response_schema()),
strict: Some(true),
},
};
CreateChatCompletionRequestArgs::default()
.model(&settings.query_model)
.messages([
ChatCompletionRequestSystemMessage::from(settings.query_system_prompt.clone()).into(),
ChatCompletionRequestUserMessage::from(user_message).into(),
])
.response_format(response_format)
.build()
}
pub async fn process_llm_response(
response: CreateChatCompletionResponse,
) -> Result<LLMResponseFormat, AppError> {
response
.choices
.first()
.and_then(|choice| choice.message.content.as_ref())
.ok_or(AppError::LLMParsing(
"No content found in LLM response".into(),
))
.and_then(|content| {
serde_json::from_str::<LLMResponseFormat>(content).map_err(|e| {
AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {e}"))
})
})
}

View File

@@ -0,0 +1,26 @@
use common::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT;
use serde_json::{json, Value};
pub static QUERY_SYSTEM_PROMPT: &str = DEFAULT_QUERY_SYSTEM_PROMPT;
pub fn get_query_response_schema() -> Value {
json!({
"type": "object",
"properties": {
"answer": { "type": "string" },
"references": {
"type": "array",
"items": {
"type": "object",
"properties": {
"reference": { "type": "string" },
},
"required": ["reference"],
"additionalProperties": false,
}
}
},
"required": ["answer", "references"],
"additionalProperties": false
})
}

View File

@@ -0,0 +1,265 @@
use std::collections::HashMap;
use serde::Deserialize;
use tracing::debug;
use common::{
error::AppError,
storage::{db::SurrealDbClient, types::StoredObject},
};
use crate::scoring::Scored;
use common::storage::types::file_info::deserialize_flexible_id;
use surrealdb::sql::Thing;
#[derive(Debug, Deserialize)]
struct FtsScoreRow {
#[serde(deserialize_with = "deserialize_flexible_id")]
id: String,
fts_score: Option<f32>,
}
/// Executes a full-text search query against SurrealDB and returns scored results.
///
/// The function expects FTS indexes to exist for the provided table. Currently supports
/// `knowledge_entity` (name + description) and `text_chunk` (chunk).
pub async fn find_items_by_fts<T>(
take: usize,
query: &str,
db_client: &SurrealDbClient,
table: &str,
user_id: &str,
) -> Result<Vec<Scored<T>>, AppError>
where
T: for<'de> serde::Deserialize<'de> + StoredObject,
{
let (filter_clause, score_clause) = match table {
"knowledge_entity" => (
"(name @0@ $terms OR description @1@ $terms)",
"(IF search::score(0) != NONE THEN search::score(0) ELSE 0 END) + \
(IF search::score(1) != NONE THEN search::score(1) ELSE 0 END)",
),
"text_chunk" => (
"(chunk @0@ $terms)",
"IF search::score(0) != NONE THEN search::score(0) ELSE 0 END",
),
_ => {
return Err(AppError::Validation(format!(
"FTS not configured for table '{table}'"
)))
}
};
let sql = format!(
"SELECT id, {score_clause} AS fts_score \
FROM {table} \
WHERE {filter_clause} \
AND user_id = $user_id \
ORDER BY fts_score DESC \
LIMIT $limit",
table = table,
filter_clause = filter_clause,
score_clause = score_clause
);
debug!(
table = table,
limit = take,
"Executing FTS query with filter clause: {}",
filter_clause
);
let mut response = db_client
.query(sql)
.bind(("terms", query.to_owned()))
.bind(("user_id", user_id.to_owned()))
.bind(("limit", take as i64))
.await?;
let score_rows: Vec<FtsScoreRow> = response.take(0)?;
if score_rows.is_empty() {
return Ok(Vec::new());
}
let ids: Vec<String> = score_rows.iter().map(|row| row.id.clone()).collect();
let thing_ids: Vec<Thing> = ids
.iter()
.map(|id| Thing::from((table, id.as_str())))
.collect();
let mut items_response = db_client
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
.bind(("table", table.to_owned()))
.bind(("things", thing_ids.clone()))
.bind(("user_id", user_id.to_owned()))
.await?;
let items: Vec<T> = items_response.take(0)?;
let mut item_map: HashMap<String, T> = items
.into_iter()
.map(|item| (item.get_id().to_owned(), item))
.collect();
let mut results = Vec::with_capacity(score_rows.len());
for row in score_rows {
if let Some(item) = item_map.remove(&row.id) {
let score = row.fts_score.unwrap_or_default();
results.push(Scored::new(item).with_fts_score(score));
}
}
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
use common::storage::types::{
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
text_chunk::TextChunk,
StoredObject,
};
use uuid::Uuid;
fn dummy_embedding() -> Vec<f32> {
vec![0.0; 1536]
}
#[tokio::test]
async fn fts_preserves_single_field_score_for_name() {
let namespace = "fts_test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("failed to create in-memory surreal");
db.apply_migrations()
.await
.expect("failed to apply migrations");
let user_id = "user_fts";
let entity = KnowledgeEntity::new(
"source_a".into(),
"Rustacean handbook".into(),
"completely unrelated description".into(),
KnowledgeEntityType::Document,
None,
dummy_embedding(),
user_id.into(),
);
db.store_item(entity.clone())
.await
.expect("failed to insert entity");
db.rebuild_indexes()
.await
.expect("failed to rebuild indexes");
let results = find_items_by_fts::<KnowledgeEntity>(
5,
"rustacean",
&db,
KnowledgeEntity::table_name(),
user_id,
)
.await
.expect("fts query failed");
assert!(!results.is_empty(), "expected at least one FTS result");
assert!(
results[0].scores.fts.is_some(),
"expected an FTS score when only the name matched"
);
}
#[tokio::test]
async fn fts_preserves_single_field_score_for_description() {
let namespace = "fts_test_ns_desc";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("failed to create in-memory surreal");
db.apply_migrations()
.await
.expect("failed to apply migrations");
let user_id = "user_fts_desc";
let entity = KnowledgeEntity::new(
"source_b".into(),
"neutral name".into(),
"Detailed notes about async runtimes".into(),
KnowledgeEntityType::Document,
None,
dummy_embedding(),
user_id.into(),
);
db.store_item(entity.clone())
.await
.expect("failed to insert entity");
db.rebuild_indexes()
.await
.expect("failed to rebuild indexes");
let results = find_items_by_fts::<KnowledgeEntity>(
5,
"async",
&db,
KnowledgeEntity::table_name(),
user_id,
)
.await
.expect("fts query failed");
assert!(!results.is_empty(), "expected at least one FTS result");
assert!(
results[0].scores.fts.is_some(),
"expected an FTS score when only the description matched"
);
}
#[tokio::test]
async fn fts_preserves_scores_for_text_chunks() {
let namespace = "fts_test_ns_chunks";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("failed to create in-memory surreal");
db.apply_migrations()
.await
.expect("failed to apply migrations");
let user_id = "user_fts_chunk";
let chunk = TextChunk::new(
"source_chunk".into(),
"GraphQL documentation reference".into(),
dummy_embedding(),
user_id.into(),
);
db.store_item(chunk.clone())
.await
.expect("failed to insert chunk");
db.rebuild_indexes()
.await
.expect("failed to rebuild indexes");
let results =
find_items_by_fts::<TextChunk>(5, "graphql", &db, TextChunk::table_name(), user_id)
.await
.expect("fts query failed");
assert!(!results.is_empty(), "expected at least one FTS result");
assert!(
results[0].scores.fts.is_some(),
"expected an FTS score when chunk field matched"
);
}
}

View File

@@ -0,0 +1,432 @@
use std::collections::{HashMap, HashSet};
use surrealdb::{sql::Thing, Error};
use common::storage::{
db::SurrealDbClient,
types::{
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
StoredObject,
},
};
/// Retrieves database entries that match a specific source identifier.
///
/// This function queries the database for all records in a specified table that have
/// a matching `source_id` field. It's commonly used to find related entities or
/// track the origin of database entries.
///
/// # Arguments
///
/// * `source_id` - The identifier to search for in the database
/// * `table_name` - The name of the table to search in
/// * `db_client` - The `SurrealDB` client instance for database operations
///
/// # Type Parameters
///
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`
///
/// # Returns
///
/// Returns a `Result` containing either:
/// * `Ok(Vec<T>)` - A vector of matching records deserialized into type `T`
/// * `Err(Error)` - An error if the database query fails
///
/// # Errors
///
/// This function will return a `Error` if:
/// * The database query fails to execute
/// * The results cannot be deserialized into type `T`
pub async fn find_entities_by_source_ids<T>(
source_ids: Vec<String>,
table_name: &str,
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<T>, Error>
where
T: for<'de> serde::Deserialize<'de>,
{
let query =
"SELECT * FROM type::table($table) WHERE source_id IN $source_ids AND user_id = $user_id";
db.query(query)
.bind(("table", table_name.to_owned()))
.bind(("source_ids", source_ids))
.bind(("user_id", user_id.to_owned()))
.await?
.take(0)
}
/// Find entities by their relationship to the id
pub async fn find_entities_by_relationship_by_id(
db: &SurrealDbClient,
entity_id: &str,
user_id: &str,
limit: usize,
) -> Result<Vec<KnowledgeEntity>, Error> {
let mut relationships_response = db
.query(
"
SELECT * FROM relates_to
WHERE metadata.user_id = $user_id
AND (in = type::thing('knowledge_entity', $entity_id)
OR out = type::thing('knowledge_entity', $entity_id))
",
)
.bind(("entity_id", entity_id.to_owned()))
.bind(("user_id", user_id.to_owned()))
.await?;
let relationships: Vec<KnowledgeRelationship> = relationships_response.take(0)?;
if relationships.is_empty() {
return Ok(Vec::new());
}
let mut neighbor_ids: Vec<String> = Vec::new();
let mut seen: HashSet<String> = HashSet::new();
for rel in relationships {
if rel.in_ == entity_id {
if seen.insert(rel.out.clone()) {
neighbor_ids.push(rel.out);
}
} else if rel.out == entity_id {
if seen.insert(rel.in_.clone()) {
neighbor_ids.push(rel.in_);
}
} else {
if seen.insert(rel.in_.clone()) {
neighbor_ids.push(rel.in_.clone());
}
if seen.insert(rel.out.clone()) {
neighbor_ids.push(rel.out);
}
}
}
neighbor_ids.retain(|id| id != entity_id);
if neighbor_ids.is_empty() {
return Ok(Vec::new());
}
if limit > 0 && neighbor_ids.len() > limit {
neighbor_ids.truncate(limit);
}
let thing_ids: Vec<Thing> = neighbor_ids
.iter()
.map(|id| Thing::from((KnowledgeEntity::table_name(), id.as_str())))
.collect();
let mut neighbors_response = db
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
.bind(("table", KnowledgeEntity::table_name().to_owned()))
.bind(("things", thing_ids))
.bind(("user_id", user_id.to_owned()))
.await?;
let neighbors: Vec<KnowledgeEntity> = neighbors_response.take(0)?;
if neighbors.is_empty() {
return Ok(Vec::new());
}
let mut neighbor_map: HashMap<String, KnowledgeEntity> = neighbors
.into_iter()
.map(|entity| (entity.id.clone(), entity))
.collect();
let mut ordered = Vec::new();
for id in neighbor_ids {
if let Some(entity) = neighbor_map.remove(&id) {
ordered.push(entity);
}
if limit > 0 && ordered.len() >= limit {
break;
}
}
Ok(ordered)
}
#[cfg(test)]
mod tests {
use super::*;
use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
use common::storage::types::knowledge_relationship::KnowledgeRelationship;
use common::storage::types::StoredObject;
use uuid::Uuid;
#[tokio::test]
async fn test_find_entities_by_source_ids() {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
// Create some test entities with different source_ids
let source_id1 = "source123".to_string();
let source_id2 = "source456".to_string();
let source_id3 = "source789".to_string();
let entity_type = KnowledgeEntityType::Document;
let embedding = vec![0.1, 0.2, 0.3];
let user_id = "user123".to_string();
// Entity with source_id1
let entity1 = KnowledgeEntity::new(
source_id1.clone(),
"Entity 1".to_string(),
"Description 1".to_string(),
entity_type.clone(),
None,
embedding.clone(),
user_id.clone(),
);
// Entity with source_id2
let entity2 = KnowledgeEntity::new(
source_id2.clone(),
"Entity 2".to_string(),
"Description 2".to_string(),
entity_type.clone(),
None,
embedding.clone(),
user_id.clone(),
);
// Another entity with source_id1
let entity3 = KnowledgeEntity::new(
source_id1.clone(),
"Entity 3".to_string(),
"Description 3".to_string(),
entity_type.clone(),
None,
embedding.clone(),
user_id.clone(),
);
// Entity with source_id3
let entity4 = KnowledgeEntity::new(
source_id3.clone(),
"Entity 4".to_string(),
"Description 4".to_string(),
entity_type.clone(),
None,
embedding.clone(),
user_id.clone(),
);
// Store all entities
db.store_item(entity1.clone())
.await
.expect("Failed to store entity 1");
db.store_item(entity2.clone())
.await
.expect("Failed to store entity 2");
db.store_item(entity3.clone())
.await
.expect("Failed to store entity 3");
db.store_item(entity4.clone())
.await
.expect("Failed to store entity 4");
// Test finding entities by multiple source_ids
let source_ids = vec![source_id1.clone(), source_id2.clone()];
let found_entities: Vec<KnowledgeEntity> =
find_entities_by_source_ids(source_ids, KnowledgeEntity::table_name(), &user_id, &db)
.await
.expect("Failed to find entities by source_ids");
// Should find 3 entities (2 with source_id1, 1 with source_id2)
assert_eq!(
found_entities.len(),
3,
"Should find 3 entities with the specified source_ids"
);
// Check that entities with source_id1 and source_id2 are found
let found_source_ids: Vec<String> =
found_entities.iter().map(|e| e.source_id.clone()).collect();
assert!(
found_source_ids.contains(&source_id1),
"Should find entities with source_id1"
);
assert!(
found_source_ids.contains(&source_id2),
"Should find entities with source_id2"
);
assert!(
!found_source_ids.contains(&source_id3),
"Should not find entities with source_id3"
);
// Test finding entities by a single source_id
let single_source_id = vec![source_id1.clone()];
let found_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
single_source_id,
KnowledgeEntity::table_name(),
&user_id,
&db,
)
.await
.expect("Failed to find entities by single source_id");
// Should find 2 entities with source_id1
assert_eq!(
found_entities.len(),
2,
"Should find 2 entities with source_id1"
);
// Check that all found entities have source_id1
for entity in found_entities {
assert_eq!(
entity.source_id, source_id1,
"All found entities should have source_id1"
);
}
// Test finding entities with non-existent source_id
let non_existent_source_id = vec!["non_existent_source".to_string()];
let found_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
non_existent_source_id,
KnowledgeEntity::table_name(),
&user_id,
&db,
)
.await
.expect("Failed to find entities by non-existent source_id");
// Should find 0 entities
assert_eq!(
found_entities.len(),
0,
"Should find 0 entities with non-existent source_id"
);
}
#[tokio::test]
async fn test_find_entities_by_relationship_by_id() {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
// Create some test entities
let entity_type = KnowledgeEntityType::Document;
let embedding = vec![0.1, 0.2, 0.3];
let user_id = "user123".to_string();
// Create the central entity we'll query relationships for
let central_entity = KnowledgeEntity::new(
"central_source".to_string(),
"Central Entity".to_string(),
"Central Description".to_string(),
entity_type.clone(),
None,
embedding.clone(),
user_id.clone(),
);
// Create related entities
let related_entity1 = KnowledgeEntity::new(
"related_source1".to_string(),
"Related Entity 1".to_string(),
"Related Description 1".to_string(),
entity_type.clone(),
None,
embedding.clone(),
user_id.clone(),
);
let related_entity2 = KnowledgeEntity::new(
"related_source2".to_string(),
"Related Entity 2".to_string(),
"Related Description 2".to_string(),
entity_type.clone(),
None,
embedding.clone(),
user_id.clone(),
);
// Create an unrelated entity
let unrelated_entity = KnowledgeEntity::new(
"unrelated_source".to_string(),
"Unrelated Entity".to_string(),
"Unrelated Description".to_string(),
entity_type.clone(),
None,
embedding.clone(),
user_id.clone(),
);
// Store all entities
let central_entity = db
.store_item(central_entity.clone())
.await
.expect("Failed to store central entity")
.unwrap();
let related_entity1 = db
.store_item(related_entity1.clone())
.await
.expect("Failed to store related entity 1")
.unwrap();
let related_entity2 = db
.store_item(related_entity2.clone())
.await
.expect("Failed to store related entity 2")
.unwrap();
let _unrelated_entity = db
.store_item(unrelated_entity.clone())
.await
.expect("Failed to store unrelated entity")
.unwrap();
// Create relationships
let source_id = "relationship_source".to_string();
// Create relationship 1: central -> related1
let relationship1 = KnowledgeRelationship::new(
central_entity.id.clone(),
related_entity1.id.clone(),
user_id.clone(),
source_id.clone(),
"references".to_string(),
);
// Create relationship 2: central -> related2
let relationship2 = KnowledgeRelationship::new(
central_entity.id.clone(),
related_entity2.id.clone(),
user_id.clone(),
source_id.clone(),
"contains".to_string(),
);
// Store relationships
relationship1
.store_relationship(&db)
.await
.expect("Failed to store relationship 1");
relationship2
.store_relationship(&db)
.await
.expect("Failed to store relationship 2");
// Test finding entities related to the central entity
let related_entities =
find_entities_by_relationship_by_id(&db, &central_entity.id, &user_id, usize::MAX)
.await
.expect("Failed to find entities by relationship");
// Check that we found relationships
assert!(
related_entities.len() >= 2,
"Should find related entities in both directions"
);
}
}

View File

@@ -0,0 +1,336 @@
pub mod answer_retrieval;
pub mod answer_retrieval_helper;
pub mod fts;
pub mod graph;
pub mod pipeline;
pub mod reranking;
pub mod scoring;
pub mod vector;
use common::{
error::AppError,
storage::{
db::SurrealDbClient,
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
},
};
use reranking::RerankerLease;
use tracing::instrument;
pub use pipeline::{
retrieved_entities_to_json, PipelineDiagnostics, PipelineStageTimings, RetrievalConfig,
RetrievalStrategy, RetrievalTuning, StrategyOutput,
};
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
#[derive(Debug, Clone)]
pub struct RetrievedChunk {
pub chunk: TextChunk,
pub score: f32,
}
// Final entity representation returned to callers, enriched with ranked chunks.
#[derive(Debug, Clone)]
pub struct RetrievedEntity {
pub entity: KnowledgeEntity,
pub score: f32,
pub chunks: Vec<RetrievedChunk>,
}
// Primary orchestrator for the process of retrieving KnowledgeEntitities related to a input_text
#[instrument(skip_all, fields(user_id))]
pub async fn retrieve_entities(
db_client: &SurrealDbClient,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<StrategyOutput, AppError> {
pipeline::run_pipeline(
db_client,
openai_client,
input_text,
user_id,
config,
reranker,
)
.await
}
#[cfg(test)]
mod tests {
use super::*;
use async_openai::Client;
use common::storage::types::{
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
knowledge_relationship::KnowledgeRelationship,
text_chunk::TextChunk,
};
use pipeline::{RetrievalConfig, RetrievalStrategy};
use uuid::Uuid;
fn test_embedding() -> Vec<f32> {
vec![0.9, 0.1, 0.0]
}
fn entity_embedding_high() -> Vec<f32> {
vec![0.8, 0.2, 0.0]
}
fn entity_embedding_low() -> Vec<f32> {
vec![0.1, 0.9, 0.0]
}
fn chunk_embedding_primary() -> Vec<f32> {
vec![0.85, 0.15, 0.0]
}
fn chunk_embedding_secondary() -> Vec<f32> {
vec![0.2, 0.8, 0.0]
}
async fn setup_test_db() -> SurrealDbClient {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
db.query(
"BEGIN TRANSACTION;
REMOVE INDEX IF EXISTS idx_embedding_chunks ON TABLE text_chunk;
DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION 3;
REMOVE INDEX IF EXISTS idx_embedding_entities ON TABLE knowledge_entity;
DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION 3;
COMMIT TRANSACTION;",
)
.await
.expect("Failed to configure indices");
db
}
#[tokio::test]
async fn test_retrieve_entities_with_embedding_basic_flow() {
let db = setup_test_db().await;
let user_id = "test_user";
let entity = KnowledgeEntity::new(
"source_1".into(),
"Rust async guide".into(),
"Detailed notes about async runtimes".into(),
KnowledgeEntityType::Document,
None,
entity_embedding_high(),
user_id.into(),
);
let chunk = TextChunk::new(
entity.source_id.clone(),
"Tokio uses cooperative scheduling for fairness.".into(),
chunk_embedding_primary(),
user_id.into(),
);
db.store_item(entity.clone())
.await
.expect("Failed to store entity");
db.store_item(chunk.clone())
.await
.expect("Failed to store chunk");
let openai_client = Client::new();
let results = pipeline::run_pipeline_with_embedding(
&db,
&openai_client,
test_embedding(),
"Rust concurrency async tasks",
user_id,
RetrievalConfig::default(),
None,
)
.await
.expect("Hybrid retrieval failed");
let entities = match results {
StrategyOutput::Entities(items) => items,
other => panic!("expected entity results, got {:?}", other),
};
assert!(
!entities.is_empty(),
"Expected at least one retrieval result"
);
let top = &entities[0];
assert!(
top.entity.name.contains("Rust"),
"Expected Rust entity to be ranked first"
);
assert!(
!top.chunks.is_empty(),
"Expected Rust entity to include supporting chunks"
);
}
#[tokio::test]
async fn test_graph_relationship_enriches_results() {
let db = setup_test_db().await;
let user_id = "graph_user";
let primary = KnowledgeEntity::new(
"primary_source".into(),
"Async Rust patterns".into(),
"Explores async runtimes and scheduling strategies.".into(),
KnowledgeEntityType::Document,
None,
entity_embedding_high(),
user_id.into(),
);
let neighbor = KnowledgeEntity::new(
"neighbor_source".into(),
"Tokio Scheduler Deep Dive".into(),
"Details on Tokio's cooperative scheduler.".into(),
KnowledgeEntityType::Document,
None,
entity_embedding_low(),
user_id.into(),
);
db.store_item(primary.clone())
.await
.expect("Failed to store primary entity");
db.store_item(neighbor.clone())
.await
.expect("Failed to store neighbor entity");
let primary_chunk = TextChunk::new(
primary.source_id.clone(),
"Rust async tasks use Tokio's cooperative scheduler.".into(),
chunk_embedding_primary(),
user_id.into(),
);
let neighbor_chunk = TextChunk::new(
neighbor.source_id.clone(),
"Tokio's scheduler manages task fairness across executors.".into(),
chunk_embedding_secondary(),
user_id.into(),
);
db.store_item(primary_chunk)
.await
.expect("Failed to store primary chunk");
db.store_item(neighbor_chunk)
.await
.expect("Failed to store neighbor chunk");
let openai_client = Client::new();
let relationship = KnowledgeRelationship::new(
primary.id.clone(),
neighbor.id.clone(),
user_id.into(),
"relationship_source".into(),
"references".into(),
);
relationship
.store_relationship(&db)
.await
.expect("Failed to store relationship");
let results = pipeline::run_pipeline_with_embedding(
&db,
&openai_client,
test_embedding(),
"Rust concurrency async tasks",
user_id,
RetrievalConfig::default(),
None,
)
.await
.expect("Hybrid retrieval failed");
let entities = match results {
StrategyOutput::Entities(items) => items,
other => panic!("expected entity results, got {:?}", other),
};
let mut neighbor_entry = None;
for entity in &entities {
if entity.entity.id == neighbor.id {
neighbor_entry = Some(entity.clone());
}
}
let neighbor_entry =
neighbor_entry.expect("Graph-enriched neighbor should appear in results");
assert!(
neighbor_entry.score > 0.2,
"Graph-enriched entity should have a meaningful fused score"
);
assert!(
neighbor_entry
.chunks
.iter()
.all(|chunk| chunk.chunk.source_id == neighbor.source_id),
"Neighbor entity should surface its own supporting chunks"
);
}
#[tokio::test]
async fn test_revised_strategy_returns_chunks() {
let db = setup_test_db().await;
let user_id = "chunk_user";
let chunk_one = TextChunk::new(
"src_alpha".into(),
"Tokio tasks execute on worker threads managed by the runtime.".into(),
chunk_embedding_primary(),
user_id.into(),
);
let chunk_two = TextChunk::new(
"src_beta".into(),
"Hyper utilizes Tokio to drive HTTP state machines efficiently.".into(),
chunk_embedding_secondary(),
user_id.into(),
);
db.store_item(chunk_one.clone())
.await
.expect("Failed to store chunk one");
db.store_item(chunk_two.clone())
.await
.expect("Failed to store chunk two");
let config = RetrievalConfig::with_strategy(RetrievalStrategy::Revised);
let openai_client = Client::new();
let results = pipeline::run_pipeline_with_embedding(
&db,
&openai_client,
test_embedding(),
"tokio runtime worker behavior",
user_id,
config,
None,
)
.await
.expect("Revised retrieval failed");
let chunks = match results {
StrategyOutput::Chunks(items) => items,
other => panic!("expected chunk output, got {:?}", other),
};
assert!(
!chunks.is_empty(),
"Revised strategy should return chunk-only responses"
);
assert!(
chunks
.iter()
.any(|entry| entry.chunk.chunk.contains("Tokio")),
"Chunk results should contain relevant snippets"
);
}
}

View File

@@ -0,0 +1,121 @@
use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RetrievalStrategy {
Initial,
Revised,
}
impl Default for RetrievalStrategy {
fn default() -> Self {
Self::Initial
}
}
impl std::str::FromStr for RetrievalStrategy {
type Err = String;
fn from_str(value: &str) -> Result<Self, Self::Err> {
match value.to_ascii_lowercase().as_str() {
"initial" => Ok(Self::Initial),
"revised" => Ok(Self::Revised),
other => Err(format!("unknown retrieval strategy '{other}'")),
}
}
}
impl fmt::Display for RetrievalStrategy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let label = match self {
RetrievalStrategy::Initial => "initial",
RetrievalStrategy::Revised => "revised",
};
f.write_str(label)
}
}
/// Tunable parameters that govern each retrieval stage.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetrievalTuning {
pub entity_vector_take: usize,
pub chunk_vector_take: usize,
pub entity_fts_take: usize,
pub chunk_fts_take: usize,
pub score_threshold: f32,
pub fallback_min_results: usize,
pub token_budget_estimate: usize,
pub avg_chars_per_token: usize,
pub max_chunks_per_entity: usize,
pub lexical_match_weight: f32,
pub graph_traversal_seed_limit: usize,
pub graph_neighbor_limit: usize,
pub graph_score_decay: f32,
pub graph_seed_min_score: f32,
pub graph_vector_inheritance: f32,
pub rerank_blend_weight: f32,
pub rerank_scores_only: bool,
pub rerank_keep_top: usize,
}
impl Default for RetrievalTuning {
fn default() -> Self {
Self {
entity_vector_take: 15,
chunk_vector_take: 20,
entity_fts_take: 10,
chunk_fts_take: 20,
score_threshold: 0.35,
fallback_min_results: 10,
token_budget_estimate: 10000,
avg_chars_per_token: 4,
max_chunks_per_entity: 4,
lexical_match_weight: 0.15,
graph_traversal_seed_limit: 5,
graph_neighbor_limit: 6,
graph_score_decay: 0.75,
graph_seed_min_score: 0.4,
graph_vector_inheritance: 0.6,
rerank_blend_weight: 0.65,
rerank_scores_only: false,
rerank_keep_top: 8,
}
}
}
/// Wrapper containing tuning plus future flags for per-request overrides.
#[derive(Debug, Clone)]
pub struct RetrievalConfig {
pub strategy: RetrievalStrategy,
pub tuning: RetrievalTuning,
}
impl RetrievalConfig {
pub fn new(tuning: RetrievalTuning) -> Self {
Self {
strategy: RetrievalStrategy::Initial,
tuning,
}
}
pub fn with_strategy(strategy: RetrievalStrategy) -> Self {
Self {
strategy,
tuning: RetrievalTuning::default(),
}
}
pub fn with_tuning(strategy: RetrievalStrategy, tuning: RetrievalTuning) -> Self {
Self { strategy, tuning }
}
}
impl Default for RetrievalConfig {
fn default() -> Self {
Self {
strategy: RetrievalStrategy::default(),
tuning: RetrievalTuning::default(),
}
}
}

View File

@@ -0,0 +1,51 @@
use serde::Serialize;
/// Captures instrumentation for each hybrid retrieval stage when diagnostics are enabled.
#[derive(Debug, Clone, Default, Serialize)]
pub struct PipelineDiagnostics {
pub collect_candidates: Option<CollectCandidatesStats>,
pub enrich_chunks_from_entities: Option<ChunkEnrichmentStats>,
pub assemble: Option<AssembleStats>,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct CollectCandidatesStats {
pub vector_entity_candidates: usize,
pub vector_chunk_candidates: usize,
pub fts_entity_candidates: usize,
pub fts_chunk_candidates: usize,
pub vector_chunk_scores: Vec<f32>,
pub fts_chunk_scores: Vec<f32>,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct ChunkEnrichmentStats {
pub filtered_entity_count: usize,
pub fallback_min_results: usize,
pub chunk_sources_considered: usize,
pub chunk_candidates_before_enrichment: usize,
pub chunk_candidates_after_enrichment: usize,
pub top_chunk_scores: Vec<f32>,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct AssembleStats {
pub token_budget_start: usize,
pub token_budget_spent: usize,
pub token_budget_remaining: usize,
pub budget_exhausted: bool,
pub chunks_selected: usize,
pub chunks_skipped_due_budget: usize,
pub entity_count: usize,
pub entity_traces: Vec<EntityAssemblyTrace>,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct EntityAssemblyTrace {
pub entity_id: String,
pub source_id: String,
pub inspected_candidates: usize,
pub selected_chunk_ids: Vec<String>,
pub selected_chunk_scores: Vec<f32>,
pub skipped_due_budget: usize,
}

View File

@@ -0,0 +1,397 @@
mod config;
mod diagnostics;
mod stages;
mod strategies;
pub use config::{RetrievalConfig, RetrievalStrategy, RetrievalTuning};
pub use diagnostics::{
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
PipelineDiagnostics,
};
use crate::{reranking::RerankerLease, RetrievedChunk, RetrievedEntity};
use async_openai::Client;
use async_trait::async_trait;
use common::{error::AppError, storage::db::SurrealDbClient};
use std::time::{Duration, Instant};
use tracing::info;
use stages::PipelineContext;
use strategies::{InitialStrategyDriver, RevisedStrategyDriver};
#[derive(Debug, Clone)]
pub enum StrategyOutput {
Entities(Vec<RetrievedEntity>),
Chunks(Vec<RetrievedChunk>),
}
impl StrategyOutput {
pub fn as_entities(&self) -> Option<&[RetrievedEntity]> {
match self {
StrategyOutput::Entities(items) => Some(items),
_ => None,
}
}
pub fn into_entities(self) -> Option<Vec<RetrievedEntity>> {
match self {
StrategyOutput::Entities(items) => Some(items),
_ => None,
}
}
pub fn as_chunks(&self) -> Option<&[RetrievedChunk]> {
match self {
StrategyOutput::Chunks(items) => Some(items),
_ => None,
}
}
pub fn into_chunks(self) -> Option<Vec<RetrievedChunk>> {
match self {
StrategyOutput::Chunks(items) => Some(items),
_ => None,
}
}
}
#[derive(Debug)]
pub struct PipelineRunOutput<T> {
pub results: T,
pub diagnostics: Option<PipelineDiagnostics>,
pub stage_timings: PipelineStageTimings,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StageKind {
Embed,
CollectCandidates,
GraphExpansion,
ChunkAttach,
Rerank,
Assemble,
}
#[derive(Debug, Clone, Default, serde::Serialize)]
pub struct PipelineStageTimings {
pub embed_ms: u128,
pub collect_candidates_ms: u128,
pub graph_expansion_ms: u128,
pub chunk_attach_ms: u128,
pub rerank_ms: u128,
pub assemble_ms: u128,
}
impl PipelineStageTimings {
pub fn record(&mut self, kind: StageKind, duration: Duration) {
let elapsed = duration.as_millis() as u128;
match kind {
StageKind::Embed => self.embed_ms += elapsed,
StageKind::CollectCandidates => self.collect_candidates_ms += elapsed,
StageKind::GraphExpansion => self.graph_expansion_ms += elapsed,
StageKind::ChunkAttach => self.chunk_attach_ms += elapsed,
StageKind::Rerank => self.rerank_ms += elapsed,
StageKind::Assemble => self.assemble_ms += elapsed,
}
}
}
#[async_trait]
pub trait PipelineStage: Send + Sync {
fn kind(&self) -> StageKind;
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError>;
}
pub type BoxedStage = Box<dyn PipelineStage + Send + Sync>;
pub trait StrategyDriver {
type Output;
fn strategy(&self) -> RetrievalStrategy;
fn stages(&self) -> Vec<BoxedStage>;
fn override_tuning(&self, _config: &mut RetrievalConfig) {}
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError>;
}
pub async fn run_pipeline(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<StrategyOutput, AppError> {
let input_chars = input_text.chars().count();
let input_preview: String = input_text.chars().take(120).collect();
let input_preview_clean = input_preview.replace('\n', " ");
let preview_len = input_preview_clean.chars().count();
info!(
%user_id,
input_chars,
preview_truncated = input_chars > preview_len,
preview = %input_preview_clean,
"Starting ingestion retrieval pipeline"
);
if config.strategy == RetrievalStrategy::Initial {
let driver = InitialStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
None,
input_text,
user_id,
config,
reranker,
false,
)
.await?;
return Ok(StrategyOutput::Entities(run.results));
}
let driver = RevisedStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
None,
input_text,
user_id,
config,
reranker,
false,
)
.await?;
Ok(StrategyOutput::Chunks(run.results))
}
pub async fn run_pipeline_with_embedding(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<StrategyOutput, AppError> {
if config.strategy == RetrievalStrategy::Initial {
let driver = InitialStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
false,
)
.await?;
return Ok(StrategyOutput::Entities(run.results));
}
let driver = RevisedStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
false,
)
.await?;
Ok(StrategyOutput::Chunks(run.results))
}
pub async fn run_pipeline_with_embedding_with_metrics(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
if config.strategy == RetrievalStrategy::Initial {
let driver = InitialStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
false,
)
.await?;
return Ok(PipelineRunOutput {
results: StrategyOutput::Entities(run.results),
diagnostics: run.diagnostics,
stage_timings: run.stage_timings,
});
}
let driver = RevisedStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
false,
)
.await?;
Ok(PipelineRunOutput {
results: StrategyOutput::Chunks(run.results),
diagnostics: run.diagnostics,
stage_timings: run.stage_timings,
})
}
pub async fn run_pipeline_with_embedding_with_diagnostics(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
if config.strategy == RetrievalStrategy::Initial {
let driver = InitialStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
true,
)
.await?;
return Ok(PipelineRunOutput {
results: StrategyOutput::Entities(run.results),
diagnostics: run.diagnostics,
stage_timings: run.stage_timings,
});
}
let driver = RevisedStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
true,
)
.await?;
Ok(PipelineRunOutput {
results: StrategyOutput::Chunks(run.results),
diagnostics: run.diagnostics,
stage_timings: run.stage_timings,
})
}
pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value {
serde_json::json!(entities
.iter()
.map(|entry| {
serde_json::json!({
"KnowledgeEntity": {
"id": entry.entity.id,
"name": entry.entity.name,
"description": entry.entity.description,
"score": round_score(entry.score),
"chunks": entry.chunks.iter().map(|chunk| {
serde_json::json!({
"score": round_score(chunk.score),
"content": chunk.chunk.chunk
})
}).collect::<Vec<_>>()
}
})
})
.collect::<Vec<_>>())
}
async fn execute_strategy<D: StrategyDriver>(
driver: D,
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
query_embedding: Option<Vec<f32>>,
input_text: &str,
user_id: &str,
mut config: RetrievalConfig,
reranker: Option<RerankerLease>,
capture_diagnostics: bool,
) -> Result<PipelineRunOutput<D::Output>, AppError> {
driver.override_tuning(&mut config);
let ctx = match query_embedding {
Some(embedding) => PipelineContext::with_embedding(
db_client,
openai_client,
embedding,
input_text.to_owned(),
user_id.to_owned(),
config,
reranker,
),
None => PipelineContext::new(
db_client,
openai_client,
input_text.to_owned(),
user_id.to_owned(),
config,
reranker,
),
};
run_with_driver(driver, ctx, capture_diagnostics).await
}
async fn run_with_driver<D: StrategyDriver>(
driver: D,
mut ctx: PipelineContext<'_>,
capture_diagnostics: bool,
) -> Result<PipelineRunOutput<D::Output>, AppError> {
if capture_diagnostics {
ctx.enable_diagnostics();
}
for stage in driver.stages() {
let start = Instant::now();
stage.execute(&mut ctx).await?;
ctx.record_stage_duration(stage.kind(), start.elapsed());
}
let diagnostics = ctx.take_diagnostics();
let stage_timings = ctx.take_stage_timings();
let results = driver.finalize(&mut ctx)?;
Ok(PipelineRunOutput {
results,
diagnostics,
stage_timings,
})
}
fn round_score(value: f32) -> f64 {
(f64::from(value) * 1000.0).round() / 1000.0
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,75 @@
use super::{
stages::{
AssembleEntitiesStage, ChunkAssembleStage, ChunkAttachStage, ChunkRerankStage,
ChunkVectorStage, CollectCandidatesStage, EmbedStage, GraphExpansionStage, PipelineContext,
RerankStage,
},
BoxedStage, RetrievalConfig, RetrievalStrategy, StrategyDriver,
};
use crate::{RetrievedChunk, RetrievedEntity};
use common::error::AppError;
pub struct InitialStrategyDriver;
impl InitialStrategyDriver {
pub fn new() -> Self {
Self
}
}
impl StrategyDriver for InitialStrategyDriver {
type Output = Vec<RetrievedEntity>;
fn strategy(&self) -> RetrievalStrategy {
RetrievalStrategy::Initial
}
fn stages(&self) -> Vec<BoxedStage> {
vec![
Box::new(EmbedStage),
Box::new(CollectCandidatesStage),
Box::new(GraphExpansionStage),
Box::new(ChunkAttachStage),
Box::new(RerankStage),
Box::new(AssembleEntitiesStage),
]
}
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
Ok(ctx.take_entity_results())
}
}
pub struct RevisedStrategyDriver;
impl RevisedStrategyDriver {
pub fn new() -> Self {
Self
}
}
impl StrategyDriver for RevisedStrategyDriver {
type Output = Vec<RetrievedChunk>;
fn strategy(&self) -> RetrievalStrategy {
RetrievalStrategy::Revised
}
fn stages(&self) -> Vec<BoxedStage> {
vec![
Box::new(EmbedStage),
Box::new(ChunkVectorStage),
Box::new(ChunkRerankStage),
Box::new(ChunkAssembleStage),
]
}
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
Ok(ctx.take_chunk_results())
}
fn override_tuning(&self, config: &mut RetrievalConfig) {
config.tuning.entity_vector_take = 0;
config.tuning.entity_fts_take = 0;
}
}

View File

@@ -0,0 +1,170 @@
use std::{
env, fs,
path::{Path, PathBuf},
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
thread::available_parallelism,
};
use common::{error::AppError, utils::config::AppConfig};
use fastembed::{RerankInitOptions, RerankResult, TextRerank};
use tokio::sync::{Mutex, OwnedSemaphorePermit, Semaphore};
use tracing::debug;
static NEXT_ENGINE: AtomicUsize = AtomicUsize::new(0);
fn pick_engine_index(pool_len: usize) -> usize {
let n = NEXT_ENGINE.fetch_add(1, Ordering::Relaxed);
n % pool_len
}
pub struct RerankerPool {
engines: Vec<Arc<Mutex<TextRerank>>>,
semaphore: Arc<Semaphore>,
}
impl RerankerPool {
/// Build the pool at startup.
/// `pool_size` controls max parallel reranks.
pub fn new(pool_size: usize) -> Result<Arc<Self>, AppError> {
Self::new_with_options(pool_size, RerankInitOptions::default())
}
fn new_with_options(
pool_size: usize,
init_options: RerankInitOptions,
) -> Result<Arc<Self>, AppError> {
if pool_size == 0 {
return Err(AppError::Validation(
"RERANKING_POOL_SIZE must be greater than zero".to_string(),
));
}
fs::create_dir_all(&init_options.cache_dir)?;
let mut engines = Vec::with_capacity(pool_size);
for x in 0..pool_size {
debug!("Creating reranking engine: {x}");
let model = TextRerank::try_new(init_options.clone())
.map_err(|e| AppError::InternalError(e.to_string()))?;
engines.push(Arc::new(Mutex::new(model)));
}
Ok(Arc::new(Self {
engines,
semaphore: Arc::new(Semaphore::new(pool_size)),
}))
}
/// Initialize a pool using application configuration.
pub fn maybe_from_config(config: &AppConfig) -> Result<Option<Arc<Self>>, AppError> {
if !config.reranking_enabled {
return Ok(None);
}
let pool_size = config.reranking_pool_size.unwrap_or_else(default_pool_size);
let init_options = build_rerank_init_options(config)?;
Self::new_with_options(pool_size, init_options).map(Some)
}
/// Check out capacity + pick an engine.
/// This returns a lease that can perform rerank().
pub async fn checkout(self: &Arc<Self>) -> RerankerLease {
// Acquire a permit. This enforces backpressure.
let permit = self
.semaphore
.clone()
.acquire_owned()
.await
.expect("semaphore closed");
// Pick an engine.
// This is naive: just pick based on a simple modulo counter.
// We use an atomic counter to avoid always choosing index 0.
let idx = pick_engine_index(self.engines.len());
let engine = self.engines[idx].clone();
RerankerLease {
_permit: permit,
engine,
}
}
}
fn default_pool_size() -> usize {
available_parallelism()
.map(|value| value.get().min(2))
.unwrap_or(2)
.max(1)
}
fn is_truthy(value: &str) -> bool {
matches!(
value.trim().to_ascii_lowercase().as_str(),
"1" | "true" | "yes" | "on"
)
}
fn build_rerank_init_options(config: &AppConfig) -> Result<RerankInitOptions, AppError> {
let mut options = RerankInitOptions::default();
let cache_dir = config
.fastembed_cache_dir
.as_ref()
.map(PathBuf::from)
.or_else(|| env::var("RERANKING_CACHE_DIR").ok().map(PathBuf::from))
.or_else(|| env::var("FASTEMBED_CACHE_DIR").ok().map(PathBuf::from))
.unwrap_or_else(|| {
Path::new(&config.data_dir)
.join("fastembed")
.join("reranker")
});
fs::create_dir_all(&cache_dir)?;
options.cache_dir = cache_dir;
let show_progress = config
.fastembed_show_download_progress
.or_else(|| env_bool("RERANKING_SHOW_DOWNLOAD_PROGRESS"))
.or_else(|| env_bool("FASTEMBED_SHOW_DOWNLOAD_PROGRESS"))
.unwrap_or(true);
options.show_download_progress = show_progress;
if let Some(max_length) = config.fastembed_max_length.or_else(|| {
env::var("RERANKING_MAX_LENGTH")
.ok()
.and_then(|value| value.parse().ok())
}) {
options.max_length = max_length;
}
Ok(options)
}
fn env_bool(key: &str) -> Option<bool> {
env::var(key).ok().map(|value| is_truthy(&value))
}
/// Active lease on a single TextRerank instance.
pub struct RerankerLease {
// When this drops the semaphore permit is released.
_permit: OwnedSemaphorePermit,
engine: Arc<Mutex<TextRerank>>,
}
impl RerankerLease {
pub async fn rerank(
&self,
query: &str,
documents: Vec<String>,
) -> Result<Vec<RerankResult>, AppError> {
// Lock this specific engine so we get &mut TextRerank
let mut guard = self.engine.lock().await;
guard
.rerank(query.to_owned(), documents, false, None)
.map_err(|e| AppError::InternalError(e.to_string()))
}
}

View File

@@ -0,0 +1,183 @@
use std::cmp::Ordering;
use common::storage::types::StoredObject;
/// Holds optional subscores gathered from different retrieval signals.
#[derive(Debug, Clone, Copy, Default)]
pub struct Scores {
pub fts: Option<f32>,
pub vector: Option<f32>,
pub graph: Option<f32>,
}
/// Generic wrapper combining an item with its accumulated retrieval scores.
#[derive(Debug, Clone)]
pub struct Scored<T> {
pub item: T,
pub scores: Scores,
pub fused: f32,
}
impl<T> Scored<T> {
pub fn new(item: T) -> Self {
Self {
item,
scores: Scores::default(),
fused: 0.0,
}
}
pub const fn with_vector_score(mut self, score: f32) -> Self {
self.scores.vector = Some(score);
self
}
pub const fn with_fts_score(mut self, score: f32) -> Self {
self.scores.fts = Some(score);
self
}
pub const fn with_graph_score(mut self, score: f32) -> Self {
self.scores.graph = Some(score);
self
}
pub const fn update_fused(&mut self, fused: f32) {
self.fused = fused;
}
}
/// Weights used for linear score fusion.
#[derive(Debug, Clone, Copy)]
pub struct FusionWeights {
pub vector: f32,
pub fts: f32,
pub graph: f32,
pub multi_bonus: f32,
}
impl Default for FusionWeights {
fn default() -> Self {
Self {
vector: 0.5,
fts: 0.3,
graph: 0.2,
multi_bonus: 0.02,
}
}
}
pub const fn clamp_unit(value: f32) -> f32 {
value.clamp(0.0, 1.0)
}
pub fn distance_to_similarity(distance: f32) -> f32 {
if !distance.is_finite() {
return 0.0;
}
clamp_unit(1.0 / (1.0 + distance.max(0.0)))
}
pub fn min_max_normalize(scores: &[f32]) -> Vec<f32> {
if scores.is_empty() {
return Vec::new();
}
let mut min = f32::MAX;
let mut max = f32::MIN;
for s in scores {
if !s.is_finite() {
continue;
}
if *s < min {
min = *s;
}
if *s > max {
max = *s;
}
}
if !min.is_finite() || !max.is_finite() {
return scores.iter().map(|_| 0.0).collect();
}
if (max - min).abs() < f32::EPSILON {
return vec![1.0; scores.len()];
}
scores
.iter()
.map(|score| {
if score.is_finite() {
clamp_unit((score - min) / (max - min))
} else {
0.0
}
})
.collect()
}
pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 {
let vector = scores.vector.unwrap_or(0.0);
let fts = scores.fts.unwrap_or(0.0);
let graph = scores.graph.unwrap_or(0.0);
let mut fused = graph.mul_add(
weights.graph,
vector.mul_add(weights.vector, fts * weights.fts),
);
let signals_present = scores
.vector
.iter()
.chain(scores.fts.iter())
.chain(scores.graph.iter())
.count();
if signals_present >= 2 {
fused += weights.multi_bonus;
}
clamp_unit(fused)
}
pub fn merge_scored_by_id<T>(
target: &mut std::collections::HashMap<String, Scored<T>>,
incoming: Vec<Scored<T>>,
) where
T: StoredObject + Clone,
{
for scored in incoming {
let id = scored.item.get_id().to_owned();
target
.entry(id)
.and_modify(|existing| {
if let Some(score) = scored.scores.vector {
existing.scores.vector = Some(score);
}
if let Some(score) = scored.scores.fts {
existing.scores.fts = Some(score);
}
if let Some(score) = scored.scores.graph {
existing.scores.graph = Some(score);
}
})
.or_insert_with(|| Scored {
item: scored.item.clone(),
scores: scored.scores,
fused: scored.fused,
});
}
}
pub fn sort_by_fused_desc<T>(items: &mut [Scored<T>])
where
T: StoredObject,
{
items.sort_by(|a, b| {
b.fused
.partial_cmp(&a.fused)
.unwrap_or(Ordering::Equal)
.then_with(|| a.item.get_id().cmp(b.item.get_id()))
});
}

View File

@@ -0,0 +1,218 @@
use std::collections::HashMap;
use common::{
error::AppError,
storage::{
db::SurrealDbClient,
types::{file_info::deserialize_flexible_id, StoredObject},
},
utils::embedding::generate_embedding,
};
use serde::Deserialize;
use surrealdb::sql::Thing;
use crate::scoring::{clamp_unit, distance_to_similarity, Scored};
/// Compares vectors and retrieves a number of items from the specified table.
///
/// This function generates embeddings for the input text, constructs a query to find the closest matches in the database,
/// and then deserializes the results into the specified type `T`.
///
/// # Arguments
///
/// * `take` - The number of items to retrieve from the database.
/// * `input_text` - The text to generate embeddings for.
/// * `db_client` - The SurrealDB client to use for querying the database.
/// * `table` - The table to query in the database.
/// * `openai_client` - The OpenAI client to use for generating embeddings.
/// * 'user_id`- The user id of the current user.
///
/// # Returns
///
/// A vector of type `T` containing the closest matches to the input text. Returns a `ProcessingError` if an error occurs.
///
/// # Type Parameters
///
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`.
pub async fn find_items_by_vector_similarity<T>(
take: usize,
input_text: &str,
db_client: &SurrealDbClient,
table: &str,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
user_id: &str,
) -> Result<Vec<Scored<T>>, AppError>
where
T: for<'de> serde::Deserialize<'de> + StoredObject,
{
// Generate embeddings
let input_embedding = generate_embedding(openai_client, input_text, db_client).await?;
find_items_by_vector_similarity_with_embedding(take, input_embedding, db_client, table, user_id)
.await
}
#[derive(Debug, Deserialize)]
struct DistanceRow {
#[serde(deserialize_with = "deserialize_flexible_id")]
id: String,
distance: Option<f32>,
}
pub async fn find_items_by_vector_similarity_with_embedding<T>(
take: usize,
query_embedding: Vec<f32>,
db_client: &SurrealDbClient,
table: &str,
user_id: &str,
) -> Result<Vec<Scored<T>>, AppError>
where
T: for<'de> serde::Deserialize<'de> + StoredObject,
{
let embedding_literal = serde_json::to_string(&query_embedding)
.map_err(|err| AppError::InternalError(format!("Failed to serialize embedding: {err}")))?;
let closest_query = format!(
"SELECT id, vector::distance::knn() AS distance \
FROM {table} \
WHERE user_id = $user_id AND embedding <|{take},40|> {embedding} \
LIMIT $limit",
table = table,
take = take,
embedding = embedding_literal
);
let mut response = db_client
.query(closest_query)
.bind(("user_id", user_id.to_owned()))
.bind(("limit", take as i64))
.await?;
let distance_rows: Vec<DistanceRow> = response.take(0)?;
if distance_rows.is_empty() {
return Ok(Vec::new());
}
let ids: Vec<String> = distance_rows.iter().map(|row| row.id.clone()).collect();
let thing_ids: Vec<Thing> = ids
.iter()
.map(|id| Thing::from((table, id.as_str())))
.collect();
let mut items_response = db_client
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
.bind(("table", table.to_owned()))
.bind(("things", thing_ids.clone()))
.bind(("user_id", user_id.to_owned()))
.await?;
let items: Vec<T> = items_response.take(0)?;
let mut item_map: HashMap<String, T> = items
.into_iter()
.map(|item| (item.get_id().to_owned(), item))
.collect();
let mut min_distance = f32::MAX;
let mut max_distance = f32::MIN;
for row in &distance_rows {
if let Some(distance) = row.distance {
if distance.is_finite() {
if distance < min_distance {
min_distance = distance;
}
if distance > max_distance {
max_distance = distance;
}
}
}
}
let normalize = min_distance.is_finite()
&& max_distance.is_finite()
&& (max_distance - min_distance).abs() > f32::EPSILON;
let mut scored = Vec::with_capacity(distance_rows.len());
for row in distance_rows {
if let Some(item) = item_map.remove(&row.id) {
let similarity = row
.distance
.map(|distance| {
if normalize {
let span = max_distance - min_distance;
if span.abs() < f32::EPSILON {
1.0
} else {
let normalized = 1.0 - ((distance - min_distance) / span);
clamp_unit(normalized)
}
} else {
distance_to_similarity(distance)
}
})
.unwrap_or_default();
scored.push(Scored::new(item).with_vector_score(similarity));
}
}
Ok(scored)
}
#[derive(Debug, Clone, Deserialize)]
pub struct ChunkSnippet {
pub id: String,
pub source_id: String,
pub chunk: String,
}
#[derive(Debug, Deserialize)]
struct ChunkDistanceRow {
distance: Option<f32>,
#[serde(deserialize_with = "deserialize_flexible_id")]
pub id: String,
pub source_id: String,
pub chunk: String,
}
pub async fn find_chunk_snippets_by_vector_similarity_with_embedding(
take: usize,
query_embedding: Vec<f32>,
db_client: &SurrealDbClient,
user_id: &str,
) -> Result<Vec<Scored<ChunkSnippet>>, AppError> {
let embedding_literal = serde_json::to_string(&query_embedding)
.map_err(|err| AppError::InternalError(format!("Failed to serialize embedding: {err}")))?;
let closest_query = format!(
"SELECT id, source_id, chunk, vector::distance::knn() AS distance \
FROM text_chunk \
WHERE user_id = $user_id AND embedding <|{take},40|> {embedding} \
LIMIT $limit",
take = take,
embedding = embedding_literal
);
let mut response = db_client
.query(closest_query)
.bind(("user_id", user_id.to_owned()))
.bind(("limit", take as i64))
.await?;
let rows: Vec<ChunkDistanceRow> = response.take(0)?;
let mut scored = Vec::with_capacity(rows.len());
for row in rows {
let similarity = row.distance.map(distance_to_similarity).unwrap_or_default();
scored.push(
Scored::new(ChunkSnippet {
id: row.id,
source_id: row.source_id,
chunk: row.chunk,
})
.with_vector_score(similarity),
);
}
Ok(scored)
}