mirror of
https://github.com/perstarkse/minne.git
synced 2026-03-29 05:41:54 +02:00
tests: testing all db interactions and types
This commit is contained in:
134
Cargo.lock
generated
134
Cargo.lock
generated
@@ -1,6 +1,6 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 3
|
||||
version = 4
|
||||
|
||||
[[package]]
|
||||
name = "Inflector"
|
||||
@@ -190,6 +190,12 @@ version = "1.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7d5a26814d8dcb93b0e5a0ff3c6d80a8843bafb21b39e8e18a6f05471870e110"
|
||||
|
||||
[[package]]
|
||||
name = "arc-swap"
|
||||
version = "1.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457"
|
||||
|
||||
[[package]]
|
||||
name = "argon2"
|
||||
version = "0.5.3"
|
||||
@@ -732,6 +738,12 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bitmaps"
|
||||
version = "3.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a1d084b0137aaa901caf9f1e8b21daa6aa24d41cd806e111335541eff9683bd6"
|
||||
|
||||
[[package]]
|
||||
name = "bitvec"
|
||||
version = "1.0.1"
|
||||
@@ -1078,6 +1090,7 @@ dependencies = [
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1183,6 +1196,25 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-deque"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
|
||||
dependencies = [
|
||||
"crossbeam-epoch",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-epoch"
|
||||
version = "0.9.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
|
||||
dependencies = [
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-utils"
|
||||
version = "0.8.20"
|
||||
@@ -1473,6 +1505,18 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "echodb"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d1eccc44ff21b80ca7e883ff57423a12610965a33637d5d0bef4adebcd81749"
|
||||
dependencies = [
|
||||
"arc-swap",
|
||||
"imbl",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ego-tree"
|
||||
version = "0.10.0"
|
||||
@@ -1552,6 +1596,19 @@ dependencies = [
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ext-sort"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fcf73e44617eab501beba39234441a194cf138629d3b6447f81f573e1c3d0a13"
|
||||
dependencies = [
|
||||
"log",
|
||||
"rayon",
|
||||
"rmp-serde",
|
||||
"serde",
|
||||
"tempfile",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fancy-regex"
|
||||
version = "0.13.0"
|
||||
@@ -2019,7 +2076,6 @@ dependencies = [
|
||||
"composite-retrieval",
|
||||
"futures",
|
||||
"include_dir",
|
||||
"json-stream-parser",
|
||||
"minijinja",
|
||||
"minijinja-autoreload",
|
||||
"minijinja-contrib",
|
||||
@@ -2374,6 +2430,28 @@ dependencies = [
|
||||
"icu_properties",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "imbl"
|
||||
version = "2.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "978d142c8028edf52095703af2fad11d6f611af1246685725d6b850634647085"
|
||||
dependencies = [
|
||||
"bitmaps",
|
||||
"imbl-sized-chunks",
|
||||
"rand_core",
|
||||
"rand_xoshiro",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "imbl-sized-chunks"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "144006fb58ed787dcae3f54575ff4349755b00ccc99f4b4873860b654be1ed63"
|
||||
dependencies = [
|
||||
"bitmaps",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "include_dir"
|
||||
version = "0.7.4"
|
||||
@@ -2523,15 +2601,6 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "json-stream-parser"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8a70ab2b05e827e0604229fcf11b24560b036a21286a41517a6cac271f12a6a9"
|
||||
dependencies = [
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "json5"
|
||||
version = "0.4.1"
|
||||
@@ -3740,12 +3809,41 @@ dependencies = [
|
||||
"getrandom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_xoshiro"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa"
|
||||
dependencies = [
|
||||
"rand_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rawpointer"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
|
||||
|
||||
[[package]]
|
||||
name = "rayon"
|
||||
version = "1.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
|
||||
dependencies = [
|
||||
"either",
|
||||
"rayon-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon-core"
|
||||
version = "1.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
|
||||
dependencies = [
|
||||
"crossbeam-deque",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "reblessive"
|
||||
version = "0.4.1"
|
||||
@@ -4038,6 +4136,17 @@ dependencies = [
|
||||
"paste",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rmp-serde"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "52e599a477cf9840e92f2cde9a7189e67b42c57532749bf90aea6ec10facd4db"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"rmp",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rmpv"
|
||||
version = "1.3.0"
|
||||
@@ -4832,6 +4941,8 @@ dependencies = [
|
||||
"dashmap 5.5.3",
|
||||
"deunicode",
|
||||
"dmp",
|
||||
"echodb",
|
||||
"ext-sort",
|
||||
"fst",
|
||||
"futures",
|
||||
"fuzzy-matcher",
|
||||
@@ -4875,6 +4986,7 @@ dependencies = [
|
||||
"storekey",
|
||||
"subtle",
|
||||
"surrealdb-derive",
|
||||
"tempfile",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tracing",
|
||||
|
||||
@@ -17,6 +17,6 @@ serde_json = "1.0.128"
|
||||
thiserror = "1.0.63"
|
||||
anyhow = "1.0.94"
|
||||
tracing = "0.1.40"
|
||||
surrealdb = "2.0.4"
|
||||
surrealdb = { version = "2.0.4", features = ["kv-mem"] }
|
||||
futures = "0.3.31"
|
||||
async-openai = "0.24.1"
|
||||
|
||||
@@ -12,7 +12,7 @@ tracing = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
surrealdb = { workspace = true }
|
||||
surrealdb = { workspace = true, features = ["kv-mem"] }
|
||||
async-openai = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
|
||||
@@ -35,3 +35,6 @@ minijinja = { version = "2.5.0", features = ["loader", "multi_template"] }
|
||||
minijinja-autoreload = "2.5.0"
|
||||
minijinja-embed = { version = "2.8.0" }
|
||||
minijinja-contrib = { version = "2.6.0", features = ["datetime", "timezone"] }
|
||||
|
||||
[features]
|
||||
test-utils = []
|
||||
|
||||
@@ -189,3 +189,104 @@ impl Deref for SurrealDbClient {
|
||||
&self.client
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-utils"))]
|
||||
impl SurrealDbClient {
|
||||
/// Create an in-memory SurrealDB client for testing.
|
||||
pub async fn memory(namespace: &str, database: &str) -> Result<Self, Error> {
|
||||
let db = connect("mem://").await?;
|
||||
|
||||
db.use_ns(namespace).use_db(database).await?;
|
||||
|
||||
Ok(SurrealDbClient { client: db })
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::stored_object;
|
||||
|
||||
use super::*;
|
||||
use uuid::Uuid;
|
||||
|
||||
stored_object!(Dummy, "dummy", {
|
||||
name: String
|
||||
});
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_initialization_and_crud() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string(); // ensures isolation per test run
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Call your initialization
|
||||
db.ensure_initialized()
|
||||
.await
|
||||
.expect("Failed to initialize schema");
|
||||
|
||||
// Test basic CRUD
|
||||
let dummy = Dummy {
|
||||
id: "abc".to_string(),
|
||||
name: "first".to_string(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
};
|
||||
|
||||
// Store
|
||||
let stored = db.store_item(dummy.clone()).await.expect("Failed to store");
|
||||
assert!(stored.is_some());
|
||||
|
||||
// Read
|
||||
let fetched = db
|
||||
.get_item::<Dummy>(&dummy.id)
|
||||
.await
|
||||
.expect("Failed to fetch");
|
||||
assert_eq!(fetched, Some(dummy.clone()));
|
||||
|
||||
// Read all
|
||||
let all = db
|
||||
.get_all_stored_items::<Dummy>()
|
||||
.await
|
||||
.expect("Failed to fetch all");
|
||||
assert!(all.contains(&dummy));
|
||||
|
||||
// Delete
|
||||
let deleted = db
|
||||
.delete_item::<Dummy>(&dummy.id)
|
||||
.await
|
||||
.expect("Failed to delete");
|
||||
assert_eq!(deleted, Some(dummy));
|
||||
|
||||
// After delete, should not be present
|
||||
let fetch_post = db
|
||||
.get_item::<Dummy>("abc")
|
||||
.await
|
||||
.expect("Failed fetch post delete");
|
||||
assert!(fetch_post.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_setup_auth() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string(); // ensures isolation per test run
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Should not panic or fail
|
||||
db.setup_auth().await.expect("Failed to setup auth");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_build_indexes() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.build_indexes().await.expect("Failed to build indexes");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use crate::storage::types::{file_info::deserialize_flexible_id, user::User, StoredObject};
|
||||
use axum::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient};
|
||||
@@ -11,32 +12,40 @@ pub struct Analytics {
|
||||
pub visitors: i64,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl StoredObject for Analytics {
|
||||
fn table_name() -> &'static str {
|
||||
"analytics"
|
||||
}
|
||||
|
||||
fn get_id(&self) -> &str {
|
||||
&self.id
|
||||
}
|
||||
}
|
||||
|
||||
impl Analytics {
|
||||
pub async fn ensure_initialized(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||
let analytics = db.select(("analytics", "current")).await?;
|
||||
let analytics = db.get_item::<Self>("current").await?;
|
||||
|
||||
if analytics.is_none() {
|
||||
let created: Option<Analytics> = db
|
||||
.create(("analytics", "current"))
|
||||
.content(Analytics {
|
||||
id: "current".to_string(),
|
||||
visitors: 0,
|
||||
page_loads: 0,
|
||||
})
|
||||
.await?;
|
||||
let created_analytics = Analytics {
|
||||
id: "current".to_string(),
|
||||
visitors: 0,
|
||||
page_loads: 0,
|
||||
};
|
||||
|
||||
return created.ok_or(AppError::Validation("Failed to initialize settings".into()));
|
||||
};
|
||||
let stored: Option<Self> = db.store_item(created_analytics).await?;
|
||||
return stored.ok_or(AppError::Validation(
|
||||
"Failed to initialize analytics".into(),
|
||||
));
|
||||
}
|
||||
|
||||
analytics.ok_or(AppError::Validation("Failed to initialize settings".into()))
|
||||
analytics.ok_or(AppError::Validation(
|
||||
"Failed to initialize analytics".into(),
|
||||
))
|
||||
}
|
||||
pub async fn get_current(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||
let analytics: Option<Self> = db
|
||||
.client
|
||||
.query("SELECT * FROM type::thing('analytics', 'current')")
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
let analytics: Option<Self> = db.get_item("current").await?;
|
||||
analytics.ok_or(AppError::NotFound("Analytics not found".into()))
|
||||
}
|
||||
|
||||
@@ -61,6 +70,7 @@ impl Analytics {
|
||||
}
|
||||
|
||||
pub async fn get_users_amount(db: &SurrealDbClient) -> Result<i64, AppError> {
|
||||
// We need to use a direct query for COUNT aggregation
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CountResult {
|
||||
count: i64,
|
||||
@@ -76,3 +86,192 @@ impl Analytics {
|
||||
Ok(result.map(|r| r.count).unwrap_or(0))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::stored_object;
|
||||
use uuid::Uuid;
|
||||
|
||||
stored_object!(TestUser, "user", {
|
||||
email: String,
|
||||
password: String,
|
||||
user_id: String
|
||||
});
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analytics_initialization() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Test initialization of analytics
|
||||
let analytics = Analytics::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to initialize analytics");
|
||||
|
||||
// Verify initial state after initialization
|
||||
assert_eq!(analytics.id, "current");
|
||||
assert_eq!(analytics.page_loads, 0);
|
||||
assert_eq!(analytics.visitors, 0);
|
||||
|
||||
// Test idempotency - ensure calling it again doesn't change anything
|
||||
let analytics_again = Analytics::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to get analytics after initialization");
|
||||
|
||||
assert_eq!(analytics.id, analytics_again.id);
|
||||
assert_eq!(analytics.page_loads, analytics_again.page_loads);
|
||||
assert_eq!(analytics.visitors, analytics_again.visitors);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_current_analytics() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Initialize analytics
|
||||
Analytics::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to initialize analytics");
|
||||
|
||||
// Test get_current method
|
||||
let analytics = Analytics::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to get current analytics");
|
||||
|
||||
assert_eq!(analytics.id, "current");
|
||||
assert_eq!(analytics.page_loads, 0);
|
||||
assert_eq!(analytics.visitors, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_increment_visitors() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Initialize analytics
|
||||
Analytics::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to initialize analytics");
|
||||
|
||||
// Test increment_visitors method
|
||||
let analytics = Analytics::increment_visitors(&db)
|
||||
.await
|
||||
.expect("Failed to increment visitors");
|
||||
|
||||
assert_eq!(analytics.visitors, 1);
|
||||
assert_eq!(analytics.page_loads, 0);
|
||||
|
||||
// Increment again and check
|
||||
let analytics = Analytics::increment_visitors(&db)
|
||||
.await
|
||||
.expect("Failed to increment visitors again");
|
||||
|
||||
assert_eq!(analytics.visitors, 2);
|
||||
assert_eq!(analytics.page_loads, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_increment_page_loads() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Initialize analytics
|
||||
Analytics::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to initialize analytics");
|
||||
|
||||
// Test increment_page_loads method
|
||||
let analytics = Analytics::increment_page_loads(&db)
|
||||
.await
|
||||
.expect("Failed to increment page loads");
|
||||
|
||||
assert_eq!(analytics.visitors, 0);
|
||||
assert_eq!(analytics.page_loads, 1);
|
||||
|
||||
// Increment again and check
|
||||
let analytics = Analytics::increment_page_loads(&db)
|
||||
.await
|
||||
.expect("Failed to increment page loads again");
|
||||
|
||||
assert_eq!(analytics.visitors, 0);
|
||||
assert_eq!(analytics.page_loads, 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_users_amount() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Test with no users
|
||||
let count = Analytics::get_users_amount(&db)
|
||||
.await
|
||||
.expect("Failed to get users amount");
|
||||
assert_eq!(count, 0);
|
||||
|
||||
// Create a few test users
|
||||
for i in 0..3 {
|
||||
let user = TestUser {
|
||||
id: format!("user{}", i),
|
||||
email: format!("user{}@example.com", i),
|
||||
password: "password".to_string(),
|
||||
user_id: format!("uid{}", i),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
};
|
||||
|
||||
db.store_item(user)
|
||||
.await
|
||||
.expect("Failed to create test user");
|
||||
}
|
||||
|
||||
// Test users amount after adding users
|
||||
let count = Analytics::get_users_amount(&db)
|
||||
.await
|
||||
.expect("Failed to get users amount after adding users");
|
||||
assert_eq!(count, 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_current_nonexistent() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Don't initialize analytics and try to get it
|
||||
let result = Analytics::get_current(&db).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
if let Err(err) = result {
|
||||
match err {
|
||||
AppError::NotFound(_) => {
|
||||
// Expected error
|
||||
}
|
||||
_ => panic!("Expected NotFound error, got: {:?}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,3 +47,178 @@ impl Conversation {
|
||||
Ok((conversation, messages))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::storage::types::message::MessageRole;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_conversation() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create a new conversation
|
||||
let user_id = "test_user";
|
||||
let title = "Test Conversation";
|
||||
let conversation = Conversation::new(user_id.to_string(), title.to_string());
|
||||
|
||||
// Verify conversation properties
|
||||
assert_eq!(conversation.user_id, user_id);
|
||||
assert_eq!(conversation.title, title);
|
||||
assert!(!conversation.id.is_empty());
|
||||
|
||||
// Store the conversation
|
||||
let result = db.store_item(conversation.clone()).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify it can be retrieved
|
||||
let retrieved: Option<Conversation> = db
|
||||
.get_item(&conversation.id)
|
||||
.await
|
||||
.expect("Failed to retrieve conversation");
|
||||
assert!(retrieved.is_some());
|
||||
|
||||
let retrieved = retrieved.unwrap();
|
||||
assert_eq!(retrieved.id, conversation.id);
|
||||
assert_eq!(retrieved.user_id, user_id);
|
||||
assert_eq!(retrieved.title, title);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_complete_conversation_not_found() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Try to get a conversation that doesn't exist
|
||||
let result =
|
||||
Conversation::get_complete_conversation("nonexistent_id", "test_user", &db).await;
|
||||
assert!(result.is_err());
|
||||
|
||||
match result {
|
||||
Err(AppError::NotFound(_)) => { /* expected error */ }
|
||||
_ => panic!("Expected NotFound error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_complete_conversation_unauthorized() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create and store a conversation for user_id_1
|
||||
let user_id_1 = "user_1";
|
||||
let conversation =
|
||||
Conversation::new(user_id_1.to_string(), "Private Conversation".to_string());
|
||||
let conversation_id = conversation.id.clone();
|
||||
|
||||
db.store_item(conversation)
|
||||
.await
|
||||
.expect("Failed to store conversation");
|
||||
|
||||
// Try to access with a different user
|
||||
let user_id_2 = "user_2";
|
||||
let result =
|
||||
Conversation::get_complete_conversation(&conversation_id, user_id_2, &db).await;
|
||||
assert!(result.is_err());
|
||||
|
||||
match result {
|
||||
Err(AppError::Auth(_)) => { /* expected error */ }
|
||||
_ => panic!("Expected Auth error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_complete_conversation_with_messages() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create and store a conversation for user_id_1
|
||||
let user_id_1 = "user_1";
|
||||
let conversation = Conversation::new(user_id_1.to_string(), "Conversation".to_string());
|
||||
let conversation_id = conversation.id.clone();
|
||||
|
||||
db.store_item(conversation)
|
||||
.await
|
||||
.expect("Failed to store conversation");
|
||||
|
||||
// Create messages
|
||||
let message1 = Message::new(
|
||||
conversation_id.clone(),
|
||||
MessageRole::User,
|
||||
"Hello, AI!".to_string(),
|
||||
None,
|
||||
);
|
||||
let message2 = Message::new(
|
||||
conversation_id.clone(),
|
||||
MessageRole::AI,
|
||||
"Hello, human! How can I help you today?".to_string(),
|
||||
None,
|
||||
);
|
||||
let message3 = Message::new(
|
||||
conversation_id.clone(),
|
||||
MessageRole::User,
|
||||
"Tell me about Rust programming.".to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
// Store messages
|
||||
db.store_item(message1)
|
||||
.await
|
||||
.expect("Failed to store message1");
|
||||
db.store_item(message2)
|
||||
.await
|
||||
.expect("Failed to store message2");
|
||||
db.store_item(message3)
|
||||
.await
|
||||
.expect("Failed to store message3");
|
||||
|
||||
// Retrieve the complete conversation
|
||||
let result =
|
||||
Conversation::get_complete_conversation(&conversation_id, user_id_1, &db).await;
|
||||
assert!(result.is_ok(), "Failed to retrieve complete conversation");
|
||||
|
||||
let (retrieved_conversation, messages) = result.unwrap();
|
||||
|
||||
// Verify conversation data
|
||||
assert_eq!(retrieved_conversation.id, conversation_id);
|
||||
assert_eq!(retrieved_conversation.user_id, user_id_1);
|
||||
assert_eq!(retrieved_conversation.title, "Conversation");
|
||||
|
||||
// Verify messages
|
||||
assert_eq!(messages.len(), 3);
|
||||
|
||||
// Verify messages are sorted by updated_at
|
||||
let message_contents: Vec<&str> = messages.iter().map(|m| m.content.as_str()).collect();
|
||||
assert!(message_contents.contains(&"Hello, AI!"));
|
||||
assert!(message_contents.contains(&"Hello, human! How can I help you today?"));
|
||||
assert!(message_contents.contains(&"Tell me about Rust programming."));
|
||||
|
||||
// Make sure we can't access with different user
|
||||
let user_id_2 = "user_2";
|
||||
let unauthorized_result =
|
||||
Conversation::get_complete_conversation(&conversation_id, user_id_2, &db).await;
|
||||
assert!(unauthorized_result.is_err());
|
||||
match unauthorized_result {
|
||||
Err(AppError::Auth(_)) => { /* expected error */ }
|
||||
_ => panic!("Expected Auth error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,16 +131,32 @@ impl FileInfo {
|
||||
/// Sanitizes the file name to prevent security vulnerabilities like directory traversal.
|
||||
/// Replaces any non-alphanumeric characters (excluding '.' and '_') with underscores.
|
||||
fn sanitize_file_name(file_name: &str) -> String {
|
||||
file_name
|
||||
.chars()
|
||||
.map(|c| {
|
||||
if c.is_ascii_alphanumeric() || c == '.' || c == '_' {
|
||||
c
|
||||
} else {
|
||||
'_'
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
if let Some(idx) = file_name.rfind('.') {
|
||||
let (name, ext) = file_name.split_at(idx);
|
||||
let sanitized_name: String = name
|
||||
.chars()
|
||||
.map(|c| {
|
||||
if c.is_ascii_alphanumeric() || c == '_' {
|
||||
c
|
||||
} else {
|
||||
'_'
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
format!("{}{}", sanitized_name, ext)
|
||||
} else {
|
||||
// No extension
|
||||
file_name
|
||||
.chars()
|
||||
.map(|c| {
|
||||
if c.is_ascii_alphanumeric() || c == '_' {
|
||||
c
|
||||
} else {
|
||||
'_'
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Persists the file to the filesystem under `./data/{user_id}/{uuid}/{file_name}`.
|
||||
@@ -243,3 +259,331 @@ impl FileInfo {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::http::HeaderMap;
|
||||
use axum_typed_multipart::FieldMetadata;
|
||||
use std::io::Write;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
/// Creates a test temporary file with the given content
|
||||
fn create_test_file(content: &[u8], file_name: &str) -> FieldData<NamedTempFile> {
|
||||
let mut temp_file = NamedTempFile::new().expect("Failed to create temp file");
|
||||
temp_file
|
||||
.write_all(content)
|
||||
.expect("Failed to write to temp file");
|
||||
|
||||
let metadata = FieldMetadata {
|
||||
name: Some("file".to_string()),
|
||||
file_name: Some(file_name.to_string()),
|
||||
content_type: None,
|
||||
headers: HeaderMap::default(),
|
||||
};
|
||||
|
||||
let field_data = FieldData {
|
||||
metadata,
|
||||
contents: temp_file,
|
||||
};
|
||||
|
||||
field_data
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_file_creation() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create a test file
|
||||
let content = b"This is a test file content";
|
||||
let file_name = "test_file.txt";
|
||||
let field_data = create_test_file(content, file_name);
|
||||
|
||||
// Create a FileInfo instance
|
||||
let user_id = "test_user";
|
||||
let file_info = FileInfo::new(field_data, &db, user_id).await;
|
||||
|
||||
// We can't fully test persistence to disk in unit tests,
|
||||
// but we can verify the database record was created
|
||||
assert!(file_info.is_ok());
|
||||
let file_info = file_info.unwrap();
|
||||
|
||||
// Check essential properties
|
||||
assert!(!file_info.id.is_empty());
|
||||
assert_eq!(file_info.file_name, file_name);
|
||||
assert!(!file_info.sha256.is_empty());
|
||||
assert!(!file_info.path.is_empty());
|
||||
assert!(file_info.mime_type.contains("text/plain"));
|
||||
|
||||
// Verify it's in the database
|
||||
let stored: Option<FileInfo> = db
|
||||
.get_item(&file_info.id)
|
||||
.await
|
||||
.expect("Failed to retrieve file info");
|
||||
assert!(stored.is_some());
|
||||
let stored = stored.unwrap();
|
||||
assert_eq!(stored.id, file_info.id);
|
||||
assert_eq!(stored.file_name, file_info.file_name);
|
||||
assert_eq!(stored.sha256, file_info.sha256);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_file_duplicate_detection() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// First, store a file with known content
|
||||
let content = b"This is a test file for duplicate detection";
|
||||
let file_name = "original.txt";
|
||||
let user_id = "test_user";
|
||||
|
||||
let field_data1 = create_test_file(content, file_name);
|
||||
let original_file_info = FileInfo::new(field_data1, &db, user_id)
|
||||
.await
|
||||
.expect("Failed to create original file");
|
||||
|
||||
// Now try to store another file with the same content but different name
|
||||
let duplicate_name = "duplicate.txt";
|
||||
let field_data2 = create_test_file(content, duplicate_name);
|
||||
|
||||
// The system should detect it's the same file and return the original FileInfo
|
||||
let duplicate_file_info = FileInfo::new(field_data2, &db, user_id)
|
||||
.await
|
||||
.expect("Failed to process duplicate file");
|
||||
|
||||
// The returned FileInfo should match the original
|
||||
assert_eq!(duplicate_file_info.id, original_file_info.id);
|
||||
assert_eq!(duplicate_file_info.sha256, original_file_info.sha256);
|
||||
|
||||
// But it should retain the original file name, not the duplicate's name
|
||||
assert_eq!(duplicate_file_info.file_name, file_name);
|
||||
assert_ne!(duplicate_file_info.file_name, duplicate_name);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_guess_mime_type() {
|
||||
// Test common file extensions
|
||||
assert_eq!(
|
||||
FileInfo::guess_mime_type(Path::new("test.txt")),
|
||||
"text/plain".to_string()
|
||||
);
|
||||
assert_eq!(
|
||||
FileInfo::guess_mime_type(Path::new("image.png")),
|
||||
"image/png".to_string()
|
||||
);
|
||||
assert_eq!(
|
||||
FileInfo::guess_mime_type(Path::new("document.pdf")),
|
||||
"application/pdf".to_string()
|
||||
);
|
||||
assert_eq!(
|
||||
FileInfo::guess_mime_type(Path::new("data.json")),
|
||||
"application/json".to_string()
|
||||
);
|
||||
|
||||
// Test unknown extension
|
||||
assert_eq!(
|
||||
FileInfo::guess_mime_type(Path::new("unknown.929yz")),
|
||||
"application/octet-stream".to_string()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sanitize_file_name() {
|
||||
// Safe characters should remain unchanged
|
||||
assert_eq!(
|
||||
FileInfo::sanitize_file_name("normal_file.txt"),
|
||||
"normal_file.txt"
|
||||
);
|
||||
assert_eq!(FileInfo::sanitize_file_name("file123.doc"), "file123.doc");
|
||||
|
||||
// Unsafe characters should be replaced with underscores
|
||||
assert_eq!(
|
||||
FileInfo::sanitize_file_name("file with spaces.txt"),
|
||||
"file_with_spaces.txt"
|
||||
);
|
||||
assert_eq!(
|
||||
FileInfo::sanitize_file_name("file/with/path.txt"),
|
||||
"file_with_path.txt"
|
||||
);
|
||||
assert_eq!(
|
||||
FileInfo::sanitize_file_name("file:with:colons.txt"),
|
||||
"file_with_colons.txt"
|
||||
);
|
||||
assert_eq!(
|
||||
FileInfo::sanitize_file_name("../dangerous.txt"),
|
||||
"___dangerous.txt"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_by_sha_not_found() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Try to find a file with a SHA that doesn't exist
|
||||
let result = FileInfo::get_by_sha("nonexistent_sha_hash", &db).await;
|
||||
assert!(result.is_err());
|
||||
|
||||
match result {
|
||||
Err(FileError::FileNotFound(_)) => {
|
||||
// Expected error
|
||||
}
|
||||
_ => panic!("Expected FileNotFound error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_manual_file_info_creation() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create a FileInfo instance directly
|
||||
let now = Utc::now();
|
||||
let file_info = FileInfo {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
sha256: "test_sha256_hash".to_string(),
|
||||
path: "/path/to/file.txt".to_string(),
|
||||
file_name: "manual_file.txt".to_string(),
|
||||
mime_type: "text/plain".to_string(),
|
||||
};
|
||||
|
||||
// Store it in the database
|
||||
let result = db.store_item(file_info.clone()).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify it can be retrieved
|
||||
let retrieved: Option<FileInfo> = db
|
||||
.get_item(&file_info.id)
|
||||
.await
|
||||
.expect("Failed to retrieve file info");
|
||||
assert!(retrieved.is_some());
|
||||
|
||||
let retrieved = retrieved.unwrap();
|
||||
assert_eq!(retrieved.id, file_info.id);
|
||||
assert_eq!(retrieved.sha256, file_info.sha256);
|
||||
assert_eq!(retrieved.file_name, file_info.file_name);
|
||||
assert_eq!(retrieved.path, file_info.path);
|
||||
assert_eq!(retrieved.mime_type, file_info.mime_type);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_id() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create a FileInfo instance directly (without persistence to disk)
|
||||
let now = Utc::now();
|
||||
let file_id = Uuid::new_v4().to_string();
|
||||
|
||||
// Create a temporary directory that mimics the structure we would have on disk
|
||||
let base_dir = Path::new("./data");
|
||||
let user_id = "test_user";
|
||||
let user_dir = base_dir.join(user_id);
|
||||
let uuid_dir = user_dir.join(&file_id);
|
||||
|
||||
tokio::fs::create_dir_all(&uuid_dir)
|
||||
.await
|
||||
.expect("Failed to create test directories");
|
||||
|
||||
// Create a test file in the directory
|
||||
let test_file_path = uuid_dir.join("test_file.txt");
|
||||
tokio::fs::write(&test_file_path, b"test content")
|
||||
.await
|
||||
.expect("Failed to write test file");
|
||||
|
||||
// The file path should point to our test file
|
||||
let file_info = FileInfo {
|
||||
id: file_id.clone(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
sha256: "test_sha256_hash".to_string(),
|
||||
path: test_file_path.to_string_lossy().to_string(),
|
||||
file_name: "test_file.txt".to_string(),
|
||||
mime_type: "text/plain".to_string(),
|
||||
};
|
||||
|
||||
// Store it in the database
|
||||
db.store_item(file_info.clone())
|
||||
.await
|
||||
.expect("Failed to store file info");
|
||||
|
||||
// Verify file exists on disk
|
||||
assert!(tokio::fs::try_exists(&test_file_path)
|
||||
.await
|
||||
.unwrap_or(false));
|
||||
|
||||
// Delete the file
|
||||
let delete_result = FileInfo::delete_by_id(&file_id, &db).await;
|
||||
|
||||
// Delete should be successful
|
||||
assert!(
|
||||
delete_result.is_ok(),
|
||||
"Failed to delete file: {:?}",
|
||||
delete_result
|
||||
);
|
||||
|
||||
// Verify the file is removed from the database
|
||||
let retrieved: Option<FileInfo> = db
|
||||
.get_item(&file_id)
|
||||
.await
|
||||
.expect("Failed to query database");
|
||||
assert!(
|
||||
retrieved.is_none(),
|
||||
"FileInfo should be deleted from the database"
|
||||
);
|
||||
|
||||
// Verify directory is gone
|
||||
assert!(
|
||||
!tokio::fs::try_exists(&uuid_dir).await.unwrap_or(true),
|
||||
"UUID directory should be deleted"
|
||||
);
|
||||
|
||||
// Clean up test directory if it exists
|
||||
let _ = tokio::fs::remove_dir_all(base_dir).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_id_not_found() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Try to delete a file that doesn't exist
|
||||
let result = FileInfo::delete_by_id("nonexistent_id", &db).await;
|
||||
|
||||
// Should fail with FileNotFound error
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(FileError::FileNotFound(_)) => {
|
||||
// Expected error
|
||||
}
|
||||
_ => panic!("Expected FileNotFound error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use crate::{error::AppError, storage::types::file_info::FileInfo};
|
||||
use chrono::Utc;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::info;
|
||||
use url::Url;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
|
||||
pub enum IngestionPayload {
|
||||
Url {
|
||||
url: String,
|
||||
@@ -93,3 +94,237 @@ impl IngestionPayload {
|
||||
Ok(object_list)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// Create a mock FileInfo for testing
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
struct MockFileInfo {
|
||||
id: String,
|
||||
}
|
||||
|
||||
impl From<MockFileInfo> for FileInfo {
|
||||
fn from(mock: MockFileInfo) -> Self {
|
||||
// This is just a test implementation, the actual fields don't matter
|
||||
// as we're just testing the IngestionPayload functionality
|
||||
FileInfo {
|
||||
id: mock.id,
|
||||
sha256: "mock-sha256".to_string(),
|
||||
path: "/mock/path".to_string(),
|
||||
file_name: "mock.txt".to_string(),
|
||||
mime_type: "text/plain".to_string(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_ingestion_payload_with_url() {
|
||||
let url = "https://example.com";
|
||||
let instructions = "Process this URL";
|
||||
let category = "websites";
|
||||
let user_id = "user123";
|
||||
let files = vec![];
|
||||
|
||||
let result = IngestionPayload::create_ingestion_payload(
|
||||
Some(url.to_string()),
|
||||
instructions.to_string(),
|
||||
category.to_string(),
|
||||
files,
|
||||
user_id,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
match &result[0] {
|
||||
IngestionPayload::Url {
|
||||
url: payload_url,
|
||||
instructions: payload_instructions,
|
||||
category: payload_category,
|
||||
user_id: payload_user_id,
|
||||
} => {
|
||||
// URL parser may normalize the URL by adding a trailing slash
|
||||
assert!(payload_url == &url.to_string() || payload_url == &format!("{}/", url));
|
||||
assert_eq!(payload_instructions, &instructions);
|
||||
assert_eq!(payload_category, &category);
|
||||
assert_eq!(payload_user_id, &user_id);
|
||||
}
|
||||
_ => panic!("Expected Url variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_ingestion_payload_with_text() {
|
||||
let text = "This is some text content";
|
||||
let instructions = "Process this text";
|
||||
let category = "notes";
|
||||
let user_id = "user123";
|
||||
let files = vec![];
|
||||
|
||||
let result = IngestionPayload::create_ingestion_payload(
|
||||
Some(text.to_string()),
|
||||
instructions.to_string(),
|
||||
category.to_string(),
|
||||
files,
|
||||
user_id,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
match &result[0] {
|
||||
IngestionPayload::Text {
|
||||
text: payload_text,
|
||||
instructions: payload_instructions,
|
||||
category: payload_category,
|
||||
user_id: payload_user_id,
|
||||
} => {
|
||||
assert_eq!(payload_text, text);
|
||||
assert_eq!(payload_instructions, instructions);
|
||||
assert_eq!(payload_category, category);
|
||||
assert_eq!(payload_user_id, user_id);
|
||||
}
|
||||
_ => panic!("Expected Text variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_ingestion_payload_with_file() {
|
||||
let instructions = "Process this file";
|
||||
let category = "documents";
|
||||
let user_id = "user123";
|
||||
|
||||
// Create a mock FileInfo
|
||||
let mock_file = MockFileInfo {
|
||||
id: "file123".to_string(),
|
||||
};
|
||||
|
||||
let file_info: FileInfo = mock_file.into();
|
||||
let files = vec![file_info.clone()];
|
||||
|
||||
let result = IngestionPayload::create_ingestion_payload(
|
||||
None,
|
||||
instructions.to_string(),
|
||||
category.to_string(),
|
||||
files,
|
||||
user_id,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
match &result[0] {
|
||||
IngestionPayload::File {
|
||||
file_info: payload_file_info,
|
||||
instructions: payload_instructions,
|
||||
category: payload_category,
|
||||
user_id: payload_user_id,
|
||||
} => {
|
||||
assert_eq!(payload_file_info.id, file_info.id);
|
||||
assert_eq!(payload_instructions, instructions);
|
||||
assert_eq!(payload_category, category);
|
||||
assert_eq!(payload_user_id, user_id);
|
||||
}
|
||||
_ => panic!("Expected File variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_ingestion_payload_with_url_and_file() {
|
||||
let url = "https://example.com";
|
||||
let instructions = "Process this data";
|
||||
let category = "mixed";
|
||||
let user_id = "user123";
|
||||
|
||||
// Create a mock FileInfo
|
||||
let mock_file = MockFileInfo {
|
||||
id: "file123".to_string(),
|
||||
};
|
||||
|
||||
let file_info: FileInfo = mock_file.into();
|
||||
let files = vec![file_info.clone()];
|
||||
|
||||
let result = IngestionPayload::create_ingestion_payload(
|
||||
Some(url.to_string()),
|
||||
instructions.to_string(),
|
||||
category.to_string(),
|
||||
files,
|
||||
user_id,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 2);
|
||||
|
||||
// Check first item is URL
|
||||
match &result[0] {
|
||||
IngestionPayload::Url {
|
||||
url: payload_url, ..
|
||||
} => {
|
||||
// URL parser may normalize the URL by adding a trailing slash
|
||||
assert!(payload_url == &url.to_string() || payload_url == &format!("{}/", url));
|
||||
}
|
||||
_ => panic!("Expected first item to be Url variant"),
|
||||
}
|
||||
|
||||
// Check second item is File
|
||||
match &result[1] {
|
||||
IngestionPayload::File {
|
||||
file_info: payload_file_info,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(payload_file_info.id, file_info.id);
|
||||
}
|
||||
_ => panic!("Expected second item to be File variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_ingestion_payload_empty_input() {
|
||||
let instructions = "Process something";
|
||||
let category = "empty";
|
||||
let user_id = "user123";
|
||||
let files = vec![];
|
||||
|
||||
let result = IngestionPayload::create_ingestion_payload(
|
||||
None,
|
||||
instructions.to_string(),
|
||||
category.to_string(),
|
||||
files,
|
||||
user_id,
|
||||
);
|
||||
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(AppError::NotFound(msg)) => {
|
||||
assert_eq!(msg, "No valid content or files provided");
|
||||
}
|
||||
_ => panic!("Expected NotFound error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_ingestion_payload_with_empty_text() {
|
||||
let text = ""; // Empty text
|
||||
let instructions = "Process this";
|
||||
let category = "notes";
|
||||
let user_id = "user123";
|
||||
let files = vec![];
|
||||
|
||||
let result = IngestionPayload::create_ingestion_payload(
|
||||
Some(text.to_string()),
|
||||
instructions.to_string(),
|
||||
category.to_string(),
|
||||
files,
|
||||
user_id,
|
||||
);
|
||||
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(AppError::NotFound(msg)) => {
|
||||
assert_eq!(msg, "No valid content or files provided");
|
||||
}
|
||||
_ => panic!("Expected NotFound error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||
|
||||
use super::ingestion_payload::IngestionPayload;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum IngestionTaskStatus {
|
||||
Created,
|
||||
InProgress {
|
||||
@@ -100,3 +100,196 @@ impl IngestionTask {
|
||||
Ok(jobs)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::Utc;
|
||||
|
||||
// Helper function to create a test ingestion payload
|
||||
fn create_test_payload(user_id: &str) -> IngestionPayload {
|
||||
IngestionPayload::Text {
|
||||
text: "Test content".to_string(),
|
||||
instructions: "Test instructions".to_string(),
|
||||
category: "Test category".to_string(),
|
||||
user_id: user_id.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_new_ingestion_task() {
|
||||
let user_id = "user123";
|
||||
let payload = create_test_payload(user_id);
|
||||
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
|
||||
// Verify task properties
|
||||
assert_eq!(task.user_id, user_id);
|
||||
assert_eq!(task.content, payload);
|
||||
assert!(matches!(task.status, IngestionTaskStatus::Created));
|
||||
assert!(!task.id.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_and_add_to_db() {
|
||||
// Setup in-memory database
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
let user_id = "user123";
|
||||
let payload = create_test_payload(user_id);
|
||||
|
||||
// Create and store task
|
||||
IngestionTask::create_and_add_to_db(payload.clone(), user_id.to_string(), &db)
|
||||
.await
|
||||
.expect("Failed to create and add task to db");
|
||||
|
||||
// Query to verify task was stored
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE user_id = '{}'",
|
||||
IngestionTask::table_name(),
|
||||
user_id
|
||||
);
|
||||
let mut result = db.query(query).await.expect("Query failed");
|
||||
let tasks: Vec<IngestionTask> = result.take(0).unwrap_or_default();
|
||||
|
||||
// Verify task is in the database
|
||||
assert!(!tasks.is_empty(), "Task should exist in the database");
|
||||
let stored_task = &tasks[0];
|
||||
assert_eq!(stored_task.user_id, user_id);
|
||||
assert!(matches!(stored_task.status, IngestionTaskStatus::Created));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_status() {
|
||||
// Setup in-memory database
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
let user_id = "user123";
|
||||
let payload = create_test_payload(user_id);
|
||||
|
||||
// Create task manually
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
let task_id = task.id.clone();
|
||||
|
||||
// Store task
|
||||
db.store_item(task).await.expect("Failed to store task");
|
||||
|
||||
// Update status to InProgress
|
||||
let now = Utc::now();
|
||||
let new_status = IngestionTaskStatus::InProgress {
|
||||
attempts: 1,
|
||||
last_attempt: now,
|
||||
};
|
||||
|
||||
IngestionTask::update_status(&task_id, new_status.clone(), &db)
|
||||
.await
|
||||
.expect("Failed to update status");
|
||||
|
||||
// Verify status updated
|
||||
let updated_task: Option<IngestionTask> = db
|
||||
.get_item::<IngestionTask>(&task_id)
|
||||
.await
|
||||
.expect("Failed to get updated task");
|
||||
|
||||
assert!(updated_task.is_some());
|
||||
let updated_task = updated_task.unwrap();
|
||||
|
||||
match updated_task.status {
|
||||
IngestionTaskStatus::InProgress { attempts, .. } => {
|
||||
assert_eq!(attempts, 1);
|
||||
}
|
||||
_ => panic!("Expected InProgress status"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_unfinished_tasks() {
|
||||
// Setup in-memory database
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
let user_id = "user123";
|
||||
let payload = create_test_payload(user_id);
|
||||
|
||||
// Create tasks with different statuses
|
||||
let created_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
|
||||
let mut in_progress_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
in_progress_task.status = IngestionTaskStatus::InProgress {
|
||||
attempts: 1,
|
||||
last_attempt: Utc::now(),
|
||||
};
|
||||
|
||||
let mut max_attempts_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
max_attempts_task.status = IngestionTaskStatus::InProgress {
|
||||
attempts: MAX_ATTEMPTS,
|
||||
last_attempt: Utc::now(),
|
||||
};
|
||||
|
||||
let mut completed_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
completed_task.status = IngestionTaskStatus::Completed;
|
||||
|
||||
let mut error_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
error_task.status = IngestionTaskStatus::Error("Test error".to_string());
|
||||
|
||||
// Store all tasks
|
||||
db.store_item(created_task)
|
||||
.await
|
||||
.expect("Failed to store created task");
|
||||
db.store_item(in_progress_task)
|
||||
.await
|
||||
.expect("Failed to store in-progress task");
|
||||
db.store_item(max_attempts_task)
|
||||
.await
|
||||
.expect("Failed to store max-attempts task");
|
||||
db.store_item(completed_task)
|
||||
.await
|
||||
.expect("Failed to store completed task");
|
||||
db.store_item(error_task)
|
||||
.await
|
||||
.expect("Failed to store error task");
|
||||
|
||||
// Get unfinished tasks
|
||||
let unfinished_tasks = IngestionTask::get_unfinished_tasks(&db)
|
||||
.await
|
||||
.expect("Failed to get unfinished tasks");
|
||||
|
||||
// Verify only Created and InProgress with attempts < MAX_ATTEMPTS are returned
|
||||
assert_eq!(unfinished_tasks.len(), 2);
|
||||
|
||||
let statuses: Vec<_> = unfinished_tasks
|
||||
.iter()
|
||||
.map(|task| match &task.status {
|
||||
IngestionTaskStatus::Created => "Created",
|
||||
IngestionTaskStatus::InProgress { attempts, .. } => {
|
||||
if *attempts < MAX_ATTEMPTS {
|
||||
"InProgress<MAX"
|
||||
} else {
|
||||
"InProgress>=MAX"
|
||||
}
|
||||
}
|
||||
IngestionTaskStatus::Completed => "Completed",
|
||||
IngestionTaskStatus::Error(_) => "Error",
|
||||
IngestionTaskStatus::Cancelled => "Cancelled",
|
||||
})
|
||||
.collect();
|
||||
|
||||
assert!(statuses.contains(&"Created"));
|
||||
assert!(statuses.contains(&"InProgress<MAX"));
|
||||
assert!(!statuses.contains(&"InProgress>=MAX"));
|
||||
assert!(!statuses.contains(&"Completed"));
|
||||
assert!(!statuses.contains(&"Error"));
|
||||
assert!(!statuses.contains(&"Cancelled"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ use crate::{
|
||||
use async_openai::{config::OpenAIConfig, Client};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
|
||||
pub enum KnowledgeEntityType {
|
||||
Idea,
|
||||
Project,
|
||||
@@ -119,3 +119,198 @@ impl KnowledgeEntity {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_knowledge_entity_creation() {
|
||||
// Create basic test entity
|
||||
let source_id = "source123".to_string();
|
||||
let name = "Test Entity".to_string();
|
||||
let description = "Test Description".to_string();
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let metadata = Some(json!({"key": "value"}));
|
||||
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let entity = KnowledgeEntity::new(
|
||||
source_id.clone(),
|
||||
name.clone(),
|
||||
description.clone(),
|
||||
entity_type.clone(),
|
||||
metadata.clone(),
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Verify all fields are set correctly
|
||||
assert_eq!(entity.source_id, source_id);
|
||||
assert_eq!(entity.name, name);
|
||||
assert_eq!(entity.description, description);
|
||||
assert_eq!(entity.entity_type, entity_type);
|
||||
assert_eq!(entity.metadata, metadata);
|
||||
assert_eq!(entity.embedding, embedding);
|
||||
assert_eq!(entity.user_id, user_id);
|
||||
assert!(!entity.id.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_knowledge_entity_type_from_string() {
|
||||
// Test conversion from String to KnowledgeEntityType
|
||||
assert_eq!(
|
||||
KnowledgeEntityType::from("idea".to_string()),
|
||||
KnowledgeEntityType::Idea
|
||||
);
|
||||
assert_eq!(
|
||||
KnowledgeEntityType::from("Idea".to_string()),
|
||||
KnowledgeEntityType::Idea
|
||||
);
|
||||
assert_eq!(
|
||||
KnowledgeEntityType::from("IDEA".to_string()),
|
||||
KnowledgeEntityType::Idea
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
KnowledgeEntityType::from("project".to_string()),
|
||||
KnowledgeEntityType::Project
|
||||
);
|
||||
assert_eq!(
|
||||
KnowledgeEntityType::from("document".to_string()),
|
||||
KnowledgeEntityType::Document
|
||||
);
|
||||
assert_eq!(
|
||||
KnowledgeEntityType::from("page".to_string()),
|
||||
KnowledgeEntityType::Page
|
||||
);
|
||||
assert_eq!(
|
||||
KnowledgeEntityType::from("textsnippet".to_string()),
|
||||
KnowledgeEntityType::TextSnippet
|
||||
);
|
||||
|
||||
// Test default case
|
||||
assert_eq!(
|
||||
KnowledgeEntityType::from("unknown".to_string()),
|
||||
KnowledgeEntityType::Document
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_knowledge_entity_variants() {
|
||||
let variants = KnowledgeEntityType::variants();
|
||||
assert_eq!(variants.len(), 5);
|
||||
assert!(variants.contains(&"Idea"));
|
||||
assert!(variants.contains(&"Project"));
|
||||
assert!(variants.contains(&"Document"));
|
||||
assert!(variants.contains(&"Page"));
|
||||
assert!(variants.contains(&"TextSnippet"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create two entities with the same source_id
|
||||
let source_id = "source123".to_string();
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let entity1 = KnowledgeEntity::new(
|
||||
source_id.clone(),
|
||||
"Entity 1".to_string(),
|
||||
"Description 1".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
let entity2 = KnowledgeEntity::new(
|
||||
source_id.clone(),
|
||||
"Entity 2".to_string(),
|
||||
"Description 2".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Create an entity with a different source_id
|
||||
let different_source_id = "different_source".to_string();
|
||||
let different_entity = KnowledgeEntity::new(
|
||||
different_source_id.clone(),
|
||||
"Different Entity".to_string(),
|
||||
"Different Description".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Store the entities
|
||||
db.store_item(entity1)
|
||||
.await
|
||||
.expect("Failed to store entity 1");
|
||||
db.store_item(entity2)
|
||||
.await
|
||||
.expect("Failed to store entity 2");
|
||||
db.store_item(different_entity.clone())
|
||||
.await
|
||||
.expect("Failed to store different entity");
|
||||
|
||||
// Delete by source_id
|
||||
KnowledgeEntity::delete_by_source_id(&source_id, &db)
|
||||
.await
|
||||
.expect("Failed to delete entities by source_id");
|
||||
|
||||
// Verify all entities with the specified source_id are deleted
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
KnowledgeEntity::table_name(),
|
||||
source_id
|
||||
);
|
||||
let remaining: Vec<KnowledgeEntity> = db
|
||||
.client
|
||||
.query(query)
|
||||
.await
|
||||
.expect("Query failed")
|
||||
.take(0)
|
||||
.expect("Failed to get query results");
|
||||
assert_eq!(
|
||||
remaining.len(),
|
||||
0,
|
||||
"All entities with the source_id should be deleted"
|
||||
);
|
||||
|
||||
// Verify the entity with a different source_id still exists
|
||||
let different_query = format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
KnowledgeEntity::table_name(),
|
||||
different_source_id
|
||||
);
|
||||
let different_remaining: Vec<KnowledgeEntity> = db
|
||||
.client
|
||||
.query(different_query)
|
||||
.await
|
||||
.expect("Query failed")
|
||||
.take(0)
|
||||
.expect("Failed to get query results");
|
||||
assert_eq!(
|
||||
different_remaining.len(),
|
||||
1,
|
||||
"Entity with different source_id should still exist"
|
||||
);
|
||||
assert_eq!(different_remaining[0].id, different_entity.id);
|
||||
}
|
||||
|
||||
// Note: We can't easily test the patch method without mocking the OpenAI client
|
||||
// and the generate_embedding function. This would require more complex setup.
|
||||
}
|
||||
|
||||
@@ -84,3 +84,258 @@ impl KnowledgeRelationship {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||
|
||||
// Helper function to create a test knowledge entity for the relationship tests
|
||||
async fn create_test_entity(name: &str, db_client: &SurrealDbClient) -> String {
|
||||
let source_id = "source123".to_string();
|
||||
let description = format!("Description for {}", name);
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let embedding = vec![0.1, 0.2, 0.3];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let entity = KnowledgeEntity::new(
|
||||
source_id,
|
||||
name.to_string(),
|
||||
description,
|
||||
entity_type,
|
||||
None,
|
||||
embedding,
|
||||
user_id,
|
||||
);
|
||||
|
||||
let stored: Option<KnowledgeEntity> = db_client
|
||||
.store_item(entity)
|
||||
.await
|
||||
.expect("Failed to store entity");
|
||||
stored.unwrap().id
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_relationship_creation() {
|
||||
let in_id = "entity1".to_string();
|
||||
let out_id = "entity2".to_string();
|
||||
let user_id = "user123".to_string();
|
||||
let source_id = "source123".to_string();
|
||||
let relationship_type = "references".to_string();
|
||||
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
in_id.clone(),
|
||||
out_id.clone(),
|
||||
user_id.clone(),
|
||||
source_id.clone(),
|
||||
relationship_type.clone(),
|
||||
);
|
||||
|
||||
// Verify fields are correctly set
|
||||
assert_eq!(relationship.in_, in_id);
|
||||
assert_eq!(relationship.out, out_id);
|
||||
assert_eq!(relationship.metadata.user_id, user_id);
|
||||
assert_eq!(relationship.metadata.source_id, source_id);
|
||||
assert_eq!(relationship.metadata.relationship_type, relationship_type);
|
||||
assert!(!relationship.id.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_relationship() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create two entities to relate
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
||||
|
||||
// Create relationship
|
||||
let user_id = "user123".to_string();
|
||||
let source_id = "source123".to_string();
|
||||
let relationship_type = "references".to_string();
|
||||
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
entity1_id.clone(),
|
||||
entity2_id.clone(),
|
||||
user_id,
|
||||
source_id.clone(),
|
||||
relationship_type,
|
||||
);
|
||||
|
||||
// Store the relationship
|
||||
relationship
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship");
|
||||
|
||||
// Query to verify the relationship exists by checking for relationships with our source_id
|
||||
// This approach is more reliable than trying to look up by ID
|
||||
let check_query = format!(
|
||||
"SELECT * FROM relates_to WHERE metadata.source_id = '{}'",
|
||||
source_id
|
||||
);
|
||||
let mut check_result = db.query(check_query).await.expect("Check query failed");
|
||||
let check_results: Vec<KnowledgeRelationship> = check_result.take(0).unwrap_or_default();
|
||||
|
||||
// Just verify that a relationship was created
|
||||
assert!(
|
||||
!check_results.is_empty(),
|
||||
"Relationship should exist in the database"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_relationship_by_id() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create two entities to relate
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
||||
|
||||
// Create relationship
|
||||
let user_id = "user123".to_string();
|
||||
let source_id = "source123".to_string();
|
||||
let relationship_type = "references".to_string();
|
||||
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
entity1_id.clone(),
|
||||
entity2_id.clone(),
|
||||
user_id,
|
||||
source_id.clone(),
|
||||
relationship_type,
|
||||
);
|
||||
|
||||
// Store the relationship
|
||||
relationship
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship");
|
||||
|
||||
// Delete the relationship by ID
|
||||
KnowledgeRelationship::delete_relationship_by_id(&relationship.id, &db)
|
||||
.await
|
||||
.expect("Failed to delete relationship by ID");
|
||||
|
||||
// Query to verify the relationship was deleted
|
||||
let query = format!("SELECT * FROM relates_to WHERE id = '{}'", relationship.id);
|
||||
let mut result = db.query(query).await.expect("Query failed");
|
||||
let results: Vec<KnowledgeRelationship> = result.take(0).unwrap_or_default();
|
||||
|
||||
// Verify the relationship no longer exists
|
||||
assert!(results.is_empty(), "Relationship should be deleted");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_relationships_by_source_id() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create entities to relate
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
||||
let entity3_id = create_test_entity("Entity 3", &db).await;
|
||||
|
||||
// Create relationships with the same source_id
|
||||
let user_id = "user123".to_string();
|
||||
let source_id = "source123".to_string();
|
||||
let different_source_id = "different_source".to_string();
|
||||
|
||||
// Create two relationships with the same source_id
|
||||
let relationship1 = KnowledgeRelationship::new(
|
||||
entity1_id.clone(),
|
||||
entity2_id.clone(),
|
||||
user_id.clone(),
|
||||
source_id.clone(),
|
||||
"references".to_string(),
|
||||
);
|
||||
|
||||
let relationship2 = KnowledgeRelationship::new(
|
||||
entity2_id.clone(),
|
||||
entity3_id.clone(),
|
||||
user_id.clone(),
|
||||
source_id.clone(),
|
||||
"contains".to_string(),
|
||||
);
|
||||
|
||||
// Create a relationship with a different source_id
|
||||
let different_relationship = KnowledgeRelationship::new(
|
||||
entity1_id.clone(),
|
||||
entity3_id.clone(),
|
||||
user_id.clone(),
|
||||
different_source_id.clone(),
|
||||
"mentions".to_string(),
|
||||
);
|
||||
|
||||
// Store all relationships
|
||||
relationship1
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship 1");
|
||||
relationship2
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship 2");
|
||||
different_relationship
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store different relationship");
|
||||
|
||||
// Delete relationships by source_id
|
||||
KnowledgeRelationship::delete_relationships_by_source_id(&source_id, &db)
|
||||
.await
|
||||
.expect("Failed to delete relationships by source_id");
|
||||
|
||||
// Query to verify the relationships with source_id were deleted
|
||||
let query1 = format!("SELECT * FROM relates_to WHERE id = '{}'", relationship1.id);
|
||||
let query2 = format!("SELECT * FROM relates_to WHERE id = '{}'", relationship2.id);
|
||||
let different_query = format!(
|
||||
"SELECT * FROM relates_to WHERE id = '{}'",
|
||||
different_relationship.id
|
||||
);
|
||||
|
||||
let mut result1 = db.query(query1).await.expect("Query 1 failed");
|
||||
let results1: Vec<KnowledgeRelationship> = result1.take(0).unwrap_or_default();
|
||||
|
||||
let mut result2 = db.query(query2).await.expect("Query 2 failed");
|
||||
let results2: Vec<KnowledgeRelationship> = result2.take(0).unwrap_or_default();
|
||||
|
||||
let mut different_result = db
|
||||
.query(different_query)
|
||||
.await
|
||||
.expect("Different query failed");
|
||||
let _different_results: Vec<KnowledgeRelationship> =
|
||||
different_result.take(0).unwrap_or_default();
|
||||
|
||||
// Verify relationships with the source_id are deleted
|
||||
assert!(results1.is_empty(), "Relationship 1 should be deleted");
|
||||
assert!(results2.is_empty(), "Relationship 2 should be deleted");
|
||||
|
||||
// For the relationship with different source ID, we need to check differently
|
||||
// Let's just verify we have a relationship where the source_id matches different_source_id
|
||||
let check_query = format!(
|
||||
"SELECT * FROM relates_to WHERE metadata.source_id = '{}'",
|
||||
different_source_id
|
||||
);
|
||||
let mut check_result = db.query(check_query).await.expect("Check query failed");
|
||||
let check_results: Vec<KnowledgeRelationship> = check_result.take(0).unwrap_or_default();
|
||||
|
||||
// Verify the relationship with a different source_id still exists
|
||||
assert!(
|
||||
!check_results.is_empty(),
|
||||
"Relationship with different source_id should still exist"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ use uuid::Uuid;
|
||||
|
||||
use crate::stored_object;
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, Serialize)]
|
||||
#[derive(Deserialize, Debug, Clone, Serialize, PartialEq)]
|
||||
pub enum MessageRole {
|
||||
User,
|
||||
AI,
|
||||
@@ -60,3 +60,128 @@ pub fn format_history(history: &[Message]) -> String {
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::storage::db::SurrealDbClient;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_message_creation() {
|
||||
// Test basic message creation
|
||||
let conversation_id = "test_conversation";
|
||||
let content = "This is a test message";
|
||||
let role = MessageRole::User;
|
||||
let references = Some(vec!["ref1".to_string(), "ref2".to_string()]);
|
||||
|
||||
let message = Message::new(
|
||||
conversation_id.to_string(),
|
||||
role.clone(),
|
||||
content.to_string(),
|
||||
references.clone(),
|
||||
);
|
||||
|
||||
// Verify message properties
|
||||
assert_eq!(message.conversation_id, conversation_id);
|
||||
assert_eq!(message.content, content);
|
||||
assert_eq!(message.role, role);
|
||||
assert_eq!(message.references, references);
|
||||
assert!(!message.id.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_message_persistence() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &uuid::Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create and store a message
|
||||
let conversation_id = "test_conversation";
|
||||
let message = Message::new(
|
||||
conversation_id.to_string(),
|
||||
MessageRole::User,
|
||||
"Hello world".to_string(),
|
||||
None,
|
||||
);
|
||||
let message_id = message.id.clone();
|
||||
|
||||
// Store the message
|
||||
db.store_item(message.clone())
|
||||
.await
|
||||
.expect("Failed to store message");
|
||||
|
||||
// Retrieve the message
|
||||
let retrieved: Option<Message> = db
|
||||
.get_item(&message_id)
|
||||
.await
|
||||
.expect("Failed to retrieve message");
|
||||
|
||||
assert!(retrieved.is_some());
|
||||
let retrieved = retrieved.unwrap();
|
||||
|
||||
// Verify retrieved properties match original
|
||||
assert_eq!(retrieved.id, message.id);
|
||||
assert_eq!(retrieved.conversation_id, message.conversation_id);
|
||||
assert_eq!(retrieved.role, message.role);
|
||||
assert_eq!(retrieved.content, message.content);
|
||||
assert_eq!(retrieved.references, message.references);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_message_role_display() {
|
||||
// Test the Display implementation for MessageRole
|
||||
assert_eq!(format!("{}", MessageRole::User), "User");
|
||||
assert_eq!(format!("{}", MessageRole::AI), "AI");
|
||||
assert_eq!(format!("{}", MessageRole::System), "System");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_message_display() {
|
||||
// Test the Display implementation for Message
|
||||
let message = Message {
|
||||
id: "test_id".to_string(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
conversation_id: "test_convo".to_string(),
|
||||
role: MessageRole::User,
|
||||
content: "Hello world".to_string(),
|
||||
references: None,
|
||||
};
|
||||
|
||||
assert_eq!(format!("{}", message), "User: Hello world");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_format_history() {
|
||||
// Create a vector of messages
|
||||
let messages = vec![
|
||||
Message {
|
||||
id: "1".to_string(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
conversation_id: "test_convo".to_string(),
|
||||
role: MessageRole::User,
|
||||
content: "Hello".to_string(),
|
||||
references: None,
|
||||
},
|
||||
Message {
|
||||
id: "2".to_string(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
conversation_id: "test_convo".to_string(),
|
||||
role: MessageRole::AI,
|
||||
content: "Hi there!".to_string(),
|
||||
references: None,
|
||||
},
|
||||
];
|
||||
|
||||
// Format the history
|
||||
let formatted = format_history(&messages);
|
||||
|
||||
// Verify the formatting
|
||||
assert_eq!(formatted, "User: Hello\nAI: Hi there!");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ macro_rules! stored_object {
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct $name {
|
||||
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||
pub id: String,
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
use crate::storage::types::file_info::deserialize_flexible_id;
|
||||
use axum::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient};
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient, storage::types::StoredObject};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct SystemSettings {
|
||||
@@ -16,41 +17,49 @@ pub struct SystemSettings {
|
||||
pub ingestion_system_prompt: String,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl StoredObject for SystemSettings {
|
||||
fn table_name() -> &'static str {
|
||||
"system_settings"
|
||||
}
|
||||
|
||||
fn get_id(&self) -> &str {
|
||||
&self.id
|
||||
}
|
||||
}
|
||||
|
||||
impl SystemSettings {
|
||||
pub async fn ensure_initialized(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||
let settings = db.select(("system_settings", "current")).await?;
|
||||
let settings: Option<Self> = db.get_item("current").await?;
|
||||
|
||||
if settings.is_none() {
|
||||
let created: Option<SystemSettings> = db
|
||||
.create(("system_settings", "current"))
|
||||
.content(SystemSettings {
|
||||
id: "current".to_string(),
|
||||
registrations_enabled: true,
|
||||
require_email_verification: false,
|
||||
query_model: "gpt-4o-mini".to_string(),
|
||||
processing_model: "gpt-4o-mini".to_string(),
|
||||
query_system_prompt: crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT.to_string(),
|
||||
ingestion_system_prompt: crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT.to_string(),
|
||||
})
|
||||
.await?;
|
||||
let created_settings = SystemSettings {
|
||||
id: "current".to_string(),
|
||||
registrations_enabled: true,
|
||||
require_email_verification: false,
|
||||
query_model: "gpt-4o-mini".to_string(),
|
||||
processing_model: "gpt-4o-mini".to_string(),
|
||||
query_system_prompt:
|
||||
crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT.to_string(),
|
||||
ingestion_system_prompt:
|
||||
crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT
|
||||
.to_string(),
|
||||
};
|
||||
|
||||
return created.ok_or(AppError::Validation("Failed to initialize settings".into()));
|
||||
};
|
||||
let stored: Option<Self> = db.store_item(created_settings).await?;
|
||||
return stored.ok_or(AppError::Validation("Failed to initialize settings".into()));
|
||||
}
|
||||
|
||||
settings.ok_or(AppError::Validation("Failed to initialize settings".into()))
|
||||
}
|
||||
|
||||
pub async fn get_current(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||
let settings: Option<Self> = db
|
||||
.client
|
||||
.query("SELECT * FROM type::thing('system_settings', 'current')")
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
let settings: Option<Self> = db.get_item("current").await?;
|
||||
settings.ok_or(AppError::NotFound("System settings not found".into()))
|
||||
}
|
||||
|
||||
pub async fn update(db: &SurrealDbClient, changes: Self) -> Result<Self, AppError> {
|
||||
// We need to use a direct query for the update with MERGE
|
||||
let updated: Option<Self> = db
|
||||
.client
|
||||
.query("UPDATE type::thing('system_settings', 'current') MERGE $changes RETURN AFTER")
|
||||
@@ -66,8 +75,11 @@ impl SystemSettings {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
query_system_prompt: crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT.to_string(),
|
||||
ingestion_system_prompt: crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT.to_string(),
|
||||
query_system_prompt: crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT
|
||||
.to_string(),
|
||||
ingestion_system_prompt:
|
||||
crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT
|
||||
.to_string(),
|
||||
query_model: "gpt-4o-mini".to_string(),
|
||||
processing_model: "gpt-4o-mini".to_string(),
|
||||
registrations_enabled: true,
|
||||
@@ -75,3 +87,159 @@ impl SystemSettings {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_settings_initialization() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Test initialization of system settings
|
||||
let settings = SystemSettings::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to initialize system settings");
|
||||
|
||||
// Verify initial state after initialization
|
||||
assert_eq!(settings.id, "current");
|
||||
assert_eq!(settings.registrations_enabled, true);
|
||||
assert_eq!(settings.require_email_verification, false);
|
||||
assert_eq!(settings.query_model, "gpt-4o-mini");
|
||||
assert_eq!(settings.processing_model, "gpt-4o-mini");
|
||||
assert_eq!(
|
||||
settings.query_system_prompt,
|
||||
crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT
|
||||
);
|
||||
assert_eq!(
|
||||
settings.ingestion_system_prompt,
|
||||
crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT
|
||||
);
|
||||
|
||||
// Test idempotency - ensure calling it again doesn't change anything
|
||||
let settings_again = SystemSettings::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to get settings after initialization");
|
||||
|
||||
assert_eq!(settings.id, settings_again.id);
|
||||
assert_eq!(
|
||||
settings.registrations_enabled,
|
||||
settings_again.registrations_enabled
|
||||
);
|
||||
assert_eq!(
|
||||
settings.require_email_verification,
|
||||
settings_again.require_email_verification
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_current_settings() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Initialize settings
|
||||
SystemSettings::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to initialize system settings");
|
||||
|
||||
// Test get_current method
|
||||
let settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to get current settings");
|
||||
|
||||
assert_eq!(settings.id, "current");
|
||||
assert_eq!(settings.registrations_enabled, true);
|
||||
assert_eq!(settings.require_email_verification, false);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_settings() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Initialize settings
|
||||
SystemSettings::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to initialize system settings");
|
||||
|
||||
// Create updated settings
|
||||
let mut updated_settings = SystemSettings::new();
|
||||
updated_settings.id = "current".to_string();
|
||||
updated_settings.registrations_enabled = false;
|
||||
updated_settings.require_email_verification = true;
|
||||
updated_settings.query_model = "gpt-4".to_string();
|
||||
|
||||
// Test update method
|
||||
let result = SystemSettings::update(&db, updated_settings)
|
||||
.await
|
||||
.expect("Failed to update settings");
|
||||
|
||||
assert_eq!(result.id, "current");
|
||||
assert_eq!(result.registrations_enabled, false);
|
||||
assert_eq!(result.require_email_verification, true);
|
||||
assert_eq!(result.query_model, "gpt-4");
|
||||
|
||||
// Verify changes persisted by getting current settings
|
||||
let current = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to get current settings after update");
|
||||
|
||||
assert_eq!(current.registrations_enabled, false);
|
||||
assert_eq!(current.require_email_verification, true);
|
||||
assert_eq!(current.query_model, "gpt-4");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_current_nonexistent() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Don't initialize settings and try to get them
|
||||
let result = SystemSettings::get_current(&db).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(AppError::NotFound(_)) => {
|
||||
// Expected error
|
||||
}
|
||||
Err(e) => panic!("Expected NotFound error, got: {:?}", e),
|
||||
Ok(_) => panic!("Expected error but got Ok"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_new_method() {
|
||||
let settings = SystemSettings::new();
|
||||
|
||||
assert!(settings.id.len() > 0);
|
||||
assert_eq!(settings.registrations_enabled, true);
|
||||
assert_eq!(settings.require_email_verification, false);
|
||||
assert_eq!(settings.query_model, "gpt-4o-mini");
|
||||
assert_eq!(settings.processing_model, "gpt-4o-mini");
|
||||
assert_eq!(
|
||||
settings.query_system_prompt,
|
||||
crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT
|
||||
);
|
||||
assert_eq!(
|
||||
settings.ingestion_system_prompt,
|
||||
crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,3 +36,175 @@ impl TextChunk {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_chunk_creation() {
|
||||
// Test basic object creation
|
||||
let source_id = "source123".to_string();
|
||||
let chunk = "This is a text chunk for testing embeddings".to_string();
|
||||
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let text_chunk = TextChunk::new(
|
||||
source_id.clone(),
|
||||
chunk.clone(),
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Check that the fields are set correctly
|
||||
assert_eq!(text_chunk.source_id, source_id);
|
||||
assert_eq!(text_chunk.chunk, chunk);
|
||||
assert_eq!(text_chunk.embedding, embedding);
|
||||
assert_eq!(text_chunk.user_id, user_id);
|
||||
assert!(!text_chunk.id.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create test data
|
||||
let source_id = "source123".to_string();
|
||||
let chunk1 = "First chunk from the same source".to_string();
|
||||
let chunk2 = "Second chunk from the same source".to_string();
|
||||
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
// Create two chunks with the same source_id
|
||||
let text_chunk1 = TextChunk::new(
|
||||
source_id.clone(),
|
||||
chunk1,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
let text_chunk2 = TextChunk::new(
|
||||
source_id.clone(),
|
||||
chunk2,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Create a chunk with a different source_id
|
||||
let different_source_id = "different_source".to_string();
|
||||
let different_chunk = TextChunk::new(
|
||||
different_source_id.clone(),
|
||||
"Different source chunk".to_string(),
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Store the chunks
|
||||
db.store_item(text_chunk1)
|
||||
.await
|
||||
.expect("Failed to store text chunk 1");
|
||||
db.store_item(text_chunk2)
|
||||
.await
|
||||
.expect("Failed to store text chunk 2");
|
||||
db.store_item(different_chunk.clone())
|
||||
.await
|
||||
.expect("Failed to store different chunk");
|
||||
|
||||
// Delete by source_id
|
||||
TextChunk::delete_by_source_id(&source_id, &db)
|
||||
.await
|
||||
.expect("Failed to delete chunks by source_id");
|
||||
|
||||
// Verify all chunks with the original source_id are deleted
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
TextChunk::table_name(),
|
||||
source_id
|
||||
);
|
||||
let remaining: Vec<TextChunk> = db
|
||||
.client
|
||||
.query(query)
|
||||
.await
|
||||
.expect("Query failed")
|
||||
.take(0)
|
||||
.expect("Failed to get query results");
|
||||
assert_eq!(
|
||||
remaining.len(),
|
||||
0,
|
||||
"All chunks with the source_id should be deleted"
|
||||
);
|
||||
|
||||
// Verify the different source_id chunk still exists
|
||||
let different_query = format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
TextChunk::table_name(),
|
||||
different_source_id
|
||||
);
|
||||
let different_remaining: Vec<TextChunk> = db
|
||||
.client
|
||||
.query(different_query)
|
||||
.await
|
||||
.expect("Query failed")
|
||||
.take(0)
|
||||
.expect("Failed to get query results");
|
||||
assert_eq!(
|
||||
different_remaining.len(),
|
||||
1,
|
||||
"Chunk with different source_id should still exist"
|
||||
);
|
||||
assert_eq!(different_remaining[0].id, different_chunk.id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_nonexistent_source_id() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create a chunk with a real source_id
|
||||
let real_source_id = "real_source".to_string();
|
||||
let chunk = "Test chunk".to_string();
|
||||
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let text_chunk = TextChunk::new(real_source_id.clone(), chunk, embedding, user_id);
|
||||
|
||||
// Store the chunk
|
||||
db.store_item(text_chunk)
|
||||
.await
|
||||
.expect("Failed to store text chunk");
|
||||
|
||||
// Delete using nonexistent source_id
|
||||
let nonexistent_source_id = "nonexistent_source";
|
||||
TextChunk::delete_by_source_id(nonexistent_source_id, &db)
|
||||
.await
|
||||
.expect("Delete operation with nonexistent source_id should not fail");
|
||||
|
||||
// Verify the real chunk still exists
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
TextChunk::table_name(),
|
||||
real_source_id
|
||||
);
|
||||
let remaining: Vec<TextChunk> = db
|
||||
.client
|
||||
.query(query)
|
||||
.await
|
||||
.expect("Query failed")
|
||||
.take(0)
|
||||
.expect("Failed to get query results");
|
||||
assert_eq!(
|
||||
remaining.len(),
|
||||
1,
|
||||
"Chunk with real source_id should still exist"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,3 +57,120 @@ impl TextContent {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_content_creation() {
|
||||
// Test basic object creation
|
||||
let text = "Test content text".to_string();
|
||||
let instructions = "Test instructions".to_string();
|
||||
let category = "Test category".to_string();
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let text_content = TextContent::new(
|
||||
text.clone(),
|
||||
instructions.clone(),
|
||||
category.clone(),
|
||||
None,
|
||||
None,
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Check that the fields are set correctly
|
||||
assert_eq!(text_content.text, text);
|
||||
assert_eq!(text_content.instructions, instructions);
|
||||
assert_eq!(text_content.category, category);
|
||||
assert_eq!(text_content.user_id, user_id);
|
||||
assert!(text_content.file_info.is_none());
|
||||
assert!(text_content.url.is_none());
|
||||
assert!(!text_content.id.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_content_with_url() {
|
||||
// Test creating with URL
|
||||
let text = "Content with URL".to_string();
|
||||
let instructions = "URL instructions".to_string();
|
||||
let category = "URL category".to_string();
|
||||
let user_id = "user123".to_string();
|
||||
let url = Some("https://example.com/document.pdf".to_string());
|
||||
|
||||
let text_content = TextContent::new(
|
||||
text.clone(),
|
||||
instructions.clone(),
|
||||
category.clone(),
|
||||
None,
|
||||
url.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Check URL field is set
|
||||
assert_eq!(text_content.url, url);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_content_patch() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create initial text content
|
||||
let initial_text = "Initial text".to_string();
|
||||
let initial_instructions = "Initial instructions".to_string();
|
||||
let initial_category = "Initial category".to_string();
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let text_content = TextContent::new(
|
||||
initial_text,
|
||||
initial_instructions,
|
||||
initial_category,
|
||||
None,
|
||||
None,
|
||||
user_id,
|
||||
);
|
||||
|
||||
// Store the text content
|
||||
let stored: Option<TextContent> = db
|
||||
.store_item(text_content.clone())
|
||||
.await
|
||||
.expect("Failed to store text content");
|
||||
assert!(stored.is_some());
|
||||
|
||||
// New values for patch
|
||||
let new_instructions = "Updated instructions";
|
||||
let new_category = "Updated category";
|
||||
let new_text = "Updated text content";
|
||||
|
||||
// Apply the patch
|
||||
TextContent::patch(
|
||||
&text_content.id,
|
||||
new_instructions,
|
||||
new_category,
|
||||
new_text,
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to patch text content");
|
||||
|
||||
// Retrieve the updated content
|
||||
let updated: Option<TextContent> = db
|
||||
.get_item(&text_content.id)
|
||||
.await
|
||||
.expect("Failed to get updated text content");
|
||||
assert!(updated.is_some());
|
||||
|
||||
let updated_content = updated.unwrap();
|
||||
|
||||
// Verify the updates
|
||||
assert_eq!(updated_content.instructions, new_instructions);
|
||||
assert_eq!(updated_content.category, new_category);
|
||||
assert_eq!(updated_content.text, new_text);
|
||||
assert!(updated_content.updated_at > text_content.updated_at);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -414,3 +414,276 @@ impl User {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// Helper function to set up a test database with SystemSettings
|
||||
async fn setup_test_db() -> SurrealDbClient {
|
||||
let namespace = "test_ns";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.ensure_initialized()
|
||||
.await
|
||||
.expect("Failed to setup the systemsettings");
|
||||
|
||||
db
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_user_creation() {
|
||||
// Setup test database
|
||||
let db = setup_test_db().await;
|
||||
|
||||
// Create a user
|
||||
let email = "test@example.com";
|
||||
let password = "test_password";
|
||||
let timezone = "America/New_York";
|
||||
|
||||
let user = User::create_new(
|
||||
email.to_string(),
|
||||
password.to_string(),
|
||||
&db,
|
||||
timezone.to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create user");
|
||||
|
||||
// Verify user properties
|
||||
assert!(!user.id.is_empty());
|
||||
assert_eq!(user.email, email);
|
||||
assert_ne!(user.password, password); // Password should be hashed
|
||||
assert!(!user.anonymous);
|
||||
assert_eq!(user.timezone, timezone);
|
||||
|
||||
// Verify it can be retrieved
|
||||
let retrieved: Option<User> = db
|
||||
.get_item(&user.id)
|
||||
.await
|
||||
.expect("Failed to retrieve user");
|
||||
assert!(retrieved.is_some());
|
||||
|
||||
let retrieved = retrieved.unwrap();
|
||||
assert_eq!(retrieved.id, user.id);
|
||||
assert_eq!(retrieved.email, email);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_user_authentication() {
|
||||
// Setup test database
|
||||
let db = setup_test_db().await;
|
||||
|
||||
// Create a user
|
||||
let email = "auth_test@example.com";
|
||||
let password = "auth_password";
|
||||
|
||||
User::create_new(
|
||||
email.to_string(),
|
||||
password.to_string(),
|
||||
&db,
|
||||
"UTC".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create user");
|
||||
|
||||
// Test successful authentication
|
||||
let auth_result = User::authenticate(email, password, &db).await;
|
||||
assert!(auth_result.is_ok());
|
||||
|
||||
// Test failed authentication with wrong password
|
||||
let wrong_auth = User::authenticate(email, "wrong_password", &db).await;
|
||||
assert!(wrong_auth.is_err());
|
||||
|
||||
// Test failed authentication with non-existent user
|
||||
let nonexistent = User::authenticate("nonexistent@example.com", password, &db).await;
|
||||
assert!(nonexistent.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_find_by_email() {
|
||||
// Setup test database
|
||||
let db = setup_test_db().await;
|
||||
|
||||
// Create a user
|
||||
let email = "find_test@example.com";
|
||||
let password = "find_password";
|
||||
|
||||
let created_user = User::create_new(
|
||||
email.to_string(),
|
||||
password.to_string(),
|
||||
&db,
|
||||
"UTC".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create user");
|
||||
|
||||
// Test finding user by email
|
||||
let found_user = User::find_by_email(email, &db)
|
||||
.await
|
||||
.expect("Error searching for user");
|
||||
assert!(found_user.is_some());
|
||||
let found_user = found_user.unwrap();
|
||||
assert_eq!(found_user.id, created_user.id);
|
||||
assert_eq!(found_user.email, email);
|
||||
|
||||
// Test finding non-existent user
|
||||
let not_found = User::find_by_email("nonexistent@example.com", &db)
|
||||
.await
|
||||
.expect("Error searching for user");
|
||||
assert!(not_found.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_api_key_management() {
|
||||
// Setup test database
|
||||
let db = setup_test_db().await;
|
||||
|
||||
// Create a user
|
||||
let email = "apikey_test@example.com";
|
||||
let password = "apikey_password";
|
||||
|
||||
let user = User::create_new(
|
||||
email.to_string(),
|
||||
password.to_string(),
|
||||
&db,
|
||||
"UTC".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create user");
|
||||
|
||||
// Initially, user should have no API key
|
||||
assert!(user.api_key.is_none());
|
||||
|
||||
// Generate API key
|
||||
let api_key = User::set_api_key(&user.id, &db)
|
||||
.await
|
||||
.expect("Failed to set API key");
|
||||
assert!(!api_key.is_empty());
|
||||
assert!(api_key.starts_with("sk_"));
|
||||
|
||||
// Verify the API key was saved
|
||||
let updated_user: Option<User> = db
|
||||
.get_item(&user.id)
|
||||
.await
|
||||
.expect("Failed to retrieve user");
|
||||
assert!(updated_user.is_some());
|
||||
let updated_user = updated_user.unwrap();
|
||||
assert_eq!(updated_user.api_key, Some(api_key.clone()));
|
||||
|
||||
// Test finding user by API key
|
||||
let found_user = User::find_by_api_key(&api_key, &db)
|
||||
.await
|
||||
.expect("Error searching by API key");
|
||||
assert!(found_user.is_some());
|
||||
let found_user = found_user.unwrap();
|
||||
assert_eq!(found_user.id, user.id);
|
||||
|
||||
// Revoke API key
|
||||
User::revoke_api_key(&user.id, &db)
|
||||
.await
|
||||
.expect("Failed to revoke API key");
|
||||
|
||||
// Verify API key was revoked
|
||||
let revoked_user: Option<User> = db
|
||||
.get_item(&user.id)
|
||||
.await
|
||||
.expect("Failed to retrieve user");
|
||||
assert!(revoked_user.is_some());
|
||||
let revoked_user = revoked_user.unwrap();
|
||||
assert!(revoked_user.api_key.is_none());
|
||||
|
||||
// Test searching by revoked API key
|
||||
let not_found = User::find_by_api_key(&api_key, &db)
|
||||
.await
|
||||
.expect("Error searching by API key");
|
||||
assert!(not_found.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_password_update() {
|
||||
// Setup test database
|
||||
let db = setup_test_db().await;
|
||||
|
||||
// Create a user
|
||||
let email = "pwd_test@example.com";
|
||||
let old_password = "old_password";
|
||||
let new_password = "new_password";
|
||||
|
||||
User::create_new(
|
||||
email.to_string(),
|
||||
old_password.to_string(),
|
||||
&db,
|
||||
"UTC".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create user");
|
||||
|
||||
// Authenticate with old password
|
||||
let auth_result = User::authenticate(email, old_password, &db).await;
|
||||
assert!(auth_result.is_ok());
|
||||
|
||||
// Update password
|
||||
User::patch_password(email, new_password, &db)
|
||||
.await
|
||||
.expect("Failed to update password");
|
||||
|
||||
// Old password should no longer work
|
||||
let old_auth = User::authenticate(email, old_password, &db).await;
|
||||
assert!(old_auth.is_err());
|
||||
|
||||
// New password should work
|
||||
let new_auth = User::authenticate(email, new_password, &db).await;
|
||||
assert!(new_auth.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validate_timezone() {
|
||||
// Valid timezones should be accepted as-is
|
||||
assert_eq!(validate_timezone("America/New_York"), "America/New_York");
|
||||
assert_eq!(validate_timezone("Europe/London"), "Europe/London");
|
||||
assert_eq!(validate_timezone("Asia/Tokyo"), "Asia/Tokyo");
|
||||
assert_eq!(validate_timezone("UTC"), "UTC");
|
||||
|
||||
// Invalid timezones should be replaced with UTC
|
||||
assert_eq!(validate_timezone("Invalid/Timezone"), "UTC");
|
||||
assert_eq!(validate_timezone("Not_Real"), "UTC");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timezone_update() {
|
||||
// Setup test database
|
||||
let db = setup_test_db().await;
|
||||
|
||||
// Create user with default timezone
|
||||
let email = "timezone_test@example.com";
|
||||
let user = User::create_new(
|
||||
email.to_string(),
|
||||
"password".to_string(),
|
||||
&db,
|
||||
"UTC".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create user");
|
||||
|
||||
assert_eq!(user.timezone, "UTC");
|
||||
|
||||
// Update timezone
|
||||
let new_timezone = "Europe/Paris";
|
||||
User::update_timezone(&user.id, new_timezone, &db)
|
||||
.await
|
||||
.expect("Failed to update timezone");
|
||||
|
||||
// Verify timezone was updated
|
||||
let updated_user: Option<User> = db
|
||||
.get_item(&user.id)
|
||||
.await
|
||||
.expect("Failed to retrieve user");
|
||||
assert!(updated_user.is_some());
|
||||
let updated_user = updated_user.unwrap();
|
||||
assert_eq!(updated_user.timezone, new_timezone);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,5 +15,7 @@ serde_json = { workspace = true }
|
||||
surrealdb = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
async-openai = { workspace = true }
|
||||
|
||||
uuid = { version = "1.10.0", features = ["v4", "serde"] }
|
||||
|
||||
common = { path = "../common" }
|
||||
common = { path = "../common", features = ["test-utils"] }
|
||||
|
||||
@@ -61,3 +61,280 @@ pub async fn find_entities_by_relationship_by_id(
|
||||
|
||||
db.query(query).await?.take(0)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||
use common::storage::types::knowledge_relationship::KnowledgeRelationship;
|
||||
use common::storage::types::StoredObject;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_find_entities_by_source_ids() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create some test entities with different source_ids
|
||||
let source_id1 = "source123".to_string();
|
||||
let source_id2 = "source456".to_string();
|
||||
let source_id3 = "source789".to_string();
|
||||
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let embedding = vec![0.1, 0.2, 0.3];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
// Entity with source_id1
|
||||
let entity1 = KnowledgeEntity::new(
|
||||
source_id1.clone(),
|
||||
"Entity 1".to_string(),
|
||||
"Description 1".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Entity with source_id2
|
||||
let entity2 = KnowledgeEntity::new(
|
||||
source_id2.clone(),
|
||||
"Entity 2".to_string(),
|
||||
"Description 2".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Another entity with source_id1
|
||||
let entity3 = KnowledgeEntity::new(
|
||||
source_id1.clone(),
|
||||
"Entity 3".to_string(),
|
||||
"Description 3".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Entity with source_id3
|
||||
let entity4 = KnowledgeEntity::new(
|
||||
source_id3.clone(),
|
||||
"Entity 4".to_string(),
|
||||
"Description 4".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Store all entities
|
||||
db.store_item(entity1.clone())
|
||||
.await
|
||||
.expect("Failed to store entity 1");
|
||||
db.store_item(entity2.clone())
|
||||
.await
|
||||
.expect("Failed to store entity 2");
|
||||
db.store_item(entity3.clone())
|
||||
.await
|
||||
.expect("Failed to store entity 3");
|
||||
db.store_item(entity4.clone())
|
||||
.await
|
||||
.expect("Failed to store entity 4");
|
||||
|
||||
// Test finding entities by multiple source_ids
|
||||
let source_ids = vec![source_id1.clone(), source_id2.clone()];
|
||||
let found_entities: Vec<KnowledgeEntity> =
|
||||
find_entities_by_source_ids(source_ids, KnowledgeEntity::table_name().to_string(), &db)
|
||||
.await
|
||||
.expect("Failed to find entities by source_ids");
|
||||
|
||||
// Should find 3 entities (2 with source_id1, 1 with source_id2)
|
||||
assert_eq!(
|
||||
found_entities.len(),
|
||||
3,
|
||||
"Should find 3 entities with the specified source_ids"
|
||||
);
|
||||
|
||||
// Check that entities with source_id1 and source_id2 are found
|
||||
let found_source_ids: Vec<String> =
|
||||
found_entities.iter().map(|e| e.source_id.clone()).collect();
|
||||
assert!(
|
||||
found_source_ids.contains(&source_id1),
|
||||
"Should find entities with source_id1"
|
||||
);
|
||||
assert!(
|
||||
found_source_ids.contains(&source_id2),
|
||||
"Should find entities with source_id2"
|
||||
);
|
||||
assert!(
|
||||
!found_source_ids.contains(&source_id3),
|
||||
"Should not find entities with source_id3"
|
||||
);
|
||||
|
||||
// Test finding entities by a single source_id
|
||||
let single_source_id = vec![source_id1.clone()];
|
||||
let found_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
|
||||
single_source_id,
|
||||
KnowledgeEntity::table_name().to_string(),
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to find entities by single source_id");
|
||||
|
||||
// Should find 2 entities with source_id1
|
||||
assert_eq!(
|
||||
found_entities.len(),
|
||||
2,
|
||||
"Should find 2 entities with source_id1"
|
||||
);
|
||||
|
||||
// Check that all found entities have source_id1
|
||||
for entity in found_entities {
|
||||
assert_eq!(
|
||||
entity.source_id, source_id1,
|
||||
"All found entities should have source_id1"
|
||||
);
|
||||
}
|
||||
|
||||
// Test finding entities with non-existent source_id
|
||||
let non_existent_source_id = vec!["non_existent_source".to_string()];
|
||||
let found_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
|
||||
non_existent_source_id,
|
||||
KnowledgeEntity::table_name().to_string(),
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to find entities by non-existent source_id");
|
||||
|
||||
// Should find 0 entities
|
||||
assert_eq!(
|
||||
found_entities.len(),
|
||||
0,
|
||||
"Should find 0 entities with non-existent source_id"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_find_entities_by_relationship_by_id() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create some test entities
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let embedding = vec![0.1, 0.2, 0.3];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
// Create the central entity we'll query relationships for
|
||||
let central_entity = KnowledgeEntity::new(
|
||||
"central_source".to_string(),
|
||||
"Central Entity".to_string(),
|
||||
"Central Description".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Create related entities
|
||||
let related_entity1 = KnowledgeEntity::new(
|
||||
"related_source1".to_string(),
|
||||
"Related Entity 1".to_string(),
|
||||
"Related Description 1".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
let related_entity2 = KnowledgeEntity::new(
|
||||
"related_source2".to_string(),
|
||||
"Related Entity 2".to_string(),
|
||||
"Related Description 2".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Create an unrelated entity
|
||||
let unrelated_entity = KnowledgeEntity::new(
|
||||
"unrelated_source".to_string(),
|
||||
"Unrelated Entity".to_string(),
|
||||
"Unrelated Description".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Store all entities
|
||||
let central_entity = db
|
||||
.store_item(central_entity.clone())
|
||||
.await
|
||||
.expect("Failed to store central entity")
|
||||
.unwrap();
|
||||
let related_entity1 = db
|
||||
.store_item(related_entity1.clone())
|
||||
.await
|
||||
.expect("Failed to store related entity 1")
|
||||
.unwrap();
|
||||
let related_entity2 = db
|
||||
.store_item(related_entity2.clone())
|
||||
.await
|
||||
.expect("Failed to store related entity 2")
|
||||
.unwrap();
|
||||
let unrelated_entity = db
|
||||
.store_item(unrelated_entity.clone())
|
||||
.await
|
||||
.expect("Failed to store unrelated entity")
|
||||
.unwrap();
|
||||
|
||||
// Create relationships
|
||||
let source_id = "relationship_source".to_string();
|
||||
|
||||
// Create relationship 1: central -> related1
|
||||
let relationship1 = KnowledgeRelationship::new(
|
||||
central_entity.id.clone(),
|
||||
related_entity1.id.clone(),
|
||||
user_id.clone(),
|
||||
source_id.clone(),
|
||||
"references".to_string(),
|
||||
);
|
||||
|
||||
// Create relationship 2: central -> related2
|
||||
let relationship2 = KnowledgeRelationship::new(
|
||||
central_entity.id.clone(),
|
||||
related_entity2.id.clone(),
|
||||
user_id.clone(),
|
||||
source_id.clone(),
|
||||
"contains".to_string(),
|
||||
);
|
||||
|
||||
// Store relationships
|
||||
relationship1
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship 1");
|
||||
relationship2
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship 2");
|
||||
|
||||
// Test finding entities related to the central entity
|
||||
let related_entities = find_entities_by_relationship_by_id(&db, central_entity.id.clone())
|
||||
.await
|
||||
.expect("Failed to find entities by relationship");
|
||||
|
||||
// Check that we found relationships
|
||||
assert!(related_entities.len() > 0, "Should find related entities");
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user