mirror of
https://github.com/perstarkse/minne.git
synced 2026-03-26 03:11:34 +01:00
tests: testing all db interactions and types
This commit is contained in:
@@ -189,3 +189,104 @@ impl Deref for SurrealDbClient {
|
||||
&self.client
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-utils"))]
|
||||
impl SurrealDbClient {
|
||||
/// Create an in-memory SurrealDB client for testing.
|
||||
pub async fn memory(namespace: &str, database: &str) -> Result<Self, Error> {
|
||||
let db = connect("mem://").await?;
|
||||
|
||||
db.use_ns(namespace).use_db(database).await?;
|
||||
|
||||
Ok(SurrealDbClient { client: db })
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::stored_object;
|
||||
|
||||
use super::*;
|
||||
use uuid::Uuid;
|
||||
|
||||
stored_object!(Dummy, "dummy", {
|
||||
name: String
|
||||
});
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_initialization_and_crud() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string(); // ensures isolation per test run
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Call your initialization
|
||||
db.ensure_initialized()
|
||||
.await
|
||||
.expect("Failed to initialize schema");
|
||||
|
||||
// Test basic CRUD
|
||||
let dummy = Dummy {
|
||||
id: "abc".to_string(),
|
||||
name: "first".to_string(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
};
|
||||
|
||||
// Store
|
||||
let stored = db.store_item(dummy.clone()).await.expect("Failed to store");
|
||||
assert!(stored.is_some());
|
||||
|
||||
// Read
|
||||
let fetched = db
|
||||
.get_item::<Dummy>(&dummy.id)
|
||||
.await
|
||||
.expect("Failed to fetch");
|
||||
assert_eq!(fetched, Some(dummy.clone()));
|
||||
|
||||
// Read all
|
||||
let all = db
|
||||
.get_all_stored_items::<Dummy>()
|
||||
.await
|
||||
.expect("Failed to fetch all");
|
||||
assert!(all.contains(&dummy));
|
||||
|
||||
// Delete
|
||||
let deleted = db
|
||||
.delete_item::<Dummy>(&dummy.id)
|
||||
.await
|
||||
.expect("Failed to delete");
|
||||
assert_eq!(deleted, Some(dummy));
|
||||
|
||||
// After delete, should not be present
|
||||
let fetch_post = db
|
||||
.get_item::<Dummy>("abc")
|
||||
.await
|
||||
.expect("Failed fetch post delete");
|
||||
assert!(fetch_post.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_setup_auth() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string(); // ensures isolation per test run
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Should not panic or fail
|
||||
db.setup_auth().await.expect("Failed to setup auth");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_build_indexes() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.build_indexes().await.expect("Failed to build indexes");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use crate::storage::types::{file_info::deserialize_flexible_id, user::User, StoredObject};
|
||||
use axum::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient};
|
||||
@@ -11,32 +12,40 @@ pub struct Analytics {
|
||||
pub visitors: i64,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl StoredObject for Analytics {
|
||||
fn table_name() -> &'static str {
|
||||
"analytics"
|
||||
}
|
||||
|
||||
fn get_id(&self) -> &str {
|
||||
&self.id
|
||||
}
|
||||
}
|
||||
|
||||
impl Analytics {
|
||||
pub async fn ensure_initialized(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||
let analytics = db.select(("analytics", "current")).await?;
|
||||
let analytics = db.get_item::<Self>("current").await?;
|
||||
|
||||
if analytics.is_none() {
|
||||
let created: Option<Analytics> = db
|
||||
.create(("analytics", "current"))
|
||||
.content(Analytics {
|
||||
id: "current".to_string(),
|
||||
visitors: 0,
|
||||
page_loads: 0,
|
||||
})
|
||||
.await?;
|
||||
let created_analytics = Analytics {
|
||||
id: "current".to_string(),
|
||||
visitors: 0,
|
||||
page_loads: 0,
|
||||
};
|
||||
|
||||
return created.ok_or(AppError::Validation("Failed to initialize settings".into()));
|
||||
};
|
||||
let stored: Option<Self> = db.store_item(created_analytics).await?;
|
||||
return stored.ok_or(AppError::Validation(
|
||||
"Failed to initialize analytics".into(),
|
||||
));
|
||||
}
|
||||
|
||||
analytics.ok_or(AppError::Validation("Failed to initialize settings".into()))
|
||||
analytics.ok_or(AppError::Validation(
|
||||
"Failed to initialize analytics".into(),
|
||||
))
|
||||
}
|
||||
pub async fn get_current(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||
let analytics: Option<Self> = db
|
||||
.client
|
||||
.query("SELECT * FROM type::thing('analytics', 'current')")
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
let analytics: Option<Self> = db.get_item("current").await?;
|
||||
analytics.ok_or(AppError::NotFound("Analytics not found".into()))
|
||||
}
|
||||
|
||||
@@ -61,6 +70,7 @@ impl Analytics {
|
||||
}
|
||||
|
||||
pub async fn get_users_amount(db: &SurrealDbClient) -> Result<i64, AppError> {
|
||||
// We need to use a direct query for COUNT aggregation
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CountResult {
|
||||
count: i64,
|
||||
@@ -76,3 +86,192 @@ impl Analytics {
|
||||
Ok(result.map(|r| r.count).unwrap_or(0))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::stored_object;
|
||||
use uuid::Uuid;
|
||||
|
||||
stored_object!(TestUser, "user", {
|
||||
email: String,
|
||||
password: String,
|
||||
user_id: String
|
||||
});
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analytics_initialization() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Test initialization of analytics
|
||||
let analytics = Analytics::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to initialize analytics");
|
||||
|
||||
// Verify initial state after initialization
|
||||
assert_eq!(analytics.id, "current");
|
||||
assert_eq!(analytics.page_loads, 0);
|
||||
assert_eq!(analytics.visitors, 0);
|
||||
|
||||
// Test idempotency - ensure calling it again doesn't change anything
|
||||
let analytics_again = Analytics::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to get analytics after initialization");
|
||||
|
||||
assert_eq!(analytics.id, analytics_again.id);
|
||||
assert_eq!(analytics.page_loads, analytics_again.page_loads);
|
||||
assert_eq!(analytics.visitors, analytics_again.visitors);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_current_analytics() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Initialize analytics
|
||||
Analytics::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to initialize analytics");
|
||||
|
||||
// Test get_current method
|
||||
let analytics = Analytics::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to get current analytics");
|
||||
|
||||
assert_eq!(analytics.id, "current");
|
||||
assert_eq!(analytics.page_loads, 0);
|
||||
assert_eq!(analytics.visitors, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_increment_visitors() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Initialize analytics
|
||||
Analytics::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to initialize analytics");
|
||||
|
||||
// Test increment_visitors method
|
||||
let analytics = Analytics::increment_visitors(&db)
|
||||
.await
|
||||
.expect("Failed to increment visitors");
|
||||
|
||||
assert_eq!(analytics.visitors, 1);
|
||||
assert_eq!(analytics.page_loads, 0);
|
||||
|
||||
// Increment again and check
|
||||
let analytics = Analytics::increment_visitors(&db)
|
||||
.await
|
||||
.expect("Failed to increment visitors again");
|
||||
|
||||
assert_eq!(analytics.visitors, 2);
|
||||
assert_eq!(analytics.page_loads, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_increment_page_loads() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Initialize analytics
|
||||
Analytics::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to initialize analytics");
|
||||
|
||||
// Test increment_page_loads method
|
||||
let analytics = Analytics::increment_page_loads(&db)
|
||||
.await
|
||||
.expect("Failed to increment page loads");
|
||||
|
||||
assert_eq!(analytics.visitors, 0);
|
||||
assert_eq!(analytics.page_loads, 1);
|
||||
|
||||
// Increment again and check
|
||||
let analytics = Analytics::increment_page_loads(&db)
|
||||
.await
|
||||
.expect("Failed to increment page loads again");
|
||||
|
||||
assert_eq!(analytics.visitors, 0);
|
||||
assert_eq!(analytics.page_loads, 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_users_amount() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Test with no users
|
||||
let count = Analytics::get_users_amount(&db)
|
||||
.await
|
||||
.expect("Failed to get users amount");
|
||||
assert_eq!(count, 0);
|
||||
|
||||
// Create a few test users
|
||||
for i in 0..3 {
|
||||
let user = TestUser {
|
||||
id: format!("user{}", i),
|
||||
email: format!("user{}@example.com", i),
|
||||
password: "password".to_string(),
|
||||
user_id: format!("uid{}", i),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
};
|
||||
|
||||
db.store_item(user)
|
||||
.await
|
||||
.expect("Failed to create test user");
|
||||
}
|
||||
|
||||
// Test users amount after adding users
|
||||
let count = Analytics::get_users_amount(&db)
|
||||
.await
|
||||
.expect("Failed to get users amount after adding users");
|
||||
assert_eq!(count, 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_current_nonexistent() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Don't initialize analytics and try to get it
|
||||
let result = Analytics::get_current(&db).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
if let Err(err) = result {
|
||||
match err {
|
||||
AppError::NotFound(_) => {
|
||||
// Expected error
|
||||
}
|
||||
_ => panic!("Expected NotFound error, got: {:?}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,3 +47,178 @@ impl Conversation {
|
||||
Ok((conversation, messages))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::storage::types::message::MessageRole;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_conversation() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create a new conversation
|
||||
let user_id = "test_user";
|
||||
let title = "Test Conversation";
|
||||
let conversation = Conversation::new(user_id.to_string(), title.to_string());
|
||||
|
||||
// Verify conversation properties
|
||||
assert_eq!(conversation.user_id, user_id);
|
||||
assert_eq!(conversation.title, title);
|
||||
assert!(!conversation.id.is_empty());
|
||||
|
||||
// Store the conversation
|
||||
let result = db.store_item(conversation.clone()).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify it can be retrieved
|
||||
let retrieved: Option<Conversation> = db
|
||||
.get_item(&conversation.id)
|
||||
.await
|
||||
.expect("Failed to retrieve conversation");
|
||||
assert!(retrieved.is_some());
|
||||
|
||||
let retrieved = retrieved.unwrap();
|
||||
assert_eq!(retrieved.id, conversation.id);
|
||||
assert_eq!(retrieved.user_id, user_id);
|
||||
assert_eq!(retrieved.title, title);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_complete_conversation_not_found() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Try to get a conversation that doesn't exist
|
||||
let result =
|
||||
Conversation::get_complete_conversation("nonexistent_id", "test_user", &db).await;
|
||||
assert!(result.is_err());
|
||||
|
||||
match result {
|
||||
Err(AppError::NotFound(_)) => { /* expected error */ }
|
||||
_ => panic!("Expected NotFound error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_complete_conversation_unauthorized() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create and store a conversation for user_id_1
|
||||
let user_id_1 = "user_1";
|
||||
let conversation =
|
||||
Conversation::new(user_id_1.to_string(), "Private Conversation".to_string());
|
||||
let conversation_id = conversation.id.clone();
|
||||
|
||||
db.store_item(conversation)
|
||||
.await
|
||||
.expect("Failed to store conversation");
|
||||
|
||||
// Try to access with a different user
|
||||
let user_id_2 = "user_2";
|
||||
let result =
|
||||
Conversation::get_complete_conversation(&conversation_id, user_id_2, &db).await;
|
||||
assert!(result.is_err());
|
||||
|
||||
match result {
|
||||
Err(AppError::Auth(_)) => { /* expected error */ }
|
||||
_ => panic!("Expected Auth error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_complete_conversation_with_messages() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create and store a conversation for user_id_1
|
||||
let user_id_1 = "user_1";
|
||||
let conversation = Conversation::new(user_id_1.to_string(), "Conversation".to_string());
|
||||
let conversation_id = conversation.id.clone();
|
||||
|
||||
db.store_item(conversation)
|
||||
.await
|
||||
.expect("Failed to store conversation");
|
||||
|
||||
// Create messages
|
||||
let message1 = Message::new(
|
||||
conversation_id.clone(),
|
||||
MessageRole::User,
|
||||
"Hello, AI!".to_string(),
|
||||
None,
|
||||
);
|
||||
let message2 = Message::new(
|
||||
conversation_id.clone(),
|
||||
MessageRole::AI,
|
||||
"Hello, human! How can I help you today?".to_string(),
|
||||
None,
|
||||
);
|
||||
let message3 = Message::new(
|
||||
conversation_id.clone(),
|
||||
MessageRole::User,
|
||||
"Tell me about Rust programming.".to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
// Store messages
|
||||
db.store_item(message1)
|
||||
.await
|
||||
.expect("Failed to store message1");
|
||||
db.store_item(message2)
|
||||
.await
|
||||
.expect("Failed to store message2");
|
||||
db.store_item(message3)
|
||||
.await
|
||||
.expect("Failed to store message3");
|
||||
|
||||
// Retrieve the complete conversation
|
||||
let result =
|
||||
Conversation::get_complete_conversation(&conversation_id, user_id_1, &db).await;
|
||||
assert!(result.is_ok(), "Failed to retrieve complete conversation");
|
||||
|
||||
let (retrieved_conversation, messages) = result.unwrap();
|
||||
|
||||
// Verify conversation data
|
||||
assert_eq!(retrieved_conversation.id, conversation_id);
|
||||
assert_eq!(retrieved_conversation.user_id, user_id_1);
|
||||
assert_eq!(retrieved_conversation.title, "Conversation");
|
||||
|
||||
// Verify messages
|
||||
assert_eq!(messages.len(), 3);
|
||||
|
||||
// Verify messages are sorted by updated_at
|
||||
let message_contents: Vec<&str> = messages.iter().map(|m| m.content.as_str()).collect();
|
||||
assert!(message_contents.contains(&"Hello, AI!"));
|
||||
assert!(message_contents.contains(&"Hello, human! How can I help you today?"));
|
||||
assert!(message_contents.contains(&"Tell me about Rust programming."));
|
||||
|
||||
// Make sure we can't access with different user
|
||||
let user_id_2 = "user_2";
|
||||
let unauthorized_result =
|
||||
Conversation::get_complete_conversation(&conversation_id, user_id_2, &db).await;
|
||||
assert!(unauthorized_result.is_err());
|
||||
match unauthorized_result {
|
||||
Err(AppError::Auth(_)) => { /* expected error */ }
|
||||
_ => panic!("Expected Auth error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,16 +131,32 @@ impl FileInfo {
|
||||
/// Sanitizes the file name to prevent security vulnerabilities like directory traversal.
|
||||
/// Replaces any non-alphanumeric characters (excluding '.' and '_') with underscores.
|
||||
fn sanitize_file_name(file_name: &str) -> String {
|
||||
file_name
|
||||
.chars()
|
||||
.map(|c| {
|
||||
if c.is_ascii_alphanumeric() || c == '.' || c == '_' {
|
||||
c
|
||||
} else {
|
||||
'_'
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
if let Some(idx) = file_name.rfind('.') {
|
||||
let (name, ext) = file_name.split_at(idx);
|
||||
let sanitized_name: String = name
|
||||
.chars()
|
||||
.map(|c| {
|
||||
if c.is_ascii_alphanumeric() || c == '_' {
|
||||
c
|
||||
} else {
|
||||
'_'
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
format!("{}{}", sanitized_name, ext)
|
||||
} else {
|
||||
// No extension
|
||||
file_name
|
||||
.chars()
|
||||
.map(|c| {
|
||||
if c.is_ascii_alphanumeric() || c == '_' {
|
||||
c
|
||||
} else {
|
||||
'_'
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Persists the file to the filesystem under `./data/{user_id}/{uuid}/{file_name}`.
|
||||
@@ -243,3 +259,331 @@ impl FileInfo {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::http::HeaderMap;
|
||||
use axum_typed_multipart::FieldMetadata;
|
||||
use std::io::Write;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
/// Creates a test temporary file with the given content
|
||||
fn create_test_file(content: &[u8], file_name: &str) -> FieldData<NamedTempFile> {
|
||||
let mut temp_file = NamedTempFile::new().expect("Failed to create temp file");
|
||||
temp_file
|
||||
.write_all(content)
|
||||
.expect("Failed to write to temp file");
|
||||
|
||||
let metadata = FieldMetadata {
|
||||
name: Some("file".to_string()),
|
||||
file_name: Some(file_name.to_string()),
|
||||
content_type: None,
|
||||
headers: HeaderMap::default(),
|
||||
};
|
||||
|
||||
let field_data = FieldData {
|
||||
metadata,
|
||||
contents: temp_file,
|
||||
};
|
||||
|
||||
field_data
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_file_creation() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create a test file
|
||||
let content = b"This is a test file content";
|
||||
let file_name = "test_file.txt";
|
||||
let field_data = create_test_file(content, file_name);
|
||||
|
||||
// Create a FileInfo instance
|
||||
let user_id = "test_user";
|
||||
let file_info = FileInfo::new(field_data, &db, user_id).await;
|
||||
|
||||
// We can't fully test persistence to disk in unit tests,
|
||||
// but we can verify the database record was created
|
||||
assert!(file_info.is_ok());
|
||||
let file_info = file_info.unwrap();
|
||||
|
||||
// Check essential properties
|
||||
assert!(!file_info.id.is_empty());
|
||||
assert_eq!(file_info.file_name, file_name);
|
||||
assert!(!file_info.sha256.is_empty());
|
||||
assert!(!file_info.path.is_empty());
|
||||
assert!(file_info.mime_type.contains("text/plain"));
|
||||
|
||||
// Verify it's in the database
|
||||
let stored: Option<FileInfo> = db
|
||||
.get_item(&file_info.id)
|
||||
.await
|
||||
.expect("Failed to retrieve file info");
|
||||
assert!(stored.is_some());
|
||||
let stored = stored.unwrap();
|
||||
assert_eq!(stored.id, file_info.id);
|
||||
assert_eq!(stored.file_name, file_info.file_name);
|
||||
assert_eq!(stored.sha256, file_info.sha256);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_file_duplicate_detection() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// First, store a file with known content
|
||||
let content = b"This is a test file for duplicate detection";
|
||||
let file_name = "original.txt";
|
||||
let user_id = "test_user";
|
||||
|
||||
let field_data1 = create_test_file(content, file_name);
|
||||
let original_file_info = FileInfo::new(field_data1, &db, user_id)
|
||||
.await
|
||||
.expect("Failed to create original file");
|
||||
|
||||
// Now try to store another file with the same content but different name
|
||||
let duplicate_name = "duplicate.txt";
|
||||
let field_data2 = create_test_file(content, duplicate_name);
|
||||
|
||||
// The system should detect it's the same file and return the original FileInfo
|
||||
let duplicate_file_info = FileInfo::new(field_data2, &db, user_id)
|
||||
.await
|
||||
.expect("Failed to process duplicate file");
|
||||
|
||||
// The returned FileInfo should match the original
|
||||
assert_eq!(duplicate_file_info.id, original_file_info.id);
|
||||
assert_eq!(duplicate_file_info.sha256, original_file_info.sha256);
|
||||
|
||||
// But it should retain the original file name, not the duplicate's name
|
||||
assert_eq!(duplicate_file_info.file_name, file_name);
|
||||
assert_ne!(duplicate_file_info.file_name, duplicate_name);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_guess_mime_type() {
|
||||
// Test common file extensions
|
||||
assert_eq!(
|
||||
FileInfo::guess_mime_type(Path::new("test.txt")),
|
||||
"text/plain".to_string()
|
||||
);
|
||||
assert_eq!(
|
||||
FileInfo::guess_mime_type(Path::new("image.png")),
|
||||
"image/png".to_string()
|
||||
);
|
||||
assert_eq!(
|
||||
FileInfo::guess_mime_type(Path::new("document.pdf")),
|
||||
"application/pdf".to_string()
|
||||
);
|
||||
assert_eq!(
|
||||
FileInfo::guess_mime_type(Path::new("data.json")),
|
||||
"application/json".to_string()
|
||||
);
|
||||
|
||||
// Test unknown extension
|
||||
assert_eq!(
|
||||
FileInfo::guess_mime_type(Path::new("unknown.929yz")),
|
||||
"application/octet-stream".to_string()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sanitize_file_name() {
|
||||
// Safe characters should remain unchanged
|
||||
assert_eq!(
|
||||
FileInfo::sanitize_file_name("normal_file.txt"),
|
||||
"normal_file.txt"
|
||||
);
|
||||
assert_eq!(FileInfo::sanitize_file_name("file123.doc"), "file123.doc");
|
||||
|
||||
// Unsafe characters should be replaced with underscores
|
||||
assert_eq!(
|
||||
FileInfo::sanitize_file_name("file with spaces.txt"),
|
||||
"file_with_spaces.txt"
|
||||
);
|
||||
assert_eq!(
|
||||
FileInfo::sanitize_file_name("file/with/path.txt"),
|
||||
"file_with_path.txt"
|
||||
);
|
||||
assert_eq!(
|
||||
FileInfo::sanitize_file_name("file:with:colons.txt"),
|
||||
"file_with_colons.txt"
|
||||
);
|
||||
assert_eq!(
|
||||
FileInfo::sanitize_file_name("../dangerous.txt"),
|
||||
"___dangerous.txt"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_by_sha_not_found() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Try to find a file with a SHA that doesn't exist
|
||||
let result = FileInfo::get_by_sha("nonexistent_sha_hash", &db).await;
|
||||
assert!(result.is_err());
|
||||
|
||||
match result {
|
||||
Err(FileError::FileNotFound(_)) => {
|
||||
// Expected error
|
||||
}
|
||||
_ => panic!("Expected FileNotFound error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_manual_file_info_creation() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create a FileInfo instance directly
|
||||
let now = Utc::now();
|
||||
let file_info = FileInfo {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
sha256: "test_sha256_hash".to_string(),
|
||||
path: "/path/to/file.txt".to_string(),
|
||||
file_name: "manual_file.txt".to_string(),
|
||||
mime_type: "text/plain".to_string(),
|
||||
};
|
||||
|
||||
// Store it in the database
|
||||
let result = db.store_item(file_info.clone()).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify it can be retrieved
|
||||
let retrieved: Option<FileInfo> = db
|
||||
.get_item(&file_info.id)
|
||||
.await
|
||||
.expect("Failed to retrieve file info");
|
||||
assert!(retrieved.is_some());
|
||||
|
||||
let retrieved = retrieved.unwrap();
|
||||
assert_eq!(retrieved.id, file_info.id);
|
||||
assert_eq!(retrieved.sha256, file_info.sha256);
|
||||
assert_eq!(retrieved.file_name, file_info.file_name);
|
||||
assert_eq!(retrieved.path, file_info.path);
|
||||
assert_eq!(retrieved.mime_type, file_info.mime_type);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_id() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create a FileInfo instance directly (without persistence to disk)
|
||||
let now = Utc::now();
|
||||
let file_id = Uuid::new_v4().to_string();
|
||||
|
||||
// Create a temporary directory that mimics the structure we would have on disk
|
||||
let base_dir = Path::new("./data");
|
||||
let user_id = "test_user";
|
||||
let user_dir = base_dir.join(user_id);
|
||||
let uuid_dir = user_dir.join(&file_id);
|
||||
|
||||
tokio::fs::create_dir_all(&uuid_dir)
|
||||
.await
|
||||
.expect("Failed to create test directories");
|
||||
|
||||
// Create a test file in the directory
|
||||
let test_file_path = uuid_dir.join("test_file.txt");
|
||||
tokio::fs::write(&test_file_path, b"test content")
|
||||
.await
|
||||
.expect("Failed to write test file");
|
||||
|
||||
// The file path should point to our test file
|
||||
let file_info = FileInfo {
|
||||
id: file_id.clone(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
sha256: "test_sha256_hash".to_string(),
|
||||
path: test_file_path.to_string_lossy().to_string(),
|
||||
file_name: "test_file.txt".to_string(),
|
||||
mime_type: "text/plain".to_string(),
|
||||
};
|
||||
|
||||
// Store it in the database
|
||||
db.store_item(file_info.clone())
|
||||
.await
|
||||
.expect("Failed to store file info");
|
||||
|
||||
// Verify file exists on disk
|
||||
assert!(tokio::fs::try_exists(&test_file_path)
|
||||
.await
|
||||
.unwrap_or(false));
|
||||
|
||||
// Delete the file
|
||||
let delete_result = FileInfo::delete_by_id(&file_id, &db).await;
|
||||
|
||||
// Delete should be successful
|
||||
assert!(
|
||||
delete_result.is_ok(),
|
||||
"Failed to delete file: {:?}",
|
||||
delete_result
|
||||
);
|
||||
|
||||
// Verify the file is removed from the database
|
||||
let retrieved: Option<FileInfo> = db
|
||||
.get_item(&file_id)
|
||||
.await
|
||||
.expect("Failed to query database");
|
||||
assert!(
|
||||
retrieved.is_none(),
|
||||
"FileInfo should be deleted from the database"
|
||||
);
|
||||
|
||||
// Verify directory is gone
|
||||
assert!(
|
||||
!tokio::fs::try_exists(&uuid_dir).await.unwrap_or(true),
|
||||
"UUID directory should be deleted"
|
||||
);
|
||||
|
||||
// Clean up test directory if it exists
|
||||
let _ = tokio::fs::remove_dir_all(base_dir).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_id_not_found() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Try to delete a file that doesn't exist
|
||||
let result = FileInfo::delete_by_id("nonexistent_id", &db).await;
|
||||
|
||||
// Should fail with FileNotFound error
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(FileError::FileNotFound(_)) => {
|
||||
// Expected error
|
||||
}
|
||||
_ => panic!("Expected FileNotFound error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use crate::{error::AppError, storage::types::file_info::FileInfo};
|
||||
use chrono::Utc;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::info;
|
||||
use url::Url;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
|
||||
pub enum IngestionPayload {
|
||||
Url {
|
||||
url: String,
|
||||
@@ -93,3 +94,237 @@ impl IngestionPayload {
|
||||
Ok(object_list)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// Create a mock FileInfo for testing
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
struct MockFileInfo {
|
||||
id: String,
|
||||
}
|
||||
|
||||
impl From<MockFileInfo> for FileInfo {
|
||||
fn from(mock: MockFileInfo) -> Self {
|
||||
// This is just a test implementation, the actual fields don't matter
|
||||
// as we're just testing the IngestionPayload functionality
|
||||
FileInfo {
|
||||
id: mock.id,
|
||||
sha256: "mock-sha256".to_string(),
|
||||
path: "/mock/path".to_string(),
|
||||
file_name: "mock.txt".to_string(),
|
||||
mime_type: "text/plain".to_string(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_ingestion_payload_with_url() {
|
||||
let url = "https://example.com";
|
||||
let instructions = "Process this URL";
|
||||
let category = "websites";
|
||||
let user_id = "user123";
|
||||
let files = vec![];
|
||||
|
||||
let result = IngestionPayload::create_ingestion_payload(
|
||||
Some(url.to_string()),
|
||||
instructions.to_string(),
|
||||
category.to_string(),
|
||||
files,
|
||||
user_id,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
match &result[0] {
|
||||
IngestionPayload::Url {
|
||||
url: payload_url,
|
||||
instructions: payload_instructions,
|
||||
category: payload_category,
|
||||
user_id: payload_user_id,
|
||||
} => {
|
||||
// URL parser may normalize the URL by adding a trailing slash
|
||||
assert!(payload_url == &url.to_string() || payload_url == &format!("{}/", url));
|
||||
assert_eq!(payload_instructions, &instructions);
|
||||
assert_eq!(payload_category, &category);
|
||||
assert_eq!(payload_user_id, &user_id);
|
||||
}
|
||||
_ => panic!("Expected Url variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_ingestion_payload_with_text() {
|
||||
let text = "This is some text content";
|
||||
let instructions = "Process this text";
|
||||
let category = "notes";
|
||||
let user_id = "user123";
|
||||
let files = vec![];
|
||||
|
||||
let result = IngestionPayload::create_ingestion_payload(
|
||||
Some(text.to_string()),
|
||||
instructions.to_string(),
|
||||
category.to_string(),
|
||||
files,
|
||||
user_id,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
match &result[0] {
|
||||
IngestionPayload::Text {
|
||||
text: payload_text,
|
||||
instructions: payload_instructions,
|
||||
category: payload_category,
|
||||
user_id: payload_user_id,
|
||||
} => {
|
||||
assert_eq!(payload_text, text);
|
||||
assert_eq!(payload_instructions, instructions);
|
||||
assert_eq!(payload_category, category);
|
||||
assert_eq!(payload_user_id, user_id);
|
||||
}
|
||||
_ => panic!("Expected Text variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_ingestion_payload_with_file() {
|
||||
let instructions = "Process this file";
|
||||
let category = "documents";
|
||||
let user_id = "user123";
|
||||
|
||||
// Create a mock FileInfo
|
||||
let mock_file = MockFileInfo {
|
||||
id: "file123".to_string(),
|
||||
};
|
||||
|
||||
let file_info: FileInfo = mock_file.into();
|
||||
let files = vec![file_info.clone()];
|
||||
|
||||
let result = IngestionPayload::create_ingestion_payload(
|
||||
None,
|
||||
instructions.to_string(),
|
||||
category.to_string(),
|
||||
files,
|
||||
user_id,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
match &result[0] {
|
||||
IngestionPayload::File {
|
||||
file_info: payload_file_info,
|
||||
instructions: payload_instructions,
|
||||
category: payload_category,
|
||||
user_id: payload_user_id,
|
||||
} => {
|
||||
assert_eq!(payload_file_info.id, file_info.id);
|
||||
assert_eq!(payload_instructions, instructions);
|
||||
assert_eq!(payload_category, category);
|
||||
assert_eq!(payload_user_id, user_id);
|
||||
}
|
||||
_ => panic!("Expected File variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_ingestion_payload_with_url_and_file() {
|
||||
let url = "https://example.com";
|
||||
let instructions = "Process this data";
|
||||
let category = "mixed";
|
||||
let user_id = "user123";
|
||||
|
||||
// Create a mock FileInfo
|
||||
let mock_file = MockFileInfo {
|
||||
id: "file123".to_string(),
|
||||
};
|
||||
|
||||
let file_info: FileInfo = mock_file.into();
|
||||
let files = vec![file_info.clone()];
|
||||
|
||||
let result = IngestionPayload::create_ingestion_payload(
|
||||
Some(url.to_string()),
|
||||
instructions.to_string(),
|
||||
category.to_string(),
|
||||
files,
|
||||
user_id,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 2);
|
||||
|
||||
// Check first item is URL
|
||||
match &result[0] {
|
||||
IngestionPayload::Url {
|
||||
url: payload_url, ..
|
||||
} => {
|
||||
// URL parser may normalize the URL by adding a trailing slash
|
||||
assert!(payload_url == &url.to_string() || payload_url == &format!("{}/", url));
|
||||
}
|
||||
_ => panic!("Expected first item to be Url variant"),
|
||||
}
|
||||
|
||||
// Check second item is File
|
||||
match &result[1] {
|
||||
IngestionPayload::File {
|
||||
file_info: payload_file_info,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(payload_file_info.id, file_info.id);
|
||||
}
|
||||
_ => panic!("Expected second item to be File variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_ingestion_payload_empty_input() {
|
||||
let instructions = "Process something";
|
||||
let category = "empty";
|
||||
let user_id = "user123";
|
||||
let files = vec![];
|
||||
|
||||
let result = IngestionPayload::create_ingestion_payload(
|
||||
None,
|
||||
instructions.to_string(),
|
||||
category.to_string(),
|
||||
files,
|
||||
user_id,
|
||||
);
|
||||
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(AppError::NotFound(msg)) => {
|
||||
assert_eq!(msg, "No valid content or files provided");
|
||||
}
|
||||
_ => panic!("Expected NotFound error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_ingestion_payload_with_empty_text() {
|
||||
let text = ""; // Empty text
|
||||
let instructions = "Process this";
|
||||
let category = "notes";
|
||||
let user_id = "user123";
|
||||
let files = vec![];
|
||||
|
||||
let result = IngestionPayload::create_ingestion_payload(
|
||||
Some(text.to_string()),
|
||||
instructions.to_string(),
|
||||
category.to_string(),
|
||||
files,
|
||||
user_id,
|
||||
);
|
||||
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(AppError::NotFound(msg)) => {
|
||||
assert_eq!(msg, "No valid content or files provided");
|
||||
}
|
||||
_ => panic!("Expected NotFound error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||
|
||||
use super::ingestion_payload::IngestionPayload;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum IngestionTaskStatus {
|
||||
Created,
|
||||
InProgress {
|
||||
@@ -100,3 +100,196 @@ impl IngestionTask {
|
||||
Ok(jobs)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::Utc;
|
||||
|
||||
// Helper function to create a test ingestion payload
|
||||
fn create_test_payload(user_id: &str) -> IngestionPayload {
|
||||
IngestionPayload::Text {
|
||||
text: "Test content".to_string(),
|
||||
instructions: "Test instructions".to_string(),
|
||||
category: "Test category".to_string(),
|
||||
user_id: user_id.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_new_ingestion_task() {
|
||||
let user_id = "user123";
|
||||
let payload = create_test_payload(user_id);
|
||||
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
|
||||
// Verify task properties
|
||||
assert_eq!(task.user_id, user_id);
|
||||
assert_eq!(task.content, payload);
|
||||
assert!(matches!(task.status, IngestionTaskStatus::Created));
|
||||
assert!(!task.id.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_and_add_to_db() {
|
||||
// Setup in-memory database
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
let user_id = "user123";
|
||||
let payload = create_test_payload(user_id);
|
||||
|
||||
// Create and store task
|
||||
IngestionTask::create_and_add_to_db(payload.clone(), user_id.to_string(), &db)
|
||||
.await
|
||||
.expect("Failed to create and add task to db");
|
||||
|
||||
// Query to verify task was stored
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE user_id = '{}'",
|
||||
IngestionTask::table_name(),
|
||||
user_id
|
||||
);
|
||||
let mut result = db.query(query).await.expect("Query failed");
|
||||
let tasks: Vec<IngestionTask> = result.take(0).unwrap_or_default();
|
||||
|
||||
// Verify task is in the database
|
||||
assert!(!tasks.is_empty(), "Task should exist in the database");
|
||||
let stored_task = &tasks[0];
|
||||
assert_eq!(stored_task.user_id, user_id);
|
||||
assert!(matches!(stored_task.status, IngestionTaskStatus::Created));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_status() {
|
||||
// Setup in-memory database
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
let user_id = "user123";
|
||||
let payload = create_test_payload(user_id);
|
||||
|
||||
// Create task manually
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
let task_id = task.id.clone();
|
||||
|
||||
// Store task
|
||||
db.store_item(task).await.expect("Failed to store task");
|
||||
|
||||
// Update status to InProgress
|
||||
let now = Utc::now();
|
||||
let new_status = IngestionTaskStatus::InProgress {
|
||||
attempts: 1,
|
||||
last_attempt: now,
|
||||
};
|
||||
|
||||
IngestionTask::update_status(&task_id, new_status.clone(), &db)
|
||||
.await
|
||||
.expect("Failed to update status");
|
||||
|
||||
// Verify status updated
|
||||
let updated_task: Option<IngestionTask> = db
|
||||
.get_item::<IngestionTask>(&task_id)
|
||||
.await
|
||||
.expect("Failed to get updated task");
|
||||
|
||||
assert!(updated_task.is_some());
|
||||
let updated_task = updated_task.unwrap();
|
||||
|
||||
match updated_task.status {
|
||||
IngestionTaskStatus::InProgress { attempts, .. } => {
|
||||
assert_eq!(attempts, 1);
|
||||
}
|
||||
_ => panic!("Expected InProgress status"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_unfinished_tasks() {
|
||||
// Setup in-memory database
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
let user_id = "user123";
|
||||
let payload = create_test_payload(user_id);
|
||||
|
||||
// Create tasks with different statuses
|
||||
let created_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
|
||||
let mut in_progress_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
in_progress_task.status = IngestionTaskStatus::InProgress {
|
||||
attempts: 1,
|
||||
last_attempt: Utc::now(),
|
||||
};
|
||||
|
||||
let mut max_attempts_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
max_attempts_task.status = IngestionTaskStatus::InProgress {
|
||||
attempts: MAX_ATTEMPTS,
|
||||
last_attempt: Utc::now(),
|
||||
};
|
||||
|
||||
let mut completed_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
completed_task.status = IngestionTaskStatus::Completed;
|
||||
|
||||
let mut error_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
error_task.status = IngestionTaskStatus::Error("Test error".to_string());
|
||||
|
||||
// Store all tasks
|
||||
db.store_item(created_task)
|
||||
.await
|
||||
.expect("Failed to store created task");
|
||||
db.store_item(in_progress_task)
|
||||
.await
|
||||
.expect("Failed to store in-progress task");
|
||||
db.store_item(max_attempts_task)
|
||||
.await
|
||||
.expect("Failed to store max-attempts task");
|
||||
db.store_item(completed_task)
|
||||
.await
|
||||
.expect("Failed to store completed task");
|
||||
db.store_item(error_task)
|
||||
.await
|
||||
.expect("Failed to store error task");
|
||||
|
||||
// Get unfinished tasks
|
||||
let unfinished_tasks = IngestionTask::get_unfinished_tasks(&db)
|
||||
.await
|
||||
.expect("Failed to get unfinished tasks");
|
||||
|
||||
// Verify only Created and InProgress with attempts < MAX_ATTEMPTS are returned
|
||||
assert_eq!(unfinished_tasks.len(), 2);
|
||||
|
||||
let statuses: Vec<_> = unfinished_tasks
|
||||
.iter()
|
||||
.map(|task| match &task.status {
|
||||
IngestionTaskStatus::Created => "Created",
|
||||
IngestionTaskStatus::InProgress { attempts, .. } => {
|
||||
if *attempts < MAX_ATTEMPTS {
|
||||
"InProgress<MAX"
|
||||
} else {
|
||||
"InProgress>=MAX"
|
||||
}
|
||||
}
|
||||
IngestionTaskStatus::Completed => "Completed",
|
||||
IngestionTaskStatus::Error(_) => "Error",
|
||||
IngestionTaskStatus::Cancelled => "Cancelled",
|
||||
})
|
||||
.collect();
|
||||
|
||||
assert!(statuses.contains(&"Created"));
|
||||
assert!(statuses.contains(&"InProgress<MAX"));
|
||||
assert!(!statuses.contains(&"InProgress>=MAX"));
|
||||
assert!(!statuses.contains(&"Completed"));
|
||||
assert!(!statuses.contains(&"Error"));
|
||||
assert!(!statuses.contains(&"Cancelled"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ use crate::{
|
||||
use async_openai::{config::OpenAIConfig, Client};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
|
||||
pub enum KnowledgeEntityType {
|
||||
Idea,
|
||||
Project,
|
||||
@@ -119,3 +119,198 @@ impl KnowledgeEntity {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_knowledge_entity_creation() {
|
||||
// Create basic test entity
|
||||
let source_id = "source123".to_string();
|
||||
let name = "Test Entity".to_string();
|
||||
let description = "Test Description".to_string();
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let metadata = Some(json!({"key": "value"}));
|
||||
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let entity = KnowledgeEntity::new(
|
||||
source_id.clone(),
|
||||
name.clone(),
|
||||
description.clone(),
|
||||
entity_type.clone(),
|
||||
metadata.clone(),
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Verify all fields are set correctly
|
||||
assert_eq!(entity.source_id, source_id);
|
||||
assert_eq!(entity.name, name);
|
||||
assert_eq!(entity.description, description);
|
||||
assert_eq!(entity.entity_type, entity_type);
|
||||
assert_eq!(entity.metadata, metadata);
|
||||
assert_eq!(entity.embedding, embedding);
|
||||
assert_eq!(entity.user_id, user_id);
|
||||
assert!(!entity.id.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_knowledge_entity_type_from_string() {
|
||||
// Test conversion from String to KnowledgeEntityType
|
||||
assert_eq!(
|
||||
KnowledgeEntityType::from("idea".to_string()),
|
||||
KnowledgeEntityType::Idea
|
||||
);
|
||||
assert_eq!(
|
||||
KnowledgeEntityType::from("Idea".to_string()),
|
||||
KnowledgeEntityType::Idea
|
||||
);
|
||||
assert_eq!(
|
||||
KnowledgeEntityType::from("IDEA".to_string()),
|
||||
KnowledgeEntityType::Idea
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
KnowledgeEntityType::from("project".to_string()),
|
||||
KnowledgeEntityType::Project
|
||||
);
|
||||
assert_eq!(
|
||||
KnowledgeEntityType::from("document".to_string()),
|
||||
KnowledgeEntityType::Document
|
||||
);
|
||||
assert_eq!(
|
||||
KnowledgeEntityType::from("page".to_string()),
|
||||
KnowledgeEntityType::Page
|
||||
);
|
||||
assert_eq!(
|
||||
KnowledgeEntityType::from("textsnippet".to_string()),
|
||||
KnowledgeEntityType::TextSnippet
|
||||
);
|
||||
|
||||
// Test default case
|
||||
assert_eq!(
|
||||
KnowledgeEntityType::from("unknown".to_string()),
|
||||
KnowledgeEntityType::Document
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_knowledge_entity_variants() {
|
||||
let variants = KnowledgeEntityType::variants();
|
||||
assert_eq!(variants.len(), 5);
|
||||
assert!(variants.contains(&"Idea"));
|
||||
assert!(variants.contains(&"Project"));
|
||||
assert!(variants.contains(&"Document"));
|
||||
assert!(variants.contains(&"Page"));
|
||||
assert!(variants.contains(&"TextSnippet"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create two entities with the same source_id
|
||||
let source_id = "source123".to_string();
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let entity1 = KnowledgeEntity::new(
|
||||
source_id.clone(),
|
||||
"Entity 1".to_string(),
|
||||
"Description 1".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
let entity2 = KnowledgeEntity::new(
|
||||
source_id.clone(),
|
||||
"Entity 2".to_string(),
|
||||
"Description 2".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Create an entity with a different source_id
|
||||
let different_source_id = "different_source".to_string();
|
||||
let different_entity = KnowledgeEntity::new(
|
||||
different_source_id.clone(),
|
||||
"Different Entity".to_string(),
|
||||
"Different Description".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Store the entities
|
||||
db.store_item(entity1)
|
||||
.await
|
||||
.expect("Failed to store entity 1");
|
||||
db.store_item(entity2)
|
||||
.await
|
||||
.expect("Failed to store entity 2");
|
||||
db.store_item(different_entity.clone())
|
||||
.await
|
||||
.expect("Failed to store different entity");
|
||||
|
||||
// Delete by source_id
|
||||
KnowledgeEntity::delete_by_source_id(&source_id, &db)
|
||||
.await
|
||||
.expect("Failed to delete entities by source_id");
|
||||
|
||||
// Verify all entities with the specified source_id are deleted
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
KnowledgeEntity::table_name(),
|
||||
source_id
|
||||
);
|
||||
let remaining: Vec<KnowledgeEntity> = db
|
||||
.client
|
||||
.query(query)
|
||||
.await
|
||||
.expect("Query failed")
|
||||
.take(0)
|
||||
.expect("Failed to get query results");
|
||||
assert_eq!(
|
||||
remaining.len(),
|
||||
0,
|
||||
"All entities with the source_id should be deleted"
|
||||
);
|
||||
|
||||
// Verify the entity with a different source_id still exists
|
||||
let different_query = format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
KnowledgeEntity::table_name(),
|
||||
different_source_id
|
||||
);
|
||||
let different_remaining: Vec<KnowledgeEntity> = db
|
||||
.client
|
||||
.query(different_query)
|
||||
.await
|
||||
.expect("Query failed")
|
||||
.take(0)
|
||||
.expect("Failed to get query results");
|
||||
assert_eq!(
|
||||
different_remaining.len(),
|
||||
1,
|
||||
"Entity with different source_id should still exist"
|
||||
);
|
||||
assert_eq!(different_remaining[0].id, different_entity.id);
|
||||
}
|
||||
|
||||
// Note: We can't easily test the patch method without mocking the OpenAI client
|
||||
// and the generate_embedding function. This would require more complex setup.
|
||||
}
|
||||
|
||||
@@ -84,3 +84,258 @@ impl KnowledgeRelationship {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||
|
||||
// Helper function to create a test knowledge entity for the relationship tests
|
||||
async fn create_test_entity(name: &str, db_client: &SurrealDbClient) -> String {
|
||||
let source_id = "source123".to_string();
|
||||
let description = format!("Description for {}", name);
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let embedding = vec![0.1, 0.2, 0.3];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let entity = KnowledgeEntity::new(
|
||||
source_id,
|
||||
name.to_string(),
|
||||
description,
|
||||
entity_type,
|
||||
None,
|
||||
embedding,
|
||||
user_id,
|
||||
);
|
||||
|
||||
let stored: Option<KnowledgeEntity> = db_client
|
||||
.store_item(entity)
|
||||
.await
|
||||
.expect("Failed to store entity");
|
||||
stored.unwrap().id
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_relationship_creation() {
|
||||
let in_id = "entity1".to_string();
|
||||
let out_id = "entity2".to_string();
|
||||
let user_id = "user123".to_string();
|
||||
let source_id = "source123".to_string();
|
||||
let relationship_type = "references".to_string();
|
||||
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
in_id.clone(),
|
||||
out_id.clone(),
|
||||
user_id.clone(),
|
||||
source_id.clone(),
|
||||
relationship_type.clone(),
|
||||
);
|
||||
|
||||
// Verify fields are correctly set
|
||||
assert_eq!(relationship.in_, in_id);
|
||||
assert_eq!(relationship.out, out_id);
|
||||
assert_eq!(relationship.metadata.user_id, user_id);
|
||||
assert_eq!(relationship.metadata.source_id, source_id);
|
||||
assert_eq!(relationship.metadata.relationship_type, relationship_type);
|
||||
assert!(!relationship.id.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_relationship() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create two entities to relate
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
||||
|
||||
// Create relationship
|
||||
let user_id = "user123".to_string();
|
||||
let source_id = "source123".to_string();
|
||||
let relationship_type = "references".to_string();
|
||||
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
entity1_id.clone(),
|
||||
entity2_id.clone(),
|
||||
user_id,
|
||||
source_id.clone(),
|
||||
relationship_type,
|
||||
);
|
||||
|
||||
// Store the relationship
|
||||
relationship
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship");
|
||||
|
||||
// Query to verify the relationship exists by checking for relationships with our source_id
|
||||
// This approach is more reliable than trying to look up by ID
|
||||
let check_query = format!(
|
||||
"SELECT * FROM relates_to WHERE metadata.source_id = '{}'",
|
||||
source_id
|
||||
);
|
||||
let mut check_result = db.query(check_query).await.expect("Check query failed");
|
||||
let check_results: Vec<KnowledgeRelationship> = check_result.take(0).unwrap_or_default();
|
||||
|
||||
// Just verify that a relationship was created
|
||||
assert!(
|
||||
!check_results.is_empty(),
|
||||
"Relationship should exist in the database"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_relationship_by_id() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create two entities to relate
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
||||
|
||||
// Create relationship
|
||||
let user_id = "user123".to_string();
|
||||
let source_id = "source123".to_string();
|
||||
let relationship_type = "references".to_string();
|
||||
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
entity1_id.clone(),
|
||||
entity2_id.clone(),
|
||||
user_id,
|
||||
source_id.clone(),
|
||||
relationship_type,
|
||||
);
|
||||
|
||||
// Store the relationship
|
||||
relationship
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship");
|
||||
|
||||
// Delete the relationship by ID
|
||||
KnowledgeRelationship::delete_relationship_by_id(&relationship.id, &db)
|
||||
.await
|
||||
.expect("Failed to delete relationship by ID");
|
||||
|
||||
// Query to verify the relationship was deleted
|
||||
let query = format!("SELECT * FROM relates_to WHERE id = '{}'", relationship.id);
|
||||
let mut result = db.query(query).await.expect("Query failed");
|
||||
let results: Vec<KnowledgeRelationship> = result.take(0).unwrap_or_default();
|
||||
|
||||
// Verify the relationship no longer exists
|
||||
assert!(results.is_empty(), "Relationship should be deleted");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_relationships_by_source_id() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create entities to relate
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
||||
let entity3_id = create_test_entity("Entity 3", &db).await;
|
||||
|
||||
// Create relationships with the same source_id
|
||||
let user_id = "user123".to_string();
|
||||
let source_id = "source123".to_string();
|
||||
let different_source_id = "different_source".to_string();
|
||||
|
||||
// Create two relationships with the same source_id
|
||||
let relationship1 = KnowledgeRelationship::new(
|
||||
entity1_id.clone(),
|
||||
entity2_id.clone(),
|
||||
user_id.clone(),
|
||||
source_id.clone(),
|
||||
"references".to_string(),
|
||||
);
|
||||
|
||||
let relationship2 = KnowledgeRelationship::new(
|
||||
entity2_id.clone(),
|
||||
entity3_id.clone(),
|
||||
user_id.clone(),
|
||||
source_id.clone(),
|
||||
"contains".to_string(),
|
||||
);
|
||||
|
||||
// Create a relationship with a different source_id
|
||||
let different_relationship = KnowledgeRelationship::new(
|
||||
entity1_id.clone(),
|
||||
entity3_id.clone(),
|
||||
user_id.clone(),
|
||||
different_source_id.clone(),
|
||||
"mentions".to_string(),
|
||||
);
|
||||
|
||||
// Store all relationships
|
||||
relationship1
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship 1");
|
||||
relationship2
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship 2");
|
||||
different_relationship
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store different relationship");
|
||||
|
||||
// Delete relationships by source_id
|
||||
KnowledgeRelationship::delete_relationships_by_source_id(&source_id, &db)
|
||||
.await
|
||||
.expect("Failed to delete relationships by source_id");
|
||||
|
||||
// Query to verify the relationships with source_id were deleted
|
||||
let query1 = format!("SELECT * FROM relates_to WHERE id = '{}'", relationship1.id);
|
||||
let query2 = format!("SELECT * FROM relates_to WHERE id = '{}'", relationship2.id);
|
||||
let different_query = format!(
|
||||
"SELECT * FROM relates_to WHERE id = '{}'",
|
||||
different_relationship.id
|
||||
);
|
||||
|
||||
let mut result1 = db.query(query1).await.expect("Query 1 failed");
|
||||
let results1: Vec<KnowledgeRelationship> = result1.take(0).unwrap_or_default();
|
||||
|
||||
let mut result2 = db.query(query2).await.expect("Query 2 failed");
|
||||
let results2: Vec<KnowledgeRelationship> = result2.take(0).unwrap_or_default();
|
||||
|
||||
let mut different_result = db
|
||||
.query(different_query)
|
||||
.await
|
||||
.expect("Different query failed");
|
||||
let _different_results: Vec<KnowledgeRelationship> =
|
||||
different_result.take(0).unwrap_or_default();
|
||||
|
||||
// Verify relationships with the source_id are deleted
|
||||
assert!(results1.is_empty(), "Relationship 1 should be deleted");
|
||||
assert!(results2.is_empty(), "Relationship 2 should be deleted");
|
||||
|
||||
// For the relationship with different source ID, we need to check differently
|
||||
// Let's just verify we have a relationship where the source_id matches different_source_id
|
||||
let check_query = format!(
|
||||
"SELECT * FROM relates_to WHERE metadata.source_id = '{}'",
|
||||
different_source_id
|
||||
);
|
||||
let mut check_result = db.query(check_query).await.expect("Check query failed");
|
||||
let check_results: Vec<KnowledgeRelationship> = check_result.take(0).unwrap_or_default();
|
||||
|
||||
// Verify the relationship with a different source_id still exists
|
||||
assert!(
|
||||
!check_results.is_empty(),
|
||||
"Relationship with different source_id should still exist"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ use uuid::Uuid;
|
||||
|
||||
use crate::stored_object;
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, Serialize)]
|
||||
#[derive(Deserialize, Debug, Clone, Serialize, PartialEq)]
|
||||
pub enum MessageRole {
|
||||
User,
|
||||
AI,
|
||||
@@ -60,3 +60,128 @@ pub fn format_history(history: &[Message]) -> String {
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::storage::db::SurrealDbClient;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_message_creation() {
|
||||
// Test basic message creation
|
||||
let conversation_id = "test_conversation";
|
||||
let content = "This is a test message";
|
||||
let role = MessageRole::User;
|
||||
let references = Some(vec!["ref1".to_string(), "ref2".to_string()]);
|
||||
|
||||
let message = Message::new(
|
||||
conversation_id.to_string(),
|
||||
role.clone(),
|
||||
content.to_string(),
|
||||
references.clone(),
|
||||
);
|
||||
|
||||
// Verify message properties
|
||||
assert_eq!(message.conversation_id, conversation_id);
|
||||
assert_eq!(message.content, content);
|
||||
assert_eq!(message.role, role);
|
||||
assert_eq!(message.references, references);
|
||||
assert!(!message.id.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_message_persistence() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &uuid::Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create and store a message
|
||||
let conversation_id = "test_conversation";
|
||||
let message = Message::new(
|
||||
conversation_id.to_string(),
|
||||
MessageRole::User,
|
||||
"Hello world".to_string(),
|
||||
None,
|
||||
);
|
||||
let message_id = message.id.clone();
|
||||
|
||||
// Store the message
|
||||
db.store_item(message.clone())
|
||||
.await
|
||||
.expect("Failed to store message");
|
||||
|
||||
// Retrieve the message
|
||||
let retrieved: Option<Message> = db
|
||||
.get_item(&message_id)
|
||||
.await
|
||||
.expect("Failed to retrieve message");
|
||||
|
||||
assert!(retrieved.is_some());
|
||||
let retrieved = retrieved.unwrap();
|
||||
|
||||
// Verify retrieved properties match original
|
||||
assert_eq!(retrieved.id, message.id);
|
||||
assert_eq!(retrieved.conversation_id, message.conversation_id);
|
||||
assert_eq!(retrieved.role, message.role);
|
||||
assert_eq!(retrieved.content, message.content);
|
||||
assert_eq!(retrieved.references, message.references);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_message_role_display() {
|
||||
// Test the Display implementation for MessageRole
|
||||
assert_eq!(format!("{}", MessageRole::User), "User");
|
||||
assert_eq!(format!("{}", MessageRole::AI), "AI");
|
||||
assert_eq!(format!("{}", MessageRole::System), "System");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_message_display() {
|
||||
// Test the Display implementation for Message
|
||||
let message = Message {
|
||||
id: "test_id".to_string(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
conversation_id: "test_convo".to_string(),
|
||||
role: MessageRole::User,
|
||||
content: "Hello world".to_string(),
|
||||
references: None,
|
||||
};
|
||||
|
||||
assert_eq!(format!("{}", message), "User: Hello world");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_format_history() {
|
||||
// Create a vector of messages
|
||||
let messages = vec![
|
||||
Message {
|
||||
id: "1".to_string(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
conversation_id: "test_convo".to_string(),
|
||||
role: MessageRole::User,
|
||||
content: "Hello".to_string(),
|
||||
references: None,
|
||||
},
|
||||
Message {
|
||||
id: "2".to_string(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
conversation_id: "test_convo".to_string(),
|
||||
role: MessageRole::AI,
|
||||
content: "Hi there!".to_string(),
|
||||
references: None,
|
||||
},
|
||||
];
|
||||
|
||||
// Format the history
|
||||
let formatted = format_history(&messages);
|
||||
|
||||
// Verify the formatting
|
||||
assert_eq!(formatted, "User: Hello\nAI: Hi there!");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ macro_rules! stored_object {
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct $name {
|
||||
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||
pub id: String,
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
use crate::storage::types::file_info::deserialize_flexible_id;
|
||||
use axum::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient};
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient, storage::types::StoredObject};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct SystemSettings {
|
||||
@@ -16,41 +17,49 @@ pub struct SystemSettings {
|
||||
pub ingestion_system_prompt: String,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl StoredObject for SystemSettings {
|
||||
fn table_name() -> &'static str {
|
||||
"system_settings"
|
||||
}
|
||||
|
||||
fn get_id(&self) -> &str {
|
||||
&self.id
|
||||
}
|
||||
}
|
||||
|
||||
impl SystemSettings {
|
||||
pub async fn ensure_initialized(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||
let settings = db.select(("system_settings", "current")).await?;
|
||||
let settings: Option<Self> = db.get_item("current").await?;
|
||||
|
||||
if settings.is_none() {
|
||||
let created: Option<SystemSettings> = db
|
||||
.create(("system_settings", "current"))
|
||||
.content(SystemSettings {
|
||||
id: "current".to_string(),
|
||||
registrations_enabled: true,
|
||||
require_email_verification: false,
|
||||
query_model: "gpt-4o-mini".to_string(),
|
||||
processing_model: "gpt-4o-mini".to_string(),
|
||||
query_system_prompt: crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT.to_string(),
|
||||
ingestion_system_prompt: crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT.to_string(),
|
||||
})
|
||||
.await?;
|
||||
let created_settings = SystemSettings {
|
||||
id: "current".to_string(),
|
||||
registrations_enabled: true,
|
||||
require_email_verification: false,
|
||||
query_model: "gpt-4o-mini".to_string(),
|
||||
processing_model: "gpt-4o-mini".to_string(),
|
||||
query_system_prompt:
|
||||
crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT.to_string(),
|
||||
ingestion_system_prompt:
|
||||
crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT
|
||||
.to_string(),
|
||||
};
|
||||
|
||||
return created.ok_or(AppError::Validation("Failed to initialize settings".into()));
|
||||
};
|
||||
let stored: Option<Self> = db.store_item(created_settings).await?;
|
||||
return stored.ok_or(AppError::Validation("Failed to initialize settings".into()));
|
||||
}
|
||||
|
||||
settings.ok_or(AppError::Validation("Failed to initialize settings".into()))
|
||||
}
|
||||
|
||||
pub async fn get_current(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||
let settings: Option<Self> = db
|
||||
.client
|
||||
.query("SELECT * FROM type::thing('system_settings', 'current')")
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
let settings: Option<Self> = db.get_item("current").await?;
|
||||
settings.ok_or(AppError::NotFound("System settings not found".into()))
|
||||
}
|
||||
|
||||
pub async fn update(db: &SurrealDbClient, changes: Self) -> Result<Self, AppError> {
|
||||
// We need to use a direct query for the update with MERGE
|
||||
let updated: Option<Self> = db
|
||||
.client
|
||||
.query("UPDATE type::thing('system_settings', 'current') MERGE $changes RETURN AFTER")
|
||||
@@ -66,8 +75,11 @@ impl SystemSettings {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
query_system_prompt: crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT.to_string(),
|
||||
ingestion_system_prompt: crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT.to_string(),
|
||||
query_system_prompt: crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT
|
||||
.to_string(),
|
||||
ingestion_system_prompt:
|
||||
crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT
|
||||
.to_string(),
|
||||
query_model: "gpt-4o-mini".to_string(),
|
||||
processing_model: "gpt-4o-mini".to_string(),
|
||||
registrations_enabled: true,
|
||||
@@ -75,3 +87,159 @@ impl SystemSettings {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_settings_initialization() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Test initialization of system settings
|
||||
let settings = SystemSettings::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to initialize system settings");
|
||||
|
||||
// Verify initial state after initialization
|
||||
assert_eq!(settings.id, "current");
|
||||
assert_eq!(settings.registrations_enabled, true);
|
||||
assert_eq!(settings.require_email_verification, false);
|
||||
assert_eq!(settings.query_model, "gpt-4o-mini");
|
||||
assert_eq!(settings.processing_model, "gpt-4o-mini");
|
||||
assert_eq!(
|
||||
settings.query_system_prompt,
|
||||
crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT
|
||||
);
|
||||
assert_eq!(
|
||||
settings.ingestion_system_prompt,
|
||||
crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT
|
||||
);
|
||||
|
||||
// Test idempotency - ensure calling it again doesn't change anything
|
||||
let settings_again = SystemSettings::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to get settings after initialization");
|
||||
|
||||
assert_eq!(settings.id, settings_again.id);
|
||||
assert_eq!(
|
||||
settings.registrations_enabled,
|
||||
settings_again.registrations_enabled
|
||||
);
|
||||
assert_eq!(
|
||||
settings.require_email_verification,
|
||||
settings_again.require_email_verification
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_current_settings() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Initialize settings
|
||||
SystemSettings::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to initialize system settings");
|
||||
|
||||
// Test get_current method
|
||||
let settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to get current settings");
|
||||
|
||||
assert_eq!(settings.id, "current");
|
||||
assert_eq!(settings.registrations_enabled, true);
|
||||
assert_eq!(settings.require_email_verification, false);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_settings() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Initialize settings
|
||||
SystemSettings::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to initialize system settings");
|
||||
|
||||
// Create updated settings
|
||||
let mut updated_settings = SystemSettings::new();
|
||||
updated_settings.id = "current".to_string();
|
||||
updated_settings.registrations_enabled = false;
|
||||
updated_settings.require_email_verification = true;
|
||||
updated_settings.query_model = "gpt-4".to_string();
|
||||
|
||||
// Test update method
|
||||
let result = SystemSettings::update(&db, updated_settings)
|
||||
.await
|
||||
.expect("Failed to update settings");
|
||||
|
||||
assert_eq!(result.id, "current");
|
||||
assert_eq!(result.registrations_enabled, false);
|
||||
assert_eq!(result.require_email_verification, true);
|
||||
assert_eq!(result.query_model, "gpt-4");
|
||||
|
||||
// Verify changes persisted by getting current settings
|
||||
let current = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to get current settings after update");
|
||||
|
||||
assert_eq!(current.registrations_enabled, false);
|
||||
assert_eq!(current.require_email_verification, true);
|
||||
assert_eq!(current.query_model, "gpt-4");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_current_nonexistent() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Don't initialize settings and try to get them
|
||||
let result = SystemSettings::get_current(&db).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(AppError::NotFound(_)) => {
|
||||
// Expected error
|
||||
}
|
||||
Err(e) => panic!("Expected NotFound error, got: {:?}", e),
|
||||
Ok(_) => panic!("Expected error but got Ok"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_new_method() {
|
||||
let settings = SystemSettings::new();
|
||||
|
||||
assert!(settings.id.len() > 0);
|
||||
assert_eq!(settings.registrations_enabled, true);
|
||||
assert_eq!(settings.require_email_verification, false);
|
||||
assert_eq!(settings.query_model, "gpt-4o-mini");
|
||||
assert_eq!(settings.processing_model, "gpt-4o-mini");
|
||||
assert_eq!(
|
||||
settings.query_system_prompt,
|
||||
crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT
|
||||
);
|
||||
assert_eq!(
|
||||
settings.ingestion_system_prompt,
|
||||
crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,3 +36,175 @@ impl TextChunk {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_chunk_creation() {
|
||||
// Test basic object creation
|
||||
let source_id = "source123".to_string();
|
||||
let chunk = "This is a text chunk for testing embeddings".to_string();
|
||||
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let text_chunk = TextChunk::new(
|
||||
source_id.clone(),
|
||||
chunk.clone(),
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Check that the fields are set correctly
|
||||
assert_eq!(text_chunk.source_id, source_id);
|
||||
assert_eq!(text_chunk.chunk, chunk);
|
||||
assert_eq!(text_chunk.embedding, embedding);
|
||||
assert_eq!(text_chunk.user_id, user_id);
|
||||
assert!(!text_chunk.id.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create test data
|
||||
let source_id = "source123".to_string();
|
||||
let chunk1 = "First chunk from the same source".to_string();
|
||||
let chunk2 = "Second chunk from the same source".to_string();
|
||||
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
// Create two chunks with the same source_id
|
||||
let text_chunk1 = TextChunk::new(
|
||||
source_id.clone(),
|
||||
chunk1,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
let text_chunk2 = TextChunk::new(
|
||||
source_id.clone(),
|
||||
chunk2,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Create a chunk with a different source_id
|
||||
let different_source_id = "different_source".to_string();
|
||||
let different_chunk = TextChunk::new(
|
||||
different_source_id.clone(),
|
||||
"Different source chunk".to_string(),
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Store the chunks
|
||||
db.store_item(text_chunk1)
|
||||
.await
|
||||
.expect("Failed to store text chunk 1");
|
||||
db.store_item(text_chunk2)
|
||||
.await
|
||||
.expect("Failed to store text chunk 2");
|
||||
db.store_item(different_chunk.clone())
|
||||
.await
|
||||
.expect("Failed to store different chunk");
|
||||
|
||||
// Delete by source_id
|
||||
TextChunk::delete_by_source_id(&source_id, &db)
|
||||
.await
|
||||
.expect("Failed to delete chunks by source_id");
|
||||
|
||||
// Verify all chunks with the original source_id are deleted
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
TextChunk::table_name(),
|
||||
source_id
|
||||
);
|
||||
let remaining: Vec<TextChunk> = db
|
||||
.client
|
||||
.query(query)
|
||||
.await
|
||||
.expect("Query failed")
|
||||
.take(0)
|
||||
.expect("Failed to get query results");
|
||||
assert_eq!(
|
||||
remaining.len(),
|
||||
0,
|
||||
"All chunks with the source_id should be deleted"
|
||||
);
|
||||
|
||||
// Verify the different source_id chunk still exists
|
||||
let different_query = format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
TextChunk::table_name(),
|
||||
different_source_id
|
||||
);
|
||||
let different_remaining: Vec<TextChunk> = db
|
||||
.client
|
||||
.query(different_query)
|
||||
.await
|
||||
.expect("Query failed")
|
||||
.take(0)
|
||||
.expect("Failed to get query results");
|
||||
assert_eq!(
|
||||
different_remaining.len(),
|
||||
1,
|
||||
"Chunk with different source_id should still exist"
|
||||
);
|
||||
assert_eq!(different_remaining[0].id, different_chunk.id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_nonexistent_source_id() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create a chunk with a real source_id
|
||||
let real_source_id = "real_source".to_string();
|
||||
let chunk = "Test chunk".to_string();
|
||||
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let text_chunk = TextChunk::new(real_source_id.clone(), chunk, embedding, user_id);
|
||||
|
||||
// Store the chunk
|
||||
db.store_item(text_chunk)
|
||||
.await
|
||||
.expect("Failed to store text chunk");
|
||||
|
||||
// Delete using nonexistent source_id
|
||||
let nonexistent_source_id = "nonexistent_source";
|
||||
TextChunk::delete_by_source_id(nonexistent_source_id, &db)
|
||||
.await
|
||||
.expect("Delete operation with nonexistent source_id should not fail");
|
||||
|
||||
// Verify the real chunk still exists
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
TextChunk::table_name(),
|
||||
real_source_id
|
||||
);
|
||||
let remaining: Vec<TextChunk> = db
|
||||
.client
|
||||
.query(query)
|
||||
.await
|
||||
.expect("Query failed")
|
||||
.take(0)
|
||||
.expect("Failed to get query results");
|
||||
assert_eq!(
|
||||
remaining.len(),
|
||||
1,
|
||||
"Chunk with real source_id should still exist"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,3 +57,120 @@ impl TextContent {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_content_creation() {
|
||||
// Test basic object creation
|
||||
let text = "Test content text".to_string();
|
||||
let instructions = "Test instructions".to_string();
|
||||
let category = "Test category".to_string();
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let text_content = TextContent::new(
|
||||
text.clone(),
|
||||
instructions.clone(),
|
||||
category.clone(),
|
||||
None,
|
||||
None,
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Check that the fields are set correctly
|
||||
assert_eq!(text_content.text, text);
|
||||
assert_eq!(text_content.instructions, instructions);
|
||||
assert_eq!(text_content.category, category);
|
||||
assert_eq!(text_content.user_id, user_id);
|
||||
assert!(text_content.file_info.is_none());
|
||||
assert!(text_content.url.is_none());
|
||||
assert!(!text_content.id.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_content_with_url() {
|
||||
// Test creating with URL
|
||||
let text = "Content with URL".to_string();
|
||||
let instructions = "URL instructions".to_string();
|
||||
let category = "URL category".to_string();
|
||||
let user_id = "user123".to_string();
|
||||
let url = Some("https://example.com/document.pdf".to_string());
|
||||
|
||||
let text_content = TextContent::new(
|
||||
text.clone(),
|
||||
instructions.clone(),
|
||||
category.clone(),
|
||||
None,
|
||||
url.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Check URL field is set
|
||||
assert_eq!(text_content.url, url);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_content_patch() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create initial text content
|
||||
let initial_text = "Initial text".to_string();
|
||||
let initial_instructions = "Initial instructions".to_string();
|
||||
let initial_category = "Initial category".to_string();
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let text_content = TextContent::new(
|
||||
initial_text,
|
||||
initial_instructions,
|
||||
initial_category,
|
||||
None,
|
||||
None,
|
||||
user_id,
|
||||
);
|
||||
|
||||
// Store the text content
|
||||
let stored: Option<TextContent> = db
|
||||
.store_item(text_content.clone())
|
||||
.await
|
||||
.expect("Failed to store text content");
|
||||
assert!(stored.is_some());
|
||||
|
||||
// New values for patch
|
||||
let new_instructions = "Updated instructions";
|
||||
let new_category = "Updated category";
|
||||
let new_text = "Updated text content";
|
||||
|
||||
// Apply the patch
|
||||
TextContent::patch(
|
||||
&text_content.id,
|
||||
new_instructions,
|
||||
new_category,
|
||||
new_text,
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to patch text content");
|
||||
|
||||
// Retrieve the updated content
|
||||
let updated: Option<TextContent> = db
|
||||
.get_item(&text_content.id)
|
||||
.await
|
||||
.expect("Failed to get updated text content");
|
||||
assert!(updated.is_some());
|
||||
|
||||
let updated_content = updated.unwrap();
|
||||
|
||||
// Verify the updates
|
||||
assert_eq!(updated_content.instructions, new_instructions);
|
||||
assert_eq!(updated_content.category, new_category);
|
||||
assert_eq!(updated_content.text, new_text);
|
||||
assert!(updated_content.updated_at > text_content.updated_at);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -414,3 +414,276 @@ impl User {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// Helper function to set up a test database with SystemSettings
|
||||
async fn setup_test_db() -> SurrealDbClient {
|
||||
let namespace = "test_ns";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.ensure_initialized()
|
||||
.await
|
||||
.expect("Failed to setup the systemsettings");
|
||||
|
||||
db
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_user_creation() {
|
||||
// Setup test database
|
||||
let db = setup_test_db().await;
|
||||
|
||||
// Create a user
|
||||
let email = "test@example.com";
|
||||
let password = "test_password";
|
||||
let timezone = "America/New_York";
|
||||
|
||||
let user = User::create_new(
|
||||
email.to_string(),
|
||||
password.to_string(),
|
||||
&db,
|
||||
timezone.to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create user");
|
||||
|
||||
// Verify user properties
|
||||
assert!(!user.id.is_empty());
|
||||
assert_eq!(user.email, email);
|
||||
assert_ne!(user.password, password); // Password should be hashed
|
||||
assert!(!user.anonymous);
|
||||
assert_eq!(user.timezone, timezone);
|
||||
|
||||
// Verify it can be retrieved
|
||||
let retrieved: Option<User> = db
|
||||
.get_item(&user.id)
|
||||
.await
|
||||
.expect("Failed to retrieve user");
|
||||
assert!(retrieved.is_some());
|
||||
|
||||
let retrieved = retrieved.unwrap();
|
||||
assert_eq!(retrieved.id, user.id);
|
||||
assert_eq!(retrieved.email, email);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_user_authentication() {
|
||||
// Setup test database
|
||||
let db = setup_test_db().await;
|
||||
|
||||
// Create a user
|
||||
let email = "auth_test@example.com";
|
||||
let password = "auth_password";
|
||||
|
||||
User::create_new(
|
||||
email.to_string(),
|
||||
password.to_string(),
|
||||
&db,
|
||||
"UTC".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create user");
|
||||
|
||||
// Test successful authentication
|
||||
let auth_result = User::authenticate(email, password, &db).await;
|
||||
assert!(auth_result.is_ok());
|
||||
|
||||
// Test failed authentication with wrong password
|
||||
let wrong_auth = User::authenticate(email, "wrong_password", &db).await;
|
||||
assert!(wrong_auth.is_err());
|
||||
|
||||
// Test failed authentication with non-existent user
|
||||
let nonexistent = User::authenticate("nonexistent@example.com", password, &db).await;
|
||||
assert!(nonexistent.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_find_by_email() {
|
||||
// Setup test database
|
||||
let db = setup_test_db().await;
|
||||
|
||||
// Create a user
|
||||
let email = "find_test@example.com";
|
||||
let password = "find_password";
|
||||
|
||||
let created_user = User::create_new(
|
||||
email.to_string(),
|
||||
password.to_string(),
|
||||
&db,
|
||||
"UTC".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create user");
|
||||
|
||||
// Test finding user by email
|
||||
let found_user = User::find_by_email(email, &db)
|
||||
.await
|
||||
.expect("Error searching for user");
|
||||
assert!(found_user.is_some());
|
||||
let found_user = found_user.unwrap();
|
||||
assert_eq!(found_user.id, created_user.id);
|
||||
assert_eq!(found_user.email, email);
|
||||
|
||||
// Test finding non-existent user
|
||||
let not_found = User::find_by_email("nonexistent@example.com", &db)
|
||||
.await
|
||||
.expect("Error searching for user");
|
||||
assert!(not_found.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_api_key_management() {
|
||||
// Setup test database
|
||||
let db = setup_test_db().await;
|
||||
|
||||
// Create a user
|
||||
let email = "apikey_test@example.com";
|
||||
let password = "apikey_password";
|
||||
|
||||
let user = User::create_new(
|
||||
email.to_string(),
|
||||
password.to_string(),
|
||||
&db,
|
||||
"UTC".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create user");
|
||||
|
||||
// Initially, user should have no API key
|
||||
assert!(user.api_key.is_none());
|
||||
|
||||
// Generate API key
|
||||
let api_key = User::set_api_key(&user.id, &db)
|
||||
.await
|
||||
.expect("Failed to set API key");
|
||||
assert!(!api_key.is_empty());
|
||||
assert!(api_key.starts_with("sk_"));
|
||||
|
||||
// Verify the API key was saved
|
||||
let updated_user: Option<User> = db
|
||||
.get_item(&user.id)
|
||||
.await
|
||||
.expect("Failed to retrieve user");
|
||||
assert!(updated_user.is_some());
|
||||
let updated_user = updated_user.unwrap();
|
||||
assert_eq!(updated_user.api_key, Some(api_key.clone()));
|
||||
|
||||
// Test finding user by API key
|
||||
let found_user = User::find_by_api_key(&api_key, &db)
|
||||
.await
|
||||
.expect("Error searching by API key");
|
||||
assert!(found_user.is_some());
|
||||
let found_user = found_user.unwrap();
|
||||
assert_eq!(found_user.id, user.id);
|
||||
|
||||
// Revoke API key
|
||||
User::revoke_api_key(&user.id, &db)
|
||||
.await
|
||||
.expect("Failed to revoke API key");
|
||||
|
||||
// Verify API key was revoked
|
||||
let revoked_user: Option<User> = db
|
||||
.get_item(&user.id)
|
||||
.await
|
||||
.expect("Failed to retrieve user");
|
||||
assert!(revoked_user.is_some());
|
||||
let revoked_user = revoked_user.unwrap();
|
||||
assert!(revoked_user.api_key.is_none());
|
||||
|
||||
// Test searching by revoked API key
|
||||
let not_found = User::find_by_api_key(&api_key, &db)
|
||||
.await
|
||||
.expect("Error searching by API key");
|
||||
assert!(not_found.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_password_update() {
|
||||
// Setup test database
|
||||
let db = setup_test_db().await;
|
||||
|
||||
// Create a user
|
||||
let email = "pwd_test@example.com";
|
||||
let old_password = "old_password";
|
||||
let new_password = "new_password";
|
||||
|
||||
User::create_new(
|
||||
email.to_string(),
|
||||
old_password.to_string(),
|
||||
&db,
|
||||
"UTC".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create user");
|
||||
|
||||
// Authenticate with old password
|
||||
let auth_result = User::authenticate(email, old_password, &db).await;
|
||||
assert!(auth_result.is_ok());
|
||||
|
||||
// Update password
|
||||
User::patch_password(email, new_password, &db)
|
||||
.await
|
||||
.expect("Failed to update password");
|
||||
|
||||
// Old password should no longer work
|
||||
let old_auth = User::authenticate(email, old_password, &db).await;
|
||||
assert!(old_auth.is_err());
|
||||
|
||||
// New password should work
|
||||
let new_auth = User::authenticate(email, new_password, &db).await;
|
||||
assert!(new_auth.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validate_timezone() {
|
||||
// Valid timezones should be accepted as-is
|
||||
assert_eq!(validate_timezone("America/New_York"), "America/New_York");
|
||||
assert_eq!(validate_timezone("Europe/London"), "Europe/London");
|
||||
assert_eq!(validate_timezone("Asia/Tokyo"), "Asia/Tokyo");
|
||||
assert_eq!(validate_timezone("UTC"), "UTC");
|
||||
|
||||
// Invalid timezones should be replaced with UTC
|
||||
assert_eq!(validate_timezone("Invalid/Timezone"), "UTC");
|
||||
assert_eq!(validate_timezone("Not_Real"), "UTC");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timezone_update() {
|
||||
// Setup test database
|
||||
let db = setup_test_db().await;
|
||||
|
||||
// Create user with default timezone
|
||||
let email = "timezone_test@example.com";
|
||||
let user = User::create_new(
|
||||
email.to_string(),
|
||||
"password".to_string(),
|
||||
&db,
|
||||
"UTC".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create user");
|
||||
|
||||
assert_eq!(user.timezone, "UTC");
|
||||
|
||||
// Update timezone
|
||||
let new_timezone = "Europe/Paris";
|
||||
User::update_timezone(&user.id, new_timezone, &db)
|
||||
.await
|
||||
.expect("Failed to update timezone");
|
||||
|
||||
// Verify timezone was updated
|
||||
let updated_user: Option<User> = db
|
||||
.get_item(&user.id)
|
||||
.await
|
||||
.expect("Failed to retrieve user");
|
||||
assert!(updated_user.is_some());
|
||||
let updated_user = updated_user.unwrap();
|
||||
assert_eq!(updated_user.timezone, new_timezone);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user