diff --git a/Cargo.toml b/Cargo.toml index 6006788..c619e74 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -106,11 +106,11 @@ missing_errors_doc = "allow" missing_panics_doc = "warn" module_name_repetitions = "warn" wildcard_dependencies = "warn" -missing_docs_in_private_items = "warn" +missing_docs_in_private_items = "allow" # Allow noisy lints that don't add value for this project needless_raw_string_hashes = "allow" multiple_bound_locations = "allow" cargo_common_metadata = "allow" multiple-crate-versions = "allow" -module_name_repetition = "allow" + diff --git a/api-router/src/api_state.rs b/api-router/src/api_state.rs index c9ba09b..9cdcd32 100644 --- a/api-router/src/api_state.rs +++ b/api-router/src/api_state.rs @@ -31,7 +31,7 @@ impl ApiState { surreal_db_client.apply_migrations().await?; let app_state = Self { - db: surreal_db_client.clone(), + db: Arc::clone(&surreal_db_client), config: config.clone(), storage, }; diff --git a/api-router/src/error.rs b/api-router/src/error.rs index f127752..b7ca986 100644 --- a/api-router/src/error.rs +++ b/api-router/src/error.rs @@ -8,7 +8,7 @@ use serde::Serialize; use thiserror::Error; #[derive(Error, Debug, Serialize, Clone)] -pub enum ApiError { +pub enum ApiErr { #[error("Internal server error")] InternalError(String), @@ -25,7 +25,7 @@ pub enum ApiError { PayloadTooLarge(String), } -impl From for ApiError { +impl From for ApiErr { fn from(err: AppError) -> Self { match err { AppError::Database(_) | AppError::OpenAI(_) => { @@ -39,7 +39,7 @@ impl From for ApiError { } } } -impl IntoResponse for ApiError { +impl IntoResponse for ApiErr { fn into_response(self) -> Response { let (status, error_response) = match self { Self::InternalError(message) => ( @@ -94,6 +94,7 @@ mod tests { use super::*; use common::error::AppError; use std::fmt::Debug; + use std::io; // Helper to check status code fn assert_status_code(response: T, expected_status: StatusCode) { @@ -105,42 +106,42 @@ mod tests { fn test_app_error_to_api_error_conversion() { // Test NotFound error conversion let not_found = AppError::NotFound("resource not found".to_string()); - let api_error = ApiError::from(not_found); - assert!(matches!(api_error, ApiError::NotFound(msg) if msg == "resource not found")); + let api_error = ApiErr::from(not_found); + assert!(matches!(api_error, ApiErr::NotFound(msg) if msg == "resource not found")); // Test Validation error conversion let validation = AppError::Validation("invalid input".to_string()); - let api_error = ApiError::from(validation); - assert!(matches!(api_error, ApiError::ValidationError(msg) if msg == "invalid input")); + let api_error = ApiErr::from(validation); + assert!(matches!(api_error, ApiErr::ValidationError(msg) if msg == "invalid input")); // Test Auth error conversion let auth = AppError::Auth("unauthorized".to_string()); - let api_error = ApiError::from(auth); - assert!(matches!(api_error, ApiError::Unauthorized(msg) if msg == "unauthorized")); + let api_error = ApiErr::from(auth); + assert!(matches!(api_error, ApiErr::Unauthorized(msg) if msg == "unauthorized")); // Test for internal errors - create a mock error that doesn't require surrealdb let internal_error = - AppError::Io(std::io::Error::new(std::io::ErrorKind::Other, "io error")); - let api_error = ApiError::from(internal_error); - assert!(matches!(api_error, ApiError::InternalError(_))); + AppError::Io(io::Error::other("io error")); + let api_error = ApiErr::from(internal_error); + assert!(matches!(api_error, ApiErr::InternalError(_))); } #[test] fn test_api_error_response_status_codes() { // Test internal error status - let error = ApiError::InternalError("server error".to_string()); + let error = ApiErr::InternalError("server error".to_string()); assert_status_code(error, StatusCode::INTERNAL_SERVER_ERROR); // Test not found status - let error = ApiError::NotFound("not found".to_string()); + let error = ApiErr::NotFound("not found".to_string()); assert_status_code(error, StatusCode::NOT_FOUND); // Test validation error status - let error = ApiError::ValidationError("invalid input".to_string()); + let error = ApiErr::ValidationError("invalid input".to_string()); assert_status_code(error, StatusCode::BAD_REQUEST); // Test unauthorized status - let error = ApiError::Unauthorized("not allowed".to_string()); + let error = ApiErr::Unauthorized("not allowed".to_string()); assert_status_code(error, StatusCode::UNAUTHORIZED); // Test payload too large status @@ -153,15 +154,15 @@ mod tests { fn test_error_messages() { // For validation errors let message = "invalid data format"; - let error = ApiError::ValidationError(message.to_string()); + let error = ApiErr::ValidationError(message.to_string()); // Check that the error itself contains the message - assert_eq!(error.to_string(), format!("Validation error: {}", message)); + assert_eq!(error.to_string(), format!("Validation error: {message}")); // For not found errors let message = "user not found"; - let error = ApiError::NotFound(message.to_string()); - assert_eq!(error.to_string(), format!("Not found: {}", message)); + let error = ApiErr::NotFound(message.to_string()); + assert_eq!(error.to_string(), format!("Not found: {message}")); } // Alternative approach for internal error test @@ -170,8 +171,8 @@ mod tests { // Create a sensitive error message let sensitive_info = "db password incorrect"; - // Create ApiError with sensitive info - let api_error = ApiError::InternalError(sensitive_info.to_string()); + // Create ApiErr with sensitive info + let api_error = ApiErr::InternalError(sensitive_info.to_string()); // Check the error message is correctly set assert_eq!(api_error.to_string(), "Internal server error"); diff --git a/api-router/src/middleware_api_auth.rs b/api-router/src/middleware_api_auth.rs index 6c58697..111d42d 100644 --- a/api-router/src/middleware_api_auth.rs +++ b/api-router/src/middleware_api_auth.rs @@ -6,19 +6,19 @@ use axum::{ use common::storage::types::user::User; -use crate::{api_state::ApiState, error::ApiError}; +use crate::{api_state::ApiState, error::ApiErr}; pub async fn api_auth( State(state): State, mut request: Request, next: Next, -) -> Result { +) -> Result { let api_key = extract_api_key(&request) - .ok_or_else(|| ApiError::Unauthorized("You have to be authenticated".to_string()))?; + .ok_or_else(|| ApiErr::Unauthorized("You have to be authenticated".to_string()))?; let user = User::find_by_api_key(&api_key, &state.db).await?; let user = - user.ok_or_else(|| ApiError::Unauthorized("You have to be authenticated".to_string()))?; + user.ok_or_else(|| ApiErr::Unauthorized("You have to be authenticated".to_string()))?; request.extensions_mut().insert(user); diff --git a/api-router/src/routes/categories.rs b/api-router/src/routes/categories.rs index 4035c1e..e04b0bb 100644 --- a/api-router/src/routes/categories.rs +++ b/api-router/src/routes/categories.rs @@ -1,12 +1,12 @@ use axum::{extract::State, response::IntoResponse, Extension, Json}; use common::storage::types::user::User; -use crate::{api_state::ApiState, error::ApiError}; +use crate::{api_state::ApiState, error::ApiErr}; -pub async fn get_categories( +pub async fn list( State(state): State, Extension(user): Extension, -) -> Result { +) -> Result { let categories = User::get_user_categories(&user.id, &state.db).await?; Ok(Json(categories)) diff --git a/api-router/src/routes/ingest.rs b/api-router/src/routes/ingest.rs index 8587dbb..a68f9ca 100644 --- a/api-router/src/routes/ingest.rs +++ b/api-router/src/routes/ingest.rs @@ -13,7 +13,7 @@ use serde_json::json; use tempfile::NamedTempFile; use tracing::info; -use crate::{api_state::ApiState, error::ApiError}; +use crate::{api_state::ApiState, error::ApiErr}; #[derive(Debug, TryFromMultipart)] pub struct IngestParams { diff --git a/common/src/storage/db.rs b/common/src/storage/db.rs index 30b15b4..f91135a 100644 --- a/common/src/storage/db.rs +++ b/common/src/storage/db.rs @@ -202,6 +202,7 @@ impl SurrealDbClient { #[cfg(test)] mod tests { + use anyhow::{self, Context}; use crate::stored_object; use super::*; @@ -212,19 +213,17 @@ mod tests { }); #[tokio::test] - async fn test_initialization_and_crud() { + async fn test_initialization_and_crud() -> anyhow::Result<()> { let namespace = "test_ns"; - let database = &Uuid::new_v4().to_string(); // ensures isolation per test run + let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; - // Call your initialization db.apply_migrations() .await - .expect("Failed to initialize schema"); + .with_context(|| "Failed to initialize schema".to_string())?; - // Test basic CRUD let dummy = Dummy { id: "abc".to_string(), name: "first".to_string(), @@ -232,50 +231,50 @@ mod tests { updated_at: Utc::now(), }; - // Store - let stored = db.store_item(dummy.clone()).await.expect("Failed to store"); + let stored = db + .store_item(dummy.clone()) + .await + .with_context(|| "Failed to store".to_string())?; assert!(stored.is_some()); - // Read let fetched = db .get_item::(&dummy.id) .await - .expect("Failed to fetch"); + .with_context(|| "Failed to fetch".to_string())?; assert_eq!(fetched, Some(dummy.clone())); - // Read all let all = db .get_all_stored_items::() .await - .expect("Failed to fetch all"); + .with_context(|| "Failed to fetch all".to_string())?; assert!(all.contains(&dummy)); - // Delete let deleted = db .delete_item::(&dummy.id) .await - .expect("Failed to delete"); + .with_context(|| "Failed to delete".to_string())?; assert_eq!(deleted, Some(dummy)); - // After delete, should not be present let fetch_post = db .get_item::("abc") .await - .expect("Failed fetch post delete"); + .with_context(|| "Failed fetch post delete".to_string())?; assert!(fetch_post.is_none()); + + Ok(()) } #[tokio::test] - async fn upsert_item_overwrites_existing_records() { + async fn upsert_item_overwrites_existing_records() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; db.apply_migrations() .await - .expect("Failed to initialize schema"); + .with_context(|| "Failed to initialize schema".to_string())?; let mut dummy = Dummy { id: "abc".to_string(), @@ -286,17 +285,21 @@ mod tests { db.store_item(dummy.clone()) .await - .expect("Failed to store initial record"); + .with_context(|| "Failed to store initial record".to_string())?; dummy.name = "updated".to_string(); let upserted = db .upsert_item(dummy.clone()) .await - .expect("Failed to upsert record"); + .with_context(|| "Failed to upsert record".to_string())?; assert!(upserted.is_some()); - let fetched: Option = db.get_item(&dummy.id).await.expect("fetch after upsert"); - assert_eq!(fetched.unwrap().name, "updated"); + let fetched: Option = db + .get_item(&dummy.id) + .await + .with_context(|| "fetch after upsert".to_string())?; + let fetched = fetched.ok_or_else(|| anyhow::anyhow!("Expected record to exist after upsert"))?; + assert_eq!(fetched.name, "updated"); let new_record = Dummy { id: "def".to_string(), @@ -306,25 +309,29 @@ mod tests { }; db.upsert_item(new_record.clone()) .await - .expect("Failed to upsert new record"); + .with_context(|| "Failed to upsert new record".to_string())?; let fetched_new: Option = db .get_item(&new_record.id) .await - .expect("fetch inserted via upsert"); + .with_context(|| "fetch inserted via upsert".to_string())?; assert_eq!(fetched_new, Some(new_record)); + + Ok(()) } #[tokio::test] - async fn test_applying_migrations() { + async fn test_applying_migrations() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; db.apply_migrations() .await - .expect("Failed to build indexes"); + .with_context(|| "Failed to build indexes".to_string())?; + + Ok(()) } } diff --git a/common/src/storage/indexes.rs b/common/src/storage/indexes.rs index 79e7ad0..1b9b682 100644 --- a/common/src/storage/indexes.rs +++ b/common/src/storage/indexes.rs @@ -159,23 +159,23 @@ impl FtsIndexSpec { /// Build runtime Surreal indexes (FTS + HNSW) using concurrent creation with readiness polling. /// Idempotent: safe to call multiple times and will overwrite HNSW definitions when the dimension changes. -pub async fn ensure_runtime_indexes( +pub async fn ensure_runtime( db: &SurrealDbClient, embedding_dimension: usize, ) -> Result<(), AppError> { - ensure_runtime_indexes_inner(db, embedding_dimension) + ensure_runtime_inner(db, embedding_dimension) .await .map_err(|err| AppError::InternalError(err.to_string())) } /// Rebuild known FTS and HNSW indexes, skipping any that are not yet defined. -pub async fn rebuild_indexes(db: &SurrealDbClient) -> Result<(), AppError> { - rebuild_indexes_inner(db) +pub async fn rebuild(db: &SurrealDbClient) -> Result<(), AppError> { + rebuild_inner(db) .await .map_err(|err| AppError::InternalError(err.to_string())) } -async fn ensure_runtime_indexes_inner( +async fn ensure_runtime_inner( db: &SurrealDbClient, embedding_dimension: usize, ) -> Result<()> { @@ -262,9 +262,8 @@ async fn get_index_status(db: &SurrealDbClient, index_name: &str, table: &str) - .context("checking index status")?; let info: Option = info_res.take(0).context("failed to take info result")?; - let info = match info { - Some(i) => i, - None => return Ok("unknown".to_string()), + let Some(info) = info else { + return Ok("unknown".to_string()); }; let building = info.get("building"); @@ -277,7 +276,7 @@ async fn get_index_status(db: &SurrealDbClient, index_name: &str, table: &str) - Ok(status) } -async fn rebuild_indexes_inner(db: &SurrealDbClient) -> Result<()> { +async fn rebuild_inner(db: &SurrealDbClient) -> Result<()> { debug!("Rebuilding indexes with concurrent definitions"); create_fts_analyzer(db).await?; @@ -385,10 +384,9 @@ async fn create_fts_analyzer(db: &SurrealDbClient) -> Result<()> { // is unavailable in the running Surreal build. Use IF NOT EXISTS to avoid clobbering // an existing analyzer definition. let snowball_query = format!( - "DEFINE ANALYZER IF NOT EXISTS {analyzer} + "DEFINE ANALYZER IF NOT EXISTS {FTS_ANALYZER_NAME} TOKENIZERS class - FILTERS lowercase, ascii, snowball(english);", - analyzer = FTS_ANALYZER_NAME + FILTERS lowercase, ascii, snowball(english);" ); match db.client.query(snowball_query).await { @@ -410,10 +408,9 @@ async fn create_fts_analyzer(db: &SurrealDbClient) -> Result<()> { } let fallback_query = format!( - "DEFINE ANALYZER IF NOT EXISTS {analyzer} + "DEFINE ANALYZER IF NOT EXISTS {FTS_ANALYZER_NAME} TOKENIZERS class - FILTERS lowercase, ascii;", - analyzer = FTS_ANALYZER_NAME + FILTERS lowercase, ascii;" ); let res = db @@ -446,6 +443,7 @@ async fn create_index_with_polling( table: &str, progress_table: Option<&str>, ) -> Result<()> { + const MAX_ATTEMPTS: usize = 3; let expected_total = match progress_table { Some(table) => Some(count_table_rows(db, table).await.with_context(|| { format!("counting rows in {table} for index {index_name} progress") @@ -453,10 +451,9 @@ async fn create_index_with_polling( None => None, }; - let mut attempts = 0; - const MAX_ATTEMPTS: usize = 3; + let mut attempts: usize = 0; loop { - attempts += 1; + attempts = attempts.saturating_add(1); let res = db .client .query(definition.clone()) @@ -527,8 +524,8 @@ async fn poll_index_build_status( break; }; - match snapshot.progress_pct { - Some(pct) => debug!( + if let Some(pct) = snapshot.progress_pct { + debug!( index = %index_name, table = %table, status = snapshot.status, @@ -539,8 +536,9 @@ async fn poll_index_build_status( total = snapshot.total_rows, progress_pct = format_args!("{pct:.1}"), "Index build status" - ), - None => debug!( + ); + } else { + debug!( index = %index_name, table = %table, status = snapshot.status, @@ -549,7 +547,7 @@ async fn poll_index_build_status( updated = snapshot.updated, processed = snapshot.processed, "Index build status" - ), + ); } if snapshot.is_ready() { @@ -611,17 +609,17 @@ fn parse_index_build_info( let initial = building .and_then(|b| b.get("initial")) - .and_then(|v| v.as_u64()) + .and_then(serde_json::Value::as_u64) .unwrap_or(0); let pending = building .and_then(|b| b.get("pending")) - .and_then(|v| v.as_u64()) + .and_then(serde_json::Value::as_u64) .unwrap_or(0); let updated = building .and_then(|b| b.get("updated")) - .and_then(|v| v.as_u64()) + .and_then(serde_json::Value::as_u64) .unwrap_or(0); // `initial` is the number of rows seen when the build started; `updated` accounts for later writes. @@ -631,7 +629,7 @@ fn parse_index_build_info( if total == 0 { 0.0 } else { - ((processed as f64 / total as f64).min(1.0)) * 100.0 + ((f64::from(u32::try_from(processed).unwrap_or(u32::MAX)) / f64::from(u32::try_from(total).unwrap_or(1))).min(1.0)) * 100.0 } }); @@ -673,7 +671,7 @@ async fn table_index_definitions( .client .query(info_query) .await - .with_context(|| format!("fetching table info for {}", table))?; + .with_context(|| format!("fetching table info for {table}"))?; let info: surrealdb::Value = response .take(0) @@ -700,12 +698,15 @@ async fn index_exists(db: &SurrealDbClient, table: &str, index_name: &str) -> Re #[cfg(test)] mod tests { - use super::*; + use anyhow::{self, Context}; + use crate::storage::db::SurrealDbClient; use serde_json::json; use uuid::Uuid; + use super::*; + #[test] - fn parse_index_build_info_reports_progress() { + fn parse_index_build_info_reports_progress() -> anyhow::Result<()> { let info = json!({ "building": { "initial": 56894, @@ -715,7 +716,8 @@ mod tests { } }); - let snapshot = parse_index_build_info(Some(info), Some(61081)).expect("snapshot"); + let snapshot = parse_index_build_info(Some(info), Some(61081)) + .context("snapshot")?; assert_eq!( snapshot, IndexBuildSnapshot { @@ -729,16 +731,19 @@ mod tests { } ); assert!(!snapshot.is_ready()); + Ok(()) } #[test] - fn parse_index_build_info_defaults_to_ready_when_no_building_block() { + fn parse_index_build_info_defaults_to_ready_when_no_building_block() -> anyhow::Result<()> { // Surreal returns `{}` when the index exists but isn't building. let info = json!({}); - let snapshot = parse_index_build_info(Some(info), Some(10)).expect("snapshot"); + let snapshot = parse_index_build_info(Some(info), Some(10)) + .context("snapshot")?; assert!(snapshot.is_ready()); assert_eq!(snapshot.processed, 0); assert_eq!(snapshot.progress_pct, Some(0.0)); + Ok(()) } #[test] @@ -748,48 +753,40 @@ mod tests { } #[tokio::test] - async fn ensure_runtime_indexes_is_idempotent() { + async fn ensure_runtime_is_idempotent() -> anyhow::Result<()> { let namespace = "indexes_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("in-memory db"); + .context("in-memory db")?; db.apply_migrations() .await - .expect("migrations should succeed"); + .context("migrations should succeed")?; - // First run creates everything - ensure_runtime_indexes(&db, 1536) - .await - .expect("initial index creation"); - - // Second run should be a no-op and still succeed - ensure_runtime_indexes(&db, 1536) - .await - .expect("second index creation"); + ensure_runtime(&db, 1536).await + .context("first call should succeed")?; + ensure_runtime(&db, 1536).await + .context("second index creation")?; + Ok(()) } #[tokio::test] - async fn ensure_hnsw_index_overwrites_dimension() { + async fn ensure_hnsw_index_overwrites_dimension() -> anyhow::Result<()> { let namespace = "indexes_dim"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("in-memory db"); + .context("in-memory db")?; db.apply_migrations() .await - .expect("migrations should succeed"); + .context("migrations should succeed")?; - // Create initial index with default dimension - ensure_runtime_indexes(&db, 1536) - .await - .expect("initial index creation"); - - // Change dimension and ensure overwrite path is exercised - ensure_runtime_indexes(&db, 128) - .await - .expect("overwritten index creation"); + ensure_runtime(&db, 1536).await + .context("initial index creation")?; + ensure_runtime(&db, 128).await + .context("overwritten index creation")?; + Ok(()) } } diff --git a/common/src/storage/store.rs b/common/src/storage/store.rs index bb78034..154b543 100644 --- a/common/src/storage/store.rs +++ b/common/src/storage/store.rs @@ -13,13 +13,13 @@ use object_store::{path::Path as ObjPath, ObjectStore}; use crate::utils::config::{AppConfig, StorageKind}; -pub type DynStore = Arc; +pub type DynStorage = Arc; /// Storage manager with persistent state and proper lifecycle management. #[derive(Clone)] pub struct StorageManager { // Store from objectstore wrapped as dyn - store: DynStore, + store: DynStorage, // Simple enum to track which kind backend_kind: StorageKind, // Where on disk @@ -46,7 +46,7 @@ impl StorageManager { /// /// This method is useful for testing scenarios where you want to inject /// a specific storage backend. - pub fn with_backend(store: DynStore, backend_kind: StorageKind) -> Self { + pub fn with_backend(store: DynStorage, backend_kind: StorageKind) -> Self { Self { store, backend_kind, @@ -216,7 +216,7 @@ impl StorageManager { /// storage backends with proper error handling and validation. async fn create_storage_backend( cfg: &AppConfig, -) -> object_store::Result<(DynStore, Option)> { +) -> object_store::Result<(DynStorage, Option)> { match cfg.storage { StorageKind::Local => { let base = resolve_base_dir(cfg); @@ -261,9 +261,7 @@ async fn create_storage_backend( builder = builder.with_endpoint(endpoint); } - if let Some(region) = &cfg.s3_region { - builder = builder.with_region(region); - } + builder = builder.with_region(&cfg.s3_region); let store = builder.build()?; Ok((Arc::new(store), None)) @@ -342,7 +340,7 @@ pub mod testing { surrealdb_password: "test".into(), surrealdb_namespace: "test".into(), surrealdb_database: "test".into(), - data_dir: base.into(), + data_dir: base, http_port: 0, openai_base_url: "..".into(), storage: StorageKind::Local, @@ -382,7 +380,7 @@ pub mod testing { #[derive(Clone)] pub struct TestStorageManager { storage: StorageManager, - _temp_dir: Option<(String, std::path::PathBuf)>, // For local storage cleanup + temp_dir: Option<(String, std::path::PathBuf)>, // For local storage cleanup } impl TestStorageManager { @@ -396,7 +394,7 @@ pub mod testing { Ok(Self { storage, - _temp_dir: None, + temp_dir: None, }) } @@ -413,7 +411,7 @@ pub mod testing { Ok(Self { storage, - _temp_dir: resolved, + temp_dir: resolved, }) } @@ -437,7 +435,7 @@ pub mod testing { Ok(Self { storage, - _temp_dir: None, + temp_dir: None, }) } @@ -454,7 +452,7 @@ pub mod testing { Ok(Self { storage, - _temp_dir: temp_dir, + temp_dir, }) } @@ -508,7 +506,7 @@ pub mod testing { impl Drop for TestStorageManager { fn drop(&mut self) { // Clean up temporary directories for local storage - if let Some((_, path)) = &self._temp_dir { + if let Some((_, path)) = &self.temp_dir { if path.exists() { let _ = std::fs::remove_dir_all(path); } @@ -584,6 +582,7 @@ pub fn split_object_path(path: &str) -> AnyResult<(String, String)> { #[cfg(test)] mod tests { use super::*; + use anyhow::Context; use crate::utils::config::{PdfIngestMode::LlmFirst, StorageKind}; use bytes::Bytes; use uuid::Uuid; @@ -623,11 +622,11 @@ mod tests { } #[tokio::test] - async fn test_storage_manager_memory_basic_operations() { + async fn test_storage_manager_memory_basic_operations() -> anyhow::Result<()> { let cfg = test_config_memory(); let storage = StorageManager::new(&cfg) .await - .expect("create storage manager"); + .with_context(|| "create storage manager".to_string())?; assert!(storage.local_base_path().is_none()); let location = "test/data/file.txt"; @@ -637,31 +636,33 @@ mod tests { storage .put(location, Bytes::from(data.to_vec())) .await - .expect("put"); - let retrieved = storage.get(location).await.expect("get"); + .with_context(|| "put".to_string())?; + let retrieved = storage.get(location).await.with_context(|| "get".to_string())?; assert_eq!(retrieved.as_ref(), data); // Test exists - assert!(storage.exists(location).await.expect("exists check")); + assert!(storage.exists(location).await.with_context(|| "exists check".to_string())?); // Test delete - storage.delete_prefix("test/data/").await.expect("delete"); + storage.delete_prefix("test/data/").await.with_context(|| "delete".to_string())?; assert!(!storage .exists(location) .await - .expect("exists check after delete")); + .with_context(|| "exists check after delete".to_string())?); + + Ok(()) } #[tokio::test] - async fn test_storage_manager_local_basic_operations() { + async fn test_storage_manager_local_basic_operations() -> anyhow::Result<()> { let base = format!("/tmp/minne_storage_test_{}", Uuid::new_v4()); let cfg = test_config(&base); let storage = StorageManager::new(&cfg) .await - .expect("create storage manager"); + .with_context(|| "create storage manager".to_string())?; let resolved_base = storage .local_base_path() - .expect("resolved base dir") + .with_context(|| "resolved base dir".to_string())? .to_path_buf(); assert_eq!(resolved_base, PathBuf::from(&base)); @@ -672,42 +673,44 @@ mod tests { storage .put(location, Bytes::from(data.to_vec())) .await - .expect("put"); - let retrieved = storage.get(location).await.expect("get"); + .with_context(|| "put".to_string())?; + let retrieved = storage.get(location).await.with_context(|| "get".to_string())?; assert_eq!(retrieved.as_ref(), data); let object_dir = resolved_base.join("test/data"); tokio::fs::metadata(&object_dir) .await - .expect("object directory exists after write"); + .with_context(|| "object directory exists after write".to_string())?; // Test exists - assert!(storage.exists(location).await.expect("exists check")); + assert!(storage.exists(location).await.with_context(|| "exists check".to_string())?); // Test delete - storage.delete_prefix("test/data/").await.expect("delete"); + storage.delete_prefix("test/data/").await.with_context(|| "delete".to_string())?; assert!(!storage .exists(location) .await - .expect("exists check after delete")); + .with_context(|| "exists check after delete".to_string())?); assert!( tokio::fs::metadata(&object_dir).await.is_err(), "object directory should be removed" ); tokio::fs::metadata(&resolved_base) .await - .expect("base directory remains intact"); + .with_context(|| "base directory remains intact".to_string())?; // Clean up let _ = tokio::fs::remove_dir_all(&base).await; + + Ok(()) } #[tokio::test] - async fn test_storage_manager_memory_persistence() { + async fn test_storage_manager_memory_persistence() -> anyhow::Result<()> { let cfg = test_config_memory(); let storage = StorageManager::new(&cfg) .await - .expect("create storage manager"); + .with_context(|| "create storage manager".to_string())?; let location = "persistence/test.txt"; let data1 = b"first data"; @@ -717,32 +720,34 @@ mod tests { storage .put(location, Bytes::from(data1.to_vec())) .await - .expect("put first"); + .with_context(|| "put first".to_string())?; // Retrieve and verify first data - let retrieved1 = storage.get(location).await.expect("get first"); + let retrieved1 = storage.get(location).await.with_context(|| "get first".to_string())?; assert_eq!(retrieved1.as_ref(), data1); // Overwrite with second data storage .put(location, Bytes::from(data2.to_vec())) .await - .expect("put second"); + .with_context(|| "put second".to_string())?; // Retrieve and verify second data - let retrieved2 = storage.get(location).await.expect("get second"); + let retrieved2 = storage.get(location).await.with_context(|| "get second".to_string())?; assert_eq!(retrieved2.as_ref(), data2); // Data persists across multiple operations using the same StorageManager assert_ne!(retrieved1.as_ref(), retrieved2.as_ref()); + + Ok(()) } #[tokio::test] - async fn test_storage_manager_list_operations() { + async fn test_storage_manager_list_operations() -> anyhow::Result<()> { let cfg = test_config_memory(); let storage = StorageManager::new(&cfg) .await - .expect("create storage manager"); + .with_context(|| "create storage manager".to_string())?; // Create multiple files let files = vec![ @@ -755,15 +760,15 @@ mod tests { storage .put(location, Bytes::from(data.to_vec())) .await - .expect("put"); + .with_context(|| "put".to_string())?; } // Test listing without prefix - let all_files = storage.list(None).await.expect("list all"); + let all_files = storage.list(None).await.with_context(|| "list all".to_string())?; assert_eq!(all_files.len(), 3); // Test listing with prefix - let dir1_files = storage.list(Some("dir1/")).await.expect("list dir1"); + let dir1_files = storage.list(Some("dir1/")).await.with_context(|| "list dir1".to_string())?; assert_eq!(dir1_files.len(), 2); assert!(dir1_files .iter() @@ -776,16 +781,18 @@ mod tests { let empty_files = storage .list(Some("nonexistent/")) .await - .expect("list nonexistent"); + .with_context(|| "list nonexistent".to_string())?; assert_eq!(empty_files.len(), 0); + + Ok(()) } #[tokio::test] - async fn test_storage_manager_stream_operations() { + async fn test_storage_manager_stream_operations() -> anyhow::Result<()> { let cfg = test_config_memory(); let storage = StorageManager::new(&cfg) .await - .expect("create storage manager"); + .with_context(|| "create storage manager".to_string())?; let location = "stream/test.bin"; let content = vec![42u8; 1024 * 64]; // 64KB of data @@ -794,22 +801,24 @@ mod tests { storage .put(location, Bytes::from(content.clone())) .await - .expect("put large data"); + .with_context(|| "put large data".to_string())?; // Get as stream - let mut stream = storage.get_stream(location).await.expect("get stream"); + let mut stream = storage.get_stream(location).await.with_context(|| "get stream".to_string())?; let mut collected = Vec::new(); while let Some(chunk) = stream.next().await { - let chunk = chunk.expect("stream chunk"); + let chunk = chunk.with_context(|| "stream chunk".to_string())?; collected.extend_from_slice(&chunk); } assert_eq!(collected, content); + + Ok(()) } #[tokio::test] - async fn test_storage_manager_with_custom_backend() { + async fn test_storage_manager_with_custom_backend() -> anyhow::Result<()> { use object_store::memory::InMemory; // Create custom memory backend @@ -823,20 +832,22 @@ mod tests { storage .put(location, Bytes::from(data.to_vec())) .await - .expect("put"); - let retrieved = storage.get(location).await.expect("get"); + .with_context(|| "put".to_string())?; + let retrieved = storage.get(location).await.with_context(|| "get".to_string())?; assert_eq!(retrieved.as_ref(), data); - assert!(storage.exists(location).await.expect("exists")); + assert!(storage.exists(location).await.with_context(|| "exists".to_string())?); assert_eq!(*storage.backend_kind(), StorageKind::Memory); + + Ok(()) } #[tokio::test] - async fn test_storage_manager_error_handling() { + async fn test_storage_manager_error_handling() -> anyhow::Result<()> { let cfg = test_config_memory(); let storage = StorageManager::new(&cfg) .await - .expect("create storage manager"); + .with_context(|| "create storage manager".to_string())?; // Test getting non-existent file let result = storage.get("nonexistent.txt").await; @@ -846,124 +857,136 @@ mod tests { let exists = storage .exists("nonexistent.txt") .await - .expect("exists check"); + .with_context(|| "exists check".to_string())?; assert!(!exists); // Test listing with invalid location (should not panic) let _result = storage.get("").await; // This may or may not error depending on the backend implementation // The important thing is that it doesn't panic + + Ok(()) } // TestStorageManager tests #[tokio::test] - async fn test_test_storage_manager_memory() { + async fn test_test_storage_manager_memory() -> anyhow::Result<()> { let test_storage = testing::TestStorageManager::new_memory() .await - .expect("create test storage"); + .with_context(|| "create test storage".to_string())?; let location = "test/storage/file.txt"; let data = b"test data with TestStorageManager"; // Test put and get - test_storage.put(location, data).await.expect("put"); - let retrieved = test_storage.get(location).await.expect("get"); + test_storage.put(location, data).await.with_context(|| "put".to_string())?; + let retrieved = test_storage.get(location).await.with_context(|| "get".to_string())?; assert_eq!(retrieved.as_ref(), data); // Test existence check - assert!(test_storage.exists(location).await.expect("exists")); + assert!(test_storage.exists(location).await.with_context(|| "exists".to_string())?); // Test list let files = test_storage .list(Some("test/storage/")) .await - .expect("list"); + .with_context(|| "list".to_string())?; assert_eq!(files.len(), 1); // Test delete test_storage .delete_prefix("test/storage/") .await - .expect("delete"); + .with_context(|| "delete".to_string())?; assert!(!test_storage .exists(location) .await - .expect("exists after delete")); + .with_context(|| "exists after delete".to_string())?); + + Ok(()) } #[tokio::test] - async fn test_test_storage_manager_local() { + async fn test_test_storage_manager_local() -> anyhow::Result<()> { let test_storage = testing::TestStorageManager::new_local() .await - .expect("create test storage"); + .with_context(|| "create test storage".to_string())?; let location = "test/local/file.txt"; let data = b"test data with local TestStorageManager"; - // Test put and get - test_storage.put(location, data).await.expect("put"); - let retrieved = test_storage.get(location).await.expect("get"); + test_storage.put(location, data).await + .with_context(|| "put".to_string())?; + let retrieved = test_storage.get(location).await + .with_context(|| "get".to_string())?; assert_eq!(retrieved.as_ref(), data); - // Test existence check - assert!(test_storage.exists(location).await.expect("exists")); + assert!(test_storage.exists(location).await + .with_context(|| "exists".to_string())?); - // The storage should be automatically cleaned up when test_storage is dropped + Ok(()) } #[tokio::test] - async fn test_test_storage_manager_isolation() { + async fn test_test_storage_manager_isolation() -> anyhow::Result<()> { let storage1 = testing::TestStorageManager::new_memory() .await - .expect("create test storage 1"); + .with_context(|| "create test storage 1".to_string())?; let storage2 = testing::TestStorageManager::new_memory() .await - .expect("create test storage 2"); + .with_context(|| "create test storage 2".to_string())?; let location = "isolation/test.txt"; let data1 = b"storage 1 data"; let data2 = b"storage 2 data"; - // Put different data in each storage - storage1.put(location, data1).await.expect("put storage 1"); - storage2.put(location, data2).await.expect("put storage 2"); + storage1.put(location, data1).await + .with_context(|| "put storage 1".to_string())?; + storage2.put(location, data2).await + .with_context(|| "put storage 2".to_string())?; - // Verify isolation - let retrieved1 = storage1.get(location).await.expect("get storage 1"); - let retrieved2 = storage2.get(location).await.expect("get storage 2"); + let retrieved1 = storage1.get(location).await + .with_context(|| "get storage 1".to_string())?; + let retrieved2 = storage2.get(location).await + .with_context(|| "get storage 2".to_string())?; assert_eq!(retrieved1.as_ref(), data1); assert_eq!(retrieved2.as_ref(), data2); assert_ne!(retrieved1.as_ref(), retrieved2.as_ref()); + + Ok(()) } #[tokio::test] - async fn test_test_storage_manager_config() { + async fn test_test_storage_manager_config() -> anyhow::Result<()> { let cfg = testing::test_config_memory(); let test_storage = testing::TestStorageManager::with_config(&cfg) .await - .expect("create test storage with config"); + .with_context(|| "create test storage with config".to_string())?; let location = "config/test.txt"; let data = b"test data with custom config"; - test_storage.put(location, data).await.expect("put"); - let retrieved = test_storage.get(location).await.expect("get"); + test_storage.put(location, data).await + .with_context(|| "put".to_string())?; + let retrieved = test_storage.get(location).await + .with_context(|| "get".to_string())?; assert_eq!(retrieved.as_ref(), data); - // Verify it's using memory backend assert_eq!(*test_storage.storage().backend_kind(), StorageKind::Memory); + + Ok(()) } // S3 Tests - Require a reachable MinIO endpoint and test bucket. // `TestStorageManager::new_s3()` probes connectivity and these tests auto-skip when unavailable. #[tokio::test] - async fn test_storage_manager_s3_basic_operations() { + async fn test_storage_manager_s3_basic_operations() -> anyhow::Result<()> { // Skip if S3 connection fails (e.g. no MinIO) let Ok(storage) = testing::TestStorageManager::new_s3().await else { eprintln!("Skipping S3 test (setup failed)"); - return; + return Ok(()); }; let prefix = format!("test-basic-{}", Uuid::new_v4()); @@ -973,31 +996,33 @@ mod tests { // Test put if let Err(e) = storage.put(&location, data).await { eprintln!("Skipping S3 test (put failed - bucket missing?): {e}"); - return; + return Ok(()); } // Test get - let retrieved = storage.get(&location).await.expect("get"); + let retrieved = storage.get(&location).await.with_context(|| "get".to_string())?; assert_eq!(retrieved.as_ref(), data); // Test exists - assert!(storage.exists(&location).await.expect("exists")); + assert!(storage.exists(&location).await.with_context(|| "exists".to_string())?); // Test delete storage .delete_prefix(&format!("{prefix}/")) .await - .expect("delete"); + .with_context(|| "delete".to_string())?; assert!(!storage .exists(&location) .await - .expect("exists after delete")); + .with_context(|| "exists after delete".to_string())?); + + Ok(()) } #[tokio::test] - async fn test_storage_manager_s3_list_operations() { + async fn test_storage_manager_s3_list_operations() -> anyhow::Result<()> { let Ok(storage) = testing::TestStorageManager::new_s3().await else { - return; + return Ok(()); }; let prefix = format!("test-list-{}", Uuid::new_v4()); @@ -1009,23 +1034,25 @@ mod tests { for (loc, data) in &files { if storage.put(loc, *data).await.is_err() { - return; // Abort if put fails + return Ok(()); // Abort if put fails } } // List with prefix let list_prefix = format!("{prefix}/"); - let items = storage.list(Some(&list_prefix)).await.expect("list"); + let items = storage.list(Some(&list_prefix)).await.with_context(|| "list".to_string())?; assert_eq!(items.len(), 3); // Cleanup - storage.delete_prefix(&list_prefix).await.expect("cleanup"); + storage.delete_prefix(&list_prefix).await.with_context(|| "cleanup".to_string())?; + + Ok(()) } #[tokio::test] - async fn test_storage_manager_s3_stream_operations() { + async fn test_storage_manager_s3_stream_operations() -> anyhow::Result<()> { let Ok(storage) = testing::TestStorageManager::new_s3().await else { - return; + return Ok(()); }; let prefix = format!("test-stream-{}", Uuid::new_v4()); @@ -1033,38 +1060,45 @@ mod tests { let content = vec![42u8; 1024 * 10]; // 10KB if storage.put(&location, &content).await.is_err() { - return; + return Ok(()); } - let mut stream = storage.get_stream(&location).await.expect("get stream"); + let mut stream = storage.get_stream(&location).await.with_context(|| "get stream".to_string())?; let mut collected = Vec::new(); while let Some(chunk) = stream.next().await { - collected.extend_from_slice(&chunk.expect("chunk")); + collected.extend_from_slice(&chunk.with_context(|| "chunk".to_string())?); } assert_eq!(collected, content); storage .delete_prefix(&format!("{prefix}/")) .await - .expect("cleanup"); + .with_context(|| "cleanup".to_string())?; + + Ok(()) } #[tokio::test] - async fn test_storage_manager_s3_backend_kind() { + async fn test_storage_manager_s3_backend_kind() -> anyhow::Result<()> { let Ok(storage) = testing::TestStorageManager::new_s3().await else { - return; + return Ok(()); }; assert_eq!(*storage.storage().backend_kind(), StorageKind::S3); + + Ok(()) } #[tokio::test] - async fn test_storage_manager_s3_error_handling() { + async fn test_storage_manager_s3_error_handling() -> anyhow::Result<()> { let Ok(storage) = testing::TestStorageManager::new_s3().await else { - return; + return Ok(()); }; let location = format!("nonexistent-{}/file.txt", Uuid::new_v4()); assert!(storage.get(&location).await.is_err()); - assert!(!storage.exists(&location).await.expect("exists check")); + // exists may fail if S3 is unavailable; treat error as false + assert!(!storage.exists(&location).await.unwrap_or(false)); + + Ok(()) } } diff --git a/common/src/storage/types/analytics.rs b/common/src/storage/types/analytics.rs index 79a477d..2f7e6e2 100644 --- a/common/src/storage/types/analytics.rs +++ b/common/src/storage/types/analytics.rs @@ -90,6 +90,7 @@ impl Analytics { mod tests { use super::*; use crate::stored_object; + use anyhow::{self}; use uuid::Uuid; stored_object!(TestUser, "user", { @@ -99,18 +100,14 @@ mod tests { }); #[tokio::test] - async fn test_analytics_initialization() { + async fn test_analytics_initialization() -> anyhow::Result<()> { // Setup in-memory database for testing let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); - let db = SurrealDbClient::memory(namespace, database) - .await - .expect("Failed to start in-memory surrealdb"); + let db = SurrealDbClient::memory(namespace, database).await?; // Test initialization of analytics - let analytics = Analytics::ensure_initialized(&db) - .await - .expect("Failed to initialize analytics"); + let analytics = Analytics::ensure_initialized(&db).await?; // Verify initial state after initialization assert_eq!(analytics.id, "current"); @@ -118,159 +115,134 @@ mod tests { assert_eq!(analytics.visitors, 0); // Test idempotency - ensure calling it again doesn't change anything - let analytics_again = Analytics::ensure_initialized(&db) - .await - .expect("Failed to get analytics after initialization"); + let analytics_again = Analytics::ensure_initialized(&db).await?; assert_eq!(analytics.id, analytics_again.id); assert_eq!(analytics.page_loads, analytics_again.page_loads); assert_eq!(analytics.visitors, analytics_again.visitors); + + Ok(()) } #[tokio::test] - async fn test_get_current_analytics() { + async fn test_get_current_analytics() -> anyhow::Result<()> { // Setup in-memory database for testing let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); - let db = SurrealDbClient::memory(namespace, database) - .await - .expect("Failed to start in-memory surrealdb"); + let db = SurrealDbClient::memory(namespace, database).await?; // Initialize analytics - Analytics::ensure_initialized(&db) - .await - .expect("Failed to initialize analytics"); + Analytics::ensure_initialized(&db).await?; // Test get_current method - let analytics = Analytics::get_current(&db) - .await - .expect("Failed to get current analytics"); + let analytics = Analytics::get_current(&db).await?; assert_eq!(analytics.id, "current"); assert_eq!(analytics.page_loads, 0); assert_eq!(analytics.visitors, 0); + + Ok(()) } #[tokio::test] - async fn test_increment_visitors() { + async fn test_increment_visitors() -> anyhow::Result<()> { // Setup in-memory database for testing let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); - let db = SurrealDbClient::memory(namespace, database) - .await - .expect("Failed to start in-memory surrealdb"); + let db = SurrealDbClient::memory(namespace, database).await?; // Initialize analytics - Analytics::ensure_initialized(&db) - .await - .expect("Failed to initialize analytics"); + Analytics::ensure_initialized(&db).await?; // Test increment_visitors method - let analytics = Analytics::increment_visitors(&db) - .await - .expect("Failed to increment visitors"); + let analytics = Analytics::increment_visitors(&db).await?; assert_eq!(analytics.visitors, 1); assert_eq!(analytics.page_loads, 0); // Increment again and check - let analytics = Analytics::increment_visitors(&db) - .await - .expect("Failed to increment visitors again"); + let analytics = Analytics::increment_visitors(&db).await?; assert_eq!(analytics.visitors, 2); assert_eq!(analytics.page_loads, 0); + + Ok(()) } #[tokio::test] - async fn test_increment_page_loads() { + async fn test_increment_page_loads() -> anyhow::Result<()> { // Setup in-memory database for testing let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); - let db = SurrealDbClient::memory(namespace, database) - .await - .expect("Failed to start in-memory surrealdb"); + let db = SurrealDbClient::memory(namespace, database).await?; // Initialize analytics - Analytics::ensure_initialized(&db) - .await - .expect("Failed to initialize analytics"); + Analytics::ensure_initialized(&db).await?; // Test increment_page_loads method - let analytics = Analytics::increment_page_loads(&db) - .await - .expect("Failed to increment page loads"); + let analytics = Analytics::increment_page_loads(&db).await?; assert_eq!(analytics.visitors, 0); assert_eq!(analytics.page_loads, 1); // Increment again and check - let analytics = Analytics::increment_page_loads(&db) - .await - .expect("Failed to increment page loads again"); + let analytics = Analytics::increment_page_loads(&db).await?; assert_eq!(analytics.visitors, 0); assert_eq!(analytics.page_loads, 2); + + Ok(()) } #[tokio::test] - async fn test_get_users_amount() { + async fn test_get_users_amount() -> anyhow::Result<()> { // Setup in-memory database for testing let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); - let db = SurrealDbClient::memory(namespace, database) - .await - .expect("Failed to start in-memory surrealdb"); + let db = SurrealDbClient::memory(namespace, database).await?; // Test with no users - let count = Analytics::get_users_amount(&db) - .await - .expect("Failed to get users amount"); + let count = Analytics::get_users_amount(&db).await?; assert_eq!(count, 0); // Create a few test users for i in 0..3 { let user = TestUser { - id: format!("user{}", i), - email: format!("user{}@example.com", i), + id: format!("user{i}"), + email: format!("user{i}@example.com"), password: "password".to_string(), - user_id: format!("uid{}", i), + user_id: format!("uid{i}"), created_at: Utc::now(), updated_at: Utc::now(), }; - db.store_item(user) - .await - .expect("Failed to create test user"); + db.store_item(user).await?; } // Test users amount after adding users - let count = Analytics::get_users_amount(&db) - .await - .expect("Failed to get users amount after adding users"); + let count = Analytics::get_users_amount(&db).await?; assert_eq!(count, 3); + + Ok(()) } #[tokio::test] - async fn test_get_current_nonexistent() { + async fn test_get_current_nonexistent() -> anyhow::Result<()> { // Setup in-memory database for testing let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); - let db = SurrealDbClient::memory(namespace, database) - .await - .expect("Failed to start in-memory surrealdb"); + let db = SurrealDbClient::memory(namespace, database).await?; // Don't initialize analytics and try to get it let result = Analytics::get_current(&db).await; assert!(result.is_err()); - if let Err(err) = result { - match err { - AppError::NotFound(_) => { - // Expected error - } - _ => panic!("Expected NotFound error, got: {:?}", err), - } + match result { + Ok(_) => anyhow::bail!("Expected NotFound error, got success"), + Err(AppError::NotFound(_)) => {} + Err(err) => anyhow::bail!("Expected NotFound error, got: {err:?}"), } + + Ok(()) } } diff --git a/common/src/storage/types/conversation.rs b/common/src/storage/types/conversation.rs index d3d640c..6cb1975 100644 --- a/common/src/storage/types/conversation.rs +++ b/common/src/storage/types/conversation.rs @@ -144,76 +144,71 @@ impl Conversation { #[cfg(test)] mod tests { + use anyhow::{self, Context}; use crate::storage::types::message::MessageRole; use super::*; #[tokio::test] - async fn test_create_conversation() { - // Setup in-memory database for testing + async fn test_create_conversation() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; - // Create a new conversation let user_id = "test_user"; let title = "Test Conversation"; let conversation = Conversation::new(user_id.to_string(), title.to_string()); - // Verify conversation properties assert_eq!(conversation.user_id, user_id); assert_eq!(conversation.title, title); assert!(!conversation.id.is_empty()); - // Store the conversation let result = db.store_item(conversation.clone()).await; assert!(result.is_ok()); - // Verify it can be retrieved let retrieved: Option = db .get_item(&conversation.id) .await - .expect("Failed to retrieve conversation"); - assert!(retrieved.is_some()); + .with_context(|| "Failed to retrieve conversation".to_string())?; - let retrieved = retrieved.unwrap(); + let retrieved = retrieved.ok_or_else(|| anyhow::anyhow!("Expected conversation to exist"))?; assert_eq!(retrieved.id, conversation.id); assert_eq!(retrieved.user_id, user_id); assert_eq!(retrieved.title, title); + + Ok(()) } #[tokio::test] - async fn test_get_complete_conversation_not_found() { - // Setup in-memory database for testing + async fn test_get_complete_conversation_not_found() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; - // Try to get a conversation that doesn't exist let result = Conversation::get_complete_conversation("nonexistent_id", "test_user", &db).await; assert!(result.is_err()); match result { - Err(AppError::NotFound(_)) => { /* expected error */ } - _ => panic!("Expected NotFound error"), + Err(AppError::NotFound(_)) => {} + _ => anyhow::bail!("Expected NotFound error"), } + + Ok(()) } #[tokio::test] - async fn test_get_complete_conversation_unauthorized() { - // Setup in-memory database for testing + async fn test_get_complete_conversation_unauthorized() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; - // Create and store a conversation for user_id_1 let user_id_1 = "user_1"; let conversation = Conversation::new(user_id_1.to_string(), "Private Conversation".to_string()); @@ -221,27 +216,28 @@ mod tests { db.store_item(conversation) .await - .expect("Failed to store conversation"); + .with_context(|| "Failed to store conversation".to_string())?; - // Try to access with a different user let user_id_2 = "user_2"; let result = Conversation::get_complete_conversation(&conversation_id, user_id_2, &db).await; assert!(result.is_err()); match result { - Err(AppError::Auth(_)) => { /* expected error */ } - _ => panic!("Expected Auth error"), + Err(AppError::Auth(_)) => {} + _ => anyhow::bail!("Expected Auth error"), } + + Ok(()) } #[tokio::test] - async fn test_patch_title_success() { + async fn test_patch_title_success() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; let user_id = "user_1"; let original_title = "Original Title"; @@ -250,49 +246,50 @@ mod tests { db.store_item(conversation) .await - .expect("Failed to store conversation"); + .with_context(|| "Failed to store conversation".to_string())?; let new_title = "Updated Title"; - // Patch title successfully let result = Conversation::patch_title(&conversation_id, user_id, new_title, &db).await; assert!(result.is_ok()); - // Retrieve from DB to verify let updated_conversation = db .get_item::(&conversation_id) .await - .expect("Failed to get conversation") - .expect("Conversation missing"); + .with_context(|| "Failed to get conversation".to_string())? + .ok_or_else(|| anyhow::anyhow!("Conversation missing"))?; assert_eq!(updated_conversation.title, new_title); assert_eq!(updated_conversation.user_id, user_id); + + Ok(()) } #[tokio::test] - async fn test_patch_title_not_found() { + async fn test_patch_title_not_found() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; - // Try to patch non-existing conversation let result = Conversation::patch_title("nonexistent", "user_x", "New Title", &db).await; assert!(result.is_err()); match result { Err(AppError::NotFound(_)) => {} - _ => panic!("Expected NotFound error"), + _ => anyhow::bail!("Expected NotFound error"), } + + Ok(()) } #[tokio::test] - async fn test_patch_title_unauthorized() { + async fn test_patch_title_unauthorized() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; let owner_id = "owner"; let other_user_id = "intruder"; @@ -301,17 +298,18 @@ mod tests { db.store_item(conversation) .await - .expect("Failed to store conversation"); + .with_context(|| "Failed to store conversation".to_string())?; - // Attempt patch with unauthorized user let result = Conversation::patch_title(&conversation_id, other_user_id, "Hacked Title", &db).await; assert!(result.is_err()); match result { Err(AppError::Auth(_)) => {} - _ => panic!("Expected Auth error"), + _ => anyhow::bail!("Expected Auth error"), } + + Ok(()) } #[tokio::test] @@ -405,24 +403,21 @@ mod tests { } #[tokio::test] - async fn test_get_complete_conversation_with_messages() { - // Setup in-memory database for testing + async fn test_get_complete_conversation_with_messages() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; - // Create and store a conversation for user_id_1 let user_id_1 = "user_1"; let conversation = Conversation::new(user_id_1.to_string(), "Conversation".to_string()); let conversation_id = conversation.id.clone(); db.store_item(conversation) .await - .expect("Failed to store conversation"); + .with_context(|| "Failed to store conversation".to_string())?; - // Create messages let message1 = Message::new( conversation_id.clone(), MessageRole::User, @@ -442,46 +437,44 @@ mod tests { None, ); - // Store messages db.store_item(message1) .await - .expect("Failed to store message1"); + .with_context(|| "Failed to store message1".to_string())?; db.store_item(message2) .await - .expect("Failed to store message2"); + .with_context(|| "Failed to store message2".to_string())?; db.store_item(message3) .await - .expect("Failed to store message3"); + .with_context(|| "Failed to store message3".to_string())?; - // Retrieve the complete conversation let result = Conversation::get_complete_conversation(&conversation_id, user_id_1, &db).await; assert!(result.is_ok(), "Failed to retrieve complete conversation"); - let (retrieved_conversation, messages) = result.unwrap(); + let (retrieved_conversation, retrieved_messages) = result + .with_context(|| "Failed to retrieve complete conversation".to_string())?; - // Verify conversation data assert_eq!(retrieved_conversation.id, conversation_id); assert_eq!(retrieved_conversation.user_id, user_id_1); assert_eq!(retrieved_conversation.title, "Conversation"); - // Verify messages - assert_eq!(messages.len(), 3); + assert_eq!(retrieved_messages.len(), 3); - // Verify messages are sorted by updated_at - let message_contents: Vec<&str> = messages.iter().map(|m| m.content.as_str()).collect(); + let message_contents: Vec<&str> = + retrieved_messages.iter().map(|m| m.content.as_str()).collect(); assert!(message_contents.contains(&"Hello, AI!")); assert!(message_contents.contains(&"Hello, human! How can I help you today?")); assert!(message_contents.contains(&"Tell me about Rust programming.")); - // Make sure we can't access with different user let user_id_2 = "user_2"; let unauthorized_result = Conversation::get_complete_conversation(&conversation_id, user_id_2, &db).await; assert!(unauthorized_result.is_err()); match unauthorized_result { - Err(AppError::Auth(_)) => { /* expected error */ } - _ => panic!("Expected Auth error"), + Err(AppError::Auth(_)) => {} + _ => anyhow::bail!("Expected Auth error"), } + + Ok(()) } } diff --git a/common/src/storage/types/file_info.rs b/common/src/storage/types/file_info.rs index 37448d2..5a7fab5 100644 --- a/common/src/storage/types/file_info.rs +++ b/common/src/storage/types/file_info.rs @@ -320,6 +320,8 @@ impl FileInfo { #[cfg(test)] mod tests { + use anyhow::{self, Context}; + use super::*; use crate::storage::store::testing::TestStorageManager; use axum::http::HeaderMap; @@ -328,11 +330,11 @@ mod tests { use tempfile::NamedTempFile; /// Creates a test temporary file with the given content - fn create_test_file(content: &[u8], file_name: &str) -> FieldData { - let mut temp_file = NamedTempFile::new().expect("Failed to create temp file"); + fn create_test_file(content: &[u8], file_name: &str) -> anyhow::Result> { + let mut temp_file = NamedTempFile::new().with_context(|| "Failed to create temp file".to_string())?; temp_file .write_all(content) - .expect("Failed to write to temp file"); + .with_context(|| "Failed to write to temp file".to_string())?; let metadata = FieldMetadata { name: Some("file".to_string()), @@ -341,31 +343,29 @@ mod tests { headers: HeaderMap::default(), }; - let field_data = FieldData { + Ok(FieldData { metadata, contents: temp_file, - }; - - field_data + }) } #[tokio::test] - async fn test_fileinfo_create_read_delete_with_storage_manager() { + async fn test_fileinfo_create_read_delete_with_storage_manager() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); - db.apply_migrations().await.unwrap(); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; + db.apply_migrations().await.with_context(|| "migrations".to_string())?; let content = b"This is a test file for StorageManager operations"; let file_name = "storage_manager_test.txt"; - let field_data = create_test_file(content, file_name); + let field_data = create_test_file(content, file_name)?; // Create test storage manager (memory backend) let test_storage = store::testing::TestStorageManager::new_memory() .await - .expect("Failed to create test storage manager"); + .with_context(|| "Failed to create test storage manager".to_string())?; // Create a FileInfo instance with storage manager let user_id = "test_user"; @@ -374,20 +374,20 @@ mod tests { let file_info = FileInfo::new_with_storage(field_data, &db, user_id, test_storage.storage()) .await - .expect("Failed to create file with StorageManager"); + .with_context(|| "Failed to create file with StorageManager".to_string())?; assert_eq!(file_info.file_name, file_name); // Verify the file exists via StorageManager and has correct content let bytes = file_info .get_content_with_storage(test_storage.storage()) .await - .expect("Failed to read file content via StorageManager"); + .with_context(|| "Failed to read file content via StorageManager".to_string())?; assert_eq!(bytes.as_ref(), content); // Test file reading let retrieved = FileInfo::get_by_id(&file_info.id, &db) .await - .expect("Failed to retrieve file info"); + .with_context(|| "Failed to retrieve file info".to_string())?; assert_eq!(retrieved.id, file_info.id); assert_eq!(retrieved.sha256, file_info.sha256); assert_eq!(retrieved.file_name, file_name); @@ -395,65 +395,65 @@ mod tests { // Test file deletion with StorageManager FileInfo::delete_by_id_with_storage(&file_info.id, &db, test_storage.storage()) .await - .expect("Failed to delete file with StorageManager"); + .with_context(|| "Failed to delete file with StorageManager".to_string())?; let deleted_result = file_info .get_content_with_storage(test_storage.storage()) .await; assert!(deleted_result.is_err(), "File should be deleted"); - - // No cleanup needed - TestStorageManager handles it automatically + Ok(()) } #[tokio::test] - async fn test_fileinfo_preserves_original_filename_and_sanitizes_path() { + async fn test_fileinfo_preserves_original_filename_and_sanitizes_path() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); - db.apply_migrations().await.unwrap(); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; + db.apply_migrations().await.with_context(|| "migrations".to_string())?; let content = b"filename sanitization"; let original_name = "Complex name (1).txt"; let expected_sanitized = "Complex_name__1_.txt"; - let field_data = create_test_file(content, original_name); + let field_data = create_test_file(content, original_name)?; let test_storage = store::testing::TestStorageManager::new_memory() .await - .expect("Failed to create test storage manager"); + .with_context(|| "Failed to create test storage manager".to_string())?; let file_info = FileInfo::new_with_storage(field_data, &db, "sanitized_user", test_storage.storage()) .await - .expect("Failed to create file via storage manager"); + .with_context(|| "Failed to create file via storage manager".to_string())?; assert_eq!(file_info.file_name, original_name); let stored_name = Path::new(&file_info.path) .file_name() .and_then(|name| name.to_str()) - .expect("stored name"); + .with_context(|| "stored name".to_string())?; assert_eq!(stored_name, expected_sanitized); + Ok(()) } #[tokio::test] - async fn test_fileinfo_duplicate_detection_with_storage_manager() { + async fn test_fileinfo_duplicate_detection_with_storage_manager() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); - db.apply_migrations().await.unwrap(); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; + db.apply_migrations().await.with_context(|| "migrations".to_string())?; let content = b"This is a test file for StorageManager duplicate detection"; let file_name = "storage_manager_duplicate.txt"; - let field_data = create_test_file(content, file_name); + let field_data = create_test_file(content, file_name)?; // Create test storage manager let test_storage = store::testing::TestStorageManager::new_memory() .await - .expect("Failed to create test storage manager"); + .with_context(|| "Failed to create test storage manager".to_string())?; // Create a FileInfo instance with storage manager let user_id = "test_user"; @@ -462,17 +462,17 @@ mod tests { let original_file_info = FileInfo::new_with_storage(field_data, &db, user_id, test_storage.storage()) .await - .expect("Failed to create original file with StorageManager"); + .with_context(|| "Failed to create original file with StorageManager".to_string())?; // Create another file with the same content but different name let duplicate_name = "storage_manager_duplicate_2.txt"; - let field_data2 = create_test_file(content, duplicate_name); + let field_data2 = create_test_file(content, duplicate_name)?; // The system should detect it's the same file and return the original FileInfo let duplicate_file_info = FileInfo::new_with_storage(field_data2, &db, user_id, test_storage.storage()) .await - .expect("Failed to process duplicate file with StorageManager"); + .with_context(|| "Failed to process duplicate file with StorageManager".to_string())?; // Verify duplicate detection worked assert_eq!(duplicate_file_info.id, original_file_info.id); @@ -484,46 +484,44 @@ mod tests { let original_content = original_file_info .get_content_with_storage(test_storage.storage()) .await - .unwrap(); + .with_context(|| "get original content".to_string())?; let duplicate_content = duplicate_file_info .get_content_with_storage(test_storage.storage()) .await - .unwrap(); + .with_context(|| "get duplicate content".to_string())?; assert_eq!(original_content.as_ref(), content); assert_eq!(duplicate_content.as_ref(), content); // Clean up FileInfo::delete_by_id_with_storage(&original_file_info.id, &db, test_storage.storage()) .await - .expect("Failed to delete original file with StorageManager"); + .with_context(|| "Failed to delete original file with StorageManager".to_string())?; + Ok(()) } #[tokio::test] - async fn test_file_creation() { + async fn test_file_creation() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; db.apply_migrations() .await - .expect("Failed to apply migrations"); + .with_context(|| "Failed to apply migrations".to_string())?; let content = b"This is a test file content"; let file_name = "test_file.txt"; - let field_data = create_test_file(content, file_name); + let field_data = create_test_file(content, file_name)?; // Create a FileInfo instance with StorageManager let user_id = "test_user"; let test_storage = TestStorageManager::new_memory() .await - .expect("create test storage manager"); + .with_context(|| "create test storage manager".to_string())?; let file_info = - FileInfo::new_with_storage(field_data, &db, user_id, test_storage.storage()).await; - - // Verify the FileInfo was created successfully - assert!(file_info.is_ok()); - let file_info = file_info.unwrap(); + FileInfo::new_with_storage(field_data, &db, user_id, test_storage.storage()) + .await?; // Check essential properties assert!(!file_info.id.is_empty()); @@ -533,32 +531,32 @@ mod tests { // path should be logical: "user_id/uuid/file_name" let parts: Vec<&str> = file_info.path.split('/').collect(); assert_eq!(parts.len(), 3); - assert_eq!(parts[0], user_id); - assert_eq!(parts[2], file_name); + assert_eq!(parts.first(), Some(&user_id)); + assert_eq!(parts.get(2), Some(&file_name)); assert!(file_info.mime_type.contains("text/plain")); // Verify it's in the database - let stored: Option = db - .get_item(&file_info.id) + let stored = db + .get_item::(&file_info.id) .await - .expect("Failed to retrieve file info"); - assert!(stored.is_some()); - let stored = stored.unwrap(); + .with_context(|| "Failed to retrieve file info".to_string())? + .with_context(|| "expected stored file".to_string())?; assert_eq!(stored.id, file_info.id); assert_eq!(stored.file_name, file_info.file_name); assert_eq!(stored.sha256, file_info.sha256); + Ok(()) } #[tokio::test] - async fn test_file_duplicate_detection() { + async fn test_file_duplicate_detection() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; db.apply_migrations() .await - .expect("Failed to apply migrations"); + .with_context(|| "Failed to apply migrations".to_string())?; // First, store a file with known content let content = b"This is a test file for duplicate detection"; @@ -567,23 +565,23 @@ mod tests { let test_storage = TestStorageManager::new_memory() .await - .expect("create test storage manager"); + .with_context(|| "create test storage manager".to_string())?; - let field_data1 = create_test_file(content, file_name); + let field_data1 = create_test_file(content, file_name)?; let original_file_info = FileInfo::new_with_storage(field_data1, &db, user_id, test_storage.storage()) .await - .expect("Failed to create original file"); + .with_context(|| "Failed to create original file".to_string())?; // Now try to store another file with the same content but different name let duplicate_name = "duplicate.txt"; - let field_data2 = create_test_file(content, duplicate_name); + let field_data2 = create_test_file(content, duplicate_name)?; // The system should detect it's the same file and return the original FileInfo let duplicate_file_info = FileInfo::new_with_storage(field_data2, &db, user_id, test_storage.storage()) .await - .expect("Failed to process duplicate file"); + .with_context(|| "Failed to process duplicate file".to_string())?; // The returned FileInfo should match the original assert_eq!(duplicate_file_info.id, original_file_info.id); @@ -592,10 +590,11 @@ mod tests { // But it should retain the original file name, not the duplicate's name assert_eq!(duplicate_file_info.file_name, file_name); assert_ne!(duplicate_file_info.file_name, duplicate_name); + Ok(()) } #[tokio::test] - async fn test_guess_mime_type() { + async fn test_guess_mime_type() -> anyhow::Result<()> { // Test common file extensions assert_eq!( FileInfo::guess_mime_type(Path::new("test.txt")), @@ -619,10 +618,11 @@ mod tests { FileInfo::guess_mime_type(Path::new("unknown.929yz")), "application/octet-stream".to_string() ); + Ok(()) } #[tokio::test] - async fn test_sanitize_file_name() { + async fn test_sanitize_file_name() -> anyhow::Result<()> { // Safe characters should remain unchanged assert_eq!( FileInfo::sanitize_file_name("normal_file.txt"), @@ -647,26 +647,26 @@ mod tests { FileInfo::sanitize_file_name("../dangerous.txt"), "___dangerous.txt" ); + Ok(()) } #[tokio::test] - async fn test_get_by_sha_not_found() { + async fn test_get_by_sha_not_found() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; // Try to find a file with a SHA that doesn't exist let result = FileInfo::get_by_sha("nonexistent_sha_hash", &db).await; assert!(result.is_err()); match result { - Err(FileError::FileNotFound(_)) => { - // Expected error - } - _ => panic!("Expected FileNotFound error"), + Err(FileError::FileNotFound(_)) => {} + _ => anyhow::bail!("Expected FileNotFound error"), } + Ok(()) } #[tokio::test] @@ -705,7 +705,7 @@ mod tests { let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; // Create a FileInfo instance directly let now = Utc::now(); @@ -725,40 +725,39 @@ mod tests { assert!(result.is_ok()); // Verify it can be retrieved - let retrieved: Option = db - .get_item(&file_info.id) + let retrieved = db + .get_item::(&file_info.id) .await - .expect("Failed to retrieve file info"); - assert!(retrieved.is_some()); - - let retrieved = retrieved.unwrap(); + .with_context(|| "Failed to retrieve file info".to_string())? + .with_context(|| "expected file".to_string())?; assert_eq!(retrieved.id, file_info.id); assert_eq!(retrieved.sha256, file_info.sha256); assert_eq!(retrieved.file_name, file_info.file_name); assert_eq!(retrieved.path, file_info.path); assert_eq!(retrieved.mime_type, file_info.mime_type); + Ok(()) } #[tokio::test] - async fn test_delete_by_id() { + async fn test_delete_by_id() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; db.apply_migrations() .await - .expect("Failed to apply migrations"); + .with_context(|| "Failed to apply migrations".to_string())?; // Create and persist a test file via FileInfo::new_with_storage let user_id = "user123"; let test_storage = TestStorageManager::new_memory() .await - .expect("create test storage manager"); - let temp = create_test_file(b"test content", "test_file.txt"); + .with_context(|| "create test storage manager".to_string())?; + let temp = create_test_file(b"test content", "test_file.txt")?; let file_info = FileInfo::new_with_storage(temp, &db, user_id, test_storage.storage()) .await - .expect("create file"); + .with_context(|| "create file".to_string())?; // Delete the file using StorageManager let delete_result = @@ -767,15 +766,14 @@ mod tests { // Delete should be successful assert!( delete_result.is_ok(), - "Failed to delete file: {:?}", - delete_result + "Failed to delete file: {delete_result:?}" ); // Verify the file is removed from the database let retrieved: Option = db .get_item(&file_info.id) .await - .expect("Failed to query database"); + .with_context(|| "Failed to query database".to_string())?; assert!( retrieved.is_none(), "FileInfo should be deleted from the database" @@ -783,32 +781,37 @@ mod tests { // Verify content no longer retrievable from storage assert!(test_storage.storage().get(&file_info.path).await.is_err()); + Ok(()) } #[tokio::test] - async fn test_delete_by_id_not_found() { + async fn test_delete_by_id_not_found() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; // Try to delete a file that doesn't exist - let test_storage = TestStorageManager::new_memory().await.unwrap(); + let test_storage = TestStorageManager::new_memory() + .await + .with_context(|| "create test storage manager".to_string())?; let result = FileInfo::delete_by_id_with_storage("nonexistent_id", &db, test_storage.storage()) .await; // Should succeed even if the file record does not exist assert!(result.is_ok()); + Ok(()) } + #[tokio::test] - async fn test_get_by_id() { + async fn test_get_by_id() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; // Create a FileInfo instance directly let now = Utc::now(); @@ -827,28 +830,27 @@ mod tests { // Store it in the database db.store_item(original_file_info.clone()) .await - .expect("Failed to store item for get_by_id test"); + .with_context(|| "Failed to store item for get_by_id test".to_string())?; // Retrieve it using get_by_id - let result = FileInfo::get_by_id(&file_id, &db).await; - - // Assert success and content match - assert!(result.is_ok()); - let retrieved_info = result.unwrap(); + let retrieved_info = FileInfo::get_by_id(&file_id, &db) + .await + .with_context(|| "get_by_id".to_string())?; assert_eq!(retrieved_info.id, original_file_info.id); assert_eq!(retrieved_info.sha256, original_file_info.sha256); assert_eq!(retrieved_info.file_name, original_file_info.file_name); assert_eq!(retrieved_info.path, original_file_info.path); assert_eq!(retrieved_info.mime_type, original_file_info.mime_type); + Ok(()) } #[tokio::test] - async fn test_get_by_id_not_found() { + async fn test_get_by_id_not_found() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; // Try to retrieve a non-existent ID let non_existent_id = "non-existent-file-id"; @@ -862,33 +864,34 @@ mod tests { Err(FileError::FileNotFound(id)) => { assert_eq!(id, non_existent_id); } - Err(e) => panic!("Expected FileNotFound error, but got {:?}", e), - Ok(_) => panic!("Expected an error, but got Ok"), + Err(e) => anyhow::bail!("Expected FileNotFound error, but got {e:?}"), + Ok(_) => anyhow::bail!("Expected an error, but got Ok"), } + Ok(()) } // StorageManager-based tests #[tokio::test] - async fn test_file_info_new_with_storage_memory() { + async fn test_file_info_new_with_storage_memory() -> anyhow::Result<()> { // Setup let db = SurrealDbClient::memory("test_ns", "test_file_storage_memory") .await - .unwrap(); - db.apply_migrations().await.unwrap(); + .with_context(|| "Failed to start DB".to_string())?; + db.apply_migrations().await.with_context(|| "migrations".to_string())?; let content = b"This is a test file for StorageManager"; - let field_data = create_test_file(content, "test_storage.txt"); + let field_data = create_test_file(content, "test_storage.txt")?; let user_id = "test_user"; // Create test storage manager let storage = store::testing::TestStorageManager::new_memory() .await - .unwrap(); + .with_context(|| "create test storage".to_string())?; // Test file creation with StorageManager let file_info = FileInfo::new_with_storage(field_data, &db, user_id, storage.storage()) .await - .expect("Failed to create file with StorageManager"); + .with_context(|| "Failed to create file with StorageManager".to_string())?; // Verify the file was created correctly assert_eq!(file_info.user_id, user_id); @@ -900,40 +903,41 @@ mod tests { let retrieved_content = file_info .get_content_with_storage(storage.storage()) .await - .expect("Failed to get file content with StorageManager"); + .with_context(|| "Failed to get file content with StorageManager".to_string())?; assert_eq!(retrieved_content.as_ref(), content); // Test file deletion with StorageManager FileInfo::delete_by_id_with_storage(&file_info.id, &db, storage.storage()) .await - .expect("Failed to delete file with StorageManager"); + .with_context(|| "Failed to delete file with StorageManager".to_string())?; // Verify file is deleted let deleted_content_result = file_info.get_content_with_storage(storage.storage()).await; assert!(deleted_content_result.is_err()); + Ok(()) } #[tokio::test] - async fn test_file_info_new_with_storage_local() { + async fn test_file_info_new_with_storage_local() -> anyhow::Result<()> { // Setup let db = SurrealDbClient::memory("test_ns", "test_file_storage_local") .await - .unwrap(); - db.apply_migrations().await.unwrap(); + .with_context(|| "Failed to start DB".to_string())?; + db.apply_migrations().await.with_context(|| "migrations".to_string())?; let content = b"This is a test file for StorageManager with local storage"; - let field_data = create_test_file(content, "test_local.txt"); + let field_data = create_test_file(content, "test_local.txt")?; let user_id = "test_user"; // Create test storage manager with local backend let storage = store::testing::TestStorageManager::new_local() .await - .unwrap(); + .with_context(|| "create test storage".to_string())?; // Test file creation with StorageManager let file_info = FileInfo::new_with_storage(field_data, &db, user_id, storage.storage()) .await - .expect("Failed to create file with StorageManager"); + .with_context(|| "Failed to create file with StorageManager".to_string())?; // Verify the file was created correctly assert_eq!(file_info.user_id, user_id); @@ -945,50 +949,51 @@ mod tests { let retrieved_content = file_info .get_content_with_storage(storage.storage()) .await - .expect("Failed to get file content with StorageManager"); + .with_context(|| "Failed to get file content with StorageManager".to_string())?; assert_eq!(retrieved_content.as_ref(), content); // Test file deletion with StorageManager FileInfo::delete_by_id_with_storage(&file_info.id, &db, storage.storage()) .await - .expect("Failed to delete file with StorageManager"); + .with_context(|| "Failed to delete file with StorageManager".to_string())?; // Verify file is deleted let deleted_content_result = file_info.get_content_with_storage(storage.storage()).await; assert!(deleted_content_result.is_err()); + Ok(()) } #[tokio::test] - async fn test_file_info_storage_manager_persistence() { + async fn test_file_info_storage_manager_persistence() -> anyhow::Result<()> { // Setup let db = SurrealDbClient::memory("test_ns", "test_file_persistence") .await - .unwrap(); - db.apply_migrations().await.unwrap(); + .with_context(|| "Failed to start DB".to_string())?; + db.apply_migrations().await.with_context(|| "migrations".to_string())?; let content = b"Test content for persistence"; - let field_data = create_test_file(content, "persistence_test.txt"); + let field_data = create_test_file(content, "persistence_test.txt")?; let user_id = "test_user"; // Create test storage manager let storage = store::testing::TestStorageManager::new_memory() .await - .unwrap(); + .with_context(|| "create test storage".to_string())?; // Create file let file_info = FileInfo::new_with_storage(field_data, &db, user_id, storage.storage()) .await - .expect("Failed to create file"); + .with_context(|| "Failed to create file".to_string())?; // Test that data persists across multiple operations with the same StorageManager let retrieved_content_1 = file_info .get_content_with_storage(storage.storage()) .await - .unwrap(); + .with_context(|| "get content 1".to_string())?; let retrieved_content_2 = file_info .get_content_with_storage(storage.storage()) .await - .unwrap(); + .with_context(|| "get content 2".to_string())?; assert_eq!(retrieved_content_1.as_ref(), content); assert_eq!(retrieved_content_2.as_ref(), content); @@ -996,68 +1001,70 @@ mod tests { // Test that different StorageManager instances don't share data (memory storage isolation) let storage2 = store::testing::TestStorageManager::new_memory() .await - .unwrap(); + .with_context(|| "create second storage".to_string())?; let isolated_content_result = file_info.get_content_with_storage(storage2.storage()).await; assert!( isolated_content_result.is_err(), "Different StorageManager should not have access to same data" ); + Ok(()) } #[tokio::test] - async fn test_file_info_storage_manager_equivalence() { + async fn test_file_info_storage_manager_equivalence() -> anyhow::Result<()> { // Setup let db = SurrealDbClient::memory("test_ns", "test_file_equivalence") .await - .unwrap(); - db.apply_migrations().await.unwrap(); + .with_context(|| "Failed to start DB".to_string())?; + db.apply_migrations().await.with_context(|| "migrations".to_string())?; let content = b"Test content for equivalence testing"; - let field_data1 = create_test_file(content, "equivalence_test_1.txt"); - let field_data2 = create_test_file(content, "equivalence_test_2.txt"); + let field_data1 = create_test_file(content, "equivalence_test_1.txt")?; + let field_data2 = create_test_file(content, "equivalence_test_2.txt")?; let user_id = "test_user"; // Create single storage manager and reuse it let storage_manager = store::testing::TestStorageManager::new_memory() .await - .unwrap(); + .with_context(|| "create storage".to_string())?; let storage = storage_manager.storage(); // Create multiple files with the same storage manager - let file_info_1 = FileInfo::new_with_storage(field_data1, &db, user_id, &storage) + let file_info_1 = FileInfo::new_with_storage(field_data1, &db, user_id, storage) .await - .expect("Failed to create file 1"); + .with_context(|| "Failed to create file 1".to_string())?; - let file_info_2 = FileInfo::new_with_storage(field_data2, &db, user_id, &storage) + let file_info_2 = FileInfo::new_with_storage(field_data2, &db, user_id, storage) .await - .expect("Failed to create file 2"); + .with_context(|| "Failed to create file 2".to_string())?; // Test that both files can be retrieved with the same storage backend let content_1 = file_info_1 - .get_content_with_storage(&storage) + .get_content_with_storage(storage) .await - .unwrap(); + .with_context(|| "get file 1 content".to_string())?; let content_2 = file_info_2 - .get_content_with_storage(&storage) + .get_content_with_storage(storage) .await - .unwrap(); + .with_context(|| "get file 2 content".to_string())?; assert_eq!(content_1.as_ref(), content); assert_eq!(content_2.as_ref(), content); // Test that files can be deleted with the same storage manager - FileInfo::delete_by_id_with_storage(&file_info_1.id, &db, &storage) + FileInfo::delete_by_id_with_storage(&file_info_1.id, &db, storage) .await - .unwrap(); - FileInfo::delete_by_id_with_storage(&file_info_2.id, &db, &storage) + .with_context(|| "delete file 1".to_string())?; + FileInfo::delete_by_id_with_storage(&file_info_2.id, &db, storage) .await - .unwrap(); + .with_context(|| "delete file 2".to_string())?; // Verify files are deleted - let deleted_content_1 = file_info_1.get_content_with_storage(&storage).await; - let deleted_content_2 = file_info_2.get_content_with_storage(&storage).await; + let deleted_content_1 = file_info_1.get_content_with_storage(storage).await; + let deleted_content_2 = file_info_2.get_content_with_storage(storage).await; assert!(deleted_content_1.is_err()); assert!(deleted_content_2.is_err()); + Ok(()) } } diff --git a/common/src/storage/types/ingestion_payload.rs b/common/src/storage/types/ingestion_payload.rs index 828a805..2adc942 100644 --- a/common/src/storage/types/ingestion_payload.rs +++ b/common/src/storage/types/ingestion_payload.rs @@ -103,6 +103,7 @@ impl IngestionPayload { #[cfg(test)] mod tests { + use anyhow::{self, Context}; use chrono::Utc; use super::*; @@ -131,7 +132,7 @@ mod tests { } #[test] - fn test_create_ingestion_payload_with_url() { + fn test_create_ingestion_payload_with_url() -> anyhow::Result<()> { let url = "https://example.com"; let context = "Process this URL"; let category = "websites"; @@ -145,10 +146,10 @@ mod tests { files, user_id, ) - .unwrap(); + .with_context(|| "create_ingestion_payload".to_string())?; assert_eq!(result.len(), 1); - match &result[0] { + match result.first().context("expected one result")? { IngestionPayload::Url { url: payload_url, context: payload_context, @@ -156,17 +157,18 @@ mod tests { user_id: payload_user_id, } => { // URL parser may normalize the URL by adding a trailing slash - assert!(payload_url == &url.to_string() || payload_url == &format!("{}/", url)); + assert!(payload_url == &url.to_string() || payload_url == &format!("{url}/")); assert_eq!(payload_context, &context); assert_eq!(payload_category, &category); assert_eq!(payload_user_id, &user_id); } - _ => panic!("Expected Url variant"), + _ => anyhow::bail!("Expected Url variant"), } + Ok(()) } #[test] - fn test_create_ingestion_payload_with_text() { + fn test_create_ingestion_payload_with_text() -> anyhow::Result<()> { let text = "This is some text content"; let context = "Process this text"; let category = "notes"; @@ -180,10 +182,10 @@ mod tests { files, user_id, ) - .unwrap(); + .with_context(|| "create_ingestion_payload".to_string())?; assert_eq!(result.len(), 1); - match &result[0] { + match result.first().context("expected one result")? { IngestionPayload::Text { text: payload_text, context: payload_context, @@ -195,12 +197,13 @@ mod tests { assert_eq!(payload_category, category); assert_eq!(payload_user_id, user_id); } - _ => panic!("Expected Text variant"), + _ => anyhow::bail!("Expected Text variant"), } + Ok(()) } #[test] - fn test_create_ingestion_payload_with_file() { + fn test_create_ingestion_payload_with_file() -> anyhow::Result<()> { let context = "Process this file"; let category = "documents"; let user_id = "user123"; @@ -220,10 +223,10 @@ mod tests { files, user_id, ) - .unwrap(); + .with_context(|| "create_ingestion_payload".to_string())?; assert_eq!(result.len(), 1); - match &result[0] { + match result.first().context("expected one result")? { IngestionPayload::File { file_info: payload_file_info, context: payload_context, @@ -235,12 +238,13 @@ mod tests { assert_eq!(payload_category, category); assert_eq!(payload_user_id, user_id); } - _ => panic!("Expected File variant"), + _ => anyhow::bail!("Expected File variant"), } + Ok(()) } #[test] - fn test_create_ingestion_payload_with_url_and_file() { + fn test_create_ingestion_payload_with_url_and_file() -> anyhow::Result<()> { let url = "https://example.com"; let context = "Process this data"; let category = "mixed"; @@ -261,35 +265,36 @@ mod tests { files, user_id, ) - .unwrap(); + .with_context(|| "create_ingestion_payload".to_string())?; assert_eq!(result.len(), 2); // Check first item is URL - match &result[0] { + match result.first().context("expected first item")? { IngestionPayload::Url { url: payload_url, .. } => { // URL parser may normalize the URL by adding a trailing slash - assert!(payload_url == &url.to_string() || payload_url == &format!("{}/", url)); + assert!(payload_url == &url.to_string() || payload_url == &format!("{url}/")); } - _ => panic!("Expected first item to be Url variant"), + _ => anyhow::bail!("Expected first item to be Url variant"), } // Check second item is File - match &result[1] { + match result.get(1).context("expected second item")? { IngestionPayload::File { file_info: payload_file_info, .. } => { assert_eq!(payload_file_info.id, file_info.id); } - _ => panic!("Expected second item to be File variant"), + _ => anyhow::bail!("Expected second item to be File variant"), } + Ok(()) } #[test] - fn test_create_ingestion_payload_empty_input() { + fn test_create_ingestion_payload_empty_input() -> anyhow::Result<()> { let context = "Process something"; let category = "empty"; let user_id = "user123"; @@ -308,12 +313,13 @@ mod tests { Err(AppError::NotFound(msg)) => { assert_eq!(msg, "No valid content or files provided"); } - _ => panic!("Expected NotFound error"), + _ => anyhow::bail!("Expected NotFound error"), } + Ok(()) } #[test] - fn test_create_ingestion_payload_with_empty_text() { + fn test_create_ingestion_payload_with_empty_text() -> anyhow::Result<()> { let text = ""; // Empty text let context = "Process this"; let category = "notes"; @@ -333,7 +339,8 @@ mod tests { Err(AppError::NotFound(msg)) => { assert_eq!(msg, "No valid content or files provided"); } - _ => panic!("Expected NotFound error"), + _ => anyhow::bail!("Expected NotFound error"), } + Ok(()) } } diff --git a/common/src/storage/types/ingestion_task.rs b/common/src/storage/types/ingestion_task.rs index 26d5520..faaadcc 100644 --- a/common/src/storage/types/ingestion_task.rs +++ b/common/src/storage/types/ingestion_task.rs @@ -529,6 +529,8 @@ impl IngestionTask { #[cfg(test)] mod tests { + use anyhow::{self, Context}; + use super::*; use crate::storage::types::ingestion_payload::IngestionPayload; @@ -541,16 +543,16 @@ mod tests { } } - async fn memory_db() -> SurrealDbClient { + async fn memory_db() -> anyhow::Result { let namespace = "test_ns"; let database = Uuid::new_v4().to_string(); SurrealDbClient::memory(namespace, &database) .await - .expect("in-memory surrealdb") + .with_context(|| "in-memory surrealdb".to_string()) } #[tokio::test] - async fn test_new_task_defaults() { + async fn test_new_task_defaults() -> anyhow::Result<()> { let user_id = "user123"; let payload = create_payload(user_id); let task = IngestionTask::new(payload.clone(), user_id.to_string()); @@ -562,73 +564,76 @@ mod tests { assert_eq!(task.max_attempts, MAX_ATTEMPTS); assert!(task.locked_at.is_none()); assert!(task.worker_id.is_none()); + Ok(()) } #[tokio::test] - async fn test_create_and_store_task() { - let db = memory_db().await; + async fn test_create_and_store_task() -> anyhow::Result<()> { + let db = memory_db().await?; let user_id = "user123"; let payload = create_payload(user_id); let created = IngestionTask::create_and_add_to_db(payload.clone(), user_id.to_string(), &db) .await - .expect("store"); + .with_context(|| "store".to_string())?; let stored: Option = db .get_item::(&created.id) .await - .expect("fetch"); + .with_context(|| "fetch".to_string())?; - let stored = stored.expect("task exists"); + let stored = stored.with_context(|| "task exists".to_string())?; assert_eq!(stored.id, created.id); assert_eq!(stored.state, TaskState::Pending); assert_eq!(stored.attempts, 0); + Ok(()) } #[tokio::test] - async fn test_claim_and_transition() { - let db = memory_db().await; + async fn test_claim_and_transition() -> anyhow::Result<()> { + let db = memory_db().await?; let user_id = "user123"; let payload = create_payload(user_id); let task = IngestionTask::new(payload, user_id.to_string()); - db.store_item(task.clone()).await.expect("store"); + db.store_item(task.clone()).await.with_context(|| "store".to_string())?; let worker_id = "worker-1"; let now = chrono::Utc::now(); let claimed = IngestionTask::claim_next_ready(&db, worker_id, now, Duration::from_secs(60)) .await - .expect("claim"); + .with_context(|| "claim".to_string())? + .with_context(|| "task claimed".to_string())?; - let claimed = claimed.expect("task claimed"); assert_eq!(claimed.state, TaskState::Reserved); assert_eq!(claimed.worker_id.as_deref(), Some(worker_id)); - let processing = claimed.mark_processing(&db).await.expect("processing"); + let processing = claimed.mark_processing(&db).await.with_context(|| "processing".to_string())?; assert_eq!(processing.state, TaskState::Processing); - let succeeded = processing.mark_succeeded(&db).await.expect("succeeded"); + let succeeded = processing.mark_succeeded(&db).await.with_context(|| "succeeded".to_string())?; assert_eq!(succeeded.state, TaskState::Succeeded); assert!(succeeded.worker_id.is_none()); assert!(succeeded.locked_at.is_none()); + Ok(()) } #[tokio::test] - async fn test_fail_and_dead_letter() { - let db = memory_db().await; + async fn test_fail_and_dead_letter() -> anyhow::Result<()> { + let db = memory_db().await?; let user_id = "user123"; let payload = create_payload(user_id); let task = IngestionTask::new(payload, user_id.to_string()); - db.store_item(task.clone()).await.expect("store"); + db.store_item(task.clone()).await.with_context(|| "store".to_string())?; let worker_id = "worker-dead"; let now = chrono::Utc::now(); let claimed = IngestionTask::claim_next_ready(&db, worker_id, now, Duration::from_secs(60)) .await - .expect("claim") - .expect("claimed"); + .with_context(|| "claim".to_string())? + .with_context(|| "claimed".to_string())?; - let processing = claimed.mark_processing(&db).await.expect("processing"); + let processing = claimed.mark_processing(&db).await.with_context(|| "processing".to_string())?; let error_info = TaskErrorInfo { code: Some("pipeline_error".into()), @@ -638,7 +643,7 @@ mod tests { let failed = processing .mark_failed(error_info.clone(), Duration::from_secs(30), &db) .await - .expect("failed update"); + .with_context(|| "failed update".to_string())?; assert_eq!(failed.state, TaskState::Failed); assert_eq!(failed.error_message.as_deref(), Some("failed")); assert!(failed.worker_id.is_none()); @@ -648,19 +653,20 @@ mod tests { let dead = failed .mark_dead_letter(error_info.clone(), &db) .await - .expect("dead letter"); + .with_context(|| "dead letter".to_string())?; assert_eq!(dead.state, TaskState::DeadLetter); assert_eq!(dead.error_message.as_deref(), Some("failed")); + Ok(()) } #[tokio::test] - async fn test_mark_processing_requires_reservation() { - let db = memory_db().await; + async fn test_mark_processing_requires_reservation() -> anyhow::Result<()> { + let db = memory_db().await?; let user_id = "user123"; let payload = create_payload(user_id); let task = IngestionTask::new(payload.clone(), user_id.to_string()); - db.store_item(task.clone()).await.expect("store"); + db.store_item(task.clone()).await.with_context(|| "store".to_string())?; let err = task .mark_processing(&db) @@ -674,18 +680,19 @@ mod tests { "unexpected message: {message}" ); } - other => panic!("expected validation error, got {other:?}"), + other => anyhow::bail!("expected validation error, got {other:?}"), } + Ok(()) } #[tokio::test] - async fn test_mark_failed_requires_processing() { - let db = memory_db().await; + async fn test_mark_failed_requires_processing() -> anyhow::Result<()> { + let db = memory_db().await?; let user_id = "user123"; let payload = create_payload(user_id); let task = IngestionTask::new(payload.clone(), user_id.to_string()); - db.store_item(task.clone()).await.expect("store"); + db.store_item(task.clone()).await.with_context(|| "store".to_string())?; let err = task .mark_failed( @@ -706,18 +713,19 @@ mod tests { "unexpected message: {message}" ); } - other => panic!("expected validation error, got {other:?}"), + other => anyhow::bail!("expected validation error, got {other:?}"), } + Ok(()) } #[tokio::test] - async fn test_release_requires_reservation() { - let db = memory_db().await; + async fn test_release_requires_reservation() -> anyhow::Result<()> { + let db = memory_db().await?; let user_id = "user123"; let payload = create_payload(user_id); let task = IngestionTask::new(payload.clone(), user_id.to_string()); - db.store_item(task.clone()).await.expect("store"); + db.store_item(task.clone()).await.with_context(|| "store".to_string())?; let err = task .release(&db) @@ -731,7 +739,8 @@ mod tests { "unexpected message: {message}" ); } - other => panic!("expected validation error, got {other:?}"), + other => anyhow::bail!("expected validation error, got {other:?}"), } + Ok(()) } } diff --git a/common/src/storage/types/knowledge_entity.rs b/common/src/storage/types/knowledge_entity.rs index 2c79444..5c1c217 100644 --- a/common/src/storage/types/knowledge_entity.rs +++ b/common/src/storage/types/knowledge_entity.rs @@ -5,7 +5,6 @@ clippy::format_push_string, clippy::uninlined_format_args, clippy::explicit_iter_loop, - clippy::items_after_statements, clippy::get_first, clippy::redundant_closure_for_method_calls )] @@ -317,6 +316,11 @@ impl KnowledgeEntity { } async fn get_user_id_by_id(id: &str, db_client: &SurrealDbClient) -> Result { + #[derive(Deserialize)] + struct Row { + user_id: String, + } + let mut response = db_client .client .query("SELECT user_id FROM type::thing($table, $id) LIMIT 1") @@ -324,10 +328,6 @@ impl KnowledgeEntity { .bind(("id", id.to_string())) .await .map_err(AppError::Database)?; - #[derive(Deserialize)] - struct Row { - user_id: String, - } let rows: Vec = response.take(0).map_err(AppError::Database)?; rows.get(0) .map(|r| r.user_id.clone()) @@ -497,7 +497,6 @@ impl KnowledgeEntity { new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone())); } info!("Successfully generated all new embeddings."); - info!("Successfully generated all new embeddings."); // Clear existing embeddings and index first to prevent SurrealDB panics and dimension conflicts. info!("Removing old index and clearing embeddings..."); @@ -572,14 +571,14 @@ impl KnowledgeEntity { #[cfg(test)] mod tests { + use anyhow::{self, Context}; use super::*; use crate::storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding; use serde_json::json; use uuid::Uuid; #[tokio::test] - async fn test_knowledge_entity_creation() { - // Create basic test entity + async fn test_knowledge_entity_creation() -> anyhow::Result<()> { let source_id = "source123".to_string(); let name = "Test Entity".to_string(); let description = "Test Description".to_string(); @@ -596,7 +595,6 @@ mod tests { user_id.clone(), ); - // Verify all fields are set correctly assert_eq!(entity.source_id, source_id); assert_eq!(entity.name, name); assert_eq!(entity.description, description); @@ -604,11 +602,12 @@ mod tests { assert_eq!(entity.metadata, metadata); assert_eq!(entity.user_id, user_id); assert!(!entity.id.is_empty()); + + Ok(()) } #[tokio::test] - async fn test_knowledge_entity_type_from_string() { - // Test conversion from String to KnowledgeEntityType + async fn test_knowledge_entity_type_from_string() -> anyhow::Result<()> { assert_eq!( KnowledgeEntityType::from("idea".to_string()), KnowledgeEntityType::Idea @@ -639,15 +638,16 @@ mod tests { KnowledgeEntityType::TextSnippet ); - // Test default case assert_eq!( KnowledgeEntityType::from("unknown".to_string()), KnowledgeEntityType::Document ); + + Ok(()) } #[tokio::test] - async fn test_knowledge_entity_variants() { + async fn test_knowledge_entity_variants() -> anyhow::Result<()> { let variants = KnowledgeEntityType::variants(); assert_eq!(variants.len(), 5); assert!(variants.contains(&"Idea")); @@ -655,28 +655,28 @@ mod tests { assert!(variants.contains(&"Document")); assert!(variants.contains(&"Page")); assert!(variants.contains(&"TextSnippet")); + + Ok(()) } #[tokio::test] - async fn test_delete_by_source_id() { - // Setup in-memory database for testing + async fn test_delete_by_source_id() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; db.apply_migrations() .await - .expect("Failed to apply migrations"); + .with_context(|| "Failed to apply migrations".to_string())?; - // Create two entities with the same source_id let source_id = "source123".to_string(); let entity_type = KnowledgeEntityType::Document; let user_id = "user123".to_string(); KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 5) .await - .expect("Failed to redefine index length"); + .with_context(|| "Failed to redefine index length".to_string())?; let entity1 = KnowledgeEntity::new( source_id.clone(), @@ -696,7 +696,6 @@ mod tests { user_id.clone(), ); - // Create an entity with a different source_id let different_source_id = "different_source".to_string(); let different_entity = KnowledgeEntity::new( different_source_id.clone(), @@ -708,23 +707,20 @@ mod tests { ); let emb = vec![0.1, 0.2, 0.3, 0.4, 0.5]; - // Store the entities KnowledgeEntity::store_with_embedding(entity1.clone(), emb.clone(), &db) .await - .expect("Failed to store entity 1"); + .with_context(|| "Failed to store entity 1".to_string())?; KnowledgeEntity::store_with_embedding(entity2.clone(), emb.clone(), &db) .await - .expect("Failed to store entity 2"); + .with_context(|| "Failed to store entity 2".to_string())?; KnowledgeEntity::store_with_embedding(different_entity.clone(), emb.clone(), &db) .await - .expect("Failed to store different entity"); + .with_context(|| "Failed to store different entity".to_string())?; - // Delete by source_id KnowledgeEntity::delete_by_source_id(&source_id, &db) .await - .expect("Failed to delete entities by source_id"); + .with_context(|| "Failed to delete entities by source_id".to_string())?; - // Verify all entities with the specified source_id are deleted let query = format!( "SELECT * FROM {} WHERE source_id = '{}'", KnowledgeEntity::table_name(), @@ -734,16 +730,11 @@ mod tests { .client .query(query) .await - .expect("Query failed") + .with_context(|| "Query failed".to_string())? .take(0) - .expect("Failed to get query results"); - assert_eq!( - remaining.len(), - 0, - "All entities with the source_id should be deleted" - ); + .with_context(|| "Failed to get query results".to_string())?; + assert!(remaining.is_empty(), "All entities with the source_id should be deleted"); - // Verify the entity with a different source_id still exists let different_query = format!( "SELECT * FROM {} WHERE source_id = '{}'", KnowledgeEntity::table_name(), @@ -753,15 +744,20 @@ mod tests { .client .query(different_query) .await - .expect("Query failed") + .with_context(|| "Query failed".to_string())? .take(0) - .expect("Failed to get query results"); + .with_context(|| "Failed to get query results".to_string())?; assert_eq!( different_remaining.len(), 1, "Entity with different source_id should still exist" ); - assert_eq!(different_remaining[0].id, different_entity.id); + assert_eq!( + different_remaining.first().context("Expected entity to exist")?.id, + different_entity.id + ); + + Ok(()) } #[tokio::test] @@ -833,35 +829,37 @@ mod tests { let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; db.apply_migrations() .await - .expect("Failed to apply migrations"); + .with_context(|| "Failed to apply migrations".to_string())?; KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) .await - .expect("Failed to redefine index length"); + .with_context(|| "Failed to redefine index length".to_string())?; let results = KnowledgeEntity::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user") .await - .expect("vector search"); + .with_context(|| "vector search".to_string())?; assert!(results.is_empty()); + + Ok(()) } #[tokio::test] - async fn test_vector_search_single_result() { + async fn test_vector_search_single_result() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; db.apply_migrations() .await - .expect("Failed to apply migrations"); + .with_context(|| "Failed to apply migrations".to_string())?; KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) .await - .expect("Failed to redefine index length"); + .with_context(|| "Failed to redefine index length".to_string())?; let user_id = "user".to_string(); let source_id = "src".to_string(); @@ -876,9 +874,12 @@ mod tests { KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], &db) .await - .expect("store entity with embedding"); + .with_context(|| "store entity with embedding".to_string())?; - let stored_entity: Option = db.get_item(&entity.id).await.unwrap(); + let stored_entity: Option = db + .get_item(&entity.id) + .await + .with_context(|| "Failed to get entity".to_string())?; assert!(stored_entity.is_some()); let stored_embeddings: Vec = db @@ -888,42 +889,44 @@ mod tests { KnowledgeEntityEmbedding::table_name() )) .await - .expect("query embeddings") + .with_context(|| "query embeddings".to_string())? .take(0) - .expect("take embeddings"); + .with_context(|| "take embeddings".to_string())?; assert_eq!(stored_embeddings.len(), 1); let rid = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id); let fetched_emb = KnowledgeEntityEmbedding::get_by_entity_id(&rid, &db) .await - .expect("fetch embedding"); + .with_context(|| "fetch embedding".to_string())?; assert!(fetched_emb.is_some()); let results = KnowledgeEntity::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id) .await - .expect("vector search"); + .with_context(|| "vector search".to_string())?; assert_eq!(results.len(), 1); - let res = &results[0]; + let res = results.first().context("Expected at least one result")?; assert_eq!(res.entity.id, entity.id); assert_eq!(res.entity.source_id, source_id); assert_eq!(res.entity.name, "hello"); + + Ok(()) } #[tokio::test] - async fn test_vector_search_orders_by_similarity() { + async fn test_vector_search_orders_by_similarity() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; db.apply_migrations() .await - .expect("Failed to apply migrations"); + .with_context(|| "Failed to apply migrations".to_string())?; KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) .await - .expect("Failed to redefine index length"); + .with_context(|| "Failed to redefine index length".to_string())?; let user_id = "user".to_string(); let e1 = KnowledgeEntity::new( @@ -945,13 +948,19 @@ mod tests { KnowledgeEntity::store_with_embedding(e1.clone(), vec![1.0, 0.0, 0.0], &db) .await - .expect("store e1"); + .with_context(|| "store e1".to_string())?; KnowledgeEntity::store_with_embedding(e2.clone(), vec![0.0, 1.0, 0.0], &db) .await - .expect("store e2"); + .with_context(|| "store e2".to_string())?; - let stored_e1: Option = db.get_item(&e1.id).await.unwrap(); - let stored_e2: Option = db.get_item(&e2.id).await.unwrap(); + let stored_e1: Option = db + .get_item(&e1.id) + .await + .with_context(|| "Failed to get entity".to_string())?; + let stored_e2: Option = db + .get_item(&e2.id) + .await + .with_context(|| "Failed to get entity".to_string())?; assert!(stored_e1.is_some() && stored_e2.is_some()); let stored_embeddings: Vec = db @@ -961,45 +970,53 @@ mod tests { KnowledgeEntityEmbedding::table_name() )) .await - .expect("query embeddings") + .with_context(|| "query embeddings".to_string())? .take(0) - .expect("take embeddings"); + .with_context(|| "take embeddings".to_string())?; assert_eq!(stored_embeddings.len(), 2); let rid_e1 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e1.id); let rid_e2 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e2.id); assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e1, &db) .await - .unwrap() + .with_context(|| "get embedding e1".to_string())? .is_some()); assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e2, &db) .await - .unwrap() + .with_context(|| "get embedding e2".to_string())? .is_some()); let results = KnowledgeEntity::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id) .await - .expect("vector search"); + .with_context(|| "vector search".to_string())?; assert_eq!(results.len(), 2); - assert_eq!(results[0].entity.id, e2.id); - assert_eq!(results[1].entity.id, e1.id); + assert_eq!( + results.first().context("Expected at least one result")?.entity.id, + e2.id + ); + assert_eq!( + results.get(1).context("Expected at least two results")?.entity.id, + e1.id + ); + + Ok(()) } #[tokio::test] - async fn test_vector_search_with_orphaned_embedding() { + async fn test_vector_search_with_orphaned_embedding() -> anyhow::Result<()> { let namespace = "test_ns_orphan"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; db.apply_migrations() .await - .expect("Failed to apply migrations"); + .with_context(|| "Failed to apply migrations".to_string())?; KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) .await - .expect("Failed to redefine index length"); + .with_context(|| "Failed to redefine index length".to_string())?; let user_id = "user".to_string(); let source_id = "src".to_string(); @@ -1014,21 +1031,20 @@ mod tests { KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], &db) .await - .expect("store entity with embedding"); + .with_context(|| "store entity with embedding".to_string())?; - // Manually delete the entity to create an orphan let query = format!("DELETE type::thing('knowledge_entity', '{}')", entity.id); - db.client.query(query).await.expect("delete entity"); + db.client + .query(query) + .await + .with_context(|| "delete entity".to_string())?; - // Now search let results = KnowledgeEntity::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id) .await - .expect("search should succeed even with orphans"); + .with_context(|| "search should succeed even with orphans".to_string())?; - assert!( - results.is_empty(), - "Should return empty result for orphan, got: {:?}", - results - ); + assert!(results.is_empty(), "Should return empty result for orphan, got: {:?}", results); + + Ok(()) } } diff --git a/common/src/storage/types/knowledge_entity_embedding.rs b/common/src/storage/types/knowledge_entity_embedding.rs index 6e92f62..8fef41b 100644 --- a/common/src/storage/types/knowledge_entity_embedding.rs +++ b/common/src/storage/types/knowledge_entity_embedding.rs @@ -110,11 +110,15 @@ impl KnowledgeEntityEmbedding { } /// Delete embeddings by source_id (via joining to knowledge_entity table) - #[allow(clippy::items_after_statements)] pub async fn delete_by_source_id( source_id: &str, db: &SurrealDbClient, ) -> Result<(), AppError> { + #[derive(Deserialize)] + struct IdRow { + id: RecordId, + } + let query = "SELECT id FROM knowledge_entity WHERE source_id = $source_id"; let mut res = db .client @@ -122,11 +126,6 @@ impl KnowledgeEntityEmbedding { .bind(("source_id", source_id.to_owned())) .await .map_err(AppError::Database)?; - #[allow(clippy::missing_docs_in_private_items)] - #[derive(Deserialize)] - struct IdRow { - id: RecordId, - } let ids: Vec = res.take(0).map_err(AppError::Database)?; for row in ids { @@ -138,6 +137,7 @@ impl KnowledgeEntityEmbedding { #[cfg(test)] mod tests { + use anyhow::{self, Context}; use super::*; use crate::storage::db::SurrealDbClient; use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}; @@ -145,18 +145,18 @@ mod tests { use surrealdb::Value as SurrealValue; use uuid::Uuid; - async fn setup_test_db() -> SurrealDbClient { + async fn setup_test_db() -> anyhow::Result { let namespace = "test_ns"; let database = Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, &database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; db.apply_migrations() .await - .expect("Failed to apply migrations"); + .with_context(|| "Failed to apply migrations".to_string())?; - db + Ok(db) } fn build_knowledge_entity_with_id( @@ -178,11 +178,11 @@ mod tests { } #[tokio::test] - async fn test_create_and_get_by_entity_id() { - let db = setup_test_db().await; + async fn test_create_and_get_by_entity_id() -> anyhow::Result<()> { + let db = setup_test_db().await?; KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) .await - .expect("set test index dimension"); + .with_context(|| "set test index dimension".to_string())?; let user_id = "user_ke"; let entity_key = "entity-1"; let source_id = "source-ke"; @@ -192,26 +192,28 @@ mod tests { KnowledgeEntity::store_with_embedding(entity.clone(), embedding_vec.clone(), &db) .await - .expect("Failed to store entity with embedding"); + .with_context(|| "Failed to store entity with embedding".to_string())?; let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id); let fetched = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db) .await - .expect("Failed to get embedding by entity_id") - .expect("Expected embedding to exist"); + .with_context(|| "Failed to get embedding by entity_id".to_string())? + .ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?; assert_eq!(fetched.user_id, user_id); assert_eq!(fetched.entity_id, entity_rid); assert_eq!(fetched.embedding, embedding_vec); + + Ok(()) } #[tokio::test] - async fn test_delete_by_entity_id() { - let db = setup_test_db().await; + async fn test_delete_by_entity_id() -> anyhow::Result<()> { + let db = setup_test_db().await?; KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) .await - .expect("set test index dimension"); + .with_context(|| "set test index dimension".to_string())?; let user_id = "user_ke"; let entity_key = "entity-delete"; let source_id = "source-del"; @@ -220,61 +222,67 @@ mod tests { KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.5_f32, 0.6, 0.7], &db) .await - .expect("Failed to store entity with embedding"); + .with_context(|| "Failed to store entity with embedding".to_string())?; let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id); let existing = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db) .await - .expect("Failed to get embedding before delete"); + .with_context(|| "Failed to get embedding before delete".to_string())?; assert!(existing.is_some()); KnowledgeEntityEmbedding::delete_by_entity_id(&entity_rid, &db) .await - .expect("Failed to delete by entity_id"); + .with_context(|| "Failed to delete by entity_id".to_string())?; let after = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db) .await - .expect("Failed to get embedding after delete"); + .with_context(|| "Failed to get embedding after delete".to_string())?; assert!(after.is_none()); + + Ok(()) } #[tokio::test] - async fn test_store_with_embedding_creates_entity_and_embedding() { - let db = setup_test_db().await; + async fn test_store_with_embedding_creates_entity_and_embedding() -> anyhow::Result<()> { + let db = setup_test_db().await?; let user_id = "user_store"; let source_id = "source_store"; let embedding = vec![0.2_f32, 0.3, 0.4]; KnowledgeEntityEmbedding::redefine_hnsw_index(&db, embedding.len()) .await - .expect("set test index dimension"); + .with_context(|| "set test index dimension".to_string())?; let entity = build_knowledge_entity_with_id("entity-store", source_id, user_id); KnowledgeEntity::store_with_embedding(entity.clone(), embedding.clone(), &db) .await - .expect("Failed to store entity with embedding"); + .with_context(|| "Failed to store entity with embedding".to_string())?; - let stored_entity: Option = db.get_item(&entity.id).await.unwrap(); + let stored_entity: Option = db + .get_item(&entity.id) + .await + .with_context(|| "Failed to get entity".to_string())?; assert!(stored_entity.is_some()); let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id); let stored_embedding = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db) .await - .expect("Failed to fetch embedding"); - assert!(stored_embedding.is_some()); - let stored_embedding = stored_embedding.unwrap(); + .with_context(|| "Failed to fetch embedding".to_string())?; + let stored_embedding = stored_embedding.ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?; assert_eq!(stored_embedding.user_id, user_id); assert_eq!(stored_embedding.entity_id, entity_rid); + + Ok(()) } #[tokio::test] - async fn test_delete_by_source_id() { - let db = setup_test_db().await; + async fn test_delete_by_source_id() -> anyhow::Result<()> { + let db = setup_test_db().await?; KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) .await - .expect("set test index dimension"); + .with_context(|| "set test index dimension".to_string())?; let user_id = "user_ke"; let source_id = "shared-ke"; let other_source = "other-ke"; @@ -285,13 +293,13 @@ mod tests { KnowledgeEntity::store_with_embedding(entity1.clone(), vec![1.0_f32, 1.1, 1.2], &db) .await - .expect("Failed to store entity with embedding"); + .with_context(|| "Failed to store entity with embedding".to_string())?; KnowledgeEntity::store_with_embedding(entity2.clone(), vec![2.0_f32, 2.1, 2.2], &db) .await - .expect("Failed to store entity with embedding"); + .with_context(|| "Failed to store entity with embedding".to_string())?; KnowledgeEntity::store_with_embedding(entity_other.clone(), vec![3.0_f32, 3.1, 3.2], &db) .await - .expect("Failed to store entity with embedding"); + .with_context(|| "Failed to store entity with embedding".to_string())?; let entity1_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity1.id); let entity2_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity2.id); @@ -299,59 +307,74 @@ mod tests { KnowledgeEntityEmbedding::delete_by_source_id(source_id, &db) .await - .expect("Failed to delete by source_id"); + .with_context(|| "Failed to delete by source_id".to_string())?; assert!( KnowledgeEntityEmbedding::get_by_entity_id(&entity1_rid, &db) .await - .unwrap() + .with_context(|| "get entity1 embedding after delete".to_string())? .is_none() ); assert!( KnowledgeEntityEmbedding::get_by_entity_id(&entity2_rid, &db) .await - .unwrap() + .with_context(|| "get entity2 embedding after delete".to_string())? .is_none() ); assert!(KnowledgeEntityEmbedding::get_by_entity_id(&other_rid, &db) .await - .unwrap() + .with_context(|| "get other embedding after delete".to_string())? .is_some()); + + Ok(()) } #[tokio::test] - async fn test_redefine_hnsw_index_updates_dimension() { - let db = setup_test_db().await; + async fn test_redefine_hnsw_index_updates_dimension() -> anyhow::Result<()> { + let db = setup_test_db().await?; KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 16) .await - .expect("failed to redefine index"); + .with_context(|| "failed to redefine index".to_string())?; let mut info_res = db .client .query("INFO FOR TABLE knowledge_entity_embedding;") .await - .expect("info query failed"); - let info: SurrealValue = info_res.take(0).expect("failed to take info result"); - let info_json: serde_json::Value = - serde_json::to_value(info).expect("failed to convert info to json"); - let idx_sql = info_json["Object"]["indexes"]["Object"] - ["idx_embedding_knowledge_entity_embedding"]["Strand"] - .as_str() + .with_context(|| "info query failed".to_string())?; + let info: SurrealValue = info_res + .take(0) + .with_context(|| "failed to take info result".to_string())?; + let info_json: serde_json::Value = serde_json::to_value(info) + .with_context(|| "failed to convert info to json".to_string())?; + let idx_sql = info_json + .get("Object") + .and_then(|v| v.get("indexes")) + .and_then(|v| v.get("Object")) + .and_then(|v| v.get("idx_embedding_knowledge_entity_embedding")) + .and_then(|v| v.get("Strand")) + .and_then(|v| v.as_str()) .unwrap_or_default(); assert!( idx_sql.contains("DIMENSION 16"), "expected index definition to contain new dimension, got: {idx_sql}" ); + + Ok(()) } #[tokio::test] - async fn test_fetch_entity_via_record_id() { - let db = setup_test_db().await; + async fn test_fetch_entity_via_record_id() -> anyhow::Result<()> { + #[derive(Deserialize)] + struct Row { + entity_id: KnowledgeEntity, + } + + let db = setup_test_db().await?; KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) .await - .expect("set test index dimension"); + .with_context(|| "set test index dimension".to_string())?; let user_id = "user_ke"; let entity_key = "entity-fetch"; let source_id = "source-fetch"; @@ -359,15 +382,10 @@ mod tests { let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id); KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.7_f32, 0.8, 0.9], &db) .await - .expect("Failed to store entity with embedding"); + .with_context(|| "Failed to store entity with embedding".to_string())?; let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id); - #[derive(Deserialize)] - struct Row { - entity_id: KnowledgeEntity, - } - let mut res = db .client .query( @@ -375,13 +393,17 @@ mod tests { ) .bind(("id", entity_rid.clone())) .await - .expect("failed to fetch embedding with FETCH"); - let rows: Vec = res.take(0).expect("failed to deserialize fetch rows"); + .with_context(|| "failed to fetch embedding with FETCH".to_string())?; + let rows: Vec = res + .take(0) + .with_context(|| "failed to deserialize fetch rows".to_string())?; assert_eq!(rows.len(), 1); - let fetched_entity = &rows[0].entity_id; + let fetched_entity = &rows.first().context("Expected at least one result")?.entity_id; assert_eq!(fetched_entity.id, entity_key); assert_eq!(fetched_entity.name, "Test entity"); assert_eq!(fetched_entity.user_id, user_id); + + Ok(()) } } diff --git a/common/src/storage/types/knowledge_relationship.rs b/common/src/storage/types/knowledge_relationship.rs index f63c8c7..f7bf164 100644 --- a/common/src/storage/types/knowledge_relationship.rs +++ b/common/src/storage/types/knowledge_relationship.rs @@ -124,6 +124,7 @@ impl KnowledgeRelationship { #[cfg(test)] mod tests { + use anyhow::{self, Context}; use super::*; use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}; @@ -155,10 +156,9 @@ mod tests { result.take(0).expect("failed to take relationship by id") } - // Helper function to create a test knowledge entity for the relationship tests - async fn create_test_entity(name: &str, db_client: &SurrealDbClient) -> String { + async fn create_test_entity(name: &str, db_client: &SurrealDbClient) -> anyhow::Result { let source_id = "source123".to_string(); - let description = format!("Description for {}", name); + let description = format!("Description for {name}"); let entity_type = KnowledgeEntityType::Document; let user_id = "user123".to_string(); @@ -174,12 +174,14 @@ mod tests { let stored: Option = db_client .store_item(entity) .await - .expect("Failed to store entity"); - stored.unwrap().id + .with_context(|| "Failed to store entity".to_string())?; + stored + .ok_or_else(|| anyhow::anyhow!("Expected stored entity to return Some")) + .map(|e| e.id) } #[tokio::test] - async fn test_relationship_creation() { + async fn test_relationship_creation() -> anyhow::Result<()> { let in_id = "entity1".to_string(); let out_id = "entity2".to_string(); let user_id = "user123".to_string(); @@ -194,25 +196,23 @@ mod tests { relationship_type.clone(), ); - // Verify fields are correctly set assert_eq!(relationship.in_, in_id); assert_eq!(relationship.out, out_id); assert_eq!(relationship.metadata.user_id, user_id); assert_eq!(relationship.metadata.source_id, source_id); assert_eq!(relationship.metadata.relationship_type, relationship_type); assert!(!relationship.id.is_empty()); + + Ok(()) } #[tokio::test] - async fn test_store_and_verify_by_source_id() { - // Setup in-memory database for testing + async fn test_store_and_verify_by_source_id() -> anyhow::Result<()> { let db = setup_test_db().await; - // Create two entities to relate - let entity1_id = create_test_entity("Entity 1", &db).await; - let entity2_id = create_test_entity("Entity 2", &db).await; + let entity1_id = create_test_entity("Entity 1", &db).await?; + let entity2_id = create_test_entity("Entity 2", &db).await?; - // Create relationship let user_id = "user123".to_string(); let source_id = "source123".to_string(); let relationship_type = "references".to_string(); @@ -225,11 +225,10 @@ mod tests { relationship_type, ); - // Store the relationship relationship .store_relationship(&db) .await - .expect("Failed to store relationship"); + .with_context(|| "Failed to store relationship".to_string())?; let persisted = get_relationship_by_id(&relationship.id, &db) .await @@ -239,8 +238,6 @@ mod tests { assert_eq!(persisted.metadata.user_id, user_id); assert_eq!(persisted.metadata.source_id, source_id); - // Query to verify the relationship exists by checking for relationships with our source_id - // This approach is more reliable than trying to look up by ID let mut check_result = db .query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id") .bind(("source_id", source_id.clone())) @@ -253,14 +250,16 @@ mod tests { 1, "Expected one relationship for source_id" ); + + Ok(()) } #[tokio::test] - async fn test_store_relationship_resists_query_injection() { + async fn test_store_relationship_resists_query_injection() -> anyhow::Result<()> { let db = setup_test_db().await; - let entity1_id = create_test_entity("Entity 1", &db).await; - let entity2_id = create_test_entity("Entity 2", &db).await; + let entity1_id = create_test_entity("Entity 1", &db).await?; + let entity2_id = create_test_entity("Entity 2", &db).await?; let relationship = KnowledgeRelationship::new( entity1_id, @@ -288,18 +287,17 @@ mod tests { rows[0].metadata.source_id, "source123'; DELETE FROM relates_to; --" ); + + Ok(()) } #[tokio::test] - async fn test_store_and_delete_relationship() { - // Setup in-memory database for testing + async fn test_store_and_delete_relationship() -> anyhow::Result<()> { let db = setup_test_db().await; - // Create two entities to relate - let entity1_id = create_test_entity("Entity 1", &db).await; - let entity2_id = create_test_entity("Entity 2", &db).await; + let entity1_id = create_test_entity("Entity 1", &db).await?; + let entity2_id = create_test_entity("Entity 2", &db).await?; - // Create relationship let user_id = "user123".to_string(); let source_id = "source123".to_string(); let relationship_type = "references".to_string(); @@ -312,52 +310,44 @@ mod tests { relationship_type, ); - // Store relationship relationship .store_relationship(&db) .await - .expect("Failed to store relationship"); + .with_context(|| "Failed to store relationship".to_string())?; - // Ensure relationship exists before deletion attempt let mut existing_before_delete = db .query(format!( - "SELECT * FROM relates_to WHERE metadata.user_id = '{}' AND metadata.source_id = '{}'", - user_id, source_id + "SELECT * FROM relates_to WHERE metadata.user_id = '{user_id}' AND metadata.source_id = '{source_id}'" )) .await - .expect("Query failed"); + .with_context(|| "Query failed".to_string())?; let before_results: Vec = existing_before_delete.take(0).unwrap_or_default(); - assert!( - !before_results.is_empty(), - "Relationship should exist before deletion" - ); + assert!(!before_results.is_empty(), "Relationship should exist before deletion"); - // Delete relationship by ID KnowledgeRelationship::delete_relationship_by_id(&relationship.id, &user_id, &db) .await - .expect("Failed to delete relationship by ID"); + .with_context(|| "Failed to delete relationship by ID".to_string())?; - // Query to verify relationship was deleted let mut result = db .query(format!( - "SELECT * FROM relates_to WHERE metadata.user_id = '{}' AND metadata.source_id = '{}'", - user_id, source_id + "SELECT * FROM relates_to WHERE metadata.user_id = '{user_id}' AND metadata.source_id = '{source_id}'" )) .await - .expect("Query failed"); + .with_context(|| "Query failed".to_string())?; let results: Vec = result.take(0).unwrap_or_default(); - // Verify relationship no longer exists assert!(results.is_empty(), "Relationship should be deleted"); + + Ok(()) } #[tokio::test] - async fn test_delete_relationship_by_id_unauthorized() { + async fn test_delete_relationship_by_id_unauthorized() -> anyhow::Result<()> { let db = setup_test_db().await; - let entity1_id = create_test_entity("Entity 1", &db).await; - let entity2_id = create_test_entity("Entity 2", &db).await; + let entity1_id = create_test_entity("Entity 1", &db).await?; + let entity2_id = create_test_entity("Entity 2", &db).await?; let owner_user_id = "owner-user".to_string(); let source_id = "source123".to_string(); @@ -373,20 +363,16 @@ mod tests { relationship .store_relationship(&db) .await - .expect("Failed to store relationship"); + .with_context(|| "Failed to store relationship".to_string())?; let mut before_attempt = db .query(format!( - "SELECT * FROM relates_to WHERE metadata.user_id = '{}'", - owner_user_id + "SELECT * FROM relates_to WHERE metadata.user_id = '{owner_user_id}'" )) .await - .expect("Query failed"); + .with_context(|| "Query failed".to_string())?; let before_results: Vec = before_attempt.take(0).unwrap_or_default(); - assert!( - !before_results.is_empty(), - "Relationship should exist before unauthorized delete attempt" - ); + assert!(!before_results.is_empty(), "Relationship should exist before unauthorized delete attempt"); let result = KnowledgeRelationship::delete_relationship_by_id( &relationship.id, @@ -397,40 +383,34 @@ mod tests { match result { Err(AppError::Auth(_)) => {} - _ => panic!("Expected authorization error when deleting someone else's relationship"), + _ => anyhow::bail!("Expected authorization error when deleting someone else's relationship"), } let mut after_attempt = db .query(format!( - "SELECT * FROM relates_to WHERE metadata.user_id = '{}'", - owner_user_id + "SELECT * FROM relates_to WHERE metadata.user_id = '{owner_user_id}'" )) .await - .expect("Query failed"); + .with_context(|| "Query failed".to_string())?; let results: Vec = after_attempt.take(0).unwrap_or_default(); - assert!( - !results.is_empty(), - "Relationship should still exist after unauthorized delete attempt" - ); + assert!(!results.is_empty(), "Relationship should still exist after unauthorized delete attempt"); + + Ok(()) } #[tokio::test] - async fn test_store_relationship_exists() { - // Setup in-memory database for testing + async fn test_store_relationship_exists() -> anyhow::Result<()> { let db = setup_test_db().await; - // Create entities to relate - let entity1_id = create_test_entity("Entity 1", &db).await; - let entity2_id = create_test_entity("Entity 2", &db).await; - let entity3_id = create_test_entity("Entity 3", &db).await; + let entity1_id = create_test_entity("Entity 1", &db).await?; + let entity2_id = create_test_entity("Entity 2", &db).await?; + let entity3_id = create_test_entity("Entity 3", &db).await?; - // Create relationships with the same source_id let user_id = "user123".to_string(); let source_id = "source123".to_string(); let different_source_id = "different_source".to_string(); - // Create two relationships with the same source_id let relationship1 = KnowledgeRelationship::new( entity1_id.clone(), entity2_id.clone(), @@ -447,7 +427,6 @@ mod tests { "contains".to_string(), ); - // Create a relationship with a different source_id let different_relationship = KnowledgeRelationship::new( entity1_id.clone(), entity3_id.clone(), @@ -456,21 +435,19 @@ mod tests { "mentions".to_string(), ); - // Store all relationships relationship1 .store_relationship(&db) .await - .expect("Failed to store relationship 1"); + .with_context(|| "Failed to store relationship 1".to_string())?; relationship2 .store_relationship(&db) .await - .expect("Failed to store relationship 2"); + .with_context(|| "Failed to store relationship 2".to_string())?; different_relationship .store_relationship(&db) .await - .expect("Failed to store different relationship"); + .with_context(|| "Failed to store different relationship".to_string())?; - // Sanity-check setup: exactly two relationships use source_id and one uses different_source_id. let mut before_delete = db .query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id") .bind(("source_id", source_id.clone())) @@ -489,31 +466,30 @@ mod tests { before_delete_different.take(0).unwrap_or_default(); assert_eq!(before_delete_different_rows.len(), 1); - // Delete relationships by source_id KnowledgeRelationship::delete_relationships_by_source_id(&source_id, &db) .await - .expect("Failed to delete relationships by source_id"); + .with_context(|| "Failed to delete relationships by source_id".to_string())?; - // Query to verify the specific relationships with source_id were deleted. let result1 = get_relationship_by_id(&relationship1.id, &db).await; let result2 = get_relationship_by_id(&relationship2.id, &db).await; let different_result = get_relationship_by_id(&different_relationship.id, &db).await; - // Verify relationships with the source_id are deleted assert!(result1.is_none(), "Relationship 1 should be deleted"); assert!(result2.is_none(), "Relationship 2 should be deleted"); let remaining = different_result.expect("Relationship with different source_id should remain"); assert_eq!(remaining.metadata.source_id, different_source_id); + + Ok(()) } #[tokio::test] - async fn test_delete_relationships_by_source_id_resists_query_injection() { + async fn test_delete_relationships_by_source_id_resists_query_injection() -> anyhow::Result<()> { let db = setup_test_db().await; - let entity1_id = create_test_entity("Entity 1", &db).await; - let entity2_id = create_test_entity("Entity 2", &db).await; - let entity3_id = create_test_entity("Entity 3", &db).await; + let entity1_id = create_test_entity("Entity 1", &db).await?; + let entity2_id = create_test_entity("Entity 2", &db).await?; + let entity3_id = create_test_entity("Entity 3", &db).await?; let safe_relationship = KnowledgeRelationship::new( entity1_id.clone(), @@ -552,5 +528,7 @@ mod tests { remaining_other.is_some(), "Other relationship should remain" ); + + Ok(()) } } diff --git a/common/src/storage/types/message.rs b/common/src/storage/types/message.rs index a120816..da0a612 100644 --- a/common/src/storage/types/message.rs +++ b/common/src/storage/types/message.rs @@ -66,12 +66,12 @@ pub fn format_history(history: &[Message]) -> String { #[cfg(test)] mod tests { + use anyhow::{self, Context}; use super::*; use crate::storage::db::SurrealDbClient; #[tokio::test] - async fn test_message_creation() { - // Test basic message creation + async fn test_message_creation() -> anyhow::Result<()> { let conversation_id = "test_conversation"; let content = "This is a test message"; let role = MessageRole::User; @@ -84,24 +84,23 @@ mod tests { references.clone(), ); - // Verify message properties assert_eq!(message.conversation_id, conversation_id); assert_eq!(message.content, content); assert_eq!(message.role, role); assert_eq!(message.references, references); assert!(!message.id.is_empty()); + + Ok(()) } #[tokio::test] - async fn test_message_persistence() { - // Setup in-memory database for testing + async fn test_message_persistence() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &uuid::Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; - // Create and store a message let conversation_id = "test_conversation"; let message = Message::new( conversation_id.to_string(), @@ -111,39 +110,37 @@ mod tests { ); let message_id = message.id.clone(); - // Store the message db.store_item(message.clone()) .await - .expect("Failed to store message"); + .with_context(|| "Failed to store message".to_string())?; - // Retrieve the message let retrieved: Option = db .get_item(&message_id) .await - .expect("Failed to retrieve message"); + .with_context(|| "Failed to retrieve message".to_string())?; - assert!(retrieved.is_some()); - let retrieved = retrieved.unwrap(); + let retrieved = retrieved.ok_or_else(|| anyhow::anyhow!("Expected message to exist"))?; - // Verify retrieved properties match original assert_eq!(retrieved.id, message.id); assert_eq!(retrieved.conversation_id, message.conversation_id); assert_eq!(retrieved.role, message.role); assert_eq!(retrieved.content, message.content); assert_eq!(retrieved.references, message.references); + + Ok(()) } #[tokio::test] - async fn test_message_role_display() { - // Test the Display implementation for MessageRole + async fn test_message_role_display() -> anyhow::Result<()> { assert_eq!(format!("{}", MessageRole::User), "User"); assert_eq!(format!("{}", MessageRole::AI), "AI"); assert_eq!(format!("{}", MessageRole::System), "System"); + + Ok(()) } #[tokio::test] - async fn test_message_display() { - // Test the Display implementation for Message + async fn test_message_display() -> anyhow::Result<()> { let message = Message { id: "test_id".to_string(), created_at: Utc::now(), @@ -154,12 +151,13 @@ mod tests { references: None, }; - assert_eq!(format!("{}", message), "User: Hello world"); + assert_eq!(format!("{message}"), "User: Hello world"); + + Ok(()) } #[tokio::test] - async fn test_format_history() { - // Create a vector of messages + async fn test_format_history() -> anyhow::Result<()> { let messages = vec![ Message { id: "1".to_string(), @@ -181,10 +179,10 @@ mod tests { }, ]; - // Format the history let formatted = format_history(&messages); - // Verify the formatting assert_eq!(formatted, "User: Hello\nAI: Hi there!"); + + Ok(()) } } diff --git a/common/src/storage/types/scratchpad.rs b/common/src/storage/types/scratchpad.rs index ae3fedb..f166f6f 100644 --- a/common/src/storage/types/scratchpad.rs +++ b/common/src/storage/types/scratchpad.rs @@ -216,20 +216,22 @@ impl Scratchpad { #[cfg(test)] mod tests { + use anyhow::{self, Context}; + use super::*; #[tokio::test] - async fn test_create_scratchpad() { + async fn test_create_scratchpad() -> anyhow::Result<()> { // Setup in-memory database for testing let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; db.apply_migrations() .await - .expect("Failed to apply migrations"); + .with_context(|| "Failed to apply migrations".to_string())?; // Create a new scratchpad let user_id = "test_user"; @@ -254,29 +256,28 @@ mod tests { let retrieved: Option = db .get_item(&scratchpad.id) .await - .expect("Failed to retrieve scratchpad"); - assert!(retrieved.is_some()); - - let retrieved = retrieved.unwrap(); + .with_context(|| "Failed to retrieve scratchpad".to_string())?; + let retrieved = retrieved.with_context(|| "expected scratchpad to exist".to_string())?; assert_eq!(retrieved.id, scratchpad.id); assert_eq!(retrieved.user_id, user_id); assert_eq!(retrieved.title, title); assert!(!retrieved.is_archived); assert!(retrieved.archived_at.is_none()); assert!(retrieved.ingested_at.is_none()); + Ok(()) } #[tokio::test] - async fn test_get_by_user() { + async fn test_get_by_user() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; db.apply_migrations() .await - .expect("Failed to apply migrations"); + .with_context(|| "Failed to apply migrations".to_string())?; let user_id = "test_user"; @@ -288,19 +289,21 @@ mod tests { // Store them let scratchpad1_id = scratchpad1.id.clone(); let scratchpad2_id = scratchpad2.id.clone(); - db.store_item(scratchpad1).await.unwrap(); - db.store_item(scratchpad2).await.unwrap(); - db.store_item(scratchpad3).await.unwrap(); + db.store_item(scratchpad1).await.with_context(|| "store scratchpad1".to_string())?; + db.store_item(scratchpad2).await.with_context(|| "store scratchpad2".to_string())?; + db.store_item(scratchpad3).await.with_context(|| "store scratchpad3".to_string())?; // Archive one of the user's scratchpads Scratchpad::archive(&scratchpad2_id, user_id, &db, false) .await - .unwrap(); + .with_context(|| "archive".to_string())?; // Get scratchpads for user_id - let user_scratchpads = Scratchpad::get_by_user(user_id, &db).await.unwrap(); + let user_scratchpads = Scratchpad::get_by_user(user_id, &db) + .await + .with_context(|| "get_by_user".to_string())?; assert_eq!(user_scratchpads.len(), 1); - assert_eq!(user_scratchpads[0].id, scratchpad1_id); + assert_eq!(user_scratchpads.first().map(|s| &s.id), Some(&scratchpad1_id)); // Verify they belong to the user for scratchpad in &user_scratchpads { @@ -309,177 +312,183 @@ mod tests { let archived = Scratchpad::get_archived_by_user(user_id, &db) .await - .unwrap(); + .with_context(|| "get_archived_by_user".to_string())?; assert_eq!(archived.len(), 1); - assert_eq!(archived[0].id, scratchpad2_id); - assert!(archived[0].is_archived); - assert!(archived[0].ingested_at.is_none()); + assert_eq!(archived.first().map(|s| &s.id), Some(&scratchpad2_id)); + assert!(archived.first().is_some_and(|s| s.is_archived)); + assert!(archived.first().is_some_and(|s| s.ingested_at.is_none())); + Ok(()) } #[tokio::test] - async fn test_archive_and_restore() { + async fn test_archive_and_restore() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; db.apply_migrations() .await - .expect("Failed to apply migrations"); + .with_context(|| "Failed to apply migrations".to_string())?; let user_id = "test_user"; let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string()); let scratchpad_id = scratchpad.id.clone(); - db.store_item(scratchpad).await.unwrap(); + db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?; let archived = Scratchpad::archive(&scratchpad_id, user_id, &db, true) .await - .expect("Failed to archive"); + .with_context(|| "Failed to archive".to_string())?; assert!(archived.is_archived); assert!(archived.archived_at.is_some()); assert!(archived.ingested_at.is_some()); let restored = Scratchpad::restore(&scratchpad_id, user_id, &db) .await - .expect("Failed to restore"); + .with_context(|| "Failed to restore".to_string())?; assert!(!restored.is_archived); assert!(restored.archived_at.is_none()); assert!(restored.ingested_at.is_none()); + Ok(()) } #[tokio::test] - async fn test_update_content() { + async fn test_update_content() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; db.apply_migrations() .await - .expect("Failed to apply migrations"); + .with_context(|| "Failed to apply migrations".to_string())?; let user_id = "test_user"; let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string()); let scratchpad_id = scratchpad.id.clone(); - db.store_item(scratchpad).await.unwrap(); + db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?; let new_content = "Updated content"; let updated = Scratchpad::update_content(&scratchpad_id, user_id, new_content, &db) .await - .unwrap(); + .with_context(|| "update_content".to_string())?; assert_eq!(updated.content, new_content); assert!(!updated.is_dirty); + Ok(()) } #[tokio::test] - async fn test_update_content_unauthorized() { + async fn test_update_content_unauthorized() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; db.apply_migrations() .await - .expect("Failed to apply migrations"); + .with_context(|| "Failed to apply migrations".to_string())?; let owner_id = "owner"; let other_user = "other_user"; let scratchpad = Scratchpad::new(owner_id.to_string(), "Test".to_string()); let scratchpad_id = scratchpad.id.clone(); - db.store_item(scratchpad).await.unwrap(); + db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?; let result = Scratchpad::update_content(&scratchpad_id, other_user, "Hacked", &db).await; assert!(result.is_err()); match result { Err(AppError::Auth(_)) => {} - _ => panic!("Expected Auth error"), + _ => anyhow::bail!("Expected Auth error"), } + Ok(()) } #[tokio::test] - async fn test_delete_scratchpad() { + async fn test_delete_scratchpad() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; db.apply_migrations() .await - .expect("Failed to apply migrations"); + .with_context(|| "Failed to apply migrations".to_string())?; let user_id = "test_user"; let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string()); let scratchpad_id = scratchpad.id.clone(); - db.store_item(scratchpad).await.unwrap(); + db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?; // Delete should succeed let result = Scratchpad::delete(&scratchpad_id, user_id, &db).await; assert!(result.is_ok()); // Verify it's gone - let retrieved: Option = db.get_item(&scratchpad_id).await.unwrap(); + let retrieved: Option = db.get_item(&scratchpad_id).await.with_context(|| "get_item".to_string())?; assert!(retrieved.is_none()); + Ok(()) } #[tokio::test] - async fn test_delete_unauthorized() { + async fn test_delete_unauthorized() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; db.apply_migrations() .await - .expect("Failed to apply migrations"); + .with_context(|| "Failed to apply migrations".to_string())?; let owner_id = "owner"; let other_user = "other_user"; let scratchpad = Scratchpad::new(owner_id.to_string(), "Test".to_string()); let scratchpad_id = scratchpad.id.clone(); - db.store_item(scratchpad).await.unwrap(); + db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?; let result = Scratchpad::delete(&scratchpad_id, other_user, &db).await; assert!(result.is_err()); match result { Err(AppError::Auth(_)) => {} - _ => panic!("Expected Auth error"), + _ => anyhow::bail!("Expected Auth error"), } // Verify it still exists - let retrieved: Option = db.get_item(&scratchpad_id).await.unwrap(); + let retrieved: Option = db.get_item(&scratchpad_id).await.with_context(|| "get_item".to_string())?; assert!(retrieved.is_some()); + Ok(()) } #[tokio::test] - async fn test_timezone_aware_scratchpad_conversion() { + async fn test_timezone_aware_scratchpad_conversion() -> anyhow::Result<()> { let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string()) .await - .expect("Failed to create test database"); + .with_context(|| "Failed to create test database".to_string())?; db.apply_migrations() .await - .expect("Failed to apply migrations"); + .with_context(|| "Failed to apply migrations".to_string())?; let user_id = "test_user_123"; let scratchpad = Scratchpad::new(user_id.to_string(), "Test Timezone Scratchpad".to_string()); let scratchpad_id = scratchpad.id.clone(); - db.store_item(scratchpad).await.unwrap(); + db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?; let retrieved = Scratchpad::get_by_id(&scratchpad_id, user_id, &db) .await - .unwrap(); + .with_context(|| "get_by_id".to_string())?; // Test that datetime fields are preserved and can be used for timezone formatting assert!(retrieved.created_at.timestamp() > 0); @@ -493,10 +502,11 @@ mod tests { // Archive the scratchpad to test optional datetime handling let archived = Scratchpad::archive(&scratchpad_id, user_id, &db, false) .await - .unwrap(); + .with_context(|| "archive".to_string())?; assert!(archived.archived_at.is_some()); - assert!(archived.archived_at.unwrap().timestamp() > 0); + assert!(archived.archived_at.with_context(|| "expected archived_at".to_string())?.timestamp() > 0); assert!(archived.ingested_at.is_none()); + Ok(()) } } diff --git a/common/src/storage/types/system_settings.rs b/common/src/storage/types/system_settings.rs index f387c61..57a1029 100644 --- a/common/src/storage/types/system_settings.rs +++ b/common/src/storage/types/system_settings.rs @@ -64,7 +64,14 @@ impl SystemSettings { let mut needs_update = false; let backend_label = provider.backend_label().to_string(); - let provider_dimensions = provider.dimension() as u32; + let provider_dimensions = u32::try_from(provider.dimension()) + .unwrap_or_else(|_| { + tracing::warn!( + "Provider dimension {} exceeds u32 max; falling back to 0", + provider.dimension() + ); + 0u32 + }); let provider_model = provider.model_code(); // Sync backend label @@ -107,7 +114,8 @@ impl SystemSettings { #[cfg(test)] mod tests { - use crate::storage::indexes::ensure_runtime_indexes; + use anyhow::{self, Context}; + use crate::storage::indexes::ensure_runtime; use crate::storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk}; use async_openai::Client; @@ -118,68 +126,102 @@ mod tests { db: &SurrealDbClient, table_name: &str, index_name: &str, - ) -> u32 { + ) -> anyhow::Result { let query = format!("INFO FOR TABLE {table_name};"); let mut response = db .client .query(query) .await - .expect("Failed to fetch table info"); + .with_context(|| "Failed to fetch table info".to_string())?; let info: surrealdb::Value = response .take(0) - .expect("Failed to extract table info response"); + .with_context(|| "Failed to extract table info response".to_string())?; let info_json: serde_json::Value = - serde_json::to_value(info).expect("Failed to convert info to json"); + serde_json::to_value(info).with_context(|| "Failed to convert info to json".to_string())?; - let indexes = info_json["Object"]["indexes"]["Object"] - .as_object() - .unwrap_or_else(|| panic!("Indexes collection missing in table info: {info_json:#?}")); + let indexes = info_json + .get("Object") + .and_then(|v| v.get("indexes")) + .and_then(|v| v.get("Object")) + .and_then(|v| v.as_object()) + .with_context(|| format!("Indexes collection missing in table info: {info_json:#?}"))?; let definition = indexes .get(index_name) .and_then(|definition| definition.get("Strand")) .and_then(|v| v.as_str()) - .unwrap_or_else(|| panic!("Index definition not found in table info: {info_json:#?}")); + .with_context(|| format!("Index definition not found in table info: {info_json:#?}"))?; let dimension_part = definition .split("DIMENSION") .nth(1) - .expect("Index definition missing DIMENSION clause"); + .with_context(|| "Index definition missing DIMENSION clause".to_string())?; let dimension_token = dimension_part .split_whitespace() .next() - .expect("Dimension value missing in definition") + .with_context(|| "Dimension value missing in definition".to_string())? .trim_end_matches(';'); dimension_token .parse::() - .expect("Dimension value is not a valid number") + .with_context(|| "Dimension value is not a valid number".to_string()) + } + + async fn simulate_reembedding( + db: &SurrealDbClient, + target_dimension: usize, + initial_chunk: TextChunk, + ) -> anyhow::Result<()> { + db.query( + "REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding;", + ) + .await + .with_context(|| "remove index".to_string())?; + let define_index_query = format!( + "DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {target_dimension};" + ); + db.query(define_index_query) + .await + .with_context(|| "Re-defining index should succeed".to_string())?; + + let new_embedding = vec![0.5; target_dimension]; + let sql = "UPSERT type::thing('text_chunk_embedding', $id) SET chunk_id = type::thing('text_chunk', $id), embedding = $embedding, user_id = $user_id;"; + + db.client + .query(sql) + .bind(("id", initial_chunk.id.clone())) + .bind(("user_id", initial_chunk.user_id.clone())) + .bind(("embedding", new_embedding)) + .await + .with_context(|| "upsert embedding".to_string())?; + + Ok(()) } #[tokio::test] - async fn test_settings_initialization() { + async fn test_settings_initialization() -> anyhow::Result<()> { // Setup in-memory database for testing let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; // Test initialization of system settings db.apply_migrations() .await - .expect("Failed to apply migrations"); + .with_context(|| "Failed to apply migrations".to_string())?; let settings = SystemSettings::get_current(&db) .await - .expect("Failed to get system settings"); + .with_context(|| "Failed to get system settings".to_string())?; // Verify initial state after initialization assert_eq!(settings.id, "current"); - assert_eq!(settings.registrations_enabled, true); - assert_eq!(settings.require_email_verification, false); + assert!(settings.registrations_enabled); + assert!(!settings.require_email_verification); assert_eq!(settings.query_model, "gpt-4o-mini"); assert_eq!(settings.processing_model, "gpt-4o-mini"); assert_eq!(settings.image_processing_model, "gpt-4o-mini"); @@ -196,10 +238,10 @@ mod tests { // Test idempotency - ensure calling it again doesn't change anything db.apply_migrations() .await - .expect("Failed to apply migrations"); + .with_context(|| "Failed to apply migrations".to_string())?; let settings_again = SystemSettings::get_current(&db) .await - .expect("Failed to get settings after initialization"); + .with_context(|| "Failed to get settings after initialization".to_string())?; assert_eq!(settings.id, settings_again.id); assert_eq!( @@ -210,48 +252,52 @@ mod tests { settings.require_email_verification, settings_again.require_email_verification ); + Ok(()) } #[tokio::test] - async fn test_get_current_settings() { + async fn test_get_current_settings() -> anyhow::Result<()> { // Setup in-memory database for testing let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; // Initialize settings db.apply_migrations() .await - .expect("Failed to apply migrations"); + .with_context(|| "Failed to apply migrations".to_string())?; // Test get_current method let settings = SystemSettings::get_current(&db) .await - .expect("Failed to get current settings"); + .with_context(|| "Failed to get current settings".to_string())?; assert_eq!(settings.id, "current"); - assert_eq!(settings.registrations_enabled, true); - assert_eq!(settings.require_email_verification, false); + assert!(settings.registrations_enabled); + assert!(!settings.require_email_verification); + Ok(()) } #[tokio::test] - async fn test_update_settings() { + async fn test_update_settings() -> anyhow::Result<()> { // Setup in-memory database for testing let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; // Initialize settings db.apply_migrations() .await - .expect("Failed to apply migrations"); + .with_context(|| "Failed to apply migrations".to_string())?; // Create updated settings - let mut updated_settings = SystemSettings::get_current(&db).await.unwrap(); + let mut updated_settings = SystemSettings::get_current(&db) + .await + .with_context(|| "get_current".to_string())?; updated_settings.id = "current".to_string(); updated_settings.registrations_enabled = false; updated_settings.require_email_verification = true; @@ -260,31 +306,32 @@ mod tests { // Test update method let result = SystemSettings::update(&db, updated_settings) .await - .expect("Failed to update settings"); + .with_context(|| "Failed to update settings".to_string())?; assert_eq!(result.id, "current"); - assert_eq!(result.registrations_enabled, false); - assert_eq!(result.require_email_verification, true); + assert!(!result.registrations_enabled); + assert!(result.require_email_verification); assert_eq!(result.query_model, "gpt-4"); // Verify changes persisted by getting current settings let current = SystemSettings::get_current(&db) .await - .expect("Failed to get current settings after update"); + .with_context(|| "Failed to get current settings after update".to_string())?; - assert_eq!(current.registrations_enabled, false); - assert_eq!(current.require_email_verification, true); + assert!(!current.registrations_enabled); + assert!(current.require_email_verification); assert_eq!(current.query_model, "gpt-4"); + Ok(()) } #[tokio::test] - async fn test_get_current_nonexistent() { + async fn test_get_current_nonexistent() -> anyhow::Result<()> { // Setup in-memory database for testing let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; // Don't initialize settings and try to get them let result = SystemSettings::get_current(&db).await; @@ -294,21 +341,22 @@ mod tests { Err(AppError::NotFound(_)) => { // Expected error } - Err(e) => panic!("Expected NotFound error, got: {:?}", e), - Ok(_) => panic!("Expected error but got Ok"), + Err(e) => anyhow::bail!("Expected NotFound error, got: {e:?}"), + Ok(_) => anyhow::bail!("Expected error but got Ok"), } + Ok(()) } #[tokio::test] - async fn test_migration_after_changing_embedding_length() { + async fn test_migration_after_changing_embedding_length() -> anyhow::Result<()> { let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string()) .await - .expect("Failed to start DB"); + .with_context(|| "Failed to start DB".to_string())?; // Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536. db.apply_migrations() .await - .expect("Initial migration failed"); + .with_context(|| "Initial migration failed".to_string())?; let initial_chunk = TextChunk::new( "source1".into(), @@ -318,43 +366,11 @@ mod tests { TextChunk::store_with_embedding(initial_chunk.clone(), vec![0.1; 1536], &db) .await - .expect("Failed to store initial chunk with embedding"); - - async fn simulate_reembedding( - db: &SurrealDbClient, - target_dimension: usize, - initial_chunk: TextChunk, - ) { - db.query( - "REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding;", - ) - .await - .unwrap(); - let define_index_query = format!( - "DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};", - target_dimension - ); - db.query(define_index_query) - .await - .expect("Re-defining index should succeed"); - - let new_embedding = vec![0.5; target_dimension]; - let sql = "UPSERT type::thing('text_chunk_embedding', $id) SET chunk_id = type::thing('text_chunk', $id), embedding = $embedding, user_id = $user_id;"; - - let update_result = db - .client - .query(sql) - .bind(("id", initial_chunk.id.clone())) - .bind(("user_id", initial_chunk.user_id.clone())) - .bind(("embedding", new_embedding)) - .await; - - assert!(update_result.is_ok()); - } + .with_context(|| "Failed to store initial chunk with embedding".to_string())?; // Re-embed with the existing configured dimension to ensure migrations remain idempotent. let target_dimension = 1536usize; - simulate_reembedding(&db, target_dimension, initial_chunk).await; + simulate_reembedding(&db, target_dimension, initial_chunk).await?; let migration_result = db.apply_migrations().await; @@ -363,34 +379,35 @@ mod tests { "Migrations should not fail: {:?}", migration_result.err() ); + Ok(()) } #[tokio::test] - async fn test_should_change_embedding_length_on_indexes_when_switching_length() { + async fn test_should_change_embedding_length_on_indexes_when_switching_length() -> anyhow::Result<()> { let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string()) .await - .expect("Failed to start DB"); + .with_context(|| "Failed to start DB".to_string())?; // Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536. db.apply_migrations() .await - .expect("Initial migration failed"); + .with_context(|| "Initial migration failed".to_string())?; let mut current_settings = SystemSettings::get_current(&db) .await - .expect("Failed to load current settings"); + .with_context(|| "Failed to load current settings".to_string())?; // Ensure runtime indexes exist with the current embedding dimension so INFO queries succeed. - ensure_runtime_indexes(&db, current_settings.embedding_dimensions as usize) + ensure_runtime(&db, current_settings.embedding_dimensions as usize) .await - .expect("failed to build runtime indexes"); + .with_context(|| "failed to build runtime indexes".to_string())?; let initial_chunk_dimension = get_hnsw_index_dimension( &db, "text_chunk_embedding", "idx_embedding_text_chunk_embedding", ) - .await; + .await?; assert_eq!( initial_chunk_dimension, current_settings.embedding_dimensions, @@ -405,7 +422,7 @@ mod tests { let updated_settings = SystemSettings::update(&db, current_settings) .await - .expect("Failed to update settings"); + .with_context(|| "Failed to update settings".to_string())?; assert_eq!( updated_settings.embedding_dimensions, new_dimension, @@ -416,23 +433,23 @@ mod tests { TextChunk::update_all_embeddings(&db, &openai_client, &new_model, new_dimension) .await - .expect("TextChunk re-embedding should succeed on fresh DB"); + .with_context(|| "TextChunk re-embedding should succeed on fresh DB".to_string())?; KnowledgeEntity::update_all_embeddings(&db, &openai_client, &new_model, new_dimension) .await - .expect("KnowledgeEntity re-embedding should succeed on fresh DB"); + .with_context(|| "KnowledgeEntity re-embedding should succeed on fresh DB".to_string())?; let text_chunk_dimension = get_hnsw_index_dimension( &db, "text_chunk_embedding", "idx_embedding_text_chunk_embedding", ) - .await; + .await?; let knowledge_dimension = get_hnsw_index_dimension( &db, "knowledge_entity_embedding", "idx_embedding_knowledge_entity_embedding", ) - .await; + .await?; assert_eq!( text_chunk_dimension, new_dimension, @@ -445,10 +462,11 @@ mod tests { let persisted_settings = SystemSettings::get_current(&db) .await - .expect("Failed to reload updated settings"); + .with_context(|| "Failed to reload updated settings".to_string())?; assert_eq!( persisted_settings.embedding_dimensions, new_dimension, "Settings should persist new embedding dimension" ); + Ok(()) } } diff --git a/common/src/storage/types/text_chunk.rs b/common/src/storage/types/text_chunk.rs index 805c931..c336813 100644 --- a/common/src/storage/types/text_chunk.rs +++ b/common/src/storage/types/text_chunk.rs @@ -1,4 +1,4 @@ -#![allow(clippy::missing_docs_in_private_items, clippy::uninlined_format_args)] +#![allow(clippy::missing_docs_in_private_items)] use std::collections::HashMap; use std::fmt::Write; @@ -237,10 +237,7 @@ impl TextChunk { new_model: &str, new_dimensions: u32, ) -> Result<(), AppError> { - info!( - "Starting re-embedding process for all text chunks. New dimensions: {}", - new_dimensions - ); + info!("Starting re-embedding process for all text chunks. New dimensions: {new_dimensions}"); // Fetch all chunks first let all_chunks: Vec = db.select(Self::table_name()).await?; @@ -252,7 +249,7 @@ impl TextChunk { return Ok(()); } - info!("Found {} chunks to process.", total_chunks); + info!("Found {total_chunks} chunks to process."); // Generate all new embeddings in memory let mut new_embeddings: HashMap, String, String)> = HashMap::new(); @@ -276,7 +273,7 @@ impl TextChunk { "CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.", chunk.id, embedding.len(), new_dimensions ); - error!("{}", err_msg); + error!("{err_msg}"); return Err(AppError::InternalError(err_msg)); } new_embeddings.insert( @@ -300,6 +297,7 @@ impl TextChunk { .join(",") ); // Use the chunk id as the embedding record id to keep a 1:1 mapping + let embedding = embedding_str; write!( &mut transaction_query, "UPSERT type::thing('text_chunk_embedding', '{id}') SET \ @@ -309,18 +307,13 @@ impl TextChunk { user_id = '{user_id}', \ created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \ updated_at = time::now();", - id = id, - embedding = embedding_str, - user_id = user_id, - source_id = source_id ) .map_err(|e| AppError::InternalError(e.to_string()))?; } write!( &mut transaction_query, - "DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};", - new_dimensions + "DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};", ) .map_err(|e| AppError::InternalError(e.to_string()))?; @@ -377,7 +370,7 @@ impl TextChunk { "CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.", chunk.id, embedding.len(), new_dimensions ); - error!("{}", err_msg); + error!("{err_msg}"); return Err(AppError::InternalError(err_msg)); } new_embeddings.insert( @@ -422,6 +415,7 @@ impl TextChunk { .collect::>() .join(",") ); + let embedding = embedding_str; write!( &mut transaction_query, "CREATE type::thing('text_chunk_embedding', '{id}') SET \ @@ -431,18 +425,13 @@ impl TextChunk { user_id = '{user_id}', \ created_at = time::now(), \ updated_at = time::now();", - id = id, - embedding = embedding_str, - user_id = user_id, - source_id = source_id ) .map_err(|e| AppError::InternalError(e.to_string()))?; } write!( &mut transaction_query, - "DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};", - new_dimensions + "DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};", ) .map_err(|e| AppError::InternalError(e.to_string()))?; @@ -462,20 +451,21 @@ impl TextChunk { #[cfg(test)] mod tests { + use anyhow::{self, Context}; + use super::*; - use crate::storage::indexes::{ensure_runtime_indexes, rebuild_indexes}; + use crate::storage::indexes::{ensure_runtime, rebuild}; use crate::storage::types::text_chunk_embedding::TextChunkEmbedding; use surrealdb::RecordId; use uuid::Uuid; - async fn ensure_chunk_fts_index(db: &SurrealDbClient) { + async fn ensure_chunk_fts_index(db: &SurrealDbClient) -> anyhow::Result<()> { let snowball_sql = r#" DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english); DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25; "#; if let Err(err) = db.client.query(snowball_sql).await { - // Fall back to ascii-only analyzer when snowball is unavailable in the build. let fallback_sql = r#" DEFINE ANALYZER OVERWRITE app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii; DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25; @@ -484,12 +474,13 @@ mod tests { db.client .query(fallback_sql) .await - .unwrap_or_else(|_| panic!("define chunk fts index fallback: {err}")); + .with_context(|| format!("define chunk fts index fallback: {err}"))?; } + Ok(()) } #[tokio::test] - async fn test_text_chunk_creation() { + async fn test_text_chunk_creation() -> anyhow::Result<()> { let source_id = "source123".to_string(); let chunk = "This is a text chunk for testing embeddings".to_string(); let user_id = "user123".to_string(); @@ -500,22 +491,23 @@ mod tests { assert_eq!(text_chunk.chunk, chunk); assert_eq!(text_chunk.user_id, user_id); assert!(!text_chunk.id.is_empty()); + Ok(()) } #[tokio::test] - async fn test_delete_by_source_id() { + async fn test_delete_by_source_id() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); - db.apply_migrations().await.expect("migrations"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; + db.apply_migrations().await.with_context(|| "migrations".to_string())?; let source_id = "source123".to_string(); let user_id = "user123".to_string(); TextChunkEmbedding::redefine_hnsw_index(&db, 5) .await - .expect("redefine index"); + .with_context(|| "redefine index".to_string())?; let chunk1 = TextChunk::new( source_id.clone(), @@ -535,61 +527,63 @@ mod tests { TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db) .await - .expect("store chunk1"); + .with_context(|| "store chunk1".to_string())?; TextChunk::store_with_embedding(chunk2.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db) .await - .expect("store chunk2"); + .with_context(|| "store chunk2".to_string())?; TextChunk::store_with_embedding( different_chunk.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db, ) .await - .expect("store different chunk"); + .with_context(|| "store different chunk".to_string())?; TextChunk::delete_by_source_id(&source_id, &db) .await - .expect("Failed to delete chunks by source_id"); + .with_context(|| "Failed to delete chunks by source_id".to_string())?; let remaining: Vec = db .client .query(format!( - "SELECT * FROM {} WHERE source_id = '{}'", + "SELECT * FROM {} WHERE source_id = '{source_id}'", TextChunk::table_name(), - source_id )) .await - .expect("Query failed") + .with_context(|| "Query failed".to_string())? .take(0) - .expect("Failed to get query results"); + .with_context(|| "Failed to get query results".to_string())?; assert_eq!(remaining.len(), 0); let different_remaining: Vec = db .client .query(format!( - "SELECT * FROM {} WHERE source_id = '{}'", + "SELECT * FROM {} WHERE source_id = 'different_source'", TextChunk::table_name(), - "different_source" )) .await - .expect("Query failed") + .with_context(|| "Query failed".to_string())? .take(0) - .expect("Failed to get query results"); + .with_context(|| "Failed to get query results".to_string())?; assert_eq!(different_remaining.len(), 1); - assert_eq!(different_remaining[0].id, different_chunk.id); + assert_eq!( + different_remaining.first().map(|r| &r.id), + Some(&different_chunk.id) + ); + Ok(()) } #[tokio::test] - async fn test_delete_by_nonexistent_source_id() { + async fn test_delete_by_nonexistent_source_id() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); - db.apply_migrations().await.expect("migrations"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; + db.apply_migrations().await.with_context(|| "migrations".to_string())?; TextChunkEmbedding::redefine_hnsw_index(&db, 5) .await - .expect("redefine index"); + .with_context(|| "redefine index".to_string())?; let real_source_id = "real_source".to_string(); let chunk = TextChunk::new( @@ -600,24 +594,24 @@ mod tests { TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db) .await - .expect("store chunk"); + .with_context(|| "store chunk".to_string())?; TextChunk::delete_by_source_id("nonexistent_source", &db) .await - .expect("Delete should succeed"); + .with_context(|| "Delete should succeed".to_string())?; let remaining: Vec = db .client .query(format!( - "SELECT * FROM {} WHERE source_id = '{}'", + "SELECT * FROM {} WHERE source_id = '{real_source_id}'", TextChunk::table_name(), - real_source_id )) .await - .expect("Query failed") + .with_context(|| "Query failed".to_string())? .take(0) - .expect("Failed to get query results"); + .with_context(|| "Failed to get query results".to_string())?; assert_eq!(remaining.len(), 1); + Ok(()) } #[tokio::test] @@ -672,13 +666,13 @@ mod tests { } #[tokio::test] - async fn test_store_with_embedding_creates_both_records() { + async fn test_store_with_embedding_creates_both_records() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); - db.apply_migrations().await.expect("migrations"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; + db.apply_migrations().await.with_context(|| "migrations".to_string())?; let source_id = "store-src".to_string(); let user_id = "user_store".to_string(); @@ -686,43 +680,43 @@ mod tests { TextChunkEmbedding::redefine_hnsw_index(&db, 3) .await - .expect("redefine index"); + .with_context(|| "redefine index".to_string())?; TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db) .await - .expect("store with embedding"); + .with_context(|| "store with embedding".to_string())?; - let stored_chunk: Option = db.get_item(&chunk.id).await.unwrap(); - assert!(stored_chunk.is_some()); - let stored_chunk = stored_chunk.unwrap(); + let stored_chunk: Option = db.get_item(&chunk.id) + .await + .with_context(|| "get_item".to_string())?; + let stored_chunk = stored_chunk.with_context(|| "expected stored chunk".to_string())?; assert_eq!(stored_chunk.source_id, source_id); assert_eq!(stored_chunk.user_id, user_id); let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id); let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db) .await - .expect("get embedding"); - assert!(embedding.is_some()); - let embedding = embedding.unwrap(); + .with_context(|| "get embedding".to_string())? + .with_context(|| "expected embedding".to_string())?; assert_eq!(embedding.chunk_id, rid); assert_eq!(embedding.user_id, user_id); assert_eq!(embedding.source_id, source_id); + Ok(()) } #[tokio::test] - async fn test_store_with_embedding_with_runtime_indexes() { + async fn test_store_with_embedding_with_runtime_indexes() -> anyhow::Result<()> { let namespace = "test_ns_runtime"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); - db.apply_migrations().await.expect("migrations"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; + db.apply_migrations().await.with_context(|| "migrations".to_string())?; - // Ensure runtime indexes are built with the expected dimension. let embedding_dimension = 3usize; - ensure_runtime_indexes(&db, embedding_dimension) + ensure_runtime(&db, embedding_dimension) .await - .expect("ensure runtime indexes"); + .with_context(|| "ensure runtime indexes".to_string())?; let chunk = TextChunk::new( "runtime_src".to_string(), @@ -732,55 +726,60 @@ mod tests { TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db) .await - .expect("store with embedding"); + .with_context(|| "store with embedding".to_string())?; - let stored_chunk: Option = db.get_item(&chunk.id).await.unwrap(); - assert!(stored_chunk.is_some(), "chunk should be stored"); + let stored_chunk: Option = db.get_item(&chunk.id) + .await + .with_context(|| "get_item".to_string())?; + let stored_chunk = stored_chunk.with_context(|| "chunk should be stored".to_string())?; + assert!(stored_chunk.id == chunk.id, "chunk should be stored"); let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id); let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db) .await - .expect("get embedding"); - assert!(embedding.is_some(), "embedding should exist"); + .with_context(|| "get embedding".to_string())? + .with_context(|| "embedding should exist".to_string())?; assert_eq!( - embedding.unwrap().embedding.len(), + embedding.embedding.len(), embedding_dimension, "embedding dimension should match runtime index" ); + Ok(()) } #[tokio::test] - async fn test_vector_search_returns_empty_when_no_embeddings() { + async fn test_vector_search_returns_empty_when_no_embeddings() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); - db.apply_migrations().await.expect("migrations"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; + db.apply_migrations().await.with_context(|| "migrations".to_string())?; TextChunkEmbedding::redefine_hnsw_index(&db, 3) .await - .expect("redefine index"); + .with_context(|| "redefine index".to_string())?; let results: Vec = TextChunk::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user") .await - .unwrap(); + .with_context(|| "vector_search".to_string())?; assert!(results.is_empty()); + Ok(()) } #[tokio::test] - async fn test_vector_search_single_result() { + async fn test_vector_search_single_result() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); - db.apply_migrations().await.expect("migrations"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; + db.apply_migrations().await.with_context(|| "migrations".to_string())?; TextChunkEmbedding::redefine_hnsw_index(&db, 3) .await - .expect("redefine index"); + .with_context(|| "redefine index".to_string())?; let source_id = "src".to_string(); let user_id = "user".to_string(); @@ -792,32 +791,33 @@ mod tests { TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db) .await - .expect("store"); + .with_context(|| "store".to_string())?; let results: Vec = TextChunk::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id) .await - .unwrap(); + .with_context(|| "vector_search".to_string())?; assert_eq!(results.len(), 1); - let res = &results[0]; + let res = results.first().context("expected first result")?; assert_eq!(res.chunk.id, chunk.id); assert_eq!(res.chunk.source_id, source_id); assert_eq!(res.chunk.chunk, "hello world"); + Ok(()) } #[tokio::test] - async fn test_vector_search_orders_by_similarity() { + async fn test_vector_search_orders_by_similarity() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); - db.apply_migrations().await.expect("migrations"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; + db.apply_migrations().await.with_context(|| "migrations".to_string())?; TextChunkEmbedding::redefine_hnsw_index(&db, 3) .await - .expect("redefine index"); + .with_context(|| "redefine index".to_string())?; let user_id = "user".to_string(); let chunk1 = TextChunk::new("s1".to_string(), "chunk one".to_string(), user_id.clone()); @@ -825,49 +825,59 @@ mod tests { TextChunk::store_with_embedding(chunk1.clone(), vec![1.0, 0.0, 0.0], &db) .await - .expect("store chunk1"); + .with_context(|| "store chunk1".to_string())?; TextChunk::store_with_embedding(chunk2.clone(), vec![0.0, 1.0, 0.0], &db) .await - .expect("store chunk2"); + .with_context(|| "store chunk2".to_string())?; let results: Vec = TextChunk::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id) .await - .unwrap(); + .with_context(|| "vector_search".to_string())?; assert_eq!(results.len(), 2); - assert_eq!(results[0].chunk.id, chunk2.id); - assert_eq!(results[1].chunk.id, chunk1.id); - assert!(results[0].score >= results[1].score); + assert_eq!( + results.first().map(|r| &r.chunk.id), + Some(&chunk2.id) + ); + assert_eq!( + results.get(1).map(|r| &r.chunk.id), + Some(&chunk1.id) + ); + let r0 = results.first().context("expected first result")?; + let r1 = results.get(1).context("expected second result")?; + assert!(r0.score >= r1.score); + Ok(()) } #[tokio::test] - async fn test_fts_search_returns_empty_when_no_chunks() { + async fn test_fts_search_returns_empty_when_no_chunks() -> anyhow::Result<()> { let namespace = "fts_chunk_ns_empty"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); - db.apply_migrations().await.expect("migrations"); - ensure_chunk_fts_index(&db).await; - rebuild_indexes(&db).await.expect("rebuild indexes"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; + db.apply_migrations().await.with_context(|| "migrations".to_string())?; + ensure_chunk_fts_index(&db).await?; + rebuild(&db).await.with_context(|| "rebuild indexes".to_string())?; let results = TextChunk::fts_search(5, "hello", &db, "user") .await - .expect("fts search"); + .with_context(|| "fts search".to_string())?; assert!(results.is_empty()); + Ok(()) } #[tokio::test] - async fn test_fts_search_single_result() { + async fn test_fts_search_single_result() -> anyhow::Result<()> { let namespace = "fts_chunk_ns_single"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); - db.apply_migrations().await.expect("migrations"); - ensure_chunk_fts_index(&db).await; + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; + db.apply_migrations().await.with_context(|| "migrations".to_string())?; + ensure_chunk_fts_index(&db).await?; let user_id = "fts_user"; let chunk = TextChunk::new( @@ -875,27 +885,29 @@ mod tests { "rustaceans love rust".to_string(), user_id.to_string(), ); - db.store_item(chunk.clone()).await.expect("store chunk"); - rebuild_indexes(&db).await.expect("rebuild indexes"); + db.store_item(chunk.clone()).await.with_context(|| "store chunk".to_string())?; + rebuild(&db).await.with_context(|| "rebuild indexes".to_string())?; let results = TextChunk::fts_search(3, "rust", &db, user_id) .await - .expect("fts search"); + .with_context(|| "fts search".to_string())?; assert_eq!(results.len(), 1); - assert_eq!(results[0].chunk.id, chunk.id); - assert!(results[0].score.is_finite(), "expected a finite FTS score"); + let r0 = results.first().context("expected first result")?; + assert_eq!(r0.chunk.id, chunk.id); + assert!(r0.score.is_finite(), "expected a finite FTS score"); + Ok(()) } #[tokio::test] - async fn test_fts_search_orders_by_score_and_filters_user() { + async fn test_fts_search_orders_by_score_and_filters_user() -> anyhow::Result<()> { let namespace = "fts_chunk_ns_order"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); - db.apply_migrations().await.expect("migrations"); - ensure_chunk_fts_index(&db).await; + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; + db.apply_migrations().await.with_context(|| "migrations".to_string())?; + ensure_chunk_fts_index(&db).await?; let user_id = "fts_user_order"; let high_score_chunk = TextChunk::new( @@ -916,18 +928,18 @@ mod tests { db.store_item(high_score_chunk.clone()) .await - .expect("store high score chunk"); + .with_context(|| "store high score chunk".to_string())?; db.store_item(low_score_chunk.clone()) .await - .expect("store low score chunk"); + .with_context(|| "store low score chunk".to_string())?; db.store_item(other_user_chunk) .await - .expect("store other user chunk"); - rebuild_indexes(&db).await.expect("rebuild indexes"); + .with_context(|| "store other user chunk".to_string())?; + rebuild(&db).await.with_context(|| "rebuild indexes".to_string())?; let results = TextChunk::fts_search(3, "apple", &db, user_id) .await - .expect("fts search"); + .with_context(|| "fts search".to_string())?; assert_eq!(results.len(), 2); let ids: Vec<_> = results.iter().map(|r| r.chunk.id.as_str()).collect(); @@ -936,9 +948,12 @@ mod tests { && ids.contains(&low_score_chunk.id.as_str()), "expected only the two chunks for the same user" ); + let r0 = results.first().context("expected first result")?; + let r1 = results.get(1).context("expected second result")?; assert!( - results[0].score >= results[1].score, + r0.score >= r1.score, "expected results ordered by descending score" ); + Ok(()) } } diff --git a/common/src/storage/types/text_chunk_embedding.rs b/common/src/storage/types/text_chunk_embedding.rs index 90a3da6..13caee1 100644 --- a/common/src/storage/types/text_chunk_embedding.rs +++ b/common/src/storage/types/text_chunk_embedding.rs @@ -126,24 +126,26 @@ impl TextChunkEmbedding { #[cfg(test)] mod tests { + use anyhow::{self, Context}; + use super::*; use crate::storage::db::SurrealDbClient; use surrealdb::Value as SurrealValue; use uuid::Uuid; /// Helper to create an in-memory DB and apply migrations - async fn setup_test_db() -> SurrealDbClient { + async fn setup_test_db() -> anyhow::Result { let namespace = "test_ns"; let database = Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, &database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; db.apply_migrations() .await - .expect("Failed to apply migrations"); + .with_context(|| "Failed to apply migrations".to_string())?; - db + Ok(db) } /// Helper: create a text_chunk with a known key, return its RecordId @@ -152,7 +154,7 @@ mod tests { key: &str, source_id: &str, user_id: &str, - ) -> RecordId { + ) -> anyhow::Result { let chunk = TextChunk { id: key.to_owned(), created_at: Utc::now(), @@ -164,21 +166,42 @@ mod tests { db.store_item(chunk) .await - .expect("Failed to create text_chunk"); + .with_context(|| "Failed to create text_chunk".to_string())?; - RecordId::from_table_key(TextChunk::table_name(), key) + Ok(RecordId::from_table_key(TextChunk::table_name(), key)) + } + + async fn get_idx_sql(db: &SurrealDbClient) -> anyhow::Result { + let mut info_res = db + .client + .query("INFO FOR TABLE text_chunk_embedding;") + .await + .with_context(|| "info query failed".to_string())?; + let info: SurrealValue = info_res.take(0).with_context(|| "failed to take info result".to_string())?; + let info_json: serde_json::Value = + serde_json::to_value(info).with_context(|| "failed to convert info to json".to_string())?; + let idx_sql = info_json + .get("Object") + .and_then(|v| v.get("indexes")) + .and_then(|v| v.get("Object")) + .and_then(|v| v.get("idx_embedding_text_chunk_embedding")) + .and_then(|v| v.get("Strand")) + .and_then(|v| v.as_str()) + .unwrap_or_default() + .to_string(); + Ok(idx_sql) } #[tokio::test] - async fn test_create_and_get_by_chunk_id() { - let db = setup_test_db().await; + async fn test_create_and_get_by_chunk_id() -> anyhow::Result<()> { + let db = setup_test_db().await?; let user_id = "user_a"; let chunk_key = "chunk-123"; let source_id = "source-1"; // 1) Create a text_chunk with a known key - let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await; + let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await?; // 2) Create and store an embedding for that chunk let embedding_vec = vec![0.1_f32, 0.2, 0.3]; @@ -191,39 +214,37 @@ mod tests { TextChunkEmbedding::redefine_hnsw_index(&db, emb.embedding.len()) .await - .expect("Failed to redefine index length"); + .with_context(|| "Failed to redefine index length".to_string())?; let _: Option = db .client .create(TextChunkEmbedding::table_name()) .content(emb) .await - .expect("Failed to store embedding") - .take() - .expect("Failed to deserialize stored embedding"); + .with_context(|| "Failed to store embedding".to_string())? + .with_context(|| "Failed to deserialize stored embedding".to_string())?; // 3) Fetch it via get_by_chunk_id let fetched = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db) .await - .expect("Failed to get embedding by chunk_id"); - - assert!(fetched.is_some(), "Expected an embedding to be found"); - let fetched = fetched.unwrap(); + .with_context(|| "Failed to get embedding by chunk_id".to_string())? + .with_context(|| "Expected an embedding to be found".to_string())?; assert_eq!(fetched.user_id, user_id); assert_eq!(fetched.chunk_id, chunk_rid); assert_eq!(fetched.embedding, embedding_vec); + Ok(()) } #[tokio::test] - async fn test_delete_by_chunk_id() { - let db = setup_test_db().await; + async fn test_delete_by_chunk_id() -> anyhow::Result<()> { + let db = setup_test_db().await?; let user_id = "user_b"; let chunk_key = "chunk-delete"; let source_id = "source-del"; - let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await; + let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await?; let emb = TextChunkEmbedding::new( chunk_key, @@ -234,50 +255,50 @@ mod tests { TextChunkEmbedding::redefine_hnsw_index(&db, emb.embedding.len()) .await - .expect("Failed to redefine index length"); + .with_context(|| "Failed to redefine index length".to_string())?; let _: Option = db .client .create(TextChunkEmbedding::table_name()) .content(emb) .await - .expect("Failed to store embedding") - .take() - .expect("Failed to deserialize stored embedding"); + .with_context(|| "Failed to store embedding".to_string())? + .with_context(|| "Failed to deserialize stored embedding".to_string())?; // Ensure it exists let existing = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db) .await - .expect("Failed to get embedding before delete"); + .with_context(|| "Failed to get embedding before delete".to_string())?; assert!(existing.is_some(), "Embedding should exist before delete"); // Delete by chunk_id TextChunkEmbedding::delete_by_chunk_id(&chunk_rid, &db) .await - .expect("Failed to delete by chunk_id"); + .with_context(|| "Failed to delete by chunk_id".to_string())?; // Ensure it no longer exists let after = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db) .await - .expect("Failed to get embedding after delete"); + .with_context(|| "Failed to get embedding after delete".to_string())?; assert!(after.is_none(), "Embedding should have been deleted"); + Ok(()) } #[tokio::test] - async fn test_delete_by_source_id() { - let db = setup_test_db().await; + async fn test_delete_by_source_id() -> anyhow::Result<()> { + let db = setup_test_db().await?; let user_id = "user_c"; let source_id = "shared-source"; let other_source = "other-source"; // Two chunks with the same source_id - let chunk1_rid = create_text_chunk_with_id(&db, "chunk-s1", source_id, user_id).await; - let chunk2_rid = create_text_chunk_with_id(&db, "chunk-s2", source_id, user_id).await; + let chunk1_rid = create_text_chunk_with_id(&db, "chunk-s1", source_id, user_id).await?; + let chunk2_rid = create_text_chunk_with_id(&db, "chunk-s2", source_id, user_id).await?; // One chunk with a different source_id let chunk_other_rid = - create_text_chunk_with_id(&db, "chunk-other", other_source, user_id).await; + create_text_chunk_with_id(&db, "chunk-other", other_source, user_id).await?; // Create embeddings for all three let emb1 = TextChunkEmbedding::new( @@ -302,7 +323,7 @@ mod tests { // Update length on index TextChunkEmbedding::redefine_hnsw_index(&db, emb1.embedding.len()) .await - .expect("Failed to redefine index length"); + .with_context(|| "Failed to redefine index length".to_string())?; for emb in [emb1, emb2, emb3] { let _: Option = db @@ -310,102 +331,82 @@ mod tests { .create(TextChunkEmbedding::table_name()) .content(emb) .await - .expect("Failed to store embedding") - .take() - .expect("Failed to deserialize stored embedding"); + .with_context(|| "Failed to store embedding".to_string())? + .with_context(|| "Failed to deserialize stored embedding".to_string())?; } // Sanity check: they all exist assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db) .await - .unwrap() + .with_context(|| "get chunk1".to_string())? .is_some()); assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db) .await - .unwrap() + .with_context(|| "get chunk2".to_string())? .is_some()); assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db) .await - .unwrap() + .with_context(|| "get chunk_other".to_string())? .is_some()); // Delete embeddings by source_id (shared-source) TextChunkEmbedding::delete_by_source_id(source_id, &db) .await - .expect("Failed to delete by source_id"); + .with_context(|| "Failed to delete by source_id".to_string())?; // Chunks from shared-source should have no embeddings assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db) .await - .unwrap() + .with_context(|| "check chunk1".to_string())? .is_none()); assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db) .await - .unwrap() + .with_context(|| "check chunk2".to_string())? .is_none()); // The other chunk should still have its embedding assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db) .await - .unwrap() + .with_context(|| "check chunk_other".to_string())? .is_some()); + Ok(()) } #[tokio::test] - async fn test_redefine_hnsw_index_updates_dimension() { - let db = setup_test_db().await; + async fn test_redefine_hnsw_index_updates_dimension() -> anyhow::Result<()> { + let db = setup_test_db().await?; // Change the index dimension from default (1536) to a smaller test value. TextChunkEmbedding::redefine_hnsw_index(&db, 8) .await - .expect("failed to redefine index"); + .with_context(|| "failed to redefine index".to_string())?; - let mut info_res = db - .client - .query("INFO FOR TABLE text_chunk_embedding;") - .await - .expect("info query failed"); - let info: SurrealValue = info_res.take(0).expect("failed to take info result"); - let info_json: serde_json::Value = - serde_json::to_value(info).expect("failed to convert info to json"); - let idx_sql = info_json["Object"]["indexes"]["Object"] - ["idx_embedding_text_chunk_embedding"]["Strand"] - .as_str() - .unwrap_or_default(); + let idx_sql = get_idx_sql(&db).await?; assert!( idx_sql.contains("DIMENSION 8"), "expected index definition to contain new dimension, got: {idx_sql}" ); + Ok(()) } #[tokio::test] - async fn test_redefine_hnsw_index_is_idempotent() { - let db = setup_test_db().await; + async fn test_redefine_hnsw_index_is_idempotent() -> anyhow::Result<()> { + let db = setup_test_db().await?; TextChunkEmbedding::redefine_hnsw_index(&db, 4) .await - .expect("first redefine failed"); + .with_context(|| "first redefine failed".to_string())?; TextChunkEmbedding::redefine_hnsw_index(&db, 4) .await - .expect("second redefine failed"); + .with_context(|| "second redefine failed".to_string())?; - let mut info_res = db - .client - .query("INFO FOR TABLE text_chunk_embedding;") - .await - .expect("info query failed"); - let info: SurrealValue = info_res.take(0).expect("failed to take info result"); - let info_json: serde_json::Value = - serde_json::to_value(info).expect("failed to convert info to json"); - let idx_sql = info_json["Object"]["indexes"]["Object"] - ["idx_embedding_text_chunk_embedding"]["Strand"] - .as_str() - .unwrap_or_default(); + let idx_sql = get_idx_sql(&db).await?; assert!( idx_sql.contains("DIMENSION 4"), "expected index definition to retain dimension 4, got: {idx_sql}" ); + Ok(()) } } diff --git a/common/src/storage/types/text_content.rs b/common/src/storage/types/text_content.rs index 02f42a8..fab4c0c 100644 --- a/common/src/storage/types/text_content.rs +++ b/common/src/storage/types/text_content.rs @@ -185,10 +185,12 @@ impl TextContent { #[cfg(test)] mod tests { + use anyhow::{self, Context}; + use super::*; #[tokio::test] - async fn test_text_content_creation() { + async fn test_text_content_creation() -> anyhow::Result<()> { // Test basic object creation let text = "Test content text".to_string(); let context = "Test context".to_string(); @@ -212,10 +214,11 @@ mod tests { assert!(text_content.file_info.is_none()); assert!(text_content.url_info.is_none()); assert!(!text_content.id.is_empty()); + Ok(()) } #[tokio::test] - async fn test_text_content_with_url() { + async fn test_text_content_with_url() -> anyhow::Result<()> { // Test creating with URL let text = "Content with URL".to_string(); let context = "URL context".to_string(); @@ -232,26 +235,27 @@ mod tests { }); let text_content = TextContent::new( - text.clone(), - Some(context.clone()), - category.clone(), + text, + Some(context), + category, None, url_info.clone(), - user_id.clone(), + user_id, ); // Check URL field is set assert_eq!(text_content.url_info, url_info); + Ok(()) } #[tokio::test] - async fn test_text_content_patch() { + async fn test_text_content_patch() -> anyhow::Result<()> { // Setup in-memory database for testing let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; // Create initial text content let initial_text = "Initial text".to_string(); @@ -272,7 +276,7 @@ mod tests { let stored: Option = db .store_item(text_content.clone()) .await - .expect("Failed to store text content"); + .with_context(|| "Failed to store text content".to_string())?; assert!(stored.is_some()); // New values for patch @@ -283,31 +287,30 @@ mod tests { // Apply the patch TextContent::patch(&text_content.id, new_context, new_category, new_text, &db) .await - .expect("Failed to patch text content"); + .with_context(|| "Failed to patch text content".to_string())?; // Retrieve the updated content let updated: Option = db .get_item(&text_content.id) .await - .expect("Failed to get updated text content"); - assert!(updated.is_some()); - - let updated_content = updated.unwrap(); + .with_context(|| "Failed to get updated text content".to_string())?; + let updated_content = updated.with_context(|| "expected updated content".to_string())?; // Verify the updates assert_eq!(updated_content.context, Some(new_context.to_string())); assert_eq!(updated_content.category, new_category); assert_eq!(updated_content.text, new_text); assert!(updated_content.updated_at > text_content.updated_at); + Ok(()) } #[tokio::test] - async fn test_has_other_with_file_detects_shared_usage() { + async fn test_has_other_with_file_detects_shared_usage() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; let user_id = "user123".to_string(); let file_info = FileInfo { @@ -340,24 +343,25 @@ mod tests { db.store_item(content_a.clone()) .await - .expect("Failed to store first content"); + .with_context(|| "Failed to store first content".to_string())?; db.store_item(content_b.clone()) .await - .expect("Failed to store second content"); + .with_context(|| "Failed to store second content".to_string())?; let has_other = TextContent::has_other_with_file(&file_info.id, &content_a.id, &db) .await - .expect("Failed to check for shared file usage"); + .with_context(|| "Failed to check for shared file usage".to_string())?; assert!(has_other); let _removed: Option = db .delete_item(&content_b.id) .await - .expect("Failed to delete second content"); + .with_context(|| "Failed to delete second content".to_string())?; let has_other_after = TextContent::has_other_with_file(&file_info.id, &content_a.id, &db) .await - .expect("Failed to check shared usage after delete"); + .with_context(|| "Failed to check shared usage after delete".to_string())?; assert!(!has_other_after); + Ok(()) } } diff --git a/common/src/storage/types/user.rs b/common/src/storage/types/user.rs index 64d7574..769d462 100644 --- a/common/src/storage/types/user.rs +++ b/common/src/storage/types/user.rs @@ -723,30 +723,32 @@ impl User { #[cfg(test)] mod tests { + use anyhow::{self, Context}; + use super::*; use crate::storage::types::ingestion_payload::IngestionPayload; use crate::storage::types::ingestion_task::{IngestionTask, TaskState, MAX_ATTEMPTS}; use std::collections::HashSet; // Helper function to set up a test database with SystemSettings - async fn setup_test_db() -> SurrealDbClient { + async fn setup_test_db() -> anyhow::Result { let namespace = "test_ns"; let database = Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, &database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; db.apply_migrations() .await - .expect("Failed to setup the migrations"); + .with_context(|| "Failed to setup the migrations".to_string())?; - db + Ok(db) } #[tokio::test] - async fn test_user_creation() { + async fn test_user_creation() -> anyhow::Result<()> { // Setup test database - let db = setup_test_db().await; + let db = setup_test_db().await?; // Create a user let email = "test@example.com"; @@ -761,7 +763,7 @@ mod tests { "system".to_string(), ) .await - .expect("Failed to create user"); + .with_context(|| "Failed to create user".to_string())?; // Verify user properties assert!(!user.id.is_empty()); @@ -774,18 +776,17 @@ mod tests { let retrieved: Option = db .get_item(&user.id) .await - .expect("Failed to retrieve user"); - assert!(retrieved.is_some()); - - let retrieved = retrieved.unwrap(); + .with_context(|| "Failed to retrieve user".to_string())?; + let retrieved = retrieved.with_context(|| "expected user to exist".to_string())?; assert_eq!(retrieved.id, user.id); assert_eq!(retrieved.email, email); + Ok(()) } #[tokio::test] - async fn test_user_authentication() { + async fn test_user_authentication() -> anyhow::Result<()> { // Setup test database - let db = setup_test_db().await; + let db = setup_test_db().await?; // Create a user let email = "auth_test@example.com"; @@ -799,7 +800,7 @@ mod tests { "system".to_string(), ) .await - .expect("Failed to create user"); + .with_context(|| "Failed to create user".to_string())?; // Test successful authentication let auth_result = User::authenticate(email, password, &db).await; @@ -812,11 +813,12 @@ mod tests { // Test failed authentication with non-existent user let nonexistent = User::authenticate("nonexistent@example.com", password, &db).await; assert!(nonexistent.is_err()); + Ok(()) } #[tokio::test] - async fn test_get_unfinished_ingestion_tasks_filters_correctly() { - let db = setup_test_db().await; + async fn test_get_unfinished_ingestion_tasks_filters_correctly() -> anyhow::Result<()> { + let db = setup_test_db().await?; let user_id = "unfinished_user"; let other_user_id = "other_user"; @@ -830,14 +832,14 @@ mod tests { let created_task = IngestionTask::new(payload.clone(), user_id.to_string()); db.store_item(created_task.clone()) .await - .expect("Failed to store created task"); + .with_context(|| "Failed to store created task".to_string())?; let mut processing_task = IngestionTask::new(payload.clone(), user_id.to_string()); processing_task.state = TaskState::Processing; processing_task.attempts = 1; db.store_item(processing_task.clone()) .await - .expect("Failed to store processing task"); + .with_context(|| "Failed to store processing task".to_string())?; let mut failed_retry_task = IngestionTask::new(payload.clone(), user_id.to_string()); failed_retry_task.state = TaskState::Failed; @@ -845,7 +847,7 @@ mod tests { failed_retry_task.scheduled_at = chrono::Utc::now() - chrono::Duration::minutes(5); db.store_item(failed_retry_task.clone()) .await - .expect("Failed to store retryable failed task"); + .with_context(|| "Failed to store retryable failed task".to_string())?; let mut failed_blocked_task = IngestionTask::new(payload.clone(), user_id.to_string()); failed_blocked_task.state = TaskState::Failed; @@ -853,13 +855,13 @@ mod tests { failed_blocked_task.error_message = Some("Too many failures".into()); db.store_item(failed_blocked_task.clone()) .await - .expect("Failed to store blocked task"); + .with_context(|| "Failed to store blocked task".to_string())?; let mut completed_task = IngestionTask::new(payload.clone(), user_id.to_string()); completed_task.state = TaskState::Succeeded; db.store_item(completed_task.clone()) .await - .expect("Failed to store completed task"); + .with_context(|| "Failed to store completed task".to_string())?; let other_payload = IngestionPayload::Text { text: "Other".to_string(), @@ -870,11 +872,11 @@ mod tests { let other_task = IngestionTask::new(other_payload, other_user_id.to_string()); db.store_item(other_task) .await - .expect("Failed to store other user task"); + .with_context(|| "Failed to store other user task".to_string())?; let unfinished = User::get_unfinished_ingestion_tasks(user_id, &db) .await - .expect("Failed to fetch unfinished tasks"); + .with_context(|| "Failed to fetch unfinished tasks".to_string())?; let unfinished_ids: HashSet = unfinished.iter().map(|task| task.id.clone()).collect(); @@ -885,11 +887,12 @@ mod tests { assert!(!unfinished_ids.contains(&failed_blocked_task.id)); assert!(!unfinished_ids.contains(&completed_task.id)); assert_eq!(unfinished_ids.len(), 3); + Ok(()) } #[tokio::test] - async fn test_get_all_ingestion_tasks_returns_sorted() { - let db = setup_test_db().await; + async fn test_get_all_ingestion_tasks_returns_sorted() -> anyhow::Result<()> { + let db = setup_test_db().await?; let user_id = "archive_user"; let other_user_id = "other_user"; @@ -902,15 +905,15 @@ mod tests { // Oldest task let mut first = IngestionTask::new(payload.clone(), user_id.to_string()); - first.created_at = first.created_at - chrono::Duration::minutes(1); + first.created_at -= chrono::Duration::minutes(1); first.updated_at = first.created_at; first.state = TaskState::Succeeded; - db.store_item(first.clone()).await.expect("store first"); + db.store_item(first.clone()).await.with_context(|| "store first".to_string())?; // Latest task let mut second = IngestionTask::new(payload.clone(), user_id.to_string()); second.state = TaskState::Processing; - db.store_item(second.clone()).await.expect("store second"); + db.store_item(second.clone()).await.with_context(|| "store second".to_string())?; let other_payload = IngestionPayload::Text { text: "Other".to_string(), @@ -919,21 +922,22 @@ mod tests { user_id: other_user_id.to_string(), }; let other_task = IngestionTask::new(other_payload, other_user_id.to_string()); - db.store_item(other_task).await.expect("store other"); + db.store_item(other_task).await.with_context(|| "store other".to_string())?; let tasks = User::get_all_ingestion_tasks(user_id, &db) .await - .expect("fetch all tasks"); + .with_context(|| "fetch all tasks".to_string())?; assert_eq!(tasks.len(), 2); - assert_eq!(tasks[0].id, second.id); // newest first - assert_eq!(tasks[1].id, first.id); + assert_eq!(tasks.first().map(|t| &t.id), Some(&second.id)); // newest first + assert_eq!(tasks.get(1).map(|t| &t.id), Some(&first.id)); + Ok(()) } #[tokio::test] - async fn test_find_by_email() { + async fn test_find_by_email() -> anyhow::Result<()> { // Setup test database - let db = setup_test_db().await; + let db = setup_test_db().await?; // Create a user let email = "find_test@example.com"; @@ -947,28 +951,28 @@ mod tests { "system".to_string(), ) .await - .expect("Failed to create user"); + .with_context(|| "Failed to create user".to_string())?; // Test finding user by email let found_user = User::find_by_email(email, &db) .await - .expect("Error searching for user"); - assert!(found_user.is_some()); - let found_user = found_user.unwrap(); + .with_context(|| "Error searching for user".to_string())? + .with_context(|| "expected user to exist".to_string())?; assert_eq!(found_user.id, created_user.id); assert_eq!(found_user.email, email); // Test finding non-existent user let not_found = User::find_by_email("nonexistent@example.com", &db) .await - .expect("Error searching for user"); + .with_context(|| "Error searching for user".to_string())?; assert!(not_found.is_none()); + Ok(()) } #[tokio::test] - async fn test_api_key_management() { + async fn test_api_key_management() -> anyhow::Result<()> { // Setup test database - let db = setup_test_db().await; + let db = setup_test_db().await?; // Create a user let email = "apikey_test@example.com"; @@ -982,7 +986,7 @@ mod tests { "system".to_string(), ) .await - .expect("Failed to create user"); + .with_context(|| "Failed to create user".to_string())?; // Initially, user should have no API key assert!(user.api_key.is_none()); @@ -990,7 +994,7 @@ mod tests { // Generate API key let api_key = User::set_api_key(&user.id, &db) .await - .expect("Failed to set API key"); + .with_context(|| "Failed to set API key".to_string())?; assert!(!api_key.is_empty()); assert!(api_key.starts_with("sk_")); @@ -998,38 +1002,36 @@ mod tests { let updated_user: Option = db .get_item(&user.id) .await - .expect("Failed to retrieve user"); - assert!(updated_user.is_some()); - let updated_user = updated_user.unwrap(); + .with_context(|| "Failed to retrieve user".to_string())?; + let updated_user = updated_user.with_context(|| "expected updated user".to_string())?; assert_eq!(updated_user.api_key, Some(api_key.clone())); // Test finding user by API key let found_user = User::find_by_api_key(&api_key, &db) .await - .expect("Error searching by API key"); - assert!(found_user.is_some()); - let found_user = found_user.unwrap(); + .with_context(|| "Error searching by API key".to_string())? + .with_context(|| "expected user found by api key".to_string())?; assert_eq!(found_user.id, user.id); // Revoke API key User::revoke_api_key(&user.id, &db) .await - .expect("Failed to revoke API key"); + .with_context(|| "Failed to revoke API key".to_string())?; // Verify API key was revoked let revoked_user: Option = db .get_item(&user.id) .await - .expect("Failed to retrieve user"); - assert!(revoked_user.is_some()); - let revoked_user = revoked_user.unwrap(); + .with_context(|| "Failed to retrieve user".to_string())?; + let revoked_user = revoked_user.with_context(|| "expected revoked user".to_string())?; assert!(revoked_user.api_key.is_none()); // Test searching by revoked API key let not_found = User::find_by_api_key(&api_key, &db) .await - .expect("Error searching by API key"); + .with_context(|| "Error searching by API key".to_string())?; assert!(not_found.is_none()); + Ok(()) } #[tokio::test] @@ -1069,9 +1071,9 @@ mod tests { } #[tokio::test] - async fn test_password_update() { + async fn test_password_update() -> anyhow::Result<()> { // Setup test database - let db = setup_test_db().await; + let db = setup_test_db().await?; // Create a user let email = "pwd_test@example.com"; @@ -1086,7 +1088,7 @@ mod tests { "system".to_string(), ) .await - .expect("Failed to create user"); + .with_context(|| "Failed to create user".to_string())?; // Authenticate with old password let auth_result = User::authenticate(email, old_password, &db).await; @@ -1095,7 +1097,7 @@ mod tests { // Update password User::patch_password(email, new_password, &db) .await - .expect("Failed to update password"); + .with_context(|| "Failed to update password".to_string())?; // Old password should no longer work let old_auth = User::authenticate(email, old_password, &db).await; @@ -1104,10 +1106,11 @@ mod tests { // New password should work let new_auth = User::authenticate(email, new_password, &db).await; assert!(new_auth.is_ok()); + Ok(()) } #[tokio::test] - async fn test_validate_timezone() { + async fn test_validate_timezone() -> anyhow::Result<()> { // Valid timezones should be accepted as-is assert_eq!(validate_timezone("America/New_York"), "America/New_York"); assert_eq!(validate_timezone("Europe/London"), "Europe/London"); @@ -1117,12 +1120,13 @@ mod tests { // Invalid timezones should be replaced with UTC assert_eq!(validate_timezone("Invalid/Timezone"), "UTC"); assert_eq!(validate_timezone("Not_Real"), "UTC"); + Ok(()) } #[tokio::test] - async fn test_timezone_update() { + async fn test_timezone_update() -> anyhow::Result<()> { // Setup test database - let db = setup_test_db().await; + let db = setup_test_db().await?; // Create user with default timezone let email = "timezone_test@example.com"; @@ -1134,7 +1138,7 @@ mod tests { "system".to_string(), ) .await - .expect("Failed to create user"); + .with_context(|| "Failed to create user".to_string())?; assert_eq!(user.timezone, "UTC"); @@ -1142,58 +1146,61 @@ mod tests { let new_timezone = "Europe/Paris"; User::update_timezone(&user.id, new_timezone, &db) .await - .expect("Failed to update timezone"); + .with_context(|| "Failed to update timezone".to_string())?; // Verify timezone was updated let updated_user: Option = db .get_item(&user.id) .await - .expect("Failed to retrieve user"); - assert!(updated_user.is_some()); - let updated_user = updated_user.unwrap(); + .with_context(|| "Failed to retrieve user".to_string())?; + let updated_user = updated_user.with_context(|| "expected updated user".to_string())?; assert_eq!(updated_user.timezone, new_timezone); + Ok(()) } #[tokio::test] - async fn test_conversations_order() { - let db = setup_test_db().await; + async fn test_conversations_order() -> anyhow::Result<()> { + let db = setup_test_db().await?; let user_id = "user_order_test"; // Create conversations with varying updated_at timestamps let mut conversations = Vec::new(); for i in 0..5 { - let mut conv = Conversation::new(user_id.to_string(), format!("Conv {}", i)); + let mut conv = Conversation::new(user_id.to_string(), format!("Conv {i}")); // Fake updated_at i minutes apart conv.created_at = chrono::Utc::now() - chrono::Duration::minutes(i); db.store_item(conv.clone()) .await - .expect("Failed to store conversation"); + .with_context(|| "Failed to store conversation".to_string())?; conversations.push(conv); } // Retrieve via get_user_conversations - should be ordered by updated_at DESC let retrieved = User::get_user_conversations(user_id, &db) .await - .expect("Failed to get conversations"); + .with_context(|| "Failed to get conversations".to_string())?; assert_eq!(retrieved.len(), conversations.len()); - for window in retrieved.windows(2) { - // Assert each earlier conversation has updated_at >= later conversation + for pair in retrieved.windows(2) { + let a = pair.first().context("expected first in pair")?; + let b = pair.get(1).context("expected second in pair")?; assert!( - window[0].created_at >= window[1].created_at, + a.created_at >= b.created_at, "Conversations not ordered descending by created_at" ); } // Check first conversation title matches the most recently updated - let most_recent = conversations.iter().max_by_key(|c| c.created_at).unwrap(); - assert_eq!(retrieved[0].id, most_recent.id); + let most_recent = conversations.iter().max_by_key(|c| c.created_at).context("expected most recent")?; + let r0 = retrieved.first().context("expected first result")?; + assert_eq!(r0.id, most_recent.id); + Ok(()) } #[tokio::test] - async fn test_get_latest_text_contents_returns_last_five() { - let db = setup_test_db().await; + async fn test_get_latest_text_contents_returns_last_five() -> anyhow::Result<()> { + let db = setup_test_db().await?; let user_id = "latest_text_user"; let mut inserted_ids = Vec::new(); @@ -1201,8 +1208,8 @@ mod tests { for i in 0..12 { let mut item = TextContent::new( - format!("Text {}", i), - Some(format!("Context {}", i)), + format!("Text {i}"), + Some(format!("Context {i}")), "Category".to_string(), None, None, @@ -1215,18 +1222,19 @@ mod tests { db.store_item(item.clone()) .await - .expect("Failed to store text content"); + .with_context(|| "Failed to store text content".to_string())?; inserted_ids.push(item.id.clone()); } let latest = User::get_latest_text_contents(user_id, &db) .await - .expect("Failed to fetch latest text contents"); + .with_context(|| "Failed to fetch latest text contents".to_string())?; assert_eq!(latest.len(), 5, "Expected exactly five items"); - let mut expected_ids = inserted_ids[inserted_ids.len() - 5..].to_vec(); + let start = inserted_ids.len().saturating_sub(5); + let mut expected_ids = inserted_ids.get(start..).unwrap_or_default().to_vec(); expected_ids.reverse(); let returned_ids: Vec = latest.iter().map(|item| item.id.clone()).collect(); @@ -1235,25 +1243,29 @@ mod tests { "Latest items did not match expectation" ); - for window in latest.windows(2) { + for pair in latest.windows(2) { + let a = pair.first().context("expected first in pair")?; + let b = pair.get(1).context("expected second in pair")?; assert!( - window[0].created_at >= window[1].created_at, + a.created_at >= b.created_at, "Results are not ordered by created_at descending" ); } + Ok(()) } #[tokio::test] - async fn test_validate_theme() { + async fn test_validate_theme() -> anyhow::Result<()> { assert_eq!(validate_theme("light"), Theme::Light); assert_eq!(validate_theme("dark"), Theme::Dark); assert_eq!(validate_theme("system"), Theme::System); assert_eq!(validate_theme("invalid"), Theme::System); + Ok(()) } #[tokio::test] - async fn test_theme_update() { - let db = setup_test_db().await; + async fn test_theme_update() -> anyhow::Result<()> { + let db = setup_test_db().await?; let email = "theme_test@example.com"; let user = User::create_new( email.to_string(), @@ -1263,30 +1275,31 @@ mod tests { "system".to_string(), ) .await - .expect("Failed to create user"); + .with_context(|| "Failed to create user".to_string())?; assert_eq!(user.theme, Theme::System); User::update_theme(&user.id, "dark", &db) .await - .expect("update theme"); + .with_context(|| "update theme".to_string())?; let updated = db .get_item::(&user.id) .await - .expect("get user") - .unwrap(); + .with_context(|| "get user".to_string())? + .with_context(|| "expected user".to_string())?; assert_eq!(updated.theme, Theme::Dark); // Invalid theme should default to system (but update_theme calls validate_theme) User::update_theme(&user.id, "invalid", &db) .await - .expect("update theme invalid"); + .with_context(|| "update theme invalid".to_string())?; let updated2 = db .get_item::(&user.id) .await - .expect("get user") - .unwrap(); + .with_context(|| "get user".to_string())? + .with_context(|| "expected user".to_string())?; assert_eq!(updated2.theme, Theme::System); + Ok(()) } } diff --git a/common/src/utils/config.rs b/common/src/utils/config.rs index 26b761c..95aa13d 100644 --- a/common/src/utils/config.rs +++ b/common/src/utils/config.rs @@ -28,8 +28,8 @@ fn default_storage_kind() -> StorageKind { StorageKind::Local } -fn default_s3_region() -> Option { - Some("us-east-1".to_string()) +fn default_s3_region() -> String { + "us-east-1".to_string() } /// Selects the strategy used for PDF ingestion. @@ -69,7 +69,7 @@ pub struct AppConfig { #[serde(default)] pub s3_endpoint: Option, #[serde(default = "default_s3_region")] - pub s3_region: Option, + pub s3_region: String, #[serde(default = "default_pdf_ingest_mode")] pub pdf_ingest_mode: PdfIngestMode, #[serde(default = "default_reranking_enabled")] diff --git a/evaluations/src/corpus/orchestrator.rs b/evaluations/src/corpus/orchestrator.rs index 413b143..38a9e44 100644 --- a/evaluations/src/corpus/orchestrator.rs +++ b/evaluations/src/corpus/orchestrator.rs @@ -14,7 +14,7 @@ use common::utils::config::get_config; use common::{ storage::{ db::SurrealDbClient, - store::{DynStore, StorageManager}, + store::{DynStorage, StorageManager}, types::{ingestion_payload::IngestionPayload, ingestion_task::IngestionTask, StoredObject}, }, utils::config::{AppConfig, StorageKind}, @@ -432,7 +432,7 @@ async fn ingest_paragraph_batch( storage: StorageKind::Memory, ..Default::default() }; - let backend: DynStore = Arc::new(InMemory::new()); + let backend: DynStorage = Arc::new(InMemory::new()); let storage = StorageManager::with_backend(backend, StorageKind::Memory); let pipeline_config = ingestion_config.clone(); diff --git a/evaluations/src/corpus/store.rs b/evaluations/src/corpus/store.rs index 5f31cf9..d2d841b 100644 --- a/evaluations/src/corpus/store.rs +++ b/evaluations/src/corpus/store.rs @@ -861,7 +861,7 @@ mod tests { let question = CorpusQuestion { question_id: "q1".to_string(), paragraph_id: paragraph_one.paragraph_id.clone(), - text_content_id: text_content_id, + text_content_id, question_text: "What is this?".to_string(), answers: vec!["Hello".to_string()], is_impossible: false, diff --git a/evaluations/src/db_helpers.rs b/evaluations/src/db_helpers.rs index c703631..47a426e 100644 --- a/evaluations/src/db_helpers.rs +++ b/evaluations/src/db_helpers.rs @@ -1,5 +1,5 @@ use anyhow::{Context, Result}; -use common::storage::{db::SurrealDbClient, indexes::ensure_runtime_indexes}; +use common::storage::{db::SurrealDbClient, indexes::ensure_runtime}; use tracing::info; // Helper functions for index management during namespace reseed @@ -11,7 +11,7 @@ pub async fn remove_all_indexes(db: &SurrealDbClient) -> Result<()> { pub async fn recreate_indexes(db: &SurrealDbClient, dimension: usize) -> Result<()> { info!("Recreating ALL indexes after namespace reseed via shared runtime helper"); - ensure_runtime_indexes(db, dimension) + ensure_runtime(db, dimension) .await .context("creating runtime indexes") } diff --git a/evaluations/src/pipeline/context.rs b/evaluations/src/pipeline/context.rs index 4bf02ed..0e12c45 100644 --- a/evaluations/src/pipeline/context.rs +++ b/evaluations/src/pipeline/context.rs @@ -13,7 +13,7 @@ use common::{ utils::embedding::EmbeddingProvider, }; use retrieval_pipeline::{ - pipeline::{PipelineStageTimings, RetrievalConfig}, + pipeline::{StageTimings, RetrievalConfig}, reranking::RerankerPool, }; @@ -56,7 +56,7 @@ pub(super) struct EvaluationContext<'a> { pub corpus_handle: Option, pub cases: Vec, pub filtered_questions: usize, - pub stage_latency_samples: Vec, + pub stage_latency_samples: Vec, pub latencies: Vec, pub diagnostics_output: Vec, pub query_summaries: Vec, diff --git a/evaluations/src/pipeline/stages/run_queries.rs b/evaluations/src/pipeline/stages/run_queries.rs index bb25523..04d5d88 100644 --- a/evaluations/src/pipeline/stages/run_queries.rs +++ b/evaluations/src/pipeline/stages/run_queries.rs @@ -10,7 +10,7 @@ use crate::eval::{ CaseSummary, RetrievedSummary, }; use retrieval_pipeline::{ - pipeline::{self, PipelineStageTimings, RetrievalConfig}, + pipeline::{self, StageTimings, RetrievalConfig}, reranking::RerankerPool, }; use tokio::sync::Semaphore; @@ -75,10 +75,10 @@ pub(crate) async fn run_queries( retrieval_config.tuning.chunk_rrf_fts_weight = value; } if let Some(value) = config.retrieval.chunk_rrf_use_vector { - retrieval_config.tuning.chunk_rrf_use_vector = value; + retrieval_config.tuning.flags.chunk_rrf_use_vector = value.into(); } if let Some(value) = config.retrieval.chunk_rrf_use_fts { - retrieval_config.tuning.chunk_rrf_use_fts = value; + retrieval_config.tuning.flags.chunk_rrf_use_fts = value.into(); } if let Some(value) = config.retrieval.chunk_avg_chars_per_token { retrieval_config.tuning.avg_chars_per_token = value; @@ -113,8 +113,8 @@ pub(crate) async fn run_queries( chunk_rrf_k = active_tuning.chunk_rrf_k, chunk_rrf_vector_weight = active_tuning.chunk_rrf_vector_weight, chunk_rrf_fts_weight = active_tuning.chunk_rrf_fts_weight, - chunk_rrf_use_vector = active_tuning.chunk_rrf_use_vector, - chunk_rrf_use_fts = active_tuning.chunk_rrf_use_fts, + chunk_rrf_use_vector = active_tuning.flags.chunk_rrf_use_vector.as_bool(), + chunk_rrf_use_fts = active_tuning.flags.chunk_rrf_use_fts.as_bool(), embedding_backend = ctx.embedding_provider().backend_label(), embedding_model = ctx .embedding_provider() @@ -181,35 +181,32 @@ pub(crate) async fn run_queries( embedding_provider.embed(&question).await.with_context(|| { format!("generating embedding for question {}", question_id) })?; - let reranker = match &rerank_pool { - Some(pool) => Some(pool.checkout().await), + let reranker = match rerank_pool.as_ref() { + Some(pool) => pool.checkout().await, None => None, }; + let params = pipeline::StrategyParams { + db_client: &db, + openai_client: &openai_client, + embedding_provider: Some(&embedding_provider), + input_text: &question, + user_id: &user_id, + config: (*retrieval_config).clone(), + reranker, + }; let (result_output, pipeline_diagnostics, stage_timings) = if diagnostics_enabled { let outcome = pipeline::run_pipeline_with_embedding_with_diagnostics( - &db, - &openai_client, - Some(&embedding_provider), + params, query_embedding, - &question, - &user_id, - (*retrieval_config).clone(), - reranker, ) .await .with_context(|| format!("running pipeline for question {}", question_id))?; (outcome.results, outcome.diagnostics, outcome.stage_timings) } else { let outcome = pipeline::run_pipeline_with_embedding_with_metrics( - &db, - &openai_client, - Some(&embedding_provider), + params, query_embedding, - &question, - &user_id, - (*retrieval_config).clone(), - reranker, ) .await .with_context(|| format!("running pipeline for question {}", question_id))?; @@ -327,7 +324,7 @@ pub(crate) async fn run_queries( usize, CaseSummary, Option, - PipelineStageTimings, + StageTimings, ), anyhow::Error, >((idx, summary, diagnostics, stage_timings)) diff --git a/evaluations/src/pipeline/stages/summarize.rs b/evaluations/src/pipeline/stages/summarize.rs index f3d03b3..a17d64f 100644 --- a/evaluations/src/pipeline/stages/summarize.rs +++ b/evaluations/src/pipeline/stages/summarize.rs @@ -205,8 +205,8 @@ pub(crate) async fn summarize( chunk_rrf_k: active_tuning.chunk_rrf_k, chunk_rrf_vector_weight: active_tuning.chunk_rrf_vector_weight, chunk_rrf_fts_weight: active_tuning.chunk_rrf_fts_weight, - chunk_rrf_use_vector: active_tuning.chunk_rrf_use_vector, - chunk_rrf_use_fts: active_tuning.chunk_rrf_use_fts, + chunk_rrf_use_vector: active_tuning.flags.chunk_rrf_use_vector.as_bool(), + chunk_rrf_use_fts: active_tuning.flags.chunk_rrf_use_fts.as_bool(), ingest_chunk_min_tokens: config.ingest.ingest_chunk_min_tokens, ingest_chunk_max_tokens: config.ingest.ingest_chunk_max_tokens, ingest_chunks_only: config.ingest.ingest_chunks_only, diff --git a/evaluations/src/slice.rs b/evaluations/src/slice.rs index 41ce029..270e7f1 100644 --- a/evaluations/src/slice.rs +++ b/evaluations/src/slice.rs @@ -1037,6 +1037,31 @@ fn write_manifest(path: &Path, manifest: &SliceManifest) -> Result<()> { Ok(()) } +use crate::args::Config; + +impl<'a> From<&'a Config> for SliceConfig<'a> { + fn from(config: &'a Config) -> Self { + slice_config_with_limit(config, None) + } +} + +pub fn slice_config_with_limit<'a>( + config: &'a Config, + limit_override: Option, +) -> SliceConfig<'a> { + SliceConfig { + cache_dir: config.cache_dir.as_path(), + force_convert: config.force_convert, + explicit_slice: config.slice.as_deref(), + limit: limit_override.or(config.limit), + corpus_limit: config.corpus_limit, + slice_seed: config.slice_seed, + llm_mode: config.llm_mode, + negative_multiplier: config.negative_multiplier, + require_verified_chunks: config.retrieval.require_verified_chunks, + } +} + #[cfg(test)] mod tests { use super::*; @@ -1214,30 +1239,3 @@ mod tests { Ok(()) } } - -// MARK: - Config integration (merged from slice.rs) - -use crate::args::Config; - -impl<'a> From<&'a Config> for SliceConfig<'a> { - fn from(config: &'a Config) -> Self { - slice_config_with_limit(config, None) - } -} - -pub fn slice_config_with_limit<'a>( - config: &'a Config, - limit_override: Option, -) -> SliceConfig<'a> { - SliceConfig { - cache_dir: config.cache_dir.as_path(), - force_convert: config.force_convert, - explicit_slice: config.slice.as_deref(), - limit: limit_override.or(config.limit), - corpus_limit: config.corpus_limit, - slice_seed: config.slice_seed, - llm_mode: config.llm_mode, - negative_multiplier: config.negative_multiplier, - require_verified_chunks: config.retrieval.require_verified_chunks, - } -} diff --git a/html-router/src/html_state.rs b/html-router/src/html_state.rs index 084f9f2..c4bcdee 100644 --- a/html-router/src/html_state.rs +++ b/html-router/src/html_state.rs @@ -39,30 +39,33 @@ const CONVERSATION_ARCHIVE_CACHE_TTL: Duration = Duration::from_secs(30); const CONVERSATION_ARCHIVE_CACHE_MAX_USERS: usize = 1024; const CONVERSATION_ARCHIVE_CACHE_CLEANUP_WRITE_INTERVAL: usize = 64; +pub struct StateResources { + pub db: Arc, + pub openai_client: Arc, + pub session_store: Arc, + pub storage: StorageManager, + pub config: AppConfig, + pub reranker_pool: Option>, + pub embedding_provider: Arc, + pub template_engine: Option>, +} + impl HtmlState { - pub async fn new_with_resources( - db: Arc, - openai_client: Arc, - session_store: Arc, - storage: StorageManager, - config: AppConfig, - reranker_pool: Option>, - embedding_provider: Arc, - template_engine: Option>, - ) -> Self { - let templates = - template_engine.unwrap_or_else(|| Arc::new(create_template_engine!("templates"))); + pub fn new_with_resources(resources: StateResources) -> Self { + let templates = resources + .template_engine + .unwrap_or_else(|| Arc::new(create_template_engine!("templates"))); debug!("Template engine configured for html_router."); Self { - db, - openai_client, - session_store, + db: resources.db, + openai_client: resources.openai_client, templates, - config, - storage, - reranker_pool, - embedding_provider, + session_store: resources.session_store, + config: resources.config, + storage: resources.storage, + reranker_pool: resources.reranker_pool, + embedding_provider: resources.embedding_provider, conversation_archive_cache: Arc::new(RwLock::new(HashMap::new())), conversation_archive_cache_writes: Arc::new(AtomicUsize::new(0)), } @@ -210,18 +213,16 @@ mod tests { EmbeddingProvider::new_hashed(8).expect("Failed to create embedding provider"), ); - HtmlState::new_with_resources( + HtmlState::new_with_resources(StateResources { db, - Arc::new(async_openai::Client::new()), + openai_client: Arc::new(async_openai::Client::new()), session_store, storage, config, - None, + reranker_pool: None, embedding_provider, - None, - ) - .await - .expect("Failed to create HtmlState") + template_engine: None, + }) } #[tokio::test] diff --git a/html-router/src/middlewares/compression.rs b/html-router/src/middlewares/compression.rs index b4aad1d..a8aa825 100644 --- a/html-router/src/middlewares/compression.rs +++ b/html-router/src/middlewares/compression.rs @@ -2,6 +2,6 @@ use tower_http::compression::CompressionLayer; /// Provides a default compression layer that negotiates encoding based on the /// `Accept-Encoding` header of the incoming request. -pub fn compression_layer() -> CompressionLayer { +pub fn layer() -> CompressionLayer { CompressionLayer::new() } diff --git a/html-router/src/middlewares/response_middleware.rs b/html-router/src/middlewares/response_middleware.rs index 5ec5691..7374005 100644 --- a/html-router/src/middlewares/response_middleware.rs +++ b/html-router/src/middlewares/response_middleware.rs @@ -10,7 +10,7 @@ use axum::{ use axum_htmx::{HxRequest, HX_TRIGGER}; use common::{ error::AppError, - utils::template_engine::{ProvidesTemplateEngine, Value}, + utils::template_engine::{ProvidesTemplateEngine, TemplateEngine, Value}, }; use minijinja::context; use serde::Serialize; @@ -146,6 +146,40 @@ struct ContextWrapper<'a> { context: HashMap, } +const HTMX_HEADERS_TO_FORWARD: &[&str] = &["HX-Push", "HX-Trigger", "HX-Redirect"]; + +fn forward_headers(from: &axum::http::HeaderMap, to: &mut axum::http::HeaderMap) { + for &header_name in HTMX_HEADERS_TO_FORWARD { + if let Ok(name) = HeaderName::from_bytes(header_name.as_bytes()) { + if let Some(value) = from.get(&name) { + to.insert(name.clone(), value.clone()); + } + } + } +} + +fn context_to_map( + value: &Value, +) -> Result, minijinja::value::ValueKind> { + match value.kind() { + minijinja::value::ValueKind::Map => { + let mut map = HashMap::new(); + if let Ok(keys) = value.try_iter() { + for key in keys { + if let Ok(val) = value.get_item(&key) { + map.insert(key.to_string(), val); + } + } + } + Ok(map) + } + minijinja::value::ValueKind::None | minijinja::value::ValueKind::Undefined => { + Ok(HashMap::new()) + } + other => Err(other), + } +} + pub async fn with_template_response( State(state): State, HxRequest(is_htmx): HxRequest, @@ -158,14 +192,12 @@ where let mut user_theme = Theme::System.as_str(); let mut initial_theme = Theme::System.initial_theme(); let mut is_authenticated = false; - let mut current_user_id = None; let mut current_user = None; { if let Some(auth) = req.extensions().get::() { if let Some(user) = &auth.current_user { is_authenticated = true; - current_user_id = Some(user.id.clone()); user_theme = user.theme.as_str(); initial_theme = user.theme.initial_theme(); current_user = Some(TemplateUser::from(user)); @@ -175,9 +207,6 @@ where let response = next.run(req).await; - // Headers to forward from the original response - const HTMX_HEADERS_TO_FORWARD: &[&str] = &["HX-Push", "HX-Trigger", "HX-Redirect"]; - if let Some(template_response) = response.extensions().get::().cloned() { let template_engine = state.template_engine(); @@ -187,56 +216,23 @@ where matches!(&template_response.template_kind, TemplateKind::Full(_)); if should_load_conversation_archive { - if let Some(user_id) = current_user_id { + if let Some(user_id) = current_user.as_ref().map(|u| &u.id) { let html_state = state.html_state(); if let Some(cached_archive) = - html_state.get_cached_conversation_archive(&user_id).await + html_state.get_cached_conversation_archive(user_id).await { conversation_archive = cached_archive; } else if let Ok(archive) = - Conversation::get_user_sidebar_conversations(&user_id, &html_state.db).await + Conversation::get_user_sidebar_conversations(user_id, &html_state.db).await { html_state - .set_cached_conversation_archive(&user_id, archive.clone()) + .set_cached_conversation_archive(user_id, archive.clone()) .await; conversation_archive = archive; } } } - fn context_to_map( - value: &Value, - ) -> Result, minijinja::value::ValueKind> { - match value.kind() { - minijinja::value::ValueKind::Map => { - let mut map = HashMap::new(); - if let Ok(keys) = value.try_iter() { - for key in keys { - if let Ok(val) = value.get_item(&key) { - map.insert(key.to_string(), val); - } - } - } - Ok(map) - } - minijinja::value::ValueKind::None | minijinja::value::ValueKind::Undefined => { - Ok(HashMap::new()) - } - other => Err(other), - } - } - - // Helper to forward relevant headers - fn forward_headers(from: &axum::http::HeaderMap, to: &mut axum::http::HeaderMap) { - for &header_name in HTMX_HEADERS_TO_FORWARD { - if let Ok(name) = HeaderName::from_bytes(header_name.as_bytes()) { - if let Some(value) = from.get(&name) { - to.insert(name.clone(), value.clone()); - } - } - } - } - let context_map = match context_to_map(&template_response.context) { Ok(map) => map, Err(kind) => { @@ -290,18 +286,17 @@ where } TemplateKind::Error(status) => { if is_htmx { - // HTMX request: Send 204 + HX-Trigger for toast let title = template_response .context .get_attr("title") .ok() - .and_then(|v| v.as_str().map(String::from)) + .and_then(|v| v.as_str().map(|s| s.to_string())) .unwrap_or_else(|| "Error".to_string()); let description = template_response .context .get_attr("description") .ok() - .and_then(|v| v.as_str().map(String::from)) + .and_then(|v| v.as_str().map(|s| s.to_string())) .unwrap_or_else(|| "An error occurred.".to_string()); let trigger_payload = json!({"toast": {"title": title, "description": description, "type": "error"}}); @@ -312,14 +307,12 @@ where }); (StatusCode::NO_CONTENT, [(HX_TRIGGER, trigger_value)], "").into_response() } else { - // Non-HTMX request: Render the full errors/error.html page match template_engine .render("errors/error.html", &Value::from_serialize(&context)) { Ok(html) => (*status, Html(html)).into_response(), Err(e) => { error!("Critical: Failed to render 'errors/error.html': {:?}", e); - // Fallback HTML, but use the intended status code (*status, Html(fallback_error())).into_response() } } diff --git a/html-router/src/router_factory.rs b/html-router/src/router_factory.rs index eccb011..3e54c86 100644 --- a/html-router/src/router_factory.rs +++ b/html-router/src/router_factory.rs @@ -9,7 +9,7 @@ use crate::{ html_state::HtmlState, middlewares::{ analytics_middleware::analytics_middleware, auth_middleware::require_auth, - compression::compression_layer, response_middleware::with_template_response, + compression, response_middleware::with_template_response, }, }; @@ -71,6 +71,7 @@ where } // Add a serving of assets + #[must_use] pub fn with_public_assets(mut self, path: &str, directory: &str) -> Self { self.public_assets_config = Some(AssetsConfig { path: path.to_string(), @@ -80,24 +81,28 @@ where } // Add a public router that will be merged at the root level + #[must_use] pub fn add_public_routes(mut self, routes: Router) -> Self { self.public_routers.push(routes); self } // Add a protected router that will be merged at the root level + #[must_use] pub fn add_protected_routes(mut self, routes: Router) -> Self { self.protected_routers.push(routes); self } // Nest a public router under a path prefix + #[must_use] pub fn nest_public_routes(mut self, path: &str, routes: Router) -> Self { self.nested_routes.push((path.to_string(), routes)); self } // Nest a protected router under a path prefix + #[must_use] pub fn nest_protected_routes(mut self, path: &str, routes: Router) -> Self { self.nested_protected_routes .push((path.to_string(), routes)); @@ -105,6 +110,7 @@ where } // Add custom middleware to be applied before the standard ones + #[must_use] pub fn with_middleware(mut self, middleware_fn: F) -> Self where F: FnOnce(Router) -> Router + Send + 'static, @@ -114,6 +120,7 @@ where } /// Enables response compression when building the router. + #[must_use] pub const fn with_compression(mut self) -> Self { self.compression_enabled = true; self @@ -191,7 +198,7 @@ where // Apply Global Middleware (Compression) if self.compression_enabled { - final_router = final_router.layer(compression_layer()); + final_router = final_router.layer(compression::layer()); } final_router diff --git a/html-router/src/routes/account/handlers.rs b/html-router/src/routes/account/handlers.rs index 99100fa..f189b3c 100644 --- a/html-router/src/routes/account/handlers.rs +++ b/html-router/src/routes/account/handlers.rs @@ -62,7 +62,7 @@ pub async fn set_api_key( let api_key = User::set_api_key(&user.id, &state.db).await?; // Clear the cache so new requests have access to the user with api key - auth.cache_clear_user(user.id.to_string()); + auth.cache_clear_user(user.id.clone()); // Render the API key section block Ok(TemplateResponse::new_partial( @@ -106,7 +106,7 @@ pub async fn update_timezone( User::update_timezone(&user.id, &form.timezone, &state.db).await?; // Clear the cache - auth.cache_clear_user(user.id.to_string()); + auth.cache_clear_user(user.id.clone()); let timezones = TZ_VARIANTS .iter() @@ -141,7 +141,7 @@ pub async fn update_theme( User::update_theme(&user.id, &form.theme, &state.db).await?; // Clear the cache - auth.cache_clear_user(user.id.to_string()); + auth.cache_clear_user(user.id.clone()); let theme_options = vec![ Theme::Light.as_str().to_string(), diff --git a/html-router/src/routes/admin/handlers.rs b/html-router/src/routes/admin/handlers.rs index 1355f30..f6cb9e6 100644 --- a/html-router/src/routes/admin/handlers.rs +++ b/html-router/src/routes/admin/handlers.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use async_openai::types::ListModelResponse; use axum::{ extract::{Query, State}, @@ -37,18 +39,14 @@ pub struct AdminPanelData { current_section: AdminSection, } -#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq, Default)] #[serde(rename_all = "snake_case")] pub enum AdminSection { + #[default] Overview, Models, } -impl Default for AdminSection { - fn default() -> Self { - Self::Overview - } -} #[derive(Deserialize)] pub struct AdminPanelQuery { @@ -107,10 +105,7 @@ fn checkbox_to_bool<'de, D>(deserializer: D) -> Result where D: serde::Deserializer<'de>, { - match String::deserialize(deserializer) { - Ok(string) => Ok(string == "on"), - Err(_) => Ok(false), - } + String::deserialize(deserializer).map(|s| s == "on") } #[derive(Deserialize)] @@ -219,8 +214,8 @@ pub async fn update_model_settings( if reembedding_needed { info!("Embedding dimensions changed. Spawning background re-embedding task..."); - let db_for_task = state.db.clone(); - let openai_for_task = state.openai_client.clone(); + let db_for_task = Arc::clone(&state.db); + let openai_for_task = Arc::clone(&state.openai_client); let new_model_for_task = new_settings.embedding_model.clone(); let new_dims_for_task = new_settings.embedding_dimensions; diff --git a/html-router/src/routes/auth/signup.rs b/html-router/src/routes/auth/signup.rs index ba0a538..c08ae3c 100644 --- a/html-router/src/routes/auth/signup.rs +++ b/html-router/src/routes/auth/signup.rs @@ -11,7 +11,7 @@ use crate::{ }; #[derive(Deserialize, Serialize)] -pub struct SignupParams { +pub struct Params { pub email: String, pub password: String, pub timezone: String, @@ -39,7 +39,7 @@ pub async fn show_signup_form( pub async fn process_signup_and_show_verification( State(state): State, auth: AuthSessionType, - Form(form): Form, + Form(form): Form, ) -> Result { let user = match User::create_new( form.email, diff --git a/html-router/src/routes/chat/chat_handlers.rs b/html-router/src/routes/chat/chat_handlers.rs index c45ec38..aca01b9 100644 --- a/html-router/src/routes/chat/chat_handlers.rs +++ b/html-router/src/routes/chat/chat_handlers.rs @@ -49,6 +49,8 @@ pub struct ChatPageData { conversation: Option, } +/// # Panics +/// Panics if the HX-Push header value cannot be parsed. pub async fn show_initialized_chat( State(state): State, RequireUser(user): RequireUser, @@ -57,14 +59,14 @@ pub async fn show_initialized_chat( let conversation = Conversation::new(user.id.clone(), "Test".to_owned()); let user_message = Message::new( - conversation.id.to_string(), + conversation.id.clone(), MessageRole::User, form.user_query, None, ); let ai_message = Message::new( - conversation.id.to_string(), + conversation.id.clone(), MessageRole::AI, form.llm_response, Some(form.references), @@ -86,10 +88,9 @@ pub async fn show_initialized_chat( ) .into_response(); - response.headers_mut().insert( - "HX-Push", - HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(), - ); + if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) { + response.headers_mut().insert("HX-Push", header_value); + } Ok(response) } @@ -130,12 +131,19 @@ pub async fn show_existing_chat( )) } +/// # Panics +/// Panics if the HX-Push header value cannot be parsed. pub async fn new_user_message( Path(conversation_id): Path, State(state): State, RequireUser(user): RequireUser, Form(form): Form, ) -> Result { + #[derive(Serialize)] + struct SSEResponseInitData { + user_message: Message, + } + let conversation: Conversation = state .db .get_item(&conversation_id) @@ -150,33 +158,34 @@ pub async fn new_user_message( state.db.store_item(user_message.clone()).await?; - #[derive(Serialize)] - struct SSEResponseInitData { - user_message: Message, - } - let mut response = TemplateResponse::new_template( "chat/streaming_response.html", SSEResponseInitData { user_message }, ) .into_response(); - response.headers_mut().insert( - "HX-Push", - HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(), - ); + if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) { + response.headers_mut().insert("HX-Push", header_value); + } Ok(response) } +/// # Panics +/// Panics if the HX-Push header value cannot be parsed. pub async fn new_chat_user_message( State(state): State, auth: AuthSession, Surreal>, Form(form): Form, ) -> Result { - let user = match auth.current_user { - Some(user) => user, - None => return Ok(Redirect::to("/").into_response()), + #[derive(Serialize)] + struct SSEResponseInitData { + user_message: Message, + conversation: Conversation, + } + + let Some(user) = auth.current_user else { + return Ok(Redirect::to("/").into_response()); }; let conversation = Conversation::new(user.id.clone(), "New chat".to_string()); @@ -191,11 +200,6 @@ pub async fn new_chat_user_message( state.db.store_item(user_message.clone()).await?; state.invalidate_conversation_archive_cache(&user.id).await; - #[derive(Serialize)] - struct SSEResponseInitData { - user_message: Message, - conversation: Conversation, - } let mut response = TemplateResponse::new_template( "chat/new_chat_first_response.html", SSEResponseInitData { @@ -205,10 +209,9 @@ pub async fn new_chat_user_message( ) .into_response(); - response.headers_mut().insert( - "HX-Push", - HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(), - ); + if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) { + response.headers_mut().insert("HX-Push", header_value); + } Ok(response.into_response()) } diff --git a/html-router/src/routes/chat/message_response_stream.rs b/html-router/src/routes/chat/message_response_stream.rs index c6e9aca..e54b16d 100644 --- a/html-router/src/routes/chat/message_response_stream.rs +++ b/html-router/src/routes/chat/message_response_stream.rs @@ -53,26 +53,22 @@ fn sse_with_keep_alive(stream: EventStream) -> SseResponse { ) } -// Error handling function fn create_error_stream(message: impl Into) -> EventStream { let message = message.into(); stream::once(async move { Ok(Event::default().event("error").data(message)) }).boxed() } -// Helper function to get message and user async fn get_message_and_user( db: &SurrealDbClient, current_user: Option, message_id: &str, ) -> Result<(Message, User, Conversation, Vec, Option), SseResponse> { - // Check authentication let Some(user) = current_user else { return Err(sse_with_keep_alive(create_error_stream( "You must be signed in to use this feature", ))); }; - // Retrieve message let message = match db.get_item::(message_id).await { Ok(Some(message)) => message, Ok(None) => { @@ -88,7 +84,6 @@ async fn get_message_and_user( } }; - // Get conversation history let (conversation, history) = match Conversation::get_complete_conversation(&message.conversation_id, &user.id, db).await { @@ -209,7 +204,6 @@ pub async fn get_response_stream( auth: AuthSessionType, Query(params): Query, ) -> SseResponse { - // 1. Authentication and initial data validation let (user_message, user, _conversation, history, existing_ai_response) = match get_message_and_user(&state.db, auth.current_user, ¶ms.message_id).await { Ok((user_message, user, conversation, history, existing_ai_response)) => ( @@ -226,9 +220,123 @@ pub async fn get_response_stream( return create_replayed_response_stream(&state, existing_ai_message); } - // 2. Retrieve knowledge entities + let (request, allowed_reference_ids) = match prepare_chat_request(&state, &user_message, &user, &history).await { + Ok(result) => result, + Err(sse) => return sse, + }; + + let openai_stream = match state.openai_client.chat().create_stream(request).await { + Ok(stream) => stream, + Err(_e) => { + return sse_with_keep_alive(create_error_stream("Failed to create OpenAI stream")); + } + }; + + build_chat_event_stream(state, openai_stream, &user_message, user.id.clone(), allowed_reference_ids) +} + +fn build_chat_event_stream( + state: HtmlState, + openai_stream: impl Stream> + Send + 'static, + user_message: &Message, + user_id: String, + allowed_reference_ids: Vec, +) -> SseResponse { + let (tx, rx) = channel::(1000); + let (tx_final, mut rx_final) = channel::(1); + + spawn_storage_task(Arc::clone(&state.db), rx, tx_final, user_message, user_id, allowed_reference_ids); + + let json_state = Arc::new(Mutex::new(StreamParserState::new())); + + let event_stream = openai_stream + .map_err(|e| Box::new(e) as Box) + .map(move |result| { + let tx_storage = tx.clone(); + let json_state = Arc::clone(&json_state); + + stream! { + match result { + Ok(response) => { + let content = response + .choices + .first() + .and_then(|choice| choice.delta.content.clone()) + .unwrap_or_default(); + + if !content.is_empty() { + let _ = tx_storage.send(content.clone()).await; + + let mut state = json_state.lock().await; + let display_content = state.process_chunk(&content); + drop(state); + if !display_content.is_empty() { + yield Ok(Event::default() + .event("chat_message") + .data(display_content)); + } + } + } + Err(e) => { + yield Ok(Event::default() + .event("error") + .data(format!("Stream error: {e}"))); + } + } + } + }) + .flatten() + .chain(stream::once(async move { + #[derive(Serialize)] + struct LocalReferenceData { + message: Message, + } + + if let Some(message) = rx_final.recv().await { + if message + .references + .as_ref() + .is_some_and(std::vec::Vec::is_empty) + { + return Ok(Event::default().event("empty")); + } + + match state.templates.render( + "chat/reference_list.html", + &Value::from_serialize(LocalReferenceData { message }), + ) { + Ok(html) => Ok(Event::default().event("references").data(html)), + Err(_) => Ok(Event::default() + .event("error") + .data("Failed to render references")), + } + } else { + Ok(Event::default() + .event("error") + .data("Failed to retrieve references")) + } + })) + .chain(once(async { + Ok(Event::default() + .event("close_stream") + .data("Stream complete")) + })) + .boxed(); + + sse_with_keep_alive(event_stream) +} + +async fn prepare_chat_request( + state: &HtmlState, + user_message: &Message, + user: &User, + history: &[Message], +) -> Result< + (async_openai::types::CreateChatCompletionRequest, Vec), + Sse> + Send>>>, +> { let rerank_lease = match state.reranker_pool.as_ref() { - Some(pool) => Some(pool.checkout().await), + Some(pool) => pool.checkout().await, None => None, }; @@ -248,59 +356,49 @@ pub async fn get_response_stream( { Ok(result) => result, Err(_e) => { - return sse_with_keep_alive(create_error_stream("Failed to retrieve knowledge")); + return Err(Sse::new(create_error_stream("Failed to retrieve knowledge"))); } }; let allowed_reference_ids = collect_reference_ids_from_retrieval(&retrieval_result); - // 3. Create the OpenAI request with appropriate context format - let context_json = match &retrieval_result { - retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(chunks), + let context_json = match retrieval_result { + retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(&chunks), retrieval_pipeline::StrategyOutput::Entities(entities) => { retrieved_entities_to_json(entities) } retrieval_pipeline::StrategyOutput::Search(search_result) => { - // For chat, use chunks from the search result chunks_to_chat_context(&search_result.chunks) } }; let formatted_user_message = - create_user_message_with_history(&context_json, &history, &user_message.content); + create_user_message_with_history(&context_json, history, &user_message.content); let Ok(settings) = SystemSettings::get_current(&state.db).await else { - return sse_with_keep_alive(create_error_stream("Failed to retrieve system settings")); + return Err(Sse::new(create_error_stream("Failed to retrieve system settings"))); }; let Ok(request) = create_chat_request(formatted_user_message, &settings) else { - return sse_with_keep_alive(create_error_stream("Failed to create chat request")); + return Err(Sse::new(create_error_stream("Failed to create chat request"))); }; - // 4. Set up the OpenAI stream - let openai_stream = match state.openai_client.chat().create_stream(request).await { - Ok(stream) => stream, - Err(_e) => { - return sse_with_keep_alive(create_error_stream("Failed to create OpenAI stream")); - } - }; + Ok((request, allowed_reference_ids)) +} - // 5. Create channel for collecting complete response - let (tx, mut rx) = channel::(1000); - let tx_clone = tx.clone(); - let (tx_final, mut rx_final) = channel::(1); +fn spawn_storage_task( + db_client: Arc, + mut rx: tokio::sync::mpsc::Receiver, + tx_final: tokio::sync::mpsc::Sender, + user_message: &Message, + user_id: String, + allowed_reference_ids: Vec, +) { + let conversation_id = user_message.conversation_id.clone(); - // 6. Set up the collection task for DB storage - let db_client = Arc::clone(&state.db); - let user_id = user.id.clone(); - let allowed_reference_ids = allowed_reference_ids.clone(); tokio::spawn(async move { - drop(tx); // Close sender when no longer needed - - // Collect full response let mut full_json = String::new(); while let Some(chunk) = rx.recv().await { full_json.push_str(&chunk); } - // Try to extract structured data if let Ok(response) = from_str::(&full_json) { let raw_references = extract_reference_strings(&response); let answer = response.answer; @@ -347,7 +445,7 @@ pub async fn get_response_stream( ); let ai_message = Message::new( - user_message.conversation_id, + conversation_id, MessageRole::AI, answer, Some(initial_validation.valid_refs), @@ -362,104 +460,11 @@ pub async fn get_response_stream( } else { error!("Failed to parse LLM response as structured format"); - // Fallback - store raw response - let ai_message = Message::new( - user_message.conversation_id, - MessageRole::AI, - full_json, - None, - ); + let ai_message = Message::new(conversation_id, MessageRole::AI, full_json, None); let _ = db_client.store_item(ai_message).await; } }); - - // Create a shared state for tracking the JSON parsing - let json_state = Arc::new(Mutex::new(StreamParserState::new())); - - // 7. Create the response event stream - let event_stream = openai_stream - .map_err(|e| Box::new(e) as Box) - .map(move |result| { - let tx_storage = tx_clone.clone(); - let json_state = Arc::clone(&json_state); - - stream! { - match result { - Ok(response) => { - let content = response - .choices - .first() - .and_then(|choice| choice.delta.content.clone()) - .unwrap_or_default(); - - if !content.is_empty() { - // Always send raw content to storage - let _ = tx_storage.send(content.clone()).await; - - // Process through JSON parser - let mut state = json_state.lock().await; - let display_content = state.process_chunk(&content); - drop(state); - if !display_content.is_empty() { - yield Ok(Event::default() - .event("chat_message") - .data(display_content)); - } - // If display_content is empty, don't yield anything - } - // If content is empty, don't yield anything - } - Err(e) => { - yield Ok(Event::default() - .event("error") - .data(format!("Stream error: {e}"))); - } - } - } - }) - .flatten() - .chain(stream::once(async move { - if let Some(message) = rx_final.recv().await { - // Don't send any event if references is empty - if message - .references - .as_ref() - .is_some_and(std::vec::Vec::is_empty) - { - return Ok(Event::default().event("empty")); // This event won't be sent - } - - // Render template with references - match state.templates.render( - "chat/reference_list.html", - &Value::from_serialize(ReferenceData { message }), - ) { - Ok(html) => { - // Return the rendered HTML - Ok(Event::default().event("references").data(html)) - } - Err(_) => { - // Handle template rendering error - Ok(Event::default() - .event("error") - .data("Failed to render references")) - } - } - } else { - // Handle case where no references were received - Ok(Event::default() - .event("error") - .data("Failed to retrieve references")) - } - })) - .chain(once(async { - Ok(Event::default() - .event("close_stream") - .data("Stream complete")) - })); - - sse_with_keep_alive(event_stream.boxed()) } struct StreamParserState { @@ -478,23 +483,18 @@ impl StreamParserState { } fn process_chunk(&mut self, chunk: &str) -> String { - // Feed all characters into the parser for c in chunk.chars() { let _ = self.parser.add_char(c); } - // Get the current state of the JSON let json = self.parser.get_result(); - // Check if we're in the answer field if let Some(obj) = json.as_object() { if let Some(answer) = obj.get("answer") { self.in_answer_field = true; - // Get current answer content let current_content = answer.as_str().unwrap_or_default().to_string(); - // Calculate difference to send only new content if current_content.len() > self.last_answer_content.len() { let new_content = current_content[self.last_answer_content.len()..].to_string(); self.last_answer_content = current_content; @@ -503,7 +503,6 @@ impl StreamParserState { } } - // No new content to return String::new() } } diff --git a/html-router/src/routes/chat/mod.rs b/html-router/src/routes/chat/mod.rs index b1bcd48..f9dfb3a 100644 --- a/html-router/src/routes/chat/mod.rs +++ b/html-router/src/routes/chat/mod.rs @@ -10,8 +10,9 @@ use axum::{ }; pub use chat_handlers::{ delete_conversation, new_chat_user_message, new_user_message, patch_conversation_title, - reload_sidebar, show_chat_base, show_conversation_editing_title, show_existing_chat, - show_initialized_chat, + reload_sidebar, show_conversation_editing_title, + show_chat_base as show_base, show_existing_chat as show_existing, + show_initialized_chat as show_initialized, }; use message_response_stream::get_response_stream; use references::show_reference_tooltip; @@ -24,10 +25,10 @@ where HtmlState: FromRef, { Router::new() - .route("/chat", get(show_chat_base).post(new_chat_user_message)) + .route("/chat", get(show_base).post(new_chat_user_message)) .route( "/chat/{id}", - get(show_existing_chat) + get(show_existing) .post(new_user_message) .delete(delete_conversation), ) @@ -36,7 +37,7 @@ where get(show_conversation_editing_title).patch(patch_conversation_title), ) .route("/chat/sidebar", get(reload_sidebar)) - .route("/initialized-chat", post(show_initialized_chat)) + .route("/initialized-chat", post(show_initialized)) .route("/chat/response-stream", get(get_response_stream)) .route("/chat/reference/{id}", get(show_reference_tooltip)) } diff --git a/html-router/src/routes/content/handlers.rs b/html-router/src/routes/content/handlers.rs index 77d526a..57f258a 100644 --- a/html-router/src/routes/content/handlers.rs +++ b/html-router/src/routes/content/handlers.rs @@ -102,13 +102,13 @@ pub async fn show_text_content_edit_form( RequireUser(user): RequireUser, Path(id): Path, ) -> Result { - let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?; - #[derive(Serialize)] pub struct TextContentEditModal { pub text_content: TextContent, } + let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?; + Ok(TemplateResponse::new_template( "content/edit_text_content_modal.html", TextContentEditModal { text_content }, @@ -214,13 +214,14 @@ pub async fn show_content_read_modal( RequireUser(user): RequireUser, Path(id): Path, ) -> Result { - // Get and validate the text content - let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?; #[derive(Serialize)] pub struct TextContentReadModalData { pub text_content: TextContent, } + // Get and validate the text content + let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?; + Ok(TemplateResponse::new_template( "content/read_content_modal.html", TextContentReadModalData { text_content }, diff --git a/html-router/src/routes/index/handlers.rs b/html-router/src/routes/index/handlers.rs index d38312d..60de178 100644 --- a/html-router/src/routes/index/handlers.rs +++ b/html-router/src/routes/index/handlers.rs @@ -226,7 +226,7 @@ fn summarize_task_content(task: &IngestionTask) -> (String, String) { ("Text".to_string(), truncate_summary(text, 80)) } common::storage::types::ingestion_payload::IngestionPayload::Url { url, .. } => { - ("URL".to_string(), url.to_string()) + ("URL".to_string(), url.clone()) } common::storage::types::ingestion_payload::IngestionPayload::File { file_info, .. } => { ("File".to_string(), file_info.file_name.clone()) @@ -248,18 +248,16 @@ pub async fn serve_file( RequireUser(user): RequireUser, Path(file_id): Path, ) -> Result { - let file_info = match FileInfo::get_by_id(&file_id, &state.db).await { - Ok(info) => info, - _ => return Ok(TemplateResponse::not_found().into_response()), + let Ok(file_info) = FileInfo::get_by_id(&file_id, &state.db).await else { + return Ok(TemplateResponse::not_found().into_response()); }; if file_info.user_id != user.id { return Ok(TemplateResponse::unauthorized().into_response()); } - let stream = match state.storage.get_stream(&file_info.path).await { - Ok(s) => s, - Err(_) => return Ok(TemplateResponse::server_error().into_response()), + let Ok(stream) = state.storage.get_stream(&file_info.path).await else { + return Ok(TemplateResponse::server_error().into_response()); }; let body = Body::from_stream(stream); diff --git a/html-router/src/routes/ingestion/handlers.rs b/html-router/src/routes/ingestion/handlers.rs index c5f8df5..d79c18d 100644 --- a/html-router/src/routes/ingestion/handlers.rs +++ b/html-router/src/routes/ingestion/handlers.rs @@ -1,4 +1,4 @@ -use std::{pin::Pin, time::Duration}; +use std::{pin::Pin, sync::Arc, time::Duration}; use axum::{ extract::{Query, State}, @@ -51,13 +51,13 @@ pub async fn show_ingest_form( State(state): State, RequireUser(user): RequireUser, ) -> Result { - let user_categories = User::get_user_categories(&user.id, &state.db).await?; - #[derive(Serialize)] pub struct ShowIngestFormData { user_categories: Vec, } + let user_categories = User::get_user_categories(&user.id, &state.db).await?; + Ok(TemplateResponse::new_template( "ingestion_modal.html", ShowIngestFormData { user_categories }, @@ -180,7 +180,7 @@ pub async fn get_task_updates_stream( Query(params): Query, ) -> TaskSse { let task_id = params.task_id.clone(); - let db = state.db.clone(); + let db = Arc::clone(&state.db); // 1. Check for authenticated user let Some(current_user) = auth.current_user else { @@ -198,7 +198,7 @@ pub async fn get_task_updates_stream( } let sse_stream = async_stream::stream! { - let mut consecutive_db_errors = 0; + let mut consecutive_db_errors: u32 = 0; let max_consecutive_db_errors = 3; loop { @@ -263,7 +263,7 @@ pub async fn get_task_updates_stream( } Err(db_err) => { error!("Database error while fetching task '{}': {:?}", task_id, db_err); - consecutive_db_errors += 1; + consecutive_db_errors = consecutive_db_errors.saturating_add(1); yield Ok(Event::default().event("error").data(format!("Temporary error fetching task update (attempt {consecutive_db_errors})."))); if consecutive_db_errors >= max_consecutive_db_errors { diff --git a/html-router/src/routes/knowledge/handlers.rs b/html-router/src/routes/knowledge/handlers.rs index 9456b12..226e983 100644 --- a/html-router/src/routes/knowledge/handlers.rs +++ b/html-router/src/routes/knowledge/handlers.rs @@ -39,7 +39,7 @@ use url::form_urlencoded; const KNOWLEDGE_ENTITIES_PER_PAGE: usize = 12; const RELATIONSHIP_TYPE_OPTIONS: &[&str] = &["RelatedTo", "RelevantTo", "SimilarTo", "References"]; -const DEFAULT_RELATIONSHIP_TYPE: &str = RELATIONSHIP_TYPE_OPTIONS[0]; +const DEFAULT_RELATIONSHIP_TYPE: &str = "RelatedTo"; const MAX_RELATIONSHIP_SUGGESTIONS: usize = 10; const SUGGESTION_MIN_SCORE: f32 = 0.5; @@ -61,15 +61,15 @@ fn canonicalize_relationship_type(value: &str) -> String { let key: String = trimmed .chars() - .filter(|c| c.is_ascii_alphanumeric()) - .flat_map(|c| c.to_lowercase()) + .filter(char::is_ascii_alphanumeric) + .flat_map(char::to_lowercase) .collect(); for option in RELATIONSHIP_TYPE_OPTIONS { let option_key: String = option .chars() - .filter(|c| c.is_ascii_alphanumeric()) - .flat_map(|c| c.to_lowercase()) + .filter(char::is_ascii_alphanumeric) + .flat_map(char::to_lowercase) .collect(); if option_key == key { return (*option).to_string(); @@ -141,7 +141,7 @@ pub async fn show_new_knowledge_entity_form( ) -> Result { let entity_types: Vec = KnowledgeEntityType::variants() .iter() - .map(|&s| s.to_owned()) + .map(ToString::to_string) .collect(); let existing_entities = User::get_knowledge_entities(&user.id, &state.db).await?; @@ -278,7 +278,7 @@ pub async fn suggest_knowledge_relationships( if !query_parts.is_empty() { let query = query_parts.join(" "); let rerank_lease = match state.reranker_pool.as_ref() { - Some(pool) => Some(pool.checkout().await), + Some(pool) => pool.checkout().await, None => None, }; @@ -406,9 +406,10 @@ fn build_relationship_table_data( .map(|relationship| { let relationship_type_label = canonicalize_relationship_type(&relationship.metadata.relationship_type); - *frequency + let count = frequency .entry(relationship_type_label.clone()) - .or_insert(0) += 1; + .or_insert(0); + *count = count.saturating_add(1); RelationshipTableRow { relationship, relationship_type_label, @@ -417,9 +418,7 @@ fn build_relationship_table_data( .collect(); let default_relationship_type = frequency .into_iter() - .max_by_key(|(_, count)| *count) - .map(|(label, _)| label) - .unwrap_or_else(|| DEFAULT_RELATIONSHIP_TYPE.to_string()); + .max_by_key(|(_, count)| *count).map_or_else(|| DEFAULT_RELATIONSHIP_TYPE.to_string(), |(label, _)| label); RelationshipTableData { entities, @@ -800,8 +799,10 @@ pub async fn get_knowledge_graph_json( for rel in &relationships { if entity_ids.contains(&rel.in_) && entity_ids.contains(&rel.out) { // undirected counting for degree - *degree_count.entry(rel.in_.clone()).or_insert(0) += 1; - *degree_count.entry(rel.out.clone()).or_insert(0) += 1; + let count = degree_count.entry(rel.in_.clone()).or_insert(0); + *count = count.saturating_add(1); + let count = degree_count.entry(rel.out.clone()).or_insert(0); + *count = count.saturating_add(1); links.push(GraphLink { source: rel.out.clone(), target: rel.in_.clone(), @@ -836,11 +837,11 @@ fn normalize_filter(input: Option) -> Option { fn trim_matching_quotes(value: &str) -> &str { let bytes = value.as_bytes(); - if bytes.len() >= 2 { - let first = bytes[0]; - let last = bytes[bytes.len() - 1]; - if (first == b'"' && last == b'"') || (first == b'\'' && last == b'\'') { - return &value[1..value.len() - 1]; + if let (Some(&first), Some(&last)) = (bytes.first(), bytes.last()) { + if bytes.len() >= 2 + && ((first == b'"' && last == b'"') || (first == b'\'' && last == b'\'')) + { + return &value[1..value.len().saturating_sub(1)]; } } value @@ -860,7 +861,7 @@ pub async fn show_edit_knowledge_entity_form( // Get entity types let entity_types: Vec = KnowledgeEntityType::variants() .iter() - .map(|&s| s.to_owned()) + .map(ToString::to_string) .collect(); // Get the entity and validate ownership diff --git a/html-router/src/routes/search/handlers.rs b/html-router/src/routes/search/handlers.rs index 244590f..85f05e4 100644 --- a/html-router/src/routes/search/handlers.rs +++ b/html-router/src/routes/search/handlers.rs @@ -11,6 +11,7 @@ use axum::{ use common::storage::types::{ serde_helpers::deserialize_flexible_id, text_content::TextContent, + user::User, StoredObject, }; use retrieval_pipeline::{RetrievalConfig, SearchResult, SearchTarget, StrategyOutput}; @@ -46,13 +47,11 @@ fn source_id_suffix(source_id: &str) -> String { fn truncate_label(value: &str, max_chars: usize) -> String { let mut end = None; - let mut count = 0; - for (idx, _) in value.char_indices() { + for (count, (idx, _)) in value.char_indices().enumerate() { if count == max_chars { end = Some(idx); break; } - count += 1; } match end { @@ -174,165 +173,31 @@ struct KnowledgeEntityForTemplate { score: f32, } +#[derive(Serialize)] +struct SearchResultForTemplate { + result_type: String, + score: f32, + #[serde(skip_serializing_if = "Option::is_none")] + text_chunk: Option, + #[serde(skip_serializing_if = "Option::is_none")] + knowledge_entity: Option, +} + +#[derive(Serialize)] +pub struct AnswerData { + search_result: Vec, + query_param: String, +} + pub async fn search_result_handler( State(state): State, Query(params): Query, RequireUser(user): RequireUser, ) -> Result { - #[derive(Serialize)] - struct SearchResultForTemplate { - result_type: String, - score: f32, - #[serde(skip_serializing_if = "Option::is_none")] - text_chunk: Option, - #[serde(skip_serializing_if = "Option::is_none")] - knowledge_entity: Option, - } - - #[derive(Serialize)] - pub struct AnswerData { - search_result: Vec, - query_param: String, - } - let (search_results_for_template, final_query_param_for_template) = if let Some(actual_query) = params.query { - let trimmed_query = actual_query.trim(); - if trimmed_query.is_empty() { - (Vec::::new(), String::new()) - } else { - // Use retrieval pipeline Search strategy - let config = RetrievalConfig::for_search(SearchTarget::Both); - - // Checkout a reranker lease if pool is available - let reranker_lease = match &state.reranker_pool { - Some(pool) => Some(pool.checkout().await), - None => None, - }; - - let result = retrieval_pipeline::pipeline::run_pipeline( - &state.db, - &state.openai_client, - Some(&state.embedding_provider), - trimmed_query, - &user.id, - config, - reranker_lease, - ) - .await?; - - let search_result = match result { - StrategyOutput::Search(sr) => sr, - _ => SearchResult::new(vec![], vec![]), - }; - - let mut source_ids = HashSet::new(); - for chunk_result in &search_result.chunks { - source_ids.insert(chunk_result.chunk.source_id.clone()); - } - for entity_result in &search_result.entities { - source_ids.insert(entity_result.entity.source_id.clone()); - } - - let source_label_map = if source_ids.is_empty() { - HashMap::new() - } else { - let record_ids: Vec = source_ids - .iter() - .filter_map(|id| { - if id.contains(':') { - RecordId::from_str(id).ok() - } else { - Some(RecordId::from_table_key(TextContent::table_name(), id)) - } - }) - .collect(); - let mut response = state - .db - .client - .query( - "SELECT id, url_info, file_info, context, category, text FROM type::table($table_name) WHERE user_id = $user_id AND id INSIDE $record_ids", - ) - .bind(("table_name", TextContent::table_name())) - .bind(("user_id", user.id.clone())) - .bind(("record_ids", record_ids)) - .await?; - let contents: Vec = response.take(0)?; - - tracing::debug!( - source_id_count = source_ids.len(), - label_row_count = contents.len(), - "Resolved search source labels" - ); - - let mut labels = HashMap::new(); - for content in contents { - let label = build_source_label(&content); - labels.insert(content.id.clone(), label.clone()); - labels.insert( - format!("{}:{}", TextContent::table_name(), content.id), - label, - ); - } - - labels - }; - - let mut combined_results: Vec = - Vec::with_capacity(search_result.chunks.len() + search_result.entities.len()); - - // Add chunk results - for chunk_result in search_result.chunks { - let source_label = source_label_map - .get(&chunk_result.chunk.source_id) - .cloned() - .unwrap_or_else(|| fallback_source_label(&chunk_result.chunk.source_id)); - combined_results.push(SearchResultForTemplate { - result_type: "text_chunk".to_string(), - score: chunk_result.score, - text_chunk: Some(TextChunkForTemplate { - id: chunk_result.chunk.id, - source_id: chunk_result.chunk.source_id, - source_label, - chunk: chunk_result.chunk.chunk, - score: chunk_result.score, - }), - knowledge_entity: None, - }); - } - - // Add entity results - for entity_result in search_result.entities { - let source_label = source_label_map - .get(&entity_result.entity.source_id) - .cloned() - .unwrap_or_else(|| fallback_source_label(&entity_result.entity.source_id)); - combined_results.push(SearchResultForTemplate { - result_type: "knowledge_entity".to_string(), - score: entity_result.score, - text_chunk: None, - knowledge_entity: Some(KnowledgeEntityForTemplate { - id: entity_result.entity.id, - name: entity_result.entity.name, - description: entity_result.entity.description, - entity_type: format!("{:?}", entity_result.entity.entity_type), - source_id: entity_result.entity.source_id, - source_label, - score: entity_result.score, - }), - }); - } - - // Sort by score descending - combined_results.sort_by(|a, b| b.score.total_cmp(&a.score)); - - // Limit results - const TOTAL_LIMIT: usize = 10; - combined_results.truncate(TOTAL_LIMIT); - - (combined_results, trimmed_query.to_string()) - } + perform_search(&state, &user, actual_query).await? } else { (Vec::::new(), String::new()) }; @@ -345,3 +210,147 @@ pub async fn search_result_handler( }, )) } + +async fn perform_search( + state: &HtmlState, + user: &User, + query: String, +) -> Result<(Vec, String), HtmlError> { + const TOTAL_LIMIT: usize = 10; + + let trimmed_query = query.trim(); + if trimmed_query.is_empty() { + return Ok((Vec::new(), String::new())); + } + + let config = RetrievalConfig::for_search(SearchTarget::Both); + + let reranker_lease = match &state.reranker_pool { + Some(pool) => pool.checkout().await, + None => None, + }; + + let params = retrieval_pipeline::pipeline::StrategyParams { + db_client: &state.db, + openai_client: &state.openai_client, + embedding_provider: Some(&state.embedding_provider), + input_text: trimmed_query, + user_id: &user.id, + config, + reranker: reranker_lease, + }; + let result = retrieval_pipeline::pipeline::execute(params).await?; + + let search_result = match result { + StrategyOutput::Search(sr) => sr, + _ => SearchResult::new(vec![], vec![]), + }; + + let source_label_map = resolve_source_labels(state, user, &search_result).await?; + + let mut combined_results: Vec = + Vec::with_capacity(search_result.chunks.len().saturating_add(search_result.entities.len())); + + for chunk_result in search_result.chunks { + let source_label = source_label_map + .get(&chunk_result.chunk.source_id) + .cloned() + .unwrap_or_else(|| fallback_source_label(&chunk_result.chunk.source_id)); + combined_results.push(SearchResultForTemplate { + result_type: "text_chunk".to_string(), + score: chunk_result.score, + text_chunk: Some(TextChunkForTemplate { + id: chunk_result.chunk.id, + source_id: chunk_result.chunk.source_id, + source_label, + chunk: chunk_result.chunk.chunk, + score: chunk_result.score, + }), + knowledge_entity: None, + }); + } + + for entity_result in search_result.entities { + let source_label = source_label_map + .get(&entity_result.entity.source_id) + .cloned() + .unwrap_or_else(|| fallback_source_label(&entity_result.entity.source_id)); + combined_results.push(SearchResultForTemplate { + result_type: "knowledge_entity".to_string(), + score: entity_result.score, + text_chunk: None, + knowledge_entity: Some(KnowledgeEntityForTemplate { + id: entity_result.entity.id, + name: entity_result.entity.name, + description: entity_result.entity.description, + entity_type: format!("{:?}", entity_result.entity.entity_type), + source_id: entity_result.entity.source_id, + source_label, + score: entity_result.score, + }), + }); + } + + combined_results.sort_by(|a, b| b.score.total_cmp(&a.score)); + combined_results.truncate(TOTAL_LIMIT); + + Ok((combined_results, trimmed_query.to_string())) +} + +async fn resolve_source_labels( + state: &HtmlState, + user: &User, + search_result: &SearchResult, +) -> Result, HtmlError> { + let mut source_ids = HashSet::new(); + for chunk_result in &search_result.chunks { + source_ids.insert(chunk_result.chunk.source_id.clone()); + } + for entity_result in &search_result.entities { + source_ids.insert(entity_result.entity.source_id.clone()); + } + + if source_ids.is_empty() { + return Ok(HashMap::new()); + } + + let record_ids: Vec = source_ids + .iter() + .filter_map(|id| { + if id.contains(':') { + RecordId::from_str(id).ok() + } else { + Some(RecordId::from_table_key(TextContent::table_name(), id)) + } + }) + .collect(); + let mut response = state + .db + .client + .query( + "SELECT id, url_info, file_info, context, category, text FROM type::table($table_name) WHERE user_id = $user_id AND id INSIDE $record_ids", + ) + .bind(("table_name", TextContent::table_name())) + .bind(("user_id", user.id.clone())) + .bind(("record_ids", record_ids)) + .await?; + let contents: Vec = response.take(0)?; + + tracing::debug!( + source_id_count = source_ids.len(), + label_row_count = contents.len(), + "Resolved search source labels" + ); + + let mut labels = HashMap::new(); + for content in contents { + let label = build_source_label(&content); + labels.insert(content.id.clone(), label.clone()); + labels.insert( + format!("{}:{}", TextContent::table_name(), content.id), + label, + ); + } + + Ok(labels) +} diff --git a/html-router/src/routes/search/mod.rs b/html-router/src/routes/search/mod.rs index e61e2bc..cbd93b0 100644 --- a/html-router/src/routes/search/mod.rs +++ b/html-router/src/routes/search/mod.rs @@ -1,7 +1,10 @@ mod handlers; use axum::{extract::FromRef, routing::get, Router}; -pub use handlers::{search_result_handler, SearchParams}; +#[allow(clippy::module_name_repetitions)] +pub use handlers::{ + search_result_handler as result_handler, SearchParams as SearchQueryParams, +}; use crate::html_state::HtmlState; @@ -10,5 +13,5 @@ where S: Clone + Send + Sync + 'static, HtmlState: FromRef, { - Router::new().route("/search", get(search_result_handler)) + Router::new().route("/search", get(result_handler)) } diff --git a/html-router/src/utils/pagination.rs b/html-router/src/utils/pagination.rs index adcf49c..54a8c6f 100644 --- a/html-router/src/utils/pagination.rs +++ b/html-router/src/utils/pagination.rs @@ -31,8 +31,8 @@ impl Pagination { } else { 0 }; - let start_index = if page_len == 0 { 0 } else { offset + 1 }; - let end_index = if page_len == 0 { 0 } else { offset + page_len }; + let start_index = if page_len == 0 { 0 } else { offset.saturating_add(1) }; + let end_index = if page_len == 0 { 0 } else { offset.saturating_add(page_len) }; Self { current_page, @@ -42,12 +42,12 @@ impl Pagination { has_previous, has_next, previous_page: if has_previous { - Some(current_page - 1) + Some(current_page.saturating_sub(1)) } else { None }, next_page: if has_next { - Some(current_page + 1) + Some(current_page.saturating_add(1)) } else { None }, @@ -68,7 +68,7 @@ pub fn paginate_items( let total_pages = if total_items == 0 { 0 } else { - ((total_items - 1) / per_page) + 1 + total_items.saturating_sub(1).checked_div(per_page).unwrap_or(0).saturating_add(1) }; let mut current_page = requested_page.unwrap_or(1); @@ -84,7 +84,7 @@ pub fn paginate_items( let offset = if total_pages == 0 { 0 } else { - per_page.saturating_mul(current_page - 1) + per_page.saturating_mul(current_page.saturating_sub(1)) }; let page_items: Vec = items.into_iter().skip(offset).take(per_page).collect(); @@ -136,8 +136,8 @@ mod tests { assert_eq!(page, vec![5]); assert_eq!(meta.current_page, 3); assert_eq!(meta.total_pages, 3); - assert_eq!(meta.has_next, false); - assert_eq!(meta.has_previous, true); + assert!(!meta.has_next, "expected no next page"); + assert!(meta.has_previous, "expected previous page"); assert_eq!(meta.start_index, 5); assert_eq!(meta.end_index, 5); } diff --git a/ingestion-pipeline/src/pipeline/services.rs b/ingestion-pipeline/src/pipeline/services.rs index 294bb0e..ccd036e 100644 --- a/ingestion-pipeline/src/pipeline/services.rs +++ b/ingestion-pipeline/src/pipeline/services.rs @@ -180,7 +180,7 @@ impl PipelineServices for DefaultPipelineServices { ); let rerank_lease = match &self.reranker_pool { - Some(pool) => Some(pool.checkout().await), + Some(pool) => pool.checkout().await, None => None, }; diff --git a/ingestion-pipeline/src/pipeline/stages/mod.rs b/ingestion-pipeline/src/pipeline/stages/mod.rs index 20ceb6d..4583db3 100644 --- a/ingestion-pipeline/src/pipeline/stages/mod.rs +++ b/ingestion-pipeline/src/pipeline/stages/mod.rs @@ -4,7 +4,7 @@ use common::{ error::AppError, storage::{ db::SurrealDbClient, - indexes::rebuild_indexes, + indexes::rebuild, types::{ ingestion_payload::IngestionPayload, knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk, @@ -191,7 +191,7 @@ pub async fn persist( ctx.db.store_item(text_content).await?; debug!("stored item"); - rebuild_indexes(ctx.db).await?; + rebuild(ctx.db).await?; debug!( task_id = %ctx.task_id, @@ -301,8 +301,8 @@ async fn store_chunk_batch( for embedded in batch { TextChunk::store_with_embedding( - embedded.chunk.to_owned(), - embedded.embedding.to_owned(), + embedded.chunk.clone(), + embedded.embedding.clone(), db, ) .await?; diff --git a/ingestion-pipeline/src/pipeline/tests.rs b/ingestion-pipeline/src/pipeline/tests.rs index 4c9d7f3..52d26de 100644 --- a/ingestion-pipeline/src/pipeline/tests.rs +++ b/ingestion-pipeline/src/pipeline/tests.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use anyhow::{self, Context}; use crate::pipeline::context::{EmbeddedKnowledgeEntity, EmbeddedTextChunk}; use async_trait::async_trait; use chrono::{Duration as ChronoDuration, Utc}; @@ -265,16 +266,12 @@ impl PipelineServices for ValidationServices { } } -async fn setup_db() -> SurrealDbClient { +async fn setup_db() -> anyhow::Result { let namespace = "pipeline_test"; let database = Uuid::new_v4().to_string(); - let db = SurrealDbClient::memory(namespace, &database) - .await - .expect("Failed to create in-memory SurrealDB"); - db.apply_migrations() - .await - .expect("Failed to apply migrations"); - db + let db = SurrealDbClient::memory(namespace, &database).await?; + db.apply_migrations().await?; + Ok(db) } fn pipeline_config() -> IngestionConfig { @@ -295,26 +292,28 @@ async fn reserve_task( worker_id: &str, payload: IngestionPayload, user_id: &str, -) -> IngestionTask { - let task = IngestionTask::create_and_add_to_db(payload, user_id.into(), db) - .await - .expect("task created"); +) -> anyhow::Result { + let task = IngestionTask::create_and_add_to_db(payload, user_id.into(), db).await?; let lease = task.lease_duration(); - IngestionTask::claim_next_ready(db, worker_id, Utc::now(), lease) - .await - .expect("claim succeeds") - .expect("task claimed") + let claimed = IngestionTask::claim_next_ready(db, worker_id, Utc::now(), lease) + .await? + .context("task claimed")?; + Ok(claimed) } #[tokio::test] -async fn ingestion_pipeline_happy_path_persists_entities() { - let db = setup_db().await; +async fn ingestion_pipeline_happy_path_persists_entities() -> anyhow::Result<()> +{ + let db = setup_db().await?; let worker_id = "worker-happy"; let user_id = "user-123"; let services = Arc::new(MockServices::new(user_id)); - let pipeline = - IngestionPipeline::with_services(Arc::new(db.clone()), pipeline_config(), services.clone()) - .expect("pipeline"); + let services_clone: Arc = Arc::::clone(&services); + let pipeline = IngestionPipeline::with_services( + Arc::new(db.clone()), + pipeline_config(), + services_clone, + )?; let task = reserve_task( &db, @@ -327,30 +326,22 @@ async fn ingestion_pipeline_happy_path_persists_entities() { }, user_id, ) - .await; + .await?; - pipeline - .process_task(task.clone()) - .await - .expect("pipeline succeeds"); + pipeline.process_task(task.clone()).await?; let stored_task: IngestionTask = db .get_item(&task.id) - .await - .expect("retrieve task") - .expect("task present"); + .await? + .context("task present")?; assert_eq!(stored_task.state, TaskState::Succeeded); let stored_entities: Vec = db .get_all_stored_items::() - .await - .expect("entities stored"); + .await?; assert!(!stored_entities.is_empty(), "entities should be stored"); - let stored_chunks: Vec = db - .get_all_stored_items::() - .await - .expect("chunks stored"); + let stored_chunks: Vec = db.get_all_stored_items::().await?; assert!( !stored_chunks.is_empty(), "chunks should be stored for ingestion text" @@ -362,22 +353,29 @@ async fn ingestion_pipeline_happy_path_persists_entities() { "expected at least one chunk embedding call" ); assert_eq!( - &call_log[0..4], - ["prepare", "retrieve", "enrich", "convert"] + call_log.get(0..4), + Some(&["prepare", "retrieve", "enrich", "convert"][..]) ); - assert!(call_log[4..].iter().all(|entry| *entry == "chunk")); + assert!( + call_log.get(4..).is_some_and(|tail| tail.iter().all(|entry| *entry == "chunk")) + ); + Ok(()) } #[tokio::test] -async fn ingestion_pipeline_chunk_only_skips_analysis() { - let db = setup_db().await; +async fn ingestion_pipeline_chunk_only_skips_analysis() -> anyhow::Result<()> { + let db = setup_db().await?; let worker_id = "worker-chunk-only"; let user_id = "user-999"; let services = Arc::new(MockServices::new(user_id)); + let services_clone: Arc = Arc::::clone(&services); let mut config = pipeline_config(); config.chunk_only = true; - let pipeline = IngestionPipeline::with_services(Arc::new(db.clone()), config, services.clone()) - .expect("pipeline"); + let pipeline = IngestionPipeline::with_services( + Arc::new(db.clone()), + config, + services_clone, + )?; let task = reserve_task( &db, @@ -390,17 +388,13 @@ async fn ingestion_pipeline_chunk_only_skips_analysis() { }, user_id, ) - .await; + .await?; - pipeline - .process_task(task.clone()) - .await - .expect("pipeline succeeds"); + pipeline.process_task(task.clone()).await?; let stored_entities: Vec = db .get_all_stored_items::() - .await - .expect("entities stored"); + .await?; assert!( stored_entities.is_empty(), "chunk-only ingestion should not persist entities" @@ -408,8 +402,7 @@ async fn ingestion_pipeline_chunk_only_skips_analysis() { let relationship_count: Option = db .client .query("SELECT count() as count FROM relates_to;") - .await - .expect("query relationships") + .await? .take::>(0) .unwrap_or_default(); assert_eq!( @@ -417,10 +410,7 @@ async fn ingestion_pipeline_chunk_only_skips_analysis() { 0, "chunk-only ingestion should not persist relationships" ); - let stored_chunks: Vec = db - .get_all_stored_items::() - .await - .expect("chunks stored"); + let stored_chunks: Vec = db.get_all_stored_items::().await?; assert!( !stored_chunks.is_empty(), "chunk-only ingestion should still persist chunks" @@ -428,19 +418,19 @@ async fn ingestion_pipeline_chunk_only_skips_analysis() { let call_log = services.calls.lock().await.clone(); assert_eq!(call_log, vec!["prepare", "chunk"]); + Ok(()) } #[tokio::test] -async fn ingestion_pipeline_failure_marks_retry() { - let db = setup_db().await; +async fn ingestion_pipeline_failure_marks_retry() -> anyhow::Result<()> { + let db = setup_db().await?; let worker_id = "worker-fail"; let user_id = "user-456"; let services = Arc::new(FailingServices { inner: MockServices::new(user_id), }); let pipeline = - IngestionPipeline::with_services(Arc::new(db.clone()), pipeline_config(), services) - .expect("pipeline"); + IngestionPipeline::with_services(Arc::new(db.clone()), pipeline_config(), services)?; let task = reserve_task( &db, @@ -453,7 +443,7 @@ async fn ingestion_pipeline_failure_marks_retry() { }, user_id, ) - .await; + .await?; let result = pipeline.process_task(task.clone()).await; assert!( @@ -463,38 +453,38 @@ async fn ingestion_pipeline_failure_marks_retry() { let stored_task: IngestionTask = db .get_item(&task.id) - .await - .expect("retrieve task") - .expect("task present"); + .await? + .context("task present")?; assert_eq!(stored_task.state, TaskState::Failed); assert!( stored_task.scheduled_at > Utc::now() - ChronoDuration::seconds(5), "failed task should schedule retry in the future" ); + Ok(()) } #[tokio::test] -async fn ingestion_pipeline_validation_failure_dead_letters_task() { - let db = setup_db().await; +async fn ingestion_pipeline_validation_failure_dead_letters_task( +) -> anyhow::Result<()> { + let db = setup_db().await?; let worker_id = "worker-validation"; let user_id = "user-789"; let services = Arc::new(ValidationServices); let pipeline = - IngestionPipeline::with_services(Arc::new(db.clone()), pipeline_config(), services) - .expect("pipeline"); + IngestionPipeline::with_services(Arc::new(db.clone()), pipeline_config(), services)?; let task = reserve_task( &db, worker_id, IngestionPayload::Text { text: "irrelevant".into(), - context: "".into(), + context: String::new(), category: "notes".into(), user_id: user_id.into(), }, user_id, ) - .await; + .await?; let result = pipeline.process_task(task.clone()).await; assert!( @@ -504,8 +494,8 @@ async fn ingestion_pipeline_validation_failure_dead_letters_task() { let stored_task: IngestionTask = db .get_item(&task.id) - .await - .expect("retrieve task") - .expect("task present"); + .await? + .context("task present")?; assert_eq!(stored_task.state, TaskState::DeadLetter); + Ok(()) } diff --git a/ingestion-pipeline/src/utils/file_text_extraction.rs b/ingestion-pipeline/src/utils/file_text_extraction.rs index 1ae88c8..74bf6eb 100644 --- a/ingestion-pipeline/src/utils/file_text_extraction.rs +++ b/ingestion-pipeline/src/utils/file_text_extraction.rs @@ -155,21 +155,20 @@ mod tests { }; #[tokio::test] - async fn extracts_text_using_memory_storage_backend() { - let mut config = AppConfig::default(); - config.storage = StorageKind::Memory; + async fn extracts_text_using_memory_storage_backend() -> anyhow::Result<()> { + let config = AppConfig { + storage: StorageKind::Memory, + ..Default::default() + }; - let storage = StorageManager::new(&config) - .await - .expect("create storage manager"); + let storage = StorageManager::new(&config).await?; let location = "user/test/file.txt"; let contents = b"hello from memory storage"; storage .put(location, Bytes::from(contents.as_slice().to_vec())) - .await - .expect("write object"); + .await?; let now = Utc::now(); let file_info = FileInfo { @@ -185,16 +184,14 @@ mod tests { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); - let db = SurrealDbClient::memory(namespace, database) - .await - .expect("create surreal memory"); + let db = SurrealDbClient::memory(namespace, database).await?; let openai_client = Client::with_config(OpenAIConfig::default()); let text = extract_text_from_file(&file_info, &db, &openai_client, &config, &storage) - .await - .expect("extract text"); + .await?; assert_eq!(text, String::from_utf8_lossy(contents)); + Ok(()) } } diff --git a/ingestion-pipeline/src/utils/pdf_ingestion.rs b/ingestion-pipeline/src/utils/pdf_ingestion.rs index 5be106c..bc39b89 100644 --- a/ingestion-pipeline/src/utils/pdf_ingestion.rs +++ b/ingestion-pipeline/src/utils/pdf_ingestion.rs @@ -715,6 +715,7 @@ const fn prompt_for_attempt(attempt: usize, base_prompt: &str) -> &str { #[cfg(test)] mod tests { use super::*; + use anyhow::{self}; #[test] fn test_looks_good_enough_short_text() { @@ -737,15 +738,16 @@ mod tests { } #[test] - fn test_debug_dump_directory_env_var() { + fn test_debug_dump_directory_env_var() -> anyhow::Result<()> { std::env::remove_var(DEBUG_IMAGE_ENV_VAR); assert!(debug_dump_directory().is_none()); std::env::set_var(DEBUG_IMAGE_ENV_VAR, "/tmp/minne_pdf_debug"); - let dir = debug_dump_directory().expect("expected debug directory"); + let dir = debug_dump_directory().ok_or_else(|| anyhow::anyhow!("expected debug directory"))?; assert_eq!(dir, PathBuf::from("/tmp/minne_pdf_debug")); std::env::remove_var(DEBUG_IMAGE_ENV_VAR); + Ok(()) } #[test] diff --git a/ingestion-pipeline/src/utils/url_text_retrieval.rs b/ingestion-pipeline/src/utils/url_text_retrieval.rs index a76bfca..a6ea67e 100644 --- a/ingestion-pipeline/src/utils/url_text_retrieval.rs +++ b/ingestion-pipeline/src/utils/url_text_retrieval.rs @@ -142,29 +142,34 @@ fn ensure_ingestion_url_allowed(url: &url::Url) -> Result { #[cfg(test)] mod tests { use super::*; + use anyhow::{self}; #[test] - fn rejects_unsupported_scheme() { - let url = url::Url::parse("ftp://example.com").expect("url"); + fn rejects_unsupported_scheme() -> anyhow::Result<()> { + let url = url::Url::parse("ftp://example.com")?; assert!(ensure_ingestion_url_allowed(&url).is_err()); + Ok(()) } #[test] - fn rejects_localhost() { - let url = url::Url::parse("http://localhost/resource").expect("url"); + fn rejects_localhost() -> anyhow::Result<()> { + let url = url::Url::parse("http://localhost/resource")?; assert!(ensure_ingestion_url_allowed(&url).is_err()); + Ok(()) } #[test] - fn rejects_private_ipv4() { - let url = url::Url::parse("http://192.168.1.10/index.html").expect("url"); + fn rejects_private_ipv4() -> anyhow::Result<()> { + let url = url::Url::parse("http://192.168.1.10/index.html")?; assert!(ensure_ingestion_url_allowed(&url).is_err()); + Ok(()) } #[test] - fn allows_public_domain_and_sanitizes() { - let url = url::Url::parse("https://sub.example.com/path").expect("url"); - let sanitized = ensure_ingestion_url_allowed(&url).expect("allowed"); + fn allows_public_domain_and_sanitizes() -> anyhow::Result<()> { + let url = url::Url::parse("https://sub.example.com/path")?; + let sanitized = ensure_ingestion_url_allowed(&url)?; assert_eq!(sanitized, "sub_example_com"); + Ok(()) } } diff --git a/main/src/main.rs b/main/src/main.rs index 44d0c8b..3d1018e 100644 --- a/main/src/main.rs +++ b/main/src/main.rs @@ -3,7 +3,7 @@ use axum::{extract::FromRef, Router}; use common::{ storage::{ db::SurrealDbClient, - indexes::ensure_runtime_indexes, + indexes::ensure_runtime, store::StorageManager, types::{ knowledge_entity::KnowledgeEntity, system_settings::SystemSettings, @@ -12,7 +12,10 @@ use common::{ }, utils::{config::get_config, embedding::EmbeddingProvider}, }; -use html_router::{html_routes, html_state::HtmlState}; +use html_router::{ + html_routes, + html_state::{HtmlState, StateResources}, +}; use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop}; use retrieval_pipeline::reranking::RerankerPool; use std::sync::Arc; @@ -21,19 +24,77 @@ use tracing_subscriber::{fmt, prelude::*, EnvFilter}; use tokio::task::LocalSet; +fn spawn_server_thread( + listener: tokio::net::TcpListener, + app: Router, +) -> std::thread::JoinHandle<()> { + std::thread::spawn(move || { + let rt = match tokio::runtime::Runtime::new() { + Ok(rt) => rt, + Err(e) => { + error!("Failed to create server runtime: {e}"); + return; + } + }; + rt.block_on(async { + if let Err(e) = axum::serve(listener, app).await { + error!("Server error: {}", e); + } + }); + }) +} + +async fn run_worker( + config: common::utils::config::AppConfig, + reranker_pool: Option>, + storage: StorageManager, +) -> anyhow::Result<()> { + let worker_db = Arc::new( + SurrealDbClient::new( + &config.surrealdb_address, + &config.surrealdb_username, + &config.surrealdb_password, + &config.surrealdb_namespace, + &config.surrealdb_database, + ) + .await?, + ); + + let openai_client = Arc::new(async_openai::Client::with_config( + async_openai::config::OpenAIConfig::new() + .with_api_key(&config.openai_api_key) + .with_api_base(&config.openai_base_url), + )); + + let embedding_provider = Arc::new( + EmbeddingProvider::from_config(&config, Some(Arc::clone(&openai_client))).await?, + ); + + let ingestion_pipeline = Arc::new( + IngestionPipeline::new( + Arc::clone(&worker_db), + openai_client, + config, + reranker_pool, + storage, + embedding_provider, + )?, + ); + + info!("Starting worker process"); + run_worker_loop(worker_db, ingestion_pipeline).await +} + #[tokio::main] async fn main() -> anyhow::Result<()> { - // Set up tracing tracing_subscriber::registry() .with(fmt::layer().with_writer(std::io::stderr)) .with(EnvFilter::from_default_env()) .try_init() .ok(); - // Get config let config = get_config()?; - // Set up router states let db = Arc::new( SurrealDbClient::new( &config.surrealdb_address, @@ -45,7 +106,6 @@ async fn main() -> anyhow::Result<()> { .await?, ); - // Ensure db is initialized db.apply_migrations().await?; let session_store = Arc::new(db.create_session_store().await?); @@ -55,27 +115,23 @@ async fn main() -> anyhow::Result<()> { .with_api_base(&config.openai_base_url), )); - // Create embedding provider based on config before syncing settings. let embedding_provider = - Arc::new(EmbeddingProvider::from_config(&config, Some(openai_client.clone())).await?); + Arc::new(EmbeddingProvider::from_config(&config, Some(Arc::clone(&openai_client))).await?); info!( embedding_backend = ?config.embedding_backend, embedding_dimension = embedding_provider.dimension(), "Embedding provider initialized" ); - // Sync SystemSettings with provider's dimensions/model/backend let (settings, dimensions_changed) = SystemSettings::sync_from_embedding_provider(&db, &embedding_provider).await?; - // If dimensions changed, re-embed existing data to keep queries working. if dimensions_changed { warn!( new_dimensions = settings.embedding_dimensions, "Embedding configuration changed; re-embedding existing data" ); - // Re-embed text chunks info!("Re-embedding TextChunks"); if let Err(e) = TextChunk::update_all_embeddings_with_provider(&db, &embedding_provider).await @@ -86,7 +142,6 @@ async fn main() -> anyhow::Result<()> { ); } - // Re-embed knowledge entities info!("Re-embedding KnowledgeEntities"); if let Err(e) = KnowledgeEntity::update_all_embeddings_with_provider(&db, &embedding_provider).await @@ -100,29 +155,25 @@ async fn main() -> anyhow::Result<()> { info!("Re-embedding complete."); } - // Now ensure runtime indexes with the correct (synced) dimensions - ensure_runtime_indexes(&db, settings.embedding_dimensions as usize).await?; + ensure_runtime(&db, settings.embedding_dimensions as usize).await?; let reranker_pool = RerankerPool::maybe_from_config(&config)?; - // Create global storage manager let storage = StorageManager::new(&config).await?; - let html_state = HtmlState::new_with_resources( + let html_state = HtmlState::new_with_resources(StateResources { db, openai_client, session_store, - storage.clone(), - config.clone(), - reranker_pool.clone(), - embedding_provider.clone(), - None, - ) - .await; + storage: storage.clone(), + config: config.clone(), + reranker_pool: reranker_pool.clone(), + embedding_provider: Arc::clone(&embedding_provider), + template_engine: None, + }); let api_state = ApiState::new(&config, storage.clone()).await?; - // Create Axum router let app = Router::new() .nest("/api/v1", api_routes_v1(&api_state)) .merge(html_routes(&html_state)) @@ -135,72 +186,16 @@ async fn main() -> anyhow::Result<()> { let serve_address = format!("0.0.0.0:{}", config.http_port); let listener = tokio::net::TcpListener::bind(serve_address).await?; - // Start the server in a separate OS thread with its own runtime - let server_handle = std::thread::spawn(move || { - let rt = tokio::runtime::Runtime::new().unwrap(); - rt.block_on(async { - if let Err(e) = axum::serve(listener, app).await { - error!("Server error: {}", e); - } - }); - }); + let server_handle = spawn_server_thread(listener, app); - // Create a LocalSet for the worker let local = LocalSet::new(); - - // Use a clone of the config for the worker - let worker_config = config.clone(); - - // Run the worker in the local set local.spawn_local(async move { - // Create worker db connection - let worker_db = Arc::new( - SurrealDbClient::new( - &worker_config.surrealdb_address, - &worker_config.surrealdb_username, - &worker_config.surrealdb_password, - &worker_config.surrealdb_namespace, - &worker_config.surrealdb_database, - ) - .await - .unwrap(), - ); - - // Initialize worker components - let openai_client = Arc::new(async_openai::Client::with_config( - async_openai::config::OpenAIConfig::new() - .with_api_key(&config.openai_api_key) - .with_api_base(&config.openai_base_url), - )); - - // Create embedding provider based on config - let embedding_provider = Arc::new( - EmbeddingProvider::from_config(&config, Some(openai_client.clone())) - .await - .expect("failed to create embedding provider"), - ); - let ingestion_pipeline = Arc::new( - IngestionPipeline::new( - worker_db.clone(), - openai_client.clone(), - config.clone(), - reranker_pool.clone(), - storage.clone(), - embedding_provider, - ) - .unwrap(), - ); - - info!("Starting worker process"); - if let Err(e) = run_worker_loop(worker_db, ingestion_pipeline).await { - error!("Worker process error: {}", e); + if let Err(e) = run_worker(config, reranker_pool, storage).await { + error!("Worker error: {}", e); } }); - - // Run the local set on the main thread local.await; - // Wait for the server thread to finish (this likely won't be reached) if let Err(e) = server_handle.join() { error!("Server thread panicked: {:?}", e); } @@ -253,52 +248,39 @@ mod tests { let namespace = "test_ns"; let database = format!("test_db_{}", Uuid::new_v4()); let data_dir = std::env::temp_dir().join(format!("minne_smoke_{}", Uuid::new_v4())); - - tokio::fs::create_dir_all(&data_dir) - .await + tokio::fs::create_dir_all(&data_dir).await .expect("failed to create temp data directory"); let config = smoke_test_config(namespace, &database, &data_dir); - let db = Arc::new( - SurrealDbClient::memory(namespace, &database) - .await - .expect("failed to start in-memory surrealdb"), - ); - db.apply_migrations() - .await - .expect("failed to apply migrations"); + let db = Arc::new(SurrealDbClient::memory(namespace, &database).await?); + db.apply_migrations().await?; - let session_store = Arc::new(db.create_session_store().await.expect("session store")); + let session_store = Arc::new(db.create_session_store().await?); let openai_client = Arc::new(async_openai::Client::with_config( async_openai::config::OpenAIConfig::new() .with_api_key(&config.openai_api_key) .with_api_base(&config.openai_base_url), )); - let storage = StorageManager::new(&config) - .await - .expect("failed to build storage manager"); + let storage = StorageManager::new(&config).await?; - // Use hashed embeddings for tests to avoid external dependencies let embedding_provider = Arc::new( - common::utils::embedding::EmbeddingProvider::new_hashed(384) - .expect("failed to create hashed embedding provider"), + common::utils::embedding::EmbeddingProvider::new_hashed(384)?, ); - let html_state = HtmlState::new_with_resources( - db.clone(), + let html_state = HtmlState::new_with_resources(StateResources { + db: Arc::clone(&db), openai_client, session_store, - storage.clone(), - config.clone(), - None, + storage: storage.clone(), + config: config.clone(), + reranker_pool: None, embedding_provider, - None, - ) - .await; + template_engine: None, + }); let api_state = ApiState { - db: db.clone(), + db: Arc::clone(&db), config: config.clone(), storage, }; @@ -376,25 +358,22 @@ mod tests { .oneshot( Request::builder() .uri("/api/v1/live") - .body(Body::empty()) - .expect("request"), + .body(Body::empty())?, ) - .await - .expect("router response"); + .await?; assert_eq!(response.status(), StatusCode::OK); let ready_response = app .oneshot( Request::builder() .uri("/api/v1/ready") - .body(Body::empty()) - .expect("request"), + .body(Body::empty())?, ) - .await - .expect("ready response"); + .await?; assert_eq!(ready_response.status(), StatusCode::OK); tokio::fs::remove_dir_all(&data_dir).await.ok(); + Ok(()) } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] diff --git a/main/src/server.rs b/main/src/server.rs index 540f1a2..1a1e047 100644 --- a/main/src/server.rs +++ b/main/src/server.rs @@ -6,7 +6,10 @@ use common::{ storage::{db::SurrealDbClient, store::StorageManager, types::system_settings::SystemSettings}, utils::{config::get_config, embedding::EmbeddingProvider}, }; -use html_router::{html_routes, html_state::HtmlState}; +use html_router::{ + html_routes, + html_state::{HtmlState, StateResources}, +}; use retrieval_pipeline::reranking::RerankerPool; use tracing::info; use tracing_subscriber::{fmt, prelude::*, EnvFilter}; @@ -52,7 +55,7 @@ async fn main() -> anyhow::Result<()> { // Create embedding provider based on config let embedding_provider = - Arc::new(EmbeddingProvider::from_config(&config, Some(openai_client.clone())).await?); + Arc::new(EmbeddingProvider::from_config(&config, Some(Arc::clone(&openai_client))).await?); info!( embedding_backend = ?config.embedding_backend, embedding_dimension = embedding_provider.dimension(), @@ -63,17 +66,16 @@ async fn main() -> anyhow::Result<()> { let (_settings, _dimensions_changed) = SystemSettings::sync_from_embedding_provider(&db, &embedding_provider).await?; - let html_state = HtmlState::new_with_resources( + let html_state = HtmlState::new_with_resources(StateResources { db, openai_client, session_store, - storage.clone(), - config.clone(), + storage: storage.clone(), + config: config.clone(), reranker_pool, embedding_provider, - None, - ) - .await; + template_engine: None, + }); let api_state = ApiState::new(&config, storage).await?; diff --git a/main/src/worker.rs b/main/src/worker.rs index 9a29a56..1e076c6 100644 --- a/main/src/worker.rs +++ b/main/src/worker.rs @@ -42,7 +42,7 @@ async fn main() -> anyhow::Result<()> { // Create embedding provider based on config let embedding_provider = - Arc::new(EmbeddingProvider::from_config(&config, Some(openai_client.clone())).await?); + Arc::new(EmbeddingProvider::from_config(&config, Some(Arc::clone(&openai_client))).await?); info!( embedding_backend = ?config.embedding_backend, "Embedding provider initialized for worker" @@ -52,8 +52,8 @@ async fn main() -> anyhow::Result<()> { let storage = StorageManager::new(&config).await?; let ingestion_pipeline = Arc::new(IngestionPipeline::new( - db.clone(), - openai_client.clone(), + Arc::clone(&db), + Arc::clone(&openai_client), config, reranker_pool, storage, diff --git a/retrieval-pipeline/src/answer_retrieval.rs b/retrieval-pipeline/src/answer_retrieval.rs index 7ad625f..7eb7fc3 100644 --- a/retrieval-pipeline/src/answer_retrieval.rs +++ b/retrieval-pipeline/src/answer_retrieval.rs @@ -118,18 +118,16 @@ pub fn create_chat_request( } pub fn process_llm_response( - response: CreateChatCompletionResponse, -) -> Result { + response: &CreateChatCompletionResponse, +) -> Result> { response .choices .first() .and_then(|choice| choice.message.content.as_ref()) - .ok_or(AppError::LLMParsing( - "No content found in LLM response".into(), - )) + .ok_or_else(|| Box::new(AppError::LLMParsing("No content found in LLM response".into()))) .and_then(|content| { serde_json::from_str::(content).map_err(|e| { - AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {e}")) + Box::new(AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {e}"))) }) }) } diff --git a/retrieval-pipeline/src/graph.rs b/retrieval-pipeline/src/graph.rs index edbc200..6b210f7 100644 --- a/retrieval-pipeline/src/graph.rs +++ b/retrieval-pipeline/src/graph.rs @@ -20,7 +20,6 @@ use common::storage::{ /// * `entity_id` - ID of the entity to find neighbors for /// * `user_id` - User ID for access control /// * `limit` - Maximum number of neighbors to return - pub async fn find_entities_by_relationship_by_id( db: &SurrealDbClient, entity_id: &str, @@ -113,25 +112,23 @@ pub async fn find_entities_by_relationship_by_id( #[cfg(test)] mod tests { + use anyhow::{self, Context}; use super::*; use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}; use common::storage::types::knowledge_relationship::KnowledgeRelationship; use uuid::Uuid; #[tokio::test] - async fn test_find_entities_by_relationship_by_id() { - // Setup in-memory database for testing + async fn test_find_entities_by_relationship_by_id() -> anyhow::Result<()> { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await - .expect("Failed to start in-memory surrealdb"); + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; - // Create some test entities let entity_type = KnowledgeEntityType::Document; let user_id = "user123".to_string(); - // Create the central entity we'll query relationships for let central_entity = KnowledgeEntity::new( "central_source".to_string(), "Central Entity".to_string(), @@ -141,7 +138,6 @@ mod tests { user_id.clone(), ); - // Create related entities let related_entity1 = KnowledgeEntity::new( "related_source1".to_string(), "Related Entity 1".to_string(), @@ -160,7 +156,6 @@ mod tests { user_id.clone(), ); - // Create an unrelated entity let unrelated_entity = KnowledgeEntity::new( "unrelated_source".to_string(), "Unrelated Entity".to_string(), @@ -170,32 +165,29 @@ mod tests { user_id.clone(), ); - // Store all entities let central_entity = db .store_item(central_entity.clone()) .await - .expect("Failed to store central entity") - .unwrap(); + .with_context(|| "Failed to store central entity".to_string())? + .ok_or_else(|| anyhow::anyhow!("Central entity not returned after store"))?; let related_entity1 = db .store_item(related_entity1.clone()) .await - .expect("Failed to store related entity 1") - .unwrap(); + .with_context(|| "Failed to store related entity 1".to_string())? + .ok_or_else(|| anyhow::anyhow!("Related entity 1 not returned after store"))?; let related_entity2 = db .store_item(related_entity2.clone()) .await - .expect("Failed to store related entity 2") - .unwrap(); + .with_context(|| "Failed to store related entity 2".to_string())? + .ok_or_else(|| anyhow::anyhow!("Related entity 2 not returned after store"))?; let _unrelated_entity = db .store_item(unrelated_entity.clone()) .await - .expect("Failed to store unrelated entity") - .unwrap(); + .with_context(|| "Failed to store unrelated entity".to_string())? + .ok_or_else(|| anyhow::anyhow!("Unrelated entity not returned after store"))?; - // Create relationships let source_id = "relationship_source".to_string(); - // Create relationship 1: central -> related1 let relationship1 = KnowledgeRelationship::new( central_entity.id.clone(), related_entity1.id.clone(), @@ -204,7 +196,6 @@ mod tests { "references".to_string(), ); - // Create relationship 2: central -> related2 let relationship2 = KnowledgeRelationship::new( central_entity.id.clone(), related_entity2.id.clone(), @@ -213,26 +204,25 @@ mod tests { "contains".to_string(), ); - // Store relationships relationship1 .store_relationship(&db) .await - .expect("Failed to store relationship 1"); + .with_context(|| "Failed to store relationship 1".to_string())?; relationship2 .store_relationship(&db) .await - .expect("Failed to store relationship 2"); + .with_context(|| "Failed to store relationship 2".to_string())?; - // Test finding entities related to the central entity let related_entities = find_entities_by_relationship_by_id(&db, ¢ral_entity.id, &user_id, usize::MAX) .await - .expect("Failed to find entities by relationship"); + .with_context(|| "Failed to find entities by relationship".to_string())?; - // Check that we found relationships assert!( related_entities.len() >= 2, "Should find related entities in both directions" ); + + Ok(()) } } diff --git a/retrieval-pipeline/src/lib.rs b/retrieval-pipeline/src/lib.rs index 9bc2b75..00595e6 100644 --- a/retrieval-pipeline/src/lib.rs +++ b/retrieval-pipeline/src/lib.rs @@ -42,10 +42,14 @@ impl SearchResult { } pub use pipeline::{ - retrieved_entities_to_json, PipelineDiagnostics, PipelineStageTimings, RetrievalConfig, - RetrievalStrategy, RetrievalTuning, SearchTarget, + retrieved_entities_to_json, Diagnostics, StageTimings, RetrievalConfig, + RetrievalStrategy, RetrievalTuning, RetrievalTuningFlags, SearchTarget, }; +// Backward-compatible type aliases for external consumers +pub type PipelineDiagnostics = Diagnostics; +pub type PipelineStageTimings = StageTimings; + // Captures a supporting chunk plus its fused retrieval score for downstream prompts. #[derive(Debug, Clone)] pub struct RetrievedChunk { @@ -61,7 +65,7 @@ pub struct RetrievedEntity { pub chunks: Vec, } -/// Primary orchestrator for the process of retrieving KnowledgeEntitities related to a input_text +/// Primary orchestrator for the process of retrieving `KnowledgeEntity` values related to an `input_text` #[instrument(skip_all, fields(user_id))] pub async fn retrieve_entities( db_client: &SurrealDbClient, @@ -72,7 +76,7 @@ pub async fn retrieve_entities( config: RetrievalConfig, reranker: Option, ) -> Result { - pipeline::run_pipeline( + let params = pipeline::StrategyParams { db_client, openai_client, embedding_provider, @@ -80,17 +84,16 @@ pub async fn retrieve_entities( user_id, config, reranker, - ) - .await + }; + pipeline::execute(params).await } #[cfg(test)] mod tests { use super::*; + use anyhow::{self}; use async_openai::Client; - use common::storage::indexes::ensure_runtime_indexes; - use common::storage::types::text_chunk::TextChunk; - use pipeline::{RetrievalConfig, RetrievalStrategy}; + use common::storage::indexes::ensure_runtime; use uuid::Uuid; fn test_embedding() -> Vec { @@ -105,27 +108,21 @@ mod tests { vec![0.2, 0.8, 0.0] } - async fn setup_test_db() -> SurrealDbClient { + async fn setup_test_db() -> anyhow::Result { let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); - let db = SurrealDbClient::memory(namespace, database) - .await - .expect("Failed to start in-memory surrealdb"); + let db = SurrealDbClient::memory(namespace, database).await?; - db.apply_migrations() - .await - .expect("Failed to apply migrations"); + db.apply_migrations().await?; - ensure_runtime_indexes(&db, 3) - .await - .expect("failed to build runtime indexes"); + ensure_runtime(&db, 3).await?; - db + Ok(db) } #[tokio::test] - async fn test_default_strategy_retrieves_chunks() { - let db = setup_test_db().await; + async fn test_default_strategy_retrieves_chunks() -> anyhow::Result<()> { + let db = setup_test_db().await?; let user_id = "test_user"; let chunk = TextChunk::new( "source_1".into(), @@ -133,39 +130,38 @@ mod tests { user_id.into(), ); - TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db) - .await - .expect("Failed to store chunk"); + TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?; let openai_client = Client::new(); - let results = pipeline::run_pipeline_with_embedding( - &db, - &openai_client, - None, - test_embedding(), - "Rust concurrency async tasks", + let params = pipeline::StrategyParams { + db_client: &db, + openai_client: &openai_client, + embedding_provider: None, + input_text: "Rust concurrency async tasks", user_id, - RetrievalConfig::default(), - None, - ) - .await - .expect("Default strategy retrieval failed"); + config: RetrievalConfig::default(), + reranker: None, + }; + let results = pipeline::run_pipeline_with_embedding(params, test_embedding()) + .await?; let chunks = match results { StrategyOutput::Chunks(items) => items, - other => panic!("expected chunk results, got {:?}", other), + other => anyhow::bail!("expected chunk results, got {other:?}"), }; assert!(!chunks.is_empty(), "Expected at least one retrieval result"); assert!( - chunks[0].chunk.chunk.contains("Tokio"), + chunks.first().is_some_and(|c| c.chunk.chunk.contains("Tokio")), "Expected chunk about Tokio" ); + Ok(()) } #[tokio::test] - async fn test_default_strategy_returns_chunks_from_multiple_sources() { - let db = setup_test_db().await; + async fn test_default_strategy_returns_chunks_from_multiple_sources( + ) -> anyhow::Result<()> { + let db = setup_test_db().await?; let user_id = "multi_source_user"; let primary_chunk = TextChunk::new( @@ -179,30 +175,25 @@ mod tests { user_id.into(), ); - TextChunk::store_with_embedding(primary_chunk, chunk_embedding_primary(), &db) - .await - .expect("Failed to store primary chunk"); - TextChunk::store_with_embedding(secondary_chunk, chunk_embedding_secondary(), &db) - .await - .expect("Failed to store secondary chunk"); + TextChunk::store_with_embedding(primary_chunk, chunk_embedding_primary(), &db).await?; + TextChunk::store_with_embedding(secondary_chunk, chunk_embedding_secondary(), &db).await?; let openai_client = Client::new(); - let results = pipeline::run_pipeline_with_embedding( - &db, - &openai_client, - None, - test_embedding(), - "Rust concurrency async tasks", + let params = pipeline::StrategyParams { + db_client: &db, + openai_client: &openai_client, + embedding_provider: None, + input_text: "Rust concurrency async tasks", user_id, - RetrievalConfig::default(), - None, - ) - .await - .expect("Default strategy retrieval failed"); + config: RetrievalConfig::default(), + reranker: None, + }; + let results = pipeline::run_pipeline_with_embedding(params, test_embedding()) + .await?; let chunks = match results { StrategyOutput::Chunks(items) => items, - other => panic!("expected chunk results, got {:?}", other), + other => anyhow::bail!("expected chunk results, got {other:?}"), }; assert!(chunks.len() >= 2, "Expected chunks from multiple sources"); @@ -216,11 +207,12 @@ mod tests { .any(|c| c.chunk.source_id == "secondary_source"), "Should include secondary source chunk" ); + Ok(()) } #[tokio::test] - async fn test_revised_strategy_returns_chunks() { - let db = setup_test_db().await; + async fn test_revised_strategy_returns_chunks() -> anyhow::Result<()> { + let db = setup_test_db().await?; let user_id = "chunk_user"; let chunk_one = TextChunk::new( "src_alpha".into(), @@ -233,31 +225,26 @@ mod tests { user_id.into(), ); - TextChunk::store_with_embedding(chunk_one.clone(), chunk_embedding_primary(), &db) - .await - .expect("Failed to store chunk one"); - TextChunk::store_with_embedding(chunk_two.clone(), chunk_embedding_secondary(), &db) - .await - .expect("Failed to store chunk two"); + TextChunk::store_with_embedding(chunk_one.clone(), chunk_embedding_primary(), &db).await?; + TextChunk::store_with_embedding(chunk_two.clone(), chunk_embedding_secondary(), &db).await?; let config = RetrievalConfig::with_strategy(RetrievalStrategy::Default); let openai_client = Client::new(); - let results = pipeline::run_pipeline_with_embedding( - &db, - &openai_client, - None, - test_embedding(), - "tokio runtime worker behavior", + let params = pipeline::StrategyParams { + db_client: &db, + openai_client: &openai_client, + embedding_provider: None, + input_text: "tokio runtime worker behavior", user_id, config, - None, - ) - .await - .expect("Revised retrieval failed"); + reranker: None, + }; + let results = pipeline::run_pipeline_with_embedding(params, test_embedding()) + .await?; let chunks = match results { StrategyOutput::Chunks(items) => items, - other => panic!("expected chunk output, got {:?}", other), + other => anyhow::bail!("expected chunk results, got {other:?}"), }; assert!( @@ -270,11 +257,12 @@ mod tests { .any(|entry| entry.chunk.chunk.contains("Tokio")), "Chunk results should contain relevant snippets" ); + Ok(()) } #[tokio::test] - async fn test_search_strategy_returns_search_result() { - let db = setup_test_db().await; + async fn test_search_strategy_returns_search_result() -> anyhow::Result<()> { + let db = setup_test_db().await?; let user_id = "search_user"; let chunk = TextChunk::new( "search_src".into(), @@ -282,33 +270,24 @@ mod tests { user_id.into(), ); - TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db) - .await - .expect("Failed to store chunk"); + TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?; let config = RetrievalConfig::for_search(pipeline::SearchTarget::Both); let openai_client = Client::new(); - let results = pipeline::run_pipeline_with_embedding( - &db, - &openai_client, - None, - test_embedding(), - "async rust programming", + let params = pipeline::StrategyParams { + db_client: &db, + openai_client: &openai_client, + embedding_provider: None, + input_text: "async rust programming", user_id, config, - None, - ) - .await - .expect("Search strategy retrieval failed"); + reranker: None, + }; + let results = pipeline::run_pipeline_with_embedding(params, test_embedding()) + .await?; - assert!( - matches!(results, StrategyOutput::Search(_)), - "expected Search output, got {:?}", - results - ); - let search_result = match results { - StrategyOutput::Search(sr) => sr, - _ => unreachable!(), + let StrategyOutput::Search(search_result) = results else { + anyhow::bail!("expected Search output"); }; // Should return chunks (entities may be empty if none stored) @@ -323,5 +302,6 @@ mod tests { .any(|c| c.chunk.chunk.contains("Tokio")), "Search results should contain relevant chunks" ); + Ok(()) } } diff --git a/retrieval-pipeline/src/pipeline/config.rs b/retrieval-pipeline/src/pipeline/config.rs index 12d3320..733c2d5 100644 --- a/retrieval-pipeline/src/pipeline/config.rs +++ b/retrieval-pipeline/src/pipeline/config.rs @@ -1,12 +1,13 @@ -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use std::fmt; use crate::scoring::FusionWeights; -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] #[serde(rename_all = "snake_case")] pub enum RetrievalStrategy { /// Primary hybrid chunk retrieval for search/chat (formerly Revised) + #[default] Default, /// Entity retrieval for suggesting relationships when creating manual entities RelationshipSuggestion, @@ -29,12 +30,6 @@ pub enum SearchTarget { Both, } -impl Default for RetrievalStrategy { - fn default() -> Self { - Self::Default - } -} - impl std::str::FromStr for RetrievalStrategy { type Err = String; @@ -70,6 +65,91 @@ impl fmt::Display for RetrievalStrategy { } } +/// Two-variant flag that serializes as a bool for backward compatibility. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum BoolFlag { + #[default] + Disabled, + Enabled, +} + +impl BoolFlag { + pub const fn as_bool(self) -> bool { + matches!(self, BoolFlag::Enabled) + } +} + +impl From for BoolFlag { + fn from(value: bool) -> Self { + if value { + BoolFlag::Enabled + } else { + BoolFlag::Disabled + } + } +} + +impl Serialize for BoolFlag { + fn serialize(&self, serializer: S) -> Result { + serializer.serialize_bool(self.as_bool()) + } +} + +impl<'de> Deserialize<'de> for BoolFlag { + fn deserialize>(deserializer: D) -> Result { + bool::deserialize(deserializer).map(|b| { + if b { + BoolFlag::Enabled + } else { + BoolFlag::Disabled + } + }) + } +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub struct RetrievalTuningFlags { + pub rerank_scores_only: BoolFlag, + pub normalize_vector_scores: BoolFlag, + pub normalize_fts_scores: BoolFlag, + pub chunk_rrf_use_vector: BoolFlag, + pub chunk_rrf_use_fts: BoolFlag, +} + +impl RetrievalTuningFlags { + pub const fn rerank_scores_only(&self) -> bool { + self.rerank_scores_only.as_bool() + } + + pub const fn normalize_vector_scores(&self) -> bool { + self.normalize_vector_scores.as_bool() + } + + pub const fn normalize_fts_scores(&self) -> bool { + self.normalize_fts_scores.as_bool() + } + + pub const fn chunk_rrf_use_vector(&self) -> bool { + self.chunk_rrf_use_vector.as_bool() + } + + pub const fn chunk_rrf_use_fts(&self) -> bool { + self.chunk_rrf_use_fts.as_bool() + } +} + +impl Default for RetrievalTuningFlags { + fn default() -> Self { + Self { + rerank_scores_only: BoolFlag::Disabled, + normalize_vector_scores: BoolFlag::Disabled, + normalize_fts_scores: BoolFlag::Enabled, + chunk_rrf_use_vector: BoolFlag::Enabled, + chunk_rrf_use_fts: BoolFlag::Enabled, + } + } +} + /// Tunable parameters that govern each retrieval stage. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RetrievalTuning { @@ -89,15 +169,11 @@ pub struct RetrievalTuning { pub graph_seed_min_score: f32, pub graph_vector_inheritance: f32, pub rerank_blend_weight: f32, - pub rerank_scores_only: bool, + pub flags: RetrievalTuningFlags, pub rerank_keep_top: usize, pub chunk_result_cap: usize, /// Optional fusion weights for hybrid search. If None, uses default weights. pub fusion_weights: Option, - /// Normalize vector similarity scores before fusion (default: true) - pub normalize_vector_scores: bool, - /// Normalize FTS (BM25) scores before fusion (default: true) - pub normalize_fts_scores: bool, /// Reciprocal rank fusion k value for chunk merging in Revised strategy. #[serde(default = "default_chunk_rrf_k")] pub chunk_rrf_k: f32, @@ -107,12 +183,6 @@ pub struct RetrievalTuning { /// Weight applied to chunk FTS ranks in RRF. #[serde(default = "default_chunk_rrf_fts_weight")] pub chunk_rrf_fts_weight: f32, - /// Whether to include vector rankings in RRF. - #[serde(default = "default_chunk_rrf_use_vector")] - pub chunk_rrf_use_vector: bool, - /// Whether to include chunk FTS rankings in RRF. - #[serde(default = "default_chunk_rrf_use_fts")] - pub chunk_rrf_use_fts: bool, } impl Default for RetrievalTuning { @@ -134,26 +204,19 @@ impl Default for RetrievalTuning { graph_seed_min_score: 0.4, graph_vector_inheritance: 0.6, rerank_blend_weight: 0.65, - rerank_scores_only: false, + flags: RetrievalTuningFlags::default(), rerank_keep_top: 8, chunk_result_cap: 5, fusion_weights: None, - // Vector scores (cosine similarity) are already in [0,1] range - // Normalization only helps when there's significant variation - normalize_vector_scores: false, - // FTS scores (BM25) are unbounded, normalization helps more - normalize_fts_scores: true, chunk_rrf_k: default_chunk_rrf_k(), chunk_rrf_vector_weight: default_chunk_rrf_vector_weight(), chunk_rrf_fts_weight: default_chunk_rrf_fts_weight(), - chunk_rrf_use_vector: default_chunk_rrf_use_vector(), - chunk_rrf_use_fts: default_chunk_rrf_use_fts(), } } } /// Wrapper containing tuning plus future flags for per-request overrides. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct RetrievalConfig { pub strategy: RetrievalStrategy, pub tuning: RetrievalTuning, @@ -211,16 +274,6 @@ impl RetrievalConfig { } } -impl Default for RetrievalConfig { - fn default() -> Self { - Self { - strategy: RetrievalStrategy::default(), - tuning: RetrievalTuning::default(), - search_target: SearchTarget::default(), - } - } -} - const fn default_chunk_rrf_k() -> f32 { 60.0 } @@ -233,10 +286,4 @@ const fn default_chunk_rrf_fts_weight() -> f32 { 1.0 } -const fn default_chunk_rrf_use_vector() -> bool { - true -} -const fn default_chunk_rrf_use_fts() -> bool { - true -} diff --git a/retrieval-pipeline/src/pipeline/diagnostics.rs b/retrieval-pipeline/src/pipeline/diagnostics.rs index 67c14d2..8fd3495 100644 --- a/retrieval-pipeline/src/pipeline/diagnostics.rs +++ b/retrieval-pipeline/src/pipeline/diagnostics.rs @@ -2,7 +2,7 @@ use serde::Serialize; /// Captures instrumentation for each hybrid retrieval stage when diagnostics are enabled. #[derive(Debug, Clone, Default, Serialize)] -pub struct PipelineDiagnostics { +pub struct Diagnostics { pub collect_candidates: Option, pub enrich_chunks_from_entities: Option, pub assemble: Option, diff --git a/retrieval-pipeline/src/pipeline/mod.rs b/retrieval-pipeline/src/pipeline/mod.rs index 1be08c4..738dd81 100644 --- a/retrieval-pipeline/src/pipeline/mod.rs +++ b/retrieval-pipeline/src/pipeline/mod.rs @@ -3,10 +3,11 @@ mod diagnostics; mod stages; mod strategies; -pub use config::{RetrievalConfig, RetrievalStrategy, RetrievalTuning, SearchTarget}; +pub use config::{ + RetrievalConfig, RetrievalStrategy, RetrievalTuning, RetrievalTuningFlags, SearchTarget, +}; pub use diagnostics::{ - AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace, - PipelineDiagnostics, + AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace, Diagnostics, }; use crate::{reranking::RerankerLease, RetrievedEntity, StrategyOutput}; @@ -37,13 +38,13 @@ pub enum StageKind { // Pipeline stage trait #[async_trait] -pub trait PipelineStage: Send + Sync { +pub trait Stage: Send + Sync { fn kind(&self) -> StageKind; async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError>; } // Type alias for boxed stages -pub type BoxedStage = Box; +pub type BoxedStage = Box; // Strategy driver trait #[async_trait] @@ -51,16 +52,16 @@ pub trait StrategyDriver: Send + Sync { type Output; fn stages(&self) -> Vec; - fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result; + fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result>; } // Pipeline stage timings tracker #[derive(Debug, Default, Clone)] -pub struct PipelineStageTimings { +pub struct StageTimings { timings: Vec<(StageKind, Duration)>, } -impl PipelineStageTimings { +impl StageTimings { pub fn record(&mut self, kind: StageKind, duration: Duration) { self.timings.push((kind, duration)); } @@ -74,8 +75,7 @@ impl PipelineStageTimings { self.timings .iter() .find(|(k, _)| *k == kind) - .map(|(_, d)| d.as_millis()) - .unwrap_or(0) + .map_or(0, |(_, d)| d.as_millis()) } pub fn embed_ms(&self) -> u128 { @@ -103,228 +103,100 @@ impl PipelineStageTimings { } } -pub struct PipelineRunOutput { +pub struct RunOutput { pub results: T, - pub diagnostics: Option, - pub stage_timings: PipelineStageTimings, + pub diagnostics: Option, + pub stage_timings: StageTimings, } -pub async fn run_pipeline( - db_client: &SurrealDbClient, - openai_client: &Client, - embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>, - input_text: &str, - user_id: &str, - config: RetrievalConfig, - reranker: Option, -) -> Result { - let input_chars = input_text.chars().count(); - let input_preview: String = input_text.chars().take(120).collect(); +pub async fn execute(params: StrategyParams<'_>) -> Result { + let input_chars = params.input_text.chars().count(); + let input_preview: String = params.input_text.chars().take(120).collect(); let input_preview_clean = input_preview.replace('\n', " "); let preview_len = input_preview_clean.chars().count(); info!( - %user_id, + user_id = %params.user_id, input_chars, preview_truncated = input_chars > preview_len, preview = %input_preview_clean, - strategy = %config.strategy, + strategy = %params.config.strategy, "Starting retrieval pipeline" ); - match config.strategy { + let strategy = params.config.strategy; + let search_target = params.config.search_target; + + match strategy { RetrievalStrategy::Default => { let driver = DefaultStrategyDriver::new(); - let run = execute_strategy( - driver, - db_client, - openai_client, - embedding_provider, - None, - input_text, - user_id, - config, - reranker, - false, - ) - .await?; + let run = execute_strategy(driver, params, None, false).await?; Ok(StrategyOutput::Chunks(run.results)) } RetrievalStrategy::RelationshipSuggestion => { let driver = RelationshipSuggestionDriver::new(); - let run = execute_strategy( - driver, - db_client, - openai_client, - embedding_provider, - None, - input_text, - user_id, - config, - reranker, - false, - ) - .await?; + let run = execute_strategy(driver, params, None, false).await?; Ok(StrategyOutput::Entities(run.results)) } RetrievalStrategy::Ingestion => { let driver = IngestionDriver::new(); - let run = execute_strategy( - driver, - db_client, - openai_client, - embedding_provider, - None, - input_text, - user_id, - config, - reranker, - false, - ) - .await?; + let run = execute_strategy(driver, params, None, false).await?; Ok(StrategyOutput::Entities(run.results)) } RetrievalStrategy::Search => { - let search_target = config.search_target; let driver = SearchStrategyDriver::new(search_target); - let run = execute_strategy( - driver, - db_client, - openai_client, - embedding_provider, - None, - input_text, - user_id, - config, - reranker, - false, - ) - .await?; + let run = execute_strategy(driver, params, None, false).await?; Ok(StrategyOutput::Search(run.results)) } } } pub async fn run_pipeline_with_embedding( - db_client: &SurrealDbClient, - openai_client: &Client, - embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>, + params: StrategyParams<'_>, query_embedding: Vec, - input_text: &str, - user_id: &str, - config: RetrievalConfig, - reranker: Option, ) -> Result { - match config.strategy { + let strategy = params.config.strategy; + let search_target = params.config.search_target; + + match strategy { RetrievalStrategy::Default => { let driver = DefaultStrategyDriver::new(); - let run = execute_strategy( - driver, - db_client, - openai_client, - embedding_provider, - Some(query_embedding), - input_text, - user_id, - config, - reranker, - false, - ) - .await?; + let run = execute_strategy(driver, params, Some(query_embedding), false).await?; Ok(StrategyOutput::Chunks(run.results)) } RetrievalStrategy::RelationshipSuggestion => { let driver = RelationshipSuggestionDriver::new(); - let run = execute_strategy( - driver, - db_client, - openai_client, - embedding_provider, - Some(query_embedding), - input_text, - user_id, - config, - reranker, - false, - ) - .await?; + let run = execute_strategy(driver, params, Some(query_embedding), false).await?; Ok(StrategyOutput::Entities(run.results)) } RetrievalStrategy::Ingestion => { let driver = IngestionDriver::new(); - let run = execute_strategy( - driver, - db_client, - openai_client, - embedding_provider, - Some(query_embedding), - input_text, - user_id, - config, - reranker, - false, - ) - .await?; + let run = execute_strategy(driver, params, Some(query_embedding), false).await?; Ok(StrategyOutput::Entities(run.results)) } RetrievalStrategy::Search => { - let search_target = config.search_target; let driver = SearchStrategyDriver::new(search_target); - let run = execute_strategy( - driver, - db_client, - openai_client, - embedding_provider, - Some(query_embedding), - input_text, - user_id, - config, - reranker, - false, - ) - .await?; + let run = execute_strategy(driver, params, Some(query_embedding), false).await?; Ok(StrategyOutput::Search(run.results)) } } } -// Note: The metrics/diagnostics variants would follow the same pattern, -// but for brevity I'm only updating the main ones used by callers. -// If metrics/diagnostics are needed for non-chat strategies, they should be updated too. -// For now, I'll update them to support at least Initial/Revised as before. - pub async fn run_pipeline_with_embedding_with_metrics( - db_client: &SurrealDbClient, - openai_client: &Client, - embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>, + params: StrategyParams<'_>, query_embedding: Vec, - input_text: &str, - user_id: &str, - config: RetrievalConfig, - reranker: Option, -) -> Result, AppError> { - match config.strategy { +) -> Result, AppError> { + let strategy = params.config.strategy; + + match strategy { RetrievalStrategy::Default => { let driver = DefaultStrategyDriver::new(); - let run = execute_strategy( - driver, - db_client, - openai_client, - embedding_provider, - Some(query_embedding), - input_text, - user_id, - config, - reranker, - false, - ) - .await?; - Ok(PipelineRunOutput { + let run = execute_strategy(driver, params, Some(query_embedding), false).await?; + Ok(RunOutput { results: StrategyOutput::Chunks(run.results), diagnostics: run.diagnostics, stage_timings: run.stage_timings, }) } - // Fallback for others if needed, or error. For now assuming metrics mainly for chat. _ => Err(AppError::InternalError( "Metrics not supported for this strategy".into(), )), @@ -332,32 +204,16 @@ pub async fn run_pipeline_with_embedding_with_metrics( } pub async fn run_pipeline_with_embedding_with_diagnostics( - db_client: &SurrealDbClient, - openai_client: &Client, - embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>, + params: StrategyParams<'_>, query_embedding: Vec, - input_text: &str, - user_id: &str, - config: RetrievalConfig, - reranker: Option, -) -> Result, AppError> { - match config.strategy { +) -> Result, AppError> { + let strategy = params.config.strategy; + + match strategy { RetrievalStrategy::Default => { let driver = DefaultStrategyDriver::new(); - let run = execute_strategy( - driver, - db_client, - openai_client, - embedding_provider, - Some(query_embedding), - input_text, - user_id, - config, - reranker, - true, - ) - .await?; - Ok(PipelineRunOutput { + let run = execute_strategy(driver, params, Some(query_embedding), true).await?; + Ok(RunOutput { results: StrategyOutput::Chunks(run.results), diagnostics: run.diagnostics, stage_timings: run.stage_timings, @@ -391,38 +247,25 @@ pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::V .collect::>()) } +pub struct StrategyParams<'a> { + pub db_client: &'a SurrealDbClient, + pub openai_client: &'a Client, + pub embedding_provider: Option<&'a common::utils::embedding::EmbeddingProvider>, + pub input_text: &'a str, + pub user_id: &'a str, + pub config: RetrievalConfig, + pub reranker: Option, +} + async fn execute_strategy( driver: D, - db_client: &SurrealDbClient, - openai_client: &Client, - embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>, + params: StrategyParams<'_>, query_embedding: Option>, - input_text: &str, - user_id: &str, - config: RetrievalConfig, - reranker: Option, capture_diagnostics: bool, -) -> Result, AppError> { +) -> Result, AppError> { let ctx = match query_embedding { - Some(embedding) => PipelineContext::with_embedding( - db_client, - openai_client, - embedding_provider, - embedding, - input_text.to_owned(), - user_id.to_owned(), - config, - reranker, - ), - None => PipelineContext::new( - db_client, - openai_client, - embedding_provider, - input_text.to_owned(), - user_id.to_owned(), - config, - reranker, - ), + Some(embedding) => PipelineContext::with_embedding(params, embedding), + None => PipelineContext::new(params), }; run_with_driver(driver, ctx, capture_diagnostics).await @@ -432,7 +275,7 @@ async fn run_with_driver( driver: D, mut ctx: PipelineContext<'_>, capture_diagnostics: bool, -) -> Result, AppError> { +) -> Result, AppError> { if capture_diagnostics { ctx.enable_diagnostics(); } @@ -445,9 +288,9 @@ async fn run_with_driver( let diagnostics = ctx.take_diagnostics(); let stage_timings = ctx.take_stage_timings(); - let results = driver.finalize(&mut ctx)?; + let results = driver.finalize(&mut ctx).map_err(|e| *e)?; - Ok(PipelineRunOutput { + Ok(RunOutput { results, diagnostics, stage_timings, diff --git a/retrieval-pipeline/src/pipeline/stages/mod.rs b/retrieval-pipeline/src/pipeline/stages/mod.rs index e7db7be..6926666 100644 --- a/retrieval-pipeline/src/pipeline/stages/mod.rs +++ b/retrieval-pipeline/src/pipeline/stages/mod.rs @@ -27,9 +27,9 @@ use super::{ config::{RetrievalConfig, RetrievalTuning}, diagnostics::{ AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace, - PipelineDiagnostics, + Diagnostics, }, - PipelineStage, PipelineStageTimings, StageKind, + StageTimings, Stage, StageKind, StrategyParams, }; pub struct PipelineContext<'a> { @@ -45,76 +45,51 @@ pub struct PipelineContext<'a> { pub chunk_values: Vec>, pub revised_chunk_values: Vec>, pub reranker: Option, - pub diagnostics: Option, + pub diagnostics: Option, pub entity_results: Vec, pub chunk_results: Vec, - stage_timings: PipelineStageTimings, + stage_timings: StageTimings, } impl<'a> PipelineContext<'a> { - pub fn new( - db_client: &'a SurrealDbClient, - openai_client: &'a Client, - embedding_provider: Option<&'a EmbeddingProvider>, - input_text: String, - user_id: String, - config: RetrievalConfig, - reranker: Option, - ) -> Self { + pub fn new(params: StrategyParams<'a>) -> Self { Self { - db_client, - openai_client, - embedding_provider, - input_text, - user_id, - config, + db_client: params.db_client, + openai_client: params.openai_client, + embedding_provider: params.embedding_provider, + input_text: params.input_text.to_owned(), + user_id: params.user_id.to_owned(), + config: params.config, query_embedding: None, entity_candidates: HashMap::new(), filtered_entities: Vec::new(), chunk_values: Vec::new(), revised_chunk_values: Vec::new(), - reranker, + reranker: params.reranker, diagnostics: None, entity_results: Vec::new(), chunk_results: Vec::new(), - stage_timings: PipelineStageTimings::default(), + stage_timings: StageTimings::default(), } } - pub fn with_embedding( - db_client: &'a SurrealDbClient, - openai_client: &'a Client, - embedding_provider: Option<&'a EmbeddingProvider>, - query_embedding: Vec, - input_text: String, - user_id: String, - config: RetrievalConfig, - reranker: Option, - ) -> Self { - let mut ctx = Self::new( - db_client, - openai_client, - embedding_provider, - input_text, - user_id, - config, - reranker, - ); + pub fn with_embedding(params: StrategyParams<'a>, query_embedding: Vec) -> Self { + let mut ctx = Self::new(params); ctx.query_embedding = Some(query_embedding); ctx } - fn ensure_embedding(&self) -> Result<&Vec, AppError> { + fn ensure_embedding(&self) -> Result<&Vec, Box> { self.query_embedding.as_ref().ok_or_else(|| { - AppError::InternalError( + Box::new(AppError::InternalError( "query embedding missing before candidate collection".to_string(), - ) + )) }) } pub fn enable_diagnostics(&mut self) { if self.diagnostics.is_none() { - self.diagnostics = Some(PipelineDiagnostics::default()); + self.diagnostics = Some(Diagnostics::default()); } } @@ -140,11 +115,11 @@ impl<'a> PipelineContext<'a> { } } - pub fn take_diagnostics(&mut self) -> Option { + pub fn take_diagnostics(&mut self) -> Option { self.diagnostics.take() } - pub fn take_stage_timings(&mut self) -> PipelineStageTimings { + pub fn take_stage_timings(&mut self) -> StageTimings { std::mem::take(&mut self.stage_timings) } @@ -165,7 +140,7 @@ impl<'a> PipelineContext<'a> { pub struct EmbedStage; #[async_trait] -impl PipelineStage for EmbedStage { +impl Stage for EmbedStage { fn kind(&self) -> StageKind { StageKind::Embed } @@ -179,7 +154,7 @@ impl PipelineStage for EmbedStage { pub struct CollectCandidatesStage; #[async_trait] -impl PipelineStage for CollectCandidatesStage { +impl Stage for CollectCandidatesStage { fn kind(&self) -> StageKind { StageKind::CollectCandidates } @@ -193,7 +168,7 @@ impl PipelineStage for CollectCandidatesStage { pub struct GraphExpansionStage; #[async_trait] -impl PipelineStage for GraphExpansionStage { +impl Stage for GraphExpansionStage { fn kind(&self) -> StageKind { StageKind::GraphExpansion } @@ -207,7 +182,7 @@ impl PipelineStage for GraphExpansionStage { pub struct RerankStage; #[async_trait] -impl PipelineStage for RerankStage { +impl Stage for RerankStage { fn kind(&self) -> StageKind { StageKind::Rerank } @@ -221,7 +196,7 @@ impl PipelineStage for RerankStage { pub struct AssembleEntitiesStage; #[async_trait] -impl PipelineStage for AssembleEntitiesStage { +impl Stage for AssembleEntitiesStage { fn kind(&self) -> StageKind { StageKind::Assemble } @@ -235,7 +210,7 @@ impl PipelineStage for AssembleEntitiesStage { pub struct ChunkVectorStage; #[async_trait] -impl PipelineStage for ChunkVectorStage { +impl Stage for ChunkVectorStage { fn kind(&self) -> StageKind { StageKind::CollectCandidates } @@ -249,7 +224,7 @@ impl PipelineStage for ChunkVectorStage { pub struct ChunkRerankStage; #[async_trait] -impl PipelineStage for ChunkRerankStage { +impl Stage for ChunkRerankStage { fn kind(&self) -> StageKind { StageKind::Rerank } @@ -263,7 +238,7 @@ impl PipelineStage for ChunkRerankStage { pub struct ChunkAssembleStage; #[async_trait] -impl PipelineStage for ChunkAssembleStage { +impl Stage for ChunkAssembleStage { fn kind(&self) -> StageKind { StageKind::Assemble } @@ -283,8 +258,7 @@ pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { let embedding = if let Some(provider) = ctx.embedding_provider { provider.embed(&ctx.input_text).await.map_err(|e| { AppError::InternalError(format!( - "Failed to generate embedding with provider: {}", - e + "Failed to generate embedding with provider: {e}", )) })? } else { @@ -299,7 +273,7 @@ pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { #[instrument(level = "trace", skip_all)] pub async fn collect_candidates(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { debug!("Collecting initial candidates via vector and FTS search"); - let embedding = ctx.ensure_embedding()?.clone(); + let embedding = ctx.ensure_embedding().map_err(|e| *e)?.clone(); let tuning = &ctx.config.tuning; let weights = FusionWeights::default(); @@ -487,11 +461,11 @@ pub async fn rerank(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { #[instrument(level = "trace", skip_all)] pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { debug!("Collecting vector chunk candidates for revised strategy"); - let embedding = ctx.ensure_embedding()?.clone(); + let embedding = ctx.ensure_embedding().map_err(|e| *e)?.clone(); let tuning = &ctx.config.tuning; let fts_take = tuning.chunk_fts_take; let (fts_query, fts_token_count) = normalize_fts_query(&ctx.input_text); - let fts_enabled = tuning.chunk_rrf_use_fts && fts_take > 0 && !fts_query.is_empty(); + let fts_enabled = tuning.flags.chunk_rrf_use_fts() && fts_take > 0 && !fts_query.is_empty(); let (vector_rows, fts_rows) = tokio::try_join!( TextChunk::vector_search( @@ -532,8 +506,8 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), k: tuning.chunk_rrf_k, vector_weight: tuning.chunk_rrf_vector_weight, fts_weight, - use_vector: tuning.chunk_rrf_use_vector, - use_fts: tuning.chunk_rrf_use_fts && fts_candidates > 0, + use_vector: tuning.flags.chunk_rrf_use_vector(), + use_fts: tuning.flags.chunk_rrf_use_fts() && fts_candidates > 0, }; let mut vector_chunks = reciprocal_rank_fusion(vector_scored, fts_scored, rrf_config); @@ -715,7 +689,7 @@ pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { let mut per_entity_count = 0; for candidate in candidates.iter() { if let Some(trace) = entity_trace.as_mut() { - trace.inspected_candidates += 1; + trace.inspected_candidates = trace.inspected_candidates.saturating_add(1); } if per_entity_count >= tuning.max_chunks_per_entity { break; @@ -723,17 +697,17 @@ pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { let estimated_tokens = estimate_tokens(&candidate.item.chunk, tuning.avg_chars_per_token); if estimated_tokens > token_budget_remaining { - chunks_skipped_due_budget += 1; + chunks_skipped_due_budget = chunks_skipped_due_budget.saturating_add(1); if let Some(trace) = entity_trace.as_mut() { - trace.skipped_due_budget += 1; + trace.skipped_due_budget = trace.skipped_due_budget.saturating_add(1); } continue; } token_budget_remaining = token_budget_remaining.saturating_sub(estimated_tokens); - tokens_spent += estimated_tokens; - per_entity_count += 1; - chunks_selected += 1; + tokens_spent = tokens_spent.saturating_add(estimated_tokens); + per_entity_count = per_entity_count.saturating_add(1); + chunks_selected = chunks_selected.saturating_add(1); selected_chunks.push(RetrievedChunk { chunk: candidate.item.clone(), @@ -780,14 +754,14 @@ pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { const SCORE_SAMPLE_LIMIT: usize = 8; -fn sample_scores(items: &[Scored], mut extractor: F) -> Vec +fn sample_scores(items: &[Scored], extractor: F) -> Vec where F: FnMut(&Scored) -> f32, { items .iter() .take(SCORE_SAMPLE_LIMIT) - .map(|item| extractor(item)) + .map(extractor) .collect() } @@ -912,7 +886,7 @@ fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec = results.iter().map(|r| r.score).collect(); let normalized_scores = min_max_normalize(&raw_scores); - let use_only = ctx.config.tuning.rerank_scores_only; + let use_only = ctx.config.tuning.flags.rerank_scores_only(); let blend = if use_only { 1.0 } else { @@ -942,11 +916,7 @@ fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec = results.iter().map(|r| r.score).collect(); let normalized_scores = min_max_normalize(&raw_scores); - let use_only = tuning.rerank_scores_only; + let use_only = tuning.flags.rerank_scores_only(); let blend = if use_only { 1.0 } else { @@ -1001,11 +971,7 @@ fn apply_chunk_rerank_results( } } - for slot in remaining.into_iter() { - if let Some(candidate) = slot { - reranked.push(candidate); - } - } + reranked.extend(remaining.into_iter().flatten()); let keep_top = tuning.rerank_keep_top; if keep_top > 0 && reranked.len() > keep_top { @@ -1017,7 +983,7 @@ fn apply_chunk_rerank_results( fn estimate_tokens(text: &str, avg_chars_per_token: usize) -> usize { let chars = text.chars().count().max(1); - (chars / avg_chars_per_token).max(1) + chars.checked_div(avg_chars_per_token).map_or(1, |v| v.max(1)) } fn rank_chunks_by_combined_score( @@ -1053,13 +1019,20 @@ fn lexical_overlap_score(terms: &[String], haystack: &str) -> f32 { return 0.0; } let lower = haystack.to_ascii_lowercase(); - let mut matches = 0usize; + let mut matches: u32 = 0; for term in terms { if lower.contains(term) { - matches += 1; + matches = matches.saturating_add(1); } } - (matches as f32) / (terms.len() as f32) + let total = u32::try_from(terms.len()).unwrap_or(u32::MAX); + if total == 0 { + return 0.0; + } + let num = matches.min(total); + let num_f32 = u16::try_from(num).map(f32::from).unwrap_or(f32::MAX); + let den_f32 = u16::try_from(total).map(f32::from).unwrap_or(f32::MAX); + num_f32 / den_f32 } #[derive(Clone)] diff --git a/retrieval-pipeline/src/pipeline/strategies.rs b/retrieval-pipeline/src/pipeline/strategies.rs index 28df122..0b1e6bb 100644 --- a/retrieval-pipeline/src/pipeline/strategies.rs +++ b/retrieval-pipeline/src/pipeline/strategies.rs @@ -28,7 +28,7 @@ impl StrategyDriver for DefaultStrategyDriver { ] } - fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result { + fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result> { Ok(ctx.take_chunk_results()) } } @@ -55,7 +55,7 @@ impl StrategyDriver for RelationshipSuggestionDriver { ] } - fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result { + fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result> { Ok(ctx.take_entity_results()) } } @@ -82,7 +82,7 @@ impl StrategyDriver for IngestionDriver { ] } - fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result { + fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result> { Ok(ctx.take_entity_results()) } } @@ -134,7 +134,7 @@ impl StrategyDriver for SearchStrategyDriver { } } - fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result { + fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result> { let chunks = match self.target { SearchTarget::EntitiesOnly => Vec::new(), _ => ctx.take_chunk_results(), diff --git a/retrieval-pipeline/src/reranking/mod.rs b/retrieval-pipeline/src/reranking/mod.rs index c919631..705af40 100644 --- a/retrieval-pipeline/src/reranking/mod.rs +++ b/retrieval-pipeline/src/reranking/mod.rs @@ -17,7 +17,7 @@ static NEXT_ENGINE: AtomicUsize = AtomicUsize::new(0); fn pick_engine_index(pool_len: usize) -> usize { let n = NEXT_ENGINE.fetch_add(1, Ordering::Relaxed); - n % pool_len + n.checked_rem(pool_len).unwrap_or(0) } pub struct RerankerPool { @@ -28,30 +28,30 @@ pub struct RerankerPool { impl RerankerPool { /// Build the pool at startup. /// `pool_size` controls max parallel reranks. - pub fn new(pool_size: usize) -> Result, AppError> { - Self::new_with_options( - pool_size, - RerankInitOptions::new(fastembed::RerankerModel::JINARerankerV1TurboEn), - ) + pub fn new(pool_size: usize) -> Result, Box> { + let init_options = + RerankInitOptions::new(fastembed::RerankerModel::JINARerankerV1TurboEn); + Self::new_with_options(pool_size, &init_options) } fn new_with_options( pool_size: usize, - init_options: RerankInitOptions, - ) -> Result, AppError> { + init_options: &RerankInitOptions, + ) -> Result, Box> { if pool_size == 0 { - return Err(AppError::Validation( + return Err(Box::new(AppError::Validation( "RERANKING_POOL_SIZE must be greater than zero".to_string(), - )); + ))); } - fs::create_dir_all(&init_options.cache_dir)?; + fs::create_dir_all(&init_options.cache_dir) + .map_err(|e| Box::new(AppError::from(e)))?; let mut engines = Vec::with_capacity(pool_size); for x in 0..pool_size { debug!("Creating reranking engine: {x}"); let model = TextRerank::try_new(init_options.clone()) - .map_err(|e| AppError::InternalError(e.to_string()))?; + .map_err(|e| Box::new(AppError::InternalError(e.to_string())))?; engines.push(Arc::new(Mutex::new(model))); } @@ -62,7 +62,7 @@ impl RerankerPool { } /// Initialize a pool using application configuration. - pub fn maybe_from_config(config: &AppConfig) -> Result>, AppError> { + pub fn maybe_from_config(config: &AppConfig) -> Result>, Box> { if !config.reranking_enabled { return Ok(None); } @@ -70,30 +70,28 @@ impl RerankerPool { let pool_size = config.reranking_pool_size.unwrap_or_else(default_pool_size); let init_options = build_rerank_init_options(config)?; - Self::new_with_options(pool_size, init_options).map(Some) + Self::new_with_options(pool_size, &init_options).map(Some) } /// Check out capacity + pick an engine. - /// This returns a lease that can perform rerank(). - pub async fn checkout(self: &Arc) -> RerankerLease { + /// This returns a lease that can perform `rerank()`. + pub async fn checkout(self: &Arc) -> Option { // Acquire a permit. This enforces backpressure. - let permit = self - .semaphore - .clone() + let permit = Arc::clone(&self.semaphore) .acquire_owned() .await - .expect("semaphore closed"); + .ok()?; // Pick an engine. // This is naive: just pick based on a simple modulo counter. // We use an atomic counter to avoid always choosing index 0. let idx = pick_engine_index(self.engines.len()); - let engine = self.engines[idx].clone(); + let engine = self.engines.get(idx).map(Arc::clone)?; - RerankerLease { + Some(RerankerLease { _permit: permit, engine, - } + }) } } @@ -111,7 +109,7 @@ fn is_truthy(value: &str) -> bool { ) } -fn build_rerank_init_options(config: &AppConfig) -> Result { +fn build_rerank_init_options(config: &AppConfig) -> Result> { let mut options = RerankInitOptions::default(); let cache_dir = config @@ -125,7 +123,7 @@ fn build_rerank_init_options(config: &AppConfig) -> Result Option { env::var(key).ok().map(|value| is_truthy(&value)) } -/// Active lease on a single TextRerank instance. +/// Active lease on a single `TextRerank` instance. pub struct RerankerLease { // When this drops the semaphore permit is released. _permit: OwnedSemaphorePermit, diff --git a/retrieval-pipeline/src/scoring.rs b/retrieval-pipeline/src/scoring.rs index 8fce2e7..5b463e8 100644 --- a/retrieval-pipeline/src/scoring.rs +++ b/retrieval-pipeline/src/scoring.rs @@ -28,16 +28,19 @@ impl Scored { } } + #[must_use] pub const fn with_vector_score(mut self, score: f32) -> Self { self.scores.vector = Some(score); self } + #[must_use] pub const fn with_fts_score(mut self, score: f32) -> Self { self.scores.fts = Some(score); self } + #[must_use] pub const fn with_graph_score(mut self, score: f32) -> Self { self.scores.graph = Some(score); self @@ -168,7 +171,7 @@ pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 { if scores.vector.is_some() && scores.fts.is_some() { // Multiplicative boost: multiply by (1 + bonus) to scale with the base score // This ensures high-scoring golden chunks get boosted more than low-scoring ones - fused = fused * (1.0 + weights.multi_bonus); + fused *= 1.0 + weights.multi_bonus; } else { // For other multi-signal combinations (e.g., vector + graph), use additive bonus fused += weights.multi_bonus; @@ -178,8 +181,8 @@ pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 { clamp_unit(fused) } -pub fn merge_scored_by_id( - target: &mut std::collections::HashMap>, +pub fn merge_scored_by_id( + target: &mut std::collections::HashMap, S>, incoming: Vec>, ) where T: StoredObject + Clone, @@ -263,7 +266,10 @@ where } } entry.item = candidate.item; - entry.fused += vector_weight / (k + rank as f32 + 1.0); + let rank_f32: f32 = u16::try_from(rank) + .map(f32::from) + .unwrap_or(f32::MAX); + entry.fused += vector_weight / (k + rank_f32 + 1.0); } } @@ -290,7 +296,10 @@ where } } entry.item = candidate.item; - entry.fused += fts_weight / (k + rank as f32 + 1.0); + let rank_f32: f32 = u16::try_from(rank) + .map(f32::from) + .unwrap_or(f32::MAX); + entry.fused += fts_weight / (k + rank_f32 + 1.0); } }