chore: harden knowledge graph storage and clear common clippy warnings

Enforce stable 1:1 entity embeddings, relationship endpoint auth, and
user-scoped deletes; align schemas/migrations and resolve common crate
clippy findings.
This commit is contained in:
Per Stark
2026-05-28 21:46:35 +02:00
parent 189adb1a5f
commit 5724f11dc1
17 changed files with 533 additions and 209 deletions
@@ -0,0 +1,33 @@
-- Harden knowledge entity embeddings and graph storage invariants.
DEFINE FIELD IF NOT EXISTS source_id ON knowledge_entity_embedding TYPE string;
-- Backfill denormalized source_id from the linked entity.
FOR $emb IN (SELECT * FROM knowledge_entity_embedding WHERE source_id = NONE OR source_id = '') {
LET $entity = (SELECT source_id FROM $emb.entity_id)[0];
IF $entity != NONE {
UPDATE $emb.id SET source_id = $entity.source_id;
}
};
-- Re-key embeddings so record id matches entity id (stable 1:1 identity).
FOR $emb IN (SELECT * FROM knowledge_entity_embedding) {
LET $entity_key = record::id($emb.entity_id);
LET $canonical = type::thing('knowledge_entity_embedding', $entity_key);
IF $emb.id != $canonical {
UPSERT $canonical CONTENT {
entity_id: $emb.entity_id,
embedding: $emb.embedding,
user_id: $emb.user_id,
source_id: $emb.source_id,
created_at: $emb.created_at,
updated_at: $emb.updated_at
};
DELETE $emb.id;
}
};
REMOVE INDEX IF EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding;
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id UNIQUE;
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_source_id_idx ON knowledge_entity_embedding FIELDS source_id;
DEFINE INDEX IF NOT EXISTS knowledge_entity_user_source_idx ON knowledge_entity FIELDS user_id, source_id;
@@ -0,0 +1 @@
{"schemas":"--- original\n+++ modified\n@@ -68,7 +68,7 @@\n\n # Defines the schema for the 'knowledge_entity' table.\n\n-DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMALESS;\n+DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMAFULL;\n\n # Standard fields\n DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE datetime;\n@@ -90,6 +90,7 @@\n -- DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;\n+DEFINE INDEX IF NOT EXISTS knowledge_entity_user_source_idx ON knowledge_entity FIELDS user_id, source_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;\n\n@@ -102,6 +103,7 @@\n DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity_embedding TYPE datetime;\n DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity_embedding TYPE datetime;\n DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity_embedding TYPE string;\n+DEFINE FIELD IF NOT EXISTS source_id ON knowledge_entity_embedding TYPE string;\n\n -- Custom fields\n DEFINE FIELD IF NOT EXISTS entity_id ON knowledge_entity_embedding TYPE record<knowledge_entity>;\n@@ -109,8 +111,9 @@\n\n -- Indexes\n -- DEFINE INDEX IF NOT EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 1536;\n-DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id;\n+DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id UNIQUE;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_user_id_idx ON knowledge_entity_embedding FIELDS user_id;\n+DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_source_id_idx ON knowledge_entity_embedding FIELDS source_id;\n\n # Defines the schema for the 'message' table.\n\n@@ -135,19 +138,17 @@\n # Defines the 'relates_to' edge table for KnowledgeRelationships.\n # Edges connect nodes, in this case knowledge_entity records.\n\n-# Define the edge table itself, enforcing connections between knowledge_entity records\n-# SCHEMAFULL requires all fields to be defined, maybe start with SCHEMALESS if metadata might vary\n-DEFINE TABLE IF NOT EXISTS relates_to SCHEMALESS TYPE RELATION FROM knowledge_entity TO knowledge_entity;\n+DEFINE TABLE IF NOT EXISTS relates_to SCHEMAFULL TYPE RELATION FROM knowledge_entity TO knowledge_entity;\n+\n+DEFINE FIELD IF NOT EXISTS in ON relates_to TYPE record<knowledge_entity>;\n+DEFINE FIELD IF NOT EXISTS out ON relates_to TYPE record<knowledge_entity>;\n\n-# Define the metadata field within the edge\n # RelationshipMetadata is a struct, store as object\n DEFINE FIELD IF NOT EXISTS metadata ON relates_to TYPE object;\n+DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string;\n+DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;\n+DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;\n\n-# Optionally, define fields within the metadata object for stricter schema (requires SCHEMAFULL on table)\n-# DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string;\n-# DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;\n-# DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;\n-\n # Add indexes based on query patterns (delete_relationships_by_source_id, get_knowledge_relationships)\n DEFINE INDEX IF NOT EXISTS relates_to_metadata_source_id_idx ON relates_to FIELDS metadata.source_id;\n DEFINE INDEX IF NOT EXISTS relates_to_metadata_user_id_idx ON relates_to FIELDS metadata.user_id;\n","events":null}
+2 -1
View File
@@ -1,6 +1,6 @@
# Defines the schema for the 'knowledge_entity' table. # Defines the schema for the 'knowledge_entity' table.
DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMALESS; DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMAFULL;
# Standard fields # Standard fields
DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE datetime; DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE datetime;
@@ -22,5 +22,6 @@ DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity TYPE string;
-- DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536; -- DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536;
DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id; DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;
DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id; DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;
DEFINE INDEX IF NOT EXISTS knowledge_entity_user_source_idx ON knowledge_entity FIELDS user_id, source_id;
DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type; DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;
DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at; DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;
@@ -7,6 +7,7 @@ DEFINE TABLE IF NOT EXISTS knowledge_entity_embedding SCHEMAFULL;
DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity_embedding TYPE datetime; DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity_embedding TYPE datetime;
DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity_embedding TYPE datetime; DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity_embedding TYPE datetime;
DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity_embedding TYPE string; DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity_embedding TYPE string;
DEFINE FIELD IF NOT EXISTS source_id ON knowledge_entity_embedding TYPE string;
-- Custom fields -- Custom fields
DEFINE FIELD IF NOT EXISTS entity_id ON knowledge_entity_embedding TYPE record<knowledge_entity>; DEFINE FIELD IF NOT EXISTS entity_id ON knowledge_entity_embedding TYPE record<knowledge_entity>;
@@ -14,5 +15,6 @@ DEFINE FIELD IF NOT EXISTS embedding ON knowledge_entity_embedding TYPE array<fl
-- Indexes -- Indexes
-- DEFINE INDEX IF NOT EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 1536; -- DEFINE INDEX IF NOT EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 1536;
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id; DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id UNIQUE;
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_user_id_idx ON knowledge_entity_embedding FIELDS user_id; DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_user_id_idx ON knowledge_entity_embedding FIELDS user_id;
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_source_id_idx ON knowledge_entity_embedding FIELDS source_id;
+7 -9
View File
@@ -1,18 +1,16 @@
# Defines the 'relates_to' edge table for KnowledgeRelationships. # Defines the 'relates_to' edge table for KnowledgeRelationships.
# Edges connect nodes, in this case knowledge_entity records. # Edges connect nodes, in this case knowledge_entity records.
# Define the edge table itself, enforcing connections between knowledge_entity records DEFINE TABLE IF NOT EXISTS relates_to SCHEMAFULL TYPE RELATION FROM knowledge_entity TO knowledge_entity;
# SCHEMAFULL requires all fields to be defined, maybe start with SCHEMALESS if metadata might vary
DEFINE TABLE IF NOT EXISTS relates_to SCHEMALESS TYPE RELATION FROM knowledge_entity TO knowledge_entity; DEFINE FIELD IF NOT EXISTS in ON relates_to TYPE record<knowledge_entity>;
DEFINE FIELD IF NOT EXISTS out ON relates_to TYPE record<knowledge_entity>;
# Define the metadata field within the edge
# RelationshipMetadata is a struct, store as object # RelationshipMetadata is a struct, store as object
DEFINE FIELD IF NOT EXISTS metadata ON relates_to TYPE object; DEFINE FIELD IF NOT EXISTS metadata ON relates_to TYPE object;
DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string;
# Optionally, define fields within the metadata object for stricter schema (requires SCHEMAFULL on table) DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;
# DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string; DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;
# DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;
# DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;
# Add indexes based on query patterns (delete_relationships_by_source_id, get_knowledge_relationships) # Add indexes based on query patterns (delete_relationships_by_source_id, get_knowledge_relationships)
DEFINE INDEX IF NOT EXISTS relates_to_metadata_source_id_idx ON relates_to FIELDS metadata.source_id; DEFINE INDEX IF NOT EXISTS relates_to_metadata_source_id_idx ON relates_to FIELDS metadata.source_id;
+27 -2
View File
@@ -11,6 +11,31 @@ use crate::{error::AppError, storage::db::SurrealDbClient};
const INDEX_POLL_INTERVAL: Duration = Duration::from_millis(50); const INDEX_POLL_INTERVAL: Duration = Duration::from_millis(50);
const FTS_ANALYZER_NAME: &str = "app_en_fts_analyzer"; const FTS_ANALYZER_NAME: &str = "app_en_fts_analyzer";
/// HNSW index options used by runtime index creation (includes CONCURRENTLY).
pub const HNSW_INDEX_OPTIONS: &str = "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY";
/// HNSW index options for use inside transactions (CONCURRENTLY not supported).
pub const HNSW_INDEX_OPTIONS_SYNC: &str = "DIST COSINE TYPE F32 EFC 100 M 8";
/// Builds a `DEFINE INDEX OVERWRITE ... HNSW` statement matching runtime index options.
#[must_use]
pub fn hnsw_index_overwrite_sql(index_name: &str, table: &str, dimension: usize) -> String {
format!(
"DEFINE INDEX OVERWRITE {index_name} ON TABLE {table} \
FIELDS embedding HNSW DIMENSION {dimension} {HNSW_INDEX_OPTIONS};"
)
}
/// Recreates an HNSW index inside a transaction (for tests and dimension migrations).
#[must_use]
pub fn hnsw_index_redefine_transaction_sql(index_name: &str, table: &str, dimension: usize) -> String {
format!(
"BEGIN TRANSACTION;
REMOVE INDEX IF EXISTS {index_name} ON TABLE {table};
DEFINE INDEX {index_name} ON TABLE {table} FIELDS embedding HNSW DIMENSION {dimension} {HNSW_INDEX_OPTIONS_SYNC};
COMMIT TRANSACTION;"
)
}
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
struct HnswIndexSpec { struct HnswIndexSpec {
index_name: &'static str, index_name: &'static str,
@@ -23,12 +48,12 @@ const fn hnsw_index_specs() -> [HnswIndexSpec; 2] {
HnswIndexSpec { HnswIndexSpec {
index_name: "idx_embedding_text_chunk_embedding", index_name: "idx_embedding_text_chunk_embedding",
table: "text_chunk_embedding", table: "text_chunk_embedding",
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY", options: HNSW_INDEX_OPTIONS,
}, },
HnswIndexSpec { HnswIndexSpec {
index_name: "idx_embedding_knowledge_entity_embedding", index_name: "idx_embedding_knowledge_entity_embedding",
table: "knowledge_entity_embedding", table: "knowledge_entity_embedding",
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY", options: HNSW_INDEX_OPTIONS,
}, },
] ]
} }
+20 -7
View File
@@ -518,8 +518,8 @@ mod tests {
#[test] #[test]
fn test_sidebar_conversation_deserializes_plain_string_id() { fn test_sidebar_conversation_deserializes_plain_string_id() {
let item: SidebarConversation = let item: SidebarConversation = serde_json::from_str(r#"{"id":"conv-plain","title":"My chat"}"#)
serde_json::from_str(r#"{"id":"conv-plain","title":"My chat"}"#).unwrap(); .expect("valid sidebar conversation json");
assert_eq!(item.id, "conv-plain"); assert_eq!(item.id, "conv-plain");
assert_eq!(item.title, "My chat"); assert_eq!(item.title, "My chat");
} }
@@ -543,8 +543,9 @@ mod tests {
.await .await
.expect("Failed to load sidebar"); .expect("Failed to load sidebar");
assert_eq!(items.len(), 1); assert_eq!(items.len(), 1);
assert_eq!(items[0].id, expected_id); let item = items.first().expect("expected one sidebar item");
assert_eq!(items[0].title, "Sidebar title"); assert_eq!(item.id, expected_id);
assert_eq!(item.title, "Sidebar title");
} }
#[tokio::test] #[tokio::test]
@@ -570,7 +571,13 @@ mod tests {
let owner_messages = let owner_messages =
fetch_messages_for_owner(&db, &conversation_id, owner).await?; fetch_messages_for_owner(&db, &conversation_id, owner).await?;
assert_eq!(owner_messages.len(), 1); assert_eq!(owner_messages.len(), 1);
assert_eq!(owner_messages[0].content, "secret message"); assert_eq!(
owner_messages
.first()
.expect("expected owner message")
.content,
"secret message"
);
let intruder_messages = let intruder_messages =
fetch_messages_for_owner(&db, &conversation_id, intruder).await?; fetch_messages_for_owner(&db, &conversation_id, intruder).await?;
@@ -617,8 +624,14 @@ mod tests {
Conversation::get_complete_conversation(&conversation_id, user_id, &db).await?; Conversation::get_complete_conversation(&conversation_id, user_id, &db).await?;
assert_eq!(messages.len(), 2); assert_eq!(messages.len(), 2);
assert_eq!(messages[0].content, "first"); assert_eq!(
assert_eq!(messages[1].content, "second"); messages.first().expect("expected first message").content,
"first"
);
assert_eq!(
messages.get(1).expect("expected second message").content,
"second"
);
Ok(()) Ok(())
} }
+4 -6
View File
@@ -90,9 +90,9 @@ impl FileInfo {
/// Replaces any non-alphanumeric characters (excluding '.' and '_') with underscores in /// Replaces any non-alphanumeric characters (excluding '.' and '_') with underscores in
/// both the stem and extension. /// both the stem and extension.
fn sanitize_file_name(file_name: &str) -> String { fn sanitize_file_name(file_name: &str) -> String {
if let Some(idx) = file_name.rfind('.') { if let Some((stem, ext)) = file_name.rsplit_once('.') {
let name = Self::sanitize_name_segment(&file_name[..idx]); let name = Self::sanitize_name_segment(stem);
let ext = Self::sanitize_name_segment(&file_name[idx + 1..]); let ext = Self::sanitize_name_segment(ext);
if ext.is_empty() { if ext.is_empty() {
name name
} else { } else {
@@ -321,7 +321,6 @@ mod tests {
use anyhow::{self, Context}; use anyhow::{self, Context};
use super::*; use super::*;
use crate::error::AppError;
use crate::storage::store::testing::TestStorageManager; use crate::storage::store::testing::TestStorageManager;
use axum::http::HeaderMap; use axum::http::HeaderMap;
use axum_typed_multipart::FieldMetadata; use axum_typed_multipart::FieldMetadata;
@@ -844,8 +843,7 @@ mod tests {
assert!(file_info.sha256.len() == 64); assert!(file_info.sha256.len() == 64);
let bytes = file_info let bytes = file_info
.get_content_with_storage(test_storage.storage()) .get_content_with_storage(test_storage.storage())
.await .await?;
.map_err(AppError::from)?;
assert!(bytes.is_empty()); assert!(bytes.is_empty());
Ok(()) Ok(())
@@ -80,7 +80,7 @@ impl IngestionPayload {
}); });
for (index, file) in files.into_iter().enumerate() { for (index, file) in files.into_iter().enumerate() {
let is_last_file = index + 1 == file_count; let is_last_file = index.saturating_add(1) == file_count;
if content_follows || !is_last_file { if content_follows || !is_last_file {
let Some(shared) = fields.as_ref() else { let Some(shared) = fields.as_ref() else {
return Err(AppError::internal("shared ingest fields consumed early")); return Err(AppError::internal("shared ingest fields consumed early"));
@@ -411,7 +411,9 @@ mod tests {
)?; )?;
assert_eq!(result.len(), 2); assert_eq!(result.len(), 2);
match (&result[0], &result[1]) { let first = result.first().expect("expected first payload");
let second = result.get(1).expect("expected second payload");
match (first, second) {
( (
IngestionPayload::File { IngestionPayload::File {
file_info: payload_file, file_info: payload_file,
@@ -499,8 +501,8 @@ mod tests {
)?; )?;
assert_eq!(result.len(), 2); assert_eq!(result.len(), 2);
assert!(matches!(result[0], IngestionPayload::File { .. })); assert!(matches!(result.first(), Some(IngestionPayload::File { .. })));
assert!(matches!(result[1], IngestionPayload::File { .. })); assert!(matches!(result.get(1), Some(IngestionPayload::File { .. })));
Ok(()) Ok(())
} }
} }
+8 -6
View File
@@ -150,8 +150,8 @@ fn invalid_transition(state: TaskState, event: TaskTransition) -> AppError {
)) ))
} }
fn worker_id_for_bind(worker_id: &Option<String>) -> String { fn worker_id_for_bind(worker_id: Option<&String>) -> String {
worker_id.as_deref().unwrap_or("").to_string() worker_id.cloned().unwrap_or_default()
} }
stored_object!(IngestionTask, "ingestion_task", { stored_object!(IngestionTask, "ingestion_task", {
@@ -360,7 +360,7 @@ impl IngestionTask {
"#; "#;
let now = chrono::Utc::now(); let now = chrono::Utc::now();
let worker_id = worker_id_for_bind(&self.worker_id); let worker_id = worker_id_for_bind(self.worker_id.as_ref());
let mut result = db let mut result = db
.client .client
.query(START_PROCESSING_QUERY) .query(START_PROCESSING_QUERY)
@@ -398,7 +398,7 @@ impl IngestionTask {
"#; "#;
let now = chrono::Utc::now(); let now = chrono::Utc::now();
let worker_id = worker_id_for_bind(&self.worker_id); let worker_id = worker_id_for_bind(self.worker_id.as_ref());
let mut result = db let mut result = db
.client .client
.query(COMPLETE_QUERY) .query(COMPLETE_QUERY)
@@ -450,7 +450,7 @@ impl IngestionTask {
) )
.unwrap_or(now); .unwrap_or(now);
let worker_id = worker_id_for_bind(&self.worker_id); let worker_id = worker_id_for_bind(self.worker_id.as_ref());
let mut result = db let mut result = db
.client .client
.query(FAIL_QUERY) .query(FAIL_QUERY)
@@ -680,7 +680,9 @@ mod tests {
IngestionTask::create_all_and_add_to_db(payloads, user_id, &db).await?; IngestionTask::create_all_and_add_to_db(payloads, user_id, &db).await?;
assert_eq!(created.len(), 2); assert_eq!(created.len(), 2);
assert_ne!(created[0].id, created[1].id); let first = created.first().expect("expected first task");
let second = created.get(1).expect("expected second task");
assert_ne!(first.id, second.id);
for task in &created { for task in &created {
let stored: Option<IngestionTask> = db.get_item::<IngestionTask>(&task.id).await?; let stored: Option<IngestionTask> = db.get_item::<IngestionTask>(&task.id).await?;
+97 -41
View File
@@ -7,7 +7,9 @@ use std::fmt::Write;
use crate::{ use crate::{
error::AppError, storage::db::SurrealDbClient, error::AppError, storage::db::SurrealDbClient,
storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding, stored_object, storage::indexes::hnsw_index_overwrite_sql,
storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding,
storage::types::system_settings::SystemSettings, stored_object,
utils::embedding::generate_embedding, utils::embedding::generate_embedding,
}; };
use async_openai::{config::OpenAIConfig, Client}; use async_openai::{config::OpenAIConfig, Client};
@@ -189,13 +191,25 @@ impl KnowledgeEntity {
embedding: Vec<f32>, embedding: Vec<f32>,
db: &SurrealDbClient, db: &SurrealDbClient,
) -> Result<(), AppError> { ) -> Result<(), AppError> {
let emb = KnowledgeEntityEmbedding::new(&entity.id, embedding, entity.user_id.clone()); let settings = SystemSettings::get_current(db).await?;
KnowledgeEntityEmbedding::validate_dimension(
&embedding,
settings.embedding_dimensions as usize,
)?;
let entity_id = entity.id.clone();
let emb = KnowledgeEntityEmbedding::new(
&entity_id,
entity.source_id.clone(),
embedding,
entity.user_id.clone(),
);
let query = format!( let query = format!(
" "
BEGIN TRANSACTION; BEGIN TRANSACTION;
CREATE type::thing('{entity_table}', $entity_id) CONTENT $entity; CREATE type::thing('{entity_table}', $entity_id) CONTENT $entity;
CREATE type::thing('{emb_table}', $emb_id) CONTENT $emb; UPSERT type::thing('{emb_table}', $entity_id) CONTENT $emb;
COMMIT TRANSACTION; COMMIT TRANSACTION;
", ",
entity_table = Self::table_name(), entity_table = Self::table_name(),
@@ -204,9 +218,8 @@ impl KnowledgeEntity {
db.client db.client
.query(query) .query(query)
.bind(("entity_id", entity.id.clone())) .bind(("entity_id", entity_id))
.bind(("entity", entity)) .bind(("entity", entity))
.bind(("emb_id", emb.id.clone()))
.bind(("emb", emb)) .bind(("emb", emb))
.await .await
.map_err(AppError::Database)? .map_err(AppError::Database)?
@@ -279,8 +292,25 @@ impl KnowledgeEntity {
"name: {name}, description: {description}, type: {entity_type:?}", "name: {name}, description: {description}, type: {entity_type:?}",
); );
let embedding = generate_embedding(ai_client, &embedding_input, db_client).await?; let embedding = generate_embedding(ai_client, &embedding_input, db_client).await?;
let user_id = Self::get_user_id_by_id(id, db_client).await?;
let emb = KnowledgeEntityEmbedding::new(id, embedding, user_id); let entity: KnowledgeEntity = db_client
.get_item(id)
.await
.map_err(AppError::Database)?
.ok_or_else(|| AppError::NotFound(format!("entity {id} not found")))?;
let settings = SystemSettings::get_current(db_client).await?;
KnowledgeEntityEmbedding::validate_dimension(
&embedding,
settings.embedding_dimensions as usize,
)?;
let emb = KnowledgeEntityEmbedding::new(
id,
entity.source_id,
embedding,
entity.user_id,
);
let now = Utc::now(); let now = Utc::now();
@@ -293,7 +323,7 @@ impl KnowledgeEntity {
description = $description, description = $description,
updated_at = $updated_at, updated_at = $updated_at,
entity_type = $entity_type; entity_type = $entity_type;
UPSERT type::thing($emb_table, $emb_id) CONTENT $emb; UPSERT type::thing($emb_table, $id) CONTENT $emb;
COMMIT TRANSACTION;", COMMIT TRANSACTION;",
) )
.bind(("table", Self::table_name())) .bind(("table", Self::table_name()))
@@ -302,33 +332,16 @@ impl KnowledgeEntity {
.bind(("name", name.to_string())) .bind(("name", name.to_string()))
.bind(("updated_at", surrealdb::Datetime::from(now))) .bind(("updated_at", surrealdb::Datetime::from(now)))
.bind(("entity_type", entity_type.to_owned())) .bind(("entity_type", entity_type.to_owned()))
.bind(("emb_id", emb.id.clone()))
.bind(("emb", emb)) .bind(("emb", emb))
.bind(("description", description.to_string())) .bind(("description", description.to_string()))
.await?; .await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
Ok(()) Ok(())
} }
async fn get_user_id_by_id(id: &str, db_client: &SurrealDbClient) -> Result<String, AppError> {
#[derive(Deserialize)]
struct Row {
user_id: String,
}
let mut response = db_client
.client
.query("SELECT user_id FROM type::thing($table, $id) LIMIT 1")
.bind(("table", Self::table_name()))
.bind(("id", id.to_string()))
.await
.map_err(AppError::Database)?;
let rows: Vec<Row> = response.take(0).map_err(AppError::Database)?;
rows.first()
.map(|r| r.user_id.clone())
.ok_or_else(|| AppError::internal("user not found for entity"))
}
/// Re-creates embeddings for all knowledge entities in the database. /// Re-creates embeddings for all knowledge entities in the database.
/// ///
/// This is a costly operation that should be run in the background. It follows the same /// This is a costly operation that should be run in the background. It follows the same
@@ -359,7 +372,7 @@ impl KnowledgeEntity {
info!("Found {total_entities} entities to process."); info!("Found {total_entities} entities to process.");
// Generate all new embeddings in memory // Generate all new embeddings in memory
let mut new_embeddings: HashMap<String, (Vec<f32>, String)> = HashMap::new(); let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new();
info!("Generating new embeddings for all entities..."); info!("Generating new embeddings for all entities...");
for entity in &all_entities { for entity in &all_entities {
let embedding_input = format!( let embedding_input = format!(
@@ -387,7 +400,14 @@ impl KnowledgeEntity {
error!("{err_msg}"); error!("{err_msg}");
return Err(AppError::internal(err_msg)); return Err(AppError::internal(err_msg));
} }
new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone())); new_embeddings.insert(
entity.id.clone(),
(
embedding,
entity.user_id.clone(),
entity.source_id.clone(),
),
);
} }
info!("Successfully generated all new embeddings."); info!("Successfully generated all new embeddings.");
@@ -396,7 +416,7 @@ impl KnowledgeEntity {
let mut transaction_query = String::from("BEGIN TRANSACTION;"); let mut transaction_query = String::from("BEGIN TRANSACTION;");
// Add all update statements to the embedding table // Add all update statements to the embedding table
for (id, (embedding, user_id)) in new_embeddings { for (id, (embedding, user_id, source_id)) in new_embeddings {
let embedding = serde_json::to_string(&embedding).map_err(|e| { let embedding = serde_json::to_string(&embedding).map_err(|e| {
AppError::internal(format!("embedding serialization failed: {e}")) AppError::internal(format!("embedding serialization failed: {e}"))
})?; })?;
@@ -406,6 +426,7 @@ impl KnowledgeEntity {
entity_id = type::thing('knowledge_entity', '{id}'), \ entity_id = type::thing('knowledge_entity', '{id}'), \
embedding = {embedding}, \ embedding = {embedding}, \
user_id = '{user_id}', \ user_id = '{user_id}', \
source_id = '{source_id}', \
created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \ created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \
updated_at = time::now();", updated_at = time::now();",
) )
@@ -414,7 +435,12 @@ impl KnowledgeEntity {
write!( write!(
transaction_query, transaction_query,
"DEFINE INDEX OVERWRITE idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};", "{}",
hnsw_index_overwrite_sql(
"idx_embedding_knowledge_entity_embedding",
KnowledgeEntityEmbedding::table_name(),
new_dimensions as usize,
)
) )
.map_err(AppError::internal)?; .map_err(AppError::internal)?;
@@ -431,6 +457,7 @@ impl KnowledgeEntity {
/// ///
/// This variant uses the application's configured embedding provider (FastEmbed, OpenAI, etc.) /// This variant uses the application's configured embedding provider (FastEmbed, OpenAI, etc.)
/// instead of directly calling OpenAI. Used during startup when embedding configuration changes. /// instead of directly calling OpenAI. Used during startup when embedding configuration changes.
#[allow(clippy::too_many_lines)]
pub async fn update_all_embeddings_with_provider( pub async fn update_all_embeddings_with_provider(
db: &SurrealDbClient, db: &SurrealDbClient,
provider: &crate::utils::embedding::EmbeddingProvider, provider: &crate::utils::embedding::EmbeddingProvider,
@@ -453,7 +480,7 @@ impl KnowledgeEntity {
info!(entities = total_entities, "Found entities to process"); info!(entities = total_entities, "Found entities to process");
// Generate all new embeddings in memory // Generate all new embeddings in memory
let mut new_embeddings: HashMap<String, (Vec<f32>, String)> = HashMap::new(); let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new();
info!("Generating new embeddings for all entities..."); info!("Generating new embeddings for all entities...");
for (i, entity) in all_entities.iter().enumerate() { for (i, entity) in all_entities.iter().enumerate() {
@@ -484,7 +511,14 @@ impl KnowledgeEntity {
error!("{err_msg}"); error!("{err_msg}");
return Err(AppError::internal(err_msg)); return Err(AppError::internal(err_msg));
} }
new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone())); new_embeddings.insert(
entity.id.clone(),
(
embedding,
entity.user_id.clone(),
entity.source_id.clone(),
),
);
} }
info!("Successfully generated all new embeddings."); info!("Successfully generated all new embeddings.");
@@ -517,7 +551,7 @@ impl KnowledgeEntity {
info!("Applying embedding updates in a transaction..."); info!("Applying embedding updates in a transaction...");
let mut transaction_query = String::from("BEGIN TRANSACTION;"); let mut transaction_query = String::from("BEGIN TRANSACTION;");
for (id, (embedding, user_id)) in new_embeddings { for (id, (embedding, user_id, source_id)) in new_embeddings {
let embedding = serde_json::to_string(&embedding).map_err(|e| { let embedding = serde_json::to_string(&embedding).map_err(|e| {
AppError::internal(format!("embedding serialization failed: {e}")) AppError::internal(format!("embedding serialization failed: {e}"))
})?; })?;
@@ -527,6 +561,7 @@ impl KnowledgeEntity {
entity_id = type::thing('knowledge_entity', '{id}'), \ entity_id = type::thing('knowledge_entity', '{id}'), \
embedding = {embedding}, \ embedding = {embedding}, \
user_id = '{user_id}', \ user_id = '{user_id}', \
source_id = '{source_id}', \
created_at = time::now(), \ created_at = time::now(), \
updated_at = time::now();", updated_at = time::now();",
) )
@@ -535,7 +570,12 @@ impl KnowledgeEntity {
write!( write!(
transaction_query, transaction_query,
"DEFINE INDEX OVERWRITE idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};", "{}",
hnsw_index_overwrite_sql(
"idx_embedding_knowledge_entity_embedding",
KnowledgeEntityEmbedding::table_name(),
new_dimensions,
)
) )
.map_err(AppError::internal)?; .map_err(AppError::internal)?;
@@ -559,10 +599,21 @@ mod tests {
#![allow(clippy::expect_used, clippy::must_use_candidate)] #![allow(clippy::expect_used, clippy::must_use_candidate)]
use super::*; use super::*;
use crate::storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding; use crate::storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding;
use crate::storage::types::system_settings::SystemSettings;
use anyhow::{self, Context}; use anyhow::{self, Context};
use serde_json::json; use serde_json::json;
use uuid::Uuid; use uuid::Uuid;
async fn configure_embedding_dimension(
db: &SurrealDbClient,
dimension: u32,
) -> anyhow::Result<()> {
let mut settings = SystemSettings::get_current(db).await?;
settings.embedding_dimensions = dimension;
SystemSettings::update(db, settings).await?;
Ok(())
}
#[tokio::test] #[tokio::test]
async fn test_knowledge_entity_creation() -> anyhow::Result<()> { async fn test_knowledge_entity_creation() -> anyhow::Result<()> {
let source_id = "source123".to_string(); let source_id = "source123".to_string();
@@ -656,14 +707,15 @@ mod tests {
.await .await
.with_context(|| "Failed to apply migrations".to_string())?; .with_context(|| "Failed to apply migrations".to_string())?;
let source_id = "source123".to_string(); configure_embedding_dimension(&db, 5).await?;
let entity_type = KnowledgeEntityType::Document;
let user_id = "user123".to_string();
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 5) KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 5)
.await .await
.with_context(|| "Failed to redefine index length".to_string())?; .with_context(|| "Failed to redefine index length".to_string())?;
let source_id = "source123".to_string();
let entity_type = KnowledgeEntityType::Document;
let user_id = "user123".to_string();
let entity1 = KnowledgeEntity::new( let entity1 = KnowledgeEntity::new(
source_id.clone(), source_id.clone(),
"Entity 1".to_string(), "Entity 1".to_string(),
@@ -763,6 +815,7 @@ mod tests {
.await .await
.expect("Failed to apply migrations"); .expect("Failed to apply migrations");
configure_embedding_dimension(&db, 3).await.expect("configure dim");
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await .await
.expect("Failed to redefine index length"); .expect("Failed to redefine index length");
@@ -847,6 +900,7 @@ mod tests {
.await .await
.with_context(|| "Failed to apply migrations".to_string())?; .with_context(|| "Failed to apply migrations".to_string())?;
configure_embedding_dimension(&db, 3).await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await .await
.with_context(|| "Failed to redefine index length".to_string())?; .with_context(|| "Failed to redefine index length".to_string())?;
@@ -914,6 +968,7 @@ mod tests {
.await .await
.with_context(|| "Failed to apply migrations".to_string())?; .with_context(|| "Failed to apply migrations".to_string())?;
configure_embedding_dimension(&db, 3).await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await .await
.with_context(|| "Failed to redefine index length".to_string())?; .with_context(|| "Failed to redefine index length".to_string())?;
@@ -1012,6 +1067,7 @@ mod tests {
.await .await
.with_context(|| "Failed to apply migrations".to_string())?; .with_context(|| "Failed to apply migrations".to_string())?;
configure_embedding_dimension(&db, 3).await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await .await
.with_context(|| "Failed to redefine index length".to_string())?; .with_context(|| "Failed to redefine index length".to_string())?;
@@ -2,11 +2,17 @@ use std::collections::HashMap;
use surrealdb::RecordId; use surrealdb::RecordId;
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object}; use crate::{
error::AppError,
storage::{db::SurrealDbClient, indexes::hnsw_index_redefine_transaction_sql},
stored_object,
};
stored_object!(KnowledgeEntityEmbedding, "knowledge_entity_embedding", { stored_object!(KnowledgeEntityEmbedding, "knowledge_entity_embedding", {
entity_id: RecordId, entity_id: RecordId,
embedding: Vec<f32>, embedding: Vec<f32>,
/// Denormalized source id for bulk deletes
source_id: String,
/// Denormalized user id for query scoping /// Denormalized user id for query scoping
user_id: String user_id: String
}); });
@@ -17,12 +23,10 @@ impl KnowledgeEntityEmbedding {
db: &SurrealDbClient, db: &SurrealDbClient,
dimension: usize, dimension: usize,
) -> Result<(), AppError> { ) -> Result<(), AppError> {
let query = format!( let query = hnsw_index_redefine_transaction_sql(
"BEGIN TRANSACTION; "idx_embedding_knowledge_entity_embedding",
REMOVE INDEX IF EXISTS idx_embedding_knowledge_entity_embedding ON TABLE {table}; Self::table_name(),
DEFINE INDEX idx_embedding_knowledge_entity_embedding ON TABLE {table} FIELDS embedding HNSW DIMENSION {dimension}; dimension,
COMMIT TRANSACTION;",
table = Self::table_name(),
); );
let res = db.client.query(query).await.map_err(AppError::Database)?; let res = db.client.query(query).await.map_err(AppError::Database)?;
@@ -31,16 +35,36 @@ impl KnowledgeEntityEmbedding {
Ok(()) Ok(())
} }
/// Create a new knowledge entity embedding /// Validates that an embedding vector matches the configured HNSW dimension.
#[allow(clippy::result_large_err)]
pub fn validate_dimension(embedding: &[f32], expected: usize) -> Result<(), AppError> {
if embedding.len() != expected {
return Err(AppError::Validation(format!(
"embedding dimension mismatch: got {}, expected {expected}",
embedding.len()
)));
}
Ok(())
}
/// Create a new knowledge entity embedding.
///
/// The embedding record id equals `entity_id` so each entity has at most one embedding row.
#[must_use] #[must_use]
pub fn new(entity_id: &str, embedding: Vec<f32>, user_id: String) -> Self { pub fn new(
entity_id: &str,
source_id: String,
embedding: Vec<f32>,
user_id: String,
) -> Self {
let now = Utc::now(); let now = Utc::now();
Self { Self {
id: uuid::Uuid::new_v4().to_string(), id: entity_id.to_owned(),
created_at: now, created_at: now,
updated_at: now, updated_at: now,
entity_id: RecordId::from_table_key("knowledge_entity", entity_id), entity_id: RecordId::from_table_key("knowledge_entity", entity_id),
embedding, embedding,
source_id,
user_id, user_id,
} }
} }
@@ -73,8 +97,6 @@ impl KnowledgeEntityEmbedding {
return Ok(HashMap::new()); return Ok(HashMap::new());
} }
let ids_list: Vec<RecordId> = entity_ids.to_vec();
let query = format!( let query = format!(
"SELECT * FROM {} WHERE entity_id INSIDE $entity_ids", "SELECT * FROM {} WHERE entity_id INSIDE $entity_ids",
Self::table_name() Self::table_name()
@@ -82,7 +104,7 @@ impl KnowledgeEntityEmbedding {
let mut result = db let mut result = db
.client .client
.query(query) .query(query)
.bind(("entity_ids", ids_list)) .bind(("entity_ids", entity_ids.to_vec()))
.await .await
.map_err(AppError::Database)?; .map_err(AppError::Database)?;
let embeddings: Vec<Self> = result.take(0).map_err(AppError::Database)?; let embeddings: Vec<Self> = result.take(0).map_err(AppError::Database)?;
@@ -106,32 +128,28 @@ impl KnowledgeEntityEmbedding {
.query(query) .query(query)
.bind(("entity_id", entity_id.clone())) .bind(("entity_id", entity_id.clone()))
.await .await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?; .map_err(AppError::Database)?;
Ok(()) Ok(())
} }
/// Delete embeddings by source_id (via joining to knowledge_entity table) /// Delete all embeddings with the given denormalized `source_id`.
pub async fn delete_by_source_id( pub async fn delete_by_source_id(
source_id: &str, source_id: &str,
db: &SurrealDbClient, db: &SurrealDbClient,
) -> Result<(), AppError> { ) -> Result<(), AppError> {
#[derive(Deserialize)] let query = format!(
struct IdRow { "DELETE FROM {} WHERE source_id = $source_id",
id: RecordId, Self::table_name()
} );
db.client
let query = "SELECT id FROM knowledge_entity WHERE source_id = $source_id";
let mut res = db
.client
.query(query) .query(query)
.bind(("source_id", source_id.to_owned())) .bind(("source_id", source_id.to_owned()))
.await .await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?; .map_err(AppError::Database)?;
let ids: Vec<IdRow> = res.take(0).map_err(AppError::Database)?;
for row in ids {
Self::delete_by_entity_id(&row.id, db).await?;
}
Ok(()) Ok(())
} }
} }
@@ -142,6 +160,7 @@ mod tests {
use super::*; use super::*;
use crate::storage::db::SurrealDbClient; use crate::storage::db::SurrealDbClient;
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}; use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
use crate::storage::types::system_settings::SystemSettings;
use anyhow::{self, Context}; use anyhow::{self, Context};
use chrono::Utc; use chrono::Utc;
use surrealdb::Value as SurrealValue; use surrealdb::Value as SurrealValue;
@@ -161,6 +180,24 @@ mod tests {
Ok(db) Ok(db)
} }
async fn setup_test_db_with_embedding_dimension(
dimension: u32,
) -> anyhow::Result<SurrealDbClient> {
let db = setup_test_db().await?;
let mut settings = SystemSettings::get_current(&db).await?;
settings.embedding_dimensions = dimension;
SystemSettings::update(&db, settings).await?;
Ok(db)
}
async fn prepare_test_db(dimension: u32) -> anyhow::Result<SurrealDbClient> {
let db = setup_test_db_with_embedding_dimension(dimension).await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, dimension as usize)
.await
.with_context(|| format!("set test index dimension to {dimension}"))?;
Ok(db)
}
fn build_knowledge_entity_with_id( fn build_knowledge_entity_with_id(
key: &str, key: &str,
source_id: &str, source_id: &str,
@@ -179,12 +216,27 @@ mod tests {
} }
} }
#[test]
fn new_uses_entity_id_as_record_id() {
let emb = KnowledgeEntityEmbedding::new(
"entity-abc",
"source-1".to_owned(),
vec![0.1, 0.2],
"user-1".to_owned(),
);
assert_eq!(emb.id, "entity-abc");
}
#[test]
fn validate_dimension_rejects_mismatch() {
let err = KnowledgeEntityEmbedding::validate_dimension(&[0.1, 0.2, 0.3], 2)
.expect_err("expected dimension mismatch");
assert!(matches!(err, AppError::Validation(_)));
}
#[tokio::test] #[tokio::test]
async fn test_create_and_get_by_entity_id() -> anyhow::Result<()> { async fn test_create_and_get_by_entity_id() -> anyhow::Result<()> {
let db = setup_test_db().await?; let db = prepare_test_db(3).await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.with_context(|| "set test index dimension".to_string())?;
let user_id = "user_ke"; let user_id = "user_ke";
let entity_key = "entity-1"; let entity_key = "entity-1";
let source_id = "source-ke"; let source_id = "source-ke";
@@ -203,7 +255,9 @@ mod tests {
.with_context(|| "Failed to get embedding by entity_id".to_string())? .with_context(|| "Failed to get embedding by entity_id".to_string())?
.ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?; .ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?;
assert_eq!(fetched.id, entity_key);
assert_eq!(fetched.user_id, user_id); assert_eq!(fetched.user_id, user_id);
assert_eq!(fetched.source_id, source_id);
assert_eq!(fetched.entity_id, entity_rid); assert_eq!(fetched.entity_id, entity_rid);
assert_eq!(fetched.embedding, embedding_vec); assert_eq!(fetched.embedding, embedding_vec);
@@ -212,10 +266,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_delete_by_entity_id() -> anyhow::Result<()> { async fn test_delete_by_entity_id() -> anyhow::Result<()> {
let db = setup_test_db().await?; let db = prepare_test_db(3).await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.with_context(|| "set test index dimension".to_string())?;
let user_id = "user_ke"; let user_id = "user_ke";
let entity_key = "entity-delete"; let entity_key = "entity-delete";
let source_id = "source-del"; let source_id = "source-del";
@@ -247,15 +298,11 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_store_with_embedding_creates_entity_and_embedding() -> anyhow::Result<()> { async fn test_store_with_embedding_creates_entity_and_embedding() -> anyhow::Result<()> {
let db = setup_test_db().await?; let db = prepare_test_db(3).await?;
let user_id = "user_store"; let user_id = "user_store";
let source_id = "source_store"; let source_id = "source_store";
let embedding = vec![0.2_f32, 0.3, 0.4]; let embedding = vec![0.2_f32, 0.3, 0.4];
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, embedding.len())
.await
.with_context(|| "set test index dimension".to_string())?;
let entity = build_knowledge_entity_with_id("entity-store", source_id, user_id); let entity = build_knowledge_entity_with_id("entity-store", source_id, user_id);
KnowledgeEntity::store_with_embedding(entity.clone(), embedding.clone(), &db) KnowledgeEntity::store_with_embedding(entity.clone(), embedding.clone(), &db)
@@ -274,18 +321,30 @@ mod tests {
.with_context(|| "Failed to fetch embedding".to_string())?; .with_context(|| "Failed to fetch embedding".to_string())?;
let stored_embedding = let stored_embedding =
stored_embedding.ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?; stored_embedding.ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?;
assert_eq!(stored_embedding.id, entity.id);
assert_eq!(stored_embedding.user_id, user_id); assert_eq!(stored_embedding.user_id, user_id);
assert_eq!(stored_embedding.source_id, source_id);
assert_eq!(stored_embedding.entity_id, entity_rid); assert_eq!(stored_embedding.entity_id, entity_rid);
Ok(()) Ok(())
} }
#[tokio::test]
async fn test_store_with_embedding_rejects_wrong_dimension() -> anyhow::Result<()> {
let db = prepare_test_db(3).await?;
let entity = build_knowledge_entity_with_id("entity-dim", "source-dim", "user-dim");
let result =
KnowledgeEntity::store_with_embedding(entity, vec![0.1, 0.2], &db).await;
assert!(matches!(result, Err(AppError::Validation(_))));
Ok(())
}
#[tokio::test] #[tokio::test]
async fn test_delete_by_source_id() -> anyhow::Result<()> { async fn test_delete_by_source_id() -> anyhow::Result<()> {
let db = setup_test_db().await?; let db = prepare_test_db(3).await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.with_context(|| "set test index dimension".to_string())?;
let user_id = "user_ke"; let user_id = "user_ke";
let source_id = "shared-ke"; let source_id = "shared-ke";
let other_source = "other-ke"; let other_source = "other-ke";
@@ -363,6 +422,10 @@ mod tests {
idx_sql.contains("DIMENSION 16"), idx_sql.contains("DIMENSION 16"),
"expected index definition to contain new dimension, got: {idx_sql}" "expected index definition to contain new dimension, got: {idx_sql}"
); );
assert!(
idx_sql.contains("DIST COSINE"),
"expected index definition to use cosine distance, got: {idx_sql}"
);
Ok(()) Ok(())
} }
@@ -374,10 +437,7 @@ mod tests {
entity_id: KnowledgeEntity, entity_id: KnowledgeEntity,
} }
let db = setup_test_db().await?; let db = prepare_test_db(3).await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.with_context(|| "set test index dimension".to_string())?;
let user_id = "user_ke"; let user_id = "user_ke";
let entity_key = "entity-fetch"; let entity_key = "entity-fetch";
let source_id = "source-fetch"; let source_id = "source-fetch";
@@ -412,4 +472,47 @@ mod tests {
Ok(()) Ok(())
} }
#[tokio::test]
async fn test_upsert_replaces_existing_embedding_row() -> anyhow::Result<()> {
let db = prepare_test_db(3).await?;
let user_id = "user-upsert";
let source_id = "source-upsert";
let entity = build_knowledge_entity_with_id("entity-upsert", source_id, user_id);
KnowledgeEntity::store_with_embedding(entity.clone(), vec![1.0_f32, 0.0, 0.0], &db)
.await
.with_context(|| "initial store".to_string())?;
let replacement = KnowledgeEntityEmbedding::new(
&entity.id,
source_id.to_owned(),
vec![0.0, 1.0, 0.0],
user_id.to_owned(),
);
db.upsert_item(replacement)
.await
.with_context(|| "upsert replacement embedding".to_string())?;
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
let rows: Vec<KnowledgeEntityEmbedding> = db
.client
.query(format!(
"SELECT * FROM {} WHERE entity_id = $entity_id",
KnowledgeEntityEmbedding::table_name()
))
.bind(("entity_id", entity_rid))
.await
.with_context(|| "count embeddings".to_string())?
.take(0)
.with_context(|| "take embeddings".to_string())?;
assert_eq!(rows.len(), 1);
let row = rows.first().expect("expected one embedding row");
assert_eq!(row.id, entity.id);
assert_eq!(row.embedding, vec![0.0, 1.0, 0.0]);
Ok(())
}
} }
@@ -1,4 +1,5 @@
use crate::storage::types::serde_helpers::deserialize_flexible_id; use crate::storage::types::serde_helpers::deserialize_flexible_id;
use crate::storage::types::user::User;
use crate::{error::AppError, storage::db::SurrealDbClient}; use crate::{error::AppError, storage::db::SurrealDbClient};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use uuid::Uuid; use uuid::Uuid;
@@ -40,7 +41,21 @@ impl KnowledgeRelationship {
}, },
} }
} }
pub async fn store_relationship(&self, db_client: &SurrealDbClient) -> Result<(), AppError> { pub async fn store_relationship(&self, db_client: &SurrealDbClient) -> Result<(), AppError> {
User::get_and_validate_knowledge_entity(
&self.in_,
&self.metadata.user_id,
db_client,
)
.await?;
User::get_and_validate_knowledge_entity(
&self.out,
&self.metadata.user_id,
db_client,
)
.await?;
db_client db_client
.client .client
.query( .query(
@@ -61,22 +76,30 @@ impl KnowledgeRelationship {
.bind(("user_id", self.metadata.user_id.clone())) .bind(("user_id", self.metadata.user_id.clone()))
.bind(("source_id", self.metadata.source_id.clone())) .bind(("source_id", self.metadata.source_id.clone()))
.bind(("relationship_type", self.metadata.relationship_type.clone())) .bind(("relationship_type", self.metadata.relationship_type.clone()))
.await? .await
.check()?; .map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
Ok(()) Ok(())
} }
pub async fn delete_relationships_by_source_id( pub async fn delete_relationships_by_source_id(
source_id: &str, source_id: &str,
user_id: &str,
db_client: &SurrealDbClient, db_client: &SurrealDbClient,
) -> Result<(), AppError> { ) -> Result<(), AppError> {
db_client db_client
.client .client
.query("DELETE FROM relates_to WHERE metadata.source_id = $source_id") .query(
"DELETE FROM relates_to WHERE metadata.source_id = $source_id AND metadata.user_id = $user_id",
)
.bind(("source_id", source_id.to_owned())) .bind(("source_id", source_id.to_owned()))
.await? .bind(("user_id", user_id.to_owned()))
.check()?; .await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
Ok(()) Ok(())
} }
@@ -86,23 +109,30 @@ impl KnowledgeRelationship {
user_id: &str, user_id: &str,
db_client: &SurrealDbClient, db_client: &SurrealDbClient,
) -> Result<(), AppError> { ) -> Result<(), AppError> {
let mut authorized_result = db_client let mut delete_result = db_client
.client .client
.query( .query(
"SELECT * FROM relates_to WHERE id = type::thing('relates_to', $id) AND metadata.user_id = $user_id", "DELETE type::thing('relates_to', $id) WHERE metadata.user_id = $user_id RETURN BEFORE;",
) )
.bind(("id", id.to_owned())) .bind(("id", id.to_owned()))
.bind(("user_id", user_id.to_owned())) .bind(("user_id", user_id.to_owned()))
.await?; .await
let authorized: Vec<KnowledgeRelationship> = authorized_result.take(0)?; .map_err(AppError::Database)?;
let deleted: Vec<KnowledgeRelationship> =
delete_result.take(0).map_err(AppError::Database)?;
if !deleted.is_empty() {
return Ok(());
}
if authorized.is_empty() {
let mut exists_result = db_client let mut exists_result = db_client
.client .client
.query("SELECT * FROM type::thing('relates_to', $id)") .query("SELECT * FROM type::thing('relates_to', $id)")
.bind(("id", id.to_owned())) .bind(("id", id.to_owned()))
.await?; .await
let existing: Option<KnowledgeRelationship> = exists_result.take(0)?; .map_err(AppError::Database)?;
let existing: Option<KnowledgeRelationship> =
exists_result.take(0).map_err(AppError::Database)?;
if existing.is_some() { if existing.is_some() {
Err(AppError::Auth( Err(AppError::Auth(
@@ -111,15 +141,6 @@ impl KnowledgeRelationship {
} else { } else {
Err(AppError::NotFound(format!("Relationship {id} not found"))) Err(AppError::NotFound(format!("Relationship {id} not found")))
} }
} else {
db_client
.client
.query("DELETE type::thing('relates_to', $id)")
.bind(("id", id.to_owned()))
.await?
.check()?;
Ok(())
}
} }
} }
@@ -158,11 +179,14 @@ mod tests {
result.take(0).expect("failed to take relationship by id") result.take(0).expect("failed to take relationship by id")
} }
async fn create_test_entity(name: &str, db_client: &SurrealDbClient) -> anyhow::Result<String> { async fn create_test_entity(
name: &str,
user_id: &str,
db_client: &SurrealDbClient,
) -> anyhow::Result<String> {
let source_id = "source123".to_string(); let source_id = "source123".to_string();
let description = format!("Description for {name}"); let description = format!("Description for {name}");
let entity_type = KnowledgeEntityType::Document; let entity_type = KnowledgeEntityType::Document;
let user_id = "user123".to_string();
let entity = KnowledgeEntity::new( let entity = KnowledgeEntity::new(
source_id, source_id,
@@ -170,7 +194,7 @@ mod tests {
description, description,
entity_type, entity_type,
None, None,
user_id, user_id.to_string(),
); );
let stored: Option<KnowledgeEntity> = db_client let stored: Option<KnowledgeEntity> = db_client
@@ -211,18 +235,18 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_store_and_verify_by_source_id() -> anyhow::Result<()> { async fn test_store_and_verify_by_source_id() -> anyhow::Result<()> {
let db = setup_test_db().await; let db = setup_test_db().await;
let user_id = "user123";
let entity1_id = create_test_entity("Entity 1", &db).await?; let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
let entity2_id = create_test_entity("Entity 2", &db).await?; let entity2_id = create_test_entity("Entity 2", user_id, &db).await?;
let user_id = "user123".to_string();
let source_id = "source123".to_string(); let source_id = "source123".to_string();
let relationship_type = "references".to_string(); let relationship_type = "references".to_string();
let relationship = KnowledgeRelationship::new( let relationship = KnowledgeRelationship::new(
entity1_id.clone(), entity1_id.clone(),
entity2_id.clone(), entity2_id.clone(),
user_id.clone(), user_id.to_string(),
source_id.clone(), source_id.clone(),
relationship_type, relationship_type,
); );
@@ -257,16 +281,38 @@ mod tests {
} }
#[tokio::test] #[tokio::test]
async fn test_store_relationship_resists_query_injection() -> anyhow::Result<()> { async fn test_store_relationship_rejects_foreign_entity() -> anyhow::Result<()> {
let db = setup_test_db().await; let db = setup_test_db().await;
let entity1_id = create_test_entity("Entity 1", &db).await?; let owner_entity = create_test_entity("Owner entity", "owner-user", &db).await?;
let entity2_id = create_test_entity("Entity 2", &db).await?; let other_entity = create_test_entity("Other entity", "other-user", &db).await?;
let relationship = KnowledgeRelationship::new(
owner_entity,
other_entity,
"owner-user".to_string(),
"source123".to_string(),
"references".to_string(),
);
let result = relationship.store_relationship(&db).await;
assert!(matches!(result, Err(AppError::Auth(_))));
Ok(())
}
#[tokio::test]
async fn test_store_relationship_resists_query_injection() -> anyhow::Result<()> {
let db = setup_test_db().await;
let user_id = "user123";
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
let entity2_id = create_test_entity("Entity 2", user_id, &db).await?;
let relationship = KnowledgeRelationship::new( let relationship = KnowledgeRelationship::new(
entity1_id, entity1_id,
entity2_id, entity2_id,
"user'123".to_string(), user_id.to_string(),
"source123'; DELETE FROM relates_to; --".to_string(), "source123'; DELETE FROM relates_to; --".to_string(),
"references'; UPDATE user SET admin = true; --".to_string(), "references'; UPDATE user SET admin = true; --".to_string(),
); );
@@ -297,18 +343,18 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_store_and_delete_relationship() -> anyhow::Result<()> { async fn test_store_and_delete_relationship() -> anyhow::Result<()> {
let db = setup_test_db().await; let db = setup_test_db().await;
let user_id = "user123";
let entity1_id = create_test_entity("Entity 1", &db).await?; let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
let entity2_id = create_test_entity("Entity 2", &db).await?; let entity2_id = create_test_entity("Entity 2", user_id, &db).await?;
let user_id = "user123".to_string();
let source_id = "source123".to_string(); let source_id = "source123".to_string();
let relationship_type = "references".to_string(); let relationship_type = "references".to_string();
let relationship = KnowledgeRelationship::new( let relationship = KnowledgeRelationship::new(
entity1_id.clone(), entity1_id.clone(),
entity2_id.clone(), entity2_id.clone(),
user_id.clone(), user_id.to_string(),
source_id.clone(), source_id.clone(),
relationship_type, relationship_type,
); );
@@ -319,9 +365,9 @@ mod tests {
.with_context(|| "Failed to store relationship".to_string())?; .with_context(|| "Failed to store relationship".to_string())?;
let mut existing_before_delete = db let mut existing_before_delete = db
.query(format!( .query("SELECT * FROM relates_to WHERE metadata.user_id = $user_id AND metadata.source_id = $source_id")
"SELECT * FROM relates_to WHERE metadata.user_id = '{user_id}' AND metadata.source_id = '{source_id}'" .bind(("user_id", user_id.to_string()))
)) .bind(("source_id", source_id.clone()))
.await .await
.with_context(|| "Query failed".to_string())?; .with_context(|| "Query failed".to_string())?;
let before_results: Vec<KnowledgeRelationship> = let before_results: Vec<KnowledgeRelationship> =
@@ -331,14 +377,14 @@ mod tests {
"Relationship should exist before deletion" "Relationship should exist before deletion"
); );
KnowledgeRelationship::delete_relationship_by_id(&relationship.id, &user_id, &db) KnowledgeRelationship::delete_relationship_by_id(&relationship.id, user_id, &db)
.await .await
.with_context(|| "Failed to delete relationship by ID".to_string())?; .with_context(|| "Failed to delete relationship by ID".to_string())?;
let mut result = db let mut result = db
.query(format!( .query("SELECT * FROM relates_to WHERE metadata.user_id = $user_id AND metadata.source_id = $source_id")
"SELECT * FROM relates_to WHERE metadata.user_id = '{user_id}' AND metadata.source_id = '{source_id}'" .bind(("user_id", user_id.to_string()))
)) .bind(("source_id", source_id))
.await .await
.with_context(|| "Query failed".to_string())?; .with_context(|| "Query failed".to_string())?;
let results: Vec<KnowledgeRelationship> = result.take(0).unwrap_or_default(); let results: Vec<KnowledgeRelationship> = result.take(0).unwrap_or_default();
@@ -351,17 +397,17 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_delete_relationship_by_id_unauthorized() -> anyhow::Result<()> { async fn test_delete_relationship_by_id_unauthorized() -> anyhow::Result<()> {
let db = setup_test_db().await; let db = setup_test_db().await;
let owner_user_id = "owner-user";
let entity1_id = create_test_entity("Entity 1", &db).await?; let entity1_id = create_test_entity("Entity 1", owner_user_id, &db).await?;
let entity2_id = create_test_entity("Entity 2", &db).await?; let entity2_id = create_test_entity("Entity 2", owner_user_id, &db).await?;
let owner_user_id = "owner-user".to_string();
let source_id = "source123".to_string(); let source_id = "source123".to_string();
let relationship = KnowledgeRelationship::new( let relationship = KnowledgeRelationship::new(
entity1_id.clone(), entity1_id.clone(),
entity2_id.clone(), entity2_id.clone(),
owner_user_id.clone(), owner_user_id.to_string(),
source_id, source_id,
"references".to_string(), "references".to_string(),
); );
@@ -372,9 +418,8 @@ mod tests {
.with_context(|| "Failed to store relationship".to_string())?; .with_context(|| "Failed to store relationship".to_string())?;
let mut before_attempt = db let mut before_attempt = db
.query(format!( .query("SELECT * FROM relates_to WHERE metadata.user_id = $user_id")
"SELECT * FROM relates_to WHERE metadata.user_id = '{owner_user_id}'" .bind(("user_id", owner_user_id.to_string()))
))
.await .await
.with_context(|| "Query failed".to_string())?; .with_context(|| "Query failed".to_string())?;
let before_results: Vec<KnowledgeRelationship> = before_attempt.take(0).unwrap_or_default(); let before_results: Vec<KnowledgeRelationship> = before_attempt.take(0).unwrap_or_default();
@@ -398,9 +443,8 @@ mod tests {
} }
let mut after_attempt = db let mut after_attempt = db
.query(format!( .query("SELECT * FROM relates_to WHERE metadata.user_id = $user_id")
"SELECT * FROM relates_to WHERE metadata.user_id = '{owner_user_id}'" .bind(("user_id", owner_user_id.to_string()))
))
.await .await
.with_context(|| "Query failed".to_string())?; .with_context(|| "Query failed".to_string())?;
let results: Vec<KnowledgeRelationship> = after_attempt.take(0).unwrap_or_default(); let results: Vec<KnowledgeRelationship> = after_attempt.take(0).unwrap_or_default();
@@ -416,19 +460,19 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_store_relationship_exists() -> anyhow::Result<()> { async fn test_store_relationship_exists() -> anyhow::Result<()> {
let db = setup_test_db().await; let db = setup_test_db().await;
let user_id = "user123";
let entity1_id = create_test_entity("Entity 1", &db).await?; let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
let entity2_id = create_test_entity("Entity 2", &db).await?; let entity2_id = create_test_entity("Entity 2", user_id, &db).await?;
let entity3_id = create_test_entity("Entity 3", &db).await?; let entity3_id = create_test_entity("Entity 3", user_id, &db).await?;
let user_id = "user123".to_string();
let source_id = "source123".to_string(); let source_id = "source123".to_string();
let different_source_id = "different_source".to_string(); let different_source_id = "different_source".to_string();
let relationship1 = KnowledgeRelationship::new( let relationship1 = KnowledgeRelationship::new(
entity1_id.clone(), entity1_id.clone(),
entity2_id.clone(), entity2_id.clone(),
user_id.clone(), user_id.to_string(),
source_id.clone(), source_id.clone(),
"references".to_string(), "references".to_string(),
); );
@@ -436,7 +480,7 @@ mod tests {
let relationship2 = KnowledgeRelationship::new( let relationship2 = KnowledgeRelationship::new(
entity2_id.clone(), entity2_id.clone(),
entity3_id.clone(), entity3_id.clone(),
user_id.clone(), user_id.to_string(),
source_id.clone(), source_id.clone(),
"contains".to_string(), "contains".to_string(),
); );
@@ -444,7 +488,7 @@ mod tests {
let different_relationship = KnowledgeRelationship::new( let different_relationship = KnowledgeRelationship::new(
entity1_id.clone(), entity1_id.clone(),
entity3_id.clone(), entity3_id.clone(),
user_id.clone(), user_id.to_string(),
different_source_id.clone(), different_source_id.clone(),
"mentions".to_string(), "mentions".to_string(),
); );
@@ -480,7 +524,7 @@ mod tests {
before_delete_different.take(0).unwrap_or_default(); before_delete_different.take(0).unwrap_or_default();
assert_eq!(before_delete_different_rows.len(), 1); assert_eq!(before_delete_different_rows.len(), 1);
KnowledgeRelationship::delete_relationships_by_source_id(&source_id, &db) KnowledgeRelationship::delete_relationships_by_source_id(&source_id, user_id, &db)
.await .await
.with_context(|| "Failed to delete relationships by source_id".to_string())?; .with_context(|| "Failed to delete relationships by source_id".to_string())?;
@@ -497,19 +541,60 @@ mod tests {
Ok(()) Ok(())
} }
#[tokio::test]
async fn test_delete_relationships_by_source_id_scoped_to_user() -> anyhow::Result<()> {
let db = setup_test_db().await;
let user_a = "user-a";
let user_b = "user-b";
let shared_source = "shared-source";
let a1 = create_test_entity("A1", user_a, &db).await?;
let a2 = create_test_entity("A2", user_a, &db).await?;
let b1 = create_test_entity("B1", user_b, &db).await?;
let b2 = create_test_entity("B2", user_b, &db).await?;
let rel_a = KnowledgeRelationship::new(
a1,
a2,
user_a.to_string(),
shared_source.to_string(),
"references".to_string(),
);
let rel_b = KnowledgeRelationship::new(
b1,
b2,
user_b.to_string(),
shared_source.to_string(),
"references".to_string(),
);
rel_a.store_relationship(&db).await?;
rel_b.store_relationship(&db).await?;
KnowledgeRelationship::delete_relationships_by_source_id(shared_source, user_a, &db)
.await?;
assert!(get_relationship_by_id(&rel_a.id, &db).await.is_none());
assert!(get_relationship_by_id(&rel_b.id, &db).await.is_some());
Ok(())
}
#[tokio::test] #[tokio::test]
async fn test_delete_relationships_by_source_id_resists_query_injection() -> anyhow::Result<()> async fn test_delete_relationships_by_source_id_resists_query_injection() -> anyhow::Result<()>
{ {
let db = setup_test_db().await; let db = setup_test_db().await;
let user_id = "user123";
let entity1_id = create_test_entity("Entity 1", &db).await?; let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
let entity2_id = create_test_entity("Entity 2", &db).await?; let entity2_id = create_test_entity("Entity 2", user_id, &db).await?;
let entity3_id = create_test_entity("Entity 3", &db).await?; let entity3_id = create_test_entity("Entity 3", user_id, &db).await?;
let safe_relationship = KnowledgeRelationship::new( let safe_relationship = KnowledgeRelationship::new(
entity1_id.clone(), entity1_id.clone(),
entity2_id.clone(), entity2_id.clone(),
"user123".to_string(), user_id.to_string(),
"safe_source".to_string(), "safe_source".to_string(),
"references".to_string(), "references".to_string(),
); );
@@ -517,7 +602,7 @@ mod tests {
let other_relationship = KnowledgeRelationship::new( let other_relationship = KnowledgeRelationship::new(
entity2_id, entity2_id,
entity3_id, entity3_id,
"user123".to_string(), user_id.to_string(),
"other_source".to_string(), "other_source".to_string(),
"contains".to_string(), "contains".to_string(),
); );
@@ -531,7 +616,11 @@ mod tests {
.await .await
.expect("store other relationship"); .expect("store other relationship");
KnowledgeRelationship::delete_relationships_by_source_id("safe_source' OR 1=1 --", &db) KnowledgeRelationship::delete_relationships_by_source_id(
"safe_source' OR 1=1 --",
user_id,
&db,
)
.await .await
.expect("delete call should succeed"); .expect("delete call should succeed");
+1 -1
View File
@@ -62,7 +62,7 @@ impl fmt::Display for Message {
pub fn format_history(history: &[Message]) -> String { pub fn format_history(history: &[Message]) -> String {
let estimated: usize = history let estimated: usize = history
.iter() .iter()
.map(|m| m.content.len() + 10) .map(|m| m.content.len().saturating_add(10))
.sum(); .sum();
let mut out = String::with_capacity(estimated); let mut out = String::with_capacity(estimated);
for (i, msg) in history.iter().enumerate() { for (i, msg) in history.iter().enumerate() {
+1
View File
@@ -308,6 +308,7 @@ fn build_manifest_batches(manifest: &CorpusManifest) -> Result<ManifestBatches>
entities.push(entity.clone()); entities.push(entity.clone());
entity_embeddings.push(KnowledgeEntityEmbedding::new( entity_embeddings.push(KnowledgeEntityEmbedding::new(
&entity.id, &entity.id,
entity.source_id.clone(),
embedded_entity.embedding.clone(), embedded_entity.embedding.clone(),
entity.user_id.clone(), entity.user_id.clone(),
)); ));
+1 -1
View File
@@ -86,7 +86,7 @@ pub async fn delete_text_content(
// Delete the text content and any related data // Delete the text content and any related data
TextChunk::delete_by_source_id(&text_content.id, &state.db).await?; TextChunk::delete_by_source_id(&text_content.id, &state.db).await?;
KnowledgeEntity::delete_by_source_id(&text_content.id, &state.db).await?; KnowledgeEntity::delete_by_source_id(&text_content.id, &state.db).await?;
KnowledgeRelationship::delete_relationships_by_source_id(&text_content.id, &state.db).await?; KnowledgeRelationship::delete_relationships_by_source_id(&text_content.id, &user.id, &state.db).await?;
state state
.db .db
.delete_item::<TextContent>(&text_content.id) .delete_item::<TextContent>(&text_content.id)
+1 -1
View File
@@ -293,7 +293,7 @@ async fn reserve_task(
payload: IngestionPayload, payload: IngestionPayload,
user_id: &str, user_id: &str,
) -> anyhow::Result<IngestionTask> { ) -> anyhow::Result<IngestionTask> {
let task = IngestionTask::create_and_add_to_db(payload, user_id.into(), db).await?; let task = IngestionTask::create_and_add_to_db(payload, user_id, db).await?;
let lease = task.lease_duration(); let lease = task.lease_duration();
let claimed = IngestionTask::claim_next_ready(db, worker_id, Utc::now(), lease) let claimed = IngestionTask::claim_next_ready(db, worker_id, Utc::now(), lease)
.await? .await?