diff --git a/README.md b/README.md index 3607aa9..cb8c650 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,8 @@ You may switch and choose between models used, and have the possiblity to change The application is built for speed and efficiency using Rust with a Server-Side Rendered (SSR) frontend (HTMX and minimal JavaScript). It's fully responsive, offering a complete mobile interface for reading, editing, and managing your content, including the graph database itself. **PWA (Progressive Web App) support** means you can "install" Minne to your device for a native-like experience. For quick capture on the go on iOS, a [**Shortcut**](https://www.icloud.com/shortcuts/e433fbd7602f4e2eaa70dca162323477) makes sending content to your Minne instance a breeze. +A hybrid retrieval layer blends embeddings, full-text search, and graph signals to surface the best context when augmenting chat responses and when building new relationships during ingestion. + Minne is open source (AGPL), self-hostable, and can be deployed flexibly: via Nix, Docker Compose, pre-built binaries, or by building from source. It can run as a single `main` binary or as separate `server` and `worker` processes for optimized resource allocation. ## Tech Stack diff --git a/common/migrations/20251003_000001_add_fts_for_entities_and_chunks.surql b/common/migrations/20251003_000001_add_fts_for_entities_and_chunks.surql new file mode 100644 index 0000000..cad213c --- /dev/null +++ b/common/migrations/20251003_000001_add_fts_for_entities_and_chunks.surql @@ -0,0 +1,17 @@ +-- Add FTS indexes for searching name and description on entities + +DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer + TOKENIZERS class + FILTERS lowercase, ascii, snowball(english); + +DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_name_idx ON TABLE knowledge_entity + FIELDS name + SEARCH ANALYZER app_en_fts_analyzer BM25; + +DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_description_idx ON TABLE knowledge_entity + FIELDS description + SEARCH ANALYZER app_en_fts_analyzer BM25; + +DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk + FIELDS chunk + SEARCH ANALYZER app_en_fts_analyzer BM25; diff --git a/common/src/storage/db.rs b/common/src/storage/db.rs index dc3bc23..8937bfb 100644 --- a/common/src/storage/db.rs +++ b/common/src/storage/db.rs @@ -80,15 +80,18 @@ impl SurrealDbClient { /// Operation to rebuild indexes pub async fn rebuild_indexes(&self) -> Result<(), Error> { debug!("Rebuilding indexes"); - self.client - .query("REBUILD INDEX IF EXISTS idx_embedding_chunks ON text_chunk") - .await?; - self.client - .query("REBUILD INDEX IF EXISTS idx_embedding_entities ON knowledge_entity") - .await?; - self.client - .query("REBUILD INDEX IF EXISTS text_content_fts_idx ON text_content") - .await?; + let rebuild_sql = r#" + BEGIN TRANSACTION; + REBUILD INDEX IF EXISTS idx_embedding_chunks ON text_chunk; + REBUILD INDEX IF EXISTS idx_embedding_entities ON knowledge_entity; + REBUILD INDEX IF EXISTS text_content_fts_idx ON text_content; + REBUILD INDEX IF EXISTS knowledge_entity_fts_name_idx ON knowledge_entity; + REBUILD INDEX IF EXISTS knowledge_entity_fts_description_idx ON knowledge_entity; + REBUILD INDEX IF EXISTS text_chunk_fts_chunk_idx ON text_chunk; + COMMIT TRANSACTION; + "#; + + self.client.query(rebuild_sql).await?; Ok(()) } diff --git a/composite-retrieval/src/answer_retrieval.rs b/composite-retrieval/src/answer_retrieval.rs index 64c3782..87c5c58 100644 --- a/composite-retrieval/src/answer_retrieval.rs +++ b/composite-retrieval/src/answer_retrieval.rs @@ -11,7 +11,6 @@ use common::{ storage::{ db::SurrealDbClient, types::{ - knowledge_entity::KnowledgeEntity, message::{format_history, Message}, system_settings::SystemSettings, }, @@ -20,7 +19,7 @@ use common::{ use serde::Deserialize; use serde_json::{json, Value}; -use crate::retrieve_entities; +use crate::{retrieve_entities, RetrievedEntity}; use super::answer_retrieval_helper::get_query_response_schema; @@ -84,21 +83,31 @@ pub async fn get_answer_with_references( }) } -pub fn format_entities_json(entities: &[KnowledgeEntity]) -> Value { +pub fn format_entities_json(entities: &[RetrievedEntity]) -> Value { json!(entities .iter() - .map(|entity| { + .map(|entry| { json!({ "KnowledgeEntity": { - "id": entity.id, - "name": entity.name, - "description": entity.description + "id": entry.entity.id, + "name": entry.entity.name, + "description": entry.entity.description, + "score": round_score(entry.score), + "chunks": entry.chunks.iter().map(|chunk| { + json!({ + "score": round_score(chunk.score), + "content": chunk.chunk.chunk + }) + }).collect::>() } }) }) .collect::>()) } +fn round_score(value: f32) -> f64 { + ((value as f64) * 1000.0).round() / 1000.0 +} pub fn create_user_message(entities_json: &Value, query: &str) -> String { format!( r#" diff --git a/composite-retrieval/src/fts.rs b/composite-retrieval/src/fts.rs new file mode 100644 index 0000000..28ca217 --- /dev/null +++ b/composite-retrieval/src/fts.rs @@ -0,0 +1,265 @@ +use std::collections::HashMap; + +use serde::Deserialize; +use tracing::debug; + +use common::{ + error::AppError, + storage::{db::SurrealDbClient, types::StoredObject}, +}; + +use crate::scoring::Scored; +use common::storage::types::file_info::deserialize_flexible_id; +use surrealdb::sql::Thing; + +#[derive(Debug, Deserialize)] +struct FtsScoreRow { + #[serde(deserialize_with = "deserialize_flexible_id")] + id: String, + fts_score: Option, +} + +/// Executes a full-text search query against SurrealDB and returns scored results. +/// +/// The function expects FTS indexes to exist for the provided table. Currently supports +/// `knowledge_entity` (name + description) and `text_chunk` (chunk). +pub async fn find_items_by_fts( + take: usize, + query: &str, + db_client: &SurrealDbClient, + table: &str, + user_id: &str, +) -> Result>, AppError> +where + T: for<'de> serde::Deserialize<'de> + StoredObject, +{ + let (filter_clause, score_clause) = match table { + "knowledge_entity" => ( + "(name @0@ $terms OR description @1@ $terms)", + "(IF search::score(0) != NONE THEN search::score(0) ELSE 0 END) + \ + (IF search::score(1) != NONE THEN search::score(1) ELSE 0 END)", + ), + "text_chunk" => ( + "(chunk @0@ $terms)", + "IF search::score(0) != NONE THEN search::score(0) ELSE 0 END", + ), + _ => { + return Err(AppError::Validation(format!( + "FTS not configured for table '{table}'" + ))) + } + }; + + let sql = format!( + "SELECT id, {score_clause} AS fts_score \ + FROM {table} \ + WHERE {filter_clause} \ + AND user_id = $user_id \ + ORDER BY fts_score DESC \ + LIMIT $limit", + table = table, + filter_clause = filter_clause, + score_clause = score_clause + ); + + debug!( + table = table, + limit = take, + "Executing FTS query with filter clause: {}", + filter_clause + ); + + let mut response = db_client + .query(sql) + .bind(("terms", query.to_owned())) + .bind(("user_id", user_id.to_owned())) + .bind(("limit", take as i64)) + .await?; + + let score_rows: Vec = response.take(0)?; + + if score_rows.is_empty() { + return Ok(Vec::new()); + } + + let ids: Vec = score_rows.iter().map(|row| row.id.clone()).collect(); + let thing_ids: Vec = ids + .iter() + .map(|id| Thing::from((table, id.as_str()))) + .collect(); + + let mut items_response = db_client + .query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id") + .bind(("table", table.to_owned())) + .bind(("things", thing_ids.clone())) + .bind(("user_id", user_id.to_owned())) + .await?; + + let items: Vec = items_response.take(0)?; + + let mut item_map: HashMap = items + .into_iter() + .map(|item| (item.get_id().to_owned(), item)) + .collect(); + + let mut results = Vec::with_capacity(score_rows.len()); + for row in score_rows { + if let Some(item) = item_map.remove(&row.id) { + let score = row.fts_score.unwrap_or_default(); + results.push(Scored::new(item).with_fts_score(score)); + } + } + + Ok(results) +} + +#[cfg(test)] +mod tests { + use super::*; + use common::storage::types::{ + knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}, + text_chunk::TextChunk, + StoredObject, + }; + use uuid::Uuid; + + fn dummy_embedding() -> Vec { + vec![0.0; 1536] + } + + #[tokio::test] + async fn fts_preserves_single_field_score_for_name() { + let namespace = "fts_test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("failed to create in-memory surreal"); + + db.apply_migrations() + .await + .expect("failed to apply migrations"); + + let user_id = "user_fts"; + let entity = KnowledgeEntity::new( + "source_a".into(), + "Rustacean handbook".into(), + "completely unrelated description".into(), + KnowledgeEntityType::Document, + None, + dummy_embedding(), + user_id.into(), + ); + + db.store_item(entity.clone()) + .await + .expect("failed to insert entity"); + + db.rebuild_indexes() + .await + .expect("failed to rebuild indexes"); + + let results = find_items_by_fts::( + 5, + "rustacean", + &db, + KnowledgeEntity::table_name(), + user_id, + ) + .await + .expect("fts query failed"); + + assert!(!results.is_empty(), "expected at least one FTS result"); + assert!( + results[0].scores.fts.is_some(), + "expected an FTS score when only the name matched" + ); + } + + #[tokio::test] + async fn fts_preserves_single_field_score_for_description() { + let namespace = "fts_test_ns_desc"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("failed to create in-memory surreal"); + + db.apply_migrations() + .await + .expect("failed to apply migrations"); + + let user_id = "user_fts_desc"; + let entity = KnowledgeEntity::new( + "source_b".into(), + "neutral name".into(), + "Detailed notes about async runtimes".into(), + KnowledgeEntityType::Document, + None, + dummy_embedding(), + user_id.into(), + ); + + db.store_item(entity.clone()) + .await + .expect("failed to insert entity"); + + db.rebuild_indexes() + .await + .expect("failed to rebuild indexes"); + + let results = find_items_by_fts::( + 5, + "async", + &db, + KnowledgeEntity::table_name(), + user_id, + ) + .await + .expect("fts query failed"); + + assert!(!results.is_empty(), "expected at least one FTS result"); + assert!( + results[0].scores.fts.is_some(), + "expected an FTS score when only the description matched" + ); + } + + #[tokio::test] + async fn fts_preserves_scores_for_text_chunks() { + let namespace = "fts_test_ns_chunks"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("failed to create in-memory surreal"); + + db.apply_migrations() + .await + .expect("failed to apply migrations"); + + let user_id = "user_fts_chunk"; + let chunk = TextChunk::new( + "source_chunk".into(), + "GraphQL documentation reference".into(), + dummy_embedding(), + user_id.into(), + ); + + db.store_item(chunk.clone()) + .await + .expect("failed to insert chunk"); + + db.rebuild_indexes() + .await + .expect("failed to rebuild indexes"); + + let results = + find_items_by_fts::(5, "graphql", &db, TextChunk::table_name(), user_id) + .await + .expect("fts query failed"); + + assert!(!results.is_empty(), "expected at least one FTS result"); + assert!( + results[0].scores.fts.is_some(), + "expected an FTS score when chunk field matched" + ); + } +} diff --git a/composite-retrieval/src/graph.rs b/composite-retrieval/src/graph.rs index 24bd9ea..9eb740a 100644 --- a/composite-retrieval/src/graph.rs +++ b/composite-retrieval/src/graph.rs @@ -1,6 +1,14 @@ -use surrealdb::Error; +use std::collections::{HashMap, HashSet}; -use common::storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity}; +use surrealdb::{sql::Thing, Error}; + +use common::storage::{ + db::SurrealDbClient, + types::{ + knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship, + StoredObject, + }, +}; /// Retrieves database entries that match a specific source identifier. /// @@ -30,18 +38,21 @@ use common::storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEnt /// * The database query fails to execute /// * The results cannot be deserialized into type `T` pub async fn find_entities_by_source_ids( - source_id: Vec, - table_name: String, + source_ids: Vec, + table_name: &str, + user_id: &str, db: &SurrealDbClient, ) -> Result, Error> where T: for<'de> serde::Deserialize<'de>, { - let query = "SELECT * FROM type::table($table) WHERE source_id IN $source_ids"; + let query = + "SELECT * FROM type::table($table) WHERE source_id IN $source_ids AND user_id = $user_id"; db.query(query) - .bind(("table", table_name)) - .bind(("source_ids", source_id)) + .bind(("table", table_name.to_owned())) + .bind(("source_ids", source_ids)) + .bind(("user_id", user_id.to_owned())) .await? .take(0) } @@ -49,14 +60,92 @@ where /// Find entities by their relationship to the id pub async fn find_entities_by_relationship_by_id( db: &SurrealDbClient, - entity_id: String, + entity_id: &str, + user_id: &str, + limit: usize, ) -> Result, Error> { - let query = format!( - "SELECT *, <-> relates_to <-> knowledge_entity AS related FROM knowledge_entity:`{}`", - entity_id - ); + let mut relationships_response = db + .query( + " + SELECT * FROM relates_to + WHERE metadata.user_id = $user_id + AND (in = type::thing('knowledge_entity', $entity_id) + OR out = type::thing('knowledge_entity', $entity_id)) + ", + ) + .bind(("entity_id", entity_id.to_owned())) + .bind(("user_id", user_id.to_owned())) + .await?; - db.query(query).await?.take(0) + let relationships: Vec = relationships_response.take(0)?; + if relationships.is_empty() { + return Ok(Vec::new()); + } + + let mut neighbor_ids: Vec = Vec::new(); + let mut seen: HashSet = HashSet::new(); + for rel in relationships { + if rel.in_ == entity_id { + if seen.insert(rel.out.clone()) { + neighbor_ids.push(rel.out); + } + } else if rel.out == entity_id { + if seen.insert(rel.in_.clone()) { + neighbor_ids.push(rel.in_); + } + } else { + if seen.insert(rel.in_.clone()) { + neighbor_ids.push(rel.in_.clone()); + } + if seen.insert(rel.out.clone()) { + neighbor_ids.push(rel.out); + } + } + } + + neighbor_ids.retain(|id| id != entity_id); + + if neighbor_ids.is_empty() { + return Ok(Vec::new()); + } + + if limit > 0 && neighbor_ids.len() > limit { + neighbor_ids.truncate(limit); + } + + let thing_ids: Vec = neighbor_ids + .iter() + .map(|id| Thing::from((KnowledgeEntity::table_name(), id.as_str()))) + .collect(); + + let mut neighbors_response = db + .query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id") + .bind(("table", KnowledgeEntity::table_name().to_owned())) + .bind(("things", thing_ids)) + .bind(("user_id", user_id.to_owned())) + .await?; + + let neighbors: Vec = neighbors_response.take(0)?; + if neighbors.is_empty() { + return Ok(Vec::new()); + } + + let mut neighbor_map: HashMap = neighbors + .into_iter() + .map(|entity| (entity.id.clone(), entity)) + .collect(); + + let mut ordered = Vec::new(); + for id in neighbor_ids { + if let Some(entity) = neighbor_map.remove(&id) { + ordered.push(entity); + } + if limit > 0 && ordered.len() >= limit { + break; + } + } + + Ok(ordered) } #[cfg(test)] @@ -146,7 +235,7 @@ mod tests { // Test finding entities by multiple source_ids let source_ids = vec![source_id1.clone(), source_id2.clone()]; let found_entities: Vec = - find_entities_by_source_ids(source_ids, KnowledgeEntity::table_name().to_string(), &db) + find_entities_by_source_ids(source_ids, KnowledgeEntity::table_name(), &user_id, &db) .await .expect("Failed to find entities by source_ids"); @@ -177,7 +266,8 @@ mod tests { let single_source_id = vec![source_id1.clone()]; let found_entities: Vec = find_entities_by_source_ids( single_source_id, - KnowledgeEntity::table_name().to_string(), + KnowledgeEntity::table_name(), + &user_id, &db, ) .await @@ -202,7 +292,8 @@ mod tests { let non_existent_source_id = vec!["non_existent_source".to_string()]; let found_entities: Vec = find_entities_by_source_ids( non_existent_source_id, - KnowledgeEntity::table_name().to_string(), + KnowledgeEntity::table_name(), + &user_id, &db, ) .await @@ -327,11 +418,15 @@ mod tests { .expect("Failed to store relationship 2"); // Test finding entities related to the central entity - let related_entities = find_entities_by_relationship_by_id(&db, central_entity.id.clone()) - .await - .expect("Failed to find entities by relationship"); + let related_entities = + find_entities_by_relationship_by_id(&db, ¢ral_entity.id, &user_id, usize::MAX) + .await + .expect("Failed to find entities by relationship"); // Check that we found relationships - assert!(related_entities.len() > 0, "Should find related entities"); + assert!( + related_entities.len() >= 2, + "Should find related entities in both directions" + ); } } diff --git a/composite-retrieval/src/lib.rs b/composite-retrieval/src/lib.rs index 90df670..e3e2ece 100644 --- a/composite-retrieval/src/lib.rs +++ b/composite-retrieval/src/lib.rs @@ -1,90 +1,714 @@ pub mod answer_retrieval; pub mod answer_retrieval_helper; +pub mod fts; pub mod graph; +pub mod scoring; pub mod vector; +use std::collections::{HashMap, HashSet}; + use common::{ error::AppError, storage::{ db::SurrealDbClient, - types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk}, + types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject}, }, + utils::embedding::generate_embedding, }; -use futures::future::{try_join, try_join_all}; +use futures::{stream::FuturesUnordered, StreamExt}; use graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids}; -use std::collections::HashMap; -use vector::find_items_by_vector_similarity; +use scoring::{ + clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc, + FusionWeights, Scored, +}; +use tracing::{debug, instrument, trace}; -/// Performs a comprehensive knowledge entity retrieval using multiple search strategies -/// to find the most relevant entities for a given query. -/// -/// # Strategy -/// The function employs a three-pronged approach to knowledge retrieval: -/// 1. Direct vector similarity search on knowledge entities -/// 2. Text chunk similarity search with source entity lookup -/// 3. Graph relationship traversal from related entities -/// -/// This combined approach ensures both semantic similarity matches and structurally -/// related content are included in the results. -/// -/// # Arguments -/// * `db_client` - SurrealDB client for database operations -/// * `openai_client` - OpenAI client for vector embeddings generation -/// * `query` - The search query string to find relevant knowledge entities -/// * 'user_id' - The user id of the current user -/// -/// # Returns -/// * `Result, AppError>` - A deduplicated vector of relevant -/// knowledge entities, or an error if the retrieval process fails +use crate::{fts::find_items_by_fts, vector::find_items_by_vector_similarity_with_embedding}; + +// Tunable knobs controlling first-pass recall, graph expansion, and answer shaping. +const ENTITY_VECTOR_TAKE: usize = 15; +const CHUNK_VECTOR_TAKE: usize = 20; +const ENTITY_FTS_TAKE: usize = 10; +const CHUNK_FTS_TAKE: usize = 20; +const SCORE_THRESHOLD: f32 = 0.35; +const FALLBACK_MIN_RESULTS: usize = 10; +const TOKEN_BUDGET_ESTIMATE: usize = 2800; +const AVG_CHARS_PER_TOKEN: usize = 4; +const MAX_CHUNKS_PER_ENTITY: usize = 4; +const GRAPH_TRAVERSAL_SEED_LIMIT: usize = 5; +const GRAPH_NEIGHBOR_LIMIT: usize = 6; +const GRAPH_SCORE_DECAY: f32 = 0.75; +const GRAPH_SEED_MIN_SCORE: f32 = 0.4; +const GRAPH_VECTOR_INHERITANCE: f32 = 0.6; + +#[derive(Debug, Clone)] +pub struct RetrievedChunk { + pub chunk: TextChunk, + pub score: f32, +} + +#[derive(Debug, Clone)] +pub struct RetrievedEntity { + pub entity: KnowledgeEntity, + pub score: f32, + pub chunks: Vec, +} + +#[instrument(skip_all, fields(user_id))] pub async fn retrieve_entities( db_client: &SurrealDbClient, openai_client: &async_openai::Client, query: &str, user_id: &str, -) -> Result, AppError> { - let (items_from_knowledge_entity_similarity, closest_chunks) = try_join( - find_items_by_vector_similarity( - 10, +) -> Result, AppError> { + trace!("Generating query embedding for hybrid retrieval"); + let query_embedding = generate_embedding(openai_client, query, db_client).await?; + retrieve_entities_with_embedding(db_client, query_embedding, query, user_id).await +} + +pub(crate) async fn retrieve_entities_with_embedding( + db_client: &SurrealDbClient, + query_embedding: Vec, + query: &str, + user_id: &str, +) -> Result, AppError> { + // 1) Gather first-pass candidates from vector search and BM25. + let weights = FusionWeights::default(); + + let (vector_entities, vector_chunks, mut fts_entities, mut fts_chunks) = tokio::try_join!( + find_items_by_vector_similarity_with_embedding( + ENTITY_VECTOR_TAKE, + query_embedding.clone(), + db_client, + "knowledge_entity", + user_id, + ), + find_items_by_vector_similarity_with_embedding( + CHUNK_VECTOR_TAKE, + query_embedding, + db_client, + "text_chunk", + user_id, + ), + find_items_by_fts( + ENTITY_FTS_TAKE, query, db_client, "knowledge_entity", - openai_client, - user_id, + user_id ), - find_items_by_vector_similarity(5, query, db_client, "text_chunk", openai_client, user_id), + find_items_by_fts(CHUNK_FTS_TAKE, query, db_client, "text_chunk", user_id), + )?; + + debug!( + vector_entities = vector_entities.len(), + vector_chunks = vector_chunks.len(), + fts_entities = fts_entities.len(), + fts_chunks = fts_chunks.len(), + "Hybrid retrieval initial candidate counts" + ); + + normalize_fts_scores(&mut fts_entities); + normalize_fts_scores(&mut fts_chunks); + + let mut entity_candidates: HashMap> = HashMap::new(); + let mut chunk_candidates: HashMap> = HashMap::new(); + + merge_scored_by_id(&mut entity_candidates, vector_entities); + merge_scored_by_id(&mut entity_candidates, fts_entities); + merge_scored_by_id(&mut chunk_candidates, vector_chunks); + merge_scored_by_id(&mut chunk_candidates, fts_chunks); + + // 2) Normalize scores, fuse them, and allow high-confidence entities to pull neighbors from the graph. + apply_fusion(&mut entity_candidates, weights); + apply_fusion(&mut chunk_candidates, weights); + enrich_entities_from_graph(&mut entity_candidates, db_client, user_id, weights).await?; + + // 3) Track high-signal chunk sources so we can backfill missing entities. + let chunk_by_source = group_chunks_by_source(&chunk_candidates); + let mut missing_sources = Vec::new(); + + for source_id in chunk_by_source.keys() { + if !entity_candidates + .values() + .any(|entity| entity.item.source_id == *source_id) + { + missing_sources.push(source_id.clone()); + } + } + + if !missing_sources.is_empty() { + let related_entities: Vec = find_entities_by_source_ids( + missing_sources.clone(), + "knowledge_entity", + user_id, + db_client, + ) + .await + .unwrap_or_default(); + + for entity in related_entities { + if let Some(chunks) = chunk_by_source.get(&entity.source_id) { + let best_chunk_score = chunks + .iter() + .map(|chunk| chunk.fused) + .fold(0.0f32, f32::max); + + let mut scored = Scored::new(entity.clone()).with_vector_score(best_chunk_score); + let fused = fuse_scores(&scored.scores, weights); + scored.update_fused(fused); + entity_candidates.insert(entity.id.clone(), scored); + } + } + } + + // Boost entities with evidence from high scoring chunks. + for entity in entity_candidates.values_mut() { + if let Some(chunks) = chunk_by_source.get(&entity.item.source_id) { + let best_chunk_score = chunks + .iter() + .map(|chunk| chunk.fused) + .fold(0.0f32, f32::max); + + if best_chunk_score > 0.0 { + let boosted = entity.scores.vector.unwrap_or(0.0).max(best_chunk_score); + entity.scores.vector = Some(boosted); + let fused = fuse_scores(&entity.scores, weights); + entity.update_fused(fused); + } + } + } + + let mut entity_results: Vec> = + entity_candidates.into_values().collect(); + sort_by_fused_desc(&mut entity_results); + + let mut filtered_entities: Vec> = entity_results + .iter() + .filter(|candidate| candidate.fused >= SCORE_THRESHOLD) + .cloned() + .collect(); + + if filtered_entities.len() < FALLBACK_MIN_RESULTS { + filtered_entities = entity_results + .into_iter() + .take(FALLBACK_MIN_RESULTS) + .collect(); + } + + // 4) Re-rank chunks and prepare for attachment to surviving entities. + let mut chunk_results: Vec> = chunk_candidates.into_values().collect(); + sort_by_fused_desc(&mut chunk_results); + + let mut chunk_by_id: HashMap> = HashMap::new(); + for chunk in chunk_results { + chunk_by_id.insert(chunk.item.id.clone(), chunk); + } + + enrich_chunks_from_entities( + &mut chunk_by_id, + &filtered_entities, + db_client, + user_id, + weights, ) .await?; - let source_ids = closest_chunks - .iter() - .map(|chunk: &TextChunk| chunk.source_id.clone()) - .collect::>(); + let mut chunk_values: Vec> = chunk_by_id.into_values().collect(); + sort_by_fused_desc(&mut chunk_values); - let items_from_text_chunk_similarity: Vec = - find_entities_by_source_ids(source_ids, "knowledge_entity".to_string(), db_client).await?; - - let items_from_relationships_futures: Vec<_> = items_from_text_chunk_similarity - .clone() - .into_iter() - .map(|entity| find_entities_by_relationship_by_id(db_client, entity.id.clone())) - .collect(); - - let items_from_relationships = try_join_all(items_from_relationships_futures) - .await? - .into_iter() - .flatten() - .collect::>(); - - let entities: Vec = items_from_knowledge_entity_similarity - .into_iter() - .chain(items_from_text_chunk_similarity.into_iter()) - .chain(items_from_relationships.into_iter()) - .fold(HashMap::new(), |mut map, entity| { - map.insert(entity.id.clone(), entity); - map - }) - .into_values() - .collect(); - - Ok(entities) + Ok(assemble_results(filtered_entities, chunk_values)) +} + +#[derive(Clone)] +struct GraphSeed { + id: String, + fused: f32, +} + +async fn enrich_entities_from_graph( + entity_candidates: &mut HashMap>, + db_client: &SurrealDbClient, + user_id: &str, + weights: FusionWeights, +) -> Result<(), AppError> { + if entity_candidates.is_empty() { + return Ok(()); + } + + // Select a small frontier of high-confidence entities to seed the relationship walk. + let mut seeds: Vec = entity_candidates + .values() + .filter(|entity| entity.fused >= GRAPH_SEED_MIN_SCORE) + .map(|entity| GraphSeed { + id: entity.item.id.clone(), + fused: entity.fused, + }) + .collect(); + + if seeds.is_empty() { + return Ok(()); + } + + seeds.sort_by(|a, b| { + b.fused + .partial_cmp(&a.fused) + .unwrap_or(std::cmp::Ordering::Equal) + }); + seeds.truncate(GRAPH_TRAVERSAL_SEED_LIMIT); + + let mut futures = FuturesUnordered::new(); + for seed in seeds.clone() { + let user_id = user_id.to_owned(); + futures.push(async move { + // Fetch neighbors concurrently to avoid serial graph round trips. + let neighbors = find_entities_by_relationship_by_id( + db_client, + &seed.id, + &user_id, + GRAPH_NEIGHBOR_LIMIT, + ) + .await; + (seed, neighbors) + }); + } + + while let Some((seed, neighbors_result)) = futures.next().await { + let neighbors = neighbors_result.map_err(AppError::from)?; + if neighbors.is_empty() { + continue; + } + + for neighbor in neighbors { + if neighbor.id == seed.id { + continue; + } + + let graph_score = clamp_unit(seed.fused * GRAPH_SCORE_DECAY); + let entry = entity_candidates + .entry(neighbor.id.clone()) + .or_insert_with(|| Scored::new(neighbor.clone())); + + entry.item = neighbor; + + let inherited_vector = clamp_unit(graph_score * GRAPH_VECTOR_INHERITANCE); + let vector_existing = entry.scores.vector.unwrap_or(0.0); + if inherited_vector > vector_existing { + entry.scores.vector = Some(inherited_vector); + } + + let existing_graph = entry.scores.graph.unwrap_or(f32::MIN); + if graph_score > existing_graph { + entry.scores.graph = Some(graph_score); + } else if entry.scores.graph.is_none() { + entry.scores.graph = Some(graph_score); + } + + let fused = fuse_scores(&entry.scores, weights); + entry.update_fused(fused); + } + } + + Ok(()) +} + +fn normalize_fts_scores(results: &mut [Scored]) { + // Scale BM25 outputs into [0,1] to keep fusion weights predictable. + let raw_scores: Vec = results + .iter() + .map(|candidate| candidate.scores.fts.unwrap_or(0.0)) + .collect(); + + let normalized = min_max_normalize(&raw_scores); + for (candidate, normalized_score) in results.iter_mut().zip(normalized.into_iter()) { + candidate.scores.fts = Some(normalized_score); + candidate.update_fused(0.0); + } +} + +fn apply_fusion(candidates: &mut HashMap>, weights: FusionWeights) +where + T: StoredObject, +{ + // Collapse individual signals into a single fused score used for ranking. + for candidate in candidates.values_mut() { + let fused = fuse_scores(&candidate.scores, weights); + candidate.update_fused(fused); + } +} + +fn group_chunks_by_source( + chunks: &HashMap>, +) -> HashMap>> { + // Preserve chunk candidates keyed by their originating source entity. + let mut by_source: HashMap>> = HashMap::new(); + + for chunk in chunks.values() { + by_source + .entry(chunk.item.source_id.clone()) + .or_default() + .push(chunk.clone()); + } + by_source +} + +async fn enrich_chunks_from_entities( + chunk_candidates: &mut HashMap>, + entities: &[Scored], + db_client: &SurrealDbClient, + user_id: &str, + weights: FusionWeights, +) -> Result<(), AppError> { + // Fetch additional chunks referenced by entities that survived the fusion stage. + let mut source_ids: HashSet = HashSet::new(); + for entity in entities { + source_ids.insert(entity.item.source_id.clone()); + } + + if source_ids.is_empty() { + return Ok(()); + } + + let chunks = find_entities_by_source_ids::( + source_ids.into_iter().collect(), + "text_chunk", + user_id, + db_client, + ) + .await?; + + let mut entity_score_lookup: HashMap = HashMap::new(); + // Cache fused scores per source so chunks inherit the strength of their parent entity. + for entity in entities { + entity_score_lookup.insert(entity.item.source_id.clone(), entity.fused); + } + + for chunk in chunks { + // Ensure each chunk is represented so downstream selection sees the latest content. + let entry = chunk_candidates + .entry(chunk.id.clone()) + .or_insert_with(|| Scored::new(chunk.clone()).with_vector_score(0.0)); + + let entity_score = entity_score_lookup + .get(&chunk.source_id) + .copied() + .unwrap_or(0.0); + + // Lift chunk score toward the entity score so supporting evidence is prioritised. + entry.scores.vector = Some(entry.scores.vector.unwrap_or(0.0).max(entity_score * 0.8)); + let fused = fuse_scores(&entry.scores, weights); + entry.update_fused(fused); + entry.item = chunk; + } + + Ok(()) +} + +fn assemble_results( + entities: Vec>, + mut chunks: Vec>, +) -> Vec { + // Re-associate chunk candidates with their parent entity for ranked selection. + let mut chunk_by_source: HashMap>> = HashMap::new(); + for chunk in chunks.drain(..) { + chunk_by_source + .entry(chunk.item.source_id.clone()) + .or_default() + .push(chunk); + } + + for chunk_list in chunk_by_source.values_mut() { + sort_by_fused_desc(chunk_list); + } + + let mut token_budget_remaining = TOKEN_BUDGET_ESTIMATE; + let mut results = Vec::new(); + + for entity in entities { + // Attach best chunks first while respecting per-entity and global token caps. + let mut selected_chunks = Vec::new(); + if let Some(candidates) = chunk_by_source.get_mut(&entity.item.source_id) { + let mut per_entity_count = 0; + candidates.sort_by(|a, b| { + b.fused + .partial_cmp(&a.fused) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + for candidate in candidates.iter() { + if per_entity_count >= MAX_CHUNKS_PER_ENTITY { + break; + } + let estimated_tokens = estimate_tokens(&candidate.item.chunk); + if estimated_tokens > token_budget_remaining { + continue; + } + + token_budget_remaining = token_budget_remaining.saturating_sub(estimated_tokens); + per_entity_count += 1; + + selected_chunks.push(RetrievedChunk { + chunk: candidate.item.clone(), + score: candidate.fused, + }); + } + } + + results.push(RetrievedEntity { + entity: entity.item.clone(), + score: entity.fused, + chunks: selected_chunks, + }); + + if token_budget_remaining == 0 { + break; + } + } + + results +} + +fn estimate_tokens(text: &str) -> usize { + // Simple heuristic to avoid calling a tokenizer in hot code paths. + let chars = text.chars().count().max(1); + (chars / AVG_CHARS_PER_TOKEN).max(1) +} + +#[cfg(test)] +mod tests { + use super::*; + use common::storage::types::{ + knowledge_entity::KnowledgeEntityType, knowledge_relationship::KnowledgeRelationship, + }; + use uuid::Uuid; + + fn test_embedding() -> Vec { + vec![0.9, 0.1, 0.0] + } + + fn entity_embedding_high() -> Vec { + vec![0.8, 0.2, 0.0] + } + + fn entity_embedding_low() -> Vec { + vec![0.1, 0.9, 0.0] + } + + fn chunk_embedding_primary() -> Vec { + vec![0.85, 0.15, 0.0] + } + + fn chunk_embedding_secondary() -> Vec { + vec![0.2, 0.8, 0.0] + } + + async fn setup_test_db() -> SurrealDbClient { + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + db.apply_migrations() + .await + .expect("Failed to apply migrations"); + + db.query( + "BEGIN TRANSACTION; + REMOVE INDEX IF EXISTS idx_embedding_chunks ON TABLE text_chunk; + DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION 3; + REMOVE INDEX IF EXISTS idx_embedding_entities ON TABLE knowledge_entity; + DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION 3; + COMMIT TRANSACTION;", + ) + .await + .expect("Failed to redefine vector indexes for tests"); + + db + } + + async fn seed_test_data(db: &SurrealDbClient, user_id: &str) { + let entity_relevant = KnowledgeEntity::new( + "source_a".into(), + "Rust Concurrency Patterns".into(), + "Discussion about async concurrency in Rust.".into(), + KnowledgeEntityType::Document, + None, + entity_embedding_high(), + user_id.into(), + ); + let entity_irrelevant = KnowledgeEntity::new( + "source_b".into(), + "Python Tips".into(), + "General Python programming tips.".into(), + KnowledgeEntityType::Document, + None, + entity_embedding_low(), + user_id.into(), + ); + + db.store_item(entity_relevant.clone()) + .await + .expect("Failed to store relevant entity"); + db.store_item(entity_irrelevant.clone()) + .await + .expect("Failed to store irrelevant entity"); + + let chunk_primary = TextChunk::new( + entity_relevant.source_id.clone(), + "Tokio enables async concurrency with lightweight tasks.".into(), + chunk_embedding_primary(), + user_id.into(), + ); + let chunk_secondary = TextChunk::new( + entity_irrelevant.source_id.clone(), + "Python focuses on readability and dynamic typing.".into(), + chunk_embedding_secondary(), + user_id.into(), + ); + + db.store_item(chunk_primary) + .await + .expect("Failed to store primary chunk"); + db.store_item(chunk_secondary) + .await + .expect("Failed to store secondary chunk"); + } + + #[tokio::test] + async fn test_hybrid_retrieval_prioritises_relevant_entity() { + let db = setup_test_db().await; + let user_id = "user123"; + seed_test_data(&db, user_id).await; + + let results = retrieve_entities_with_embedding( + &db, + test_embedding(), + "Rust concurrency async tasks", + user_id, + ) + .await + .expect("Hybrid retrieval failed"); + + assert!( + !results.is_empty(), + "Expected at least one retrieval result" + ); + + let top = &results[0]; + + assert!( + top.entity.name.contains("Rust"), + "Expected Rust entity to be ranked first" + ); + assert!( + !top.chunks.is_empty(), + "Expected Rust entity to include supporting chunks" + ); + + let chunk_texts: Vec<&str> = top + .chunks + .iter() + .map(|chunk| chunk.chunk.chunk.as_str()) + .collect(); + assert!( + chunk_texts.iter().any(|text| text.contains("Tokio")), + "Expected chunk discussing Tokio to be included" + ); + } + + #[tokio::test] + async fn test_graph_relationship_enriches_results() { + let db = setup_test_db().await; + let user_id = "graph_user"; + + let primary = KnowledgeEntity::new( + "primary_source".into(), + "Async Rust patterns".into(), + "Explores async runtimes and scheduling strategies.".into(), + KnowledgeEntityType::Document, + None, + entity_embedding_high(), + user_id.into(), + ); + let neighbor = KnowledgeEntity::new( + "neighbor_source".into(), + "Tokio Scheduler Deep Dive".into(), + "Details on Tokio's cooperative scheduler.".into(), + KnowledgeEntityType::Document, + None, + entity_embedding_low(), + user_id.into(), + ); + + db.store_item(primary.clone()) + .await + .expect("Failed to store primary entity"); + db.store_item(neighbor.clone()) + .await + .expect("Failed to store neighbor entity"); + + let primary_chunk = TextChunk::new( + primary.source_id.clone(), + "Rust async tasks use Tokio's cooperative scheduler.".into(), + chunk_embedding_primary(), + user_id.into(), + ); + let neighbor_chunk = TextChunk::new( + neighbor.source_id.clone(), + "Tokio's scheduler manages task fairness across executors.".into(), + chunk_embedding_secondary(), + user_id.into(), + ); + + db.store_item(primary_chunk) + .await + .expect("Failed to store primary chunk"); + db.store_item(neighbor_chunk) + .await + .expect("Failed to store neighbor chunk"); + + let relationship = KnowledgeRelationship::new( + primary.id.clone(), + neighbor.id.clone(), + user_id.into(), + "relationship_source".into(), + "references".into(), + ); + relationship + .store_relationship(&db) + .await + .expect("Failed to store relationship"); + + let results = retrieve_entities_with_embedding( + &db, + test_embedding(), + "Rust concurrency async tasks", + user_id, + ) + .await + .expect("Hybrid retrieval failed"); + + let mut neighbor_entry = None; + for entity in &results { + if entity.entity.id == neighbor.id { + neighbor_entry = Some(entity.clone()); + } + } + + let neighbor_entry = + neighbor_entry.expect("Graph-enriched neighbor should appear in results"); + + assert!( + neighbor_entry.score > 0.2, + "Graph-enriched entity should have a meaningful fused score" + ); + assert!( + neighbor_entry + .chunks + .iter() + .all(|chunk| chunk.chunk.source_id == neighbor.source_id), + "Neighbor entity should surface its own supporting chunks" + ); + } } diff --git a/composite-retrieval/src/scoring.rs b/composite-retrieval/src/scoring.rs new file mode 100644 index 0000000..d0f28b5 --- /dev/null +++ b/composite-retrieval/src/scoring.rs @@ -0,0 +1,180 @@ +use std::cmp::Ordering; + +use common::storage::types::StoredObject; + +/// Holds optional subscores gathered from different retrieval signals. +#[derive(Debug, Clone, Copy, Default)] +pub struct Scores { + pub fts: Option, + pub vector: Option, + pub graph: Option, +} + +/// Generic wrapper combining an item with its accumulated retrieval scores. +#[derive(Debug, Clone)] +pub struct Scored { + pub item: T, + pub scores: Scores, + pub fused: f32, +} + +impl Scored { + pub fn new(item: T) -> Self { + Self { + item, + scores: Scores::default(), + fused: 0.0, + } + } + + pub fn with_vector_score(mut self, score: f32) -> Self { + self.scores.vector = Some(score); + self + } + + pub fn with_fts_score(mut self, score: f32) -> Self { + self.scores.fts = Some(score); + self + } + + pub fn with_graph_score(mut self, score: f32) -> Self { + self.scores.graph = Some(score); + self + } + + pub fn update_fused(&mut self, fused: f32) { + self.fused = fused; + } +} + +/// Weights used for linear score fusion. +#[derive(Debug, Clone, Copy)] +pub struct FusionWeights { + pub vector: f32, + pub fts: f32, + pub graph: f32, + pub multi_bonus: f32, +} + +impl Default for FusionWeights { + fn default() -> Self { + Self { + vector: 0.5, + fts: 0.3, + graph: 0.2, + multi_bonus: 0.02, + } + } +} + +pub fn clamp_unit(value: f32) -> f32 { + value.max(0.0).min(1.0) +} + +pub fn distance_to_similarity(distance: f32) -> f32 { + if !distance.is_finite() { + return 0.0; + } + clamp_unit(1.0 / (1.0 + distance.max(0.0))) +} + +pub fn min_max_normalize(scores: &[f32]) -> Vec { + if scores.is_empty() { + return Vec::new(); + } + + let mut min = f32::MAX; + let mut max = f32::MIN; + + for s in scores { + if !s.is_finite() { + continue; + } + if *s < min { + min = *s; + } + if *s > max { + max = *s; + } + } + + if !min.is_finite() || !max.is_finite() { + return scores.iter().map(|_| 0.0).collect(); + } + + if (max - min).abs() < f32::EPSILON { + return vec![1.0; scores.len()]; + } + + scores + .iter() + .map(|score| { + if !score.is_finite() { + 0.0 + } else { + clamp_unit((score - min) / (max - min)) + } + }) + .collect() +} + +pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 { + let vector = scores.vector.unwrap_or(0.0); + let fts = scores.fts.unwrap_or(0.0); + let graph = scores.graph.unwrap_or(0.0); + + let mut fused = vector * weights.vector + fts * weights.fts + graph * weights.graph; + + let signals_present = scores + .vector + .iter() + .chain(scores.fts.iter()) + .chain(scores.graph.iter()) + .count(); + if signals_present >= 2 { + fused += weights.multi_bonus; + } + + clamp_unit(fused) +} + +pub fn merge_scored_by_id( + target: &mut std::collections::HashMap>, + incoming: Vec>, +) where + T: StoredObject + Clone, +{ + for scored in incoming { + let id = scored.item.get_id().to_owned(); + target + .entry(id) + .and_modify(|existing| { + if let Some(score) = scored.scores.vector { + existing.scores.vector = Some(score); + } + if let Some(score) = scored.scores.fts { + existing.scores.fts = Some(score); + } + if let Some(score) = scored.scores.graph { + existing.scores.graph = Some(score); + } + }) + .or_insert_with(|| Scored { + item: scored.item.clone(), + scores: scored.scores, + fused: scored.fused, + }); + } +} + +pub fn sort_by_fused_desc(items: &mut [Scored]) +where + T: StoredObject, +{ + items.sort_by(|a, b| { + b.fused + .partial_cmp(&a.fused) + .unwrap_or(Ordering::Equal) + .then_with(|| a.item.get_id().cmp(b.item.get_id())) + }); +} diff --git a/composite-retrieval/src/vector.rs b/composite-retrieval/src/vector.rs index d69f49a..56bf2ef 100644 --- a/composite-retrieval/src/vector.rs +++ b/composite-retrieval/src/vector.rs @@ -1,4 +1,15 @@ -use common::{error::AppError, storage::db::SurrealDbClient, utils::embedding::generate_embedding}; +use std::collections::HashMap; + +use common::storage::types::file_info::deserialize_flexible_id; +use common::{ + error::AppError, + storage::{db::SurrealDbClient, types::StoredObject}, + utils::embedding::generate_embedding, +}; +use serde::Deserialize; +use surrealdb::sql::Thing; + +use crate::scoring::{distance_to_similarity, Scored}; /// Compares vectors and retrieves a number of items from the specified table. /// @@ -22,24 +33,90 @@ use common::{error::AppError, storage::db::SurrealDbClient, utils::embedding::ge /// /// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`. pub async fn find_items_by_vector_similarity( - take: u8, + take: usize, input_text: &str, db_client: &SurrealDbClient, table: &str, openai_client: &async_openai::Client, user_id: &str, -) -> Result, AppError> +) -> Result>, AppError> where - T: for<'de> serde::Deserialize<'de>, + T: for<'de> serde::Deserialize<'de> + StoredObject, { // Generate embeddings let input_embedding = generate_embedding(openai_client, input_text, db_client).await?; - - // Construct the query - let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE user_id = '{}' AND embedding <|{},40|> {:?} ORDER BY distance", table, user_id, take, input_embedding); - - // Perform query and deserialize to struct - let closest_entities: Vec = db_client.query(closest_query).await?.take(0)?; - - Ok(closest_entities) + find_items_by_vector_similarity_with_embedding(take, input_embedding, db_client, table, user_id) + .await +} + +#[derive(Debug, Deserialize)] +struct DistanceRow { + #[serde(deserialize_with = "deserialize_flexible_id")] + id: String, + distance: Option, +} + +pub async fn find_items_by_vector_similarity_with_embedding( + take: usize, + query_embedding: Vec, + db_client: &SurrealDbClient, + table: &str, + user_id: &str, +) -> Result>, AppError> +where + T: for<'de> serde::Deserialize<'de> + StoredObject, +{ + let embedding_literal = serde_json::to_string(&query_embedding) + .map_err(|err| AppError::InternalError(format!("Failed to serialize embedding: {err}")))?; + let closest_query = format!( + "SELECT id, vector::distance::knn() AS distance \ + FROM {table} \ + WHERE user_id = $user_id AND embedding <|{take},40|> {embedding} \ + LIMIT $limit", + table = table, + take = take, + embedding = embedding_literal + ); + + let mut response = db_client + .query(closest_query) + .bind(("user_id", user_id.to_owned())) + .bind(("limit", take as i64)) + .await?; + + let distance_rows: Vec = response.take(0)?; + + if distance_rows.is_empty() { + return Ok(Vec::new()); + } + + let ids: Vec = distance_rows.iter().map(|row| row.id.clone()).collect(); + let thing_ids: Vec = ids + .iter() + .map(|id| Thing::from((table, id.as_str()))) + .collect(); + + let mut items_response = db_client + .query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id") + .bind(("table", table.to_owned())) + .bind(("things", thing_ids.clone())) + .bind(("user_id", user_id.to_owned())) + .await?; + + let items: Vec = items_response.take(0)?; + + let mut item_map: HashMap = items + .into_iter() + .map(|item| (item.get_id().to_owned(), item)) + .collect(); + + let mut scored = Vec::with_capacity(distance_rows.len()); + for row in distance_rows { + if let Some(item) = item_map.remove(&row.id) { + let similarity = row.distance.map(distance_to_similarity).unwrap_or_default(); + scored.push(Scored::new(item).with_vector_score(similarity)); + } + } + + Ok(scored) } diff --git a/ingestion-pipeline/src/enricher.rs b/ingestion-pipeline/src/enricher.rs index e2fc811..eab2c5b 100644 --- a/ingestion-pipeline/src/enricher.rs +++ b/ingestion-pipeline/src/enricher.rs @@ -7,13 +7,11 @@ use async_openai::types::{ }; use common::{ error::AppError, - storage::{ - db::SurrealDbClient, - types::{knowledge_entity::KnowledgeEntity, system_settings::SystemSettings}, - }, + storage::{db::SurrealDbClient, types::system_settings::SystemSettings}, +}; +use composite_retrieval::{ + answer_retrieval::format_entities_json, retrieve_entities, RetrievedEntity, }; -use composite_retrieval::retrieve_entities; -use serde_json::json; use tracing::{debug, info}; use crate::{ @@ -61,7 +59,7 @@ impl IngestionEnricher { context: Option<&str>, text: &str, user_id: &str, - ) -> Result, AppError> { + ) -> Result, AppError> { let input_text = format!( "content: {}, category: {}, user_context: {:?}", text, category, context @@ -75,22 +73,11 @@ impl IngestionEnricher { category: &str, context: Option<&str>, text: &str, - similar_entities: &[KnowledgeEntity], + similar_entities: &[RetrievedEntity], ) -> Result { let settings = SystemSettings::get_current(&self.db_client).await?; - let entities_json = json!(similar_entities - .iter() - .map(|entity| { - json!({ - "KnowledgeEntity": { - "id": entity.id, - "name": entity.name, - "description": entity.description - } - }) - }) - .collect::>()); + let entities_json = format_entities_json(similar_entities); let user_message = format!( "Category:\n{}\ncontext:\n{:?}\nContent:\n{}\nExisting KnowledgeEntities in database:\n{}",