mirror of
https://github.com/perstarkse/minne.git
synced 2026-05-30 03:10:45 +02:00
chore: harden knowledge graph storage and clear common clippy warnings
Enforce stable 1:1 entity embeddings, relationship endpoint auth, and user-scoped deletes; align schemas/migrations and resolve common crate clippy findings.
This commit is contained in:
@@ -0,0 +1,33 @@
|
||||
-- Harden knowledge entity embeddings and graph storage invariants.
|
||||
|
||||
DEFINE FIELD IF NOT EXISTS source_id ON knowledge_entity_embedding TYPE string;
|
||||
|
||||
-- Backfill denormalized source_id from the linked entity.
|
||||
FOR $emb IN (SELECT * FROM knowledge_entity_embedding WHERE source_id = NONE OR source_id = '') {
|
||||
LET $entity = (SELECT source_id FROM $emb.entity_id)[0];
|
||||
IF $entity != NONE {
|
||||
UPDATE $emb.id SET source_id = $entity.source_id;
|
||||
}
|
||||
};
|
||||
|
||||
-- Re-key embeddings so record id matches entity id (stable 1:1 identity).
|
||||
FOR $emb IN (SELECT * FROM knowledge_entity_embedding) {
|
||||
LET $entity_key = record::id($emb.entity_id);
|
||||
LET $canonical = type::thing('knowledge_entity_embedding', $entity_key);
|
||||
IF $emb.id != $canonical {
|
||||
UPSERT $canonical CONTENT {
|
||||
entity_id: $emb.entity_id,
|
||||
embedding: $emb.embedding,
|
||||
user_id: $emb.user_id,
|
||||
source_id: $emb.source_id,
|
||||
created_at: $emb.created_at,
|
||||
updated_at: $emb.updated_at
|
||||
};
|
||||
DELETE $emb.id;
|
||||
}
|
||||
};
|
||||
|
||||
REMOVE INDEX IF EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id UNIQUE;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_source_id_idx ON knowledge_entity_embedding FIELDS source_id;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_user_source_idx ON knowledge_entity FIELDS user_id, source_id;
|
||||
@@ -0,0 +1 @@
|
||||
{"schemas":"--- original\n+++ modified\n@@ -68,7 +68,7 @@\n\n # Defines the schema for the 'knowledge_entity' table.\n\n-DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMALESS;\n+DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMAFULL;\n\n # Standard fields\n DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE datetime;\n@@ -90,6 +90,7 @@\n -- DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;\n+DEFINE INDEX IF NOT EXISTS knowledge_entity_user_source_idx ON knowledge_entity FIELDS user_id, source_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;\n\n@@ -102,6 +103,7 @@\n DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity_embedding TYPE datetime;\n DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity_embedding TYPE datetime;\n DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity_embedding TYPE string;\n+DEFINE FIELD IF NOT EXISTS source_id ON knowledge_entity_embedding TYPE string;\n\n -- Custom fields\n DEFINE FIELD IF NOT EXISTS entity_id ON knowledge_entity_embedding TYPE record<knowledge_entity>;\n@@ -109,8 +111,9 @@\n\n -- Indexes\n -- DEFINE INDEX IF NOT EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 1536;\n-DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id;\n+DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id UNIQUE;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_user_id_idx ON knowledge_entity_embedding FIELDS user_id;\n+DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_source_id_idx ON knowledge_entity_embedding FIELDS source_id;\n\n # Defines the schema for the 'message' table.\n\n@@ -135,19 +138,17 @@\n # Defines the 'relates_to' edge table for KnowledgeRelationships.\n # Edges connect nodes, in this case knowledge_entity records.\n\n-# Define the edge table itself, enforcing connections between knowledge_entity records\n-# SCHEMAFULL requires all fields to be defined, maybe start with SCHEMALESS if metadata might vary\n-DEFINE TABLE IF NOT EXISTS relates_to SCHEMALESS TYPE RELATION FROM knowledge_entity TO knowledge_entity;\n+DEFINE TABLE IF NOT EXISTS relates_to SCHEMAFULL TYPE RELATION FROM knowledge_entity TO knowledge_entity;\n+\n+DEFINE FIELD IF NOT EXISTS in ON relates_to TYPE record<knowledge_entity>;\n+DEFINE FIELD IF NOT EXISTS out ON relates_to TYPE record<knowledge_entity>;\n\n-# Define the metadata field within the edge\n # RelationshipMetadata is a struct, store as object\n DEFINE FIELD IF NOT EXISTS metadata ON relates_to TYPE object;\n+DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string;\n+DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;\n+DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;\n\n-# Optionally, define fields within the metadata object for stricter schema (requires SCHEMAFULL on table)\n-# DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string;\n-# DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;\n-# DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;\n-\n # Add indexes based on query patterns (delete_relationships_by_source_id, get_knowledge_relationships)\n DEFINE INDEX IF NOT EXISTS relates_to_metadata_source_id_idx ON relates_to FIELDS metadata.source_id;\n DEFINE INDEX IF NOT EXISTS relates_to_metadata_user_id_idx ON relates_to FIELDS metadata.user_id;\n","events":null}
|
||||
@@ -1,6 +1,6 @@
|
||||
# Defines the schema for the 'knowledge_entity' table.
|
||||
|
||||
DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMALESS;
|
||||
DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMAFULL;
|
||||
|
||||
# Standard fields
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE datetime;
|
||||
@@ -22,5 +22,6 @@ DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity TYPE string;
|
||||
-- DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_user_source_idx ON knowledge_entity FIELDS user_id, source_id;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;
|
||||
|
||||
@@ -7,6 +7,7 @@ DEFINE TABLE IF NOT EXISTS knowledge_entity_embedding SCHEMAFULL;
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity_embedding TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity_embedding TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity_embedding TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS source_id ON knowledge_entity_embedding TYPE string;
|
||||
|
||||
-- Custom fields
|
||||
DEFINE FIELD IF NOT EXISTS entity_id ON knowledge_entity_embedding TYPE record<knowledge_entity>;
|
||||
@@ -14,5 +15,6 @@ DEFINE FIELD IF NOT EXISTS embedding ON knowledge_entity_embedding TYPE array<fl
|
||||
|
||||
-- Indexes
|
||||
-- DEFINE INDEX IF NOT EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 1536;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id UNIQUE;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_user_id_idx ON knowledge_entity_embedding FIELDS user_id;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_source_id_idx ON knowledge_entity_embedding FIELDS source_id;
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
# Defines the 'relates_to' edge table for KnowledgeRelationships.
|
||||
# Edges connect nodes, in this case knowledge_entity records.
|
||||
|
||||
# Define the edge table itself, enforcing connections between knowledge_entity records
|
||||
# SCHEMAFULL requires all fields to be defined, maybe start with SCHEMALESS if metadata might vary
|
||||
DEFINE TABLE IF NOT EXISTS relates_to SCHEMALESS TYPE RELATION FROM knowledge_entity TO knowledge_entity;
|
||||
DEFINE TABLE IF NOT EXISTS relates_to SCHEMAFULL TYPE RELATION FROM knowledge_entity TO knowledge_entity;
|
||||
|
||||
DEFINE FIELD IF NOT EXISTS in ON relates_to TYPE record<knowledge_entity>;
|
||||
DEFINE FIELD IF NOT EXISTS out ON relates_to TYPE record<knowledge_entity>;
|
||||
|
||||
# Define the metadata field within the edge
|
||||
# RelationshipMetadata is a struct, store as object
|
||||
DEFINE FIELD IF NOT EXISTS metadata ON relates_to TYPE object;
|
||||
|
||||
# Optionally, define fields within the metadata object for stricter schema (requires SCHEMAFULL on table)
|
||||
# DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string;
|
||||
# DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;
|
||||
# DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;
|
||||
|
||||
# Add indexes based on query patterns (delete_relationships_by_source_id, get_knowledge_relationships)
|
||||
DEFINE INDEX IF NOT EXISTS relates_to_metadata_source_id_idx ON relates_to FIELDS metadata.source_id;
|
||||
|
||||
@@ -11,6 +11,31 @@ use crate::{error::AppError, storage::db::SurrealDbClient};
|
||||
const INDEX_POLL_INTERVAL: Duration = Duration::from_millis(50);
|
||||
const FTS_ANALYZER_NAME: &str = "app_en_fts_analyzer";
|
||||
|
||||
/// HNSW index options used by runtime index creation (includes CONCURRENTLY).
|
||||
pub const HNSW_INDEX_OPTIONS: &str = "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY";
|
||||
/// HNSW index options for use inside transactions (CONCURRENTLY not supported).
|
||||
pub const HNSW_INDEX_OPTIONS_SYNC: &str = "DIST COSINE TYPE F32 EFC 100 M 8";
|
||||
|
||||
/// Builds a `DEFINE INDEX OVERWRITE ... HNSW` statement matching runtime index options.
|
||||
#[must_use]
|
||||
pub fn hnsw_index_overwrite_sql(index_name: &str, table: &str, dimension: usize) -> String {
|
||||
format!(
|
||||
"DEFINE INDEX OVERWRITE {index_name} ON TABLE {table} \
|
||||
FIELDS embedding HNSW DIMENSION {dimension} {HNSW_INDEX_OPTIONS};"
|
||||
)
|
||||
}
|
||||
|
||||
/// Recreates an HNSW index inside a transaction (for tests and dimension migrations).
|
||||
#[must_use]
|
||||
pub fn hnsw_index_redefine_transaction_sql(index_name: &str, table: &str, dimension: usize) -> String {
|
||||
format!(
|
||||
"BEGIN TRANSACTION;
|
||||
REMOVE INDEX IF EXISTS {index_name} ON TABLE {table};
|
||||
DEFINE INDEX {index_name} ON TABLE {table} FIELDS embedding HNSW DIMENSION {dimension} {HNSW_INDEX_OPTIONS_SYNC};
|
||||
COMMIT TRANSACTION;"
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct HnswIndexSpec {
|
||||
index_name: &'static str,
|
||||
@@ -23,12 +48,12 @@ const fn hnsw_index_specs() -> [HnswIndexSpec; 2] {
|
||||
HnswIndexSpec {
|
||||
index_name: "idx_embedding_text_chunk_embedding",
|
||||
table: "text_chunk_embedding",
|
||||
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY",
|
||||
options: HNSW_INDEX_OPTIONS,
|
||||
},
|
||||
HnswIndexSpec {
|
||||
index_name: "idx_embedding_knowledge_entity_embedding",
|
||||
table: "knowledge_entity_embedding",
|
||||
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY",
|
||||
options: HNSW_INDEX_OPTIONS,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
@@ -518,8 +518,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_sidebar_conversation_deserializes_plain_string_id() {
|
||||
let item: SidebarConversation =
|
||||
serde_json::from_str(r#"{"id":"conv-plain","title":"My chat"}"#).unwrap();
|
||||
let item: SidebarConversation = serde_json::from_str(r#"{"id":"conv-plain","title":"My chat"}"#)
|
||||
.expect("valid sidebar conversation json");
|
||||
assert_eq!(item.id, "conv-plain");
|
||||
assert_eq!(item.title, "My chat");
|
||||
}
|
||||
@@ -543,8 +543,9 @@ mod tests {
|
||||
.await
|
||||
.expect("Failed to load sidebar");
|
||||
assert_eq!(items.len(), 1);
|
||||
assert_eq!(items[0].id, expected_id);
|
||||
assert_eq!(items[0].title, "Sidebar title");
|
||||
let item = items.first().expect("expected one sidebar item");
|
||||
assert_eq!(item.id, expected_id);
|
||||
assert_eq!(item.title, "Sidebar title");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -570,7 +571,13 @@ mod tests {
|
||||
let owner_messages =
|
||||
fetch_messages_for_owner(&db, &conversation_id, owner).await?;
|
||||
assert_eq!(owner_messages.len(), 1);
|
||||
assert_eq!(owner_messages[0].content, "secret message");
|
||||
assert_eq!(
|
||||
owner_messages
|
||||
.first()
|
||||
.expect("expected owner message")
|
||||
.content,
|
||||
"secret message"
|
||||
);
|
||||
|
||||
let intruder_messages =
|
||||
fetch_messages_for_owner(&db, &conversation_id, intruder).await?;
|
||||
@@ -617,8 +624,14 @@ mod tests {
|
||||
Conversation::get_complete_conversation(&conversation_id, user_id, &db).await?;
|
||||
|
||||
assert_eq!(messages.len(), 2);
|
||||
assert_eq!(messages[0].content, "first");
|
||||
assert_eq!(messages[1].content, "second");
|
||||
assert_eq!(
|
||||
messages.first().expect("expected first message").content,
|
||||
"first"
|
||||
);
|
||||
assert_eq!(
|
||||
messages.get(1).expect("expected second message").content,
|
||||
"second"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -90,9 +90,9 @@ impl FileInfo {
|
||||
/// Replaces any non-alphanumeric characters (excluding '.' and '_') with underscores in
|
||||
/// both the stem and extension.
|
||||
fn sanitize_file_name(file_name: &str) -> String {
|
||||
if let Some(idx) = file_name.rfind('.') {
|
||||
let name = Self::sanitize_name_segment(&file_name[..idx]);
|
||||
let ext = Self::sanitize_name_segment(&file_name[idx + 1..]);
|
||||
if let Some((stem, ext)) = file_name.rsplit_once('.') {
|
||||
let name = Self::sanitize_name_segment(stem);
|
||||
let ext = Self::sanitize_name_segment(ext);
|
||||
if ext.is_empty() {
|
||||
name
|
||||
} else {
|
||||
@@ -321,7 +321,6 @@ mod tests {
|
||||
use anyhow::{self, Context};
|
||||
|
||||
use super::*;
|
||||
use crate::error::AppError;
|
||||
use crate::storage::store::testing::TestStorageManager;
|
||||
use axum::http::HeaderMap;
|
||||
use axum_typed_multipart::FieldMetadata;
|
||||
@@ -844,8 +843,7 @@ mod tests {
|
||||
assert!(file_info.sha256.len() == 64);
|
||||
let bytes = file_info
|
||||
.get_content_with_storage(test_storage.storage())
|
||||
.await
|
||||
.map_err(AppError::from)?;
|
||||
.await?;
|
||||
assert!(bytes.is_empty());
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -80,7 +80,7 @@ impl IngestionPayload {
|
||||
});
|
||||
|
||||
for (index, file) in files.into_iter().enumerate() {
|
||||
let is_last_file = index + 1 == file_count;
|
||||
let is_last_file = index.saturating_add(1) == file_count;
|
||||
if content_follows || !is_last_file {
|
||||
let Some(shared) = fields.as_ref() else {
|
||||
return Err(AppError::internal("shared ingest fields consumed early"));
|
||||
@@ -411,7 +411,9 @@ mod tests {
|
||||
)?;
|
||||
|
||||
assert_eq!(result.len(), 2);
|
||||
match (&result[0], &result[1]) {
|
||||
let first = result.first().expect("expected first payload");
|
||||
let second = result.get(1).expect("expected second payload");
|
||||
match (first, second) {
|
||||
(
|
||||
IngestionPayload::File {
|
||||
file_info: payload_file,
|
||||
@@ -499,8 +501,8 @@ mod tests {
|
||||
)?;
|
||||
|
||||
assert_eq!(result.len(), 2);
|
||||
assert!(matches!(result[0], IngestionPayload::File { .. }));
|
||||
assert!(matches!(result[1], IngestionPayload::File { .. }));
|
||||
assert!(matches!(result.first(), Some(IngestionPayload::File { .. })));
|
||||
assert!(matches!(result.get(1), Some(IngestionPayload::File { .. })));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -150,8 +150,8 @@ fn invalid_transition(state: TaskState, event: TaskTransition) -> AppError {
|
||||
))
|
||||
}
|
||||
|
||||
fn worker_id_for_bind(worker_id: &Option<String>) -> String {
|
||||
worker_id.as_deref().unwrap_or("").to_string()
|
||||
fn worker_id_for_bind(worker_id: Option<&String>) -> String {
|
||||
worker_id.cloned().unwrap_or_default()
|
||||
}
|
||||
|
||||
stored_object!(IngestionTask, "ingestion_task", {
|
||||
@@ -360,7 +360,7 @@ impl IngestionTask {
|
||||
"#;
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
let worker_id = worker_id_for_bind(&self.worker_id);
|
||||
let worker_id = worker_id_for_bind(self.worker_id.as_ref());
|
||||
let mut result = db
|
||||
.client
|
||||
.query(START_PROCESSING_QUERY)
|
||||
@@ -398,7 +398,7 @@ impl IngestionTask {
|
||||
"#;
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
let worker_id = worker_id_for_bind(&self.worker_id);
|
||||
let worker_id = worker_id_for_bind(self.worker_id.as_ref());
|
||||
let mut result = db
|
||||
.client
|
||||
.query(COMPLETE_QUERY)
|
||||
@@ -450,7 +450,7 @@ impl IngestionTask {
|
||||
)
|
||||
.unwrap_or(now);
|
||||
|
||||
let worker_id = worker_id_for_bind(&self.worker_id);
|
||||
let worker_id = worker_id_for_bind(self.worker_id.as_ref());
|
||||
let mut result = db
|
||||
.client
|
||||
.query(FAIL_QUERY)
|
||||
@@ -680,7 +680,9 @@ mod tests {
|
||||
IngestionTask::create_all_and_add_to_db(payloads, user_id, &db).await?;
|
||||
|
||||
assert_eq!(created.len(), 2);
|
||||
assert_ne!(created[0].id, created[1].id);
|
||||
let first = created.first().expect("expected first task");
|
||||
let second = created.get(1).expect("expected second task");
|
||||
assert_ne!(first.id, second.id);
|
||||
|
||||
for task in &created {
|
||||
let stored: Option<IngestionTask> = db.get_item::<IngestionTask>(&task.id).await?;
|
||||
|
||||
@@ -7,7 +7,9 @@ use std::fmt::Write;
|
||||
|
||||
use crate::{
|
||||
error::AppError, storage::db::SurrealDbClient,
|
||||
storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding, stored_object,
|
||||
storage::indexes::hnsw_index_overwrite_sql,
|
||||
storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding,
|
||||
storage::types::system_settings::SystemSettings, stored_object,
|
||||
utils::embedding::generate_embedding,
|
||||
};
|
||||
use async_openai::{config::OpenAIConfig, Client};
|
||||
@@ -189,13 +191,25 @@ impl KnowledgeEntity {
|
||||
embedding: Vec<f32>,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let emb = KnowledgeEntityEmbedding::new(&entity.id, embedding, entity.user_id.clone());
|
||||
let settings = SystemSettings::get_current(db).await?;
|
||||
KnowledgeEntityEmbedding::validate_dimension(
|
||||
&embedding,
|
||||
settings.embedding_dimensions as usize,
|
||||
)?;
|
||||
|
||||
let entity_id = entity.id.clone();
|
||||
let emb = KnowledgeEntityEmbedding::new(
|
||||
&entity_id,
|
||||
entity.source_id.clone(),
|
||||
embedding,
|
||||
entity.user_id.clone(),
|
||||
);
|
||||
|
||||
let query = format!(
|
||||
"
|
||||
BEGIN TRANSACTION;
|
||||
CREATE type::thing('{entity_table}', $entity_id) CONTENT $entity;
|
||||
CREATE type::thing('{emb_table}', $emb_id) CONTENT $emb;
|
||||
UPSERT type::thing('{emb_table}', $entity_id) CONTENT $emb;
|
||||
COMMIT TRANSACTION;
|
||||
",
|
||||
entity_table = Self::table_name(),
|
||||
@@ -204,9 +218,8 @@ impl KnowledgeEntity {
|
||||
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("entity_id", entity.id.clone()))
|
||||
.bind(("entity_id", entity_id))
|
||||
.bind(("entity", entity))
|
||||
.bind(("emb_id", emb.id.clone()))
|
||||
.bind(("emb", emb))
|
||||
.await
|
||||
.map_err(AppError::Database)?
|
||||
@@ -275,12 +288,29 @@ impl KnowledgeEntity {
|
||||
db_client: &SurrealDbClient,
|
||||
ai_client: &Client<OpenAIConfig>,
|
||||
) -> Result<(), AppError> {
|
||||
let embedding_input = format!(
|
||||
"name: {name}, description: {description}, type: {entity_type:?}",
|
||||
);
|
||||
let embedding_input = format!(
|
||||
"name: {name}, description: {description}, type: {entity_type:?}",
|
||||
);
|
||||
let embedding = generate_embedding(ai_client, &embedding_input, db_client).await?;
|
||||
let user_id = Self::get_user_id_by_id(id, db_client).await?;
|
||||
let emb = KnowledgeEntityEmbedding::new(id, embedding, user_id);
|
||||
|
||||
let entity: KnowledgeEntity = db_client
|
||||
.get_item(id)
|
||||
.await
|
||||
.map_err(AppError::Database)?
|
||||
.ok_or_else(|| AppError::NotFound(format!("entity {id} not found")))?;
|
||||
|
||||
let settings = SystemSettings::get_current(db_client).await?;
|
||||
KnowledgeEntityEmbedding::validate_dimension(
|
||||
&embedding,
|
||||
settings.embedding_dimensions as usize,
|
||||
)?;
|
||||
|
||||
let emb = KnowledgeEntityEmbedding::new(
|
||||
id,
|
||||
entity.source_id,
|
||||
embedding,
|
||||
entity.user_id,
|
||||
);
|
||||
|
||||
let now = Utc::now();
|
||||
|
||||
@@ -293,7 +323,7 @@ impl KnowledgeEntity {
|
||||
description = $description,
|
||||
updated_at = $updated_at,
|
||||
entity_type = $entity_type;
|
||||
UPSERT type::thing($emb_table, $emb_id) CONTENT $emb;
|
||||
UPSERT type::thing($emb_table, $id) CONTENT $emb;
|
||||
COMMIT TRANSACTION;",
|
||||
)
|
||||
.bind(("table", Self::table_name()))
|
||||
@@ -302,33 +332,16 @@ impl KnowledgeEntity {
|
||||
.bind(("name", name.to_string()))
|
||||
.bind(("updated_at", surrealdb::Datetime::from(now)))
|
||||
.bind(("entity_type", entity_type.to_owned()))
|
||||
.bind(("emb_id", emb.id.clone()))
|
||||
.bind(("emb", emb))
|
||||
.bind(("description", description.to_string()))
|
||||
.await?;
|
||||
.await
|
||||
.map_err(AppError::Database)?
|
||||
.check()
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_user_id_by_id(id: &str, db_client: &SurrealDbClient) -> Result<String, AppError> {
|
||||
#[derive(Deserialize)]
|
||||
struct Row {
|
||||
user_id: String,
|
||||
}
|
||||
|
||||
let mut response = db_client
|
||||
.client
|
||||
.query("SELECT user_id FROM type::thing($table, $id) LIMIT 1")
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("id", id.to_string()))
|
||||
.await
|
||||
.map_err(AppError::Database)?;
|
||||
let rows: Vec<Row> = response.take(0).map_err(AppError::Database)?;
|
||||
rows.first()
|
||||
.map(|r| r.user_id.clone())
|
||||
.ok_or_else(|| AppError::internal("user not found for entity"))
|
||||
}
|
||||
|
||||
/// Re-creates embeddings for all knowledge entities in the database.
|
||||
///
|
||||
/// This is a costly operation that should be run in the background. It follows the same
|
||||
@@ -359,7 +372,7 @@ impl KnowledgeEntity {
|
||||
info!("Found {total_entities} entities to process.");
|
||||
|
||||
// Generate all new embeddings in memory
|
||||
let mut new_embeddings: HashMap<String, (Vec<f32>, String)> = HashMap::new();
|
||||
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new();
|
||||
info!("Generating new embeddings for all entities...");
|
||||
for entity in &all_entities {
|
||||
let embedding_input = format!(
|
||||
@@ -387,7 +400,14 @@ impl KnowledgeEntity {
|
||||
error!("{err_msg}");
|
||||
return Err(AppError::internal(err_msg));
|
||||
}
|
||||
new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone()));
|
||||
new_embeddings.insert(
|
||||
entity.id.clone(),
|
||||
(
|
||||
embedding,
|
||||
entity.user_id.clone(),
|
||||
entity.source_id.clone(),
|
||||
),
|
||||
);
|
||||
}
|
||||
info!("Successfully generated all new embeddings.");
|
||||
|
||||
@@ -396,7 +416,7 @@ impl KnowledgeEntity {
|
||||
let mut transaction_query = String::from("BEGIN TRANSACTION;");
|
||||
|
||||
// Add all update statements to the embedding table
|
||||
for (id, (embedding, user_id)) in new_embeddings {
|
||||
for (id, (embedding, user_id, source_id)) in new_embeddings {
|
||||
let embedding = serde_json::to_string(&embedding).map_err(|e| {
|
||||
AppError::internal(format!("embedding serialization failed: {e}"))
|
||||
})?;
|
||||
@@ -406,6 +426,7 @@ impl KnowledgeEntity {
|
||||
entity_id = type::thing('knowledge_entity', '{id}'), \
|
||||
embedding = {embedding}, \
|
||||
user_id = '{user_id}', \
|
||||
source_id = '{source_id}', \
|
||||
created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \
|
||||
updated_at = time::now();",
|
||||
)
|
||||
@@ -414,7 +435,12 @@ impl KnowledgeEntity {
|
||||
|
||||
write!(
|
||||
transaction_query,
|
||||
"DEFINE INDEX OVERWRITE idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};",
|
||||
"{}",
|
||||
hnsw_index_overwrite_sql(
|
||||
"idx_embedding_knowledge_entity_embedding",
|
||||
KnowledgeEntityEmbedding::table_name(),
|
||||
new_dimensions as usize,
|
||||
)
|
||||
)
|
||||
.map_err(AppError::internal)?;
|
||||
|
||||
@@ -431,6 +457,7 @@ impl KnowledgeEntity {
|
||||
///
|
||||
/// This variant uses the application's configured embedding provider (FastEmbed, OpenAI, etc.)
|
||||
/// instead of directly calling OpenAI. Used during startup when embedding configuration changes.
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn update_all_embeddings_with_provider(
|
||||
db: &SurrealDbClient,
|
||||
provider: &crate::utils::embedding::EmbeddingProvider,
|
||||
@@ -453,7 +480,7 @@ impl KnowledgeEntity {
|
||||
info!(entities = total_entities, "Found entities to process");
|
||||
|
||||
// Generate all new embeddings in memory
|
||||
let mut new_embeddings: HashMap<String, (Vec<f32>, String)> = HashMap::new();
|
||||
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new();
|
||||
info!("Generating new embeddings for all entities...");
|
||||
|
||||
for (i, entity) in all_entities.iter().enumerate() {
|
||||
@@ -484,7 +511,14 @@ impl KnowledgeEntity {
|
||||
error!("{err_msg}");
|
||||
return Err(AppError::internal(err_msg));
|
||||
}
|
||||
new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone()));
|
||||
new_embeddings.insert(
|
||||
entity.id.clone(),
|
||||
(
|
||||
embedding,
|
||||
entity.user_id.clone(),
|
||||
entity.source_id.clone(),
|
||||
),
|
||||
);
|
||||
}
|
||||
info!("Successfully generated all new embeddings.");
|
||||
|
||||
@@ -517,7 +551,7 @@ impl KnowledgeEntity {
|
||||
info!("Applying embedding updates in a transaction...");
|
||||
let mut transaction_query = String::from("BEGIN TRANSACTION;");
|
||||
|
||||
for (id, (embedding, user_id)) in new_embeddings {
|
||||
for (id, (embedding, user_id, source_id)) in new_embeddings {
|
||||
let embedding = serde_json::to_string(&embedding).map_err(|e| {
|
||||
AppError::internal(format!("embedding serialization failed: {e}"))
|
||||
})?;
|
||||
@@ -527,6 +561,7 @@ impl KnowledgeEntity {
|
||||
entity_id = type::thing('knowledge_entity', '{id}'), \
|
||||
embedding = {embedding}, \
|
||||
user_id = '{user_id}', \
|
||||
source_id = '{source_id}', \
|
||||
created_at = time::now(), \
|
||||
updated_at = time::now();",
|
||||
)
|
||||
@@ -535,7 +570,12 @@ impl KnowledgeEntity {
|
||||
|
||||
write!(
|
||||
transaction_query,
|
||||
"DEFINE INDEX OVERWRITE idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};",
|
||||
"{}",
|
||||
hnsw_index_overwrite_sql(
|
||||
"idx_embedding_knowledge_entity_embedding",
|
||||
KnowledgeEntityEmbedding::table_name(),
|
||||
new_dimensions,
|
||||
)
|
||||
)
|
||||
.map_err(AppError::internal)?;
|
||||
|
||||
@@ -559,10 +599,21 @@ mod tests {
|
||||
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||
use super::*;
|
||||
use crate::storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding;
|
||||
use crate::storage::types::system_settings::SystemSettings;
|
||||
use anyhow::{self, Context};
|
||||
use serde_json::json;
|
||||
use uuid::Uuid;
|
||||
|
||||
async fn configure_embedding_dimension(
|
||||
db: &SurrealDbClient,
|
||||
dimension: u32,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut settings = SystemSettings::get_current(db).await?;
|
||||
settings.embedding_dimensions = dimension;
|
||||
SystemSettings::update(db, settings).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_knowledge_entity_creation() -> anyhow::Result<()> {
|
||||
let source_id = "source123".to_string();
|
||||
@@ -656,14 +707,15 @@ mod tests {
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
let source_id = "source123".to_string();
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
configure_embedding_dimension(&db, 5).await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 5)
|
||||
.await
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
|
||||
let source_id = "source123".to_string();
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let entity1 = KnowledgeEntity::new(
|
||||
source_id.clone(),
|
||||
"Entity 1".to_string(),
|
||||
@@ -763,6 +815,7 @@ mod tests {
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
configure_embedding_dimension(&db, 3).await.expect("configure dim");
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
@@ -847,6 +900,7 @@ mod tests {
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
configure_embedding_dimension(&db, 3).await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
@@ -914,6 +968,7 @@ mod tests {
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
configure_embedding_dimension(&db, 3).await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
@@ -1012,6 +1067,7 @@ mod tests {
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
configure_embedding_dimension(&db, 3).await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
|
||||
@@ -2,11 +2,17 @@ use std::collections::HashMap;
|
||||
|
||||
use surrealdb::RecordId;
|
||||
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||
use crate::{
|
||||
error::AppError,
|
||||
storage::{db::SurrealDbClient, indexes::hnsw_index_redefine_transaction_sql},
|
||||
stored_object,
|
||||
};
|
||||
|
||||
stored_object!(KnowledgeEntityEmbedding, "knowledge_entity_embedding", {
|
||||
entity_id: RecordId,
|
||||
embedding: Vec<f32>,
|
||||
/// Denormalized source id for bulk deletes
|
||||
source_id: String,
|
||||
/// Denormalized user id for query scoping
|
||||
user_id: String
|
||||
});
|
||||
@@ -17,12 +23,10 @@ impl KnowledgeEntityEmbedding {
|
||||
db: &SurrealDbClient,
|
||||
dimension: usize,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"BEGIN TRANSACTION;
|
||||
REMOVE INDEX IF EXISTS idx_embedding_knowledge_entity_embedding ON TABLE {table};
|
||||
DEFINE INDEX idx_embedding_knowledge_entity_embedding ON TABLE {table} FIELDS embedding HNSW DIMENSION {dimension};
|
||||
COMMIT TRANSACTION;",
|
||||
table = Self::table_name(),
|
||||
let query = hnsw_index_redefine_transaction_sql(
|
||||
"idx_embedding_knowledge_entity_embedding",
|
||||
Self::table_name(),
|
||||
dimension,
|
||||
);
|
||||
|
||||
let res = db.client.query(query).await.map_err(AppError::Database)?;
|
||||
@@ -31,16 +35,36 @@ impl KnowledgeEntityEmbedding {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create a new knowledge entity embedding
|
||||
/// Validates that an embedding vector matches the configured HNSW dimension.
|
||||
#[allow(clippy::result_large_err)]
|
||||
pub fn validate_dimension(embedding: &[f32], expected: usize) -> Result<(), AppError> {
|
||||
if embedding.len() != expected {
|
||||
return Err(AppError::Validation(format!(
|
||||
"embedding dimension mismatch: got {}, expected {expected}",
|
||||
embedding.len()
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create a new knowledge entity embedding.
|
||||
///
|
||||
/// The embedding record id equals `entity_id` so each entity has at most one embedding row.
|
||||
#[must_use]
|
||||
pub fn new(entity_id: &str, embedding: Vec<f32>, user_id: String) -> Self {
|
||||
pub fn new(
|
||||
entity_id: &str,
|
||||
source_id: String,
|
||||
embedding: Vec<f32>,
|
||||
user_id: String,
|
||||
) -> Self {
|
||||
let now = Utc::now();
|
||||
Self {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
id: entity_id.to_owned(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
entity_id: RecordId::from_table_key("knowledge_entity", entity_id),
|
||||
embedding,
|
||||
source_id,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
@@ -73,8 +97,6 @@ impl KnowledgeEntityEmbedding {
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
|
||||
let ids_list: Vec<RecordId> = entity_ids.to_vec();
|
||||
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE entity_id INSIDE $entity_ids",
|
||||
Self::table_name()
|
||||
@@ -82,7 +104,7 @@ impl KnowledgeEntityEmbedding {
|
||||
let mut result = db
|
||||
.client
|
||||
.query(query)
|
||||
.bind(("entity_ids", ids_list))
|
||||
.bind(("entity_ids", entity_ids.to_vec()))
|
||||
.await
|
||||
.map_err(AppError::Database)?;
|
||||
let embeddings: Vec<Self> = result.take(0).map_err(AppError::Database)?;
|
||||
@@ -106,32 +128,28 @@ impl KnowledgeEntityEmbedding {
|
||||
.query(query)
|
||||
.bind(("entity_id", entity_id.clone()))
|
||||
.await
|
||||
.map_err(AppError::Database)?
|
||||
.check()
|
||||
.map_err(AppError::Database)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Delete embeddings by source_id (via joining to knowledge_entity table)
|
||||
/// Delete all embeddings with the given denormalized `source_id`.
|
||||
pub async fn delete_by_source_id(
|
||||
source_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
#[derive(Deserialize)]
|
||||
struct IdRow {
|
||||
id: RecordId,
|
||||
}
|
||||
|
||||
let query = "SELECT id FROM knowledge_entity WHERE source_id = $source_id";
|
||||
let mut res = db
|
||||
.client
|
||||
let query = format!(
|
||||
"DELETE FROM {} WHERE source_id = $source_id",
|
||||
Self::table_name()
|
||||
);
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await
|
||||
.map_err(AppError::Database)?
|
||||
.check()
|
||||
.map_err(AppError::Database)?;
|
||||
let ids: Vec<IdRow> = res.take(0).map_err(AppError::Database)?;
|
||||
|
||||
for row in ids {
|
||||
Self::delete_by_entity_id(&row.id, db).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -142,6 +160,7 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::storage::db::SurrealDbClient;
|
||||
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||
use crate::storage::types::system_settings::SystemSettings;
|
||||
use anyhow::{self, Context};
|
||||
use chrono::Utc;
|
||||
use surrealdb::Value as SurrealValue;
|
||||
@@ -161,6 +180,24 @@ mod tests {
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
async fn setup_test_db_with_embedding_dimension(
|
||||
dimension: u32,
|
||||
) -> anyhow::Result<SurrealDbClient> {
|
||||
let db = setup_test_db().await?;
|
||||
let mut settings = SystemSettings::get_current(&db).await?;
|
||||
settings.embedding_dimensions = dimension;
|
||||
SystemSettings::update(&db, settings).await?;
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
async fn prepare_test_db(dimension: u32) -> anyhow::Result<SurrealDbClient> {
|
||||
let db = setup_test_db_with_embedding_dimension(dimension).await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, dimension as usize)
|
||||
.await
|
||||
.with_context(|| format!("set test index dimension to {dimension}"))?;
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
fn build_knowledge_entity_with_id(
|
||||
key: &str,
|
||||
source_id: &str,
|
||||
@@ -179,12 +216,27 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_uses_entity_id_as_record_id() {
|
||||
let emb = KnowledgeEntityEmbedding::new(
|
||||
"entity-abc",
|
||||
"source-1".to_owned(),
|
||||
vec![0.1, 0.2],
|
||||
"user-1".to_owned(),
|
||||
);
|
||||
assert_eq!(emb.id, "entity-abc");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_dimension_rejects_mismatch() {
|
||||
let err = KnowledgeEntityEmbedding::validate_dimension(&[0.1, 0.2, 0.3], 2)
|
||||
.expect_err("expected dimension mismatch");
|
||||
assert!(matches!(err, AppError::Validation(_)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_and_get_by_entity_id() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.with_context(|| "set test index dimension".to_string())?;
|
||||
let db = prepare_test_db(3).await?;
|
||||
let user_id = "user_ke";
|
||||
let entity_key = "entity-1";
|
||||
let source_id = "source-ke";
|
||||
@@ -203,7 +255,9 @@ mod tests {
|
||||
.with_context(|| "Failed to get embedding by entity_id".to_string())?
|
||||
.ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?;
|
||||
|
||||
assert_eq!(fetched.id, entity_key);
|
||||
assert_eq!(fetched.user_id, user_id);
|
||||
assert_eq!(fetched.source_id, source_id);
|
||||
assert_eq!(fetched.entity_id, entity_rid);
|
||||
assert_eq!(fetched.embedding, embedding_vec);
|
||||
|
||||
@@ -212,10 +266,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_entity_id() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.with_context(|| "set test index dimension".to_string())?;
|
||||
let db = prepare_test_db(3).await?;
|
||||
let user_id = "user_ke";
|
||||
let entity_key = "entity-delete";
|
||||
let source_id = "source-del";
|
||||
@@ -247,15 +298,11 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_with_embedding_creates_entity_and_embedding() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let db = prepare_test_db(3).await?;
|
||||
let user_id = "user_store";
|
||||
let source_id = "source_store";
|
||||
let embedding = vec![0.2_f32, 0.3, 0.4];
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, embedding.len())
|
||||
.await
|
||||
.with_context(|| "set test index dimension".to_string())?;
|
||||
|
||||
let entity = build_knowledge_entity_with_id("entity-store", source_id, user_id);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), embedding.clone(), &db)
|
||||
@@ -274,18 +321,30 @@ mod tests {
|
||||
.with_context(|| "Failed to fetch embedding".to_string())?;
|
||||
let stored_embedding =
|
||||
stored_embedding.ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?;
|
||||
assert_eq!(stored_embedding.id, entity.id);
|
||||
assert_eq!(stored_embedding.user_id, user_id);
|
||||
assert_eq!(stored_embedding.source_id, source_id);
|
||||
assert_eq!(stored_embedding.entity_id, entity_rid);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_with_embedding_rejects_wrong_dimension() -> anyhow::Result<()> {
|
||||
let db = prepare_test_db(3).await?;
|
||||
|
||||
let entity = build_knowledge_entity_with_id("entity-dim", "source-dim", "user-dim");
|
||||
let result =
|
||||
KnowledgeEntity::store_with_embedding(entity, vec![0.1, 0.2], &db).await;
|
||||
|
||||
assert!(matches!(result, Err(AppError::Validation(_))));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.with_context(|| "set test index dimension".to_string())?;
|
||||
let db = prepare_test_db(3).await?;
|
||||
let user_id = "user_ke";
|
||||
let source_id = "shared-ke";
|
||||
let other_source = "other-ke";
|
||||
@@ -363,6 +422,10 @@ mod tests {
|
||||
idx_sql.contains("DIMENSION 16"),
|
||||
"expected index definition to contain new dimension, got: {idx_sql}"
|
||||
);
|
||||
assert!(
|
||||
idx_sql.contains("DIST COSINE"),
|
||||
"expected index definition to use cosine distance, got: {idx_sql}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -374,10 +437,7 @@ mod tests {
|
||||
entity_id: KnowledgeEntity,
|
||||
}
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.with_context(|| "set test index dimension".to_string())?;
|
||||
let db = prepare_test_db(3).await?;
|
||||
let user_id = "user_ke";
|
||||
let entity_key = "entity-fetch";
|
||||
let source_id = "source-fetch";
|
||||
@@ -412,4 +472,47 @@ mod tests {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_upsert_replaces_existing_embedding_row() -> anyhow::Result<()> {
|
||||
let db = prepare_test_db(3).await?;
|
||||
|
||||
let user_id = "user-upsert";
|
||||
let source_id = "source-upsert";
|
||||
let entity = build_knowledge_entity_with_id("entity-upsert", source_id, user_id);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![1.0_f32, 0.0, 0.0], &db)
|
||||
.await
|
||||
.with_context(|| "initial store".to_string())?;
|
||||
|
||||
let replacement = KnowledgeEntityEmbedding::new(
|
||||
&entity.id,
|
||||
source_id.to_owned(),
|
||||
vec![0.0, 1.0, 0.0],
|
||||
user_id.to_owned(),
|
||||
);
|
||||
db.upsert_item(replacement)
|
||||
.await
|
||||
.with_context(|| "upsert replacement embedding".to_string())?;
|
||||
|
||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
let rows: Vec<KnowledgeEntityEmbedding> = db
|
||||
.client
|
||||
.query(format!(
|
||||
"SELECT * FROM {} WHERE entity_id = $entity_id",
|
||||
KnowledgeEntityEmbedding::table_name()
|
||||
))
|
||||
.bind(("entity_id", entity_rid))
|
||||
.await
|
||||
.with_context(|| "count embeddings".to_string())?
|
||||
.take(0)
|
||||
.with_context(|| "take embeddings".to_string())?;
|
||||
|
||||
assert_eq!(rows.len(), 1);
|
||||
let row = rows.first().expect("expected one embedding row");
|
||||
assert_eq!(row.id, entity.id);
|
||||
assert_eq!(row.embedding, vec![0.0, 1.0, 0.0]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use crate::storage::types::serde_helpers::deserialize_flexible_id;
|
||||
use crate::storage::types::user::User;
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
@@ -40,7 +41,21 @@ impl KnowledgeRelationship {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn store_relationship(&self, db_client: &SurrealDbClient) -> Result<(), AppError> {
|
||||
User::get_and_validate_knowledge_entity(
|
||||
&self.in_,
|
||||
&self.metadata.user_id,
|
||||
db_client,
|
||||
)
|
||||
.await?;
|
||||
User::get_and_validate_knowledge_entity(
|
||||
&self.out,
|
||||
&self.metadata.user_id,
|
||||
db_client,
|
||||
)
|
||||
.await?;
|
||||
|
||||
db_client
|
||||
.client
|
||||
.query(
|
||||
@@ -61,22 +76,30 @@ impl KnowledgeRelationship {
|
||||
.bind(("user_id", self.metadata.user_id.clone()))
|
||||
.bind(("source_id", self.metadata.source_id.clone()))
|
||||
.bind(("relationship_type", self.metadata.relationship_type.clone()))
|
||||
.await?
|
||||
.check()?;
|
||||
.await
|
||||
.map_err(AppError::Database)?
|
||||
.check()
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn delete_relationships_by_source_id(
|
||||
source_id: &str,
|
||||
user_id: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
db_client
|
||||
.client
|
||||
.query("DELETE FROM relates_to WHERE metadata.source_id = $source_id")
|
||||
.query(
|
||||
"DELETE FROM relates_to WHERE metadata.source_id = $source_id AND metadata.user_id = $user_id",
|
||||
)
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await?
|
||||
.check()?;
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await
|
||||
.map_err(AppError::Database)?
|
||||
.check()
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -86,39 +109,37 @@ impl KnowledgeRelationship {
|
||||
user_id: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let mut authorized_result = db_client
|
||||
let mut delete_result = db_client
|
||||
.client
|
||||
.query(
|
||||
"SELECT * FROM relates_to WHERE id = type::thing('relates_to', $id) AND metadata.user_id = $user_id",
|
||||
"DELETE type::thing('relates_to', $id) WHERE metadata.user_id = $user_id RETURN BEFORE;",
|
||||
)
|
||||
.bind(("id", id.to_owned()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?;
|
||||
let authorized: Vec<KnowledgeRelationship> = authorized_result.take(0)?;
|
||||
.await
|
||||
.map_err(AppError::Database)?;
|
||||
let deleted: Vec<KnowledgeRelationship> =
|
||||
delete_result.take(0).map_err(AppError::Database)?;
|
||||
|
||||
if authorized.is_empty() {
|
||||
let mut exists_result = db_client
|
||||
.client
|
||||
.query("SELECT * FROM type::thing('relates_to', $id)")
|
||||
.bind(("id", id.to_owned()))
|
||||
.await?;
|
||||
let existing: Option<KnowledgeRelationship> = exists_result.take(0)?;
|
||||
if !deleted.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if existing.is_some() {
|
||||
Err(AppError::Auth(
|
||||
"Not authorized to delete relationship".into(),
|
||||
))
|
||||
} else {
|
||||
Err(AppError::NotFound(format!("Relationship {id} not found")))
|
||||
}
|
||||
let mut exists_result = db_client
|
||||
.client
|
||||
.query("SELECT * FROM type::thing('relates_to', $id)")
|
||||
.bind(("id", id.to_owned()))
|
||||
.await
|
||||
.map_err(AppError::Database)?;
|
||||
let existing: Option<KnowledgeRelationship> =
|
||||
exists_result.take(0).map_err(AppError::Database)?;
|
||||
|
||||
if existing.is_some() {
|
||||
Err(AppError::Auth(
|
||||
"Not authorized to delete relationship".into(),
|
||||
))
|
||||
} else {
|
||||
db_client
|
||||
.client
|
||||
.query("DELETE type::thing('relates_to', $id)")
|
||||
.bind(("id", id.to_owned()))
|
||||
.await?
|
||||
.check()?;
|
||||
Ok(())
|
||||
Err(AppError::NotFound(format!("Relationship {id} not found")))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -158,11 +179,14 @@ mod tests {
|
||||
result.take(0).expect("failed to take relationship by id")
|
||||
}
|
||||
|
||||
async fn create_test_entity(name: &str, db_client: &SurrealDbClient) -> anyhow::Result<String> {
|
||||
async fn create_test_entity(
|
||||
name: &str,
|
||||
user_id: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
) -> anyhow::Result<String> {
|
||||
let source_id = "source123".to_string();
|
||||
let description = format!("Description for {name}");
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let entity = KnowledgeEntity::new(
|
||||
source_id,
|
||||
@@ -170,7 +194,7 @@ mod tests {
|
||||
description,
|
||||
entity_type,
|
||||
None,
|
||||
user_id,
|
||||
user_id.to_string(),
|
||||
);
|
||||
|
||||
let stored: Option<KnowledgeEntity> = db_client
|
||||
@@ -211,18 +235,18 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_store_and_verify_by_source_id() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "user123";
|
||||
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await?;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await?;
|
||||
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
||||
let entity2_id = create_test_entity("Entity 2", user_id, &db).await?;
|
||||
|
||||
let user_id = "user123".to_string();
|
||||
let source_id = "source123".to_string();
|
||||
let relationship_type = "references".to_string();
|
||||
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
entity1_id.clone(),
|
||||
entity2_id.clone(),
|
||||
user_id.clone(),
|
||||
user_id.to_string(),
|
||||
source_id.clone(),
|
||||
relationship_type,
|
||||
);
|
||||
@@ -257,16 +281,38 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_relationship_resists_query_injection() -> anyhow::Result<()> {
|
||||
async fn test_store_relationship_rejects_foreign_entity() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await?;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await?;
|
||||
let owner_entity = create_test_entity("Owner entity", "owner-user", &db).await?;
|
||||
let other_entity = create_test_entity("Other entity", "other-user", &db).await?;
|
||||
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
owner_entity,
|
||||
other_entity,
|
||||
"owner-user".to_string(),
|
||||
"source123".to_string(),
|
||||
"references".to_string(),
|
||||
);
|
||||
|
||||
let result = relationship.store_relationship(&db).await;
|
||||
assert!(matches!(result, Err(AppError::Auth(_))));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_relationship_resists_query_injection() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "user123";
|
||||
|
||||
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
||||
let entity2_id = create_test_entity("Entity 2", user_id, &db).await?;
|
||||
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
entity1_id,
|
||||
entity2_id,
|
||||
"user'123".to_string(),
|
||||
user_id.to_string(),
|
||||
"source123'; DELETE FROM relates_to; --".to_string(),
|
||||
"references'; UPDATE user SET admin = true; --".to_string(),
|
||||
);
|
||||
@@ -297,18 +343,18 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_store_and_delete_relationship() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "user123";
|
||||
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await?;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await?;
|
||||
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
||||
let entity2_id = create_test_entity("Entity 2", user_id, &db).await?;
|
||||
|
||||
let user_id = "user123".to_string();
|
||||
let source_id = "source123".to_string();
|
||||
let relationship_type = "references".to_string();
|
||||
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
entity1_id.clone(),
|
||||
entity2_id.clone(),
|
||||
user_id.clone(),
|
||||
user_id.to_string(),
|
||||
source_id.clone(),
|
||||
relationship_type,
|
||||
);
|
||||
@@ -319,9 +365,9 @@ mod tests {
|
||||
.with_context(|| "Failed to store relationship".to_string())?;
|
||||
|
||||
let mut existing_before_delete = db
|
||||
.query(format!(
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{user_id}' AND metadata.source_id = '{source_id}'"
|
||||
))
|
||||
.query("SELECT * FROM relates_to WHERE metadata.user_id = $user_id AND metadata.source_id = $source_id")
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.bind(("source_id", source_id.clone()))
|
||||
.await
|
||||
.with_context(|| "Query failed".to_string())?;
|
||||
let before_results: Vec<KnowledgeRelationship> =
|
||||
@@ -331,14 +377,14 @@ mod tests {
|
||||
"Relationship should exist before deletion"
|
||||
);
|
||||
|
||||
KnowledgeRelationship::delete_relationship_by_id(&relationship.id, &user_id, &db)
|
||||
KnowledgeRelationship::delete_relationship_by_id(&relationship.id, user_id, &db)
|
||||
.await
|
||||
.with_context(|| "Failed to delete relationship by ID".to_string())?;
|
||||
|
||||
let mut result = db
|
||||
.query(format!(
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{user_id}' AND metadata.source_id = '{source_id}'"
|
||||
))
|
||||
.query("SELECT * FROM relates_to WHERE metadata.user_id = $user_id AND metadata.source_id = $source_id")
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.bind(("source_id", source_id))
|
||||
.await
|
||||
.with_context(|| "Query failed".to_string())?;
|
||||
let results: Vec<KnowledgeRelationship> = result.take(0).unwrap_or_default();
|
||||
@@ -351,17 +397,17 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_delete_relationship_by_id_unauthorized() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await;
|
||||
let owner_user_id = "owner-user";
|
||||
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await?;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await?;
|
||||
let entity1_id = create_test_entity("Entity 1", owner_user_id, &db).await?;
|
||||
let entity2_id = create_test_entity("Entity 2", owner_user_id, &db).await?;
|
||||
|
||||
let owner_user_id = "owner-user".to_string();
|
||||
let source_id = "source123".to_string();
|
||||
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
entity1_id.clone(),
|
||||
entity2_id.clone(),
|
||||
owner_user_id.clone(),
|
||||
owner_user_id.to_string(),
|
||||
source_id,
|
||||
"references".to_string(),
|
||||
);
|
||||
@@ -372,9 +418,8 @@ mod tests {
|
||||
.with_context(|| "Failed to store relationship".to_string())?;
|
||||
|
||||
let mut before_attempt = db
|
||||
.query(format!(
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{owner_user_id}'"
|
||||
))
|
||||
.query("SELECT * FROM relates_to WHERE metadata.user_id = $user_id")
|
||||
.bind(("user_id", owner_user_id.to_string()))
|
||||
.await
|
||||
.with_context(|| "Query failed".to_string())?;
|
||||
let before_results: Vec<KnowledgeRelationship> = before_attempt.take(0).unwrap_or_default();
|
||||
@@ -398,9 +443,8 @@ mod tests {
|
||||
}
|
||||
|
||||
let mut after_attempt = db
|
||||
.query(format!(
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{owner_user_id}'"
|
||||
))
|
||||
.query("SELECT * FROM relates_to WHERE metadata.user_id = $user_id")
|
||||
.bind(("user_id", owner_user_id.to_string()))
|
||||
.await
|
||||
.with_context(|| "Query failed".to_string())?;
|
||||
let results: Vec<KnowledgeRelationship> = after_attempt.take(0).unwrap_or_default();
|
||||
@@ -416,19 +460,19 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_store_relationship_exists() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "user123";
|
||||
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await?;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await?;
|
||||
let entity3_id = create_test_entity("Entity 3", &db).await?;
|
||||
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
||||
let entity2_id = create_test_entity("Entity 2", user_id, &db).await?;
|
||||
let entity3_id = create_test_entity("Entity 3", user_id, &db).await?;
|
||||
|
||||
let user_id = "user123".to_string();
|
||||
let source_id = "source123".to_string();
|
||||
let different_source_id = "different_source".to_string();
|
||||
|
||||
let relationship1 = KnowledgeRelationship::new(
|
||||
entity1_id.clone(),
|
||||
entity2_id.clone(),
|
||||
user_id.clone(),
|
||||
user_id.to_string(),
|
||||
source_id.clone(),
|
||||
"references".to_string(),
|
||||
);
|
||||
@@ -436,7 +480,7 @@ mod tests {
|
||||
let relationship2 = KnowledgeRelationship::new(
|
||||
entity2_id.clone(),
|
||||
entity3_id.clone(),
|
||||
user_id.clone(),
|
||||
user_id.to_string(),
|
||||
source_id.clone(),
|
||||
"contains".to_string(),
|
||||
);
|
||||
@@ -444,7 +488,7 @@ mod tests {
|
||||
let different_relationship = KnowledgeRelationship::new(
|
||||
entity1_id.clone(),
|
||||
entity3_id.clone(),
|
||||
user_id.clone(),
|
||||
user_id.to_string(),
|
||||
different_source_id.clone(),
|
||||
"mentions".to_string(),
|
||||
);
|
||||
@@ -480,7 +524,7 @@ mod tests {
|
||||
before_delete_different.take(0).unwrap_or_default();
|
||||
assert_eq!(before_delete_different_rows.len(), 1);
|
||||
|
||||
KnowledgeRelationship::delete_relationships_by_source_id(&source_id, &db)
|
||||
KnowledgeRelationship::delete_relationships_by_source_id(&source_id, user_id, &db)
|
||||
.await
|
||||
.with_context(|| "Failed to delete relationships by source_id".to_string())?;
|
||||
|
||||
@@ -497,19 +541,60 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_relationships_by_source_id_scoped_to_user() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
let user_a = "user-a";
|
||||
let user_b = "user-b";
|
||||
let shared_source = "shared-source";
|
||||
|
||||
let a1 = create_test_entity("A1", user_a, &db).await?;
|
||||
let a2 = create_test_entity("A2", user_a, &db).await?;
|
||||
let b1 = create_test_entity("B1", user_b, &db).await?;
|
||||
let b2 = create_test_entity("B2", user_b, &db).await?;
|
||||
|
||||
let rel_a = KnowledgeRelationship::new(
|
||||
a1,
|
||||
a2,
|
||||
user_a.to_string(),
|
||||
shared_source.to_string(),
|
||||
"references".to_string(),
|
||||
);
|
||||
let rel_b = KnowledgeRelationship::new(
|
||||
b1,
|
||||
b2,
|
||||
user_b.to_string(),
|
||||
shared_source.to_string(),
|
||||
"references".to_string(),
|
||||
);
|
||||
|
||||
rel_a.store_relationship(&db).await?;
|
||||
rel_b.store_relationship(&db).await?;
|
||||
|
||||
KnowledgeRelationship::delete_relationships_by_source_id(shared_source, user_a, &db)
|
||||
.await?;
|
||||
|
||||
assert!(get_relationship_by_id(&rel_a.id, &db).await.is_none());
|
||||
assert!(get_relationship_by_id(&rel_b.id, &db).await.is_some());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_relationships_by_source_id_resists_query_injection() -> anyhow::Result<()>
|
||||
{
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "user123";
|
||||
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await?;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await?;
|
||||
let entity3_id = create_test_entity("Entity 3", &db).await?;
|
||||
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
||||
let entity2_id = create_test_entity("Entity 2", user_id, &db).await?;
|
||||
let entity3_id = create_test_entity("Entity 3", user_id, &db).await?;
|
||||
|
||||
let safe_relationship = KnowledgeRelationship::new(
|
||||
entity1_id.clone(),
|
||||
entity2_id.clone(),
|
||||
"user123".to_string(),
|
||||
user_id.to_string(),
|
||||
"safe_source".to_string(),
|
||||
"references".to_string(),
|
||||
);
|
||||
@@ -517,7 +602,7 @@ mod tests {
|
||||
let other_relationship = KnowledgeRelationship::new(
|
||||
entity2_id,
|
||||
entity3_id,
|
||||
"user123".to_string(),
|
||||
user_id.to_string(),
|
||||
"other_source".to_string(),
|
||||
"contains".to_string(),
|
||||
);
|
||||
@@ -531,9 +616,13 @@ mod tests {
|
||||
.await
|
||||
.expect("store other relationship");
|
||||
|
||||
KnowledgeRelationship::delete_relationships_by_source_id("safe_source' OR 1=1 --", &db)
|
||||
.await
|
||||
.expect("delete call should succeed");
|
||||
KnowledgeRelationship::delete_relationships_by_source_id(
|
||||
"safe_source' OR 1=1 --",
|
||||
user_id,
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
.expect("delete call should succeed");
|
||||
|
||||
let remaining_safe = get_relationship_by_id(&safe_relationship.id, &db).await;
|
||||
let remaining_other = get_relationship_by_id(&other_relationship.id, &db).await;
|
||||
|
||||
@@ -62,7 +62,7 @@ impl fmt::Display for Message {
|
||||
pub fn format_history(history: &[Message]) -> String {
|
||||
let estimated: usize = history
|
||||
.iter()
|
||||
.map(|m| m.content.len() + 10)
|
||||
.map(|m| m.content.len().saturating_add(10))
|
||||
.sum();
|
||||
let mut out = String::with_capacity(estimated);
|
||||
for (i, msg) in history.iter().enumerate() {
|
||||
|
||||
@@ -308,6 +308,7 @@ fn build_manifest_batches(manifest: &CorpusManifest) -> Result<ManifestBatches>
|
||||
entities.push(entity.clone());
|
||||
entity_embeddings.push(KnowledgeEntityEmbedding::new(
|
||||
&entity.id,
|
||||
entity.source_id.clone(),
|
||||
embedded_entity.embedding.clone(),
|
||||
entity.user_id.clone(),
|
||||
));
|
||||
|
||||
@@ -86,7 +86,7 @@ pub async fn delete_text_content(
|
||||
// Delete the text content and any related data
|
||||
TextChunk::delete_by_source_id(&text_content.id, &state.db).await?;
|
||||
KnowledgeEntity::delete_by_source_id(&text_content.id, &state.db).await?;
|
||||
KnowledgeRelationship::delete_relationships_by_source_id(&text_content.id, &state.db).await?;
|
||||
KnowledgeRelationship::delete_relationships_by_source_id(&text_content.id, &user.id, &state.db).await?;
|
||||
state
|
||||
.db
|
||||
.delete_item::<TextContent>(&text_content.id)
|
||||
|
||||
@@ -293,7 +293,7 @@ async fn reserve_task(
|
||||
payload: IngestionPayload,
|
||||
user_id: &str,
|
||||
) -> anyhow::Result<IngestionTask> {
|
||||
let task = IngestionTask::create_and_add_to_db(payload, user_id.into(), db).await?;
|
||||
let task = IngestionTask::create_and_add_to_db(payload, user_id, db).await?;
|
||||
let lease = task.lease_duration();
|
||||
let claimed = IngestionTask::claim_next_ready(db, worker_id, Utc::now(), lease)
|
||||
.await?
|
||||
|
||||
Reference in New Issue
Block a user