chore: harden knowledge graph storage and clear common clippy warnings

Enforce stable 1:1 entity embeddings, relationship endpoint auth, and
user-scoped deletes; align schemas/migrations and resolve common crate
clippy findings.
This commit is contained in:
Per Stark
2026-05-28 21:46:35 +02:00
parent 189adb1a5f
commit 5724f11dc1
17 changed files with 533 additions and 209 deletions
@@ -0,0 +1,33 @@
-- Harden knowledge entity embeddings and graph storage invariants.
DEFINE FIELD IF NOT EXISTS source_id ON knowledge_entity_embedding TYPE string;
-- Backfill denormalized source_id from the linked entity.
FOR $emb IN (SELECT * FROM knowledge_entity_embedding WHERE source_id = NONE OR source_id = '') {
LET $entity = (SELECT source_id FROM $emb.entity_id)[0];
IF $entity != NONE {
UPDATE $emb.id SET source_id = $entity.source_id;
}
};
-- Re-key embeddings so record id matches entity id (stable 1:1 identity).
FOR $emb IN (SELECT * FROM knowledge_entity_embedding) {
LET $entity_key = record::id($emb.entity_id);
LET $canonical = type::thing('knowledge_entity_embedding', $entity_key);
IF $emb.id != $canonical {
UPSERT $canonical CONTENT {
entity_id: $emb.entity_id,
embedding: $emb.embedding,
user_id: $emb.user_id,
source_id: $emb.source_id,
created_at: $emb.created_at,
updated_at: $emb.updated_at
};
DELETE $emb.id;
}
};
REMOVE INDEX IF EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding;
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id UNIQUE;
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_source_id_idx ON knowledge_entity_embedding FIELDS source_id;
DEFINE INDEX IF NOT EXISTS knowledge_entity_user_source_idx ON knowledge_entity FIELDS user_id, source_id;
@@ -0,0 +1 @@
{"schemas":"--- original\n+++ modified\n@@ -68,7 +68,7 @@\n\n # Defines the schema for the 'knowledge_entity' table.\n\n-DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMALESS;\n+DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMAFULL;\n\n # Standard fields\n DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE datetime;\n@@ -90,6 +90,7 @@\n -- DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;\n+DEFINE INDEX IF NOT EXISTS knowledge_entity_user_source_idx ON knowledge_entity FIELDS user_id, source_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;\n\n@@ -102,6 +103,7 @@\n DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity_embedding TYPE datetime;\n DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity_embedding TYPE datetime;\n DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity_embedding TYPE string;\n+DEFINE FIELD IF NOT EXISTS source_id ON knowledge_entity_embedding TYPE string;\n\n -- Custom fields\n DEFINE FIELD IF NOT EXISTS entity_id ON knowledge_entity_embedding TYPE record<knowledge_entity>;\n@@ -109,8 +111,9 @@\n\n -- Indexes\n -- DEFINE INDEX IF NOT EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 1536;\n-DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id;\n+DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id UNIQUE;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_user_id_idx ON knowledge_entity_embedding FIELDS user_id;\n+DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_source_id_idx ON knowledge_entity_embedding FIELDS source_id;\n\n # Defines the schema for the 'message' table.\n\n@@ -135,19 +138,17 @@\n # Defines the 'relates_to' edge table for KnowledgeRelationships.\n # Edges connect nodes, in this case knowledge_entity records.\n\n-# Define the edge table itself, enforcing connections between knowledge_entity records\n-# SCHEMAFULL requires all fields to be defined, maybe start with SCHEMALESS if metadata might vary\n-DEFINE TABLE IF NOT EXISTS relates_to SCHEMALESS TYPE RELATION FROM knowledge_entity TO knowledge_entity;\n+DEFINE TABLE IF NOT EXISTS relates_to SCHEMAFULL TYPE RELATION FROM knowledge_entity TO knowledge_entity;\n+\n+DEFINE FIELD IF NOT EXISTS in ON relates_to TYPE record<knowledge_entity>;\n+DEFINE FIELD IF NOT EXISTS out ON relates_to TYPE record<knowledge_entity>;\n\n-# Define the metadata field within the edge\n # RelationshipMetadata is a struct, store as object\n DEFINE FIELD IF NOT EXISTS metadata ON relates_to TYPE object;\n+DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string;\n+DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;\n+DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;\n\n-# Optionally, define fields within the metadata object for stricter schema (requires SCHEMAFULL on table)\n-# DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string;\n-# DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;\n-# DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;\n-\n # Add indexes based on query patterns (delete_relationships_by_source_id, get_knowledge_relationships)\n DEFINE INDEX IF NOT EXISTS relates_to_metadata_source_id_idx ON relates_to FIELDS metadata.source_id;\n DEFINE INDEX IF NOT EXISTS relates_to_metadata_user_id_idx ON relates_to FIELDS metadata.user_id;\n","events":null}
+2 -1
View File
@@ -1,6 +1,6 @@
# Defines the schema for the 'knowledge_entity' table.
DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMALESS;
DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMAFULL;
# Standard fields
DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE datetime;
@@ -22,5 +22,6 @@ DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity TYPE string;
-- DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536;
DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;
DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;
DEFINE INDEX IF NOT EXISTS knowledge_entity_user_source_idx ON knowledge_entity FIELDS user_id, source_id;
DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;
DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;
@@ -7,6 +7,7 @@ DEFINE TABLE IF NOT EXISTS knowledge_entity_embedding SCHEMAFULL;
DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity_embedding TYPE datetime;
DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity_embedding TYPE datetime;
DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity_embedding TYPE string;
DEFINE FIELD IF NOT EXISTS source_id ON knowledge_entity_embedding TYPE string;
-- Custom fields
DEFINE FIELD IF NOT EXISTS entity_id ON knowledge_entity_embedding TYPE record<knowledge_entity>;
@@ -14,5 +15,6 @@ DEFINE FIELD IF NOT EXISTS embedding ON knowledge_entity_embedding TYPE array<fl
-- Indexes
-- DEFINE INDEX IF NOT EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 1536;
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id;
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id UNIQUE;
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_user_id_idx ON knowledge_entity_embedding FIELDS user_id;
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_source_id_idx ON knowledge_entity_embedding FIELDS source_id;
+7 -9
View File
@@ -1,18 +1,16 @@
# Defines the 'relates_to' edge table for KnowledgeRelationships.
# Edges connect nodes, in this case knowledge_entity records.
# Define the edge table itself, enforcing connections between knowledge_entity records
# SCHEMAFULL requires all fields to be defined, maybe start with SCHEMALESS if metadata might vary
DEFINE TABLE IF NOT EXISTS relates_to SCHEMALESS TYPE RELATION FROM knowledge_entity TO knowledge_entity;
DEFINE TABLE IF NOT EXISTS relates_to SCHEMAFULL TYPE RELATION FROM knowledge_entity TO knowledge_entity;
DEFINE FIELD IF NOT EXISTS in ON relates_to TYPE record<knowledge_entity>;
DEFINE FIELD IF NOT EXISTS out ON relates_to TYPE record<knowledge_entity>;
# Define the metadata field within the edge
# RelationshipMetadata is a struct, store as object
DEFINE FIELD IF NOT EXISTS metadata ON relates_to TYPE object;
# Optionally, define fields within the metadata object for stricter schema (requires SCHEMAFULL on table)
# DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string;
# DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;
# DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;
DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string;
DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;
DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;
# Add indexes based on query patterns (delete_relationships_by_source_id, get_knowledge_relationships)
DEFINE INDEX IF NOT EXISTS relates_to_metadata_source_id_idx ON relates_to FIELDS metadata.source_id;
+27 -2
View File
@@ -11,6 +11,31 @@ use crate::{error::AppError, storage::db::SurrealDbClient};
const INDEX_POLL_INTERVAL: Duration = Duration::from_millis(50);
const FTS_ANALYZER_NAME: &str = "app_en_fts_analyzer";
/// HNSW index options used by runtime index creation (includes CONCURRENTLY).
pub const HNSW_INDEX_OPTIONS: &str = "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY";
/// HNSW index options for use inside transactions (CONCURRENTLY not supported).
pub const HNSW_INDEX_OPTIONS_SYNC: &str = "DIST COSINE TYPE F32 EFC 100 M 8";
/// Builds a `DEFINE INDEX OVERWRITE ... HNSW` statement matching runtime index options.
#[must_use]
pub fn hnsw_index_overwrite_sql(index_name: &str, table: &str, dimension: usize) -> String {
format!(
"DEFINE INDEX OVERWRITE {index_name} ON TABLE {table} \
FIELDS embedding HNSW DIMENSION {dimension} {HNSW_INDEX_OPTIONS};"
)
}
/// Recreates an HNSW index inside a transaction (for tests and dimension migrations).
#[must_use]
pub fn hnsw_index_redefine_transaction_sql(index_name: &str, table: &str, dimension: usize) -> String {
format!(
"BEGIN TRANSACTION;
REMOVE INDEX IF EXISTS {index_name} ON TABLE {table};
DEFINE INDEX {index_name} ON TABLE {table} FIELDS embedding HNSW DIMENSION {dimension} {HNSW_INDEX_OPTIONS_SYNC};
COMMIT TRANSACTION;"
)
}
#[derive(Clone, Copy)]
struct HnswIndexSpec {
index_name: &'static str,
@@ -23,12 +48,12 @@ const fn hnsw_index_specs() -> [HnswIndexSpec; 2] {
HnswIndexSpec {
index_name: "idx_embedding_text_chunk_embedding",
table: "text_chunk_embedding",
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY",
options: HNSW_INDEX_OPTIONS,
},
HnswIndexSpec {
index_name: "idx_embedding_knowledge_entity_embedding",
table: "knowledge_entity_embedding",
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY",
options: HNSW_INDEX_OPTIONS,
},
]
}
+20 -7
View File
@@ -518,8 +518,8 @@ mod tests {
#[test]
fn test_sidebar_conversation_deserializes_plain_string_id() {
let item: SidebarConversation =
serde_json::from_str(r#"{"id":"conv-plain","title":"My chat"}"#).unwrap();
let item: SidebarConversation = serde_json::from_str(r#"{"id":"conv-plain","title":"My chat"}"#)
.expect("valid sidebar conversation json");
assert_eq!(item.id, "conv-plain");
assert_eq!(item.title, "My chat");
}
@@ -543,8 +543,9 @@ mod tests {
.await
.expect("Failed to load sidebar");
assert_eq!(items.len(), 1);
assert_eq!(items[0].id, expected_id);
assert_eq!(items[0].title, "Sidebar title");
let item = items.first().expect("expected one sidebar item");
assert_eq!(item.id, expected_id);
assert_eq!(item.title, "Sidebar title");
}
#[tokio::test]
@@ -570,7 +571,13 @@ mod tests {
let owner_messages =
fetch_messages_for_owner(&db, &conversation_id, owner).await?;
assert_eq!(owner_messages.len(), 1);
assert_eq!(owner_messages[0].content, "secret message");
assert_eq!(
owner_messages
.first()
.expect("expected owner message")
.content,
"secret message"
);
let intruder_messages =
fetch_messages_for_owner(&db, &conversation_id, intruder).await?;
@@ -617,8 +624,14 @@ mod tests {
Conversation::get_complete_conversation(&conversation_id, user_id, &db).await?;
assert_eq!(messages.len(), 2);
assert_eq!(messages[0].content, "first");
assert_eq!(messages[1].content, "second");
assert_eq!(
messages.first().expect("expected first message").content,
"first"
);
assert_eq!(
messages.get(1).expect("expected second message").content,
"second"
);
Ok(())
}
+4 -6
View File
@@ -90,9 +90,9 @@ impl FileInfo {
/// Replaces any non-alphanumeric characters (excluding '.' and '_') with underscores in
/// both the stem and extension.
fn sanitize_file_name(file_name: &str) -> String {
if let Some(idx) = file_name.rfind('.') {
let name = Self::sanitize_name_segment(&file_name[..idx]);
let ext = Self::sanitize_name_segment(&file_name[idx + 1..]);
if let Some((stem, ext)) = file_name.rsplit_once('.') {
let name = Self::sanitize_name_segment(stem);
let ext = Self::sanitize_name_segment(ext);
if ext.is_empty() {
name
} else {
@@ -321,7 +321,6 @@ mod tests {
use anyhow::{self, Context};
use super::*;
use crate::error::AppError;
use crate::storage::store::testing::TestStorageManager;
use axum::http::HeaderMap;
use axum_typed_multipart::FieldMetadata;
@@ -844,8 +843,7 @@ mod tests {
assert!(file_info.sha256.len() == 64);
let bytes = file_info
.get_content_with_storage(test_storage.storage())
.await
.map_err(AppError::from)?;
.await?;
assert!(bytes.is_empty());
Ok(())
@@ -80,7 +80,7 @@ impl IngestionPayload {
});
for (index, file) in files.into_iter().enumerate() {
let is_last_file = index + 1 == file_count;
let is_last_file = index.saturating_add(1) == file_count;
if content_follows || !is_last_file {
let Some(shared) = fields.as_ref() else {
return Err(AppError::internal("shared ingest fields consumed early"));
@@ -411,7 +411,9 @@ mod tests {
)?;
assert_eq!(result.len(), 2);
match (&result[0], &result[1]) {
let first = result.first().expect("expected first payload");
let second = result.get(1).expect("expected second payload");
match (first, second) {
(
IngestionPayload::File {
file_info: payload_file,
@@ -499,8 +501,8 @@ mod tests {
)?;
assert_eq!(result.len(), 2);
assert!(matches!(result[0], IngestionPayload::File { .. }));
assert!(matches!(result[1], IngestionPayload::File { .. }));
assert!(matches!(result.first(), Some(IngestionPayload::File { .. })));
assert!(matches!(result.get(1), Some(IngestionPayload::File { .. })));
Ok(())
}
}
+8 -6
View File
@@ -150,8 +150,8 @@ fn invalid_transition(state: TaskState, event: TaskTransition) -> AppError {
))
}
fn worker_id_for_bind(worker_id: &Option<String>) -> String {
worker_id.as_deref().unwrap_or("").to_string()
fn worker_id_for_bind(worker_id: Option<&String>) -> String {
worker_id.cloned().unwrap_or_default()
}
stored_object!(IngestionTask, "ingestion_task", {
@@ -360,7 +360,7 @@ impl IngestionTask {
"#;
let now = chrono::Utc::now();
let worker_id = worker_id_for_bind(&self.worker_id);
let worker_id = worker_id_for_bind(self.worker_id.as_ref());
let mut result = db
.client
.query(START_PROCESSING_QUERY)
@@ -398,7 +398,7 @@ impl IngestionTask {
"#;
let now = chrono::Utc::now();
let worker_id = worker_id_for_bind(&self.worker_id);
let worker_id = worker_id_for_bind(self.worker_id.as_ref());
let mut result = db
.client
.query(COMPLETE_QUERY)
@@ -450,7 +450,7 @@ impl IngestionTask {
)
.unwrap_or(now);
let worker_id = worker_id_for_bind(&self.worker_id);
let worker_id = worker_id_for_bind(self.worker_id.as_ref());
let mut result = db
.client
.query(FAIL_QUERY)
@@ -680,7 +680,9 @@ mod tests {
IngestionTask::create_all_and_add_to_db(payloads, user_id, &db).await?;
assert_eq!(created.len(), 2);
assert_ne!(created[0].id, created[1].id);
let first = created.first().expect("expected first task");
let second = created.get(1).expect("expected second task");
assert_ne!(first.id, second.id);
for task in &created {
let stored: Option<IngestionTask> = db.get_item::<IngestionTask>(&task.id).await?;
+100 -44
View File
@@ -7,7 +7,9 @@ use std::fmt::Write;
use crate::{
error::AppError, storage::db::SurrealDbClient,
storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding, stored_object,
storage::indexes::hnsw_index_overwrite_sql,
storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding,
storage::types::system_settings::SystemSettings, stored_object,
utils::embedding::generate_embedding,
};
use async_openai::{config::OpenAIConfig, Client};
@@ -189,13 +191,25 @@ impl KnowledgeEntity {
embedding: Vec<f32>,
db: &SurrealDbClient,
) -> Result<(), AppError> {
let emb = KnowledgeEntityEmbedding::new(&entity.id, embedding, entity.user_id.clone());
let settings = SystemSettings::get_current(db).await?;
KnowledgeEntityEmbedding::validate_dimension(
&embedding,
settings.embedding_dimensions as usize,
)?;
let entity_id = entity.id.clone();
let emb = KnowledgeEntityEmbedding::new(
&entity_id,
entity.source_id.clone(),
embedding,
entity.user_id.clone(),
);
let query = format!(
"
BEGIN TRANSACTION;
CREATE type::thing('{entity_table}', $entity_id) CONTENT $entity;
CREATE type::thing('{emb_table}', $emb_id) CONTENT $emb;
UPSERT type::thing('{emb_table}', $entity_id) CONTENT $emb;
COMMIT TRANSACTION;
",
entity_table = Self::table_name(),
@@ -204,9 +218,8 @@ impl KnowledgeEntity {
db.client
.query(query)
.bind(("entity_id", entity.id.clone()))
.bind(("entity_id", entity_id))
.bind(("entity", entity))
.bind(("emb_id", emb.id.clone()))
.bind(("emb", emb))
.await
.map_err(AppError::Database)?
@@ -275,12 +288,29 @@ impl KnowledgeEntity {
db_client: &SurrealDbClient,
ai_client: &Client<OpenAIConfig>,
) -> Result<(), AppError> {
let embedding_input = format!(
"name: {name}, description: {description}, type: {entity_type:?}",
);
let embedding_input = format!(
"name: {name}, description: {description}, type: {entity_type:?}",
);
let embedding = generate_embedding(ai_client, &embedding_input, db_client).await?;
let user_id = Self::get_user_id_by_id(id, db_client).await?;
let emb = KnowledgeEntityEmbedding::new(id, embedding, user_id);
let entity: KnowledgeEntity = db_client
.get_item(id)
.await
.map_err(AppError::Database)?
.ok_or_else(|| AppError::NotFound(format!("entity {id} not found")))?;
let settings = SystemSettings::get_current(db_client).await?;
KnowledgeEntityEmbedding::validate_dimension(
&embedding,
settings.embedding_dimensions as usize,
)?;
let emb = KnowledgeEntityEmbedding::new(
id,
entity.source_id,
embedding,
entity.user_id,
);
let now = Utc::now();
@@ -293,7 +323,7 @@ impl KnowledgeEntity {
description = $description,
updated_at = $updated_at,
entity_type = $entity_type;
UPSERT type::thing($emb_table, $emb_id) CONTENT $emb;
UPSERT type::thing($emb_table, $id) CONTENT $emb;
COMMIT TRANSACTION;",
)
.bind(("table", Self::table_name()))
@@ -302,33 +332,16 @@ impl KnowledgeEntity {
.bind(("name", name.to_string()))
.bind(("updated_at", surrealdb::Datetime::from(now)))
.bind(("entity_type", entity_type.to_owned()))
.bind(("emb_id", emb.id.clone()))
.bind(("emb", emb))
.bind(("description", description.to_string()))
.await?;
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
Ok(())
}
async fn get_user_id_by_id(id: &str, db_client: &SurrealDbClient) -> Result<String, AppError> {
#[derive(Deserialize)]
struct Row {
user_id: String,
}
let mut response = db_client
.client
.query("SELECT user_id FROM type::thing($table, $id) LIMIT 1")
.bind(("table", Self::table_name()))
.bind(("id", id.to_string()))
.await
.map_err(AppError::Database)?;
let rows: Vec<Row> = response.take(0).map_err(AppError::Database)?;
rows.first()
.map(|r| r.user_id.clone())
.ok_or_else(|| AppError::internal("user not found for entity"))
}
/// Re-creates embeddings for all knowledge entities in the database.
///
/// This is a costly operation that should be run in the background. It follows the same
@@ -359,7 +372,7 @@ impl KnowledgeEntity {
info!("Found {total_entities} entities to process.");
// Generate all new embeddings in memory
let mut new_embeddings: HashMap<String, (Vec<f32>, String)> = HashMap::new();
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new();
info!("Generating new embeddings for all entities...");
for entity in &all_entities {
let embedding_input = format!(
@@ -387,7 +400,14 @@ impl KnowledgeEntity {
error!("{err_msg}");
return Err(AppError::internal(err_msg));
}
new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone()));
new_embeddings.insert(
entity.id.clone(),
(
embedding,
entity.user_id.clone(),
entity.source_id.clone(),
),
);
}
info!("Successfully generated all new embeddings.");
@@ -396,7 +416,7 @@ impl KnowledgeEntity {
let mut transaction_query = String::from("BEGIN TRANSACTION;");
// Add all update statements to the embedding table
for (id, (embedding, user_id)) in new_embeddings {
for (id, (embedding, user_id, source_id)) in new_embeddings {
let embedding = serde_json::to_string(&embedding).map_err(|e| {
AppError::internal(format!("embedding serialization failed: {e}"))
})?;
@@ -406,6 +426,7 @@ impl KnowledgeEntity {
entity_id = type::thing('knowledge_entity', '{id}'), \
embedding = {embedding}, \
user_id = '{user_id}', \
source_id = '{source_id}', \
created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \
updated_at = time::now();",
)
@@ -414,7 +435,12 @@ impl KnowledgeEntity {
write!(
transaction_query,
"DEFINE INDEX OVERWRITE idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};",
"{}",
hnsw_index_overwrite_sql(
"idx_embedding_knowledge_entity_embedding",
KnowledgeEntityEmbedding::table_name(),
new_dimensions as usize,
)
)
.map_err(AppError::internal)?;
@@ -431,6 +457,7 @@ impl KnowledgeEntity {
///
/// This variant uses the application's configured embedding provider (FastEmbed, OpenAI, etc.)
/// instead of directly calling OpenAI. Used during startup when embedding configuration changes.
#[allow(clippy::too_many_lines)]
pub async fn update_all_embeddings_with_provider(
db: &SurrealDbClient,
provider: &crate::utils::embedding::EmbeddingProvider,
@@ -453,7 +480,7 @@ impl KnowledgeEntity {
info!(entities = total_entities, "Found entities to process");
// Generate all new embeddings in memory
let mut new_embeddings: HashMap<String, (Vec<f32>, String)> = HashMap::new();
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new();
info!("Generating new embeddings for all entities...");
for (i, entity) in all_entities.iter().enumerate() {
@@ -484,7 +511,14 @@ impl KnowledgeEntity {
error!("{err_msg}");
return Err(AppError::internal(err_msg));
}
new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone()));
new_embeddings.insert(
entity.id.clone(),
(
embedding,
entity.user_id.clone(),
entity.source_id.clone(),
),
);
}
info!("Successfully generated all new embeddings.");
@@ -517,7 +551,7 @@ impl KnowledgeEntity {
info!("Applying embedding updates in a transaction...");
let mut transaction_query = String::from("BEGIN TRANSACTION;");
for (id, (embedding, user_id)) in new_embeddings {
for (id, (embedding, user_id, source_id)) in new_embeddings {
let embedding = serde_json::to_string(&embedding).map_err(|e| {
AppError::internal(format!("embedding serialization failed: {e}"))
})?;
@@ -527,6 +561,7 @@ impl KnowledgeEntity {
entity_id = type::thing('knowledge_entity', '{id}'), \
embedding = {embedding}, \
user_id = '{user_id}', \
source_id = '{source_id}', \
created_at = time::now(), \
updated_at = time::now();",
)
@@ -535,7 +570,12 @@ impl KnowledgeEntity {
write!(
transaction_query,
"DEFINE INDEX OVERWRITE idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};",
"{}",
hnsw_index_overwrite_sql(
"idx_embedding_knowledge_entity_embedding",
KnowledgeEntityEmbedding::table_name(),
new_dimensions,
)
)
.map_err(AppError::internal)?;
@@ -559,10 +599,21 @@ mod tests {
#![allow(clippy::expect_used, clippy::must_use_candidate)]
use super::*;
use crate::storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding;
use crate::storage::types::system_settings::SystemSettings;
use anyhow::{self, Context};
use serde_json::json;
use uuid::Uuid;
async fn configure_embedding_dimension(
db: &SurrealDbClient,
dimension: u32,
) -> anyhow::Result<()> {
let mut settings = SystemSettings::get_current(db).await?;
settings.embedding_dimensions = dimension;
SystemSettings::update(db, settings).await?;
Ok(())
}
#[tokio::test]
async fn test_knowledge_entity_creation() -> anyhow::Result<()> {
let source_id = "source123".to_string();
@@ -656,14 +707,15 @@ mod tests {
.await
.with_context(|| "Failed to apply migrations".to_string())?;
let source_id = "source123".to_string();
let entity_type = KnowledgeEntityType::Document;
let user_id = "user123".to_string();
configure_embedding_dimension(&db, 5).await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 5)
.await
.with_context(|| "Failed to redefine index length".to_string())?;
let source_id = "source123".to_string();
let entity_type = KnowledgeEntityType::Document;
let user_id = "user123".to_string();
let entity1 = KnowledgeEntity::new(
source_id.clone(),
"Entity 1".to_string(),
@@ -763,6 +815,7 @@ mod tests {
.await
.expect("Failed to apply migrations");
configure_embedding_dimension(&db, 3).await.expect("configure dim");
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("Failed to redefine index length");
@@ -847,6 +900,7 @@ mod tests {
.await
.with_context(|| "Failed to apply migrations".to_string())?;
configure_embedding_dimension(&db, 3).await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.with_context(|| "Failed to redefine index length".to_string())?;
@@ -914,6 +968,7 @@ mod tests {
.await
.with_context(|| "Failed to apply migrations".to_string())?;
configure_embedding_dimension(&db, 3).await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.with_context(|| "Failed to redefine index length".to_string())?;
@@ -1012,6 +1067,7 @@ mod tests {
.await
.with_context(|| "Failed to apply migrations".to_string())?;
configure_embedding_dimension(&db, 3).await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.with_context(|| "Failed to redefine index length".to_string())?;
@@ -2,11 +2,17 @@ use std::collections::HashMap;
use surrealdb::RecordId;
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
use crate::{
error::AppError,
storage::{db::SurrealDbClient, indexes::hnsw_index_redefine_transaction_sql},
stored_object,
};
stored_object!(KnowledgeEntityEmbedding, "knowledge_entity_embedding", {
entity_id: RecordId,
embedding: Vec<f32>,
/// Denormalized source id for bulk deletes
source_id: String,
/// Denormalized user id for query scoping
user_id: String
});
@@ -17,12 +23,10 @@ impl KnowledgeEntityEmbedding {
db: &SurrealDbClient,
dimension: usize,
) -> Result<(), AppError> {
let query = format!(
"BEGIN TRANSACTION;
REMOVE INDEX IF EXISTS idx_embedding_knowledge_entity_embedding ON TABLE {table};
DEFINE INDEX idx_embedding_knowledge_entity_embedding ON TABLE {table} FIELDS embedding HNSW DIMENSION {dimension};
COMMIT TRANSACTION;",
table = Self::table_name(),
let query = hnsw_index_redefine_transaction_sql(
"idx_embedding_knowledge_entity_embedding",
Self::table_name(),
dimension,
);
let res = db.client.query(query).await.map_err(AppError::Database)?;
@@ -31,16 +35,36 @@ impl KnowledgeEntityEmbedding {
Ok(())
}
/// Create a new knowledge entity embedding
/// Validates that an embedding vector matches the configured HNSW dimension.
#[allow(clippy::result_large_err)]
pub fn validate_dimension(embedding: &[f32], expected: usize) -> Result<(), AppError> {
if embedding.len() != expected {
return Err(AppError::Validation(format!(
"embedding dimension mismatch: got {}, expected {expected}",
embedding.len()
)));
}
Ok(())
}
/// Create a new knowledge entity embedding.
///
/// The embedding record id equals `entity_id` so each entity has at most one embedding row.
#[must_use]
pub fn new(entity_id: &str, embedding: Vec<f32>, user_id: String) -> Self {
pub fn new(
entity_id: &str,
source_id: String,
embedding: Vec<f32>,
user_id: String,
) -> Self {
let now = Utc::now();
Self {
id: uuid::Uuid::new_v4().to_string(),
id: entity_id.to_owned(),
created_at: now,
updated_at: now,
entity_id: RecordId::from_table_key("knowledge_entity", entity_id),
embedding,
source_id,
user_id,
}
}
@@ -73,8 +97,6 @@ impl KnowledgeEntityEmbedding {
return Ok(HashMap::new());
}
let ids_list: Vec<RecordId> = entity_ids.to_vec();
let query = format!(
"SELECT * FROM {} WHERE entity_id INSIDE $entity_ids",
Self::table_name()
@@ -82,7 +104,7 @@ impl KnowledgeEntityEmbedding {
let mut result = db
.client
.query(query)
.bind(("entity_ids", ids_list))
.bind(("entity_ids", entity_ids.to_vec()))
.await
.map_err(AppError::Database)?;
let embeddings: Vec<Self> = result.take(0).map_err(AppError::Database)?;
@@ -106,32 +128,28 @@ impl KnowledgeEntityEmbedding {
.query(query)
.bind(("entity_id", entity_id.clone()))
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
Ok(())
}
/// Delete embeddings by source_id (via joining to knowledge_entity table)
/// Delete all embeddings with the given denormalized `source_id`.
pub async fn delete_by_source_id(
source_id: &str,
db: &SurrealDbClient,
) -> Result<(), AppError> {
#[derive(Deserialize)]
struct IdRow {
id: RecordId,
}
let query = "SELECT id FROM knowledge_entity WHERE source_id = $source_id";
let mut res = db
.client
let query = format!(
"DELETE FROM {} WHERE source_id = $source_id",
Self::table_name()
);
db.client
.query(query)
.bind(("source_id", source_id.to_owned()))
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
let ids: Vec<IdRow> = res.take(0).map_err(AppError::Database)?;
for row in ids {
Self::delete_by_entity_id(&row.id, db).await?;
}
Ok(())
}
}
@@ -142,6 +160,7 @@ mod tests {
use super::*;
use crate::storage::db::SurrealDbClient;
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
use crate::storage::types::system_settings::SystemSettings;
use anyhow::{self, Context};
use chrono::Utc;
use surrealdb::Value as SurrealValue;
@@ -161,6 +180,24 @@ mod tests {
Ok(db)
}
async fn setup_test_db_with_embedding_dimension(
dimension: u32,
) -> anyhow::Result<SurrealDbClient> {
let db = setup_test_db().await?;
let mut settings = SystemSettings::get_current(&db).await?;
settings.embedding_dimensions = dimension;
SystemSettings::update(&db, settings).await?;
Ok(db)
}
async fn prepare_test_db(dimension: u32) -> anyhow::Result<SurrealDbClient> {
let db = setup_test_db_with_embedding_dimension(dimension).await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, dimension as usize)
.await
.with_context(|| format!("set test index dimension to {dimension}"))?;
Ok(db)
}
fn build_knowledge_entity_with_id(
key: &str,
source_id: &str,
@@ -179,12 +216,27 @@ mod tests {
}
}
#[test]
fn new_uses_entity_id_as_record_id() {
let emb = KnowledgeEntityEmbedding::new(
"entity-abc",
"source-1".to_owned(),
vec![0.1, 0.2],
"user-1".to_owned(),
);
assert_eq!(emb.id, "entity-abc");
}
#[test]
fn validate_dimension_rejects_mismatch() {
let err = KnowledgeEntityEmbedding::validate_dimension(&[0.1, 0.2, 0.3], 2)
.expect_err("expected dimension mismatch");
assert!(matches!(err, AppError::Validation(_)));
}
#[tokio::test]
async fn test_create_and_get_by_entity_id() -> anyhow::Result<()> {
let db = setup_test_db().await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.with_context(|| "set test index dimension".to_string())?;
let db = prepare_test_db(3).await?;
let user_id = "user_ke";
let entity_key = "entity-1";
let source_id = "source-ke";
@@ -203,7 +255,9 @@ mod tests {
.with_context(|| "Failed to get embedding by entity_id".to_string())?
.ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?;
assert_eq!(fetched.id, entity_key);
assert_eq!(fetched.user_id, user_id);
assert_eq!(fetched.source_id, source_id);
assert_eq!(fetched.entity_id, entity_rid);
assert_eq!(fetched.embedding, embedding_vec);
@@ -212,10 +266,7 @@ mod tests {
#[tokio::test]
async fn test_delete_by_entity_id() -> anyhow::Result<()> {
let db = setup_test_db().await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.with_context(|| "set test index dimension".to_string())?;
let db = prepare_test_db(3).await?;
let user_id = "user_ke";
let entity_key = "entity-delete";
let source_id = "source-del";
@@ -247,15 +298,11 @@ mod tests {
#[tokio::test]
async fn test_store_with_embedding_creates_entity_and_embedding() -> anyhow::Result<()> {
let db = setup_test_db().await?;
let db = prepare_test_db(3).await?;
let user_id = "user_store";
let source_id = "source_store";
let embedding = vec![0.2_f32, 0.3, 0.4];
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, embedding.len())
.await
.with_context(|| "set test index dimension".to_string())?;
let entity = build_knowledge_entity_with_id("entity-store", source_id, user_id);
KnowledgeEntity::store_with_embedding(entity.clone(), embedding.clone(), &db)
@@ -274,18 +321,30 @@ mod tests {
.with_context(|| "Failed to fetch embedding".to_string())?;
let stored_embedding =
stored_embedding.ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?;
assert_eq!(stored_embedding.id, entity.id);
assert_eq!(stored_embedding.user_id, user_id);
assert_eq!(stored_embedding.source_id, source_id);
assert_eq!(stored_embedding.entity_id, entity_rid);
Ok(())
}
#[tokio::test]
async fn test_store_with_embedding_rejects_wrong_dimension() -> anyhow::Result<()> {
let db = prepare_test_db(3).await?;
let entity = build_knowledge_entity_with_id("entity-dim", "source-dim", "user-dim");
let result =
KnowledgeEntity::store_with_embedding(entity, vec![0.1, 0.2], &db).await;
assert!(matches!(result, Err(AppError::Validation(_))));
Ok(())
}
#[tokio::test]
async fn test_delete_by_source_id() -> anyhow::Result<()> {
let db = setup_test_db().await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.with_context(|| "set test index dimension".to_string())?;
let db = prepare_test_db(3).await?;
let user_id = "user_ke";
let source_id = "shared-ke";
let other_source = "other-ke";
@@ -363,6 +422,10 @@ mod tests {
idx_sql.contains("DIMENSION 16"),
"expected index definition to contain new dimension, got: {idx_sql}"
);
assert!(
idx_sql.contains("DIST COSINE"),
"expected index definition to use cosine distance, got: {idx_sql}"
);
Ok(())
}
@@ -374,10 +437,7 @@ mod tests {
entity_id: KnowledgeEntity,
}
let db = setup_test_db().await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.with_context(|| "set test index dimension".to_string())?;
let db = prepare_test_db(3).await?;
let user_id = "user_ke";
let entity_key = "entity-fetch";
let source_id = "source-fetch";
@@ -412,4 +472,47 @@ mod tests {
Ok(())
}
#[tokio::test]
async fn test_upsert_replaces_existing_embedding_row() -> anyhow::Result<()> {
let db = prepare_test_db(3).await?;
let user_id = "user-upsert";
let source_id = "source-upsert";
let entity = build_knowledge_entity_with_id("entity-upsert", source_id, user_id);
KnowledgeEntity::store_with_embedding(entity.clone(), vec![1.0_f32, 0.0, 0.0], &db)
.await
.with_context(|| "initial store".to_string())?;
let replacement = KnowledgeEntityEmbedding::new(
&entity.id,
source_id.to_owned(),
vec![0.0, 1.0, 0.0],
user_id.to_owned(),
);
db.upsert_item(replacement)
.await
.with_context(|| "upsert replacement embedding".to_string())?;
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
let rows: Vec<KnowledgeEntityEmbedding> = db
.client
.query(format!(
"SELECT * FROM {} WHERE entity_id = $entity_id",
KnowledgeEntityEmbedding::table_name()
))
.bind(("entity_id", entity_rid))
.await
.with_context(|| "count embeddings".to_string())?
.take(0)
.with_context(|| "take embeddings".to_string())?;
assert_eq!(rows.len(), 1);
let row = rows.first().expect("expected one embedding row");
assert_eq!(row.id, entity.id);
assert_eq!(row.embedding, vec![0.0, 1.0, 0.0]);
Ok(())
}
}
@@ -1,4 +1,5 @@
use crate::storage::types::serde_helpers::deserialize_flexible_id;
use crate::storage::types::user::User;
use crate::{error::AppError, storage::db::SurrealDbClient};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
@@ -40,7 +41,21 @@ impl KnowledgeRelationship {
},
}
}
pub async fn store_relationship(&self, db_client: &SurrealDbClient) -> Result<(), AppError> {
User::get_and_validate_knowledge_entity(
&self.in_,
&self.metadata.user_id,
db_client,
)
.await?;
User::get_and_validate_knowledge_entity(
&self.out,
&self.metadata.user_id,
db_client,
)
.await?;
db_client
.client
.query(
@@ -61,22 +76,30 @@ impl KnowledgeRelationship {
.bind(("user_id", self.metadata.user_id.clone()))
.bind(("source_id", self.metadata.source_id.clone()))
.bind(("relationship_type", self.metadata.relationship_type.clone()))
.await?
.check()?;
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
Ok(())
}
pub async fn delete_relationships_by_source_id(
source_id: &str,
user_id: &str,
db_client: &SurrealDbClient,
) -> Result<(), AppError> {
db_client
.client
.query("DELETE FROM relates_to WHERE metadata.source_id = $source_id")
.query(
"DELETE FROM relates_to WHERE metadata.source_id = $source_id AND metadata.user_id = $user_id",
)
.bind(("source_id", source_id.to_owned()))
.await?
.check()?;
.bind(("user_id", user_id.to_owned()))
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
Ok(())
}
@@ -86,39 +109,37 @@ impl KnowledgeRelationship {
user_id: &str,
db_client: &SurrealDbClient,
) -> Result<(), AppError> {
let mut authorized_result = db_client
let mut delete_result = db_client
.client
.query(
"SELECT * FROM relates_to WHERE id = type::thing('relates_to', $id) AND metadata.user_id = $user_id",
"DELETE type::thing('relates_to', $id) WHERE metadata.user_id = $user_id RETURN BEFORE;",
)
.bind(("id", id.to_owned()))
.bind(("user_id", user_id.to_owned()))
.await?;
let authorized: Vec<KnowledgeRelationship> = authorized_result.take(0)?;
.await
.map_err(AppError::Database)?;
let deleted: Vec<KnowledgeRelationship> =
delete_result.take(0).map_err(AppError::Database)?;
if authorized.is_empty() {
let mut exists_result = db_client
.client
.query("SELECT * FROM type::thing('relates_to', $id)")
.bind(("id", id.to_owned()))
.await?;
let existing: Option<KnowledgeRelationship> = exists_result.take(0)?;
if !deleted.is_empty() {
return Ok(());
}
if existing.is_some() {
Err(AppError::Auth(
"Not authorized to delete relationship".into(),
))
} else {
Err(AppError::NotFound(format!("Relationship {id} not found")))
}
let mut exists_result = db_client
.client
.query("SELECT * FROM type::thing('relates_to', $id)")
.bind(("id", id.to_owned()))
.await
.map_err(AppError::Database)?;
let existing: Option<KnowledgeRelationship> =
exists_result.take(0).map_err(AppError::Database)?;
if existing.is_some() {
Err(AppError::Auth(
"Not authorized to delete relationship".into(),
))
} else {
db_client
.client
.query("DELETE type::thing('relates_to', $id)")
.bind(("id", id.to_owned()))
.await?
.check()?;
Ok(())
Err(AppError::NotFound(format!("Relationship {id} not found")))
}
}
}
@@ -158,11 +179,14 @@ mod tests {
result.take(0).expect("failed to take relationship by id")
}
async fn create_test_entity(name: &str, db_client: &SurrealDbClient) -> anyhow::Result<String> {
async fn create_test_entity(
name: &str,
user_id: &str,
db_client: &SurrealDbClient,
) -> anyhow::Result<String> {
let source_id = "source123".to_string();
let description = format!("Description for {name}");
let entity_type = KnowledgeEntityType::Document;
let user_id = "user123".to_string();
let entity = KnowledgeEntity::new(
source_id,
@@ -170,7 +194,7 @@ mod tests {
description,
entity_type,
None,
user_id,
user_id.to_string(),
);
let stored: Option<KnowledgeEntity> = db_client
@@ -211,18 +235,18 @@ mod tests {
#[tokio::test]
async fn test_store_and_verify_by_source_id() -> anyhow::Result<()> {
let db = setup_test_db().await;
let user_id = "user123";
let entity1_id = create_test_entity("Entity 1", &db).await?;
let entity2_id = create_test_entity("Entity 2", &db).await?;
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
let entity2_id = create_test_entity("Entity 2", user_id, &db).await?;
let user_id = "user123".to_string();
let source_id = "source123".to_string();
let relationship_type = "references".to_string();
let relationship = KnowledgeRelationship::new(
entity1_id.clone(),
entity2_id.clone(),
user_id.clone(),
user_id.to_string(),
source_id.clone(),
relationship_type,
);
@@ -257,16 +281,38 @@ mod tests {
}
#[tokio::test]
async fn test_store_relationship_resists_query_injection() -> anyhow::Result<()> {
async fn test_store_relationship_rejects_foreign_entity() -> anyhow::Result<()> {
let db = setup_test_db().await;
let entity1_id = create_test_entity("Entity 1", &db).await?;
let entity2_id = create_test_entity("Entity 2", &db).await?;
let owner_entity = create_test_entity("Owner entity", "owner-user", &db).await?;
let other_entity = create_test_entity("Other entity", "other-user", &db).await?;
let relationship = KnowledgeRelationship::new(
owner_entity,
other_entity,
"owner-user".to_string(),
"source123".to_string(),
"references".to_string(),
);
let result = relationship.store_relationship(&db).await;
assert!(matches!(result, Err(AppError::Auth(_))));
Ok(())
}
#[tokio::test]
async fn test_store_relationship_resists_query_injection() -> anyhow::Result<()> {
let db = setup_test_db().await;
let user_id = "user123";
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
let entity2_id = create_test_entity("Entity 2", user_id, &db).await?;
let relationship = KnowledgeRelationship::new(
entity1_id,
entity2_id,
"user'123".to_string(),
user_id.to_string(),
"source123'; DELETE FROM relates_to; --".to_string(),
"references'; UPDATE user SET admin = true; --".to_string(),
);
@@ -297,18 +343,18 @@ mod tests {
#[tokio::test]
async fn test_store_and_delete_relationship() -> anyhow::Result<()> {
let db = setup_test_db().await;
let user_id = "user123";
let entity1_id = create_test_entity("Entity 1", &db).await?;
let entity2_id = create_test_entity("Entity 2", &db).await?;
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
let entity2_id = create_test_entity("Entity 2", user_id, &db).await?;
let user_id = "user123".to_string();
let source_id = "source123".to_string();
let relationship_type = "references".to_string();
let relationship = KnowledgeRelationship::new(
entity1_id.clone(),
entity2_id.clone(),
user_id.clone(),
user_id.to_string(),
source_id.clone(),
relationship_type,
);
@@ -319,9 +365,9 @@ mod tests {
.with_context(|| "Failed to store relationship".to_string())?;
let mut existing_before_delete = db
.query(format!(
"SELECT * FROM relates_to WHERE metadata.user_id = '{user_id}' AND metadata.source_id = '{source_id}'"
))
.query("SELECT * FROM relates_to WHERE metadata.user_id = $user_id AND metadata.source_id = $source_id")
.bind(("user_id", user_id.to_string()))
.bind(("source_id", source_id.clone()))
.await
.with_context(|| "Query failed".to_string())?;
let before_results: Vec<KnowledgeRelationship> =
@@ -331,14 +377,14 @@ mod tests {
"Relationship should exist before deletion"
);
KnowledgeRelationship::delete_relationship_by_id(&relationship.id, &user_id, &db)
KnowledgeRelationship::delete_relationship_by_id(&relationship.id, user_id, &db)
.await
.with_context(|| "Failed to delete relationship by ID".to_string())?;
let mut result = db
.query(format!(
"SELECT * FROM relates_to WHERE metadata.user_id = '{user_id}' AND metadata.source_id = '{source_id}'"
))
.query("SELECT * FROM relates_to WHERE metadata.user_id = $user_id AND metadata.source_id = $source_id")
.bind(("user_id", user_id.to_string()))
.bind(("source_id", source_id))
.await
.with_context(|| "Query failed".to_string())?;
let results: Vec<KnowledgeRelationship> = result.take(0).unwrap_or_default();
@@ -351,17 +397,17 @@ mod tests {
#[tokio::test]
async fn test_delete_relationship_by_id_unauthorized() -> anyhow::Result<()> {
let db = setup_test_db().await;
let owner_user_id = "owner-user";
let entity1_id = create_test_entity("Entity 1", &db).await?;
let entity2_id = create_test_entity("Entity 2", &db).await?;
let entity1_id = create_test_entity("Entity 1", owner_user_id, &db).await?;
let entity2_id = create_test_entity("Entity 2", owner_user_id, &db).await?;
let owner_user_id = "owner-user".to_string();
let source_id = "source123".to_string();
let relationship = KnowledgeRelationship::new(
entity1_id.clone(),
entity2_id.clone(),
owner_user_id.clone(),
owner_user_id.to_string(),
source_id,
"references".to_string(),
);
@@ -372,9 +418,8 @@ mod tests {
.with_context(|| "Failed to store relationship".to_string())?;
let mut before_attempt = db
.query(format!(
"SELECT * FROM relates_to WHERE metadata.user_id = '{owner_user_id}'"
))
.query("SELECT * FROM relates_to WHERE metadata.user_id = $user_id")
.bind(("user_id", owner_user_id.to_string()))
.await
.with_context(|| "Query failed".to_string())?;
let before_results: Vec<KnowledgeRelationship> = before_attempt.take(0).unwrap_or_default();
@@ -398,9 +443,8 @@ mod tests {
}
let mut after_attempt = db
.query(format!(
"SELECT * FROM relates_to WHERE metadata.user_id = '{owner_user_id}'"
))
.query("SELECT * FROM relates_to WHERE metadata.user_id = $user_id")
.bind(("user_id", owner_user_id.to_string()))
.await
.with_context(|| "Query failed".to_string())?;
let results: Vec<KnowledgeRelationship> = after_attempt.take(0).unwrap_or_default();
@@ -416,19 +460,19 @@ mod tests {
#[tokio::test]
async fn test_store_relationship_exists() -> anyhow::Result<()> {
let db = setup_test_db().await;
let user_id = "user123";
let entity1_id = create_test_entity("Entity 1", &db).await?;
let entity2_id = create_test_entity("Entity 2", &db).await?;
let entity3_id = create_test_entity("Entity 3", &db).await?;
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
let entity2_id = create_test_entity("Entity 2", user_id, &db).await?;
let entity3_id = create_test_entity("Entity 3", user_id, &db).await?;
let user_id = "user123".to_string();
let source_id = "source123".to_string();
let different_source_id = "different_source".to_string();
let relationship1 = KnowledgeRelationship::new(
entity1_id.clone(),
entity2_id.clone(),
user_id.clone(),
user_id.to_string(),
source_id.clone(),
"references".to_string(),
);
@@ -436,7 +480,7 @@ mod tests {
let relationship2 = KnowledgeRelationship::new(
entity2_id.clone(),
entity3_id.clone(),
user_id.clone(),
user_id.to_string(),
source_id.clone(),
"contains".to_string(),
);
@@ -444,7 +488,7 @@ mod tests {
let different_relationship = KnowledgeRelationship::new(
entity1_id.clone(),
entity3_id.clone(),
user_id.clone(),
user_id.to_string(),
different_source_id.clone(),
"mentions".to_string(),
);
@@ -480,7 +524,7 @@ mod tests {
before_delete_different.take(0).unwrap_or_default();
assert_eq!(before_delete_different_rows.len(), 1);
KnowledgeRelationship::delete_relationships_by_source_id(&source_id, &db)
KnowledgeRelationship::delete_relationships_by_source_id(&source_id, user_id, &db)
.await
.with_context(|| "Failed to delete relationships by source_id".to_string())?;
@@ -497,19 +541,60 @@ mod tests {
Ok(())
}
#[tokio::test]
async fn test_delete_relationships_by_source_id_scoped_to_user() -> anyhow::Result<()> {
let db = setup_test_db().await;
let user_a = "user-a";
let user_b = "user-b";
let shared_source = "shared-source";
let a1 = create_test_entity("A1", user_a, &db).await?;
let a2 = create_test_entity("A2", user_a, &db).await?;
let b1 = create_test_entity("B1", user_b, &db).await?;
let b2 = create_test_entity("B2", user_b, &db).await?;
let rel_a = KnowledgeRelationship::new(
a1,
a2,
user_a.to_string(),
shared_source.to_string(),
"references".to_string(),
);
let rel_b = KnowledgeRelationship::new(
b1,
b2,
user_b.to_string(),
shared_source.to_string(),
"references".to_string(),
);
rel_a.store_relationship(&db).await?;
rel_b.store_relationship(&db).await?;
KnowledgeRelationship::delete_relationships_by_source_id(shared_source, user_a, &db)
.await?;
assert!(get_relationship_by_id(&rel_a.id, &db).await.is_none());
assert!(get_relationship_by_id(&rel_b.id, &db).await.is_some());
Ok(())
}
#[tokio::test]
async fn test_delete_relationships_by_source_id_resists_query_injection() -> anyhow::Result<()>
{
let db = setup_test_db().await;
let user_id = "user123";
let entity1_id = create_test_entity("Entity 1", &db).await?;
let entity2_id = create_test_entity("Entity 2", &db).await?;
let entity3_id = create_test_entity("Entity 3", &db).await?;
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
let entity2_id = create_test_entity("Entity 2", user_id, &db).await?;
let entity3_id = create_test_entity("Entity 3", user_id, &db).await?;
let safe_relationship = KnowledgeRelationship::new(
entity1_id.clone(),
entity2_id.clone(),
"user123".to_string(),
user_id.to_string(),
"safe_source".to_string(),
"references".to_string(),
);
@@ -517,7 +602,7 @@ mod tests {
let other_relationship = KnowledgeRelationship::new(
entity2_id,
entity3_id,
"user123".to_string(),
user_id.to_string(),
"other_source".to_string(),
"contains".to_string(),
);
@@ -531,9 +616,13 @@ mod tests {
.await
.expect("store other relationship");
KnowledgeRelationship::delete_relationships_by_source_id("safe_source' OR 1=1 --", &db)
.await
.expect("delete call should succeed");
KnowledgeRelationship::delete_relationships_by_source_id(
"safe_source' OR 1=1 --",
user_id,
&db,
)
.await
.expect("delete call should succeed");
let remaining_safe = get_relationship_by_id(&safe_relationship.id, &db).await;
let remaining_other = get_relationship_by_id(&other_relationship.id, &db).await;
+1 -1
View File
@@ -62,7 +62,7 @@ impl fmt::Display for Message {
pub fn format_history(history: &[Message]) -> String {
let estimated: usize = history
.iter()
.map(|m| m.content.len() + 10)
.map(|m| m.content.len().saturating_add(10))
.sum();
let mut out = String::with_capacity(estimated);
for (i, msg) in history.iter().enumerate() {