From a5bc72aedf9ca167d2c556bc2eb59c02bdfde32a Mon Sep 17 00:00:00 2001 From: Per Stark Date: Wed, 10 Dec 2025 13:54:08 +0100 Subject: [PATCH] passed wide smoke check --- .gitignore | 4 +- common/src/storage/db.rs | 20 -- common/src/storage/indexes.rs | 304 +++++++++++------- common/src/storage/store.rs | 3 + common/src/storage/types/text_chunk.rs | 89 +++-- common/src/utils/embedding.rs | 6 +- evaluations/src/types.rs | 30 +- .../routes/chat/message_response_stream.rs | 4 + html-router/src/routes/search/handlers.rs | 96 ++++-- ingestion-pipeline/src/pipeline/services.rs | 13 +- ingestion-pipeline/src/pipeline/stages/mod.rs | 54 ++-- main/src/main.rs | 15 +- 12 files changed, 403 insertions(+), 235 deletions(-) diff --git a/.gitignore b/.gitignore index a6df1fe..b9de9c3 100644 --- a/.gitignore +++ b/.gitignore @@ -10,8 +10,8 @@ result data database -eval/cache/ -eval/reports/ +evaluations/cache/ +evaluations/reports/ # Devenv .devenv* diff --git a/common/src/storage/db.rs b/common/src/storage/db.rs index 83d1496..8aa9776 100644 --- a/common/src/storage/db.rs +++ b/common/src/storage/db.rs @@ -97,26 +97,6 @@ impl SurrealDbClient { Ok(()) } - /// Operation to rebuild indexes - pub async fn rebuild_indexes(&self) -> Result<(), AppError> { - debug!("Rebuilding indexes"); - let rebuild_sql = r#" - REBUILD INDEX IF EXISTS text_content_fts_idx ON text_content; - REBUILD INDEX IF EXISTS knowledge_entity_fts_name_idx ON knowledge_entity; - REBUILD INDEX IF EXISTS knowledge_entity_fts_description_idx ON knowledge_entity; - REBUILD INDEX IF EXISTS text_chunk_fts_chunk_idx ON text_chunk; - REBUILD INDEX IF EXISTS idx_embedding_text_chunk_embedding ON text_chunk_embedding; - REBUILD INDEX IF EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding; - "#; - - self.client - .query(rebuild_sql) - .await - .map_err(|e| AppError::InternalError(e.to_string()))?; - - Ok(()) - } - /// Operation to store a object in SurrealDB, requires the struct to implement StoredObject /// /// # Arguments diff --git a/common/src/storage/indexes.rs b/common/src/storage/indexes.rs index 7a8f924..2b198e8 100644 --- a/common/src/storage/indexes.rs +++ b/common/src/storage/indexes.rs @@ -1,19 +1,9 @@ -#![allow( - clippy::missing_docs_in_private_items, - clippy::module_name_repetitions, - clippy::items_after_statements, - clippy::arithmetic_side_effects, - clippy::cast_precision_loss, - clippy::redundant_closure_for_method_calls, - clippy::single_match_else, - clippy::uninlined_format_args -)] use std::time::Duration; use anyhow::{Context, Result}; use futures::future::try_join_all; use serde::Deserialize; -use serde_json::Value; +use serde_json::{Map, Value}; use tracing::{debug, info, warn}; use crate::{error::AppError, storage::db::SurrealDbClient}; @@ -28,6 +18,82 @@ struct HnswIndexSpec { options: &'static str, } +const fn hnsw_index_specs() -> [HnswIndexSpec; 2] { + [ + HnswIndexSpec { + index_name: "idx_embedding_text_chunk_embedding", + table: "text_chunk_embedding", + options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY", + }, + HnswIndexSpec { + index_name: "idx_embedding_knowledge_entity_embedding", + table: "knowledge_entity_embedding", + options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY", + }, + ] +} + +const fn fts_index_specs() -> [FtsIndexSpec; 8] { + [ + FtsIndexSpec { + index_name: "text_content_fts_idx", + table: "text_content", + field: "text", + analyzer: Some(FTS_ANALYZER_NAME), + method: "BM25", + }, + FtsIndexSpec { + index_name: "text_content_context_fts_idx", + table: "text_content", + field: "context", + analyzer: Some(FTS_ANALYZER_NAME), + method: "BM25", + }, + FtsIndexSpec { + index_name: "text_content_file_name_fts_idx", + table: "text_content", + field: "file_info.file_name", + analyzer: Some(FTS_ANALYZER_NAME), + method: "BM25", + }, + FtsIndexSpec { + index_name: "text_content_url_fts_idx", + table: "text_content", + field: "url_info.url", + analyzer: Some(FTS_ANALYZER_NAME), + method: "BM25", + }, + FtsIndexSpec { + index_name: "text_content_url_title_fts_idx", + table: "text_content", + field: "url_info.title", + analyzer: Some(FTS_ANALYZER_NAME), + method: "BM25", + }, + FtsIndexSpec { + index_name: "knowledge_entity_fts_name_idx", + table: "knowledge_entity", + field: "name", + analyzer: Some(FTS_ANALYZER_NAME), + method: "BM25", + }, + FtsIndexSpec { + index_name: "knowledge_entity_fts_description_idx", + table: "knowledge_entity", + field: "description", + analyzer: Some(FTS_ANALYZER_NAME), + method: "BM25", + }, + FtsIndexSpec { + index_name: "text_chunk_fts_chunk_idx", + table: "text_chunk", + field: "chunk", + analyzer: Some(FTS_ANALYZER_NAME), + method: "BM25", + }, + ] +} + impl HnswIndexSpec { fn definition_if_not_exists(&self, dimension: usize) -> String { format!( @@ -75,6 +141,20 @@ impl FtsIndexSpec { field = self.field, ) } + + fn overwrite_definition(&self) -> String { + let analyzer_clause = self + .analyzer + .map(|analyzer| format!(" SEARCH ANALYZER {analyzer} {}", self.method)) + .unwrap_or_default(); + + format!( + "DEFINE INDEX OVERWRITE {index} ON TABLE {table} FIELDS {field}{analyzer_clause} CONCURRENTLY;", + index = self.index_name, + table = self.table, + field = self.field, + ) + } } /// Build runtime Surreal indexes (FTS + HNSW) using concurrent creation with readiness polling. @@ -88,6 +168,13 @@ pub async fn ensure_runtime_indexes( .map_err(|err| AppError::InternalError(err.to_string())) } +/// Rebuild known FTS and HNSW indexes, skipping any that are not yet defined. +pub async fn rebuild_indexes(db: &SurrealDbClient) -> Result<(), AppError> { + rebuild_indexes_inner(db) + .await + .map_err(|err| AppError::InternalError(err.to_string())) +} + async fn ensure_runtime_indexes_inner( db: &SurrealDbClient, embedding_dimension: usize, @@ -147,32 +234,68 @@ async fn ensure_runtime_indexes_inner( Ok(()) } -async fn hnsw_index_state( +async fn rebuild_indexes_inner(db: &SurrealDbClient) -> Result<()> { + debug!("Rebuilding indexes with concurrent definitions"); + create_fts_analyzer(db).await?; + + for spec in fts_index_specs() { + if !index_exists(db, spec.table, spec.index_name).await? { + debug!( + index = spec.index_name, + table = spec.table, + "Skipping FTS rebuild because index is missing" + ); + continue; + } + + create_index_with_polling( + db, + spec.overwrite_definition(), + spec.index_name, + spec.table, + Some(spec.table), + ) + .await?; + } + + let hnsw_tasks = hnsw_index_specs().into_iter().map(|spec| async move { + if !index_exists(db, spec.table, spec.index_name).await? { + debug!( + index = spec.index_name, + table = spec.table, + "Skipping HNSW rebuild because index is missing" + ); + return Ok(()); + } + + let Some(dimension) = existing_hnsw_dimension(db, &spec).await? else { + warn!( + index = spec.index_name, + table = spec.table, + "HNSW index missing dimension; skipping rebuild" + ); + return Ok(()); + }; + + create_index_with_polling( + db, + spec.definition_overwrite(dimension), + spec.index_name, + spec.table, + Some(spec.table), + ) + .await + }); + + try_join_all(hnsw_tasks).await.map(|_| ()) +} + +async fn existing_hnsw_dimension( db: &SurrealDbClient, spec: &HnswIndexSpec, - expected_dimension: usize, -) -> Result { - let info_query = format!("INFO FOR TABLE {table};", table = spec.table); - let mut response = db - .client - .query(info_query) - .await - .with_context(|| format!("fetching table info for {}", spec.table))?; - - let info: surrealdb::Value = response - .take(0) - .context("failed to take table info response")?; - - let info_json: Value = - serde_json::to_value(info).context("serializing table info to JSON for parsing")?; - - let Some(indexes) = info_json - .get("Object") - .and_then(|o| o.get("indexes")) - .and_then(|i| i.get("Object")) - .and_then(|i| i.as_object()) - else { - return Ok(HnswIndexState::Missing); +) -> Result> { + let Some(indexes) = table_index_definitions(db, spec.table).await? else { + return Ok(None); }; let Some(definition) = indexes @@ -180,17 +303,23 @@ async fn hnsw_index_state( .and_then(|details| details.get("Strand")) .and_then(|v| v.as_str()) else { - return Ok(HnswIndexState::Missing); + return Ok(None); }; - let Some(current_dimension) = extract_dimension(definition) else { - return Ok(HnswIndexState::Missing); - }; + Ok(extract_dimension(definition).and_then(|d| usize::try_from(d).ok())) +} - if current_dimension == expected_dimension as u64 { - Ok(HnswIndexState::Matches) - } else { - Ok(HnswIndexState::Different(current_dimension)) +async fn hnsw_index_state( + db: &SurrealDbClient, + spec: &HnswIndexSpec, + expected_dimension: usize, +) -> Result { + match existing_hnsw_dimension(db, spec).await? { + None => Ok(HnswIndexState::Missing), + Some(current_dimension) if current_dimension == expected_dimension => { + Ok(HnswIndexState::Matches) + } + Some(current_dimension) => Ok(HnswIndexState::Different(current_dimension as u64)), } } @@ -492,7 +621,10 @@ async fn count_table_rows(db: &SurrealDbClient, table: &str) -> Result { Ok(rows.first().map_or(0, |r| r.count)) } -async fn index_exists(db: &SurrealDbClient, table: &str, index_name: &str) -> Result { +async fn table_index_definitions( + db: &SurrealDbClient, + table: &str, +) -> Result>> { let info_query = format!("INFO FOR TABLE {table};"); let mut response = db .client @@ -507,94 +639,22 @@ async fn index_exists(db: &SurrealDbClient, table: &str, index_name: &str) -> Re let info_json: Value = serde_json::to_value(info).context("serializing table info to JSON for parsing")?; - let Some(indexes) = info_json + Ok(info_json .get("Object") .and_then(|o| o.get("indexes")) .and_then(|i| i.get("Object")) .and_then(|i| i.as_object()) - else { + .cloned()) +} + +async fn index_exists(db: &SurrealDbClient, table: &str, index_name: &str) -> Result { + let Some(indexes) = table_index_definitions(db, table).await? else { return Ok(false); }; Ok(indexes.contains_key(index_name)) } -const fn hnsw_index_specs() -> [HnswIndexSpec; 2] { - [ - HnswIndexSpec { - index_name: "idx_embedding_text_chunk_embedding", - table: "text_chunk_embedding", - options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY", - }, - HnswIndexSpec { - index_name: "idx_embedding_knowledge_entity_embedding", - table: "knowledge_entity_embedding", - options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY", - }, - ] -} - -const fn fts_index_specs() -> [FtsIndexSpec; 8] { - [ - FtsIndexSpec { - index_name: "text_content_fts_idx", - table: "text_content", - field: "text", - analyzer: Some(FTS_ANALYZER_NAME), - method: "BM25", - }, - FtsIndexSpec { - index_name: "text_content_context_fts_idx", - table: "text_content", - field: "context", - analyzer: Some(FTS_ANALYZER_NAME), - method: "BM25", - }, - FtsIndexSpec { - index_name: "text_content_file_name_fts_idx", - table: "text_content", - field: "file_info.file_name", - analyzer: Some(FTS_ANALYZER_NAME), - method: "BM25", - }, - FtsIndexSpec { - index_name: "text_content_url_fts_idx", - table: "text_content", - field: "url_info.url", - analyzer: Some(FTS_ANALYZER_NAME), - method: "BM25", - }, - FtsIndexSpec { - index_name: "text_content_url_title_fts_idx", - table: "text_content", - field: "url_info.title", - analyzer: Some(FTS_ANALYZER_NAME), - method: "BM25", - }, - FtsIndexSpec { - index_name: "knowledge_entity_fts_name_idx", - table: "knowledge_entity", - field: "name", - analyzer: Some(FTS_ANALYZER_NAME), - method: "BM25", - }, - FtsIndexSpec { - index_name: "knowledge_entity_fts_description_idx", - table: "knowledge_entity", - field: "description", - analyzer: Some(FTS_ANALYZER_NAME), - method: "BM25", - }, - FtsIndexSpec { - index_name: "text_chunk_fts_chunk_idx", - table: "text_chunk", - field: "chunk", - analyzer: Some(FTS_ANALYZER_NAME), - method: "BM25", - }, - ] -} - #[cfg(test)] mod tests { use super::*; diff --git a/common/src/storage/store.rs b/common/src/storage/store.rs index 6ae5457..9039483 100644 --- a/common/src/storage/store.rs +++ b/common/src/storage/store.rs @@ -17,8 +17,11 @@ pub type DynStore = Arc; /// Storage manager with persistent state and proper lifecycle management. #[derive(Clone)] pub struct StorageManager { + // Store from objectstore wrapped as dyn store: DynStore, + // Simple enum to track which kind backend_kind: StorageKind, + // Where on disk local_base: Option, } diff --git a/common/src/storage/types/text_chunk.rs b/common/src/storage/types/text_chunk.rs index faf1bd8..517f81a 100644 --- a/common/src/storage/types/text_chunk.rs +++ b/common/src/storage/types/text_chunk.rs @@ -61,36 +61,34 @@ impl TextChunk { embedding: Vec, db: &SurrealDbClient, ) -> Result<(), AppError> { - let emb = TextChunkEmbedding::new( - &chunk.id, - chunk.source_id.clone(), - embedding, - chunk.user_id.clone(), - ); + let chunk_id = chunk.id.clone(); + let source_id = chunk.source_id.clone(); + let user_id = chunk.user_id.clone(); - // Create both records in a single query - let query = format!( - " - BEGIN TRANSACTION; - CREATE type::thing('{chunk_table}', $chunk_id) CONTENT $chunk; - CREATE type::thing('{emb_table}', $emb_id) CONTENT $emb; - COMMIT TRANSACTION; - ", - chunk_table = Self::table_name(), - emb_table = TextChunkEmbedding::table_name(), - ); + let emb = TextChunkEmbedding::new(&chunk_id, source_id.clone(), embedding, user_id.clone()); - db.client - .query(query) - .bind(("chunk_id", chunk.id.clone())) + // Create both records in a single transaction so we don't orphan embeddings or chunks + let response = db + .client + .query("BEGIN TRANSACTION;") + .query(format!( + "CREATE type::thing('{chunk_table}', $chunk_id) CONTENT $chunk;", + chunk_table = Self::table_name(), + )) + .query(format!( + "CREATE type::thing('{emb_table}', $emb_id) CONTENT $emb;", + emb_table = TextChunkEmbedding::table_name(), + )) + .query("COMMIT TRANSACTION;") + .bind(("chunk_id", chunk_id.clone())) .bind(("chunk", chunk)) .bind(("emb_id", emb.id.clone())) .bind(("emb", emb)) .await - .map_err(AppError::Database)? - .check() .map_err(AppError::Database)?; + response.check().map_err(AppError::Database)?; + Ok(()) } @@ -330,6 +328,7 @@ impl TextChunk { #[cfg(test)] mod tests { use super::*; + use crate::storage::indexes::{ensure_runtime_indexes, rebuild_indexes}; use crate::storage::types::text_chunk_embedding::TextChunkEmbedding; use surrealdb::RecordId; use uuid::Uuid; @@ -524,6 +523,46 @@ mod tests { assert_eq!(embedding.source_id, source_id); } + #[tokio::test] + async fn test_store_with_embedding_with_runtime_indexes() { + let namespace = "test_ns_runtime"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + db.apply_migrations().await.expect("migrations"); + + // Ensure runtime indexes are built with the expected dimension. + let embedding_dimension = 3usize; + ensure_runtime_indexes(&db, embedding_dimension) + .await + .expect("ensure runtime indexes"); + + let chunk = TextChunk::new( + "runtime_src".to_string(), + "runtime chunk body".to_string(), + "runtime_user".to_string(), + ); + + TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db) + .await + .expect("store with embedding"); + + let stored_chunk: Option = db.get_item(&chunk.id).await.unwrap(); + assert!(stored_chunk.is_some(), "chunk should be stored"); + + let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id); + let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db) + .await + .expect("get embedding"); + assert!(embedding.is_some(), "embedding should exist"); + assert_eq!( + embedding.unwrap().embedding.len(), + embedding_dimension, + "embedding dimension should match runtime index" + ); + } + #[tokio::test] async fn test_vector_search_returns_empty_when_no_embeddings() { let namespace = "test_ns"; @@ -625,7 +664,7 @@ mod tests { .expect("Failed to start in-memory surrealdb"); db.apply_migrations().await.expect("migrations"); ensure_chunk_fts_index(&db).await; - db.rebuild_indexes().await.expect("rebuild indexes"); + rebuild_indexes(&db).await.expect("rebuild indexes"); let results = TextChunk::fts_search(5, "hello", &db, "user") .await @@ -651,7 +690,7 @@ mod tests { user_id.to_string(), ); db.store_item(chunk.clone()).await.expect("store chunk"); - db.rebuild_indexes().await.expect("rebuild indexes"); + rebuild_indexes(&db).await.expect("rebuild indexes"); let results = TextChunk::fts_search(3, "rust", &db, user_id) .await @@ -698,7 +737,7 @@ mod tests { db.store_item(other_user_chunk) .await .expect("store other user chunk"); - db.rebuild_indexes().await.expect("rebuild indexes"); + rebuild_indexes(&db).await.expect("rebuild indexes"); let results = TextChunk::fts_search(3, "apple", &db, user_id) .await diff --git a/common/src/utils/embedding.rs b/common/src/utils/embedding.rs index b7813b8..63d4c67 100644 --- a/common/src/utils/embedding.rs +++ b/common/src/utils/embedding.rs @@ -20,8 +20,8 @@ use crate::{ #[allow(clippy::module_name_repetitions)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum EmbeddingBackend { - OpenAI, #[default] + OpenAI, FastEmbed, Hashed, } @@ -276,9 +276,7 @@ fn bucket(token: &str, dimension: usize) -> usize { let safe_dimension = dimension.max(1); let mut hasher = DefaultHasher::new(); token.hash(&mut hasher); - usize::try_from(hasher.finish()) - .unwrap_or_default() - % safe_dimension + usize::try_from(hasher.finish()).unwrap_or_default() % safe_dimension } // Backward compatibility function diff --git a/evaluations/src/types.rs b/evaluations/src/types.rs index 0a2d21d..b01e169 100644 --- a/evaluations/src/types.rs +++ b/evaluations/src/types.rs @@ -215,16 +215,30 @@ impl EvaluationCandidate { } } +fn candidates_from_entities(entities: Vec) -> Vec { + entities + .into_iter() + .map(EvaluationCandidate::from_entity) + .collect() +} + +fn candidates_from_chunks(chunks: Vec) -> Vec { + chunks + .into_iter() + .map(EvaluationCandidate::from_chunk) + .collect() +} + pub fn adapt_strategy_output(output: StrategyOutput) -> Vec { match output { - StrategyOutput::Entities(entities) => entities - .into_iter() - .map(EvaluationCandidate::from_entity) - .collect(), - StrategyOutput::Chunks(chunks) => chunks - .into_iter() - .map(EvaluationCandidate::from_chunk) - .collect(), + StrategyOutput::Entities(entities) => candidates_from_entities(entities), + StrategyOutput::Chunks(chunks) => candidates_from_chunks(chunks), + StrategyOutput::Search(search_result) => { + let mut candidates = candidates_from_entities(search_result.entities); + candidates.extend(candidates_from_chunks(search_result.chunks)); + candidates.sort_by(|a, b| b.score.total_cmp(&a.score)); + candidates + } } } diff --git a/html-router/src/routes/chat/message_response_stream.rs b/html-router/src/routes/chat/message_response_stream.rs index 5aec618..e6c5f9c 100644 --- a/html-router/src/routes/chat/message_response_stream.rs +++ b/html-router/src/routes/chat/message_response_stream.rs @@ -151,6 +151,10 @@ pub async fn get_response_stream( retrieval_pipeline::StrategyOutput::Entities(entities) => { retrieved_entities_to_json(&entities) } + retrieval_pipeline::StrategyOutput::Search(search_result) => { + // For chat, use chunks from the search result + chunks_to_chat_context(&search_result.chunks) + } }; let formatted_user_message = create_user_message_with_history(&context_json, &history, &user_message.content); diff --git a/html-router/src/routes/search/handlers.rs b/html-router/src/routes/search/handlers.rs index 910a447..eca8e38 100644 --- a/html-router/src/routes/search/handlers.rs +++ b/html-router/src/routes/search/handlers.rs @@ -1,17 +1,13 @@ -use std::{fmt, str::FromStr}; +use std::{fmt, str::FromStr, time::Duration}; use axum::{ extract::{Query, State}, response::IntoResponse, }; -use common::storage::types::{ - conversation::Conversation, - knowledge_entity::{KnowledgeEntity, KnowledgeEntitySearchResult}, - text_content::{TextContent, TextContentSearchResult}, - user::User, -}; -use futures::future::try_join; +use common::storage::types::{conversation::Conversation, user::User}; +use retrieval_pipeline::{RetrievalConfig, SearchResult, SearchTarget, StrategyOutput}; use serde::{de, Deserialize, Deserializer, Serialize}; +use tokio::time::error::Elapsed; use crate::{ html_state::HtmlState, @@ -20,6 +16,7 @@ use crate::{ response_middleware::{HtmlError, TemplateResponse}, }, }; + /// Serde deserialization decorator to map empty Strings to None, fn empty_string_as_none<'de, D, T>(de: D) -> Result, D::Error> where @@ -40,6 +37,26 @@ pub struct SearchParams { query: Option, } +/// Chunk result for template rendering +#[derive(Serialize)] +struct TextChunkForTemplate { + id: String, + source_id: String, + chunk: String, + score: f32, +} + +/// Entity result for template rendering (from pipeline) +#[derive(Serialize)] +struct KnowledgeEntityForTemplate { + id: String, + name: String, + description: String, + entity_type: String, + source_id: String, + score: f32, +} + pub async fn search_result_handler( State(state): State, Query(params): Query, @@ -50,9 +67,9 @@ pub async fn search_result_handler( result_type: String, score: f32, #[serde(skip_serializing_if = "Option::is_none")] - text_content: Option, + text_chunk: Option, #[serde(skip_serializing_if = "Option::is_none")] - knowledge_entity: Option, + knowledge_entity: Option, } #[derive(Serialize)] @@ -70,37 +87,64 @@ pub async fn search_result_handler( if trimmed_query.is_empty() { (Vec::::new(), String::new()) } else { - const TOTAL_LIMIT: usize = 10; - let (text_results, entity_results) = try_join( - TextContent::search(&state.db, trimmed_query, &user.id, TOTAL_LIMIT), - KnowledgeEntity::search(&state.db, trimmed_query, &user.id, TOTAL_LIMIT), + // Use retrieval pipeline Search strategy + let config = RetrievalConfig::for_search(SearchTarget::Both); + let result = retrieval_pipeline::pipeline::run_pipeline( + &state.db, + &state.openai_client, + None, // No embedding provider in HtmlState + trimmed_query, + &user.id, + config, + None, // No reranker for now ) .await?; - let mut combined_results: Vec = - Vec::with_capacity(text_results.len() + entity_results.len()); + let search_result = match result { + StrategyOutput::Search(sr) => sr, + _ => SearchResult::new(vec![], vec![]), + }; - for text_result in text_results { - let score = text_result.score; + let mut combined_results: Vec = + Vec::with_capacity(search_result.chunks.len() + search_result.entities.len()); + + // Add chunk results + for chunk_result in search_result.chunks { combined_results.push(SearchResultForTemplate { - result_type: "text_content".to_string(), - score, - text_content: Some(text_result), + result_type: "text_chunk".to_string(), + score: chunk_result.score, + text_chunk: Some(TextChunkForTemplate { + id: chunk_result.chunk.id, + source_id: chunk_result.chunk.source_id, + chunk: chunk_result.chunk.chunk, + score: chunk_result.score, + }), knowledge_entity: None, }); } - for entity_result in entity_results { - let score = entity_result.score; + // Add entity results + for entity_result in search_result.entities { combined_results.push(SearchResultForTemplate { result_type: "knowledge_entity".to_string(), - score, - text_content: None, - knowledge_entity: Some(entity_result), + score: entity_result.score, + text_chunk: None, + knowledge_entity: Some(KnowledgeEntityForTemplate { + id: entity_result.entity.id, + name: entity_result.entity.name, + description: entity_result.entity.description, + entity_type: format!("{:?}", entity_result.entity.entity_type), + source_id: entity_result.entity.source_id, + score: entity_result.score, + }), }); } + // Sort by score descending combined_results.sort_by(|a, b| b.score.total_cmp(&a.score)); + + // Limit results + const TOTAL_LIMIT: usize = 10; combined_results.truncate(TOTAL_LIMIT); (combined_results, trimmed_query.to_string()) diff --git a/ingestion-pipeline/src/pipeline/services.rs b/ingestion-pipeline/src/pipeline/services.rs index e550322..61aa803 100644 --- a/ingestion-pipeline/src/pipeline/services.rs +++ b/ingestion-pipeline/src/pipeline/services.rs @@ -184,7 +184,8 @@ impl PipelineServices for DefaultPipelineServices { None => None, }; - let config = retrieval_pipeline::RetrievalConfig::for_ingestion(); + let config = + retrieval_pipeline::RetrievalConfig::for_search(retrieval_pipeline::SearchTarget::EntitiesOnly); match retrieval_pipeline::retrieve_entities( &self.db, &self.openai_client, @@ -197,6 +198,16 @@ impl PipelineServices for DefaultPipelineServices { .await { Ok(retrieval_pipeline::StrategyOutput::Entities(entities)) => Ok(entities), + Ok(retrieval_pipeline::StrategyOutput::Search(search)) => { + let chunk_count = search.chunks.len(); + let entities = search.entities; + tracing::debug!( + chunk_count, + entity_count = entities.len(), + "ingestion search results returned entities" + ); + Ok(entities) + } Ok(retrieval_pipeline::StrategyOutput::Chunks(_)) => Err(AppError::InternalError( "Ingestion retrieval should return entities".into(), )), diff --git a/ingestion-pipeline/src/pipeline/stages/mod.rs b/ingestion-pipeline/src/pipeline/stages/mod.rs index ffe6ba8..20ceb6d 100644 --- a/ingestion-pipeline/src/pipeline/stages/mod.rs +++ b/ingestion-pipeline/src/pipeline/stages/mod.rs @@ -4,6 +4,7 @@ use common::{ error::AppError, storage::{ db::SurrealDbClient, + indexes::rebuild_indexes, types::{ ingestion_payload::IngestionPayload, knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk, @@ -131,7 +132,7 @@ pub async fn enrich( }); return machine .enrich() - .map_err(|(_, guard)| map_guard_error("enrich", &guard)); + .map_err(|(_, guard)| map_guard_error("enrich", &guard)); } let content = ctx.text_content()?; @@ -173,18 +174,24 @@ pub async fn persist( let entity_count = entities.len(); let relationship_count = relationships.len(); - let ((), chunk_count) = tokio::try_join!( - store_graph_entities(ctx.db, &ctx.pipeline_config.tuning, entities, relationships), - store_vector_chunks( - ctx.db, - ctx.task_id.as_str(), - &chunks, - &ctx.pipeline_config.tuning - ) - )?; + debug!("Were storing chunks"); + let chunk_count = store_vector_chunks( + ctx.db, + ctx.task_id.as_str(), + &chunks, + &ctx.pipeline_config.tuning, + ) + .await?; + + debug!("We stored chunks"); + store_graph_entities(ctx.db, &ctx.pipeline_config.tuning, entities, relationships).await?; + + debug!("Stored graph entities"); ctx.db.store_item(text_content).await?; - ctx.db.rebuild_indexes().await?; + + debug!("stored item"); + rebuild_indexes(ctx.db).await?; debug!( task_id = %ctx.task_id, @@ -268,17 +275,9 @@ async fn store_vector_chunks( let chunk_count = chunks.len(); let batch_size = tuning.chunk_insert_concurrency.max(1); - for embedded in chunks { - debug!( - task_id = %task_id, - chunk_id = %embedded.chunk.id, - chunk_len = embedded.chunk.chunk.chars().count(), - "chunk persisted" - ); - } for batch in chunks.chunks(batch_size) { - store_chunk_batch(db, batch, tuning).await?; + store_chunk_batch(db, batch, tuning, task_id).await?; } Ok(chunk_count) @@ -294,14 +293,25 @@ async fn store_chunk_batch( db: &SurrealDbClient, batch: &[EmbeddedTextChunk], _tuning: &super::config::IngestionTuning, + task_id: &str, ) -> Result<(), AppError> { if batch.is_empty() { return Ok(()); } for embedded in batch { - TextChunk::store_with_embedding(embedded.chunk.clone(), embedded.embedding.clone(), db) - .await?; + TextChunk::store_with_embedding( + embedded.chunk.to_owned(), + embedded.embedding.to_owned(), + db, + ) + .await?; + debug!( + task_id = %task_id, + chunk_id = %embedded.chunk.id, + chunk_len = embedded.chunk.chunk.chars().count(), + "chunk persisted" + ); } Ok(()) diff --git a/main/src/main.rs b/main/src/main.rs index 76be6a9..ca23d0e 100644 --- a/main/src/main.rs +++ b/main/src/main.rs @@ -2,7 +2,9 @@ use api_router::{api_routes_v1, api_state::ApiState}; use axum::{extract::FromRef, Router}; use common::{ storage::{ - db::SurrealDbClient, indexes::ensure_runtime_indexes, store::StorageManager, + db::SurrealDbClient, + indexes::ensure_runtime_indexes, + store::StorageManager, types::system_settings::SystemSettings, }, utils::config::get_config, @@ -112,7 +114,7 @@ async fn main() -> Result<(), Box> { .await .unwrap(), ); - let _settings = SystemSettings::get_current(&worker_db) + let settings = SystemSettings::get_current(&worker_db) .await .expect("failed to load system settings"); @@ -125,9 +127,12 @@ async fn main() -> Result<(), Box> { // Create embedding provider for ingestion let embedding_provider = Arc::new( - common::utils::embedding::EmbeddingProvider::new_fastembed(None) - .await - .expect("failed to create embedding provider"), + common::utils::embedding::EmbeddingProvider::new_openai( + openai_client.clone(), + settings.embedding_model, + settings.embedding_dimensions, + ) + .expect("failed to create embedding provider"), ); let ingestion_pipeline = Arc::new( IngestionPipeline::new(