diff --git a/common/migrations/20260528_000003_text_chunk_embedding_storage_hardening.surql b/common/migrations/20260528_000003_text_chunk_embedding_storage_hardening.surql new file mode 100644 index 0000000..3fcb7ff --- /dev/null +++ b/common/migrations/20260528_000003_text_chunk_embedding_storage_hardening.surql @@ -0,0 +1,21 @@ +-- Harden text chunk embeddings storage invariants. + +-- Re-key embeddings so record id matches chunk id (stable 1:1 identity). +FOR $emb IN (SELECT * FROM text_chunk_embedding) { + LET $chunk_key = record::id($emb.chunk_id); + LET $canonical = type::thing('text_chunk_embedding', $chunk_key); + IF $emb.id != $canonical { + UPSERT $canonical CONTENT { + chunk_id: $emb.chunk_id, + embedding: $emb.embedding, + user_id: $emb.user_id, + source_id: $emb.source_id, + created_at: $emb.created_at, + updated_at: $emb.updated_at + }; + DELETE $emb.id; + } +}; + +REMOVE INDEX IF EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding; +DEFINE INDEX IF NOT EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding FIELDS chunk_id UNIQUE; diff --git a/common/migrations/definitions/20260528_000003_text_chunk_embedding_storage_hardening.json b/common/migrations/definitions/20260528_000003_text_chunk_embedding_storage_hardening.json new file mode 100644 index 0000000..563d558 --- /dev/null +++ b/common/migrations/definitions/20260528_000003_text_chunk_embedding_storage_hardening.json @@ -0,0 +1 @@ +{"schemas":"--- original\n+++ modified\n@@ -237,7 +237,7 @@\n\n -- Indexes\n -- DEFINE INDEX IF NOT EXISTS idx_embedding_text_chunk_embedding ON text_chunk_embedding FIELDS embedding HNSW DIMENSION 1536;\n-DEFINE INDEX IF NOT EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding FIELDS chunk_id;\n+DEFINE INDEX IF NOT EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding FIELDS chunk_id UNIQUE;\n DEFINE INDEX IF NOT EXISTS text_chunk_embedding_user_id_idx ON text_chunk_embedding FIELDS user_id;\n DEFINE INDEX IF NOT EXISTS text_chunk_embedding_source_id_idx ON text_chunk_embedding FIELDS source_id;\n\n","events":null} \ No newline at end of file diff --git a/common/schemas/text_chunk_embedding.surql b/common/schemas/text_chunk_embedding.surql index 5a43d55..d2b4999 100644 --- a/common/schemas/text_chunk_embedding.surql +++ b/common/schemas/text_chunk_embedding.surql @@ -15,6 +15,6 @@ DEFINE FIELD IF NOT EXISTS embedding ON text_chunk_embedding TYPE array; -- Indexes -- DEFINE INDEX IF NOT EXISTS idx_embedding_text_chunk_embedding ON text_chunk_embedding FIELDS embedding HNSW DIMENSION 1536; -DEFINE INDEX IF NOT EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding FIELDS chunk_id; +DEFINE INDEX IF NOT EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding FIELDS chunk_id UNIQUE; DEFINE INDEX IF NOT EXISTS text_chunk_embedding_user_id_idx ON text_chunk_embedding FIELDS user_id; DEFINE INDEX IF NOT EXISTS text_chunk_embedding_source_id_idx ON text_chunk_embedding FIELDS source_id; diff --git a/common/src/storage/types/text_chunk.rs b/common/src/storage/types/text_chunk.rs index 3b4ea92..b158d4e 100644 --- a/common/src/storage/types/text_chunk.rs +++ b/common/src/storage/types/text_chunk.rs @@ -2,6 +2,8 @@ use std::collections::HashMap; use std::fmt::Write; +use crate::storage::indexes::hnsw_index_overwrite_sql; +use crate::storage::types::system_settings::SystemSettings; use crate::storage::types::text_chunk_embedding::TextChunkEmbedding; use crate::{error::AppError, storage::db::SurrealDbClient, stored_object}; use async_openai::{config::OpenAIConfig, Client}; @@ -10,7 +12,7 @@ use tokio_retry::{ Retry, }; -use tracing::{error, info}; +use tracing::{error, info, warn}; use uuid::Uuid; stored_object!(TextChunk, "text_chunk", { @@ -45,14 +47,17 @@ impl TextChunk { source_id: &str, db_client: &SurrealDbClient, ) -> Result<(), AppError> { - // Delete embeddings first - TextChunkEmbedding::delete_by_source_id(source_id, db_client).await?; - db_client .client - .query("DELETE FROM type::table($table) WHERE source_id = $source_id") - .bind(("table", Self::table_name())) + .query("BEGIN TRANSACTION;") + .query(format!( + "DELETE FROM {} WHERE source_id = $source_id;", + TextChunkEmbedding::table_name() + )) + .query("DELETE FROM type::table($table) WHERE source_id = $source_id;") + .query("COMMIT TRANSACTION;") .bind(("source_id", source_id.to_owned())) + .bind(("table", Self::table_name())) .await .map_err(AppError::Database)? .check() @@ -68,34 +73,41 @@ impl TextChunk { embedding: Vec, db: &SurrealDbClient, ) -> Result<(), AppError> { + let settings = SystemSettings::get_current(db).await?; + TextChunkEmbedding::validate_dimension( + &embedding, + settings.embedding_dimensions as usize, + )?; + let chunk_id = chunk.id.clone(); - let source_id = chunk.source_id.clone(); - let user_id = chunk.user_id.clone(); + let emb = TextChunkEmbedding::new( + &chunk_id, + chunk.source_id.clone(), + embedding, + chunk.user_id.clone(), + ); - let emb = TextChunkEmbedding::new(&chunk_id, source_id.clone(), embedding, user_id.clone()); + let query = format!( + " + BEGIN TRANSACTION; + CREATE type::thing('{chunk_table}', $chunk_id) CONTENT $chunk; + UPSERT type::thing('{emb_table}', $chunk_id) CONTENT $emb; + COMMIT TRANSACTION; + ", + chunk_table = Self::table_name(), + emb_table = TextChunkEmbedding::table_name(), + ); - // Create both records in a single transaction so we don't orphan embeddings or chunks - let response = db - .client - .query("BEGIN TRANSACTION;") - .query(format!( - "CREATE type::thing('{chunk_table}', $chunk_id) CONTENT $chunk;", - chunk_table = Self::table_name(), - )) - .query(format!( - "CREATE type::thing('{emb_table}', $emb_id) CONTENT $emb;", - emb_table = TextChunkEmbedding::table_name(), - )) - .query("COMMIT TRANSACTION;") - .bind(("chunk_id", chunk_id.clone())) + db.client + .query(query) + .bind(("chunk_id", chunk_id)) .bind(("chunk", chunk)) - .bind(("emb_id", emb.id.clone())) .bind(("emb", emb)) .await + .map_err(AppError::Database)? + .check() .map_err(AppError::Database)?; - response.check().map_err(AppError::Database)?; - Ok(()) } @@ -147,6 +159,9 @@ impl TextChunk { r.chunk_id.map(|chunk| TextChunkSearchResult { chunk, score: r.score, + }).or_else(|| { + warn!("vector search hit orphaned text_chunk_embedding row with missing chunk"); + None }) }) .collect()) @@ -296,22 +311,32 @@ impl TextChunk { let embedding = serde_json::to_string(&embedding).map_err(|e| { AppError::internal(format!("embedding serialization failed: {e}")) })?; + let id = surql_json_string(&id)?; + let user_id = surql_json_string(&user_id)?; + let source_id = surql_json_string(&source_id)?; write!( &mut transaction_query, - "UPSERT type::thing('text_chunk_embedding', '{id}') SET \ - chunk_id = type::thing('text_chunk', '{id}'), \ - source_id = '{source_id}', \ + "UPSERT type::thing('{emb_table}', {id}) SET \ + chunk_id = type::thing('{chunk_table}', {id}), \ + source_id = {source_id}, \ embedding = {embedding}, \ - user_id = '{user_id}', \ + user_id = {user_id}, \ created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \ updated_at = time::now();", + emb_table = TextChunkEmbedding::table_name(), + chunk_table = Self::table_name(), ) .map_err(AppError::internal)?; } write!( &mut transaction_query, - "DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};", + "{}", + hnsw_index_overwrite_sql( + "idx_embedding_text_chunk_embedding", + TextChunkEmbedding::table_name(), + new_dimensions as usize, + ) ) .map_err(AppError::internal)?; @@ -408,22 +433,32 @@ impl TextChunk { let embedding = serde_json::to_string(&embedding).map_err(|e| { AppError::internal(format!("embedding serialization failed: {e}")) })?; + let id = surql_json_string(&id)?; + let user_id = surql_json_string(&user_id)?; + let source_id = surql_json_string(&source_id)?; write!( &mut transaction_query, - "CREATE type::thing('text_chunk_embedding', '{id}') SET \ - chunk_id = type::thing('text_chunk', '{id}'), \ - source_id = '{source_id}', \ + "CREATE type::thing('{emb_table}', {id}) SET \ + chunk_id = type::thing('{chunk_table}', {id}), \ + source_id = {source_id}, \ embedding = {embedding}, \ - user_id = '{user_id}', \ + user_id = {user_id}, \ created_at = time::now(), \ updated_at = time::now();", + emb_table = TextChunkEmbedding::table_name(), + chunk_table = Self::table_name(), ) .map_err(AppError::internal)?; } write!( &mut transaction_query, - "DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};", + "{}", + hnsw_index_overwrite_sql( + "idx_embedding_text_chunk_embedding", + TextChunkEmbedding::table_name(), + new_dimensions, + ) ) .map_err(AppError::internal)?; @@ -441,6 +476,12 @@ impl TextChunk { } } +#[allow(clippy::result_large_err)] +fn surql_json_string(value: &str) -> Result { + serde_json::to_string(value) + .map_err(|e| AppError::internal(format!("string serialization failed: {e}"))) +} + #[cfg(test)] mod tests { #![allow(clippy::expect_used, clippy::must_use_candidate)] @@ -448,10 +489,21 @@ mod tests { use super::*; use crate::storage::indexes::{ensure_runtime, rebuild}; + use crate::storage::types::system_settings::SystemSettings; use crate::storage::types::text_chunk_embedding::TextChunkEmbedding; use surrealdb::RecordId; use uuid::Uuid; + async fn configure_embedding_dimension( + db: &SurrealDbClient, + dimension: u32, + ) -> anyhow::Result<()> { + let mut settings = SystemSettings::get_current(db).await?; + settings.embedding_dimensions = dimension; + SystemSettings::update(db, settings).await?; + Ok(()) + } + async fn ensure_chunk_fts_index(db: &SurrealDbClient) -> anyhow::Result<()> { let snowball_sql = r#" DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english); @@ -500,6 +552,7 @@ mod tests { let source_id = "source123".to_string(); let user_id = "user123".to_string(); + configure_embedding_dimension(&db, 5).await?; TextChunkEmbedding::redefine_hnsw_index(&db, 5) .await .with_context(|| "redefine index".to_string())?; @@ -578,6 +631,7 @@ mod tests { db.apply_migrations() .await .with_context(|| "migrations".to_string())?; + configure_embedding_dimension(&db, 5).await?; TextChunkEmbedding::redefine_hnsw_index(&db, 5) .await .with_context(|| "redefine index".to_string())?; @@ -619,6 +673,7 @@ mod tests { .await .expect("Failed to start in-memory surrealdb"); db.apply_migrations().await.expect("migrations"); + configure_embedding_dimension(&db, 5).await.expect("configure dim"); TextChunkEmbedding::redefine_hnsw_index(&db, 5) .await .expect("redefine index"); @@ -677,6 +732,7 @@ mod tests { let user_id = "user_store".to_string(); let chunk = TextChunk::new(source_id.clone(), "chunk body".to_string(), user_id.clone()); + configure_embedding_dimension(&db, 3).await?; TextChunkEmbedding::redefine_hnsw_index(&db, 3) .await .with_context(|| "redefine index".to_string())?; @@ -699,6 +755,7 @@ mod tests { .with_context(|| "get embedding".to_string())? .with_context(|| "expected embedding".to_string())?; assert_eq!(embedding.chunk_id, rid); + assert_eq!(embedding.id, chunk.id); assert_eq!(embedding.user_id, user_id); assert_eq!(embedding.source_id, source_id); Ok(()) @@ -716,6 +773,11 @@ mod tests { .with_context(|| "migrations".to_string())?; let embedding_dimension = 3usize; + configure_embedding_dimension( + &db, + u32::try_from(embedding_dimension).expect("test embedding dimension fits in u32"), + ) + .await?; ensure_runtime(&db, embedding_dimension) .await .with_context(|| "ensure runtime indexes".to_string())?; @@ -761,6 +823,7 @@ mod tests { .await .with_context(|| "migrations".to_string())?; + configure_embedding_dimension(&db, 3).await?; TextChunkEmbedding::redefine_hnsw_index(&db, 3) .await .with_context(|| "redefine index".to_string())?; @@ -784,6 +847,7 @@ mod tests { .await .with_context(|| "migrations".to_string())?; + configure_embedding_dimension(&db, 3).await?; TextChunkEmbedding::redefine_hnsw_index(&db, 3) .await .with_context(|| "redefine index".to_string())?; @@ -824,6 +888,7 @@ mod tests { .await .with_context(|| "migrations".to_string())?; + configure_embedding_dimension(&db, 3).await?; TextChunkEmbedding::redefine_hnsw_index(&db, 3) .await .with_context(|| "redefine index".to_string())?; @@ -973,4 +1038,77 @@ mod tests { ); Ok(()) } + + #[tokio::test] + async fn test_store_with_embedding_rejects_wrong_dimension() -> anyhow::Result<()> { + let namespace = "test_ns_dim"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; + db.apply_migrations() + .await + .with_context(|| "migrations".to_string())?; + configure_embedding_dimension(&db, 3).await?; + + let chunk = TextChunk::new( + "src".to_string(), + "body".to_string(), + "user".to_string(), + ); + + let err = TextChunk::store_with_embedding(chunk, vec![0.1, 0.2], &db) + .await + .expect_err("expected dimension validation failure"); + assert!(matches!(err, AppError::Validation(_))); + + Ok(()) + } + + #[tokio::test] + async fn test_vector_search_with_orphaned_embedding() -> anyhow::Result<()> { + let namespace = "test_ns_orphan_chunk"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; + db.apply_migrations() + .await + .with_context(|| "migrations".to_string())?; + configure_embedding_dimension(&db, 3).await?; + TextChunkEmbedding::redefine_hnsw_index(&db, 3) + .await + .with_context(|| "redefine index".to_string())?; + + let user_id = "user".to_string(); + let chunk = TextChunk::new( + "src".to_string(), + "orphan chunk".to_string(), + user_id.clone(), + ); + + TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db) + .await + .with_context(|| "store chunk with embedding".to_string())?; + + db.client + .query(format!( + "DELETE type::thing('{table}', $id);", + table = TextChunk::table_name() + )) + .bind(("id", chunk.id.clone())) + .await + .with_context(|| "delete chunk".to_string())?; + + let results = TextChunk::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id) + .await + .with_context(|| "search should succeed even with orphans".to_string())?; + + assert!( + results.is_empty(), + "should return empty result for orphan, got: {results:?}" + ); + + Ok(()) + } } diff --git a/common/src/storage/types/text_chunk_embedding.rs b/common/src/storage/types/text_chunk_embedding.rs index 35545e0..69b45e0 100644 --- a/common/src/storage/types/text_chunk_embedding.rs +++ b/common/src/storage/types/text_chunk_embedding.rs @@ -1,7 +1,11 @@ use surrealdb::RecordId; use crate::storage::types::text_chunk::TextChunk; -use crate::{error::AppError, storage::db::SurrealDbClient, stored_object}; +use crate::{ + error::AppError, + storage::{db::SurrealDbClient, indexes::hnsw_index_redefine_transaction_sql}, + stored_object, +}; stored_object!(TextChunkEmbedding, "text_chunk_embedding", { /// Record link to the owning text_chunk @@ -23,12 +27,10 @@ impl TextChunkEmbedding { db: &SurrealDbClient, dimension: usize, ) -> Result<(), AppError> { - let query = format!( - "BEGIN TRANSACTION; - REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE {table}; - DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE {table} FIELDS embedding HNSW DIMENSION {dimension}; - COMMIT TRANSACTION;", - table = Self::table_name(), + let query = hnsw_index_redefine_transaction_sql( + "idx_embedding_text_chunk_embedding", + Self::table_name(), + dimension, ); let res = db.client.query(query).await.map_err(AppError::Database)?; @@ -37,20 +39,30 @@ impl TextChunkEmbedding { Ok(()) } - /// Create a new text chunk embedding + /// Validates that an embedding vector matches the configured HNSW dimension. + #[allow(clippy::result_large_err)] + pub fn validate_dimension(embedding: &[f32], expected: usize) -> Result<(), AppError> { + if embedding.len() != expected { + return Err(AppError::Validation(format!( + "embedding dimension mismatch: got {}, expected {expected}", + embedding.len() + ))); + } + Ok(()) + } + + /// Create a new text chunk embedding. /// - /// `chunk_id` is the **key** part of the text_chunk id (e.g. the UUID), - /// not "text_chunk:uuid". + /// The embedding record id equals `chunk_id` so each chunk has at most one embedding row. + /// `chunk_id` is the **key** part of the text_chunk id (e.g. the UUID), not "text_chunk:uuid". #[must_use] pub fn new(chunk_id: &str, source_id: String, embedding: Vec, user_id: String) -> Self { let now = Utc::now(); Self { - // NOTE: `stored_object!` macro defines `id` as `String` - id: uuid::Uuid::new_v4().to_string(), + id: chunk_id.to_owned(), created_at: now, updated_at: now, - // Create a record link: text_chunk: chunk_id: RecordId::from_table_key(TextChunk::table_name(), chunk_id), source_id, embedding, @@ -132,10 +144,10 @@ mod tests { use super::*; use crate::storage::db::SurrealDbClient; + use crate::storage::types::system_settings::SystemSettings; use surrealdb::Value as SurrealValue; use uuid::Uuid; - /// Helper to create an in-memory DB and apply migrations async fn setup_test_db() -> anyhow::Result { let namespace = "test_ns"; let database = Uuid::new_v4().to_string(); @@ -150,7 +162,24 @@ mod tests { Ok(db) } - /// Helper: create a text_chunk with a known key, return its RecordId + async fn setup_test_db_with_embedding_dimension( + dimension: u32, + ) -> anyhow::Result { + let db = setup_test_db().await?; + let mut settings = SystemSettings::get_current(&db).await?; + settings.embedding_dimensions = dimension; + SystemSettings::update(&db, settings).await?; + Ok(db) + } + + async fn prepare_test_db(dimension: u32) -> anyhow::Result { + let db = setup_test_db_with_embedding_dimension(dimension).await?; + TextChunkEmbedding::redefine_hnsw_index(&db, dimension as usize) + .await + .with_context(|| format!("set test index dimension to {dimension}"))?; + Ok(db) + } + async fn create_text_chunk_with_id( db: &SurrealDbClient, key: &str, @@ -196,18 +225,34 @@ mod tests { Ok(idx_sql) } + #[test] + fn new_uses_chunk_id_as_record_id() { + let emb = TextChunkEmbedding::new( + "chunk-abc", + "source-1".to_owned(), + vec![0.1, 0.2], + "user-1".to_owned(), + ); + assert_eq!(emb.id, "chunk-abc"); + } + + #[test] + fn validate_dimension_rejects_mismatch() { + let err = TextChunkEmbedding::validate_dimension(&[0.1, 0.2, 0.3], 2) + .expect_err("expected dimension mismatch"); + assert!(matches!(err, AppError::Validation(_))); + } + #[tokio::test] async fn test_create_and_get_by_chunk_id() -> anyhow::Result<()> { - let db = setup_test_db().await?; + let db = prepare_test_db(3).await?; let user_id = "user_a"; let chunk_key = "chunk-123"; let source_id = "source-1"; - // 1) Create a text_chunk with a known key let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await?; - // 2) Create and store an embedding for that chunk let embedding_vec = vec![0.1_f32, 0.2, 0.3]; let emb = TextChunkEmbedding::new( chunk_key, @@ -216,24 +261,16 @@ mod tests { user_id.to_string(), ); - TextChunkEmbedding::redefine_hnsw_index(&db, emb.embedding.len()) + db.upsert_item(emb) .await - .with_context(|| "Failed to redefine index length".to_string())?; + .with_context(|| "Failed to store embedding".to_string())?; - let _: Option = db - .client - .create(TextChunkEmbedding::table_name()) - .content(emb) - .await - .with_context(|| "Failed to store embedding".to_string())? - .with_context(|| "Failed to deserialize stored embedding".to_string())?; - - // 3) Fetch it via get_by_chunk_id let fetched = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db) .await .with_context(|| "Failed to get embedding by chunk_id".to_string())? .with_context(|| "Expected an embedding to be found".to_string())?; + assert_eq!(fetched.id, chunk_key); assert_eq!(fetched.user_id, user_id); assert_eq!(fetched.chunk_id, chunk_rid); assert_eq!(fetched.embedding, embedding_vec); @@ -242,7 +279,7 @@ mod tests { #[tokio::test] async fn test_delete_by_chunk_id() -> anyhow::Result<()> { - let db = setup_test_db().await?; + let db = prepare_test_db(3).await?; let user_id = "user_b"; let chunk_key = "chunk-delete"; @@ -257,30 +294,19 @@ mod tests { user_id.to_string(), ); - TextChunkEmbedding::redefine_hnsw_index(&db, emb.embedding.len()) + db.upsert_item(emb) .await - .with_context(|| "Failed to redefine index length".to_string())?; + .with_context(|| "Failed to store embedding".to_string())?; - let _: Option = db - .client - .create(TextChunkEmbedding::table_name()) - .content(emb) - .await - .with_context(|| "Failed to store embedding".to_string())? - .with_context(|| "Failed to deserialize stored embedding".to_string())?; - - // Ensure it exists let existing = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db) .await .with_context(|| "Failed to get embedding before delete".to_string())?; assert!(existing.is_some(), "Embedding should exist before delete"); - // Delete by chunk_id TextChunkEmbedding::delete_by_chunk_id(&chunk_rid, &db) .await .with_context(|| "Failed to delete by chunk_id".to_string())?; - // Ensure it no longer exists let after = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db) .await .with_context(|| "Failed to get embedding after delete".to_string())?; @@ -290,56 +316,33 @@ mod tests { #[tokio::test] async fn test_delete_by_source_id() -> anyhow::Result<()> { - let db = setup_test_db().await?; + let db = prepare_test_db(1).await?; let user_id = "user_c"; let source_id = "shared-source"; let other_source = "other-source"; - // Two chunks with the same source_id let chunk1_rid = create_text_chunk_with_id(&db, "chunk-s1", source_id, user_id).await?; let chunk2_rid = create_text_chunk_with_id(&db, "chunk-s2", source_id, user_id).await?; - - // One chunk with a different source_id let chunk_other_rid = create_text_chunk_with_id(&db, "chunk-other", other_source, user_id).await?; - // Create embeddings for all three - let emb1 = TextChunkEmbedding::new( - "chunk-s1", - source_id.to_string(), - vec![0.1], - user_id.to_string(), - ); - let emb2 = TextChunkEmbedding::new( - "chunk-s2", - source_id.to_string(), - vec![0.2], - user_id.to_string(), - ); - let emb3 = TextChunkEmbedding::new( - "chunk-other", - other_source.to_string(), - vec![0.3], - user_id.to_string(), - ); - - // Update length on index - TextChunkEmbedding::redefine_hnsw_index(&db, emb1.embedding.len()) - .await - .with_context(|| "Failed to redefine index length".to_string())?; - - for emb in [emb1, emb2, emb3] { - let _: Option = db - .client - .create(TextChunkEmbedding::table_name()) - .content(emb) + for (key, src, vec) in [ + ("chunk-s1", source_id, vec![0.1]), + ("chunk-s2", source_id, vec![0.2]), + ("chunk-other", other_source, vec![0.3]), + ] { + let emb = TextChunkEmbedding::new( + key, + src.to_string(), + vec, + user_id.to_string(), + ); + db.upsert_item(emb) .await - .with_context(|| "Failed to store embedding".to_string())? - .with_context(|| "Failed to deserialize stored embedding".to_string())?; + .with_context(|| format!("store embedding for {key}"))?; } - // Sanity check: they all exist assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db) .await .with_context(|| "get chunk1".to_string())? @@ -353,12 +356,10 @@ mod tests { .with_context(|| "get chunk_other".to_string())? .is_some()); - // Delete embeddings by source_id (shared-source) TextChunkEmbedding::delete_by_source_id(source_id, &db) .await .with_context(|| "Failed to delete by source_id".to_string())?; - // Chunks from shared-source should have no embeddings assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db) .await .with_context(|| "check chunk1".to_string())? @@ -367,8 +368,6 @@ mod tests { .await .with_context(|| "check chunk2".to_string())? .is_none()); - - // The other chunk should still have its embedding assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db) .await .with_context(|| "check chunk_other".to_string())? @@ -376,11 +375,61 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_upsert_replaces_existing_embedding_row() -> anyhow::Result<()> { + let db = prepare_test_db(3).await?; + + let user_id = "user-upsert"; + let source_id = "source-upsert"; + let chunk_key = "chunk-upsert"; + + create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await?; + + let initial = TextChunkEmbedding::new( + chunk_key, + source_id.to_owned(), + vec![1.0_f32, 0.0, 0.0], + user_id.to_owned(), + ); + db.upsert_item(initial) + .await + .with_context(|| "initial upsert".to_string())?; + + let replacement = TextChunkEmbedding::new( + chunk_key, + source_id.to_owned(), + vec![0.0, 1.0, 0.0], + user_id.to_owned(), + ); + db.upsert_item(replacement) + .await + .with_context(|| "upsert replacement embedding".to_string())?; + + let chunk_rid = RecordId::from_table_key(TextChunk::table_name(), chunk_key); + let rows: Vec = db + .client + .query(format!( + "SELECT * FROM {} WHERE chunk_id = $chunk_id", + TextChunkEmbedding::table_name() + )) + .bind(("chunk_id", chunk_rid)) + .await + .with_context(|| "count embeddings".to_string())? + .take(0) + .with_context(|| "take embeddings".to_string())?; + + assert_eq!(rows.len(), 1); + let row = rows.first().expect("expected one embedding row"); + assert_eq!(row.id, chunk_key); + assert_eq!(row.embedding, vec![0.0, 1.0, 0.0]); + + Ok(()) + } + #[tokio::test] async fn test_redefine_hnsw_index_updates_dimension() -> anyhow::Result<()> { let db = setup_test_db().await?; - // Change the index dimension from default (1536) to a smaller test value. TextChunkEmbedding::redefine_hnsw_index(&db, 8) .await .with_context(|| "failed to redefine index".to_string())?; @@ -391,6 +440,10 @@ mod tests { idx_sql.contains("DIMENSION 8"), "expected index definition to contain new dimension, got: {idx_sql}" ); + assert!( + idx_sql.contains("DIST COSINE"), + "expected index definition to use cosine distance, got: {idx_sql}" + ); Ok(()) } diff --git a/common/src/storage/types/text_content.rs b/common/src/storage/types/text_content.rs index 25c8aef..27326dd 100644 --- a/common/src/storage/types/text_content.rs +++ b/common/src/storage/types/text_content.rs @@ -101,7 +101,7 @@ impl TextContent { ) -> Result<(), AppError> { let now = Utc::now(); - let _res: Option = db + let updated: Option = db .update((Self::table_name(), id)) .patch(PatchOp::replace("/context", context)) .patch(PatchOp::replace("/category", category)) @@ -110,7 +110,14 @@ impl TextContent { "/updated_at", surrealdb::Datetime::from(now), )) - .await?; + .await + .map_err(AppError::Database)?; + + if updated.is_none() { + return Err(AppError::NotFound(format!( + "text content {id} not found" + ))); + } Ok(()) } @@ -128,9 +135,10 @@ impl TextContent { .bind(("table_name", TextContent::table_name())) .bind(("file_id", file_id.to_owned())) .bind(("exclude_id", exclude_id.to_owned())) - .await?; + .await + .map_err(AppError::Database)?; - let existing: Option = response.take(0)?; + let existing: Option = response.take(0).map_err(AppError::Database)?; Ok(existing.is_some()) } @@ -141,7 +149,8 @@ impl TextContent { user_id: &str, limit: usize, ) -> Result, AppError> { - let sql = r#" + let sql = format!( + r#" SELECT *, search::highlight('', '', 0) AS highlighted_text, @@ -158,7 +167,7 @@ impl TextContent { IF search::score(4) != NONE THEN search::score(4) ELSE 0 END + IF search::score(5) != NONE THEN search::score(5) ELSE 0 END ) AS score - FROM text_content + FROM {table} WHERE ( text @0@ $terms OR @@ -171,16 +180,19 @@ impl TextContent { AND user_id = $user_id ORDER BY score DESC LIMIT $limit; - "#; + "#, + table = Self::table_name(), + ); - Ok(db - .client + db.client .query(sql) .bind(("terms", search_terms.to_owned())) .bind(("user_id", user_id.to_owned())) .bind(("limit", limit)) - .await? - .take(0)?) + .await + .map_err(AppError::Database)? + .take(0) + .map_err(AppError::Database) } } @@ -190,6 +202,25 @@ mod tests { use anyhow::{self, Context}; use super::*; + use crate::storage::indexes::{ensure_runtime, rebuild}; + + async fn setup_test_db() -> anyhow::Result { + let namespace = "test_ns"; + let database = Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, &database) + .await + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; + db.apply_migrations() + .await + .with_context(|| "Failed to apply migrations".to_string())?; + ensure_runtime(&db, 1536) + .await + .with_context(|| "ensure runtime indexes".to_string())?; + rebuild(&db) + .await + .with_context(|| "rebuild indexes".to_string())?; + Ok(db) + } #[tokio::test] async fn test_text_content_creation() -> anyhow::Result<()> { @@ -306,6 +337,18 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_text_content_patch_not_found() -> anyhow::Result<()> { + let db = setup_test_db().await?; + + let err = TextContent::patch("missing-id", "ctx", "cat", "text", &db) + .await + .expect_err("expected not found"); + + assert!(matches!(err, AppError::NotFound(_))); + Ok(()) + } + #[tokio::test] async fn test_has_other_with_file_detects_shared_usage() -> anyhow::Result<()> { let namespace = "test_ns"; @@ -366,4 +409,60 @@ mod tests { assert!(!has_other_after); Ok(()) } + + #[tokio::test] + async fn test_search_returns_empty_when_no_content() -> anyhow::Result<()> { + let db = setup_test_db().await?; + + let results = TextContent::search(&db, "hello", "user", 5) + .await + .with_context(|| "search".to_string())?; + + assert!(results.is_empty()); + Ok(()) + } + + #[tokio::test] + async fn test_search_finds_matching_text_and_filters_user() -> anyhow::Result<()> { + let db = setup_test_db().await?; + let user_id = "search_user"; + + let matching = TextContent::new( + "rust programming language".to_string(), + Some("context".to_string()), + "notes".to_string(), + None, + None, + user_id.to_string(), + ); + let other_user = TextContent::new( + "rust programming language".to_string(), + None, + "notes".to_string(), + None, + None, + "other_user".to_string(), + ); + + db.store_item(matching.clone()) + .await + .with_context(|| "store matching".to_string())?; + db.store_item(other_user) + .await + .with_context(|| "store other user".to_string())?; + rebuild(&db) + .await + .with_context(|| "rebuild indexes".to_string())?; + + let results = TextContent::search(&db, "rust", user_id, 5) + .await + .with_context(|| "search".to_string())?; + + assert_eq!(results.len(), 1); + let row = results.first().context("expected one result")?; + assert_eq!(row.id, matching.id); + assert_eq!(row.user_id, user_id); + assert!(row.score.is_finite()); + Ok(()) + } }