From 5724f11dc1b25976e6456e5b6fe7340351733a50 Mon Sep 17 00:00:00 2001 From: Per Stark Date: Thu, 28 May 2026 21:46:35 +0200 Subject: [PATCH] chore: harden knowledge graph storage and clear common clippy warnings Enforce stable 1:1 entity embeddings, relationship endpoint auth, and user-scoped deletes; align schemas/migrations and resolve common crate clippy findings. --- ...02_knowledge_graph_storage_hardening.surql | 33 +++ ...002_knowledge_graph_storage_hardening.json | 1 + common/schemas/knowledge_entity.surql | 3 +- .../schemas/knowledge_entity_embedding.surql | 4 +- common/schemas/relates_to.surql | 16 +- common/src/storage/indexes.rs | 29 ++- common/src/storage/types/conversation.rs | 27 +- common/src/storage/types/file_info.rs | 10 +- common/src/storage/types/ingestion_payload.rs | 10 +- common/src/storage/types/ingestion_task.rs | 14 +- common/src/storage/types/knowledge_entity.rs | 144 ++++++---- .../types/knowledge_entity_embedding.rs | 199 ++++++++++---- .../storage/types/knowledge_relationship.rs | 245 ++++++++++++------ common/src/storage/types/message.rs | 2 +- evaluations/src/corpus/store.rs | 1 + html-router/src/routes/index/handlers.rs | 2 +- ingestion-pipeline/src/pipeline/tests.rs | 2 +- 17 files changed, 533 insertions(+), 209 deletions(-) create mode 100644 common/migrations/20260528_000002_knowledge_graph_storage_hardening.surql create mode 100644 common/migrations/definitions/20260528_000002_knowledge_graph_storage_hardening.json diff --git a/common/migrations/20260528_000002_knowledge_graph_storage_hardening.surql b/common/migrations/20260528_000002_knowledge_graph_storage_hardening.surql new file mode 100644 index 0000000..309d219 --- /dev/null +++ b/common/migrations/20260528_000002_knowledge_graph_storage_hardening.surql @@ -0,0 +1,33 @@ +-- Harden knowledge entity embeddings and graph storage invariants. + +DEFINE FIELD IF NOT EXISTS source_id ON knowledge_entity_embedding TYPE string; + +-- Backfill denormalized source_id from the linked entity. +FOR $emb IN (SELECT * FROM knowledge_entity_embedding WHERE source_id = NONE OR source_id = '') { + LET $entity = (SELECT source_id FROM $emb.entity_id)[0]; + IF $entity != NONE { + UPDATE $emb.id SET source_id = $entity.source_id; + } +}; + +-- Re-key embeddings so record id matches entity id (stable 1:1 identity). +FOR $emb IN (SELECT * FROM knowledge_entity_embedding) { + LET $entity_key = record::id($emb.entity_id); + LET $canonical = type::thing('knowledge_entity_embedding', $entity_key); + IF $emb.id != $canonical { + UPSERT $canonical CONTENT { + entity_id: $emb.entity_id, + embedding: $emb.embedding, + user_id: $emb.user_id, + source_id: $emb.source_id, + created_at: $emb.created_at, + updated_at: $emb.updated_at + }; + DELETE $emb.id; + } +}; + +REMOVE INDEX IF EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding; +DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id UNIQUE; +DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_source_id_idx ON knowledge_entity_embedding FIELDS source_id; +DEFINE INDEX IF NOT EXISTS knowledge_entity_user_source_idx ON knowledge_entity FIELDS user_id, source_id; diff --git a/common/migrations/definitions/20260528_000002_knowledge_graph_storage_hardening.json b/common/migrations/definitions/20260528_000002_knowledge_graph_storage_hardening.json new file mode 100644 index 0000000..f02b4e6 --- /dev/null +++ b/common/migrations/definitions/20260528_000002_knowledge_graph_storage_hardening.json @@ -0,0 +1 @@ +{"schemas":"--- original\n+++ modified\n@@ -68,7 +68,7 @@\n\n # Defines the schema for the 'knowledge_entity' table.\n\n-DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMALESS;\n+DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMAFULL;\n\n # Standard fields\n DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE datetime;\n@@ -90,6 +90,7 @@\n -- DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;\n+DEFINE INDEX IF NOT EXISTS knowledge_entity_user_source_idx ON knowledge_entity FIELDS user_id, source_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;\n\n@@ -102,6 +103,7 @@\n DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity_embedding TYPE datetime;\n DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity_embedding TYPE datetime;\n DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity_embedding TYPE string;\n+DEFINE FIELD IF NOT EXISTS source_id ON knowledge_entity_embedding TYPE string;\n\n -- Custom fields\n DEFINE FIELD IF NOT EXISTS entity_id ON knowledge_entity_embedding TYPE record;\n@@ -109,8 +111,9 @@\n\n -- Indexes\n -- DEFINE INDEX IF NOT EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 1536;\n-DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id;\n+DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id UNIQUE;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_user_id_idx ON knowledge_entity_embedding FIELDS user_id;\n+DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_source_id_idx ON knowledge_entity_embedding FIELDS source_id;\n\n # Defines the schema for the 'message' table.\n\n@@ -135,19 +138,17 @@\n # Defines the 'relates_to' edge table for KnowledgeRelationships.\n # Edges connect nodes, in this case knowledge_entity records.\n\n-# Define the edge table itself, enforcing connections between knowledge_entity records\n-# SCHEMAFULL requires all fields to be defined, maybe start with SCHEMALESS if metadata might vary\n-DEFINE TABLE IF NOT EXISTS relates_to SCHEMALESS TYPE RELATION FROM knowledge_entity TO knowledge_entity;\n+DEFINE TABLE IF NOT EXISTS relates_to SCHEMAFULL TYPE RELATION FROM knowledge_entity TO knowledge_entity;\n+\n+DEFINE FIELD IF NOT EXISTS in ON relates_to TYPE record;\n+DEFINE FIELD IF NOT EXISTS out ON relates_to TYPE record;\n\n-# Define the metadata field within the edge\n # RelationshipMetadata is a struct, store as object\n DEFINE FIELD IF NOT EXISTS metadata ON relates_to TYPE object;\n+DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string;\n+DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;\n+DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;\n\n-# Optionally, define fields within the metadata object for stricter schema (requires SCHEMAFULL on table)\n-# DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string;\n-# DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;\n-# DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;\n-\n # Add indexes based on query patterns (delete_relationships_by_source_id, get_knowledge_relationships)\n DEFINE INDEX IF NOT EXISTS relates_to_metadata_source_id_idx ON relates_to FIELDS metadata.source_id;\n DEFINE INDEX IF NOT EXISTS relates_to_metadata_user_id_idx ON relates_to FIELDS metadata.user_id;\n","events":null} \ No newline at end of file diff --git a/common/schemas/knowledge_entity.surql b/common/schemas/knowledge_entity.surql index 6c6be77..9894b6b 100644 --- a/common/schemas/knowledge_entity.surql +++ b/common/schemas/knowledge_entity.surql @@ -1,6 +1,6 @@ # Defines the schema for the 'knowledge_entity' table. -DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMALESS; +DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMAFULL; # Standard fields DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE datetime; @@ -22,5 +22,6 @@ DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity TYPE string; -- DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536; DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id; DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id; +DEFINE INDEX IF NOT EXISTS knowledge_entity_user_source_idx ON knowledge_entity FIELDS user_id, source_id; DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type; DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at; diff --git a/common/schemas/knowledge_entity_embedding.surql b/common/schemas/knowledge_entity_embedding.surql index 7f852b4..546c4f6 100644 --- a/common/schemas/knowledge_entity_embedding.surql +++ b/common/schemas/knowledge_entity_embedding.surql @@ -7,6 +7,7 @@ DEFINE TABLE IF NOT EXISTS knowledge_entity_embedding SCHEMAFULL; DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity_embedding TYPE datetime; DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity_embedding TYPE datetime; DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity_embedding TYPE string; +DEFINE FIELD IF NOT EXISTS source_id ON knowledge_entity_embedding TYPE string; -- Custom fields DEFINE FIELD IF NOT EXISTS entity_id ON knowledge_entity_embedding TYPE record; @@ -14,5 +15,6 @@ DEFINE FIELD IF NOT EXISTS embedding ON knowledge_entity_embedding TYPE array; +DEFINE FIELD IF NOT EXISTS out ON relates_to TYPE record; -# Define the metadata field within the edge # RelationshipMetadata is a struct, store as object DEFINE FIELD IF NOT EXISTS metadata ON relates_to TYPE object; - -# Optionally, define fields within the metadata object for stricter schema (requires SCHEMAFULL on table) -# DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string; -# DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string; -# DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string; +DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string; +DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string; +DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string; # Add indexes based on query patterns (delete_relationships_by_source_id, get_knowledge_relationships) DEFINE INDEX IF NOT EXISTS relates_to_metadata_source_id_idx ON relates_to FIELDS metadata.source_id; diff --git a/common/src/storage/indexes.rs b/common/src/storage/indexes.rs index d9a3885..40a0e4d 100644 --- a/common/src/storage/indexes.rs +++ b/common/src/storage/indexes.rs @@ -11,6 +11,31 @@ use crate::{error::AppError, storage::db::SurrealDbClient}; const INDEX_POLL_INTERVAL: Duration = Duration::from_millis(50); const FTS_ANALYZER_NAME: &str = "app_en_fts_analyzer"; +/// HNSW index options used by runtime index creation (includes CONCURRENTLY). +pub const HNSW_INDEX_OPTIONS: &str = "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY"; +/// HNSW index options for use inside transactions (CONCURRENTLY not supported). +pub const HNSW_INDEX_OPTIONS_SYNC: &str = "DIST COSINE TYPE F32 EFC 100 M 8"; + +/// Builds a `DEFINE INDEX OVERWRITE ... HNSW` statement matching runtime index options. +#[must_use] +pub fn hnsw_index_overwrite_sql(index_name: &str, table: &str, dimension: usize) -> String { + format!( + "DEFINE INDEX OVERWRITE {index_name} ON TABLE {table} \ + FIELDS embedding HNSW DIMENSION {dimension} {HNSW_INDEX_OPTIONS};" + ) +} + +/// Recreates an HNSW index inside a transaction (for tests and dimension migrations). +#[must_use] +pub fn hnsw_index_redefine_transaction_sql(index_name: &str, table: &str, dimension: usize) -> String { + format!( + "BEGIN TRANSACTION; + REMOVE INDEX IF EXISTS {index_name} ON TABLE {table}; + DEFINE INDEX {index_name} ON TABLE {table} FIELDS embedding HNSW DIMENSION {dimension} {HNSW_INDEX_OPTIONS_SYNC}; + COMMIT TRANSACTION;" + ) +} + #[derive(Clone, Copy)] struct HnswIndexSpec { index_name: &'static str, @@ -23,12 +48,12 @@ const fn hnsw_index_specs() -> [HnswIndexSpec; 2] { HnswIndexSpec { index_name: "idx_embedding_text_chunk_embedding", table: "text_chunk_embedding", - options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY", + options: HNSW_INDEX_OPTIONS, }, HnswIndexSpec { index_name: "idx_embedding_knowledge_entity_embedding", table: "knowledge_entity_embedding", - options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY", + options: HNSW_INDEX_OPTIONS, }, ] } diff --git a/common/src/storage/types/conversation.rs b/common/src/storage/types/conversation.rs index 3dedf4c..8c7418d 100644 --- a/common/src/storage/types/conversation.rs +++ b/common/src/storage/types/conversation.rs @@ -518,8 +518,8 @@ mod tests { #[test] fn test_sidebar_conversation_deserializes_plain_string_id() { - let item: SidebarConversation = - serde_json::from_str(r#"{"id":"conv-plain","title":"My chat"}"#).unwrap(); + let item: SidebarConversation = serde_json::from_str(r#"{"id":"conv-plain","title":"My chat"}"#) + .expect("valid sidebar conversation json"); assert_eq!(item.id, "conv-plain"); assert_eq!(item.title, "My chat"); } @@ -543,8 +543,9 @@ mod tests { .await .expect("Failed to load sidebar"); assert_eq!(items.len(), 1); - assert_eq!(items[0].id, expected_id); - assert_eq!(items[0].title, "Sidebar title"); + let item = items.first().expect("expected one sidebar item"); + assert_eq!(item.id, expected_id); + assert_eq!(item.title, "Sidebar title"); } #[tokio::test] @@ -570,7 +571,13 @@ mod tests { let owner_messages = fetch_messages_for_owner(&db, &conversation_id, owner).await?; assert_eq!(owner_messages.len(), 1); - assert_eq!(owner_messages[0].content, "secret message"); + assert_eq!( + owner_messages + .first() + .expect("expected owner message") + .content, + "secret message" + ); let intruder_messages = fetch_messages_for_owner(&db, &conversation_id, intruder).await?; @@ -617,8 +624,14 @@ mod tests { Conversation::get_complete_conversation(&conversation_id, user_id, &db).await?; assert_eq!(messages.len(), 2); - assert_eq!(messages[0].content, "first"); - assert_eq!(messages[1].content, "second"); + assert_eq!( + messages.first().expect("expected first message").content, + "first" + ); + assert_eq!( + messages.get(1).expect("expected second message").content, + "second" + ); Ok(()) } diff --git a/common/src/storage/types/file_info.rs b/common/src/storage/types/file_info.rs index 4282148..aad4494 100644 --- a/common/src/storage/types/file_info.rs +++ b/common/src/storage/types/file_info.rs @@ -90,9 +90,9 @@ impl FileInfo { /// Replaces any non-alphanumeric characters (excluding '.' and '_') with underscores in /// both the stem and extension. fn sanitize_file_name(file_name: &str) -> String { - if let Some(idx) = file_name.rfind('.') { - let name = Self::sanitize_name_segment(&file_name[..idx]); - let ext = Self::sanitize_name_segment(&file_name[idx + 1..]); + if let Some((stem, ext)) = file_name.rsplit_once('.') { + let name = Self::sanitize_name_segment(stem); + let ext = Self::sanitize_name_segment(ext); if ext.is_empty() { name } else { @@ -321,7 +321,6 @@ mod tests { use anyhow::{self, Context}; use super::*; - use crate::error::AppError; use crate::storage::store::testing::TestStorageManager; use axum::http::HeaderMap; use axum_typed_multipart::FieldMetadata; @@ -844,8 +843,7 @@ mod tests { assert!(file_info.sha256.len() == 64); let bytes = file_info .get_content_with_storage(test_storage.storage()) - .await - .map_err(AppError::from)?; + .await?; assert!(bytes.is_empty()); Ok(()) diff --git a/common/src/storage/types/ingestion_payload.rs b/common/src/storage/types/ingestion_payload.rs index 5788381..f597af2 100644 --- a/common/src/storage/types/ingestion_payload.rs +++ b/common/src/storage/types/ingestion_payload.rs @@ -80,7 +80,7 @@ impl IngestionPayload { }); for (index, file) in files.into_iter().enumerate() { - let is_last_file = index + 1 == file_count; + let is_last_file = index.saturating_add(1) == file_count; if content_follows || !is_last_file { let Some(shared) = fields.as_ref() else { return Err(AppError::internal("shared ingest fields consumed early")); @@ -411,7 +411,9 @@ mod tests { )?; assert_eq!(result.len(), 2); - match (&result[0], &result[1]) { + let first = result.first().expect("expected first payload"); + let second = result.get(1).expect("expected second payload"); + match (first, second) { ( IngestionPayload::File { file_info: payload_file, @@ -499,8 +501,8 @@ mod tests { )?; assert_eq!(result.len(), 2); - assert!(matches!(result[0], IngestionPayload::File { .. })); - assert!(matches!(result[1], IngestionPayload::File { .. })); + assert!(matches!(result.first(), Some(IngestionPayload::File { .. }))); + assert!(matches!(result.get(1), Some(IngestionPayload::File { .. }))); Ok(()) } } diff --git a/common/src/storage/types/ingestion_task.rs b/common/src/storage/types/ingestion_task.rs index 2e5930d..00c8bfa 100644 --- a/common/src/storage/types/ingestion_task.rs +++ b/common/src/storage/types/ingestion_task.rs @@ -150,8 +150,8 @@ fn invalid_transition(state: TaskState, event: TaskTransition) -> AppError { )) } -fn worker_id_for_bind(worker_id: &Option) -> String { - worker_id.as_deref().unwrap_or("").to_string() +fn worker_id_for_bind(worker_id: Option<&String>) -> String { + worker_id.cloned().unwrap_or_default() } stored_object!(IngestionTask, "ingestion_task", { @@ -360,7 +360,7 @@ impl IngestionTask { "#; let now = chrono::Utc::now(); - let worker_id = worker_id_for_bind(&self.worker_id); + let worker_id = worker_id_for_bind(self.worker_id.as_ref()); let mut result = db .client .query(START_PROCESSING_QUERY) @@ -398,7 +398,7 @@ impl IngestionTask { "#; let now = chrono::Utc::now(); - let worker_id = worker_id_for_bind(&self.worker_id); + let worker_id = worker_id_for_bind(self.worker_id.as_ref()); let mut result = db .client .query(COMPLETE_QUERY) @@ -450,7 +450,7 @@ impl IngestionTask { ) .unwrap_or(now); - let worker_id = worker_id_for_bind(&self.worker_id); + let worker_id = worker_id_for_bind(self.worker_id.as_ref()); let mut result = db .client .query(FAIL_QUERY) @@ -680,7 +680,9 @@ mod tests { IngestionTask::create_all_and_add_to_db(payloads, user_id, &db).await?; assert_eq!(created.len(), 2); - assert_ne!(created[0].id, created[1].id); + let first = created.first().expect("expected first task"); + let second = created.get(1).expect("expected second task"); + assert_ne!(first.id, second.id); for task in &created { let stored: Option = db.get_item::(&task.id).await?; diff --git a/common/src/storage/types/knowledge_entity.rs b/common/src/storage/types/knowledge_entity.rs index 1c7e07d..23d49c5 100644 --- a/common/src/storage/types/knowledge_entity.rs +++ b/common/src/storage/types/knowledge_entity.rs @@ -7,7 +7,9 @@ use std::fmt::Write; use crate::{ error::AppError, storage::db::SurrealDbClient, - storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding, stored_object, + storage::indexes::hnsw_index_overwrite_sql, + storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding, + storage::types::system_settings::SystemSettings, stored_object, utils::embedding::generate_embedding, }; use async_openai::{config::OpenAIConfig, Client}; @@ -189,13 +191,25 @@ impl KnowledgeEntity { embedding: Vec, db: &SurrealDbClient, ) -> Result<(), AppError> { - let emb = KnowledgeEntityEmbedding::new(&entity.id, embedding, entity.user_id.clone()); + let settings = SystemSettings::get_current(db).await?; + KnowledgeEntityEmbedding::validate_dimension( + &embedding, + settings.embedding_dimensions as usize, + )?; + + let entity_id = entity.id.clone(); + let emb = KnowledgeEntityEmbedding::new( + &entity_id, + entity.source_id.clone(), + embedding, + entity.user_id.clone(), + ); let query = format!( " BEGIN TRANSACTION; CREATE type::thing('{entity_table}', $entity_id) CONTENT $entity; - CREATE type::thing('{emb_table}', $emb_id) CONTENT $emb; + UPSERT type::thing('{emb_table}', $entity_id) CONTENT $emb; COMMIT TRANSACTION; ", entity_table = Self::table_name(), @@ -204,9 +218,8 @@ impl KnowledgeEntity { db.client .query(query) - .bind(("entity_id", entity.id.clone())) + .bind(("entity_id", entity_id)) .bind(("entity", entity)) - .bind(("emb_id", emb.id.clone())) .bind(("emb", emb)) .await .map_err(AppError::Database)? @@ -275,12 +288,29 @@ impl KnowledgeEntity { db_client: &SurrealDbClient, ai_client: &Client, ) -> Result<(), AppError> { - let embedding_input = format!( - "name: {name}, description: {description}, type: {entity_type:?}", - ); + let embedding_input = format!( + "name: {name}, description: {description}, type: {entity_type:?}", + ); let embedding = generate_embedding(ai_client, &embedding_input, db_client).await?; - let user_id = Self::get_user_id_by_id(id, db_client).await?; - let emb = KnowledgeEntityEmbedding::new(id, embedding, user_id); + + let entity: KnowledgeEntity = db_client + .get_item(id) + .await + .map_err(AppError::Database)? + .ok_or_else(|| AppError::NotFound(format!("entity {id} not found")))?; + + let settings = SystemSettings::get_current(db_client).await?; + KnowledgeEntityEmbedding::validate_dimension( + &embedding, + settings.embedding_dimensions as usize, + )?; + + let emb = KnowledgeEntityEmbedding::new( + id, + entity.source_id, + embedding, + entity.user_id, + ); let now = Utc::now(); @@ -293,7 +323,7 @@ impl KnowledgeEntity { description = $description, updated_at = $updated_at, entity_type = $entity_type; - UPSERT type::thing($emb_table, $emb_id) CONTENT $emb; + UPSERT type::thing($emb_table, $id) CONTENT $emb; COMMIT TRANSACTION;", ) .bind(("table", Self::table_name())) @@ -302,33 +332,16 @@ impl KnowledgeEntity { .bind(("name", name.to_string())) .bind(("updated_at", surrealdb::Datetime::from(now))) .bind(("entity_type", entity_type.to_owned())) - .bind(("emb_id", emb.id.clone())) .bind(("emb", emb)) .bind(("description", description.to_string())) - .await?; + .await + .map_err(AppError::Database)? + .check() + .map_err(AppError::Database)?; Ok(()) } - async fn get_user_id_by_id(id: &str, db_client: &SurrealDbClient) -> Result { - #[derive(Deserialize)] - struct Row { - user_id: String, - } - - let mut response = db_client - .client - .query("SELECT user_id FROM type::thing($table, $id) LIMIT 1") - .bind(("table", Self::table_name())) - .bind(("id", id.to_string())) - .await - .map_err(AppError::Database)?; - let rows: Vec = response.take(0).map_err(AppError::Database)?; - rows.first() - .map(|r| r.user_id.clone()) - .ok_or_else(|| AppError::internal("user not found for entity")) - } - /// Re-creates embeddings for all knowledge entities in the database. /// /// This is a costly operation that should be run in the background. It follows the same @@ -359,7 +372,7 @@ impl KnowledgeEntity { info!("Found {total_entities} entities to process."); // Generate all new embeddings in memory - let mut new_embeddings: HashMap, String)> = HashMap::new(); + let mut new_embeddings: HashMap, String, String)> = HashMap::new(); info!("Generating new embeddings for all entities..."); for entity in &all_entities { let embedding_input = format!( @@ -387,7 +400,14 @@ impl KnowledgeEntity { error!("{err_msg}"); return Err(AppError::internal(err_msg)); } - new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone())); + new_embeddings.insert( + entity.id.clone(), + ( + embedding, + entity.user_id.clone(), + entity.source_id.clone(), + ), + ); } info!("Successfully generated all new embeddings."); @@ -396,7 +416,7 @@ impl KnowledgeEntity { let mut transaction_query = String::from("BEGIN TRANSACTION;"); // Add all update statements to the embedding table - for (id, (embedding, user_id)) in new_embeddings { + for (id, (embedding, user_id, source_id)) in new_embeddings { let embedding = serde_json::to_string(&embedding).map_err(|e| { AppError::internal(format!("embedding serialization failed: {e}")) })?; @@ -406,6 +426,7 @@ impl KnowledgeEntity { entity_id = type::thing('knowledge_entity', '{id}'), \ embedding = {embedding}, \ user_id = '{user_id}', \ + source_id = '{source_id}', \ created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \ updated_at = time::now();", ) @@ -414,7 +435,12 @@ impl KnowledgeEntity { write!( transaction_query, - "DEFINE INDEX OVERWRITE idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};", + "{}", + hnsw_index_overwrite_sql( + "idx_embedding_knowledge_entity_embedding", + KnowledgeEntityEmbedding::table_name(), + new_dimensions as usize, + ) ) .map_err(AppError::internal)?; @@ -431,6 +457,7 @@ impl KnowledgeEntity { /// /// This variant uses the application's configured embedding provider (FastEmbed, OpenAI, etc.) /// instead of directly calling OpenAI. Used during startup when embedding configuration changes. + #[allow(clippy::too_many_lines)] pub async fn update_all_embeddings_with_provider( db: &SurrealDbClient, provider: &crate::utils::embedding::EmbeddingProvider, @@ -453,7 +480,7 @@ impl KnowledgeEntity { info!(entities = total_entities, "Found entities to process"); // Generate all new embeddings in memory - let mut new_embeddings: HashMap, String)> = HashMap::new(); + let mut new_embeddings: HashMap, String, String)> = HashMap::new(); info!("Generating new embeddings for all entities..."); for (i, entity) in all_entities.iter().enumerate() { @@ -484,7 +511,14 @@ impl KnowledgeEntity { error!("{err_msg}"); return Err(AppError::internal(err_msg)); } - new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone())); + new_embeddings.insert( + entity.id.clone(), + ( + embedding, + entity.user_id.clone(), + entity.source_id.clone(), + ), + ); } info!("Successfully generated all new embeddings."); @@ -517,7 +551,7 @@ impl KnowledgeEntity { info!("Applying embedding updates in a transaction..."); let mut transaction_query = String::from("BEGIN TRANSACTION;"); - for (id, (embedding, user_id)) in new_embeddings { + for (id, (embedding, user_id, source_id)) in new_embeddings { let embedding = serde_json::to_string(&embedding).map_err(|e| { AppError::internal(format!("embedding serialization failed: {e}")) })?; @@ -527,6 +561,7 @@ impl KnowledgeEntity { entity_id = type::thing('knowledge_entity', '{id}'), \ embedding = {embedding}, \ user_id = '{user_id}', \ + source_id = '{source_id}', \ created_at = time::now(), \ updated_at = time::now();", ) @@ -535,7 +570,12 @@ impl KnowledgeEntity { write!( transaction_query, - "DEFINE INDEX OVERWRITE idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};", + "{}", + hnsw_index_overwrite_sql( + "idx_embedding_knowledge_entity_embedding", + KnowledgeEntityEmbedding::table_name(), + new_dimensions, + ) ) .map_err(AppError::internal)?; @@ -559,10 +599,21 @@ mod tests { #![allow(clippy::expect_used, clippy::must_use_candidate)] use super::*; use crate::storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding; + use crate::storage::types::system_settings::SystemSettings; use anyhow::{self, Context}; use serde_json::json; use uuid::Uuid; + async fn configure_embedding_dimension( + db: &SurrealDbClient, + dimension: u32, + ) -> anyhow::Result<()> { + let mut settings = SystemSettings::get_current(db).await?; + settings.embedding_dimensions = dimension; + SystemSettings::update(db, settings).await?; + Ok(()) + } + #[tokio::test] async fn test_knowledge_entity_creation() -> anyhow::Result<()> { let source_id = "source123".to_string(); @@ -656,14 +707,15 @@ mod tests { .await .with_context(|| "Failed to apply migrations".to_string())?; - let source_id = "source123".to_string(); - let entity_type = KnowledgeEntityType::Document; - let user_id = "user123".to_string(); - + configure_embedding_dimension(&db, 5).await?; KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 5) .await .with_context(|| "Failed to redefine index length".to_string())?; + let source_id = "source123".to_string(); + let entity_type = KnowledgeEntityType::Document; + let user_id = "user123".to_string(); + let entity1 = KnowledgeEntity::new( source_id.clone(), "Entity 1".to_string(), @@ -763,6 +815,7 @@ mod tests { .await .expect("Failed to apply migrations"); + configure_embedding_dimension(&db, 3).await.expect("configure dim"); KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) .await .expect("Failed to redefine index length"); @@ -847,6 +900,7 @@ mod tests { .await .with_context(|| "Failed to apply migrations".to_string())?; + configure_embedding_dimension(&db, 3).await?; KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) .await .with_context(|| "Failed to redefine index length".to_string())?; @@ -914,6 +968,7 @@ mod tests { .await .with_context(|| "Failed to apply migrations".to_string())?; + configure_embedding_dimension(&db, 3).await?; KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) .await .with_context(|| "Failed to redefine index length".to_string())?; @@ -1012,6 +1067,7 @@ mod tests { .await .with_context(|| "Failed to apply migrations".to_string())?; + configure_embedding_dimension(&db, 3).await?; KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) .await .with_context(|| "Failed to redefine index length".to_string())?; diff --git a/common/src/storage/types/knowledge_entity_embedding.rs b/common/src/storage/types/knowledge_entity_embedding.rs index 4fa994d..32d0c6b 100644 --- a/common/src/storage/types/knowledge_entity_embedding.rs +++ b/common/src/storage/types/knowledge_entity_embedding.rs @@ -2,11 +2,17 @@ use std::collections::HashMap; use surrealdb::RecordId; -use crate::{error::AppError, storage::db::SurrealDbClient, stored_object}; +use crate::{ + error::AppError, + storage::{db::SurrealDbClient, indexes::hnsw_index_redefine_transaction_sql}, + stored_object, +}; stored_object!(KnowledgeEntityEmbedding, "knowledge_entity_embedding", { entity_id: RecordId, embedding: Vec, + /// Denormalized source id for bulk deletes + source_id: String, /// Denormalized user id for query scoping user_id: String }); @@ -17,12 +23,10 @@ impl KnowledgeEntityEmbedding { db: &SurrealDbClient, dimension: usize, ) -> Result<(), AppError> { - let query = format!( - "BEGIN TRANSACTION; - REMOVE INDEX IF EXISTS idx_embedding_knowledge_entity_embedding ON TABLE {table}; - DEFINE INDEX idx_embedding_knowledge_entity_embedding ON TABLE {table} FIELDS embedding HNSW DIMENSION {dimension}; - COMMIT TRANSACTION;", - table = Self::table_name(), + let query = hnsw_index_redefine_transaction_sql( + "idx_embedding_knowledge_entity_embedding", + Self::table_name(), + dimension, ); let res = db.client.query(query).await.map_err(AppError::Database)?; @@ -31,16 +35,36 @@ impl KnowledgeEntityEmbedding { Ok(()) } - /// Create a new knowledge entity embedding + /// Validates that an embedding vector matches the configured HNSW dimension. + #[allow(clippy::result_large_err)] + pub fn validate_dimension(embedding: &[f32], expected: usize) -> Result<(), AppError> { + if embedding.len() != expected { + return Err(AppError::Validation(format!( + "embedding dimension mismatch: got {}, expected {expected}", + embedding.len() + ))); + } + Ok(()) + } + + /// Create a new knowledge entity embedding. + /// + /// The embedding record id equals `entity_id` so each entity has at most one embedding row. #[must_use] - pub fn new(entity_id: &str, embedding: Vec, user_id: String) -> Self { + pub fn new( + entity_id: &str, + source_id: String, + embedding: Vec, + user_id: String, + ) -> Self { let now = Utc::now(); Self { - id: uuid::Uuid::new_v4().to_string(), + id: entity_id.to_owned(), created_at: now, updated_at: now, entity_id: RecordId::from_table_key("knowledge_entity", entity_id), embedding, + source_id, user_id, } } @@ -73,8 +97,6 @@ impl KnowledgeEntityEmbedding { return Ok(HashMap::new()); } - let ids_list: Vec = entity_ids.to_vec(); - let query = format!( "SELECT * FROM {} WHERE entity_id INSIDE $entity_ids", Self::table_name() @@ -82,7 +104,7 @@ impl KnowledgeEntityEmbedding { let mut result = db .client .query(query) - .bind(("entity_ids", ids_list)) + .bind(("entity_ids", entity_ids.to_vec())) .await .map_err(AppError::Database)?; let embeddings: Vec = result.take(0).map_err(AppError::Database)?; @@ -106,32 +128,28 @@ impl KnowledgeEntityEmbedding { .query(query) .bind(("entity_id", entity_id.clone())) .await + .map_err(AppError::Database)? + .check() .map_err(AppError::Database)?; Ok(()) } - /// Delete embeddings by source_id (via joining to knowledge_entity table) + /// Delete all embeddings with the given denormalized `source_id`. pub async fn delete_by_source_id( source_id: &str, db: &SurrealDbClient, ) -> Result<(), AppError> { - #[derive(Deserialize)] - struct IdRow { - id: RecordId, - } - - let query = "SELECT id FROM knowledge_entity WHERE source_id = $source_id"; - let mut res = db - .client + let query = format!( + "DELETE FROM {} WHERE source_id = $source_id", + Self::table_name() + ); + db.client .query(query) .bind(("source_id", source_id.to_owned())) .await + .map_err(AppError::Database)? + .check() .map_err(AppError::Database)?; - let ids: Vec = res.take(0).map_err(AppError::Database)?; - - for row in ids { - Self::delete_by_entity_id(&row.id, db).await?; - } Ok(()) } } @@ -142,6 +160,7 @@ mod tests { use super::*; use crate::storage::db::SurrealDbClient; use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}; + use crate::storage::types::system_settings::SystemSettings; use anyhow::{self, Context}; use chrono::Utc; use surrealdb::Value as SurrealValue; @@ -161,6 +180,24 @@ mod tests { Ok(db) } + async fn setup_test_db_with_embedding_dimension( + dimension: u32, + ) -> anyhow::Result { + let db = setup_test_db().await?; + let mut settings = SystemSettings::get_current(&db).await?; + settings.embedding_dimensions = dimension; + SystemSettings::update(&db, settings).await?; + Ok(db) + } + + async fn prepare_test_db(dimension: u32) -> anyhow::Result { + let db = setup_test_db_with_embedding_dimension(dimension).await?; + KnowledgeEntityEmbedding::redefine_hnsw_index(&db, dimension as usize) + .await + .with_context(|| format!("set test index dimension to {dimension}"))?; + Ok(db) + } + fn build_knowledge_entity_with_id( key: &str, source_id: &str, @@ -179,12 +216,27 @@ mod tests { } } + #[test] + fn new_uses_entity_id_as_record_id() { + let emb = KnowledgeEntityEmbedding::new( + "entity-abc", + "source-1".to_owned(), + vec![0.1, 0.2], + "user-1".to_owned(), + ); + assert_eq!(emb.id, "entity-abc"); + } + + #[test] + fn validate_dimension_rejects_mismatch() { + let err = KnowledgeEntityEmbedding::validate_dimension(&[0.1, 0.2, 0.3], 2) + .expect_err("expected dimension mismatch"); + assert!(matches!(err, AppError::Validation(_))); + } + #[tokio::test] async fn test_create_and_get_by_entity_id() -> anyhow::Result<()> { - let db = setup_test_db().await?; - KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) - .await - .with_context(|| "set test index dimension".to_string())?; + let db = prepare_test_db(3).await?; let user_id = "user_ke"; let entity_key = "entity-1"; let source_id = "source-ke"; @@ -203,7 +255,9 @@ mod tests { .with_context(|| "Failed to get embedding by entity_id".to_string())? .ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?; + assert_eq!(fetched.id, entity_key); assert_eq!(fetched.user_id, user_id); + assert_eq!(fetched.source_id, source_id); assert_eq!(fetched.entity_id, entity_rid); assert_eq!(fetched.embedding, embedding_vec); @@ -212,10 +266,7 @@ mod tests { #[tokio::test] async fn test_delete_by_entity_id() -> anyhow::Result<()> { - let db = setup_test_db().await?; - KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) - .await - .with_context(|| "set test index dimension".to_string())?; + let db = prepare_test_db(3).await?; let user_id = "user_ke"; let entity_key = "entity-delete"; let source_id = "source-del"; @@ -247,15 +298,11 @@ mod tests { #[tokio::test] async fn test_store_with_embedding_creates_entity_and_embedding() -> anyhow::Result<()> { - let db = setup_test_db().await?; + let db = prepare_test_db(3).await?; let user_id = "user_store"; let source_id = "source_store"; let embedding = vec![0.2_f32, 0.3, 0.4]; - KnowledgeEntityEmbedding::redefine_hnsw_index(&db, embedding.len()) - .await - .with_context(|| "set test index dimension".to_string())?; - let entity = build_knowledge_entity_with_id("entity-store", source_id, user_id); KnowledgeEntity::store_with_embedding(entity.clone(), embedding.clone(), &db) @@ -274,18 +321,30 @@ mod tests { .with_context(|| "Failed to fetch embedding".to_string())?; let stored_embedding = stored_embedding.ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?; + assert_eq!(stored_embedding.id, entity.id); assert_eq!(stored_embedding.user_id, user_id); + assert_eq!(stored_embedding.source_id, source_id); assert_eq!(stored_embedding.entity_id, entity_rid); Ok(()) } + #[tokio::test] + async fn test_store_with_embedding_rejects_wrong_dimension() -> anyhow::Result<()> { + let db = prepare_test_db(3).await?; + + let entity = build_knowledge_entity_with_id("entity-dim", "source-dim", "user-dim"); + let result = + KnowledgeEntity::store_with_embedding(entity, vec![0.1, 0.2], &db).await; + + assert!(matches!(result, Err(AppError::Validation(_)))); + + Ok(()) + } + #[tokio::test] async fn test_delete_by_source_id() -> anyhow::Result<()> { - let db = setup_test_db().await?; - KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) - .await - .with_context(|| "set test index dimension".to_string())?; + let db = prepare_test_db(3).await?; let user_id = "user_ke"; let source_id = "shared-ke"; let other_source = "other-ke"; @@ -363,6 +422,10 @@ mod tests { idx_sql.contains("DIMENSION 16"), "expected index definition to contain new dimension, got: {idx_sql}" ); + assert!( + idx_sql.contains("DIST COSINE"), + "expected index definition to use cosine distance, got: {idx_sql}" + ); Ok(()) } @@ -374,10 +437,7 @@ mod tests { entity_id: KnowledgeEntity, } - let db = setup_test_db().await?; - KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) - .await - .with_context(|| "set test index dimension".to_string())?; + let db = prepare_test_db(3).await?; let user_id = "user_ke"; let entity_key = "entity-fetch"; let source_id = "source-fetch"; @@ -412,4 +472,47 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_upsert_replaces_existing_embedding_row() -> anyhow::Result<()> { + let db = prepare_test_db(3).await?; + + let user_id = "user-upsert"; + let source_id = "source-upsert"; + let entity = build_knowledge_entity_with_id("entity-upsert", source_id, user_id); + + KnowledgeEntity::store_with_embedding(entity.clone(), vec![1.0_f32, 0.0, 0.0], &db) + .await + .with_context(|| "initial store".to_string())?; + + let replacement = KnowledgeEntityEmbedding::new( + &entity.id, + source_id.to_owned(), + vec![0.0, 1.0, 0.0], + user_id.to_owned(), + ); + db.upsert_item(replacement) + .await + .with_context(|| "upsert replacement embedding".to_string())?; + + let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id); + let rows: Vec = db + .client + .query(format!( + "SELECT * FROM {} WHERE entity_id = $entity_id", + KnowledgeEntityEmbedding::table_name() + )) + .bind(("entity_id", entity_rid)) + .await + .with_context(|| "count embeddings".to_string())? + .take(0) + .with_context(|| "take embeddings".to_string())?; + + assert_eq!(rows.len(), 1); + let row = rows.first().expect("expected one embedding row"); + assert_eq!(row.id, entity.id); + assert_eq!(row.embedding, vec![0.0, 1.0, 0.0]); + + Ok(()) + } } diff --git a/common/src/storage/types/knowledge_relationship.rs b/common/src/storage/types/knowledge_relationship.rs index f22c0c8..f73a6f3 100644 --- a/common/src/storage/types/knowledge_relationship.rs +++ b/common/src/storage/types/knowledge_relationship.rs @@ -1,4 +1,5 @@ use crate::storage::types::serde_helpers::deserialize_flexible_id; +use crate::storage::types::user::User; use crate::{error::AppError, storage::db::SurrealDbClient}; use serde::{Deserialize, Serialize}; use uuid::Uuid; @@ -40,7 +41,21 @@ impl KnowledgeRelationship { }, } } + pub async fn store_relationship(&self, db_client: &SurrealDbClient) -> Result<(), AppError> { + User::get_and_validate_knowledge_entity( + &self.in_, + &self.metadata.user_id, + db_client, + ) + .await?; + User::get_and_validate_knowledge_entity( + &self.out, + &self.metadata.user_id, + db_client, + ) + .await?; + db_client .client .query( @@ -61,22 +76,30 @@ impl KnowledgeRelationship { .bind(("user_id", self.metadata.user_id.clone())) .bind(("source_id", self.metadata.source_id.clone())) .bind(("relationship_type", self.metadata.relationship_type.clone())) - .await? - .check()?; + .await + .map_err(AppError::Database)? + .check() + .map_err(AppError::Database)?; Ok(()) } pub async fn delete_relationships_by_source_id( source_id: &str, + user_id: &str, db_client: &SurrealDbClient, ) -> Result<(), AppError> { db_client .client - .query("DELETE FROM relates_to WHERE metadata.source_id = $source_id") + .query( + "DELETE FROM relates_to WHERE metadata.source_id = $source_id AND metadata.user_id = $user_id", + ) .bind(("source_id", source_id.to_owned())) - .await? - .check()?; + .bind(("user_id", user_id.to_owned())) + .await + .map_err(AppError::Database)? + .check() + .map_err(AppError::Database)?; Ok(()) } @@ -86,39 +109,37 @@ impl KnowledgeRelationship { user_id: &str, db_client: &SurrealDbClient, ) -> Result<(), AppError> { - let mut authorized_result = db_client + let mut delete_result = db_client .client .query( - "SELECT * FROM relates_to WHERE id = type::thing('relates_to', $id) AND metadata.user_id = $user_id", + "DELETE type::thing('relates_to', $id) WHERE metadata.user_id = $user_id RETURN BEFORE;", ) .bind(("id", id.to_owned())) .bind(("user_id", user_id.to_owned())) - .await?; - let authorized: Vec = authorized_result.take(0)?; + .await + .map_err(AppError::Database)?; + let deleted: Vec = + delete_result.take(0).map_err(AppError::Database)?; - if authorized.is_empty() { - let mut exists_result = db_client - .client - .query("SELECT * FROM type::thing('relates_to', $id)") - .bind(("id", id.to_owned())) - .await?; - let existing: Option = exists_result.take(0)?; + if !deleted.is_empty() { + return Ok(()); + } - if existing.is_some() { - Err(AppError::Auth( - "Not authorized to delete relationship".into(), - )) - } else { - Err(AppError::NotFound(format!("Relationship {id} not found"))) - } + let mut exists_result = db_client + .client + .query("SELECT * FROM type::thing('relates_to', $id)") + .bind(("id", id.to_owned())) + .await + .map_err(AppError::Database)?; + let existing: Option = + exists_result.take(0).map_err(AppError::Database)?; + + if existing.is_some() { + Err(AppError::Auth( + "Not authorized to delete relationship".into(), + )) } else { - db_client - .client - .query("DELETE type::thing('relates_to', $id)") - .bind(("id", id.to_owned())) - .await? - .check()?; - Ok(()) + Err(AppError::NotFound(format!("Relationship {id} not found"))) } } } @@ -158,11 +179,14 @@ mod tests { result.take(0).expect("failed to take relationship by id") } - async fn create_test_entity(name: &str, db_client: &SurrealDbClient) -> anyhow::Result { + async fn create_test_entity( + name: &str, + user_id: &str, + db_client: &SurrealDbClient, + ) -> anyhow::Result { let source_id = "source123".to_string(); let description = format!("Description for {name}"); let entity_type = KnowledgeEntityType::Document; - let user_id = "user123".to_string(); let entity = KnowledgeEntity::new( source_id, @@ -170,7 +194,7 @@ mod tests { description, entity_type, None, - user_id, + user_id.to_string(), ); let stored: Option = db_client @@ -211,18 +235,18 @@ mod tests { #[tokio::test] async fn test_store_and_verify_by_source_id() -> anyhow::Result<()> { let db = setup_test_db().await; + let user_id = "user123"; - let entity1_id = create_test_entity("Entity 1", &db).await?; - let entity2_id = create_test_entity("Entity 2", &db).await?; + let entity1_id = create_test_entity("Entity 1", user_id, &db).await?; + let entity2_id = create_test_entity("Entity 2", user_id, &db).await?; - let user_id = "user123".to_string(); let source_id = "source123".to_string(); let relationship_type = "references".to_string(); let relationship = KnowledgeRelationship::new( entity1_id.clone(), entity2_id.clone(), - user_id.clone(), + user_id.to_string(), source_id.clone(), relationship_type, ); @@ -257,16 +281,38 @@ mod tests { } #[tokio::test] - async fn test_store_relationship_resists_query_injection() -> anyhow::Result<()> { + async fn test_store_relationship_rejects_foreign_entity() -> anyhow::Result<()> { let db = setup_test_db().await; - let entity1_id = create_test_entity("Entity 1", &db).await?; - let entity2_id = create_test_entity("Entity 2", &db).await?; + let owner_entity = create_test_entity("Owner entity", "owner-user", &db).await?; + let other_entity = create_test_entity("Other entity", "other-user", &db).await?; + + let relationship = KnowledgeRelationship::new( + owner_entity, + other_entity, + "owner-user".to_string(), + "source123".to_string(), + "references".to_string(), + ); + + let result = relationship.store_relationship(&db).await; + assert!(matches!(result, Err(AppError::Auth(_)))); + + Ok(()) + } + + #[tokio::test] + async fn test_store_relationship_resists_query_injection() -> anyhow::Result<()> { + let db = setup_test_db().await; + let user_id = "user123"; + + let entity1_id = create_test_entity("Entity 1", user_id, &db).await?; + let entity2_id = create_test_entity("Entity 2", user_id, &db).await?; let relationship = KnowledgeRelationship::new( entity1_id, entity2_id, - "user'123".to_string(), + user_id.to_string(), "source123'; DELETE FROM relates_to; --".to_string(), "references'; UPDATE user SET admin = true; --".to_string(), ); @@ -297,18 +343,18 @@ mod tests { #[tokio::test] async fn test_store_and_delete_relationship() -> anyhow::Result<()> { let db = setup_test_db().await; + let user_id = "user123"; - let entity1_id = create_test_entity("Entity 1", &db).await?; - let entity2_id = create_test_entity("Entity 2", &db).await?; + let entity1_id = create_test_entity("Entity 1", user_id, &db).await?; + let entity2_id = create_test_entity("Entity 2", user_id, &db).await?; - let user_id = "user123".to_string(); let source_id = "source123".to_string(); let relationship_type = "references".to_string(); let relationship = KnowledgeRelationship::new( entity1_id.clone(), entity2_id.clone(), - user_id.clone(), + user_id.to_string(), source_id.clone(), relationship_type, ); @@ -319,9 +365,9 @@ mod tests { .with_context(|| "Failed to store relationship".to_string())?; let mut existing_before_delete = db - .query(format!( - "SELECT * FROM relates_to WHERE metadata.user_id = '{user_id}' AND metadata.source_id = '{source_id}'" - )) + .query("SELECT * FROM relates_to WHERE metadata.user_id = $user_id AND metadata.source_id = $source_id") + .bind(("user_id", user_id.to_string())) + .bind(("source_id", source_id.clone())) .await .with_context(|| "Query failed".to_string())?; let before_results: Vec = @@ -331,14 +377,14 @@ mod tests { "Relationship should exist before deletion" ); - KnowledgeRelationship::delete_relationship_by_id(&relationship.id, &user_id, &db) + KnowledgeRelationship::delete_relationship_by_id(&relationship.id, user_id, &db) .await .with_context(|| "Failed to delete relationship by ID".to_string())?; let mut result = db - .query(format!( - "SELECT * FROM relates_to WHERE metadata.user_id = '{user_id}' AND metadata.source_id = '{source_id}'" - )) + .query("SELECT * FROM relates_to WHERE metadata.user_id = $user_id AND metadata.source_id = $source_id") + .bind(("user_id", user_id.to_string())) + .bind(("source_id", source_id)) .await .with_context(|| "Query failed".to_string())?; let results: Vec = result.take(0).unwrap_or_default(); @@ -351,17 +397,17 @@ mod tests { #[tokio::test] async fn test_delete_relationship_by_id_unauthorized() -> anyhow::Result<()> { let db = setup_test_db().await; + let owner_user_id = "owner-user"; - let entity1_id = create_test_entity("Entity 1", &db).await?; - let entity2_id = create_test_entity("Entity 2", &db).await?; + let entity1_id = create_test_entity("Entity 1", owner_user_id, &db).await?; + let entity2_id = create_test_entity("Entity 2", owner_user_id, &db).await?; - let owner_user_id = "owner-user".to_string(); let source_id = "source123".to_string(); let relationship = KnowledgeRelationship::new( entity1_id.clone(), entity2_id.clone(), - owner_user_id.clone(), + owner_user_id.to_string(), source_id, "references".to_string(), ); @@ -372,9 +418,8 @@ mod tests { .with_context(|| "Failed to store relationship".to_string())?; let mut before_attempt = db - .query(format!( - "SELECT * FROM relates_to WHERE metadata.user_id = '{owner_user_id}'" - )) + .query("SELECT * FROM relates_to WHERE metadata.user_id = $user_id") + .bind(("user_id", owner_user_id.to_string())) .await .with_context(|| "Query failed".to_string())?; let before_results: Vec = before_attempt.take(0).unwrap_or_default(); @@ -398,9 +443,8 @@ mod tests { } let mut after_attempt = db - .query(format!( - "SELECT * FROM relates_to WHERE metadata.user_id = '{owner_user_id}'" - )) + .query("SELECT * FROM relates_to WHERE metadata.user_id = $user_id") + .bind(("user_id", owner_user_id.to_string())) .await .with_context(|| "Query failed".to_string())?; let results: Vec = after_attempt.take(0).unwrap_or_default(); @@ -416,19 +460,19 @@ mod tests { #[tokio::test] async fn test_store_relationship_exists() -> anyhow::Result<()> { let db = setup_test_db().await; + let user_id = "user123"; - let entity1_id = create_test_entity("Entity 1", &db).await?; - let entity2_id = create_test_entity("Entity 2", &db).await?; - let entity3_id = create_test_entity("Entity 3", &db).await?; + let entity1_id = create_test_entity("Entity 1", user_id, &db).await?; + let entity2_id = create_test_entity("Entity 2", user_id, &db).await?; + let entity3_id = create_test_entity("Entity 3", user_id, &db).await?; - let user_id = "user123".to_string(); let source_id = "source123".to_string(); let different_source_id = "different_source".to_string(); let relationship1 = KnowledgeRelationship::new( entity1_id.clone(), entity2_id.clone(), - user_id.clone(), + user_id.to_string(), source_id.clone(), "references".to_string(), ); @@ -436,7 +480,7 @@ mod tests { let relationship2 = KnowledgeRelationship::new( entity2_id.clone(), entity3_id.clone(), - user_id.clone(), + user_id.to_string(), source_id.clone(), "contains".to_string(), ); @@ -444,7 +488,7 @@ mod tests { let different_relationship = KnowledgeRelationship::new( entity1_id.clone(), entity3_id.clone(), - user_id.clone(), + user_id.to_string(), different_source_id.clone(), "mentions".to_string(), ); @@ -480,7 +524,7 @@ mod tests { before_delete_different.take(0).unwrap_or_default(); assert_eq!(before_delete_different_rows.len(), 1); - KnowledgeRelationship::delete_relationships_by_source_id(&source_id, &db) + KnowledgeRelationship::delete_relationships_by_source_id(&source_id, user_id, &db) .await .with_context(|| "Failed to delete relationships by source_id".to_string())?; @@ -497,19 +541,60 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_delete_relationships_by_source_id_scoped_to_user() -> anyhow::Result<()> { + let db = setup_test_db().await; + + let user_a = "user-a"; + let user_b = "user-b"; + let shared_source = "shared-source"; + + let a1 = create_test_entity("A1", user_a, &db).await?; + let a2 = create_test_entity("A2", user_a, &db).await?; + let b1 = create_test_entity("B1", user_b, &db).await?; + let b2 = create_test_entity("B2", user_b, &db).await?; + + let rel_a = KnowledgeRelationship::new( + a1, + a2, + user_a.to_string(), + shared_source.to_string(), + "references".to_string(), + ); + let rel_b = KnowledgeRelationship::new( + b1, + b2, + user_b.to_string(), + shared_source.to_string(), + "references".to_string(), + ); + + rel_a.store_relationship(&db).await?; + rel_b.store_relationship(&db).await?; + + KnowledgeRelationship::delete_relationships_by_source_id(shared_source, user_a, &db) + .await?; + + assert!(get_relationship_by_id(&rel_a.id, &db).await.is_none()); + assert!(get_relationship_by_id(&rel_b.id, &db).await.is_some()); + + Ok(()) + } + #[tokio::test] async fn test_delete_relationships_by_source_id_resists_query_injection() -> anyhow::Result<()> { let db = setup_test_db().await; + let user_id = "user123"; - let entity1_id = create_test_entity("Entity 1", &db).await?; - let entity2_id = create_test_entity("Entity 2", &db).await?; - let entity3_id = create_test_entity("Entity 3", &db).await?; + let entity1_id = create_test_entity("Entity 1", user_id, &db).await?; + let entity2_id = create_test_entity("Entity 2", user_id, &db).await?; + let entity3_id = create_test_entity("Entity 3", user_id, &db).await?; let safe_relationship = KnowledgeRelationship::new( entity1_id.clone(), entity2_id.clone(), - "user123".to_string(), + user_id.to_string(), "safe_source".to_string(), "references".to_string(), ); @@ -517,7 +602,7 @@ mod tests { let other_relationship = KnowledgeRelationship::new( entity2_id, entity3_id, - "user123".to_string(), + user_id.to_string(), "other_source".to_string(), "contains".to_string(), ); @@ -531,9 +616,13 @@ mod tests { .await .expect("store other relationship"); - KnowledgeRelationship::delete_relationships_by_source_id("safe_source' OR 1=1 --", &db) - .await - .expect("delete call should succeed"); + KnowledgeRelationship::delete_relationships_by_source_id( + "safe_source' OR 1=1 --", + user_id, + &db, + ) + .await + .expect("delete call should succeed"); let remaining_safe = get_relationship_by_id(&safe_relationship.id, &db).await; let remaining_other = get_relationship_by_id(&other_relationship.id, &db).await; diff --git a/common/src/storage/types/message.rs b/common/src/storage/types/message.rs index 00910e3..4a0f073 100644 --- a/common/src/storage/types/message.rs +++ b/common/src/storage/types/message.rs @@ -62,7 +62,7 @@ impl fmt::Display for Message { pub fn format_history(history: &[Message]) -> String { let estimated: usize = history .iter() - .map(|m| m.content.len() + 10) + .map(|m| m.content.len().saturating_add(10)) .sum(); let mut out = String::with_capacity(estimated); for (i, msg) in history.iter().enumerate() { diff --git a/evaluations/src/corpus/store.rs b/evaluations/src/corpus/store.rs index e8bc6a3..1ad2d53 100644 --- a/evaluations/src/corpus/store.rs +++ b/evaluations/src/corpus/store.rs @@ -308,6 +308,7 @@ fn build_manifest_batches(manifest: &CorpusManifest) -> Result entities.push(entity.clone()); entity_embeddings.push(KnowledgeEntityEmbedding::new( &entity.id, + entity.source_id.clone(), embedded_entity.embedding.clone(), entity.user_id.clone(), )); diff --git a/html-router/src/routes/index/handlers.rs b/html-router/src/routes/index/handlers.rs index eed1791..c174405 100644 --- a/html-router/src/routes/index/handlers.rs +++ b/html-router/src/routes/index/handlers.rs @@ -86,7 +86,7 @@ pub async fn delete_text_content( // Delete the text content and any related data TextChunk::delete_by_source_id(&text_content.id, &state.db).await?; KnowledgeEntity::delete_by_source_id(&text_content.id, &state.db).await?; - KnowledgeRelationship::delete_relationships_by_source_id(&text_content.id, &state.db).await?; + KnowledgeRelationship::delete_relationships_by_source_id(&text_content.id, &user.id, &state.db).await?; state .db .delete_item::(&text_content.id) diff --git a/ingestion-pipeline/src/pipeline/tests.rs b/ingestion-pipeline/src/pipeline/tests.rs index 52d26de..27c7ad9 100644 --- a/ingestion-pipeline/src/pipeline/tests.rs +++ b/ingestion-pipeline/src/pipeline/tests.rs @@ -293,7 +293,7 @@ async fn reserve_task( payload: IngestionPayload, user_id: &str, ) -> anyhow::Result { - let task = IngestionTask::create_and_add_to_db(payload, user_id.into(), db).await?; + let task = IngestionTask::create_and_add_to_db(payload, user_id, db).await?; let lease = task.lease_duration(); let claimed = IngestionTask::claim_next_ready(db, worker_id, Utc::now(), lease) .await?