diff --git a/Cargo.lock b/Cargo.lock index 37523aa..5234710 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3291,7 +3291,7 @@ checksum = "c41e0c4fef86961ac6d6f8a82609f55f31b05e4fce149ac5710e439df7619ba4" [[package]] name = "main" -version = "0.2.1" +version = "0.2.2" dependencies = [ "anyhow", "api-router", diff --git a/common/src/storage/types/knowledge_entity.rs b/common/src/storage/types/knowledge_entity.rs index 131ad6c..87bc8bd 100644 --- a/common/src/storage/types/knowledge_entity.rs +++ b/common/src/storage/types/knowledge_entity.rs @@ -150,7 +150,18 @@ impl KnowledgeEntity { let all_entities: Vec = db.select(Self::table_name()).await?; let total_entities = all_entities.len(); if total_entities == 0 { - info!("No knowledge entities to update. Skipping."); + info!("No knowledge entities to update. Just updating the idx"); + + let mut transaction_query = String::from("BEGIN TRANSACTION;"); + transaction_query + .push_str("REMOVE INDEX idx_embedding_entities ON TABLE knowledge_entity;"); + transaction_query.push_str(&format!( + "DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {};", + new_dimensions + )); + transaction_query.push_str("COMMIT TRANSACTION;"); + + db.query(transaction_query).await?; return Ok(()); } info!("Found {} entities to process.", total_entities); diff --git a/common/src/storage/types/system_settings.rs b/common/src/storage/types/system_settings.rs index 04c9e0b..77112ca 100644 --- a/common/src/storage/types/system_settings.rs +++ b/common/src/storage/types/system_settings.rs @@ -53,11 +53,60 @@ impl SystemSettings { #[cfg(test)] mod tests { - use crate::storage::types::text_chunk::TextChunk; + use crate::storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk}; + use async_openai::Client; use super::*; use uuid::Uuid; + async fn get_hnsw_index_dimension( + db: &SurrealDbClient, + table_name: &str, + index_name: &str, + ) -> u32 { + let query = format!("INFO FOR TABLE {table_name};"); + let mut response = db + .client + .query(query) + .await + .expect("Failed to fetch table info"); + + let info: Option = response + .take(0) + .expect("Failed to extract table info response"); + + let info = info.expect("Table info result missing"); + + let indexes = info + .get("indexes") + .or_else(|| { + info.get("tables") + .and_then(|tables| tables.get(table_name)) + .and_then(|table| table.get("indexes")) + }) + .unwrap_or_else(|| panic!("Indexes collection missing in table info: {info:#?}")); + + let definition = indexes + .get(index_name) + .and_then(|definition| definition.as_str()) + .unwrap_or_else(|| panic!("Index definition not found in table info: {info:#?}")); + + let dimension_part = definition + .split("DIMENSION") + .nth(1) + .expect("Index definition missing DIMENSION clause"); + + let dimension_token = dimension_part + .split_whitespace() + .next() + .expect("Dimension value missing in definition") + .trim_end_matches(';'); + + dimension_token + .parse::() + .expect("Dimension value is not a valid number") + } + #[tokio::test] async fn test_settings_initialization() { // Setup in-memory database for testing @@ -255,4 +304,74 @@ mod tests { assert!(migration_result.is_ok(), "Migrations should not fail"); } + + #[tokio::test] + async fn test_should_change_embedding_length_on_indexes_when_switching_length() { + let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string()) + .await + .expect("Failed to start DB"); + + // Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536. + db.apply_migrations() + .await + .expect("Initial migration failed"); + + let mut current_settings = SystemSettings::get_current(&db) + .await + .expect("Failed to load current settings"); + + let initial_chunk_dimension = + get_hnsw_index_dimension(&db, "text_chunk", "idx_embedding_chunks").await; + + assert_eq!( + initial_chunk_dimension, current_settings.embedding_dimensions, + "embedding size should match initial system settings" + ); + + let new_dimension = 768; + let new_model = "new-test-embedding-model".to_string(); + + current_settings.embedding_dimensions = new_dimension; + current_settings.embedding_model = new_model.clone(); + + let updated_settings = SystemSettings::update(&db, current_settings) + .await + .expect("Failed to update settings"); + + assert_eq!( + updated_settings.embedding_dimensions, new_dimension, + "Settings should reflect the new embedding dimension" + ); + + let openai_client = Client::new(); + + TextChunk::update_all_embeddings(&db, &openai_client, &new_model, new_dimension) + .await + .expect("TextChunk re-embedding should succeed on fresh DB"); + KnowledgeEntity::update_all_embeddings(&db, &openai_client, &new_model, new_dimension) + .await + .expect("KnowledgeEntity re-embedding should succeed on fresh DB"); + + let text_chunk_dimension = + get_hnsw_index_dimension(&db, "text_chunk", "idx_embedding_chunks").await; + let knowledge_dimension = + get_hnsw_index_dimension(&db, "knowledge_entity", "idx_embedding_entities").await; + + assert_eq!( + text_chunk_dimension, new_dimension, + "text_chunk index dimension should update" + ); + assert_eq!( + knowledge_dimension, new_dimension, + "knowledge_entity index dimension should update" + ); + + let persisted_settings = SystemSettings::get_current(&db) + .await + .expect("Failed to reload updated settings"); + assert_eq!( + persisted_settings.embedding_dimensions, new_dimension, + "Settings should persist new embedding dimension" + ); + } } diff --git a/common/src/storage/types/text_chunk.rs b/common/src/storage/types/text_chunk.rs index 0d498ab..11a574a 100644 --- a/common/src/storage/types/text_chunk.rs +++ b/common/src/storage/types/text_chunk.rs @@ -68,7 +68,17 @@ impl TextChunk { let all_chunks: Vec = db.select(Self::table_name()).await?; let total_chunks = all_chunks.len(); if total_chunks == 0 { - info!("No text chunks to update. Skipping."); + info!("No text chunks to update. Just updating the idx"); + + let mut transaction_query = String::from("BEGIN TRANSACTION;"); + transaction_query.push_str("REMOVE INDEX idx_embedding_chunks ON TABLE text_chunk;"); + transaction_query.push_str(&format!( + "DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {};", + new_dimensions)); + transaction_query.push_str("COMMIT TRANSACTION;"); + + db.query(transaction_query).await?; + return Ok(()); } info!("Found {} chunks to process.", total_chunks);