clippy: adhere to pedantic clippy, uniform test error handling

This commit is contained in:
Per Stark
2026-05-26 11:43:45 +02:00
parent 6a5d631287
commit 000852c94c
68 changed files with 2468 additions and 2547 deletions
+35 -28
View File
@@ -202,6 +202,7 @@ impl SurrealDbClient {
#[cfg(test)]
mod tests {
use anyhow::{self, Context};
use crate::stored_object;
use super::*;
@@ -212,19 +213,17 @@ mod tests {
});
#[tokio::test]
async fn test_initialization_and_crud() {
async fn test_initialization_and_crud() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); // ensures isolation per test run
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Call your initialization
db.apply_migrations()
.await
.expect("Failed to initialize schema");
.with_context(|| "Failed to initialize schema".to_string())?;
// Test basic CRUD
let dummy = Dummy {
id: "abc".to_string(),
name: "first".to_string(),
@@ -232,50 +231,50 @@ mod tests {
updated_at: Utc::now(),
};
// Store
let stored = db.store_item(dummy.clone()).await.expect("Failed to store");
let stored = db
.store_item(dummy.clone())
.await
.with_context(|| "Failed to store".to_string())?;
assert!(stored.is_some());
// Read
let fetched = db
.get_item::<Dummy>(&dummy.id)
.await
.expect("Failed to fetch");
.with_context(|| "Failed to fetch".to_string())?;
assert_eq!(fetched, Some(dummy.clone()));
// Read all
let all = db
.get_all_stored_items::<Dummy>()
.await
.expect("Failed to fetch all");
.with_context(|| "Failed to fetch all".to_string())?;
assert!(all.contains(&dummy));
// Delete
let deleted = db
.delete_item::<Dummy>(&dummy.id)
.await
.expect("Failed to delete");
.with_context(|| "Failed to delete".to_string())?;
assert_eq!(deleted, Some(dummy));
// After delete, should not be present
let fetch_post = db
.get_item::<Dummy>("abc")
.await
.expect("Failed fetch post delete");
.with_context(|| "Failed fetch post delete".to_string())?;
assert!(fetch_post.is_none());
Ok(())
}
#[tokio::test]
async fn upsert_item_overwrites_existing_records() {
async fn upsert_item_overwrites_existing_records() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations()
.await
.expect("Failed to initialize schema");
.with_context(|| "Failed to initialize schema".to_string())?;
let mut dummy = Dummy {
id: "abc".to_string(),
@@ -286,17 +285,21 @@ mod tests {
db.store_item(dummy.clone())
.await
.expect("Failed to store initial record");
.with_context(|| "Failed to store initial record".to_string())?;
dummy.name = "updated".to_string();
let upserted = db
.upsert_item(dummy.clone())
.await
.expect("Failed to upsert record");
.with_context(|| "Failed to upsert record".to_string())?;
assert!(upserted.is_some());
let fetched: Option<Dummy> = db.get_item(&dummy.id).await.expect("fetch after upsert");
assert_eq!(fetched.unwrap().name, "updated");
let fetched: Option<Dummy> = db
.get_item(&dummy.id)
.await
.with_context(|| "fetch after upsert".to_string())?;
let fetched = fetched.ok_or_else(|| anyhow::anyhow!("Expected record to exist after upsert"))?;
assert_eq!(fetched.name, "updated");
let new_record = Dummy {
id: "def".to_string(),
@@ -306,25 +309,29 @@ mod tests {
};
db.upsert_item(new_record.clone())
.await
.expect("Failed to upsert new record");
.with_context(|| "Failed to upsert new record".to_string())?;
let fetched_new: Option<Dummy> = db
.get_item(&new_record.id)
.await
.expect("fetch inserted via upsert");
.with_context(|| "fetch inserted via upsert".to_string())?;
assert_eq!(fetched_new, Some(new_record));
Ok(())
}
#[tokio::test]
async fn test_applying_migrations() {
async fn test_applying_migrations() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations()
.await
.expect("Failed to build indexes");
.with_context(|| "Failed to build indexes".to_string())?;
Ok(())
}
}
+54 -57
View File
@@ -159,23 +159,23 @@ impl FtsIndexSpec {
/// Build runtime Surreal indexes (FTS + HNSW) using concurrent creation with readiness polling.
/// Idempotent: safe to call multiple times and will overwrite HNSW definitions when the dimension changes.
pub async fn ensure_runtime_indexes(
pub async fn ensure_runtime(
db: &SurrealDbClient,
embedding_dimension: usize,
) -> Result<(), AppError> {
ensure_runtime_indexes_inner(db, embedding_dimension)
ensure_runtime_inner(db, embedding_dimension)
.await
.map_err(|err| AppError::InternalError(err.to_string()))
}
/// Rebuild known FTS and HNSW indexes, skipping any that are not yet defined.
pub async fn rebuild_indexes(db: &SurrealDbClient) -> Result<(), AppError> {
rebuild_indexes_inner(db)
pub async fn rebuild(db: &SurrealDbClient) -> Result<(), AppError> {
rebuild_inner(db)
.await
.map_err(|err| AppError::InternalError(err.to_string()))
}
async fn ensure_runtime_indexes_inner(
async fn ensure_runtime_inner(
db: &SurrealDbClient,
embedding_dimension: usize,
) -> Result<()> {
@@ -262,9 +262,8 @@ async fn get_index_status(db: &SurrealDbClient, index_name: &str, table: &str) -
.context("checking index status")?;
let info: Option<Value> = info_res.take(0).context("failed to take info result")?;
let info = match info {
Some(i) => i,
None => return Ok("unknown".to_string()),
let Some(info) = info else {
return Ok("unknown".to_string());
};
let building = info.get("building");
@@ -277,7 +276,7 @@ async fn get_index_status(db: &SurrealDbClient, index_name: &str, table: &str) -
Ok(status)
}
async fn rebuild_indexes_inner(db: &SurrealDbClient) -> Result<()> {
async fn rebuild_inner(db: &SurrealDbClient) -> Result<()> {
debug!("Rebuilding indexes with concurrent definitions");
create_fts_analyzer(db).await?;
@@ -385,10 +384,9 @@ async fn create_fts_analyzer(db: &SurrealDbClient) -> Result<()> {
// is unavailable in the running Surreal build. Use IF NOT EXISTS to avoid clobbering
// an existing analyzer definition.
let snowball_query = format!(
"DEFINE ANALYZER IF NOT EXISTS {analyzer}
"DEFINE ANALYZER IF NOT EXISTS {FTS_ANALYZER_NAME}
TOKENIZERS class
FILTERS lowercase, ascii, snowball(english);",
analyzer = FTS_ANALYZER_NAME
FILTERS lowercase, ascii, snowball(english);"
);
match db.client.query(snowball_query).await {
@@ -410,10 +408,9 @@ async fn create_fts_analyzer(db: &SurrealDbClient) -> Result<()> {
}
let fallback_query = format!(
"DEFINE ANALYZER IF NOT EXISTS {analyzer}
"DEFINE ANALYZER IF NOT EXISTS {FTS_ANALYZER_NAME}
TOKENIZERS class
FILTERS lowercase, ascii;",
analyzer = FTS_ANALYZER_NAME
FILTERS lowercase, ascii;"
);
let res = db
@@ -446,6 +443,7 @@ async fn create_index_with_polling(
table: &str,
progress_table: Option<&str>,
) -> Result<()> {
const MAX_ATTEMPTS: usize = 3;
let expected_total = match progress_table {
Some(table) => Some(count_table_rows(db, table).await.with_context(|| {
format!("counting rows in {table} for index {index_name} progress")
@@ -453,10 +451,9 @@ async fn create_index_with_polling(
None => None,
};
let mut attempts = 0;
const MAX_ATTEMPTS: usize = 3;
let mut attempts: usize = 0;
loop {
attempts += 1;
attempts = attempts.saturating_add(1);
let res = db
.client
.query(definition.clone())
@@ -527,8 +524,8 @@ async fn poll_index_build_status(
break;
};
match snapshot.progress_pct {
Some(pct) => debug!(
if let Some(pct) = snapshot.progress_pct {
debug!(
index = %index_name,
table = %table,
status = snapshot.status,
@@ -539,8 +536,9 @@ async fn poll_index_build_status(
total = snapshot.total_rows,
progress_pct = format_args!("{pct:.1}"),
"Index build status"
),
None => debug!(
);
} else {
debug!(
index = %index_name,
table = %table,
status = snapshot.status,
@@ -549,7 +547,7 @@ async fn poll_index_build_status(
updated = snapshot.updated,
processed = snapshot.processed,
"Index build status"
),
);
}
if snapshot.is_ready() {
@@ -611,17 +609,17 @@ fn parse_index_build_info(
let initial = building
.and_then(|b| b.get("initial"))
.and_then(|v| v.as_u64())
.and_then(serde_json::Value::as_u64)
.unwrap_or(0);
let pending = building
.and_then(|b| b.get("pending"))
.and_then(|v| v.as_u64())
.and_then(serde_json::Value::as_u64)
.unwrap_or(0);
let updated = building
.and_then(|b| b.get("updated"))
.and_then(|v| v.as_u64())
.and_then(serde_json::Value::as_u64)
.unwrap_or(0);
// `initial` is the number of rows seen when the build started; `updated` accounts for later writes.
@@ -631,7 +629,7 @@ fn parse_index_build_info(
if total == 0 {
0.0
} else {
((processed as f64 / total as f64).min(1.0)) * 100.0
((f64::from(u32::try_from(processed).unwrap_or(u32::MAX)) / f64::from(u32::try_from(total).unwrap_or(1))).min(1.0)) * 100.0
}
});
@@ -673,7 +671,7 @@ async fn table_index_definitions(
.client
.query(info_query)
.await
.with_context(|| format!("fetching table info for {}", table))?;
.with_context(|| format!("fetching table info for {table}"))?;
let info: surrealdb::Value = response
.take(0)
@@ -700,12 +698,15 @@ async fn index_exists(db: &SurrealDbClient, table: &str, index_name: &str) -> Re
#[cfg(test)]
mod tests {
use super::*;
use anyhow::{self, Context};
use crate::storage::db::SurrealDbClient;
use serde_json::json;
use uuid::Uuid;
use super::*;
#[test]
fn parse_index_build_info_reports_progress() {
fn parse_index_build_info_reports_progress() -> anyhow::Result<()> {
let info = json!({
"building": {
"initial": 56894,
@@ -715,7 +716,8 @@ mod tests {
}
});
let snapshot = parse_index_build_info(Some(info), Some(61081)).expect("snapshot");
let snapshot = parse_index_build_info(Some(info), Some(61081))
.context("snapshot")?;
assert_eq!(
snapshot,
IndexBuildSnapshot {
@@ -729,16 +731,19 @@ mod tests {
}
);
assert!(!snapshot.is_ready());
Ok(())
}
#[test]
fn parse_index_build_info_defaults_to_ready_when_no_building_block() {
fn parse_index_build_info_defaults_to_ready_when_no_building_block() -> anyhow::Result<()> {
// Surreal returns `{}` when the index exists but isn't building.
let info = json!({});
let snapshot = parse_index_build_info(Some(info), Some(10)).expect("snapshot");
let snapshot = parse_index_build_info(Some(info), Some(10))
.context("snapshot")?;
assert!(snapshot.is_ready());
assert_eq!(snapshot.processed, 0);
assert_eq!(snapshot.progress_pct, Some(0.0));
Ok(())
}
#[test]
@@ -748,48 +753,40 @@ mod tests {
}
#[tokio::test]
async fn ensure_runtime_indexes_is_idempotent() {
async fn ensure_runtime_is_idempotent() -> anyhow::Result<()> {
let namespace = "indexes_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("in-memory db");
.context("in-memory db")?;
db.apply_migrations()
.await
.expect("migrations should succeed");
.context("migrations should succeed")?;
// First run creates everything
ensure_runtime_indexes(&db, 1536)
.await
.expect("initial index creation");
// Second run should be a no-op and still succeed
ensure_runtime_indexes(&db, 1536)
.await
.expect("second index creation");
ensure_runtime(&db, 1536).await
.context("first call should succeed")?;
ensure_runtime(&db, 1536).await
.context("second index creation")?;
Ok(())
}
#[tokio::test]
async fn ensure_hnsw_index_overwrites_dimension() {
async fn ensure_hnsw_index_overwrites_dimension() -> anyhow::Result<()> {
let namespace = "indexes_dim";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("in-memory db");
.context("in-memory db")?;
db.apply_migrations()
.await
.expect("migrations should succeed");
.context("migrations should succeed")?;
// Create initial index with default dimension
ensure_runtime_indexes(&db, 1536)
.await
.expect("initial index creation");
// Change dimension and ensure overwrite path is exercised
ensure_runtime_indexes(&db, 128)
.await
.expect("overwritten index creation");
ensure_runtime(&db, 1536).await
.context("initial index creation")?;
ensure_runtime(&db, 128).await
.context("overwritten index creation")?;
Ok(())
}
}
+142 -108
View File
@@ -13,13 +13,13 @@ use object_store::{path::Path as ObjPath, ObjectStore};
use crate::utils::config::{AppConfig, StorageKind};
pub type DynStore = Arc<dyn ObjectStore>;
pub type DynStorage = Arc<dyn ObjectStore>;
/// Storage manager with persistent state and proper lifecycle management.
#[derive(Clone)]
pub struct StorageManager {
// Store from objectstore wrapped as dyn
store: DynStore,
store: DynStorage,
// Simple enum to track which kind
backend_kind: StorageKind,
// Where on disk
@@ -46,7 +46,7 @@ impl StorageManager {
///
/// This method is useful for testing scenarios where you want to inject
/// a specific storage backend.
pub fn with_backend(store: DynStore, backend_kind: StorageKind) -> Self {
pub fn with_backend(store: DynStorage, backend_kind: StorageKind) -> Self {
Self {
store,
backend_kind,
@@ -216,7 +216,7 @@ impl StorageManager {
/// storage backends with proper error handling and validation.
async fn create_storage_backend(
cfg: &AppConfig,
) -> object_store::Result<(DynStore, Option<PathBuf>)> {
) -> object_store::Result<(DynStorage, Option<PathBuf>)> {
match cfg.storage {
StorageKind::Local => {
let base = resolve_base_dir(cfg);
@@ -261,9 +261,7 @@ async fn create_storage_backend(
builder = builder.with_endpoint(endpoint);
}
if let Some(region) = &cfg.s3_region {
builder = builder.with_region(region);
}
builder = builder.with_region(&cfg.s3_region);
let store = builder.build()?;
Ok((Arc::new(store), None))
@@ -342,7 +340,7 @@ pub mod testing {
surrealdb_password: "test".into(),
surrealdb_namespace: "test".into(),
surrealdb_database: "test".into(),
data_dir: base.into(),
data_dir: base,
http_port: 0,
openai_base_url: "..".into(),
storage: StorageKind::Local,
@@ -382,7 +380,7 @@ pub mod testing {
#[derive(Clone)]
pub struct TestStorageManager {
storage: StorageManager,
_temp_dir: Option<(String, std::path::PathBuf)>, // For local storage cleanup
temp_dir: Option<(String, std::path::PathBuf)>, // For local storage cleanup
}
impl TestStorageManager {
@@ -396,7 +394,7 @@ pub mod testing {
Ok(Self {
storage,
_temp_dir: None,
temp_dir: None,
})
}
@@ -413,7 +411,7 @@ pub mod testing {
Ok(Self {
storage,
_temp_dir: resolved,
temp_dir: resolved,
})
}
@@ -437,7 +435,7 @@ pub mod testing {
Ok(Self {
storage,
_temp_dir: None,
temp_dir: None,
})
}
@@ -454,7 +452,7 @@ pub mod testing {
Ok(Self {
storage,
_temp_dir: temp_dir,
temp_dir,
})
}
@@ -508,7 +506,7 @@ pub mod testing {
impl Drop for TestStorageManager {
fn drop(&mut self) {
// Clean up temporary directories for local storage
if let Some((_, path)) = &self._temp_dir {
if let Some((_, path)) = &self.temp_dir {
if path.exists() {
let _ = std::fs::remove_dir_all(path);
}
@@ -584,6 +582,7 @@ pub fn split_object_path(path: &str) -> AnyResult<(String, String)> {
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Context;
use crate::utils::config::{PdfIngestMode::LlmFirst, StorageKind};
use bytes::Bytes;
use uuid::Uuid;
@@ -623,11 +622,11 @@ mod tests {
}
#[tokio::test]
async fn test_storage_manager_memory_basic_operations() {
async fn test_storage_manager_memory_basic_operations() -> anyhow::Result<()> {
let cfg = test_config_memory();
let storage = StorageManager::new(&cfg)
.await
.expect("create storage manager");
.with_context(|| "create storage manager".to_string())?;
assert!(storage.local_base_path().is_none());
let location = "test/data/file.txt";
@@ -637,31 +636,33 @@ mod tests {
storage
.put(location, Bytes::from(data.to_vec()))
.await
.expect("put");
let retrieved = storage.get(location).await.expect("get");
.with_context(|| "put".to_string())?;
let retrieved = storage.get(location).await.with_context(|| "get".to_string())?;
assert_eq!(retrieved.as_ref(), data);
// Test exists
assert!(storage.exists(location).await.expect("exists check"));
assert!(storage.exists(location).await.with_context(|| "exists check".to_string())?);
// Test delete
storage.delete_prefix("test/data/").await.expect("delete");
storage.delete_prefix("test/data/").await.with_context(|| "delete".to_string())?;
assert!(!storage
.exists(location)
.await
.expect("exists check after delete"));
.with_context(|| "exists check after delete".to_string())?);
Ok(())
}
#[tokio::test]
async fn test_storage_manager_local_basic_operations() {
async fn test_storage_manager_local_basic_operations() -> anyhow::Result<()> {
let base = format!("/tmp/minne_storage_test_{}", Uuid::new_v4());
let cfg = test_config(&base);
let storage = StorageManager::new(&cfg)
.await
.expect("create storage manager");
.with_context(|| "create storage manager".to_string())?;
let resolved_base = storage
.local_base_path()
.expect("resolved base dir")
.with_context(|| "resolved base dir".to_string())?
.to_path_buf();
assert_eq!(resolved_base, PathBuf::from(&base));
@@ -672,42 +673,44 @@ mod tests {
storage
.put(location, Bytes::from(data.to_vec()))
.await
.expect("put");
let retrieved = storage.get(location).await.expect("get");
.with_context(|| "put".to_string())?;
let retrieved = storage.get(location).await.with_context(|| "get".to_string())?;
assert_eq!(retrieved.as_ref(), data);
let object_dir = resolved_base.join("test/data");
tokio::fs::metadata(&object_dir)
.await
.expect("object directory exists after write");
.with_context(|| "object directory exists after write".to_string())?;
// Test exists
assert!(storage.exists(location).await.expect("exists check"));
assert!(storage.exists(location).await.with_context(|| "exists check".to_string())?);
// Test delete
storage.delete_prefix("test/data/").await.expect("delete");
storage.delete_prefix("test/data/").await.with_context(|| "delete".to_string())?;
assert!(!storage
.exists(location)
.await
.expect("exists check after delete"));
.with_context(|| "exists check after delete".to_string())?);
assert!(
tokio::fs::metadata(&object_dir).await.is_err(),
"object directory should be removed"
);
tokio::fs::metadata(&resolved_base)
.await
.expect("base directory remains intact");
.with_context(|| "base directory remains intact".to_string())?;
// Clean up
let _ = tokio::fs::remove_dir_all(&base).await;
Ok(())
}
#[tokio::test]
async fn test_storage_manager_memory_persistence() {
async fn test_storage_manager_memory_persistence() -> anyhow::Result<()> {
let cfg = test_config_memory();
let storage = StorageManager::new(&cfg)
.await
.expect("create storage manager");
.with_context(|| "create storage manager".to_string())?;
let location = "persistence/test.txt";
let data1 = b"first data";
@@ -717,32 +720,34 @@ mod tests {
storage
.put(location, Bytes::from(data1.to_vec()))
.await
.expect("put first");
.with_context(|| "put first".to_string())?;
// Retrieve and verify first data
let retrieved1 = storage.get(location).await.expect("get first");
let retrieved1 = storage.get(location).await.with_context(|| "get first".to_string())?;
assert_eq!(retrieved1.as_ref(), data1);
// Overwrite with second data
storage
.put(location, Bytes::from(data2.to_vec()))
.await
.expect("put second");
.with_context(|| "put second".to_string())?;
// Retrieve and verify second data
let retrieved2 = storage.get(location).await.expect("get second");
let retrieved2 = storage.get(location).await.with_context(|| "get second".to_string())?;
assert_eq!(retrieved2.as_ref(), data2);
// Data persists across multiple operations using the same StorageManager
assert_ne!(retrieved1.as_ref(), retrieved2.as_ref());
Ok(())
}
#[tokio::test]
async fn test_storage_manager_list_operations() {
async fn test_storage_manager_list_operations() -> anyhow::Result<()> {
let cfg = test_config_memory();
let storage = StorageManager::new(&cfg)
.await
.expect("create storage manager");
.with_context(|| "create storage manager".to_string())?;
// Create multiple files
let files = vec![
@@ -755,15 +760,15 @@ mod tests {
storage
.put(location, Bytes::from(data.to_vec()))
.await
.expect("put");
.with_context(|| "put".to_string())?;
}
// Test listing without prefix
let all_files = storage.list(None).await.expect("list all");
let all_files = storage.list(None).await.with_context(|| "list all".to_string())?;
assert_eq!(all_files.len(), 3);
// Test listing with prefix
let dir1_files = storage.list(Some("dir1/")).await.expect("list dir1");
let dir1_files = storage.list(Some("dir1/")).await.with_context(|| "list dir1".to_string())?;
assert_eq!(dir1_files.len(), 2);
assert!(dir1_files
.iter()
@@ -776,16 +781,18 @@ mod tests {
let empty_files = storage
.list(Some("nonexistent/"))
.await
.expect("list nonexistent");
.with_context(|| "list nonexistent".to_string())?;
assert_eq!(empty_files.len(), 0);
Ok(())
}
#[tokio::test]
async fn test_storage_manager_stream_operations() {
async fn test_storage_manager_stream_operations() -> anyhow::Result<()> {
let cfg = test_config_memory();
let storage = StorageManager::new(&cfg)
.await
.expect("create storage manager");
.with_context(|| "create storage manager".to_string())?;
let location = "stream/test.bin";
let content = vec![42u8; 1024 * 64]; // 64KB of data
@@ -794,22 +801,24 @@ mod tests {
storage
.put(location, Bytes::from(content.clone()))
.await
.expect("put large data");
.with_context(|| "put large data".to_string())?;
// Get as stream
let mut stream = storage.get_stream(location).await.expect("get stream");
let mut stream = storage.get_stream(location).await.with_context(|| "get stream".to_string())?;
let mut collected = Vec::new();
while let Some(chunk) = stream.next().await {
let chunk = chunk.expect("stream chunk");
let chunk = chunk.with_context(|| "stream chunk".to_string())?;
collected.extend_from_slice(&chunk);
}
assert_eq!(collected, content);
Ok(())
}
#[tokio::test]
async fn test_storage_manager_with_custom_backend() {
async fn test_storage_manager_with_custom_backend() -> anyhow::Result<()> {
use object_store::memory::InMemory;
// Create custom memory backend
@@ -823,20 +832,22 @@ mod tests {
storage
.put(location, Bytes::from(data.to_vec()))
.await
.expect("put");
let retrieved = storage.get(location).await.expect("get");
.with_context(|| "put".to_string())?;
let retrieved = storage.get(location).await.with_context(|| "get".to_string())?;
assert_eq!(retrieved.as_ref(), data);
assert!(storage.exists(location).await.expect("exists"));
assert!(storage.exists(location).await.with_context(|| "exists".to_string())?);
assert_eq!(*storage.backend_kind(), StorageKind::Memory);
Ok(())
}
#[tokio::test]
async fn test_storage_manager_error_handling() {
async fn test_storage_manager_error_handling() -> anyhow::Result<()> {
let cfg = test_config_memory();
let storage = StorageManager::new(&cfg)
.await
.expect("create storage manager");
.with_context(|| "create storage manager".to_string())?;
// Test getting non-existent file
let result = storage.get("nonexistent.txt").await;
@@ -846,124 +857,136 @@ mod tests {
let exists = storage
.exists("nonexistent.txt")
.await
.expect("exists check");
.with_context(|| "exists check".to_string())?;
assert!(!exists);
// Test listing with invalid location (should not panic)
let _result = storage.get("").await;
// This may or may not error depending on the backend implementation
// The important thing is that it doesn't panic
Ok(())
}
// TestStorageManager tests
#[tokio::test]
async fn test_test_storage_manager_memory() {
async fn test_test_storage_manager_memory() -> anyhow::Result<()> {
let test_storage = testing::TestStorageManager::new_memory()
.await
.expect("create test storage");
.with_context(|| "create test storage".to_string())?;
let location = "test/storage/file.txt";
let data = b"test data with TestStorageManager";
// Test put and get
test_storage.put(location, data).await.expect("put");
let retrieved = test_storage.get(location).await.expect("get");
test_storage.put(location, data).await.with_context(|| "put".to_string())?;
let retrieved = test_storage.get(location).await.with_context(|| "get".to_string())?;
assert_eq!(retrieved.as_ref(), data);
// Test existence check
assert!(test_storage.exists(location).await.expect("exists"));
assert!(test_storage.exists(location).await.with_context(|| "exists".to_string())?);
// Test list
let files = test_storage
.list(Some("test/storage/"))
.await
.expect("list");
.with_context(|| "list".to_string())?;
assert_eq!(files.len(), 1);
// Test delete
test_storage
.delete_prefix("test/storage/")
.await
.expect("delete");
.with_context(|| "delete".to_string())?;
assert!(!test_storage
.exists(location)
.await
.expect("exists after delete"));
.with_context(|| "exists after delete".to_string())?);
Ok(())
}
#[tokio::test]
async fn test_test_storage_manager_local() {
async fn test_test_storage_manager_local() -> anyhow::Result<()> {
let test_storage = testing::TestStorageManager::new_local()
.await
.expect("create test storage");
.with_context(|| "create test storage".to_string())?;
let location = "test/local/file.txt";
let data = b"test data with local TestStorageManager";
// Test put and get
test_storage.put(location, data).await.expect("put");
let retrieved = test_storage.get(location).await.expect("get");
test_storage.put(location, data).await
.with_context(|| "put".to_string())?;
let retrieved = test_storage.get(location).await
.with_context(|| "get".to_string())?;
assert_eq!(retrieved.as_ref(), data);
// Test existence check
assert!(test_storage.exists(location).await.expect("exists"));
assert!(test_storage.exists(location).await
.with_context(|| "exists".to_string())?);
// The storage should be automatically cleaned up when test_storage is dropped
Ok(())
}
#[tokio::test]
async fn test_test_storage_manager_isolation() {
async fn test_test_storage_manager_isolation() -> anyhow::Result<()> {
let storage1 = testing::TestStorageManager::new_memory()
.await
.expect("create test storage 1");
.with_context(|| "create test storage 1".to_string())?;
let storage2 = testing::TestStorageManager::new_memory()
.await
.expect("create test storage 2");
.with_context(|| "create test storage 2".to_string())?;
let location = "isolation/test.txt";
let data1 = b"storage 1 data";
let data2 = b"storage 2 data";
// Put different data in each storage
storage1.put(location, data1).await.expect("put storage 1");
storage2.put(location, data2).await.expect("put storage 2");
storage1.put(location, data1).await
.with_context(|| "put storage 1".to_string())?;
storage2.put(location, data2).await
.with_context(|| "put storage 2".to_string())?;
// Verify isolation
let retrieved1 = storage1.get(location).await.expect("get storage 1");
let retrieved2 = storage2.get(location).await.expect("get storage 2");
let retrieved1 = storage1.get(location).await
.with_context(|| "get storage 1".to_string())?;
let retrieved2 = storage2.get(location).await
.with_context(|| "get storage 2".to_string())?;
assert_eq!(retrieved1.as_ref(), data1);
assert_eq!(retrieved2.as_ref(), data2);
assert_ne!(retrieved1.as_ref(), retrieved2.as_ref());
Ok(())
}
#[tokio::test]
async fn test_test_storage_manager_config() {
async fn test_test_storage_manager_config() -> anyhow::Result<()> {
let cfg = testing::test_config_memory();
let test_storage = testing::TestStorageManager::with_config(&cfg)
.await
.expect("create test storage with config");
.with_context(|| "create test storage with config".to_string())?;
let location = "config/test.txt";
let data = b"test data with custom config";
test_storage.put(location, data).await.expect("put");
let retrieved = test_storage.get(location).await.expect("get");
test_storage.put(location, data).await
.with_context(|| "put".to_string())?;
let retrieved = test_storage.get(location).await
.with_context(|| "get".to_string())?;
assert_eq!(retrieved.as_ref(), data);
// Verify it's using memory backend
assert_eq!(*test_storage.storage().backend_kind(), StorageKind::Memory);
Ok(())
}
// S3 Tests - Require a reachable MinIO endpoint and test bucket.
// `TestStorageManager::new_s3()` probes connectivity and these tests auto-skip when unavailable.
#[tokio::test]
async fn test_storage_manager_s3_basic_operations() {
async fn test_storage_manager_s3_basic_operations() -> anyhow::Result<()> {
// Skip if S3 connection fails (e.g. no MinIO)
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
eprintln!("Skipping S3 test (setup failed)");
return;
return Ok(());
};
let prefix = format!("test-basic-{}", Uuid::new_v4());
@@ -973,31 +996,33 @@ mod tests {
// Test put
if let Err(e) = storage.put(&location, data).await {
eprintln!("Skipping S3 test (put failed - bucket missing?): {e}");
return;
return Ok(());
}
// Test get
let retrieved = storage.get(&location).await.expect("get");
let retrieved = storage.get(&location).await.with_context(|| "get".to_string())?;
assert_eq!(retrieved.as_ref(), data);
// Test exists
assert!(storage.exists(&location).await.expect("exists"));
assert!(storage.exists(&location).await.with_context(|| "exists".to_string())?);
// Test delete
storage
.delete_prefix(&format!("{prefix}/"))
.await
.expect("delete");
.with_context(|| "delete".to_string())?;
assert!(!storage
.exists(&location)
.await
.expect("exists after delete"));
.with_context(|| "exists after delete".to_string())?);
Ok(())
}
#[tokio::test]
async fn test_storage_manager_s3_list_operations() {
async fn test_storage_manager_s3_list_operations() -> anyhow::Result<()> {
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
return;
return Ok(());
};
let prefix = format!("test-list-{}", Uuid::new_v4());
@@ -1009,23 +1034,25 @@ mod tests {
for (loc, data) in &files {
if storage.put(loc, *data).await.is_err() {
return; // Abort if put fails
return Ok(()); // Abort if put fails
}
}
// List with prefix
let list_prefix = format!("{prefix}/");
let items = storage.list(Some(&list_prefix)).await.expect("list");
let items = storage.list(Some(&list_prefix)).await.with_context(|| "list".to_string())?;
assert_eq!(items.len(), 3);
// Cleanup
storage.delete_prefix(&list_prefix).await.expect("cleanup");
storage.delete_prefix(&list_prefix).await.with_context(|| "cleanup".to_string())?;
Ok(())
}
#[tokio::test]
async fn test_storage_manager_s3_stream_operations() {
async fn test_storage_manager_s3_stream_operations() -> anyhow::Result<()> {
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
return;
return Ok(());
};
let prefix = format!("test-stream-{}", Uuid::new_v4());
@@ -1033,38 +1060,45 @@ mod tests {
let content = vec![42u8; 1024 * 10]; // 10KB
if storage.put(&location, &content).await.is_err() {
return;
return Ok(());
}
let mut stream = storage.get_stream(&location).await.expect("get stream");
let mut stream = storage.get_stream(&location).await.with_context(|| "get stream".to_string())?;
let mut collected = Vec::new();
while let Some(chunk) = stream.next().await {
collected.extend_from_slice(&chunk.expect("chunk"));
collected.extend_from_slice(&chunk.with_context(|| "chunk".to_string())?);
}
assert_eq!(collected, content);
storage
.delete_prefix(&format!("{prefix}/"))
.await
.expect("cleanup");
.with_context(|| "cleanup".to_string())?;
Ok(())
}
#[tokio::test]
async fn test_storage_manager_s3_backend_kind() {
async fn test_storage_manager_s3_backend_kind() -> anyhow::Result<()> {
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
return;
return Ok(());
};
assert_eq!(*storage.storage().backend_kind(), StorageKind::S3);
Ok(())
}
#[tokio::test]
async fn test_storage_manager_s3_error_handling() {
async fn test_storage_manager_s3_error_handling() -> anyhow::Result<()> {
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
return;
return Ok(());
};
let location = format!("nonexistent-{}/file.txt", Uuid::new_v4());
assert!(storage.get(&location).await.is_err());
assert!(!storage.exists(&location).await.expect("exists check"));
// exists may fail if S3 is unavailable; treat error as false
assert!(!storage.exists(&location).await.unwrap_or(false));
Ok(())
}
}
+45 -73
View File
@@ -90,6 +90,7 @@ impl Analytics {
mod tests {
use super::*;
use crate::stored_object;
use anyhow::{self};
use uuid::Uuid;
stored_object!(TestUser, "user", {
@@ -99,18 +100,14 @@ mod tests {
});
#[tokio::test]
async fn test_analytics_initialization() {
async fn test_analytics_initialization() -> anyhow::Result<()> {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
let db = SurrealDbClient::memory(namespace, database).await?;
// Test initialization of analytics
let analytics = Analytics::ensure_initialized(&db)
.await
.expect("Failed to initialize analytics");
let analytics = Analytics::ensure_initialized(&db).await?;
// Verify initial state after initialization
assert_eq!(analytics.id, "current");
@@ -118,159 +115,134 @@ mod tests {
assert_eq!(analytics.visitors, 0);
// Test idempotency - ensure calling it again doesn't change anything
let analytics_again = Analytics::ensure_initialized(&db)
.await
.expect("Failed to get analytics after initialization");
let analytics_again = Analytics::ensure_initialized(&db).await?;
assert_eq!(analytics.id, analytics_again.id);
assert_eq!(analytics.page_loads, analytics_again.page_loads);
assert_eq!(analytics.visitors, analytics_again.visitors);
Ok(())
}
#[tokio::test]
async fn test_get_current_analytics() {
async fn test_get_current_analytics() -> anyhow::Result<()> {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
let db = SurrealDbClient::memory(namespace, database).await?;
// Initialize analytics
Analytics::ensure_initialized(&db)
.await
.expect("Failed to initialize analytics");
Analytics::ensure_initialized(&db).await?;
// Test get_current method
let analytics = Analytics::get_current(&db)
.await
.expect("Failed to get current analytics");
let analytics = Analytics::get_current(&db).await?;
assert_eq!(analytics.id, "current");
assert_eq!(analytics.page_loads, 0);
assert_eq!(analytics.visitors, 0);
Ok(())
}
#[tokio::test]
async fn test_increment_visitors() {
async fn test_increment_visitors() -> anyhow::Result<()> {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
let db = SurrealDbClient::memory(namespace, database).await?;
// Initialize analytics
Analytics::ensure_initialized(&db)
.await
.expect("Failed to initialize analytics");
Analytics::ensure_initialized(&db).await?;
// Test increment_visitors method
let analytics = Analytics::increment_visitors(&db)
.await
.expect("Failed to increment visitors");
let analytics = Analytics::increment_visitors(&db).await?;
assert_eq!(analytics.visitors, 1);
assert_eq!(analytics.page_loads, 0);
// Increment again and check
let analytics = Analytics::increment_visitors(&db)
.await
.expect("Failed to increment visitors again");
let analytics = Analytics::increment_visitors(&db).await?;
assert_eq!(analytics.visitors, 2);
assert_eq!(analytics.page_loads, 0);
Ok(())
}
#[tokio::test]
async fn test_increment_page_loads() {
async fn test_increment_page_loads() -> anyhow::Result<()> {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
let db = SurrealDbClient::memory(namespace, database).await?;
// Initialize analytics
Analytics::ensure_initialized(&db)
.await
.expect("Failed to initialize analytics");
Analytics::ensure_initialized(&db).await?;
// Test increment_page_loads method
let analytics = Analytics::increment_page_loads(&db)
.await
.expect("Failed to increment page loads");
let analytics = Analytics::increment_page_loads(&db).await?;
assert_eq!(analytics.visitors, 0);
assert_eq!(analytics.page_loads, 1);
// Increment again and check
let analytics = Analytics::increment_page_loads(&db)
.await
.expect("Failed to increment page loads again");
let analytics = Analytics::increment_page_loads(&db).await?;
assert_eq!(analytics.visitors, 0);
assert_eq!(analytics.page_loads, 2);
Ok(())
}
#[tokio::test]
async fn test_get_users_amount() {
async fn test_get_users_amount() -> anyhow::Result<()> {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
let db = SurrealDbClient::memory(namespace, database).await?;
// Test with no users
let count = Analytics::get_users_amount(&db)
.await
.expect("Failed to get users amount");
let count = Analytics::get_users_amount(&db).await?;
assert_eq!(count, 0);
// Create a few test users
for i in 0..3 {
let user = TestUser {
id: format!("user{}", i),
email: format!("user{}@example.com", i),
id: format!("user{i}"),
email: format!("user{i}@example.com"),
password: "password".to_string(),
user_id: format!("uid{}", i),
user_id: format!("uid{i}"),
created_at: Utc::now(),
updated_at: Utc::now(),
};
db.store_item(user)
.await
.expect("Failed to create test user");
db.store_item(user).await?;
}
// Test users amount after adding users
let count = Analytics::get_users_amount(&db)
.await
.expect("Failed to get users amount after adding users");
let count = Analytics::get_users_amount(&db).await?;
assert_eq!(count, 3);
Ok(())
}
#[tokio::test]
async fn test_get_current_nonexistent() {
async fn test_get_current_nonexistent() -> anyhow::Result<()> {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
let db = SurrealDbClient::memory(namespace, database).await?;
// Don't initialize analytics and try to get it
let result = Analytics::get_current(&db).await;
assert!(result.is_err());
if let Err(err) = result {
match err {
AppError::NotFound(_) => {
// Expected error
}
_ => panic!("Expected NotFound error, got: {:?}", err),
}
match result {
Ok(_) => anyhow::bail!("Expected NotFound error, got success"),
Err(AppError::NotFound(_)) => {}
Err(err) => anyhow::bail!("Expected NotFound error, got: {err:?}"),
}
Ok(())
}
}
+53 -60
View File
@@ -144,76 +144,71 @@ impl Conversation {
#[cfg(test)]
mod tests {
use anyhow::{self, Context};
use crate::storage::types::message::MessageRole;
use super::*;
#[tokio::test]
async fn test_create_conversation() {
// Setup in-memory database for testing
async fn test_create_conversation() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Create a new conversation
let user_id = "test_user";
let title = "Test Conversation";
let conversation = Conversation::new(user_id.to_string(), title.to_string());
// Verify conversation properties
assert_eq!(conversation.user_id, user_id);
assert_eq!(conversation.title, title);
assert!(!conversation.id.is_empty());
// Store the conversation
let result = db.store_item(conversation.clone()).await;
assert!(result.is_ok());
// Verify it can be retrieved
let retrieved: Option<Conversation> = db
.get_item(&conversation.id)
.await
.expect("Failed to retrieve conversation");
assert!(retrieved.is_some());
.with_context(|| "Failed to retrieve conversation".to_string())?;
let retrieved = retrieved.unwrap();
let retrieved = retrieved.ok_or_else(|| anyhow::anyhow!("Expected conversation to exist"))?;
assert_eq!(retrieved.id, conversation.id);
assert_eq!(retrieved.user_id, user_id);
assert_eq!(retrieved.title, title);
Ok(())
}
#[tokio::test]
async fn test_get_complete_conversation_not_found() {
// Setup in-memory database for testing
async fn test_get_complete_conversation_not_found() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Try to get a conversation that doesn't exist
let result =
Conversation::get_complete_conversation("nonexistent_id", "test_user", &db).await;
assert!(result.is_err());
match result {
Err(AppError::NotFound(_)) => { /* expected error */ }
_ => panic!("Expected NotFound error"),
Err(AppError::NotFound(_)) => {}
_ => anyhow::bail!("Expected NotFound error"),
}
Ok(())
}
#[tokio::test]
async fn test_get_complete_conversation_unauthorized() {
// Setup in-memory database for testing
async fn test_get_complete_conversation_unauthorized() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Create and store a conversation for user_id_1
let user_id_1 = "user_1";
let conversation =
Conversation::new(user_id_1.to_string(), "Private Conversation".to_string());
@@ -221,27 +216,28 @@ mod tests {
db.store_item(conversation)
.await
.expect("Failed to store conversation");
.with_context(|| "Failed to store conversation".to_string())?;
// Try to access with a different user
let user_id_2 = "user_2";
let result =
Conversation::get_complete_conversation(&conversation_id, user_id_2, &db).await;
assert!(result.is_err());
match result {
Err(AppError::Auth(_)) => { /* expected error */ }
_ => panic!("Expected Auth error"),
Err(AppError::Auth(_)) => {}
_ => anyhow::bail!("Expected Auth error"),
}
Ok(())
}
#[tokio::test]
async fn test_patch_title_success() {
async fn test_patch_title_success() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
let user_id = "user_1";
let original_title = "Original Title";
@@ -250,49 +246,50 @@ mod tests {
db.store_item(conversation)
.await
.expect("Failed to store conversation");
.with_context(|| "Failed to store conversation".to_string())?;
let new_title = "Updated Title";
// Patch title successfully
let result = Conversation::patch_title(&conversation_id, user_id, new_title, &db).await;
assert!(result.is_ok());
// Retrieve from DB to verify
let updated_conversation = db
.get_item::<Conversation>(&conversation_id)
.await
.expect("Failed to get conversation")
.expect("Conversation missing");
.with_context(|| "Failed to get conversation".to_string())?
.ok_or_else(|| anyhow::anyhow!("Conversation missing"))?;
assert_eq!(updated_conversation.title, new_title);
assert_eq!(updated_conversation.user_id, user_id);
Ok(())
}
#[tokio::test]
async fn test_patch_title_not_found() {
async fn test_patch_title_not_found() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Try to patch non-existing conversation
let result = Conversation::patch_title("nonexistent", "user_x", "New Title", &db).await;
assert!(result.is_err());
match result {
Err(AppError::NotFound(_)) => {}
_ => panic!("Expected NotFound error"),
_ => anyhow::bail!("Expected NotFound error"),
}
Ok(())
}
#[tokio::test]
async fn test_patch_title_unauthorized() {
async fn test_patch_title_unauthorized() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
let owner_id = "owner";
let other_user_id = "intruder";
@@ -301,17 +298,18 @@ mod tests {
db.store_item(conversation)
.await
.expect("Failed to store conversation");
.with_context(|| "Failed to store conversation".to_string())?;
// Attempt patch with unauthorized user
let result =
Conversation::patch_title(&conversation_id, other_user_id, "Hacked Title", &db).await;
assert!(result.is_err());
match result {
Err(AppError::Auth(_)) => {}
_ => panic!("Expected Auth error"),
_ => anyhow::bail!("Expected Auth error"),
}
Ok(())
}
#[tokio::test]
@@ -405,24 +403,21 @@ mod tests {
}
#[tokio::test]
async fn test_get_complete_conversation_with_messages() {
// Setup in-memory database for testing
async fn test_get_complete_conversation_with_messages() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Create and store a conversation for user_id_1
let user_id_1 = "user_1";
let conversation = Conversation::new(user_id_1.to_string(), "Conversation".to_string());
let conversation_id = conversation.id.clone();
db.store_item(conversation)
.await
.expect("Failed to store conversation");
.with_context(|| "Failed to store conversation".to_string())?;
// Create messages
let message1 = Message::new(
conversation_id.clone(),
MessageRole::User,
@@ -442,46 +437,44 @@ mod tests {
None,
);
// Store messages
db.store_item(message1)
.await
.expect("Failed to store message1");
.with_context(|| "Failed to store message1".to_string())?;
db.store_item(message2)
.await
.expect("Failed to store message2");
.with_context(|| "Failed to store message2".to_string())?;
db.store_item(message3)
.await
.expect("Failed to store message3");
.with_context(|| "Failed to store message3".to_string())?;
// Retrieve the complete conversation
let result =
Conversation::get_complete_conversation(&conversation_id, user_id_1, &db).await;
assert!(result.is_ok(), "Failed to retrieve complete conversation");
let (retrieved_conversation, messages) = result.unwrap();
let (retrieved_conversation, retrieved_messages) = result
.with_context(|| "Failed to retrieve complete conversation".to_string())?;
// Verify conversation data
assert_eq!(retrieved_conversation.id, conversation_id);
assert_eq!(retrieved_conversation.user_id, user_id_1);
assert_eq!(retrieved_conversation.title, "Conversation");
// Verify messages
assert_eq!(messages.len(), 3);
assert_eq!(retrieved_messages.len(), 3);
// Verify messages are sorted by updated_at
let message_contents: Vec<&str> = messages.iter().map(|m| m.content.as_str()).collect();
let message_contents: Vec<&str> =
retrieved_messages.iter().map(|m| m.content.as_str()).collect();
assert!(message_contents.contains(&"Hello, AI!"));
assert!(message_contents.contains(&"Hello, human! How can I help you today?"));
assert!(message_contents.contains(&"Tell me about Rust programming."));
// Make sure we can't access with different user
let user_id_2 = "user_2";
let unauthorized_result =
Conversation::get_complete_conversation(&conversation_id, user_id_2, &db).await;
assert!(unauthorized_result.is_err());
match unauthorized_result {
Err(AppError::Auth(_)) => { /* expected error */ }
_ => panic!("Expected Auth error"),
Err(AppError::Auth(_)) => {}
_ => anyhow::bail!("Expected Auth error"),
}
Ok(())
}
}
+152 -145
View File
@@ -320,6 +320,8 @@ impl FileInfo {
#[cfg(test)]
mod tests {
use anyhow::{self, Context};
use super::*;
use crate::storage::store::testing::TestStorageManager;
use axum::http::HeaderMap;
@@ -328,11 +330,11 @@ mod tests {
use tempfile::NamedTempFile;
/// Creates a test temporary file with the given content
fn create_test_file(content: &[u8], file_name: &str) -> FieldData<NamedTempFile> {
let mut temp_file = NamedTempFile::new().expect("Failed to create temp file");
fn create_test_file(content: &[u8], file_name: &str) -> anyhow::Result<FieldData<NamedTempFile>> {
let mut temp_file = NamedTempFile::new().with_context(|| "Failed to create temp file".to_string())?;
temp_file
.write_all(content)
.expect("Failed to write to temp file");
.with_context(|| "Failed to write to temp file".to_string())?;
let metadata = FieldMetadata {
name: Some("file".to_string()),
@@ -341,31 +343,29 @@ mod tests {
headers: HeaderMap::default(),
};
let field_data = FieldData {
Ok(FieldData {
metadata,
contents: temp_file,
};
field_data
})
}
#[tokio::test]
async fn test_fileinfo_create_read_delete_with_storage_manager() {
async fn test_fileinfo_create_read_delete_with_storage_manager() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.unwrap();
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
let content = b"This is a test file for StorageManager operations";
let file_name = "storage_manager_test.txt";
let field_data = create_test_file(content, file_name);
let field_data = create_test_file(content, file_name)?;
// Create test storage manager (memory backend)
let test_storage = store::testing::TestStorageManager::new_memory()
.await
.expect("Failed to create test storage manager");
.with_context(|| "Failed to create test storage manager".to_string())?;
// Create a FileInfo instance with storage manager
let user_id = "test_user";
@@ -374,20 +374,20 @@ mod tests {
let file_info =
FileInfo::new_with_storage(field_data, &db, user_id, test_storage.storage())
.await
.expect("Failed to create file with StorageManager");
.with_context(|| "Failed to create file with StorageManager".to_string())?;
assert_eq!(file_info.file_name, file_name);
// Verify the file exists via StorageManager and has correct content
let bytes = file_info
.get_content_with_storage(test_storage.storage())
.await
.expect("Failed to read file content via StorageManager");
.with_context(|| "Failed to read file content via StorageManager".to_string())?;
assert_eq!(bytes.as_ref(), content);
// Test file reading
let retrieved = FileInfo::get_by_id(&file_info.id, &db)
.await
.expect("Failed to retrieve file info");
.with_context(|| "Failed to retrieve file info".to_string())?;
assert_eq!(retrieved.id, file_info.id);
assert_eq!(retrieved.sha256, file_info.sha256);
assert_eq!(retrieved.file_name, file_name);
@@ -395,65 +395,65 @@ mod tests {
// Test file deletion with StorageManager
FileInfo::delete_by_id_with_storage(&file_info.id, &db, test_storage.storage())
.await
.expect("Failed to delete file with StorageManager");
.with_context(|| "Failed to delete file with StorageManager".to_string())?;
let deleted_result = file_info
.get_content_with_storage(test_storage.storage())
.await;
assert!(deleted_result.is_err(), "File should be deleted");
// No cleanup needed - TestStorageManager handles it automatically
Ok(())
}
#[tokio::test]
async fn test_fileinfo_preserves_original_filename_and_sanitizes_path() {
async fn test_fileinfo_preserves_original_filename_and_sanitizes_path() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.unwrap();
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
let content = b"filename sanitization";
let original_name = "Complex name (1).txt";
let expected_sanitized = "Complex_name__1_.txt";
let field_data = create_test_file(content, original_name);
let field_data = create_test_file(content, original_name)?;
let test_storage = store::testing::TestStorageManager::new_memory()
.await
.expect("Failed to create test storage manager");
.with_context(|| "Failed to create test storage manager".to_string())?;
let file_info =
FileInfo::new_with_storage(field_data, &db, "sanitized_user", test_storage.storage())
.await
.expect("Failed to create file via storage manager");
.with_context(|| "Failed to create file via storage manager".to_string())?;
assert_eq!(file_info.file_name, original_name);
let stored_name = Path::new(&file_info.path)
.file_name()
.and_then(|name| name.to_str())
.expect("stored name");
.with_context(|| "stored name".to_string())?;
assert_eq!(stored_name, expected_sanitized);
Ok(())
}
#[tokio::test]
async fn test_fileinfo_duplicate_detection_with_storage_manager() {
async fn test_fileinfo_duplicate_detection_with_storage_manager() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.unwrap();
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
let content = b"This is a test file for StorageManager duplicate detection";
let file_name = "storage_manager_duplicate.txt";
let field_data = create_test_file(content, file_name);
let field_data = create_test_file(content, file_name)?;
// Create test storage manager
let test_storage = store::testing::TestStorageManager::new_memory()
.await
.expect("Failed to create test storage manager");
.with_context(|| "Failed to create test storage manager".to_string())?;
// Create a FileInfo instance with storage manager
let user_id = "test_user";
@@ -462,17 +462,17 @@ mod tests {
let original_file_info =
FileInfo::new_with_storage(field_data, &db, user_id, test_storage.storage())
.await
.expect("Failed to create original file with StorageManager");
.with_context(|| "Failed to create original file with StorageManager".to_string())?;
// Create another file with the same content but different name
let duplicate_name = "storage_manager_duplicate_2.txt";
let field_data2 = create_test_file(content, duplicate_name);
let field_data2 = create_test_file(content, duplicate_name)?;
// The system should detect it's the same file and return the original FileInfo
let duplicate_file_info =
FileInfo::new_with_storage(field_data2, &db, user_id, test_storage.storage())
.await
.expect("Failed to process duplicate file with StorageManager");
.with_context(|| "Failed to process duplicate file with StorageManager".to_string())?;
// Verify duplicate detection worked
assert_eq!(duplicate_file_info.id, original_file_info.id);
@@ -484,46 +484,44 @@ mod tests {
let original_content = original_file_info
.get_content_with_storage(test_storage.storage())
.await
.unwrap();
.with_context(|| "get original content".to_string())?;
let duplicate_content = duplicate_file_info
.get_content_with_storage(test_storage.storage())
.await
.unwrap();
.with_context(|| "get duplicate content".to_string())?;
assert_eq!(original_content.as_ref(), content);
assert_eq!(duplicate_content.as_ref(), content);
// Clean up
FileInfo::delete_by_id_with_storage(&original_file_info.id, &db, test_storage.storage())
.await
.expect("Failed to delete original file with StorageManager");
.with_context(|| "Failed to delete original file with StorageManager".to_string())?;
Ok(())
}
#[tokio::test]
async fn test_file_creation() {
async fn test_file_creation() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations()
.await
.expect("Failed to apply migrations");
.with_context(|| "Failed to apply migrations".to_string())?;
let content = b"This is a test file content";
let file_name = "test_file.txt";
let field_data = create_test_file(content, file_name);
let field_data = create_test_file(content, file_name)?;
// Create a FileInfo instance with StorageManager
let user_id = "test_user";
let test_storage = TestStorageManager::new_memory()
.await
.expect("create test storage manager");
.with_context(|| "create test storage manager".to_string())?;
let file_info =
FileInfo::new_with_storage(field_data, &db, user_id, test_storage.storage()).await;
// Verify the FileInfo was created successfully
assert!(file_info.is_ok());
let file_info = file_info.unwrap();
FileInfo::new_with_storage(field_data, &db, user_id, test_storage.storage())
.await?;
// Check essential properties
assert!(!file_info.id.is_empty());
@@ -533,32 +531,32 @@ mod tests {
// path should be logical: "user_id/uuid/file_name"
let parts: Vec<&str> = file_info.path.split('/').collect();
assert_eq!(parts.len(), 3);
assert_eq!(parts[0], user_id);
assert_eq!(parts[2], file_name);
assert_eq!(parts.first(), Some(&user_id));
assert_eq!(parts.get(2), Some(&file_name));
assert!(file_info.mime_type.contains("text/plain"));
// Verify it's in the database
let stored: Option<FileInfo> = db
.get_item(&file_info.id)
let stored = db
.get_item::<FileInfo>(&file_info.id)
.await
.expect("Failed to retrieve file info");
assert!(stored.is_some());
let stored = stored.unwrap();
.with_context(|| "Failed to retrieve file info".to_string())?
.with_context(|| "expected stored file".to_string())?;
assert_eq!(stored.id, file_info.id);
assert_eq!(stored.file_name, file_info.file_name);
assert_eq!(stored.sha256, file_info.sha256);
Ok(())
}
#[tokio::test]
async fn test_file_duplicate_detection() {
async fn test_file_duplicate_detection() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations()
.await
.expect("Failed to apply migrations");
.with_context(|| "Failed to apply migrations".to_string())?;
// First, store a file with known content
let content = b"This is a test file for duplicate detection";
@@ -567,23 +565,23 @@ mod tests {
let test_storage = TestStorageManager::new_memory()
.await
.expect("create test storage manager");
.with_context(|| "create test storage manager".to_string())?;
let field_data1 = create_test_file(content, file_name);
let field_data1 = create_test_file(content, file_name)?;
let original_file_info =
FileInfo::new_with_storage(field_data1, &db, user_id, test_storage.storage())
.await
.expect("Failed to create original file");
.with_context(|| "Failed to create original file".to_string())?;
// Now try to store another file with the same content but different name
let duplicate_name = "duplicate.txt";
let field_data2 = create_test_file(content, duplicate_name);
let field_data2 = create_test_file(content, duplicate_name)?;
// The system should detect it's the same file and return the original FileInfo
let duplicate_file_info =
FileInfo::new_with_storage(field_data2, &db, user_id, test_storage.storage())
.await
.expect("Failed to process duplicate file");
.with_context(|| "Failed to process duplicate file".to_string())?;
// The returned FileInfo should match the original
assert_eq!(duplicate_file_info.id, original_file_info.id);
@@ -592,10 +590,11 @@ mod tests {
// But it should retain the original file name, not the duplicate's name
assert_eq!(duplicate_file_info.file_name, file_name);
assert_ne!(duplicate_file_info.file_name, duplicate_name);
Ok(())
}
#[tokio::test]
async fn test_guess_mime_type() {
async fn test_guess_mime_type() -> anyhow::Result<()> {
// Test common file extensions
assert_eq!(
FileInfo::guess_mime_type(Path::new("test.txt")),
@@ -619,10 +618,11 @@ mod tests {
FileInfo::guess_mime_type(Path::new("unknown.929yz")),
"application/octet-stream".to_string()
);
Ok(())
}
#[tokio::test]
async fn test_sanitize_file_name() {
async fn test_sanitize_file_name() -> anyhow::Result<()> {
// Safe characters should remain unchanged
assert_eq!(
FileInfo::sanitize_file_name("normal_file.txt"),
@@ -647,26 +647,26 @@ mod tests {
FileInfo::sanitize_file_name("../dangerous.txt"),
"___dangerous.txt"
);
Ok(())
}
#[tokio::test]
async fn test_get_by_sha_not_found() {
async fn test_get_by_sha_not_found() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Try to find a file with a SHA that doesn't exist
let result = FileInfo::get_by_sha("nonexistent_sha_hash", &db).await;
assert!(result.is_err());
match result {
Err(FileError::FileNotFound(_)) => {
// Expected error
}
_ => panic!("Expected FileNotFound error"),
Err(FileError::FileNotFound(_)) => {}
_ => anyhow::bail!("Expected FileNotFound error"),
}
Ok(())
}
#[tokio::test]
@@ -705,7 +705,7 @@ mod tests {
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Create a FileInfo instance directly
let now = Utc::now();
@@ -725,40 +725,39 @@ mod tests {
assert!(result.is_ok());
// Verify it can be retrieved
let retrieved: Option<FileInfo> = db
.get_item(&file_info.id)
let retrieved = db
.get_item::<FileInfo>(&file_info.id)
.await
.expect("Failed to retrieve file info");
assert!(retrieved.is_some());
let retrieved = retrieved.unwrap();
.with_context(|| "Failed to retrieve file info".to_string())?
.with_context(|| "expected file".to_string())?;
assert_eq!(retrieved.id, file_info.id);
assert_eq!(retrieved.sha256, file_info.sha256);
assert_eq!(retrieved.file_name, file_info.file_name);
assert_eq!(retrieved.path, file_info.path);
assert_eq!(retrieved.mime_type, file_info.mime_type);
Ok(())
}
#[tokio::test]
async fn test_delete_by_id() {
async fn test_delete_by_id() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations()
.await
.expect("Failed to apply migrations");
.with_context(|| "Failed to apply migrations".to_string())?;
// Create and persist a test file via FileInfo::new_with_storage
let user_id = "user123";
let test_storage = TestStorageManager::new_memory()
.await
.expect("create test storage manager");
let temp = create_test_file(b"test content", "test_file.txt");
.with_context(|| "create test storage manager".to_string())?;
let temp = create_test_file(b"test content", "test_file.txt")?;
let file_info = FileInfo::new_with_storage(temp, &db, user_id, test_storage.storage())
.await
.expect("create file");
.with_context(|| "create file".to_string())?;
// Delete the file using StorageManager
let delete_result =
@@ -767,15 +766,14 @@ mod tests {
// Delete should be successful
assert!(
delete_result.is_ok(),
"Failed to delete file: {:?}",
delete_result
"Failed to delete file: {delete_result:?}"
);
// Verify the file is removed from the database
let retrieved: Option<FileInfo> = db
.get_item(&file_info.id)
.await
.expect("Failed to query database");
.with_context(|| "Failed to query database".to_string())?;
assert!(
retrieved.is_none(),
"FileInfo should be deleted from the database"
@@ -783,32 +781,37 @@ mod tests {
// Verify content no longer retrievable from storage
assert!(test_storage.storage().get(&file_info.path).await.is_err());
Ok(())
}
#[tokio::test]
async fn test_delete_by_id_not_found() {
async fn test_delete_by_id_not_found() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Try to delete a file that doesn't exist
let test_storage = TestStorageManager::new_memory().await.unwrap();
let test_storage = TestStorageManager::new_memory()
.await
.with_context(|| "create test storage manager".to_string())?;
let result =
FileInfo::delete_by_id_with_storage("nonexistent_id", &db, test_storage.storage())
.await;
// Should succeed even if the file record does not exist
assert!(result.is_ok());
Ok(())
}
#[tokio::test]
async fn test_get_by_id() {
async fn test_get_by_id() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Create a FileInfo instance directly
let now = Utc::now();
@@ -827,28 +830,27 @@ mod tests {
// Store it in the database
db.store_item(original_file_info.clone())
.await
.expect("Failed to store item for get_by_id test");
.with_context(|| "Failed to store item for get_by_id test".to_string())?;
// Retrieve it using get_by_id
let result = FileInfo::get_by_id(&file_id, &db).await;
// Assert success and content match
assert!(result.is_ok());
let retrieved_info = result.unwrap();
let retrieved_info = FileInfo::get_by_id(&file_id, &db)
.await
.with_context(|| "get_by_id".to_string())?;
assert_eq!(retrieved_info.id, original_file_info.id);
assert_eq!(retrieved_info.sha256, original_file_info.sha256);
assert_eq!(retrieved_info.file_name, original_file_info.file_name);
assert_eq!(retrieved_info.path, original_file_info.path);
assert_eq!(retrieved_info.mime_type, original_file_info.mime_type);
Ok(())
}
#[tokio::test]
async fn test_get_by_id_not_found() {
async fn test_get_by_id_not_found() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Try to retrieve a non-existent ID
let non_existent_id = "non-existent-file-id";
@@ -862,33 +864,34 @@ mod tests {
Err(FileError::FileNotFound(id)) => {
assert_eq!(id, non_existent_id);
}
Err(e) => panic!("Expected FileNotFound error, but got {:?}", e),
Ok(_) => panic!("Expected an error, but got Ok"),
Err(e) => anyhow::bail!("Expected FileNotFound error, but got {e:?}"),
Ok(_) => anyhow::bail!("Expected an error, but got Ok"),
}
Ok(())
}
// StorageManager-based tests
#[tokio::test]
async fn test_file_info_new_with_storage_memory() {
async fn test_file_info_new_with_storage_memory() -> anyhow::Result<()> {
// Setup
let db = SurrealDbClient::memory("test_ns", "test_file_storage_memory")
.await
.unwrap();
db.apply_migrations().await.unwrap();
.with_context(|| "Failed to start DB".to_string())?;
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
let content = b"This is a test file for StorageManager";
let field_data = create_test_file(content, "test_storage.txt");
let field_data = create_test_file(content, "test_storage.txt")?;
let user_id = "test_user";
// Create test storage manager
let storage = store::testing::TestStorageManager::new_memory()
.await
.unwrap();
.with_context(|| "create test storage".to_string())?;
// Test file creation with StorageManager
let file_info = FileInfo::new_with_storage(field_data, &db, user_id, storage.storage())
.await
.expect("Failed to create file with StorageManager");
.with_context(|| "Failed to create file with StorageManager".to_string())?;
// Verify the file was created correctly
assert_eq!(file_info.user_id, user_id);
@@ -900,40 +903,41 @@ mod tests {
let retrieved_content = file_info
.get_content_with_storage(storage.storage())
.await
.expect("Failed to get file content with StorageManager");
.with_context(|| "Failed to get file content with StorageManager".to_string())?;
assert_eq!(retrieved_content.as_ref(), content);
// Test file deletion with StorageManager
FileInfo::delete_by_id_with_storage(&file_info.id, &db, storage.storage())
.await
.expect("Failed to delete file with StorageManager");
.with_context(|| "Failed to delete file with StorageManager".to_string())?;
// Verify file is deleted
let deleted_content_result = file_info.get_content_with_storage(storage.storage()).await;
assert!(deleted_content_result.is_err());
Ok(())
}
#[tokio::test]
async fn test_file_info_new_with_storage_local() {
async fn test_file_info_new_with_storage_local() -> anyhow::Result<()> {
// Setup
let db = SurrealDbClient::memory("test_ns", "test_file_storage_local")
.await
.unwrap();
db.apply_migrations().await.unwrap();
.with_context(|| "Failed to start DB".to_string())?;
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
let content = b"This is a test file for StorageManager with local storage";
let field_data = create_test_file(content, "test_local.txt");
let field_data = create_test_file(content, "test_local.txt")?;
let user_id = "test_user";
// Create test storage manager with local backend
let storage = store::testing::TestStorageManager::new_local()
.await
.unwrap();
.with_context(|| "create test storage".to_string())?;
// Test file creation with StorageManager
let file_info = FileInfo::new_with_storage(field_data, &db, user_id, storage.storage())
.await
.expect("Failed to create file with StorageManager");
.with_context(|| "Failed to create file with StorageManager".to_string())?;
// Verify the file was created correctly
assert_eq!(file_info.user_id, user_id);
@@ -945,50 +949,51 @@ mod tests {
let retrieved_content = file_info
.get_content_with_storage(storage.storage())
.await
.expect("Failed to get file content with StorageManager");
.with_context(|| "Failed to get file content with StorageManager".to_string())?;
assert_eq!(retrieved_content.as_ref(), content);
// Test file deletion with StorageManager
FileInfo::delete_by_id_with_storage(&file_info.id, &db, storage.storage())
.await
.expect("Failed to delete file with StorageManager");
.with_context(|| "Failed to delete file with StorageManager".to_string())?;
// Verify file is deleted
let deleted_content_result = file_info.get_content_with_storage(storage.storage()).await;
assert!(deleted_content_result.is_err());
Ok(())
}
#[tokio::test]
async fn test_file_info_storage_manager_persistence() {
async fn test_file_info_storage_manager_persistence() -> anyhow::Result<()> {
// Setup
let db = SurrealDbClient::memory("test_ns", "test_file_persistence")
.await
.unwrap();
db.apply_migrations().await.unwrap();
.with_context(|| "Failed to start DB".to_string())?;
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
let content = b"Test content for persistence";
let field_data = create_test_file(content, "persistence_test.txt");
let field_data = create_test_file(content, "persistence_test.txt")?;
let user_id = "test_user";
// Create test storage manager
let storage = store::testing::TestStorageManager::new_memory()
.await
.unwrap();
.with_context(|| "create test storage".to_string())?;
// Create file
let file_info = FileInfo::new_with_storage(field_data, &db, user_id, storage.storage())
.await
.expect("Failed to create file");
.with_context(|| "Failed to create file".to_string())?;
// Test that data persists across multiple operations with the same StorageManager
let retrieved_content_1 = file_info
.get_content_with_storage(storage.storage())
.await
.unwrap();
.with_context(|| "get content 1".to_string())?;
let retrieved_content_2 = file_info
.get_content_with_storage(storage.storage())
.await
.unwrap();
.with_context(|| "get content 2".to_string())?;
assert_eq!(retrieved_content_1.as_ref(), content);
assert_eq!(retrieved_content_2.as_ref(), content);
@@ -996,68 +1001,70 @@ mod tests {
// Test that different StorageManager instances don't share data (memory storage isolation)
let storage2 = store::testing::TestStorageManager::new_memory()
.await
.unwrap();
.with_context(|| "create second storage".to_string())?;
let isolated_content_result = file_info.get_content_with_storage(storage2.storage()).await;
assert!(
isolated_content_result.is_err(),
"Different StorageManager should not have access to same data"
);
Ok(())
}
#[tokio::test]
async fn test_file_info_storage_manager_equivalence() {
async fn test_file_info_storage_manager_equivalence() -> anyhow::Result<()> {
// Setup
let db = SurrealDbClient::memory("test_ns", "test_file_equivalence")
.await
.unwrap();
db.apply_migrations().await.unwrap();
.with_context(|| "Failed to start DB".to_string())?;
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
let content = b"Test content for equivalence testing";
let field_data1 = create_test_file(content, "equivalence_test_1.txt");
let field_data2 = create_test_file(content, "equivalence_test_2.txt");
let field_data1 = create_test_file(content, "equivalence_test_1.txt")?;
let field_data2 = create_test_file(content, "equivalence_test_2.txt")?;
let user_id = "test_user";
// Create single storage manager and reuse it
let storage_manager = store::testing::TestStorageManager::new_memory()
.await
.unwrap();
.with_context(|| "create storage".to_string())?;
let storage = storage_manager.storage();
// Create multiple files with the same storage manager
let file_info_1 = FileInfo::new_with_storage(field_data1, &db, user_id, &storage)
let file_info_1 = FileInfo::new_with_storage(field_data1, &db, user_id, storage)
.await
.expect("Failed to create file 1");
.with_context(|| "Failed to create file 1".to_string())?;
let file_info_2 = FileInfo::new_with_storage(field_data2, &db, user_id, &storage)
let file_info_2 = FileInfo::new_with_storage(field_data2, &db, user_id, storage)
.await
.expect("Failed to create file 2");
.with_context(|| "Failed to create file 2".to_string())?;
// Test that both files can be retrieved with the same storage backend
let content_1 = file_info_1
.get_content_with_storage(&storage)
.get_content_with_storage(storage)
.await
.unwrap();
.with_context(|| "get file 1 content".to_string())?;
let content_2 = file_info_2
.get_content_with_storage(&storage)
.get_content_with_storage(storage)
.await
.unwrap();
.with_context(|| "get file 2 content".to_string())?;
assert_eq!(content_1.as_ref(), content);
assert_eq!(content_2.as_ref(), content);
// Test that files can be deleted with the same storage manager
FileInfo::delete_by_id_with_storage(&file_info_1.id, &db, &storage)
FileInfo::delete_by_id_with_storage(&file_info_1.id, &db, storage)
.await
.unwrap();
FileInfo::delete_by_id_with_storage(&file_info_2.id, &db, &storage)
.with_context(|| "delete file 1".to_string())?;
FileInfo::delete_by_id_with_storage(&file_info_2.id, &db, storage)
.await
.unwrap();
.with_context(|| "delete file 2".to_string())?;
// Verify files are deleted
let deleted_content_1 = file_info_1.get_content_with_storage(&storage).await;
let deleted_content_2 = file_info_2.get_content_with_storage(&storage).await;
let deleted_content_1 = file_info_1.get_content_with_storage(storage).await;
let deleted_content_2 = file_info_2.get_content_with_storage(storage).await;
assert!(deleted_content_1.is_err());
assert!(deleted_content_2.is_err());
Ok(())
}
}
+31 -24
View File
@@ -103,6 +103,7 @@ impl IngestionPayload {
#[cfg(test)]
mod tests {
use anyhow::{self, Context};
use chrono::Utc;
use super::*;
@@ -131,7 +132,7 @@ mod tests {
}
#[test]
fn test_create_ingestion_payload_with_url() {
fn test_create_ingestion_payload_with_url() -> anyhow::Result<()> {
let url = "https://example.com";
let context = "Process this URL";
let category = "websites";
@@ -145,10 +146,10 @@ mod tests {
files,
user_id,
)
.unwrap();
.with_context(|| "create_ingestion_payload".to_string())?;
assert_eq!(result.len(), 1);
match &result[0] {
match result.first().context("expected one result")? {
IngestionPayload::Url {
url: payload_url,
context: payload_context,
@@ -156,17 +157,18 @@ mod tests {
user_id: payload_user_id,
} => {
// URL parser may normalize the URL by adding a trailing slash
assert!(payload_url == &url.to_string() || payload_url == &format!("{}/", url));
assert!(payload_url == &url.to_string() || payload_url == &format!("{url}/"));
assert_eq!(payload_context, &context);
assert_eq!(payload_category, &category);
assert_eq!(payload_user_id, &user_id);
}
_ => panic!("Expected Url variant"),
_ => anyhow::bail!("Expected Url variant"),
}
Ok(())
}
#[test]
fn test_create_ingestion_payload_with_text() {
fn test_create_ingestion_payload_with_text() -> anyhow::Result<()> {
let text = "This is some text content";
let context = "Process this text";
let category = "notes";
@@ -180,10 +182,10 @@ mod tests {
files,
user_id,
)
.unwrap();
.with_context(|| "create_ingestion_payload".to_string())?;
assert_eq!(result.len(), 1);
match &result[0] {
match result.first().context("expected one result")? {
IngestionPayload::Text {
text: payload_text,
context: payload_context,
@@ -195,12 +197,13 @@ mod tests {
assert_eq!(payload_category, category);
assert_eq!(payload_user_id, user_id);
}
_ => panic!("Expected Text variant"),
_ => anyhow::bail!("Expected Text variant"),
}
Ok(())
}
#[test]
fn test_create_ingestion_payload_with_file() {
fn test_create_ingestion_payload_with_file() -> anyhow::Result<()> {
let context = "Process this file";
let category = "documents";
let user_id = "user123";
@@ -220,10 +223,10 @@ mod tests {
files,
user_id,
)
.unwrap();
.with_context(|| "create_ingestion_payload".to_string())?;
assert_eq!(result.len(), 1);
match &result[0] {
match result.first().context("expected one result")? {
IngestionPayload::File {
file_info: payload_file_info,
context: payload_context,
@@ -235,12 +238,13 @@ mod tests {
assert_eq!(payload_category, category);
assert_eq!(payload_user_id, user_id);
}
_ => panic!("Expected File variant"),
_ => anyhow::bail!("Expected File variant"),
}
Ok(())
}
#[test]
fn test_create_ingestion_payload_with_url_and_file() {
fn test_create_ingestion_payload_with_url_and_file() -> anyhow::Result<()> {
let url = "https://example.com";
let context = "Process this data";
let category = "mixed";
@@ -261,35 +265,36 @@ mod tests {
files,
user_id,
)
.unwrap();
.with_context(|| "create_ingestion_payload".to_string())?;
assert_eq!(result.len(), 2);
// Check first item is URL
match &result[0] {
match result.first().context("expected first item")? {
IngestionPayload::Url {
url: payload_url, ..
} => {
// URL parser may normalize the URL by adding a trailing slash
assert!(payload_url == &url.to_string() || payload_url == &format!("{}/", url));
assert!(payload_url == &url.to_string() || payload_url == &format!("{url}/"));
}
_ => panic!("Expected first item to be Url variant"),
_ => anyhow::bail!("Expected first item to be Url variant"),
}
// Check second item is File
match &result[1] {
match result.get(1).context("expected second item")? {
IngestionPayload::File {
file_info: payload_file_info,
..
} => {
assert_eq!(payload_file_info.id, file_info.id);
}
_ => panic!("Expected second item to be File variant"),
_ => anyhow::bail!("Expected second item to be File variant"),
}
Ok(())
}
#[test]
fn test_create_ingestion_payload_empty_input() {
fn test_create_ingestion_payload_empty_input() -> anyhow::Result<()> {
let context = "Process something";
let category = "empty";
let user_id = "user123";
@@ -308,12 +313,13 @@ mod tests {
Err(AppError::NotFound(msg)) => {
assert_eq!(msg, "No valid content or files provided");
}
_ => panic!("Expected NotFound error"),
_ => anyhow::bail!("Expected NotFound error"),
}
Ok(())
}
#[test]
fn test_create_ingestion_payload_with_empty_text() {
fn test_create_ingestion_payload_with_empty_text() -> anyhow::Result<()> {
let text = ""; // Empty text
let context = "Process this";
let category = "notes";
@@ -333,7 +339,8 @@ mod tests {
Err(AppError::NotFound(msg)) => {
assert_eq!(msg, "No valid content or files provided");
}
_ => panic!("Expected NotFound error"),
_ => anyhow::bail!("Expected NotFound error"),
}
Ok(())
}
}
+44 -35
View File
@@ -529,6 +529,8 @@ impl IngestionTask {
#[cfg(test)]
mod tests {
use anyhow::{self, Context};
use super::*;
use crate::storage::types::ingestion_payload::IngestionPayload;
@@ -541,16 +543,16 @@ mod tests {
}
}
async fn memory_db() -> SurrealDbClient {
async fn memory_db() -> anyhow::Result<SurrealDbClient> {
let namespace = "test_ns";
let database = Uuid::new_v4().to_string();
SurrealDbClient::memory(namespace, &database)
.await
.expect("in-memory surrealdb")
.with_context(|| "in-memory surrealdb".to_string())
}
#[tokio::test]
async fn test_new_task_defaults() {
async fn test_new_task_defaults() -> anyhow::Result<()> {
let user_id = "user123";
let payload = create_payload(user_id);
let task = IngestionTask::new(payload.clone(), user_id.to_string());
@@ -562,73 +564,76 @@ mod tests {
assert_eq!(task.max_attempts, MAX_ATTEMPTS);
assert!(task.locked_at.is_none());
assert!(task.worker_id.is_none());
Ok(())
}
#[tokio::test]
async fn test_create_and_store_task() {
let db = memory_db().await;
async fn test_create_and_store_task() -> anyhow::Result<()> {
let db = memory_db().await?;
let user_id = "user123";
let payload = create_payload(user_id);
let created =
IngestionTask::create_and_add_to_db(payload.clone(), user_id.to_string(), &db)
.await
.expect("store");
.with_context(|| "store".to_string())?;
let stored: Option<IngestionTask> = db
.get_item::<IngestionTask>(&created.id)
.await
.expect("fetch");
.with_context(|| "fetch".to_string())?;
let stored = stored.expect("task exists");
let stored = stored.with_context(|| "task exists".to_string())?;
assert_eq!(stored.id, created.id);
assert_eq!(stored.state, TaskState::Pending);
assert_eq!(stored.attempts, 0);
Ok(())
}
#[tokio::test]
async fn test_claim_and_transition() {
let db = memory_db().await;
async fn test_claim_and_transition() -> anyhow::Result<()> {
let db = memory_db().await?;
let user_id = "user123";
let payload = create_payload(user_id);
let task = IngestionTask::new(payload, user_id.to_string());
db.store_item(task.clone()).await.expect("store");
db.store_item(task.clone()).await.with_context(|| "store".to_string())?;
let worker_id = "worker-1";
let now = chrono::Utc::now();
let claimed = IngestionTask::claim_next_ready(&db, worker_id, now, Duration::from_secs(60))
.await
.expect("claim");
.with_context(|| "claim".to_string())?
.with_context(|| "task claimed".to_string())?;
let claimed = claimed.expect("task claimed");
assert_eq!(claimed.state, TaskState::Reserved);
assert_eq!(claimed.worker_id.as_deref(), Some(worker_id));
let processing = claimed.mark_processing(&db).await.expect("processing");
let processing = claimed.mark_processing(&db).await.with_context(|| "processing".to_string())?;
assert_eq!(processing.state, TaskState::Processing);
let succeeded = processing.mark_succeeded(&db).await.expect("succeeded");
let succeeded = processing.mark_succeeded(&db).await.with_context(|| "succeeded".to_string())?;
assert_eq!(succeeded.state, TaskState::Succeeded);
assert!(succeeded.worker_id.is_none());
assert!(succeeded.locked_at.is_none());
Ok(())
}
#[tokio::test]
async fn test_fail_and_dead_letter() {
let db = memory_db().await;
async fn test_fail_and_dead_letter() -> anyhow::Result<()> {
let db = memory_db().await?;
let user_id = "user123";
let payload = create_payload(user_id);
let task = IngestionTask::new(payload, user_id.to_string());
db.store_item(task.clone()).await.expect("store");
db.store_item(task.clone()).await.with_context(|| "store".to_string())?;
let worker_id = "worker-dead";
let now = chrono::Utc::now();
let claimed = IngestionTask::claim_next_ready(&db, worker_id, now, Duration::from_secs(60))
.await
.expect("claim")
.expect("claimed");
.with_context(|| "claim".to_string())?
.with_context(|| "claimed".to_string())?;
let processing = claimed.mark_processing(&db).await.expect("processing");
let processing = claimed.mark_processing(&db).await.with_context(|| "processing".to_string())?;
let error_info = TaskErrorInfo {
code: Some("pipeline_error".into()),
@@ -638,7 +643,7 @@ mod tests {
let failed = processing
.mark_failed(error_info.clone(), Duration::from_secs(30), &db)
.await
.expect("failed update");
.with_context(|| "failed update".to_string())?;
assert_eq!(failed.state, TaskState::Failed);
assert_eq!(failed.error_message.as_deref(), Some("failed"));
assert!(failed.worker_id.is_none());
@@ -648,19 +653,20 @@ mod tests {
let dead = failed
.mark_dead_letter(error_info.clone(), &db)
.await
.expect("dead letter");
.with_context(|| "dead letter".to_string())?;
assert_eq!(dead.state, TaskState::DeadLetter);
assert_eq!(dead.error_message.as_deref(), Some("failed"));
Ok(())
}
#[tokio::test]
async fn test_mark_processing_requires_reservation() {
let db = memory_db().await;
async fn test_mark_processing_requires_reservation() -> anyhow::Result<()> {
let db = memory_db().await?;
let user_id = "user123";
let payload = create_payload(user_id);
let task = IngestionTask::new(payload.clone(), user_id.to_string());
db.store_item(task.clone()).await.expect("store");
db.store_item(task.clone()).await.with_context(|| "store".to_string())?;
let err = task
.mark_processing(&db)
@@ -674,18 +680,19 @@ mod tests {
"unexpected message: {message}"
);
}
other => panic!("expected validation error, got {other:?}"),
other => anyhow::bail!("expected validation error, got {other:?}"),
}
Ok(())
}
#[tokio::test]
async fn test_mark_failed_requires_processing() {
let db = memory_db().await;
async fn test_mark_failed_requires_processing() -> anyhow::Result<()> {
let db = memory_db().await?;
let user_id = "user123";
let payload = create_payload(user_id);
let task = IngestionTask::new(payload.clone(), user_id.to_string());
db.store_item(task.clone()).await.expect("store");
db.store_item(task.clone()).await.with_context(|| "store".to_string())?;
let err = task
.mark_failed(
@@ -706,18 +713,19 @@ mod tests {
"unexpected message: {message}"
);
}
other => panic!("expected validation error, got {other:?}"),
other => anyhow::bail!("expected validation error, got {other:?}"),
}
Ok(())
}
#[tokio::test]
async fn test_release_requires_reservation() {
let db = memory_db().await;
async fn test_release_requires_reservation() -> anyhow::Result<()> {
let db = memory_db().await?;
let user_id = "user123";
let payload = create_payload(user_id);
let task = IngestionTask::new(payload.clone(), user_id.to_string());
db.store_item(task.clone()).await.expect("store");
db.store_item(task.clone()).await.with_context(|| "store".to_string())?;
let err = task
.release(&db)
@@ -731,7 +739,8 @@ mod tests {
"unexpected message: {message}"
);
}
other => panic!("expected validation error, got {other:?}"),
other => anyhow::bail!("expected validation error, got {other:?}"),
}
Ok(())
}
}
+98 -82
View File
@@ -5,7 +5,6 @@
clippy::format_push_string,
clippy::uninlined_format_args,
clippy::explicit_iter_loop,
clippy::items_after_statements,
clippy::get_first,
clippy::redundant_closure_for_method_calls
)]
@@ -317,6 +316,11 @@ impl KnowledgeEntity {
}
async fn get_user_id_by_id(id: &str, db_client: &SurrealDbClient) -> Result<String, AppError> {
#[derive(Deserialize)]
struct Row {
user_id: String,
}
let mut response = db_client
.client
.query("SELECT user_id FROM type::thing($table, $id) LIMIT 1")
@@ -324,10 +328,6 @@ impl KnowledgeEntity {
.bind(("id", id.to_string()))
.await
.map_err(AppError::Database)?;
#[derive(Deserialize)]
struct Row {
user_id: String,
}
let rows: Vec<Row> = response.take(0).map_err(AppError::Database)?;
rows.get(0)
.map(|r| r.user_id.clone())
@@ -497,7 +497,6 @@ impl KnowledgeEntity {
new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone()));
}
info!("Successfully generated all new embeddings.");
info!("Successfully generated all new embeddings.");
// Clear existing embeddings and index first to prevent SurrealDB panics and dimension conflicts.
info!("Removing old index and clearing embeddings...");
@@ -572,14 +571,14 @@ impl KnowledgeEntity {
#[cfg(test)]
mod tests {
use anyhow::{self, Context};
use super::*;
use crate::storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding;
use serde_json::json;
use uuid::Uuid;
#[tokio::test]
async fn test_knowledge_entity_creation() {
// Create basic test entity
async fn test_knowledge_entity_creation() -> anyhow::Result<()> {
let source_id = "source123".to_string();
let name = "Test Entity".to_string();
let description = "Test Description".to_string();
@@ -596,7 +595,6 @@ mod tests {
user_id.clone(),
);
// Verify all fields are set correctly
assert_eq!(entity.source_id, source_id);
assert_eq!(entity.name, name);
assert_eq!(entity.description, description);
@@ -604,11 +602,12 @@ mod tests {
assert_eq!(entity.metadata, metadata);
assert_eq!(entity.user_id, user_id);
assert!(!entity.id.is_empty());
Ok(())
}
#[tokio::test]
async fn test_knowledge_entity_type_from_string() {
// Test conversion from String to KnowledgeEntityType
async fn test_knowledge_entity_type_from_string() -> anyhow::Result<()> {
assert_eq!(
KnowledgeEntityType::from("idea".to_string()),
KnowledgeEntityType::Idea
@@ -639,15 +638,16 @@ mod tests {
KnowledgeEntityType::TextSnippet
);
// Test default case
assert_eq!(
KnowledgeEntityType::from("unknown".to_string()),
KnowledgeEntityType::Document
);
Ok(())
}
#[tokio::test]
async fn test_knowledge_entity_variants() {
async fn test_knowledge_entity_variants() -> anyhow::Result<()> {
let variants = KnowledgeEntityType::variants();
assert_eq!(variants.len(), 5);
assert!(variants.contains(&"Idea"));
@@ -655,28 +655,28 @@ mod tests {
assert!(variants.contains(&"Document"));
assert!(variants.contains(&"Page"));
assert!(variants.contains(&"TextSnippet"));
Ok(())
}
#[tokio::test]
async fn test_delete_by_source_id() {
// Setup in-memory database for testing
async fn test_delete_by_source_id() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations()
.await
.expect("Failed to apply migrations");
.with_context(|| "Failed to apply migrations".to_string())?;
// Create two entities with the same source_id
let source_id = "source123".to_string();
let entity_type = KnowledgeEntityType::Document;
let user_id = "user123".to_string();
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 5)
.await
.expect("Failed to redefine index length");
.with_context(|| "Failed to redefine index length".to_string())?;
let entity1 = KnowledgeEntity::new(
source_id.clone(),
@@ -696,7 +696,6 @@ mod tests {
user_id.clone(),
);
// Create an entity with a different source_id
let different_source_id = "different_source".to_string();
let different_entity = KnowledgeEntity::new(
different_source_id.clone(),
@@ -708,23 +707,20 @@ mod tests {
);
let emb = vec![0.1, 0.2, 0.3, 0.4, 0.5];
// Store the entities
KnowledgeEntity::store_with_embedding(entity1.clone(), emb.clone(), &db)
.await
.expect("Failed to store entity 1");
.with_context(|| "Failed to store entity 1".to_string())?;
KnowledgeEntity::store_with_embedding(entity2.clone(), emb.clone(), &db)
.await
.expect("Failed to store entity 2");
.with_context(|| "Failed to store entity 2".to_string())?;
KnowledgeEntity::store_with_embedding(different_entity.clone(), emb.clone(), &db)
.await
.expect("Failed to store different entity");
.with_context(|| "Failed to store different entity".to_string())?;
// Delete by source_id
KnowledgeEntity::delete_by_source_id(&source_id, &db)
.await
.expect("Failed to delete entities by source_id");
.with_context(|| "Failed to delete entities by source_id".to_string())?;
// Verify all entities with the specified source_id are deleted
let query = format!(
"SELECT * FROM {} WHERE source_id = '{}'",
KnowledgeEntity::table_name(),
@@ -734,16 +730,11 @@ mod tests {
.client
.query(query)
.await
.expect("Query failed")
.with_context(|| "Query failed".to_string())?
.take(0)
.expect("Failed to get query results");
assert_eq!(
remaining.len(),
0,
"All entities with the source_id should be deleted"
);
.with_context(|| "Failed to get query results".to_string())?;
assert!(remaining.is_empty(), "All entities with the source_id should be deleted");
// Verify the entity with a different source_id still exists
let different_query = format!(
"SELECT * FROM {} WHERE source_id = '{}'",
KnowledgeEntity::table_name(),
@@ -753,15 +744,20 @@ mod tests {
.client
.query(different_query)
.await
.expect("Query failed")
.with_context(|| "Query failed".to_string())?
.take(0)
.expect("Failed to get query results");
.with_context(|| "Failed to get query results".to_string())?;
assert_eq!(
different_remaining.len(),
1,
"Entity with different source_id should still exist"
);
assert_eq!(different_remaining[0].id, different_entity.id);
assert_eq!(
different_remaining.first().context("Expected entity to exist")?.id,
different_entity.id
);
Ok(())
}
#[tokio::test]
@@ -833,35 +829,37 @@ mod tests {
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations()
.await
.expect("Failed to apply migrations");
.with_context(|| "Failed to apply migrations".to_string())?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("Failed to redefine index length");
.with_context(|| "Failed to redefine index length".to_string())?;
let results = KnowledgeEntity::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user")
.await
.expect("vector search");
.with_context(|| "vector search".to_string())?;
assert!(results.is_empty());
Ok(())
}
#[tokio::test]
async fn test_vector_search_single_result() {
async fn test_vector_search_single_result() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations()
.await
.expect("Failed to apply migrations");
.with_context(|| "Failed to apply migrations".to_string())?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("Failed to redefine index length");
.with_context(|| "Failed to redefine index length".to_string())?;
let user_id = "user".to_string();
let source_id = "src".to_string();
@@ -876,9 +874,12 @@ mod tests {
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], &db)
.await
.expect("store entity with embedding");
.with_context(|| "store entity with embedding".to_string())?;
let stored_entity: Option<KnowledgeEntity> = db.get_item(&entity.id).await.unwrap();
let stored_entity: Option<KnowledgeEntity> = db
.get_item(&entity.id)
.await
.with_context(|| "Failed to get entity".to_string())?;
assert!(stored_entity.is_some());
let stored_embeddings: Vec<KnowledgeEntityEmbedding> = db
@@ -888,42 +889,44 @@ mod tests {
KnowledgeEntityEmbedding::table_name()
))
.await
.expect("query embeddings")
.with_context(|| "query embeddings".to_string())?
.take(0)
.expect("take embeddings");
.with_context(|| "take embeddings".to_string())?;
assert_eq!(stored_embeddings.len(), 1);
let rid = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
let fetched_emb = KnowledgeEntityEmbedding::get_by_entity_id(&rid, &db)
.await
.expect("fetch embedding");
.with_context(|| "fetch embedding".to_string())?;
assert!(fetched_emb.is_some());
let results = KnowledgeEntity::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
.await
.expect("vector search");
.with_context(|| "vector search".to_string())?;
assert_eq!(results.len(), 1);
let res = &results[0];
let res = results.first().context("Expected at least one result")?;
assert_eq!(res.entity.id, entity.id);
assert_eq!(res.entity.source_id, source_id);
assert_eq!(res.entity.name, "hello");
Ok(())
}
#[tokio::test]
async fn test_vector_search_orders_by_similarity() {
async fn test_vector_search_orders_by_similarity() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations()
.await
.expect("Failed to apply migrations");
.with_context(|| "Failed to apply migrations".to_string())?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("Failed to redefine index length");
.with_context(|| "Failed to redefine index length".to_string())?;
let user_id = "user".to_string();
let e1 = KnowledgeEntity::new(
@@ -945,13 +948,19 @@ mod tests {
KnowledgeEntity::store_with_embedding(e1.clone(), vec![1.0, 0.0, 0.0], &db)
.await
.expect("store e1");
.with_context(|| "store e1".to_string())?;
KnowledgeEntity::store_with_embedding(e2.clone(), vec![0.0, 1.0, 0.0], &db)
.await
.expect("store e2");
.with_context(|| "store e2".to_string())?;
let stored_e1: Option<KnowledgeEntity> = db.get_item(&e1.id).await.unwrap();
let stored_e2: Option<KnowledgeEntity> = db.get_item(&e2.id).await.unwrap();
let stored_e1: Option<KnowledgeEntity> = db
.get_item(&e1.id)
.await
.with_context(|| "Failed to get entity".to_string())?;
let stored_e2: Option<KnowledgeEntity> = db
.get_item(&e2.id)
.await
.with_context(|| "Failed to get entity".to_string())?;
assert!(stored_e1.is_some() && stored_e2.is_some());
let stored_embeddings: Vec<KnowledgeEntityEmbedding> = db
@@ -961,45 +970,53 @@ mod tests {
KnowledgeEntityEmbedding::table_name()
))
.await
.expect("query embeddings")
.with_context(|| "query embeddings".to_string())?
.take(0)
.expect("take embeddings");
.with_context(|| "take embeddings".to_string())?;
assert_eq!(stored_embeddings.len(), 2);
let rid_e1 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e1.id);
let rid_e2 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e2.id);
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e1, &db)
.await
.unwrap()
.with_context(|| "get embedding e1".to_string())?
.is_some());
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e2, &db)
.await
.unwrap()
.with_context(|| "get embedding e2".to_string())?
.is_some());
let results = KnowledgeEntity::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id)
.await
.expect("vector search");
.with_context(|| "vector search".to_string())?;
assert_eq!(results.len(), 2);
assert_eq!(results[0].entity.id, e2.id);
assert_eq!(results[1].entity.id, e1.id);
assert_eq!(
results.first().context("Expected at least one result")?.entity.id,
e2.id
);
assert_eq!(
results.get(1).context("Expected at least two results")?.entity.id,
e1.id
);
Ok(())
}
#[tokio::test]
async fn test_vector_search_with_orphaned_embedding() {
async fn test_vector_search_with_orphaned_embedding() -> anyhow::Result<()> {
let namespace = "test_ns_orphan";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations()
.await
.expect("Failed to apply migrations");
.with_context(|| "Failed to apply migrations".to_string())?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("Failed to redefine index length");
.with_context(|| "Failed to redefine index length".to_string())?;
let user_id = "user".to_string();
let source_id = "src".to_string();
@@ -1014,21 +1031,20 @@ mod tests {
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], &db)
.await
.expect("store entity with embedding");
.with_context(|| "store entity with embedding".to_string())?;
// Manually delete the entity to create an orphan
let query = format!("DELETE type::thing('knowledge_entity', '{}')", entity.id);
db.client.query(query).await.expect("delete entity");
db.client
.query(query)
.await
.with_context(|| "delete entity".to_string())?;
// Now search
let results = KnowledgeEntity::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
.await
.expect("search should succeed even with orphans");
.with_context(|| "search should succeed even with orphans".to_string())?;
assert!(
results.is_empty(),
"Should return empty result for orphan, got: {:?}",
results
);
assert!(results.is_empty(), "Should return empty result for orphan, got: {:?}", results);
Ok(())
}
}
@@ -110,11 +110,15 @@ impl KnowledgeEntityEmbedding {
}
/// Delete embeddings by source_id (via joining to knowledge_entity table)
#[allow(clippy::items_after_statements)]
pub async fn delete_by_source_id(
source_id: &str,
db: &SurrealDbClient,
) -> Result<(), AppError> {
#[derive(Deserialize)]
struct IdRow {
id: RecordId,
}
let query = "SELECT id FROM knowledge_entity WHERE source_id = $source_id";
let mut res = db
.client
@@ -122,11 +126,6 @@ impl KnowledgeEntityEmbedding {
.bind(("source_id", source_id.to_owned()))
.await
.map_err(AppError::Database)?;
#[allow(clippy::missing_docs_in_private_items)]
#[derive(Deserialize)]
struct IdRow {
id: RecordId,
}
let ids: Vec<IdRow> = res.take(0).map_err(AppError::Database)?;
for row in ids {
@@ -138,6 +137,7 @@ impl KnowledgeEntityEmbedding {
#[cfg(test)]
mod tests {
use anyhow::{self, Context};
use super::*;
use crate::storage::db::SurrealDbClient;
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
@@ -145,18 +145,18 @@ mod tests {
use surrealdb::Value as SurrealValue;
use uuid::Uuid;
async fn setup_test_db() -> SurrealDbClient {
async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
let namespace = "test_ns";
let database = Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, &database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations()
.await
.expect("Failed to apply migrations");
.with_context(|| "Failed to apply migrations".to_string())?;
db
Ok(db)
}
fn build_knowledge_entity_with_id(
@@ -178,11 +178,11 @@ mod tests {
}
#[tokio::test]
async fn test_create_and_get_by_entity_id() {
let db = setup_test_db().await;
async fn test_create_and_get_by_entity_id() -> anyhow::Result<()> {
let db = setup_test_db().await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("set test index dimension");
.with_context(|| "set test index dimension".to_string())?;
let user_id = "user_ke";
let entity_key = "entity-1";
let source_id = "source-ke";
@@ -192,26 +192,28 @@ mod tests {
KnowledgeEntity::store_with_embedding(entity.clone(), embedding_vec.clone(), &db)
.await
.expect("Failed to store entity with embedding");
.with_context(|| "Failed to store entity with embedding".to_string())?;
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
let fetched = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
.await
.expect("Failed to get embedding by entity_id")
.expect("Expected embedding to exist");
.with_context(|| "Failed to get embedding by entity_id".to_string())?
.ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?;
assert_eq!(fetched.user_id, user_id);
assert_eq!(fetched.entity_id, entity_rid);
assert_eq!(fetched.embedding, embedding_vec);
Ok(())
}
#[tokio::test]
async fn test_delete_by_entity_id() {
let db = setup_test_db().await;
async fn test_delete_by_entity_id() -> anyhow::Result<()> {
let db = setup_test_db().await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("set test index dimension");
.with_context(|| "set test index dimension".to_string())?;
let user_id = "user_ke";
let entity_key = "entity-delete";
let source_id = "source-del";
@@ -220,61 +222,67 @@ mod tests {
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.5_f32, 0.6, 0.7], &db)
.await
.expect("Failed to store entity with embedding");
.with_context(|| "Failed to store entity with embedding".to_string())?;
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
let existing = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
.await
.expect("Failed to get embedding before delete");
.with_context(|| "Failed to get embedding before delete".to_string())?;
assert!(existing.is_some());
KnowledgeEntityEmbedding::delete_by_entity_id(&entity_rid, &db)
.await
.expect("Failed to delete by entity_id");
.with_context(|| "Failed to delete by entity_id".to_string())?;
let after = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
.await
.expect("Failed to get embedding after delete");
.with_context(|| "Failed to get embedding after delete".to_string())?;
assert!(after.is_none());
Ok(())
}
#[tokio::test]
async fn test_store_with_embedding_creates_entity_and_embedding() {
let db = setup_test_db().await;
async fn test_store_with_embedding_creates_entity_and_embedding() -> anyhow::Result<()> {
let db = setup_test_db().await?;
let user_id = "user_store";
let source_id = "source_store";
let embedding = vec![0.2_f32, 0.3, 0.4];
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, embedding.len())
.await
.expect("set test index dimension");
.with_context(|| "set test index dimension".to_string())?;
let entity = build_knowledge_entity_with_id("entity-store", source_id, user_id);
KnowledgeEntity::store_with_embedding(entity.clone(), embedding.clone(), &db)
.await
.expect("Failed to store entity with embedding");
.with_context(|| "Failed to store entity with embedding".to_string())?;
let stored_entity: Option<KnowledgeEntity> = db.get_item(&entity.id).await.unwrap();
let stored_entity: Option<KnowledgeEntity> = db
.get_item(&entity.id)
.await
.with_context(|| "Failed to get entity".to_string())?;
assert!(stored_entity.is_some());
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
let stored_embedding = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
.await
.expect("Failed to fetch embedding");
assert!(stored_embedding.is_some());
let stored_embedding = stored_embedding.unwrap();
.with_context(|| "Failed to fetch embedding".to_string())?;
let stored_embedding = stored_embedding.ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?;
assert_eq!(stored_embedding.user_id, user_id);
assert_eq!(stored_embedding.entity_id, entity_rid);
Ok(())
}
#[tokio::test]
async fn test_delete_by_source_id() {
let db = setup_test_db().await;
async fn test_delete_by_source_id() -> anyhow::Result<()> {
let db = setup_test_db().await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("set test index dimension");
.with_context(|| "set test index dimension".to_string())?;
let user_id = "user_ke";
let source_id = "shared-ke";
let other_source = "other-ke";
@@ -285,13 +293,13 @@ mod tests {
KnowledgeEntity::store_with_embedding(entity1.clone(), vec![1.0_f32, 1.1, 1.2], &db)
.await
.expect("Failed to store entity with embedding");
.with_context(|| "Failed to store entity with embedding".to_string())?;
KnowledgeEntity::store_with_embedding(entity2.clone(), vec![2.0_f32, 2.1, 2.2], &db)
.await
.expect("Failed to store entity with embedding");
.with_context(|| "Failed to store entity with embedding".to_string())?;
KnowledgeEntity::store_with_embedding(entity_other.clone(), vec![3.0_f32, 3.1, 3.2], &db)
.await
.expect("Failed to store entity with embedding");
.with_context(|| "Failed to store entity with embedding".to_string())?;
let entity1_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity1.id);
let entity2_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity2.id);
@@ -299,59 +307,74 @@ mod tests {
KnowledgeEntityEmbedding::delete_by_source_id(source_id, &db)
.await
.expect("Failed to delete by source_id");
.with_context(|| "Failed to delete by source_id".to_string())?;
assert!(
KnowledgeEntityEmbedding::get_by_entity_id(&entity1_rid, &db)
.await
.unwrap()
.with_context(|| "get entity1 embedding after delete".to_string())?
.is_none()
);
assert!(
KnowledgeEntityEmbedding::get_by_entity_id(&entity2_rid, &db)
.await
.unwrap()
.with_context(|| "get entity2 embedding after delete".to_string())?
.is_none()
);
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&other_rid, &db)
.await
.unwrap()
.with_context(|| "get other embedding after delete".to_string())?
.is_some());
Ok(())
}
#[tokio::test]
async fn test_redefine_hnsw_index_updates_dimension() {
let db = setup_test_db().await;
async fn test_redefine_hnsw_index_updates_dimension() -> anyhow::Result<()> {
let db = setup_test_db().await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 16)
.await
.expect("failed to redefine index");
.with_context(|| "failed to redefine index".to_string())?;
let mut info_res = db
.client
.query("INFO FOR TABLE knowledge_entity_embedding;")
.await
.expect("info query failed");
let info: SurrealValue = info_res.take(0).expect("failed to take info result");
let info_json: serde_json::Value =
serde_json::to_value(info).expect("failed to convert info to json");
let idx_sql = info_json["Object"]["indexes"]["Object"]
["idx_embedding_knowledge_entity_embedding"]["Strand"]
.as_str()
.with_context(|| "info query failed".to_string())?;
let info: SurrealValue = info_res
.take(0)
.with_context(|| "failed to take info result".to_string())?;
let info_json: serde_json::Value = serde_json::to_value(info)
.with_context(|| "failed to convert info to json".to_string())?;
let idx_sql = info_json
.get("Object")
.and_then(|v| v.get("indexes"))
.and_then(|v| v.get("Object"))
.and_then(|v| v.get("idx_embedding_knowledge_entity_embedding"))
.and_then(|v| v.get("Strand"))
.and_then(|v| v.as_str())
.unwrap_or_default();
assert!(
idx_sql.contains("DIMENSION 16"),
"expected index definition to contain new dimension, got: {idx_sql}"
);
Ok(())
}
#[tokio::test]
async fn test_fetch_entity_via_record_id() {
let db = setup_test_db().await;
async fn test_fetch_entity_via_record_id() -> anyhow::Result<()> {
#[derive(Deserialize)]
struct Row {
entity_id: KnowledgeEntity,
}
let db = setup_test_db().await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("set test index dimension");
.with_context(|| "set test index dimension".to_string())?;
let user_id = "user_ke";
let entity_key = "entity-fetch";
let source_id = "source-fetch";
@@ -359,15 +382,10 @@ mod tests {
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.7_f32, 0.8, 0.9], &db)
.await
.expect("Failed to store entity with embedding");
.with_context(|| "Failed to store entity with embedding".to_string())?;
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
#[derive(Deserialize)]
struct Row {
entity_id: KnowledgeEntity,
}
let mut res = db
.client
.query(
@@ -375,13 +393,17 @@ mod tests {
)
.bind(("id", entity_rid.clone()))
.await
.expect("failed to fetch embedding with FETCH");
let rows: Vec<Row> = res.take(0).expect("failed to deserialize fetch rows");
.with_context(|| "failed to fetch embedding with FETCH".to_string())?;
let rows: Vec<Row> = res
.take(0)
.with_context(|| "failed to deserialize fetch rows".to_string())?;
assert_eq!(rows.len(), 1);
let fetched_entity = &rows[0].entity_id;
let fetched_entity = &rows.first().context("Expected at least one result")?.entity_id;
assert_eq!(fetched_entity.id, entity_key);
assert_eq!(fetched_entity.name, "Test entity");
assert_eq!(fetched_entity.user_id, user_id);
Ok(())
}
}
@@ -124,6 +124,7 @@ impl KnowledgeRelationship {
#[cfg(test)]
mod tests {
use anyhow::{self, Context};
use super::*;
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
@@ -155,10 +156,9 @@ mod tests {
result.take(0).expect("failed to take relationship by id")
}
// Helper function to create a test knowledge entity for the relationship tests
async fn create_test_entity(name: &str, db_client: &SurrealDbClient) -> String {
async fn create_test_entity(name: &str, db_client: &SurrealDbClient) -> anyhow::Result<String> {
let source_id = "source123".to_string();
let description = format!("Description for {}", name);
let description = format!("Description for {name}");
let entity_type = KnowledgeEntityType::Document;
let user_id = "user123".to_string();
@@ -174,12 +174,14 @@ mod tests {
let stored: Option<KnowledgeEntity> = db_client
.store_item(entity)
.await
.expect("Failed to store entity");
stored.unwrap().id
.with_context(|| "Failed to store entity".to_string())?;
stored
.ok_or_else(|| anyhow::anyhow!("Expected stored entity to return Some"))
.map(|e| e.id)
}
#[tokio::test]
async fn test_relationship_creation() {
async fn test_relationship_creation() -> anyhow::Result<()> {
let in_id = "entity1".to_string();
let out_id = "entity2".to_string();
let user_id = "user123".to_string();
@@ -194,25 +196,23 @@ mod tests {
relationship_type.clone(),
);
// Verify fields are correctly set
assert_eq!(relationship.in_, in_id);
assert_eq!(relationship.out, out_id);
assert_eq!(relationship.metadata.user_id, user_id);
assert_eq!(relationship.metadata.source_id, source_id);
assert_eq!(relationship.metadata.relationship_type, relationship_type);
assert!(!relationship.id.is_empty());
Ok(())
}
#[tokio::test]
async fn test_store_and_verify_by_source_id() {
// Setup in-memory database for testing
async fn test_store_and_verify_by_source_id() -> anyhow::Result<()> {
let db = setup_test_db().await;
// Create two entities to relate
let entity1_id = create_test_entity("Entity 1", &db).await;
let entity2_id = create_test_entity("Entity 2", &db).await;
let entity1_id = create_test_entity("Entity 1", &db).await?;
let entity2_id = create_test_entity("Entity 2", &db).await?;
// Create relationship
let user_id = "user123".to_string();
let source_id = "source123".to_string();
let relationship_type = "references".to_string();
@@ -225,11 +225,10 @@ mod tests {
relationship_type,
);
// Store the relationship
relationship
.store_relationship(&db)
.await
.expect("Failed to store relationship");
.with_context(|| "Failed to store relationship".to_string())?;
let persisted = get_relationship_by_id(&relationship.id, &db)
.await
@@ -239,8 +238,6 @@ mod tests {
assert_eq!(persisted.metadata.user_id, user_id);
assert_eq!(persisted.metadata.source_id, source_id);
// Query to verify the relationship exists by checking for relationships with our source_id
// This approach is more reliable than trying to look up by ID
let mut check_result = db
.query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id")
.bind(("source_id", source_id.clone()))
@@ -253,14 +250,16 @@ mod tests {
1,
"Expected one relationship for source_id"
);
Ok(())
}
#[tokio::test]
async fn test_store_relationship_resists_query_injection() {
async fn test_store_relationship_resists_query_injection() -> anyhow::Result<()> {
let db = setup_test_db().await;
let entity1_id = create_test_entity("Entity 1", &db).await;
let entity2_id = create_test_entity("Entity 2", &db).await;
let entity1_id = create_test_entity("Entity 1", &db).await?;
let entity2_id = create_test_entity("Entity 2", &db).await?;
let relationship = KnowledgeRelationship::new(
entity1_id,
@@ -288,18 +287,17 @@ mod tests {
rows[0].metadata.source_id,
"source123'; DELETE FROM relates_to; --"
);
Ok(())
}
#[tokio::test]
async fn test_store_and_delete_relationship() {
// Setup in-memory database for testing
async fn test_store_and_delete_relationship() -> anyhow::Result<()> {
let db = setup_test_db().await;
// Create two entities to relate
let entity1_id = create_test_entity("Entity 1", &db).await;
let entity2_id = create_test_entity("Entity 2", &db).await;
let entity1_id = create_test_entity("Entity 1", &db).await?;
let entity2_id = create_test_entity("Entity 2", &db).await?;
// Create relationship
let user_id = "user123".to_string();
let source_id = "source123".to_string();
let relationship_type = "references".to_string();
@@ -312,52 +310,44 @@ mod tests {
relationship_type,
);
// Store relationship
relationship
.store_relationship(&db)
.await
.expect("Failed to store relationship");
.with_context(|| "Failed to store relationship".to_string())?;
// Ensure relationship exists before deletion attempt
let mut existing_before_delete = db
.query(format!(
"SELECT * FROM relates_to WHERE metadata.user_id = '{}' AND metadata.source_id = '{}'",
user_id, source_id
"SELECT * FROM relates_to WHERE metadata.user_id = '{user_id}' AND metadata.source_id = '{source_id}'"
))
.await
.expect("Query failed");
.with_context(|| "Query failed".to_string())?;
let before_results: Vec<KnowledgeRelationship> =
existing_before_delete.take(0).unwrap_or_default();
assert!(
!before_results.is_empty(),
"Relationship should exist before deletion"
);
assert!(!before_results.is_empty(), "Relationship should exist before deletion");
// Delete relationship by ID
KnowledgeRelationship::delete_relationship_by_id(&relationship.id, &user_id, &db)
.await
.expect("Failed to delete relationship by ID");
.with_context(|| "Failed to delete relationship by ID".to_string())?;
// Query to verify relationship was deleted
let mut result = db
.query(format!(
"SELECT * FROM relates_to WHERE metadata.user_id = '{}' AND metadata.source_id = '{}'",
user_id, source_id
"SELECT * FROM relates_to WHERE metadata.user_id = '{user_id}' AND metadata.source_id = '{source_id}'"
))
.await
.expect("Query failed");
.with_context(|| "Query failed".to_string())?;
let results: Vec<KnowledgeRelationship> = result.take(0).unwrap_or_default();
// Verify relationship no longer exists
assert!(results.is_empty(), "Relationship should be deleted");
Ok(())
}
#[tokio::test]
async fn test_delete_relationship_by_id_unauthorized() {
async fn test_delete_relationship_by_id_unauthorized() -> anyhow::Result<()> {
let db = setup_test_db().await;
let entity1_id = create_test_entity("Entity 1", &db).await;
let entity2_id = create_test_entity("Entity 2", &db).await;
let entity1_id = create_test_entity("Entity 1", &db).await?;
let entity2_id = create_test_entity("Entity 2", &db).await?;
let owner_user_id = "owner-user".to_string();
let source_id = "source123".to_string();
@@ -373,20 +363,16 @@ mod tests {
relationship
.store_relationship(&db)
.await
.expect("Failed to store relationship");
.with_context(|| "Failed to store relationship".to_string())?;
let mut before_attempt = db
.query(format!(
"SELECT * FROM relates_to WHERE metadata.user_id = '{}'",
owner_user_id
"SELECT * FROM relates_to WHERE metadata.user_id = '{owner_user_id}'"
))
.await
.expect("Query failed");
.with_context(|| "Query failed".to_string())?;
let before_results: Vec<KnowledgeRelationship> = before_attempt.take(0).unwrap_or_default();
assert!(
!before_results.is_empty(),
"Relationship should exist before unauthorized delete attempt"
);
assert!(!before_results.is_empty(), "Relationship should exist before unauthorized delete attempt");
let result = KnowledgeRelationship::delete_relationship_by_id(
&relationship.id,
@@ -397,40 +383,34 @@ mod tests {
match result {
Err(AppError::Auth(_)) => {}
_ => panic!("Expected authorization error when deleting someone else's relationship"),
_ => anyhow::bail!("Expected authorization error when deleting someone else's relationship"),
}
let mut after_attempt = db
.query(format!(
"SELECT * FROM relates_to WHERE metadata.user_id = '{}'",
owner_user_id
"SELECT * FROM relates_to WHERE metadata.user_id = '{owner_user_id}'"
))
.await
.expect("Query failed");
.with_context(|| "Query failed".to_string())?;
let results: Vec<KnowledgeRelationship> = after_attempt.take(0).unwrap_or_default();
assert!(
!results.is_empty(),
"Relationship should still exist after unauthorized delete attempt"
);
assert!(!results.is_empty(), "Relationship should still exist after unauthorized delete attempt");
Ok(())
}
#[tokio::test]
async fn test_store_relationship_exists() {
// Setup in-memory database for testing
async fn test_store_relationship_exists() -> anyhow::Result<()> {
let db = setup_test_db().await;
// Create entities to relate
let entity1_id = create_test_entity("Entity 1", &db).await;
let entity2_id = create_test_entity("Entity 2", &db).await;
let entity3_id = create_test_entity("Entity 3", &db).await;
let entity1_id = create_test_entity("Entity 1", &db).await?;
let entity2_id = create_test_entity("Entity 2", &db).await?;
let entity3_id = create_test_entity("Entity 3", &db).await?;
// Create relationships with the same source_id
let user_id = "user123".to_string();
let source_id = "source123".to_string();
let different_source_id = "different_source".to_string();
// Create two relationships with the same source_id
let relationship1 = KnowledgeRelationship::new(
entity1_id.clone(),
entity2_id.clone(),
@@ -447,7 +427,6 @@ mod tests {
"contains".to_string(),
);
// Create a relationship with a different source_id
let different_relationship = KnowledgeRelationship::new(
entity1_id.clone(),
entity3_id.clone(),
@@ -456,21 +435,19 @@ mod tests {
"mentions".to_string(),
);
// Store all relationships
relationship1
.store_relationship(&db)
.await
.expect("Failed to store relationship 1");
.with_context(|| "Failed to store relationship 1".to_string())?;
relationship2
.store_relationship(&db)
.await
.expect("Failed to store relationship 2");
.with_context(|| "Failed to store relationship 2".to_string())?;
different_relationship
.store_relationship(&db)
.await
.expect("Failed to store different relationship");
.with_context(|| "Failed to store different relationship".to_string())?;
// Sanity-check setup: exactly two relationships use source_id and one uses different_source_id.
let mut before_delete = db
.query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id")
.bind(("source_id", source_id.clone()))
@@ -489,31 +466,30 @@ mod tests {
before_delete_different.take(0).unwrap_or_default();
assert_eq!(before_delete_different_rows.len(), 1);
// Delete relationships by source_id
KnowledgeRelationship::delete_relationships_by_source_id(&source_id, &db)
.await
.expect("Failed to delete relationships by source_id");
.with_context(|| "Failed to delete relationships by source_id".to_string())?;
// Query to verify the specific relationships with source_id were deleted.
let result1 = get_relationship_by_id(&relationship1.id, &db).await;
let result2 = get_relationship_by_id(&relationship2.id, &db).await;
let different_result = get_relationship_by_id(&different_relationship.id, &db).await;
// Verify relationships with the source_id are deleted
assert!(result1.is_none(), "Relationship 1 should be deleted");
assert!(result2.is_none(), "Relationship 2 should be deleted");
let remaining =
different_result.expect("Relationship with different source_id should remain");
assert_eq!(remaining.metadata.source_id, different_source_id);
Ok(())
}
#[tokio::test]
async fn test_delete_relationships_by_source_id_resists_query_injection() {
async fn test_delete_relationships_by_source_id_resists_query_injection() -> anyhow::Result<()> {
let db = setup_test_db().await;
let entity1_id = create_test_entity("Entity 1", &db).await;
let entity2_id = create_test_entity("Entity 2", &db).await;
let entity3_id = create_test_entity("Entity 3", &db).await;
let entity1_id = create_test_entity("Entity 1", &db).await?;
let entity2_id = create_test_entity("Entity 2", &db).await?;
let entity3_id = create_test_entity("Entity 3", &db).await?;
let safe_relationship = KnowledgeRelationship::new(
entity1_id.clone(),
@@ -552,5 +528,7 @@ mod tests {
remaining_other.is_some(),
"Other relationship should remain"
);
Ok(())
}
}
+21 -23
View File
@@ -66,12 +66,12 @@ pub fn format_history(history: &[Message]) -> String {
#[cfg(test)]
mod tests {
use anyhow::{self, Context};
use super::*;
use crate::storage::db::SurrealDbClient;
#[tokio::test]
async fn test_message_creation() {
// Test basic message creation
async fn test_message_creation() -> anyhow::Result<()> {
let conversation_id = "test_conversation";
let content = "This is a test message";
let role = MessageRole::User;
@@ -84,24 +84,23 @@ mod tests {
references.clone(),
);
// Verify message properties
assert_eq!(message.conversation_id, conversation_id);
assert_eq!(message.content, content);
assert_eq!(message.role, role);
assert_eq!(message.references, references);
assert!(!message.id.is_empty());
Ok(())
}
#[tokio::test]
async fn test_message_persistence() {
// Setup in-memory database for testing
async fn test_message_persistence() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &uuid::Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Create and store a message
let conversation_id = "test_conversation";
let message = Message::new(
conversation_id.to_string(),
@@ -111,39 +110,37 @@ mod tests {
);
let message_id = message.id.clone();
// Store the message
db.store_item(message.clone())
.await
.expect("Failed to store message");
.with_context(|| "Failed to store message".to_string())?;
// Retrieve the message
let retrieved: Option<Message> = db
.get_item(&message_id)
.await
.expect("Failed to retrieve message");
.with_context(|| "Failed to retrieve message".to_string())?;
assert!(retrieved.is_some());
let retrieved = retrieved.unwrap();
let retrieved = retrieved.ok_or_else(|| anyhow::anyhow!("Expected message to exist"))?;
// Verify retrieved properties match original
assert_eq!(retrieved.id, message.id);
assert_eq!(retrieved.conversation_id, message.conversation_id);
assert_eq!(retrieved.role, message.role);
assert_eq!(retrieved.content, message.content);
assert_eq!(retrieved.references, message.references);
Ok(())
}
#[tokio::test]
async fn test_message_role_display() {
// Test the Display implementation for MessageRole
async fn test_message_role_display() -> anyhow::Result<()> {
assert_eq!(format!("{}", MessageRole::User), "User");
assert_eq!(format!("{}", MessageRole::AI), "AI");
assert_eq!(format!("{}", MessageRole::System), "System");
Ok(())
}
#[tokio::test]
async fn test_message_display() {
// Test the Display implementation for Message
async fn test_message_display() -> anyhow::Result<()> {
let message = Message {
id: "test_id".to_string(),
created_at: Utc::now(),
@@ -154,12 +151,13 @@ mod tests {
references: None,
};
assert_eq!(format!("{}", message), "User: Hello world");
assert_eq!(format!("{message}"), "User: Hello world");
Ok(())
}
#[tokio::test]
async fn test_format_history() {
// Create a vector of messages
async fn test_format_history() -> anyhow::Result<()> {
let messages = vec![
Message {
id: "1".to_string(),
@@ -181,10 +179,10 @@ mod tests {
},
];
// Format the history
let formatted = format_history(&messages);
// Verify the formatting
assert_eq!(formatted, "User: Hello\nAI: Hi there!");
Ok(())
}
}
+64 -54
View File
@@ -216,20 +216,22 @@ impl Scratchpad {
#[cfg(test)]
mod tests {
use anyhow::{self, Context};
use super::*;
#[tokio::test]
async fn test_create_scratchpad() {
async fn test_create_scratchpad() -> anyhow::Result<()> {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations()
.await
.expect("Failed to apply migrations");
.with_context(|| "Failed to apply migrations".to_string())?;
// Create a new scratchpad
let user_id = "test_user";
@@ -254,29 +256,28 @@ mod tests {
let retrieved: Option<Scratchpad> = db
.get_item(&scratchpad.id)
.await
.expect("Failed to retrieve scratchpad");
assert!(retrieved.is_some());
let retrieved = retrieved.unwrap();
.with_context(|| "Failed to retrieve scratchpad".to_string())?;
let retrieved = retrieved.with_context(|| "expected scratchpad to exist".to_string())?;
assert_eq!(retrieved.id, scratchpad.id);
assert_eq!(retrieved.user_id, user_id);
assert_eq!(retrieved.title, title);
assert!(!retrieved.is_archived);
assert!(retrieved.archived_at.is_none());
assert!(retrieved.ingested_at.is_none());
Ok(())
}
#[tokio::test]
async fn test_get_by_user() {
async fn test_get_by_user() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations()
.await
.expect("Failed to apply migrations");
.with_context(|| "Failed to apply migrations".to_string())?;
let user_id = "test_user";
@@ -288,19 +289,21 @@ mod tests {
// Store them
let scratchpad1_id = scratchpad1.id.clone();
let scratchpad2_id = scratchpad2.id.clone();
db.store_item(scratchpad1).await.unwrap();
db.store_item(scratchpad2).await.unwrap();
db.store_item(scratchpad3).await.unwrap();
db.store_item(scratchpad1).await.with_context(|| "store scratchpad1".to_string())?;
db.store_item(scratchpad2).await.with_context(|| "store scratchpad2".to_string())?;
db.store_item(scratchpad3).await.with_context(|| "store scratchpad3".to_string())?;
// Archive one of the user's scratchpads
Scratchpad::archive(&scratchpad2_id, user_id, &db, false)
.await
.unwrap();
.with_context(|| "archive".to_string())?;
// Get scratchpads for user_id
let user_scratchpads = Scratchpad::get_by_user(user_id, &db).await.unwrap();
let user_scratchpads = Scratchpad::get_by_user(user_id, &db)
.await
.with_context(|| "get_by_user".to_string())?;
assert_eq!(user_scratchpads.len(), 1);
assert_eq!(user_scratchpads[0].id, scratchpad1_id);
assert_eq!(user_scratchpads.first().map(|s| &s.id), Some(&scratchpad1_id));
// Verify they belong to the user
for scratchpad in &user_scratchpads {
@@ -309,177 +312,183 @@ mod tests {
let archived = Scratchpad::get_archived_by_user(user_id, &db)
.await
.unwrap();
.with_context(|| "get_archived_by_user".to_string())?;
assert_eq!(archived.len(), 1);
assert_eq!(archived[0].id, scratchpad2_id);
assert!(archived[0].is_archived);
assert!(archived[0].ingested_at.is_none());
assert_eq!(archived.first().map(|s| &s.id), Some(&scratchpad2_id));
assert!(archived.first().is_some_and(|s| s.is_archived));
assert!(archived.first().is_some_and(|s| s.ingested_at.is_none()));
Ok(())
}
#[tokio::test]
async fn test_archive_and_restore() {
async fn test_archive_and_restore() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations()
.await
.expect("Failed to apply migrations");
.with_context(|| "Failed to apply migrations".to_string())?;
let user_id = "test_user";
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
let scratchpad_id = scratchpad.id.clone();
db.store_item(scratchpad).await.unwrap();
db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?;
let archived = Scratchpad::archive(&scratchpad_id, user_id, &db, true)
.await
.expect("Failed to archive");
.with_context(|| "Failed to archive".to_string())?;
assert!(archived.is_archived);
assert!(archived.archived_at.is_some());
assert!(archived.ingested_at.is_some());
let restored = Scratchpad::restore(&scratchpad_id, user_id, &db)
.await
.expect("Failed to restore");
.with_context(|| "Failed to restore".to_string())?;
assert!(!restored.is_archived);
assert!(restored.archived_at.is_none());
assert!(restored.ingested_at.is_none());
Ok(())
}
#[tokio::test]
async fn test_update_content() {
async fn test_update_content() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations()
.await
.expect("Failed to apply migrations");
.with_context(|| "Failed to apply migrations".to_string())?;
let user_id = "test_user";
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
let scratchpad_id = scratchpad.id.clone();
db.store_item(scratchpad).await.unwrap();
db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?;
let new_content = "Updated content";
let updated = Scratchpad::update_content(&scratchpad_id, user_id, new_content, &db)
.await
.unwrap();
.with_context(|| "update_content".to_string())?;
assert_eq!(updated.content, new_content);
assert!(!updated.is_dirty);
Ok(())
}
#[tokio::test]
async fn test_update_content_unauthorized() {
async fn test_update_content_unauthorized() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations()
.await
.expect("Failed to apply migrations");
.with_context(|| "Failed to apply migrations".to_string())?;
let owner_id = "owner";
let other_user = "other_user";
let scratchpad = Scratchpad::new(owner_id.to_string(), "Test".to_string());
let scratchpad_id = scratchpad.id.clone();
db.store_item(scratchpad).await.unwrap();
db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?;
let result = Scratchpad::update_content(&scratchpad_id, other_user, "Hacked", &db).await;
assert!(result.is_err());
match result {
Err(AppError::Auth(_)) => {}
_ => panic!("Expected Auth error"),
_ => anyhow::bail!("Expected Auth error"),
}
Ok(())
}
#[tokio::test]
async fn test_delete_scratchpad() {
async fn test_delete_scratchpad() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations()
.await
.expect("Failed to apply migrations");
.with_context(|| "Failed to apply migrations".to_string())?;
let user_id = "test_user";
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
let scratchpad_id = scratchpad.id.clone();
db.store_item(scratchpad).await.unwrap();
db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?;
// Delete should succeed
let result = Scratchpad::delete(&scratchpad_id, user_id, &db).await;
assert!(result.is_ok());
// Verify it's gone
let retrieved: Option<Scratchpad> = db.get_item(&scratchpad_id).await.unwrap();
let retrieved: Option<Scratchpad> = db.get_item(&scratchpad_id).await.with_context(|| "get_item".to_string())?;
assert!(retrieved.is_none());
Ok(())
}
#[tokio::test]
async fn test_delete_unauthorized() {
async fn test_delete_unauthorized() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations()
.await
.expect("Failed to apply migrations");
.with_context(|| "Failed to apply migrations".to_string())?;
let owner_id = "owner";
let other_user = "other_user";
let scratchpad = Scratchpad::new(owner_id.to_string(), "Test".to_string());
let scratchpad_id = scratchpad.id.clone();
db.store_item(scratchpad).await.unwrap();
db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?;
let result = Scratchpad::delete(&scratchpad_id, other_user, &db).await;
assert!(result.is_err());
match result {
Err(AppError::Auth(_)) => {}
_ => panic!("Expected Auth error"),
_ => anyhow::bail!("Expected Auth error"),
}
// Verify it still exists
let retrieved: Option<Scratchpad> = db.get_item(&scratchpad_id).await.unwrap();
let retrieved: Option<Scratchpad> = db.get_item(&scratchpad_id).await.with_context(|| "get_item".to_string())?;
assert!(retrieved.is_some());
Ok(())
}
#[tokio::test]
async fn test_timezone_aware_scratchpad_conversion() {
async fn test_timezone_aware_scratchpad_conversion() -> anyhow::Result<()> {
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
.await
.expect("Failed to create test database");
.with_context(|| "Failed to create test database".to_string())?;
db.apply_migrations()
.await
.expect("Failed to apply migrations");
.with_context(|| "Failed to apply migrations".to_string())?;
let user_id = "test_user_123";
let scratchpad =
Scratchpad::new(user_id.to_string(), "Test Timezone Scratchpad".to_string());
let scratchpad_id = scratchpad.id.clone();
db.store_item(scratchpad).await.unwrap();
db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?;
let retrieved = Scratchpad::get_by_id(&scratchpad_id, user_id, &db)
.await
.unwrap();
.with_context(|| "get_by_id".to_string())?;
// Test that datetime fields are preserved and can be used for timezone formatting
assert!(retrieved.created_at.timestamp() > 0);
@@ -493,10 +502,11 @@ mod tests {
// Archive the scratchpad to test optional datetime handling
let archived = Scratchpad::archive(&scratchpad_id, user_id, &db, false)
.await
.unwrap();
.with_context(|| "archive".to_string())?;
assert!(archived.archived_at.is_some());
assert!(archived.archived_at.unwrap().timestamp() > 0);
assert!(archived.archived_at.with_context(|| "expected archived_at".to_string())?.timestamp() > 0);
assert!(archived.ingested_at.is_none());
Ok(())
}
}
+109 -91
View File
@@ -64,7 +64,14 @@ impl SystemSettings {
let mut needs_update = false;
let backend_label = provider.backend_label().to_string();
let provider_dimensions = provider.dimension() as u32;
let provider_dimensions = u32::try_from(provider.dimension())
.unwrap_or_else(|_| {
tracing::warn!(
"Provider dimension {} exceeds u32 max; falling back to 0",
provider.dimension()
);
0u32
});
let provider_model = provider.model_code();
// Sync backend label
@@ -107,7 +114,8 @@ impl SystemSettings {
#[cfg(test)]
mod tests {
use crate::storage::indexes::ensure_runtime_indexes;
use anyhow::{self, Context};
use crate::storage::indexes::ensure_runtime;
use crate::storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk};
use async_openai::Client;
@@ -118,68 +126,102 @@ mod tests {
db: &SurrealDbClient,
table_name: &str,
index_name: &str,
) -> u32 {
) -> anyhow::Result<u32> {
let query = format!("INFO FOR TABLE {table_name};");
let mut response = db
.client
.query(query)
.await
.expect("Failed to fetch table info");
.with_context(|| "Failed to fetch table info".to_string())?;
let info: surrealdb::Value = response
.take(0)
.expect("Failed to extract table info response");
.with_context(|| "Failed to extract table info response".to_string())?;
let info_json: serde_json::Value =
serde_json::to_value(info).expect("Failed to convert info to json");
serde_json::to_value(info).with_context(|| "Failed to convert info to json".to_string())?;
let indexes = info_json["Object"]["indexes"]["Object"]
.as_object()
.unwrap_or_else(|| panic!("Indexes collection missing in table info: {info_json:#?}"));
let indexes = info_json
.get("Object")
.and_then(|v| v.get("indexes"))
.and_then(|v| v.get("Object"))
.and_then(|v| v.as_object())
.with_context(|| format!("Indexes collection missing in table info: {info_json:#?}"))?;
let definition = indexes
.get(index_name)
.and_then(|definition| definition.get("Strand"))
.and_then(|v| v.as_str())
.unwrap_or_else(|| panic!("Index definition not found in table info: {info_json:#?}"));
.with_context(|| format!("Index definition not found in table info: {info_json:#?}"))?;
let dimension_part = definition
.split("DIMENSION")
.nth(1)
.expect("Index definition missing DIMENSION clause");
.with_context(|| "Index definition missing DIMENSION clause".to_string())?;
let dimension_token = dimension_part
.split_whitespace()
.next()
.expect("Dimension value missing in definition")
.with_context(|| "Dimension value missing in definition".to_string())?
.trim_end_matches(';');
dimension_token
.parse::<u32>()
.expect("Dimension value is not a valid number")
.with_context(|| "Dimension value is not a valid number".to_string())
}
async fn simulate_reembedding(
db: &SurrealDbClient,
target_dimension: usize,
initial_chunk: TextChunk,
) -> anyhow::Result<()> {
db.query(
"REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding;",
)
.await
.with_context(|| "remove index".to_string())?;
let define_index_query = format!(
"DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {target_dimension};"
);
db.query(define_index_query)
.await
.with_context(|| "Re-defining index should succeed".to_string())?;
let new_embedding = vec![0.5; target_dimension];
let sql = "UPSERT type::thing('text_chunk_embedding', $id) SET chunk_id = type::thing('text_chunk', $id), embedding = $embedding, user_id = $user_id;";
db.client
.query(sql)
.bind(("id", initial_chunk.id.clone()))
.bind(("user_id", initial_chunk.user_id.clone()))
.bind(("embedding", new_embedding))
.await
.with_context(|| "upsert embedding".to_string())?;
Ok(())
}
#[tokio::test]
async fn test_settings_initialization() {
async fn test_settings_initialization() -> anyhow::Result<()> {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Test initialization of system settings
db.apply_migrations()
.await
.expect("Failed to apply migrations");
.with_context(|| "Failed to apply migrations".to_string())?;
let settings = SystemSettings::get_current(&db)
.await
.expect("Failed to get system settings");
.with_context(|| "Failed to get system settings".to_string())?;
// Verify initial state after initialization
assert_eq!(settings.id, "current");
assert_eq!(settings.registrations_enabled, true);
assert_eq!(settings.require_email_verification, false);
assert!(settings.registrations_enabled);
assert!(!settings.require_email_verification);
assert_eq!(settings.query_model, "gpt-4o-mini");
assert_eq!(settings.processing_model, "gpt-4o-mini");
assert_eq!(settings.image_processing_model, "gpt-4o-mini");
@@ -196,10 +238,10 @@ mod tests {
// Test idempotency - ensure calling it again doesn't change anything
db.apply_migrations()
.await
.expect("Failed to apply migrations");
.with_context(|| "Failed to apply migrations".to_string())?;
let settings_again = SystemSettings::get_current(&db)
.await
.expect("Failed to get settings after initialization");
.with_context(|| "Failed to get settings after initialization".to_string())?;
assert_eq!(settings.id, settings_again.id);
assert_eq!(
@@ -210,48 +252,52 @@ mod tests {
settings.require_email_verification,
settings_again.require_email_verification
);
Ok(())
}
#[tokio::test]
async fn test_get_current_settings() {
async fn test_get_current_settings() -> anyhow::Result<()> {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Initialize settings
db.apply_migrations()
.await
.expect("Failed to apply migrations");
.with_context(|| "Failed to apply migrations".to_string())?;
// Test get_current method
let settings = SystemSettings::get_current(&db)
.await
.expect("Failed to get current settings");
.with_context(|| "Failed to get current settings".to_string())?;
assert_eq!(settings.id, "current");
assert_eq!(settings.registrations_enabled, true);
assert_eq!(settings.require_email_verification, false);
assert!(settings.registrations_enabled);
assert!(!settings.require_email_verification);
Ok(())
}
#[tokio::test]
async fn test_update_settings() {
async fn test_update_settings() -> anyhow::Result<()> {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Initialize settings
db.apply_migrations()
.await
.expect("Failed to apply migrations");
.with_context(|| "Failed to apply migrations".to_string())?;
// Create updated settings
let mut updated_settings = SystemSettings::get_current(&db).await.unwrap();
let mut updated_settings = SystemSettings::get_current(&db)
.await
.with_context(|| "get_current".to_string())?;
updated_settings.id = "current".to_string();
updated_settings.registrations_enabled = false;
updated_settings.require_email_verification = true;
@@ -260,31 +306,32 @@ mod tests {
// Test update method
let result = SystemSettings::update(&db, updated_settings)
.await
.expect("Failed to update settings");
.with_context(|| "Failed to update settings".to_string())?;
assert_eq!(result.id, "current");
assert_eq!(result.registrations_enabled, false);
assert_eq!(result.require_email_verification, true);
assert!(!result.registrations_enabled);
assert!(result.require_email_verification);
assert_eq!(result.query_model, "gpt-4");
// Verify changes persisted by getting current settings
let current = SystemSettings::get_current(&db)
.await
.expect("Failed to get current settings after update");
.with_context(|| "Failed to get current settings after update".to_string())?;
assert_eq!(current.registrations_enabled, false);
assert_eq!(current.require_email_verification, true);
assert!(!current.registrations_enabled);
assert!(current.require_email_verification);
assert_eq!(current.query_model, "gpt-4");
Ok(())
}
#[tokio::test]
async fn test_get_current_nonexistent() {
async fn test_get_current_nonexistent() -> anyhow::Result<()> {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Don't initialize settings and try to get them
let result = SystemSettings::get_current(&db).await;
@@ -294,21 +341,22 @@ mod tests {
Err(AppError::NotFound(_)) => {
// Expected error
}
Err(e) => panic!("Expected NotFound error, got: {:?}", e),
Ok(_) => panic!("Expected error but got Ok"),
Err(e) => anyhow::bail!("Expected NotFound error, got: {e:?}"),
Ok(_) => anyhow::bail!("Expected error but got Ok"),
}
Ok(())
}
#[tokio::test]
async fn test_migration_after_changing_embedding_length() {
async fn test_migration_after_changing_embedding_length() -> anyhow::Result<()> {
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
.await
.expect("Failed to start DB");
.with_context(|| "Failed to start DB".to_string())?;
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
db.apply_migrations()
.await
.expect("Initial migration failed");
.with_context(|| "Initial migration failed".to_string())?;
let initial_chunk = TextChunk::new(
"source1".into(),
@@ -318,43 +366,11 @@ mod tests {
TextChunk::store_with_embedding(initial_chunk.clone(), vec![0.1; 1536], &db)
.await
.expect("Failed to store initial chunk with embedding");
async fn simulate_reembedding(
db: &SurrealDbClient,
target_dimension: usize,
initial_chunk: TextChunk,
) {
db.query(
"REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding;",
)
.await
.unwrap();
let define_index_query = format!(
"DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};",
target_dimension
);
db.query(define_index_query)
.await
.expect("Re-defining index should succeed");
let new_embedding = vec![0.5; target_dimension];
let sql = "UPSERT type::thing('text_chunk_embedding', $id) SET chunk_id = type::thing('text_chunk', $id), embedding = $embedding, user_id = $user_id;";
let update_result = db
.client
.query(sql)
.bind(("id", initial_chunk.id.clone()))
.bind(("user_id", initial_chunk.user_id.clone()))
.bind(("embedding", new_embedding))
.await;
assert!(update_result.is_ok());
}
.with_context(|| "Failed to store initial chunk with embedding".to_string())?;
// Re-embed with the existing configured dimension to ensure migrations remain idempotent.
let target_dimension = 1536usize;
simulate_reembedding(&db, target_dimension, initial_chunk).await;
simulate_reembedding(&db, target_dimension, initial_chunk).await?;
let migration_result = db.apply_migrations().await;
@@ -363,34 +379,35 @@ mod tests {
"Migrations should not fail: {:?}",
migration_result.err()
);
Ok(())
}
#[tokio::test]
async fn test_should_change_embedding_length_on_indexes_when_switching_length() {
async fn test_should_change_embedding_length_on_indexes_when_switching_length() -> anyhow::Result<()> {
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
.await
.expect("Failed to start DB");
.with_context(|| "Failed to start DB".to_string())?;
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
db.apply_migrations()
.await
.expect("Initial migration failed");
.with_context(|| "Initial migration failed".to_string())?;
let mut current_settings = SystemSettings::get_current(&db)
.await
.expect("Failed to load current settings");
.with_context(|| "Failed to load current settings".to_string())?;
// Ensure runtime indexes exist with the current embedding dimension so INFO queries succeed.
ensure_runtime_indexes(&db, current_settings.embedding_dimensions as usize)
ensure_runtime(&db, current_settings.embedding_dimensions as usize)
.await
.expect("failed to build runtime indexes");
.with_context(|| "failed to build runtime indexes".to_string())?;
let initial_chunk_dimension = get_hnsw_index_dimension(
&db,
"text_chunk_embedding",
"idx_embedding_text_chunk_embedding",
)
.await;
.await?;
assert_eq!(
initial_chunk_dimension, current_settings.embedding_dimensions,
@@ -405,7 +422,7 @@ mod tests {
let updated_settings = SystemSettings::update(&db, current_settings)
.await
.expect("Failed to update settings");
.with_context(|| "Failed to update settings".to_string())?;
assert_eq!(
updated_settings.embedding_dimensions, new_dimension,
@@ -416,23 +433,23 @@ mod tests {
TextChunk::update_all_embeddings(&db, &openai_client, &new_model, new_dimension)
.await
.expect("TextChunk re-embedding should succeed on fresh DB");
.with_context(|| "TextChunk re-embedding should succeed on fresh DB".to_string())?;
KnowledgeEntity::update_all_embeddings(&db, &openai_client, &new_model, new_dimension)
.await
.expect("KnowledgeEntity re-embedding should succeed on fresh DB");
.with_context(|| "KnowledgeEntity re-embedding should succeed on fresh DB".to_string())?;
let text_chunk_dimension = get_hnsw_index_dimension(
&db,
"text_chunk_embedding",
"idx_embedding_text_chunk_embedding",
)
.await;
.await?;
let knowledge_dimension = get_hnsw_index_dimension(
&db,
"knowledge_entity_embedding",
"idx_embedding_knowledge_entity_embedding",
)
.await;
.await?;
assert_eq!(
text_chunk_dimension, new_dimension,
@@ -445,10 +462,11 @@ mod tests {
let persisted_settings = SystemSettings::get_current(&db)
.await
.expect("Failed to reload updated settings");
.with_context(|| "Failed to reload updated settings".to_string())?;
assert_eq!(
persisted_settings.embedding_dimensions, new_dimension,
"Settings should persist new embedding dimension"
);
Ok(())
}
}
+137 -122
View File
@@ -1,4 +1,4 @@
#![allow(clippy::missing_docs_in_private_items, clippy::uninlined_format_args)]
#![allow(clippy::missing_docs_in_private_items)]
use std::collections::HashMap;
use std::fmt::Write;
@@ -237,10 +237,7 @@ impl TextChunk {
new_model: &str,
new_dimensions: u32,
) -> Result<(), AppError> {
info!(
"Starting re-embedding process for all text chunks. New dimensions: {}",
new_dimensions
);
info!("Starting re-embedding process for all text chunks. New dimensions: {new_dimensions}");
// Fetch all chunks first
let all_chunks: Vec<TextChunk> = db.select(Self::table_name()).await?;
@@ -252,7 +249,7 @@ impl TextChunk {
return Ok(());
}
info!("Found {} chunks to process.", total_chunks);
info!("Found {total_chunks} chunks to process.");
// Generate all new embeddings in memory
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new();
@@ -276,7 +273,7 @@ impl TextChunk {
"CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.",
chunk.id, embedding.len(), new_dimensions
);
error!("{}", err_msg);
error!("{err_msg}");
return Err(AppError::InternalError(err_msg));
}
new_embeddings.insert(
@@ -300,6 +297,7 @@ impl TextChunk {
.join(",")
);
// Use the chunk id as the embedding record id to keep a 1:1 mapping
let embedding = embedding_str;
write!(
&mut transaction_query,
"UPSERT type::thing('text_chunk_embedding', '{id}') SET \
@@ -309,18 +307,13 @@ impl TextChunk {
user_id = '{user_id}', \
created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \
updated_at = time::now();",
id = id,
embedding = embedding_str,
user_id = user_id,
source_id = source_id
)
.map_err(|e| AppError::InternalError(e.to_string()))?;
}
write!(
&mut transaction_query,
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};",
new_dimensions
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};",
)
.map_err(|e| AppError::InternalError(e.to_string()))?;
@@ -377,7 +370,7 @@ impl TextChunk {
"CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.",
chunk.id, embedding.len(), new_dimensions
);
error!("{}", err_msg);
error!("{err_msg}");
return Err(AppError::InternalError(err_msg));
}
new_embeddings.insert(
@@ -422,6 +415,7 @@ impl TextChunk {
.collect::<Vec<_>>()
.join(",")
);
let embedding = embedding_str;
write!(
&mut transaction_query,
"CREATE type::thing('text_chunk_embedding', '{id}') SET \
@@ -431,18 +425,13 @@ impl TextChunk {
user_id = '{user_id}', \
created_at = time::now(), \
updated_at = time::now();",
id = id,
embedding = embedding_str,
user_id = user_id,
source_id = source_id
)
.map_err(|e| AppError::InternalError(e.to_string()))?;
}
write!(
&mut transaction_query,
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};",
new_dimensions
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};",
)
.map_err(|e| AppError::InternalError(e.to_string()))?;
@@ -462,20 +451,21 @@ impl TextChunk {
#[cfg(test)]
mod tests {
use anyhow::{self, Context};
use super::*;
use crate::storage::indexes::{ensure_runtime_indexes, rebuild_indexes};
use crate::storage::indexes::{ensure_runtime, rebuild};
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
use surrealdb::RecordId;
use uuid::Uuid;
async fn ensure_chunk_fts_index(db: &SurrealDbClient) {
async fn ensure_chunk_fts_index(db: &SurrealDbClient) -> anyhow::Result<()> {
let snowball_sql = r#"
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english);
DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25;
"#;
if let Err(err) = db.client.query(snowball_sql).await {
// Fall back to ascii-only analyzer when snowball is unavailable in the build.
let fallback_sql = r#"
DEFINE ANALYZER OVERWRITE app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii;
DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25;
@@ -484,12 +474,13 @@ mod tests {
db.client
.query(fallback_sql)
.await
.unwrap_or_else(|_| panic!("define chunk fts index fallback: {err}"));
.with_context(|| format!("define chunk fts index fallback: {err}"))?;
}
Ok(())
}
#[tokio::test]
async fn test_text_chunk_creation() {
async fn test_text_chunk_creation() -> anyhow::Result<()> {
let source_id = "source123".to_string();
let chunk = "This is a text chunk for testing embeddings".to_string();
let user_id = "user123".to_string();
@@ -500,22 +491,23 @@ mod tests {
assert_eq!(text_chunk.chunk, chunk);
assert_eq!(text_chunk.user_id, user_id);
assert!(!text_chunk.id.is_empty());
Ok(())
}
#[tokio::test]
async fn test_delete_by_source_id() {
async fn test_delete_by_source_id() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
let source_id = "source123".to_string();
let user_id = "user123".to_string();
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
.await
.expect("redefine index");
.with_context(|| "redefine index".to_string())?;
let chunk1 = TextChunk::new(
source_id.clone(),
@@ -535,61 +527,63 @@ mod tests {
TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
.await
.expect("store chunk1");
.with_context(|| "store chunk1".to_string())?;
TextChunk::store_with_embedding(chunk2.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
.await
.expect("store chunk2");
.with_context(|| "store chunk2".to_string())?;
TextChunk::store_with_embedding(
different_chunk.clone(),
vec![0.1, 0.2, 0.3, 0.4, 0.5],
&db,
)
.await
.expect("store different chunk");
.with_context(|| "store different chunk".to_string())?;
TextChunk::delete_by_source_id(&source_id, &db)
.await
.expect("Failed to delete chunks by source_id");
.with_context(|| "Failed to delete chunks by source_id".to_string())?;
let remaining: Vec<TextChunk> = db
.client
.query(format!(
"SELECT * FROM {} WHERE source_id = '{}'",
"SELECT * FROM {} WHERE source_id = '{source_id}'",
TextChunk::table_name(),
source_id
))
.await
.expect("Query failed")
.with_context(|| "Query failed".to_string())?
.take(0)
.expect("Failed to get query results");
.with_context(|| "Failed to get query results".to_string())?;
assert_eq!(remaining.len(), 0);
let different_remaining: Vec<TextChunk> = db
.client
.query(format!(
"SELECT * FROM {} WHERE source_id = '{}'",
"SELECT * FROM {} WHERE source_id = 'different_source'",
TextChunk::table_name(),
"different_source"
))
.await
.expect("Query failed")
.with_context(|| "Query failed".to_string())?
.take(0)
.expect("Failed to get query results");
.with_context(|| "Failed to get query results".to_string())?;
assert_eq!(different_remaining.len(), 1);
assert_eq!(different_remaining[0].id, different_chunk.id);
assert_eq!(
different_remaining.first().map(|r| &r.id),
Some(&different_chunk.id)
);
Ok(())
}
#[tokio::test]
async fn test_delete_by_nonexistent_source_id() {
async fn test_delete_by_nonexistent_source_id() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
.await
.expect("redefine index");
.with_context(|| "redefine index".to_string())?;
let real_source_id = "real_source".to_string();
let chunk = TextChunk::new(
@@ -600,24 +594,24 @@ mod tests {
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
.await
.expect("store chunk");
.with_context(|| "store chunk".to_string())?;
TextChunk::delete_by_source_id("nonexistent_source", &db)
.await
.expect("Delete should succeed");
.with_context(|| "Delete should succeed".to_string())?;
let remaining: Vec<TextChunk> = db
.client
.query(format!(
"SELECT * FROM {} WHERE source_id = '{}'",
"SELECT * FROM {} WHERE source_id = '{real_source_id}'",
TextChunk::table_name(),
real_source_id
))
.await
.expect("Query failed")
.with_context(|| "Query failed".to_string())?
.take(0)
.expect("Failed to get query results");
.with_context(|| "Failed to get query results".to_string())?;
assert_eq!(remaining.len(), 1);
Ok(())
}
#[tokio::test]
@@ -672,13 +666,13 @@ mod tests {
}
#[tokio::test]
async fn test_store_with_embedding_creates_both_records() {
async fn test_store_with_embedding_creates_both_records() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
let source_id = "store-src".to_string();
let user_id = "user_store".to_string();
@@ -686,43 +680,43 @@ mod tests {
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("redefine index");
.with_context(|| "redefine index".to_string())?;
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
.await
.expect("store with embedding");
.with_context(|| "store with embedding".to_string())?;
let stored_chunk: Option<TextChunk> = db.get_item(&chunk.id).await.unwrap();
assert!(stored_chunk.is_some());
let stored_chunk = stored_chunk.unwrap();
let stored_chunk: Option<TextChunk> = db.get_item(&chunk.id)
.await
.with_context(|| "get_item".to_string())?;
let stored_chunk = stored_chunk.with_context(|| "expected stored chunk".to_string())?;
assert_eq!(stored_chunk.source_id, source_id);
assert_eq!(stored_chunk.user_id, user_id);
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db)
.await
.expect("get embedding");
assert!(embedding.is_some());
let embedding = embedding.unwrap();
.with_context(|| "get embedding".to_string())?
.with_context(|| "expected embedding".to_string())?;
assert_eq!(embedding.chunk_id, rid);
assert_eq!(embedding.user_id, user_id);
assert_eq!(embedding.source_id, source_id);
Ok(())
}
#[tokio::test]
async fn test_store_with_embedding_with_runtime_indexes() {
async fn test_store_with_embedding_with_runtime_indexes() -> anyhow::Result<()> {
let namespace = "test_ns_runtime";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
// Ensure runtime indexes are built with the expected dimension.
let embedding_dimension = 3usize;
ensure_runtime_indexes(&db, embedding_dimension)
ensure_runtime(&db, embedding_dimension)
.await
.expect("ensure runtime indexes");
.with_context(|| "ensure runtime indexes".to_string())?;
let chunk = TextChunk::new(
"runtime_src".to_string(),
@@ -732,55 +726,60 @@ mod tests {
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
.await
.expect("store with embedding");
.with_context(|| "store with embedding".to_string())?;
let stored_chunk: Option<TextChunk> = db.get_item(&chunk.id).await.unwrap();
assert!(stored_chunk.is_some(), "chunk should be stored");
let stored_chunk: Option<TextChunk> = db.get_item(&chunk.id)
.await
.with_context(|| "get_item".to_string())?;
let stored_chunk = stored_chunk.with_context(|| "chunk should be stored".to_string())?;
assert!(stored_chunk.id == chunk.id, "chunk should be stored");
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db)
.await
.expect("get embedding");
assert!(embedding.is_some(), "embedding should exist");
.with_context(|| "get embedding".to_string())?
.with_context(|| "embedding should exist".to_string())?;
assert_eq!(
embedding.unwrap().embedding.len(),
embedding.embedding.len(),
embedding_dimension,
"embedding dimension should match runtime index"
);
Ok(())
}
#[tokio::test]
async fn test_vector_search_returns_empty_when_no_embeddings() {
async fn test_vector_search_returns_empty_when_no_embeddings() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("redefine index");
.with_context(|| "redefine index".to_string())?;
let results: Vec<TextChunkSearchResult> =
TextChunk::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user")
.await
.unwrap();
.with_context(|| "vector_search".to_string())?;
assert!(results.is_empty());
Ok(())
}
#[tokio::test]
async fn test_vector_search_single_result() {
async fn test_vector_search_single_result() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("redefine index");
.with_context(|| "redefine index".to_string())?;
let source_id = "src".to_string();
let user_id = "user".to_string();
@@ -792,32 +791,33 @@ mod tests {
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
.await
.expect("store");
.with_context(|| "store".to_string())?;
let results: Vec<TextChunkSearchResult> =
TextChunk::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
.await
.unwrap();
.with_context(|| "vector_search".to_string())?;
assert_eq!(results.len(), 1);
let res = &results[0];
let res = results.first().context("expected first result")?;
assert_eq!(res.chunk.id, chunk.id);
assert_eq!(res.chunk.source_id, source_id);
assert_eq!(res.chunk.chunk, "hello world");
Ok(())
}
#[tokio::test]
async fn test_vector_search_orders_by_similarity() {
async fn test_vector_search_orders_by_similarity() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("redefine index");
.with_context(|| "redefine index".to_string())?;
let user_id = "user".to_string();
let chunk1 = TextChunk::new("s1".to_string(), "chunk one".to_string(), user_id.clone());
@@ -825,49 +825,59 @@ mod tests {
TextChunk::store_with_embedding(chunk1.clone(), vec![1.0, 0.0, 0.0], &db)
.await
.expect("store chunk1");
.with_context(|| "store chunk1".to_string())?;
TextChunk::store_with_embedding(chunk2.clone(), vec![0.0, 1.0, 0.0], &db)
.await
.expect("store chunk2");
.with_context(|| "store chunk2".to_string())?;
let results: Vec<TextChunkSearchResult> =
TextChunk::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id)
.await
.unwrap();
.with_context(|| "vector_search".to_string())?;
assert_eq!(results.len(), 2);
assert_eq!(results[0].chunk.id, chunk2.id);
assert_eq!(results[1].chunk.id, chunk1.id);
assert!(results[0].score >= results[1].score);
assert_eq!(
results.first().map(|r| &r.chunk.id),
Some(&chunk2.id)
);
assert_eq!(
results.get(1).map(|r| &r.chunk.id),
Some(&chunk1.id)
);
let r0 = results.first().context("expected first result")?;
let r1 = results.get(1).context("expected second result")?;
assert!(r0.score >= r1.score);
Ok(())
}
#[tokio::test]
async fn test_fts_search_returns_empty_when_no_chunks() {
async fn test_fts_search_returns_empty_when_no_chunks() -> anyhow::Result<()> {
let namespace = "fts_chunk_ns_empty";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
ensure_chunk_fts_index(&db).await;
rebuild_indexes(&db).await.expect("rebuild indexes");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
ensure_chunk_fts_index(&db).await?;
rebuild(&db).await.with_context(|| "rebuild indexes".to_string())?;
let results = TextChunk::fts_search(5, "hello", &db, "user")
.await
.expect("fts search");
.with_context(|| "fts search".to_string())?;
assert!(results.is_empty());
Ok(())
}
#[tokio::test]
async fn test_fts_search_single_result() {
async fn test_fts_search_single_result() -> anyhow::Result<()> {
let namespace = "fts_chunk_ns_single";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
ensure_chunk_fts_index(&db).await;
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
ensure_chunk_fts_index(&db).await?;
let user_id = "fts_user";
let chunk = TextChunk::new(
@@ -875,27 +885,29 @@ mod tests {
"rustaceans love rust".to_string(),
user_id.to_string(),
);
db.store_item(chunk.clone()).await.expect("store chunk");
rebuild_indexes(&db).await.expect("rebuild indexes");
db.store_item(chunk.clone()).await.with_context(|| "store chunk".to_string())?;
rebuild(&db).await.with_context(|| "rebuild indexes".to_string())?;
let results = TextChunk::fts_search(3, "rust", &db, user_id)
.await
.expect("fts search");
.with_context(|| "fts search".to_string())?;
assert_eq!(results.len(), 1);
assert_eq!(results[0].chunk.id, chunk.id);
assert!(results[0].score.is_finite(), "expected a finite FTS score");
let r0 = results.first().context("expected first result")?;
assert_eq!(r0.chunk.id, chunk.id);
assert!(r0.score.is_finite(), "expected a finite FTS score");
Ok(())
}
#[tokio::test]
async fn test_fts_search_orders_by_score_and_filters_user() {
async fn test_fts_search_orders_by_score_and_filters_user() -> anyhow::Result<()> {
let namespace = "fts_chunk_ns_order";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
ensure_chunk_fts_index(&db).await;
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
ensure_chunk_fts_index(&db).await?;
let user_id = "fts_user_order";
let high_score_chunk = TextChunk::new(
@@ -916,18 +928,18 @@ mod tests {
db.store_item(high_score_chunk.clone())
.await
.expect("store high score chunk");
.with_context(|| "store high score chunk".to_string())?;
db.store_item(low_score_chunk.clone())
.await
.expect("store low score chunk");
.with_context(|| "store low score chunk".to_string())?;
db.store_item(other_user_chunk)
.await
.expect("store other user chunk");
rebuild_indexes(&db).await.expect("rebuild indexes");
.with_context(|| "store other user chunk".to_string())?;
rebuild(&db).await.with_context(|| "rebuild indexes".to_string())?;
let results = TextChunk::fts_search(3, "apple", &db, user_id)
.await
.expect("fts search");
.with_context(|| "fts search".to_string())?;
assert_eq!(results.len(), 2);
let ids: Vec<_> = results.iter().map(|r| r.chunk.id.as_str()).collect();
@@ -936,9 +948,12 @@ mod tests {
&& ids.contains(&low_score_chunk.id.as_str()),
"expected only the two chunks for the same user"
);
let r0 = results.first().context("expected first result")?;
let r1 = results.get(1).context("expected second result")?;
assert!(
results[0].score >= results[1].score,
r0.score >= r1.score,
"expected results ordered by descending score"
);
Ok(())
}
}
@@ -126,24 +126,26 @@ impl TextChunkEmbedding {
#[cfg(test)]
mod tests {
use anyhow::{self, Context};
use super::*;
use crate::storage::db::SurrealDbClient;
use surrealdb::Value as SurrealValue;
use uuid::Uuid;
/// Helper to create an in-memory DB and apply migrations
async fn setup_test_db() -> SurrealDbClient {
async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
let namespace = "test_ns";
let database = Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, &database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations()
.await
.expect("Failed to apply migrations");
.with_context(|| "Failed to apply migrations".to_string())?;
db
Ok(db)
}
/// Helper: create a text_chunk with a known key, return its RecordId
@@ -152,7 +154,7 @@ mod tests {
key: &str,
source_id: &str,
user_id: &str,
) -> RecordId {
) -> anyhow::Result<RecordId> {
let chunk = TextChunk {
id: key.to_owned(),
created_at: Utc::now(),
@@ -164,21 +166,42 @@ mod tests {
db.store_item(chunk)
.await
.expect("Failed to create text_chunk");
.with_context(|| "Failed to create text_chunk".to_string())?;
RecordId::from_table_key(TextChunk::table_name(), key)
Ok(RecordId::from_table_key(TextChunk::table_name(), key))
}
async fn get_idx_sql(db: &SurrealDbClient) -> anyhow::Result<String> {
let mut info_res = db
.client
.query("INFO FOR TABLE text_chunk_embedding;")
.await
.with_context(|| "info query failed".to_string())?;
let info: SurrealValue = info_res.take(0).with_context(|| "failed to take info result".to_string())?;
let info_json: serde_json::Value =
serde_json::to_value(info).with_context(|| "failed to convert info to json".to_string())?;
let idx_sql = info_json
.get("Object")
.and_then(|v| v.get("indexes"))
.and_then(|v| v.get("Object"))
.and_then(|v| v.get("idx_embedding_text_chunk_embedding"))
.and_then(|v| v.get("Strand"))
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string();
Ok(idx_sql)
}
#[tokio::test]
async fn test_create_and_get_by_chunk_id() {
let db = setup_test_db().await;
async fn test_create_and_get_by_chunk_id() -> anyhow::Result<()> {
let db = setup_test_db().await?;
let user_id = "user_a";
let chunk_key = "chunk-123";
let source_id = "source-1";
// 1) Create a text_chunk with a known key
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await;
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await?;
// 2) Create and store an embedding for that chunk
let embedding_vec = vec![0.1_f32, 0.2, 0.3];
@@ -191,39 +214,37 @@ mod tests {
TextChunkEmbedding::redefine_hnsw_index(&db, emb.embedding.len())
.await
.expect("Failed to redefine index length");
.with_context(|| "Failed to redefine index length".to_string())?;
let _: Option<TextChunkEmbedding> = db
.client
.create(TextChunkEmbedding::table_name())
.content(emb)
.await
.expect("Failed to store embedding")
.take()
.expect("Failed to deserialize stored embedding");
.with_context(|| "Failed to store embedding".to_string())?
.with_context(|| "Failed to deserialize stored embedding".to_string())?;
// 3) Fetch it via get_by_chunk_id
let fetched = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
.await
.expect("Failed to get embedding by chunk_id");
assert!(fetched.is_some(), "Expected an embedding to be found");
let fetched = fetched.unwrap();
.with_context(|| "Failed to get embedding by chunk_id".to_string())?
.with_context(|| "Expected an embedding to be found".to_string())?;
assert_eq!(fetched.user_id, user_id);
assert_eq!(fetched.chunk_id, chunk_rid);
assert_eq!(fetched.embedding, embedding_vec);
Ok(())
}
#[tokio::test]
async fn test_delete_by_chunk_id() {
let db = setup_test_db().await;
async fn test_delete_by_chunk_id() -> anyhow::Result<()> {
let db = setup_test_db().await?;
let user_id = "user_b";
let chunk_key = "chunk-delete";
let source_id = "source-del";
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await;
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await?;
let emb = TextChunkEmbedding::new(
chunk_key,
@@ -234,50 +255,50 @@ mod tests {
TextChunkEmbedding::redefine_hnsw_index(&db, emb.embedding.len())
.await
.expect("Failed to redefine index length");
.with_context(|| "Failed to redefine index length".to_string())?;
let _: Option<TextChunkEmbedding> = db
.client
.create(TextChunkEmbedding::table_name())
.content(emb)
.await
.expect("Failed to store embedding")
.take()
.expect("Failed to deserialize stored embedding");
.with_context(|| "Failed to store embedding".to_string())?
.with_context(|| "Failed to deserialize stored embedding".to_string())?;
// Ensure it exists
let existing = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
.await
.expect("Failed to get embedding before delete");
.with_context(|| "Failed to get embedding before delete".to_string())?;
assert!(existing.is_some(), "Embedding should exist before delete");
// Delete by chunk_id
TextChunkEmbedding::delete_by_chunk_id(&chunk_rid, &db)
.await
.expect("Failed to delete by chunk_id");
.with_context(|| "Failed to delete by chunk_id".to_string())?;
// Ensure it no longer exists
let after = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
.await
.expect("Failed to get embedding after delete");
.with_context(|| "Failed to get embedding after delete".to_string())?;
assert!(after.is_none(), "Embedding should have been deleted");
Ok(())
}
#[tokio::test]
async fn test_delete_by_source_id() {
let db = setup_test_db().await;
async fn test_delete_by_source_id() -> anyhow::Result<()> {
let db = setup_test_db().await?;
let user_id = "user_c";
let source_id = "shared-source";
let other_source = "other-source";
// Two chunks with the same source_id
let chunk1_rid = create_text_chunk_with_id(&db, "chunk-s1", source_id, user_id).await;
let chunk2_rid = create_text_chunk_with_id(&db, "chunk-s2", source_id, user_id).await;
let chunk1_rid = create_text_chunk_with_id(&db, "chunk-s1", source_id, user_id).await?;
let chunk2_rid = create_text_chunk_with_id(&db, "chunk-s2", source_id, user_id).await?;
// One chunk with a different source_id
let chunk_other_rid =
create_text_chunk_with_id(&db, "chunk-other", other_source, user_id).await;
create_text_chunk_with_id(&db, "chunk-other", other_source, user_id).await?;
// Create embeddings for all three
let emb1 = TextChunkEmbedding::new(
@@ -302,7 +323,7 @@ mod tests {
// Update length on index
TextChunkEmbedding::redefine_hnsw_index(&db, emb1.embedding.len())
.await
.expect("Failed to redefine index length");
.with_context(|| "Failed to redefine index length".to_string())?;
for emb in [emb1, emb2, emb3] {
let _: Option<TextChunkEmbedding> = db
@@ -310,102 +331,82 @@ mod tests {
.create(TextChunkEmbedding::table_name())
.content(emb)
.await
.expect("Failed to store embedding")
.take()
.expect("Failed to deserialize stored embedding");
.with_context(|| "Failed to store embedding".to_string())?
.with_context(|| "Failed to deserialize stored embedding".to_string())?;
}
// Sanity check: they all exist
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
.await
.unwrap()
.with_context(|| "get chunk1".to_string())?
.is_some());
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
.await
.unwrap()
.with_context(|| "get chunk2".to_string())?
.is_some());
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
.await
.unwrap()
.with_context(|| "get chunk_other".to_string())?
.is_some());
// Delete embeddings by source_id (shared-source)
TextChunkEmbedding::delete_by_source_id(source_id, &db)
.await
.expect("Failed to delete by source_id");
.with_context(|| "Failed to delete by source_id".to_string())?;
// Chunks from shared-source should have no embeddings
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
.await
.unwrap()
.with_context(|| "check chunk1".to_string())?
.is_none());
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
.await
.unwrap()
.with_context(|| "check chunk2".to_string())?
.is_none());
// The other chunk should still have its embedding
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
.await
.unwrap()
.with_context(|| "check chunk_other".to_string())?
.is_some());
Ok(())
}
#[tokio::test]
async fn test_redefine_hnsw_index_updates_dimension() {
let db = setup_test_db().await;
async fn test_redefine_hnsw_index_updates_dimension() -> anyhow::Result<()> {
let db = setup_test_db().await?;
// Change the index dimension from default (1536) to a smaller test value.
TextChunkEmbedding::redefine_hnsw_index(&db, 8)
.await
.expect("failed to redefine index");
.with_context(|| "failed to redefine index".to_string())?;
let mut info_res = db
.client
.query("INFO FOR TABLE text_chunk_embedding;")
.await
.expect("info query failed");
let info: SurrealValue = info_res.take(0).expect("failed to take info result");
let info_json: serde_json::Value =
serde_json::to_value(info).expect("failed to convert info to json");
let idx_sql = info_json["Object"]["indexes"]["Object"]
["idx_embedding_text_chunk_embedding"]["Strand"]
.as_str()
.unwrap_or_default();
let idx_sql = get_idx_sql(&db).await?;
assert!(
idx_sql.contains("DIMENSION 8"),
"expected index definition to contain new dimension, got: {idx_sql}"
);
Ok(())
}
#[tokio::test]
async fn test_redefine_hnsw_index_is_idempotent() {
let db = setup_test_db().await;
async fn test_redefine_hnsw_index_is_idempotent() -> anyhow::Result<()> {
let db = setup_test_db().await?;
TextChunkEmbedding::redefine_hnsw_index(&db, 4)
.await
.expect("first redefine failed");
.with_context(|| "first redefine failed".to_string())?;
TextChunkEmbedding::redefine_hnsw_index(&db, 4)
.await
.expect("second redefine failed");
.with_context(|| "second redefine failed".to_string())?;
let mut info_res = db
.client
.query("INFO FOR TABLE text_chunk_embedding;")
.await
.expect("info query failed");
let info: SurrealValue = info_res.take(0).expect("failed to take info result");
let info_json: serde_json::Value =
serde_json::to_value(info).expect("failed to convert info to json");
let idx_sql = info_json["Object"]["indexes"]["Object"]
["idx_embedding_text_chunk_embedding"]["Strand"]
.as_str()
.unwrap_or_default();
let idx_sql = get_idx_sql(&db).await?;
assert!(
idx_sql.contains("DIMENSION 4"),
"expected index definition to retain dimension 4, got: {idx_sql}"
);
Ok(())
}
}
+25 -21
View File
@@ -185,10 +185,12 @@ impl TextContent {
#[cfg(test)]
mod tests {
use anyhow::{self, Context};
use super::*;
#[tokio::test]
async fn test_text_content_creation() {
async fn test_text_content_creation() -> anyhow::Result<()> {
// Test basic object creation
let text = "Test content text".to_string();
let context = "Test context".to_string();
@@ -212,10 +214,11 @@ mod tests {
assert!(text_content.file_info.is_none());
assert!(text_content.url_info.is_none());
assert!(!text_content.id.is_empty());
Ok(())
}
#[tokio::test]
async fn test_text_content_with_url() {
async fn test_text_content_with_url() -> anyhow::Result<()> {
// Test creating with URL
let text = "Content with URL".to_string();
let context = "URL context".to_string();
@@ -232,26 +235,27 @@ mod tests {
});
let text_content = TextContent::new(
text.clone(),
Some(context.clone()),
category.clone(),
text,
Some(context),
category,
None,
url_info.clone(),
user_id.clone(),
user_id,
);
// Check URL field is set
assert_eq!(text_content.url_info, url_info);
Ok(())
}
#[tokio::test]
async fn test_text_content_patch() {
async fn test_text_content_patch() -> anyhow::Result<()> {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Create initial text content
let initial_text = "Initial text".to_string();
@@ -272,7 +276,7 @@ mod tests {
let stored: Option<TextContent> = db
.store_item(text_content.clone())
.await
.expect("Failed to store text content");
.with_context(|| "Failed to store text content".to_string())?;
assert!(stored.is_some());
// New values for patch
@@ -283,31 +287,30 @@ mod tests {
// Apply the patch
TextContent::patch(&text_content.id, new_context, new_category, new_text, &db)
.await
.expect("Failed to patch text content");
.with_context(|| "Failed to patch text content".to_string())?;
// Retrieve the updated content
let updated: Option<TextContent> = db
.get_item(&text_content.id)
.await
.expect("Failed to get updated text content");
assert!(updated.is_some());
let updated_content = updated.unwrap();
.with_context(|| "Failed to get updated text content".to_string())?;
let updated_content = updated.with_context(|| "expected updated content".to_string())?;
// Verify the updates
assert_eq!(updated_content.context, Some(new_context.to_string()));
assert_eq!(updated_content.category, new_category);
assert_eq!(updated_content.text, new_text);
assert!(updated_content.updated_at > text_content.updated_at);
Ok(())
}
#[tokio::test]
async fn test_has_other_with_file_detects_shared_usage() {
async fn test_has_other_with_file_detects_shared_usage() -> anyhow::Result<()> {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
let user_id = "user123".to_string();
let file_info = FileInfo {
@@ -340,24 +343,25 @@ mod tests {
db.store_item(content_a.clone())
.await
.expect("Failed to store first content");
.with_context(|| "Failed to store first content".to_string())?;
db.store_item(content_b.clone())
.await
.expect("Failed to store second content");
.with_context(|| "Failed to store second content".to_string())?;
let has_other = TextContent::has_other_with_file(&file_info.id, &content_a.id, &db)
.await
.expect("Failed to check for shared file usage");
.with_context(|| "Failed to check for shared file usage".to_string())?;
assert!(has_other);
let _removed: Option<TextContent> = db
.delete_item(&content_b.id)
.await
.expect("Failed to delete second content");
.with_context(|| "Failed to delete second content".to_string())?;
let has_other_after = TextContent::has_other_with_file(&file_info.id, &content_a.id, &db)
.await
.expect("Failed to check shared usage after delete");
.with_context(|| "Failed to check shared usage after delete".to_string())?;
assert!(!has_other_after);
Ok(())
}
}
+108 -95
View File
@@ -723,30 +723,32 @@ impl User {
#[cfg(test)]
mod tests {
use anyhow::{self, Context};
use super::*;
use crate::storage::types::ingestion_payload::IngestionPayload;
use crate::storage::types::ingestion_task::{IngestionTask, TaskState, MAX_ATTEMPTS};
use std::collections::HashSet;
// Helper function to set up a test database with SystemSettings
async fn setup_test_db() -> SurrealDbClient {
async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
let namespace = "test_ns";
let database = Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, &database)
.await
.expect("Failed to start in-memory surrealdb");
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations()
.await
.expect("Failed to setup the migrations");
.with_context(|| "Failed to setup the migrations".to_string())?;
db
Ok(db)
}
#[tokio::test]
async fn test_user_creation() {
async fn test_user_creation() -> anyhow::Result<()> {
// Setup test database
let db = setup_test_db().await;
let db = setup_test_db().await?;
// Create a user
let email = "test@example.com";
@@ -761,7 +763,7 @@ mod tests {
"system".to_string(),
)
.await
.expect("Failed to create user");
.with_context(|| "Failed to create user".to_string())?;
// Verify user properties
assert!(!user.id.is_empty());
@@ -774,18 +776,17 @@ mod tests {
let retrieved: Option<User> = db
.get_item(&user.id)
.await
.expect("Failed to retrieve user");
assert!(retrieved.is_some());
let retrieved = retrieved.unwrap();
.with_context(|| "Failed to retrieve user".to_string())?;
let retrieved = retrieved.with_context(|| "expected user to exist".to_string())?;
assert_eq!(retrieved.id, user.id);
assert_eq!(retrieved.email, email);
Ok(())
}
#[tokio::test]
async fn test_user_authentication() {
async fn test_user_authentication() -> anyhow::Result<()> {
// Setup test database
let db = setup_test_db().await;
let db = setup_test_db().await?;
// Create a user
let email = "auth_test@example.com";
@@ -799,7 +800,7 @@ mod tests {
"system".to_string(),
)
.await
.expect("Failed to create user");
.with_context(|| "Failed to create user".to_string())?;
// Test successful authentication
let auth_result = User::authenticate(email, password, &db).await;
@@ -812,11 +813,12 @@ mod tests {
// Test failed authentication with non-existent user
let nonexistent = User::authenticate("nonexistent@example.com", password, &db).await;
assert!(nonexistent.is_err());
Ok(())
}
#[tokio::test]
async fn test_get_unfinished_ingestion_tasks_filters_correctly() {
let db = setup_test_db().await;
async fn test_get_unfinished_ingestion_tasks_filters_correctly() -> anyhow::Result<()> {
let db = setup_test_db().await?;
let user_id = "unfinished_user";
let other_user_id = "other_user";
@@ -830,14 +832,14 @@ mod tests {
let created_task = IngestionTask::new(payload.clone(), user_id.to_string());
db.store_item(created_task.clone())
.await
.expect("Failed to store created task");
.with_context(|| "Failed to store created task".to_string())?;
let mut processing_task = IngestionTask::new(payload.clone(), user_id.to_string());
processing_task.state = TaskState::Processing;
processing_task.attempts = 1;
db.store_item(processing_task.clone())
.await
.expect("Failed to store processing task");
.with_context(|| "Failed to store processing task".to_string())?;
let mut failed_retry_task = IngestionTask::new(payload.clone(), user_id.to_string());
failed_retry_task.state = TaskState::Failed;
@@ -845,7 +847,7 @@ mod tests {
failed_retry_task.scheduled_at = chrono::Utc::now() - chrono::Duration::minutes(5);
db.store_item(failed_retry_task.clone())
.await
.expect("Failed to store retryable failed task");
.with_context(|| "Failed to store retryable failed task".to_string())?;
let mut failed_blocked_task = IngestionTask::new(payload.clone(), user_id.to_string());
failed_blocked_task.state = TaskState::Failed;
@@ -853,13 +855,13 @@ mod tests {
failed_blocked_task.error_message = Some("Too many failures".into());
db.store_item(failed_blocked_task.clone())
.await
.expect("Failed to store blocked task");
.with_context(|| "Failed to store blocked task".to_string())?;
let mut completed_task = IngestionTask::new(payload.clone(), user_id.to_string());
completed_task.state = TaskState::Succeeded;
db.store_item(completed_task.clone())
.await
.expect("Failed to store completed task");
.with_context(|| "Failed to store completed task".to_string())?;
let other_payload = IngestionPayload::Text {
text: "Other".to_string(),
@@ -870,11 +872,11 @@ mod tests {
let other_task = IngestionTask::new(other_payload, other_user_id.to_string());
db.store_item(other_task)
.await
.expect("Failed to store other user task");
.with_context(|| "Failed to store other user task".to_string())?;
let unfinished = User::get_unfinished_ingestion_tasks(user_id, &db)
.await
.expect("Failed to fetch unfinished tasks");
.with_context(|| "Failed to fetch unfinished tasks".to_string())?;
let unfinished_ids: HashSet<String> =
unfinished.iter().map(|task| task.id.clone()).collect();
@@ -885,11 +887,12 @@ mod tests {
assert!(!unfinished_ids.contains(&failed_blocked_task.id));
assert!(!unfinished_ids.contains(&completed_task.id));
assert_eq!(unfinished_ids.len(), 3);
Ok(())
}
#[tokio::test]
async fn test_get_all_ingestion_tasks_returns_sorted() {
let db = setup_test_db().await;
async fn test_get_all_ingestion_tasks_returns_sorted() -> anyhow::Result<()> {
let db = setup_test_db().await?;
let user_id = "archive_user";
let other_user_id = "other_user";
@@ -902,15 +905,15 @@ mod tests {
// Oldest task
let mut first = IngestionTask::new(payload.clone(), user_id.to_string());
first.created_at = first.created_at - chrono::Duration::minutes(1);
first.created_at -= chrono::Duration::minutes(1);
first.updated_at = first.created_at;
first.state = TaskState::Succeeded;
db.store_item(first.clone()).await.expect("store first");
db.store_item(first.clone()).await.with_context(|| "store first".to_string())?;
// Latest task
let mut second = IngestionTask::new(payload.clone(), user_id.to_string());
second.state = TaskState::Processing;
db.store_item(second.clone()).await.expect("store second");
db.store_item(second.clone()).await.with_context(|| "store second".to_string())?;
let other_payload = IngestionPayload::Text {
text: "Other".to_string(),
@@ -919,21 +922,22 @@ mod tests {
user_id: other_user_id.to_string(),
};
let other_task = IngestionTask::new(other_payload, other_user_id.to_string());
db.store_item(other_task).await.expect("store other");
db.store_item(other_task).await.with_context(|| "store other".to_string())?;
let tasks = User::get_all_ingestion_tasks(user_id, &db)
.await
.expect("fetch all tasks");
.with_context(|| "fetch all tasks".to_string())?;
assert_eq!(tasks.len(), 2);
assert_eq!(tasks[0].id, second.id); // newest first
assert_eq!(tasks[1].id, first.id);
assert_eq!(tasks.first().map(|t| &t.id), Some(&second.id)); // newest first
assert_eq!(tasks.get(1).map(|t| &t.id), Some(&first.id));
Ok(())
}
#[tokio::test]
async fn test_find_by_email() {
async fn test_find_by_email() -> anyhow::Result<()> {
// Setup test database
let db = setup_test_db().await;
let db = setup_test_db().await?;
// Create a user
let email = "find_test@example.com";
@@ -947,28 +951,28 @@ mod tests {
"system".to_string(),
)
.await
.expect("Failed to create user");
.with_context(|| "Failed to create user".to_string())?;
// Test finding user by email
let found_user = User::find_by_email(email, &db)
.await
.expect("Error searching for user");
assert!(found_user.is_some());
let found_user = found_user.unwrap();
.with_context(|| "Error searching for user".to_string())?
.with_context(|| "expected user to exist".to_string())?;
assert_eq!(found_user.id, created_user.id);
assert_eq!(found_user.email, email);
// Test finding non-existent user
let not_found = User::find_by_email("nonexistent@example.com", &db)
.await
.expect("Error searching for user");
.with_context(|| "Error searching for user".to_string())?;
assert!(not_found.is_none());
Ok(())
}
#[tokio::test]
async fn test_api_key_management() {
async fn test_api_key_management() -> anyhow::Result<()> {
// Setup test database
let db = setup_test_db().await;
let db = setup_test_db().await?;
// Create a user
let email = "apikey_test@example.com";
@@ -982,7 +986,7 @@ mod tests {
"system".to_string(),
)
.await
.expect("Failed to create user");
.with_context(|| "Failed to create user".to_string())?;
// Initially, user should have no API key
assert!(user.api_key.is_none());
@@ -990,7 +994,7 @@ mod tests {
// Generate API key
let api_key = User::set_api_key(&user.id, &db)
.await
.expect("Failed to set API key");
.with_context(|| "Failed to set API key".to_string())?;
assert!(!api_key.is_empty());
assert!(api_key.starts_with("sk_"));
@@ -998,38 +1002,36 @@ mod tests {
let updated_user: Option<User> = db
.get_item(&user.id)
.await
.expect("Failed to retrieve user");
assert!(updated_user.is_some());
let updated_user = updated_user.unwrap();
.with_context(|| "Failed to retrieve user".to_string())?;
let updated_user = updated_user.with_context(|| "expected updated user".to_string())?;
assert_eq!(updated_user.api_key, Some(api_key.clone()));
// Test finding user by API key
let found_user = User::find_by_api_key(&api_key, &db)
.await
.expect("Error searching by API key");
assert!(found_user.is_some());
let found_user = found_user.unwrap();
.with_context(|| "Error searching by API key".to_string())?
.with_context(|| "expected user found by api key".to_string())?;
assert_eq!(found_user.id, user.id);
// Revoke API key
User::revoke_api_key(&user.id, &db)
.await
.expect("Failed to revoke API key");
.with_context(|| "Failed to revoke API key".to_string())?;
// Verify API key was revoked
let revoked_user: Option<User> = db
.get_item(&user.id)
.await
.expect("Failed to retrieve user");
assert!(revoked_user.is_some());
let revoked_user = revoked_user.unwrap();
.with_context(|| "Failed to retrieve user".to_string())?;
let revoked_user = revoked_user.with_context(|| "expected revoked user".to_string())?;
assert!(revoked_user.api_key.is_none());
// Test searching by revoked API key
let not_found = User::find_by_api_key(&api_key, &db)
.await
.expect("Error searching by API key");
.with_context(|| "Error searching by API key".to_string())?;
assert!(not_found.is_none());
Ok(())
}
#[tokio::test]
@@ -1069,9 +1071,9 @@ mod tests {
}
#[tokio::test]
async fn test_password_update() {
async fn test_password_update() -> anyhow::Result<()> {
// Setup test database
let db = setup_test_db().await;
let db = setup_test_db().await?;
// Create a user
let email = "pwd_test@example.com";
@@ -1086,7 +1088,7 @@ mod tests {
"system".to_string(),
)
.await
.expect("Failed to create user");
.with_context(|| "Failed to create user".to_string())?;
// Authenticate with old password
let auth_result = User::authenticate(email, old_password, &db).await;
@@ -1095,7 +1097,7 @@ mod tests {
// Update password
User::patch_password(email, new_password, &db)
.await
.expect("Failed to update password");
.with_context(|| "Failed to update password".to_string())?;
// Old password should no longer work
let old_auth = User::authenticate(email, old_password, &db).await;
@@ -1104,10 +1106,11 @@ mod tests {
// New password should work
let new_auth = User::authenticate(email, new_password, &db).await;
assert!(new_auth.is_ok());
Ok(())
}
#[tokio::test]
async fn test_validate_timezone() {
async fn test_validate_timezone() -> anyhow::Result<()> {
// Valid timezones should be accepted as-is
assert_eq!(validate_timezone("America/New_York"), "America/New_York");
assert_eq!(validate_timezone("Europe/London"), "Europe/London");
@@ -1117,12 +1120,13 @@ mod tests {
// Invalid timezones should be replaced with UTC
assert_eq!(validate_timezone("Invalid/Timezone"), "UTC");
assert_eq!(validate_timezone("Not_Real"), "UTC");
Ok(())
}
#[tokio::test]
async fn test_timezone_update() {
async fn test_timezone_update() -> anyhow::Result<()> {
// Setup test database
let db = setup_test_db().await;
let db = setup_test_db().await?;
// Create user with default timezone
let email = "timezone_test@example.com";
@@ -1134,7 +1138,7 @@ mod tests {
"system".to_string(),
)
.await
.expect("Failed to create user");
.with_context(|| "Failed to create user".to_string())?;
assert_eq!(user.timezone, "UTC");
@@ -1142,58 +1146,61 @@ mod tests {
let new_timezone = "Europe/Paris";
User::update_timezone(&user.id, new_timezone, &db)
.await
.expect("Failed to update timezone");
.with_context(|| "Failed to update timezone".to_string())?;
// Verify timezone was updated
let updated_user: Option<User> = db
.get_item(&user.id)
.await
.expect("Failed to retrieve user");
assert!(updated_user.is_some());
let updated_user = updated_user.unwrap();
.with_context(|| "Failed to retrieve user".to_string())?;
let updated_user = updated_user.with_context(|| "expected updated user".to_string())?;
assert_eq!(updated_user.timezone, new_timezone);
Ok(())
}
#[tokio::test]
async fn test_conversations_order() {
let db = setup_test_db().await;
async fn test_conversations_order() -> anyhow::Result<()> {
let db = setup_test_db().await?;
let user_id = "user_order_test";
// Create conversations with varying updated_at timestamps
let mut conversations = Vec::new();
for i in 0..5 {
let mut conv = Conversation::new(user_id.to_string(), format!("Conv {}", i));
let mut conv = Conversation::new(user_id.to_string(), format!("Conv {i}"));
// Fake updated_at i minutes apart
conv.created_at = chrono::Utc::now() - chrono::Duration::minutes(i);
db.store_item(conv.clone())
.await
.expect("Failed to store conversation");
.with_context(|| "Failed to store conversation".to_string())?;
conversations.push(conv);
}
// Retrieve via get_user_conversations - should be ordered by updated_at DESC
let retrieved = User::get_user_conversations(user_id, &db)
.await
.expect("Failed to get conversations");
.with_context(|| "Failed to get conversations".to_string())?;
assert_eq!(retrieved.len(), conversations.len());
for window in retrieved.windows(2) {
// Assert each earlier conversation has updated_at >= later conversation
for pair in retrieved.windows(2) {
let a = pair.first().context("expected first in pair")?;
let b = pair.get(1).context("expected second in pair")?;
assert!(
window[0].created_at >= window[1].created_at,
a.created_at >= b.created_at,
"Conversations not ordered descending by created_at"
);
}
// Check first conversation title matches the most recently updated
let most_recent = conversations.iter().max_by_key(|c| c.created_at).unwrap();
assert_eq!(retrieved[0].id, most_recent.id);
let most_recent = conversations.iter().max_by_key(|c| c.created_at).context("expected most recent")?;
let r0 = retrieved.first().context("expected first result")?;
assert_eq!(r0.id, most_recent.id);
Ok(())
}
#[tokio::test]
async fn test_get_latest_text_contents_returns_last_five() {
let db = setup_test_db().await;
async fn test_get_latest_text_contents_returns_last_five() -> anyhow::Result<()> {
let db = setup_test_db().await?;
let user_id = "latest_text_user";
let mut inserted_ids = Vec::new();
@@ -1201,8 +1208,8 @@ mod tests {
for i in 0..12 {
let mut item = TextContent::new(
format!("Text {}", i),
Some(format!("Context {}", i)),
format!("Text {i}"),
Some(format!("Context {i}")),
"Category".to_string(),
None,
None,
@@ -1215,18 +1222,19 @@ mod tests {
db.store_item(item.clone())
.await
.expect("Failed to store text content");
.with_context(|| "Failed to store text content".to_string())?;
inserted_ids.push(item.id.clone());
}
let latest = User::get_latest_text_contents(user_id, &db)
.await
.expect("Failed to fetch latest text contents");
.with_context(|| "Failed to fetch latest text contents".to_string())?;
assert_eq!(latest.len(), 5, "Expected exactly five items");
let mut expected_ids = inserted_ids[inserted_ids.len() - 5..].to_vec();
let start = inserted_ids.len().saturating_sub(5);
let mut expected_ids = inserted_ids.get(start..).unwrap_or_default().to_vec();
expected_ids.reverse();
let returned_ids: Vec<String> = latest.iter().map(|item| item.id.clone()).collect();
@@ -1235,25 +1243,29 @@ mod tests {
"Latest items did not match expectation"
);
for window in latest.windows(2) {
for pair in latest.windows(2) {
let a = pair.first().context("expected first in pair")?;
let b = pair.get(1).context("expected second in pair")?;
assert!(
window[0].created_at >= window[1].created_at,
a.created_at >= b.created_at,
"Results are not ordered by created_at descending"
);
}
Ok(())
}
#[tokio::test]
async fn test_validate_theme() {
async fn test_validate_theme() -> anyhow::Result<()> {
assert_eq!(validate_theme("light"), Theme::Light);
assert_eq!(validate_theme("dark"), Theme::Dark);
assert_eq!(validate_theme("system"), Theme::System);
assert_eq!(validate_theme("invalid"), Theme::System);
Ok(())
}
#[tokio::test]
async fn test_theme_update() {
let db = setup_test_db().await;
async fn test_theme_update() -> anyhow::Result<()> {
let db = setup_test_db().await?;
let email = "theme_test@example.com";
let user = User::create_new(
email.to_string(),
@@ -1263,30 +1275,31 @@ mod tests {
"system".to_string(),
)
.await
.expect("Failed to create user");
.with_context(|| "Failed to create user".to_string())?;
assert_eq!(user.theme, Theme::System);
User::update_theme(&user.id, "dark", &db)
.await
.expect("update theme");
.with_context(|| "update theme".to_string())?;
let updated = db
.get_item::<User>(&user.id)
.await
.expect("get user")
.unwrap();
.with_context(|| "get user".to_string())?
.with_context(|| "expected user".to_string())?;
assert_eq!(updated.theme, Theme::Dark);
// Invalid theme should default to system (but update_theme calls validate_theme)
User::update_theme(&user.id, "invalid", &db)
.await
.expect("update theme invalid");
.with_context(|| "update theme invalid".to_string())?;
let updated2 = db
.get_item::<User>(&user.id)
.await
.expect("get user")
.unwrap();
.with_context(|| "get user".to_string())?
.with_context(|| "expected user".to_string())?;
assert_eq!(updated2.theme, Theme::System);
Ok(())
}
}