From 77b157b3b80c138ea51c0a6e40b265cd5f0b8a1f Mon Sep 17 00:00:00 2001 From: Per Stark Date: Mon, 14 Apr 2025 17:24:04 +0200 Subject: [PATCH] tests: testing all db interactions and types --- Cargo.lock | 134 ++++++- Cargo.toml | 2 +- common/Cargo.toml | 5 +- common/src/storage/db.rs | 101 +++++ common/src/storage/types/analytics.rs | 235 ++++++++++- common/src/storage/types/conversation.rs | 175 +++++++++ common/src/storage/types/file_info.rs | 364 +++++++++++++++++- common/src/storage/types/ingestion_payload.rs | 237 +++++++++++- common/src/storage/types/ingestion_task.rs | 195 +++++++++- common/src/storage/types/knowledge_entity.rs | 197 +++++++++- .../storage/types/knowledge_relationship.rs | 255 ++++++++++++ common/src/storage/types/message.rs | 127 +++++- common/src/storage/types/mod.rs | 2 +- common/src/storage/types/system_settings.rs | 216 +++++++++-- common/src/storage/types/text_chunk.rs | 172 +++++++++ common/src/storage/types/text_content.rs | 117 ++++++ common/src/storage/types/user.rs | 273 +++++++++++++ composite-retrieval/Cargo.toml | 4 +- composite-retrieval/src/graph.rs | 277 +++++++++++++ 19 files changed, 3017 insertions(+), 71 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 745e182..a648df8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "Inflector" @@ -190,6 +190,12 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7d5a26814d8dcb93b0e5a0ff3c6d80a8843bafb21b39e8e18a6f05471870e110" +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + [[package]] name = "argon2" version = "0.5.3" @@ -732,6 +738,12 @@ dependencies = [ "serde", ] +[[package]] +name = "bitmaps" +version = "3.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1d084b0137aaa901caf9f1e8b21daa6aa24d41cd806e111335541eff9683bd6" + [[package]] name = "bitvec" version = "1.0.1" @@ -1078,6 +1090,7 @@ dependencies = [ "thiserror", "tokio", "tracing", + "uuid", ] [[package]] @@ -1183,6 +1196,25 @@ dependencies = [ "libc", ] +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.20" @@ -1473,6 +1505,18 @@ dependencies = [ "num-traits", ] +[[package]] +name = "echodb" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d1eccc44ff21b80ca7e883ff57423a12610965a33637d5d0bef4adebcd81749" +dependencies = [ + "arc-swap", + "imbl", + "thiserror", + "tokio", +] + [[package]] name = "ego-tree" version = "0.10.0" @@ -1552,6 +1596,19 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "ext-sort" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcf73e44617eab501beba39234441a194cf138629d3b6447f81f573e1c3d0a13" +dependencies = [ + "log", + "rayon", + "rmp-serde", + "serde", + "tempfile", +] + [[package]] name = "fancy-regex" version = "0.13.0" @@ -2019,7 +2076,6 @@ dependencies = [ "composite-retrieval", "futures", "include_dir", - "json-stream-parser", "minijinja", "minijinja-autoreload", "minijinja-contrib", @@ -2374,6 +2430,28 @@ dependencies = [ "icu_properties", ] +[[package]] +name = "imbl" +version = "2.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978d142c8028edf52095703af2fad11d6f611af1246685725d6b850634647085" +dependencies = [ + "bitmaps", + "imbl-sized-chunks", + "rand_core", + "rand_xoshiro", + "version_check", +] + +[[package]] +name = "imbl-sized-chunks" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "144006fb58ed787dcae3f54575ff4349755b00ccc99f4b4873860b654be1ed63" +dependencies = [ + "bitmaps", +] + [[package]] name = "include_dir" version = "0.7.4" @@ -2523,15 +2601,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "json-stream-parser" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a70ab2b05e827e0604229fcf11b24560b036a21286a41517a6cac271f12a6a9" -dependencies = [ - "serde_json", -] - [[package]] name = "json5" version = "0.4.1" @@ -3740,12 +3809,41 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rand_xoshiro" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa" +dependencies = [ + "rand_core", +] + [[package]] name = "rawpointer" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "reblessive" version = "0.4.1" @@ -4038,6 +4136,17 @@ dependencies = [ "paste", ] +[[package]] +name = "rmp-serde" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52e599a477cf9840e92f2cde9a7189e67b42c57532749bf90aea6ec10facd4db" +dependencies = [ + "byteorder", + "rmp", + "serde", +] + [[package]] name = "rmpv" version = "1.3.0" @@ -4832,6 +4941,8 @@ dependencies = [ "dashmap 5.5.3", "deunicode", "dmp", + "echodb", + "ext-sort", "fst", "futures", "fuzzy-matcher", @@ -4875,6 +4986,7 @@ dependencies = [ "storekey", "subtle", "surrealdb-derive", + "tempfile", "thiserror", "tokio", "tracing", diff --git a/Cargo.toml b/Cargo.toml index 1b33f69..f740acf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,6 @@ serde_json = "1.0.128" thiserror = "1.0.63" anyhow = "1.0.94" tracing = "0.1.40" -surrealdb = "2.0.4" +surrealdb = { version = "2.0.4", features = ["kv-mem"] } futures = "0.3.31" async-openai = "0.24.1" diff --git a/common/Cargo.toml b/common/Cargo.toml index 240fca6..5a2ef9e 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -12,7 +12,7 @@ tracing = { workspace = true } anyhow = { workspace = true } thiserror = { workspace = true } serde_json = { workspace = true } -surrealdb = { workspace = true } +surrealdb = { workspace = true, features = ["kv-mem"] } async-openai = { workspace = true } futures = { workspace = true } @@ -35,3 +35,6 @@ minijinja = { version = "2.5.0", features = ["loader", "multi_template"] } minijinja-autoreload = "2.5.0" minijinja-embed = { version = "2.8.0" } minijinja-contrib = { version = "2.6.0", features = ["datetime", "timezone"] } + +[features] +test-utils = [] diff --git a/common/src/storage/db.rs b/common/src/storage/db.rs index e43c4b2..fdc4791 100644 --- a/common/src/storage/db.rs +++ b/common/src/storage/db.rs @@ -189,3 +189,104 @@ impl Deref for SurrealDbClient { &self.client } } + +#[cfg(any(test, feature = "test-utils"))] +impl SurrealDbClient { + /// Create an in-memory SurrealDB client for testing. + pub async fn memory(namespace: &str, database: &str) -> Result { + let db = connect("mem://").await?; + + db.use_ns(namespace).use_db(database).await?; + + Ok(SurrealDbClient { client: db }) + } +} + +#[cfg(test)] +mod tests { + use crate::stored_object; + + use super::*; + use uuid::Uuid; + + stored_object!(Dummy, "dummy", { + name: String + }); + + #[tokio::test] + async fn test_initialization_and_crud() { + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); // ensures isolation per test run + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Call your initialization + db.ensure_initialized() + .await + .expect("Failed to initialize schema"); + + // Test basic CRUD + let dummy = Dummy { + id: "abc".to_string(), + name: "first".to_string(), + created_at: Utc::now(), + updated_at: Utc::now(), + }; + + // Store + let stored = db.store_item(dummy.clone()).await.expect("Failed to store"); + assert!(stored.is_some()); + + // Read + let fetched = db + .get_item::(&dummy.id) + .await + .expect("Failed to fetch"); + assert_eq!(fetched, Some(dummy.clone())); + + // Read all + let all = db + .get_all_stored_items::() + .await + .expect("Failed to fetch all"); + assert!(all.contains(&dummy)); + + // Delete + let deleted = db + .delete_item::(&dummy.id) + .await + .expect("Failed to delete"); + assert_eq!(deleted, Some(dummy)); + + // After delete, should not be present + let fetch_post = db + .get_item::("abc") + .await + .expect("Failed fetch post delete"); + assert!(fetch_post.is_none()); + } + + #[tokio::test] + async fn test_setup_auth() { + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); // ensures isolation per test run + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Should not panic or fail + db.setup_auth().await.expect("Failed to setup auth"); + } + + #[tokio::test] + async fn test_build_indexes() { + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + db.build_indexes().await.expect("Failed to build indexes"); + } +} diff --git a/common/src/storage/types/analytics.rs b/common/src/storage/types/analytics.rs index 43dee3a..56b34d5 100644 --- a/common/src/storage/types/analytics.rs +++ b/common/src/storage/types/analytics.rs @@ -1,4 +1,5 @@ use crate::storage::types::{file_info::deserialize_flexible_id, user::User, StoredObject}; +use axum::async_trait; use serde::{Deserialize, Serialize}; use crate::{error::AppError, storage::db::SurrealDbClient}; @@ -11,32 +12,40 @@ pub struct Analytics { pub visitors: i64, } +#[async_trait] +impl StoredObject for Analytics { + fn table_name() -> &'static str { + "analytics" + } + + fn get_id(&self) -> &str { + &self.id + } +} + impl Analytics { pub async fn ensure_initialized(db: &SurrealDbClient) -> Result { - let analytics = db.select(("analytics", "current")).await?; + let analytics = db.get_item::("current").await?; if analytics.is_none() { - let created: Option = db - .create(("analytics", "current")) - .content(Analytics { - id: "current".to_string(), - visitors: 0, - page_loads: 0, - }) - .await?; + let created_analytics = Analytics { + id: "current".to_string(), + visitors: 0, + page_loads: 0, + }; - return created.ok_or(AppError::Validation("Failed to initialize settings".into())); - }; + let stored: Option = db.store_item(created_analytics).await?; + return stored.ok_or(AppError::Validation( + "Failed to initialize analytics".into(), + )); + } - analytics.ok_or(AppError::Validation("Failed to initialize settings".into())) + analytics.ok_or(AppError::Validation( + "Failed to initialize analytics".into(), + )) } pub async fn get_current(db: &SurrealDbClient) -> Result { - let analytics: Option = db - .client - .query("SELECT * FROM type::thing('analytics', 'current')") - .await? - .take(0)?; - + let analytics: Option = db.get_item("current").await?; analytics.ok_or(AppError::NotFound("Analytics not found".into())) } @@ -61,6 +70,7 @@ impl Analytics { } pub async fn get_users_amount(db: &SurrealDbClient) -> Result { + // We need to use a direct query for COUNT aggregation #[derive(Debug, Deserialize)] struct CountResult { count: i64, @@ -76,3 +86,192 @@ impl Analytics { Ok(result.map(|r| r.count).unwrap_or(0)) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::stored_object; + use uuid::Uuid; + + stored_object!(TestUser, "user", { + email: String, + password: String, + user_id: String + }); + + #[tokio::test] + async fn test_analytics_initialization() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Test initialization of analytics + let analytics = Analytics::ensure_initialized(&db) + .await + .expect("Failed to initialize analytics"); + + // Verify initial state after initialization + assert_eq!(analytics.id, "current"); + assert_eq!(analytics.page_loads, 0); + assert_eq!(analytics.visitors, 0); + + // Test idempotency - ensure calling it again doesn't change anything + let analytics_again = Analytics::ensure_initialized(&db) + .await + .expect("Failed to get analytics after initialization"); + + assert_eq!(analytics.id, analytics_again.id); + assert_eq!(analytics.page_loads, analytics_again.page_loads); + assert_eq!(analytics.visitors, analytics_again.visitors); + } + + #[tokio::test] + async fn test_get_current_analytics() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Initialize analytics + Analytics::ensure_initialized(&db) + .await + .expect("Failed to initialize analytics"); + + // Test get_current method + let analytics = Analytics::get_current(&db) + .await + .expect("Failed to get current analytics"); + + assert_eq!(analytics.id, "current"); + assert_eq!(analytics.page_loads, 0); + assert_eq!(analytics.visitors, 0); + } + + #[tokio::test] + async fn test_increment_visitors() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Initialize analytics + Analytics::ensure_initialized(&db) + .await + .expect("Failed to initialize analytics"); + + // Test increment_visitors method + let analytics = Analytics::increment_visitors(&db) + .await + .expect("Failed to increment visitors"); + + assert_eq!(analytics.visitors, 1); + assert_eq!(analytics.page_loads, 0); + + // Increment again and check + let analytics = Analytics::increment_visitors(&db) + .await + .expect("Failed to increment visitors again"); + + assert_eq!(analytics.visitors, 2); + assert_eq!(analytics.page_loads, 0); + } + + #[tokio::test] + async fn test_increment_page_loads() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Initialize analytics + Analytics::ensure_initialized(&db) + .await + .expect("Failed to initialize analytics"); + + // Test increment_page_loads method + let analytics = Analytics::increment_page_loads(&db) + .await + .expect("Failed to increment page loads"); + + assert_eq!(analytics.visitors, 0); + assert_eq!(analytics.page_loads, 1); + + // Increment again and check + let analytics = Analytics::increment_page_loads(&db) + .await + .expect("Failed to increment page loads again"); + + assert_eq!(analytics.visitors, 0); + assert_eq!(analytics.page_loads, 2); + } + + #[tokio::test] + async fn test_get_users_amount() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Test with no users + let count = Analytics::get_users_amount(&db) + .await + .expect("Failed to get users amount"); + assert_eq!(count, 0); + + // Create a few test users + for i in 0..3 { + let user = TestUser { + id: format!("user{}", i), + email: format!("user{}@example.com", i), + password: "password".to_string(), + user_id: format!("uid{}", i), + created_at: Utc::now(), + updated_at: Utc::now(), + }; + + db.store_item(user) + .await + .expect("Failed to create test user"); + } + + // Test users amount after adding users + let count = Analytics::get_users_amount(&db) + .await + .expect("Failed to get users amount after adding users"); + assert_eq!(count, 3); + } + + #[tokio::test] + async fn test_get_current_nonexistent() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Don't initialize analytics and try to get it + let result = Analytics::get_current(&db).await; + + assert!(result.is_err()); + if let Err(err) = result { + match err { + AppError::NotFound(_) => { + // Expected error + } + _ => panic!("Expected NotFound error, got: {:?}", err), + } + } + } +} diff --git a/common/src/storage/types/conversation.rs b/common/src/storage/types/conversation.rs index 2a6675f..6b7c157 100644 --- a/common/src/storage/types/conversation.rs +++ b/common/src/storage/types/conversation.rs @@ -47,3 +47,178 @@ impl Conversation { Ok((conversation, messages)) } } + +#[cfg(test)] +mod tests { + use crate::storage::types::message::MessageRole; + + use super::*; + + #[tokio::test] + async fn test_create_conversation() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Create a new conversation + let user_id = "test_user"; + let title = "Test Conversation"; + let conversation = Conversation::new(user_id.to_string(), title.to_string()); + + // Verify conversation properties + assert_eq!(conversation.user_id, user_id); + assert_eq!(conversation.title, title); + assert!(!conversation.id.is_empty()); + + // Store the conversation + let result = db.store_item(conversation.clone()).await; + assert!(result.is_ok()); + + // Verify it can be retrieved + let retrieved: Option = db + .get_item(&conversation.id) + .await + .expect("Failed to retrieve conversation"); + assert!(retrieved.is_some()); + + let retrieved = retrieved.unwrap(); + assert_eq!(retrieved.id, conversation.id); + assert_eq!(retrieved.user_id, user_id); + assert_eq!(retrieved.title, title); + } + + #[tokio::test] + async fn test_get_complete_conversation_not_found() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Try to get a conversation that doesn't exist + let result = + Conversation::get_complete_conversation("nonexistent_id", "test_user", &db).await; + assert!(result.is_err()); + + match result { + Err(AppError::NotFound(_)) => { /* expected error */ } + _ => panic!("Expected NotFound error"), + } + } + + #[tokio::test] + async fn test_get_complete_conversation_unauthorized() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Create and store a conversation for user_id_1 + let user_id_1 = "user_1"; + let conversation = + Conversation::new(user_id_1.to_string(), "Private Conversation".to_string()); + let conversation_id = conversation.id.clone(); + + db.store_item(conversation) + .await + .expect("Failed to store conversation"); + + // Try to access with a different user + let user_id_2 = "user_2"; + let result = + Conversation::get_complete_conversation(&conversation_id, user_id_2, &db).await; + assert!(result.is_err()); + + match result { + Err(AppError::Auth(_)) => { /* expected error */ } + _ => panic!("Expected Auth error"), + } + } + + #[tokio::test] + async fn test_get_complete_conversation_with_messages() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Create and store a conversation for user_id_1 + let user_id_1 = "user_1"; + let conversation = Conversation::new(user_id_1.to_string(), "Conversation".to_string()); + let conversation_id = conversation.id.clone(); + + db.store_item(conversation) + .await + .expect("Failed to store conversation"); + + // Create messages + let message1 = Message::new( + conversation_id.clone(), + MessageRole::User, + "Hello, AI!".to_string(), + None, + ); + let message2 = Message::new( + conversation_id.clone(), + MessageRole::AI, + "Hello, human! How can I help you today?".to_string(), + None, + ); + let message3 = Message::new( + conversation_id.clone(), + MessageRole::User, + "Tell me about Rust programming.".to_string(), + None, + ); + + // Store messages + db.store_item(message1) + .await + .expect("Failed to store message1"); + db.store_item(message2) + .await + .expect("Failed to store message2"); + db.store_item(message3) + .await + .expect("Failed to store message3"); + + // Retrieve the complete conversation + let result = + Conversation::get_complete_conversation(&conversation_id, user_id_1, &db).await; + assert!(result.is_ok(), "Failed to retrieve complete conversation"); + + let (retrieved_conversation, messages) = result.unwrap(); + + // Verify conversation data + assert_eq!(retrieved_conversation.id, conversation_id); + assert_eq!(retrieved_conversation.user_id, user_id_1); + assert_eq!(retrieved_conversation.title, "Conversation"); + + // Verify messages + assert_eq!(messages.len(), 3); + + // Verify messages are sorted by updated_at + let message_contents: Vec<&str> = messages.iter().map(|m| m.content.as_str()).collect(); + assert!(message_contents.contains(&"Hello, AI!")); + assert!(message_contents.contains(&"Hello, human! How can I help you today?")); + assert!(message_contents.contains(&"Tell me about Rust programming.")); + + // Make sure we can't access with different user + let user_id_2 = "user_2"; + let unauthorized_result = + Conversation::get_complete_conversation(&conversation_id, user_id_2, &db).await; + assert!(unauthorized_result.is_err()); + match unauthorized_result { + Err(AppError::Auth(_)) => { /* expected error */ } + _ => panic!("Expected Auth error"), + } + } +} diff --git a/common/src/storage/types/file_info.rs b/common/src/storage/types/file_info.rs index 8ff3342..f556597 100644 --- a/common/src/storage/types/file_info.rs +++ b/common/src/storage/types/file_info.rs @@ -131,16 +131,32 @@ impl FileInfo { /// Sanitizes the file name to prevent security vulnerabilities like directory traversal. /// Replaces any non-alphanumeric characters (excluding '.' and '_') with underscores. fn sanitize_file_name(file_name: &str) -> String { - file_name - .chars() - .map(|c| { - if c.is_ascii_alphanumeric() || c == '.' || c == '_' { - c - } else { - '_' - } - }) - .collect() + if let Some(idx) = file_name.rfind('.') { + let (name, ext) = file_name.split_at(idx); + let sanitized_name: String = name + .chars() + .map(|c| { + if c.is_ascii_alphanumeric() || c == '_' { + c + } else { + '_' + } + }) + .collect(); + format!("{}{}", sanitized_name, ext) + } else { + // No extension + file_name + .chars() + .map(|c| { + if c.is_ascii_alphanumeric() || c == '_' { + c + } else { + '_' + } + }) + .collect() + } } /// Persists the file to the filesystem under `./data/{user_id}/{uuid}/{file_name}`. @@ -243,3 +259,331 @@ impl FileInfo { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use axum::http::HeaderMap; + use axum_typed_multipart::FieldMetadata; + use std::io::Write; + use tempfile::NamedTempFile; + + /// Creates a test temporary file with the given content + fn create_test_file(content: &[u8], file_name: &str) -> FieldData { + let mut temp_file = NamedTempFile::new().expect("Failed to create temp file"); + temp_file + .write_all(content) + .expect("Failed to write to temp file"); + + let metadata = FieldMetadata { + name: Some("file".to_string()), + file_name: Some(file_name.to_string()), + content_type: None, + headers: HeaderMap::default(), + }; + + let field_data = FieldData { + metadata, + contents: temp_file, + }; + + field_data + } + + #[tokio::test] + async fn test_file_creation() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Create a test file + let content = b"This is a test file content"; + let file_name = "test_file.txt"; + let field_data = create_test_file(content, file_name); + + // Create a FileInfo instance + let user_id = "test_user"; + let file_info = FileInfo::new(field_data, &db, user_id).await; + + // We can't fully test persistence to disk in unit tests, + // but we can verify the database record was created + assert!(file_info.is_ok()); + let file_info = file_info.unwrap(); + + // Check essential properties + assert!(!file_info.id.is_empty()); + assert_eq!(file_info.file_name, file_name); + assert!(!file_info.sha256.is_empty()); + assert!(!file_info.path.is_empty()); + assert!(file_info.mime_type.contains("text/plain")); + + // Verify it's in the database + let stored: Option = db + .get_item(&file_info.id) + .await + .expect("Failed to retrieve file info"); + assert!(stored.is_some()); + let stored = stored.unwrap(); + assert_eq!(stored.id, file_info.id); + assert_eq!(stored.file_name, file_info.file_name); + assert_eq!(stored.sha256, file_info.sha256); + } + + #[tokio::test] + async fn test_file_duplicate_detection() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // First, store a file with known content + let content = b"This is a test file for duplicate detection"; + let file_name = "original.txt"; + let user_id = "test_user"; + + let field_data1 = create_test_file(content, file_name); + let original_file_info = FileInfo::new(field_data1, &db, user_id) + .await + .expect("Failed to create original file"); + + // Now try to store another file with the same content but different name + let duplicate_name = "duplicate.txt"; + let field_data2 = create_test_file(content, duplicate_name); + + // The system should detect it's the same file and return the original FileInfo + let duplicate_file_info = FileInfo::new(field_data2, &db, user_id) + .await + .expect("Failed to process duplicate file"); + + // The returned FileInfo should match the original + assert_eq!(duplicate_file_info.id, original_file_info.id); + assert_eq!(duplicate_file_info.sha256, original_file_info.sha256); + + // But it should retain the original file name, not the duplicate's name + assert_eq!(duplicate_file_info.file_name, file_name); + assert_ne!(duplicate_file_info.file_name, duplicate_name); + } + + #[tokio::test] + async fn test_guess_mime_type() { + // Test common file extensions + assert_eq!( + FileInfo::guess_mime_type(Path::new("test.txt")), + "text/plain".to_string() + ); + assert_eq!( + FileInfo::guess_mime_type(Path::new("image.png")), + "image/png".to_string() + ); + assert_eq!( + FileInfo::guess_mime_type(Path::new("document.pdf")), + "application/pdf".to_string() + ); + assert_eq!( + FileInfo::guess_mime_type(Path::new("data.json")), + "application/json".to_string() + ); + + // Test unknown extension + assert_eq!( + FileInfo::guess_mime_type(Path::new("unknown.929yz")), + "application/octet-stream".to_string() + ); + } + + #[tokio::test] + async fn test_sanitize_file_name() { + // Safe characters should remain unchanged + assert_eq!( + FileInfo::sanitize_file_name("normal_file.txt"), + "normal_file.txt" + ); + assert_eq!(FileInfo::sanitize_file_name("file123.doc"), "file123.doc"); + + // Unsafe characters should be replaced with underscores + assert_eq!( + FileInfo::sanitize_file_name("file with spaces.txt"), + "file_with_spaces.txt" + ); + assert_eq!( + FileInfo::sanitize_file_name("file/with/path.txt"), + "file_with_path.txt" + ); + assert_eq!( + FileInfo::sanitize_file_name("file:with:colons.txt"), + "file_with_colons.txt" + ); + assert_eq!( + FileInfo::sanitize_file_name("../dangerous.txt"), + "___dangerous.txt" + ); + } + + #[tokio::test] + async fn test_get_by_sha_not_found() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Try to find a file with a SHA that doesn't exist + let result = FileInfo::get_by_sha("nonexistent_sha_hash", &db).await; + assert!(result.is_err()); + + match result { + Err(FileError::FileNotFound(_)) => { + // Expected error + } + _ => panic!("Expected FileNotFound error"), + } + } + + #[tokio::test] + async fn test_manual_file_info_creation() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Create a FileInfo instance directly + let now = Utc::now(); + let file_info = FileInfo { + id: Uuid::new_v4().to_string(), + created_at: now, + updated_at: now, + sha256: "test_sha256_hash".to_string(), + path: "/path/to/file.txt".to_string(), + file_name: "manual_file.txt".to_string(), + mime_type: "text/plain".to_string(), + }; + + // Store it in the database + let result = db.store_item(file_info.clone()).await; + assert!(result.is_ok()); + + // Verify it can be retrieved + let retrieved: Option = db + .get_item(&file_info.id) + .await + .expect("Failed to retrieve file info"); + assert!(retrieved.is_some()); + + let retrieved = retrieved.unwrap(); + assert_eq!(retrieved.id, file_info.id); + assert_eq!(retrieved.sha256, file_info.sha256); + assert_eq!(retrieved.file_name, file_info.file_name); + assert_eq!(retrieved.path, file_info.path); + assert_eq!(retrieved.mime_type, file_info.mime_type); + } + + #[tokio::test] + async fn test_delete_by_id() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Create a FileInfo instance directly (without persistence to disk) + let now = Utc::now(); + let file_id = Uuid::new_v4().to_string(); + + // Create a temporary directory that mimics the structure we would have on disk + let base_dir = Path::new("./data"); + let user_id = "test_user"; + let user_dir = base_dir.join(user_id); + let uuid_dir = user_dir.join(&file_id); + + tokio::fs::create_dir_all(&uuid_dir) + .await + .expect("Failed to create test directories"); + + // Create a test file in the directory + let test_file_path = uuid_dir.join("test_file.txt"); + tokio::fs::write(&test_file_path, b"test content") + .await + .expect("Failed to write test file"); + + // The file path should point to our test file + let file_info = FileInfo { + id: file_id.clone(), + created_at: now, + updated_at: now, + sha256: "test_sha256_hash".to_string(), + path: test_file_path.to_string_lossy().to_string(), + file_name: "test_file.txt".to_string(), + mime_type: "text/plain".to_string(), + }; + + // Store it in the database + db.store_item(file_info.clone()) + .await + .expect("Failed to store file info"); + + // Verify file exists on disk + assert!(tokio::fs::try_exists(&test_file_path) + .await + .unwrap_or(false)); + + // Delete the file + let delete_result = FileInfo::delete_by_id(&file_id, &db).await; + + // Delete should be successful + assert!( + delete_result.is_ok(), + "Failed to delete file: {:?}", + delete_result + ); + + // Verify the file is removed from the database + let retrieved: Option = db + .get_item(&file_id) + .await + .expect("Failed to query database"); + assert!( + retrieved.is_none(), + "FileInfo should be deleted from the database" + ); + + // Verify directory is gone + assert!( + !tokio::fs::try_exists(&uuid_dir).await.unwrap_or(true), + "UUID directory should be deleted" + ); + + // Clean up test directory if it exists + let _ = tokio::fs::remove_dir_all(base_dir).await; + } + + #[tokio::test] + async fn test_delete_by_id_not_found() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Try to delete a file that doesn't exist + let result = FileInfo::delete_by_id("nonexistent_id", &db).await; + + // Should fail with FileNotFound error + assert!(result.is_err()); + match result { + Err(FileError::FileNotFound(_)) => { + // Expected error + } + _ => panic!("Expected FileNotFound error"), + } + } +} diff --git a/common/src/storage/types/ingestion_payload.rs b/common/src/storage/types/ingestion_payload.rs index 7d7a523..e3b5851 100644 --- a/common/src/storage/types/ingestion_payload.rs +++ b/common/src/storage/types/ingestion_payload.rs @@ -1,9 +1,10 @@ use crate::{error::AppError, storage::types::file_info::FileInfo}; +use chrono::Utc; use serde::{Deserialize, Serialize}; use tracing::info; use url::Url; -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] pub enum IngestionPayload { Url { url: String, @@ -93,3 +94,237 @@ impl IngestionPayload { Ok(object_list) } } + +#[cfg(test)] +mod tests { + use super::*; + + // Create a mock FileInfo for testing + #[derive(Debug, Clone, PartialEq)] + struct MockFileInfo { + id: String, + } + + impl From for FileInfo { + fn from(mock: MockFileInfo) -> Self { + // This is just a test implementation, the actual fields don't matter + // as we're just testing the IngestionPayload functionality + FileInfo { + id: mock.id, + sha256: "mock-sha256".to_string(), + path: "/mock/path".to_string(), + file_name: "mock.txt".to_string(), + mime_type: "text/plain".to_string(), + created_at: Utc::now(), + updated_at: Utc::now(), + } + } + } + + #[test] + fn test_create_ingestion_payload_with_url() { + let url = "https://example.com"; + let instructions = "Process this URL"; + let category = "websites"; + let user_id = "user123"; + let files = vec![]; + + let result = IngestionPayload::create_ingestion_payload( + Some(url.to_string()), + instructions.to_string(), + category.to_string(), + files, + user_id, + ) + .unwrap(); + + assert_eq!(result.len(), 1); + match &result[0] { + IngestionPayload::Url { + url: payload_url, + instructions: payload_instructions, + category: payload_category, + user_id: payload_user_id, + } => { + // URL parser may normalize the URL by adding a trailing slash + assert!(payload_url == &url.to_string() || payload_url == &format!("{}/", url)); + assert_eq!(payload_instructions, &instructions); + assert_eq!(payload_category, &category); + assert_eq!(payload_user_id, &user_id); + } + _ => panic!("Expected Url variant"), + } + } + + #[test] + fn test_create_ingestion_payload_with_text() { + let text = "This is some text content"; + let instructions = "Process this text"; + let category = "notes"; + let user_id = "user123"; + let files = vec![]; + + let result = IngestionPayload::create_ingestion_payload( + Some(text.to_string()), + instructions.to_string(), + category.to_string(), + files, + user_id, + ) + .unwrap(); + + assert_eq!(result.len(), 1); + match &result[0] { + IngestionPayload::Text { + text: payload_text, + instructions: payload_instructions, + category: payload_category, + user_id: payload_user_id, + } => { + assert_eq!(payload_text, text); + assert_eq!(payload_instructions, instructions); + assert_eq!(payload_category, category); + assert_eq!(payload_user_id, user_id); + } + _ => panic!("Expected Text variant"), + } + } + + #[test] + fn test_create_ingestion_payload_with_file() { + let instructions = "Process this file"; + let category = "documents"; + let user_id = "user123"; + + // Create a mock FileInfo + let mock_file = MockFileInfo { + id: "file123".to_string(), + }; + + let file_info: FileInfo = mock_file.into(); + let files = vec![file_info.clone()]; + + let result = IngestionPayload::create_ingestion_payload( + None, + instructions.to_string(), + category.to_string(), + files, + user_id, + ) + .unwrap(); + + assert_eq!(result.len(), 1); + match &result[0] { + IngestionPayload::File { + file_info: payload_file_info, + instructions: payload_instructions, + category: payload_category, + user_id: payload_user_id, + } => { + assert_eq!(payload_file_info.id, file_info.id); + assert_eq!(payload_instructions, instructions); + assert_eq!(payload_category, category); + assert_eq!(payload_user_id, user_id); + } + _ => panic!("Expected File variant"), + } + } + + #[test] + fn test_create_ingestion_payload_with_url_and_file() { + let url = "https://example.com"; + let instructions = "Process this data"; + let category = "mixed"; + let user_id = "user123"; + + // Create a mock FileInfo + let mock_file = MockFileInfo { + id: "file123".to_string(), + }; + + let file_info: FileInfo = mock_file.into(); + let files = vec![file_info.clone()]; + + let result = IngestionPayload::create_ingestion_payload( + Some(url.to_string()), + instructions.to_string(), + category.to_string(), + files, + user_id, + ) + .unwrap(); + + assert_eq!(result.len(), 2); + + // Check first item is URL + match &result[0] { + IngestionPayload::Url { + url: payload_url, .. + } => { + // URL parser may normalize the URL by adding a trailing slash + assert!(payload_url == &url.to_string() || payload_url == &format!("{}/", url)); + } + _ => panic!("Expected first item to be Url variant"), + } + + // Check second item is File + match &result[1] { + IngestionPayload::File { + file_info: payload_file_info, + .. + } => { + assert_eq!(payload_file_info.id, file_info.id); + } + _ => panic!("Expected second item to be File variant"), + } + } + + #[test] + fn test_create_ingestion_payload_empty_input() { + let instructions = "Process something"; + let category = "empty"; + let user_id = "user123"; + let files = vec![]; + + let result = IngestionPayload::create_ingestion_payload( + None, + instructions.to_string(), + category.to_string(), + files, + user_id, + ); + + assert!(result.is_err()); + match result { + Err(AppError::NotFound(msg)) => { + assert_eq!(msg, "No valid content or files provided"); + } + _ => panic!("Expected NotFound error"), + } + } + + #[test] + fn test_create_ingestion_payload_with_empty_text() { + let text = ""; // Empty text + let instructions = "Process this"; + let category = "notes"; + let user_id = "user123"; + let files = vec![]; + + let result = IngestionPayload::create_ingestion_payload( + Some(text.to_string()), + instructions.to_string(), + category.to_string(), + files, + user_id, + ); + + assert!(result.is_err()); + match result { + Err(AppError::NotFound(msg)) => { + assert_eq!(msg, "No valid content or files provided"); + } + _ => panic!("Expected NotFound error"), + } + } +} diff --git a/common/src/storage/types/ingestion_task.rs b/common/src/storage/types/ingestion_task.rs index 1a89df1..a7d4957 100644 --- a/common/src/storage/types/ingestion_task.rs +++ b/common/src/storage/types/ingestion_task.rs @@ -6,7 +6,7 @@ use crate::{error::AppError, storage::db::SurrealDbClient, stored_object}; use super::ingestion_payload::IngestionPayload; -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub enum IngestionTaskStatus { Created, InProgress { @@ -100,3 +100,196 @@ impl IngestionTask { Ok(jobs) } } + +#[cfg(test)] +mod tests { + use super::*; + use chrono::Utc; + + // Helper function to create a test ingestion payload + fn create_test_payload(user_id: &str) -> IngestionPayload { + IngestionPayload::Text { + text: "Test content".to_string(), + instructions: "Test instructions".to_string(), + category: "Test category".to_string(), + user_id: user_id.to_string(), + } + } + + #[tokio::test] + async fn test_new_ingestion_task() { + let user_id = "user123"; + let payload = create_test_payload(user_id); + + let task = IngestionTask::new(payload.clone(), user_id.to_string()).await; + + // Verify task properties + assert_eq!(task.user_id, user_id); + assert_eq!(task.content, payload); + assert!(matches!(task.status, IngestionTaskStatus::Created)); + assert!(!task.id.is_empty()); + } + + #[tokio::test] + async fn test_create_and_add_to_db() { + // Setup in-memory database + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + let user_id = "user123"; + let payload = create_test_payload(user_id); + + // Create and store task + IngestionTask::create_and_add_to_db(payload.clone(), user_id.to_string(), &db) + .await + .expect("Failed to create and add task to db"); + + // Query to verify task was stored + let query = format!( + "SELECT * FROM {} WHERE user_id = '{}'", + IngestionTask::table_name(), + user_id + ); + let mut result = db.query(query).await.expect("Query failed"); + let tasks: Vec = result.take(0).unwrap_or_default(); + + // Verify task is in the database + assert!(!tasks.is_empty(), "Task should exist in the database"); + let stored_task = &tasks[0]; + assert_eq!(stored_task.user_id, user_id); + assert!(matches!(stored_task.status, IngestionTaskStatus::Created)); + } + + #[tokio::test] + async fn test_update_status() { + // Setup in-memory database + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + let user_id = "user123"; + let payload = create_test_payload(user_id); + + // Create task manually + let task = IngestionTask::new(payload.clone(), user_id.to_string()).await; + let task_id = task.id.clone(); + + // Store task + db.store_item(task).await.expect("Failed to store task"); + + // Update status to InProgress + let now = Utc::now(); + let new_status = IngestionTaskStatus::InProgress { + attempts: 1, + last_attempt: now, + }; + + IngestionTask::update_status(&task_id, new_status.clone(), &db) + .await + .expect("Failed to update status"); + + // Verify status updated + let updated_task: Option = db + .get_item::(&task_id) + .await + .expect("Failed to get updated task"); + + assert!(updated_task.is_some()); + let updated_task = updated_task.unwrap(); + + match updated_task.status { + IngestionTaskStatus::InProgress { attempts, .. } => { + assert_eq!(attempts, 1); + } + _ => panic!("Expected InProgress status"), + } + } + + #[tokio::test] + async fn test_get_unfinished_tasks() { + // Setup in-memory database + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + let user_id = "user123"; + let payload = create_test_payload(user_id); + + // Create tasks with different statuses + let created_task = IngestionTask::new(payload.clone(), user_id.to_string()).await; + + let mut in_progress_task = IngestionTask::new(payload.clone(), user_id.to_string()).await; + in_progress_task.status = IngestionTaskStatus::InProgress { + attempts: 1, + last_attempt: Utc::now(), + }; + + let mut max_attempts_task = IngestionTask::new(payload.clone(), user_id.to_string()).await; + max_attempts_task.status = IngestionTaskStatus::InProgress { + attempts: MAX_ATTEMPTS, + last_attempt: Utc::now(), + }; + + let mut completed_task = IngestionTask::new(payload.clone(), user_id.to_string()).await; + completed_task.status = IngestionTaskStatus::Completed; + + let mut error_task = IngestionTask::new(payload.clone(), user_id.to_string()).await; + error_task.status = IngestionTaskStatus::Error("Test error".to_string()); + + // Store all tasks + db.store_item(created_task) + .await + .expect("Failed to store created task"); + db.store_item(in_progress_task) + .await + .expect("Failed to store in-progress task"); + db.store_item(max_attempts_task) + .await + .expect("Failed to store max-attempts task"); + db.store_item(completed_task) + .await + .expect("Failed to store completed task"); + db.store_item(error_task) + .await + .expect("Failed to store error task"); + + // Get unfinished tasks + let unfinished_tasks = IngestionTask::get_unfinished_tasks(&db) + .await + .expect("Failed to get unfinished tasks"); + + // Verify only Created and InProgress with attempts < MAX_ATTEMPTS are returned + assert_eq!(unfinished_tasks.len(), 2); + + let statuses: Vec<_> = unfinished_tasks + .iter() + .map(|task| match &task.status { + IngestionTaskStatus::Created => "Created", + IngestionTaskStatus::InProgress { attempts, .. } => { + if *attempts < MAX_ATTEMPTS { + "InProgress=MAX" + } + } + IngestionTaskStatus::Completed => "Completed", + IngestionTaskStatus::Error(_) => "Error", + IngestionTaskStatus::Cancelled => "Cancelled", + }) + .collect(); + + assert!(statuses.contains(&"Created")); + assert!(statuses.contains(&"InProgress=MAX")); + assert!(!statuses.contains(&"Completed")); + assert!(!statuses.contains(&"Error")); + assert!(!statuses.contains(&"Cancelled")); + } +} diff --git a/common/src/storage/types/knowledge_entity.rs b/common/src/storage/types/knowledge_entity.rs index 5707129..fb2a38e 100644 --- a/common/src/storage/types/knowledge_entity.rs +++ b/common/src/storage/types/knowledge_entity.rs @@ -5,7 +5,7 @@ use crate::{ use async_openai::{config::OpenAIConfig, Client}; use uuid::Uuid; -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] pub enum KnowledgeEntityType { Idea, Project, @@ -119,3 +119,198 @@ impl KnowledgeEntity { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[tokio::test] + async fn test_knowledge_entity_creation() { + // Create basic test entity + let source_id = "source123".to_string(); + let name = "Test Entity".to_string(); + let description = "Test Description".to_string(); + let entity_type = KnowledgeEntityType::Document; + let metadata = Some(json!({"key": "value"})); + let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5]; + let user_id = "user123".to_string(); + + let entity = KnowledgeEntity::new( + source_id.clone(), + name.clone(), + description.clone(), + entity_type.clone(), + metadata.clone(), + embedding.clone(), + user_id.clone(), + ); + + // Verify all fields are set correctly + assert_eq!(entity.source_id, source_id); + assert_eq!(entity.name, name); + assert_eq!(entity.description, description); + assert_eq!(entity.entity_type, entity_type); + assert_eq!(entity.metadata, metadata); + assert_eq!(entity.embedding, embedding); + assert_eq!(entity.user_id, user_id); + assert!(!entity.id.is_empty()); + } + + #[tokio::test] + async fn test_knowledge_entity_type_from_string() { + // Test conversion from String to KnowledgeEntityType + assert_eq!( + KnowledgeEntityType::from("idea".to_string()), + KnowledgeEntityType::Idea + ); + assert_eq!( + KnowledgeEntityType::from("Idea".to_string()), + KnowledgeEntityType::Idea + ); + assert_eq!( + KnowledgeEntityType::from("IDEA".to_string()), + KnowledgeEntityType::Idea + ); + + assert_eq!( + KnowledgeEntityType::from("project".to_string()), + KnowledgeEntityType::Project + ); + assert_eq!( + KnowledgeEntityType::from("document".to_string()), + KnowledgeEntityType::Document + ); + assert_eq!( + KnowledgeEntityType::from("page".to_string()), + KnowledgeEntityType::Page + ); + assert_eq!( + KnowledgeEntityType::from("textsnippet".to_string()), + KnowledgeEntityType::TextSnippet + ); + + // Test default case + assert_eq!( + KnowledgeEntityType::from("unknown".to_string()), + KnowledgeEntityType::Document + ); + } + + #[tokio::test] + async fn test_knowledge_entity_variants() { + let variants = KnowledgeEntityType::variants(); + assert_eq!(variants.len(), 5); + assert!(variants.contains(&"Idea")); + assert!(variants.contains(&"Project")); + assert!(variants.contains(&"Document")); + assert!(variants.contains(&"Page")); + assert!(variants.contains(&"TextSnippet")); + } + + #[tokio::test] + async fn test_delete_by_source_id() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Create two entities with the same source_id + let source_id = "source123".to_string(); + let entity_type = KnowledgeEntityType::Document; + let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5]; + let user_id = "user123".to_string(); + + let entity1 = KnowledgeEntity::new( + source_id.clone(), + "Entity 1".to_string(), + "Description 1".to_string(), + entity_type.clone(), + None, + embedding.clone(), + user_id.clone(), + ); + + let entity2 = KnowledgeEntity::new( + source_id.clone(), + "Entity 2".to_string(), + "Description 2".to_string(), + entity_type.clone(), + None, + embedding.clone(), + user_id.clone(), + ); + + // Create an entity with a different source_id + let different_source_id = "different_source".to_string(); + let different_entity = KnowledgeEntity::new( + different_source_id.clone(), + "Different Entity".to_string(), + "Different Description".to_string(), + entity_type.clone(), + None, + embedding.clone(), + user_id.clone(), + ); + + // Store the entities + db.store_item(entity1) + .await + .expect("Failed to store entity 1"); + db.store_item(entity2) + .await + .expect("Failed to store entity 2"); + db.store_item(different_entity.clone()) + .await + .expect("Failed to store different entity"); + + // Delete by source_id + KnowledgeEntity::delete_by_source_id(&source_id, &db) + .await + .expect("Failed to delete entities by source_id"); + + // Verify all entities with the specified source_id are deleted + let query = format!( + "SELECT * FROM {} WHERE source_id = '{}'", + KnowledgeEntity::table_name(), + source_id + ); + let remaining: Vec = db + .client + .query(query) + .await + .expect("Query failed") + .take(0) + .expect("Failed to get query results"); + assert_eq!( + remaining.len(), + 0, + "All entities with the source_id should be deleted" + ); + + // Verify the entity with a different source_id still exists + let different_query = format!( + "SELECT * FROM {} WHERE source_id = '{}'", + KnowledgeEntity::table_name(), + different_source_id + ); + let different_remaining: Vec = db + .client + .query(different_query) + .await + .expect("Query failed") + .take(0) + .expect("Failed to get query results"); + assert_eq!( + different_remaining.len(), + 1, + "Entity with different source_id should still exist" + ); + assert_eq!(different_remaining[0].id, different_entity.id); + } + + // Note: We can't easily test the patch method without mocking the OpenAI client + // and the generate_embedding function. This would require more complex setup. +} diff --git a/common/src/storage/types/knowledge_relationship.rs b/common/src/storage/types/knowledge_relationship.rs index d8ceecf..c40799b 100644 --- a/common/src/storage/types/knowledge_relationship.rs +++ b/common/src/storage/types/knowledge_relationship.rs @@ -84,3 +84,258 @@ impl KnowledgeRelationship { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}; + + // Helper function to create a test knowledge entity for the relationship tests + async fn create_test_entity(name: &str, db_client: &SurrealDbClient) -> String { + let source_id = "source123".to_string(); + let description = format!("Description for {}", name); + let entity_type = KnowledgeEntityType::Document; + let embedding = vec![0.1, 0.2, 0.3]; + let user_id = "user123".to_string(); + + let entity = KnowledgeEntity::new( + source_id, + name.to_string(), + description, + entity_type, + None, + embedding, + user_id, + ); + + let stored: Option = db_client + .store_item(entity) + .await + .expect("Failed to store entity"); + stored.unwrap().id + } + + #[tokio::test] + async fn test_relationship_creation() { + let in_id = "entity1".to_string(); + let out_id = "entity2".to_string(); + let user_id = "user123".to_string(); + let source_id = "source123".to_string(); + let relationship_type = "references".to_string(); + + let relationship = KnowledgeRelationship::new( + in_id.clone(), + out_id.clone(), + user_id.clone(), + source_id.clone(), + relationship_type.clone(), + ); + + // Verify fields are correctly set + assert_eq!(relationship.in_, in_id); + assert_eq!(relationship.out, out_id); + assert_eq!(relationship.metadata.user_id, user_id); + assert_eq!(relationship.metadata.source_id, source_id); + assert_eq!(relationship.metadata.relationship_type, relationship_type); + assert!(!relationship.id.is_empty()); + } + + #[tokio::test] + async fn test_store_relationship() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Create two entities to relate + let entity1_id = create_test_entity("Entity 1", &db).await; + let entity2_id = create_test_entity("Entity 2", &db).await; + + // Create relationship + let user_id = "user123".to_string(); + let source_id = "source123".to_string(); + let relationship_type = "references".to_string(); + + let relationship = KnowledgeRelationship::new( + entity1_id.clone(), + entity2_id.clone(), + user_id, + source_id.clone(), + relationship_type, + ); + + // Store the relationship + relationship + .store_relationship(&db) + .await + .expect("Failed to store relationship"); + + // Query to verify the relationship exists by checking for relationships with our source_id + // This approach is more reliable than trying to look up by ID + let check_query = format!( + "SELECT * FROM relates_to WHERE metadata.source_id = '{}'", + source_id + ); + let mut check_result = db.query(check_query).await.expect("Check query failed"); + let check_results: Vec = check_result.take(0).unwrap_or_default(); + + // Just verify that a relationship was created + assert!( + !check_results.is_empty(), + "Relationship should exist in the database" + ); + } + + #[tokio::test] + async fn test_delete_relationship_by_id() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Create two entities to relate + let entity1_id = create_test_entity("Entity 1", &db).await; + let entity2_id = create_test_entity("Entity 2", &db).await; + + // Create relationship + let user_id = "user123".to_string(); + let source_id = "source123".to_string(); + let relationship_type = "references".to_string(); + + let relationship = KnowledgeRelationship::new( + entity1_id.clone(), + entity2_id.clone(), + user_id, + source_id.clone(), + relationship_type, + ); + + // Store the relationship + relationship + .store_relationship(&db) + .await + .expect("Failed to store relationship"); + + // Delete the relationship by ID + KnowledgeRelationship::delete_relationship_by_id(&relationship.id, &db) + .await + .expect("Failed to delete relationship by ID"); + + // Query to verify the relationship was deleted + let query = format!("SELECT * FROM relates_to WHERE id = '{}'", relationship.id); + let mut result = db.query(query).await.expect("Query failed"); + let results: Vec = result.take(0).unwrap_or_default(); + + // Verify the relationship no longer exists + assert!(results.is_empty(), "Relationship should be deleted"); + } + + #[tokio::test] + async fn test_delete_relationships_by_source_id() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Create entities to relate + let entity1_id = create_test_entity("Entity 1", &db).await; + let entity2_id = create_test_entity("Entity 2", &db).await; + let entity3_id = create_test_entity("Entity 3", &db).await; + + // Create relationships with the same source_id + let user_id = "user123".to_string(); + let source_id = "source123".to_string(); + let different_source_id = "different_source".to_string(); + + // Create two relationships with the same source_id + let relationship1 = KnowledgeRelationship::new( + entity1_id.clone(), + entity2_id.clone(), + user_id.clone(), + source_id.clone(), + "references".to_string(), + ); + + let relationship2 = KnowledgeRelationship::new( + entity2_id.clone(), + entity3_id.clone(), + user_id.clone(), + source_id.clone(), + "contains".to_string(), + ); + + // Create a relationship with a different source_id + let different_relationship = KnowledgeRelationship::new( + entity1_id.clone(), + entity3_id.clone(), + user_id.clone(), + different_source_id.clone(), + "mentions".to_string(), + ); + + // Store all relationships + relationship1 + .store_relationship(&db) + .await + .expect("Failed to store relationship 1"); + relationship2 + .store_relationship(&db) + .await + .expect("Failed to store relationship 2"); + different_relationship + .store_relationship(&db) + .await + .expect("Failed to store different relationship"); + + // Delete relationships by source_id + KnowledgeRelationship::delete_relationships_by_source_id(&source_id, &db) + .await + .expect("Failed to delete relationships by source_id"); + + // Query to verify the relationships with source_id were deleted + let query1 = format!("SELECT * FROM relates_to WHERE id = '{}'", relationship1.id); + let query2 = format!("SELECT * FROM relates_to WHERE id = '{}'", relationship2.id); + let different_query = format!( + "SELECT * FROM relates_to WHERE id = '{}'", + different_relationship.id + ); + + let mut result1 = db.query(query1).await.expect("Query 1 failed"); + let results1: Vec = result1.take(0).unwrap_or_default(); + + let mut result2 = db.query(query2).await.expect("Query 2 failed"); + let results2: Vec = result2.take(0).unwrap_or_default(); + + let mut different_result = db + .query(different_query) + .await + .expect("Different query failed"); + let _different_results: Vec = + different_result.take(0).unwrap_or_default(); + + // Verify relationships with the source_id are deleted + assert!(results1.is_empty(), "Relationship 1 should be deleted"); + assert!(results2.is_empty(), "Relationship 2 should be deleted"); + + // For the relationship with different source ID, we need to check differently + // Let's just verify we have a relationship where the source_id matches different_source_id + let check_query = format!( + "SELECT * FROM relates_to WHERE metadata.source_id = '{}'", + different_source_id + ); + let mut check_result = db.query(check_query).await.expect("Check query failed"); + let check_results: Vec = check_result.take(0).unwrap_or_default(); + + // Verify the relationship with a different source_id still exists + assert!( + !check_results.is_empty(), + "Relationship with different source_id should still exist" + ); + } +} diff --git a/common/src/storage/types/message.rs b/common/src/storage/types/message.rs index f3a2160..3ddf937 100644 --- a/common/src/storage/types/message.rs +++ b/common/src/storage/types/message.rs @@ -2,7 +2,7 @@ use uuid::Uuid; use crate::stored_object; -#[derive(Deserialize, Debug, Clone, Serialize)] +#[derive(Deserialize, Debug, Clone, Serialize, PartialEq)] pub enum MessageRole { User, AI, @@ -60,3 +60,128 @@ pub fn format_history(history: &[Message]) -> String { .collect::>() .join("\n") } + +#[cfg(test)] +mod tests { + use super::*; + use crate::storage::db::SurrealDbClient; + + #[tokio::test] + async fn test_message_creation() { + // Test basic message creation + let conversation_id = "test_conversation"; + let content = "This is a test message"; + let role = MessageRole::User; + let references = Some(vec!["ref1".to_string(), "ref2".to_string()]); + + let message = Message::new( + conversation_id.to_string(), + role.clone(), + content.to_string(), + references.clone(), + ); + + // Verify message properties + assert_eq!(message.conversation_id, conversation_id); + assert_eq!(message.content, content); + assert_eq!(message.role, role); + assert_eq!(message.references, references); + assert!(!message.id.is_empty()); + } + + #[tokio::test] + async fn test_message_persistence() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &uuid::Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Create and store a message + let conversation_id = "test_conversation"; + let message = Message::new( + conversation_id.to_string(), + MessageRole::User, + "Hello world".to_string(), + None, + ); + let message_id = message.id.clone(); + + // Store the message + db.store_item(message.clone()) + .await + .expect("Failed to store message"); + + // Retrieve the message + let retrieved: Option = db + .get_item(&message_id) + .await + .expect("Failed to retrieve message"); + + assert!(retrieved.is_some()); + let retrieved = retrieved.unwrap(); + + // Verify retrieved properties match original + assert_eq!(retrieved.id, message.id); + assert_eq!(retrieved.conversation_id, message.conversation_id); + assert_eq!(retrieved.role, message.role); + assert_eq!(retrieved.content, message.content); + assert_eq!(retrieved.references, message.references); + } + + #[tokio::test] + async fn test_message_role_display() { + // Test the Display implementation for MessageRole + assert_eq!(format!("{}", MessageRole::User), "User"); + assert_eq!(format!("{}", MessageRole::AI), "AI"); + assert_eq!(format!("{}", MessageRole::System), "System"); + } + + #[tokio::test] + async fn test_message_display() { + // Test the Display implementation for Message + let message = Message { + id: "test_id".to_string(), + created_at: Utc::now(), + updated_at: Utc::now(), + conversation_id: "test_convo".to_string(), + role: MessageRole::User, + content: "Hello world".to_string(), + references: None, + }; + + assert_eq!(format!("{}", message), "User: Hello world"); + } + + #[tokio::test] + async fn test_format_history() { + // Create a vector of messages + let messages = vec![ + Message { + id: "1".to_string(), + created_at: Utc::now(), + updated_at: Utc::now(), + conversation_id: "test_convo".to_string(), + role: MessageRole::User, + content: "Hello".to_string(), + references: None, + }, + Message { + id: "2".to_string(), + created_at: Utc::now(), + updated_at: Utc::now(), + conversation_id: "test_convo".to_string(), + role: MessageRole::AI, + content: "Hi there!".to_string(), + references: None, + }, + ]; + + // Format the history + let formatted = format_history(&messages); + + // Verify the formatting + assert_eq!(formatted, "User: Hello\nAI: Hi there!"); + } +} diff --git a/common/src/storage/types/mod.rs b/common/src/storage/types/mod.rs index ce566c9..67c11ff 100644 --- a/common/src/storage/types/mod.rs +++ b/common/src/storage/types/mod.rs @@ -87,7 +87,7 @@ macro_rules! stored_object { } - #[derive(Debug, Clone, Serialize, Deserialize)] + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct $name { #[serde(deserialize_with = "deserialize_flexible_id")] pub id: String, diff --git a/common/src/storage/types/system_settings.rs b/common/src/storage/types/system_settings.rs index c6b8447..1831ea7 100644 --- a/common/src/storage/types/system_settings.rs +++ b/common/src/storage/types/system_settings.rs @@ -1,8 +1,9 @@ use crate::storage::types::file_info::deserialize_flexible_id; +use axum::async_trait; use serde::{Deserialize, Serialize}; use uuid::Uuid; -use crate::{error::AppError, storage::db::SurrealDbClient}; +use crate::{error::AppError, storage::db::SurrealDbClient, storage::types::StoredObject}; #[derive(Debug, Serialize, Deserialize, Clone)] pub struct SystemSettings { @@ -16,41 +17,49 @@ pub struct SystemSettings { pub ingestion_system_prompt: String, } +#[async_trait] +impl StoredObject for SystemSettings { + fn table_name() -> &'static str { + "system_settings" + } + + fn get_id(&self) -> &str { + &self.id + } +} + impl SystemSettings { pub async fn ensure_initialized(db: &SurrealDbClient) -> Result { - let settings = db.select(("system_settings", "current")).await?; + let settings: Option = db.get_item("current").await?; if settings.is_none() { - let created: Option = db - .create(("system_settings", "current")) - .content(SystemSettings { - id: "current".to_string(), - registrations_enabled: true, - require_email_verification: false, - query_model: "gpt-4o-mini".to_string(), - processing_model: "gpt-4o-mini".to_string(), - query_system_prompt: crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT.to_string(), - ingestion_system_prompt: crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT.to_string(), - }) - .await?; + let created_settings = SystemSettings { + id: "current".to_string(), + registrations_enabled: true, + require_email_verification: false, + query_model: "gpt-4o-mini".to_string(), + processing_model: "gpt-4o-mini".to_string(), + query_system_prompt: + crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT.to_string(), + ingestion_system_prompt: + crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT + .to_string(), + }; - return created.ok_or(AppError::Validation("Failed to initialize settings".into())); - }; + let stored: Option = db.store_item(created_settings).await?; + return stored.ok_or(AppError::Validation("Failed to initialize settings".into())); + } settings.ok_or(AppError::Validation("Failed to initialize settings".into())) } pub async fn get_current(db: &SurrealDbClient) -> Result { - let settings: Option = db - .client - .query("SELECT * FROM type::thing('system_settings', 'current')") - .await? - .take(0)?; - + let settings: Option = db.get_item("current").await?; settings.ok_or(AppError::NotFound("System settings not found".into())) } pub async fn update(db: &SurrealDbClient, changes: Self) -> Result { + // We need to use a direct query for the update with MERGE let updated: Option = db .client .query("UPDATE type::thing('system_settings', 'current') MERGE $changes RETURN AFTER") @@ -66,8 +75,11 @@ impl SystemSettings { pub fn new() -> Self { Self { id: Uuid::new_v4().to_string(), - query_system_prompt: crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT.to_string(), - ingestion_system_prompt: crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT.to_string(), + query_system_prompt: crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT + .to_string(), + ingestion_system_prompt: + crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT + .to_string(), query_model: "gpt-4o-mini".to_string(), processing_model: "gpt-4o-mini".to_string(), registrations_enabled: true, @@ -75,3 +87,159 @@ impl SystemSettings { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_settings_initialization() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Test initialization of system settings + let settings = SystemSettings::ensure_initialized(&db) + .await + .expect("Failed to initialize system settings"); + + // Verify initial state after initialization + assert_eq!(settings.id, "current"); + assert_eq!(settings.registrations_enabled, true); + assert_eq!(settings.require_email_verification, false); + assert_eq!(settings.query_model, "gpt-4o-mini"); + assert_eq!(settings.processing_model, "gpt-4o-mini"); + assert_eq!( + settings.query_system_prompt, + crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT + ); + assert_eq!( + settings.ingestion_system_prompt, + crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT + ); + + // Test idempotency - ensure calling it again doesn't change anything + let settings_again = SystemSettings::ensure_initialized(&db) + .await + .expect("Failed to get settings after initialization"); + + assert_eq!(settings.id, settings_again.id); + assert_eq!( + settings.registrations_enabled, + settings_again.registrations_enabled + ); + assert_eq!( + settings.require_email_verification, + settings_again.require_email_verification + ); + } + + #[tokio::test] + async fn test_get_current_settings() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Initialize settings + SystemSettings::ensure_initialized(&db) + .await + .expect("Failed to initialize system settings"); + + // Test get_current method + let settings = SystemSettings::get_current(&db) + .await + .expect("Failed to get current settings"); + + assert_eq!(settings.id, "current"); + assert_eq!(settings.registrations_enabled, true); + assert_eq!(settings.require_email_verification, false); + } + + #[tokio::test] + async fn test_update_settings() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Initialize settings + SystemSettings::ensure_initialized(&db) + .await + .expect("Failed to initialize system settings"); + + // Create updated settings + let mut updated_settings = SystemSettings::new(); + updated_settings.id = "current".to_string(); + updated_settings.registrations_enabled = false; + updated_settings.require_email_verification = true; + updated_settings.query_model = "gpt-4".to_string(); + + // Test update method + let result = SystemSettings::update(&db, updated_settings) + .await + .expect("Failed to update settings"); + + assert_eq!(result.id, "current"); + assert_eq!(result.registrations_enabled, false); + assert_eq!(result.require_email_verification, true); + assert_eq!(result.query_model, "gpt-4"); + + // Verify changes persisted by getting current settings + let current = SystemSettings::get_current(&db) + .await + .expect("Failed to get current settings after update"); + + assert_eq!(current.registrations_enabled, false); + assert_eq!(current.require_email_verification, true); + assert_eq!(current.query_model, "gpt-4"); + } + + #[tokio::test] + async fn test_get_current_nonexistent() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Don't initialize settings and try to get them + let result = SystemSettings::get_current(&db).await; + + assert!(result.is_err()); + match result { + Err(AppError::NotFound(_)) => { + // Expected error + } + Err(e) => panic!("Expected NotFound error, got: {:?}", e), + Ok(_) => panic!("Expected error but got Ok"), + } + } + + #[tokio::test] + async fn test_new_method() { + let settings = SystemSettings::new(); + + assert!(settings.id.len() > 0); + assert_eq!(settings.registrations_enabled, true); + assert_eq!(settings.require_email_verification, false); + assert_eq!(settings.query_model, "gpt-4o-mini"); + assert_eq!(settings.processing_model, "gpt-4o-mini"); + assert_eq!( + settings.query_system_prompt, + crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT + ); + assert_eq!( + settings.ingestion_system_prompt, + crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT + ); + } +} diff --git a/common/src/storage/types/text_chunk.rs b/common/src/storage/types/text_chunk.rs index 9b56699..cfb4a51 100644 --- a/common/src/storage/types/text_chunk.rs +++ b/common/src/storage/types/text_chunk.rs @@ -36,3 +36,175 @@ impl TextChunk { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_text_chunk_creation() { + // Test basic object creation + let source_id = "source123".to_string(); + let chunk = "This is a text chunk for testing embeddings".to_string(); + let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5]; + let user_id = "user123".to_string(); + + let text_chunk = TextChunk::new( + source_id.clone(), + chunk.clone(), + embedding.clone(), + user_id.clone(), + ); + + // Check that the fields are set correctly + assert_eq!(text_chunk.source_id, source_id); + assert_eq!(text_chunk.chunk, chunk); + assert_eq!(text_chunk.embedding, embedding); + assert_eq!(text_chunk.user_id, user_id); + assert!(!text_chunk.id.is_empty()); + } + + #[tokio::test] + async fn test_delete_by_source_id() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Create test data + let source_id = "source123".to_string(); + let chunk1 = "First chunk from the same source".to_string(); + let chunk2 = "Second chunk from the same source".to_string(); + let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5]; + let user_id = "user123".to_string(); + + // Create two chunks with the same source_id + let text_chunk1 = TextChunk::new( + source_id.clone(), + chunk1, + embedding.clone(), + user_id.clone(), + ); + + let text_chunk2 = TextChunk::new( + source_id.clone(), + chunk2, + embedding.clone(), + user_id.clone(), + ); + + // Create a chunk with a different source_id + let different_source_id = "different_source".to_string(); + let different_chunk = TextChunk::new( + different_source_id.clone(), + "Different source chunk".to_string(), + embedding.clone(), + user_id.clone(), + ); + + // Store the chunks + db.store_item(text_chunk1) + .await + .expect("Failed to store text chunk 1"); + db.store_item(text_chunk2) + .await + .expect("Failed to store text chunk 2"); + db.store_item(different_chunk.clone()) + .await + .expect("Failed to store different chunk"); + + // Delete by source_id + TextChunk::delete_by_source_id(&source_id, &db) + .await + .expect("Failed to delete chunks by source_id"); + + // Verify all chunks with the original source_id are deleted + let query = format!( + "SELECT * FROM {} WHERE source_id = '{}'", + TextChunk::table_name(), + source_id + ); + let remaining: Vec = db + .client + .query(query) + .await + .expect("Query failed") + .take(0) + .expect("Failed to get query results"); + assert_eq!( + remaining.len(), + 0, + "All chunks with the source_id should be deleted" + ); + + // Verify the different source_id chunk still exists + let different_query = format!( + "SELECT * FROM {} WHERE source_id = '{}'", + TextChunk::table_name(), + different_source_id + ); + let different_remaining: Vec = db + .client + .query(different_query) + .await + .expect("Query failed") + .take(0) + .expect("Failed to get query results"); + assert_eq!( + different_remaining.len(), + 1, + "Chunk with different source_id should still exist" + ); + assert_eq!(different_remaining[0].id, different_chunk.id); + } + + #[tokio::test] + async fn test_delete_by_nonexistent_source_id() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Create a chunk with a real source_id + let real_source_id = "real_source".to_string(); + let chunk = "Test chunk".to_string(); + let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5]; + let user_id = "user123".to_string(); + + let text_chunk = TextChunk::new(real_source_id.clone(), chunk, embedding, user_id); + + // Store the chunk + db.store_item(text_chunk) + .await + .expect("Failed to store text chunk"); + + // Delete using nonexistent source_id + let nonexistent_source_id = "nonexistent_source"; + TextChunk::delete_by_source_id(nonexistent_source_id, &db) + .await + .expect("Delete operation with nonexistent source_id should not fail"); + + // Verify the real chunk still exists + let query = format!( + "SELECT * FROM {} WHERE source_id = '{}'", + TextChunk::table_name(), + real_source_id + ); + let remaining: Vec = db + .client + .query(query) + .await + .expect("Query failed") + .take(0) + .expect("Failed to get query results"); + assert_eq!( + remaining.len(), + 1, + "Chunk with real source_id should still exist" + ); + } +} diff --git a/common/src/storage/types/text_content.rs b/common/src/storage/types/text_content.rs index cd0eaae..790cdae 100644 --- a/common/src/storage/types/text_content.rs +++ b/common/src/storage/types/text_content.rs @@ -57,3 +57,120 @@ impl TextContent { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_text_content_creation() { + // Test basic object creation + let text = "Test content text".to_string(); + let instructions = "Test instructions".to_string(); + let category = "Test category".to_string(); + let user_id = "user123".to_string(); + + let text_content = TextContent::new( + text.clone(), + instructions.clone(), + category.clone(), + None, + None, + user_id.clone(), + ); + + // Check that the fields are set correctly + assert_eq!(text_content.text, text); + assert_eq!(text_content.instructions, instructions); + assert_eq!(text_content.category, category); + assert_eq!(text_content.user_id, user_id); + assert!(text_content.file_info.is_none()); + assert!(text_content.url.is_none()); + assert!(!text_content.id.is_empty()); + } + + #[tokio::test] + async fn test_text_content_with_url() { + // Test creating with URL + let text = "Content with URL".to_string(); + let instructions = "URL instructions".to_string(); + let category = "URL category".to_string(); + let user_id = "user123".to_string(); + let url = Some("https://example.com/document.pdf".to_string()); + + let text_content = TextContent::new( + text.clone(), + instructions.clone(), + category.clone(), + None, + url.clone(), + user_id.clone(), + ); + + // Check URL field is set + assert_eq!(text_content.url, url); + } + + #[tokio::test] + async fn test_text_content_patch() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Create initial text content + let initial_text = "Initial text".to_string(); + let initial_instructions = "Initial instructions".to_string(); + let initial_category = "Initial category".to_string(); + let user_id = "user123".to_string(); + + let text_content = TextContent::new( + initial_text, + initial_instructions, + initial_category, + None, + None, + user_id, + ); + + // Store the text content + let stored: Option = db + .store_item(text_content.clone()) + .await + .expect("Failed to store text content"); + assert!(stored.is_some()); + + // New values for patch + let new_instructions = "Updated instructions"; + let new_category = "Updated category"; + let new_text = "Updated text content"; + + // Apply the patch + TextContent::patch( + &text_content.id, + new_instructions, + new_category, + new_text, + &db, + ) + .await + .expect("Failed to patch text content"); + + // Retrieve the updated content + let updated: Option = db + .get_item(&text_content.id) + .await + .expect("Failed to get updated text content"); + assert!(updated.is_some()); + + let updated_content = updated.unwrap(); + + // Verify the updates + assert_eq!(updated_content.instructions, new_instructions); + assert_eq!(updated_content.category, new_category); + assert_eq!(updated_content.text, new_text); + assert!(updated_content.updated_at > text_content.updated_at); + } +} diff --git a/common/src/storage/types/user.rs b/common/src/storage/types/user.rs index 39b91bb..a35399c 100644 --- a/common/src/storage/types/user.rs +++ b/common/src/storage/types/user.rs @@ -414,3 +414,276 @@ impl User { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + + // Helper function to set up a test database with SystemSettings + async fn setup_test_db() -> SurrealDbClient { + let namespace = "test_ns"; + let database = Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, &database) + .await + .expect("Failed to start in-memory surrealdb"); + + db.ensure_initialized() + .await + .expect("Failed to setup the systemsettings"); + + db + } + + #[tokio::test] + async fn test_user_creation() { + // Setup test database + let db = setup_test_db().await; + + // Create a user + let email = "test@example.com"; + let password = "test_password"; + let timezone = "America/New_York"; + + let user = User::create_new( + email.to_string(), + password.to_string(), + &db, + timezone.to_string(), + ) + .await + .expect("Failed to create user"); + + // Verify user properties + assert!(!user.id.is_empty()); + assert_eq!(user.email, email); + assert_ne!(user.password, password); // Password should be hashed + assert!(!user.anonymous); + assert_eq!(user.timezone, timezone); + + // Verify it can be retrieved + let retrieved: Option = db + .get_item(&user.id) + .await + .expect("Failed to retrieve user"); + assert!(retrieved.is_some()); + + let retrieved = retrieved.unwrap(); + assert_eq!(retrieved.id, user.id); + assert_eq!(retrieved.email, email); + } + + #[tokio::test] + async fn test_user_authentication() { + // Setup test database + let db = setup_test_db().await; + + // Create a user + let email = "auth_test@example.com"; + let password = "auth_password"; + + User::create_new( + email.to_string(), + password.to_string(), + &db, + "UTC".to_string(), + ) + .await + .expect("Failed to create user"); + + // Test successful authentication + let auth_result = User::authenticate(email, password, &db).await; + assert!(auth_result.is_ok()); + + // Test failed authentication with wrong password + let wrong_auth = User::authenticate(email, "wrong_password", &db).await; + assert!(wrong_auth.is_err()); + + // Test failed authentication with non-existent user + let nonexistent = User::authenticate("nonexistent@example.com", password, &db).await; + assert!(nonexistent.is_err()); + } + + #[tokio::test] + async fn test_find_by_email() { + // Setup test database + let db = setup_test_db().await; + + // Create a user + let email = "find_test@example.com"; + let password = "find_password"; + + let created_user = User::create_new( + email.to_string(), + password.to_string(), + &db, + "UTC".to_string(), + ) + .await + .expect("Failed to create user"); + + // Test finding user by email + let found_user = User::find_by_email(email, &db) + .await + .expect("Error searching for user"); + assert!(found_user.is_some()); + let found_user = found_user.unwrap(); + assert_eq!(found_user.id, created_user.id); + assert_eq!(found_user.email, email); + + // Test finding non-existent user + let not_found = User::find_by_email("nonexistent@example.com", &db) + .await + .expect("Error searching for user"); + assert!(not_found.is_none()); + } + + #[tokio::test] + async fn test_api_key_management() { + // Setup test database + let db = setup_test_db().await; + + // Create a user + let email = "apikey_test@example.com"; + let password = "apikey_password"; + + let user = User::create_new( + email.to_string(), + password.to_string(), + &db, + "UTC".to_string(), + ) + .await + .expect("Failed to create user"); + + // Initially, user should have no API key + assert!(user.api_key.is_none()); + + // Generate API key + let api_key = User::set_api_key(&user.id, &db) + .await + .expect("Failed to set API key"); + assert!(!api_key.is_empty()); + assert!(api_key.starts_with("sk_")); + + // Verify the API key was saved + let updated_user: Option = db + .get_item(&user.id) + .await + .expect("Failed to retrieve user"); + assert!(updated_user.is_some()); + let updated_user = updated_user.unwrap(); + assert_eq!(updated_user.api_key, Some(api_key.clone())); + + // Test finding user by API key + let found_user = User::find_by_api_key(&api_key, &db) + .await + .expect("Error searching by API key"); + assert!(found_user.is_some()); + let found_user = found_user.unwrap(); + assert_eq!(found_user.id, user.id); + + // Revoke API key + User::revoke_api_key(&user.id, &db) + .await + .expect("Failed to revoke API key"); + + // Verify API key was revoked + let revoked_user: Option = db + .get_item(&user.id) + .await + .expect("Failed to retrieve user"); + assert!(revoked_user.is_some()); + let revoked_user = revoked_user.unwrap(); + assert!(revoked_user.api_key.is_none()); + + // Test searching by revoked API key + let not_found = User::find_by_api_key(&api_key, &db) + .await + .expect("Error searching by API key"); + assert!(not_found.is_none()); + } + + #[tokio::test] + async fn test_password_update() { + // Setup test database + let db = setup_test_db().await; + + // Create a user + let email = "pwd_test@example.com"; + let old_password = "old_password"; + let new_password = "new_password"; + + User::create_new( + email.to_string(), + old_password.to_string(), + &db, + "UTC".to_string(), + ) + .await + .expect("Failed to create user"); + + // Authenticate with old password + let auth_result = User::authenticate(email, old_password, &db).await; + assert!(auth_result.is_ok()); + + // Update password + User::patch_password(email, new_password, &db) + .await + .expect("Failed to update password"); + + // Old password should no longer work + let old_auth = User::authenticate(email, old_password, &db).await; + assert!(old_auth.is_err()); + + // New password should work + let new_auth = User::authenticate(email, new_password, &db).await; + assert!(new_auth.is_ok()); + } + + #[tokio::test] + async fn test_validate_timezone() { + // Valid timezones should be accepted as-is + assert_eq!(validate_timezone("America/New_York"), "America/New_York"); + assert_eq!(validate_timezone("Europe/London"), "Europe/London"); + assert_eq!(validate_timezone("Asia/Tokyo"), "Asia/Tokyo"); + assert_eq!(validate_timezone("UTC"), "UTC"); + + // Invalid timezones should be replaced with UTC + assert_eq!(validate_timezone("Invalid/Timezone"), "UTC"); + assert_eq!(validate_timezone("Not_Real"), "UTC"); + } + + #[tokio::test] + async fn test_timezone_update() { + // Setup test database + let db = setup_test_db().await; + + // Create user with default timezone + let email = "timezone_test@example.com"; + let user = User::create_new( + email.to_string(), + "password".to_string(), + &db, + "UTC".to_string(), + ) + .await + .expect("Failed to create user"); + + assert_eq!(user.timezone, "UTC"); + + // Update timezone + let new_timezone = "Europe/Paris"; + User::update_timezone(&user.id, new_timezone, &db) + .await + .expect("Failed to update timezone"); + + // Verify timezone was updated + let updated_user: Option = db + .get_item(&user.id) + .await + .expect("Failed to retrieve user"); + assert!(updated_user.is_some()); + let updated_user = updated_user.unwrap(); + assert_eq!(updated_user.timezone, new_timezone); + } +} diff --git a/composite-retrieval/Cargo.toml b/composite-retrieval/Cargo.toml index 29ac9fb..cc275ff 100644 --- a/composite-retrieval/Cargo.toml +++ b/composite-retrieval/Cargo.toml @@ -15,5 +15,7 @@ serde_json = { workspace = true } surrealdb = { workspace = true } futures = { workspace = true } async-openai = { workspace = true } + +uuid = { version = "1.10.0", features = ["v4", "serde"] } -common = { path = "../common" } +common = { path = "../common", features = ["test-utils"] } diff --git a/composite-retrieval/src/graph.rs b/composite-retrieval/src/graph.rs index 81a9712..89c9a63 100644 --- a/composite-retrieval/src/graph.rs +++ b/composite-retrieval/src/graph.rs @@ -61,3 +61,280 @@ pub async fn find_entities_by_relationship_by_id( db.query(query).await?.take(0) } + +#[cfg(test)] +mod tests { + use super::*; + use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}; + use common::storage::types::knowledge_relationship::KnowledgeRelationship; + use common::storage::types::StoredObject; + use uuid::Uuid; + + #[tokio::test] + async fn test_find_entities_by_source_ids() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Create some test entities with different source_ids + let source_id1 = "source123".to_string(); + let source_id2 = "source456".to_string(); + let source_id3 = "source789".to_string(); + + let entity_type = KnowledgeEntityType::Document; + let embedding = vec![0.1, 0.2, 0.3]; + let user_id = "user123".to_string(); + + // Entity with source_id1 + let entity1 = KnowledgeEntity::new( + source_id1.clone(), + "Entity 1".to_string(), + "Description 1".to_string(), + entity_type.clone(), + None, + embedding.clone(), + user_id.clone(), + ); + + // Entity with source_id2 + let entity2 = KnowledgeEntity::new( + source_id2.clone(), + "Entity 2".to_string(), + "Description 2".to_string(), + entity_type.clone(), + None, + embedding.clone(), + user_id.clone(), + ); + + // Another entity with source_id1 + let entity3 = KnowledgeEntity::new( + source_id1.clone(), + "Entity 3".to_string(), + "Description 3".to_string(), + entity_type.clone(), + None, + embedding.clone(), + user_id.clone(), + ); + + // Entity with source_id3 + let entity4 = KnowledgeEntity::new( + source_id3.clone(), + "Entity 4".to_string(), + "Description 4".to_string(), + entity_type.clone(), + None, + embedding.clone(), + user_id.clone(), + ); + + // Store all entities + db.store_item(entity1.clone()) + .await + .expect("Failed to store entity 1"); + db.store_item(entity2.clone()) + .await + .expect("Failed to store entity 2"); + db.store_item(entity3.clone()) + .await + .expect("Failed to store entity 3"); + db.store_item(entity4.clone()) + .await + .expect("Failed to store entity 4"); + + // Test finding entities by multiple source_ids + let source_ids = vec![source_id1.clone(), source_id2.clone()]; + let found_entities: Vec = + find_entities_by_source_ids(source_ids, KnowledgeEntity::table_name().to_string(), &db) + .await + .expect("Failed to find entities by source_ids"); + + // Should find 3 entities (2 with source_id1, 1 with source_id2) + assert_eq!( + found_entities.len(), + 3, + "Should find 3 entities with the specified source_ids" + ); + + // Check that entities with source_id1 and source_id2 are found + let found_source_ids: Vec = + found_entities.iter().map(|e| e.source_id.clone()).collect(); + assert!( + found_source_ids.contains(&source_id1), + "Should find entities with source_id1" + ); + assert!( + found_source_ids.contains(&source_id2), + "Should find entities with source_id2" + ); + assert!( + !found_source_ids.contains(&source_id3), + "Should not find entities with source_id3" + ); + + // Test finding entities by a single source_id + let single_source_id = vec![source_id1.clone()]; + let found_entities: Vec = find_entities_by_source_ids( + single_source_id, + KnowledgeEntity::table_name().to_string(), + &db, + ) + .await + .expect("Failed to find entities by single source_id"); + + // Should find 2 entities with source_id1 + assert_eq!( + found_entities.len(), + 2, + "Should find 2 entities with source_id1" + ); + + // Check that all found entities have source_id1 + for entity in found_entities { + assert_eq!( + entity.source_id, source_id1, + "All found entities should have source_id1" + ); + } + + // Test finding entities with non-existent source_id + let non_existent_source_id = vec!["non_existent_source".to_string()]; + let found_entities: Vec = find_entities_by_source_ids( + non_existent_source_id, + KnowledgeEntity::table_name().to_string(), + &db, + ) + .await + .expect("Failed to find entities by non-existent source_id"); + + // Should find 0 entities + assert_eq!( + found_entities.len(), + 0, + "Should find 0 entities with non-existent source_id" + ); + } + + #[tokio::test] + async fn test_find_entities_by_relationship_by_id() { + // Setup in-memory database for testing + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + // Create some test entities + let entity_type = KnowledgeEntityType::Document; + let embedding = vec![0.1, 0.2, 0.3]; + let user_id = "user123".to_string(); + + // Create the central entity we'll query relationships for + let central_entity = KnowledgeEntity::new( + "central_source".to_string(), + "Central Entity".to_string(), + "Central Description".to_string(), + entity_type.clone(), + None, + embedding.clone(), + user_id.clone(), + ); + + // Create related entities + let related_entity1 = KnowledgeEntity::new( + "related_source1".to_string(), + "Related Entity 1".to_string(), + "Related Description 1".to_string(), + entity_type.clone(), + None, + embedding.clone(), + user_id.clone(), + ); + + let related_entity2 = KnowledgeEntity::new( + "related_source2".to_string(), + "Related Entity 2".to_string(), + "Related Description 2".to_string(), + entity_type.clone(), + None, + embedding.clone(), + user_id.clone(), + ); + + // Create an unrelated entity + let unrelated_entity = KnowledgeEntity::new( + "unrelated_source".to_string(), + "Unrelated Entity".to_string(), + "Unrelated Description".to_string(), + entity_type.clone(), + None, + embedding.clone(), + user_id.clone(), + ); + + // Store all entities + let central_entity = db + .store_item(central_entity.clone()) + .await + .expect("Failed to store central entity") + .unwrap(); + let related_entity1 = db + .store_item(related_entity1.clone()) + .await + .expect("Failed to store related entity 1") + .unwrap(); + let related_entity2 = db + .store_item(related_entity2.clone()) + .await + .expect("Failed to store related entity 2") + .unwrap(); + let unrelated_entity = db + .store_item(unrelated_entity.clone()) + .await + .expect("Failed to store unrelated entity") + .unwrap(); + + // Create relationships + let source_id = "relationship_source".to_string(); + + // Create relationship 1: central -> related1 + let relationship1 = KnowledgeRelationship::new( + central_entity.id.clone(), + related_entity1.id.clone(), + user_id.clone(), + source_id.clone(), + "references".to_string(), + ); + + // Create relationship 2: central -> related2 + let relationship2 = KnowledgeRelationship::new( + central_entity.id.clone(), + related_entity2.id.clone(), + user_id.clone(), + source_id.clone(), + "contains".to_string(), + ); + + // Store relationships + relationship1 + .store_relationship(&db) + .await + .expect("Failed to store relationship 1"); + relationship2 + .store_relationship(&db) + .await + .expect("Failed to store relationship 2"); + + // Test finding entities related to the central entity + let related_entities = find_entities_by_relationship_by_id(&db, central_entity.id.clone()) + .await + .expect("Failed to find entities by relationship"); + + // Check that we found relationships + assert!(related_entities.len() > 0, "Should find related entities"); + } +}