feat: hybrid search

This commit is contained in:
Per Stark
2025-10-14 20:38:43 +02:00
parent aa0b1462a1
commit dc40cf7663
10 changed files with 1390 additions and 131 deletions

View File

@@ -11,7 +11,6 @@ use common::{
storage::{
db::SurrealDbClient,
types::{
knowledge_entity::KnowledgeEntity,
message::{format_history, Message},
system_settings::SystemSettings,
},
@@ -20,7 +19,7 @@ use common::{
use serde::Deserialize;
use serde_json::{json, Value};
use crate::retrieve_entities;
use crate::{retrieve_entities, RetrievedEntity};
use super::answer_retrieval_helper::get_query_response_schema;
@@ -84,21 +83,31 @@ pub async fn get_answer_with_references(
})
}
pub fn format_entities_json(entities: &[KnowledgeEntity]) -> Value {
pub fn format_entities_json(entities: &[RetrievedEntity]) -> Value {
json!(entities
.iter()
.map(|entity| {
.map(|entry| {
json!({
"KnowledgeEntity": {
"id": entity.id,
"name": entity.name,
"description": entity.description
"id": entry.entity.id,
"name": entry.entity.name,
"description": entry.entity.description,
"score": round_score(entry.score),
"chunks": entry.chunks.iter().map(|chunk| {
json!({
"score": round_score(chunk.score),
"content": chunk.chunk.chunk
})
}).collect::<Vec<_>>()
}
})
})
.collect::<Vec<_>>())
}
fn round_score(value: f32) -> f64 {
((value as f64) * 1000.0).round() / 1000.0
}
pub fn create_user_message(entities_json: &Value, query: &str) -> String {
format!(
r#"

View File

@@ -0,0 +1,265 @@
use std::collections::HashMap;
use serde::Deserialize;
use tracing::debug;
use common::{
error::AppError,
storage::{db::SurrealDbClient, types::StoredObject},
};
use crate::scoring::Scored;
use common::storage::types::file_info::deserialize_flexible_id;
use surrealdb::sql::Thing;
#[derive(Debug, Deserialize)]
struct FtsScoreRow {
#[serde(deserialize_with = "deserialize_flexible_id")]
id: String,
fts_score: Option<f32>,
}
/// Executes a full-text search query against SurrealDB and returns scored results.
///
/// The function expects FTS indexes to exist for the provided table. Currently supports
/// `knowledge_entity` (name + description) and `text_chunk` (chunk).
pub async fn find_items_by_fts<T>(
take: usize,
query: &str,
db_client: &SurrealDbClient,
table: &str,
user_id: &str,
) -> Result<Vec<Scored<T>>, AppError>
where
T: for<'de> serde::Deserialize<'de> + StoredObject,
{
let (filter_clause, score_clause) = match table {
"knowledge_entity" => (
"(name @0@ $terms OR description @1@ $terms)",
"(IF search::score(0) != NONE THEN search::score(0) ELSE 0 END) + \
(IF search::score(1) != NONE THEN search::score(1) ELSE 0 END)",
),
"text_chunk" => (
"(chunk @0@ $terms)",
"IF search::score(0) != NONE THEN search::score(0) ELSE 0 END",
),
_ => {
return Err(AppError::Validation(format!(
"FTS not configured for table '{table}'"
)))
}
};
let sql = format!(
"SELECT id, {score_clause} AS fts_score \
FROM {table} \
WHERE {filter_clause} \
AND user_id = $user_id \
ORDER BY fts_score DESC \
LIMIT $limit",
table = table,
filter_clause = filter_clause,
score_clause = score_clause
);
debug!(
table = table,
limit = take,
"Executing FTS query with filter clause: {}",
filter_clause
);
let mut response = db_client
.query(sql)
.bind(("terms", query.to_owned()))
.bind(("user_id", user_id.to_owned()))
.bind(("limit", take as i64))
.await?;
let score_rows: Vec<FtsScoreRow> = response.take(0)?;
if score_rows.is_empty() {
return Ok(Vec::new());
}
let ids: Vec<String> = score_rows.iter().map(|row| row.id.clone()).collect();
let thing_ids: Vec<Thing> = ids
.iter()
.map(|id| Thing::from((table, id.as_str())))
.collect();
let mut items_response = db_client
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
.bind(("table", table.to_owned()))
.bind(("things", thing_ids.clone()))
.bind(("user_id", user_id.to_owned()))
.await?;
let items: Vec<T> = items_response.take(0)?;
let mut item_map: HashMap<String, T> = items
.into_iter()
.map(|item| (item.get_id().to_owned(), item))
.collect();
let mut results = Vec::with_capacity(score_rows.len());
for row in score_rows {
if let Some(item) = item_map.remove(&row.id) {
let score = row.fts_score.unwrap_or_default();
results.push(Scored::new(item).with_fts_score(score));
}
}
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
use common::storage::types::{
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
text_chunk::TextChunk,
StoredObject,
};
use uuid::Uuid;
fn dummy_embedding() -> Vec<f32> {
vec![0.0; 1536]
}
#[tokio::test]
async fn fts_preserves_single_field_score_for_name() {
let namespace = "fts_test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("failed to create in-memory surreal");
db.apply_migrations()
.await
.expect("failed to apply migrations");
let user_id = "user_fts";
let entity = KnowledgeEntity::new(
"source_a".into(),
"Rustacean handbook".into(),
"completely unrelated description".into(),
KnowledgeEntityType::Document,
None,
dummy_embedding(),
user_id.into(),
);
db.store_item(entity.clone())
.await
.expect("failed to insert entity");
db.rebuild_indexes()
.await
.expect("failed to rebuild indexes");
let results = find_items_by_fts::<KnowledgeEntity>(
5,
"rustacean",
&db,
KnowledgeEntity::table_name(),
user_id,
)
.await
.expect("fts query failed");
assert!(!results.is_empty(), "expected at least one FTS result");
assert!(
results[0].scores.fts.is_some(),
"expected an FTS score when only the name matched"
);
}
#[tokio::test]
async fn fts_preserves_single_field_score_for_description() {
let namespace = "fts_test_ns_desc";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("failed to create in-memory surreal");
db.apply_migrations()
.await
.expect("failed to apply migrations");
let user_id = "user_fts_desc";
let entity = KnowledgeEntity::new(
"source_b".into(),
"neutral name".into(),
"Detailed notes about async runtimes".into(),
KnowledgeEntityType::Document,
None,
dummy_embedding(),
user_id.into(),
);
db.store_item(entity.clone())
.await
.expect("failed to insert entity");
db.rebuild_indexes()
.await
.expect("failed to rebuild indexes");
let results = find_items_by_fts::<KnowledgeEntity>(
5,
"async",
&db,
KnowledgeEntity::table_name(),
user_id,
)
.await
.expect("fts query failed");
assert!(!results.is_empty(), "expected at least one FTS result");
assert!(
results[0].scores.fts.is_some(),
"expected an FTS score when only the description matched"
);
}
#[tokio::test]
async fn fts_preserves_scores_for_text_chunks() {
let namespace = "fts_test_ns_chunks";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("failed to create in-memory surreal");
db.apply_migrations()
.await
.expect("failed to apply migrations");
let user_id = "user_fts_chunk";
let chunk = TextChunk::new(
"source_chunk".into(),
"GraphQL documentation reference".into(),
dummy_embedding(),
user_id.into(),
);
db.store_item(chunk.clone())
.await
.expect("failed to insert chunk");
db.rebuild_indexes()
.await
.expect("failed to rebuild indexes");
let results =
find_items_by_fts::<TextChunk>(5, "graphql", &db, TextChunk::table_name(), user_id)
.await
.expect("fts query failed");
assert!(!results.is_empty(), "expected at least one FTS result");
assert!(
results[0].scores.fts.is_some(),
"expected an FTS score when chunk field matched"
);
}
}

View File

@@ -1,6 +1,14 @@
use surrealdb::Error;
use std::collections::{HashMap, HashSet};
use common::storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity};
use surrealdb::{sql::Thing, Error};
use common::storage::{
db::SurrealDbClient,
types::{
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
StoredObject,
},
};
/// Retrieves database entries that match a specific source identifier.
///
@@ -30,18 +38,21 @@ use common::storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEnt
/// * The database query fails to execute
/// * The results cannot be deserialized into type `T`
pub async fn find_entities_by_source_ids<T>(
source_id: Vec<String>,
table_name: String,
source_ids: Vec<String>,
table_name: &str,
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<T>, Error>
where
T: for<'de> serde::Deserialize<'de>,
{
let query = "SELECT * FROM type::table($table) WHERE source_id IN $source_ids";
let query =
"SELECT * FROM type::table($table) WHERE source_id IN $source_ids AND user_id = $user_id";
db.query(query)
.bind(("table", table_name))
.bind(("source_ids", source_id))
.bind(("table", table_name.to_owned()))
.bind(("source_ids", source_ids))
.bind(("user_id", user_id.to_owned()))
.await?
.take(0)
}
@@ -49,14 +60,92 @@ where
/// Find entities by their relationship to the id
pub async fn find_entities_by_relationship_by_id(
db: &SurrealDbClient,
entity_id: String,
entity_id: &str,
user_id: &str,
limit: usize,
) -> Result<Vec<KnowledgeEntity>, Error> {
let query = format!(
"SELECT *, <-> relates_to <-> knowledge_entity AS related FROM knowledge_entity:`{}`",
entity_id
);
let mut relationships_response = db
.query(
"
SELECT * FROM relates_to
WHERE metadata.user_id = $user_id
AND (in = type::thing('knowledge_entity', $entity_id)
OR out = type::thing('knowledge_entity', $entity_id))
",
)
.bind(("entity_id", entity_id.to_owned()))
.bind(("user_id", user_id.to_owned()))
.await?;
db.query(query).await?.take(0)
let relationships: Vec<KnowledgeRelationship> = relationships_response.take(0)?;
if relationships.is_empty() {
return Ok(Vec::new());
}
let mut neighbor_ids: Vec<String> = Vec::new();
let mut seen: HashSet<String> = HashSet::new();
for rel in relationships {
if rel.in_ == entity_id {
if seen.insert(rel.out.clone()) {
neighbor_ids.push(rel.out);
}
} else if rel.out == entity_id {
if seen.insert(rel.in_.clone()) {
neighbor_ids.push(rel.in_);
}
} else {
if seen.insert(rel.in_.clone()) {
neighbor_ids.push(rel.in_.clone());
}
if seen.insert(rel.out.clone()) {
neighbor_ids.push(rel.out);
}
}
}
neighbor_ids.retain(|id| id != entity_id);
if neighbor_ids.is_empty() {
return Ok(Vec::new());
}
if limit > 0 && neighbor_ids.len() > limit {
neighbor_ids.truncate(limit);
}
let thing_ids: Vec<Thing> = neighbor_ids
.iter()
.map(|id| Thing::from((KnowledgeEntity::table_name(), id.as_str())))
.collect();
let mut neighbors_response = db
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
.bind(("table", KnowledgeEntity::table_name().to_owned()))
.bind(("things", thing_ids))
.bind(("user_id", user_id.to_owned()))
.await?;
let neighbors: Vec<KnowledgeEntity> = neighbors_response.take(0)?;
if neighbors.is_empty() {
return Ok(Vec::new());
}
let mut neighbor_map: HashMap<String, KnowledgeEntity> = neighbors
.into_iter()
.map(|entity| (entity.id.clone(), entity))
.collect();
let mut ordered = Vec::new();
for id in neighbor_ids {
if let Some(entity) = neighbor_map.remove(&id) {
ordered.push(entity);
}
if limit > 0 && ordered.len() >= limit {
break;
}
}
Ok(ordered)
}
#[cfg(test)]
@@ -146,7 +235,7 @@ mod tests {
// Test finding entities by multiple source_ids
let source_ids = vec![source_id1.clone(), source_id2.clone()];
let found_entities: Vec<KnowledgeEntity> =
find_entities_by_source_ids(source_ids, KnowledgeEntity::table_name().to_string(), &db)
find_entities_by_source_ids(source_ids, KnowledgeEntity::table_name(), &user_id, &db)
.await
.expect("Failed to find entities by source_ids");
@@ -177,7 +266,8 @@ mod tests {
let single_source_id = vec![source_id1.clone()];
let found_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
single_source_id,
KnowledgeEntity::table_name().to_string(),
KnowledgeEntity::table_name(),
&user_id,
&db,
)
.await
@@ -202,7 +292,8 @@ mod tests {
let non_existent_source_id = vec!["non_existent_source".to_string()];
let found_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
non_existent_source_id,
KnowledgeEntity::table_name().to_string(),
KnowledgeEntity::table_name(),
&user_id,
&db,
)
.await
@@ -327,11 +418,15 @@ mod tests {
.expect("Failed to store relationship 2");
// Test finding entities related to the central entity
let related_entities = find_entities_by_relationship_by_id(&db, central_entity.id.clone())
.await
.expect("Failed to find entities by relationship");
let related_entities =
find_entities_by_relationship_by_id(&db, &central_entity.id, &user_id, usize::MAX)
.await
.expect("Failed to find entities by relationship");
// Check that we found relationships
assert!(related_entities.len() > 0, "Should find related entities");
assert!(
related_entities.len() >= 2,
"Should find related entities in both directions"
);
}
}

View File

@@ -1,90 +1,714 @@
pub mod answer_retrieval;
pub mod answer_retrieval_helper;
pub mod fts;
pub mod graph;
pub mod scoring;
pub mod vector;
use std::collections::{HashMap, HashSet};
use common::{
error::AppError,
storage::{
db::SurrealDbClient,
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject},
},
utils::embedding::generate_embedding,
};
use futures::future::{try_join, try_join_all};
use futures::{stream::FuturesUnordered, StreamExt};
use graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids};
use std::collections::HashMap;
use vector::find_items_by_vector_similarity;
use scoring::{
clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc,
FusionWeights, Scored,
};
use tracing::{debug, instrument, trace};
/// Performs a comprehensive knowledge entity retrieval using multiple search strategies
/// to find the most relevant entities for a given query.
///
/// # Strategy
/// The function employs a three-pronged approach to knowledge retrieval:
/// 1. Direct vector similarity search on knowledge entities
/// 2. Text chunk similarity search with source entity lookup
/// 3. Graph relationship traversal from related entities
///
/// This combined approach ensures both semantic similarity matches and structurally
/// related content are included in the results.
///
/// # Arguments
/// * `db_client` - SurrealDB client for database operations
/// * `openai_client` - OpenAI client for vector embeddings generation
/// * `query` - The search query string to find relevant knowledge entities
/// * 'user_id' - The user id of the current user
///
/// # Returns
/// * `Result<Vec<KnowledgeEntity>, AppError>` - A deduplicated vector of relevant
/// knowledge entities, or an error if the retrieval process fails
use crate::{fts::find_items_by_fts, vector::find_items_by_vector_similarity_with_embedding};
// Tunable knobs controlling first-pass recall, graph expansion, and answer shaping.
const ENTITY_VECTOR_TAKE: usize = 15;
const CHUNK_VECTOR_TAKE: usize = 20;
const ENTITY_FTS_TAKE: usize = 10;
const CHUNK_FTS_TAKE: usize = 20;
const SCORE_THRESHOLD: f32 = 0.35;
const FALLBACK_MIN_RESULTS: usize = 10;
const TOKEN_BUDGET_ESTIMATE: usize = 2800;
const AVG_CHARS_PER_TOKEN: usize = 4;
const MAX_CHUNKS_PER_ENTITY: usize = 4;
const GRAPH_TRAVERSAL_SEED_LIMIT: usize = 5;
const GRAPH_NEIGHBOR_LIMIT: usize = 6;
const GRAPH_SCORE_DECAY: f32 = 0.75;
const GRAPH_SEED_MIN_SCORE: f32 = 0.4;
const GRAPH_VECTOR_INHERITANCE: f32 = 0.6;
#[derive(Debug, Clone)]
pub struct RetrievedChunk {
pub chunk: TextChunk,
pub score: f32,
}
#[derive(Debug, Clone)]
pub struct RetrievedEntity {
pub entity: KnowledgeEntity,
pub score: f32,
pub chunks: Vec<RetrievedChunk>,
}
#[instrument(skip_all, fields(user_id))]
pub async fn retrieve_entities(
db_client: &SurrealDbClient,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
query: &str,
user_id: &str,
) -> Result<Vec<KnowledgeEntity>, AppError> {
let (items_from_knowledge_entity_similarity, closest_chunks) = try_join(
find_items_by_vector_similarity(
10,
) -> Result<Vec<RetrievedEntity>, AppError> {
trace!("Generating query embedding for hybrid retrieval");
let query_embedding = generate_embedding(openai_client, query, db_client).await?;
retrieve_entities_with_embedding(db_client, query_embedding, query, user_id).await
}
pub(crate) async fn retrieve_entities_with_embedding(
db_client: &SurrealDbClient,
query_embedding: Vec<f32>,
query: &str,
user_id: &str,
) -> Result<Vec<RetrievedEntity>, AppError> {
// 1) Gather first-pass candidates from vector search and BM25.
let weights = FusionWeights::default();
let (vector_entities, vector_chunks, mut fts_entities, mut fts_chunks) = tokio::try_join!(
find_items_by_vector_similarity_with_embedding(
ENTITY_VECTOR_TAKE,
query_embedding.clone(),
db_client,
"knowledge_entity",
user_id,
),
find_items_by_vector_similarity_with_embedding(
CHUNK_VECTOR_TAKE,
query_embedding,
db_client,
"text_chunk",
user_id,
),
find_items_by_fts(
ENTITY_FTS_TAKE,
query,
db_client,
"knowledge_entity",
openai_client,
user_id,
user_id
),
find_items_by_vector_similarity(5, query, db_client, "text_chunk", openai_client, user_id),
find_items_by_fts(CHUNK_FTS_TAKE, query, db_client, "text_chunk", user_id),
)?;
debug!(
vector_entities = vector_entities.len(),
vector_chunks = vector_chunks.len(),
fts_entities = fts_entities.len(),
fts_chunks = fts_chunks.len(),
"Hybrid retrieval initial candidate counts"
);
normalize_fts_scores(&mut fts_entities);
normalize_fts_scores(&mut fts_chunks);
let mut entity_candidates: HashMap<String, Scored<KnowledgeEntity>> = HashMap::new();
let mut chunk_candidates: HashMap<String, Scored<TextChunk>> = HashMap::new();
merge_scored_by_id(&mut entity_candidates, vector_entities);
merge_scored_by_id(&mut entity_candidates, fts_entities);
merge_scored_by_id(&mut chunk_candidates, vector_chunks);
merge_scored_by_id(&mut chunk_candidates, fts_chunks);
// 2) Normalize scores, fuse them, and allow high-confidence entities to pull neighbors from the graph.
apply_fusion(&mut entity_candidates, weights);
apply_fusion(&mut chunk_candidates, weights);
enrich_entities_from_graph(&mut entity_candidates, db_client, user_id, weights).await?;
// 3) Track high-signal chunk sources so we can backfill missing entities.
let chunk_by_source = group_chunks_by_source(&chunk_candidates);
let mut missing_sources = Vec::new();
for source_id in chunk_by_source.keys() {
if !entity_candidates
.values()
.any(|entity| entity.item.source_id == *source_id)
{
missing_sources.push(source_id.clone());
}
}
if !missing_sources.is_empty() {
let related_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
missing_sources.clone(),
"knowledge_entity",
user_id,
db_client,
)
.await
.unwrap_or_default();
for entity in related_entities {
if let Some(chunks) = chunk_by_source.get(&entity.source_id) {
let best_chunk_score = chunks
.iter()
.map(|chunk| chunk.fused)
.fold(0.0f32, f32::max);
let mut scored = Scored::new(entity.clone()).with_vector_score(best_chunk_score);
let fused = fuse_scores(&scored.scores, weights);
scored.update_fused(fused);
entity_candidates.insert(entity.id.clone(), scored);
}
}
}
// Boost entities with evidence from high scoring chunks.
for entity in entity_candidates.values_mut() {
if let Some(chunks) = chunk_by_source.get(&entity.item.source_id) {
let best_chunk_score = chunks
.iter()
.map(|chunk| chunk.fused)
.fold(0.0f32, f32::max);
if best_chunk_score > 0.0 {
let boosted = entity.scores.vector.unwrap_or(0.0).max(best_chunk_score);
entity.scores.vector = Some(boosted);
let fused = fuse_scores(&entity.scores, weights);
entity.update_fused(fused);
}
}
}
let mut entity_results: Vec<Scored<KnowledgeEntity>> =
entity_candidates.into_values().collect();
sort_by_fused_desc(&mut entity_results);
let mut filtered_entities: Vec<Scored<KnowledgeEntity>> = entity_results
.iter()
.filter(|candidate| candidate.fused >= SCORE_THRESHOLD)
.cloned()
.collect();
if filtered_entities.len() < FALLBACK_MIN_RESULTS {
filtered_entities = entity_results
.into_iter()
.take(FALLBACK_MIN_RESULTS)
.collect();
}
// 4) Re-rank chunks and prepare for attachment to surviving entities.
let mut chunk_results: Vec<Scored<TextChunk>> = chunk_candidates.into_values().collect();
sort_by_fused_desc(&mut chunk_results);
let mut chunk_by_id: HashMap<String, Scored<TextChunk>> = HashMap::new();
for chunk in chunk_results {
chunk_by_id.insert(chunk.item.id.clone(), chunk);
}
enrich_chunks_from_entities(
&mut chunk_by_id,
&filtered_entities,
db_client,
user_id,
weights,
)
.await?;
let source_ids = closest_chunks
.iter()
.map(|chunk: &TextChunk| chunk.source_id.clone())
.collect::<Vec<String>>();
let mut chunk_values: Vec<Scored<TextChunk>> = chunk_by_id.into_values().collect();
sort_by_fused_desc(&mut chunk_values);
let items_from_text_chunk_similarity: Vec<KnowledgeEntity> =
find_entities_by_source_ids(source_ids, "knowledge_entity".to_string(), db_client).await?;
let items_from_relationships_futures: Vec<_> = items_from_text_chunk_similarity
.clone()
.into_iter()
.map(|entity| find_entities_by_relationship_by_id(db_client, entity.id.clone()))
.collect();
let items_from_relationships = try_join_all(items_from_relationships_futures)
.await?
.into_iter()
.flatten()
.collect::<Vec<KnowledgeEntity>>();
let entities: Vec<KnowledgeEntity> = items_from_knowledge_entity_similarity
.into_iter()
.chain(items_from_text_chunk_similarity.into_iter())
.chain(items_from_relationships.into_iter())
.fold(HashMap::new(), |mut map, entity| {
map.insert(entity.id.clone(), entity);
map
})
.into_values()
.collect();
Ok(entities)
Ok(assemble_results(filtered_entities, chunk_values))
}
#[derive(Clone)]
struct GraphSeed {
id: String,
fused: f32,
}
async fn enrich_entities_from_graph(
entity_candidates: &mut HashMap<String, Scored<KnowledgeEntity>>,
db_client: &SurrealDbClient,
user_id: &str,
weights: FusionWeights,
) -> Result<(), AppError> {
if entity_candidates.is_empty() {
return Ok(());
}
// Select a small frontier of high-confidence entities to seed the relationship walk.
let mut seeds: Vec<GraphSeed> = entity_candidates
.values()
.filter(|entity| entity.fused >= GRAPH_SEED_MIN_SCORE)
.map(|entity| GraphSeed {
id: entity.item.id.clone(),
fused: entity.fused,
})
.collect();
if seeds.is_empty() {
return Ok(());
}
seeds.sort_by(|a, b| {
b.fused
.partial_cmp(&a.fused)
.unwrap_or(std::cmp::Ordering::Equal)
});
seeds.truncate(GRAPH_TRAVERSAL_SEED_LIMIT);
let mut futures = FuturesUnordered::new();
for seed in seeds.clone() {
let user_id = user_id.to_owned();
futures.push(async move {
// Fetch neighbors concurrently to avoid serial graph round trips.
let neighbors = find_entities_by_relationship_by_id(
db_client,
&seed.id,
&user_id,
GRAPH_NEIGHBOR_LIMIT,
)
.await;
(seed, neighbors)
});
}
while let Some((seed, neighbors_result)) = futures.next().await {
let neighbors = neighbors_result.map_err(AppError::from)?;
if neighbors.is_empty() {
continue;
}
for neighbor in neighbors {
if neighbor.id == seed.id {
continue;
}
let graph_score = clamp_unit(seed.fused * GRAPH_SCORE_DECAY);
let entry = entity_candidates
.entry(neighbor.id.clone())
.or_insert_with(|| Scored::new(neighbor.clone()));
entry.item = neighbor;
let inherited_vector = clamp_unit(graph_score * GRAPH_VECTOR_INHERITANCE);
let vector_existing = entry.scores.vector.unwrap_or(0.0);
if inherited_vector > vector_existing {
entry.scores.vector = Some(inherited_vector);
}
let existing_graph = entry.scores.graph.unwrap_or(f32::MIN);
if graph_score > existing_graph {
entry.scores.graph = Some(graph_score);
} else if entry.scores.graph.is_none() {
entry.scores.graph = Some(graph_score);
}
let fused = fuse_scores(&entry.scores, weights);
entry.update_fused(fused);
}
}
Ok(())
}
fn normalize_fts_scores<T>(results: &mut [Scored<T>]) {
// Scale BM25 outputs into [0,1] to keep fusion weights predictable.
let raw_scores: Vec<f32> = results
.iter()
.map(|candidate| candidate.scores.fts.unwrap_or(0.0))
.collect();
let normalized = min_max_normalize(&raw_scores);
for (candidate, normalized_score) in results.iter_mut().zip(normalized.into_iter()) {
candidate.scores.fts = Some(normalized_score);
candidate.update_fused(0.0);
}
}
fn apply_fusion<T>(candidates: &mut HashMap<String, Scored<T>>, weights: FusionWeights)
where
T: StoredObject,
{
// Collapse individual signals into a single fused score used for ranking.
for candidate in candidates.values_mut() {
let fused = fuse_scores(&candidate.scores, weights);
candidate.update_fused(fused);
}
}
fn group_chunks_by_source(
chunks: &HashMap<String, Scored<TextChunk>>,
) -> HashMap<String, Vec<Scored<TextChunk>>> {
// Preserve chunk candidates keyed by their originating source entity.
let mut by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
for chunk in chunks.values() {
by_source
.entry(chunk.item.source_id.clone())
.or_default()
.push(chunk.clone());
}
by_source
}
async fn enrich_chunks_from_entities(
chunk_candidates: &mut HashMap<String, Scored<TextChunk>>,
entities: &[Scored<KnowledgeEntity>],
db_client: &SurrealDbClient,
user_id: &str,
weights: FusionWeights,
) -> Result<(), AppError> {
// Fetch additional chunks referenced by entities that survived the fusion stage.
let mut source_ids: HashSet<String> = HashSet::new();
for entity in entities {
source_ids.insert(entity.item.source_id.clone());
}
if source_ids.is_empty() {
return Ok(());
}
let chunks = find_entities_by_source_ids::<TextChunk>(
source_ids.into_iter().collect(),
"text_chunk",
user_id,
db_client,
)
.await?;
let mut entity_score_lookup: HashMap<String, f32> = HashMap::new();
// Cache fused scores per source so chunks inherit the strength of their parent entity.
for entity in entities {
entity_score_lookup.insert(entity.item.source_id.clone(), entity.fused);
}
for chunk in chunks {
// Ensure each chunk is represented so downstream selection sees the latest content.
let entry = chunk_candidates
.entry(chunk.id.clone())
.or_insert_with(|| Scored::new(chunk.clone()).with_vector_score(0.0));
let entity_score = entity_score_lookup
.get(&chunk.source_id)
.copied()
.unwrap_or(0.0);
// Lift chunk score toward the entity score so supporting evidence is prioritised.
entry.scores.vector = Some(entry.scores.vector.unwrap_or(0.0).max(entity_score * 0.8));
let fused = fuse_scores(&entry.scores, weights);
entry.update_fused(fused);
entry.item = chunk;
}
Ok(())
}
fn assemble_results(
entities: Vec<Scored<KnowledgeEntity>>,
mut chunks: Vec<Scored<TextChunk>>,
) -> Vec<RetrievedEntity> {
// Re-associate chunk candidates with their parent entity for ranked selection.
let mut chunk_by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
for chunk in chunks.drain(..) {
chunk_by_source
.entry(chunk.item.source_id.clone())
.or_default()
.push(chunk);
}
for chunk_list in chunk_by_source.values_mut() {
sort_by_fused_desc(chunk_list);
}
let mut token_budget_remaining = TOKEN_BUDGET_ESTIMATE;
let mut results = Vec::new();
for entity in entities {
// Attach best chunks first while respecting per-entity and global token caps.
let mut selected_chunks = Vec::new();
if let Some(candidates) = chunk_by_source.get_mut(&entity.item.source_id) {
let mut per_entity_count = 0;
candidates.sort_by(|a, b| {
b.fused
.partial_cmp(&a.fused)
.unwrap_or(std::cmp::Ordering::Equal)
});
for candidate in candidates.iter() {
if per_entity_count >= MAX_CHUNKS_PER_ENTITY {
break;
}
let estimated_tokens = estimate_tokens(&candidate.item.chunk);
if estimated_tokens > token_budget_remaining {
continue;
}
token_budget_remaining = token_budget_remaining.saturating_sub(estimated_tokens);
per_entity_count += 1;
selected_chunks.push(RetrievedChunk {
chunk: candidate.item.clone(),
score: candidate.fused,
});
}
}
results.push(RetrievedEntity {
entity: entity.item.clone(),
score: entity.fused,
chunks: selected_chunks,
});
if token_budget_remaining == 0 {
break;
}
}
results
}
fn estimate_tokens(text: &str) -> usize {
// Simple heuristic to avoid calling a tokenizer in hot code paths.
let chars = text.chars().count().max(1);
(chars / AVG_CHARS_PER_TOKEN).max(1)
}
#[cfg(test)]
mod tests {
use super::*;
use common::storage::types::{
knowledge_entity::KnowledgeEntityType, knowledge_relationship::KnowledgeRelationship,
};
use uuid::Uuid;
fn test_embedding() -> Vec<f32> {
vec![0.9, 0.1, 0.0]
}
fn entity_embedding_high() -> Vec<f32> {
vec![0.8, 0.2, 0.0]
}
fn entity_embedding_low() -> Vec<f32> {
vec![0.1, 0.9, 0.0]
}
fn chunk_embedding_primary() -> Vec<f32> {
vec![0.85, 0.15, 0.0]
}
fn chunk_embedding_secondary() -> Vec<f32> {
vec![0.2, 0.8, 0.0]
}
async fn setup_test_db() -> SurrealDbClient {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
db.query(
"BEGIN TRANSACTION;
REMOVE INDEX IF EXISTS idx_embedding_chunks ON TABLE text_chunk;
DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION 3;
REMOVE INDEX IF EXISTS idx_embedding_entities ON TABLE knowledge_entity;
DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION 3;
COMMIT TRANSACTION;",
)
.await
.expect("Failed to redefine vector indexes for tests");
db
}
async fn seed_test_data(db: &SurrealDbClient, user_id: &str) {
let entity_relevant = KnowledgeEntity::new(
"source_a".into(),
"Rust Concurrency Patterns".into(),
"Discussion about async concurrency in Rust.".into(),
KnowledgeEntityType::Document,
None,
entity_embedding_high(),
user_id.into(),
);
let entity_irrelevant = KnowledgeEntity::new(
"source_b".into(),
"Python Tips".into(),
"General Python programming tips.".into(),
KnowledgeEntityType::Document,
None,
entity_embedding_low(),
user_id.into(),
);
db.store_item(entity_relevant.clone())
.await
.expect("Failed to store relevant entity");
db.store_item(entity_irrelevant.clone())
.await
.expect("Failed to store irrelevant entity");
let chunk_primary = TextChunk::new(
entity_relevant.source_id.clone(),
"Tokio enables async concurrency with lightweight tasks.".into(),
chunk_embedding_primary(),
user_id.into(),
);
let chunk_secondary = TextChunk::new(
entity_irrelevant.source_id.clone(),
"Python focuses on readability and dynamic typing.".into(),
chunk_embedding_secondary(),
user_id.into(),
);
db.store_item(chunk_primary)
.await
.expect("Failed to store primary chunk");
db.store_item(chunk_secondary)
.await
.expect("Failed to store secondary chunk");
}
#[tokio::test]
async fn test_hybrid_retrieval_prioritises_relevant_entity() {
let db = setup_test_db().await;
let user_id = "user123";
seed_test_data(&db, user_id).await;
let results = retrieve_entities_with_embedding(
&db,
test_embedding(),
"Rust concurrency async tasks",
user_id,
)
.await
.expect("Hybrid retrieval failed");
assert!(
!results.is_empty(),
"Expected at least one retrieval result"
);
let top = &results[0];
assert!(
top.entity.name.contains("Rust"),
"Expected Rust entity to be ranked first"
);
assert!(
!top.chunks.is_empty(),
"Expected Rust entity to include supporting chunks"
);
let chunk_texts: Vec<&str> = top
.chunks
.iter()
.map(|chunk| chunk.chunk.chunk.as_str())
.collect();
assert!(
chunk_texts.iter().any(|text| text.contains("Tokio")),
"Expected chunk discussing Tokio to be included"
);
}
#[tokio::test]
async fn test_graph_relationship_enriches_results() {
let db = setup_test_db().await;
let user_id = "graph_user";
let primary = KnowledgeEntity::new(
"primary_source".into(),
"Async Rust patterns".into(),
"Explores async runtimes and scheduling strategies.".into(),
KnowledgeEntityType::Document,
None,
entity_embedding_high(),
user_id.into(),
);
let neighbor = KnowledgeEntity::new(
"neighbor_source".into(),
"Tokio Scheduler Deep Dive".into(),
"Details on Tokio's cooperative scheduler.".into(),
KnowledgeEntityType::Document,
None,
entity_embedding_low(),
user_id.into(),
);
db.store_item(primary.clone())
.await
.expect("Failed to store primary entity");
db.store_item(neighbor.clone())
.await
.expect("Failed to store neighbor entity");
let primary_chunk = TextChunk::new(
primary.source_id.clone(),
"Rust async tasks use Tokio's cooperative scheduler.".into(),
chunk_embedding_primary(),
user_id.into(),
);
let neighbor_chunk = TextChunk::new(
neighbor.source_id.clone(),
"Tokio's scheduler manages task fairness across executors.".into(),
chunk_embedding_secondary(),
user_id.into(),
);
db.store_item(primary_chunk)
.await
.expect("Failed to store primary chunk");
db.store_item(neighbor_chunk)
.await
.expect("Failed to store neighbor chunk");
let relationship = KnowledgeRelationship::new(
primary.id.clone(),
neighbor.id.clone(),
user_id.into(),
"relationship_source".into(),
"references".into(),
);
relationship
.store_relationship(&db)
.await
.expect("Failed to store relationship");
let results = retrieve_entities_with_embedding(
&db,
test_embedding(),
"Rust concurrency async tasks",
user_id,
)
.await
.expect("Hybrid retrieval failed");
let mut neighbor_entry = None;
for entity in &results {
if entity.entity.id == neighbor.id {
neighbor_entry = Some(entity.clone());
}
}
let neighbor_entry =
neighbor_entry.expect("Graph-enriched neighbor should appear in results");
assert!(
neighbor_entry.score > 0.2,
"Graph-enriched entity should have a meaningful fused score"
);
assert!(
neighbor_entry
.chunks
.iter()
.all(|chunk| chunk.chunk.source_id == neighbor.source_id),
"Neighbor entity should surface its own supporting chunks"
);
}
}

View File

@@ -0,0 +1,180 @@
use std::cmp::Ordering;
use common::storage::types::StoredObject;
/// Holds optional subscores gathered from different retrieval signals.
#[derive(Debug, Clone, Copy, Default)]
pub struct Scores {
pub fts: Option<f32>,
pub vector: Option<f32>,
pub graph: Option<f32>,
}
/// Generic wrapper combining an item with its accumulated retrieval scores.
#[derive(Debug, Clone)]
pub struct Scored<T> {
pub item: T,
pub scores: Scores,
pub fused: f32,
}
impl<T> Scored<T> {
pub fn new(item: T) -> Self {
Self {
item,
scores: Scores::default(),
fused: 0.0,
}
}
pub fn with_vector_score(mut self, score: f32) -> Self {
self.scores.vector = Some(score);
self
}
pub fn with_fts_score(mut self, score: f32) -> Self {
self.scores.fts = Some(score);
self
}
pub fn with_graph_score(mut self, score: f32) -> Self {
self.scores.graph = Some(score);
self
}
pub fn update_fused(&mut self, fused: f32) {
self.fused = fused;
}
}
/// Weights used for linear score fusion.
#[derive(Debug, Clone, Copy)]
pub struct FusionWeights {
pub vector: f32,
pub fts: f32,
pub graph: f32,
pub multi_bonus: f32,
}
impl Default for FusionWeights {
fn default() -> Self {
Self {
vector: 0.5,
fts: 0.3,
graph: 0.2,
multi_bonus: 0.02,
}
}
}
pub fn clamp_unit(value: f32) -> f32 {
value.max(0.0).min(1.0)
}
pub fn distance_to_similarity(distance: f32) -> f32 {
if !distance.is_finite() {
return 0.0;
}
clamp_unit(1.0 / (1.0 + distance.max(0.0)))
}
pub fn min_max_normalize(scores: &[f32]) -> Vec<f32> {
if scores.is_empty() {
return Vec::new();
}
let mut min = f32::MAX;
let mut max = f32::MIN;
for s in scores {
if !s.is_finite() {
continue;
}
if *s < min {
min = *s;
}
if *s > max {
max = *s;
}
}
if !min.is_finite() || !max.is_finite() {
return scores.iter().map(|_| 0.0).collect();
}
if (max - min).abs() < f32::EPSILON {
return vec![1.0; scores.len()];
}
scores
.iter()
.map(|score| {
if !score.is_finite() {
0.0
} else {
clamp_unit((score - min) / (max - min))
}
})
.collect()
}
pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 {
let vector = scores.vector.unwrap_or(0.0);
let fts = scores.fts.unwrap_or(0.0);
let graph = scores.graph.unwrap_or(0.0);
let mut fused = vector * weights.vector + fts * weights.fts + graph * weights.graph;
let signals_present = scores
.vector
.iter()
.chain(scores.fts.iter())
.chain(scores.graph.iter())
.count();
if signals_present >= 2 {
fused += weights.multi_bonus;
}
clamp_unit(fused)
}
pub fn merge_scored_by_id<T>(
target: &mut std::collections::HashMap<String, Scored<T>>,
incoming: Vec<Scored<T>>,
) where
T: StoredObject + Clone,
{
for scored in incoming {
let id = scored.item.get_id().to_owned();
target
.entry(id)
.and_modify(|existing| {
if let Some(score) = scored.scores.vector {
existing.scores.vector = Some(score);
}
if let Some(score) = scored.scores.fts {
existing.scores.fts = Some(score);
}
if let Some(score) = scored.scores.graph {
existing.scores.graph = Some(score);
}
})
.or_insert_with(|| Scored {
item: scored.item.clone(),
scores: scored.scores,
fused: scored.fused,
});
}
}
pub fn sort_by_fused_desc<T>(items: &mut [Scored<T>])
where
T: StoredObject,
{
items.sort_by(|a, b| {
b.fused
.partial_cmp(&a.fused)
.unwrap_or(Ordering::Equal)
.then_with(|| a.item.get_id().cmp(b.item.get_id()))
});
}

View File

@@ -1,4 +1,15 @@
use common::{error::AppError, storage::db::SurrealDbClient, utils::embedding::generate_embedding};
use std::collections::HashMap;
use common::storage::types::file_info::deserialize_flexible_id;
use common::{
error::AppError,
storage::{db::SurrealDbClient, types::StoredObject},
utils::embedding::generate_embedding,
};
use serde::Deserialize;
use surrealdb::sql::Thing;
use crate::scoring::{distance_to_similarity, Scored};
/// Compares vectors and retrieves a number of items from the specified table.
///
@@ -22,24 +33,90 @@ use common::{error::AppError, storage::db::SurrealDbClient, utils::embedding::ge
///
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`.
pub async fn find_items_by_vector_similarity<T>(
take: u8,
take: usize,
input_text: &str,
db_client: &SurrealDbClient,
table: &str,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
user_id: &str,
) -> Result<Vec<T>, AppError>
) -> Result<Vec<Scored<T>>, AppError>
where
T: for<'de> serde::Deserialize<'de>,
T: for<'de> serde::Deserialize<'de> + StoredObject,
{
// Generate embeddings
let input_embedding = generate_embedding(openai_client, input_text, db_client).await?;
// Construct the query
let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE user_id = '{}' AND embedding <|{},40|> {:?} ORDER BY distance", table, user_id, take, input_embedding);
// Perform query and deserialize to struct
let closest_entities: Vec<T> = db_client.query(closest_query).await?.take(0)?;
Ok(closest_entities)
find_items_by_vector_similarity_with_embedding(take, input_embedding, db_client, table, user_id)
.await
}
#[derive(Debug, Deserialize)]
struct DistanceRow {
#[serde(deserialize_with = "deserialize_flexible_id")]
id: String,
distance: Option<f32>,
}
pub async fn find_items_by_vector_similarity_with_embedding<T>(
take: usize,
query_embedding: Vec<f32>,
db_client: &SurrealDbClient,
table: &str,
user_id: &str,
) -> Result<Vec<Scored<T>>, AppError>
where
T: for<'de> serde::Deserialize<'de> + StoredObject,
{
let embedding_literal = serde_json::to_string(&query_embedding)
.map_err(|err| AppError::InternalError(format!("Failed to serialize embedding: {err}")))?;
let closest_query = format!(
"SELECT id, vector::distance::knn() AS distance \
FROM {table} \
WHERE user_id = $user_id AND embedding <|{take},40|> {embedding} \
LIMIT $limit",
table = table,
take = take,
embedding = embedding_literal
);
let mut response = db_client
.query(closest_query)
.bind(("user_id", user_id.to_owned()))
.bind(("limit", take as i64))
.await?;
let distance_rows: Vec<DistanceRow> = response.take(0)?;
if distance_rows.is_empty() {
return Ok(Vec::new());
}
let ids: Vec<String> = distance_rows.iter().map(|row| row.id.clone()).collect();
let thing_ids: Vec<Thing> = ids
.iter()
.map(|id| Thing::from((table, id.as_str())))
.collect();
let mut items_response = db_client
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
.bind(("table", table.to_owned()))
.bind(("things", thing_ids.clone()))
.bind(("user_id", user_id.to_owned()))
.await?;
let items: Vec<T> = items_response.take(0)?;
let mut item_map: HashMap<String, T> = items
.into_iter()
.map(|item| (item.get_id().to_owned(), item))
.collect();
let mut scored = Vec::with_capacity(distance_rows.len());
for row in distance_rows {
if let Some(item) = item_map.remove(&row.id) {
let similarity = row.distance.map(distance_to_similarity).unwrap_or_default();
scored.push(Scored::new(item).with_vector_score(similarity));
}
}
Ok(scored)
}