mirror of
https://github.com/perstarkse/minne.git
synced 2026-05-28 10:29:30 +02:00
clippy: adhere to pedantic clippy, uniform test error handling
This commit is contained in:
+35
-28
@@ -202,6 +202,7 @@ impl SurrealDbClient {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
use crate::stored_object;
|
||||
|
||||
use super::*;
|
||||
@@ -212,19 +213,17 @@ mod tests {
|
||||
});
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_initialization_and_crud() {
|
||||
async fn test_initialization_and_crud() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string(); // ensures isolation per test run
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Call your initialization
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to initialize schema");
|
||||
.with_context(|| "Failed to initialize schema".to_string())?;
|
||||
|
||||
// Test basic CRUD
|
||||
let dummy = Dummy {
|
||||
id: "abc".to_string(),
|
||||
name: "first".to_string(),
|
||||
@@ -232,50 +231,50 @@ mod tests {
|
||||
updated_at: Utc::now(),
|
||||
};
|
||||
|
||||
// Store
|
||||
let stored = db.store_item(dummy.clone()).await.expect("Failed to store");
|
||||
let stored = db
|
||||
.store_item(dummy.clone())
|
||||
.await
|
||||
.with_context(|| "Failed to store".to_string())?;
|
||||
assert!(stored.is_some());
|
||||
|
||||
// Read
|
||||
let fetched = db
|
||||
.get_item::<Dummy>(&dummy.id)
|
||||
.await
|
||||
.expect("Failed to fetch");
|
||||
.with_context(|| "Failed to fetch".to_string())?;
|
||||
assert_eq!(fetched, Some(dummy.clone()));
|
||||
|
||||
// Read all
|
||||
let all = db
|
||||
.get_all_stored_items::<Dummy>()
|
||||
.await
|
||||
.expect("Failed to fetch all");
|
||||
.with_context(|| "Failed to fetch all".to_string())?;
|
||||
assert!(all.contains(&dummy));
|
||||
|
||||
// Delete
|
||||
let deleted = db
|
||||
.delete_item::<Dummy>(&dummy.id)
|
||||
.await
|
||||
.expect("Failed to delete");
|
||||
.with_context(|| "Failed to delete".to_string())?;
|
||||
assert_eq!(deleted, Some(dummy));
|
||||
|
||||
// After delete, should not be present
|
||||
let fetch_post = db
|
||||
.get_item::<Dummy>("abc")
|
||||
.await
|
||||
.expect("Failed fetch post delete");
|
||||
.with_context(|| "Failed fetch post delete".to_string())?;
|
||||
assert!(fetch_post.is_none());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn upsert_item_overwrites_existing_records() {
|
||||
async fn upsert_item_overwrites_existing_records() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to initialize schema");
|
||||
.with_context(|| "Failed to initialize schema".to_string())?;
|
||||
|
||||
let mut dummy = Dummy {
|
||||
id: "abc".to_string(),
|
||||
@@ -286,17 +285,21 @@ mod tests {
|
||||
|
||||
db.store_item(dummy.clone())
|
||||
.await
|
||||
.expect("Failed to store initial record");
|
||||
.with_context(|| "Failed to store initial record".to_string())?;
|
||||
|
||||
dummy.name = "updated".to_string();
|
||||
let upserted = db
|
||||
.upsert_item(dummy.clone())
|
||||
.await
|
||||
.expect("Failed to upsert record");
|
||||
.with_context(|| "Failed to upsert record".to_string())?;
|
||||
assert!(upserted.is_some());
|
||||
|
||||
let fetched: Option<Dummy> = db.get_item(&dummy.id).await.expect("fetch after upsert");
|
||||
assert_eq!(fetched.unwrap().name, "updated");
|
||||
let fetched: Option<Dummy> = db
|
||||
.get_item(&dummy.id)
|
||||
.await
|
||||
.with_context(|| "fetch after upsert".to_string())?;
|
||||
let fetched = fetched.ok_or_else(|| anyhow::anyhow!("Expected record to exist after upsert"))?;
|
||||
assert_eq!(fetched.name, "updated");
|
||||
|
||||
let new_record = Dummy {
|
||||
id: "def".to_string(),
|
||||
@@ -306,25 +309,29 @@ mod tests {
|
||||
};
|
||||
db.upsert_item(new_record.clone())
|
||||
.await
|
||||
.expect("Failed to upsert new record");
|
||||
.with_context(|| "Failed to upsert new record".to_string())?;
|
||||
|
||||
let fetched_new: Option<Dummy> = db
|
||||
.get_item(&new_record.id)
|
||||
.await
|
||||
.expect("fetch inserted via upsert");
|
||||
.with_context(|| "fetch inserted via upsert".to_string())?;
|
||||
assert_eq!(fetched_new, Some(new_record));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_applying_migrations() {
|
||||
async fn test_applying_migrations() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to build indexes");
|
||||
.with_context(|| "Failed to build indexes".to_string())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -159,23 +159,23 @@ impl FtsIndexSpec {
|
||||
|
||||
/// Build runtime Surreal indexes (FTS + HNSW) using concurrent creation with readiness polling.
|
||||
/// Idempotent: safe to call multiple times and will overwrite HNSW definitions when the dimension changes.
|
||||
pub async fn ensure_runtime_indexes(
|
||||
pub async fn ensure_runtime(
|
||||
db: &SurrealDbClient,
|
||||
embedding_dimension: usize,
|
||||
) -> Result<(), AppError> {
|
||||
ensure_runtime_indexes_inner(db, embedding_dimension)
|
||||
ensure_runtime_inner(db, embedding_dimension)
|
||||
.await
|
||||
.map_err(|err| AppError::InternalError(err.to_string()))
|
||||
}
|
||||
|
||||
/// Rebuild known FTS and HNSW indexes, skipping any that are not yet defined.
|
||||
pub async fn rebuild_indexes(db: &SurrealDbClient) -> Result<(), AppError> {
|
||||
rebuild_indexes_inner(db)
|
||||
pub async fn rebuild(db: &SurrealDbClient) -> Result<(), AppError> {
|
||||
rebuild_inner(db)
|
||||
.await
|
||||
.map_err(|err| AppError::InternalError(err.to_string()))
|
||||
}
|
||||
|
||||
async fn ensure_runtime_indexes_inner(
|
||||
async fn ensure_runtime_inner(
|
||||
db: &SurrealDbClient,
|
||||
embedding_dimension: usize,
|
||||
) -> Result<()> {
|
||||
@@ -262,9 +262,8 @@ async fn get_index_status(db: &SurrealDbClient, index_name: &str, table: &str) -
|
||||
.context("checking index status")?;
|
||||
let info: Option<Value> = info_res.take(0).context("failed to take info result")?;
|
||||
|
||||
let info = match info {
|
||||
Some(i) => i,
|
||||
None => return Ok("unknown".to_string()),
|
||||
let Some(info) = info else {
|
||||
return Ok("unknown".to_string());
|
||||
};
|
||||
|
||||
let building = info.get("building");
|
||||
@@ -277,7 +276,7 @@ async fn get_index_status(db: &SurrealDbClient, index_name: &str, table: &str) -
|
||||
Ok(status)
|
||||
}
|
||||
|
||||
async fn rebuild_indexes_inner(db: &SurrealDbClient) -> Result<()> {
|
||||
async fn rebuild_inner(db: &SurrealDbClient) -> Result<()> {
|
||||
debug!("Rebuilding indexes with concurrent definitions");
|
||||
create_fts_analyzer(db).await?;
|
||||
|
||||
@@ -385,10 +384,9 @@ async fn create_fts_analyzer(db: &SurrealDbClient) -> Result<()> {
|
||||
// is unavailable in the running Surreal build. Use IF NOT EXISTS to avoid clobbering
|
||||
// an existing analyzer definition.
|
||||
let snowball_query = format!(
|
||||
"DEFINE ANALYZER IF NOT EXISTS {analyzer}
|
||||
"DEFINE ANALYZER IF NOT EXISTS {FTS_ANALYZER_NAME}
|
||||
TOKENIZERS class
|
||||
FILTERS lowercase, ascii, snowball(english);",
|
||||
analyzer = FTS_ANALYZER_NAME
|
||||
FILTERS lowercase, ascii, snowball(english);"
|
||||
);
|
||||
|
||||
match db.client.query(snowball_query).await {
|
||||
@@ -410,10 +408,9 @@ async fn create_fts_analyzer(db: &SurrealDbClient) -> Result<()> {
|
||||
}
|
||||
|
||||
let fallback_query = format!(
|
||||
"DEFINE ANALYZER IF NOT EXISTS {analyzer}
|
||||
"DEFINE ANALYZER IF NOT EXISTS {FTS_ANALYZER_NAME}
|
||||
TOKENIZERS class
|
||||
FILTERS lowercase, ascii;",
|
||||
analyzer = FTS_ANALYZER_NAME
|
||||
FILTERS lowercase, ascii;"
|
||||
);
|
||||
|
||||
let res = db
|
||||
@@ -446,6 +443,7 @@ async fn create_index_with_polling(
|
||||
table: &str,
|
||||
progress_table: Option<&str>,
|
||||
) -> Result<()> {
|
||||
const MAX_ATTEMPTS: usize = 3;
|
||||
let expected_total = match progress_table {
|
||||
Some(table) => Some(count_table_rows(db, table).await.with_context(|| {
|
||||
format!("counting rows in {table} for index {index_name} progress")
|
||||
@@ -453,10 +451,9 @@ async fn create_index_with_polling(
|
||||
None => None,
|
||||
};
|
||||
|
||||
let mut attempts = 0;
|
||||
const MAX_ATTEMPTS: usize = 3;
|
||||
let mut attempts: usize = 0;
|
||||
loop {
|
||||
attempts += 1;
|
||||
attempts = attempts.saturating_add(1);
|
||||
let res = db
|
||||
.client
|
||||
.query(definition.clone())
|
||||
@@ -527,8 +524,8 @@ async fn poll_index_build_status(
|
||||
break;
|
||||
};
|
||||
|
||||
match snapshot.progress_pct {
|
||||
Some(pct) => debug!(
|
||||
if let Some(pct) = snapshot.progress_pct {
|
||||
debug!(
|
||||
index = %index_name,
|
||||
table = %table,
|
||||
status = snapshot.status,
|
||||
@@ -539,8 +536,9 @@ async fn poll_index_build_status(
|
||||
total = snapshot.total_rows,
|
||||
progress_pct = format_args!("{pct:.1}"),
|
||||
"Index build status"
|
||||
),
|
||||
None => debug!(
|
||||
);
|
||||
} else {
|
||||
debug!(
|
||||
index = %index_name,
|
||||
table = %table,
|
||||
status = snapshot.status,
|
||||
@@ -549,7 +547,7 @@ async fn poll_index_build_status(
|
||||
updated = snapshot.updated,
|
||||
processed = snapshot.processed,
|
||||
"Index build status"
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
if snapshot.is_ready() {
|
||||
@@ -611,17 +609,17 @@ fn parse_index_build_info(
|
||||
|
||||
let initial = building
|
||||
.and_then(|b| b.get("initial"))
|
||||
.and_then(|v| v.as_u64())
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.unwrap_or(0);
|
||||
|
||||
let pending = building
|
||||
.and_then(|b| b.get("pending"))
|
||||
.and_then(|v| v.as_u64())
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.unwrap_or(0);
|
||||
|
||||
let updated = building
|
||||
.and_then(|b| b.get("updated"))
|
||||
.and_then(|v| v.as_u64())
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.unwrap_or(0);
|
||||
|
||||
// `initial` is the number of rows seen when the build started; `updated` accounts for later writes.
|
||||
@@ -631,7 +629,7 @@ fn parse_index_build_info(
|
||||
if total == 0 {
|
||||
0.0
|
||||
} else {
|
||||
((processed as f64 / total as f64).min(1.0)) * 100.0
|
||||
((f64::from(u32::try_from(processed).unwrap_or(u32::MAX)) / f64::from(u32::try_from(total).unwrap_or(1))).min(1.0)) * 100.0
|
||||
}
|
||||
});
|
||||
|
||||
@@ -673,7 +671,7 @@ async fn table_index_definitions(
|
||||
.client
|
||||
.query(info_query)
|
||||
.await
|
||||
.with_context(|| format!("fetching table info for {}", table))?;
|
||||
.with_context(|| format!("fetching table info for {table}"))?;
|
||||
|
||||
let info: surrealdb::Value = response
|
||||
.take(0)
|
||||
@@ -700,12 +698,15 @@ async fn index_exists(db: &SurrealDbClient, table: &str, index_name: &str) -> Re
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use anyhow::{self, Context};
|
||||
use crate::storage::db::SurrealDbClient;
|
||||
use serde_json::json;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_index_build_info_reports_progress() {
|
||||
fn parse_index_build_info_reports_progress() -> anyhow::Result<()> {
|
||||
let info = json!({
|
||||
"building": {
|
||||
"initial": 56894,
|
||||
@@ -715,7 +716,8 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let snapshot = parse_index_build_info(Some(info), Some(61081)).expect("snapshot");
|
||||
let snapshot = parse_index_build_info(Some(info), Some(61081))
|
||||
.context("snapshot")?;
|
||||
assert_eq!(
|
||||
snapshot,
|
||||
IndexBuildSnapshot {
|
||||
@@ -729,16 +731,19 @@ mod tests {
|
||||
}
|
||||
);
|
||||
assert!(!snapshot.is_ready());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_index_build_info_defaults_to_ready_when_no_building_block() {
|
||||
fn parse_index_build_info_defaults_to_ready_when_no_building_block() -> anyhow::Result<()> {
|
||||
// Surreal returns `{}` when the index exists but isn't building.
|
||||
let info = json!({});
|
||||
let snapshot = parse_index_build_info(Some(info), Some(10)).expect("snapshot");
|
||||
let snapshot = parse_index_build_info(Some(info), Some(10))
|
||||
.context("snapshot")?;
|
||||
assert!(snapshot.is_ready());
|
||||
assert_eq!(snapshot.processed, 0);
|
||||
assert_eq!(snapshot.progress_pct, Some(0.0));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -748,48 +753,40 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ensure_runtime_indexes_is_idempotent() {
|
||||
async fn ensure_runtime_is_idempotent() -> anyhow::Result<()> {
|
||||
let namespace = "indexes_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("in-memory db");
|
||||
.context("in-memory db")?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("migrations should succeed");
|
||||
.context("migrations should succeed")?;
|
||||
|
||||
// First run creates everything
|
||||
ensure_runtime_indexes(&db, 1536)
|
||||
.await
|
||||
.expect("initial index creation");
|
||||
|
||||
// Second run should be a no-op and still succeed
|
||||
ensure_runtime_indexes(&db, 1536)
|
||||
.await
|
||||
.expect("second index creation");
|
||||
ensure_runtime(&db, 1536).await
|
||||
.context("first call should succeed")?;
|
||||
ensure_runtime(&db, 1536).await
|
||||
.context("second index creation")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ensure_hnsw_index_overwrites_dimension() {
|
||||
async fn ensure_hnsw_index_overwrites_dimension() -> anyhow::Result<()> {
|
||||
let namespace = "indexes_dim";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("in-memory db");
|
||||
.context("in-memory db")?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("migrations should succeed");
|
||||
.context("migrations should succeed")?;
|
||||
|
||||
// Create initial index with default dimension
|
||||
ensure_runtime_indexes(&db, 1536)
|
||||
.await
|
||||
.expect("initial index creation");
|
||||
|
||||
// Change dimension and ensure overwrite path is exercised
|
||||
ensure_runtime_indexes(&db, 128)
|
||||
.await
|
||||
.expect("overwritten index creation");
|
||||
ensure_runtime(&db, 1536).await
|
||||
.context("initial index creation")?;
|
||||
ensure_runtime(&db, 128).await
|
||||
.context("overwritten index creation")?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
+142
-108
@@ -13,13 +13,13 @@ use object_store::{path::Path as ObjPath, ObjectStore};
|
||||
|
||||
use crate::utils::config::{AppConfig, StorageKind};
|
||||
|
||||
pub type DynStore = Arc<dyn ObjectStore>;
|
||||
pub type DynStorage = Arc<dyn ObjectStore>;
|
||||
|
||||
/// Storage manager with persistent state and proper lifecycle management.
|
||||
#[derive(Clone)]
|
||||
pub struct StorageManager {
|
||||
// Store from objectstore wrapped as dyn
|
||||
store: DynStore,
|
||||
store: DynStorage,
|
||||
// Simple enum to track which kind
|
||||
backend_kind: StorageKind,
|
||||
// Where on disk
|
||||
@@ -46,7 +46,7 @@ impl StorageManager {
|
||||
///
|
||||
/// This method is useful for testing scenarios where you want to inject
|
||||
/// a specific storage backend.
|
||||
pub fn with_backend(store: DynStore, backend_kind: StorageKind) -> Self {
|
||||
pub fn with_backend(store: DynStorage, backend_kind: StorageKind) -> Self {
|
||||
Self {
|
||||
store,
|
||||
backend_kind,
|
||||
@@ -216,7 +216,7 @@ impl StorageManager {
|
||||
/// storage backends with proper error handling and validation.
|
||||
async fn create_storage_backend(
|
||||
cfg: &AppConfig,
|
||||
) -> object_store::Result<(DynStore, Option<PathBuf>)> {
|
||||
) -> object_store::Result<(DynStorage, Option<PathBuf>)> {
|
||||
match cfg.storage {
|
||||
StorageKind::Local => {
|
||||
let base = resolve_base_dir(cfg);
|
||||
@@ -261,9 +261,7 @@ async fn create_storage_backend(
|
||||
builder = builder.with_endpoint(endpoint);
|
||||
}
|
||||
|
||||
if let Some(region) = &cfg.s3_region {
|
||||
builder = builder.with_region(region);
|
||||
}
|
||||
builder = builder.with_region(&cfg.s3_region);
|
||||
|
||||
let store = builder.build()?;
|
||||
Ok((Arc::new(store), None))
|
||||
@@ -342,7 +340,7 @@ pub mod testing {
|
||||
surrealdb_password: "test".into(),
|
||||
surrealdb_namespace: "test".into(),
|
||||
surrealdb_database: "test".into(),
|
||||
data_dir: base.into(),
|
||||
data_dir: base,
|
||||
http_port: 0,
|
||||
openai_base_url: "..".into(),
|
||||
storage: StorageKind::Local,
|
||||
@@ -382,7 +380,7 @@ pub mod testing {
|
||||
#[derive(Clone)]
|
||||
pub struct TestStorageManager {
|
||||
storage: StorageManager,
|
||||
_temp_dir: Option<(String, std::path::PathBuf)>, // For local storage cleanup
|
||||
temp_dir: Option<(String, std::path::PathBuf)>, // For local storage cleanup
|
||||
}
|
||||
|
||||
impl TestStorageManager {
|
||||
@@ -396,7 +394,7 @@ pub mod testing {
|
||||
|
||||
Ok(Self {
|
||||
storage,
|
||||
_temp_dir: None,
|
||||
temp_dir: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -413,7 +411,7 @@ pub mod testing {
|
||||
|
||||
Ok(Self {
|
||||
storage,
|
||||
_temp_dir: resolved,
|
||||
temp_dir: resolved,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -437,7 +435,7 @@ pub mod testing {
|
||||
|
||||
Ok(Self {
|
||||
storage,
|
||||
_temp_dir: None,
|
||||
temp_dir: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -454,7 +452,7 @@ pub mod testing {
|
||||
|
||||
Ok(Self {
|
||||
storage,
|
||||
_temp_dir: temp_dir,
|
||||
temp_dir,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -508,7 +506,7 @@ pub mod testing {
|
||||
impl Drop for TestStorageManager {
|
||||
fn drop(&mut self) {
|
||||
// Clean up temporary directories for local storage
|
||||
if let Some((_, path)) = &self._temp_dir {
|
||||
if let Some((_, path)) = &self.temp_dir {
|
||||
if path.exists() {
|
||||
let _ = std::fs::remove_dir_all(path);
|
||||
}
|
||||
@@ -584,6 +582,7 @@ pub fn split_object_path(path: &str) -> AnyResult<(String, String)> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use anyhow::Context;
|
||||
use crate::utils::config::{PdfIngestMode::LlmFirst, StorageKind};
|
||||
use bytes::Bytes;
|
||||
use uuid::Uuid;
|
||||
@@ -623,11 +622,11 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_memory_basic_operations() {
|
||||
async fn test_storage_manager_memory_basic_operations() -> anyhow::Result<()> {
|
||||
let cfg = test_config_memory();
|
||||
let storage = StorageManager::new(&cfg)
|
||||
.await
|
||||
.expect("create storage manager");
|
||||
.with_context(|| "create storage manager".to_string())?;
|
||||
assert!(storage.local_base_path().is_none());
|
||||
|
||||
let location = "test/data/file.txt";
|
||||
@@ -637,31 +636,33 @@ mod tests {
|
||||
storage
|
||||
.put(location, Bytes::from(data.to_vec()))
|
||||
.await
|
||||
.expect("put");
|
||||
let retrieved = storage.get(location).await.expect("get");
|
||||
.with_context(|| "put".to_string())?;
|
||||
let retrieved = storage.get(location).await.with_context(|| "get".to_string())?;
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
// Test exists
|
||||
assert!(storage.exists(location).await.expect("exists check"));
|
||||
assert!(storage.exists(location).await.with_context(|| "exists check".to_string())?);
|
||||
|
||||
// Test delete
|
||||
storage.delete_prefix("test/data/").await.expect("delete");
|
||||
storage.delete_prefix("test/data/").await.with_context(|| "delete".to_string())?;
|
||||
assert!(!storage
|
||||
.exists(location)
|
||||
.await
|
||||
.expect("exists check after delete"));
|
||||
.with_context(|| "exists check after delete".to_string())?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_local_basic_operations() {
|
||||
async fn test_storage_manager_local_basic_operations() -> anyhow::Result<()> {
|
||||
let base = format!("/tmp/minne_storage_test_{}", Uuid::new_v4());
|
||||
let cfg = test_config(&base);
|
||||
let storage = StorageManager::new(&cfg)
|
||||
.await
|
||||
.expect("create storage manager");
|
||||
.with_context(|| "create storage manager".to_string())?;
|
||||
let resolved_base = storage
|
||||
.local_base_path()
|
||||
.expect("resolved base dir")
|
||||
.with_context(|| "resolved base dir".to_string())?
|
||||
.to_path_buf();
|
||||
assert_eq!(resolved_base, PathBuf::from(&base));
|
||||
|
||||
@@ -672,42 +673,44 @@ mod tests {
|
||||
storage
|
||||
.put(location, Bytes::from(data.to_vec()))
|
||||
.await
|
||||
.expect("put");
|
||||
let retrieved = storage.get(location).await.expect("get");
|
||||
.with_context(|| "put".to_string())?;
|
||||
let retrieved = storage.get(location).await.with_context(|| "get".to_string())?;
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
let object_dir = resolved_base.join("test/data");
|
||||
tokio::fs::metadata(&object_dir)
|
||||
.await
|
||||
.expect("object directory exists after write");
|
||||
.with_context(|| "object directory exists after write".to_string())?;
|
||||
|
||||
// Test exists
|
||||
assert!(storage.exists(location).await.expect("exists check"));
|
||||
assert!(storage.exists(location).await.with_context(|| "exists check".to_string())?);
|
||||
|
||||
// Test delete
|
||||
storage.delete_prefix("test/data/").await.expect("delete");
|
||||
storage.delete_prefix("test/data/").await.with_context(|| "delete".to_string())?;
|
||||
assert!(!storage
|
||||
.exists(location)
|
||||
.await
|
||||
.expect("exists check after delete"));
|
||||
.with_context(|| "exists check after delete".to_string())?);
|
||||
assert!(
|
||||
tokio::fs::metadata(&object_dir).await.is_err(),
|
||||
"object directory should be removed"
|
||||
);
|
||||
tokio::fs::metadata(&resolved_base)
|
||||
.await
|
||||
.expect("base directory remains intact");
|
||||
.with_context(|| "base directory remains intact".to_string())?;
|
||||
|
||||
// Clean up
|
||||
let _ = tokio::fs::remove_dir_all(&base).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_memory_persistence() {
|
||||
async fn test_storage_manager_memory_persistence() -> anyhow::Result<()> {
|
||||
let cfg = test_config_memory();
|
||||
let storage = StorageManager::new(&cfg)
|
||||
.await
|
||||
.expect("create storage manager");
|
||||
.with_context(|| "create storage manager".to_string())?;
|
||||
|
||||
let location = "persistence/test.txt";
|
||||
let data1 = b"first data";
|
||||
@@ -717,32 +720,34 @@ mod tests {
|
||||
storage
|
||||
.put(location, Bytes::from(data1.to_vec()))
|
||||
.await
|
||||
.expect("put first");
|
||||
.with_context(|| "put first".to_string())?;
|
||||
|
||||
// Retrieve and verify first data
|
||||
let retrieved1 = storage.get(location).await.expect("get first");
|
||||
let retrieved1 = storage.get(location).await.with_context(|| "get first".to_string())?;
|
||||
assert_eq!(retrieved1.as_ref(), data1);
|
||||
|
||||
// Overwrite with second data
|
||||
storage
|
||||
.put(location, Bytes::from(data2.to_vec()))
|
||||
.await
|
||||
.expect("put second");
|
||||
.with_context(|| "put second".to_string())?;
|
||||
|
||||
// Retrieve and verify second data
|
||||
let retrieved2 = storage.get(location).await.expect("get second");
|
||||
let retrieved2 = storage.get(location).await.with_context(|| "get second".to_string())?;
|
||||
assert_eq!(retrieved2.as_ref(), data2);
|
||||
|
||||
// Data persists across multiple operations using the same StorageManager
|
||||
assert_ne!(retrieved1.as_ref(), retrieved2.as_ref());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_list_operations() {
|
||||
async fn test_storage_manager_list_operations() -> anyhow::Result<()> {
|
||||
let cfg = test_config_memory();
|
||||
let storage = StorageManager::new(&cfg)
|
||||
.await
|
||||
.expect("create storage manager");
|
||||
.with_context(|| "create storage manager".to_string())?;
|
||||
|
||||
// Create multiple files
|
||||
let files = vec![
|
||||
@@ -755,15 +760,15 @@ mod tests {
|
||||
storage
|
||||
.put(location, Bytes::from(data.to_vec()))
|
||||
.await
|
||||
.expect("put");
|
||||
.with_context(|| "put".to_string())?;
|
||||
}
|
||||
|
||||
// Test listing without prefix
|
||||
let all_files = storage.list(None).await.expect("list all");
|
||||
let all_files = storage.list(None).await.with_context(|| "list all".to_string())?;
|
||||
assert_eq!(all_files.len(), 3);
|
||||
|
||||
// Test listing with prefix
|
||||
let dir1_files = storage.list(Some("dir1/")).await.expect("list dir1");
|
||||
let dir1_files = storage.list(Some("dir1/")).await.with_context(|| "list dir1".to_string())?;
|
||||
assert_eq!(dir1_files.len(), 2);
|
||||
assert!(dir1_files
|
||||
.iter()
|
||||
@@ -776,16 +781,18 @@ mod tests {
|
||||
let empty_files = storage
|
||||
.list(Some("nonexistent/"))
|
||||
.await
|
||||
.expect("list nonexistent");
|
||||
.with_context(|| "list nonexistent".to_string())?;
|
||||
assert_eq!(empty_files.len(), 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_stream_operations() {
|
||||
async fn test_storage_manager_stream_operations() -> anyhow::Result<()> {
|
||||
let cfg = test_config_memory();
|
||||
let storage = StorageManager::new(&cfg)
|
||||
.await
|
||||
.expect("create storage manager");
|
||||
.with_context(|| "create storage manager".to_string())?;
|
||||
|
||||
let location = "stream/test.bin";
|
||||
let content = vec![42u8; 1024 * 64]; // 64KB of data
|
||||
@@ -794,22 +801,24 @@ mod tests {
|
||||
storage
|
||||
.put(location, Bytes::from(content.clone()))
|
||||
.await
|
||||
.expect("put large data");
|
||||
.with_context(|| "put large data".to_string())?;
|
||||
|
||||
// Get as stream
|
||||
let mut stream = storage.get_stream(location).await.expect("get stream");
|
||||
let mut stream = storage.get_stream(location).await.with_context(|| "get stream".to_string())?;
|
||||
let mut collected = Vec::new();
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk.expect("stream chunk");
|
||||
let chunk = chunk.with_context(|| "stream chunk".to_string())?;
|
||||
collected.extend_from_slice(&chunk);
|
||||
}
|
||||
|
||||
assert_eq!(collected, content);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_with_custom_backend() {
|
||||
async fn test_storage_manager_with_custom_backend() -> anyhow::Result<()> {
|
||||
use object_store::memory::InMemory;
|
||||
|
||||
// Create custom memory backend
|
||||
@@ -823,20 +832,22 @@ mod tests {
|
||||
storage
|
||||
.put(location, Bytes::from(data.to_vec()))
|
||||
.await
|
||||
.expect("put");
|
||||
let retrieved = storage.get(location).await.expect("get");
|
||||
.with_context(|| "put".to_string())?;
|
||||
let retrieved = storage.get(location).await.with_context(|| "get".to_string())?;
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
assert!(storage.exists(location).await.expect("exists"));
|
||||
assert!(storage.exists(location).await.with_context(|| "exists".to_string())?);
|
||||
assert_eq!(*storage.backend_kind(), StorageKind::Memory);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_error_handling() {
|
||||
async fn test_storage_manager_error_handling() -> anyhow::Result<()> {
|
||||
let cfg = test_config_memory();
|
||||
let storage = StorageManager::new(&cfg)
|
||||
.await
|
||||
.expect("create storage manager");
|
||||
.with_context(|| "create storage manager".to_string())?;
|
||||
|
||||
// Test getting non-existent file
|
||||
let result = storage.get("nonexistent.txt").await;
|
||||
@@ -846,124 +857,136 @@ mod tests {
|
||||
let exists = storage
|
||||
.exists("nonexistent.txt")
|
||||
.await
|
||||
.expect("exists check");
|
||||
.with_context(|| "exists check".to_string())?;
|
||||
assert!(!exists);
|
||||
|
||||
// Test listing with invalid location (should not panic)
|
||||
let _result = storage.get("").await;
|
||||
// This may or may not error depending on the backend implementation
|
||||
// The important thing is that it doesn't panic
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// TestStorageManager tests
|
||||
#[tokio::test]
|
||||
async fn test_test_storage_manager_memory() {
|
||||
async fn test_test_storage_manager_memory() -> anyhow::Result<()> {
|
||||
let test_storage = testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.expect("create test storage");
|
||||
.with_context(|| "create test storage".to_string())?;
|
||||
|
||||
let location = "test/storage/file.txt";
|
||||
let data = b"test data with TestStorageManager";
|
||||
|
||||
// Test put and get
|
||||
test_storage.put(location, data).await.expect("put");
|
||||
let retrieved = test_storage.get(location).await.expect("get");
|
||||
test_storage.put(location, data).await.with_context(|| "put".to_string())?;
|
||||
let retrieved = test_storage.get(location).await.with_context(|| "get".to_string())?;
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
// Test existence check
|
||||
assert!(test_storage.exists(location).await.expect("exists"));
|
||||
assert!(test_storage.exists(location).await.with_context(|| "exists".to_string())?);
|
||||
|
||||
// Test list
|
||||
let files = test_storage
|
||||
.list(Some("test/storage/"))
|
||||
.await
|
||||
.expect("list");
|
||||
.with_context(|| "list".to_string())?;
|
||||
assert_eq!(files.len(), 1);
|
||||
|
||||
// Test delete
|
||||
test_storage
|
||||
.delete_prefix("test/storage/")
|
||||
.await
|
||||
.expect("delete");
|
||||
.with_context(|| "delete".to_string())?;
|
||||
assert!(!test_storage
|
||||
.exists(location)
|
||||
.await
|
||||
.expect("exists after delete"));
|
||||
.with_context(|| "exists after delete".to_string())?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_test_storage_manager_local() {
|
||||
async fn test_test_storage_manager_local() -> anyhow::Result<()> {
|
||||
let test_storage = testing::TestStorageManager::new_local()
|
||||
.await
|
||||
.expect("create test storage");
|
||||
.with_context(|| "create test storage".to_string())?;
|
||||
|
||||
let location = "test/local/file.txt";
|
||||
let data = b"test data with local TestStorageManager";
|
||||
|
||||
// Test put and get
|
||||
test_storage.put(location, data).await.expect("put");
|
||||
let retrieved = test_storage.get(location).await.expect("get");
|
||||
test_storage.put(location, data).await
|
||||
.with_context(|| "put".to_string())?;
|
||||
let retrieved = test_storage.get(location).await
|
||||
.with_context(|| "get".to_string())?;
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
// Test existence check
|
||||
assert!(test_storage.exists(location).await.expect("exists"));
|
||||
assert!(test_storage.exists(location).await
|
||||
.with_context(|| "exists".to_string())?);
|
||||
|
||||
// The storage should be automatically cleaned up when test_storage is dropped
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_test_storage_manager_isolation() {
|
||||
async fn test_test_storage_manager_isolation() -> anyhow::Result<()> {
|
||||
let storage1 = testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.expect("create test storage 1");
|
||||
.with_context(|| "create test storage 1".to_string())?;
|
||||
let storage2 = testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.expect("create test storage 2");
|
||||
.with_context(|| "create test storage 2".to_string())?;
|
||||
|
||||
let location = "isolation/test.txt";
|
||||
let data1 = b"storage 1 data";
|
||||
let data2 = b"storage 2 data";
|
||||
|
||||
// Put different data in each storage
|
||||
storage1.put(location, data1).await.expect("put storage 1");
|
||||
storage2.put(location, data2).await.expect("put storage 2");
|
||||
storage1.put(location, data1).await
|
||||
.with_context(|| "put storage 1".to_string())?;
|
||||
storage2.put(location, data2).await
|
||||
.with_context(|| "put storage 2".to_string())?;
|
||||
|
||||
// Verify isolation
|
||||
let retrieved1 = storage1.get(location).await.expect("get storage 1");
|
||||
let retrieved2 = storage2.get(location).await.expect("get storage 2");
|
||||
let retrieved1 = storage1.get(location).await
|
||||
.with_context(|| "get storage 1".to_string())?;
|
||||
let retrieved2 = storage2.get(location).await
|
||||
.with_context(|| "get storage 2".to_string())?;
|
||||
|
||||
assert_eq!(retrieved1.as_ref(), data1);
|
||||
assert_eq!(retrieved2.as_ref(), data2);
|
||||
assert_ne!(retrieved1.as_ref(), retrieved2.as_ref());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_test_storage_manager_config() {
|
||||
async fn test_test_storage_manager_config() -> anyhow::Result<()> {
|
||||
let cfg = testing::test_config_memory();
|
||||
let test_storage = testing::TestStorageManager::with_config(&cfg)
|
||||
.await
|
||||
.expect("create test storage with config");
|
||||
.with_context(|| "create test storage with config".to_string())?;
|
||||
|
||||
let location = "config/test.txt";
|
||||
let data = b"test data with custom config";
|
||||
|
||||
test_storage.put(location, data).await.expect("put");
|
||||
let retrieved = test_storage.get(location).await.expect("get");
|
||||
test_storage.put(location, data).await
|
||||
.with_context(|| "put".to_string())?;
|
||||
let retrieved = test_storage.get(location).await
|
||||
.with_context(|| "get".to_string())?;
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
// Verify it's using memory backend
|
||||
assert_eq!(*test_storage.storage().backend_kind(), StorageKind::Memory);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// S3 Tests - Require a reachable MinIO endpoint and test bucket.
|
||||
// `TestStorageManager::new_s3()` probes connectivity and these tests auto-skip when unavailable.
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_s3_basic_operations() {
|
||||
async fn test_storage_manager_s3_basic_operations() -> anyhow::Result<()> {
|
||||
// Skip if S3 connection fails (e.g. no MinIO)
|
||||
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
|
||||
eprintln!("Skipping S3 test (setup failed)");
|
||||
return;
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let prefix = format!("test-basic-{}", Uuid::new_v4());
|
||||
@@ -973,31 +996,33 @@ mod tests {
|
||||
// Test put
|
||||
if let Err(e) = storage.put(&location, data).await {
|
||||
eprintln!("Skipping S3 test (put failed - bucket missing?): {e}");
|
||||
return;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Test get
|
||||
let retrieved = storage.get(&location).await.expect("get");
|
||||
let retrieved = storage.get(&location).await.with_context(|| "get".to_string())?;
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
// Test exists
|
||||
assert!(storage.exists(&location).await.expect("exists"));
|
||||
assert!(storage.exists(&location).await.with_context(|| "exists".to_string())?);
|
||||
|
||||
// Test delete
|
||||
storage
|
||||
.delete_prefix(&format!("{prefix}/"))
|
||||
.await
|
||||
.expect("delete");
|
||||
.with_context(|| "delete".to_string())?;
|
||||
assert!(!storage
|
||||
.exists(&location)
|
||||
.await
|
||||
.expect("exists after delete"));
|
||||
.with_context(|| "exists after delete".to_string())?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_s3_list_operations() {
|
||||
async fn test_storage_manager_s3_list_operations() -> anyhow::Result<()> {
|
||||
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
|
||||
return;
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let prefix = format!("test-list-{}", Uuid::new_v4());
|
||||
@@ -1009,23 +1034,25 @@ mod tests {
|
||||
|
||||
for (loc, data) in &files {
|
||||
if storage.put(loc, *data).await.is_err() {
|
||||
return; // Abort if put fails
|
||||
return Ok(()); // Abort if put fails
|
||||
}
|
||||
}
|
||||
|
||||
// List with prefix
|
||||
let list_prefix = format!("{prefix}/");
|
||||
let items = storage.list(Some(&list_prefix)).await.expect("list");
|
||||
let items = storage.list(Some(&list_prefix)).await.with_context(|| "list".to_string())?;
|
||||
assert_eq!(items.len(), 3);
|
||||
|
||||
// Cleanup
|
||||
storage.delete_prefix(&list_prefix).await.expect("cleanup");
|
||||
storage.delete_prefix(&list_prefix).await.with_context(|| "cleanup".to_string())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_s3_stream_operations() {
|
||||
async fn test_storage_manager_s3_stream_operations() -> anyhow::Result<()> {
|
||||
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
|
||||
return;
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let prefix = format!("test-stream-{}", Uuid::new_v4());
|
||||
@@ -1033,38 +1060,45 @@ mod tests {
|
||||
let content = vec![42u8; 1024 * 10]; // 10KB
|
||||
|
||||
if storage.put(&location, &content).await.is_err() {
|
||||
return;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut stream = storage.get_stream(&location).await.expect("get stream");
|
||||
let mut stream = storage.get_stream(&location).await.with_context(|| "get stream".to_string())?;
|
||||
let mut collected = Vec::new();
|
||||
while let Some(chunk) = stream.next().await {
|
||||
collected.extend_from_slice(&chunk.expect("chunk"));
|
||||
collected.extend_from_slice(&chunk.with_context(|| "chunk".to_string())?);
|
||||
}
|
||||
assert_eq!(collected, content);
|
||||
|
||||
storage
|
||||
.delete_prefix(&format!("{prefix}/"))
|
||||
.await
|
||||
.expect("cleanup");
|
||||
.with_context(|| "cleanup".to_string())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_s3_backend_kind() {
|
||||
async fn test_storage_manager_s3_backend_kind() -> anyhow::Result<()> {
|
||||
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
|
||||
return;
|
||||
return Ok(());
|
||||
};
|
||||
assert_eq!(*storage.storage().backend_kind(), StorageKind::S3);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_s3_error_handling() {
|
||||
async fn test_storage_manager_s3_error_handling() -> anyhow::Result<()> {
|
||||
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
|
||||
return;
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let location = format!("nonexistent-{}/file.txt", Uuid::new_v4());
|
||||
assert!(storage.get(&location).await.is_err());
|
||||
assert!(!storage.exists(&location).await.expect("exists check"));
|
||||
// exists may fail if S3 is unavailable; treat error as false
|
||||
assert!(!storage.exists(&location).await.unwrap_or(false));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,6 +90,7 @@ impl Analytics {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::stored_object;
|
||||
use anyhow::{self};
|
||||
use uuid::Uuid;
|
||||
|
||||
stored_object!(TestUser, "user", {
|
||||
@@ -99,18 +100,14 @@ mod tests {
|
||||
});
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analytics_initialization() {
|
||||
async fn test_analytics_initialization() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
// Test initialization of analytics
|
||||
let analytics = Analytics::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to initialize analytics");
|
||||
let analytics = Analytics::ensure_initialized(&db).await?;
|
||||
|
||||
// Verify initial state after initialization
|
||||
assert_eq!(analytics.id, "current");
|
||||
@@ -118,159 +115,134 @@ mod tests {
|
||||
assert_eq!(analytics.visitors, 0);
|
||||
|
||||
// Test idempotency - ensure calling it again doesn't change anything
|
||||
let analytics_again = Analytics::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to get analytics after initialization");
|
||||
let analytics_again = Analytics::ensure_initialized(&db).await?;
|
||||
|
||||
assert_eq!(analytics.id, analytics_again.id);
|
||||
assert_eq!(analytics.page_loads, analytics_again.page_loads);
|
||||
assert_eq!(analytics.visitors, analytics_again.visitors);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_current_analytics() {
|
||||
async fn test_get_current_analytics() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
// Initialize analytics
|
||||
Analytics::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to initialize analytics");
|
||||
Analytics::ensure_initialized(&db).await?;
|
||||
|
||||
// Test get_current method
|
||||
let analytics = Analytics::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to get current analytics");
|
||||
let analytics = Analytics::get_current(&db).await?;
|
||||
|
||||
assert_eq!(analytics.id, "current");
|
||||
assert_eq!(analytics.page_loads, 0);
|
||||
assert_eq!(analytics.visitors, 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_increment_visitors() {
|
||||
async fn test_increment_visitors() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
// Initialize analytics
|
||||
Analytics::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to initialize analytics");
|
||||
Analytics::ensure_initialized(&db).await?;
|
||||
|
||||
// Test increment_visitors method
|
||||
let analytics = Analytics::increment_visitors(&db)
|
||||
.await
|
||||
.expect("Failed to increment visitors");
|
||||
let analytics = Analytics::increment_visitors(&db).await?;
|
||||
|
||||
assert_eq!(analytics.visitors, 1);
|
||||
assert_eq!(analytics.page_loads, 0);
|
||||
|
||||
// Increment again and check
|
||||
let analytics = Analytics::increment_visitors(&db)
|
||||
.await
|
||||
.expect("Failed to increment visitors again");
|
||||
let analytics = Analytics::increment_visitors(&db).await?;
|
||||
|
||||
assert_eq!(analytics.visitors, 2);
|
||||
assert_eq!(analytics.page_loads, 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_increment_page_loads() {
|
||||
async fn test_increment_page_loads() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
// Initialize analytics
|
||||
Analytics::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to initialize analytics");
|
||||
Analytics::ensure_initialized(&db).await?;
|
||||
|
||||
// Test increment_page_loads method
|
||||
let analytics = Analytics::increment_page_loads(&db)
|
||||
.await
|
||||
.expect("Failed to increment page loads");
|
||||
let analytics = Analytics::increment_page_loads(&db).await?;
|
||||
|
||||
assert_eq!(analytics.visitors, 0);
|
||||
assert_eq!(analytics.page_loads, 1);
|
||||
|
||||
// Increment again and check
|
||||
let analytics = Analytics::increment_page_loads(&db)
|
||||
.await
|
||||
.expect("Failed to increment page loads again");
|
||||
let analytics = Analytics::increment_page_loads(&db).await?;
|
||||
|
||||
assert_eq!(analytics.visitors, 0);
|
||||
assert_eq!(analytics.page_loads, 2);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_users_amount() {
|
||||
async fn test_get_users_amount() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
// Test with no users
|
||||
let count = Analytics::get_users_amount(&db)
|
||||
.await
|
||||
.expect("Failed to get users amount");
|
||||
let count = Analytics::get_users_amount(&db).await?;
|
||||
assert_eq!(count, 0);
|
||||
|
||||
// Create a few test users
|
||||
for i in 0..3 {
|
||||
let user = TestUser {
|
||||
id: format!("user{}", i),
|
||||
email: format!("user{}@example.com", i),
|
||||
id: format!("user{i}"),
|
||||
email: format!("user{i}@example.com"),
|
||||
password: "password".to_string(),
|
||||
user_id: format!("uid{}", i),
|
||||
user_id: format!("uid{i}"),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
};
|
||||
|
||||
db.store_item(user)
|
||||
.await
|
||||
.expect("Failed to create test user");
|
||||
db.store_item(user).await?;
|
||||
}
|
||||
|
||||
// Test users amount after adding users
|
||||
let count = Analytics::get_users_amount(&db)
|
||||
.await
|
||||
.expect("Failed to get users amount after adding users");
|
||||
let count = Analytics::get_users_amount(&db).await?;
|
||||
assert_eq!(count, 3);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_current_nonexistent() {
|
||||
async fn test_get_current_nonexistent() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
// Don't initialize analytics and try to get it
|
||||
let result = Analytics::get_current(&db).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
if let Err(err) = result {
|
||||
match err {
|
||||
AppError::NotFound(_) => {
|
||||
// Expected error
|
||||
}
|
||||
_ => panic!("Expected NotFound error, got: {:?}", err),
|
||||
}
|
||||
match result {
|
||||
Ok(_) => anyhow::bail!("Expected NotFound error, got success"),
|
||||
Err(AppError::NotFound(_)) => {}
|
||||
Err(err) => anyhow::bail!("Expected NotFound error, got: {err:?}"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,76 +144,71 @@ impl Conversation {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
use crate::storage::types::message::MessageRole;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_conversation() {
|
||||
// Setup in-memory database for testing
|
||||
async fn test_create_conversation() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Create a new conversation
|
||||
let user_id = "test_user";
|
||||
let title = "Test Conversation";
|
||||
let conversation = Conversation::new(user_id.to_string(), title.to_string());
|
||||
|
||||
// Verify conversation properties
|
||||
assert_eq!(conversation.user_id, user_id);
|
||||
assert_eq!(conversation.title, title);
|
||||
assert!(!conversation.id.is_empty());
|
||||
|
||||
// Store the conversation
|
||||
let result = db.store_item(conversation.clone()).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify it can be retrieved
|
||||
let retrieved: Option<Conversation> = db
|
||||
.get_item(&conversation.id)
|
||||
.await
|
||||
.expect("Failed to retrieve conversation");
|
||||
assert!(retrieved.is_some());
|
||||
.with_context(|| "Failed to retrieve conversation".to_string())?;
|
||||
|
||||
let retrieved = retrieved.unwrap();
|
||||
let retrieved = retrieved.ok_or_else(|| anyhow::anyhow!("Expected conversation to exist"))?;
|
||||
assert_eq!(retrieved.id, conversation.id);
|
||||
assert_eq!(retrieved.user_id, user_id);
|
||||
assert_eq!(retrieved.title, title);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_complete_conversation_not_found() {
|
||||
// Setup in-memory database for testing
|
||||
async fn test_get_complete_conversation_not_found() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Try to get a conversation that doesn't exist
|
||||
let result =
|
||||
Conversation::get_complete_conversation("nonexistent_id", "test_user", &db).await;
|
||||
assert!(result.is_err());
|
||||
|
||||
match result {
|
||||
Err(AppError::NotFound(_)) => { /* expected error */ }
|
||||
_ => panic!("Expected NotFound error"),
|
||||
Err(AppError::NotFound(_)) => {}
|
||||
_ => anyhow::bail!("Expected NotFound error"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_complete_conversation_unauthorized() {
|
||||
// Setup in-memory database for testing
|
||||
async fn test_get_complete_conversation_unauthorized() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Create and store a conversation for user_id_1
|
||||
let user_id_1 = "user_1";
|
||||
let conversation =
|
||||
Conversation::new(user_id_1.to_string(), "Private Conversation".to_string());
|
||||
@@ -221,27 +216,28 @@ mod tests {
|
||||
|
||||
db.store_item(conversation)
|
||||
.await
|
||||
.expect("Failed to store conversation");
|
||||
.with_context(|| "Failed to store conversation".to_string())?;
|
||||
|
||||
// Try to access with a different user
|
||||
let user_id_2 = "user_2";
|
||||
let result =
|
||||
Conversation::get_complete_conversation(&conversation_id, user_id_2, &db).await;
|
||||
assert!(result.is_err());
|
||||
|
||||
match result {
|
||||
Err(AppError::Auth(_)) => { /* expected error */ }
|
||||
_ => panic!("Expected Auth error"),
|
||||
Err(AppError::Auth(_)) => {}
|
||||
_ => anyhow::bail!("Expected Auth error"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_patch_title_success() {
|
||||
async fn test_patch_title_success() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
let user_id = "user_1";
|
||||
let original_title = "Original Title";
|
||||
@@ -250,49 +246,50 @@ mod tests {
|
||||
|
||||
db.store_item(conversation)
|
||||
.await
|
||||
.expect("Failed to store conversation");
|
||||
.with_context(|| "Failed to store conversation".to_string())?;
|
||||
|
||||
let new_title = "Updated Title";
|
||||
|
||||
// Patch title successfully
|
||||
let result = Conversation::patch_title(&conversation_id, user_id, new_title, &db).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Retrieve from DB to verify
|
||||
let updated_conversation = db
|
||||
.get_item::<Conversation>(&conversation_id)
|
||||
.await
|
||||
.expect("Failed to get conversation")
|
||||
.expect("Conversation missing");
|
||||
.with_context(|| "Failed to get conversation".to_string())?
|
||||
.ok_or_else(|| anyhow::anyhow!("Conversation missing"))?;
|
||||
assert_eq!(updated_conversation.title, new_title);
|
||||
assert_eq!(updated_conversation.user_id, user_id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_patch_title_not_found() {
|
||||
async fn test_patch_title_not_found() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Try to patch non-existing conversation
|
||||
let result = Conversation::patch_title("nonexistent", "user_x", "New Title", &db).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(AppError::NotFound(_)) => {}
|
||||
_ => panic!("Expected NotFound error"),
|
||||
_ => anyhow::bail!("Expected NotFound error"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_patch_title_unauthorized() {
|
||||
async fn test_patch_title_unauthorized() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
let owner_id = "owner";
|
||||
let other_user_id = "intruder";
|
||||
@@ -301,17 +298,18 @@ mod tests {
|
||||
|
||||
db.store_item(conversation)
|
||||
.await
|
||||
.expect("Failed to store conversation");
|
||||
.with_context(|| "Failed to store conversation".to_string())?;
|
||||
|
||||
// Attempt patch with unauthorized user
|
||||
let result =
|
||||
Conversation::patch_title(&conversation_id, other_user_id, "Hacked Title", &db).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(AppError::Auth(_)) => {}
|
||||
_ => panic!("Expected Auth error"),
|
||||
_ => anyhow::bail!("Expected Auth error"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -405,24 +403,21 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_complete_conversation_with_messages() {
|
||||
// Setup in-memory database for testing
|
||||
async fn test_get_complete_conversation_with_messages() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Create and store a conversation for user_id_1
|
||||
let user_id_1 = "user_1";
|
||||
let conversation = Conversation::new(user_id_1.to_string(), "Conversation".to_string());
|
||||
let conversation_id = conversation.id.clone();
|
||||
|
||||
db.store_item(conversation)
|
||||
.await
|
||||
.expect("Failed to store conversation");
|
||||
.with_context(|| "Failed to store conversation".to_string())?;
|
||||
|
||||
// Create messages
|
||||
let message1 = Message::new(
|
||||
conversation_id.clone(),
|
||||
MessageRole::User,
|
||||
@@ -442,46 +437,44 @@ mod tests {
|
||||
None,
|
||||
);
|
||||
|
||||
// Store messages
|
||||
db.store_item(message1)
|
||||
.await
|
||||
.expect("Failed to store message1");
|
||||
.with_context(|| "Failed to store message1".to_string())?;
|
||||
db.store_item(message2)
|
||||
.await
|
||||
.expect("Failed to store message2");
|
||||
.with_context(|| "Failed to store message2".to_string())?;
|
||||
db.store_item(message3)
|
||||
.await
|
||||
.expect("Failed to store message3");
|
||||
.with_context(|| "Failed to store message3".to_string())?;
|
||||
|
||||
// Retrieve the complete conversation
|
||||
let result =
|
||||
Conversation::get_complete_conversation(&conversation_id, user_id_1, &db).await;
|
||||
assert!(result.is_ok(), "Failed to retrieve complete conversation");
|
||||
|
||||
let (retrieved_conversation, messages) = result.unwrap();
|
||||
let (retrieved_conversation, retrieved_messages) = result
|
||||
.with_context(|| "Failed to retrieve complete conversation".to_string())?;
|
||||
|
||||
// Verify conversation data
|
||||
assert_eq!(retrieved_conversation.id, conversation_id);
|
||||
assert_eq!(retrieved_conversation.user_id, user_id_1);
|
||||
assert_eq!(retrieved_conversation.title, "Conversation");
|
||||
|
||||
// Verify messages
|
||||
assert_eq!(messages.len(), 3);
|
||||
assert_eq!(retrieved_messages.len(), 3);
|
||||
|
||||
// Verify messages are sorted by updated_at
|
||||
let message_contents: Vec<&str> = messages.iter().map(|m| m.content.as_str()).collect();
|
||||
let message_contents: Vec<&str> =
|
||||
retrieved_messages.iter().map(|m| m.content.as_str()).collect();
|
||||
assert!(message_contents.contains(&"Hello, AI!"));
|
||||
assert!(message_contents.contains(&"Hello, human! How can I help you today?"));
|
||||
assert!(message_contents.contains(&"Tell me about Rust programming."));
|
||||
|
||||
// Make sure we can't access with different user
|
||||
let user_id_2 = "user_2";
|
||||
let unauthorized_result =
|
||||
Conversation::get_complete_conversation(&conversation_id, user_id_2, &db).await;
|
||||
assert!(unauthorized_result.is_err());
|
||||
match unauthorized_result {
|
||||
Err(AppError::Auth(_)) => { /* expected error */ }
|
||||
_ => panic!("Expected Auth error"),
|
||||
Err(AppError::Auth(_)) => {}
|
||||
_ => anyhow::bail!("Expected Auth error"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -320,6 +320,8 @@ impl FileInfo {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
|
||||
use super::*;
|
||||
use crate::storage::store::testing::TestStorageManager;
|
||||
use axum::http::HeaderMap;
|
||||
@@ -328,11 +330,11 @@ mod tests {
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
/// Creates a test temporary file with the given content
|
||||
fn create_test_file(content: &[u8], file_name: &str) -> FieldData<NamedTempFile> {
|
||||
let mut temp_file = NamedTempFile::new().expect("Failed to create temp file");
|
||||
fn create_test_file(content: &[u8], file_name: &str) -> anyhow::Result<FieldData<NamedTempFile>> {
|
||||
let mut temp_file = NamedTempFile::new().with_context(|| "Failed to create temp file".to_string())?;
|
||||
temp_file
|
||||
.write_all(content)
|
||||
.expect("Failed to write to temp file");
|
||||
.with_context(|| "Failed to write to temp file".to_string())?;
|
||||
|
||||
let metadata = FieldMetadata {
|
||||
name: Some("file".to_string()),
|
||||
@@ -341,31 +343,29 @@ mod tests {
|
||||
headers: HeaderMap::default(),
|
||||
};
|
||||
|
||||
let field_data = FieldData {
|
||||
Ok(FieldData {
|
||||
metadata,
|
||||
contents: temp_file,
|
||||
};
|
||||
|
||||
field_data
|
||||
})
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fileinfo_create_read_delete_with_storage_manager() {
|
||||
async fn test_fileinfo_create_read_delete_with_storage_manager() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.unwrap();
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let content = b"This is a test file for StorageManager operations";
|
||||
let file_name = "storage_manager_test.txt";
|
||||
let field_data = create_test_file(content, file_name);
|
||||
let field_data = create_test_file(content, file_name)?;
|
||||
|
||||
// Create test storage manager (memory backend)
|
||||
let test_storage = store::testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.expect("Failed to create test storage manager");
|
||||
.with_context(|| "Failed to create test storage manager".to_string())?;
|
||||
|
||||
// Create a FileInfo instance with storage manager
|
||||
let user_id = "test_user";
|
||||
@@ -374,20 +374,20 @@ mod tests {
|
||||
let file_info =
|
||||
FileInfo::new_with_storage(field_data, &db, user_id, test_storage.storage())
|
||||
.await
|
||||
.expect("Failed to create file with StorageManager");
|
||||
.with_context(|| "Failed to create file with StorageManager".to_string())?;
|
||||
assert_eq!(file_info.file_name, file_name);
|
||||
|
||||
// Verify the file exists via StorageManager and has correct content
|
||||
let bytes = file_info
|
||||
.get_content_with_storage(test_storage.storage())
|
||||
.await
|
||||
.expect("Failed to read file content via StorageManager");
|
||||
.with_context(|| "Failed to read file content via StorageManager".to_string())?;
|
||||
assert_eq!(bytes.as_ref(), content);
|
||||
|
||||
// Test file reading
|
||||
let retrieved = FileInfo::get_by_id(&file_info.id, &db)
|
||||
.await
|
||||
.expect("Failed to retrieve file info");
|
||||
.with_context(|| "Failed to retrieve file info".to_string())?;
|
||||
assert_eq!(retrieved.id, file_info.id);
|
||||
assert_eq!(retrieved.sha256, file_info.sha256);
|
||||
assert_eq!(retrieved.file_name, file_name);
|
||||
@@ -395,65 +395,65 @@ mod tests {
|
||||
// Test file deletion with StorageManager
|
||||
FileInfo::delete_by_id_with_storage(&file_info.id, &db, test_storage.storage())
|
||||
.await
|
||||
.expect("Failed to delete file with StorageManager");
|
||||
.with_context(|| "Failed to delete file with StorageManager".to_string())?;
|
||||
|
||||
let deleted_result = file_info
|
||||
.get_content_with_storage(test_storage.storage())
|
||||
.await;
|
||||
assert!(deleted_result.is_err(), "File should be deleted");
|
||||
|
||||
// No cleanup needed - TestStorageManager handles it automatically
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fileinfo_preserves_original_filename_and_sanitizes_path() {
|
||||
async fn test_fileinfo_preserves_original_filename_and_sanitizes_path() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.unwrap();
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let content = b"filename sanitization";
|
||||
let original_name = "Complex name (1).txt";
|
||||
let expected_sanitized = "Complex_name__1_.txt";
|
||||
let field_data = create_test_file(content, original_name);
|
||||
let field_data = create_test_file(content, original_name)?;
|
||||
|
||||
let test_storage = store::testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.expect("Failed to create test storage manager");
|
||||
.with_context(|| "Failed to create test storage manager".to_string())?;
|
||||
|
||||
let file_info =
|
||||
FileInfo::new_with_storage(field_data, &db, "sanitized_user", test_storage.storage())
|
||||
.await
|
||||
.expect("Failed to create file via storage manager");
|
||||
.with_context(|| "Failed to create file via storage manager".to_string())?;
|
||||
|
||||
assert_eq!(file_info.file_name, original_name);
|
||||
|
||||
let stored_name = Path::new(&file_info.path)
|
||||
.file_name()
|
||||
.and_then(|name| name.to_str())
|
||||
.expect("stored name");
|
||||
.with_context(|| "stored name".to_string())?;
|
||||
assert_eq!(stored_name, expected_sanitized);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fileinfo_duplicate_detection_with_storage_manager() {
|
||||
async fn test_fileinfo_duplicate_detection_with_storage_manager() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.unwrap();
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let content = b"This is a test file for StorageManager duplicate detection";
|
||||
let file_name = "storage_manager_duplicate.txt";
|
||||
let field_data = create_test_file(content, file_name);
|
||||
let field_data = create_test_file(content, file_name)?;
|
||||
|
||||
// Create test storage manager
|
||||
let test_storage = store::testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.expect("Failed to create test storage manager");
|
||||
.with_context(|| "Failed to create test storage manager".to_string())?;
|
||||
|
||||
// Create a FileInfo instance with storage manager
|
||||
let user_id = "test_user";
|
||||
@@ -462,17 +462,17 @@ mod tests {
|
||||
let original_file_info =
|
||||
FileInfo::new_with_storage(field_data, &db, user_id, test_storage.storage())
|
||||
.await
|
||||
.expect("Failed to create original file with StorageManager");
|
||||
.with_context(|| "Failed to create original file with StorageManager".to_string())?;
|
||||
|
||||
// Create another file with the same content but different name
|
||||
let duplicate_name = "storage_manager_duplicate_2.txt";
|
||||
let field_data2 = create_test_file(content, duplicate_name);
|
||||
let field_data2 = create_test_file(content, duplicate_name)?;
|
||||
|
||||
// The system should detect it's the same file and return the original FileInfo
|
||||
let duplicate_file_info =
|
||||
FileInfo::new_with_storage(field_data2, &db, user_id, test_storage.storage())
|
||||
.await
|
||||
.expect("Failed to process duplicate file with StorageManager");
|
||||
.with_context(|| "Failed to process duplicate file with StorageManager".to_string())?;
|
||||
|
||||
// Verify duplicate detection worked
|
||||
assert_eq!(duplicate_file_info.id, original_file_info.id);
|
||||
@@ -484,46 +484,44 @@ mod tests {
|
||||
let original_content = original_file_info
|
||||
.get_content_with_storage(test_storage.storage())
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "get original content".to_string())?;
|
||||
let duplicate_content = duplicate_file_info
|
||||
.get_content_with_storage(test_storage.storage())
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "get duplicate content".to_string())?;
|
||||
assert_eq!(original_content.as_ref(), content);
|
||||
assert_eq!(duplicate_content.as_ref(), content);
|
||||
|
||||
// Clean up
|
||||
FileInfo::delete_by_id_with_storage(&original_file_info.id, &db, test_storage.storage())
|
||||
.await
|
||||
.expect("Failed to delete original file with StorageManager");
|
||||
.with_context(|| "Failed to delete original file with StorageManager".to_string())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_file_creation() {
|
||||
async fn test_file_creation() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
let content = b"This is a test file content";
|
||||
let file_name = "test_file.txt";
|
||||
let field_data = create_test_file(content, file_name);
|
||||
let field_data = create_test_file(content, file_name)?;
|
||||
|
||||
// Create a FileInfo instance with StorageManager
|
||||
let user_id = "test_user";
|
||||
let test_storage = TestStorageManager::new_memory()
|
||||
.await
|
||||
.expect("create test storage manager");
|
||||
.with_context(|| "create test storage manager".to_string())?;
|
||||
let file_info =
|
||||
FileInfo::new_with_storage(field_data, &db, user_id, test_storage.storage()).await;
|
||||
|
||||
// Verify the FileInfo was created successfully
|
||||
assert!(file_info.is_ok());
|
||||
let file_info = file_info.unwrap();
|
||||
FileInfo::new_with_storage(field_data, &db, user_id, test_storage.storage())
|
||||
.await?;
|
||||
|
||||
// Check essential properties
|
||||
assert!(!file_info.id.is_empty());
|
||||
@@ -533,32 +531,32 @@ mod tests {
|
||||
// path should be logical: "user_id/uuid/file_name"
|
||||
let parts: Vec<&str> = file_info.path.split('/').collect();
|
||||
assert_eq!(parts.len(), 3);
|
||||
assert_eq!(parts[0], user_id);
|
||||
assert_eq!(parts[2], file_name);
|
||||
assert_eq!(parts.first(), Some(&user_id));
|
||||
assert_eq!(parts.get(2), Some(&file_name));
|
||||
assert!(file_info.mime_type.contains("text/plain"));
|
||||
|
||||
// Verify it's in the database
|
||||
let stored: Option<FileInfo> = db
|
||||
.get_item(&file_info.id)
|
||||
let stored = db
|
||||
.get_item::<FileInfo>(&file_info.id)
|
||||
.await
|
||||
.expect("Failed to retrieve file info");
|
||||
assert!(stored.is_some());
|
||||
let stored = stored.unwrap();
|
||||
.with_context(|| "Failed to retrieve file info".to_string())?
|
||||
.with_context(|| "expected stored file".to_string())?;
|
||||
assert_eq!(stored.id, file_info.id);
|
||||
assert_eq!(stored.file_name, file_info.file_name);
|
||||
assert_eq!(stored.sha256, file_info.sha256);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_file_duplicate_detection() {
|
||||
async fn test_file_duplicate_detection() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
// First, store a file with known content
|
||||
let content = b"This is a test file for duplicate detection";
|
||||
@@ -567,23 +565,23 @@ mod tests {
|
||||
|
||||
let test_storage = TestStorageManager::new_memory()
|
||||
.await
|
||||
.expect("create test storage manager");
|
||||
.with_context(|| "create test storage manager".to_string())?;
|
||||
|
||||
let field_data1 = create_test_file(content, file_name);
|
||||
let field_data1 = create_test_file(content, file_name)?;
|
||||
let original_file_info =
|
||||
FileInfo::new_with_storage(field_data1, &db, user_id, test_storage.storage())
|
||||
.await
|
||||
.expect("Failed to create original file");
|
||||
.with_context(|| "Failed to create original file".to_string())?;
|
||||
|
||||
// Now try to store another file with the same content but different name
|
||||
let duplicate_name = "duplicate.txt";
|
||||
let field_data2 = create_test_file(content, duplicate_name);
|
||||
let field_data2 = create_test_file(content, duplicate_name)?;
|
||||
|
||||
// The system should detect it's the same file and return the original FileInfo
|
||||
let duplicate_file_info =
|
||||
FileInfo::new_with_storage(field_data2, &db, user_id, test_storage.storage())
|
||||
.await
|
||||
.expect("Failed to process duplicate file");
|
||||
.with_context(|| "Failed to process duplicate file".to_string())?;
|
||||
|
||||
// The returned FileInfo should match the original
|
||||
assert_eq!(duplicate_file_info.id, original_file_info.id);
|
||||
@@ -592,10 +590,11 @@ mod tests {
|
||||
// But it should retain the original file name, not the duplicate's name
|
||||
assert_eq!(duplicate_file_info.file_name, file_name);
|
||||
assert_ne!(duplicate_file_info.file_name, duplicate_name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_guess_mime_type() {
|
||||
async fn test_guess_mime_type() -> anyhow::Result<()> {
|
||||
// Test common file extensions
|
||||
assert_eq!(
|
||||
FileInfo::guess_mime_type(Path::new("test.txt")),
|
||||
@@ -619,10 +618,11 @@ mod tests {
|
||||
FileInfo::guess_mime_type(Path::new("unknown.929yz")),
|
||||
"application/octet-stream".to_string()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sanitize_file_name() {
|
||||
async fn test_sanitize_file_name() -> anyhow::Result<()> {
|
||||
// Safe characters should remain unchanged
|
||||
assert_eq!(
|
||||
FileInfo::sanitize_file_name("normal_file.txt"),
|
||||
@@ -647,26 +647,26 @@ mod tests {
|
||||
FileInfo::sanitize_file_name("../dangerous.txt"),
|
||||
"___dangerous.txt"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_by_sha_not_found() {
|
||||
async fn test_get_by_sha_not_found() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Try to find a file with a SHA that doesn't exist
|
||||
let result = FileInfo::get_by_sha("nonexistent_sha_hash", &db).await;
|
||||
assert!(result.is_err());
|
||||
|
||||
match result {
|
||||
Err(FileError::FileNotFound(_)) => {
|
||||
// Expected error
|
||||
}
|
||||
_ => panic!("Expected FileNotFound error"),
|
||||
Err(FileError::FileNotFound(_)) => {}
|
||||
_ => anyhow::bail!("Expected FileNotFound error"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -705,7 +705,7 @@ mod tests {
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Create a FileInfo instance directly
|
||||
let now = Utc::now();
|
||||
@@ -725,40 +725,39 @@ mod tests {
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify it can be retrieved
|
||||
let retrieved: Option<FileInfo> = db
|
||||
.get_item(&file_info.id)
|
||||
let retrieved = db
|
||||
.get_item::<FileInfo>(&file_info.id)
|
||||
.await
|
||||
.expect("Failed to retrieve file info");
|
||||
assert!(retrieved.is_some());
|
||||
|
||||
let retrieved = retrieved.unwrap();
|
||||
.with_context(|| "Failed to retrieve file info".to_string())?
|
||||
.with_context(|| "expected file".to_string())?;
|
||||
assert_eq!(retrieved.id, file_info.id);
|
||||
assert_eq!(retrieved.sha256, file_info.sha256);
|
||||
assert_eq!(retrieved.file_name, file_info.file_name);
|
||||
assert_eq!(retrieved.path, file_info.path);
|
||||
assert_eq!(retrieved.mime_type, file_info.mime_type);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_id() {
|
||||
async fn test_delete_by_id() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
// Create and persist a test file via FileInfo::new_with_storage
|
||||
let user_id = "user123";
|
||||
let test_storage = TestStorageManager::new_memory()
|
||||
.await
|
||||
.expect("create test storage manager");
|
||||
let temp = create_test_file(b"test content", "test_file.txt");
|
||||
.with_context(|| "create test storage manager".to_string())?;
|
||||
let temp = create_test_file(b"test content", "test_file.txt")?;
|
||||
let file_info = FileInfo::new_with_storage(temp, &db, user_id, test_storage.storage())
|
||||
.await
|
||||
.expect("create file");
|
||||
.with_context(|| "create file".to_string())?;
|
||||
|
||||
// Delete the file using StorageManager
|
||||
let delete_result =
|
||||
@@ -767,15 +766,14 @@ mod tests {
|
||||
// Delete should be successful
|
||||
assert!(
|
||||
delete_result.is_ok(),
|
||||
"Failed to delete file: {:?}",
|
||||
delete_result
|
||||
"Failed to delete file: {delete_result:?}"
|
||||
);
|
||||
|
||||
// Verify the file is removed from the database
|
||||
let retrieved: Option<FileInfo> = db
|
||||
.get_item(&file_info.id)
|
||||
.await
|
||||
.expect("Failed to query database");
|
||||
.with_context(|| "Failed to query database".to_string())?;
|
||||
assert!(
|
||||
retrieved.is_none(),
|
||||
"FileInfo should be deleted from the database"
|
||||
@@ -783,32 +781,37 @@ mod tests {
|
||||
|
||||
// Verify content no longer retrievable from storage
|
||||
assert!(test_storage.storage().get(&file_info.path).await.is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_id_not_found() {
|
||||
async fn test_delete_by_id_not_found() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Try to delete a file that doesn't exist
|
||||
let test_storage = TestStorageManager::new_memory().await.unwrap();
|
||||
let test_storage = TestStorageManager::new_memory()
|
||||
.await
|
||||
.with_context(|| "create test storage manager".to_string())?;
|
||||
let result =
|
||||
FileInfo::delete_by_id_with_storage("nonexistent_id", &db, test_storage.storage())
|
||||
.await;
|
||||
|
||||
// Should succeed even if the file record does not exist
|
||||
assert!(result.is_ok());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_by_id() {
|
||||
async fn test_get_by_id() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Create a FileInfo instance directly
|
||||
let now = Utc::now();
|
||||
@@ -827,28 +830,27 @@ mod tests {
|
||||
// Store it in the database
|
||||
db.store_item(original_file_info.clone())
|
||||
.await
|
||||
.expect("Failed to store item for get_by_id test");
|
||||
.with_context(|| "Failed to store item for get_by_id test".to_string())?;
|
||||
|
||||
// Retrieve it using get_by_id
|
||||
let result = FileInfo::get_by_id(&file_id, &db).await;
|
||||
|
||||
// Assert success and content match
|
||||
assert!(result.is_ok());
|
||||
let retrieved_info = result.unwrap();
|
||||
let retrieved_info = FileInfo::get_by_id(&file_id, &db)
|
||||
.await
|
||||
.with_context(|| "get_by_id".to_string())?;
|
||||
assert_eq!(retrieved_info.id, original_file_info.id);
|
||||
assert_eq!(retrieved_info.sha256, original_file_info.sha256);
|
||||
assert_eq!(retrieved_info.file_name, original_file_info.file_name);
|
||||
assert_eq!(retrieved_info.path, original_file_info.path);
|
||||
assert_eq!(retrieved_info.mime_type, original_file_info.mime_type);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_by_id_not_found() {
|
||||
async fn test_get_by_id_not_found() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Try to retrieve a non-existent ID
|
||||
let non_existent_id = "non-existent-file-id";
|
||||
@@ -862,33 +864,34 @@ mod tests {
|
||||
Err(FileError::FileNotFound(id)) => {
|
||||
assert_eq!(id, non_existent_id);
|
||||
}
|
||||
Err(e) => panic!("Expected FileNotFound error, but got {:?}", e),
|
||||
Ok(_) => panic!("Expected an error, but got Ok"),
|
||||
Err(e) => anyhow::bail!("Expected FileNotFound error, but got {e:?}"),
|
||||
Ok(_) => anyhow::bail!("Expected an error, but got Ok"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// StorageManager-based tests
|
||||
#[tokio::test]
|
||||
async fn test_file_info_new_with_storage_memory() {
|
||||
async fn test_file_info_new_with_storage_memory() -> anyhow::Result<()> {
|
||||
// Setup
|
||||
let db = SurrealDbClient::memory("test_ns", "test_file_storage_memory")
|
||||
.await
|
||||
.unwrap();
|
||||
db.apply_migrations().await.unwrap();
|
||||
.with_context(|| "Failed to start DB".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let content = b"This is a test file for StorageManager";
|
||||
let field_data = create_test_file(content, "test_storage.txt");
|
||||
let field_data = create_test_file(content, "test_storage.txt")?;
|
||||
let user_id = "test_user";
|
||||
|
||||
// Create test storage manager
|
||||
let storage = store::testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "create test storage".to_string())?;
|
||||
|
||||
// Test file creation with StorageManager
|
||||
let file_info = FileInfo::new_with_storage(field_data, &db, user_id, storage.storage())
|
||||
.await
|
||||
.expect("Failed to create file with StorageManager");
|
||||
.with_context(|| "Failed to create file with StorageManager".to_string())?;
|
||||
|
||||
// Verify the file was created correctly
|
||||
assert_eq!(file_info.user_id, user_id);
|
||||
@@ -900,40 +903,41 @@ mod tests {
|
||||
let retrieved_content = file_info
|
||||
.get_content_with_storage(storage.storage())
|
||||
.await
|
||||
.expect("Failed to get file content with StorageManager");
|
||||
.with_context(|| "Failed to get file content with StorageManager".to_string())?;
|
||||
assert_eq!(retrieved_content.as_ref(), content);
|
||||
|
||||
// Test file deletion with StorageManager
|
||||
FileInfo::delete_by_id_with_storage(&file_info.id, &db, storage.storage())
|
||||
.await
|
||||
.expect("Failed to delete file with StorageManager");
|
||||
.with_context(|| "Failed to delete file with StorageManager".to_string())?;
|
||||
|
||||
// Verify file is deleted
|
||||
let deleted_content_result = file_info.get_content_with_storage(storage.storage()).await;
|
||||
assert!(deleted_content_result.is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_file_info_new_with_storage_local() {
|
||||
async fn test_file_info_new_with_storage_local() -> anyhow::Result<()> {
|
||||
// Setup
|
||||
let db = SurrealDbClient::memory("test_ns", "test_file_storage_local")
|
||||
.await
|
||||
.unwrap();
|
||||
db.apply_migrations().await.unwrap();
|
||||
.with_context(|| "Failed to start DB".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let content = b"This is a test file for StorageManager with local storage";
|
||||
let field_data = create_test_file(content, "test_local.txt");
|
||||
let field_data = create_test_file(content, "test_local.txt")?;
|
||||
let user_id = "test_user";
|
||||
|
||||
// Create test storage manager with local backend
|
||||
let storage = store::testing::TestStorageManager::new_local()
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "create test storage".to_string())?;
|
||||
|
||||
// Test file creation with StorageManager
|
||||
let file_info = FileInfo::new_with_storage(field_data, &db, user_id, storage.storage())
|
||||
.await
|
||||
.expect("Failed to create file with StorageManager");
|
||||
.with_context(|| "Failed to create file with StorageManager".to_string())?;
|
||||
|
||||
// Verify the file was created correctly
|
||||
assert_eq!(file_info.user_id, user_id);
|
||||
@@ -945,50 +949,51 @@ mod tests {
|
||||
let retrieved_content = file_info
|
||||
.get_content_with_storage(storage.storage())
|
||||
.await
|
||||
.expect("Failed to get file content with StorageManager");
|
||||
.with_context(|| "Failed to get file content with StorageManager".to_string())?;
|
||||
assert_eq!(retrieved_content.as_ref(), content);
|
||||
|
||||
// Test file deletion with StorageManager
|
||||
FileInfo::delete_by_id_with_storage(&file_info.id, &db, storage.storage())
|
||||
.await
|
||||
.expect("Failed to delete file with StorageManager");
|
||||
.with_context(|| "Failed to delete file with StorageManager".to_string())?;
|
||||
|
||||
// Verify file is deleted
|
||||
let deleted_content_result = file_info.get_content_with_storage(storage.storage()).await;
|
||||
assert!(deleted_content_result.is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_file_info_storage_manager_persistence() {
|
||||
async fn test_file_info_storage_manager_persistence() -> anyhow::Result<()> {
|
||||
// Setup
|
||||
let db = SurrealDbClient::memory("test_ns", "test_file_persistence")
|
||||
.await
|
||||
.unwrap();
|
||||
db.apply_migrations().await.unwrap();
|
||||
.with_context(|| "Failed to start DB".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let content = b"Test content for persistence";
|
||||
let field_data = create_test_file(content, "persistence_test.txt");
|
||||
let field_data = create_test_file(content, "persistence_test.txt")?;
|
||||
let user_id = "test_user";
|
||||
|
||||
// Create test storage manager
|
||||
let storage = store::testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "create test storage".to_string())?;
|
||||
|
||||
// Create file
|
||||
let file_info = FileInfo::new_with_storage(field_data, &db, user_id, storage.storage())
|
||||
.await
|
||||
.expect("Failed to create file");
|
||||
.with_context(|| "Failed to create file".to_string())?;
|
||||
|
||||
// Test that data persists across multiple operations with the same StorageManager
|
||||
let retrieved_content_1 = file_info
|
||||
.get_content_with_storage(storage.storage())
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "get content 1".to_string())?;
|
||||
let retrieved_content_2 = file_info
|
||||
.get_content_with_storage(storage.storage())
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "get content 2".to_string())?;
|
||||
|
||||
assert_eq!(retrieved_content_1.as_ref(), content);
|
||||
assert_eq!(retrieved_content_2.as_ref(), content);
|
||||
@@ -996,68 +1001,70 @@ mod tests {
|
||||
// Test that different StorageManager instances don't share data (memory storage isolation)
|
||||
let storage2 = store::testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "create second storage".to_string())?;
|
||||
let isolated_content_result = file_info.get_content_with_storage(storage2.storage()).await;
|
||||
assert!(
|
||||
isolated_content_result.is_err(),
|
||||
"Different StorageManager should not have access to same data"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_file_info_storage_manager_equivalence() {
|
||||
async fn test_file_info_storage_manager_equivalence() -> anyhow::Result<()> {
|
||||
// Setup
|
||||
let db = SurrealDbClient::memory("test_ns", "test_file_equivalence")
|
||||
.await
|
||||
.unwrap();
|
||||
db.apply_migrations().await.unwrap();
|
||||
.with_context(|| "Failed to start DB".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let content = b"Test content for equivalence testing";
|
||||
let field_data1 = create_test_file(content, "equivalence_test_1.txt");
|
||||
let field_data2 = create_test_file(content, "equivalence_test_2.txt");
|
||||
let field_data1 = create_test_file(content, "equivalence_test_1.txt")?;
|
||||
let field_data2 = create_test_file(content, "equivalence_test_2.txt")?;
|
||||
let user_id = "test_user";
|
||||
|
||||
// Create single storage manager and reuse it
|
||||
let storage_manager = store::testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "create storage".to_string())?;
|
||||
let storage = storage_manager.storage();
|
||||
|
||||
// Create multiple files with the same storage manager
|
||||
let file_info_1 = FileInfo::new_with_storage(field_data1, &db, user_id, &storage)
|
||||
let file_info_1 = FileInfo::new_with_storage(field_data1, &db, user_id, storage)
|
||||
.await
|
||||
.expect("Failed to create file 1");
|
||||
.with_context(|| "Failed to create file 1".to_string())?;
|
||||
|
||||
let file_info_2 = FileInfo::new_with_storage(field_data2, &db, user_id, &storage)
|
||||
let file_info_2 = FileInfo::new_with_storage(field_data2, &db, user_id, storage)
|
||||
.await
|
||||
.expect("Failed to create file 2");
|
||||
.with_context(|| "Failed to create file 2".to_string())?;
|
||||
|
||||
// Test that both files can be retrieved with the same storage backend
|
||||
let content_1 = file_info_1
|
||||
.get_content_with_storage(&storage)
|
||||
.get_content_with_storage(storage)
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "get file 1 content".to_string())?;
|
||||
let content_2 = file_info_2
|
||||
.get_content_with_storage(&storage)
|
||||
.get_content_with_storage(storage)
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "get file 2 content".to_string())?;
|
||||
|
||||
assert_eq!(content_1.as_ref(), content);
|
||||
assert_eq!(content_2.as_ref(), content);
|
||||
|
||||
// Test that files can be deleted with the same storage manager
|
||||
FileInfo::delete_by_id_with_storage(&file_info_1.id, &db, &storage)
|
||||
FileInfo::delete_by_id_with_storage(&file_info_1.id, &db, storage)
|
||||
.await
|
||||
.unwrap();
|
||||
FileInfo::delete_by_id_with_storage(&file_info_2.id, &db, &storage)
|
||||
.with_context(|| "delete file 1".to_string())?;
|
||||
FileInfo::delete_by_id_with_storage(&file_info_2.id, &db, storage)
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "delete file 2".to_string())?;
|
||||
|
||||
// Verify files are deleted
|
||||
let deleted_content_1 = file_info_1.get_content_with_storage(&storage).await;
|
||||
let deleted_content_2 = file_info_2.get_content_with_storage(&storage).await;
|
||||
let deleted_content_1 = file_info_1.get_content_with_storage(storage).await;
|
||||
let deleted_content_2 = file_info_2.get_content_with_storage(storage).await;
|
||||
|
||||
assert!(deleted_content_1.is_err());
|
||||
assert!(deleted_content_2.is_err());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,6 +103,7 @@ impl IngestionPayload {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
use chrono::Utc;
|
||||
|
||||
use super::*;
|
||||
@@ -131,7 +132,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_ingestion_payload_with_url() {
|
||||
fn test_create_ingestion_payload_with_url() -> anyhow::Result<()> {
|
||||
let url = "https://example.com";
|
||||
let context = "Process this URL";
|
||||
let category = "websites";
|
||||
@@ -145,10 +146,10 @@ mod tests {
|
||||
files,
|
||||
user_id,
|
||||
)
|
||||
.unwrap();
|
||||
.with_context(|| "create_ingestion_payload".to_string())?;
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
match &result[0] {
|
||||
match result.first().context("expected one result")? {
|
||||
IngestionPayload::Url {
|
||||
url: payload_url,
|
||||
context: payload_context,
|
||||
@@ -156,17 +157,18 @@ mod tests {
|
||||
user_id: payload_user_id,
|
||||
} => {
|
||||
// URL parser may normalize the URL by adding a trailing slash
|
||||
assert!(payload_url == &url.to_string() || payload_url == &format!("{}/", url));
|
||||
assert!(payload_url == &url.to_string() || payload_url == &format!("{url}/"));
|
||||
assert_eq!(payload_context, &context);
|
||||
assert_eq!(payload_category, &category);
|
||||
assert_eq!(payload_user_id, &user_id);
|
||||
}
|
||||
_ => panic!("Expected Url variant"),
|
||||
_ => anyhow::bail!("Expected Url variant"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_ingestion_payload_with_text() {
|
||||
fn test_create_ingestion_payload_with_text() -> anyhow::Result<()> {
|
||||
let text = "This is some text content";
|
||||
let context = "Process this text";
|
||||
let category = "notes";
|
||||
@@ -180,10 +182,10 @@ mod tests {
|
||||
files,
|
||||
user_id,
|
||||
)
|
||||
.unwrap();
|
||||
.with_context(|| "create_ingestion_payload".to_string())?;
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
match &result[0] {
|
||||
match result.first().context("expected one result")? {
|
||||
IngestionPayload::Text {
|
||||
text: payload_text,
|
||||
context: payload_context,
|
||||
@@ -195,12 +197,13 @@ mod tests {
|
||||
assert_eq!(payload_category, category);
|
||||
assert_eq!(payload_user_id, user_id);
|
||||
}
|
||||
_ => panic!("Expected Text variant"),
|
||||
_ => anyhow::bail!("Expected Text variant"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_ingestion_payload_with_file() {
|
||||
fn test_create_ingestion_payload_with_file() -> anyhow::Result<()> {
|
||||
let context = "Process this file";
|
||||
let category = "documents";
|
||||
let user_id = "user123";
|
||||
@@ -220,10 +223,10 @@ mod tests {
|
||||
files,
|
||||
user_id,
|
||||
)
|
||||
.unwrap();
|
||||
.with_context(|| "create_ingestion_payload".to_string())?;
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
match &result[0] {
|
||||
match result.first().context("expected one result")? {
|
||||
IngestionPayload::File {
|
||||
file_info: payload_file_info,
|
||||
context: payload_context,
|
||||
@@ -235,12 +238,13 @@ mod tests {
|
||||
assert_eq!(payload_category, category);
|
||||
assert_eq!(payload_user_id, user_id);
|
||||
}
|
||||
_ => panic!("Expected File variant"),
|
||||
_ => anyhow::bail!("Expected File variant"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_ingestion_payload_with_url_and_file() {
|
||||
fn test_create_ingestion_payload_with_url_and_file() -> anyhow::Result<()> {
|
||||
let url = "https://example.com";
|
||||
let context = "Process this data";
|
||||
let category = "mixed";
|
||||
@@ -261,35 +265,36 @@ mod tests {
|
||||
files,
|
||||
user_id,
|
||||
)
|
||||
.unwrap();
|
||||
.with_context(|| "create_ingestion_payload".to_string())?;
|
||||
|
||||
assert_eq!(result.len(), 2);
|
||||
|
||||
// Check first item is URL
|
||||
match &result[0] {
|
||||
match result.first().context("expected first item")? {
|
||||
IngestionPayload::Url {
|
||||
url: payload_url, ..
|
||||
} => {
|
||||
// URL parser may normalize the URL by adding a trailing slash
|
||||
assert!(payload_url == &url.to_string() || payload_url == &format!("{}/", url));
|
||||
assert!(payload_url == &url.to_string() || payload_url == &format!("{url}/"));
|
||||
}
|
||||
_ => panic!("Expected first item to be Url variant"),
|
||||
_ => anyhow::bail!("Expected first item to be Url variant"),
|
||||
}
|
||||
|
||||
// Check second item is File
|
||||
match &result[1] {
|
||||
match result.get(1).context("expected second item")? {
|
||||
IngestionPayload::File {
|
||||
file_info: payload_file_info,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(payload_file_info.id, file_info.id);
|
||||
}
|
||||
_ => panic!("Expected second item to be File variant"),
|
||||
_ => anyhow::bail!("Expected second item to be File variant"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_ingestion_payload_empty_input() {
|
||||
fn test_create_ingestion_payload_empty_input() -> anyhow::Result<()> {
|
||||
let context = "Process something";
|
||||
let category = "empty";
|
||||
let user_id = "user123";
|
||||
@@ -308,12 +313,13 @@ mod tests {
|
||||
Err(AppError::NotFound(msg)) => {
|
||||
assert_eq!(msg, "No valid content or files provided");
|
||||
}
|
||||
_ => panic!("Expected NotFound error"),
|
||||
_ => anyhow::bail!("Expected NotFound error"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_ingestion_payload_with_empty_text() {
|
||||
fn test_create_ingestion_payload_with_empty_text() -> anyhow::Result<()> {
|
||||
let text = ""; // Empty text
|
||||
let context = "Process this";
|
||||
let category = "notes";
|
||||
@@ -333,7 +339,8 @@ mod tests {
|
||||
Err(AppError::NotFound(msg)) => {
|
||||
assert_eq!(msg, "No valid content or files provided");
|
||||
}
|
||||
_ => panic!("Expected NotFound error"),
|
||||
_ => anyhow::bail!("Expected NotFound error"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -529,6 +529,8 @@ impl IngestionTask {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
|
||||
use super::*;
|
||||
use crate::storage::types::ingestion_payload::IngestionPayload;
|
||||
|
||||
@@ -541,16 +543,16 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
async fn memory_db() -> SurrealDbClient {
|
||||
async fn memory_db() -> anyhow::Result<SurrealDbClient> {
|
||||
let namespace = "test_ns";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.expect("in-memory surrealdb")
|
||||
.with_context(|| "in-memory surrealdb".to_string())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_new_task_defaults() {
|
||||
async fn test_new_task_defaults() -> anyhow::Result<()> {
|
||||
let user_id = "user123";
|
||||
let payload = create_payload(user_id);
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
@@ -562,73 +564,76 @@ mod tests {
|
||||
assert_eq!(task.max_attempts, MAX_ATTEMPTS);
|
||||
assert!(task.locked_at.is_none());
|
||||
assert!(task.worker_id.is_none());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_and_store_task() {
|
||||
let db = memory_db().await;
|
||||
async fn test_create_and_store_task() -> anyhow::Result<()> {
|
||||
let db = memory_db().await?;
|
||||
let user_id = "user123";
|
||||
let payload = create_payload(user_id);
|
||||
|
||||
let created =
|
||||
IngestionTask::create_and_add_to_db(payload.clone(), user_id.to_string(), &db)
|
||||
.await
|
||||
.expect("store");
|
||||
.with_context(|| "store".to_string())?;
|
||||
|
||||
let stored: Option<IngestionTask> = db
|
||||
.get_item::<IngestionTask>(&created.id)
|
||||
.await
|
||||
.expect("fetch");
|
||||
.with_context(|| "fetch".to_string())?;
|
||||
|
||||
let stored = stored.expect("task exists");
|
||||
let stored = stored.with_context(|| "task exists".to_string())?;
|
||||
assert_eq!(stored.id, created.id);
|
||||
assert_eq!(stored.state, TaskState::Pending);
|
||||
assert_eq!(stored.attempts, 0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_claim_and_transition() {
|
||||
let db = memory_db().await;
|
||||
async fn test_claim_and_transition() -> anyhow::Result<()> {
|
||||
let db = memory_db().await?;
|
||||
let user_id = "user123";
|
||||
let payload = create_payload(user_id);
|
||||
let task = IngestionTask::new(payload, user_id.to_string());
|
||||
db.store_item(task.clone()).await.expect("store");
|
||||
db.store_item(task.clone()).await.with_context(|| "store".to_string())?;
|
||||
|
||||
let worker_id = "worker-1";
|
||||
let now = chrono::Utc::now();
|
||||
let claimed = IngestionTask::claim_next_ready(&db, worker_id, now, Duration::from_secs(60))
|
||||
.await
|
||||
.expect("claim");
|
||||
.with_context(|| "claim".to_string())?
|
||||
.with_context(|| "task claimed".to_string())?;
|
||||
|
||||
let claimed = claimed.expect("task claimed");
|
||||
assert_eq!(claimed.state, TaskState::Reserved);
|
||||
assert_eq!(claimed.worker_id.as_deref(), Some(worker_id));
|
||||
|
||||
let processing = claimed.mark_processing(&db).await.expect("processing");
|
||||
let processing = claimed.mark_processing(&db).await.with_context(|| "processing".to_string())?;
|
||||
assert_eq!(processing.state, TaskState::Processing);
|
||||
|
||||
let succeeded = processing.mark_succeeded(&db).await.expect("succeeded");
|
||||
let succeeded = processing.mark_succeeded(&db).await.with_context(|| "succeeded".to_string())?;
|
||||
assert_eq!(succeeded.state, TaskState::Succeeded);
|
||||
assert!(succeeded.worker_id.is_none());
|
||||
assert!(succeeded.locked_at.is_none());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fail_and_dead_letter() {
|
||||
let db = memory_db().await;
|
||||
async fn test_fail_and_dead_letter() -> anyhow::Result<()> {
|
||||
let db = memory_db().await?;
|
||||
let user_id = "user123";
|
||||
let payload = create_payload(user_id);
|
||||
let task = IngestionTask::new(payload, user_id.to_string());
|
||||
db.store_item(task.clone()).await.expect("store");
|
||||
db.store_item(task.clone()).await.with_context(|| "store".to_string())?;
|
||||
|
||||
let worker_id = "worker-dead";
|
||||
let now = chrono::Utc::now();
|
||||
let claimed = IngestionTask::claim_next_ready(&db, worker_id, now, Duration::from_secs(60))
|
||||
.await
|
||||
.expect("claim")
|
||||
.expect("claimed");
|
||||
.with_context(|| "claim".to_string())?
|
||||
.with_context(|| "claimed".to_string())?;
|
||||
|
||||
let processing = claimed.mark_processing(&db).await.expect("processing");
|
||||
let processing = claimed.mark_processing(&db).await.with_context(|| "processing".to_string())?;
|
||||
|
||||
let error_info = TaskErrorInfo {
|
||||
code: Some("pipeline_error".into()),
|
||||
@@ -638,7 +643,7 @@ mod tests {
|
||||
let failed = processing
|
||||
.mark_failed(error_info.clone(), Duration::from_secs(30), &db)
|
||||
.await
|
||||
.expect("failed update");
|
||||
.with_context(|| "failed update".to_string())?;
|
||||
assert_eq!(failed.state, TaskState::Failed);
|
||||
assert_eq!(failed.error_message.as_deref(), Some("failed"));
|
||||
assert!(failed.worker_id.is_none());
|
||||
@@ -648,19 +653,20 @@ mod tests {
|
||||
let dead = failed
|
||||
.mark_dead_letter(error_info.clone(), &db)
|
||||
.await
|
||||
.expect("dead letter");
|
||||
.with_context(|| "dead letter".to_string())?;
|
||||
assert_eq!(dead.state, TaskState::DeadLetter);
|
||||
assert_eq!(dead.error_message.as_deref(), Some("failed"));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mark_processing_requires_reservation() {
|
||||
let db = memory_db().await;
|
||||
async fn test_mark_processing_requires_reservation() -> anyhow::Result<()> {
|
||||
let db = memory_db().await?;
|
||||
let user_id = "user123";
|
||||
let payload = create_payload(user_id);
|
||||
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
db.store_item(task.clone()).await.expect("store");
|
||||
db.store_item(task.clone()).await.with_context(|| "store".to_string())?;
|
||||
|
||||
let err = task
|
||||
.mark_processing(&db)
|
||||
@@ -674,18 +680,19 @@ mod tests {
|
||||
"unexpected message: {message}"
|
||||
);
|
||||
}
|
||||
other => panic!("expected validation error, got {other:?}"),
|
||||
other => anyhow::bail!("expected validation error, got {other:?}"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mark_failed_requires_processing() {
|
||||
let db = memory_db().await;
|
||||
async fn test_mark_failed_requires_processing() -> anyhow::Result<()> {
|
||||
let db = memory_db().await?;
|
||||
let user_id = "user123";
|
||||
let payload = create_payload(user_id);
|
||||
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
db.store_item(task.clone()).await.expect("store");
|
||||
db.store_item(task.clone()).await.with_context(|| "store".to_string())?;
|
||||
|
||||
let err = task
|
||||
.mark_failed(
|
||||
@@ -706,18 +713,19 @@ mod tests {
|
||||
"unexpected message: {message}"
|
||||
);
|
||||
}
|
||||
other => panic!("expected validation error, got {other:?}"),
|
||||
other => anyhow::bail!("expected validation error, got {other:?}"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_release_requires_reservation() {
|
||||
let db = memory_db().await;
|
||||
async fn test_release_requires_reservation() -> anyhow::Result<()> {
|
||||
let db = memory_db().await?;
|
||||
let user_id = "user123";
|
||||
let payload = create_payload(user_id);
|
||||
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
db.store_item(task.clone()).await.expect("store");
|
||||
db.store_item(task.clone()).await.with_context(|| "store".to_string())?;
|
||||
|
||||
let err = task
|
||||
.release(&db)
|
||||
@@ -731,7 +739,8 @@ mod tests {
|
||||
"unexpected message: {message}"
|
||||
);
|
||||
}
|
||||
other => panic!("expected validation error, got {other:?}"),
|
||||
other => anyhow::bail!("expected validation error, got {other:?}"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
clippy::format_push_string,
|
||||
clippy::uninlined_format_args,
|
||||
clippy::explicit_iter_loop,
|
||||
clippy::items_after_statements,
|
||||
clippy::get_first,
|
||||
clippy::redundant_closure_for_method_calls
|
||||
)]
|
||||
@@ -317,6 +316,11 @@ impl KnowledgeEntity {
|
||||
}
|
||||
|
||||
async fn get_user_id_by_id(id: &str, db_client: &SurrealDbClient) -> Result<String, AppError> {
|
||||
#[derive(Deserialize)]
|
||||
struct Row {
|
||||
user_id: String,
|
||||
}
|
||||
|
||||
let mut response = db_client
|
||||
.client
|
||||
.query("SELECT user_id FROM type::thing($table, $id) LIMIT 1")
|
||||
@@ -324,10 +328,6 @@ impl KnowledgeEntity {
|
||||
.bind(("id", id.to_string()))
|
||||
.await
|
||||
.map_err(AppError::Database)?;
|
||||
#[derive(Deserialize)]
|
||||
struct Row {
|
||||
user_id: String,
|
||||
}
|
||||
let rows: Vec<Row> = response.take(0).map_err(AppError::Database)?;
|
||||
rows.get(0)
|
||||
.map(|r| r.user_id.clone())
|
||||
@@ -497,7 +497,6 @@ impl KnowledgeEntity {
|
||||
new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone()));
|
||||
}
|
||||
info!("Successfully generated all new embeddings.");
|
||||
info!("Successfully generated all new embeddings.");
|
||||
|
||||
// Clear existing embeddings and index first to prevent SurrealDB panics and dimension conflicts.
|
||||
info!("Removing old index and clearing embeddings...");
|
||||
@@ -572,14 +571,14 @@ impl KnowledgeEntity {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
use super::*;
|
||||
use crate::storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding;
|
||||
use serde_json::json;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_knowledge_entity_creation() {
|
||||
// Create basic test entity
|
||||
async fn test_knowledge_entity_creation() -> anyhow::Result<()> {
|
||||
let source_id = "source123".to_string();
|
||||
let name = "Test Entity".to_string();
|
||||
let description = "Test Description".to_string();
|
||||
@@ -596,7 +595,6 @@ mod tests {
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Verify all fields are set correctly
|
||||
assert_eq!(entity.source_id, source_id);
|
||||
assert_eq!(entity.name, name);
|
||||
assert_eq!(entity.description, description);
|
||||
@@ -604,11 +602,12 @@ mod tests {
|
||||
assert_eq!(entity.metadata, metadata);
|
||||
assert_eq!(entity.user_id, user_id);
|
||||
assert!(!entity.id.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_knowledge_entity_type_from_string() {
|
||||
// Test conversion from String to KnowledgeEntityType
|
||||
async fn test_knowledge_entity_type_from_string() -> anyhow::Result<()> {
|
||||
assert_eq!(
|
||||
KnowledgeEntityType::from("idea".to_string()),
|
||||
KnowledgeEntityType::Idea
|
||||
@@ -639,15 +638,16 @@ mod tests {
|
||||
KnowledgeEntityType::TextSnippet
|
||||
);
|
||||
|
||||
// Test default case
|
||||
assert_eq!(
|
||||
KnowledgeEntityType::from("unknown".to_string()),
|
||||
KnowledgeEntityType::Document
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_knowledge_entity_variants() {
|
||||
async fn test_knowledge_entity_variants() -> anyhow::Result<()> {
|
||||
let variants = KnowledgeEntityType::variants();
|
||||
assert_eq!(variants.len(), 5);
|
||||
assert!(variants.contains(&"Idea"));
|
||||
@@ -655,28 +655,28 @@ mod tests {
|
||||
assert!(variants.contains(&"Document"));
|
||||
assert!(variants.contains(&"Page"));
|
||||
assert!(variants.contains(&"TextSnippet"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id() {
|
||||
// Setup in-memory database for testing
|
||||
async fn test_delete_by_source_id() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
// Create two entities with the same source_id
|
||||
let source_id = "source123".to_string();
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 5)
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
|
||||
let entity1 = KnowledgeEntity::new(
|
||||
source_id.clone(),
|
||||
@@ -696,7 +696,6 @@ mod tests {
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Create an entity with a different source_id
|
||||
let different_source_id = "different_source".to_string();
|
||||
let different_entity = KnowledgeEntity::new(
|
||||
different_source_id.clone(),
|
||||
@@ -708,23 +707,20 @@ mod tests {
|
||||
);
|
||||
|
||||
let emb = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
// Store the entities
|
||||
KnowledgeEntity::store_with_embedding(entity1.clone(), emb.clone(), &db)
|
||||
.await
|
||||
.expect("Failed to store entity 1");
|
||||
.with_context(|| "Failed to store entity 1".to_string())?;
|
||||
KnowledgeEntity::store_with_embedding(entity2.clone(), emb.clone(), &db)
|
||||
.await
|
||||
.expect("Failed to store entity 2");
|
||||
.with_context(|| "Failed to store entity 2".to_string())?;
|
||||
KnowledgeEntity::store_with_embedding(different_entity.clone(), emb.clone(), &db)
|
||||
.await
|
||||
.expect("Failed to store different entity");
|
||||
.with_context(|| "Failed to store different entity".to_string())?;
|
||||
|
||||
// Delete by source_id
|
||||
KnowledgeEntity::delete_by_source_id(&source_id, &db)
|
||||
.await
|
||||
.expect("Failed to delete entities by source_id");
|
||||
.with_context(|| "Failed to delete entities by source_id".to_string())?;
|
||||
|
||||
// Verify all entities with the specified source_id are deleted
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
KnowledgeEntity::table_name(),
|
||||
@@ -734,16 +730,11 @@ mod tests {
|
||||
.client
|
||||
.query(query)
|
||||
.await
|
||||
.expect("Query failed")
|
||||
.with_context(|| "Query failed".to_string())?
|
||||
.take(0)
|
||||
.expect("Failed to get query results");
|
||||
assert_eq!(
|
||||
remaining.len(),
|
||||
0,
|
||||
"All entities with the source_id should be deleted"
|
||||
);
|
||||
.with_context(|| "Failed to get query results".to_string())?;
|
||||
assert!(remaining.is_empty(), "All entities with the source_id should be deleted");
|
||||
|
||||
// Verify the entity with a different source_id still exists
|
||||
let different_query = format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
KnowledgeEntity::table_name(),
|
||||
@@ -753,15 +744,20 @@ mod tests {
|
||||
.client
|
||||
.query(different_query)
|
||||
.await
|
||||
.expect("Query failed")
|
||||
.with_context(|| "Query failed".to_string())?
|
||||
.take(0)
|
||||
.expect("Failed to get query results");
|
||||
.with_context(|| "Failed to get query results".to_string())?;
|
||||
assert_eq!(
|
||||
different_remaining.len(),
|
||||
1,
|
||||
"Entity with different source_id should still exist"
|
||||
);
|
||||
assert_eq!(different_remaining[0].id, different_entity.id);
|
||||
assert_eq!(
|
||||
different_remaining.first().context("Expected entity to exist")?.id,
|
||||
different_entity.id
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -833,35 +829,37 @@ mod tests {
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
|
||||
let results = KnowledgeEntity::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user")
|
||||
.await
|
||||
.expect("vector search");
|
||||
.with_context(|| "vector search".to_string())?;
|
||||
assert!(results.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_single_result() {
|
||||
async fn test_vector_search_single_result() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
|
||||
let user_id = "user".to_string();
|
||||
let source_id = "src".to_string();
|
||||
@@ -876,9 +874,12 @@ mod tests {
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
.await
|
||||
.expect("store entity with embedding");
|
||||
.with_context(|| "store entity with embedding".to_string())?;
|
||||
|
||||
let stored_entity: Option<KnowledgeEntity> = db.get_item(&entity.id).await.unwrap();
|
||||
let stored_entity: Option<KnowledgeEntity> = db
|
||||
.get_item(&entity.id)
|
||||
.await
|
||||
.with_context(|| "Failed to get entity".to_string())?;
|
||||
assert!(stored_entity.is_some());
|
||||
|
||||
let stored_embeddings: Vec<KnowledgeEntityEmbedding> = db
|
||||
@@ -888,42 +889,44 @@ mod tests {
|
||||
KnowledgeEntityEmbedding::table_name()
|
||||
))
|
||||
.await
|
||||
.expect("query embeddings")
|
||||
.with_context(|| "query embeddings".to_string())?
|
||||
.take(0)
|
||||
.expect("take embeddings");
|
||||
.with_context(|| "take embeddings".to_string())?;
|
||||
assert_eq!(stored_embeddings.len(), 1);
|
||||
|
||||
let rid = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
let fetched_emb = KnowledgeEntityEmbedding::get_by_entity_id(&rid, &db)
|
||||
.await
|
||||
.expect("fetch embedding");
|
||||
.with_context(|| "fetch embedding".to_string())?;
|
||||
assert!(fetched_emb.is_some());
|
||||
|
||||
let results = KnowledgeEntity::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
|
||||
.await
|
||||
.expect("vector search");
|
||||
.with_context(|| "vector search".to_string())?;
|
||||
|
||||
assert_eq!(results.len(), 1);
|
||||
let res = &results[0];
|
||||
let res = results.first().context("Expected at least one result")?;
|
||||
assert_eq!(res.entity.id, entity.id);
|
||||
assert_eq!(res.entity.source_id, source_id);
|
||||
assert_eq!(res.entity.name, "hello");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_orders_by_similarity() {
|
||||
async fn test_vector_search_orders_by_similarity() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
|
||||
let user_id = "user".to_string();
|
||||
let e1 = KnowledgeEntity::new(
|
||||
@@ -945,13 +948,19 @@ mod tests {
|
||||
|
||||
KnowledgeEntity::store_with_embedding(e1.clone(), vec![1.0, 0.0, 0.0], &db)
|
||||
.await
|
||||
.expect("store e1");
|
||||
.with_context(|| "store e1".to_string())?;
|
||||
KnowledgeEntity::store_with_embedding(e2.clone(), vec![0.0, 1.0, 0.0], &db)
|
||||
.await
|
||||
.expect("store e2");
|
||||
.with_context(|| "store e2".to_string())?;
|
||||
|
||||
let stored_e1: Option<KnowledgeEntity> = db.get_item(&e1.id).await.unwrap();
|
||||
let stored_e2: Option<KnowledgeEntity> = db.get_item(&e2.id).await.unwrap();
|
||||
let stored_e1: Option<KnowledgeEntity> = db
|
||||
.get_item(&e1.id)
|
||||
.await
|
||||
.with_context(|| "Failed to get entity".to_string())?;
|
||||
let stored_e2: Option<KnowledgeEntity> = db
|
||||
.get_item(&e2.id)
|
||||
.await
|
||||
.with_context(|| "Failed to get entity".to_string())?;
|
||||
assert!(stored_e1.is_some() && stored_e2.is_some());
|
||||
|
||||
let stored_embeddings: Vec<KnowledgeEntityEmbedding> = db
|
||||
@@ -961,45 +970,53 @@ mod tests {
|
||||
KnowledgeEntityEmbedding::table_name()
|
||||
))
|
||||
.await
|
||||
.expect("query embeddings")
|
||||
.with_context(|| "query embeddings".to_string())?
|
||||
.take(0)
|
||||
.expect("take embeddings");
|
||||
.with_context(|| "take embeddings".to_string())?;
|
||||
assert_eq!(stored_embeddings.len(), 2);
|
||||
|
||||
let rid_e1 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e1.id);
|
||||
let rid_e2 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e2.id);
|
||||
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e1, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.with_context(|| "get embedding e1".to_string())?
|
||||
.is_some());
|
||||
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e2, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.with_context(|| "get embedding e2".to_string())?
|
||||
.is_some());
|
||||
|
||||
let results = KnowledgeEntity::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id)
|
||||
.await
|
||||
.expect("vector search");
|
||||
.with_context(|| "vector search".to_string())?;
|
||||
|
||||
assert_eq!(results.len(), 2);
|
||||
assert_eq!(results[0].entity.id, e2.id);
|
||||
assert_eq!(results[1].entity.id, e1.id);
|
||||
assert_eq!(
|
||||
results.first().context("Expected at least one result")?.entity.id,
|
||||
e2.id
|
||||
);
|
||||
assert_eq!(
|
||||
results.get(1).context("Expected at least two results")?.entity.id,
|
||||
e1.id
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_with_orphaned_embedding() {
|
||||
async fn test_vector_search_with_orphaned_embedding() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns_orphan";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
|
||||
let user_id = "user".to_string();
|
||||
let source_id = "src".to_string();
|
||||
@@ -1014,21 +1031,20 @@ mod tests {
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
.await
|
||||
.expect("store entity with embedding");
|
||||
.with_context(|| "store entity with embedding".to_string())?;
|
||||
|
||||
// Manually delete the entity to create an orphan
|
||||
let query = format!("DELETE type::thing('knowledge_entity', '{}')", entity.id);
|
||||
db.client.query(query).await.expect("delete entity");
|
||||
db.client
|
||||
.query(query)
|
||||
.await
|
||||
.with_context(|| "delete entity".to_string())?;
|
||||
|
||||
// Now search
|
||||
let results = KnowledgeEntity::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
|
||||
.await
|
||||
.expect("search should succeed even with orphans");
|
||||
.with_context(|| "search should succeed even with orphans".to_string())?;
|
||||
|
||||
assert!(
|
||||
results.is_empty(),
|
||||
"Should return empty result for orphan, got: {:?}",
|
||||
results
|
||||
);
|
||||
assert!(results.is_empty(), "Should return empty result for orphan, got: {:?}", results);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -110,11 +110,15 @@ impl KnowledgeEntityEmbedding {
|
||||
}
|
||||
|
||||
/// Delete embeddings by source_id (via joining to knowledge_entity table)
|
||||
#[allow(clippy::items_after_statements)]
|
||||
pub async fn delete_by_source_id(
|
||||
source_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
#[derive(Deserialize)]
|
||||
struct IdRow {
|
||||
id: RecordId,
|
||||
}
|
||||
|
||||
let query = "SELECT id FROM knowledge_entity WHERE source_id = $source_id";
|
||||
let mut res = db
|
||||
.client
|
||||
@@ -122,11 +126,6 @@ impl KnowledgeEntityEmbedding {
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await
|
||||
.map_err(AppError::Database)?;
|
||||
#[allow(clippy::missing_docs_in_private_items)]
|
||||
#[derive(Deserialize)]
|
||||
struct IdRow {
|
||||
id: RecordId,
|
||||
}
|
||||
let ids: Vec<IdRow> = res.take(0).map_err(AppError::Database)?;
|
||||
|
||||
for row in ids {
|
||||
@@ -138,6 +137,7 @@ impl KnowledgeEntityEmbedding {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
use super::*;
|
||||
use crate::storage::db::SurrealDbClient;
|
||||
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||
@@ -145,18 +145,18 @@ mod tests {
|
||||
use surrealdb::Value as SurrealValue;
|
||||
use uuid::Uuid;
|
||||
|
||||
async fn setup_test_db() -> SurrealDbClient {
|
||||
async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
|
||||
let namespace = "test_ns";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
db
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
fn build_knowledge_entity_with_id(
|
||||
@@ -178,11 +178,11 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_and_get_by_entity_id() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_create_and_get_by_entity_id() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("set test index dimension");
|
||||
.with_context(|| "set test index dimension".to_string())?;
|
||||
let user_id = "user_ke";
|
||||
let entity_key = "entity-1";
|
||||
let source_id = "source-ke";
|
||||
@@ -192,26 +192,28 @@ mod tests {
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), embedding_vec.clone(), &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
|
||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
|
||||
let fetched = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
.await
|
||||
.expect("Failed to get embedding by entity_id")
|
||||
.expect("Expected embedding to exist");
|
||||
.with_context(|| "Failed to get embedding by entity_id".to_string())?
|
||||
.ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?;
|
||||
|
||||
assert_eq!(fetched.user_id, user_id);
|
||||
assert_eq!(fetched.entity_id, entity_rid);
|
||||
assert_eq!(fetched.embedding, embedding_vec);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_entity_id() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_delete_by_entity_id() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("set test index dimension");
|
||||
.with_context(|| "set test index dimension".to_string())?;
|
||||
let user_id = "user_ke";
|
||||
let entity_key = "entity-delete";
|
||||
let source_id = "source-del";
|
||||
@@ -220,61 +222,67 @@ mod tests {
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.5_f32, 0.6, 0.7], &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
|
||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
|
||||
let existing = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
.await
|
||||
.expect("Failed to get embedding before delete");
|
||||
.with_context(|| "Failed to get embedding before delete".to_string())?;
|
||||
assert!(existing.is_some());
|
||||
|
||||
KnowledgeEntityEmbedding::delete_by_entity_id(&entity_rid, &db)
|
||||
.await
|
||||
.expect("Failed to delete by entity_id");
|
||||
.with_context(|| "Failed to delete by entity_id".to_string())?;
|
||||
|
||||
let after = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
.await
|
||||
.expect("Failed to get embedding after delete");
|
||||
.with_context(|| "Failed to get embedding after delete".to_string())?;
|
||||
assert!(after.is_none());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_with_embedding_creates_entity_and_embedding() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_store_with_embedding_creates_entity_and_embedding() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "user_store";
|
||||
let source_id = "source_store";
|
||||
let embedding = vec![0.2_f32, 0.3, 0.4];
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, embedding.len())
|
||||
.await
|
||||
.expect("set test index dimension");
|
||||
.with_context(|| "set test index dimension".to_string())?;
|
||||
|
||||
let entity = build_knowledge_entity_with_id("entity-store", source_id, user_id);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), embedding.clone(), &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
|
||||
let stored_entity: Option<KnowledgeEntity> = db.get_item(&entity.id).await.unwrap();
|
||||
let stored_entity: Option<KnowledgeEntity> = db
|
||||
.get_item(&entity.id)
|
||||
.await
|
||||
.with_context(|| "Failed to get entity".to_string())?;
|
||||
assert!(stored_entity.is_some());
|
||||
|
||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
let stored_embedding = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
.await
|
||||
.expect("Failed to fetch embedding");
|
||||
assert!(stored_embedding.is_some());
|
||||
let stored_embedding = stored_embedding.unwrap();
|
||||
.with_context(|| "Failed to fetch embedding".to_string())?;
|
||||
let stored_embedding = stored_embedding.ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?;
|
||||
assert_eq!(stored_embedding.user_id, user_id);
|
||||
assert_eq!(stored_embedding.entity_id, entity_rid);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_delete_by_source_id() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("set test index dimension");
|
||||
.with_context(|| "set test index dimension".to_string())?;
|
||||
let user_id = "user_ke";
|
||||
let source_id = "shared-ke";
|
||||
let other_source = "other-ke";
|
||||
@@ -285,13 +293,13 @@ mod tests {
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity1.clone(), vec![1.0_f32, 1.1, 1.2], &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
KnowledgeEntity::store_with_embedding(entity2.clone(), vec![2.0_f32, 2.1, 2.2], &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
KnowledgeEntity::store_with_embedding(entity_other.clone(), vec![3.0_f32, 3.1, 3.2], &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
|
||||
let entity1_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity1.id);
|
||||
let entity2_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity2.id);
|
||||
@@ -299,59 +307,74 @@ mod tests {
|
||||
|
||||
KnowledgeEntityEmbedding::delete_by_source_id(source_id, &db)
|
||||
.await
|
||||
.expect("Failed to delete by source_id");
|
||||
.with_context(|| "Failed to delete by source_id".to_string())?;
|
||||
|
||||
assert!(
|
||||
KnowledgeEntityEmbedding::get_by_entity_id(&entity1_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.with_context(|| "get entity1 embedding after delete".to_string())?
|
||||
.is_none()
|
||||
);
|
||||
assert!(
|
||||
KnowledgeEntityEmbedding::get_by_entity_id(&entity2_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.with_context(|| "get entity2 embedding after delete".to_string())?
|
||||
.is_none()
|
||||
);
|
||||
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&other_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.with_context(|| "get other embedding after delete".to_string())?
|
||||
.is_some());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_redefine_hnsw_index_updates_dimension() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_redefine_hnsw_index_updates_dimension() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 16)
|
||||
.await
|
||||
.expect("failed to redefine index");
|
||||
.with_context(|| "failed to redefine index".to_string())?;
|
||||
|
||||
let mut info_res = db
|
||||
.client
|
||||
.query("INFO FOR TABLE knowledge_entity_embedding;")
|
||||
.await
|
||||
.expect("info query failed");
|
||||
let info: SurrealValue = info_res.take(0).expect("failed to take info result");
|
||||
let info_json: serde_json::Value =
|
||||
serde_json::to_value(info).expect("failed to convert info to json");
|
||||
let idx_sql = info_json["Object"]["indexes"]["Object"]
|
||||
["idx_embedding_knowledge_entity_embedding"]["Strand"]
|
||||
.as_str()
|
||||
.with_context(|| "info query failed".to_string())?;
|
||||
let info: SurrealValue = info_res
|
||||
.take(0)
|
||||
.with_context(|| "failed to take info result".to_string())?;
|
||||
let info_json: serde_json::Value = serde_json::to_value(info)
|
||||
.with_context(|| "failed to convert info to json".to_string())?;
|
||||
let idx_sql = info_json
|
||||
.get("Object")
|
||||
.and_then(|v| v.get("indexes"))
|
||||
.and_then(|v| v.get("Object"))
|
||||
.and_then(|v| v.get("idx_embedding_knowledge_entity_embedding"))
|
||||
.and_then(|v| v.get("Strand"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default();
|
||||
|
||||
assert!(
|
||||
idx_sql.contains("DIMENSION 16"),
|
||||
"expected index definition to contain new dimension, got: {idx_sql}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fetch_entity_via_record_id() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_fetch_entity_via_record_id() -> anyhow::Result<()> {
|
||||
#[derive(Deserialize)]
|
||||
struct Row {
|
||||
entity_id: KnowledgeEntity,
|
||||
}
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("set test index dimension");
|
||||
.with_context(|| "set test index dimension".to_string())?;
|
||||
let user_id = "user_ke";
|
||||
let entity_key = "entity-fetch";
|
||||
let source_id = "source-fetch";
|
||||
@@ -359,15 +382,10 @@ mod tests {
|
||||
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.7_f32, 0.8, 0.9], &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
|
||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Row {
|
||||
entity_id: KnowledgeEntity,
|
||||
}
|
||||
|
||||
let mut res = db
|
||||
.client
|
||||
.query(
|
||||
@@ -375,13 +393,17 @@ mod tests {
|
||||
)
|
||||
.bind(("id", entity_rid.clone()))
|
||||
.await
|
||||
.expect("failed to fetch embedding with FETCH");
|
||||
let rows: Vec<Row> = res.take(0).expect("failed to deserialize fetch rows");
|
||||
.with_context(|| "failed to fetch embedding with FETCH".to_string())?;
|
||||
let rows: Vec<Row> = res
|
||||
.take(0)
|
||||
.with_context(|| "failed to deserialize fetch rows".to_string())?;
|
||||
|
||||
assert_eq!(rows.len(), 1);
|
||||
let fetched_entity = &rows[0].entity_id;
|
||||
let fetched_entity = &rows.first().context("Expected at least one result")?.entity_id;
|
||||
assert_eq!(fetched_entity.id, entity_key);
|
||||
assert_eq!(fetched_entity.name, "Test entity");
|
||||
assert_eq!(fetched_entity.user_id, user_id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,6 +124,7 @@ impl KnowledgeRelationship {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
use super::*;
|
||||
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||
|
||||
@@ -155,10 +156,9 @@ mod tests {
|
||||
result.take(0).expect("failed to take relationship by id")
|
||||
}
|
||||
|
||||
// Helper function to create a test knowledge entity for the relationship tests
|
||||
async fn create_test_entity(name: &str, db_client: &SurrealDbClient) -> String {
|
||||
async fn create_test_entity(name: &str, db_client: &SurrealDbClient) -> anyhow::Result<String> {
|
||||
let source_id = "source123".to_string();
|
||||
let description = format!("Description for {}", name);
|
||||
let description = format!("Description for {name}");
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
@@ -174,12 +174,14 @@ mod tests {
|
||||
let stored: Option<KnowledgeEntity> = db_client
|
||||
.store_item(entity)
|
||||
.await
|
||||
.expect("Failed to store entity");
|
||||
stored.unwrap().id
|
||||
.with_context(|| "Failed to store entity".to_string())?;
|
||||
stored
|
||||
.ok_or_else(|| anyhow::anyhow!("Expected stored entity to return Some"))
|
||||
.map(|e| e.id)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_relationship_creation() {
|
||||
async fn test_relationship_creation() -> anyhow::Result<()> {
|
||||
let in_id = "entity1".to_string();
|
||||
let out_id = "entity2".to_string();
|
||||
let user_id = "user123".to_string();
|
||||
@@ -194,25 +196,23 @@ mod tests {
|
||||
relationship_type.clone(),
|
||||
);
|
||||
|
||||
// Verify fields are correctly set
|
||||
assert_eq!(relationship.in_, in_id);
|
||||
assert_eq!(relationship.out, out_id);
|
||||
assert_eq!(relationship.metadata.user_id, user_id);
|
||||
assert_eq!(relationship.metadata.source_id, source_id);
|
||||
assert_eq!(relationship.metadata.relationship_type, relationship_type);
|
||||
assert!(!relationship.id.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_and_verify_by_source_id() {
|
||||
// Setup in-memory database for testing
|
||||
async fn test_store_and_verify_by_source_id() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
// Create two entities to relate
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await?;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await?;
|
||||
|
||||
// Create relationship
|
||||
let user_id = "user123".to_string();
|
||||
let source_id = "source123".to_string();
|
||||
let relationship_type = "references".to_string();
|
||||
@@ -225,11 +225,10 @@ mod tests {
|
||||
relationship_type,
|
||||
);
|
||||
|
||||
// Store the relationship
|
||||
relationship
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship");
|
||||
.with_context(|| "Failed to store relationship".to_string())?;
|
||||
|
||||
let persisted = get_relationship_by_id(&relationship.id, &db)
|
||||
.await
|
||||
@@ -239,8 +238,6 @@ mod tests {
|
||||
assert_eq!(persisted.metadata.user_id, user_id);
|
||||
assert_eq!(persisted.metadata.source_id, source_id);
|
||||
|
||||
// Query to verify the relationship exists by checking for relationships with our source_id
|
||||
// This approach is more reliable than trying to look up by ID
|
||||
let mut check_result = db
|
||||
.query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id")
|
||||
.bind(("source_id", source_id.clone()))
|
||||
@@ -253,14 +250,16 @@ mod tests {
|
||||
1,
|
||||
"Expected one relationship for source_id"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_relationship_resists_query_injection() {
|
||||
async fn test_store_relationship_resists_query_injection() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await?;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await?;
|
||||
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
entity1_id,
|
||||
@@ -288,18 +287,17 @@ mod tests {
|
||||
rows[0].metadata.source_id,
|
||||
"source123'; DELETE FROM relates_to; --"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_and_delete_relationship() {
|
||||
// Setup in-memory database for testing
|
||||
async fn test_store_and_delete_relationship() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
// Create two entities to relate
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await?;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await?;
|
||||
|
||||
// Create relationship
|
||||
let user_id = "user123".to_string();
|
||||
let source_id = "source123".to_string();
|
||||
let relationship_type = "references".to_string();
|
||||
@@ -312,52 +310,44 @@ mod tests {
|
||||
relationship_type,
|
||||
);
|
||||
|
||||
// Store relationship
|
||||
relationship
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship");
|
||||
.with_context(|| "Failed to store relationship".to_string())?;
|
||||
|
||||
// Ensure relationship exists before deletion attempt
|
||||
let mut existing_before_delete = db
|
||||
.query(format!(
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{}' AND metadata.source_id = '{}'",
|
||||
user_id, source_id
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{user_id}' AND metadata.source_id = '{source_id}'"
|
||||
))
|
||||
.await
|
||||
.expect("Query failed");
|
||||
.with_context(|| "Query failed".to_string())?;
|
||||
let before_results: Vec<KnowledgeRelationship> =
|
||||
existing_before_delete.take(0).unwrap_or_default();
|
||||
assert!(
|
||||
!before_results.is_empty(),
|
||||
"Relationship should exist before deletion"
|
||||
);
|
||||
assert!(!before_results.is_empty(), "Relationship should exist before deletion");
|
||||
|
||||
// Delete relationship by ID
|
||||
KnowledgeRelationship::delete_relationship_by_id(&relationship.id, &user_id, &db)
|
||||
.await
|
||||
.expect("Failed to delete relationship by ID");
|
||||
.with_context(|| "Failed to delete relationship by ID".to_string())?;
|
||||
|
||||
// Query to verify relationship was deleted
|
||||
let mut result = db
|
||||
.query(format!(
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{}' AND metadata.source_id = '{}'",
|
||||
user_id, source_id
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{user_id}' AND metadata.source_id = '{source_id}'"
|
||||
))
|
||||
.await
|
||||
.expect("Query failed");
|
||||
.with_context(|| "Query failed".to_string())?;
|
||||
let results: Vec<KnowledgeRelationship> = result.take(0).unwrap_or_default();
|
||||
|
||||
// Verify relationship no longer exists
|
||||
assert!(results.is_empty(), "Relationship should be deleted");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_relationship_by_id_unauthorized() {
|
||||
async fn test_delete_relationship_by_id_unauthorized() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await?;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await?;
|
||||
|
||||
let owner_user_id = "owner-user".to_string();
|
||||
let source_id = "source123".to_string();
|
||||
@@ -373,20 +363,16 @@ mod tests {
|
||||
relationship
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship");
|
||||
.with_context(|| "Failed to store relationship".to_string())?;
|
||||
|
||||
let mut before_attempt = db
|
||||
.query(format!(
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{}'",
|
||||
owner_user_id
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{owner_user_id}'"
|
||||
))
|
||||
.await
|
||||
.expect("Query failed");
|
||||
.with_context(|| "Query failed".to_string())?;
|
||||
let before_results: Vec<KnowledgeRelationship> = before_attempt.take(0).unwrap_or_default();
|
||||
assert!(
|
||||
!before_results.is_empty(),
|
||||
"Relationship should exist before unauthorized delete attempt"
|
||||
);
|
||||
assert!(!before_results.is_empty(), "Relationship should exist before unauthorized delete attempt");
|
||||
|
||||
let result = KnowledgeRelationship::delete_relationship_by_id(
|
||||
&relationship.id,
|
||||
@@ -397,40 +383,34 @@ mod tests {
|
||||
|
||||
match result {
|
||||
Err(AppError::Auth(_)) => {}
|
||||
_ => panic!("Expected authorization error when deleting someone else's relationship"),
|
||||
_ => anyhow::bail!("Expected authorization error when deleting someone else's relationship"),
|
||||
}
|
||||
|
||||
let mut after_attempt = db
|
||||
.query(format!(
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{}'",
|
||||
owner_user_id
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{owner_user_id}'"
|
||||
))
|
||||
.await
|
||||
.expect("Query failed");
|
||||
.with_context(|| "Query failed".to_string())?;
|
||||
let results: Vec<KnowledgeRelationship> = after_attempt.take(0).unwrap_or_default();
|
||||
|
||||
assert!(
|
||||
!results.is_empty(),
|
||||
"Relationship should still exist after unauthorized delete attempt"
|
||||
);
|
||||
assert!(!results.is_empty(), "Relationship should still exist after unauthorized delete attempt");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_relationship_exists() {
|
||||
// Setup in-memory database for testing
|
||||
async fn test_store_relationship_exists() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
// Create entities to relate
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
||||
let entity3_id = create_test_entity("Entity 3", &db).await;
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await?;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await?;
|
||||
let entity3_id = create_test_entity("Entity 3", &db).await?;
|
||||
|
||||
// Create relationships with the same source_id
|
||||
let user_id = "user123".to_string();
|
||||
let source_id = "source123".to_string();
|
||||
let different_source_id = "different_source".to_string();
|
||||
|
||||
// Create two relationships with the same source_id
|
||||
let relationship1 = KnowledgeRelationship::new(
|
||||
entity1_id.clone(),
|
||||
entity2_id.clone(),
|
||||
@@ -447,7 +427,6 @@ mod tests {
|
||||
"contains".to_string(),
|
||||
);
|
||||
|
||||
// Create a relationship with a different source_id
|
||||
let different_relationship = KnowledgeRelationship::new(
|
||||
entity1_id.clone(),
|
||||
entity3_id.clone(),
|
||||
@@ -456,21 +435,19 @@ mod tests {
|
||||
"mentions".to_string(),
|
||||
);
|
||||
|
||||
// Store all relationships
|
||||
relationship1
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship 1");
|
||||
.with_context(|| "Failed to store relationship 1".to_string())?;
|
||||
relationship2
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship 2");
|
||||
.with_context(|| "Failed to store relationship 2".to_string())?;
|
||||
different_relationship
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store different relationship");
|
||||
.with_context(|| "Failed to store different relationship".to_string())?;
|
||||
|
||||
// Sanity-check setup: exactly two relationships use source_id and one uses different_source_id.
|
||||
let mut before_delete = db
|
||||
.query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id")
|
||||
.bind(("source_id", source_id.clone()))
|
||||
@@ -489,31 +466,30 @@ mod tests {
|
||||
before_delete_different.take(0).unwrap_or_default();
|
||||
assert_eq!(before_delete_different_rows.len(), 1);
|
||||
|
||||
// Delete relationships by source_id
|
||||
KnowledgeRelationship::delete_relationships_by_source_id(&source_id, &db)
|
||||
.await
|
||||
.expect("Failed to delete relationships by source_id");
|
||||
.with_context(|| "Failed to delete relationships by source_id".to_string())?;
|
||||
|
||||
// Query to verify the specific relationships with source_id were deleted.
|
||||
let result1 = get_relationship_by_id(&relationship1.id, &db).await;
|
||||
let result2 = get_relationship_by_id(&relationship2.id, &db).await;
|
||||
let different_result = get_relationship_by_id(&different_relationship.id, &db).await;
|
||||
|
||||
// Verify relationships with the source_id are deleted
|
||||
assert!(result1.is_none(), "Relationship 1 should be deleted");
|
||||
assert!(result2.is_none(), "Relationship 2 should be deleted");
|
||||
let remaining =
|
||||
different_result.expect("Relationship with different source_id should remain");
|
||||
assert_eq!(remaining.metadata.source_id, different_source_id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_relationships_by_source_id_resists_query_injection() {
|
||||
async fn test_delete_relationships_by_source_id_resists_query_injection() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
||||
let entity3_id = create_test_entity("Entity 3", &db).await;
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await?;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await?;
|
||||
let entity3_id = create_test_entity("Entity 3", &db).await?;
|
||||
|
||||
let safe_relationship = KnowledgeRelationship::new(
|
||||
entity1_id.clone(),
|
||||
@@ -552,5 +528,7 @@ mod tests {
|
||||
remaining_other.is_some(),
|
||||
"Other relationship should remain"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,12 +66,12 @@ pub fn format_history(history: &[Message]) -> String {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
use super::*;
|
||||
use crate::storage::db::SurrealDbClient;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_message_creation() {
|
||||
// Test basic message creation
|
||||
async fn test_message_creation() -> anyhow::Result<()> {
|
||||
let conversation_id = "test_conversation";
|
||||
let content = "This is a test message";
|
||||
let role = MessageRole::User;
|
||||
@@ -84,24 +84,23 @@ mod tests {
|
||||
references.clone(),
|
||||
);
|
||||
|
||||
// Verify message properties
|
||||
assert_eq!(message.conversation_id, conversation_id);
|
||||
assert_eq!(message.content, content);
|
||||
assert_eq!(message.role, role);
|
||||
assert_eq!(message.references, references);
|
||||
assert!(!message.id.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_message_persistence() {
|
||||
// Setup in-memory database for testing
|
||||
async fn test_message_persistence() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &uuid::Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Create and store a message
|
||||
let conversation_id = "test_conversation";
|
||||
let message = Message::new(
|
||||
conversation_id.to_string(),
|
||||
@@ -111,39 +110,37 @@ mod tests {
|
||||
);
|
||||
let message_id = message.id.clone();
|
||||
|
||||
// Store the message
|
||||
db.store_item(message.clone())
|
||||
.await
|
||||
.expect("Failed to store message");
|
||||
.with_context(|| "Failed to store message".to_string())?;
|
||||
|
||||
// Retrieve the message
|
||||
let retrieved: Option<Message> = db
|
||||
.get_item(&message_id)
|
||||
.await
|
||||
.expect("Failed to retrieve message");
|
||||
.with_context(|| "Failed to retrieve message".to_string())?;
|
||||
|
||||
assert!(retrieved.is_some());
|
||||
let retrieved = retrieved.unwrap();
|
||||
let retrieved = retrieved.ok_or_else(|| anyhow::anyhow!("Expected message to exist"))?;
|
||||
|
||||
// Verify retrieved properties match original
|
||||
assert_eq!(retrieved.id, message.id);
|
||||
assert_eq!(retrieved.conversation_id, message.conversation_id);
|
||||
assert_eq!(retrieved.role, message.role);
|
||||
assert_eq!(retrieved.content, message.content);
|
||||
assert_eq!(retrieved.references, message.references);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_message_role_display() {
|
||||
// Test the Display implementation for MessageRole
|
||||
async fn test_message_role_display() -> anyhow::Result<()> {
|
||||
assert_eq!(format!("{}", MessageRole::User), "User");
|
||||
assert_eq!(format!("{}", MessageRole::AI), "AI");
|
||||
assert_eq!(format!("{}", MessageRole::System), "System");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_message_display() {
|
||||
// Test the Display implementation for Message
|
||||
async fn test_message_display() -> anyhow::Result<()> {
|
||||
let message = Message {
|
||||
id: "test_id".to_string(),
|
||||
created_at: Utc::now(),
|
||||
@@ -154,12 +151,13 @@ mod tests {
|
||||
references: None,
|
||||
};
|
||||
|
||||
assert_eq!(format!("{}", message), "User: Hello world");
|
||||
assert_eq!(format!("{message}"), "User: Hello world");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_format_history() {
|
||||
// Create a vector of messages
|
||||
async fn test_format_history() -> anyhow::Result<()> {
|
||||
let messages = vec![
|
||||
Message {
|
||||
id: "1".to_string(),
|
||||
@@ -181,10 +179,10 @@ mod tests {
|
||||
},
|
||||
];
|
||||
|
||||
// Format the history
|
||||
let formatted = format_history(&messages);
|
||||
|
||||
// Verify the formatting
|
||||
assert_eq!(formatted, "User: Hello\nAI: Hi there!");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -216,20 +216,22 @@ impl Scratchpad {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_scratchpad() {
|
||||
async fn test_create_scratchpad() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
// Create a new scratchpad
|
||||
let user_id = "test_user";
|
||||
@@ -254,29 +256,28 @@ mod tests {
|
||||
let retrieved: Option<Scratchpad> = db
|
||||
.get_item(&scratchpad.id)
|
||||
.await
|
||||
.expect("Failed to retrieve scratchpad");
|
||||
assert!(retrieved.is_some());
|
||||
|
||||
let retrieved = retrieved.unwrap();
|
||||
.with_context(|| "Failed to retrieve scratchpad".to_string())?;
|
||||
let retrieved = retrieved.with_context(|| "expected scratchpad to exist".to_string())?;
|
||||
assert_eq!(retrieved.id, scratchpad.id);
|
||||
assert_eq!(retrieved.user_id, user_id);
|
||||
assert_eq!(retrieved.title, title);
|
||||
assert!(!retrieved.is_archived);
|
||||
assert!(retrieved.archived_at.is_none());
|
||||
assert!(retrieved.ingested_at.is_none());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_by_user() {
|
||||
async fn test_get_by_user() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
let user_id = "test_user";
|
||||
|
||||
@@ -288,19 +289,21 @@ mod tests {
|
||||
// Store them
|
||||
let scratchpad1_id = scratchpad1.id.clone();
|
||||
let scratchpad2_id = scratchpad2.id.clone();
|
||||
db.store_item(scratchpad1).await.unwrap();
|
||||
db.store_item(scratchpad2).await.unwrap();
|
||||
db.store_item(scratchpad3).await.unwrap();
|
||||
db.store_item(scratchpad1).await.with_context(|| "store scratchpad1".to_string())?;
|
||||
db.store_item(scratchpad2).await.with_context(|| "store scratchpad2".to_string())?;
|
||||
db.store_item(scratchpad3).await.with_context(|| "store scratchpad3".to_string())?;
|
||||
|
||||
// Archive one of the user's scratchpads
|
||||
Scratchpad::archive(&scratchpad2_id, user_id, &db, false)
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "archive".to_string())?;
|
||||
|
||||
// Get scratchpads for user_id
|
||||
let user_scratchpads = Scratchpad::get_by_user(user_id, &db).await.unwrap();
|
||||
let user_scratchpads = Scratchpad::get_by_user(user_id, &db)
|
||||
.await
|
||||
.with_context(|| "get_by_user".to_string())?;
|
||||
assert_eq!(user_scratchpads.len(), 1);
|
||||
assert_eq!(user_scratchpads[0].id, scratchpad1_id);
|
||||
assert_eq!(user_scratchpads.first().map(|s| &s.id), Some(&scratchpad1_id));
|
||||
|
||||
// Verify they belong to the user
|
||||
for scratchpad in &user_scratchpads {
|
||||
@@ -309,177 +312,183 @@ mod tests {
|
||||
|
||||
let archived = Scratchpad::get_archived_by_user(user_id, &db)
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "get_archived_by_user".to_string())?;
|
||||
assert_eq!(archived.len(), 1);
|
||||
assert_eq!(archived[0].id, scratchpad2_id);
|
||||
assert!(archived[0].is_archived);
|
||||
assert!(archived[0].ingested_at.is_none());
|
||||
assert_eq!(archived.first().map(|s| &s.id), Some(&scratchpad2_id));
|
||||
assert!(archived.first().is_some_and(|s| s.is_archived));
|
||||
assert!(archived.first().is_some_and(|s| s.ingested_at.is_none()));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_archive_and_restore() {
|
||||
async fn test_archive_and_restore() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
let user_id = "test_user";
|
||||
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
||||
let scratchpad_id = scratchpad.id.clone();
|
||||
db.store_item(scratchpad).await.unwrap();
|
||||
db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?;
|
||||
|
||||
let archived = Scratchpad::archive(&scratchpad_id, user_id, &db, true)
|
||||
.await
|
||||
.expect("Failed to archive");
|
||||
.with_context(|| "Failed to archive".to_string())?;
|
||||
assert!(archived.is_archived);
|
||||
assert!(archived.archived_at.is_some());
|
||||
assert!(archived.ingested_at.is_some());
|
||||
|
||||
let restored = Scratchpad::restore(&scratchpad_id, user_id, &db)
|
||||
.await
|
||||
.expect("Failed to restore");
|
||||
.with_context(|| "Failed to restore".to_string())?;
|
||||
assert!(!restored.is_archived);
|
||||
assert!(restored.archived_at.is_none());
|
||||
assert!(restored.ingested_at.is_none());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_content() {
|
||||
async fn test_update_content() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
let user_id = "test_user";
|
||||
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
||||
let scratchpad_id = scratchpad.id.clone();
|
||||
|
||||
db.store_item(scratchpad).await.unwrap();
|
||||
db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?;
|
||||
|
||||
let new_content = "Updated content";
|
||||
let updated = Scratchpad::update_content(&scratchpad_id, user_id, new_content, &db)
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "update_content".to_string())?;
|
||||
|
||||
assert_eq!(updated.content, new_content);
|
||||
assert!(!updated.is_dirty);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_content_unauthorized() {
|
||||
async fn test_update_content_unauthorized() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
let owner_id = "owner";
|
||||
let other_user = "other_user";
|
||||
let scratchpad = Scratchpad::new(owner_id.to_string(), "Test".to_string());
|
||||
let scratchpad_id = scratchpad.id.clone();
|
||||
|
||||
db.store_item(scratchpad).await.unwrap();
|
||||
db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?;
|
||||
|
||||
let result = Scratchpad::update_content(&scratchpad_id, other_user, "Hacked", &db).await;
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(AppError::Auth(_)) => {}
|
||||
_ => panic!("Expected Auth error"),
|
||||
_ => anyhow::bail!("Expected Auth error"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_scratchpad() {
|
||||
async fn test_delete_scratchpad() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
let user_id = "test_user";
|
||||
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
||||
let scratchpad_id = scratchpad.id.clone();
|
||||
|
||||
db.store_item(scratchpad).await.unwrap();
|
||||
db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?;
|
||||
|
||||
// Delete should succeed
|
||||
let result = Scratchpad::delete(&scratchpad_id, user_id, &db).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify it's gone
|
||||
let retrieved: Option<Scratchpad> = db.get_item(&scratchpad_id).await.unwrap();
|
||||
let retrieved: Option<Scratchpad> = db.get_item(&scratchpad_id).await.with_context(|| "get_item".to_string())?;
|
||||
assert!(retrieved.is_none());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_unauthorized() {
|
||||
async fn test_delete_unauthorized() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
let owner_id = "owner";
|
||||
let other_user = "other_user";
|
||||
let scratchpad = Scratchpad::new(owner_id.to_string(), "Test".to_string());
|
||||
let scratchpad_id = scratchpad.id.clone();
|
||||
|
||||
db.store_item(scratchpad).await.unwrap();
|
||||
db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?;
|
||||
|
||||
let result = Scratchpad::delete(&scratchpad_id, other_user, &db).await;
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(AppError::Auth(_)) => {}
|
||||
_ => panic!("Expected Auth error"),
|
||||
_ => anyhow::bail!("Expected Auth error"),
|
||||
}
|
||||
|
||||
// Verify it still exists
|
||||
let retrieved: Option<Scratchpad> = db.get_item(&scratchpad_id).await.unwrap();
|
||||
let retrieved: Option<Scratchpad> = db.get_item(&scratchpad_id).await.with_context(|| "get_item".to_string())?;
|
||||
assert!(retrieved.is_some());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timezone_aware_scratchpad_conversion() {
|
||||
async fn test_timezone_aware_scratchpad_conversion() -> anyhow::Result<()> {
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.expect("Failed to create test database");
|
||||
.with_context(|| "Failed to create test database".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
let user_id = "test_user_123";
|
||||
let scratchpad =
|
||||
Scratchpad::new(user_id.to_string(), "Test Timezone Scratchpad".to_string());
|
||||
let scratchpad_id = scratchpad.id.clone();
|
||||
|
||||
db.store_item(scratchpad).await.unwrap();
|
||||
db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?;
|
||||
|
||||
let retrieved = Scratchpad::get_by_id(&scratchpad_id, user_id, &db)
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "get_by_id".to_string())?;
|
||||
|
||||
// Test that datetime fields are preserved and can be used for timezone formatting
|
||||
assert!(retrieved.created_at.timestamp() > 0);
|
||||
@@ -493,10 +502,11 @@ mod tests {
|
||||
// Archive the scratchpad to test optional datetime handling
|
||||
let archived = Scratchpad::archive(&scratchpad_id, user_id, &db, false)
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "archive".to_string())?;
|
||||
|
||||
assert!(archived.archived_at.is_some());
|
||||
assert!(archived.archived_at.unwrap().timestamp() > 0);
|
||||
assert!(archived.archived_at.with_context(|| "expected archived_at".to_string())?.timestamp() > 0);
|
||||
assert!(archived.ingested_at.is_none());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,7 +64,14 @@ impl SystemSettings {
|
||||
let mut needs_update = false;
|
||||
|
||||
let backend_label = provider.backend_label().to_string();
|
||||
let provider_dimensions = provider.dimension() as u32;
|
||||
let provider_dimensions = u32::try_from(provider.dimension())
|
||||
.unwrap_or_else(|_| {
|
||||
tracing::warn!(
|
||||
"Provider dimension {} exceeds u32 max; falling back to 0",
|
||||
provider.dimension()
|
||||
);
|
||||
0u32
|
||||
});
|
||||
let provider_model = provider.model_code();
|
||||
|
||||
// Sync backend label
|
||||
@@ -107,7 +114,8 @@ impl SystemSettings {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::storage::indexes::ensure_runtime_indexes;
|
||||
use anyhow::{self, Context};
|
||||
use crate::storage::indexes::ensure_runtime;
|
||||
use crate::storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk};
|
||||
use async_openai::Client;
|
||||
|
||||
@@ -118,68 +126,102 @@ mod tests {
|
||||
db: &SurrealDbClient,
|
||||
table_name: &str,
|
||||
index_name: &str,
|
||||
) -> u32 {
|
||||
) -> anyhow::Result<u32> {
|
||||
let query = format!("INFO FOR TABLE {table_name};");
|
||||
let mut response = db
|
||||
.client
|
||||
.query(query)
|
||||
.await
|
||||
.expect("Failed to fetch table info");
|
||||
.with_context(|| "Failed to fetch table info".to_string())?;
|
||||
|
||||
let info: surrealdb::Value = response
|
||||
.take(0)
|
||||
.expect("Failed to extract table info response");
|
||||
.with_context(|| "Failed to extract table info response".to_string())?;
|
||||
|
||||
let info_json: serde_json::Value =
|
||||
serde_json::to_value(info).expect("Failed to convert info to json");
|
||||
serde_json::to_value(info).with_context(|| "Failed to convert info to json".to_string())?;
|
||||
|
||||
let indexes = info_json["Object"]["indexes"]["Object"]
|
||||
.as_object()
|
||||
.unwrap_or_else(|| panic!("Indexes collection missing in table info: {info_json:#?}"));
|
||||
let indexes = info_json
|
||||
.get("Object")
|
||||
.and_then(|v| v.get("indexes"))
|
||||
.and_then(|v| v.get("Object"))
|
||||
.and_then(|v| v.as_object())
|
||||
.with_context(|| format!("Indexes collection missing in table info: {info_json:#?}"))?;
|
||||
|
||||
let definition = indexes
|
||||
.get(index_name)
|
||||
.and_then(|definition| definition.get("Strand"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_else(|| panic!("Index definition not found in table info: {info_json:#?}"));
|
||||
.with_context(|| format!("Index definition not found in table info: {info_json:#?}"))?;
|
||||
|
||||
let dimension_part = definition
|
||||
.split("DIMENSION")
|
||||
.nth(1)
|
||||
.expect("Index definition missing DIMENSION clause");
|
||||
.with_context(|| "Index definition missing DIMENSION clause".to_string())?;
|
||||
|
||||
let dimension_token = dimension_part
|
||||
.split_whitespace()
|
||||
.next()
|
||||
.expect("Dimension value missing in definition")
|
||||
.with_context(|| "Dimension value missing in definition".to_string())?
|
||||
.trim_end_matches(';');
|
||||
|
||||
dimension_token
|
||||
.parse::<u32>()
|
||||
.expect("Dimension value is not a valid number")
|
||||
.with_context(|| "Dimension value is not a valid number".to_string())
|
||||
}
|
||||
|
||||
async fn simulate_reembedding(
|
||||
db: &SurrealDbClient,
|
||||
target_dimension: usize,
|
||||
initial_chunk: TextChunk,
|
||||
) -> anyhow::Result<()> {
|
||||
db.query(
|
||||
"REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding;",
|
||||
)
|
||||
.await
|
||||
.with_context(|| "remove index".to_string())?;
|
||||
let define_index_query = format!(
|
||||
"DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {target_dimension};"
|
||||
);
|
||||
db.query(define_index_query)
|
||||
.await
|
||||
.with_context(|| "Re-defining index should succeed".to_string())?;
|
||||
|
||||
let new_embedding = vec![0.5; target_dimension];
|
||||
let sql = "UPSERT type::thing('text_chunk_embedding', $id) SET chunk_id = type::thing('text_chunk', $id), embedding = $embedding, user_id = $user_id;";
|
||||
|
||||
db.client
|
||||
.query(sql)
|
||||
.bind(("id", initial_chunk.id.clone()))
|
||||
.bind(("user_id", initial_chunk.user_id.clone()))
|
||||
.bind(("embedding", new_embedding))
|
||||
.await
|
||||
.with_context(|| "upsert embedding".to_string())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_settings_initialization() {
|
||||
async fn test_settings_initialization() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Test initialization of system settings
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to get system settings");
|
||||
.with_context(|| "Failed to get system settings".to_string())?;
|
||||
|
||||
// Verify initial state after initialization
|
||||
assert_eq!(settings.id, "current");
|
||||
assert_eq!(settings.registrations_enabled, true);
|
||||
assert_eq!(settings.require_email_verification, false);
|
||||
assert!(settings.registrations_enabled);
|
||||
assert!(!settings.require_email_verification);
|
||||
assert_eq!(settings.query_model, "gpt-4o-mini");
|
||||
assert_eq!(settings.processing_model, "gpt-4o-mini");
|
||||
assert_eq!(settings.image_processing_model, "gpt-4o-mini");
|
||||
@@ -196,10 +238,10 @@ mod tests {
|
||||
// Test idempotency - ensure calling it again doesn't change anything
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let settings_again = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to get settings after initialization");
|
||||
.with_context(|| "Failed to get settings after initialization".to_string())?;
|
||||
|
||||
assert_eq!(settings.id, settings_again.id);
|
||||
assert_eq!(
|
||||
@@ -210,48 +252,52 @@ mod tests {
|
||||
settings.require_email_verification,
|
||||
settings_again.require_email_verification
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_current_settings() {
|
||||
async fn test_get_current_settings() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Initialize settings
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
// Test get_current method
|
||||
let settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to get current settings");
|
||||
.with_context(|| "Failed to get current settings".to_string())?;
|
||||
|
||||
assert_eq!(settings.id, "current");
|
||||
assert_eq!(settings.registrations_enabled, true);
|
||||
assert_eq!(settings.require_email_verification, false);
|
||||
assert!(settings.registrations_enabled);
|
||||
assert!(!settings.require_email_verification);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_settings() {
|
||||
async fn test_update_settings() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Initialize settings
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
// Create updated settings
|
||||
let mut updated_settings = SystemSettings::get_current(&db).await.unwrap();
|
||||
let mut updated_settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.with_context(|| "get_current".to_string())?;
|
||||
updated_settings.id = "current".to_string();
|
||||
updated_settings.registrations_enabled = false;
|
||||
updated_settings.require_email_verification = true;
|
||||
@@ -260,31 +306,32 @@ mod tests {
|
||||
// Test update method
|
||||
let result = SystemSettings::update(&db, updated_settings)
|
||||
.await
|
||||
.expect("Failed to update settings");
|
||||
.with_context(|| "Failed to update settings".to_string())?;
|
||||
|
||||
assert_eq!(result.id, "current");
|
||||
assert_eq!(result.registrations_enabled, false);
|
||||
assert_eq!(result.require_email_verification, true);
|
||||
assert!(!result.registrations_enabled);
|
||||
assert!(result.require_email_verification);
|
||||
assert_eq!(result.query_model, "gpt-4");
|
||||
|
||||
// Verify changes persisted by getting current settings
|
||||
let current = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to get current settings after update");
|
||||
.with_context(|| "Failed to get current settings after update".to_string())?;
|
||||
|
||||
assert_eq!(current.registrations_enabled, false);
|
||||
assert_eq!(current.require_email_verification, true);
|
||||
assert!(!current.registrations_enabled);
|
||||
assert!(current.require_email_verification);
|
||||
assert_eq!(current.query_model, "gpt-4");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_current_nonexistent() {
|
||||
async fn test_get_current_nonexistent() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Don't initialize settings and try to get them
|
||||
let result = SystemSettings::get_current(&db).await;
|
||||
@@ -294,21 +341,22 @@ mod tests {
|
||||
Err(AppError::NotFound(_)) => {
|
||||
// Expected error
|
||||
}
|
||||
Err(e) => panic!("Expected NotFound error, got: {:?}", e),
|
||||
Ok(_) => panic!("Expected error but got Ok"),
|
||||
Err(e) => anyhow::bail!("Expected NotFound error, got: {e:?}"),
|
||||
Ok(_) => anyhow::bail!("Expected error but got Ok"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_migration_after_changing_embedding_length() {
|
||||
async fn test_migration_after_changing_embedding_length() -> anyhow::Result<()> {
|
||||
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.expect("Failed to start DB");
|
||||
.with_context(|| "Failed to start DB".to_string())?;
|
||||
|
||||
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Initial migration failed");
|
||||
.with_context(|| "Initial migration failed".to_string())?;
|
||||
|
||||
let initial_chunk = TextChunk::new(
|
||||
"source1".into(),
|
||||
@@ -318,43 +366,11 @@ mod tests {
|
||||
|
||||
TextChunk::store_with_embedding(initial_chunk.clone(), vec![0.1; 1536], &db)
|
||||
.await
|
||||
.expect("Failed to store initial chunk with embedding");
|
||||
|
||||
async fn simulate_reembedding(
|
||||
db: &SurrealDbClient,
|
||||
target_dimension: usize,
|
||||
initial_chunk: TextChunk,
|
||||
) {
|
||||
db.query(
|
||||
"REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding;",
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let define_index_query = format!(
|
||||
"DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};",
|
||||
target_dimension
|
||||
);
|
||||
db.query(define_index_query)
|
||||
.await
|
||||
.expect("Re-defining index should succeed");
|
||||
|
||||
let new_embedding = vec![0.5; target_dimension];
|
||||
let sql = "UPSERT type::thing('text_chunk_embedding', $id) SET chunk_id = type::thing('text_chunk', $id), embedding = $embedding, user_id = $user_id;";
|
||||
|
||||
let update_result = db
|
||||
.client
|
||||
.query(sql)
|
||||
.bind(("id", initial_chunk.id.clone()))
|
||||
.bind(("user_id", initial_chunk.user_id.clone()))
|
||||
.bind(("embedding", new_embedding))
|
||||
.await;
|
||||
|
||||
assert!(update_result.is_ok());
|
||||
}
|
||||
.with_context(|| "Failed to store initial chunk with embedding".to_string())?;
|
||||
|
||||
// Re-embed with the existing configured dimension to ensure migrations remain idempotent.
|
||||
let target_dimension = 1536usize;
|
||||
simulate_reembedding(&db, target_dimension, initial_chunk).await;
|
||||
simulate_reembedding(&db, target_dimension, initial_chunk).await?;
|
||||
|
||||
let migration_result = db.apply_migrations().await;
|
||||
|
||||
@@ -363,34 +379,35 @@ mod tests {
|
||||
"Migrations should not fail: {:?}",
|
||||
migration_result.err()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_should_change_embedding_length_on_indexes_when_switching_length() {
|
||||
async fn test_should_change_embedding_length_on_indexes_when_switching_length() -> anyhow::Result<()> {
|
||||
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.expect("Failed to start DB");
|
||||
.with_context(|| "Failed to start DB".to_string())?;
|
||||
|
||||
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Initial migration failed");
|
||||
.with_context(|| "Initial migration failed".to_string())?;
|
||||
|
||||
let mut current_settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to load current settings");
|
||||
.with_context(|| "Failed to load current settings".to_string())?;
|
||||
|
||||
// Ensure runtime indexes exist with the current embedding dimension so INFO queries succeed.
|
||||
ensure_runtime_indexes(&db, current_settings.embedding_dimensions as usize)
|
||||
ensure_runtime(&db, current_settings.embedding_dimensions as usize)
|
||||
.await
|
||||
.expect("failed to build runtime indexes");
|
||||
.with_context(|| "failed to build runtime indexes".to_string())?;
|
||||
|
||||
let initial_chunk_dimension = get_hnsw_index_dimension(
|
||||
&db,
|
||||
"text_chunk_embedding",
|
||||
"idx_embedding_text_chunk_embedding",
|
||||
)
|
||||
.await;
|
||||
.await?;
|
||||
|
||||
assert_eq!(
|
||||
initial_chunk_dimension, current_settings.embedding_dimensions,
|
||||
@@ -405,7 +422,7 @@ mod tests {
|
||||
|
||||
let updated_settings = SystemSettings::update(&db, current_settings)
|
||||
.await
|
||||
.expect("Failed to update settings");
|
||||
.with_context(|| "Failed to update settings".to_string())?;
|
||||
|
||||
assert_eq!(
|
||||
updated_settings.embedding_dimensions, new_dimension,
|
||||
@@ -416,23 +433,23 @@ mod tests {
|
||||
|
||||
TextChunk::update_all_embeddings(&db, &openai_client, &new_model, new_dimension)
|
||||
.await
|
||||
.expect("TextChunk re-embedding should succeed on fresh DB");
|
||||
.with_context(|| "TextChunk re-embedding should succeed on fresh DB".to_string())?;
|
||||
KnowledgeEntity::update_all_embeddings(&db, &openai_client, &new_model, new_dimension)
|
||||
.await
|
||||
.expect("KnowledgeEntity re-embedding should succeed on fresh DB");
|
||||
.with_context(|| "KnowledgeEntity re-embedding should succeed on fresh DB".to_string())?;
|
||||
|
||||
let text_chunk_dimension = get_hnsw_index_dimension(
|
||||
&db,
|
||||
"text_chunk_embedding",
|
||||
"idx_embedding_text_chunk_embedding",
|
||||
)
|
||||
.await;
|
||||
.await?;
|
||||
let knowledge_dimension = get_hnsw_index_dimension(
|
||||
&db,
|
||||
"knowledge_entity_embedding",
|
||||
"idx_embedding_knowledge_entity_embedding",
|
||||
)
|
||||
.await;
|
||||
.await?;
|
||||
|
||||
assert_eq!(
|
||||
text_chunk_dimension, new_dimension,
|
||||
@@ -445,10 +462,11 @@ mod tests {
|
||||
|
||||
let persisted_settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to reload updated settings");
|
||||
.with_context(|| "Failed to reload updated settings".to_string())?;
|
||||
assert_eq!(
|
||||
persisted_settings.embedding_dimensions, new_dimension,
|
||||
"Settings should persist new embedding dimension"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#![allow(clippy::missing_docs_in_private_items, clippy::uninlined_format_args)]
|
||||
#![allow(clippy::missing_docs_in_private_items)]
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Write;
|
||||
|
||||
@@ -237,10 +237,7 @@ impl TextChunk {
|
||||
new_model: &str,
|
||||
new_dimensions: u32,
|
||||
) -> Result<(), AppError> {
|
||||
info!(
|
||||
"Starting re-embedding process for all text chunks. New dimensions: {}",
|
||||
new_dimensions
|
||||
);
|
||||
info!("Starting re-embedding process for all text chunks. New dimensions: {new_dimensions}");
|
||||
|
||||
// Fetch all chunks first
|
||||
let all_chunks: Vec<TextChunk> = db.select(Self::table_name()).await?;
|
||||
@@ -252,7 +249,7 @@ impl TextChunk {
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
info!("Found {} chunks to process.", total_chunks);
|
||||
info!("Found {total_chunks} chunks to process.");
|
||||
|
||||
// Generate all new embeddings in memory
|
||||
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new();
|
||||
@@ -276,7 +273,7 @@ impl TextChunk {
|
||||
"CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.",
|
||||
chunk.id, embedding.len(), new_dimensions
|
||||
);
|
||||
error!("{}", err_msg);
|
||||
error!("{err_msg}");
|
||||
return Err(AppError::InternalError(err_msg));
|
||||
}
|
||||
new_embeddings.insert(
|
||||
@@ -300,6 +297,7 @@ impl TextChunk {
|
||||
.join(",")
|
||||
);
|
||||
// Use the chunk id as the embedding record id to keep a 1:1 mapping
|
||||
let embedding = embedding_str;
|
||||
write!(
|
||||
&mut transaction_query,
|
||||
"UPSERT type::thing('text_chunk_embedding', '{id}') SET \
|
||||
@@ -309,18 +307,13 @@ impl TextChunk {
|
||||
user_id = '{user_id}', \
|
||||
created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \
|
||||
updated_at = time::now();",
|
||||
id = id,
|
||||
embedding = embedding_str,
|
||||
user_id = user_id,
|
||||
source_id = source_id
|
||||
)
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
}
|
||||
|
||||
write!(
|
||||
&mut transaction_query,
|
||||
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};",
|
||||
new_dimensions
|
||||
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};",
|
||||
)
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
|
||||
@@ -377,7 +370,7 @@ impl TextChunk {
|
||||
"CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.",
|
||||
chunk.id, embedding.len(), new_dimensions
|
||||
);
|
||||
error!("{}", err_msg);
|
||||
error!("{err_msg}");
|
||||
return Err(AppError::InternalError(err_msg));
|
||||
}
|
||||
new_embeddings.insert(
|
||||
@@ -422,6 +415,7 @@ impl TextChunk {
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
let embedding = embedding_str;
|
||||
write!(
|
||||
&mut transaction_query,
|
||||
"CREATE type::thing('text_chunk_embedding', '{id}') SET \
|
||||
@@ -431,18 +425,13 @@ impl TextChunk {
|
||||
user_id = '{user_id}', \
|
||||
created_at = time::now(), \
|
||||
updated_at = time::now();",
|
||||
id = id,
|
||||
embedding = embedding_str,
|
||||
user_id = user_id,
|
||||
source_id = source_id
|
||||
)
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
}
|
||||
|
||||
write!(
|
||||
&mut transaction_query,
|
||||
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};",
|
||||
new_dimensions
|
||||
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};",
|
||||
)
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
|
||||
@@ -462,20 +451,21 @@ impl TextChunk {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
|
||||
use super::*;
|
||||
use crate::storage::indexes::{ensure_runtime_indexes, rebuild_indexes};
|
||||
use crate::storage::indexes::{ensure_runtime, rebuild};
|
||||
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
|
||||
use surrealdb::RecordId;
|
||||
use uuid::Uuid;
|
||||
|
||||
async fn ensure_chunk_fts_index(db: &SurrealDbClient) {
|
||||
async fn ensure_chunk_fts_index(db: &SurrealDbClient) -> anyhow::Result<()> {
|
||||
let snowball_sql = r#"
|
||||
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english);
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
"#;
|
||||
|
||||
if let Err(err) = db.client.query(snowball_sql).await {
|
||||
// Fall back to ascii-only analyzer when snowball is unavailable in the build.
|
||||
let fallback_sql = r#"
|
||||
DEFINE ANALYZER OVERWRITE app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii;
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
@@ -484,12 +474,13 @@ mod tests {
|
||||
db.client
|
||||
.query(fallback_sql)
|
||||
.await
|
||||
.unwrap_or_else(|_| panic!("define chunk fts index fallback: {err}"));
|
||||
.with_context(|| format!("define chunk fts index fallback: {err}"))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_chunk_creation() {
|
||||
async fn test_text_chunk_creation() -> anyhow::Result<()> {
|
||||
let source_id = "source123".to_string();
|
||||
let chunk = "This is a text chunk for testing embeddings".to_string();
|
||||
let user_id = "user123".to_string();
|
||||
@@ -500,22 +491,23 @@ mod tests {
|
||||
assert_eq!(text_chunk.chunk, chunk);
|
||||
assert_eq!(text_chunk.user_id, user_id);
|
||||
assert!(!text_chunk.id.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id() {
|
||||
async fn test_delete_by_source_id() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let source_id = "source123".to_string();
|
||||
let user_id = "user123".to_string();
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
.with_context(|| "redefine index".to_string())?;
|
||||
|
||||
let chunk1 = TextChunk::new(
|
||||
source_id.clone(),
|
||||
@@ -535,61 +527,63 @@ mod tests {
|
||||
|
||||
TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
|
||||
.await
|
||||
.expect("store chunk1");
|
||||
.with_context(|| "store chunk1".to_string())?;
|
||||
TextChunk::store_with_embedding(chunk2.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
|
||||
.await
|
||||
.expect("store chunk2");
|
||||
.with_context(|| "store chunk2".to_string())?;
|
||||
TextChunk::store_with_embedding(
|
||||
different_chunk.clone(),
|
||||
vec![0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
.expect("store different chunk");
|
||||
.with_context(|| "store different chunk".to_string())?;
|
||||
|
||||
TextChunk::delete_by_source_id(&source_id, &db)
|
||||
.await
|
||||
.expect("Failed to delete chunks by source_id");
|
||||
.with_context(|| "Failed to delete chunks by source_id".to_string())?;
|
||||
|
||||
let remaining: Vec<TextChunk> = db
|
||||
.client
|
||||
.query(format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
"SELECT * FROM {} WHERE source_id = '{source_id}'",
|
||||
TextChunk::table_name(),
|
||||
source_id
|
||||
))
|
||||
.await
|
||||
.expect("Query failed")
|
||||
.with_context(|| "Query failed".to_string())?
|
||||
.take(0)
|
||||
.expect("Failed to get query results");
|
||||
.with_context(|| "Failed to get query results".to_string())?;
|
||||
assert_eq!(remaining.len(), 0);
|
||||
|
||||
let different_remaining: Vec<TextChunk> = db
|
||||
.client
|
||||
.query(format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
"SELECT * FROM {} WHERE source_id = 'different_source'",
|
||||
TextChunk::table_name(),
|
||||
"different_source"
|
||||
))
|
||||
.await
|
||||
.expect("Query failed")
|
||||
.with_context(|| "Query failed".to_string())?
|
||||
.take(0)
|
||||
.expect("Failed to get query results");
|
||||
.with_context(|| "Failed to get query results".to_string())?;
|
||||
assert_eq!(different_remaining.len(), 1);
|
||||
assert_eq!(different_remaining[0].id, different_chunk.id);
|
||||
assert_eq!(
|
||||
different_remaining.first().map(|r| &r.id),
|
||||
Some(&different_chunk.id)
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_nonexistent_source_id() {
|
||||
async fn test_delete_by_nonexistent_source_id() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
.with_context(|| "redefine index".to_string())?;
|
||||
|
||||
let real_source_id = "real_source".to_string();
|
||||
let chunk = TextChunk::new(
|
||||
@@ -600,24 +594,24 @@ mod tests {
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
|
||||
.await
|
||||
.expect("store chunk");
|
||||
.with_context(|| "store chunk".to_string())?;
|
||||
|
||||
TextChunk::delete_by_source_id("nonexistent_source", &db)
|
||||
.await
|
||||
.expect("Delete should succeed");
|
||||
.with_context(|| "Delete should succeed".to_string())?;
|
||||
|
||||
let remaining: Vec<TextChunk> = db
|
||||
.client
|
||||
.query(format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
"SELECT * FROM {} WHERE source_id = '{real_source_id}'",
|
||||
TextChunk::table_name(),
|
||||
real_source_id
|
||||
))
|
||||
.await
|
||||
.expect("Query failed")
|
||||
.with_context(|| "Query failed".to_string())?
|
||||
.take(0)
|
||||
.expect("Failed to get query results");
|
||||
.with_context(|| "Failed to get query results".to_string())?;
|
||||
assert_eq!(remaining.len(), 1);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -672,13 +666,13 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_with_embedding_creates_both_records() {
|
||||
async fn test_store_with_embedding_creates_both_records() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let source_id = "store-src".to_string();
|
||||
let user_id = "user_store".to_string();
|
||||
@@ -686,43 +680,43 @@ mod tests {
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
.with_context(|| "redefine index".to_string())?;
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
.await
|
||||
.expect("store with embedding");
|
||||
.with_context(|| "store with embedding".to_string())?;
|
||||
|
||||
let stored_chunk: Option<TextChunk> = db.get_item(&chunk.id).await.unwrap();
|
||||
assert!(stored_chunk.is_some());
|
||||
let stored_chunk = stored_chunk.unwrap();
|
||||
let stored_chunk: Option<TextChunk> = db.get_item(&chunk.id)
|
||||
.await
|
||||
.with_context(|| "get_item".to_string())?;
|
||||
let stored_chunk = stored_chunk.with_context(|| "expected stored chunk".to_string())?;
|
||||
assert_eq!(stored_chunk.source_id, source_id);
|
||||
assert_eq!(stored_chunk.user_id, user_id);
|
||||
|
||||
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
|
||||
let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db)
|
||||
.await
|
||||
.expect("get embedding");
|
||||
assert!(embedding.is_some());
|
||||
let embedding = embedding.unwrap();
|
||||
.with_context(|| "get embedding".to_string())?
|
||||
.with_context(|| "expected embedding".to_string())?;
|
||||
assert_eq!(embedding.chunk_id, rid);
|
||||
assert_eq!(embedding.user_id, user_id);
|
||||
assert_eq!(embedding.source_id, source_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_with_embedding_with_runtime_indexes() {
|
||||
async fn test_store_with_embedding_with_runtime_indexes() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns_runtime";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
|
||||
// Ensure runtime indexes are built with the expected dimension.
|
||||
let embedding_dimension = 3usize;
|
||||
ensure_runtime_indexes(&db, embedding_dimension)
|
||||
ensure_runtime(&db, embedding_dimension)
|
||||
.await
|
||||
.expect("ensure runtime indexes");
|
||||
.with_context(|| "ensure runtime indexes".to_string())?;
|
||||
|
||||
let chunk = TextChunk::new(
|
||||
"runtime_src".to_string(),
|
||||
@@ -732,55 +726,60 @@ mod tests {
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
.await
|
||||
.expect("store with embedding");
|
||||
.with_context(|| "store with embedding".to_string())?;
|
||||
|
||||
let stored_chunk: Option<TextChunk> = db.get_item(&chunk.id).await.unwrap();
|
||||
assert!(stored_chunk.is_some(), "chunk should be stored");
|
||||
let stored_chunk: Option<TextChunk> = db.get_item(&chunk.id)
|
||||
.await
|
||||
.with_context(|| "get_item".to_string())?;
|
||||
let stored_chunk = stored_chunk.with_context(|| "chunk should be stored".to_string())?;
|
||||
assert!(stored_chunk.id == chunk.id, "chunk should be stored");
|
||||
|
||||
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
|
||||
let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db)
|
||||
.await
|
||||
.expect("get embedding");
|
||||
assert!(embedding.is_some(), "embedding should exist");
|
||||
.with_context(|| "get embedding".to_string())?
|
||||
.with_context(|| "embedding should exist".to_string())?;
|
||||
assert_eq!(
|
||||
embedding.unwrap().embedding.len(),
|
||||
embedding.embedding.len(),
|
||||
embedding_dimension,
|
||||
"embedding dimension should match runtime index"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_returns_empty_when_no_embeddings() {
|
||||
async fn test_vector_search_returns_empty_when_no_embeddings() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
.with_context(|| "redefine index".to_string())?;
|
||||
|
||||
let results: Vec<TextChunkSearchResult> =
|
||||
TextChunk::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user")
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "vector_search".to_string())?;
|
||||
assert!(results.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_single_result() {
|
||||
async fn test_vector_search_single_result() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
.with_context(|| "redefine index".to_string())?;
|
||||
|
||||
let source_id = "src".to_string();
|
||||
let user_id = "user".to_string();
|
||||
@@ -792,32 +791,33 @@ mod tests {
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
.await
|
||||
.expect("store");
|
||||
.with_context(|| "store".to_string())?;
|
||||
|
||||
let results: Vec<TextChunkSearchResult> =
|
||||
TextChunk::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "vector_search".to_string())?;
|
||||
|
||||
assert_eq!(results.len(), 1);
|
||||
let res = &results[0];
|
||||
let res = results.first().context("expected first result")?;
|
||||
assert_eq!(res.chunk.id, chunk.id);
|
||||
assert_eq!(res.chunk.source_id, source_id);
|
||||
assert_eq!(res.chunk.chunk, "hello world");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_orders_by_similarity() {
|
||||
async fn test_vector_search_orders_by_similarity() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
.with_context(|| "redefine index".to_string())?;
|
||||
|
||||
let user_id = "user".to_string();
|
||||
let chunk1 = TextChunk::new("s1".to_string(), "chunk one".to_string(), user_id.clone());
|
||||
@@ -825,49 +825,59 @@ mod tests {
|
||||
|
||||
TextChunk::store_with_embedding(chunk1.clone(), vec![1.0, 0.0, 0.0], &db)
|
||||
.await
|
||||
.expect("store chunk1");
|
||||
.with_context(|| "store chunk1".to_string())?;
|
||||
TextChunk::store_with_embedding(chunk2.clone(), vec![0.0, 1.0, 0.0], &db)
|
||||
.await
|
||||
.expect("store chunk2");
|
||||
.with_context(|| "store chunk2".to_string())?;
|
||||
|
||||
let results: Vec<TextChunkSearchResult> =
|
||||
TextChunk::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id)
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "vector_search".to_string())?;
|
||||
|
||||
assert_eq!(results.len(), 2);
|
||||
assert_eq!(results[0].chunk.id, chunk2.id);
|
||||
assert_eq!(results[1].chunk.id, chunk1.id);
|
||||
assert!(results[0].score >= results[1].score);
|
||||
assert_eq!(
|
||||
results.first().map(|r| &r.chunk.id),
|
||||
Some(&chunk2.id)
|
||||
);
|
||||
assert_eq!(
|
||||
results.get(1).map(|r| &r.chunk.id),
|
||||
Some(&chunk1.id)
|
||||
);
|
||||
let r0 = results.first().context("expected first result")?;
|
||||
let r1 = results.get(1).context("expected second result")?;
|
||||
assert!(r0.score >= r1.score);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fts_search_returns_empty_when_no_chunks() {
|
||||
async fn test_fts_search_returns_empty_when_no_chunks() -> anyhow::Result<()> {
|
||||
let namespace = "fts_chunk_ns_empty";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
ensure_chunk_fts_index(&db).await;
|
||||
rebuild_indexes(&db).await.expect("rebuild indexes");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
ensure_chunk_fts_index(&db).await?;
|
||||
rebuild(&db).await.with_context(|| "rebuild indexes".to_string())?;
|
||||
|
||||
let results = TextChunk::fts_search(5, "hello", &db, "user")
|
||||
.await
|
||||
.expect("fts search");
|
||||
.with_context(|| "fts search".to_string())?;
|
||||
|
||||
assert!(results.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fts_search_single_result() {
|
||||
async fn test_fts_search_single_result() -> anyhow::Result<()> {
|
||||
let namespace = "fts_chunk_ns_single";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
ensure_chunk_fts_index(&db).await;
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
ensure_chunk_fts_index(&db).await?;
|
||||
|
||||
let user_id = "fts_user";
|
||||
let chunk = TextChunk::new(
|
||||
@@ -875,27 +885,29 @@ mod tests {
|
||||
"rustaceans love rust".to_string(),
|
||||
user_id.to_string(),
|
||||
);
|
||||
db.store_item(chunk.clone()).await.expect("store chunk");
|
||||
rebuild_indexes(&db).await.expect("rebuild indexes");
|
||||
db.store_item(chunk.clone()).await.with_context(|| "store chunk".to_string())?;
|
||||
rebuild(&db).await.with_context(|| "rebuild indexes".to_string())?;
|
||||
|
||||
let results = TextChunk::fts_search(3, "rust", &db, user_id)
|
||||
.await
|
||||
.expect("fts search");
|
||||
.with_context(|| "fts search".to_string())?;
|
||||
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].chunk.id, chunk.id);
|
||||
assert!(results[0].score.is_finite(), "expected a finite FTS score");
|
||||
let r0 = results.first().context("expected first result")?;
|
||||
assert_eq!(r0.chunk.id, chunk.id);
|
||||
assert!(r0.score.is_finite(), "expected a finite FTS score");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fts_search_orders_by_score_and_filters_user() {
|
||||
async fn test_fts_search_orders_by_score_and_filters_user() -> anyhow::Result<()> {
|
||||
let namespace = "fts_chunk_ns_order";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
ensure_chunk_fts_index(&db).await;
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
ensure_chunk_fts_index(&db).await?;
|
||||
|
||||
let user_id = "fts_user_order";
|
||||
let high_score_chunk = TextChunk::new(
|
||||
@@ -916,18 +928,18 @@ mod tests {
|
||||
|
||||
db.store_item(high_score_chunk.clone())
|
||||
.await
|
||||
.expect("store high score chunk");
|
||||
.with_context(|| "store high score chunk".to_string())?;
|
||||
db.store_item(low_score_chunk.clone())
|
||||
.await
|
||||
.expect("store low score chunk");
|
||||
.with_context(|| "store low score chunk".to_string())?;
|
||||
db.store_item(other_user_chunk)
|
||||
.await
|
||||
.expect("store other user chunk");
|
||||
rebuild_indexes(&db).await.expect("rebuild indexes");
|
||||
.with_context(|| "store other user chunk".to_string())?;
|
||||
rebuild(&db).await.with_context(|| "rebuild indexes".to_string())?;
|
||||
|
||||
let results = TextChunk::fts_search(3, "apple", &db, user_id)
|
||||
.await
|
||||
.expect("fts search");
|
||||
.with_context(|| "fts search".to_string())?;
|
||||
|
||||
assert_eq!(results.len(), 2);
|
||||
let ids: Vec<_> = results.iter().map(|r| r.chunk.id.as_str()).collect();
|
||||
@@ -936,9 +948,12 @@ mod tests {
|
||||
&& ids.contains(&low_score_chunk.id.as_str()),
|
||||
"expected only the two chunks for the same user"
|
||||
);
|
||||
let r0 = results.first().context("expected first result")?;
|
||||
let r1 = results.get(1).context("expected second result")?;
|
||||
assert!(
|
||||
results[0].score >= results[1].score,
|
||||
r0.score >= r1.score,
|
||||
"expected results ordered by descending score"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -126,24 +126,26 @@ impl TextChunkEmbedding {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
|
||||
use super::*;
|
||||
use crate::storage::db::SurrealDbClient;
|
||||
use surrealdb::Value as SurrealValue;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Helper to create an in-memory DB and apply migrations
|
||||
async fn setup_test_db() -> SurrealDbClient {
|
||||
async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
|
||||
let namespace = "test_ns";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
db
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
/// Helper: create a text_chunk with a known key, return its RecordId
|
||||
@@ -152,7 +154,7 @@ mod tests {
|
||||
key: &str,
|
||||
source_id: &str,
|
||||
user_id: &str,
|
||||
) -> RecordId {
|
||||
) -> anyhow::Result<RecordId> {
|
||||
let chunk = TextChunk {
|
||||
id: key.to_owned(),
|
||||
created_at: Utc::now(),
|
||||
@@ -164,21 +166,42 @@ mod tests {
|
||||
|
||||
db.store_item(chunk)
|
||||
.await
|
||||
.expect("Failed to create text_chunk");
|
||||
.with_context(|| "Failed to create text_chunk".to_string())?;
|
||||
|
||||
RecordId::from_table_key(TextChunk::table_name(), key)
|
||||
Ok(RecordId::from_table_key(TextChunk::table_name(), key))
|
||||
}
|
||||
|
||||
async fn get_idx_sql(db: &SurrealDbClient) -> anyhow::Result<String> {
|
||||
let mut info_res = db
|
||||
.client
|
||||
.query("INFO FOR TABLE text_chunk_embedding;")
|
||||
.await
|
||||
.with_context(|| "info query failed".to_string())?;
|
||||
let info: SurrealValue = info_res.take(0).with_context(|| "failed to take info result".to_string())?;
|
||||
let info_json: serde_json::Value =
|
||||
serde_json::to_value(info).with_context(|| "failed to convert info to json".to_string())?;
|
||||
let idx_sql = info_json
|
||||
.get("Object")
|
||||
.and_then(|v| v.get("indexes"))
|
||||
.and_then(|v| v.get("Object"))
|
||||
.and_then(|v| v.get("idx_embedding_text_chunk_embedding"))
|
||||
.and_then(|v| v.get("Strand"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
Ok(idx_sql)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_and_get_by_chunk_id() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_create_and_get_by_chunk_id() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let user_id = "user_a";
|
||||
let chunk_key = "chunk-123";
|
||||
let source_id = "source-1";
|
||||
|
||||
// 1) Create a text_chunk with a known key
|
||||
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await;
|
||||
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await?;
|
||||
|
||||
// 2) Create and store an embedding for that chunk
|
||||
let embedding_vec = vec![0.1_f32, 0.2, 0.3];
|
||||
@@ -191,39 +214,37 @@ mod tests {
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, emb.embedding.len())
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
|
||||
let _: Option<TextChunkEmbedding> = db
|
||||
.client
|
||||
.create(TextChunkEmbedding::table_name())
|
||||
.content(emb)
|
||||
.await
|
||||
.expect("Failed to store embedding")
|
||||
.take()
|
||||
.expect("Failed to deserialize stored embedding");
|
||||
.with_context(|| "Failed to store embedding".to_string())?
|
||||
.with_context(|| "Failed to deserialize stored embedding".to_string())?;
|
||||
|
||||
// 3) Fetch it via get_by_chunk_id
|
||||
let fetched = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
||||
.await
|
||||
.expect("Failed to get embedding by chunk_id");
|
||||
|
||||
assert!(fetched.is_some(), "Expected an embedding to be found");
|
||||
let fetched = fetched.unwrap();
|
||||
.with_context(|| "Failed to get embedding by chunk_id".to_string())?
|
||||
.with_context(|| "Expected an embedding to be found".to_string())?;
|
||||
|
||||
assert_eq!(fetched.user_id, user_id);
|
||||
assert_eq!(fetched.chunk_id, chunk_rid);
|
||||
assert_eq!(fetched.embedding, embedding_vec);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_chunk_id() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_delete_by_chunk_id() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let user_id = "user_b";
|
||||
let chunk_key = "chunk-delete";
|
||||
let source_id = "source-del";
|
||||
|
||||
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await;
|
||||
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await?;
|
||||
|
||||
let emb = TextChunkEmbedding::new(
|
||||
chunk_key,
|
||||
@@ -234,50 +255,50 @@ mod tests {
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, emb.embedding.len())
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
|
||||
let _: Option<TextChunkEmbedding> = db
|
||||
.client
|
||||
.create(TextChunkEmbedding::table_name())
|
||||
.content(emb)
|
||||
.await
|
||||
.expect("Failed to store embedding")
|
||||
.take()
|
||||
.expect("Failed to deserialize stored embedding");
|
||||
.with_context(|| "Failed to store embedding".to_string())?
|
||||
.with_context(|| "Failed to deserialize stored embedding".to_string())?;
|
||||
|
||||
// Ensure it exists
|
||||
let existing = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
||||
.await
|
||||
.expect("Failed to get embedding before delete");
|
||||
.with_context(|| "Failed to get embedding before delete".to_string())?;
|
||||
assert!(existing.is_some(), "Embedding should exist before delete");
|
||||
|
||||
// Delete by chunk_id
|
||||
TextChunkEmbedding::delete_by_chunk_id(&chunk_rid, &db)
|
||||
.await
|
||||
.expect("Failed to delete by chunk_id");
|
||||
.with_context(|| "Failed to delete by chunk_id".to_string())?;
|
||||
|
||||
// Ensure it no longer exists
|
||||
let after = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
||||
.await
|
||||
.expect("Failed to get embedding after delete");
|
||||
.with_context(|| "Failed to get embedding after delete".to_string())?;
|
||||
assert!(after.is_none(), "Embedding should have been deleted");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_delete_by_source_id() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let user_id = "user_c";
|
||||
let source_id = "shared-source";
|
||||
let other_source = "other-source";
|
||||
|
||||
// Two chunks with the same source_id
|
||||
let chunk1_rid = create_text_chunk_with_id(&db, "chunk-s1", source_id, user_id).await;
|
||||
let chunk2_rid = create_text_chunk_with_id(&db, "chunk-s2", source_id, user_id).await;
|
||||
let chunk1_rid = create_text_chunk_with_id(&db, "chunk-s1", source_id, user_id).await?;
|
||||
let chunk2_rid = create_text_chunk_with_id(&db, "chunk-s2", source_id, user_id).await?;
|
||||
|
||||
// One chunk with a different source_id
|
||||
let chunk_other_rid =
|
||||
create_text_chunk_with_id(&db, "chunk-other", other_source, user_id).await;
|
||||
create_text_chunk_with_id(&db, "chunk-other", other_source, user_id).await?;
|
||||
|
||||
// Create embeddings for all three
|
||||
let emb1 = TextChunkEmbedding::new(
|
||||
@@ -302,7 +323,7 @@ mod tests {
|
||||
// Update length on index
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, emb1.embedding.len())
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
|
||||
for emb in [emb1, emb2, emb3] {
|
||||
let _: Option<TextChunkEmbedding> = db
|
||||
@@ -310,102 +331,82 @@ mod tests {
|
||||
.create(TextChunkEmbedding::table_name())
|
||||
.content(emb)
|
||||
.await
|
||||
.expect("Failed to store embedding")
|
||||
.take()
|
||||
.expect("Failed to deserialize stored embedding");
|
||||
.with_context(|| "Failed to store embedding".to_string())?
|
||||
.with_context(|| "Failed to deserialize stored embedding".to_string())?;
|
||||
}
|
||||
|
||||
// Sanity check: they all exist
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.with_context(|| "get chunk1".to_string())?
|
||||
.is_some());
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.with_context(|| "get chunk2".to_string())?
|
||||
.is_some());
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.with_context(|| "get chunk_other".to_string())?
|
||||
.is_some());
|
||||
|
||||
// Delete embeddings by source_id (shared-source)
|
||||
TextChunkEmbedding::delete_by_source_id(source_id, &db)
|
||||
.await
|
||||
.expect("Failed to delete by source_id");
|
||||
.with_context(|| "Failed to delete by source_id".to_string())?;
|
||||
|
||||
// Chunks from shared-source should have no embeddings
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.with_context(|| "check chunk1".to_string())?
|
||||
.is_none());
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.with_context(|| "check chunk2".to_string())?
|
||||
.is_none());
|
||||
|
||||
// The other chunk should still have its embedding
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.with_context(|| "check chunk_other".to_string())?
|
||||
.is_some());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_redefine_hnsw_index_updates_dimension() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_redefine_hnsw_index_updates_dimension() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
// Change the index dimension from default (1536) to a smaller test value.
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 8)
|
||||
.await
|
||||
.expect("failed to redefine index");
|
||||
.with_context(|| "failed to redefine index".to_string())?;
|
||||
|
||||
let mut info_res = db
|
||||
.client
|
||||
.query("INFO FOR TABLE text_chunk_embedding;")
|
||||
.await
|
||||
.expect("info query failed");
|
||||
let info: SurrealValue = info_res.take(0).expect("failed to take info result");
|
||||
let info_json: serde_json::Value =
|
||||
serde_json::to_value(info).expect("failed to convert info to json");
|
||||
let idx_sql = info_json["Object"]["indexes"]["Object"]
|
||||
["idx_embedding_text_chunk_embedding"]["Strand"]
|
||||
.as_str()
|
||||
.unwrap_or_default();
|
||||
let idx_sql = get_idx_sql(&db).await?;
|
||||
|
||||
assert!(
|
||||
idx_sql.contains("DIMENSION 8"),
|
||||
"expected index definition to contain new dimension, got: {idx_sql}"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_redefine_hnsw_index_is_idempotent() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_redefine_hnsw_index_is_idempotent() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 4)
|
||||
.await
|
||||
.expect("first redefine failed");
|
||||
.with_context(|| "first redefine failed".to_string())?;
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 4)
|
||||
.await
|
||||
.expect("second redefine failed");
|
||||
.with_context(|| "second redefine failed".to_string())?;
|
||||
|
||||
let mut info_res = db
|
||||
.client
|
||||
.query("INFO FOR TABLE text_chunk_embedding;")
|
||||
.await
|
||||
.expect("info query failed");
|
||||
let info: SurrealValue = info_res.take(0).expect("failed to take info result");
|
||||
let info_json: serde_json::Value =
|
||||
serde_json::to_value(info).expect("failed to convert info to json");
|
||||
let idx_sql = info_json["Object"]["indexes"]["Object"]
|
||||
["idx_embedding_text_chunk_embedding"]["Strand"]
|
||||
.as_str()
|
||||
.unwrap_or_default();
|
||||
let idx_sql = get_idx_sql(&db).await?;
|
||||
|
||||
assert!(
|
||||
idx_sql.contains("DIMENSION 4"),
|
||||
"expected index definition to retain dimension 4, got: {idx_sql}"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,10 +185,12 @@ impl TextContent {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_content_creation() {
|
||||
async fn test_text_content_creation() -> anyhow::Result<()> {
|
||||
// Test basic object creation
|
||||
let text = "Test content text".to_string();
|
||||
let context = "Test context".to_string();
|
||||
@@ -212,10 +214,11 @@ mod tests {
|
||||
assert!(text_content.file_info.is_none());
|
||||
assert!(text_content.url_info.is_none());
|
||||
assert!(!text_content.id.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_content_with_url() {
|
||||
async fn test_text_content_with_url() -> anyhow::Result<()> {
|
||||
// Test creating with URL
|
||||
let text = "Content with URL".to_string();
|
||||
let context = "URL context".to_string();
|
||||
@@ -232,26 +235,27 @@ mod tests {
|
||||
});
|
||||
|
||||
let text_content = TextContent::new(
|
||||
text.clone(),
|
||||
Some(context.clone()),
|
||||
category.clone(),
|
||||
text,
|
||||
Some(context),
|
||||
category,
|
||||
None,
|
||||
url_info.clone(),
|
||||
user_id.clone(),
|
||||
user_id,
|
||||
);
|
||||
|
||||
// Check URL field is set
|
||||
assert_eq!(text_content.url_info, url_info);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_content_patch() {
|
||||
async fn test_text_content_patch() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Create initial text content
|
||||
let initial_text = "Initial text".to_string();
|
||||
@@ -272,7 +276,7 @@ mod tests {
|
||||
let stored: Option<TextContent> = db
|
||||
.store_item(text_content.clone())
|
||||
.await
|
||||
.expect("Failed to store text content");
|
||||
.with_context(|| "Failed to store text content".to_string())?;
|
||||
assert!(stored.is_some());
|
||||
|
||||
// New values for patch
|
||||
@@ -283,31 +287,30 @@ mod tests {
|
||||
// Apply the patch
|
||||
TextContent::patch(&text_content.id, new_context, new_category, new_text, &db)
|
||||
.await
|
||||
.expect("Failed to patch text content");
|
||||
.with_context(|| "Failed to patch text content".to_string())?;
|
||||
|
||||
// Retrieve the updated content
|
||||
let updated: Option<TextContent> = db
|
||||
.get_item(&text_content.id)
|
||||
.await
|
||||
.expect("Failed to get updated text content");
|
||||
assert!(updated.is_some());
|
||||
|
||||
let updated_content = updated.unwrap();
|
||||
.with_context(|| "Failed to get updated text content".to_string())?;
|
||||
let updated_content = updated.with_context(|| "expected updated content".to_string())?;
|
||||
|
||||
// Verify the updates
|
||||
assert_eq!(updated_content.context, Some(new_context.to_string()));
|
||||
assert_eq!(updated_content.category, new_category);
|
||||
assert_eq!(updated_content.text, new_text);
|
||||
assert!(updated_content.updated_at > text_content.updated_at);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_has_other_with_file_detects_shared_usage() {
|
||||
async fn test_has_other_with_file_detects_shared_usage() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
let user_id = "user123".to_string();
|
||||
let file_info = FileInfo {
|
||||
@@ -340,24 +343,25 @@ mod tests {
|
||||
|
||||
db.store_item(content_a.clone())
|
||||
.await
|
||||
.expect("Failed to store first content");
|
||||
.with_context(|| "Failed to store first content".to_string())?;
|
||||
db.store_item(content_b.clone())
|
||||
.await
|
||||
.expect("Failed to store second content");
|
||||
.with_context(|| "Failed to store second content".to_string())?;
|
||||
|
||||
let has_other = TextContent::has_other_with_file(&file_info.id, &content_a.id, &db)
|
||||
.await
|
||||
.expect("Failed to check for shared file usage");
|
||||
.with_context(|| "Failed to check for shared file usage".to_string())?;
|
||||
assert!(has_other);
|
||||
|
||||
let _removed: Option<TextContent> = db
|
||||
.delete_item(&content_b.id)
|
||||
.await
|
||||
.expect("Failed to delete second content");
|
||||
.with_context(|| "Failed to delete second content".to_string())?;
|
||||
|
||||
let has_other_after = TextContent::has_other_with_file(&file_info.id, &content_a.id, &db)
|
||||
.await
|
||||
.expect("Failed to check shared usage after delete");
|
||||
.with_context(|| "Failed to check shared usage after delete".to_string())?;
|
||||
assert!(!has_other_after);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -723,30 +723,32 @@ impl User {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
|
||||
use super::*;
|
||||
use crate::storage::types::ingestion_payload::IngestionPayload;
|
||||
use crate::storage::types::ingestion_task::{IngestionTask, TaskState, MAX_ATTEMPTS};
|
||||
use std::collections::HashSet;
|
||||
|
||||
// Helper function to set up a test database with SystemSettings
|
||||
async fn setup_test_db() -> SurrealDbClient {
|
||||
async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
|
||||
let namespace = "test_ns";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to setup the migrations");
|
||||
.with_context(|| "Failed to setup the migrations".to_string())?;
|
||||
|
||||
db
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_user_creation() {
|
||||
async fn test_user_creation() -> anyhow::Result<()> {
|
||||
// Setup test database
|
||||
let db = setup_test_db().await;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
// Create a user
|
||||
let email = "test@example.com";
|
||||
@@ -761,7 +763,7 @@ mod tests {
|
||||
"system".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create user");
|
||||
.with_context(|| "Failed to create user".to_string())?;
|
||||
|
||||
// Verify user properties
|
||||
assert!(!user.id.is_empty());
|
||||
@@ -774,18 +776,17 @@ mod tests {
|
||||
let retrieved: Option<User> = db
|
||||
.get_item(&user.id)
|
||||
.await
|
||||
.expect("Failed to retrieve user");
|
||||
assert!(retrieved.is_some());
|
||||
|
||||
let retrieved = retrieved.unwrap();
|
||||
.with_context(|| "Failed to retrieve user".to_string())?;
|
||||
let retrieved = retrieved.with_context(|| "expected user to exist".to_string())?;
|
||||
assert_eq!(retrieved.id, user.id);
|
||||
assert_eq!(retrieved.email, email);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_user_authentication() {
|
||||
async fn test_user_authentication() -> anyhow::Result<()> {
|
||||
// Setup test database
|
||||
let db = setup_test_db().await;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
// Create a user
|
||||
let email = "auth_test@example.com";
|
||||
@@ -799,7 +800,7 @@ mod tests {
|
||||
"system".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create user");
|
||||
.with_context(|| "Failed to create user".to_string())?;
|
||||
|
||||
// Test successful authentication
|
||||
let auth_result = User::authenticate(email, password, &db).await;
|
||||
@@ -812,11 +813,12 @@ mod tests {
|
||||
// Test failed authentication with non-existent user
|
||||
let nonexistent = User::authenticate("nonexistent@example.com", password, &db).await;
|
||||
assert!(nonexistent.is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_unfinished_ingestion_tasks_filters_correctly() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_get_unfinished_ingestion_tasks_filters_correctly() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "unfinished_user";
|
||||
let other_user_id = "other_user";
|
||||
|
||||
@@ -830,14 +832,14 @@ mod tests {
|
||||
let created_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
db.store_item(created_task.clone())
|
||||
.await
|
||||
.expect("Failed to store created task");
|
||||
.with_context(|| "Failed to store created task".to_string())?;
|
||||
|
||||
let mut processing_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
processing_task.state = TaskState::Processing;
|
||||
processing_task.attempts = 1;
|
||||
db.store_item(processing_task.clone())
|
||||
.await
|
||||
.expect("Failed to store processing task");
|
||||
.with_context(|| "Failed to store processing task".to_string())?;
|
||||
|
||||
let mut failed_retry_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
failed_retry_task.state = TaskState::Failed;
|
||||
@@ -845,7 +847,7 @@ mod tests {
|
||||
failed_retry_task.scheduled_at = chrono::Utc::now() - chrono::Duration::minutes(5);
|
||||
db.store_item(failed_retry_task.clone())
|
||||
.await
|
||||
.expect("Failed to store retryable failed task");
|
||||
.with_context(|| "Failed to store retryable failed task".to_string())?;
|
||||
|
||||
let mut failed_blocked_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
failed_blocked_task.state = TaskState::Failed;
|
||||
@@ -853,13 +855,13 @@ mod tests {
|
||||
failed_blocked_task.error_message = Some("Too many failures".into());
|
||||
db.store_item(failed_blocked_task.clone())
|
||||
.await
|
||||
.expect("Failed to store blocked task");
|
||||
.with_context(|| "Failed to store blocked task".to_string())?;
|
||||
|
||||
let mut completed_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
completed_task.state = TaskState::Succeeded;
|
||||
db.store_item(completed_task.clone())
|
||||
.await
|
||||
.expect("Failed to store completed task");
|
||||
.with_context(|| "Failed to store completed task".to_string())?;
|
||||
|
||||
let other_payload = IngestionPayload::Text {
|
||||
text: "Other".to_string(),
|
||||
@@ -870,11 +872,11 @@ mod tests {
|
||||
let other_task = IngestionTask::new(other_payload, other_user_id.to_string());
|
||||
db.store_item(other_task)
|
||||
.await
|
||||
.expect("Failed to store other user task");
|
||||
.with_context(|| "Failed to store other user task".to_string())?;
|
||||
|
||||
let unfinished = User::get_unfinished_ingestion_tasks(user_id, &db)
|
||||
.await
|
||||
.expect("Failed to fetch unfinished tasks");
|
||||
.with_context(|| "Failed to fetch unfinished tasks".to_string())?;
|
||||
|
||||
let unfinished_ids: HashSet<String> =
|
||||
unfinished.iter().map(|task| task.id.clone()).collect();
|
||||
@@ -885,11 +887,12 @@ mod tests {
|
||||
assert!(!unfinished_ids.contains(&failed_blocked_task.id));
|
||||
assert!(!unfinished_ids.contains(&completed_task.id));
|
||||
assert_eq!(unfinished_ids.len(), 3);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_all_ingestion_tasks_returns_sorted() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_get_all_ingestion_tasks_returns_sorted() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "archive_user";
|
||||
let other_user_id = "other_user";
|
||||
|
||||
@@ -902,15 +905,15 @@ mod tests {
|
||||
|
||||
// Oldest task
|
||||
let mut first = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
first.created_at = first.created_at - chrono::Duration::minutes(1);
|
||||
first.created_at -= chrono::Duration::minutes(1);
|
||||
first.updated_at = first.created_at;
|
||||
first.state = TaskState::Succeeded;
|
||||
db.store_item(first.clone()).await.expect("store first");
|
||||
db.store_item(first.clone()).await.with_context(|| "store first".to_string())?;
|
||||
|
||||
// Latest task
|
||||
let mut second = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
second.state = TaskState::Processing;
|
||||
db.store_item(second.clone()).await.expect("store second");
|
||||
db.store_item(second.clone()).await.with_context(|| "store second".to_string())?;
|
||||
|
||||
let other_payload = IngestionPayload::Text {
|
||||
text: "Other".to_string(),
|
||||
@@ -919,21 +922,22 @@ mod tests {
|
||||
user_id: other_user_id.to_string(),
|
||||
};
|
||||
let other_task = IngestionTask::new(other_payload, other_user_id.to_string());
|
||||
db.store_item(other_task).await.expect("store other");
|
||||
db.store_item(other_task).await.with_context(|| "store other".to_string())?;
|
||||
|
||||
let tasks = User::get_all_ingestion_tasks(user_id, &db)
|
||||
.await
|
||||
.expect("fetch all tasks");
|
||||
.with_context(|| "fetch all tasks".to_string())?;
|
||||
|
||||
assert_eq!(tasks.len(), 2);
|
||||
assert_eq!(tasks[0].id, second.id); // newest first
|
||||
assert_eq!(tasks[1].id, first.id);
|
||||
assert_eq!(tasks.first().map(|t| &t.id), Some(&second.id)); // newest first
|
||||
assert_eq!(tasks.get(1).map(|t| &t.id), Some(&first.id));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_find_by_email() {
|
||||
async fn test_find_by_email() -> anyhow::Result<()> {
|
||||
// Setup test database
|
||||
let db = setup_test_db().await;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
// Create a user
|
||||
let email = "find_test@example.com";
|
||||
@@ -947,28 +951,28 @@ mod tests {
|
||||
"system".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create user");
|
||||
.with_context(|| "Failed to create user".to_string())?;
|
||||
|
||||
// Test finding user by email
|
||||
let found_user = User::find_by_email(email, &db)
|
||||
.await
|
||||
.expect("Error searching for user");
|
||||
assert!(found_user.is_some());
|
||||
let found_user = found_user.unwrap();
|
||||
.with_context(|| "Error searching for user".to_string())?
|
||||
.with_context(|| "expected user to exist".to_string())?;
|
||||
assert_eq!(found_user.id, created_user.id);
|
||||
assert_eq!(found_user.email, email);
|
||||
|
||||
// Test finding non-existent user
|
||||
let not_found = User::find_by_email("nonexistent@example.com", &db)
|
||||
.await
|
||||
.expect("Error searching for user");
|
||||
.with_context(|| "Error searching for user".to_string())?;
|
||||
assert!(not_found.is_none());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_api_key_management() {
|
||||
async fn test_api_key_management() -> anyhow::Result<()> {
|
||||
// Setup test database
|
||||
let db = setup_test_db().await;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
// Create a user
|
||||
let email = "apikey_test@example.com";
|
||||
@@ -982,7 +986,7 @@ mod tests {
|
||||
"system".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create user");
|
||||
.with_context(|| "Failed to create user".to_string())?;
|
||||
|
||||
// Initially, user should have no API key
|
||||
assert!(user.api_key.is_none());
|
||||
@@ -990,7 +994,7 @@ mod tests {
|
||||
// Generate API key
|
||||
let api_key = User::set_api_key(&user.id, &db)
|
||||
.await
|
||||
.expect("Failed to set API key");
|
||||
.with_context(|| "Failed to set API key".to_string())?;
|
||||
assert!(!api_key.is_empty());
|
||||
assert!(api_key.starts_with("sk_"));
|
||||
|
||||
@@ -998,38 +1002,36 @@ mod tests {
|
||||
let updated_user: Option<User> = db
|
||||
.get_item(&user.id)
|
||||
.await
|
||||
.expect("Failed to retrieve user");
|
||||
assert!(updated_user.is_some());
|
||||
let updated_user = updated_user.unwrap();
|
||||
.with_context(|| "Failed to retrieve user".to_string())?;
|
||||
let updated_user = updated_user.with_context(|| "expected updated user".to_string())?;
|
||||
assert_eq!(updated_user.api_key, Some(api_key.clone()));
|
||||
|
||||
// Test finding user by API key
|
||||
let found_user = User::find_by_api_key(&api_key, &db)
|
||||
.await
|
||||
.expect("Error searching by API key");
|
||||
assert!(found_user.is_some());
|
||||
let found_user = found_user.unwrap();
|
||||
.with_context(|| "Error searching by API key".to_string())?
|
||||
.with_context(|| "expected user found by api key".to_string())?;
|
||||
assert_eq!(found_user.id, user.id);
|
||||
|
||||
// Revoke API key
|
||||
User::revoke_api_key(&user.id, &db)
|
||||
.await
|
||||
.expect("Failed to revoke API key");
|
||||
.with_context(|| "Failed to revoke API key".to_string())?;
|
||||
|
||||
// Verify API key was revoked
|
||||
let revoked_user: Option<User> = db
|
||||
.get_item(&user.id)
|
||||
.await
|
||||
.expect("Failed to retrieve user");
|
||||
assert!(revoked_user.is_some());
|
||||
let revoked_user = revoked_user.unwrap();
|
||||
.with_context(|| "Failed to retrieve user".to_string())?;
|
||||
let revoked_user = revoked_user.with_context(|| "expected revoked user".to_string())?;
|
||||
assert!(revoked_user.api_key.is_none());
|
||||
|
||||
// Test searching by revoked API key
|
||||
let not_found = User::find_by_api_key(&api_key, &db)
|
||||
.await
|
||||
.expect("Error searching by API key");
|
||||
.with_context(|| "Error searching by API key".to_string())?;
|
||||
assert!(not_found.is_none());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -1069,9 +1071,9 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_password_update() {
|
||||
async fn test_password_update() -> anyhow::Result<()> {
|
||||
// Setup test database
|
||||
let db = setup_test_db().await;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
// Create a user
|
||||
let email = "pwd_test@example.com";
|
||||
@@ -1086,7 +1088,7 @@ mod tests {
|
||||
"system".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create user");
|
||||
.with_context(|| "Failed to create user".to_string())?;
|
||||
|
||||
// Authenticate with old password
|
||||
let auth_result = User::authenticate(email, old_password, &db).await;
|
||||
@@ -1095,7 +1097,7 @@ mod tests {
|
||||
// Update password
|
||||
User::patch_password(email, new_password, &db)
|
||||
.await
|
||||
.expect("Failed to update password");
|
||||
.with_context(|| "Failed to update password".to_string())?;
|
||||
|
||||
// Old password should no longer work
|
||||
let old_auth = User::authenticate(email, old_password, &db).await;
|
||||
@@ -1104,10 +1106,11 @@ mod tests {
|
||||
// New password should work
|
||||
let new_auth = User::authenticate(email, new_password, &db).await;
|
||||
assert!(new_auth.is_ok());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validate_timezone() {
|
||||
async fn test_validate_timezone() -> anyhow::Result<()> {
|
||||
// Valid timezones should be accepted as-is
|
||||
assert_eq!(validate_timezone("America/New_York"), "America/New_York");
|
||||
assert_eq!(validate_timezone("Europe/London"), "Europe/London");
|
||||
@@ -1117,12 +1120,13 @@ mod tests {
|
||||
// Invalid timezones should be replaced with UTC
|
||||
assert_eq!(validate_timezone("Invalid/Timezone"), "UTC");
|
||||
assert_eq!(validate_timezone("Not_Real"), "UTC");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timezone_update() {
|
||||
async fn test_timezone_update() -> anyhow::Result<()> {
|
||||
// Setup test database
|
||||
let db = setup_test_db().await;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
// Create user with default timezone
|
||||
let email = "timezone_test@example.com";
|
||||
@@ -1134,7 +1138,7 @@ mod tests {
|
||||
"system".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create user");
|
||||
.with_context(|| "Failed to create user".to_string())?;
|
||||
|
||||
assert_eq!(user.timezone, "UTC");
|
||||
|
||||
@@ -1142,58 +1146,61 @@ mod tests {
|
||||
let new_timezone = "Europe/Paris";
|
||||
User::update_timezone(&user.id, new_timezone, &db)
|
||||
.await
|
||||
.expect("Failed to update timezone");
|
||||
.with_context(|| "Failed to update timezone".to_string())?;
|
||||
|
||||
// Verify timezone was updated
|
||||
let updated_user: Option<User> = db
|
||||
.get_item(&user.id)
|
||||
.await
|
||||
.expect("Failed to retrieve user");
|
||||
assert!(updated_user.is_some());
|
||||
let updated_user = updated_user.unwrap();
|
||||
.with_context(|| "Failed to retrieve user".to_string())?;
|
||||
let updated_user = updated_user.with_context(|| "expected updated user".to_string())?;
|
||||
assert_eq!(updated_user.timezone, new_timezone);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_conversations_order() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_conversations_order() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "user_order_test";
|
||||
|
||||
// Create conversations with varying updated_at timestamps
|
||||
let mut conversations = Vec::new();
|
||||
for i in 0..5 {
|
||||
let mut conv = Conversation::new(user_id.to_string(), format!("Conv {}", i));
|
||||
let mut conv = Conversation::new(user_id.to_string(), format!("Conv {i}"));
|
||||
// Fake updated_at i minutes apart
|
||||
conv.created_at = chrono::Utc::now() - chrono::Duration::minutes(i);
|
||||
db.store_item(conv.clone())
|
||||
.await
|
||||
.expect("Failed to store conversation");
|
||||
.with_context(|| "Failed to store conversation".to_string())?;
|
||||
conversations.push(conv);
|
||||
}
|
||||
|
||||
// Retrieve via get_user_conversations - should be ordered by updated_at DESC
|
||||
let retrieved = User::get_user_conversations(user_id, &db)
|
||||
.await
|
||||
.expect("Failed to get conversations");
|
||||
.with_context(|| "Failed to get conversations".to_string())?;
|
||||
|
||||
assert_eq!(retrieved.len(), conversations.len());
|
||||
|
||||
for window in retrieved.windows(2) {
|
||||
// Assert each earlier conversation has updated_at >= later conversation
|
||||
for pair in retrieved.windows(2) {
|
||||
let a = pair.first().context("expected first in pair")?;
|
||||
let b = pair.get(1).context("expected second in pair")?;
|
||||
assert!(
|
||||
window[0].created_at >= window[1].created_at,
|
||||
a.created_at >= b.created_at,
|
||||
"Conversations not ordered descending by created_at"
|
||||
);
|
||||
}
|
||||
|
||||
// Check first conversation title matches the most recently updated
|
||||
let most_recent = conversations.iter().max_by_key(|c| c.created_at).unwrap();
|
||||
assert_eq!(retrieved[0].id, most_recent.id);
|
||||
let most_recent = conversations.iter().max_by_key(|c| c.created_at).context("expected most recent")?;
|
||||
let r0 = retrieved.first().context("expected first result")?;
|
||||
assert_eq!(r0.id, most_recent.id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_latest_text_contents_returns_last_five() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_get_latest_text_contents_returns_last_five() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "latest_text_user";
|
||||
|
||||
let mut inserted_ids = Vec::new();
|
||||
@@ -1201,8 +1208,8 @@ mod tests {
|
||||
|
||||
for i in 0..12 {
|
||||
let mut item = TextContent::new(
|
||||
format!("Text {}", i),
|
||||
Some(format!("Context {}", i)),
|
||||
format!("Text {i}"),
|
||||
Some(format!("Context {i}")),
|
||||
"Category".to_string(),
|
||||
None,
|
||||
None,
|
||||
@@ -1215,18 +1222,19 @@ mod tests {
|
||||
|
||||
db.store_item(item.clone())
|
||||
.await
|
||||
.expect("Failed to store text content");
|
||||
.with_context(|| "Failed to store text content".to_string())?;
|
||||
|
||||
inserted_ids.push(item.id.clone());
|
||||
}
|
||||
|
||||
let latest = User::get_latest_text_contents(user_id, &db)
|
||||
.await
|
||||
.expect("Failed to fetch latest text contents");
|
||||
.with_context(|| "Failed to fetch latest text contents".to_string())?;
|
||||
|
||||
assert_eq!(latest.len(), 5, "Expected exactly five items");
|
||||
|
||||
let mut expected_ids = inserted_ids[inserted_ids.len() - 5..].to_vec();
|
||||
let start = inserted_ids.len().saturating_sub(5);
|
||||
let mut expected_ids = inserted_ids.get(start..).unwrap_or_default().to_vec();
|
||||
expected_ids.reverse();
|
||||
|
||||
let returned_ids: Vec<String> = latest.iter().map(|item| item.id.clone()).collect();
|
||||
@@ -1235,25 +1243,29 @@ mod tests {
|
||||
"Latest items did not match expectation"
|
||||
);
|
||||
|
||||
for window in latest.windows(2) {
|
||||
for pair in latest.windows(2) {
|
||||
let a = pair.first().context("expected first in pair")?;
|
||||
let b = pair.get(1).context("expected second in pair")?;
|
||||
assert!(
|
||||
window[0].created_at >= window[1].created_at,
|
||||
a.created_at >= b.created_at,
|
||||
"Results are not ordered by created_at descending"
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validate_theme() {
|
||||
async fn test_validate_theme() -> anyhow::Result<()> {
|
||||
assert_eq!(validate_theme("light"), Theme::Light);
|
||||
assert_eq!(validate_theme("dark"), Theme::Dark);
|
||||
assert_eq!(validate_theme("system"), Theme::System);
|
||||
assert_eq!(validate_theme("invalid"), Theme::System);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_theme_update() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_theme_update() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let email = "theme_test@example.com";
|
||||
let user = User::create_new(
|
||||
email.to_string(),
|
||||
@@ -1263,30 +1275,31 @@ mod tests {
|
||||
"system".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create user");
|
||||
.with_context(|| "Failed to create user".to_string())?;
|
||||
|
||||
assert_eq!(user.theme, Theme::System);
|
||||
|
||||
User::update_theme(&user.id, "dark", &db)
|
||||
.await
|
||||
.expect("update theme");
|
||||
.with_context(|| "update theme".to_string())?;
|
||||
|
||||
let updated = db
|
||||
.get_item::<User>(&user.id)
|
||||
.await
|
||||
.expect("get user")
|
||||
.unwrap();
|
||||
.with_context(|| "get user".to_string())?
|
||||
.with_context(|| "expected user".to_string())?;
|
||||
assert_eq!(updated.theme, Theme::Dark);
|
||||
|
||||
// Invalid theme should default to system (but update_theme calls validate_theme)
|
||||
User::update_theme(&user.id, "invalid", &db)
|
||||
.await
|
||||
.expect("update theme invalid");
|
||||
.with_context(|| "update theme invalid".to_string())?;
|
||||
let updated2 = db
|
||||
.get_item::<User>(&user.id)
|
||||
.await
|
||||
.expect("get user")
|
||||
.unwrap();
|
||||
.with_context(|| "get user".to_string())?
|
||||
.with_context(|| "expected user".to_string())?;
|
||||
assert_eq!(updated2.theme, Theme::System);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user