clippy: adhere to pedantic clippy, uniform test error handling

This commit is contained in:
Per Stark
2026-05-26 11:43:45 +02:00
parent 6a5d631287
commit 000852c94c
68 changed files with 2468 additions and 2547 deletions
+2 -2
View File
@@ -106,11 +106,11 @@ missing_errors_doc = "allow"
missing_panics_doc = "warn" missing_panics_doc = "warn"
module_name_repetitions = "warn" module_name_repetitions = "warn"
wildcard_dependencies = "warn" wildcard_dependencies = "warn"
missing_docs_in_private_items = "warn" missing_docs_in_private_items = "allow"
# Allow noisy lints that don't add value for this project # Allow noisy lints that don't add value for this project
needless_raw_string_hashes = "allow" needless_raw_string_hashes = "allow"
multiple_bound_locations = "allow" multiple_bound_locations = "allow"
cargo_common_metadata = "allow" cargo_common_metadata = "allow"
multiple-crate-versions = "allow" multiple-crate-versions = "allow"
module_name_repetition = "allow"
+1 -1
View File
@@ -31,7 +31,7 @@ impl ApiState {
surreal_db_client.apply_migrations().await?; surreal_db_client.apply_migrations().await?;
let app_state = Self { let app_state = Self {
db: surreal_db_client.clone(), db: Arc::clone(&surreal_db_client),
config: config.clone(), config: config.clone(),
storage, storage,
}; };
+23 -22
View File
@@ -8,7 +8,7 @@ use serde::Serialize;
use thiserror::Error; use thiserror::Error;
#[derive(Error, Debug, Serialize, Clone)] #[derive(Error, Debug, Serialize, Clone)]
pub enum ApiError { pub enum ApiErr {
#[error("Internal server error")] #[error("Internal server error")]
InternalError(String), InternalError(String),
@@ -25,7 +25,7 @@ pub enum ApiError {
PayloadTooLarge(String), PayloadTooLarge(String),
} }
impl From<AppError> for ApiError { impl From<AppError> for ApiErr {
fn from(err: AppError) -> Self { fn from(err: AppError) -> Self {
match err { match err {
AppError::Database(_) | AppError::OpenAI(_) => { AppError::Database(_) | AppError::OpenAI(_) => {
@@ -39,7 +39,7 @@ impl From<AppError> for ApiError {
} }
} }
} }
impl IntoResponse for ApiError { impl IntoResponse for ApiErr {
fn into_response(self) -> Response { fn into_response(self) -> Response {
let (status, error_response) = match self { let (status, error_response) = match self {
Self::InternalError(message) => ( Self::InternalError(message) => (
@@ -94,6 +94,7 @@ mod tests {
use super::*; use super::*;
use common::error::AppError; use common::error::AppError;
use std::fmt::Debug; use std::fmt::Debug;
use std::io;
// Helper to check status code // Helper to check status code
fn assert_status_code<T: IntoResponse + Debug>(response: T, expected_status: StatusCode) { fn assert_status_code<T: IntoResponse + Debug>(response: T, expected_status: StatusCode) {
@@ -105,42 +106,42 @@ mod tests {
fn test_app_error_to_api_error_conversion() { fn test_app_error_to_api_error_conversion() {
// Test NotFound error conversion // Test NotFound error conversion
let not_found = AppError::NotFound("resource not found".to_string()); let not_found = AppError::NotFound("resource not found".to_string());
let api_error = ApiError::from(not_found); let api_error = ApiErr::from(not_found);
assert!(matches!(api_error, ApiError::NotFound(msg) if msg == "resource not found")); assert!(matches!(api_error, ApiErr::NotFound(msg) if msg == "resource not found"));
// Test Validation error conversion // Test Validation error conversion
let validation = AppError::Validation("invalid input".to_string()); let validation = AppError::Validation("invalid input".to_string());
let api_error = ApiError::from(validation); let api_error = ApiErr::from(validation);
assert!(matches!(api_error, ApiError::ValidationError(msg) if msg == "invalid input")); assert!(matches!(api_error, ApiErr::ValidationError(msg) if msg == "invalid input"));
// Test Auth error conversion // Test Auth error conversion
let auth = AppError::Auth("unauthorized".to_string()); let auth = AppError::Auth("unauthorized".to_string());
let api_error = ApiError::from(auth); let api_error = ApiErr::from(auth);
assert!(matches!(api_error, ApiError::Unauthorized(msg) if msg == "unauthorized")); assert!(matches!(api_error, ApiErr::Unauthorized(msg) if msg == "unauthorized"));
// Test for internal errors - create a mock error that doesn't require surrealdb // Test for internal errors - create a mock error that doesn't require surrealdb
let internal_error = let internal_error =
AppError::Io(std::io::Error::new(std::io::ErrorKind::Other, "io error")); AppError::Io(io::Error::other("io error"));
let api_error = ApiError::from(internal_error); let api_error = ApiErr::from(internal_error);
assert!(matches!(api_error, ApiError::InternalError(_))); assert!(matches!(api_error, ApiErr::InternalError(_)));
} }
#[test] #[test]
fn test_api_error_response_status_codes() { fn test_api_error_response_status_codes() {
// Test internal error status // Test internal error status
let error = ApiError::InternalError("server error".to_string()); let error = ApiErr::InternalError("server error".to_string());
assert_status_code(error, StatusCode::INTERNAL_SERVER_ERROR); assert_status_code(error, StatusCode::INTERNAL_SERVER_ERROR);
// Test not found status // Test not found status
let error = ApiError::NotFound("not found".to_string()); let error = ApiErr::NotFound("not found".to_string());
assert_status_code(error, StatusCode::NOT_FOUND); assert_status_code(error, StatusCode::NOT_FOUND);
// Test validation error status // Test validation error status
let error = ApiError::ValidationError("invalid input".to_string()); let error = ApiErr::ValidationError("invalid input".to_string());
assert_status_code(error, StatusCode::BAD_REQUEST); assert_status_code(error, StatusCode::BAD_REQUEST);
// Test unauthorized status // Test unauthorized status
let error = ApiError::Unauthorized("not allowed".to_string()); let error = ApiErr::Unauthorized("not allowed".to_string());
assert_status_code(error, StatusCode::UNAUTHORIZED); assert_status_code(error, StatusCode::UNAUTHORIZED);
// Test payload too large status // Test payload too large status
@@ -153,15 +154,15 @@ mod tests {
fn test_error_messages() { fn test_error_messages() {
// For validation errors // For validation errors
let message = "invalid data format"; let message = "invalid data format";
let error = ApiError::ValidationError(message.to_string()); let error = ApiErr::ValidationError(message.to_string());
// Check that the error itself contains the message // Check that the error itself contains the message
assert_eq!(error.to_string(), format!("Validation error: {}", message)); assert_eq!(error.to_string(), format!("Validation error: {message}"));
// For not found errors // For not found errors
let message = "user not found"; let message = "user not found";
let error = ApiError::NotFound(message.to_string()); let error = ApiErr::NotFound(message.to_string());
assert_eq!(error.to_string(), format!("Not found: {}", message)); assert_eq!(error.to_string(), format!("Not found: {message}"));
} }
// Alternative approach for internal error test // Alternative approach for internal error test
@@ -170,8 +171,8 @@ mod tests {
// Create a sensitive error message // Create a sensitive error message
let sensitive_info = "db password incorrect"; let sensitive_info = "db password incorrect";
// Create ApiError with sensitive info // Create ApiErr with sensitive info
let api_error = ApiError::InternalError(sensitive_info.to_string()); let api_error = ApiErr::InternalError(sensitive_info.to_string());
// Check the error message is correctly set // Check the error message is correctly set
assert_eq!(api_error.to_string(), "Internal server error"); assert_eq!(api_error.to_string(), "Internal server error");
+4 -4
View File
@@ -6,19 +6,19 @@ use axum::{
use common::storage::types::user::User; use common::storage::types::user::User;
use crate::{api_state::ApiState, error::ApiError}; use crate::{api_state::ApiState, error::ApiErr};
pub async fn api_auth( pub async fn api_auth(
State(state): State<ApiState>, State(state): State<ApiState>,
mut request: Request, mut request: Request,
next: Next, next: Next,
) -> Result<Response, ApiError> { ) -> Result<Response, ApiErr> {
let api_key = extract_api_key(&request) let api_key = extract_api_key(&request)
.ok_or_else(|| ApiError::Unauthorized("You have to be authenticated".to_string()))?; .ok_or_else(|| ApiErr::Unauthorized("You have to be authenticated".to_string()))?;
let user = User::find_by_api_key(&api_key, &state.db).await?; let user = User::find_by_api_key(&api_key, &state.db).await?;
let user = let user =
user.ok_or_else(|| ApiError::Unauthorized("You have to be authenticated".to_string()))?; user.ok_or_else(|| ApiErr::Unauthorized("You have to be authenticated".to_string()))?;
request.extensions_mut().insert(user); request.extensions_mut().insert(user);
+3 -3
View File
@@ -1,12 +1,12 @@
use axum::{extract::State, response::IntoResponse, Extension, Json}; use axum::{extract::State, response::IntoResponse, Extension, Json};
use common::storage::types::user::User; use common::storage::types::user::User;
use crate::{api_state::ApiState, error::ApiError}; use crate::{api_state::ApiState, error::ApiErr};
pub async fn get_categories( pub async fn list(
State(state): State<ApiState>, State(state): State<ApiState>,
Extension(user): Extension<User>, Extension(user): Extension<User>,
) -> Result<impl IntoResponse, ApiError> { ) -> Result<impl IntoResponse, ApiErr> {
let categories = User::get_user_categories(&user.id, &state.db).await?; let categories = User::get_user_categories(&user.id, &state.db).await?;
Ok(Json(categories)) Ok(Json(categories))
+1 -1
View File
@@ -13,7 +13,7 @@ use serde_json::json;
use tempfile::NamedTempFile; use tempfile::NamedTempFile;
use tracing::info; use tracing::info;
use crate::{api_state::ApiState, error::ApiError}; use crate::{api_state::ApiState, error::ApiErr};
#[derive(Debug, TryFromMultipart)] #[derive(Debug, TryFromMultipart)]
pub struct IngestParams { pub struct IngestParams {
+35 -28
View File
@@ -202,6 +202,7 @@ impl SurrealDbClient {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use anyhow::{self, Context};
use crate::stored_object; use crate::stored_object;
use super::*; use super::*;
@@ -212,19 +213,17 @@ mod tests {
}); });
#[tokio::test] #[tokio::test]
async fn test_initialization_and_crud() { async fn test_initialization_and_crud() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); // ensures isolation per test run let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Call your initialization
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Failed to initialize schema"); .with_context(|| "Failed to initialize schema".to_string())?;
// Test basic CRUD
let dummy = Dummy { let dummy = Dummy {
id: "abc".to_string(), id: "abc".to_string(),
name: "first".to_string(), name: "first".to_string(),
@@ -232,50 +231,50 @@ mod tests {
updated_at: Utc::now(), updated_at: Utc::now(),
}; };
// Store let stored = db
let stored = db.store_item(dummy.clone()).await.expect("Failed to store"); .store_item(dummy.clone())
.await
.with_context(|| "Failed to store".to_string())?;
assert!(stored.is_some()); assert!(stored.is_some());
// Read
let fetched = db let fetched = db
.get_item::<Dummy>(&dummy.id) .get_item::<Dummy>(&dummy.id)
.await .await
.expect("Failed to fetch"); .with_context(|| "Failed to fetch".to_string())?;
assert_eq!(fetched, Some(dummy.clone())); assert_eq!(fetched, Some(dummy.clone()));
// Read all
let all = db let all = db
.get_all_stored_items::<Dummy>() .get_all_stored_items::<Dummy>()
.await .await
.expect("Failed to fetch all"); .with_context(|| "Failed to fetch all".to_string())?;
assert!(all.contains(&dummy)); assert!(all.contains(&dummy));
// Delete
let deleted = db let deleted = db
.delete_item::<Dummy>(&dummy.id) .delete_item::<Dummy>(&dummy.id)
.await .await
.expect("Failed to delete"); .with_context(|| "Failed to delete".to_string())?;
assert_eq!(deleted, Some(dummy)); assert_eq!(deleted, Some(dummy));
// After delete, should not be present
let fetch_post = db let fetch_post = db
.get_item::<Dummy>("abc") .get_item::<Dummy>("abc")
.await .await
.expect("Failed fetch post delete"); .with_context(|| "Failed fetch post delete".to_string())?;
assert!(fetch_post.is_none()); assert!(fetch_post.is_none());
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn upsert_item_overwrites_existing_records() { async fn upsert_item_overwrites_existing_records() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Failed to initialize schema"); .with_context(|| "Failed to initialize schema".to_string())?;
let mut dummy = Dummy { let mut dummy = Dummy {
id: "abc".to_string(), id: "abc".to_string(),
@@ -286,17 +285,21 @@ mod tests {
db.store_item(dummy.clone()) db.store_item(dummy.clone())
.await .await
.expect("Failed to store initial record"); .with_context(|| "Failed to store initial record".to_string())?;
dummy.name = "updated".to_string(); dummy.name = "updated".to_string();
let upserted = db let upserted = db
.upsert_item(dummy.clone()) .upsert_item(dummy.clone())
.await .await
.expect("Failed to upsert record"); .with_context(|| "Failed to upsert record".to_string())?;
assert!(upserted.is_some()); assert!(upserted.is_some());
let fetched: Option<Dummy> = db.get_item(&dummy.id).await.expect("fetch after upsert"); let fetched: Option<Dummy> = db
assert_eq!(fetched.unwrap().name, "updated"); .get_item(&dummy.id)
.await
.with_context(|| "fetch after upsert".to_string())?;
let fetched = fetched.ok_or_else(|| anyhow::anyhow!("Expected record to exist after upsert"))?;
assert_eq!(fetched.name, "updated");
let new_record = Dummy { let new_record = Dummy {
id: "def".to_string(), id: "def".to_string(),
@@ -306,25 +309,29 @@ mod tests {
}; };
db.upsert_item(new_record.clone()) db.upsert_item(new_record.clone())
.await .await
.expect("Failed to upsert new record"); .with_context(|| "Failed to upsert new record".to_string())?;
let fetched_new: Option<Dummy> = db let fetched_new: Option<Dummy> = db
.get_item(&new_record.id) .get_item(&new_record.id)
.await .await
.expect("fetch inserted via upsert"); .with_context(|| "fetch inserted via upsert".to_string())?;
assert_eq!(fetched_new, Some(new_record)); assert_eq!(fetched_new, Some(new_record));
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_applying_migrations() { async fn test_applying_migrations() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Failed to build indexes"); .with_context(|| "Failed to build indexes".to_string())?;
Ok(())
} }
} }
+54 -57
View File
@@ -159,23 +159,23 @@ impl FtsIndexSpec {
/// Build runtime Surreal indexes (FTS + HNSW) using concurrent creation with readiness polling. /// Build runtime Surreal indexes (FTS + HNSW) using concurrent creation with readiness polling.
/// Idempotent: safe to call multiple times and will overwrite HNSW definitions when the dimension changes. /// Idempotent: safe to call multiple times and will overwrite HNSW definitions when the dimension changes.
pub async fn ensure_runtime_indexes( pub async fn ensure_runtime(
db: &SurrealDbClient, db: &SurrealDbClient,
embedding_dimension: usize, embedding_dimension: usize,
) -> Result<(), AppError> { ) -> Result<(), AppError> {
ensure_runtime_indexes_inner(db, embedding_dimension) ensure_runtime_inner(db, embedding_dimension)
.await .await
.map_err(|err| AppError::InternalError(err.to_string())) .map_err(|err| AppError::InternalError(err.to_string()))
} }
/// Rebuild known FTS and HNSW indexes, skipping any that are not yet defined. /// Rebuild known FTS and HNSW indexes, skipping any that are not yet defined.
pub async fn rebuild_indexes(db: &SurrealDbClient) -> Result<(), AppError> { pub async fn rebuild(db: &SurrealDbClient) -> Result<(), AppError> {
rebuild_indexes_inner(db) rebuild_inner(db)
.await .await
.map_err(|err| AppError::InternalError(err.to_string())) .map_err(|err| AppError::InternalError(err.to_string()))
} }
async fn ensure_runtime_indexes_inner( async fn ensure_runtime_inner(
db: &SurrealDbClient, db: &SurrealDbClient,
embedding_dimension: usize, embedding_dimension: usize,
) -> Result<()> { ) -> Result<()> {
@@ -262,9 +262,8 @@ async fn get_index_status(db: &SurrealDbClient, index_name: &str, table: &str) -
.context("checking index status")?; .context("checking index status")?;
let info: Option<Value> = info_res.take(0).context("failed to take info result")?; let info: Option<Value> = info_res.take(0).context("failed to take info result")?;
let info = match info { let Some(info) = info else {
Some(i) => i, return Ok("unknown".to_string());
None => return Ok("unknown".to_string()),
}; };
let building = info.get("building"); let building = info.get("building");
@@ -277,7 +276,7 @@ async fn get_index_status(db: &SurrealDbClient, index_name: &str, table: &str) -
Ok(status) Ok(status)
} }
async fn rebuild_indexes_inner(db: &SurrealDbClient) -> Result<()> { async fn rebuild_inner(db: &SurrealDbClient) -> Result<()> {
debug!("Rebuilding indexes with concurrent definitions"); debug!("Rebuilding indexes with concurrent definitions");
create_fts_analyzer(db).await?; create_fts_analyzer(db).await?;
@@ -385,10 +384,9 @@ async fn create_fts_analyzer(db: &SurrealDbClient) -> Result<()> {
// is unavailable in the running Surreal build. Use IF NOT EXISTS to avoid clobbering // is unavailable in the running Surreal build. Use IF NOT EXISTS to avoid clobbering
// an existing analyzer definition. // an existing analyzer definition.
let snowball_query = format!( let snowball_query = format!(
"DEFINE ANALYZER IF NOT EXISTS {analyzer} "DEFINE ANALYZER IF NOT EXISTS {FTS_ANALYZER_NAME}
TOKENIZERS class TOKENIZERS class
FILTERS lowercase, ascii, snowball(english);", FILTERS lowercase, ascii, snowball(english);"
analyzer = FTS_ANALYZER_NAME
); );
match db.client.query(snowball_query).await { match db.client.query(snowball_query).await {
@@ -410,10 +408,9 @@ async fn create_fts_analyzer(db: &SurrealDbClient) -> Result<()> {
} }
let fallback_query = format!( let fallback_query = format!(
"DEFINE ANALYZER IF NOT EXISTS {analyzer} "DEFINE ANALYZER IF NOT EXISTS {FTS_ANALYZER_NAME}
TOKENIZERS class TOKENIZERS class
FILTERS lowercase, ascii;", FILTERS lowercase, ascii;"
analyzer = FTS_ANALYZER_NAME
); );
let res = db let res = db
@@ -446,6 +443,7 @@ async fn create_index_with_polling(
table: &str, table: &str,
progress_table: Option<&str>, progress_table: Option<&str>,
) -> Result<()> { ) -> Result<()> {
const MAX_ATTEMPTS: usize = 3;
let expected_total = match progress_table { let expected_total = match progress_table {
Some(table) => Some(count_table_rows(db, table).await.with_context(|| { Some(table) => Some(count_table_rows(db, table).await.with_context(|| {
format!("counting rows in {table} for index {index_name} progress") format!("counting rows in {table} for index {index_name} progress")
@@ -453,10 +451,9 @@ async fn create_index_with_polling(
None => None, None => None,
}; };
let mut attempts = 0; let mut attempts: usize = 0;
const MAX_ATTEMPTS: usize = 3;
loop { loop {
attempts += 1; attempts = attempts.saturating_add(1);
let res = db let res = db
.client .client
.query(definition.clone()) .query(definition.clone())
@@ -527,8 +524,8 @@ async fn poll_index_build_status(
break; break;
}; };
match snapshot.progress_pct { if let Some(pct) = snapshot.progress_pct {
Some(pct) => debug!( debug!(
index = %index_name, index = %index_name,
table = %table, table = %table,
status = snapshot.status, status = snapshot.status,
@@ -539,8 +536,9 @@ async fn poll_index_build_status(
total = snapshot.total_rows, total = snapshot.total_rows,
progress_pct = format_args!("{pct:.1}"), progress_pct = format_args!("{pct:.1}"),
"Index build status" "Index build status"
), );
None => debug!( } else {
debug!(
index = %index_name, index = %index_name,
table = %table, table = %table,
status = snapshot.status, status = snapshot.status,
@@ -549,7 +547,7 @@ async fn poll_index_build_status(
updated = snapshot.updated, updated = snapshot.updated,
processed = snapshot.processed, processed = snapshot.processed,
"Index build status" "Index build status"
), );
} }
if snapshot.is_ready() { if snapshot.is_ready() {
@@ -611,17 +609,17 @@ fn parse_index_build_info(
let initial = building let initial = building
.and_then(|b| b.get("initial")) .and_then(|b| b.get("initial"))
.and_then(|v| v.as_u64()) .and_then(serde_json::Value::as_u64)
.unwrap_or(0); .unwrap_or(0);
let pending = building let pending = building
.and_then(|b| b.get("pending")) .and_then(|b| b.get("pending"))
.and_then(|v| v.as_u64()) .and_then(serde_json::Value::as_u64)
.unwrap_or(0); .unwrap_or(0);
let updated = building let updated = building
.and_then(|b| b.get("updated")) .and_then(|b| b.get("updated"))
.and_then(|v| v.as_u64()) .and_then(serde_json::Value::as_u64)
.unwrap_or(0); .unwrap_or(0);
// `initial` is the number of rows seen when the build started; `updated` accounts for later writes. // `initial` is the number of rows seen when the build started; `updated` accounts for later writes.
@@ -631,7 +629,7 @@ fn parse_index_build_info(
if total == 0 { if total == 0 {
0.0 0.0
} else { } else {
((processed as f64 / total as f64).min(1.0)) * 100.0 ((f64::from(u32::try_from(processed).unwrap_or(u32::MAX)) / f64::from(u32::try_from(total).unwrap_or(1))).min(1.0)) * 100.0
} }
}); });
@@ -673,7 +671,7 @@ async fn table_index_definitions(
.client .client
.query(info_query) .query(info_query)
.await .await
.with_context(|| format!("fetching table info for {}", table))?; .with_context(|| format!("fetching table info for {table}"))?;
let info: surrealdb::Value = response let info: surrealdb::Value = response
.take(0) .take(0)
@@ -700,12 +698,15 @@ async fn index_exists(db: &SurrealDbClient, table: &str, index_name: &str) -> Re
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use anyhow::{self, Context};
use crate::storage::db::SurrealDbClient;
use serde_json::json; use serde_json::json;
use uuid::Uuid; use uuid::Uuid;
use super::*;
#[test] #[test]
fn parse_index_build_info_reports_progress() { fn parse_index_build_info_reports_progress() -> anyhow::Result<()> {
let info = json!({ let info = json!({
"building": { "building": {
"initial": 56894, "initial": 56894,
@@ -715,7 +716,8 @@ mod tests {
} }
}); });
let snapshot = parse_index_build_info(Some(info), Some(61081)).expect("snapshot"); let snapshot = parse_index_build_info(Some(info), Some(61081))
.context("snapshot")?;
assert_eq!( assert_eq!(
snapshot, snapshot,
IndexBuildSnapshot { IndexBuildSnapshot {
@@ -729,16 +731,19 @@ mod tests {
} }
); );
assert!(!snapshot.is_ready()); assert!(!snapshot.is_ready());
Ok(())
} }
#[test] #[test]
fn parse_index_build_info_defaults_to_ready_when_no_building_block() { fn parse_index_build_info_defaults_to_ready_when_no_building_block() -> anyhow::Result<()> {
// Surreal returns `{}` when the index exists but isn't building. // Surreal returns `{}` when the index exists but isn't building.
let info = json!({}); let info = json!({});
let snapshot = parse_index_build_info(Some(info), Some(10)).expect("snapshot"); let snapshot = parse_index_build_info(Some(info), Some(10))
.context("snapshot")?;
assert!(snapshot.is_ready()); assert!(snapshot.is_ready());
assert_eq!(snapshot.processed, 0); assert_eq!(snapshot.processed, 0);
assert_eq!(snapshot.progress_pct, Some(0.0)); assert_eq!(snapshot.progress_pct, Some(0.0));
Ok(())
} }
#[test] #[test]
@@ -748,48 +753,40 @@ mod tests {
} }
#[tokio::test] #[tokio::test]
async fn ensure_runtime_indexes_is_idempotent() { async fn ensure_runtime_is_idempotent() -> anyhow::Result<()> {
let namespace = "indexes_ns"; let namespace = "indexes_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("in-memory db"); .context("in-memory db")?;
db.apply_migrations() db.apply_migrations()
.await .await
.expect("migrations should succeed"); .context("migrations should succeed")?;
// First run creates everything ensure_runtime(&db, 1536).await
ensure_runtime_indexes(&db, 1536) .context("first call should succeed")?;
.await ensure_runtime(&db, 1536).await
.expect("initial index creation"); .context("second index creation")?;
Ok(())
// Second run should be a no-op and still succeed
ensure_runtime_indexes(&db, 1536)
.await
.expect("second index creation");
} }
#[tokio::test] #[tokio::test]
async fn ensure_hnsw_index_overwrites_dimension() { async fn ensure_hnsw_index_overwrites_dimension() -> anyhow::Result<()> {
let namespace = "indexes_dim"; let namespace = "indexes_dim";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("in-memory db"); .context("in-memory db")?;
db.apply_migrations() db.apply_migrations()
.await .await
.expect("migrations should succeed"); .context("migrations should succeed")?;
// Create initial index with default dimension ensure_runtime(&db, 1536).await
ensure_runtime_indexes(&db, 1536) .context("initial index creation")?;
.await ensure_runtime(&db, 128).await
.expect("initial index creation"); .context("overwritten index creation")?;
Ok(())
// Change dimension and ensure overwrite path is exercised
ensure_runtime_indexes(&db, 128)
.await
.expect("overwritten index creation");
} }
} }
+142 -108
View File
@@ -13,13 +13,13 @@ use object_store::{path::Path as ObjPath, ObjectStore};
use crate::utils::config::{AppConfig, StorageKind}; use crate::utils::config::{AppConfig, StorageKind};
pub type DynStore = Arc<dyn ObjectStore>; pub type DynStorage = Arc<dyn ObjectStore>;
/// Storage manager with persistent state and proper lifecycle management. /// Storage manager with persistent state and proper lifecycle management.
#[derive(Clone)] #[derive(Clone)]
pub struct StorageManager { pub struct StorageManager {
// Store from objectstore wrapped as dyn // Store from objectstore wrapped as dyn
store: DynStore, store: DynStorage,
// Simple enum to track which kind // Simple enum to track which kind
backend_kind: StorageKind, backend_kind: StorageKind,
// Where on disk // Where on disk
@@ -46,7 +46,7 @@ impl StorageManager {
/// ///
/// This method is useful for testing scenarios where you want to inject /// This method is useful for testing scenarios where you want to inject
/// a specific storage backend. /// a specific storage backend.
pub fn with_backend(store: DynStore, backend_kind: StorageKind) -> Self { pub fn with_backend(store: DynStorage, backend_kind: StorageKind) -> Self {
Self { Self {
store, store,
backend_kind, backend_kind,
@@ -216,7 +216,7 @@ impl StorageManager {
/// storage backends with proper error handling and validation. /// storage backends with proper error handling and validation.
async fn create_storage_backend( async fn create_storage_backend(
cfg: &AppConfig, cfg: &AppConfig,
) -> object_store::Result<(DynStore, Option<PathBuf>)> { ) -> object_store::Result<(DynStorage, Option<PathBuf>)> {
match cfg.storage { match cfg.storage {
StorageKind::Local => { StorageKind::Local => {
let base = resolve_base_dir(cfg); let base = resolve_base_dir(cfg);
@@ -261,9 +261,7 @@ async fn create_storage_backend(
builder = builder.with_endpoint(endpoint); builder = builder.with_endpoint(endpoint);
} }
if let Some(region) = &cfg.s3_region { builder = builder.with_region(&cfg.s3_region);
builder = builder.with_region(region);
}
let store = builder.build()?; let store = builder.build()?;
Ok((Arc::new(store), None)) Ok((Arc::new(store), None))
@@ -342,7 +340,7 @@ pub mod testing {
surrealdb_password: "test".into(), surrealdb_password: "test".into(),
surrealdb_namespace: "test".into(), surrealdb_namespace: "test".into(),
surrealdb_database: "test".into(), surrealdb_database: "test".into(),
data_dir: base.into(), data_dir: base,
http_port: 0, http_port: 0,
openai_base_url: "..".into(), openai_base_url: "..".into(),
storage: StorageKind::Local, storage: StorageKind::Local,
@@ -382,7 +380,7 @@ pub mod testing {
#[derive(Clone)] #[derive(Clone)]
pub struct TestStorageManager { pub struct TestStorageManager {
storage: StorageManager, storage: StorageManager,
_temp_dir: Option<(String, std::path::PathBuf)>, // For local storage cleanup temp_dir: Option<(String, std::path::PathBuf)>, // For local storage cleanup
} }
impl TestStorageManager { impl TestStorageManager {
@@ -396,7 +394,7 @@ pub mod testing {
Ok(Self { Ok(Self {
storage, storage,
_temp_dir: None, temp_dir: None,
}) })
} }
@@ -413,7 +411,7 @@ pub mod testing {
Ok(Self { Ok(Self {
storage, storage,
_temp_dir: resolved, temp_dir: resolved,
}) })
} }
@@ -437,7 +435,7 @@ pub mod testing {
Ok(Self { Ok(Self {
storage, storage,
_temp_dir: None, temp_dir: None,
}) })
} }
@@ -454,7 +452,7 @@ pub mod testing {
Ok(Self { Ok(Self {
storage, storage,
_temp_dir: temp_dir, temp_dir,
}) })
} }
@@ -508,7 +506,7 @@ pub mod testing {
impl Drop for TestStorageManager { impl Drop for TestStorageManager {
fn drop(&mut self) { fn drop(&mut self) {
// Clean up temporary directories for local storage // Clean up temporary directories for local storage
if let Some((_, path)) = &self._temp_dir { if let Some((_, path)) = &self.temp_dir {
if path.exists() { if path.exists() {
let _ = std::fs::remove_dir_all(path); let _ = std::fs::remove_dir_all(path);
} }
@@ -584,6 +582,7 @@ pub fn split_object_path(path: &str) -> AnyResult<(String, String)> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use anyhow::Context;
use crate::utils::config::{PdfIngestMode::LlmFirst, StorageKind}; use crate::utils::config::{PdfIngestMode::LlmFirst, StorageKind};
use bytes::Bytes; use bytes::Bytes;
use uuid::Uuid; use uuid::Uuid;
@@ -623,11 +622,11 @@ mod tests {
} }
#[tokio::test] #[tokio::test]
async fn test_storage_manager_memory_basic_operations() { async fn test_storage_manager_memory_basic_operations() -> anyhow::Result<()> {
let cfg = test_config_memory(); let cfg = test_config_memory();
let storage = StorageManager::new(&cfg) let storage = StorageManager::new(&cfg)
.await .await
.expect("create storage manager"); .with_context(|| "create storage manager".to_string())?;
assert!(storage.local_base_path().is_none()); assert!(storage.local_base_path().is_none());
let location = "test/data/file.txt"; let location = "test/data/file.txt";
@@ -637,31 +636,33 @@ mod tests {
storage storage
.put(location, Bytes::from(data.to_vec())) .put(location, Bytes::from(data.to_vec()))
.await .await
.expect("put"); .with_context(|| "put".to_string())?;
let retrieved = storage.get(location).await.expect("get"); let retrieved = storage.get(location).await.with_context(|| "get".to_string())?;
assert_eq!(retrieved.as_ref(), data); assert_eq!(retrieved.as_ref(), data);
// Test exists // Test exists
assert!(storage.exists(location).await.expect("exists check")); assert!(storage.exists(location).await.with_context(|| "exists check".to_string())?);
// Test delete // Test delete
storage.delete_prefix("test/data/").await.expect("delete"); storage.delete_prefix("test/data/").await.with_context(|| "delete".to_string())?;
assert!(!storage assert!(!storage
.exists(location) .exists(location)
.await .await
.expect("exists check after delete")); .with_context(|| "exists check after delete".to_string())?);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_storage_manager_local_basic_operations() { async fn test_storage_manager_local_basic_operations() -> anyhow::Result<()> {
let base = format!("/tmp/minne_storage_test_{}", Uuid::new_v4()); let base = format!("/tmp/minne_storage_test_{}", Uuid::new_v4());
let cfg = test_config(&base); let cfg = test_config(&base);
let storage = StorageManager::new(&cfg) let storage = StorageManager::new(&cfg)
.await .await
.expect("create storage manager"); .with_context(|| "create storage manager".to_string())?;
let resolved_base = storage let resolved_base = storage
.local_base_path() .local_base_path()
.expect("resolved base dir") .with_context(|| "resolved base dir".to_string())?
.to_path_buf(); .to_path_buf();
assert_eq!(resolved_base, PathBuf::from(&base)); assert_eq!(resolved_base, PathBuf::from(&base));
@@ -672,42 +673,44 @@ mod tests {
storage storage
.put(location, Bytes::from(data.to_vec())) .put(location, Bytes::from(data.to_vec()))
.await .await
.expect("put"); .with_context(|| "put".to_string())?;
let retrieved = storage.get(location).await.expect("get"); let retrieved = storage.get(location).await.with_context(|| "get".to_string())?;
assert_eq!(retrieved.as_ref(), data); assert_eq!(retrieved.as_ref(), data);
let object_dir = resolved_base.join("test/data"); let object_dir = resolved_base.join("test/data");
tokio::fs::metadata(&object_dir) tokio::fs::metadata(&object_dir)
.await .await
.expect("object directory exists after write"); .with_context(|| "object directory exists after write".to_string())?;
// Test exists // Test exists
assert!(storage.exists(location).await.expect("exists check")); assert!(storage.exists(location).await.with_context(|| "exists check".to_string())?);
// Test delete // Test delete
storage.delete_prefix("test/data/").await.expect("delete"); storage.delete_prefix("test/data/").await.with_context(|| "delete".to_string())?;
assert!(!storage assert!(!storage
.exists(location) .exists(location)
.await .await
.expect("exists check after delete")); .with_context(|| "exists check after delete".to_string())?);
assert!( assert!(
tokio::fs::metadata(&object_dir).await.is_err(), tokio::fs::metadata(&object_dir).await.is_err(),
"object directory should be removed" "object directory should be removed"
); );
tokio::fs::metadata(&resolved_base) tokio::fs::metadata(&resolved_base)
.await .await
.expect("base directory remains intact"); .with_context(|| "base directory remains intact".to_string())?;
// Clean up // Clean up
let _ = tokio::fs::remove_dir_all(&base).await; let _ = tokio::fs::remove_dir_all(&base).await;
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_storage_manager_memory_persistence() { async fn test_storage_manager_memory_persistence() -> anyhow::Result<()> {
let cfg = test_config_memory(); let cfg = test_config_memory();
let storage = StorageManager::new(&cfg) let storage = StorageManager::new(&cfg)
.await .await
.expect("create storage manager"); .with_context(|| "create storage manager".to_string())?;
let location = "persistence/test.txt"; let location = "persistence/test.txt";
let data1 = b"first data"; let data1 = b"first data";
@@ -717,32 +720,34 @@ mod tests {
storage storage
.put(location, Bytes::from(data1.to_vec())) .put(location, Bytes::from(data1.to_vec()))
.await .await
.expect("put first"); .with_context(|| "put first".to_string())?;
// Retrieve and verify first data // Retrieve and verify first data
let retrieved1 = storage.get(location).await.expect("get first"); let retrieved1 = storage.get(location).await.with_context(|| "get first".to_string())?;
assert_eq!(retrieved1.as_ref(), data1); assert_eq!(retrieved1.as_ref(), data1);
// Overwrite with second data // Overwrite with second data
storage storage
.put(location, Bytes::from(data2.to_vec())) .put(location, Bytes::from(data2.to_vec()))
.await .await
.expect("put second"); .with_context(|| "put second".to_string())?;
// Retrieve and verify second data // Retrieve and verify second data
let retrieved2 = storage.get(location).await.expect("get second"); let retrieved2 = storage.get(location).await.with_context(|| "get second".to_string())?;
assert_eq!(retrieved2.as_ref(), data2); assert_eq!(retrieved2.as_ref(), data2);
// Data persists across multiple operations using the same StorageManager // Data persists across multiple operations using the same StorageManager
assert_ne!(retrieved1.as_ref(), retrieved2.as_ref()); assert_ne!(retrieved1.as_ref(), retrieved2.as_ref());
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_storage_manager_list_operations() { async fn test_storage_manager_list_operations() -> anyhow::Result<()> {
let cfg = test_config_memory(); let cfg = test_config_memory();
let storage = StorageManager::new(&cfg) let storage = StorageManager::new(&cfg)
.await .await
.expect("create storage manager"); .with_context(|| "create storage manager".to_string())?;
// Create multiple files // Create multiple files
let files = vec![ let files = vec![
@@ -755,15 +760,15 @@ mod tests {
storage storage
.put(location, Bytes::from(data.to_vec())) .put(location, Bytes::from(data.to_vec()))
.await .await
.expect("put"); .with_context(|| "put".to_string())?;
} }
// Test listing without prefix // Test listing without prefix
let all_files = storage.list(None).await.expect("list all"); let all_files = storage.list(None).await.with_context(|| "list all".to_string())?;
assert_eq!(all_files.len(), 3); assert_eq!(all_files.len(), 3);
// Test listing with prefix // Test listing with prefix
let dir1_files = storage.list(Some("dir1/")).await.expect("list dir1"); let dir1_files = storage.list(Some("dir1/")).await.with_context(|| "list dir1".to_string())?;
assert_eq!(dir1_files.len(), 2); assert_eq!(dir1_files.len(), 2);
assert!(dir1_files assert!(dir1_files
.iter() .iter()
@@ -776,16 +781,18 @@ mod tests {
let empty_files = storage let empty_files = storage
.list(Some("nonexistent/")) .list(Some("nonexistent/"))
.await .await
.expect("list nonexistent"); .with_context(|| "list nonexistent".to_string())?;
assert_eq!(empty_files.len(), 0); assert_eq!(empty_files.len(), 0);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_storage_manager_stream_operations() { async fn test_storage_manager_stream_operations() -> anyhow::Result<()> {
let cfg = test_config_memory(); let cfg = test_config_memory();
let storage = StorageManager::new(&cfg) let storage = StorageManager::new(&cfg)
.await .await
.expect("create storage manager"); .with_context(|| "create storage manager".to_string())?;
let location = "stream/test.bin"; let location = "stream/test.bin";
let content = vec![42u8; 1024 * 64]; // 64KB of data let content = vec![42u8; 1024 * 64]; // 64KB of data
@@ -794,22 +801,24 @@ mod tests {
storage storage
.put(location, Bytes::from(content.clone())) .put(location, Bytes::from(content.clone()))
.await .await
.expect("put large data"); .with_context(|| "put large data".to_string())?;
// Get as stream // Get as stream
let mut stream = storage.get_stream(location).await.expect("get stream"); let mut stream = storage.get_stream(location).await.with_context(|| "get stream".to_string())?;
let mut collected = Vec::new(); let mut collected = Vec::new();
while let Some(chunk) = stream.next().await { while let Some(chunk) = stream.next().await {
let chunk = chunk.expect("stream chunk"); let chunk = chunk.with_context(|| "stream chunk".to_string())?;
collected.extend_from_slice(&chunk); collected.extend_from_slice(&chunk);
} }
assert_eq!(collected, content); assert_eq!(collected, content);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_storage_manager_with_custom_backend() { async fn test_storage_manager_with_custom_backend() -> anyhow::Result<()> {
use object_store::memory::InMemory; use object_store::memory::InMemory;
// Create custom memory backend // Create custom memory backend
@@ -823,20 +832,22 @@ mod tests {
storage storage
.put(location, Bytes::from(data.to_vec())) .put(location, Bytes::from(data.to_vec()))
.await .await
.expect("put"); .with_context(|| "put".to_string())?;
let retrieved = storage.get(location).await.expect("get"); let retrieved = storage.get(location).await.with_context(|| "get".to_string())?;
assert_eq!(retrieved.as_ref(), data); assert_eq!(retrieved.as_ref(), data);
assert!(storage.exists(location).await.expect("exists")); assert!(storage.exists(location).await.with_context(|| "exists".to_string())?);
assert_eq!(*storage.backend_kind(), StorageKind::Memory); assert_eq!(*storage.backend_kind(), StorageKind::Memory);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_storage_manager_error_handling() { async fn test_storage_manager_error_handling() -> anyhow::Result<()> {
let cfg = test_config_memory(); let cfg = test_config_memory();
let storage = StorageManager::new(&cfg) let storage = StorageManager::new(&cfg)
.await .await
.expect("create storage manager"); .with_context(|| "create storage manager".to_string())?;
// Test getting non-existent file // Test getting non-existent file
let result = storage.get("nonexistent.txt").await; let result = storage.get("nonexistent.txt").await;
@@ -846,124 +857,136 @@ mod tests {
let exists = storage let exists = storage
.exists("nonexistent.txt") .exists("nonexistent.txt")
.await .await
.expect("exists check"); .with_context(|| "exists check".to_string())?;
assert!(!exists); assert!(!exists);
// Test listing with invalid location (should not panic) // Test listing with invalid location (should not panic)
let _result = storage.get("").await; let _result = storage.get("").await;
// This may or may not error depending on the backend implementation // This may or may not error depending on the backend implementation
// The important thing is that it doesn't panic // The important thing is that it doesn't panic
Ok(())
} }
// TestStorageManager tests // TestStorageManager tests
#[tokio::test] #[tokio::test]
async fn test_test_storage_manager_memory() { async fn test_test_storage_manager_memory() -> anyhow::Result<()> {
let test_storage = testing::TestStorageManager::new_memory() let test_storage = testing::TestStorageManager::new_memory()
.await .await
.expect("create test storage"); .with_context(|| "create test storage".to_string())?;
let location = "test/storage/file.txt"; let location = "test/storage/file.txt";
let data = b"test data with TestStorageManager"; let data = b"test data with TestStorageManager";
// Test put and get // Test put and get
test_storage.put(location, data).await.expect("put"); test_storage.put(location, data).await.with_context(|| "put".to_string())?;
let retrieved = test_storage.get(location).await.expect("get"); let retrieved = test_storage.get(location).await.with_context(|| "get".to_string())?;
assert_eq!(retrieved.as_ref(), data); assert_eq!(retrieved.as_ref(), data);
// Test existence check // Test existence check
assert!(test_storage.exists(location).await.expect("exists")); assert!(test_storage.exists(location).await.with_context(|| "exists".to_string())?);
// Test list // Test list
let files = test_storage let files = test_storage
.list(Some("test/storage/")) .list(Some("test/storage/"))
.await .await
.expect("list"); .with_context(|| "list".to_string())?;
assert_eq!(files.len(), 1); assert_eq!(files.len(), 1);
// Test delete // Test delete
test_storage test_storage
.delete_prefix("test/storage/") .delete_prefix("test/storage/")
.await .await
.expect("delete"); .with_context(|| "delete".to_string())?;
assert!(!test_storage assert!(!test_storage
.exists(location) .exists(location)
.await .await
.expect("exists after delete")); .with_context(|| "exists after delete".to_string())?);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_test_storage_manager_local() { async fn test_test_storage_manager_local() -> anyhow::Result<()> {
let test_storage = testing::TestStorageManager::new_local() let test_storage = testing::TestStorageManager::new_local()
.await .await
.expect("create test storage"); .with_context(|| "create test storage".to_string())?;
let location = "test/local/file.txt"; let location = "test/local/file.txt";
let data = b"test data with local TestStorageManager"; let data = b"test data with local TestStorageManager";
// Test put and get test_storage.put(location, data).await
test_storage.put(location, data).await.expect("put"); .with_context(|| "put".to_string())?;
let retrieved = test_storage.get(location).await.expect("get"); let retrieved = test_storage.get(location).await
.with_context(|| "get".to_string())?;
assert_eq!(retrieved.as_ref(), data); assert_eq!(retrieved.as_ref(), data);
// Test existence check assert!(test_storage.exists(location).await
assert!(test_storage.exists(location).await.expect("exists")); .with_context(|| "exists".to_string())?);
// The storage should be automatically cleaned up when test_storage is dropped Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_test_storage_manager_isolation() { async fn test_test_storage_manager_isolation() -> anyhow::Result<()> {
let storage1 = testing::TestStorageManager::new_memory() let storage1 = testing::TestStorageManager::new_memory()
.await .await
.expect("create test storage 1"); .with_context(|| "create test storage 1".to_string())?;
let storage2 = testing::TestStorageManager::new_memory() let storage2 = testing::TestStorageManager::new_memory()
.await .await
.expect("create test storage 2"); .with_context(|| "create test storage 2".to_string())?;
let location = "isolation/test.txt"; let location = "isolation/test.txt";
let data1 = b"storage 1 data"; let data1 = b"storage 1 data";
let data2 = b"storage 2 data"; let data2 = b"storage 2 data";
// Put different data in each storage storage1.put(location, data1).await
storage1.put(location, data1).await.expect("put storage 1"); .with_context(|| "put storage 1".to_string())?;
storage2.put(location, data2).await.expect("put storage 2"); storage2.put(location, data2).await
.with_context(|| "put storage 2".to_string())?;
// Verify isolation let retrieved1 = storage1.get(location).await
let retrieved1 = storage1.get(location).await.expect("get storage 1"); .with_context(|| "get storage 1".to_string())?;
let retrieved2 = storage2.get(location).await.expect("get storage 2"); let retrieved2 = storage2.get(location).await
.with_context(|| "get storage 2".to_string())?;
assert_eq!(retrieved1.as_ref(), data1); assert_eq!(retrieved1.as_ref(), data1);
assert_eq!(retrieved2.as_ref(), data2); assert_eq!(retrieved2.as_ref(), data2);
assert_ne!(retrieved1.as_ref(), retrieved2.as_ref()); assert_ne!(retrieved1.as_ref(), retrieved2.as_ref());
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_test_storage_manager_config() { async fn test_test_storage_manager_config() -> anyhow::Result<()> {
let cfg = testing::test_config_memory(); let cfg = testing::test_config_memory();
let test_storage = testing::TestStorageManager::with_config(&cfg) let test_storage = testing::TestStorageManager::with_config(&cfg)
.await .await
.expect("create test storage with config"); .with_context(|| "create test storage with config".to_string())?;
let location = "config/test.txt"; let location = "config/test.txt";
let data = b"test data with custom config"; let data = b"test data with custom config";
test_storage.put(location, data).await.expect("put"); test_storage.put(location, data).await
let retrieved = test_storage.get(location).await.expect("get"); .with_context(|| "put".to_string())?;
let retrieved = test_storage.get(location).await
.with_context(|| "get".to_string())?;
assert_eq!(retrieved.as_ref(), data); assert_eq!(retrieved.as_ref(), data);
// Verify it's using memory backend
assert_eq!(*test_storage.storage().backend_kind(), StorageKind::Memory); assert_eq!(*test_storage.storage().backend_kind(), StorageKind::Memory);
Ok(())
} }
// S3 Tests - Require a reachable MinIO endpoint and test bucket. // S3 Tests - Require a reachable MinIO endpoint and test bucket.
// `TestStorageManager::new_s3()` probes connectivity and these tests auto-skip when unavailable. // `TestStorageManager::new_s3()` probes connectivity and these tests auto-skip when unavailable.
#[tokio::test] #[tokio::test]
async fn test_storage_manager_s3_basic_operations() { async fn test_storage_manager_s3_basic_operations() -> anyhow::Result<()> {
// Skip if S3 connection fails (e.g. no MinIO) // Skip if S3 connection fails (e.g. no MinIO)
let Ok(storage) = testing::TestStorageManager::new_s3().await else { let Ok(storage) = testing::TestStorageManager::new_s3().await else {
eprintln!("Skipping S3 test (setup failed)"); eprintln!("Skipping S3 test (setup failed)");
return; return Ok(());
}; };
let prefix = format!("test-basic-{}", Uuid::new_v4()); let prefix = format!("test-basic-{}", Uuid::new_v4());
@@ -973,31 +996,33 @@ mod tests {
// Test put // Test put
if let Err(e) = storage.put(&location, data).await { if let Err(e) = storage.put(&location, data).await {
eprintln!("Skipping S3 test (put failed - bucket missing?): {e}"); eprintln!("Skipping S3 test (put failed - bucket missing?): {e}");
return; return Ok(());
} }
// Test get // Test get
let retrieved = storage.get(&location).await.expect("get"); let retrieved = storage.get(&location).await.with_context(|| "get".to_string())?;
assert_eq!(retrieved.as_ref(), data); assert_eq!(retrieved.as_ref(), data);
// Test exists // Test exists
assert!(storage.exists(&location).await.expect("exists")); assert!(storage.exists(&location).await.with_context(|| "exists".to_string())?);
// Test delete // Test delete
storage storage
.delete_prefix(&format!("{prefix}/")) .delete_prefix(&format!("{prefix}/"))
.await .await
.expect("delete"); .with_context(|| "delete".to_string())?;
assert!(!storage assert!(!storage
.exists(&location) .exists(&location)
.await .await
.expect("exists after delete")); .with_context(|| "exists after delete".to_string())?);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_storage_manager_s3_list_operations() { async fn test_storage_manager_s3_list_operations() -> anyhow::Result<()> {
let Ok(storage) = testing::TestStorageManager::new_s3().await else { let Ok(storage) = testing::TestStorageManager::new_s3().await else {
return; return Ok(());
}; };
let prefix = format!("test-list-{}", Uuid::new_v4()); let prefix = format!("test-list-{}", Uuid::new_v4());
@@ -1009,23 +1034,25 @@ mod tests {
for (loc, data) in &files { for (loc, data) in &files {
if storage.put(loc, *data).await.is_err() { if storage.put(loc, *data).await.is_err() {
return; // Abort if put fails return Ok(()); // Abort if put fails
} }
} }
// List with prefix // List with prefix
let list_prefix = format!("{prefix}/"); let list_prefix = format!("{prefix}/");
let items = storage.list(Some(&list_prefix)).await.expect("list"); let items = storage.list(Some(&list_prefix)).await.with_context(|| "list".to_string())?;
assert_eq!(items.len(), 3); assert_eq!(items.len(), 3);
// Cleanup // Cleanup
storage.delete_prefix(&list_prefix).await.expect("cleanup"); storage.delete_prefix(&list_prefix).await.with_context(|| "cleanup".to_string())?;
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_storage_manager_s3_stream_operations() { async fn test_storage_manager_s3_stream_operations() -> anyhow::Result<()> {
let Ok(storage) = testing::TestStorageManager::new_s3().await else { let Ok(storage) = testing::TestStorageManager::new_s3().await else {
return; return Ok(());
}; };
let prefix = format!("test-stream-{}", Uuid::new_v4()); let prefix = format!("test-stream-{}", Uuid::new_v4());
@@ -1033,38 +1060,45 @@ mod tests {
let content = vec![42u8; 1024 * 10]; // 10KB let content = vec![42u8; 1024 * 10]; // 10KB
if storage.put(&location, &content).await.is_err() { if storage.put(&location, &content).await.is_err() {
return; return Ok(());
} }
let mut stream = storage.get_stream(&location).await.expect("get stream"); let mut stream = storage.get_stream(&location).await.with_context(|| "get stream".to_string())?;
let mut collected = Vec::new(); let mut collected = Vec::new();
while let Some(chunk) = stream.next().await { while let Some(chunk) = stream.next().await {
collected.extend_from_slice(&chunk.expect("chunk")); collected.extend_from_slice(&chunk.with_context(|| "chunk".to_string())?);
} }
assert_eq!(collected, content); assert_eq!(collected, content);
storage storage
.delete_prefix(&format!("{prefix}/")) .delete_prefix(&format!("{prefix}/"))
.await .await
.expect("cleanup"); .with_context(|| "cleanup".to_string())?;
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_storage_manager_s3_backend_kind() { async fn test_storage_manager_s3_backend_kind() -> anyhow::Result<()> {
let Ok(storage) = testing::TestStorageManager::new_s3().await else { let Ok(storage) = testing::TestStorageManager::new_s3().await else {
return; return Ok(());
}; };
assert_eq!(*storage.storage().backend_kind(), StorageKind::S3); assert_eq!(*storage.storage().backend_kind(), StorageKind::S3);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_storage_manager_s3_error_handling() { async fn test_storage_manager_s3_error_handling() -> anyhow::Result<()> {
let Ok(storage) = testing::TestStorageManager::new_s3().await else { let Ok(storage) = testing::TestStorageManager::new_s3().await else {
return; return Ok(());
}; };
let location = format!("nonexistent-{}/file.txt", Uuid::new_v4()); let location = format!("nonexistent-{}/file.txt", Uuid::new_v4());
assert!(storage.get(&location).await.is_err()); assert!(storage.get(&location).await.is_err());
assert!(!storage.exists(&location).await.expect("exists check")); // exists may fail if S3 is unavailable; treat error as false
assert!(!storage.exists(&location).await.unwrap_or(false));
Ok(())
} }
} }
+45 -73
View File
@@ -90,6 +90,7 @@ impl Analytics {
mod tests { mod tests {
use super::*; use super::*;
use crate::stored_object; use crate::stored_object;
use anyhow::{self};
use uuid::Uuid; use uuid::Uuid;
stored_object!(TestUser, "user", { stored_object!(TestUser, "user", {
@@ -99,18 +100,14 @@ mod tests {
}); });
#[tokio::test] #[tokio::test]
async fn test_analytics_initialization() { async fn test_analytics_initialization() -> anyhow::Result<()> {
// Setup in-memory database for testing // Setup in-memory database for testing
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database).await?;
.await
.expect("Failed to start in-memory surrealdb");
// Test initialization of analytics // Test initialization of analytics
let analytics = Analytics::ensure_initialized(&db) let analytics = Analytics::ensure_initialized(&db).await?;
.await
.expect("Failed to initialize analytics");
// Verify initial state after initialization // Verify initial state after initialization
assert_eq!(analytics.id, "current"); assert_eq!(analytics.id, "current");
@@ -118,159 +115,134 @@ mod tests {
assert_eq!(analytics.visitors, 0); assert_eq!(analytics.visitors, 0);
// Test idempotency - ensure calling it again doesn't change anything // Test idempotency - ensure calling it again doesn't change anything
let analytics_again = Analytics::ensure_initialized(&db) let analytics_again = Analytics::ensure_initialized(&db).await?;
.await
.expect("Failed to get analytics after initialization");
assert_eq!(analytics.id, analytics_again.id); assert_eq!(analytics.id, analytics_again.id);
assert_eq!(analytics.page_loads, analytics_again.page_loads); assert_eq!(analytics.page_loads, analytics_again.page_loads);
assert_eq!(analytics.visitors, analytics_again.visitors); assert_eq!(analytics.visitors, analytics_again.visitors);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_get_current_analytics() { async fn test_get_current_analytics() -> anyhow::Result<()> {
// Setup in-memory database for testing // Setup in-memory database for testing
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database).await?;
.await
.expect("Failed to start in-memory surrealdb");
// Initialize analytics // Initialize analytics
Analytics::ensure_initialized(&db) Analytics::ensure_initialized(&db).await?;
.await
.expect("Failed to initialize analytics");
// Test get_current method // Test get_current method
let analytics = Analytics::get_current(&db) let analytics = Analytics::get_current(&db).await?;
.await
.expect("Failed to get current analytics");
assert_eq!(analytics.id, "current"); assert_eq!(analytics.id, "current");
assert_eq!(analytics.page_loads, 0); assert_eq!(analytics.page_loads, 0);
assert_eq!(analytics.visitors, 0); assert_eq!(analytics.visitors, 0);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_increment_visitors() { async fn test_increment_visitors() -> anyhow::Result<()> {
// Setup in-memory database for testing // Setup in-memory database for testing
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database).await?;
.await
.expect("Failed to start in-memory surrealdb");
// Initialize analytics // Initialize analytics
Analytics::ensure_initialized(&db) Analytics::ensure_initialized(&db).await?;
.await
.expect("Failed to initialize analytics");
// Test increment_visitors method // Test increment_visitors method
let analytics = Analytics::increment_visitors(&db) let analytics = Analytics::increment_visitors(&db).await?;
.await
.expect("Failed to increment visitors");
assert_eq!(analytics.visitors, 1); assert_eq!(analytics.visitors, 1);
assert_eq!(analytics.page_loads, 0); assert_eq!(analytics.page_loads, 0);
// Increment again and check // Increment again and check
let analytics = Analytics::increment_visitors(&db) let analytics = Analytics::increment_visitors(&db).await?;
.await
.expect("Failed to increment visitors again");
assert_eq!(analytics.visitors, 2); assert_eq!(analytics.visitors, 2);
assert_eq!(analytics.page_loads, 0); assert_eq!(analytics.page_loads, 0);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_increment_page_loads() { async fn test_increment_page_loads() -> anyhow::Result<()> {
// Setup in-memory database for testing // Setup in-memory database for testing
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database).await?;
.await
.expect("Failed to start in-memory surrealdb");
// Initialize analytics // Initialize analytics
Analytics::ensure_initialized(&db) Analytics::ensure_initialized(&db).await?;
.await
.expect("Failed to initialize analytics");
// Test increment_page_loads method // Test increment_page_loads method
let analytics = Analytics::increment_page_loads(&db) let analytics = Analytics::increment_page_loads(&db).await?;
.await
.expect("Failed to increment page loads");
assert_eq!(analytics.visitors, 0); assert_eq!(analytics.visitors, 0);
assert_eq!(analytics.page_loads, 1); assert_eq!(analytics.page_loads, 1);
// Increment again and check // Increment again and check
let analytics = Analytics::increment_page_loads(&db) let analytics = Analytics::increment_page_loads(&db).await?;
.await
.expect("Failed to increment page loads again");
assert_eq!(analytics.visitors, 0); assert_eq!(analytics.visitors, 0);
assert_eq!(analytics.page_loads, 2); assert_eq!(analytics.page_loads, 2);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_get_users_amount() { async fn test_get_users_amount() -> anyhow::Result<()> {
// Setup in-memory database for testing // Setup in-memory database for testing
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database).await?;
.await
.expect("Failed to start in-memory surrealdb");
// Test with no users // Test with no users
let count = Analytics::get_users_amount(&db) let count = Analytics::get_users_amount(&db).await?;
.await
.expect("Failed to get users amount");
assert_eq!(count, 0); assert_eq!(count, 0);
// Create a few test users // Create a few test users
for i in 0..3 { for i in 0..3 {
let user = TestUser { let user = TestUser {
id: format!("user{}", i), id: format!("user{i}"),
email: format!("user{}@example.com", i), email: format!("user{i}@example.com"),
password: "password".to_string(), password: "password".to_string(),
user_id: format!("uid{}", i), user_id: format!("uid{i}"),
created_at: Utc::now(), created_at: Utc::now(),
updated_at: Utc::now(), updated_at: Utc::now(),
}; };
db.store_item(user) db.store_item(user).await?;
.await
.expect("Failed to create test user");
} }
// Test users amount after adding users // Test users amount after adding users
let count = Analytics::get_users_amount(&db) let count = Analytics::get_users_amount(&db).await?;
.await
.expect("Failed to get users amount after adding users");
assert_eq!(count, 3); assert_eq!(count, 3);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_get_current_nonexistent() { async fn test_get_current_nonexistent() -> anyhow::Result<()> {
// Setup in-memory database for testing // Setup in-memory database for testing
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database).await?;
.await
.expect("Failed to start in-memory surrealdb");
// Don't initialize analytics and try to get it // Don't initialize analytics and try to get it
let result = Analytics::get_current(&db).await; let result = Analytics::get_current(&db).await;
assert!(result.is_err()); assert!(result.is_err());
if let Err(err) = result { match result {
match err { Ok(_) => anyhow::bail!("Expected NotFound error, got success"),
AppError::NotFound(_) => { Err(AppError::NotFound(_)) => {}
// Expected error Err(err) => anyhow::bail!("Expected NotFound error, got: {err:?}"),
}
_ => panic!("Expected NotFound error, got: {:?}", err),
}
} }
Ok(())
} }
} }
+54 -61
View File
@@ -144,76 +144,71 @@ impl Conversation {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use anyhow::{self, Context};
use crate::storage::types::message::MessageRole; use crate::storage::types::message::MessageRole;
use super::*; use super::*;
#[tokio::test] #[tokio::test]
async fn test_create_conversation() { async fn test_create_conversation() -> anyhow::Result<()> {
// Setup in-memory database for testing
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Create a new conversation
let user_id = "test_user"; let user_id = "test_user";
let title = "Test Conversation"; let title = "Test Conversation";
let conversation = Conversation::new(user_id.to_string(), title.to_string()); let conversation = Conversation::new(user_id.to_string(), title.to_string());
// Verify conversation properties
assert_eq!(conversation.user_id, user_id); assert_eq!(conversation.user_id, user_id);
assert_eq!(conversation.title, title); assert_eq!(conversation.title, title);
assert!(!conversation.id.is_empty()); assert!(!conversation.id.is_empty());
// Store the conversation
let result = db.store_item(conversation.clone()).await; let result = db.store_item(conversation.clone()).await;
assert!(result.is_ok()); assert!(result.is_ok());
// Verify it can be retrieved
let retrieved: Option<Conversation> = db let retrieved: Option<Conversation> = db
.get_item(&conversation.id) .get_item(&conversation.id)
.await .await
.expect("Failed to retrieve conversation"); .with_context(|| "Failed to retrieve conversation".to_string())?;
assert!(retrieved.is_some());
let retrieved = retrieved.unwrap(); let retrieved = retrieved.ok_or_else(|| anyhow::anyhow!("Expected conversation to exist"))?;
assert_eq!(retrieved.id, conversation.id); assert_eq!(retrieved.id, conversation.id);
assert_eq!(retrieved.user_id, user_id); assert_eq!(retrieved.user_id, user_id);
assert_eq!(retrieved.title, title); assert_eq!(retrieved.title, title);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_get_complete_conversation_not_found() { async fn test_get_complete_conversation_not_found() -> anyhow::Result<()> {
// Setup in-memory database for testing
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Try to get a conversation that doesn't exist
let result = let result =
Conversation::get_complete_conversation("nonexistent_id", "test_user", &db).await; Conversation::get_complete_conversation("nonexistent_id", "test_user", &db).await;
assert!(result.is_err()); assert!(result.is_err());
match result { match result {
Err(AppError::NotFound(_)) => { /* expected error */ } Err(AppError::NotFound(_)) => {}
_ => panic!("Expected NotFound error"), _ => anyhow::bail!("Expected NotFound error"),
} }
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_get_complete_conversation_unauthorized() { async fn test_get_complete_conversation_unauthorized() -> anyhow::Result<()> {
// Setup in-memory database for testing
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Create and store a conversation for user_id_1
let user_id_1 = "user_1"; let user_id_1 = "user_1";
let conversation = let conversation =
Conversation::new(user_id_1.to_string(), "Private Conversation".to_string()); Conversation::new(user_id_1.to_string(), "Private Conversation".to_string());
@@ -221,27 +216,28 @@ mod tests {
db.store_item(conversation) db.store_item(conversation)
.await .await
.expect("Failed to store conversation"); .with_context(|| "Failed to store conversation".to_string())?;
// Try to access with a different user
let user_id_2 = "user_2"; let user_id_2 = "user_2";
let result = let result =
Conversation::get_complete_conversation(&conversation_id, user_id_2, &db).await; Conversation::get_complete_conversation(&conversation_id, user_id_2, &db).await;
assert!(result.is_err()); assert!(result.is_err());
match result { match result {
Err(AppError::Auth(_)) => { /* expected error */ } Err(AppError::Auth(_)) => {}
_ => panic!("Expected Auth error"), _ => anyhow::bail!("Expected Auth error"),
} }
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_patch_title_success() { async fn test_patch_title_success() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
let user_id = "user_1"; let user_id = "user_1";
let original_title = "Original Title"; let original_title = "Original Title";
@@ -250,49 +246,50 @@ mod tests {
db.store_item(conversation) db.store_item(conversation)
.await .await
.expect("Failed to store conversation"); .with_context(|| "Failed to store conversation".to_string())?;
let new_title = "Updated Title"; let new_title = "Updated Title";
// Patch title successfully
let result = Conversation::patch_title(&conversation_id, user_id, new_title, &db).await; let result = Conversation::patch_title(&conversation_id, user_id, new_title, &db).await;
assert!(result.is_ok()); assert!(result.is_ok());
// Retrieve from DB to verify
let updated_conversation = db let updated_conversation = db
.get_item::<Conversation>(&conversation_id) .get_item::<Conversation>(&conversation_id)
.await .await
.expect("Failed to get conversation") .with_context(|| "Failed to get conversation".to_string())?
.expect("Conversation missing"); .ok_or_else(|| anyhow::anyhow!("Conversation missing"))?;
assert_eq!(updated_conversation.title, new_title); assert_eq!(updated_conversation.title, new_title);
assert_eq!(updated_conversation.user_id, user_id); assert_eq!(updated_conversation.user_id, user_id);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_patch_title_not_found() { async fn test_patch_title_not_found() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Try to patch non-existing conversation
let result = Conversation::patch_title("nonexistent", "user_x", "New Title", &db).await; let result = Conversation::patch_title("nonexistent", "user_x", "New Title", &db).await;
assert!(result.is_err()); assert!(result.is_err());
match result { match result {
Err(AppError::NotFound(_)) => {} Err(AppError::NotFound(_)) => {}
_ => panic!("Expected NotFound error"), _ => anyhow::bail!("Expected NotFound error"),
} }
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_patch_title_unauthorized() { async fn test_patch_title_unauthorized() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
let owner_id = "owner"; let owner_id = "owner";
let other_user_id = "intruder"; let other_user_id = "intruder";
@@ -301,17 +298,18 @@ mod tests {
db.store_item(conversation) db.store_item(conversation)
.await .await
.expect("Failed to store conversation"); .with_context(|| "Failed to store conversation".to_string())?;
// Attempt patch with unauthorized user
let result = let result =
Conversation::patch_title(&conversation_id, other_user_id, "Hacked Title", &db).await; Conversation::patch_title(&conversation_id, other_user_id, "Hacked Title", &db).await;
assert!(result.is_err()); assert!(result.is_err());
match result { match result {
Err(AppError::Auth(_)) => {} Err(AppError::Auth(_)) => {}
_ => panic!("Expected Auth error"), _ => anyhow::bail!("Expected Auth error"),
} }
Ok(())
} }
#[tokio::test] #[tokio::test]
@@ -405,24 +403,21 @@ mod tests {
} }
#[tokio::test] #[tokio::test]
async fn test_get_complete_conversation_with_messages() { async fn test_get_complete_conversation_with_messages() -> anyhow::Result<()> {
// Setup in-memory database for testing
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Create and store a conversation for user_id_1
let user_id_1 = "user_1"; let user_id_1 = "user_1";
let conversation = Conversation::new(user_id_1.to_string(), "Conversation".to_string()); let conversation = Conversation::new(user_id_1.to_string(), "Conversation".to_string());
let conversation_id = conversation.id.clone(); let conversation_id = conversation.id.clone();
db.store_item(conversation) db.store_item(conversation)
.await .await
.expect("Failed to store conversation"); .with_context(|| "Failed to store conversation".to_string())?;
// Create messages
let message1 = Message::new( let message1 = Message::new(
conversation_id.clone(), conversation_id.clone(),
MessageRole::User, MessageRole::User,
@@ -442,46 +437,44 @@ mod tests {
None, None,
); );
// Store messages
db.store_item(message1) db.store_item(message1)
.await .await
.expect("Failed to store message1"); .with_context(|| "Failed to store message1".to_string())?;
db.store_item(message2) db.store_item(message2)
.await .await
.expect("Failed to store message2"); .with_context(|| "Failed to store message2".to_string())?;
db.store_item(message3) db.store_item(message3)
.await .await
.expect("Failed to store message3"); .with_context(|| "Failed to store message3".to_string())?;
// Retrieve the complete conversation
let result = let result =
Conversation::get_complete_conversation(&conversation_id, user_id_1, &db).await; Conversation::get_complete_conversation(&conversation_id, user_id_1, &db).await;
assert!(result.is_ok(), "Failed to retrieve complete conversation"); assert!(result.is_ok(), "Failed to retrieve complete conversation");
let (retrieved_conversation, messages) = result.unwrap(); let (retrieved_conversation, retrieved_messages) = result
.with_context(|| "Failed to retrieve complete conversation".to_string())?;
// Verify conversation data
assert_eq!(retrieved_conversation.id, conversation_id); assert_eq!(retrieved_conversation.id, conversation_id);
assert_eq!(retrieved_conversation.user_id, user_id_1); assert_eq!(retrieved_conversation.user_id, user_id_1);
assert_eq!(retrieved_conversation.title, "Conversation"); assert_eq!(retrieved_conversation.title, "Conversation");
// Verify messages assert_eq!(retrieved_messages.len(), 3);
assert_eq!(messages.len(), 3);
// Verify messages are sorted by updated_at let message_contents: Vec<&str> =
let message_contents: Vec<&str> = messages.iter().map(|m| m.content.as_str()).collect(); retrieved_messages.iter().map(|m| m.content.as_str()).collect();
assert!(message_contents.contains(&"Hello, AI!")); assert!(message_contents.contains(&"Hello, AI!"));
assert!(message_contents.contains(&"Hello, human! How can I help you today?")); assert!(message_contents.contains(&"Hello, human! How can I help you today?"));
assert!(message_contents.contains(&"Tell me about Rust programming.")); assert!(message_contents.contains(&"Tell me about Rust programming."));
// Make sure we can't access with different user
let user_id_2 = "user_2"; let user_id_2 = "user_2";
let unauthorized_result = let unauthorized_result =
Conversation::get_complete_conversation(&conversation_id, user_id_2, &db).await; Conversation::get_complete_conversation(&conversation_id, user_id_2, &db).await;
assert!(unauthorized_result.is_err()); assert!(unauthorized_result.is_err());
match unauthorized_result { match unauthorized_result {
Err(AppError::Auth(_)) => { /* expected error */ } Err(AppError::Auth(_)) => {}
_ => panic!("Expected Auth error"), _ => anyhow::bail!("Expected Auth error"),
} }
Ok(())
} }
} }
+152 -145
View File
@@ -320,6 +320,8 @@ impl FileInfo {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use anyhow::{self, Context};
use super::*; use super::*;
use crate::storage::store::testing::TestStorageManager; use crate::storage::store::testing::TestStorageManager;
use axum::http::HeaderMap; use axum::http::HeaderMap;
@@ -328,11 +330,11 @@ mod tests {
use tempfile::NamedTempFile; use tempfile::NamedTempFile;
/// Creates a test temporary file with the given content /// Creates a test temporary file with the given content
fn create_test_file(content: &[u8], file_name: &str) -> FieldData<NamedTempFile> { fn create_test_file(content: &[u8], file_name: &str) -> anyhow::Result<FieldData<NamedTempFile>> {
let mut temp_file = NamedTempFile::new().expect("Failed to create temp file"); let mut temp_file = NamedTempFile::new().with_context(|| "Failed to create temp file".to_string())?;
temp_file temp_file
.write_all(content) .write_all(content)
.expect("Failed to write to temp file"); .with_context(|| "Failed to write to temp file".to_string())?;
let metadata = FieldMetadata { let metadata = FieldMetadata {
name: Some("file".to_string()), name: Some("file".to_string()),
@@ -341,31 +343,29 @@ mod tests {
headers: HeaderMap::default(), headers: HeaderMap::default(),
}; };
let field_data = FieldData { Ok(FieldData {
metadata, metadata,
contents: temp_file, contents: temp_file,
}; })
field_data
} }
#[tokio::test] #[tokio::test]
async fn test_fileinfo_create_read_delete_with_storage_manager() { async fn test_fileinfo_create_read_delete_with_storage_manager() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations().await.unwrap(); db.apply_migrations().await.with_context(|| "migrations".to_string())?;
let content = b"This is a test file for StorageManager operations"; let content = b"This is a test file for StorageManager operations";
let file_name = "storage_manager_test.txt"; let file_name = "storage_manager_test.txt";
let field_data = create_test_file(content, file_name); let field_data = create_test_file(content, file_name)?;
// Create test storage manager (memory backend) // Create test storage manager (memory backend)
let test_storage = store::testing::TestStorageManager::new_memory() let test_storage = store::testing::TestStorageManager::new_memory()
.await .await
.expect("Failed to create test storage manager"); .with_context(|| "Failed to create test storage manager".to_string())?;
// Create a FileInfo instance with storage manager // Create a FileInfo instance with storage manager
let user_id = "test_user"; let user_id = "test_user";
@@ -374,20 +374,20 @@ mod tests {
let file_info = let file_info =
FileInfo::new_with_storage(field_data, &db, user_id, test_storage.storage()) FileInfo::new_with_storage(field_data, &db, user_id, test_storage.storage())
.await .await
.expect("Failed to create file with StorageManager"); .with_context(|| "Failed to create file with StorageManager".to_string())?;
assert_eq!(file_info.file_name, file_name); assert_eq!(file_info.file_name, file_name);
// Verify the file exists via StorageManager and has correct content // Verify the file exists via StorageManager and has correct content
let bytes = file_info let bytes = file_info
.get_content_with_storage(test_storage.storage()) .get_content_with_storage(test_storage.storage())
.await .await
.expect("Failed to read file content via StorageManager"); .with_context(|| "Failed to read file content via StorageManager".to_string())?;
assert_eq!(bytes.as_ref(), content); assert_eq!(bytes.as_ref(), content);
// Test file reading // Test file reading
let retrieved = FileInfo::get_by_id(&file_info.id, &db) let retrieved = FileInfo::get_by_id(&file_info.id, &db)
.await .await
.expect("Failed to retrieve file info"); .with_context(|| "Failed to retrieve file info".to_string())?;
assert_eq!(retrieved.id, file_info.id); assert_eq!(retrieved.id, file_info.id);
assert_eq!(retrieved.sha256, file_info.sha256); assert_eq!(retrieved.sha256, file_info.sha256);
assert_eq!(retrieved.file_name, file_name); assert_eq!(retrieved.file_name, file_name);
@@ -395,65 +395,65 @@ mod tests {
// Test file deletion with StorageManager // Test file deletion with StorageManager
FileInfo::delete_by_id_with_storage(&file_info.id, &db, test_storage.storage()) FileInfo::delete_by_id_with_storage(&file_info.id, &db, test_storage.storage())
.await .await
.expect("Failed to delete file with StorageManager"); .with_context(|| "Failed to delete file with StorageManager".to_string())?;
let deleted_result = file_info let deleted_result = file_info
.get_content_with_storage(test_storage.storage()) .get_content_with_storage(test_storage.storage())
.await; .await;
assert!(deleted_result.is_err(), "File should be deleted"); assert!(deleted_result.is_err(), "File should be deleted");
Ok(())
// No cleanup needed - TestStorageManager handles it automatically
} }
#[tokio::test] #[tokio::test]
async fn test_fileinfo_preserves_original_filename_and_sanitizes_path() { async fn test_fileinfo_preserves_original_filename_and_sanitizes_path() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations().await.unwrap(); db.apply_migrations().await.with_context(|| "migrations".to_string())?;
let content = b"filename sanitization"; let content = b"filename sanitization";
let original_name = "Complex name (1).txt"; let original_name = "Complex name (1).txt";
let expected_sanitized = "Complex_name__1_.txt"; let expected_sanitized = "Complex_name__1_.txt";
let field_data = create_test_file(content, original_name); let field_data = create_test_file(content, original_name)?;
let test_storage = store::testing::TestStorageManager::new_memory() let test_storage = store::testing::TestStorageManager::new_memory()
.await .await
.expect("Failed to create test storage manager"); .with_context(|| "Failed to create test storage manager".to_string())?;
let file_info = let file_info =
FileInfo::new_with_storage(field_data, &db, "sanitized_user", test_storage.storage()) FileInfo::new_with_storage(field_data, &db, "sanitized_user", test_storage.storage())
.await .await
.expect("Failed to create file via storage manager"); .with_context(|| "Failed to create file via storage manager".to_string())?;
assert_eq!(file_info.file_name, original_name); assert_eq!(file_info.file_name, original_name);
let stored_name = Path::new(&file_info.path) let stored_name = Path::new(&file_info.path)
.file_name() .file_name()
.and_then(|name| name.to_str()) .and_then(|name| name.to_str())
.expect("stored name"); .with_context(|| "stored name".to_string())?;
assert_eq!(stored_name, expected_sanitized); assert_eq!(stored_name, expected_sanitized);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_fileinfo_duplicate_detection_with_storage_manager() { async fn test_fileinfo_duplicate_detection_with_storage_manager() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations().await.unwrap(); db.apply_migrations().await.with_context(|| "migrations".to_string())?;
let content = b"This is a test file for StorageManager duplicate detection"; let content = b"This is a test file for StorageManager duplicate detection";
let file_name = "storage_manager_duplicate.txt"; let file_name = "storage_manager_duplicate.txt";
let field_data = create_test_file(content, file_name); let field_data = create_test_file(content, file_name)?;
// Create test storage manager // Create test storage manager
let test_storage = store::testing::TestStorageManager::new_memory() let test_storage = store::testing::TestStorageManager::new_memory()
.await .await
.expect("Failed to create test storage manager"); .with_context(|| "Failed to create test storage manager".to_string())?;
// Create a FileInfo instance with storage manager // Create a FileInfo instance with storage manager
let user_id = "test_user"; let user_id = "test_user";
@@ -462,17 +462,17 @@ mod tests {
let original_file_info = let original_file_info =
FileInfo::new_with_storage(field_data, &db, user_id, test_storage.storage()) FileInfo::new_with_storage(field_data, &db, user_id, test_storage.storage())
.await .await
.expect("Failed to create original file with StorageManager"); .with_context(|| "Failed to create original file with StorageManager".to_string())?;
// Create another file with the same content but different name // Create another file with the same content but different name
let duplicate_name = "storage_manager_duplicate_2.txt"; let duplicate_name = "storage_manager_duplicate_2.txt";
let field_data2 = create_test_file(content, duplicate_name); let field_data2 = create_test_file(content, duplicate_name)?;
// The system should detect it's the same file and return the original FileInfo // The system should detect it's the same file and return the original FileInfo
let duplicate_file_info = let duplicate_file_info =
FileInfo::new_with_storage(field_data2, &db, user_id, test_storage.storage()) FileInfo::new_with_storage(field_data2, &db, user_id, test_storage.storage())
.await .await
.expect("Failed to process duplicate file with StorageManager"); .with_context(|| "Failed to process duplicate file with StorageManager".to_string())?;
// Verify duplicate detection worked // Verify duplicate detection worked
assert_eq!(duplicate_file_info.id, original_file_info.id); assert_eq!(duplicate_file_info.id, original_file_info.id);
@@ -484,46 +484,44 @@ mod tests {
let original_content = original_file_info let original_content = original_file_info
.get_content_with_storage(test_storage.storage()) .get_content_with_storage(test_storage.storage())
.await .await
.unwrap(); .with_context(|| "get original content".to_string())?;
let duplicate_content = duplicate_file_info let duplicate_content = duplicate_file_info
.get_content_with_storage(test_storage.storage()) .get_content_with_storage(test_storage.storage())
.await .await
.unwrap(); .with_context(|| "get duplicate content".to_string())?;
assert_eq!(original_content.as_ref(), content); assert_eq!(original_content.as_ref(), content);
assert_eq!(duplicate_content.as_ref(), content); assert_eq!(duplicate_content.as_ref(), content);
// Clean up // Clean up
FileInfo::delete_by_id_with_storage(&original_file_info.id, &db, test_storage.storage()) FileInfo::delete_by_id_with_storage(&original_file_info.id, &db, test_storage.storage())
.await .await
.expect("Failed to delete original file with StorageManager"); .with_context(|| "Failed to delete original file with StorageManager".to_string())?;
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_file_creation() { async fn test_file_creation() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Failed to apply migrations"); .with_context(|| "Failed to apply migrations".to_string())?;
let content = b"This is a test file content"; let content = b"This is a test file content";
let file_name = "test_file.txt"; let file_name = "test_file.txt";
let field_data = create_test_file(content, file_name); let field_data = create_test_file(content, file_name)?;
// Create a FileInfo instance with StorageManager // Create a FileInfo instance with StorageManager
let user_id = "test_user"; let user_id = "test_user";
let test_storage = TestStorageManager::new_memory() let test_storage = TestStorageManager::new_memory()
.await .await
.expect("create test storage manager"); .with_context(|| "create test storage manager".to_string())?;
let file_info = let file_info =
FileInfo::new_with_storage(field_data, &db, user_id, test_storage.storage()).await; FileInfo::new_with_storage(field_data, &db, user_id, test_storage.storage())
.await?;
// Verify the FileInfo was created successfully
assert!(file_info.is_ok());
let file_info = file_info.unwrap();
// Check essential properties // Check essential properties
assert!(!file_info.id.is_empty()); assert!(!file_info.id.is_empty());
@@ -533,32 +531,32 @@ mod tests {
// path should be logical: "user_id/uuid/file_name" // path should be logical: "user_id/uuid/file_name"
let parts: Vec<&str> = file_info.path.split('/').collect(); let parts: Vec<&str> = file_info.path.split('/').collect();
assert_eq!(parts.len(), 3); assert_eq!(parts.len(), 3);
assert_eq!(parts[0], user_id); assert_eq!(parts.first(), Some(&user_id));
assert_eq!(parts[2], file_name); assert_eq!(parts.get(2), Some(&file_name));
assert!(file_info.mime_type.contains("text/plain")); assert!(file_info.mime_type.contains("text/plain"));
// Verify it's in the database // Verify it's in the database
let stored: Option<FileInfo> = db let stored = db
.get_item(&file_info.id) .get_item::<FileInfo>(&file_info.id)
.await .await
.expect("Failed to retrieve file info"); .with_context(|| "Failed to retrieve file info".to_string())?
assert!(stored.is_some()); .with_context(|| "expected stored file".to_string())?;
let stored = stored.unwrap();
assert_eq!(stored.id, file_info.id); assert_eq!(stored.id, file_info.id);
assert_eq!(stored.file_name, file_info.file_name); assert_eq!(stored.file_name, file_info.file_name);
assert_eq!(stored.sha256, file_info.sha256); assert_eq!(stored.sha256, file_info.sha256);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_file_duplicate_detection() { async fn test_file_duplicate_detection() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Failed to apply migrations"); .with_context(|| "Failed to apply migrations".to_string())?;
// First, store a file with known content // First, store a file with known content
let content = b"This is a test file for duplicate detection"; let content = b"This is a test file for duplicate detection";
@@ -567,23 +565,23 @@ mod tests {
let test_storage = TestStorageManager::new_memory() let test_storage = TestStorageManager::new_memory()
.await .await
.expect("create test storage manager"); .with_context(|| "create test storage manager".to_string())?;
let field_data1 = create_test_file(content, file_name); let field_data1 = create_test_file(content, file_name)?;
let original_file_info = let original_file_info =
FileInfo::new_with_storage(field_data1, &db, user_id, test_storage.storage()) FileInfo::new_with_storage(field_data1, &db, user_id, test_storage.storage())
.await .await
.expect("Failed to create original file"); .with_context(|| "Failed to create original file".to_string())?;
// Now try to store another file with the same content but different name // Now try to store another file with the same content but different name
let duplicate_name = "duplicate.txt"; let duplicate_name = "duplicate.txt";
let field_data2 = create_test_file(content, duplicate_name); let field_data2 = create_test_file(content, duplicate_name)?;
// The system should detect it's the same file and return the original FileInfo // The system should detect it's the same file and return the original FileInfo
let duplicate_file_info = let duplicate_file_info =
FileInfo::new_with_storage(field_data2, &db, user_id, test_storage.storage()) FileInfo::new_with_storage(field_data2, &db, user_id, test_storage.storage())
.await .await
.expect("Failed to process duplicate file"); .with_context(|| "Failed to process duplicate file".to_string())?;
// The returned FileInfo should match the original // The returned FileInfo should match the original
assert_eq!(duplicate_file_info.id, original_file_info.id); assert_eq!(duplicate_file_info.id, original_file_info.id);
@@ -592,10 +590,11 @@ mod tests {
// But it should retain the original file name, not the duplicate's name // But it should retain the original file name, not the duplicate's name
assert_eq!(duplicate_file_info.file_name, file_name); assert_eq!(duplicate_file_info.file_name, file_name);
assert_ne!(duplicate_file_info.file_name, duplicate_name); assert_ne!(duplicate_file_info.file_name, duplicate_name);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_guess_mime_type() { async fn test_guess_mime_type() -> anyhow::Result<()> {
// Test common file extensions // Test common file extensions
assert_eq!( assert_eq!(
FileInfo::guess_mime_type(Path::new("test.txt")), FileInfo::guess_mime_type(Path::new("test.txt")),
@@ -619,10 +618,11 @@ mod tests {
FileInfo::guess_mime_type(Path::new("unknown.929yz")), FileInfo::guess_mime_type(Path::new("unknown.929yz")),
"application/octet-stream".to_string() "application/octet-stream".to_string()
); );
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_sanitize_file_name() { async fn test_sanitize_file_name() -> anyhow::Result<()> {
// Safe characters should remain unchanged // Safe characters should remain unchanged
assert_eq!( assert_eq!(
FileInfo::sanitize_file_name("normal_file.txt"), FileInfo::sanitize_file_name("normal_file.txt"),
@@ -647,26 +647,26 @@ mod tests {
FileInfo::sanitize_file_name("../dangerous.txt"), FileInfo::sanitize_file_name("../dangerous.txt"),
"___dangerous.txt" "___dangerous.txt"
); );
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_get_by_sha_not_found() { async fn test_get_by_sha_not_found() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Try to find a file with a SHA that doesn't exist // Try to find a file with a SHA that doesn't exist
let result = FileInfo::get_by_sha("nonexistent_sha_hash", &db).await; let result = FileInfo::get_by_sha("nonexistent_sha_hash", &db).await;
assert!(result.is_err()); assert!(result.is_err());
match result { match result {
Err(FileError::FileNotFound(_)) => { Err(FileError::FileNotFound(_)) => {}
// Expected error _ => anyhow::bail!("Expected FileNotFound error"),
}
_ => panic!("Expected FileNotFound error"),
} }
Ok(())
} }
#[tokio::test] #[tokio::test]
@@ -705,7 +705,7 @@ mod tests {
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Create a FileInfo instance directly // Create a FileInfo instance directly
let now = Utc::now(); let now = Utc::now();
@@ -725,40 +725,39 @@ mod tests {
assert!(result.is_ok()); assert!(result.is_ok());
// Verify it can be retrieved // Verify it can be retrieved
let retrieved: Option<FileInfo> = db let retrieved = db
.get_item(&file_info.id) .get_item::<FileInfo>(&file_info.id)
.await .await
.expect("Failed to retrieve file info"); .with_context(|| "Failed to retrieve file info".to_string())?
assert!(retrieved.is_some()); .with_context(|| "expected file".to_string())?;
let retrieved = retrieved.unwrap();
assert_eq!(retrieved.id, file_info.id); assert_eq!(retrieved.id, file_info.id);
assert_eq!(retrieved.sha256, file_info.sha256); assert_eq!(retrieved.sha256, file_info.sha256);
assert_eq!(retrieved.file_name, file_info.file_name); assert_eq!(retrieved.file_name, file_info.file_name);
assert_eq!(retrieved.path, file_info.path); assert_eq!(retrieved.path, file_info.path);
assert_eq!(retrieved.mime_type, file_info.mime_type); assert_eq!(retrieved.mime_type, file_info.mime_type);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_delete_by_id() { async fn test_delete_by_id() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Failed to apply migrations"); .with_context(|| "Failed to apply migrations".to_string())?;
// Create and persist a test file via FileInfo::new_with_storage // Create and persist a test file via FileInfo::new_with_storage
let user_id = "user123"; let user_id = "user123";
let test_storage = TestStorageManager::new_memory() let test_storage = TestStorageManager::new_memory()
.await .await
.expect("create test storage manager"); .with_context(|| "create test storage manager".to_string())?;
let temp = create_test_file(b"test content", "test_file.txt"); let temp = create_test_file(b"test content", "test_file.txt")?;
let file_info = FileInfo::new_with_storage(temp, &db, user_id, test_storage.storage()) let file_info = FileInfo::new_with_storage(temp, &db, user_id, test_storage.storage())
.await .await
.expect("create file"); .with_context(|| "create file".to_string())?;
// Delete the file using StorageManager // Delete the file using StorageManager
let delete_result = let delete_result =
@@ -767,15 +766,14 @@ mod tests {
// Delete should be successful // Delete should be successful
assert!( assert!(
delete_result.is_ok(), delete_result.is_ok(),
"Failed to delete file: {:?}", "Failed to delete file: {delete_result:?}"
delete_result
); );
// Verify the file is removed from the database // Verify the file is removed from the database
let retrieved: Option<FileInfo> = db let retrieved: Option<FileInfo> = db
.get_item(&file_info.id) .get_item(&file_info.id)
.await .await
.expect("Failed to query database"); .with_context(|| "Failed to query database".to_string())?;
assert!( assert!(
retrieved.is_none(), retrieved.is_none(),
"FileInfo should be deleted from the database" "FileInfo should be deleted from the database"
@@ -783,32 +781,37 @@ mod tests {
// Verify content no longer retrievable from storage // Verify content no longer retrievable from storage
assert!(test_storage.storage().get(&file_info.path).await.is_err()); assert!(test_storage.storage().get(&file_info.path).await.is_err());
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_delete_by_id_not_found() { async fn test_delete_by_id_not_found() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Try to delete a file that doesn't exist // Try to delete a file that doesn't exist
let test_storage = TestStorageManager::new_memory().await.unwrap(); let test_storage = TestStorageManager::new_memory()
.await
.with_context(|| "create test storage manager".to_string())?;
let result = let result =
FileInfo::delete_by_id_with_storage("nonexistent_id", &db, test_storage.storage()) FileInfo::delete_by_id_with_storage("nonexistent_id", &db, test_storage.storage())
.await; .await;
// Should succeed even if the file record does not exist // Should succeed even if the file record does not exist
assert!(result.is_ok()); assert!(result.is_ok());
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_get_by_id() { async fn test_get_by_id() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Create a FileInfo instance directly // Create a FileInfo instance directly
let now = Utc::now(); let now = Utc::now();
@@ -827,28 +830,27 @@ mod tests {
// Store it in the database // Store it in the database
db.store_item(original_file_info.clone()) db.store_item(original_file_info.clone())
.await .await
.expect("Failed to store item for get_by_id test"); .with_context(|| "Failed to store item for get_by_id test".to_string())?;
// Retrieve it using get_by_id // Retrieve it using get_by_id
let result = FileInfo::get_by_id(&file_id, &db).await; let retrieved_info = FileInfo::get_by_id(&file_id, &db)
.await
// Assert success and content match .with_context(|| "get_by_id".to_string())?;
assert!(result.is_ok());
let retrieved_info = result.unwrap();
assert_eq!(retrieved_info.id, original_file_info.id); assert_eq!(retrieved_info.id, original_file_info.id);
assert_eq!(retrieved_info.sha256, original_file_info.sha256); assert_eq!(retrieved_info.sha256, original_file_info.sha256);
assert_eq!(retrieved_info.file_name, original_file_info.file_name); assert_eq!(retrieved_info.file_name, original_file_info.file_name);
assert_eq!(retrieved_info.path, original_file_info.path); assert_eq!(retrieved_info.path, original_file_info.path);
assert_eq!(retrieved_info.mime_type, original_file_info.mime_type); assert_eq!(retrieved_info.mime_type, original_file_info.mime_type);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_get_by_id_not_found() { async fn test_get_by_id_not_found() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Try to retrieve a non-existent ID // Try to retrieve a non-existent ID
let non_existent_id = "non-existent-file-id"; let non_existent_id = "non-existent-file-id";
@@ -862,33 +864,34 @@ mod tests {
Err(FileError::FileNotFound(id)) => { Err(FileError::FileNotFound(id)) => {
assert_eq!(id, non_existent_id); assert_eq!(id, non_existent_id);
} }
Err(e) => panic!("Expected FileNotFound error, but got {:?}", e), Err(e) => anyhow::bail!("Expected FileNotFound error, but got {e:?}"),
Ok(_) => panic!("Expected an error, but got Ok"), Ok(_) => anyhow::bail!("Expected an error, but got Ok"),
} }
Ok(())
} }
// StorageManager-based tests // StorageManager-based tests
#[tokio::test] #[tokio::test]
async fn test_file_info_new_with_storage_memory() { async fn test_file_info_new_with_storage_memory() -> anyhow::Result<()> {
// Setup // Setup
let db = SurrealDbClient::memory("test_ns", "test_file_storage_memory") let db = SurrealDbClient::memory("test_ns", "test_file_storage_memory")
.await .await
.unwrap(); .with_context(|| "Failed to start DB".to_string())?;
db.apply_migrations().await.unwrap(); db.apply_migrations().await.with_context(|| "migrations".to_string())?;
let content = b"This is a test file for StorageManager"; let content = b"This is a test file for StorageManager";
let field_data = create_test_file(content, "test_storage.txt"); let field_data = create_test_file(content, "test_storage.txt")?;
let user_id = "test_user"; let user_id = "test_user";
// Create test storage manager // Create test storage manager
let storage = store::testing::TestStorageManager::new_memory() let storage = store::testing::TestStorageManager::new_memory()
.await .await
.unwrap(); .with_context(|| "create test storage".to_string())?;
// Test file creation with StorageManager // Test file creation with StorageManager
let file_info = FileInfo::new_with_storage(field_data, &db, user_id, storage.storage()) let file_info = FileInfo::new_with_storage(field_data, &db, user_id, storage.storage())
.await .await
.expect("Failed to create file with StorageManager"); .with_context(|| "Failed to create file with StorageManager".to_string())?;
// Verify the file was created correctly // Verify the file was created correctly
assert_eq!(file_info.user_id, user_id); assert_eq!(file_info.user_id, user_id);
@@ -900,40 +903,41 @@ mod tests {
let retrieved_content = file_info let retrieved_content = file_info
.get_content_with_storage(storage.storage()) .get_content_with_storage(storage.storage())
.await .await
.expect("Failed to get file content with StorageManager"); .with_context(|| "Failed to get file content with StorageManager".to_string())?;
assert_eq!(retrieved_content.as_ref(), content); assert_eq!(retrieved_content.as_ref(), content);
// Test file deletion with StorageManager // Test file deletion with StorageManager
FileInfo::delete_by_id_with_storage(&file_info.id, &db, storage.storage()) FileInfo::delete_by_id_with_storage(&file_info.id, &db, storage.storage())
.await .await
.expect("Failed to delete file with StorageManager"); .with_context(|| "Failed to delete file with StorageManager".to_string())?;
// Verify file is deleted // Verify file is deleted
let deleted_content_result = file_info.get_content_with_storage(storage.storage()).await; let deleted_content_result = file_info.get_content_with_storage(storage.storage()).await;
assert!(deleted_content_result.is_err()); assert!(deleted_content_result.is_err());
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_file_info_new_with_storage_local() { async fn test_file_info_new_with_storage_local() -> anyhow::Result<()> {
// Setup // Setup
let db = SurrealDbClient::memory("test_ns", "test_file_storage_local") let db = SurrealDbClient::memory("test_ns", "test_file_storage_local")
.await .await
.unwrap(); .with_context(|| "Failed to start DB".to_string())?;
db.apply_migrations().await.unwrap(); db.apply_migrations().await.with_context(|| "migrations".to_string())?;
let content = b"This is a test file for StorageManager with local storage"; let content = b"This is a test file for StorageManager with local storage";
let field_data = create_test_file(content, "test_local.txt"); let field_data = create_test_file(content, "test_local.txt")?;
let user_id = "test_user"; let user_id = "test_user";
// Create test storage manager with local backend // Create test storage manager with local backend
let storage = store::testing::TestStorageManager::new_local() let storage = store::testing::TestStorageManager::new_local()
.await .await
.unwrap(); .with_context(|| "create test storage".to_string())?;
// Test file creation with StorageManager // Test file creation with StorageManager
let file_info = FileInfo::new_with_storage(field_data, &db, user_id, storage.storage()) let file_info = FileInfo::new_with_storage(field_data, &db, user_id, storage.storage())
.await .await
.expect("Failed to create file with StorageManager"); .with_context(|| "Failed to create file with StorageManager".to_string())?;
// Verify the file was created correctly // Verify the file was created correctly
assert_eq!(file_info.user_id, user_id); assert_eq!(file_info.user_id, user_id);
@@ -945,50 +949,51 @@ mod tests {
let retrieved_content = file_info let retrieved_content = file_info
.get_content_with_storage(storage.storage()) .get_content_with_storage(storage.storage())
.await .await
.expect("Failed to get file content with StorageManager"); .with_context(|| "Failed to get file content with StorageManager".to_string())?;
assert_eq!(retrieved_content.as_ref(), content); assert_eq!(retrieved_content.as_ref(), content);
// Test file deletion with StorageManager // Test file deletion with StorageManager
FileInfo::delete_by_id_with_storage(&file_info.id, &db, storage.storage()) FileInfo::delete_by_id_with_storage(&file_info.id, &db, storage.storage())
.await .await
.expect("Failed to delete file with StorageManager"); .with_context(|| "Failed to delete file with StorageManager".to_string())?;
// Verify file is deleted // Verify file is deleted
let deleted_content_result = file_info.get_content_with_storage(storage.storage()).await; let deleted_content_result = file_info.get_content_with_storage(storage.storage()).await;
assert!(deleted_content_result.is_err()); assert!(deleted_content_result.is_err());
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_file_info_storage_manager_persistence() { async fn test_file_info_storage_manager_persistence() -> anyhow::Result<()> {
// Setup // Setup
let db = SurrealDbClient::memory("test_ns", "test_file_persistence") let db = SurrealDbClient::memory("test_ns", "test_file_persistence")
.await .await
.unwrap(); .with_context(|| "Failed to start DB".to_string())?;
db.apply_migrations().await.unwrap(); db.apply_migrations().await.with_context(|| "migrations".to_string())?;
let content = b"Test content for persistence"; let content = b"Test content for persistence";
let field_data = create_test_file(content, "persistence_test.txt"); let field_data = create_test_file(content, "persistence_test.txt")?;
let user_id = "test_user"; let user_id = "test_user";
// Create test storage manager // Create test storage manager
let storage = store::testing::TestStorageManager::new_memory() let storage = store::testing::TestStorageManager::new_memory()
.await .await
.unwrap(); .with_context(|| "create test storage".to_string())?;
// Create file // Create file
let file_info = FileInfo::new_with_storage(field_data, &db, user_id, storage.storage()) let file_info = FileInfo::new_with_storage(field_data, &db, user_id, storage.storage())
.await .await
.expect("Failed to create file"); .with_context(|| "Failed to create file".to_string())?;
// Test that data persists across multiple operations with the same StorageManager // Test that data persists across multiple operations with the same StorageManager
let retrieved_content_1 = file_info let retrieved_content_1 = file_info
.get_content_with_storage(storage.storage()) .get_content_with_storage(storage.storage())
.await .await
.unwrap(); .with_context(|| "get content 1".to_string())?;
let retrieved_content_2 = file_info let retrieved_content_2 = file_info
.get_content_with_storage(storage.storage()) .get_content_with_storage(storage.storage())
.await .await
.unwrap(); .with_context(|| "get content 2".to_string())?;
assert_eq!(retrieved_content_1.as_ref(), content); assert_eq!(retrieved_content_1.as_ref(), content);
assert_eq!(retrieved_content_2.as_ref(), content); assert_eq!(retrieved_content_2.as_ref(), content);
@@ -996,68 +1001,70 @@ mod tests {
// Test that different StorageManager instances don't share data (memory storage isolation) // Test that different StorageManager instances don't share data (memory storage isolation)
let storage2 = store::testing::TestStorageManager::new_memory() let storage2 = store::testing::TestStorageManager::new_memory()
.await .await
.unwrap(); .with_context(|| "create second storage".to_string())?;
let isolated_content_result = file_info.get_content_with_storage(storage2.storage()).await; let isolated_content_result = file_info.get_content_with_storage(storage2.storage()).await;
assert!( assert!(
isolated_content_result.is_err(), isolated_content_result.is_err(),
"Different StorageManager should not have access to same data" "Different StorageManager should not have access to same data"
); );
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_file_info_storage_manager_equivalence() { async fn test_file_info_storage_manager_equivalence() -> anyhow::Result<()> {
// Setup // Setup
let db = SurrealDbClient::memory("test_ns", "test_file_equivalence") let db = SurrealDbClient::memory("test_ns", "test_file_equivalence")
.await .await
.unwrap(); .with_context(|| "Failed to start DB".to_string())?;
db.apply_migrations().await.unwrap(); db.apply_migrations().await.with_context(|| "migrations".to_string())?;
let content = b"Test content for equivalence testing"; let content = b"Test content for equivalence testing";
let field_data1 = create_test_file(content, "equivalence_test_1.txt"); let field_data1 = create_test_file(content, "equivalence_test_1.txt")?;
let field_data2 = create_test_file(content, "equivalence_test_2.txt"); let field_data2 = create_test_file(content, "equivalence_test_2.txt")?;
let user_id = "test_user"; let user_id = "test_user";
// Create single storage manager and reuse it // Create single storage manager and reuse it
let storage_manager = store::testing::TestStorageManager::new_memory() let storage_manager = store::testing::TestStorageManager::new_memory()
.await .await
.unwrap(); .with_context(|| "create storage".to_string())?;
let storage = storage_manager.storage(); let storage = storage_manager.storage();
// Create multiple files with the same storage manager // Create multiple files with the same storage manager
let file_info_1 = FileInfo::new_with_storage(field_data1, &db, user_id, &storage) let file_info_1 = FileInfo::new_with_storage(field_data1, &db, user_id, storage)
.await .await
.expect("Failed to create file 1"); .with_context(|| "Failed to create file 1".to_string())?;
let file_info_2 = FileInfo::new_with_storage(field_data2, &db, user_id, &storage) let file_info_2 = FileInfo::new_with_storage(field_data2, &db, user_id, storage)
.await .await
.expect("Failed to create file 2"); .with_context(|| "Failed to create file 2".to_string())?;
// Test that both files can be retrieved with the same storage backend // Test that both files can be retrieved with the same storage backend
let content_1 = file_info_1 let content_1 = file_info_1
.get_content_with_storage(&storage) .get_content_with_storage(storage)
.await .await
.unwrap(); .with_context(|| "get file 1 content".to_string())?;
let content_2 = file_info_2 let content_2 = file_info_2
.get_content_with_storage(&storage) .get_content_with_storage(storage)
.await .await
.unwrap(); .with_context(|| "get file 2 content".to_string())?;
assert_eq!(content_1.as_ref(), content); assert_eq!(content_1.as_ref(), content);
assert_eq!(content_2.as_ref(), content); assert_eq!(content_2.as_ref(), content);
// Test that files can be deleted with the same storage manager // Test that files can be deleted with the same storage manager
FileInfo::delete_by_id_with_storage(&file_info_1.id, &db, &storage) FileInfo::delete_by_id_with_storage(&file_info_1.id, &db, storage)
.await .await
.unwrap(); .with_context(|| "delete file 1".to_string())?;
FileInfo::delete_by_id_with_storage(&file_info_2.id, &db, &storage) FileInfo::delete_by_id_with_storage(&file_info_2.id, &db, storage)
.await .await
.unwrap(); .with_context(|| "delete file 2".to_string())?;
// Verify files are deleted // Verify files are deleted
let deleted_content_1 = file_info_1.get_content_with_storage(&storage).await; let deleted_content_1 = file_info_1.get_content_with_storage(storage).await;
let deleted_content_2 = file_info_2.get_content_with_storage(&storage).await; let deleted_content_2 = file_info_2.get_content_with_storage(storage).await;
assert!(deleted_content_1.is_err()); assert!(deleted_content_1.is_err());
assert!(deleted_content_2.is_err()); assert!(deleted_content_2.is_err());
Ok(())
} }
} }
+32 -25
View File
@@ -103,6 +103,7 @@ impl IngestionPayload {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use anyhow::{self, Context};
use chrono::Utc; use chrono::Utc;
use super::*; use super::*;
@@ -131,7 +132,7 @@ mod tests {
} }
#[test] #[test]
fn test_create_ingestion_payload_with_url() { fn test_create_ingestion_payload_with_url() -> anyhow::Result<()> {
let url = "https://example.com"; let url = "https://example.com";
let context = "Process this URL"; let context = "Process this URL";
let category = "websites"; let category = "websites";
@@ -145,10 +146,10 @@ mod tests {
files, files,
user_id, user_id,
) )
.unwrap(); .with_context(|| "create_ingestion_payload".to_string())?;
assert_eq!(result.len(), 1); assert_eq!(result.len(), 1);
match &result[0] { match result.first().context("expected one result")? {
IngestionPayload::Url { IngestionPayload::Url {
url: payload_url, url: payload_url,
context: payload_context, context: payload_context,
@@ -156,17 +157,18 @@ mod tests {
user_id: payload_user_id, user_id: payload_user_id,
} => { } => {
// URL parser may normalize the URL by adding a trailing slash // URL parser may normalize the URL by adding a trailing slash
assert!(payload_url == &url.to_string() || payload_url == &format!("{}/", url)); assert!(payload_url == &url.to_string() || payload_url == &format!("{url}/"));
assert_eq!(payload_context, &context); assert_eq!(payload_context, &context);
assert_eq!(payload_category, &category); assert_eq!(payload_category, &category);
assert_eq!(payload_user_id, &user_id); assert_eq!(payload_user_id, &user_id);
} }
_ => panic!("Expected Url variant"), _ => anyhow::bail!("Expected Url variant"),
} }
Ok(())
} }
#[test] #[test]
fn test_create_ingestion_payload_with_text() { fn test_create_ingestion_payload_with_text() -> anyhow::Result<()> {
let text = "This is some text content"; let text = "This is some text content";
let context = "Process this text"; let context = "Process this text";
let category = "notes"; let category = "notes";
@@ -180,10 +182,10 @@ mod tests {
files, files,
user_id, user_id,
) )
.unwrap(); .with_context(|| "create_ingestion_payload".to_string())?;
assert_eq!(result.len(), 1); assert_eq!(result.len(), 1);
match &result[0] { match result.first().context("expected one result")? {
IngestionPayload::Text { IngestionPayload::Text {
text: payload_text, text: payload_text,
context: payload_context, context: payload_context,
@@ -195,12 +197,13 @@ mod tests {
assert_eq!(payload_category, category); assert_eq!(payload_category, category);
assert_eq!(payload_user_id, user_id); assert_eq!(payload_user_id, user_id);
} }
_ => panic!("Expected Text variant"), _ => anyhow::bail!("Expected Text variant"),
} }
Ok(())
} }
#[test] #[test]
fn test_create_ingestion_payload_with_file() { fn test_create_ingestion_payload_with_file() -> anyhow::Result<()> {
let context = "Process this file"; let context = "Process this file";
let category = "documents"; let category = "documents";
let user_id = "user123"; let user_id = "user123";
@@ -220,10 +223,10 @@ mod tests {
files, files,
user_id, user_id,
) )
.unwrap(); .with_context(|| "create_ingestion_payload".to_string())?;
assert_eq!(result.len(), 1); assert_eq!(result.len(), 1);
match &result[0] { match result.first().context("expected one result")? {
IngestionPayload::File { IngestionPayload::File {
file_info: payload_file_info, file_info: payload_file_info,
context: payload_context, context: payload_context,
@@ -235,12 +238,13 @@ mod tests {
assert_eq!(payload_category, category); assert_eq!(payload_category, category);
assert_eq!(payload_user_id, user_id); assert_eq!(payload_user_id, user_id);
} }
_ => panic!("Expected File variant"), _ => anyhow::bail!("Expected File variant"),
} }
Ok(())
} }
#[test] #[test]
fn test_create_ingestion_payload_with_url_and_file() { fn test_create_ingestion_payload_with_url_and_file() -> anyhow::Result<()> {
let url = "https://example.com"; let url = "https://example.com";
let context = "Process this data"; let context = "Process this data";
let category = "mixed"; let category = "mixed";
@@ -261,35 +265,36 @@ mod tests {
files, files,
user_id, user_id,
) )
.unwrap(); .with_context(|| "create_ingestion_payload".to_string())?;
assert_eq!(result.len(), 2); assert_eq!(result.len(), 2);
// Check first item is URL // Check first item is URL
match &result[0] { match result.first().context("expected first item")? {
IngestionPayload::Url { IngestionPayload::Url {
url: payload_url, .. url: payload_url, ..
} => { } => {
// URL parser may normalize the URL by adding a trailing slash // URL parser may normalize the URL by adding a trailing slash
assert!(payload_url == &url.to_string() || payload_url == &format!("{}/", url)); assert!(payload_url == &url.to_string() || payload_url == &format!("{url}/"));
} }
_ => panic!("Expected first item to be Url variant"), _ => anyhow::bail!("Expected first item to be Url variant"),
} }
// Check second item is File // Check second item is File
match &result[1] { match result.get(1).context("expected second item")? {
IngestionPayload::File { IngestionPayload::File {
file_info: payload_file_info, file_info: payload_file_info,
.. ..
} => { } => {
assert_eq!(payload_file_info.id, file_info.id); assert_eq!(payload_file_info.id, file_info.id);
} }
_ => panic!("Expected second item to be File variant"), _ => anyhow::bail!("Expected second item to be File variant"),
} }
Ok(())
} }
#[test] #[test]
fn test_create_ingestion_payload_empty_input() { fn test_create_ingestion_payload_empty_input() -> anyhow::Result<()> {
let context = "Process something"; let context = "Process something";
let category = "empty"; let category = "empty";
let user_id = "user123"; let user_id = "user123";
@@ -308,12 +313,13 @@ mod tests {
Err(AppError::NotFound(msg)) => { Err(AppError::NotFound(msg)) => {
assert_eq!(msg, "No valid content or files provided"); assert_eq!(msg, "No valid content or files provided");
} }
_ => panic!("Expected NotFound error"), _ => anyhow::bail!("Expected NotFound error"),
} }
Ok(())
} }
#[test] #[test]
fn test_create_ingestion_payload_with_empty_text() { fn test_create_ingestion_payload_with_empty_text() -> anyhow::Result<()> {
let text = ""; // Empty text let text = ""; // Empty text
let context = "Process this"; let context = "Process this";
let category = "notes"; let category = "notes";
@@ -333,7 +339,8 @@ mod tests {
Err(AppError::NotFound(msg)) => { Err(AppError::NotFound(msg)) => {
assert_eq!(msg, "No valid content or files provided"); assert_eq!(msg, "No valid content or files provided");
} }
_ => panic!("Expected NotFound error"), _ => anyhow::bail!("Expected NotFound error"),
} }
Ok(())
} }
} }
+45 -36
View File
@@ -529,6 +529,8 @@ impl IngestionTask {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use anyhow::{self, Context};
use super::*; use super::*;
use crate::storage::types::ingestion_payload::IngestionPayload; use crate::storage::types::ingestion_payload::IngestionPayload;
@@ -541,16 +543,16 @@ mod tests {
} }
} }
async fn memory_db() -> SurrealDbClient { async fn memory_db() -> anyhow::Result<SurrealDbClient> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = Uuid::new_v4().to_string(); let database = Uuid::new_v4().to_string();
SurrealDbClient::memory(namespace, &database) SurrealDbClient::memory(namespace, &database)
.await .await
.expect("in-memory surrealdb") .with_context(|| "in-memory surrealdb".to_string())
} }
#[tokio::test] #[tokio::test]
async fn test_new_task_defaults() { async fn test_new_task_defaults() -> anyhow::Result<()> {
let user_id = "user123"; let user_id = "user123";
let payload = create_payload(user_id); let payload = create_payload(user_id);
let task = IngestionTask::new(payload.clone(), user_id.to_string()); let task = IngestionTask::new(payload.clone(), user_id.to_string());
@@ -562,73 +564,76 @@ mod tests {
assert_eq!(task.max_attempts, MAX_ATTEMPTS); assert_eq!(task.max_attempts, MAX_ATTEMPTS);
assert!(task.locked_at.is_none()); assert!(task.locked_at.is_none());
assert!(task.worker_id.is_none()); assert!(task.worker_id.is_none());
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_create_and_store_task() { async fn test_create_and_store_task() -> anyhow::Result<()> {
let db = memory_db().await; let db = memory_db().await?;
let user_id = "user123"; let user_id = "user123";
let payload = create_payload(user_id); let payload = create_payload(user_id);
let created = let created =
IngestionTask::create_and_add_to_db(payload.clone(), user_id.to_string(), &db) IngestionTask::create_and_add_to_db(payload.clone(), user_id.to_string(), &db)
.await .await
.expect("store"); .with_context(|| "store".to_string())?;
let stored: Option<IngestionTask> = db let stored: Option<IngestionTask> = db
.get_item::<IngestionTask>(&created.id) .get_item::<IngestionTask>(&created.id)
.await .await
.expect("fetch"); .with_context(|| "fetch".to_string())?;
let stored = stored.expect("task exists"); let stored = stored.with_context(|| "task exists".to_string())?;
assert_eq!(stored.id, created.id); assert_eq!(stored.id, created.id);
assert_eq!(stored.state, TaskState::Pending); assert_eq!(stored.state, TaskState::Pending);
assert_eq!(stored.attempts, 0); assert_eq!(stored.attempts, 0);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_claim_and_transition() { async fn test_claim_and_transition() -> anyhow::Result<()> {
let db = memory_db().await; let db = memory_db().await?;
let user_id = "user123"; let user_id = "user123";
let payload = create_payload(user_id); let payload = create_payload(user_id);
let task = IngestionTask::new(payload, user_id.to_string()); let task = IngestionTask::new(payload, user_id.to_string());
db.store_item(task.clone()).await.expect("store"); db.store_item(task.clone()).await.with_context(|| "store".to_string())?;
let worker_id = "worker-1"; let worker_id = "worker-1";
let now = chrono::Utc::now(); let now = chrono::Utc::now();
let claimed = IngestionTask::claim_next_ready(&db, worker_id, now, Duration::from_secs(60)) let claimed = IngestionTask::claim_next_ready(&db, worker_id, now, Duration::from_secs(60))
.await .await
.expect("claim"); .with_context(|| "claim".to_string())?
.with_context(|| "task claimed".to_string())?;
let claimed = claimed.expect("task claimed");
assert_eq!(claimed.state, TaskState::Reserved); assert_eq!(claimed.state, TaskState::Reserved);
assert_eq!(claimed.worker_id.as_deref(), Some(worker_id)); assert_eq!(claimed.worker_id.as_deref(), Some(worker_id));
let processing = claimed.mark_processing(&db).await.expect("processing"); let processing = claimed.mark_processing(&db).await.with_context(|| "processing".to_string())?;
assert_eq!(processing.state, TaskState::Processing); assert_eq!(processing.state, TaskState::Processing);
let succeeded = processing.mark_succeeded(&db).await.expect("succeeded"); let succeeded = processing.mark_succeeded(&db).await.with_context(|| "succeeded".to_string())?;
assert_eq!(succeeded.state, TaskState::Succeeded); assert_eq!(succeeded.state, TaskState::Succeeded);
assert!(succeeded.worker_id.is_none()); assert!(succeeded.worker_id.is_none());
assert!(succeeded.locked_at.is_none()); assert!(succeeded.locked_at.is_none());
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_fail_and_dead_letter() { async fn test_fail_and_dead_letter() -> anyhow::Result<()> {
let db = memory_db().await; let db = memory_db().await?;
let user_id = "user123"; let user_id = "user123";
let payload = create_payload(user_id); let payload = create_payload(user_id);
let task = IngestionTask::new(payload, user_id.to_string()); let task = IngestionTask::new(payload, user_id.to_string());
db.store_item(task.clone()).await.expect("store"); db.store_item(task.clone()).await.with_context(|| "store".to_string())?;
let worker_id = "worker-dead"; let worker_id = "worker-dead";
let now = chrono::Utc::now(); let now = chrono::Utc::now();
let claimed = IngestionTask::claim_next_ready(&db, worker_id, now, Duration::from_secs(60)) let claimed = IngestionTask::claim_next_ready(&db, worker_id, now, Duration::from_secs(60))
.await .await
.expect("claim") .with_context(|| "claim".to_string())?
.expect("claimed"); .with_context(|| "claimed".to_string())?;
let processing = claimed.mark_processing(&db).await.expect("processing"); let processing = claimed.mark_processing(&db).await.with_context(|| "processing".to_string())?;
let error_info = TaskErrorInfo { let error_info = TaskErrorInfo {
code: Some("pipeline_error".into()), code: Some("pipeline_error".into()),
@@ -638,7 +643,7 @@ mod tests {
let failed = processing let failed = processing
.mark_failed(error_info.clone(), Duration::from_secs(30), &db) .mark_failed(error_info.clone(), Duration::from_secs(30), &db)
.await .await
.expect("failed update"); .with_context(|| "failed update".to_string())?;
assert_eq!(failed.state, TaskState::Failed); assert_eq!(failed.state, TaskState::Failed);
assert_eq!(failed.error_message.as_deref(), Some("failed")); assert_eq!(failed.error_message.as_deref(), Some("failed"));
assert!(failed.worker_id.is_none()); assert!(failed.worker_id.is_none());
@@ -648,19 +653,20 @@ mod tests {
let dead = failed let dead = failed
.mark_dead_letter(error_info.clone(), &db) .mark_dead_letter(error_info.clone(), &db)
.await .await
.expect("dead letter"); .with_context(|| "dead letter".to_string())?;
assert_eq!(dead.state, TaskState::DeadLetter); assert_eq!(dead.state, TaskState::DeadLetter);
assert_eq!(dead.error_message.as_deref(), Some("failed")); assert_eq!(dead.error_message.as_deref(), Some("failed"));
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_mark_processing_requires_reservation() { async fn test_mark_processing_requires_reservation() -> anyhow::Result<()> {
let db = memory_db().await; let db = memory_db().await?;
let user_id = "user123"; let user_id = "user123";
let payload = create_payload(user_id); let payload = create_payload(user_id);
let task = IngestionTask::new(payload.clone(), user_id.to_string()); let task = IngestionTask::new(payload.clone(), user_id.to_string());
db.store_item(task.clone()).await.expect("store"); db.store_item(task.clone()).await.with_context(|| "store".to_string())?;
let err = task let err = task
.mark_processing(&db) .mark_processing(&db)
@@ -674,18 +680,19 @@ mod tests {
"unexpected message: {message}" "unexpected message: {message}"
); );
} }
other => panic!("expected validation error, got {other:?}"), other => anyhow::bail!("expected validation error, got {other:?}"),
} }
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_mark_failed_requires_processing() { async fn test_mark_failed_requires_processing() -> anyhow::Result<()> {
let db = memory_db().await; let db = memory_db().await?;
let user_id = "user123"; let user_id = "user123";
let payload = create_payload(user_id); let payload = create_payload(user_id);
let task = IngestionTask::new(payload.clone(), user_id.to_string()); let task = IngestionTask::new(payload.clone(), user_id.to_string());
db.store_item(task.clone()).await.expect("store"); db.store_item(task.clone()).await.with_context(|| "store".to_string())?;
let err = task let err = task
.mark_failed( .mark_failed(
@@ -706,18 +713,19 @@ mod tests {
"unexpected message: {message}" "unexpected message: {message}"
); );
} }
other => panic!("expected validation error, got {other:?}"), other => anyhow::bail!("expected validation error, got {other:?}"),
} }
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_release_requires_reservation() { async fn test_release_requires_reservation() -> anyhow::Result<()> {
let db = memory_db().await; let db = memory_db().await?;
let user_id = "user123"; let user_id = "user123";
let payload = create_payload(user_id); let payload = create_payload(user_id);
let task = IngestionTask::new(payload.clone(), user_id.to_string()); let task = IngestionTask::new(payload.clone(), user_id.to_string());
db.store_item(task.clone()).await.expect("store"); db.store_item(task.clone()).await.with_context(|| "store".to_string())?;
let err = task let err = task
.release(&db) .release(&db)
@@ -731,7 +739,8 @@ mod tests {
"unexpected message: {message}" "unexpected message: {message}"
); );
} }
other => panic!("expected validation error, got {other:?}"), other => anyhow::bail!("expected validation error, got {other:?}"),
} }
Ok(())
} }
} }
+98 -82
View File
@@ -5,7 +5,6 @@
clippy::format_push_string, clippy::format_push_string,
clippy::uninlined_format_args, clippy::uninlined_format_args,
clippy::explicit_iter_loop, clippy::explicit_iter_loop,
clippy::items_after_statements,
clippy::get_first, clippy::get_first,
clippy::redundant_closure_for_method_calls clippy::redundant_closure_for_method_calls
)] )]
@@ -317,6 +316,11 @@ impl KnowledgeEntity {
} }
async fn get_user_id_by_id(id: &str, db_client: &SurrealDbClient) -> Result<String, AppError> { async fn get_user_id_by_id(id: &str, db_client: &SurrealDbClient) -> Result<String, AppError> {
#[derive(Deserialize)]
struct Row {
user_id: String,
}
let mut response = db_client let mut response = db_client
.client .client
.query("SELECT user_id FROM type::thing($table, $id) LIMIT 1") .query("SELECT user_id FROM type::thing($table, $id) LIMIT 1")
@@ -324,10 +328,6 @@ impl KnowledgeEntity {
.bind(("id", id.to_string())) .bind(("id", id.to_string()))
.await .await
.map_err(AppError::Database)?; .map_err(AppError::Database)?;
#[derive(Deserialize)]
struct Row {
user_id: String,
}
let rows: Vec<Row> = response.take(0).map_err(AppError::Database)?; let rows: Vec<Row> = response.take(0).map_err(AppError::Database)?;
rows.get(0) rows.get(0)
.map(|r| r.user_id.clone()) .map(|r| r.user_id.clone())
@@ -497,7 +497,6 @@ impl KnowledgeEntity {
new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone())); new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone()));
} }
info!("Successfully generated all new embeddings."); info!("Successfully generated all new embeddings.");
info!("Successfully generated all new embeddings.");
// Clear existing embeddings and index first to prevent SurrealDB panics and dimension conflicts. // Clear existing embeddings and index first to prevent SurrealDB panics and dimension conflicts.
info!("Removing old index and clearing embeddings..."); info!("Removing old index and clearing embeddings...");
@@ -572,14 +571,14 @@ impl KnowledgeEntity {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use anyhow::{self, Context};
use super::*; use super::*;
use crate::storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding; use crate::storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding;
use serde_json::json; use serde_json::json;
use uuid::Uuid; use uuid::Uuid;
#[tokio::test] #[tokio::test]
async fn test_knowledge_entity_creation() { async fn test_knowledge_entity_creation() -> anyhow::Result<()> {
// Create basic test entity
let source_id = "source123".to_string(); let source_id = "source123".to_string();
let name = "Test Entity".to_string(); let name = "Test Entity".to_string();
let description = "Test Description".to_string(); let description = "Test Description".to_string();
@@ -596,7 +595,6 @@ mod tests {
user_id.clone(), user_id.clone(),
); );
// Verify all fields are set correctly
assert_eq!(entity.source_id, source_id); assert_eq!(entity.source_id, source_id);
assert_eq!(entity.name, name); assert_eq!(entity.name, name);
assert_eq!(entity.description, description); assert_eq!(entity.description, description);
@@ -604,11 +602,12 @@ mod tests {
assert_eq!(entity.metadata, metadata); assert_eq!(entity.metadata, metadata);
assert_eq!(entity.user_id, user_id); assert_eq!(entity.user_id, user_id);
assert!(!entity.id.is_empty()); assert!(!entity.id.is_empty());
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_knowledge_entity_type_from_string() { async fn test_knowledge_entity_type_from_string() -> anyhow::Result<()> {
// Test conversion from String to KnowledgeEntityType
assert_eq!( assert_eq!(
KnowledgeEntityType::from("idea".to_string()), KnowledgeEntityType::from("idea".to_string()),
KnowledgeEntityType::Idea KnowledgeEntityType::Idea
@@ -639,15 +638,16 @@ mod tests {
KnowledgeEntityType::TextSnippet KnowledgeEntityType::TextSnippet
); );
// Test default case
assert_eq!( assert_eq!(
KnowledgeEntityType::from("unknown".to_string()), KnowledgeEntityType::from("unknown".to_string()),
KnowledgeEntityType::Document KnowledgeEntityType::Document
); );
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_knowledge_entity_variants() { async fn test_knowledge_entity_variants() -> anyhow::Result<()> {
let variants = KnowledgeEntityType::variants(); let variants = KnowledgeEntityType::variants();
assert_eq!(variants.len(), 5); assert_eq!(variants.len(), 5);
assert!(variants.contains(&"Idea")); assert!(variants.contains(&"Idea"));
@@ -655,28 +655,28 @@ mod tests {
assert!(variants.contains(&"Document")); assert!(variants.contains(&"Document"));
assert!(variants.contains(&"Page")); assert!(variants.contains(&"Page"));
assert!(variants.contains(&"TextSnippet")); assert!(variants.contains(&"TextSnippet"));
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_delete_by_source_id() { async fn test_delete_by_source_id() -> anyhow::Result<()> {
// Setup in-memory database for testing
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Failed to apply migrations"); .with_context(|| "Failed to apply migrations".to_string())?;
// Create two entities with the same source_id
let source_id = "source123".to_string(); let source_id = "source123".to_string();
let entity_type = KnowledgeEntityType::Document; let entity_type = KnowledgeEntityType::Document;
let user_id = "user123".to_string(); let user_id = "user123".to_string();
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 5) KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 5)
.await .await
.expect("Failed to redefine index length"); .with_context(|| "Failed to redefine index length".to_string())?;
let entity1 = KnowledgeEntity::new( let entity1 = KnowledgeEntity::new(
source_id.clone(), source_id.clone(),
@@ -696,7 +696,6 @@ mod tests {
user_id.clone(), user_id.clone(),
); );
// Create an entity with a different source_id
let different_source_id = "different_source".to_string(); let different_source_id = "different_source".to_string();
let different_entity = KnowledgeEntity::new( let different_entity = KnowledgeEntity::new(
different_source_id.clone(), different_source_id.clone(),
@@ -708,23 +707,20 @@ mod tests {
); );
let emb = vec![0.1, 0.2, 0.3, 0.4, 0.5]; let emb = vec![0.1, 0.2, 0.3, 0.4, 0.5];
// Store the entities
KnowledgeEntity::store_with_embedding(entity1.clone(), emb.clone(), &db) KnowledgeEntity::store_with_embedding(entity1.clone(), emb.clone(), &db)
.await .await
.expect("Failed to store entity 1"); .with_context(|| "Failed to store entity 1".to_string())?;
KnowledgeEntity::store_with_embedding(entity2.clone(), emb.clone(), &db) KnowledgeEntity::store_with_embedding(entity2.clone(), emb.clone(), &db)
.await .await
.expect("Failed to store entity 2"); .with_context(|| "Failed to store entity 2".to_string())?;
KnowledgeEntity::store_with_embedding(different_entity.clone(), emb.clone(), &db) KnowledgeEntity::store_with_embedding(different_entity.clone(), emb.clone(), &db)
.await .await
.expect("Failed to store different entity"); .with_context(|| "Failed to store different entity".to_string())?;
// Delete by source_id
KnowledgeEntity::delete_by_source_id(&source_id, &db) KnowledgeEntity::delete_by_source_id(&source_id, &db)
.await .await
.expect("Failed to delete entities by source_id"); .with_context(|| "Failed to delete entities by source_id".to_string())?;
// Verify all entities with the specified source_id are deleted
let query = format!( let query = format!(
"SELECT * FROM {} WHERE source_id = '{}'", "SELECT * FROM {} WHERE source_id = '{}'",
KnowledgeEntity::table_name(), KnowledgeEntity::table_name(),
@@ -734,16 +730,11 @@ mod tests {
.client .client
.query(query) .query(query)
.await .await
.expect("Query failed") .with_context(|| "Query failed".to_string())?
.take(0) .take(0)
.expect("Failed to get query results"); .with_context(|| "Failed to get query results".to_string())?;
assert_eq!( assert!(remaining.is_empty(), "All entities with the source_id should be deleted");
remaining.len(),
0,
"All entities with the source_id should be deleted"
);
// Verify the entity with a different source_id still exists
let different_query = format!( let different_query = format!(
"SELECT * FROM {} WHERE source_id = '{}'", "SELECT * FROM {} WHERE source_id = '{}'",
KnowledgeEntity::table_name(), KnowledgeEntity::table_name(),
@@ -753,15 +744,20 @@ mod tests {
.client .client
.query(different_query) .query(different_query)
.await .await
.expect("Query failed") .with_context(|| "Query failed".to_string())?
.take(0) .take(0)
.expect("Failed to get query results"); .with_context(|| "Failed to get query results".to_string())?;
assert_eq!( assert_eq!(
different_remaining.len(), different_remaining.len(),
1, 1,
"Entity with different source_id should still exist" "Entity with different source_id should still exist"
); );
assert_eq!(different_remaining[0].id, different_entity.id); assert_eq!(
different_remaining.first().context("Expected entity to exist")?.id,
different_entity.id
);
Ok(())
} }
#[tokio::test] #[tokio::test]
@@ -833,35 +829,37 @@ mod tests {
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Failed to apply migrations"); .with_context(|| "Failed to apply migrations".to_string())?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await .await
.expect("Failed to redefine index length"); .with_context(|| "Failed to redefine index length".to_string())?;
let results = KnowledgeEntity::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user") let results = KnowledgeEntity::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user")
.await .await
.expect("vector search"); .with_context(|| "vector search".to_string())?;
assert!(results.is_empty()); assert!(results.is_empty());
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_vector_search_single_result() { async fn test_vector_search_single_result() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Failed to apply migrations"); .with_context(|| "Failed to apply migrations".to_string())?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await .await
.expect("Failed to redefine index length"); .with_context(|| "Failed to redefine index length".to_string())?;
let user_id = "user".to_string(); let user_id = "user".to_string();
let source_id = "src".to_string(); let source_id = "src".to_string();
@@ -876,9 +874,12 @@ mod tests {
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], &db) KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], &db)
.await .await
.expect("store entity with embedding"); .with_context(|| "store entity with embedding".to_string())?;
let stored_entity: Option<KnowledgeEntity> = db.get_item(&entity.id).await.unwrap(); let stored_entity: Option<KnowledgeEntity> = db
.get_item(&entity.id)
.await
.with_context(|| "Failed to get entity".to_string())?;
assert!(stored_entity.is_some()); assert!(stored_entity.is_some());
let stored_embeddings: Vec<KnowledgeEntityEmbedding> = db let stored_embeddings: Vec<KnowledgeEntityEmbedding> = db
@@ -888,42 +889,44 @@ mod tests {
KnowledgeEntityEmbedding::table_name() KnowledgeEntityEmbedding::table_name()
)) ))
.await .await
.expect("query embeddings") .with_context(|| "query embeddings".to_string())?
.take(0) .take(0)
.expect("take embeddings"); .with_context(|| "take embeddings".to_string())?;
assert_eq!(stored_embeddings.len(), 1); assert_eq!(stored_embeddings.len(), 1);
let rid = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id); let rid = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
let fetched_emb = KnowledgeEntityEmbedding::get_by_entity_id(&rid, &db) let fetched_emb = KnowledgeEntityEmbedding::get_by_entity_id(&rid, &db)
.await .await
.expect("fetch embedding"); .with_context(|| "fetch embedding".to_string())?;
assert!(fetched_emb.is_some()); assert!(fetched_emb.is_some());
let results = KnowledgeEntity::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id) let results = KnowledgeEntity::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
.await .await
.expect("vector search"); .with_context(|| "vector search".to_string())?;
assert_eq!(results.len(), 1); assert_eq!(results.len(), 1);
let res = &results[0]; let res = results.first().context("Expected at least one result")?;
assert_eq!(res.entity.id, entity.id); assert_eq!(res.entity.id, entity.id);
assert_eq!(res.entity.source_id, source_id); assert_eq!(res.entity.source_id, source_id);
assert_eq!(res.entity.name, "hello"); assert_eq!(res.entity.name, "hello");
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_vector_search_orders_by_similarity() { async fn test_vector_search_orders_by_similarity() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Failed to apply migrations"); .with_context(|| "Failed to apply migrations".to_string())?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await .await
.expect("Failed to redefine index length"); .with_context(|| "Failed to redefine index length".to_string())?;
let user_id = "user".to_string(); let user_id = "user".to_string();
let e1 = KnowledgeEntity::new( let e1 = KnowledgeEntity::new(
@@ -945,13 +948,19 @@ mod tests {
KnowledgeEntity::store_with_embedding(e1.clone(), vec![1.0, 0.0, 0.0], &db) KnowledgeEntity::store_with_embedding(e1.clone(), vec![1.0, 0.0, 0.0], &db)
.await .await
.expect("store e1"); .with_context(|| "store e1".to_string())?;
KnowledgeEntity::store_with_embedding(e2.clone(), vec![0.0, 1.0, 0.0], &db) KnowledgeEntity::store_with_embedding(e2.clone(), vec![0.0, 1.0, 0.0], &db)
.await .await
.expect("store e2"); .with_context(|| "store e2".to_string())?;
let stored_e1: Option<KnowledgeEntity> = db.get_item(&e1.id).await.unwrap(); let stored_e1: Option<KnowledgeEntity> = db
let stored_e2: Option<KnowledgeEntity> = db.get_item(&e2.id).await.unwrap(); .get_item(&e1.id)
.await
.with_context(|| "Failed to get entity".to_string())?;
let stored_e2: Option<KnowledgeEntity> = db
.get_item(&e2.id)
.await
.with_context(|| "Failed to get entity".to_string())?;
assert!(stored_e1.is_some() && stored_e2.is_some()); assert!(stored_e1.is_some() && stored_e2.is_some());
let stored_embeddings: Vec<KnowledgeEntityEmbedding> = db let stored_embeddings: Vec<KnowledgeEntityEmbedding> = db
@@ -961,45 +970,53 @@ mod tests {
KnowledgeEntityEmbedding::table_name() KnowledgeEntityEmbedding::table_name()
)) ))
.await .await
.expect("query embeddings") .with_context(|| "query embeddings".to_string())?
.take(0) .take(0)
.expect("take embeddings"); .with_context(|| "take embeddings".to_string())?;
assert_eq!(stored_embeddings.len(), 2); assert_eq!(stored_embeddings.len(), 2);
let rid_e1 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e1.id); let rid_e1 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e1.id);
let rid_e2 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e2.id); let rid_e2 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e2.id);
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e1, &db) assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e1, &db)
.await .await
.unwrap() .with_context(|| "get embedding e1".to_string())?
.is_some()); .is_some());
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e2, &db) assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e2, &db)
.await .await
.unwrap() .with_context(|| "get embedding e2".to_string())?
.is_some()); .is_some());
let results = KnowledgeEntity::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id) let results = KnowledgeEntity::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id)
.await .await
.expect("vector search"); .with_context(|| "vector search".to_string())?;
assert_eq!(results.len(), 2); assert_eq!(results.len(), 2);
assert_eq!(results[0].entity.id, e2.id); assert_eq!(
assert_eq!(results[1].entity.id, e1.id); results.first().context("Expected at least one result")?.entity.id,
e2.id
);
assert_eq!(
results.get(1).context("Expected at least two results")?.entity.id,
e1.id
);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_vector_search_with_orphaned_embedding() { async fn test_vector_search_with_orphaned_embedding() -> anyhow::Result<()> {
let namespace = "test_ns_orphan"; let namespace = "test_ns_orphan";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Failed to apply migrations"); .with_context(|| "Failed to apply migrations".to_string())?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await .await
.expect("Failed to redefine index length"); .with_context(|| "Failed to redefine index length".to_string())?;
let user_id = "user".to_string(); let user_id = "user".to_string();
let source_id = "src".to_string(); let source_id = "src".to_string();
@@ -1014,21 +1031,20 @@ mod tests {
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], &db) KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], &db)
.await .await
.expect("store entity with embedding"); .with_context(|| "store entity with embedding".to_string())?;
// Manually delete the entity to create an orphan
let query = format!("DELETE type::thing('knowledge_entity', '{}')", entity.id); let query = format!("DELETE type::thing('knowledge_entity', '{}')", entity.id);
db.client.query(query).await.expect("delete entity"); db.client
.query(query)
.await
.with_context(|| "delete entity".to_string())?;
// Now search
let results = KnowledgeEntity::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id) let results = KnowledgeEntity::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
.await .await
.expect("search should succeed even with orphans"); .with_context(|| "search should succeed even with orphans".to_string())?;
assert!( assert!(results.is_empty(), "Should return empty result for orphan, got: {:?}", results);
results.is_empty(),
"Should return empty result for orphan, got: {:?}", Ok(())
results
);
} }
} }
@@ -110,11 +110,15 @@ impl KnowledgeEntityEmbedding {
} }
/// Delete embeddings by source_id (via joining to knowledge_entity table) /// Delete embeddings by source_id (via joining to knowledge_entity table)
#[allow(clippy::items_after_statements)]
pub async fn delete_by_source_id( pub async fn delete_by_source_id(
source_id: &str, source_id: &str,
db: &SurrealDbClient, db: &SurrealDbClient,
) -> Result<(), AppError> { ) -> Result<(), AppError> {
#[derive(Deserialize)]
struct IdRow {
id: RecordId,
}
let query = "SELECT id FROM knowledge_entity WHERE source_id = $source_id"; let query = "SELECT id FROM knowledge_entity WHERE source_id = $source_id";
let mut res = db let mut res = db
.client .client
@@ -122,11 +126,6 @@ impl KnowledgeEntityEmbedding {
.bind(("source_id", source_id.to_owned())) .bind(("source_id", source_id.to_owned()))
.await .await
.map_err(AppError::Database)?; .map_err(AppError::Database)?;
#[allow(clippy::missing_docs_in_private_items)]
#[derive(Deserialize)]
struct IdRow {
id: RecordId,
}
let ids: Vec<IdRow> = res.take(0).map_err(AppError::Database)?; let ids: Vec<IdRow> = res.take(0).map_err(AppError::Database)?;
for row in ids { for row in ids {
@@ -138,6 +137,7 @@ impl KnowledgeEntityEmbedding {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use anyhow::{self, Context};
use super::*; use super::*;
use crate::storage::db::SurrealDbClient; use crate::storage::db::SurrealDbClient;
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}; use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
@@ -145,18 +145,18 @@ mod tests {
use surrealdb::Value as SurrealValue; use surrealdb::Value as SurrealValue;
use uuid::Uuid; use uuid::Uuid;
async fn setup_test_db() -> SurrealDbClient { async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = Uuid::new_v4().to_string(); let database = Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, &database) let db = SurrealDbClient::memory(namespace, &database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Failed to apply migrations"); .with_context(|| "Failed to apply migrations".to_string())?;
db Ok(db)
} }
fn build_knowledge_entity_with_id( fn build_knowledge_entity_with_id(
@@ -178,11 +178,11 @@ mod tests {
} }
#[tokio::test] #[tokio::test]
async fn test_create_and_get_by_entity_id() { async fn test_create_and_get_by_entity_id() -> anyhow::Result<()> {
let db = setup_test_db().await; let db = setup_test_db().await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await .await
.expect("set test index dimension"); .with_context(|| "set test index dimension".to_string())?;
let user_id = "user_ke"; let user_id = "user_ke";
let entity_key = "entity-1"; let entity_key = "entity-1";
let source_id = "source-ke"; let source_id = "source-ke";
@@ -192,26 +192,28 @@ mod tests {
KnowledgeEntity::store_with_embedding(entity.clone(), embedding_vec.clone(), &db) KnowledgeEntity::store_with_embedding(entity.clone(), embedding_vec.clone(), &db)
.await .await
.expect("Failed to store entity with embedding"); .with_context(|| "Failed to store entity with embedding".to_string())?;
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id); let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
let fetched = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db) let fetched = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
.await .await
.expect("Failed to get embedding by entity_id") .with_context(|| "Failed to get embedding by entity_id".to_string())?
.expect("Expected embedding to exist"); .ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?;
assert_eq!(fetched.user_id, user_id); assert_eq!(fetched.user_id, user_id);
assert_eq!(fetched.entity_id, entity_rid); assert_eq!(fetched.entity_id, entity_rid);
assert_eq!(fetched.embedding, embedding_vec); assert_eq!(fetched.embedding, embedding_vec);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_delete_by_entity_id() { async fn test_delete_by_entity_id() -> anyhow::Result<()> {
let db = setup_test_db().await; let db = setup_test_db().await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await .await
.expect("set test index dimension"); .with_context(|| "set test index dimension".to_string())?;
let user_id = "user_ke"; let user_id = "user_ke";
let entity_key = "entity-delete"; let entity_key = "entity-delete";
let source_id = "source-del"; let source_id = "source-del";
@@ -220,61 +222,67 @@ mod tests {
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.5_f32, 0.6, 0.7], &db) KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.5_f32, 0.6, 0.7], &db)
.await .await
.expect("Failed to store entity with embedding"); .with_context(|| "Failed to store entity with embedding".to_string())?;
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id); let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
let existing = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db) let existing = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
.await .await
.expect("Failed to get embedding before delete"); .with_context(|| "Failed to get embedding before delete".to_string())?;
assert!(existing.is_some()); assert!(existing.is_some());
KnowledgeEntityEmbedding::delete_by_entity_id(&entity_rid, &db) KnowledgeEntityEmbedding::delete_by_entity_id(&entity_rid, &db)
.await .await
.expect("Failed to delete by entity_id"); .with_context(|| "Failed to delete by entity_id".to_string())?;
let after = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db) let after = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
.await .await
.expect("Failed to get embedding after delete"); .with_context(|| "Failed to get embedding after delete".to_string())?;
assert!(after.is_none()); assert!(after.is_none());
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_store_with_embedding_creates_entity_and_embedding() { async fn test_store_with_embedding_creates_entity_and_embedding() -> anyhow::Result<()> {
let db = setup_test_db().await; let db = setup_test_db().await?;
let user_id = "user_store"; let user_id = "user_store";
let source_id = "source_store"; let source_id = "source_store";
let embedding = vec![0.2_f32, 0.3, 0.4]; let embedding = vec![0.2_f32, 0.3, 0.4];
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, embedding.len()) KnowledgeEntityEmbedding::redefine_hnsw_index(&db, embedding.len())
.await .await
.expect("set test index dimension"); .with_context(|| "set test index dimension".to_string())?;
let entity = build_knowledge_entity_with_id("entity-store", source_id, user_id); let entity = build_knowledge_entity_with_id("entity-store", source_id, user_id);
KnowledgeEntity::store_with_embedding(entity.clone(), embedding.clone(), &db) KnowledgeEntity::store_with_embedding(entity.clone(), embedding.clone(), &db)
.await .await
.expect("Failed to store entity with embedding"); .with_context(|| "Failed to store entity with embedding".to_string())?;
let stored_entity: Option<KnowledgeEntity> = db.get_item(&entity.id).await.unwrap(); let stored_entity: Option<KnowledgeEntity> = db
.get_item(&entity.id)
.await
.with_context(|| "Failed to get entity".to_string())?;
assert!(stored_entity.is_some()); assert!(stored_entity.is_some());
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id); let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
let stored_embedding = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db) let stored_embedding = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
.await .await
.expect("Failed to fetch embedding"); .with_context(|| "Failed to fetch embedding".to_string())?;
assert!(stored_embedding.is_some()); let stored_embedding = stored_embedding.ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?;
let stored_embedding = stored_embedding.unwrap();
assert_eq!(stored_embedding.user_id, user_id); assert_eq!(stored_embedding.user_id, user_id);
assert_eq!(stored_embedding.entity_id, entity_rid); assert_eq!(stored_embedding.entity_id, entity_rid);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_delete_by_source_id() { async fn test_delete_by_source_id() -> anyhow::Result<()> {
let db = setup_test_db().await; let db = setup_test_db().await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await .await
.expect("set test index dimension"); .with_context(|| "set test index dimension".to_string())?;
let user_id = "user_ke"; let user_id = "user_ke";
let source_id = "shared-ke"; let source_id = "shared-ke";
let other_source = "other-ke"; let other_source = "other-ke";
@@ -285,13 +293,13 @@ mod tests {
KnowledgeEntity::store_with_embedding(entity1.clone(), vec![1.0_f32, 1.1, 1.2], &db) KnowledgeEntity::store_with_embedding(entity1.clone(), vec![1.0_f32, 1.1, 1.2], &db)
.await .await
.expect("Failed to store entity with embedding"); .with_context(|| "Failed to store entity with embedding".to_string())?;
KnowledgeEntity::store_with_embedding(entity2.clone(), vec![2.0_f32, 2.1, 2.2], &db) KnowledgeEntity::store_with_embedding(entity2.clone(), vec![2.0_f32, 2.1, 2.2], &db)
.await .await
.expect("Failed to store entity with embedding"); .with_context(|| "Failed to store entity with embedding".to_string())?;
KnowledgeEntity::store_with_embedding(entity_other.clone(), vec![3.0_f32, 3.1, 3.2], &db) KnowledgeEntity::store_with_embedding(entity_other.clone(), vec![3.0_f32, 3.1, 3.2], &db)
.await .await
.expect("Failed to store entity with embedding"); .with_context(|| "Failed to store entity with embedding".to_string())?;
let entity1_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity1.id); let entity1_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity1.id);
let entity2_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity2.id); let entity2_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity2.id);
@@ -299,59 +307,74 @@ mod tests {
KnowledgeEntityEmbedding::delete_by_source_id(source_id, &db) KnowledgeEntityEmbedding::delete_by_source_id(source_id, &db)
.await .await
.expect("Failed to delete by source_id"); .with_context(|| "Failed to delete by source_id".to_string())?;
assert!( assert!(
KnowledgeEntityEmbedding::get_by_entity_id(&entity1_rid, &db) KnowledgeEntityEmbedding::get_by_entity_id(&entity1_rid, &db)
.await .await
.unwrap() .with_context(|| "get entity1 embedding after delete".to_string())?
.is_none() .is_none()
); );
assert!( assert!(
KnowledgeEntityEmbedding::get_by_entity_id(&entity2_rid, &db) KnowledgeEntityEmbedding::get_by_entity_id(&entity2_rid, &db)
.await .await
.unwrap() .with_context(|| "get entity2 embedding after delete".to_string())?
.is_none() .is_none()
); );
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&other_rid, &db) assert!(KnowledgeEntityEmbedding::get_by_entity_id(&other_rid, &db)
.await .await
.unwrap() .with_context(|| "get other embedding after delete".to_string())?
.is_some()); .is_some());
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_redefine_hnsw_index_updates_dimension() { async fn test_redefine_hnsw_index_updates_dimension() -> anyhow::Result<()> {
let db = setup_test_db().await; let db = setup_test_db().await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 16) KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 16)
.await .await
.expect("failed to redefine index"); .with_context(|| "failed to redefine index".to_string())?;
let mut info_res = db let mut info_res = db
.client .client
.query("INFO FOR TABLE knowledge_entity_embedding;") .query("INFO FOR TABLE knowledge_entity_embedding;")
.await .await
.expect("info query failed"); .with_context(|| "info query failed".to_string())?;
let info: SurrealValue = info_res.take(0).expect("failed to take info result"); let info: SurrealValue = info_res
let info_json: serde_json::Value = .take(0)
serde_json::to_value(info).expect("failed to convert info to json"); .with_context(|| "failed to take info result".to_string())?;
let idx_sql = info_json["Object"]["indexes"]["Object"] let info_json: serde_json::Value = serde_json::to_value(info)
["idx_embedding_knowledge_entity_embedding"]["Strand"] .with_context(|| "failed to convert info to json".to_string())?;
.as_str() let idx_sql = info_json
.get("Object")
.and_then(|v| v.get("indexes"))
.and_then(|v| v.get("Object"))
.and_then(|v| v.get("idx_embedding_knowledge_entity_embedding"))
.and_then(|v| v.get("Strand"))
.and_then(|v| v.as_str())
.unwrap_or_default(); .unwrap_or_default();
assert!( assert!(
idx_sql.contains("DIMENSION 16"), idx_sql.contains("DIMENSION 16"),
"expected index definition to contain new dimension, got: {idx_sql}" "expected index definition to contain new dimension, got: {idx_sql}"
); );
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_fetch_entity_via_record_id() { async fn test_fetch_entity_via_record_id() -> anyhow::Result<()> {
let db = setup_test_db().await; #[derive(Deserialize)]
struct Row {
entity_id: KnowledgeEntity,
}
let db = setup_test_db().await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await .await
.expect("set test index dimension"); .with_context(|| "set test index dimension".to_string())?;
let user_id = "user_ke"; let user_id = "user_ke";
let entity_key = "entity-fetch"; let entity_key = "entity-fetch";
let source_id = "source-fetch"; let source_id = "source-fetch";
@@ -359,15 +382,10 @@ mod tests {
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id); let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.7_f32, 0.8, 0.9], &db) KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.7_f32, 0.8, 0.9], &db)
.await .await
.expect("Failed to store entity with embedding"); .with_context(|| "Failed to store entity with embedding".to_string())?;
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id); let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
#[derive(Deserialize)]
struct Row {
entity_id: KnowledgeEntity,
}
let mut res = db let mut res = db
.client .client
.query( .query(
@@ -375,13 +393,17 @@ mod tests {
) )
.bind(("id", entity_rid.clone())) .bind(("id", entity_rid.clone()))
.await .await
.expect("failed to fetch embedding with FETCH"); .with_context(|| "failed to fetch embedding with FETCH".to_string())?;
let rows: Vec<Row> = res.take(0).expect("failed to deserialize fetch rows"); let rows: Vec<Row> = res
.take(0)
.with_context(|| "failed to deserialize fetch rows".to_string())?;
assert_eq!(rows.len(), 1); assert_eq!(rows.len(), 1);
let fetched_entity = &rows[0].entity_id; let fetched_entity = &rows.first().context("Expected at least one result")?.entity_id;
assert_eq!(fetched_entity.id, entity_key); assert_eq!(fetched_entity.id, entity_key);
assert_eq!(fetched_entity.name, "Test entity"); assert_eq!(fetched_entity.name, "Test entity");
assert_eq!(fetched_entity.user_id, user_id); assert_eq!(fetched_entity.user_id, user_id);
Ok(())
} }
} }
@@ -124,6 +124,7 @@ impl KnowledgeRelationship {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use anyhow::{self, Context};
use super::*; use super::*;
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}; use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
@@ -155,10 +156,9 @@ mod tests {
result.take(0).expect("failed to take relationship by id") result.take(0).expect("failed to take relationship by id")
} }
// Helper function to create a test knowledge entity for the relationship tests async fn create_test_entity(name: &str, db_client: &SurrealDbClient) -> anyhow::Result<String> {
async fn create_test_entity(name: &str, db_client: &SurrealDbClient) -> String {
let source_id = "source123".to_string(); let source_id = "source123".to_string();
let description = format!("Description for {}", name); let description = format!("Description for {name}");
let entity_type = KnowledgeEntityType::Document; let entity_type = KnowledgeEntityType::Document;
let user_id = "user123".to_string(); let user_id = "user123".to_string();
@@ -174,12 +174,14 @@ mod tests {
let stored: Option<KnowledgeEntity> = db_client let stored: Option<KnowledgeEntity> = db_client
.store_item(entity) .store_item(entity)
.await .await
.expect("Failed to store entity"); .with_context(|| "Failed to store entity".to_string())?;
stored.unwrap().id stored
.ok_or_else(|| anyhow::anyhow!("Expected stored entity to return Some"))
.map(|e| e.id)
} }
#[tokio::test] #[tokio::test]
async fn test_relationship_creation() { async fn test_relationship_creation() -> anyhow::Result<()> {
let in_id = "entity1".to_string(); let in_id = "entity1".to_string();
let out_id = "entity2".to_string(); let out_id = "entity2".to_string();
let user_id = "user123".to_string(); let user_id = "user123".to_string();
@@ -194,25 +196,23 @@ mod tests {
relationship_type.clone(), relationship_type.clone(),
); );
// Verify fields are correctly set
assert_eq!(relationship.in_, in_id); assert_eq!(relationship.in_, in_id);
assert_eq!(relationship.out, out_id); assert_eq!(relationship.out, out_id);
assert_eq!(relationship.metadata.user_id, user_id); assert_eq!(relationship.metadata.user_id, user_id);
assert_eq!(relationship.metadata.source_id, source_id); assert_eq!(relationship.metadata.source_id, source_id);
assert_eq!(relationship.metadata.relationship_type, relationship_type); assert_eq!(relationship.metadata.relationship_type, relationship_type);
assert!(!relationship.id.is_empty()); assert!(!relationship.id.is_empty());
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_store_and_verify_by_source_id() { async fn test_store_and_verify_by_source_id() -> anyhow::Result<()> {
// Setup in-memory database for testing
let db = setup_test_db().await; let db = setup_test_db().await;
// Create two entities to relate let entity1_id = create_test_entity("Entity 1", &db).await?;
let entity1_id = create_test_entity("Entity 1", &db).await; let entity2_id = create_test_entity("Entity 2", &db).await?;
let entity2_id = create_test_entity("Entity 2", &db).await;
// Create relationship
let user_id = "user123".to_string(); let user_id = "user123".to_string();
let source_id = "source123".to_string(); let source_id = "source123".to_string();
let relationship_type = "references".to_string(); let relationship_type = "references".to_string();
@@ -225,11 +225,10 @@ mod tests {
relationship_type, relationship_type,
); );
// Store the relationship
relationship relationship
.store_relationship(&db) .store_relationship(&db)
.await .await
.expect("Failed to store relationship"); .with_context(|| "Failed to store relationship".to_string())?;
let persisted = get_relationship_by_id(&relationship.id, &db) let persisted = get_relationship_by_id(&relationship.id, &db)
.await .await
@@ -239,8 +238,6 @@ mod tests {
assert_eq!(persisted.metadata.user_id, user_id); assert_eq!(persisted.metadata.user_id, user_id);
assert_eq!(persisted.metadata.source_id, source_id); assert_eq!(persisted.metadata.source_id, source_id);
// Query to verify the relationship exists by checking for relationships with our source_id
// This approach is more reliable than trying to look up by ID
let mut check_result = db let mut check_result = db
.query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id") .query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id")
.bind(("source_id", source_id.clone())) .bind(("source_id", source_id.clone()))
@@ -253,14 +250,16 @@ mod tests {
1, 1,
"Expected one relationship for source_id" "Expected one relationship for source_id"
); );
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_store_relationship_resists_query_injection() { async fn test_store_relationship_resists_query_injection() -> anyhow::Result<()> {
let db = setup_test_db().await; let db = setup_test_db().await;
let entity1_id = create_test_entity("Entity 1", &db).await; let entity1_id = create_test_entity("Entity 1", &db).await?;
let entity2_id = create_test_entity("Entity 2", &db).await; let entity2_id = create_test_entity("Entity 2", &db).await?;
let relationship = KnowledgeRelationship::new( let relationship = KnowledgeRelationship::new(
entity1_id, entity1_id,
@@ -288,18 +287,17 @@ mod tests {
rows[0].metadata.source_id, rows[0].metadata.source_id,
"source123'; DELETE FROM relates_to; --" "source123'; DELETE FROM relates_to; --"
); );
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_store_and_delete_relationship() { async fn test_store_and_delete_relationship() -> anyhow::Result<()> {
// Setup in-memory database for testing
let db = setup_test_db().await; let db = setup_test_db().await;
// Create two entities to relate let entity1_id = create_test_entity("Entity 1", &db).await?;
let entity1_id = create_test_entity("Entity 1", &db).await; let entity2_id = create_test_entity("Entity 2", &db).await?;
let entity2_id = create_test_entity("Entity 2", &db).await;
// Create relationship
let user_id = "user123".to_string(); let user_id = "user123".to_string();
let source_id = "source123".to_string(); let source_id = "source123".to_string();
let relationship_type = "references".to_string(); let relationship_type = "references".to_string();
@@ -312,52 +310,44 @@ mod tests {
relationship_type, relationship_type,
); );
// Store relationship
relationship relationship
.store_relationship(&db) .store_relationship(&db)
.await .await
.expect("Failed to store relationship"); .with_context(|| "Failed to store relationship".to_string())?;
// Ensure relationship exists before deletion attempt
let mut existing_before_delete = db let mut existing_before_delete = db
.query(format!( .query(format!(
"SELECT * FROM relates_to WHERE metadata.user_id = '{}' AND metadata.source_id = '{}'", "SELECT * FROM relates_to WHERE metadata.user_id = '{user_id}' AND metadata.source_id = '{source_id}'"
user_id, source_id
)) ))
.await .await
.expect("Query failed"); .with_context(|| "Query failed".to_string())?;
let before_results: Vec<KnowledgeRelationship> = let before_results: Vec<KnowledgeRelationship> =
existing_before_delete.take(0).unwrap_or_default(); existing_before_delete.take(0).unwrap_or_default();
assert!( assert!(!before_results.is_empty(), "Relationship should exist before deletion");
!before_results.is_empty(),
"Relationship should exist before deletion"
);
// Delete relationship by ID
KnowledgeRelationship::delete_relationship_by_id(&relationship.id, &user_id, &db) KnowledgeRelationship::delete_relationship_by_id(&relationship.id, &user_id, &db)
.await .await
.expect("Failed to delete relationship by ID"); .with_context(|| "Failed to delete relationship by ID".to_string())?;
// Query to verify relationship was deleted
let mut result = db let mut result = db
.query(format!( .query(format!(
"SELECT * FROM relates_to WHERE metadata.user_id = '{}' AND metadata.source_id = '{}'", "SELECT * FROM relates_to WHERE metadata.user_id = '{user_id}' AND metadata.source_id = '{source_id}'"
user_id, source_id
)) ))
.await .await
.expect("Query failed"); .with_context(|| "Query failed".to_string())?;
let results: Vec<KnowledgeRelationship> = result.take(0).unwrap_or_default(); let results: Vec<KnowledgeRelationship> = result.take(0).unwrap_or_default();
// Verify relationship no longer exists
assert!(results.is_empty(), "Relationship should be deleted"); assert!(results.is_empty(), "Relationship should be deleted");
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_delete_relationship_by_id_unauthorized() { async fn test_delete_relationship_by_id_unauthorized() -> anyhow::Result<()> {
let db = setup_test_db().await; let db = setup_test_db().await;
let entity1_id = create_test_entity("Entity 1", &db).await; let entity1_id = create_test_entity("Entity 1", &db).await?;
let entity2_id = create_test_entity("Entity 2", &db).await; let entity2_id = create_test_entity("Entity 2", &db).await?;
let owner_user_id = "owner-user".to_string(); let owner_user_id = "owner-user".to_string();
let source_id = "source123".to_string(); let source_id = "source123".to_string();
@@ -373,20 +363,16 @@ mod tests {
relationship relationship
.store_relationship(&db) .store_relationship(&db)
.await .await
.expect("Failed to store relationship"); .with_context(|| "Failed to store relationship".to_string())?;
let mut before_attempt = db let mut before_attempt = db
.query(format!( .query(format!(
"SELECT * FROM relates_to WHERE metadata.user_id = '{}'", "SELECT * FROM relates_to WHERE metadata.user_id = '{owner_user_id}'"
owner_user_id
)) ))
.await .await
.expect("Query failed"); .with_context(|| "Query failed".to_string())?;
let before_results: Vec<KnowledgeRelationship> = before_attempt.take(0).unwrap_or_default(); let before_results: Vec<KnowledgeRelationship> = before_attempt.take(0).unwrap_or_default();
assert!( assert!(!before_results.is_empty(), "Relationship should exist before unauthorized delete attempt");
!before_results.is_empty(),
"Relationship should exist before unauthorized delete attempt"
);
let result = KnowledgeRelationship::delete_relationship_by_id( let result = KnowledgeRelationship::delete_relationship_by_id(
&relationship.id, &relationship.id,
@@ -397,40 +383,34 @@ mod tests {
match result { match result {
Err(AppError::Auth(_)) => {} Err(AppError::Auth(_)) => {}
_ => panic!("Expected authorization error when deleting someone else's relationship"), _ => anyhow::bail!("Expected authorization error when deleting someone else's relationship"),
} }
let mut after_attempt = db let mut after_attempt = db
.query(format!( .query(format!(
"SELECT * FROM relates_to WHERE metadata.user_id = '{}'", "SELECT * FROM relates_to WHERE metadata.user_id = '{owner_user_id}'"
owner_user_id
)) ))
.await .await
.expect("Query failed"); .with_context(|| "Query failed".to_string())?;
let results: Vec<KnowledgeRelationship> = after_attempt.take(0).unwrap_or_default(); let results: Vec<KnowledgeRelationship> = after_attempt.take(0).unwrap_or_default();
assert!( assert!(!results.is_empty(), "Relationship should still exist after unauthorized delete attempt");
!results.is_empty(),
"Relationship should still exist after unauthorized delete attempt" Ok(())
);
} }
#[tokio::test] #[tokio::test]
async fn test_store_relationship_exists() { async fn test_store_relationship_exists() -> anyhow::Result<()> {
// Setup in-memory database for testing
let db = setup_test_db().await; let db = setup_test_db().await;
// Create entities to relate let entity1_id = create_test_entity("Entity 1", &db).await?;
let entity1_id = create_test_entity("Entity 1", &db).await; let entity2_id = create_test_entity("Entity 2", &db).await?;
let entity2_id = create_test_entity("Entity 2", &db).await; let entity3_id = create_test_entity("Entity 3", &db).await?;
let entity3_id = create_test_entity("Entity 3", &db).await;
// Create relationships with the same source_id
let user_id = "user123".to_string(); let user_id = "user123".to_string();
let source_id = "source123".to_string(); let source_id = "source123".to_string();
let different_source_id = "different_source".to_string(); let different_source_id = "different_source".to_string();
// Create two relationships with the same source_id
let relationship1 = KnowledgeRelationship::new( let relationship1 = KnowledgeRelationship::new(
entity1_id.clone(), entity1_id.clone(),
entity2_id.clone(), entity2_id.clone(),
@@ -447,7 +427,6 @@ mod tests {
"contains".to_string(), "contains".to_string(),
); );
// Create a relationship with a different source_id
let different_relationship = KnowledgeRelationship::new( let different_relationship = KnowledgeRelationship::new(
entity1_id.clone(), entity1_id.clone(),
entity3_id.clone(), entity3_id.clone(),
@@ -456,21 +435,19 @@ mod tests {
"mentions".to_string(), "mentions".to_string(),
); );
// Store all relationships
relationship1 relationship1
.store_relationship(&db) .store_relationship(&db)
.await .await
.expect("Failed to store relationship 1"); .with_context(|| "Failed to store relationship 1".to_string())?;
relationship2 relationship2
.store_relationship(&db) .store_relationship(&db)
.await .await
.expect("Failed to store relationship 2"); .with_context(|| "Failed to store relationship 2".to_string())?;
different_relationship different_relationship
.store_relationship(&db) .store_relationship(&db)
.await .await
.expect("Failed to store different relationship"); .with_context(|| "Failed to store different relationship".to_string())?;
// Sanity-check setup: exactly two relationships use source_id and one uses different_source_id.
let mut before_delete = db let mut before_delete = db
.query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id") .query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id")
.bind(("source_id", source_id.clone())) .bind(("source_id", source_id.clone()))
@@ -489,31 +466,30 @@ mod tests {
before_delete_different.take(0).unwrap_or_default(); before_delete_different.take(0).unwrap_or_default();
assert_eq!(before_delete_different_rows.len(), 1); assert_eq!(before_delete_different_rows.len(), 1);
// Delete relationships by source_id
KnowledgeRelationship::delete_relationships_by_source_id(&source_id, &db) KnowledgeRelationship::delete_relationships_by_source_id(&source_id, &db)
.await .await
.expect("Failed to delete relationships by source_id"); .with_context(|| "Failed to delete relationships by source_id".to_string())?;
// Query to verify the specific relationships with source_id were deleted.
let result1 = get_relationship_by_id(&relationship1.id, &db).await; let result1 = get_relationship_by_id(&relationship1.id, &db).await;
let result2 = get_relationship_by_id(&relationship2.id, &db).await; let result2 = get_relationship_by_id(&relationship2.id, &db).await;
let different_result = get_relationship_by_id(&different_relationship.id, &db).await; let different_result = get_relationship_by_id(&different_relationship.id, &db).await;
// Verify relationships with the source_id are deleted
assert!(result1.is_none(), "Relationship 1 should be deleted"); assert!(result1.is_none(), "Relationship 1 should be deleted");
assert!(result2.is_none(), "Relationship 2 should be deleted"); assert!(result2.is_none(), "Relationship 2 should be deleted");
let remaining = let remaining =
different_result.expect("Relationship with different source_id should remain"); different_result.expect("Relationship with different source_id should remain");
assert_eq!(remaining.metadata.source_id, different_source_id); assert_eq!(remaining.metadata.source_id, different_source_id);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_delete_relationships_by_source_id_resists_query_injection() { async fn test_delete_relationships_by_source_id_resists_query_injection() -> anyhow::Result<()> {
let db = setup_test_db().await; let db = setup_test_db().await;
let entity1_id = create_test_entity("Entity 1", &db).await; let entity1_id = create_test_entity("Entity 1", &db).await?;
let entity2_id = create_test_entity("Entity 2", &db).await; let entity2_id = create_test_entity("Entity 2", &db).await?;
let entity3_id = create_test_entity("Entity 3", &db).await; let entity3_id = create_test_entity("Entity 3", &db).await?;
let safe_relationship = KnowledgeRelationship::new( let safe_relationship = KnowledgeRelationship::new(
entity1_id.clone(), entity1_id.clone(),
@@ -552,5 +528,7 @@ mod tests {
remaining_other.is_some(), remaining_other.is_some(),
"Other relationship should remain" "Other relationship should remain"
); );
Ok(())
} }
} }
+21 -23
View File
@@ -66,12 +66,12 @@ pub fn format_history(history: &[Message]) -> String {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use anyhow::{self, Context};
use super::*; use super::*;
use crate::storage::db::SurrealDbClient; use crate::storage::db::SurrealDbClient;
#[tokio::test] #[tokio::test]
async fn test_message_creation() { async fn test_message_creation() -> anyhow::Result<()> {
// Test basic message creation
let conversation_id = "test_conversation"; let conversation_id = "test_conversation";
let content = "This is a test message"; let content = "This is a test message";
let role = MessageRole::User; let role = MessageRole::User;
@@ -84,24 +84,23 @@ mod tests {
references.clone(), references.clone(),
); );
// Verify message properties
assert_eq!(message.conversation_id, conversation_id); assert_eq!(message.conversation_id, conversation_id);
assert_eq!(message.content, content); assert_eq!(message.content, content);
assert_eq!(message.role, role); assert_eq!(message.role, role);
assert_eq!(message.references, references); assert_eq!(message.references, references);
assert!(!message.id.is_empty()); assert!(!message.id.is_empty());
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_message_persistence() { async fn test_message_persistence() -> anyhow::Result<()> {
// Setup in-memory database for testing
let namespace = "test_ns"; let namespace = "test_ns";
let database = &uuid::Uuid::new_v4().to_string(); let database = &uuid::Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Create and store a message
let conversation_id = "test_conversation"; let conversation_id = "test_conversation";
let message = Message::new( let message = Message::new(
conversation_id.to_string(), conversation_id.to_string(),
@@ -111,39 +110,37 @@ mod tests {
); );
let message_id = message.id.clone(); let message_id = message.id.clone();
// Store the message
db.store_item(message.clone()) db.store_item(message.clone())
.await .await
.expect("Failed to store message"); .with_context(|| "Failed to store message".to_string())?;
// Retrieve the message
let retrieved: Option<Message> = db let retrieved: Option<Message> = db
.get_item(&message_id) .get_item(&message_id)
.await .await
.expect("Failed to retrieve message"); .with_context(|| "Failed to retrieve message".to_string())?;
assert!(retrieved.is_some()); let retrieved = retrieved.ok_or_else(|| anyhow::anyhow!("Expected message to exist"))?;
let retrieved = retrieved.unwrap();
// Verify retrieved properties match original
assert_eq!(retrieved.id, message.id); assert_eq!(retrieved.id, message.id);
assert_eq!(retrieved.conversation_id, message.conversation_id); assert_eq!(retrieved.conversation_id, message.conversation_id);
assert_eq!(retrieved.role, message.role); assert_eq!(retrieved.role, message.role);
assert_eq!(retrieved.content, message.content); assert_eq!(retrieved.content, message.content);
assert_eq!(retrieved.references, message.references); assert_eq!(retrieved.references, message.references);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_message_role_display() { async fn test_message_role_display() -> anyhow::Result<()> {
// Test the Display implementation for MessageRole
assert_eq!(format!("{}", MessageRole::User), "User"); assert_eq!(format!("{}", MessageRole::User), "User");
assert_eq!(format!("{}", MessageRole::AI), "AI"); assert_eq!(format!("{}", MessageRole::AI), "AI");
assert_eq!(format!("{}", MessageRole::System), "System"); assert_eq!(format!("{}", MessageRole::System), "System");
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_message_display() { async fn test_message_display() -> anyhow::Result<()> {
// Test the Display implementation for Message
let message = Message { let message = Message {
id: "test_id".to_string(), id: "test_id".to_string(),
created_at: Utc::now(), created_at: Utc::now(),
@@ -154,12 +151,13 @@ mod tests {
references: None, references: None,
}; };
assert_eq!(format!("{}", message), "User: Hello world"); assert_eq!(format!("{message}"), "User: Hello world");
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_format_history() { async fn test_format_history() -> anyhow::Result<()> {
// Create a vector of messages
let messages = vec![ let messages = vec![
Message { Message {
id: "1".to_string(), id: "1".to_string(),
@@ -181,10 +179,10 @@ mod tests {
}, },
]; ];
// Format the history
let formatted = format_history(&messages); let formatted = format_history(&messages);
// Verify the formatting
assert_eq!(formatted, "User: Hello\nAI: Hi there!"); assert_eq!(formatted, "User: Hello\nAI: Hi there!");
Ok(())
} }
} }
+64 -54
View File
@@ -216,20 +216,22 @@ impl Scratchpad {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use anyhow::{self, Context};
use super::*; use super::*;
#[tokio::test] #[tokio::test]
async fn test_create_scratchpad() { async fn test_create_scratchpad() -> anyhow::Result<()> {
// Setup in-memory database for testing // Setup in-memory database for testing
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Failed to apply migrations"); .with_context(|| "Failed to apply migrations".to_string())?;
// Create a new scratchpad // Create a new scratchpad
let user_id = "test_user"; let user_id = "test_user";
@@ -254,29 +256,28 @@ mod tests {
let retrieved: Option<Scratchpad> = db let retrieved: Option<Scratchpad> = db
.get_item(&scratchpad.id) .get_item(&scratchpad.id)
.await .await
.expect("Failed to retrieve scratchpad"); .with_context(|| "Failed to retrieve scratchpad".to_string())?;
assert!(retrieved.is_some()); let retrieved = retrieved.with_context(|| "expected scratchpad to exist".to_string())?;
let retrieved = retrieved.unwrap();
assert_eq!(retrieved.id, scratchpad.id); assert_eq!(retrieved.id, scratchpad.id);
assert_eq!(retrieved.user_id, user_id); assert_eq!(retrieved.user_id, user_id);
assert_eq!(retrieved.title, title); assert_eq!(retrieved.title, title);
assert!(!retrieved.is_archived); assert!(!retrieved.is_archived);
assert!(retrieved.archived_at.is_none()); assert!(retrieved.archived_at.is_none());
assert!(retrieved.ingested_at.is_none()); assert!(retrieved.ingested_at.is_none());
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_get_by_user() { async fn test_get_by_user() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Failed to apply migrations"); .with_context(|| "Failed to apply migrations".to_string())?;
let user_id = "test_user"; let user_id = "test_user";
@@ -288,19 +289,21 @@ mod tests {
// Store them // Store them
let scratchpad1_id = scratchpad1.id.clone(); let scratchpad1_id = scratchpad1.id.clone();
let scratchpad2_id = scratchpad2.id.clone(); let scratchpad2_id = scratchpad2.id.clone();
db.store_item(scratchpad1).await.unwrap(); db.store_item(scratchpad1).await.with_context(|| "store scratchpad1".to_string())?;
db.store_item(scratchpad2).await.unwrap(); db.store_item(scratchpad2).await.with_context(|| "store scratchpad2".to_string())?;
db.store_item(scratchpad3).await.unwrap(); db.store_item(scratchpad3).await.with_context(|| "store scratchpad3".to_string())?;
// Archive one of the user's scratchpads // Archive one of the user's scratchpads
Scratchpad::archive(&scratchpad2_id, user_id, &db, false) Scratchpad::archive(&scratchpad2_id, user_id, &db, false)
.await .await
.unwrap(); .with_context(|| "archive".to_string())?;
// Get scratchpads for user_id // Get scratchpads for user_id
let user_scratchpads = Scratchpad::get_by_user(user_id, &db).await.unwrap(); let user_scratchpads = Scratchpad::get_by_user(user_id, &db)
.await
.with_context(|| "get_by_user".to_string())?;
assert_eq!(user_scratchpads.len(), 1); assert_eq!(user_scratchpads.len(), 1);
assert_eq!(user_scratchpads[0].id, scratchpad1_id); assert_eq!(user_scratchpads.first().map(|s| &s.id), Some(&scratchpad1_id));
// Verify they belong to the user // Verify they belong to the user
for scratchpad in &user_scratchpads { for scratchpad in &user_scratchpads {
@@ -309,177 +312,183 @@ mod tests {
let archived = Scratchpad::get_archived_by_user(user_id, &db) let archived = Scratchpad::get_archived_by_user(user_id, &db)
.await .await
.unwrap(); .with_context(|| "get_archived_by_user".to_string())?;
assert_eq!(archived.len(), 1); assert_eq!(archived.len(), 1);
assert_eq!(archived[0].id, scratchpad2_id); assert_eq!(archived.first().map(|s| &s.id), Some(&scratchpad2_id));
assert!(archived[0].is_archived); assert!(archived.first().is_some_and(|s| s.is_archived));
assert!(archived[0].ingested_at.is_none()); assert!(archived.first().is_some_and(|s| s.ingested_at.is_none()));
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_archive_and_restore() { async fn test_archive_and_restore() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Failed to apply migrations"); .with_context(|| "Failed to apply migrations".to_string())?;
let user_id = "test_user"; let user_id = "test_user";
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string()); let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
let scratchpad_id = scratchpad.id.clone(); let scratchpad_id = scratchpad.id.clone();
db.store_item(scratchpad).await.unwrap(); db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?;
let archived = Scratchpad::archive(&scratchpad_id, user_id, &db, true) let archived = Scratchpad::archive(&scratchpad_id, user_id, &db, true)
.await .await
.expect("Failed to archive"); .with_context(|| "Failed to archive".to_string())?;
assert!(archived.is_archived); assert!(archived.is_archived);
assert!(archived.archived_at.is_some()); assert!(archived.archived_at.is_some());
assert!(archived.ingested_at.is_some()); assert!(archived.ingested_at.is_some());
let restored = Scratchpad::restore(&scratchpad_id, user_id, &db) let restored = Scratchpad::restore(&scratchpad_id, user_id, &db)
.await .await
.expect("Failed to restore"); .with_context(|| "Failed to restore".to_string())?;
assert!(!restored.is_archived); assert!(!restored.is_archived);
assert!(restored.archived_at.is_none()); assert!(restored.archived_at.is_none());
assert!(restored.ingested_at.is_none()); assert!(restored.ingested_at.is_none());
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_update_content() { async fn test_update_content() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Failed to apply migrations"); .with_context(|| "Failed to apply migrations".to_string())?;
let user_id = "test_user"; let user_id = "test_user";
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string()); let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
let scratchpad_id = scratchpad.id.clone(); let scratchpad_id = scratchpad.id.clone();
db.store_item(scratchpad).await.unwrap(); db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?;
let new_content = "Updated content"; let new_content = "Updated content";
let updated = Scratchpad::update_content(&scratchpad_id, user_id, new_content, &db) let updated = Scratchpad::update_content(&scratchpad_id, user_id, new_content, &db)
.await .await
.unwrap(); .with_context(|| "update_content".to_string())?;
assert_eq!(updated.content, new_content); assert_eq!(updated.content, new_content);
assert!(!updated.is_dirty); assert!(!updated.is_dirty);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_update_content_unauthorized() { async fn test_update_content_unauthorized() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Failed to apply migrations"); .with_context(|| "Failed to apply migrations".to_string())?;
let owner_id = "owner"; let owner_id = "owner";
let other_user = "other_user"; let other_user = "other_user";
let scratchpad = Scratchpad::new(owner_id.to_string(), "Test".to_string()); let scratchpad = Scratchpad::new(owner_id.to_string(), "Test".to_string());
let scratchpad_id = scratchpad.id.clone(); let scratchpad_id = scratchpad.id.clone();
db.store_item(scratchpad).await.unwrap(); db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?;
let result = Scratchpad::update_content(&scratchpad_id, other_user, "Hacked", &db).await; let result = Scratchpad::update_content(&scratchpad_id, other_user, "Hacked", &db).await;
assert!(result.is_err()); assert!(result.is_err());
match result { match result {
Err(AppError::Auth(_)) => {} Err(AppError::Auth(_)) => {}
_ => panic!("Expected Auth error"), _ => anyhow::bail!("Expected Auth error"),
} }
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_delete_scratchpad() { async fn test_delete_scratchpad() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Failed to apply migrations"); .with_context(|| "Failed to apply migrations".to_string())?;
let user_id = "test_user"; let user_id = "test_user";
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string()); let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
let scratchpad_id = scratchpad.id.clone(); let scratchpad_id = scratchpad.id.clone();
db.store_item(scratchpad).await.unwrap(); db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?;
// Delete should succeed // Delete should succeed
let result = Scratchpad::delete(&scratchpad_id, user_id, &db).await; let result = Scratchpad::delete(&scratchpad_id, user_id, &db).await;
assert!(result.is_ok()); assert!(result.is_ok());
// Verify it's gone // Verify it's gone
let retrieved: Option<Scratchpad> = db.get_item(&scratchpad_id).await.unwrap(); let retrieved: Option<Scratchpad> = db.get_item(&scratchpad_id).await.with_context(|| "get_item".to_string())?;
assert!(retrieved.is_none()); assert!(retrieved.is_none());
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_delete_unauthorized() { async fn test_delete_unauthorized() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Failed to apply migrations"); .with_context(|| "Failed to apply migrations".to_string())?;
let owner_id = "owner"; let owner_id = "owner";
let other_user = "other_user"; let other_user = "other_user";
let scratchpad = Scratchpad::new(owner_id.to_string(), "Test".to_string()); let scratchpad = Scratchpad::new(owner_id.to_string(), "Test".to_string());
let scratchpad_id = scratchpad.id.clone(); let scratchpad_id = scratchpad.id.clone();
db.store_item(scratchpad).await.unwrap(); db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?;
let result = Scratchpad::delete(&scratchpad_id, other_user, &db).await; let result = Scratchpad::delete(&scratchpad_id, other_user, &db).await;
assert!(result.is_err()); assert!(result.is_err());
match result { match result {
Err(AppError::Auth(_)) => {} Err(AppError::Auth(_)) => {}
_ => panic!("Expected Auth error"), _ => anyhow::bail!("Expected Auth error"),
} }
// Verify it still exists // Verify it still exists
let retrieved: Option<Scratchpad> = db.get_item(&scratchpad_id).await.unwrap(); let retrieved: Option<Scratchpad> = db.get_item(&scratchpad_id).await.with_context(|| "get_item".to_string())?;
assert!(retrieved.is_some()); assert!(retrieved.is_some());
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_timezone_aware_scratchpad_conversion() { async fn test_timezone_aware_scratchpad_conversion() -> anyhow::Result<()> {
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string()) let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
.await .await
.expect("Failed to create test database"); .with_context(|| "Failed to create test database".to_string())?;
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Failed to apply migrations"); .with_context(|| "Failed to apply migrations".to_string())?;
let user_id = "test_user_123"; let user_id = "test_user_123";
let scratchpad = let scratchpad =
Scratchpad::new(user_id.to_string(), "Test Timezone Scratchpad".to_string()); Scratchpad::new(user_id.to_string(), "Test Timezone Scratchpad".to_string());
let scratchpad_id = scratchpad.id.clone(); let scratchpad_id = scratchpad.id.clone();
db.store_item(scratchpad).await.unwrap(); db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?;
let retrieved = Scratchpad::get_by_id(&scratchpad_id, user_id, &db) let retrieved = Scratchpad::get_by_id(&scratchpad_id, user_id, &db)
.await .await
.unwrap(); .with_context(|| "get_by_id".to_string())?;
// Test that datetime fields are preserved and can be used for timezone formatting // Test that datetime fields are preserved and can be used for timezone formatting
assert!(retrieved.created_at.timestamp() > 0); assert!(retrieved.created_at.timestamp() > 0);
@@ -493,10 +502,11 @@ mod tests {
// Archive the scratchpad to test optional datetime handling // Archive the scratchpad to test optional datetime handling
let archived = Scratchpad::archive(&scratchpad_id, user_id, &db, false) let archived = Scratchpad::archive(&scratchpad_id, user_id, &db, false)
.await .await
.unwrap(); .with_context(|| "archive".to_string())?;
assert!(archived.archived_at.is_some()); assert!(archived.archived_at.is_some());
assert!(archived.archived_at.unwrap().timestamp() > 0); assert!(archived.archived_at.with_context(|| "expected archived_at".to_string())?.timestamp() > 0);
assert!(archived.ingested_at.is_none()); assert!(archived.ingested_at.is_none());
Ok(())
} }
} }
+109 -91
View File
@@ -64,7 +64,14 @@ impl SystemSettings {
let mut needs_update = false; let mut needs_update = false;
let backend_label = provider.backend_label().to_string(); let backend_label = provider.backend_label().to_string();
let provider_dimensions = provider.dimension() as u32; let provider_dimensions = u32::try_from(provider.dimension())
.unwrap_or_else(|_| {
tracing::warn!(
"Provider dimension {} exceeds u32 max; falling back to 0",
provider.dimension()
);
0u32
});
let provider_model = provider.model_code(); let provider_model = provider.model_code();
// Sync backend label // Sync backend label
@@ -107,7 +114,8 @@ impl SystemSettings {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::storage::indexes::ensure_runtime_indexes; use anyhow::{self, Context};
use crate::storage::indexes::ensure_runtime;
use crate::storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk}; use crate::storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk};
use async_openai::Client; use async_openai::Client;
@@ -118,68 +126,102 @@ mod tests {
db: &SurrealDbClient, db: &SurrealDbClient,
table_name: &str, table_name: &str,
index_name: &str, index_name: &str,
) -> u32 { ) -> anyhow::Result<u32> {
let query = format!("INFO FOR TABLE {table_name};"); let query = format!("INFO FOR TABLE {table_name};");
let mut response = db let mut response = db
.client .client
.query(query) .query(query)
.await .await
.expect("Failed to fetch table info"); .with_context(|| "Failed to fetch table info".to_string())?;
let info: surrealdb::Value = response let info: surrealdb::Value = response
.take(0) .take(0)
.expect("Failed to extract table info response"); .with_context(|| "Failed to extract table info response".to_string())?;
let info_json: serde_json::Value = let info_json: serde_json::Value =
serde_json::to_value(info).expect("Failed to convert info to json"); serde_json::to_value(info).with_context(|| "Failed to convert info to json".to_string())?;
let indexes = info_json["Object"]["indexes"]["Object"] let indexes = info_json
.as_object() .get("Object")
.unwrap_or_else(|| panic!("Indexes collection missing in table info: {info_json:#?}")); .and_then(|v| v.get("indexes"))
.and_then(|v| v.get("Object"))
.and_then(|v| v.as_object())
.with_context(|| format!("Indexes collection missing in table info: {info_json:#?}"))?;
let definition = indexes let definition = indexes
.get(index_name) .get(index_name)
.and_then(|definition| definition.get("Strand")) .and_then(|definition| definition.get("Strand"))
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.unwrap_or_else(|| panic!("Index definition not found in table info: {info_json:#?}")); .with_context(|| format!("Index definition not found in table info: {info_json:#?}"))?;
let dimension_part = definition let dimension_part = definition
.split("DIMENSION") .split("DIMENSION")
.nth(1) .nth(1)
.expect("Index definition missing DIMENSION clause"); .with_context(|| "Index definition missing DIMENSION clause".to_string())?;
let dimension_token = dimension_part let dimension_token = dimension_part
.split_whitespace() .split_whitespace()
.next() .next()
.expect("Dimension value missing in definition") .with_context(|| "Dimension value missing in definition".to_string())?
.trim_end_matches(';'); .trim_end_matches(';');
dimension_token dimension_token
.parse::<u32>() .parse::<u32>()
.expect("Dimension value is not a valid number") .with_context(|| "Dimension value is not a valid number".to_string())
}
async fn simulate_reembedding(
db: &SurrealDbClient,
target_dimension: usize,
initial_chunk: TextChunk,
) -> anyhow::Result<()> {
db.query(
"REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding;",
)
.await
.with_context(|| "remove index".to_string())?;
let define_index_query = format!(
"DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {target_dimension};"
);
db.query(define_index_query)
.await
.with_context(|| "Re-defining index should succeed".to_string())?;
let new_embedding = vec![0.5; target_dimension];
let sql = "UPSERT type::thing('text_chunk_embedding', $id) SET chunk_id = type::thing('text_chunk', $id), embedding = $embedding, user_id = $user_id;";
db.client
.query(sql)
.bind(("id", initial_chunk.id.clone()))
.bind(("user_id", initial_chunk.user_id.clone()))
.bind(("embedding", new_embedding))
.await
.with_context(|| "upsert embedding".to_string())?;
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_settings_initialization() { async fn test_settings_initialization() -> anyhow::Result<()> {
// Setup in-memory database for testing // Setup in-memory database for testing
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Test initialization of system settings // Test initialization of system settings
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Failed to apply migrations"); .with_context(|| "Failed to apply migrations".to_string())?;
let settings = SystemSettings::get_current(&db) let settings = SystemSettings::get_current(&db)
.await .await
.expect("Failed to get system settings"); .with_context(|| "Failed to get system settings".to_string())?;
// Verify initial state after initialization // Verify initial state after initialization
assert_eq!(settings.id, "current"); assert_eq!(settings.id, "current");
assert_eq!(settings.registrations_enabled, true); assert!(settings.registrations_enabled);
assert_eq!(settings.require_email_verification, false); assert!(!settings.require_email_verification);
assert_eq!(settings.query_model, "gpt-4o-mini"); assert_eq!(settings.query_model, "gpt-4o-mini");
assert_eq!(settings.processing_model, "gpt-4o-mini"); assert_eq!(settings.processing_model, "gpt-4o-mini");
assert_eq!(settings.image_processing_model, "gpt-4o-mini"); assert_eq!(settings.image_processing_model, "gpt-4o-mini");
@@ -196,10 +238,10 @@ mod tests {
// Test idempotency - ensure calling it again doesn't change anything // Test idempotency - ensure calling it again doesn't change anything
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Failed to apply migrations"); .with_context(|| "Failed to apply migrations".to_string())?;
let settings_again = SystemSettings::get_current(&db) let settings_again = SystemSettings::get_current(&db)
.await .await
.expect("Failed to get settings after initialization"); .with_context(|| "Failed to get settings after initialization".to_string())?;
assert_eq!(settings.id, settings_again.id); assert_eq!(settings.id, settings_again.id);
assert_eq!( assert_eq!(
@@ -210,48 +252,52 @@ mod tests {
settings.require_email_verification, settings.require_email_verification,
settings_again.require_email_verification settings_again.require_email_verification
); );
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_get_current_settings() { async fn test_get_current_settings() -> anyhow::Result<()> {
// Setup in-memory database for testing // Setup in-memory database for testing
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Initialize settings // Initialize settings
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Failed to apply migrations"); .with_context(|| "Failed to apply migrations".to_string())?;
// Test get_current method // Test get_current method
let settings = SystemSettings::get_current(&db) let settings = SystemSettings::get_current(&db)
.await .await
.expect("Failed to get current settings"); .with_context(|| "Failed to get current settings".to_string())?;
assert_eq!(settings.id, "current"); assert_eq!(settings.id, "current");
assert_eq!(settings.registrations_enabled, true); assert!(settings.registrations_enabled);
assert_eq!(settings.require_email_verification, false); assert!(!settings.require_email_verification);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_update_settings() { async fn test_update_settings() -> anyhow::Result<()> {
// Setup in-memory database for testing // Setup in-memory database for testing
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Initialize settings // Initialize settings
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Failed to apply migrations"); .with_context(|| "Failed to apply migrations".to_string())?;
// Create updated settings // Create updated settings
let mut updated_settings = SystemSettings::get_current(&db).await.unwrap(); let mut updated_settings = SystemSettings::get_current(&db)
.await
.with_context(|| "get_current".to_string())?;
updated_settings.id = "current".to_string(); updated_settings.id = "current".to_string();
updated_settings.registrations_enabled = false; updated_settings.registrations_enabled = false;
updated_settings.require_email_verification = true; updated_settings.require_email_verification = true;
@@ -260,31 +306,32 @@ mod tests {
// Test update method // Test update method
let result = SystemSettings::update(&db, updated_settings) let result = SystemSettings::update(&db, updated_settings)
.await .await
.expect("Failed to update settings"); .with_context(|| "Failed to update settings".to_string())?;
assert_eq!(result.id, "current"); assert_eq!(result.id, "current");
assert_eq!(result.registrations_enabled, false); assert!(!result.registrations_enabled);
assert_eq!(result.require_email_verification, true); assert!(result.require_email_verification);
assert_eq!(result.query_model, "gpt-4"); assert_eq!(result.query_model, "gpt-4");
// Verify changes persisted by getting current settings // Verify changes persisted by getting current settings
let current = SystemSettings::get_current(&db) let current = SystemSettings::get_current(&db)
.await .await
.expect("Failed to get current settings after update"); .with_context(|| "Failed to get current settings after update".to_string())?;
assert_eq!(current.registrations_enabled, false); assert!(!current.registrations_enabled);
assert_eq!(current.require_email_verification, true); assert!(current.require_email_verification);
assert_eq!(current.query_model, "gpt-4"); assert_eq!(current.query_model, "gpt-4");
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_get_current_nonexistent() { async fn test_get_current_nonexistent() -> anyhow::Result<()> {
// Setup in-memory database for testing // Setup in-memory database for testing
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Don't initialize settings and try to get them // Don't initialize settings and try to get them
let result = SystemSettings::get_current(&db).await; let result = SystemSettings::get_current(&db).await;
@@ -294,21 +341,22 @@ mod tests {
Err(AppError::NotFound(_)) => { Err(AppError::NotFound(_)) => {
// Expected error // Expected error
} }
Err(e) => panic!("Expected NotFound error, got: {:?}", e), Err(e) => anyhow::bail!("Expected NotFound error, got: {e:?}"),
Ok(_) => panic!("Expected error but got Ok"), Ok(_) => anyhow::bail!("Expected error but got Ok"),
} }
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_migration_after_changing_embedding_length() { async fn test_migration_after_changing_embedding_length() -> anyhow::Result<()> {
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string()) let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
.await .await
.expect("Failed to start DB"); .with_context(|| "Failed to start DB".to_string())?;
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536. // Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Initial migration failed"); .with_context(|| "Initial migration failed".to_string())?;
let initial_chunk = TextChunk::new( let initial_chunk = TextChunk::new(
"source1".into(), "source1".into(),
@@ -318,43 +366,11 @@ mod tests {
TextChunk::store_with_embedding(initial_chunk.clone(), vec![0.1; 1536], &db) TextChunk::store_with_embedding(initial_chunk.clone(), vec![0.1; 1536], &db)
.await .await
.expect("Failed to store initial chunk with embedding"); .with_context(|| "Failed to store initial chunk with embedding".to_string())?;
async fn simulate_reembedding(
db: &SurrealDbClient,
target_dimension: usize,
initial_chunk: TextChunk,
) {
db.query(
"REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding;",
)
.await
.unwrap();
let define_index_query = format!(
"DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};",
target_dimension
);
db.query(define_index_query)
.await
.expect("Re-defining index should succeed");
let new_embedding = vec![0.5; target_dimension];
let sql = "UPSERT type::thing('text_chunk_embedding', $id) SET chunk_id = type::thing('text_chunk', $id), embedding = $embedding, user_id = $user_id;";
let update_result = db
.client
.query(sql)
.bind(("id", initial_chunk.id.clone()))
.bind(("user_id", initial_chunk.user_id.clone()))
.bind(("embedding", new_embedding))
.await;
assert!(update_result.is_ok());
}
// Re-embed with the existing configured dimension to ensure migrations remain idempotent. // Re-embed with the existing configured dimension to ensure migrations remain idempotent.
let target_dimension = 1536usize; let target_dimension = 1536usize;
simulate_reembedding(&db, target_dimension, initial_chunk).await; simulate_reembedding(&db, target_dimension, initial_chunk).await?;
let migration_result = db.apply_migrations().await; let migration_result = db.apply_migrations().await;
@@ -363,34 +379,35 @@ mod tests {
"Migrations should not fail: {:?}", "Migrations should not fail: {:?}",
migration_result.err() migration_result.err()
); );
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_should_change_embedding_length_on_indexes_when_switching_length() { async fn test_should_change_embedding_length_on_indexes_when_switching_length() -> anyhow::Result<()> {
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string()) let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
.await .await
.expect("Failed to start DB"); .with_context(|| "Failed to start DB".to_string())?;
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536. // Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Initial migration failed"); .with_context(|| "Initial migration failed".to_string())?;
let mut current_settings = SystemSettings::get_current(&db) let mut current_settings = SystemSettings::get_current(&db)
.await .await
.expect("Failed to load current settings"); .with_context(|| "Failed to load current settings".to_string())?;
// Ensure runtime indexes exist with the current embedding dimension so INFO queries succeed. // Ensure runtime indexes exist with the current embedding dimension so INFO queries succeed.
ensure_runtime_indexes(&db, current_settings.embedding_dimensions as usize) ensure_runtime(&db, current_settings.embedding_dimensions as usize)
.await .await
.expect("failed to build runtime indexes"); .with_context(|| "failed to build runtime indexes".to_string())?;
let initial_chunk_dimension = get_hnsw_index_dimension( let initial_chunk_dimension = get_hnsw_index_dimension(
&db, &db,
"text_chunk_embedding", "text_chunk_embedding",
"idx_embedding_text_chunk_embedding", "idx_embedding_text_chunk_embedding",
) )
.await; .await?;
assert_eq!( assert_eq!(
initial_chunk_dimension, current_settings.embedding_dimensions, initial_chunk_dimension, current_settings.embedding_dimensions,
@@ -405,7 +422,7 @@ mod tests {
let updated_settings = SystemSettings::update(&db, current_settings) let updated_settings = SystemSettings::update(&db, current_settings)
.await .await
.expect("Failed to update settings"); .with_context(|| "Failed to update settings".to_string())?;
assert_eq!( assert_eq!(
updated_settings.embedding_dimensions, new_dimension, updated_settings.embedding_dimensions, new_dimension,
@@ -416,23 +433,23 @@ mod tests {
TextChunk::update_all_embeddings(&db, &openai_client, &new_model, new_dimension) TextChunk::update_all_embeddings(&db, &openai_client, &new_model, new_dimension)
.await .await
.expect("TextChunk re-embedding should succeed on fresh DB"); .with_context(|| "TextChunk re-embedding should succeed on fresh DB".to_string())?;
KnowledgeEntity::update_all_embeddings(&db, &openai_client, &new_model, new_dimension) KnowledgeEntity::update_all_embeddings(&db, &openai_client, &new_model, new_dimension)
.await .await
.expect("KnowledgeEntity re-embedding should succeed on fresh DB"); .with_context(|| "KnowledgeEntity re-embedding should succeed on fresh DB".to_string())?;
let text_chunk_dimension = get_hnsw_index_dimension( let text_chunk_dimension = get_hnsw_index_dimension(
&db, &db,
"text_chunk_embedding", "text_chunk_embedding",
"idx_embedding_text_chunk_embedding", "idx_embedding_text_chunk_embedding",
) )
.await; .await?;
let knowledge_dimension = get_hnsw_index_dimension( let knowledge_dimension = get_hnsw_index_dimension(
&db, &db,
"knowledge_entity_embedding", "knowledge_entity_embedding",
"idx_embedding_knowledge_entity_embedding", "idx_embedding_knowledge_entity_embedding",
) )
.await; .await?;
assert_eq!( assert_eq!(
text_chunk_dimension, new_dimension, text_chunk_dimension, new_dimension,
@@ -445,10 +462,11 @@ mod tests {
let persisted_settings = SystemSettings::get_current(&db) let persisted_settings = SystemSettings::get_current(&db)
.await .await
.expect("Failed to reload updated settings"); .with_context(|| "Failed to reload updated settings".to_string())?;
assert_eq!( assert_eq!(
persisted_settings.embedding_dimensions, new_dimension, persisted_settings.embedding_dimensions, new_dimension,
"Settings should persist new embedding dimension" "Settings should persist new embedding dimension"
); );
Ok(())
} }
} }
+137 -122
View File
@@ -1,4 +1,4 @@
#![allow(clippy::missing_docs_in_private_items, clippy::uninlined_format_args)] #![allow(clippy::missing_docs_in_private_items)]
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Write; use std::fmt::Write;
@@ -237,10 +237,7 @@ impl TextChunk {
new_model: &str, new_model: &str,
new_dimensions: u32, new_dimensions: u32,
) -> Result<(), AppError> { ) -> Result<(), AppError> {
info!( info!("Starting re-embedding process for all text chunks. New dimensions: {new_dimensions}");
"Starting re-embedding process for all text chunks. New dimensions: {}",
new_dimensions
);
// Fetch all chunks first // Fetch all chunks first
let all_chunks: Vec<TextChunk> = db.select(Self::table_name()).await?; let all_chunks: Vec<TextChunk> = db.select(Self::table_name()).await?;
@@ -252,7 +249,7 @@ impl TextChunk {
return Ok(()); return Ok(());
} }
info!("Found {} chunks to process.", total_chunks); info!("Found {total_chunks} chunks to process.");
// Generate all new embeddings in memory // Generate all new embeddings in memory
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new(); let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new();
@@ -276,7 +273,7 @@ impl TextChunk {
"CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.", "CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.",
chunk.id, embedding.len(), new_dimensions chunk.id, embedding.len(), new_dimensions
); );
error!("{}", err_msg); error!("{err_msg}");
return Err(AppError::InternalError(err_msg)); return Err(AppError::InternalError(err_msg));
} }
new_embeddings.insert( new_embeddings.insert(
@@ -300,6 +297,7 @@ impl TextChunk {
.join(",") .join(",")
); );
// Use the chunk id as the embedding record id to keep a 1:1 mapping // Use the chunk id as the embedding record id to keep a 1:1 mapping
let embedding = embedding_str;
write!( write!(
&mut transaction_query, &mut transaction_query,
"UPSERT type::thing('text_chunk_embedding', '{id}') SET \ "UPSERT type::thing('text_chunk_embedding', '{id}') SET \
@@ -309,18 +307,13 @@ impl TextChunk {
user_id = '{user_id}', \ user_id = '{user_id}', \
created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \ created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \
updated_at = time::now();", updated_at = time::now();",
id = id,
embedding = embedding_str,
user_id = user_id,
source_id = source_id
) )
.map_err(|e| AppError::InternalError(e.to_string()))?; .map_err(|e| AppError::InternalError(e.to_string()))?;
} }
write!( write!(
&mut transaction_query, &mut transaction_query,
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};", "DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};",
new_dimensions
) )
.map_err(|e| AppError::InternalError(e.to_string()))?; .map_err(|e| AppError::InternalError(e.to_string()))?;
@@ -377,7 +370,7 @@ impl TextChunk {
"CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.", "CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.",
chunk.id, embedding.len(), new_dimensions chunk.id, embedding.len(), new_dimensions
); );
error!("{}", err_msg); error!("{err_msg}");
return Err(AppError::InternalError(err_msg)); return Err(AppError::InternalError(err_msg));
} }
new_embeddings.insert( new_embeddings.insert(
@@ -422,6 +415,7 @@ impl TextChunk {
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(",") .join(",")
); );
let embedding = embedding_str;
write!( write!(
&mut transaction_query, &mut transaction_query,
"CREATE type::thing('text_chunk_embedding', '{id}') SET \ "CREATE type::thing('text_chunk_embedding', '{id}') SET \
@@ -431,18 +425,13 @@ impl TextChunk {
user_id = '{user_id}', \ user_id = '{user_id}', \
created_at = time::now(), \ created_at = time::now(), \
updated_at = time::now();", updated_at = time::now();",
id = id,
embedding = embedding_str,
user_id = user_id,
source_id = source_id
) )
.map_err(|e| AppError::InternalError(e.to_string()))?; .map_err(|e| AppError::InternalError(e.to_string()))?;
} }
write!( write!(
&mut transaction_query, &mut transaction_query,
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};", "DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};",
new_dimensions
) )
.map_err(|e| AppError::InternalError(e.to_string()))?; .map_err(|e| AppError::InternalError(e.to_string()))?;
@@ -462,20 +451,21 @@ impl TextChunk {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use anyhow::{self, Context};
use super::*; use super::*;
use crate::storage::indexes::{ensure_runtime_indexes, rebuild_indexes}; use crate::storage::indexes::{ensure_runtime, rebuild};
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding; use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
use surrealdb::RecordId; use surrealdb::RecordId;
use uuid::Uuid; use uuid::Uuid;
async fn ensure_chunk_fts_index(db: &SurrealDbClient) { async fn ensure_chunk_fts_index(db: &SurrealDbClient) -> anyhow::Result<()> {
let snowball_sql = r#" let snowball_sql = r#"
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english); DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english);
DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25; DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25;
"#; "#;
if let Err(err) = db.client.query(snowball_sql).await { if let Err(err) = db.client.query(snowball_sql).await {
// Fall back to ascii-only analyzer when snowball is unavailable in the build.
let fallback_sql = r#" let fallback_sql = r#"
DEFINE ANALYZER OVERWRITE app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii; DEFINE ANALYZER OVERWRITE app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii;
DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25; DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25;
@@ -484,12 +474,13 @@ mod tests {
db.client db.client
.query(fallback_sql) .query(fallback_sql)
.await .await
.unwrap_or_else(|_| panic!("define chunk fts index fallback: {err}")); .with_context(|| format!("define chunk fts index fallback: {err}"))?;
} }
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_text_chunk_creation() { async fn test_text_chunk_creation() -> anyhow::Result<()> {
let source_id = "source123".to_string(); let source_id = "source123".to_string();
let chunk = "This is a text chunk for testing embeddings".to_string(); let chunk = "This is a text chunk for testing embeddings".to_string();
let user_id = "user123".to_string(); let user_id = "user123".to_string();
@@ -500,22 +491,23 @@ mod tests {
assert_eq!(text_chunk.chunk, chunk); assert_eq!(text_chunk.chunk, chunk);
assert_eq!(text_chunk.user_id, user_id); assert_eq!(text_chunk.user_id, user_id);
assert!(!text_chunk.id.is_empty()); assert!(!text_chunk.id.is_empty());
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_delete_by_source_id() { async fn test_delete_by_source_id() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations().await.expect("migrations"); db.apply_migrations().await.with_context(|| "migrations".to_string())?;
let source_id = "source123".to_string(); let source_id = "source123".to_string();
let user_id = "user123".to_string(); let user_id = "user123".to_string();
TextChunkEmbedding::redefine_hnsw_index(&db, 5) TextChunkEmbedding::redefine_hnsw_index(&db, 5)
.await .await
.expect("redefine index"); .with_context(|| "redefine index".to_string())?;
let chunk1 = TextChunk::new( let chunk1 = TextChunk::new(
source_id.clone(), source_id.clone(),
@@ -535,61 +527,63 @@ mod tests {
TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db) TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
.await .await
.expect("store chunk1"); .with_context(|| "store chunk1".to_string())?;
TextChunk::store_with_embedding(chunk2.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db) TextChunk::store_with_embedding(chunk2.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
.await .await
.expect("store chunk2"); .with_context(|| "store chunk2".to_string())?;
TextChunk::store_with_embedding( TextChunk::store_with_embedding(
different_chunk.clone(), different_chunk.clone(),
vec![0.1, 0.2, 0.3, 0.4, 0.5], vec![0.1, 0.2, 0.3, 0.4, 0.5],
&db, &db,
) )
.await .await
.expect("store different chunk"); .with_context(|| "store different chunk".to_string())?;
TextChunk::delete_by_source_id(&source_id, &db) TextChunk::delete_by_source_id(&source_id, &db)
.await .await
.expect("Failed to delete chunks by source_id"); .with_context(|| "Failed to delete chunks by source_id".to_string())?;
let remaining: Vec<TextChunk> = db let remaining: Vec<TextChunk> = db
.client .client
.query(format!( .query(format!(
"SELECT * FROM {} WHERE source_id = '{}'", "SELECT * FROM {} WHERE source_id = '{source_id}'",
TextChunk::table_name(), TextChunk::table_name(),
source_id
)) ))
.await .await
.expect("Query failed") .with_context(|| "Query failed".to_string())?
.take(0) .take(0)
.expect("Failed to get query results"); .with_context(|| "Failed to get query results".to_string())?;
assert_eq!(remaining.len(), 0); assert_eq!(remaining.len(), 0);
let different_remaining: Vec<TextChunk> = db let different_remaining: Vec<TextChunk> = db
.client .client
.query(format!( .query(format!(
"SELECT * FROM {} WHERE source_id = '{}'", "SELECT * FROM {} WHERE source_id = 'different_source'",
TextChunk::table_name(), TextChunk::table_name(),
"different_source"
)) ))
.await .await
.expect("Query failed") .with_context(|| "Query failed".to_string())?
.take(0) .take(0)
.expect("Failed to get query results"); .with_context(|| "Failed to get query results".to_string())?;
assert_eq!(different_remaining.len(), 1); assert_eq!(different_remaining.len(), 1);
assert_eq!(different_remaining[0].id, different_chunk.id); assert_eq!(
different_remaining.first().map(|r| &r.id),
Some(&different_chunk.id)
);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_delete_by_nonexistent_source_id() { async fn test_delete_by_nonexistent_source_id() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations().await.expect("migrations"); db.apply_migrations().await.with_context(|| "migrations".to_string())?;
TextChunkEmbedding::redefine_hnsw_index(&db, 5) TextChunkEmbedding::redefine_hnsw_index(&db, 5)
.await .await
.expect("redefine index"); .with_context(|| "redefine index".to_string())?;
let real_source_id = "real_source".to_string(); let real_source_id = "real_source".to_string();
let chunk = TextChunk::new( let chunk = TextChunk::new(
@@ -600,24 +594,24 @@ mod tests {
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db) TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
.await .await
.expect("store chunk"); .with_context(|| "store chunk".to_string())?;
TextChunk::delete_by_source_id("nonexistent_source", &db) TextChunk::delete_by_source_id("nonexistent_source", &db)
.await .await
.expect("Delete should succeed"); .with_context(|| "Delete should succeed".to_string())?;
let remaining: Vec<TextChunk> = db let remaining: Vec<TextChunk> = db
.client .client
.query(format!( .query(format!(
"SELECT * FROM {} WHERE source_id = '{}'", "SELECT * FROM {} WHERE source_id = '{real_source_id}'",
TextChunk::table_name(), TextChunk::table_name(),
real_source_id
)) ))
.await .await
.expect("Query failed") .with_context(|| "Query failed".to_string())?
.take(0) .take(0)
.expect("Failed to get query results"); .with_context(|| "Failed to get query results".to_string())?;
assert_eq!(remaining.len(), 1); assert_eq!(remaining.len(), 1);
Ok(())
} }
#[tokio::test] #[tokio::test]
@@ -672,13 +666,13 @@ mod tests {
} }
#[tokio::test] #[tokio::test]
async fn test_store_with_embedding_creates_both_records() { async fn test_store_with_embedding_creates_both_records() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations().await.expect("migrations"); db.apply_migrations().await.with_context(|| "migrations".to_string())?;
let source_id = "store-src".to_string(); let source_id = "store-src".to_string();
let user_id = "user_store".to_string(); let user_id = "user_store".to_string();
@@ -686,43 +680,43 @@ mod tests {
TextChunkEmbedding::redefine_hnsw_index(&db, 3) TextChunkEmbedding::redefine_hnsw_index(&db, 3)
.await .await
.expect("redefine index"); .with_context(|| "redefine index".to_string())?;
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db) TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
.await .await
.expect("store with embedding"); .with_context(|| "store with embedding".to_string())?;
let stored_chunk: Option<TextChunk> = db.get_item(&chunk.id).await.unwrap(); let stored_chunk: Option<TextChunk> = db.get_item(&chunk.id)
assert!(stored_chunk.is_some()); .await
let stored_chunk = stored_chunk.unwrap(); .with_context(|| "get_item".to_string())?;
let stored_chunk = stored_chunk.with_context(|| "expected stored chunk".to_string())?;
assert_eq!(stored_chunk.source_id, source_id); assert_eq!(stored_chunk.source_id, source_id);
assert_eq!(stored_chunk.user_id, user_id); assert_eq!(stored_chunk.user_id, user_id);
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id); let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db) let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db)
.await .await
.expect("get embedding"); .with_context(|| "get embedding".to_string())?
assert!(embedding.is_some()); .with_context(|| "expected embedding".to_string())?;
let embedding = embedding.unwrap();
assert_eq!(embedding.chunk_id, rid); assert_eq!(embedding.chunk_id, rid);
assert_eq!(embedding.user_id, user_id); assert_eq!(embedding.user_id, user_id);
assert_eq!(embedding.source_id, source_id); assert_eq!(embedding.source_id, source_id);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_store_with_embedding_with_runtime_indexes() { async fn test_store_with_embedding_with_runtime_indexes() -> anyhow::Result<()> {
let namespace = "test_ns_runtime"; let namespace = "test_ns_runtime";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations().await.expect("migrations"); db.apply_migrations().await.with_context(|| "migrations".to_string())?;
// Ensure runtime indexes are built with the expected dimension.
let embedding_dimension = 3usize; let embedding_dimension = 3usize;
ensure_runtime_indexes(&db, embedding_dimension) ensure_runtime(&db, embedding_dimension)
.await .await
.expect("ensure runtime indexes"); .with_context(|| "ensure runtime indexes".to_string())?;
let chunk = TextChunk::new( let chunk = TextChunk::new(
"runtime_src".to_string(), "runtime_src".to_string(),
@@ -732,55 +726,60 @@ mod tests {
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db) TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
.await .await
.expect("store with embedding"); .with_context(|| "store with embedding".to_string())?;
let stored_chunk: Option<TextChunk> = db.get_item(&chunk.id).await.unwrap(); let stored_chunk: Option<TextChunk> = db.get_item(&chunk.id)
assert!(stored_chunk.is_some(), "chunk should be stored"); .await
.with_context(|| "get_item".to_string())?;
let stored_chunk = stored_chunk.with_context(|| "chunk should be stored".to_string())?;
assert!(stored_chunk.id == chunk.id, "chunk should be stored");
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id); let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db) let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db)
.await .await
.expect("get embedding"); .with_context(|| "get embedding".to_string())?
assert!(embedding.is_some(), "embedding should exist"); .with_context(|| "embedding should exist".to_string())?;
assert_eq!( assert_eq!(
embedding.unwrap().embedding.len(), embedding.embedding.len(),
embedding_dimension, embedding_dimension,
"embedding dimension should match runtime index" "embedding dimension should match runtime index"
); );
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_vector_search_returns_empty_when_no_embeddings() { async fn test_vector_search_returns_empty_when_no_embeddings() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations().await.expect("migrations"); db.apply_migrations().await.with_context(|| "migrations".to_string())?;
TextChunkEmbedding::redefine_hnsw_index(&db, 3) TextChunkEmbedding::redefine_hnsw_index(&db, 3)
.await .await
.expect("redefine index"); .with_context(|| "redefine index".to_string())?;
let results: Vec<TextChunkSearchResult> = let results: Vec<TextChunkSearchResult> =
TextChunk::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user") TextChunk::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user")
.await .await
.unwrap(); .with_context(|| "vector_search".to_string())?;
assert!(results.is_empty()); assert!(results.is_empty());
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_vector_search_single_result() { async fn test_vector_search_single_result() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations().await.expect("migrations"); db.apply_migrations().await.with_context(|| "migrations".to_string())?;
TextChunkEmbedding::redefine_hnsw_index(&db, 3) TextChunkEmbedding::redefine_hnsw_index(&db, 3)
.await .await
.expect("redefine index"); .with_context(|| "redefine index".to_string())?;
let source_id = "src".to_string(); let source_id = "src".to_string();
let user_id = "user".to_string(); let user_id = "user".to_string();
@@ -792,32 +791,33 @@ mod tests {
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db) TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
.await .await
.expect("store"); .with_context(|| "store".to_string())?;
let results: Vec<TextChunkSearchResult> = let results: Vec<TextChunkSearchResult> =
TextChunk::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id) TextChunk::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
.await .await
.unwrap(); .with_context(|| "vector_search".to_string())?;
assert_eq!(results.len(), 1); assert_eq!(results.len(), 1);
let res = &results[0]; let res = results.first().context("expected first result")?;
assert_eq!(res.chunk.id, chunk.id); assert_eq!(res.chunk.id, chunk.id);
assert_eq!(res.chunk.source_id, source_id); assert_eq!(res.chunk.source_id, source_id);
assert_eq!(res.chunk.chunk, "hello world"); assert_eq!(res.chunk.chunk, "hello world");
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_vector_search_orders_by_similarity() { async fn test_vector_search_orders_by_similarity() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations().await.expect("migrations"); db.apply_migrations().await.with_context(|| "migrations".to_string())?;
TextChunkEmbedding::redefine_hnsw_index(&db, 3) TextChunkEmbedding::redefine_hnsw_index(&db, 3)
.await .await
.expect("redefine index"); .with_context(|| "redefine index".to_string())?;
let user_id = "user".to_string(); let user_id = "user".to_string();
let chunk1 = TextChunk::new("s1".to_string(), "chunk one".to_string(), user_id.clone()); let chunk1 = TextChunk::new("s1".to_string(), "chunk one".to_string(), user_id.clone());
@@ -825,49 +825,59 @@ mod tests {
TextChunk::store_with_embedding(chunk1.clone(), vec![1.0, 0.0, 0.0], &db) TextChunk::store_with_embedding(chunk1.clone(), vec![1.0, 0.0, 0.0], &db)
.await .await
.expect("store chunk1"); .with_context(|| "store chunk1".to_string())?;
TextChunk::store_with_embedding(chunk2.clone(), vec![0.0, 1.0, 0.0], &db) TextChunk::store_with_embedding(chunk2.clone(), vec![0.0, 1.0, 0.0], &db)
.await .await
.expect("store chunk2"); .with_context(|| "store chunk2".to_string())?;
let results: Vec<TextChunkSearchResult> = let results: Vec<TextChunkSearchResult> =
TextChunk::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id) TextChunk::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id)
.await .await
.unwrap(); .with_context(|| "vector_search".to_string())?;
assert_eq!(results.len(), 2); assert_eq!(results.len(), 2);
assert_eq!(results[0].chunk.id, chunk2.id); assert_eq!(
assert_eq!(results[1].chunk.id, chunk1.id); results.first().map(|r| &r.chunk.id),
assert!(results[0].score >= results[1].score); Some(&chunk2.id)
);
assert_eq!(
results.get(1).map(|r| &r.chunk.id),
Some(&chunk1.id)
);
let r0 = results.first().context("expected first result")?;
let r1 = results.get(1).context("expected second result")?;
assert!(r0.score >= r1.score);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_fts_search_returns_empty_when_no_chunks() { async fn test_fts_search_returns_empty_when_no_chunks() -> anyhow::Result<()> {
let namespace = "fts_chunk_ns_empty"; let namespace = "fts_chunk_ns_empty";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations().await.expect("migrations"); db.apply_migrations().await.with_context(|| "migrations".to_string())?;
ensure_chunk_fts_index(&db).await; ensure_chunk_fts_index(&db).await?;
rebuild_indexes(&db).await.expect("rebuild indexes"); rebuild(&db).await.with_context(|| "rebuild indexes".to_string())?;
let results = TextChunk::fts_search(5, "hello", &db, "user") let results = TextChunk::fts_search(5, "hello", &db, "user")
.await .await
.expect("fts search"); .with_context(|| "fts search".to_string())?;
assert!(results.is_empty()); assert!(results.is_empty());
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_fts_search_single_result() { async fn test_fts_search_single_result() -> anyhow::Result<()> {
let namespace = "fts_chunk_ns_single"; let namespace = "fts_chunk_ns_single";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations().await.expect("migrations"); db.apply_migrations().await.with_context(|| "migrations".to_string())?;
ensure_chunk_fts_index(&db).await; ensure_chunk_fts_index(&db).await?;
let user_id = "fts_user"; let user_id = "fts_user";
let chunk = TextChunk::new( let chunk = TextChunk::new(
@@ -875,27 +885,29 @@ mod tests {
"rustaceans love rust".to_string(), "rustaceans love rust".to_string(),
user_id.to_string(), user_id.to_string(),
); );
db.store_item(chunk.clone()).await.expect("store chunk"); db.store_item(chunk.clone()).await.with_context(|| "store chunk".to_string())?;
rebuild_indexes(&db).await.expect("rebuild indexes"); rebuild(&db).await.with_context(|| "rebuild indexes".to_string())?;
let results = TextChunk::fts_search(3, "rust", &db, user_id) let results = TextChunk::fts_search(3, "rust", &db, user_id)
.await .await
.expect("fts search"); .with_context(|| "fts search".to_string())?;
assert_eq!(results.len(), 1); assert_eq!(results.len(), 1);
assert_eq!(results[0].chunk.id, chunk.id); let r0 = results.first().context("expected first result")?;
assert!(results[0].score.is_finite(), "expected a finite FTS score"); assert_eq!(r0.chunk.id, chunk.id);
assert!(r0.score.is_finite(), "expected a finite FTS score");
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_fts_search_orders_by_score_and_filters_user() { async fn test_fts_search_orders_by_score_and_filters_user() -> anyhow::Result<()> {
let namespace = "fts_chunk_ns_order"; let namespace = "fts_chunk_ns_order";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations().await.expect("migrations"); db.apply_migrations().await.with_context(|| "migrations".to_string())?;
ensure_chunk_fts_index(&db).await; ensure_chunk_fts_index(&db).await?;
let user_id = "fts_user_order"; let user_id = "fts_user_order";
let high_score_chunk = TextChunk::new( let high_score_chunk = TextChunk::new(
@@ -916,18 +928,18 @@ mod tests {
db.store_item(high_score_chunk.clone()) db.store_item(high_score_chunk.clone())
.await .await
.expect("store high score chunk"); .with_context(|| "store high score chunk".to_string())?;
db.store_item(low_score_chunk.clone()) db.store_item(low_score_chunk.clone())
.await .await
.expect("store low score chunk"); .with_context(|| "store low score chunk".to_string())?;
db.store_item(other_user_chunk) db.store_item(other_user_chunk)
.await .await
.expect("store other user chunk"); .with_context(|| "store other user chunk".to_string())?;
rebuild_indexes(&db).await.expect("rebuild indexes"); rebuild(&db).await.with_context(|| "rebuild indexes".to_string())?;
let results = TextChunk::fts_search(3, "apple", &db, user_id) let results = TextChunk::fts_search(3, "apple", &db, user_id)
.await .await
.expect("fts search"); .with_context(|| "fts search".to_string())?;
assert_eq!(results.len(), 2); assert_eq!(results.len(), 2);
let ids: Vec<_> = results.iter().map(|r| r.chunk.id.as_str()).collect(); let ids: Vec<_> = results.iter().map(|r| r.chunk.id.as_str()).collect();
@@ -936,9 +948,12 @@ mod tests {
&& ids.contains(&low_score_chunk.id.as_str()), && ids.contains(&low_score_chunk.id.as_str()),
"expected only the two chunks for the same user" "expected only the two chunks for the same user"
); );
let r0 = results.first().context("expected first result")?;
let r1 = results.get(1).context("expected second result")?;
assert!( assert!(
results[0].score >= results[1].score, r0.score >= r1.score,
"expected results ordered by descending score" "expected results ordered by descending score"
); );
Ok(())
} }
} }
@@ -126,24 +126,26 @@ impl TextChunkEmbedding {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use anyhow::{self, Context};
use super::*; use super::*;
use crate::storage::db::SurrealDbClient; use crate::storage::db::SurrealDbClient;
use surrealdb::Value as SurrealValue; use surrealdb::Value as SurrealValue;
use uuid::Uuid; use uuid::Uuid;
/// Helper to create an in-memory DB and apply migrations /// Helper to create an in-memory DB and apply migrations
async fn setup_test_db() -> SurrealDbClient { async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = Uuid::new_v4().to_string(); let database = Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, &database) let db = SurrealDbClient::memory(namespace, &database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Failed to apply migrations"); .with_context(|| "Failed to apply migrations".to_string())?;
db Ok(db)
} }
/// Helper: create a text_chunk with a known key, return its RecordId /// Helper: create a text_chunk with a known key, return its RecordId
@@ -152,7 +154,7 @@ mod tests {
key: &str, key: &str,
source_id: &str, source_id: &str,
user_id: &str, user_id: &str,
) -> RecordId { ) -> anyhow::Result<RecordId> {
let chunk = TextChunk { let chunk = TextChunk {
id: key.to_owned(), id: key.to_owned(),
created_at: Utc::now(), created_at: Utc::now(),
@@ -164,21 +166,42 @@ mod tests {
db.store_item(chunk) db.store_item(chunk)
.await .await
.expect("Failed to create text_chunk"); .with_context(|| "Failed to create text_chunk".to_string())?;
RecordId::from_table_key(TextChunk::table_name(), key) Ok(RecordId::from_table_key(TextChunk::table_name(), key))
}
async fn get_idx_sql(db: &SurrealDbClient) -> anyhow::Result<String> {
let mut info_res = db
.client
.query("INFO FOR TABLE text_chunk_embedding;")
.await
.with_context(|| "info query failed".to_string())?;
let info: SurrealValue = info_res.take(0).with_context(|| "failed to take info result".to_string())?;
let info_json: serde_json::Value =
serde_json::to_value(info).with_context(|| "failed to convert info to json".to_string())?;
let idx_sql = info_json
.get("Object")
.and_then(|v| v.get("indexes"))
.and_then(|v| v.get("Object"))
.and_then(|v| v.get("idx_embedding_text_chunk_embedding"))
.and_then(|v| v.get("Strand"))
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string();
Ok(idx_sql)
} }
#[tokio::test] #[tokio::test]
async fn test_create_and_get_by_chunk_id() { async fn test_create_and_get_by_chunk_id() -> anyhow::Result<()> {
let db = setup_test_db().await; let db = setup_test_db().await?;
let user_id = "user_a"; let user_id = "user_a";
let chunk_key = "chunk-123"; let chunk_key = "chunk-123";
let source_id = "source-1"; let source_id = "source-1";
// 1) Create a text_chunk with a known key // 1) Create a text_chunk with a known key
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await; let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await?;
// 2) Create and store an embedding for that chunk // 2) Create and store an embedding for that chunk
let embedding_vec = vec![0.1_f32, 0.2, 0.3]; let embedding_vec = vec![0.1_f32, 0.2, 0.3];
@@ -191,39 +214,37 @@ mod tests {
TextChunkEmbedding::redefine_hnsw_index(&db, emb.embedding.len()) TextChunkEmbedding::redefine_hnsw_index(&db, emb.embedding.len())
.await .await
.expect("Failed to redefine index length"); .with_context(|| "Failed to redefine index length".to_string())?;
let _: Option<TextChunkEmbedding> = db let _: Option<TextChunkEmbedding> = db
.client .client
.create(TextChunkEmbedding::table_name()) .create(TextChunkEmbedding::table_name())
.content(emb) .content(emb)
.await .await
.expect("Failed to store embedding") .with_context(|| "Failed to store embedding".to_string())?
.take() .with_context(|| "Failed to deserialize stored embedding".to_string())?;
.expect("Failed to deserialize stored embedding");
// 3) Fetch it via get_by_chunk_id // 3) Fetch it via get_by_chunk_id
let fetched = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db) let fetched = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
.await .await
.expect("Failed to get embedding by chunk_id"); .with_context(|| "Failed to get embedding by chunk_id".to_string())?
.with_context(|| "Expected an embedding to be found".to_string())?;
assert!(fetched.is_some(), "Expected an embedding to be found");
let fetched = fetched.unwrap();
assert_eq!(fetched.user_id, user_id); assert_eq!(fetched.user_id, user_id);
assert_eq!(fetched.chunk_id, chunk_rid); assert_eq!(fetched.chunk_id, chunk_rid);
assert_eq!(fetched.embedding, embedding_vec); assert_eq!(fetched.embedding, embedding_vec);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_delete_by_chunk_id() { async fn test_delete_by_chunk_id() -> anyhow::Result<()> {
let db = setup_test_db().await; let db = setup_test_db().await?;
let user_id = "user_b"; let user_id = "user_b";
let chunk_key = "chunk-delete"; let chunk_key = "chunk-delete";
let source_id = "source-del"; let source_id = "source-del";
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await; let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await?;
let emb = TextChunkEmbedding::new( let emb = TextChunkEmbedding::new(
chunk_key, chunk_key,
@@ -234,50 +255,50 @@ mod tests {
TextChunkEmbedding::redefine_hnsw_index(&db, emb.embedding.len()) TextChunkEmbedding::redefine_hnsw_index(&db, emb.embedding.len())
.await .await
.expect("Failed to redefine index length"); .with_context(|| "Failed to redefine index length".to_string())?;
let _: Option<TextChunkEmbedding> = db let _: Option<TextChunkEmbedding> = db
.client .client
.create(TextChunkEmbedding::table_name()) .create(TextChunkEmbedding::table_name())
.content(emb) .content(emb)
.await .await
.expect("Failed to store embedding") .with_context(|| "Failed to store embedding".to_string())?
.take() .with_context(|| "Failed to deserialize stored embedding".to_string())?;
.expect("Failed to deserialize stored embedding");
// Ensure it exists // Ensure it exists
let existing = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db) let existing = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
.await .await
.expect("Failed to get embedding before delete"); .with_context(|| "Failed to get embedding before delete".to_string())?;
assert!(existing.is_some(), "Embedding should exist before delete"); assert!(existing.is_some(), "Embedding should exist before delete");
// Delete by chunk_id // Delete by chunk_id
TextChunkEmbedding::delete_by_chunk_id(&chunk_rid, &db) TextChunkEmbedding::delete_by_chunk_id(&chunk_rid, &db)
.await .await
.expect("Failed to delete by chunk_id"); .with_context(|| "Failed to delete by chunk_id".to_string())?;
// Ensure it no longer exists // Ensure it no longer exists
let after = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db) let after = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
.await .await
.expect("Failed to get embedding after delete"); .with_context(|| "Failed to get embedding after delete".to_string())?;
assert!(after.is_none(), "Embedding should have been deleted"); assert!(after.is_none(), "Embedding should have been deleted");
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_delete_by_source_id() { async fn test_delete_by_source_id() -> anyhow::Result<()> {
let db = setup_test_db().await; let db = setup_test_db().await?;
let user_id = "user_c"; let user_id = "user_c";
let source_id = "shared-source"; let source_id = "shared-source";
let other_source = "other-source"; let other_source = "other-source";
// Two chunks with the same source_id // Two chunks with the same source_id
let chunk1_rid = create_text_chunk_with_id(&db, "chunk-s1", source_id, user_id).await; let chunk1_rid = create_text_chunk_with_id(&db, "chunk-s1", source_id, user_id).await?;
let chunk2_rid = create_text_chunk_with_id(&db, "chunk-s2", source_id, user_id).await; let chunk2_rid = create_text_chunk_with_id(&db, "chunk-s2", source_id, user_id).await?;
// One chunk with a different source_id // One chunk with a different source_id
let chunk_other_rid = let chunk_other_rid =
create_text_chunk_with_id(&db, "chunk-other", other_source, user_id).await; create_text_chunk_with_id(&db, "chunk-other", other_source, user_id).await?;
// Create embeddings for all three // Create embeddings for all three
let emb1 = TextChunkEmbedding::new( let emb1 = TextChunkEmbedding::new(
@@ -302,7 +323,7 @@ mod tests {
// Update length on index // Update length on index
TextChunkEmbedding::redefine_hnsw_index(&db, emb1.embedding.len()) TextChunkEmbedding::redefine_hnsw_index(&db, emb1.embedding.len())
.await .await
.expect("Failed to redefine index length"); .with_context(|| "Failed to redefine index length".to_string())?;
for emb in [emb1, emb2, emb3] { for emb in [emb1, emb2, emb3] {
let _: Option<TextChunkEmbedding> = db let _: Option<TextChunkEmbedding> = db
@@ -310,102 +331,82 @@ mod tests {
.create(TextChunkEmbedding::table_name()) .create(TextChunkEmbedding::table_name())
.content(emb) .content(emb)
.await .await
.expect("Failed to store embedding") .with_context(|| "Failed to store embedding".to_string())?
.take() .with_context(|| "Failed to deserialize stored embedding".to_string())?;
.expect("Failed to deserialize stored embedding");
} }
// Sanity check: they all exist // Sanity check: they all exist
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db) assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
.await .await
.unwrap() .with_context(|| "get chunk1".to_string())?
.is_some()); .is_some());
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db) assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
.await .await
.unwrap() .with_context(|| "get chunk2".to_string())?
.is_some()); .is_some());
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db) assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
.await .await
.unwrap() .with_context(|| "get chunk_other".to_string())?
.is_some()); .is_some());
// Delete embeddings by source_id (shared-source) // Delete embeddings by source_id (shared-source)
TextChunkEmbedding::delete_by_source_id(source_id, &db) TextChunkEmbedding::delete_by_source_id(source_id, &db)
.await .await
.expect("Failed to delete by source_id"); .with_context(|| "Failed to delete by source_id".to_string())?;
// Chunks from shared-source should have no embeddings // Chunks from shared-source should have no embeddings
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db) assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
.await .await
.unwrap() .with_context(|| "check chunk1".to_string())?
.is_none()); .is_none());
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db) assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
.await .await
.unwrap() .with_context(|| "check chunk2".to_string())?
.is_none()); .is_none());
// The other chunk should still have its embedding // The other chunk should still have its embedding
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db) assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
.await .await
.unwrap() .with_context(|| "check chunk_other".to_string())?
.is_some()); .is_some());
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_redefine_hnsw_index_updates_dimension() { async fn test_redefine_hnsw_index_updates_dimension() -> anyhow::Result<()> {
let db = setup_test_db().await; let db = setup_test_db().await?;
// Change the index dimension from default (1536) to a smaller test value. // Change the index dimension from default (1536) to a smaller test value.
TextChunkEmbedding::redefine_hnsw_index(&db, 8) TextChunkEmbedding::redefine_hnsw_index(&db, 8)
.await .await
.expect("failed to redefine index"); .with_context(|| "failed to redefine index".to_string())?;
let mut info_res = db let idx_sql = get_idx_sql(&db).await?;
.client
.query("INFO FOR TABLE text_chunk_embedding;")
.await
.expect("info query failed");
let info: SurrealValue = info_res.take(0).expect("failed to take info result");
let info_json: serde_json::Value =
serde_json::to_value(info).expect("failed to convert info to json");
let idx_sql = info_json["Object"]["indexes"]["Object"]
["idx_embedding_text_chunk_embedding"]["Strand"]
.as_str()
.unwrap_or_default();
assert!( assert!(
idx_sql.contains("DIMENSION 8"), idx_sql.contains("DIMENSION 8"),
"expected index definition to contain new dimension, got: {idx_sql}" "expected index definition to contain new dimension, got: {idx_sql}"
); );
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_redefine_hnsw_index_is_idempotent() { async fn test_redefine_hnsw_index_is_idempotent() -> anyhow::Result<()> {
let db = setup_test_db().await; let db = setup_test_db().await?;
TextChunkEmbedding::redefine_hnsw_index(&db, 4) TextChunkEmbedding::redefine_hnsw_index(&db, 4)
.await .await
.expect("first redefine failed"); .with_context(|| "first redefine failed".to_string())?;
TextChunkEmbedding::redefine_hnsw_index(&db, 4) TextChunkEmbedding::redefine_hnsw_index(&db, 4)
.await .await
.expect("second redefine failed"); .with_context(|| "second redefine failed".to_string())?;
let mut info_res = db let idx_sql = get_idx_sql(&db).await?;
.client
.query("INFO FOR TABLE text_chunk_embedding;")
.await
.expect("info query failed");
let info: SurrealValue = info_res.take(0).expect("failed to take info result");
let info_json: serde_json::Value =
serde_json::to_value(info).expect("failed to convert info to json");
let idx_sql = info_json["Object"]["indexes"]["Object"]
["idx_embedding_text_chunk_embedding"]["Strand"]
.as_str()
.unwrap_or_default();
assert!( assert!(
idx_sql.contains("DIMENSION 4"), idx_sql.contains("DIMENSION 4"),
"expected index definition to retain dimension 4, got: {idx_sql}" "expected index definition to retain dimension 4, got: {idx_sql}"
); );
Ok(())
} }
} }
+25 -21
View File
@@ -185,10 +185,12 @@ impl TextContent {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use anyhow::{self, Context};
use super::*; use super::*;
#[tokio::test] #[tokio::test]
async fn test_text_content_creation() { async fn test_text_content_creation() -> anyhow::Result<()> {
// Test basic object creation // Test basic object creation
let text = "Test content text".to_string(); let text = "Test content text".to_string();
let context = "Test context".to_string(); let context = "Test context".to_string();
@@ -212,10 +214,11 @@ mod tests {
assert!(text_content.file_info.is_none()); assert!(text_content.file_info.is_none());
assert!(text_content.url_info.is_none()); assert!(text_content.url_info.is_none());
assert!(!text_content.id.is_empty()); assert!(!text_content.id.is_empty());
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_text_content_with_url() { async fn test_text_content_with_url() -> anyhow::Result<()> {
// Test creating with URL // Test creating with URL
let text = "Content with URL".to_string(); let text = "Content with URL".to_string();
let context = "URL context".to_string(); let context = "URL context".to_string();
@@ -232,26 +235,27 @@ mod tests {
}); });
let text_content = TextContent::new( let text_content = TextContent::new(
text.clone(), text,
Some(context.clone()), Some(context),
category.clone(), category,
None, None,
url_info.clone(), url_info.clone(),
user_id.clone(), user_id,
); );
// Check URL field is set // Check URL field is set
assert_eq!(text_content.url_info, url_info); assert_eq!(text_content.url_info, url_info);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_text_content_patch() { async fn test_text_content_patch() -> anyhow::Result<()> {
// Setup in-memory database for testing // Setup in-memory database for testing
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Create initial text content // Create initial text content
let initial_text = "Initial text".to_string(); let initial_text = "Initial text".to_string();
@@ -272,7 +276,7 @@ mod tests {
let stored: Option<TextContent> = db let stored: Option<TextContent> = db
.store_item(text_content.clone()) .store_item(text_content.clone())
.await .await
.expect("Failed to store text content"); .with_context(|| "Failed to store text content".to_string())?;
assert!(stored.is_some()); assert!(stored.is_some());
// New values for patch // New values for patch
@@ -283,31 +287,30 @@ mod tests {
// Apply the patch // Apply the patch
TextContent::patch(&text_content.id, new_context, new_category, new_text, &db) TextContent::patch(&text_content.id, new_context, new_category, new_text, &db)
.await .await
.expect("Failed to patch text content"); .with_context(|| "Failed to patch text content".to_string())?;
// Retrieve the updated content // Retrieve the updated content
let updated: Option<TextContent> = db let updated: Option<TextContent> = db
.get_item(&text_content.id) .get_item(&text_content.id)
.await .await
.expect("Failed to get updated text content"); .with_context(|| "Failed to get updated text content".to_string())?;
assert!(updated.is_some()); let updated_content = updated.with_context(|| "expected updated content".to_string())?;
let updated_content = updated.unwrap();
// Verify the updates // Verify the updates
assert_eq!(updated_content.context, Some(new_context.to_string())); assert_eq!(updated_content.context, Some(new_context.to_string()));
assert_eq!(updated_content.category, new_category); assert_eq!(updated_content.category, new_category);
assert_eq!(updated_content.text, new_text); assert_eq!(updated_content.text, new_text);
assert!(updated_content.updated_at > text_content.updated_at); assert!(updated_content.updated_at > text_content.updated_at);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_has_other_with_file_detects_shared_usage() { async fn test_has_other_with_file_detects_shared_usage() -> anyhow::Result<()> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
let user_id = "user123".to_string(); let user_id = "user123".to_string();
let file_info = FileInfo { let file_info = FileInfo {
@@ -340,24 +343,25 @@ mod tests {
db.store_item(content_a.clone()) db.store_item(content_a.clone())
.await .await
.expect("Failed to store first content"); .with_context(|| "Failed to store first content".to_string())?;
db.store_item(content_b.clone()) db.store_item(content_b.clone())
.await .await
.expect("Failed to store second content"); .with_context(|| "Failed to store second content".to_string())?;
let has_other = TextContent::has_other_with_file(&file_info.id, &content_a.id, &db) let has_other = TextContent::has_other_with_file(&file_info.id, &content_a.id, &db)
.await .await
.expect("Failed to check for shared file usage"); .with_context(|| "Failed to check for shared file usage".to_string())?;
assert!(has_other); assert!(has_other);
let _removed: Option<TextContent> = db let _removed: Option<TextContent> = db
.delete_item(&content_b.id) .delete_item(&content_b.id)
.await .await
.expect("Failed to delete second content"); .with_context(|| "Failed to delete second content".to_string())?;
let has_other_after = TextContent::has_other_with_file(&file_info.id, &content_a.id, &db) let has_other_after = TextContent::has_other_with_file(&file_info.id, &content_a.id, &db)
.await .await
.expect("Failed to check shared usage after delete"); .with_context(|| "Failed to check shared usage after delete".to_string())?;
assert!(!has_other_after); assert!(!has_other_after);
Ok(())
} }
} }
+108 -95
View File
@@ -723,30 +723,32 @@ impl User {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use anyhow::{self, Context};
use super::*; use super::*;
use crate::storage::types::ingestion_payload::IngestionPayload; use crate::storage::types::ingestion_payload::IngestionPayload;
use crate::storage::types::ingestion_task::{IngestionTask, TaskState, MAX_ATTEMPTS}; use crate::storage::types::ingestion_task::{IngestionTask, TaskState, MAX_ATTEMPTS};
use std::collections::HashSet; use std::collections::HashSet;
// Helper function to set up a test database with SystemSettings // Helper function to set up a test database with SystemSettings
async fn setup_test_db() -> SurrealDbClient { async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = Uuid::new_v4().to_string(); let database = Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, &database) let db = SurrealDbClient::memory(namespace, &database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations() db.apply_migrations()
.await .await
.expect("Failed to setup the migrations"); .with_context(|| "Failed to setup the migrations".to_string())?;
db Ok(db)
} }
#[tokio::test] #[tokio::test]
async fn test_user_creation() { async fn test_user_creation() -> anyhow::Result<()> {
// Setup test database // Setup test database
let db = setup_test_db().await; let db = setup_test_db().await?;
// Create a user // Create a user
let email = "test@example.com"; let email = "test@example.com";
@@ -761,7 +763,7 @@ mod tests {
"system".to_string(), "system".to_string(),
) )
.await .await
.expect("Failed to create user"); .with_context(|| "Failed to create user".to_string())?;
// Verify user properties // Verify user properties
assert!(!user.id.is_empty()); assert!(!user.id.is_empty());
@@ -774,18 +776,17 @@ mod tests {
let retrieved: Option<User> = db let retrieved: Option<User> = db
.get_item(&user.id) .get_item(&user.id)
.await .await
.expect("Failed to retrieve user"); .with_context(|| "Failed to retrieve user".to_string())?;
assert!(retrieved.is_some()); let retrieved = retrieved.with_context(|| "expected user to exist".to_string())?;
let retrieved = retrieved.unwrap();
assert_eq!(retrieved.id, user.id); assert_eq!(retrieved.id, user.id);
assert_eq!(retrieved.email, email); assert_eq!(retrieved.email, email);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_user_authentication() { async fn test_user_authentication() -> anyhow::Result<()> {
// Setup test database // Setup test database
let db = setup_test_db().await; let db = setup_test_db().await?;
// Create a user // Create a user
let email = "auth_test@example.com"; let email = "auth_test@example.com";
@@ -799,7 +800,7 @@ mod tests {
"system".to_string(), "system".to_string(),
) )
.await .await
.expect("Failed to create user"); .with_context(|| "Failed to create user".to_string())?;
// Test successful authentication // Test successful authentication
let auth_result = User::authenticate(email, password, &db).await; let auth_result = User::authenticate(email, password, &db).await;
@@ -812,11 +813,12 @@ mod tests {
// Test failed authentication with non-existent user // Test failed authentication with non-existent user
let nonexistent = User::authenticate("nonexistent@example.com", password, &db).await; let nonexistent = User::authenticate("nonexistent@example.com", password, &db).await;
assert!(nonexistent.is_err()); assert!(nonexistent.is_err());
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_get_unfinished_ingestion_tasks_filters_correctly() { async fn test_get_unfinished_ingestion_tasks_filters_correctly() -> anyhow::Result<()> {
let db = setup_test_db().await; let db = setup_test_db().await?;
let user_id = "unfinished_user"; let user_id = "unfinished_user";
let other_user_id = "other_user"; let other_user_id = "other_user";
@@ -830,14 +832,14 @@ mod tests {
let created_task = IngestionTask::new(payload.clone(), user_id.to_string()); let created_task = IngestionTask::new(payload.clone(), user_id.to_string());
db.store_item(created_task.clone()) db.store_item(created_task.clone())
.await .await
.expect("Failed to store created task"); .with_context(|| "Failed to store created task".to_string())?;
let mut processing_task = IngestionTask::new(payload.clone(), user_id.to_string()); let mut processing_task = IngestionTask::new(payload.clone(), user_id.to_string());
processing_task.state = TaskState::Processing; processing_task.state = TaskState::Processing;
processing_task.attempts = 1; processing_task.attempts = 1;
db.store_item(processing_task.clone()) db.store_item(processing_task.clone())
.await .await
.expect("Failed to store processing task"); .with_context(|| "Failed to store processing task".to_string())?;
let mut failed_retry_task = IngestionTask::new(payload.clone(), user_id.to_string()); let mut failed_retry_task = IngestionTask::new(payload.clone(), user_id.to_string());
failed_retry_task.state = TaskState::Failed; failed_retry_task.state = TaskState::Failed;
@@ -845,7 +847,7 @@ mod tests {
failed_retry_task.scheduled_at = chrono::Utc::now() - chrono::Duration::minutes(5); failed_retry_task.scheduled_at = chrono::Utc::now() - chrono::Duration::minutes(5);
db.store_item(failed_retry_task.clone()) db.store_item(failed_retry_task.clone())
.await .await
.expect("Failed to store retryable failed task"); .with_context(|| "Failed to store retryable failed task".to_string())?;
let mut failed_blocked_task = IngestionTask::new(payload.clone(), user_id.to_string()); let mut failed_blocked_task = IngestionTask::new(payload.clone(), user_id.to_string());
failed_blocked_task.state = TaskState::Failed; failed_blocked_task.state = TaskState::Failed;
@@ -853,13 +855,13 @@ mod tests {
failed_blocked_task.error_message = Some("Too many failures".into()); failed_blocked_task.error_message = Some("Too many failures".into());
db.store_item(failed_blocked_task.clone()) db.store_item(failed_blocked_task.clone())
.await .await
.expect("Failed to store blocked task"); .with_context(|| "Failed to store blocked task".to_string())?;
let mut completed_task = IngestionTask::new(payload.clone(), user_id.to_string()); let mut completed_task = IngestionTask::new(payload.clone(), user_id.to_string());
completed_task.state = TaskState::Succeeded; completed_task.state = TaskState::Succeeded;
db.store_item(completed_task.clone()) db.store_item(completed_task.clone())
.await .await
.expect("Failed to store completed task"); .with_context(|| "Failed to store completed task".to_string())?;
let other_payload = IngestionPayload::Text { let other_payload = IngestionPayload::Text {
text: "Other".to_string(), text: "Other".to_string(),
@@ -870,11 +872,11 @@ mod tests {
let other_task = IngestionTask::new(other_payload, other_user_id.to_string()); let other_task = IngestionTask::new(other_payload, other_user_id.to_string());
db.store_item(other_task) db.store_item(other_task)
.await .await
.expect("Failed to store other user task"); .with_context(|| "Failed to store other user task".to_string())?;
let unfinished = User::get_unfinished_ingestion_tasks(user_id, &db) let unfinished = User::get_unfinished_ingestion_tasks(user_id, &db)
.await .await
.expect("Failed to fetch unfinished tasks"); .with_context(|| "Failed to fetch unfinished tasks".to_string())?;
let unfinished_ids: HashSet<String> = let unfinished_ids: HashSet<String> =
unfinished.iter().map(|task| task.id.clone()).collect(); unfinished.iter().map(|task| task.id.clone()).collect();
@@ -885,11 +887,12 @@ mod tests {
assert!(!unfinished_ids.contains(&failed_blocked_task.id)); assert!(!unfinished_ids.contains(&failed_blocked_task.id));
assert!(!unfinished_ids.contains(&completed_task.id)); assert!(!unfinished_ids.contains(&completed_task.id));
assert_eq!(unfinished_ids.len(), 3); assert_eq!(unfinished_ids.len(), 3);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_get_all_ingestion_tasks_returns_sorted() { async fn test_get_all_ingestion_tasks_returns_sorted() -> anyhow::Result<()> {
let db = setup_test_db().await; let db = setup_test_db().await?;
let user_id = "archive_user"; let user_id = "archive_user";
let other_user_id = "other_user"; let other_user_id = "other_user";
@@ -902,15 +905,15 @@ mod tests {
// Oldest task // Oldest task
let mut first = IngestionTask::new(payload.clone(), user_id.to_string()); let mut first = IngestionTask::new(payload.clone(), user_id.to_string());
first.created_at = first.created_at - chrono::Duration::minutes(1); first.created_at -= chrono::Duration::minutes(1);
first.updated_at = first.created_at; first.updated_at = first.created_at;
first.state = TaskState::Succeeded; first.state = TaskState::Succeeded;
db.store_item(first.clone()).await.expect("store first"); db.store_item(first.clone()).await.with_context(|| "store first".to_string())?;
// Latest task // Latest task
let mut second = IngestionTask::new(payload.clone(), user_id.to_string()); let mut second = IngestionTask::new(payload.clone(), user_id.to_string());
second.state = TaskState::Processing; second.state = TaskState::Processing;
db.store_item(second.clone()).await.expect("store second"); db.store_item(second.clone()).await.with_context(|| "store second".to_string())?;
let other_payload = IngestionPayload::Text { let other_payload = IngestionPayload::Text {
text: "Other".to_string(), text: "Other".to_string(),
@@ -919,21 +922,22 @@ mod tests {
user_id: other_user_id.to_string(), user_id: other_user_id.to_string(),
}; };
let other_task = IngestionTask::new(other_payload, other_user_id.to_string()); let other_task = IngestionTask::new(other_payload, other_user_id.to_string());
db.store_item(other_task).await.expect("store other"); db.store_item(other_task).await.with_context(|| "store other".to_string())?;
let tasks = User::get_all_ingestion_tasks(user_id, &db) let tasks = User::get_all_ingestion_tasks(user_id, &db)
.await .await
.expect("fetch all tasks"); .with_context(|| "fetch all tasks".to_string())?;
assert_eq!(tasks.len(), 2); assert_eq!(tasks.len(), 2);
assert_eq!(tasks[0].id, second.id); // newest first assert_eq!(tasks.first().map(|t| &t.id), Some(&second.id)); // newest first
assert_eq!(tasks[1].id, first.id); assert_eq!(tasks.get(1).map(|t| &t.id), Some(&first.id));
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_find_by_email() { async fn test_find_by_email() -> anyhow::Result<()> {
// Setup test database // Setup test database
let db = setup_test_db().await; let db = setup_test_db().await?;
// Create a user // Create a user
let email = "find_test@example.com"; let email = "find_test@example.com";
@@ -947,28 +951,28 @@ mod tests {
"system".to_string(), "system".to_string(),
) )
.await .await
.expect("Failed to create user"); .with_context(|| "Failed to create user".to_string())?;
// Test finding user by email // Test finding user by email
let found_user = User::find_by_email(email, &db) let found_user = User::find_by_email(email, &db)
.await .await
.expect("Error searching for user"); .with_context(|| "Error searching for user".to_string())?
assert!(found_user.is_some()); .with_context(|| "expected user to exist".to_string())?;
let found_user = found_user.unwrap();
assert_eq!(found_user.id, created_user.id); assert_eq!(found_user.id, created_user.id);
assert_eq!(found_user.email, email); assert_eq!(found_user.email, email);
// Test finding non-existent user // Test finding non-existent user
let not_found = User::find_by_email("nonexistent@example.com", &db) let not_found = User::find_by_email("nonexistent@example.com", &db)
.await .await
.expect("Error searching for user"); .with_context(|| "Error searching for user".to_string())?;
assert!(not_found.is_none()); assert!(not_found.is_none());
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_api_key_management() { async fn test_api_key_management() -> anyhow::Result<()> {
// Setup test database // Setup test database
let db = setup_test_db().await; let db = setup_test_db().await?;
// Create a user // Create a user
let email = "apikey_test@example.com"; let email = "apikey_test@example.com";
@@ -982,7 +986,7 @@ mod tests {
"system".to_string(), "system".to_string(),
) )
.await .await
.expect("Failed to create user"); .with_context(|| "Failed to create user".to_string())?;
// Initially, user should have no API key // Initially, user should have no API key
assert!(user.api_key.is_none()); assert!(user.api_key.is_none());
@@ -990,7 +994,7 @@ mod tests {
// Generate API key // Generate API key
let api_key = User::set_api_key(&user.id, &db) let api_key = User::set_api_key(&user.id, &db)
.await .await
.expect("Failed to set API key"); .with_context(|| "Failed to set API key".to_string())?;
assert!(!api_key.is_empty()); assert!(!api_key.is_empty());
assert!(api_key.starts_with("sk_")); assert!(api_key.starts_with("sk_"));
@@ -998,38 +1002,36 @@ mod tests {
let updated_user: Option<User> = db let updated_user: Option<User> = db
.get_item(&user.id) .get_item(&user.id)
.await .await
.expect("Failed to retrieve user"); .with_context(|| "Failed to retrieve user".to_string())?;
assert!(updated_user.is_some()); let updated_user = updated_user.with_context(|| "expected updated user".to_string())?;
let updated_user = updated_user.unwrap();
assert_eq!(updated_user.api_key, Some(api_key.clone())); assert_eq!(updated_user.api_key, Some(api_key.clone()));
// Test finding user by API key // Test finding user by API key
let found_user = User::find_by_api_key(&api_key, &db) let found_user = User::find_by_api_key(&api_key, &db)
.await .await
.expect("Error searching by API key"); .with_context(|| "Error searching by API key".to_string())?
assert!(found_user.is_some()); .with_context(|| "expected user found by api key".to_string())?;
let found_user = found_user.unwrap();
assert_eq!(found_user.id, user.id); assert_eq!(found_user.id, user.id);
// Revoke API key // Revoke API key
User::revoke_api_key(&user.id, &db) User::revoke_api_key(&user.id, &db)
.await .await
.expect("Failed to revoke API key"); .with_context(|| "Failed to revoke API key".to_string())?;
// Verify API key was revoked // Verify API key was revoked
let revoked_user: Option<User> = db let revoked_user: Option<User> = db
.get_item(&user.id) .get_item(&user.id)
.await .await
.expect("Failed to retrieve user"); .with_context(|| "Failed to retrieve user".to_string())?;
assert!(revoked_user.is_some()); let revoked_user = revoked_user.with_context(|| "expected revoked user".to_string())?;
let revoked_user = revoked_user.unwrap();
assert!(revoked_user.api_key.is_none()); assert!(revoked_user.api_key.is_none());
// Test searching by revoked API key // Test searching by revoked API key
let not_found = User::find_by_api_key(&api_key, &db) let not_found = User::find_by_api_key(&api_key, &db)
.await .await
.expect("Error searching by API key"); .with_context(|| "Error searching by API key".to_string())?;
assert!(not_found.is_none()); assert!(not_found.is_none());
Ok(())
} }
#[tokio::test] #[tokio::test]
@@ -1069,9 +1071,9 @@ mod tests {
} }
#[tokio::test] #[tokio::test]
async fn test_password_update() { async fn test_password_update() -> anyhow::Result<()> {
// Setup test database // Setup test database
let db = setup_test_db().await; let db = setup_test_db().await?;
// Create a user // Create a user
let email = "pwd_test@example.com"; let email = "pwd_test@example.com";
@@ -1086,7 +1088,7 @@ mod tests {
"system".to_string(), "system".to_string(),
) )
.await .await
.expect("Failed to create user"); .with_context(|| "Failed to create user".to_string())?;
// Authenticate with old password // Authenticate with old password
let auth_result = User::authenticate(email, old_password, &db).await; let auth_result = User::authenticate(email, old_password, &db).await;
@@ -1095,7 +1097,7 @@ mod tests {
// Update password // Update password
User::patch_password(email, new_password, &db) User::patch_password(email, new_password, &db)
.await .await
.expect("Failed to update password"); .with_context(|| "Failed to update password".to_string())?;
// Old password should no longer work // Old password should no longer work
let old_auth = User::authenticate(email, old_password, &db).await; let old_auth = User::authenticate(email, old_password, &db).await;
@@ -1104,10 +1106,11 @@ mod tests {
// New password should work // New password should work
let new_auth = User::authenticate(email, new_password, &db).await; let new_auth = User::authenticate(email, new_password, &db).await;
assert!(new_auth.is_ok()); assert!(new_auth.is_ok());
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_validate_timezone() { async fn test_validate_timezone() -> anyhow::Result<()> {
// Valid timezones should be accepted as-is // Valid timezones should be accepted as-is
assert_eq!(validate_timezone("America/New_York"), "America/New_York"); assert_eq!(validate_timezone("America/New_York"), "America/New_York");
assert_eq!(validate_timezone("Europe/London"), "Europe/London"); assert_eq!(validate_timezone("Europe/London"), "Europe/London");
@@ -1117,12 +1120,13 @@ mod tests {
// Invalid timezones should be replaced with UTC // Invalid timezones should be replaced with UTC
assert_eq!(validate_timezone("Invalid/Timezone"), "UTC"); assert_eq!(validate_timezone("Invalid/Timezone"), "UTC");
assert_eq!(validate_timezone("Not_Real"), "UTC"); assert_eq!(validate_timezone("Not_Real"), "UTC");
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_timezone_update() { async fn test_timezone_update() -> anyhow::Result<()> {
// Setup test database // Setup test database
let db = setup_test_db().await; let db = setup_test_db().await?;
// Create user with default timezone // Create user with default timezone
let email = "timezone_test@example.com"; let email = "timezone_test@example.com";
@@ -1134,7 +1138,7 @@ mod tests {
"system".to_string(), "system".to_string(),
) )
.await .await
.expect("Failed to create user"); .with_context(|| "Failed to create user".to_string())?;
assert_eq!(user.timezone, "UTC"); assert_eq!(user.timezone, "UTC");
@@ -1142,58 +1146,61 @@ mod tests {
let new_timezone = "Europe/Paris"; let new_timezone = "Europe/Paris";
User::update_timezone(&user.id, new_timezone, &db) User::update_timezone(&user.id, new_timezone, &db)
.await .await
.expect("Failed to update timezone"); .with_context(|| "Failed to update timezone".to_string())?;
// Verify timezone was updated // Verify timezone was updated
let updated_user: Option<User> = db let updated_user: Option<User> = db
.get_item(&user.id) .get_item(&user.id)
.await .await
.expect("Failed to retrieve user"); .with_context(|| "Failed to retrieve user".to_string())?;
assert!(updated_user.is_some()); let updated_user = updated_user.with_context(|| "expected updated user".to_string())?;
let updated_user = updated_user.unwrap();
assert_eq!(updated_user.timezone, new_timezone); assert_eq!(updated_user.timezone, new_timezone);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_conversations_order() { async fn test_conversations_order() -> anyhow::Result<()> {
let db = setup_test_db().await; let db = setup_test_db().await?;
let user_id = "user_order_test"; let user_id = "user_order_test";
// Create conversations with varying updated_at timestamps // Create conversations with varying updated_at timestamps
let mut conversations = Vec::new(); let mut conversations = Vec::new();
for i in 0..5 { for i in 0..5 {
let mut conv = Conversation::new(user_id.to_string(), format!("Conv {}", i)); let mut conv = Conversation::new(user_id.to_string(), format!("Conv {i}"));
// Fake updated_at i minutes apart // Fake updated_at i minutes apart
conv.created_at = chrono::Utc::now() - chrono::Duration::minutes(i); conv.created_at = chrono::Utc::now() - chrono::Duration::minutes(i);
db.store_item(conv.clone()) db.store_item(conv.clone())
.await .await
.expect("Failed to store conversation"); .with_context(|| "Failed to store conversation".to_string())?;
conversations.push(conv); conversations.push(conv);
} }
// Retrieve via get_user_conversations - should be ordered by updated_at DESC // Retrieve via get_user_conversations - should be ordered by updated_at DESC
let retrieved = User::get_user_conversations(user_id, &db) let retrieved = User::get_user_conversations(user_id, &db)
.await .await
.expect("Failed to get conversations"); .with_context(|| "Failed to get conversations".to_string())?;
assert_eq!(retrieved.len(), conversations.len()); assert_eq!(retrieved.len(), conversations.len());
for window in retrieved.windows(2) { for pair in retrieved.windows(2) {
// Assert each earlier conversation has updated_at >= later conversation let a = pair.first().context("expected first in pair")?;
let b = pair.get(1).context("expected second in pair")?;
assert!( assert!(
window[0].created_at >= window[1].created_at, a.created_at >= b.created_at,
"Conversations not ordered descending by created_at" "Conversations not ordered descending by created_at"
); );
} }
// Check first conversation title matches the most recently updated // Check first conversation title matches the most recently updated
let most_recent = conversations.iter().max_by_key(|c| c.created_at).unwrap(); let most_recent = conversations.iter().max_by_key(|c| c.created_at).context("expected most recent")?;
assert_eq!(retrieved[0].id, most_recent.id); let r0 = retrieved.first().context("expected first result")?;
assert_eq!(r0.id, most_recent.id);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_get_latest_text_contents_returns_last_five() { async fn test_get_latest_text_contents_returns_last_five() -> anyhow::Result<()> {
let db = setup_test_db().await; let db = setup_test_db().await?;
let user_id = "latest_text_user"; let user_id = "latest_text_user";
let mut inserted_ids = Vec::new(); let mut inserted_ids = Vec::new();
@@ -1201,8 +1208,8 @@ mod tests {
for i in 0..12 { for i in 0..12 {
let mut item = TextContent::new( let mut item = TextContent::new(
format!("Text {}", i), format!("Text {i}"),
Some(format!("Context {}", i)), Some(format!("Context {i}")),
"Category".to_string(), "Category".to_string(),
None, None,
None, None,
@@ -1215,18 +1222,19 @@ mod tests {
db.store_item(item.clone()) db.store_item(item.clone())
.await .await
.expect("Failed to store text content"); .with_context(|| "Failed to store text content".to_string())?;
inserted_ids.push(item.id.clone()); inserted_ids.push(item.id.clone());
} }
let latest = User::get_latest_text_contents(user_id, &db) let latest = User::get_latest_text_contents(user_id, &db)
.await .await
.expect("Failed to fetch latest text contents"); .with_context(|| "Failed to fetch latest text contents".to_string())?;
assert_eq!(latest.len(), 5, "Expected exactly five items"); assert_eq!(latest.len(), 5, "Expected exactly five items");
let mut expected_ids = inserted_ids[inserted_ids.len() - 5..].to_vec(); let start = inserted_ids.len().saturating_sub(5);
let mut expected_ids = inserted_ids.get(start..).unwrap_or_default().to_vec();
expected_ids.reverse(); expected_ids.reverse();
let returned_ids: Vec<String> = latest.iter().map(|item| item.id.clone()).collect(); let returned_ids: Vec<String> = latest.iter().map(|item| item.id.clone()).collect();
@@ -1235,25 +1243,29 @@ mod tests {
"Latest items did not match expectation" "Latest items did not match expectation"
); );
for window in latest.windows(2) { for pair in latest.windows(2) {
let a = pair.first().context("expected first in pair")?;
let b = pair.get(1).context("expected second in pair")?;
assert!( assert!(
window[0].created_at >= window[1].created_at, a.created_at >= b.created_at,
"Results are not ordered by created_at descending" "Results are not ordered by created_at descending"
); );
} }
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_validate_theme() { async fn test_validate_theme() -> anyhow::Result<()> {
assert_eq!(validate_theme("light"), Theme::Light); assert_eq!(validate_theme("light"), Theme::Light);
assert_eq!(validate_theme("dark"), Theme::Dark); assert_eq!(validate_theme("dark"), Theme::Dark);
assert_eq!(validate_theme("system"), Theme::System); assert_eq!(validate_theme("system"), Theme::System);
assert_eq!(validate_theme("invalid"), Theme::System); assert_eq!(validate_theme("invalid"), Theme::System);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_theme_update() { async fn test_theme_update() -> anyhow::Result<()> {
let db = setup_test_db().await; let db = setup_test_db().await?;
let email = "theme_test@example.com"; let email = "theme_test@example.com";
let user = User::create_new( let user = User::create_new(
email.to_string(), email.to_string(),
@@ -1263,30 +1275,31 @@ mod tests {
"system".to_string(), "system".to_string(),
) )
.await .await
.expect("Failed to create user"); .with_context(|| "Failed to create user".to_string())?;
assert_eq!(user.theme, Theme::System); assert_eq!(user.theme, Theme::System);
User::update_theme(&user.id, "dark", &db) User::update_theme(&user.id, "dark", &db)
.await .await
.expect("update theme"); .with_context(|| "update theme".to_string())?;
let updated = db let updated = db
.get_item::<User>(&user.id) .get_item::<User>(&user.id)
.await .await
.expect("get user") .with_context(|| "get user".to_string())?
.unwrap(); .with_context(|| "expected user".to_string())?;
assert_eq!(updated.theme, Theme::Dark); assert_eq!(updated.theme, Theme::Dark);
// Invalid theme should default to system (but update_theme calls validate_theme) // Invalid theme should default to system (but update_theme calls validate_theme)
User::update_theme(&user.id, "invalid", &db) User::update_theme(&user.id, "invalid", &db)
.await .await
.expect("update theme invalid"); .with_context(|| "update theme invalid".to_string())?;
let updated2 = db let updated2 = db
.get_item::<User>(&user.id) .get_item::<User>(&user.id)
.await .await
.expect("get user") .with_context(|| "get user".to_string())?
.unwrap(); .with_context(|| "expected user".to_string())?;
assert_eq!(updated2.theme, Theme::System); assert_eq!(updated2.theme, Theme::System);
Ok(())
} }
} }
+3 -3
View File
@@ -28,8 +28,8 @@ fn default_storage_kind() -> StorageKind {
StorageKind::Local StorageKind::Local
} }
fn default_s3_region() -> Option<String> { fn default_s3_region() -> String {
Some("us-east-1".to_string()) "us-east-1".to_string()
} }
/// Selects the strategy used for PDF ingestion. /// Selects the strategy used for PDF ingestion.
@@ -69,7 +69,7 @@ pub struct AppConfig {
#[serde(default)] #[serde(default)]
pub s3_endpoint: Option<String>, pub s3_endpoint: Option<String>,
#[serde(default = "default_s3_region")] #[serde(default = "default_s3_region")]
pub s3_region: Option<String>, pub s3_region: String,
#[serde(default = "default_pdf_ingest_mode")] #[serde(default = "default_pdf_ingest_mode")]
pub pdf_ingest_mode: PdfIngestMode, pub pdf_ingest_mode: PdfIngestMode,
#[serde(default = "default_reranking_enabled")] #[serde(default = "default_reranking_enabled")]
+2 -2
View File
@@ -14,7 +14,7 @@ use common::utils::config::get_config;
use common::{ use common::{
storage::{ storage::{
db::SurrealDbClient, db::SurrealDbClient,
store::{DynStore, StorageManager}, store::{DynStorage, StorageManager},
types::{ingestion_payload::IngestionPayload, ingestion_task::IngestionTask, StoredObject}, types::{ingestion_payload::IngestionPayload, ingestion_task::IngestionTask, StoredObject},
}, },
utils::config::{AppConfig, StorageKind}, utils::config::{AppConfig, StorageKind},
@@ -432,7 +432,7 @@ async fn ingest_paragraph_batch(
storage: StorageKind::Memory, storage: StorageKind::Memory,
..Default::default() ..Default::default()
}; };
let backend: DynStore = Arc::new(InMemory::new()); let backend: DynStorage = Arc::new(InMemory::new());
let storage = StorageManager::with_backend(backend, StorageKind::Memory); let storage = StorageManager::with_backend(backend, StorageKind::Memory);
let pipeline_config = ingestion_config.clone(); let pipeline_config = ingestion_config.clone();
+1 -1
View File
@@ -861,7 +861,7 @@ mod tests {
let question = CorpusQuestion { let question = CorpusQuestion {
question_id: "q1".to_string(), question_id: "q1".to_string(),
paragraph_id: paragraph_one.paragraph_id.clone(), paragraph_id: paragraph_one.paragraph_id.clone(),
text_content_id: text_content_id, text_content_id,
question_text: "What is this?".to_string(), question_text: "What is this?".to_string(),
answers: vec!["Hello".to_string()], answers: vec!["Hello".to_string()],
is_impossible: false, is_impossible: false,
+2 -2
View File
@@ -1,5 +1,5 @@
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use common::storage::{db::SurrealDbClient, indexes::ensure_runtime_indexes}; use common::storage::{db::SurrealDbClient, indexes::ensure_runtime};
use tracing::info; use tracing::info;
// Helper functions for index management during namespace reseed // Helper functions for index management during namespace reseed
@@ -11,7 +11,7 @@ pub async fn remove_all_indexes(db: &SurrealDbClient) -> Result<()> {
pub async fn recreate_indexes(db: &SurrealDbClient, dimension: usize) -> Result<()> { pub async fn recreate_indexes(db: &SurrealDbClient, dimension: usize) -> Result<()> {
info!("Recreating ALL indexes after namespace reseed via shared runtime helper"); info!("Recreating ALL indexes after namespace reseed via shared runtime helper");
ensure_runtime_indexes(db, dimension) ensure_runtime(db, dimension)
.await .await
.context("creating runtime indexes") .context("creating runtime indexes")
} }
+2 -2
View File
@@ -13,7 +13,7 @@ use common::{
utils::embedding::EmbeddingProvider, utils::embedding::EmbeddingProvider,
}; };
use retrieval_pipeline::{ use retrieval_pipeline::{
pipeline::{PipelineStageTimings, RetrievalConfig}, pipeline::{StageTimings, RetrievalConfig},
reranking::RerankerPool, reranking::RerankerPool,
}; };
@@ -56,7 +56,7 @@ pub(super) struct EvaluationContext<'a> {
pub corpus_handle: Option<corpus::CorpusHandle>, pub corpus_handle: Option<corpus::CorpusHandle>,
pub cases: Vec<SeededCase>, pub cases: Vec<SeededCase>,
pub filtered_questions: usize, pub filtered_questions: usize,
pub stage_latency_samples: Vec<PipelineStageTimings>, pub stage_latency_samples: Vec<StageTimings>,
pub latencies: Vec<u128>, pub latencies: Vec<u128>,
pub diagnostics_output: Vec<CaseDiagnostics>, pub diagnostics_output: Vec<CaseDiagnostics>,
pub query_summaries: Vec<CaseSummary>, pub query_summaries: Vec<CaseSummary>,
+19 -22
View File
@@ -10,7 +10,7 @@ use crate::eval::{
CaseSummary, RetrievedSummary, CaseSummary, RetrievedSummary,
}; };
use retrieval_pipeline::{ use retrieval_pipeline::{
pipeline::{self, PipelineStageTimings, RetrievalConfig}, pipeline::{self, StageTimings, RetrievalConfig},
reranking::RerankerPool, reranking::RerankerPool,
}; };
use tokio::sync::Semaphore; use tokio::sync::Semaphore;
@@ -75,10 +75,10 @@ pub(crate) async fn run_queries(
retrieval_config.tuning.chunk_rrf_fts_weight = value; retrieval_config.tuning.chunk_rrf_fts_weight = value;
} }
if let Some(value) = config.retrieval.chunk_rrf_use_vector { if let Some(value) = config.retrieval.chunk_rrf_use_vector {
retrieval_config.tuning.chunk_rrf_use_vector = value; retrieval_config.tuning.flags.chunk_rrf_use_vector = value.into();
} }
if let Some(value) = config.retrieval.chunk_rrf_use_fts { if let Some(value) = config.retrieval.chunk_rrf_use_fts {
retrieval_config.tuning.chunk_rrf_use_fts = value; retrieval_config.tuning.flags.chunk_rrf_use_fts = value.into();
} }
if let Some(value) = config.retrieval.chunk_avg_chars_per_token { if let Some(value) = config.retrieval.chunk_avg_chars_per_token {
retrieval_config.tuning.avg_chars_per_token = value; retrieval_config.tuning.avg_chars_per_token = value;
@@ -113,8 +113,8 @@ pub(crate) async fn run_queries(
chunk_rrf_k = active_tuning.chunk_rrf_k, chunk_rrf_k = active_tuning.chunk_rrf_k,
chunk_rrf_vector_weight = active_tuning.chunk_rrf_vector_weight, chunk_rrf_vector_weight = active_tuning.chunk_rrf_vector_weight,
chunk_rrf_fts_weight = active_tuning.chunk_rrf_fts_weight, chunk_rrf_fts_weight = active_tuning.chunk_rrf_fts_weight,
chunk_rrf_use_vector = active_tuning.chunk_rrf_use_vector, chunk_rrf_use_vector = active_tuning.flags.chunk_rrf_use_vector.as_bool(),
chunk_rrf_use_fts = active_tuning.chunk_rrf_use_fts, chunk_rrf_use_fts = active_tuning.flags.chunk_rrf_use_fts.as_bool(),
embedding_backend = ctx.embedding_provider().backend_label(), embedding_backend = ctx.embedding_provider().backend_label(),
embedding_model = ctx embedding_model = ctx
.embedding_provider() .embedding_provider()
@@ -181,35 +181,32 @@ pub(crate) async fn run_queries(
embedding_provider.embed(&question).await.with_context(|| { embedding_provider.embed(&question).await.with_context(|| {
format!("generating embedding for question {}", question_id) format!("generating embedding for question {}", question_id)
})?; })?;
let reranker = match &rerank_pool { let reranker = match rerank_pool.as_ref() {
Some(pool) => Some(pool.checkout().await), Some(pool) => pool.checkout().await,
None => None, None => None,
}; };
let params = pipeline::StrategyParams {
db_client: &db,
openai_client: &openai_client,
embedding_provider: Some(&embedding_provider),
input_text: &question,
user_id: &user_id,
config: (*retrieval_config).clone(),
reranker,
};
let (result_output, pipeline_diagnostics, stage_timings) = if diagnostics_enabled { let (result_output, pipeline_diagnostics, stage_timings) = if diagnostics_enabled {
let outcome = pipeline::run_pipeline_with_embedding_with_diagnostics( let outcome = pipeline::run_pipeline_with_embedding_with_diagnostics(
&db, params,
&openai_client,
Some(&embedding_provider),
query_embedding, query_embedding,
&question,
&user_id,
(*retrieval_config).clone(),
reranker,
) )
.await .await
.with_context(|| format!("running pipeline for question {}", question_id))?; .with_context(|| format!("running pipeline for question {}", question_id))?;
(outcome.results, outcome.diagnostics, outcome.stage_timings) (outcome.results, outcome.diagnostics, outcome.stage_timings)
} else { } else {
let outcome = pipeline::run_pipeline_with_embedding_with_metrics( let outcome = pipeline::run_pipeline_with_embedding_with_metrics(
&db, params,
&openai_client,
Some(&embedding_provider),
query_embedding, query_embedding,
&question,
&user_id,
(*retrieval_config).clone(),
reranker,
) )
.await .await
.with_context(|| format!("running pipeline for question {}", question_id))?; .with_context(|| format!("running pipeline for question {}", question_id))?;
@@ -327,7 +324,7 @@ pub(crate) async fn run_queries(
usize, usize,
CaseSummary, CaseSummary,
Option<CaseDiagnostics>, Option<CaseDiagnostics>,
PipelineStageTimings, StageTimings,
), ),
anyhow::Error, anyhow::Error,
>((idx, summary, diagnostics, stage_timings)) >((idx, summary, diagnostics, stage_timings))
+2 -2
View File
@@ -205,8 +205,8 @@ pub(crate) async fn summarize(
chunk_rrf_k: active_tuning.chunk_rrf_k, chunk_rrf_k: active_tuning.chunk_rrf_k,
chunk_rrf_vector_weight: active_tuning.chunk_rrf_vector_weight, chunk_rrf_vector_weight: active_tuning.chunk_rrf_vector_weight,
chunk_rrf_fts_weight: active_tuning.chunk_rrf_fts_weight, chunk_rrf_fts_weight: active_tuning.chunk_rrf_fts_weight,
chunk_rrf_use_vector: active_tuning.chunk_rrf_use_vector, chunk_rrf_use_vector: active_tuning.flags.chunk_rrf_use_vector.as_bool(),
chunk_rrf_use_fts: active_tuning.chunk_rrf_use_fts, chunk_rrf_use_fts: active_tuning.flags.chunk_rrf_use_fts.as_bool(),
ingest_chunk_min_tokens: config.ingest.ingest_chunk_min_tokens, ingest_chunk_min_tokens: config.ingest.ingest_chunk_min_tokens,
ingest_chunk_max_tokens: config.ingest.ingest_chunk_max_tokens, ingest_chunk_max_tokens: config.ingest.ingest_chunk_max_tokens,
ingest_chunks_only: config.ingest.ingest_chunks_only, ingest_chunks_only: config.ingest.ingest_chunks_only,
+25 -27
View File
@@ -1037,6 +1037,31 @@ fn write_manifest(path: &Path, manifest: &SliceManifest) -> Result<()> {
Ok(()) Ok(())
} }
use crate::args::Config;
impl<'a> From<&'a Config> for SliceConfig<'a> {
fn from(config: &'a Config) -> Self {
slice_config_with_limit(config, None)
}
}
pub fn slice_config_with_limit<'a>(
config: &'a Config,
limit_override: Option<usize>,
) -> SliceConfig<'a> {
SliceConfig {
cache_dir: config.cache_dir.as_path(),
force_convert: config.force_convert,
explicit_slice: config.slice.as_deref(),
limit: limit_override.or(config.limit),
corpus_limit: config.corpus_limit,
slice_seed: config.slice_seed,
llm_mode: config.llm_mode,
negative_multiplier: config.negative_multiplier,
require_verified_chunks: config.retrieval.require_verified_chunks,
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@@ -1214,30 +1239,3 @@ mod tests {
Ok(()) Ok(())
} }
} }
// MARK: - Config integration (merged from slice.rs)
use crate::args::Config;
impl<'a> From<&'a Config> for SliceConfig<'a> {
fn from(config: &'a Config) -> Self {
slice_config_with_limit(config, None)
}
}
pub fn slice_config_with_limit<'a>(
config: &'a Config,
limit_override: Option<usize>,
) -> SliceConfig<'a> {
SliceConfig {
cache_dir: config.cache_dir.as_path(),
force_convert: config.force_convert,
explicit_slice: config.slice.as_deref(),
limit: limit_override.or(config.limit),
corpus_limit: config.corpus_limit,
slice_seed: config.slice_seed,
llm_mode: config.llm_mode,
negative_multiplier: config.negative_multiplier,
require_verified_chunks: config.retrieval.require_verified_chunks,
}
}
+27 -26
View File
@@ -39,30 +39,33 @@ const CONVERSATION_ARCHIVE_CACHE_TTL: Duration = Duration::from_secs(30);
const CONVERSATION_ARCHIVE_CACHE_MAX_USERS: usize = 1024; const CONVERSATION_ARCHIVE_CACHE_MAX_USERS: usize = 1024;
const CONVERSATION_ARCHIVE_CACHE_CLEANUP_WRITE_INTERVAL: usize = 64; const CONVERSATION_ARCHIVE_CACHE_CLEANUP_WRITE_INTERVAL: usize = 64;
pub struct StateResources {
pub db: Arc<SurrealDbClient>,
pub openai_client: Arc<OpenAIClientType>,
pub session_store: Arc<SessionStoreType>,
pub storage: StorageManager,
pub config: AppConfig,
pub reranker_pool: Option<Arc<RerankerPool>>,
pub embedding_provider: Arc<EmbeddingProvider>,
pub template_engine: Option<Arc<TemplateEngine>>,
}
impl HtmlState { impl HtmlState {
pub async fn new_with_resources( pub fn new_with_resources(resources: StateResources) -> Self {
db: Arc<SurrealDbClient>, let templates = resources
openai_client: Arc<OpenAIClientType>, .template_engine
session_store: Arc<SessionStoreType>, .unwrap_or_else(|| Arc::new(create_template_engine!("templates")));
storage: StorageManager,
config: AppConfig,
reranker_pool: Option<Arc<RerankerPool>>,
embedding_provider: Arc<EmbeddingProvider>,
template_engine: Option<Arc<TemplateEngine>>,
) -> Self {
let templates =
template_engine.unwrap_or_else(|| Arc::new(create_template_engine!("templates")));
debug!("Template engine configured for html_router."); debug!("Template engine configured for html_router.");
Self { Self {
db, db: resources.db,
openai_client, openai_client: resources.openai_client,
session_store,
templates, templates,
config, session_store: resources.session_store,
storage, config: resources.config,
reranker_pool, storage: resources.storage,
embedding_provider, reranker_pool: resources.reranker_pool,
embedding_provider: resources.embedding_provider,
conversation_archive_cache: Arc::new(RwLock::new(HashMap::new())), conversation_archive_cache: Arc::new(RwLock::new(HashMap::new())),
conversation_archive_cache_writes: Arc::new(AtomicUsize::new(0)), conversation_archive_cache_writes: Arc::new(AtomicUsize::new(0)),
} }
@@ -210,18 +213,16 @@ mod tests {
EmbeddingProvider::new_hashed(8).expect("Failed to create embedding provider"), EmbeddingProvider::new_hashed(8).expect("Failed to create embedding provider"),
); );
HtmlState::new_with_resources( HtmlState::new_with_resources(StateResources {
db, db,
Arc::new(async_openai::Client::new()), openai_client: Arc::new(async_openai::Client::new()),
session_store, session_store,
storage, storage,
config, config,
None, reranker_pool: None,
embedding_provider, embedding_provider,
None, template_engine: None,
) })
.await
.expect("Failed to create HtmlState")
} }
#[tokio::test] #[tokio::test]
+1 -1
View File
@@ -2,6 +2,6 @@ use tower_http::compression::CompressionLayer;
/// Provides a default compression layer that negotiates encoding based on the /// Provides a default compression layer that negotiates encoding based on the
/// `Accept-Encoding` header of the incoming request. /// `Accept-Encoding` header of the incoming request.
pub fn compression_layer() -> CompressionLayer { pub fn layer() -> CompressionLayer {
CompressionLayer::new() CompressionLayer::new()
} }
@@ -10,7 +10,7 @@ use axum::{
use axum_htmx::{HxRequest, HX_TRIGGER}; use axum_htmx::{HxRequest, HX_TRIGGER};
use common::{ use common::{
error::AppError, error::AppError,
utils::template_engine::{ProvidesTemplateEngine, Value}, utils::template_engine::{ProvidesTemplateEngine, TemplateEngine, Value},
}; };
use minijinja::context; use minijinja::context;
use serde::Serialize; use serde::Serialize;
@@ -146,60 +146,14 @@ struct ContextWrapper<'a> {
context: HashMap<String, Value>, context: HashMap<String, Value>,
} }
pub async fn with_template_response<S>(
State(state): State<S>,
HxRequest(is_htmx): HxRequest,
req: Request,
next: Next,
) -> Response
where
S: ProvidesTemplateEngine + ProvidesHtmlState + Clone + Send + Sync + 'static,
{
let mut user_theme = Theme::System.as_str();
let mut initial_theme = Theme::System.initial_theme();
let mut is_authenticated = false;
let mut current_user_id = None;
let mut current_user = None;
{
if let Some(auth) = req.extensions().get::<AuthSessionType>() {
if let Some(user) = &auth.current_user {
is_authenticated = true;
current_user_id = Some(user.id.clone());
user_theme = user.theme.as_str();
initial_theme = user.theme.initial_theme();
current_user = Some(TemplateUser::from(user));
}
}
}
let response = next.run(req).await;
// Headers to forward from the original response
const HTMX_HEADERS_TO_FORWARD: &[&str] = &["HX-Push", "HX-Trigger", "HX-Redirect"]; const HTMX_HEADERS_TO_FORWARD: &[&str] = &["HX-Push", "HX-Trigger", "HX-Redirect"];
if let Some(template_response) = response.extensions().get::<TemplateResponse>().cloned() { fn forward_headers(from: &axum::http::HeaderMap, to: &mut axum::http::HeaderMap) {
let template_engine = state.template_engine(); for &header_name in HTMX_HEADERS_TO_FORWARD {
if let Ok(name) = HeaderName::from_bytes(header_name.as_bytes()) {
let mut conversation_archive = Vec::new(); if let Some(value) = from.get(&name) {
to.insert(name.clone(), value.clone());
let should_load_conversation_archive = }
matches!(&template_response.template_kind, TemplateKind::Full(_));
if should_load_conversation_archive {
if let Some(user_id) = current_user_id {
let html_state = state.html_state();
if let Some(cached_archive) =
html_state.get_cached_conversation_archive(&user_id).await
{
conversation_archive = cached_archive;
} else if let Ok(archive) =
Conversation::get_user_sidebar_conversations(&user_id, &html_state.db).await
{
html_state
.set_cached_conversation_archive(&user_id, archive.clone())
.await;
conversation_archive = archive;
} }
} }
} }
@@ -226,15 +180,57 @@ where
} }
} }
// Helper to forward relevant headers pub async fn with_template_response<S>(
fn forward_headers(from: &axum::http::HeaderMap, to: &mut axum::http::HeaderMap) { State(state): State<S>,
for &header_name in HTMX_HEADERS_TO_FORWARD { HxRequest(is_htmx): HxRequest,
if let Ok(name) = HeaderName::from_bytes(header_name.as_bytes()) { req: Request,
if let Some(value) = from.get(&name) { next: Next,
to.insert(name.clone(), value.clone()); ) -> Response
where
S: ProvidesTemplateEngine + ProvidesHtmlState + Clone + Send + Sync + 'static,
{
let mut user_theme = Theme::System.as_str();
let mut initial_theme = Theme::System.initial_theme();
let mut is_authenticated = false;
let mut current_user = None;
{
if let Some(auth) = req.extensions().get::<AuthSessionType>() {
if let Some(user) = &auth.current_user {
is_authenticated = true;
user_theme = user.theme.as_str();
initial_theme = user.theme.initial_theme();
current_user = Some(TemplateUser::from(user));
} }
} }
} }
let response = next.run(req).await;
if let Some(template_response) = response.extensions().get::<TemplateResponse>().cloned() {
let template_engine = state.template_engine();
let mut conversation_archive = Vec::new();
let should_load_conversation_archive =
matches!(&template_response.template_kind, TemplateKind::Full(_));
if should_load_conversation_archive {
if let Some(user_id) = current_user.as_ref().map(|u| &u.id) {
let html_state = state.html_state();
if let Some(cached_archive) =
html_state.get_cached_conversation_archive(user_id).await
{
conversation_archive = cached_archive;
} else if let Ok(archive) =
Conversation::get_user_sidebar_conversations(user_id, &html_state.db).await
{
html_state
.set_cached_conversation_archive(user_id, archive.clone())
.await;
conversation_archive = archive;
}
}
} }
let context_map = match context_to_map(&template_response.context) { let context_map = match context_to_map(&template_response.context) {
@@ -290,18 +286,17 @@ where
} }
TemplateKind::Error(status) => { TemplateKind::Error(status) => {
if is_htmx { if is_htmx {
// HTMX request: Send 204 + HX-Trigger for toast
let title = template_response let title = template_response
.context .context
.get_attr("title") .get_attr("title")
.ok() .ok()
.and_then(|v| v.as_str().map(String::from)) .and_then(|v| v.as_str().map(|s| s.to_string()))
.unwrap_or_else(|| "Error".to_string()); .unwrap_or_else(|| "Error".to_string());
let description = template_response let description = template_response
.context .context
.get_attr("description") .get_attr("description")
.ok() .ok()
.and_then(|v| v.as_str().map(String::from)) .and_then(|v| v.as_str().map(|s| s.to_string()))
.unwrap_or_else(|| "An error occurred.".to_string()); .unwrap_or_else(|| "An error occurred.".to_string());
let trigger_payload = json!({"toast": {"title": title, "description": description, "type": "error"}}); let trigger_payload = json!({"toast": {"title": title, "description": description, "type": "error"}});
@@ -312,14 +307,12 @@ where
}); });
(StatusCode::NO_CONTENT, [(HX_TRIGGER, trigger_value)], "").into_response() (StatusCode::NO_CONTENT, [(HX_TRIGGER, trigger_value)], "").into_response()
} else { } else {
// Non-HTMX request: Render the full errors/error.html page
match template_engine match template_engine
.render("errors/error.html", &Value::from_serialize(&context)) .render("errors/error.html", &Value::from_serialize(&context))
{ {
Ok(html) => (*status, Html(html)).into_response(), Ok(html) => (*status, Html(html)).into_response(),
Err(e) => { Err(e) => {
error!("Critical: Failed to render 'errors/error.html': {:?}", e); error!("Critical: Failed to render 'errors/error.html': {:?}", e);
// Fallback HTML, but use the intended status code
(*status, Html(fallback_error())).into_response() (*status, Html(fallback_error())).into_response()
} }
} }
+9 -2
View File
@@ -9,7 +9,7 @@ use crate::{
html_state::HtmlState, html_state::HtmlState,
middlewares::{ middlewares::{
analytics_middleware::analytics_middleware, auth_middleware::require_auth, analytics_middleware::analytics_middleware, auth_middleware::require_auth,
compression::compression_layer, response_middleware::with_template_response, compression, response_middleware::with_template_response,
}, },
}; };
@@ -71,6 +71,7 @@ where
} }
// Add a serving of assets // Add a serving of assets
#[must_use]
pub fn with_public_assets(mut self, path: &str, directory: &str) -> Self { pub fn with_public_assets(mut self, path: &str, directory: &str) -> Self {
self.public_assets_config = Some(AssetsConfig { self.public_assets_config = Some(AssetsConfig {
path: path.to_string(), path: path.to_string(),
@@ -80,24 +81,28 @@ where
} }
// Add a public router that will be merged at the root level // Add a public router that will be merged at the root level
#[must_use]
pub fn add_public_routes(mut self, routes: Router<S>) -> Self { pub fn add_public_routes(mut self, routes: Router<S>) -> Self {
self.public_routers.push(routes); self.public_routers.push(routes);
self self
} }
// Add a protected router that will be merged at the root level // Add a protected router that will be merged at the root level
#[must_use]
pub fn add_protected_routes(mut self, routes: Router<S>) -> Self { pub fn add_protected_routes(mut self, routes: Router<S>) -> Self {
self.protected_routers.push(routes); self.protected_routers.push(routes);
self self
} }
// Nest a public router under a path prefix // Nest a public router under a path prefix
#[must_use]
pub fn nest_public_routes(mut self, path: &str, routes: Router<S>) -> Self { pub fn nest_public_routes(mut self, path: &str, routes: Router<S>) -> Self {
self.nested_routes.push((path.to_string(), routes)); self.nested_routes.push((path.to_string(), routes));
self self
} }
// Nest a protected router under a path prefix // Nest a protected router under a path prefix
#[must_use]
pub fn nest_protected_routes(mut self, path: &str, routes: Router<S>) -> Self { pub fn nest_protected_routes(mut self, path: &str, routes: Router<S>) -> Self {
self.nested_protected_routes self.nested_protected_routes
.push((path.to_string(), routes)); .push((path.to_string(), routes));
@@ -105,6 +110,7 @@ where
} }
// Add custom middleware to be applied before the standard ones // Add custom middleware to be applied before the standard ones
#[must_use]
pub fn with_middleware<F>(mut self, middleware_fn: F) -> Self pub fn with_middleware<F>(mut self, middleware_fn: F) -> Self
where where
F: FnOnce(Router<S>) -> Router<S> + Send + 'static, F: FnOnce(Router<S>) -> Router<S> + Send + 'static,
@@ -114,6 +120,7 @@ where
} }
/// Enables response compression when building the router. /// Enables response compression when building the router.
#[must_use]
pub const fn with_compression(mut self) -> Self { pub const fn with_compression(mut self) -> Self {
self.compression_enabled = true; self.compression_enabled = true;
self self
@@ -191,7 +198,7 @@ where
// Apply Global Middleware (Compression) // Apply Global Middleware (Compression)
if self.compression_enabled { if self.compression_enabled {
final_router = final_router.layer(compression_layer()); final_router = final_router.layer(compression::layer());
} }
final_router final_router
+3 -3
View File
@@ -62,7 +62,7 @@ pub async fn set_api_key(
let api_key = User::set_api_key(&user.id, &state.db).await?; let api_key = User::set_api_key(&user.id, &state.db).await?;
// Clear the cache so new requests have access to the user with api key // Clear the cache so new requests have access to the user with api key
auth.cache_clear_user(user.id.to_string()); auth.cache_clear_user(user.id.clone());
// Render the API key section block // Render the API key section block
Ok(TemplateResponse::new_partial( Ok(TemplateResponse::new_partial(
@@ -106,7 +106,7 @@ pub async fn update_timezone(
User::update_timezone(&user.id, &form.timezone, &state.db).await?; User::update_timezone(&user.id, &form.timezone, &state.db).await?;
// Clear the cache // Clear the cache
auth.cache_clear_user(user.id.to_string()); auth.cache_clear_user(user.id.clone());
let timezones = TZ_VARIANTS let timezones = TZ_VARIANTS
.iter() .iter()
@@ -141,7 +141,7 @@ pub async fn update_theme(
User::update_theme(&user.id, &form.theme, &state.db).await?; User::update_theme(&user.id, &form.theme, &state.db).await?;
// Clear the cache // Clear the cache
auth.cache_clear_user(user.id.to_string()); auth.cache_clear_user(user.id.clone());
let theme_options = vec![ let theme_options = vec![
Theme::Light.as_str().to_string(), Theme::Light.as_str().to_string(),
+7 -12
View File
@@ -1,3 +1,5 @@
use std::sync::Arc;
use async_openai::types::ListModelResponse; use async_openai::types::ListModelResponse;
use axum::{ use axum::{
extract::{Query, State}, extract::{Query, State},
@@ -37,18 +39,14 @@ pub struct AdminPanelData {
current_section: AdminSection, current_section: AdminSection,
} }
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)] #[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq, Default)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum AdminSection { pub enum AdminSection {
#[default]
Overview, Overview,
Models, Models,
} }
impl Default for AdminSection {
fn default() -> Self {
Self::Overview
}
}
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct AdminPanelQuery { pub struct AdminPanelQuery {
@@ -107,10 +105,7 @@ fn checkbox_to_bool<'de, D>(deserializer: D) -> Result<bool, D::Error>
where where
D: serde::Deserializer<'de>, D: serde::Deserializer<'de>,
{ {
match String::deserialize(deserializer) { String::deserialize(deserializer).map(|s| s == "on")
Ok(string) => Ok(string == "on"),
Err(_) => Ok(false),
}
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@@ -219,8 +214,8 @@ pub async fn update_model_settings(
if reembedding_needed { if reembedding_needed {
info!("Embedding dimensions changed. Spawning background re-embedding task..."); info!("Embedding dimensions changed. Spawning background re-embedding task...");
let db_for_task = state.db.clone(); let db_for_task = Arc::clone(&state.db);
let openai_for_task = state.openai_client.clone(); let openai_for_task = Arc::clone(&state.openai_client);
let new_model_for_task = new_settings.embedding_model.clone(); let new_model_for_task = new_settings.embedding_model.clone();
let new_dims_for_task = new_settings.embedding_dimensions; let new_dims_for_task = new_settings.embedding_dimensions;
+2 -2
View File
@@ -11,7 +11,7 @@ use crate::{
}; };
#[derive(Deserialize, Serialize)] #[derive(Deserialize, Serialize)]
pub struct SignupParams { pub struct Params {
pub email: String, pub email: String,
pub password: String, pub password: String,
pub timezone: String, pub timezone: String,
@@ -39,7 +39,7 @@ pub async fn show_signup_form(
pub async fn process_signup_and_show_verification( pub async fn process_signup_and_show_verification(
State(state): State<HtmlState>, State(state): State<HtmlState>,
auth: AuthSessionType, auth: AuthSessionType,
Form(form): Form<SignupParams>, Form(form): Form<Params>,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
let user = match User::create_new( let user = match User::create_new(
form.email, form.email,
+30 -27
View File
@@ -49,6 +49,8 @@ pub struct ChatPageData {
conversation: Option<Conversation>, conversation: Option<Conversation>,
} }
/// # Panics
/// Panics if the HX-Push header value cannot be parsed.
pub async fn show_initialized_chat( pub async fn show_initialized_chat(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
@@ -57,14 +59,14 @@ pub async fn show_initialized_chat(
let conversation = Conversation::new(user.id.clone(), "Test".to_owned()); let conversation = Conversation::new(user.id.clone(), "Test".to_owned());
let user_message = Message::new( let user_message = Message::new(
conversation.id.to_string(), conversation.id.clone(),
MessageRole::User, MessageRole::User,
form.user_query, form.user_query,
None, None,
); );
let ai_message = Message::new( let ai_message = Message::new(
conversation.id.to_string(), conversation.id.clone(),
MessageRole::AI, MessageRole::AI,
form.llm_response, form.llm_response,
Some(form.references), Some(form.references),
@@ -86,10 +88,9 @@ pub async fn show_initialized_chat(
) )
.into_response(); .into_response();
response.headers_mut().insert( if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) {
"HX-Push", response.headers_mut().insert("HX-Push", header_value);
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(), }
);
Ok(response) Ok(response)
} }
@@ -130,12 +131,19 @@ pub async fn show_existing_chat(
)) ))
} }
/// # Panics
/// Panics if the HX-Push header value cannot be parsed.
pub async fn new_user_message( pub async fn new_user_message(
Path(conversation_id): Path<String>, Path(conversation_id): Path<String>,
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
Form(form): Form<NewMessageForm>, Form(form): Form<NewMessageForm>,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
#[derive(Serialize)]
struct SSEResponseInitData {
user_message: Message,
}
let conversation: Conversation = state let conversation: Conversation = state
.db .db
.get_item(&conversation_id) .get_item(&conversation_id)
@@ -150,33 +158,34 @@ pub async fn new_user_message(
state.db.store_item(user_message.clone()).await?; state.db.store_item(user_message.clone()).await?;
#[derive(Serialize)]
struct SSEResponseInitData {
user_message: Message,
}
let mut response = TemplateResponse::new_template( let mut response = TemplateResponse::new_template(
"chat/streaming_response.html", "chat/streaming_response.html",
SSEResponseInitData { user_message }, SSEResponseInitData { user_message },
) )
.into_response(); .into_response();
response.headers_mut().insert( if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) {
"HX-Push", response.headers_mut().insert("HX-Push", header_value);
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(), }
);
Ok(response) Ok(response)
} }
/// # Panics
/// Panics if the HX-Push header value cannot be parsed.
pub async fn new_chat_user_message( pub async fn new_chat_user_message(
State(state): State<HtmlState>, State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>, auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Form(form): Form<NewMessageForm>, Form(form): Form<NewMessageForm>,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
let user = match auth.current_user { #[derive(Serialize)]
Some(user) => user, struct SSEResponseInitData {
None => return Ok(Redirect::to("/").into_response()), user_message: Message,
conversation: Conversation,
}
let Some(user) = auth.current_user else {
return Ok(Redirect::to("/").into_response());
}; };
let conversation = Conversation::new(user.id.clone(), "New chat".to_string()); let conversation = Conversation::new(user.id.clone(), "New chat".to_string());
@@ -191,11 +200,6 @@ pub async fn new_chat_user_message(
state.db.store_item(user_message.clone()).await?; state.db.store_item(user_message.clone()).await?;
state.invalidate_conversation_archive_cache(&user.id).await; state.invalidate_conversation_archive_cache(&user.id).await;
#[derive(Serialize)]
struct SSEResponseInitData {
user_message: Message,
conversation: Conversation,
}
let mut response = TemplateResponse::new_template( let mut response = TemplateResponse::new_template(
"chat/new_chat_first_response.html", "chat/new_chat_first_response.html",
SSEResponseInitData { SSEResponseInitData {
@@ -205,10 +209,9 @@ pub async fn new_chat_user_message(
) )
.into_response(); .into_response();
response.headers_mut().insert( if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) {
"HX-Push", response.headers_mut().insert("HX-Push", header_value);
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(), }
);
Ok(response.into_response()) Ok(response.into_response())
} }
@@ -53,26 +53,22 @@ fn sse_with_keep_alive(stream: EventStream) -> SseResponse {
) )
} }
// Error handling function
fn create_error_stream(message: impl Into<String>) -> EventStream { fn create_error_stream(message: impl Into<String>) -> EventStream {
let message = message.into(); let message = message.into();
stream::once(async move { Ok(Event::default().event("error").data(message)) }).boxed() stream::once(async move { Ok(Event::default().event("error").data(message)) }).boxed()
} }
// Helper function to get message and user
async fn get_message_and_user( async fn get_message_and_user(
db: &SurrealDbClient, db: &SurrealDbClient,
current_user: Option<User>, current_user: Option<User>,
message_id: &str, message_id: &str,
) -> Result<(Message, User, Conversation, Vec<Message>, Option<Message>), SseResponse> { ) -> Result<(Message, User, Conversation, Vec<Message>, Option<Message>), SseResponse> {
// Check authentication
let Some(user) = current_user else { let Some(user) = current_user else {
return Err(sse_with_keep_alive(create_error_stream( return Err(sse_with_keep_alive(create_error_stream(
"You must be signed in to use this feature", "You must be signed in to use this feature",
))); )));
}; };
// Retrieve message
let message = match db.get_item::<Message>(message_id).await { let message = match db.get_item::<Message>(message_id).await {
Ok(Some(message)) => message, Ok(Some(message)) => message,
Ok(None) => { Ok(None) => {
@@ -88,7 +84,6 @@ async fn get_message_and_user(
} }
}; };
// Get conversation history
let (conversation, history) = let (conversation, history) =
match Conversation::get_complete_conversation(&message.conversation_id, &user.id, db).await match Conversation::get_complete_conversation(&message.conversation_id, &user.id, db).await
{ {
@@ -209,7 +204,6 @@ pub async fn get_response_stream(
auth: AuthSessionType, auth: AuthSessionType,
Query(params): Query<QueryParams>, Query(params): Query<QueryParams>,
) -> SseResponse { ) -> SseResponse {
// 1. Authentication and initial data validation
let (user_message, user, _conversation, history, existing_ai_response) = let (user_message, user, _conversation, history, existing_ai_response) =
match get_message_and_user(&state.db, auth.current_user, &params.message_id).await { match get_message_and_user(&state.db, auth.current_user, &params.message_id).await {
Ok((user_message, user, conversation, history, existing_ai_response)) => ( Ok((user_message, user, conversation, history, existing_ai_response)) => (
@@ -226,9 +220,123 @@ pub async fn get_response_stream(
return create_replayed_response_stream(&state, existing_ai_message); return create_replayed_response_stream(&state, existing_ai_message);
} }
// 2. Retrieve knowledge entities let (request, allowed_reference_ids) = match prepare_chat_request(&state, &user_message, &user, &history).await {
Ok(result) => result,
Err(sse) => return sse,
};
let openai_stream = match state.openai_client.chat().create_stream(request).await {
Ok(stream) => stream,
Err(_e) => {
return sse_with_keep_alive(create_error_stream("Failed to create OpenAI stream"));
}
};
build_chat_event_stream(state, openai_stream, &user_message, user.id.clone(), allowed_reference_ids)
}
fn build_chat_event_stream(
state: HtmlState,
openai_stream: impl Stream<Item = Result<async_openai::types::CreateChatCompletionStreamResponse, async_openai::error::OpenAIError>> + Send + 'static,
user_message: &Message,
user_id: String,
allowed_reference_ids: Vec<String>,
) -> SseResponse {
let (tx, rx) = channel::<String>(1000);
let (tx_final, mut rx_final) = channel::<Message>(1);
spawn_storage_task(Arc::clone(&state.db), rx, tx_final, user_message, user_id, allowed_reference_ids);
let json_state = Arc::new(Mutex::new(StreamParserState::new()));
let event_stream = openai_stream
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
.map(move |result| {
let tx_storage = tx.clone();
let json_state = Arc::clone(&json_state);
stream! {
match result {
Ok(response) => {
let content = response
.choices
.first()
.and_then(|choice| choice.delta.content.clone())
.unwrap_or_default();
if !content.is_empty() {
let _ = tx_storage.send(content.clone()).await;
let mut state = json_state.lock().await;
let display_content = state.process_chunk(&content);
drop(state);
if !display_content.is_empty() {
yield Ok(Event::default()
.event("chat_message")
.data(display_content));
}
}
}
Err(e) => {
yield Ok(Event::default()
.event("error")
.data(format!("Stream error: {e}")));
}
}
}
})
.flatten()
.chain(stream::once(async move {
#[derive(Serialize)]
struct LocalReferenceData {
message: Message,
}
if let Some(message) = rx_final.recv().await {
if message
.references
.as_ref()
.is_some_and(std::vec::Vec::is_empty)
{
return Ok(Event::default().event("empty"));
}
match state.templates.render(
"chat/reference_list.html",
&Value::from_serialize(LocalReferenceData { message }),
) {
Ok(html) => Ok(Event::default().event("references").data(html)),
Err(_) => Ok(Event::default()
.event("error")
.data("Failed to render references")),
}
} else {
Ok(Event::default()
.event("error")
.data("Failed to retrieve references"))
}
}))
.chain(once(async {
Ok(Event::default()
.event("close_stream")
.data("Stream complete"))
}))
.boxed();
sse_with_keep_alive(event_stream)
}
async fn prepare_chat_request(
state: &HtmlState,
user_message: &Message,
user: &User,
history: &[Message],
) -> Result<
(async_openai::types::CreateChatCompletionRequest, Vec<String>),
Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>>,
> {
let rerank_lease = match state.reranker_pool.as_ref() { let rerank_lease = match state.reranker_pool.as_ref() {
Some(pool) => Some(pool.checkout().await), Some(pool) => pool.checkout().await,
None => None, None => None,
}; };
@@ -248,59 +356,49 @@ pub async fn get_response_stream(
{ {
Ok(result) => result, Ok(result) => result,
Err(_e) => { Err(_e) => {
return sse_with_keep_alive(create_error_stream("Failed to retrieve knowledge")); return Err(Sse::new(create_error_stream("Failed to retrieve knowledge")));
} }
}; };
let allowed_reference_ids = collect_reference_ids_from_retrieval(&retrieval_result); let allowed_reference_ids = collect_reference_ids_from_retrieval(&retrieval_result);
// 3. Create the OpenAI request with appropriate context format let context_json = match retrieval_result {
let context_json = match &retrieval_result { retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(&chunks),
retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(chunks),
retrieval_pipeline::StrategyOutput::Entities(entities) => { retrieval_pipeline::StrategyOutput::Entities(entities) => {
retrieved_entities_to_json(entities) retrieved_entities_to_json(entities)
} }
retrieval_pipeline::StrategyOutput::Search(search_result) => { retrieval_pipeline::StrategyOutput::Search(search_result) => {
// For chat, use chunks from the search result
chunks_to_chat_context(&search_result.chunks) chunks_to_chat_context(&search_result.chunks)
} }
}; };
let formatted_user_message = let formatted_user_message =
create_user_message_with_history(&context_json, &history, &user_message.content); create_user_message_with_history(&context_json, history, &user_message.content);
let Ok(settings) = SystemSettings::get_current(&state.db).await else { let Ok(settings) = SystemSettings::get_current(&state.db).await else {
return sse_with_keep_alive(create_error_stream("Failed to retrieve system settings")); return Err(Sse::new(create_error_stream("Failed to retrieve system settings")));
}; };
let Ok(request) = create_chat_request(formatted_user_message, &settings) else { let Ok(request) = create_chat_request(formatted_user_message, &settings) else {
return sse_with_keep_alive(create_error_stream("Failed to create chat request")); return Err(Sse::new(create_error_stream("Failed to create chat request")));
}; };
// 4. Set up the OpenAI stream Ok((request, allowed_reference_ids))
let openai_stream = match state.openai_client.chat().create_stream(request).await {
Ok(stream) => stream,
Err(_e) => {
return sse_with_keep_alive(create_error_stream("Failed to create OpenAI stream"));
} }
};
// 5. Create channel for collecting complete response fn spawn_storage_task(
let (tx, mut rx) = channel::<String>(1000); db_client: Arc<SurrealDbClient>,
let tx_clone = tx.clone(); mut rx: tokio::sync::mpsc::Receiver<String>,
let (tx_final, mut rx_final) = channel::<Message>(1); tx_final: tokio::sync::mpsc::Sender<Message>,
user_message: &Message,
user_id: String,
allowed_reference_ids: Vec<String>,
) {
let conversation_id = user_message.conversation_id.clone();
// 6. Set up the collection task for DB storage
let db_client = Arc::clone(&state.db);
let user_id = user.id.clone();
let allowed_reference_ids = allowed_reference_ids.clone();
tokio::spawn(async move { tokio::spawn(async move {
drop(tx); // Close sender when no longer needed
// Collect full response
let mut full_json = String::new(); let mut full_json = String::new();
while let Some(chunk) = rx.recv().await { while let Some(chunk) = rx.recv().await {
full_json.push_str(&chunk); full_json.push_str(&chunk);
} }
// Try to extract structured data
if let Ok(response) = from_str::<LLMResponseFormat>(&full_json) { if let Ok(response) = from_str::<LLMResponseFormat>(&full_json) {
let raw_references = extract_reference_strings(&response); let raw_references = extract_reference_strings(&response);
let answer = response.answer; let answer = response.answer;
@@ -347,7 +445,7 @@ pub async fn get_response_stream(
); );
let ai_message = Message::new( let ai_message = Message::new(
user_message.conversation_id, conversation_id,
MessageRole::AI, MessageRole::AI,
answer, answer,
Some(initial_validation.valid_refs), Some(initial_validation.valid_refs),
@@ -362,104 +460,11 @@ pub async fn get_response_stream(
} else { } else {
error!("Failed to parse LLM response as structured format"); error!("Failed to parse LLM response as structured format");
// Fallback - store raw response let ai_message = Message::new(conversation_id, MessageRole::AI, full_json, None);
let ai_message = Message::new(
user_message.conversation_id,
MessageRole::AI,
full_json,
None,
);
let _ = db_client.store_item(ai_message).await; let _ = db_client.store_item(ai_message).await;
} }
}); });
// Create a shared state for tracking the JSON parsing
let json_state = Arc::new(Mutex::new(StreamParserState::new()));
// 7. Create the response event stream
let event_stream = openai_stream
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
.map(move |result| {
let tx_storage = tx_clone.clone();
let json_state = Arc::clone(&json_state);
stream! {
match result {
Ok(response) => {
let content = response
.choices
.first()
.and_then(|choice| choice.delta.content.clone())
.unwrap_or_default();
if !content.is_empty() {
// Always send raw content to storage
let _ = tx_storage.send(content.clone()).await;
// Process through JSON parser
let mut state = json_state.lock().await;
let display_content = state.process_chunk(&content);
drop(state);
if !display_content.is_empty() {
yield Ok(Event::default()
.event("chat_message")
.data(display_content));
}
// If display_content is empty, don't yield anything
}
// If content is empty, don't yield anything
}
Err(e) => {
yield Ok(Event::default()
.event("error")
.data(format!("Stream error: {e}")));
}
}
}
})
.flatten()
.chain(stream::once(async move {
if let Some(message) = rx_final.recv().await {
// Don't send any event if references is empty
if message
.references
.as_ref()
.is_some_and(std::vec::Vec::is_empty)
{
return Ok(Event::default().event("empty")); // This event won't be sent
}
// Render template with references
match state.templates.render(
"chat/reference_list.html",
&Value::from_serialize(ReferenceData { message }),
) {
Ok(html) => {
// Return the rendered HTML
Ok(Event::default().event("references").data(html))
}
Err(_) => {
// Handle template rendering error
Ok(Event::default()
.event("error")
.data("Failed to render references"))
}
}
} else {
// Handle case where no references were received
Ok(Event::default()
.event("error")
.data("Failed to retrieve references"))
}
}))
.chain(once(async {
Ok(Event::default()
.event("close_stream")
.data("Stream complete"))
}));
sse_with_keep_alive(event_stream.boxed())
} }
struct StreamParserState { struct StreamParserState {
@@ -478,23 +483,18 @@ impl StreamParserState {
} }
fn process_chunk(&mut self, chunk: &str) -> String { fn process_chunk(&mut self, chunk: &str) -> String {
// Feed all characters into the parser
for c in chunk.chars() { for c in chunk.chars() {
let _ = self.parser.add_char(c); let _ = self.parser.add_char(c);
} }
// Get the current state of the JSON
let json = self.parser.get_result(); let json = self.parser.get_result();
// Check if we're in the answer field
if let Some(obj) = json.as_object() { if let Some(obj) = json.as_object() {
if let Some(answer) = obj.get("answer") { if let Some(answer) = obj.get("answer") {
self.in_answer_field = true; self.in_answer_field = true;
// Get current answer content
let current_content = answer.as_str().unwrap_or_default().to_string(); let current_content = answer.as_str().unwrap_or_default().to_string();
// Calculate difference to send only new content
if current_content.len() > self.last_answer_content.len() { if current_content.len() > self.last_answer_content.len() {
let new_content = current_content[self.last_answer_content.len()..].to_string(); let new_content = current_content[self.last_answer_content.len()..].to_string();
self.last_answer_content = current_content; self.last_answer_content = current_content;
@@ -503,7 +503,6 @@ impl StreamParserState {
} }
} }
// No new content to return
String::new() String::new()
} }
} }
+6 -5
View File
@@ -10,8 +10,9 @@ use axum::{
}; };
pub use chat_handlers::{ pub use chat_handlers::{
delete_conversation, new_chat_user_message, new_user_message, patch_conversation_title, delete_conversation, new_chat_user_message, new_user_message, patch_conversation_title,
reload_sidebar, show_chat_base, show_conversation_editing_title, show_existing_chat, reload_sidebar, show_conversation_editing_title,
show_initialized_chat, show_chat_base as show_base, show_existing_chat as show_existing,
show_initialized_chat as show_initialized,
}; };
use message_response_stream::get_response_stream; use message_response_stream::get_response_stream;
use references::show_reference_tooltip; use references::show_reference_tooltip;
@@ -24,10 +25,10 @@ where
HtmlState: FromRef<S>, HtmlState: FromRef<S>,
{ {
Router::new() Router::new()
.route("/chat", get(show_chat_base).post(new_chat_user_message)) .route("/chat", get(show_base).post(new_chat_user_message))
.route( .route(
"/chat/{id}", "/chat/{id}",
get(show_existing_chat) get(show_existing)
.post(new_user_message) .post(new_user_message)
.delete(delete_conversation), .delete(delete_conversation),
) )
@@ -36,7 +37,7 @@ where
get(show_conversation_editing_title).patch(patch_conversation_title), get(show_conversation_editing_title).patch(patch_conversation_title),
) )
.route("/chat/sidebar", get(reload_sidebar)) .route("/chat/sidebar", get(reload_sidebar))
.route("/initialized-chat", post(show_initialized_chat)) .route("/initialized-chat", post(show_initialized))
.route("/chat/response-stream", get(get_response_stream)) .route("/chat/response-stream", get(get_response_stream))
.route("/chat/reference/{id}", get(show_reference_tooltip)) .route("/chat/reference/{id}", get(show_reference_tooltip))
} }
+5 -4
View File
@@ -102,13 +102,13 @@ pub async fn show_text_content_edit_form(
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
Path(id): Path<String>, Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
#[derive(Serialize)] #[derive(Serialize)]
pub struct TextContentEditModal { pub struct TextContentEditModal {
pub text_content: TextContent, pub text_content: TextContent,
} }
let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"content/edit_text_content_modal.html", "content/edit_text_content_modal.html",
TextContentEditModal { text_content }, TextContentEditModal { text_content },
@@ -214,13 +214,14 @@ pub async fn show_content_read_modal(
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
Path(id): Path<String>, Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
// Get and validate the text content
let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
#[derive(Serialize)] #[derive(Serialize)]
pub struct TextContentReadModalData { pub struct TextContentReadModalData {
pub text_content: TextContent, pub text_content: TextContent,
} }
// Get and validate the text content
let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"content/read_content_modal.html", "content/read_content_modal.html",
TextContentReadModalData { text_content }, TextContentReadModalData { text_content },
+5 -7
View File
@@ -226,7 +226,7 @@ fn summarize_task_content(task: &IngestionTask) -> (String, String) {
("Text".to_string(), truncate_summary(text, 80)) ("Text".to_string(), truncate_summary(text, 80))
} }
common::storage::types::ingestion_payload::IngestionPayload::Url { url, .. } => { common::storage::types::ingestion_payload::IngestionPayload::Url { url, .. } => {
("URL".to_string(), url.to_string()) ("URL".to_string(), url.clone())
} }
common::storage::types::ingestion_payload::IngestionPayload::File { file_info, .. } => { common::storage::types::ingestion_payload::IngestionPayload::File { file_info, .. } => {
("File".to_string(), file_info.file_name.clone()) ("File".to_string(), file_info.file_name.clone())
@@ -248,18 +248,16 @@ pub async fn serve_file(
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
Path(file_id): Path<String>, Path(file_id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
let file_info = match FileInfo::get_by_id(&file_id, &state.db).await { let Ok(file_info) = FileInfo::get_by_id(&file_id, &state.db).await else {
Ok(info) => info, return Ok(TemplateResponse::not_found().into_response());
_ => return Ok(TemplateResponse::not_found().into_response()),
}; };
if file_info.user_id != user.id { if file_info.user_id != user.id {
return Ok(TemplateResponse::unauthorized().into_response()); return Ok(TemplateResponse::unauthorized().into_response());
} }
let stream = match state.storage.get_stream(&file_info.path).await { let Ok(stream) = state.storage.get_stream(&file_info.path).await else {
Ok(s) => s, return Ok(TemplateResponse::server_error().into_response());
Err(_) => return Ok(TemplateResponse::server_error().into_response()),
}; };
let body = Body::from_stream(stream); let body = Body::from_stream(stream);
+6 -6
View File
@@ -1,4 +1,4 @@
use std::{pin::Pin, time::Duration}; use std::{pin::Pin, sync::Arc, time::Duration};
use axum::{ use axum::{
extract::{Query, State}, extract::{Query, State},
@@ -51,13 +51,13 @@ pub async fn show_ingest_form(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
let user_categories = User::get_user_categories(&user.id, &state.db).await?;
#[derive(Serialize)] #[derive(Serialize)]
pub struct ShowIngestFormData { pub struct ShowIngestFormData {
user_categories: Vec<String>, user_categories: Vec<String>,
} }
let user_categories = User::get_user_categories(&user.id, &state.db).await?;
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"ingestion_modal.html", "ingestion_modal.html",
ShowIngestFormData { user_categories }, ShowIngestFormData { user_categories },
@@ -180,7 +180,7 @@ pub async fn get_task_updates_stream(
Query(params): Query<QueryParams>, Query(params): Query<QueryParams>,
) -> TaskSse { ) -> TaskSse {
let task_id = params.task_id.clone(); let task_id = params.task_id.clone();
let db = state.db.clone(); let db = Arc::clone(&state.db);
// 1. Check for authenticated user // 1. Check for authenticated user
let Some(current_user) = auth.current_user else { let Some(current_user) = auth.current_user else {
@@ -198,7 +198,7 @@ pub async fn get_task_updates_stream(
} }
let sse_stream = async_stream::stream! { let sse_stream = async_stream::stream! {
let mut consecutive_db_errors = 0; let mut consecutive_db_errors: u32 = 0;
let max_consecutive_db_errors = 3; let max_consecutive_db_errors = 3;
loop { loop {
@@ -263,7 +263,7 @@ pub async fn get_task_updates_stream(
} }
Err(db_err) => { Err(db_err) => {
error!("Database error while fetching task '{}': {:?}", task_id, db_err); error!("Database error while fetching task '{}': {:?}", task_id, db_err);
consecutive_db_errors += 1; consecutive_db_errors = consecutive_db_errors.saturating_add(1);
yield Ok(Event::default().event("error").data(format!("Temporary error fetching task update (attempt {consecutive_db_errors})."))); yield Ok(Event::default().event("error").data(format!("Temporary error fetching task update (attempt {consecutive_db_errors}).")));
if consecutive_db_errors >= max_consecutive_db_errors { if consecutive_db_errors >= max_consecutive_db_errors {
+21 -20
View File
@@ -39,7 +39,7 @@ use url::form_urlencoded;
const KNOWLEDGE_ENTITIES_PER_PAGE: usize = 12; const KNOWLEDGE_ENTITIES_PER_PAGE: usize = 12;
const RELATIONSHIP_TYPE_OPTIONS: &[&str] = &["RelatedTo", "RelevantTo", "SimilarTo", "References"]; const RELATIONSHIP_TYPE_OPTIONS: &[&str] = &["RelatedTo", "RelevantTo", "SimilarTo", "References"];
const DEFAULT_RELATIONSHIP_TYPE: &str = RELATIONSHIP_TYPE_OPTIONS[0]; const DEFAULT_RELATIONSHIP_TYPE: &str = "RelatedTo";
const MAX_RELATIONSHIP_SUGGESTIONS: usize = 10; const MAX_RELATIONSHIP_SUGGESTIONS: usize = 10;
const SUGGESTION_MIN_SCORE: f32 = 0.5; const SUGGESTION_MIN_SCORE: f32 = 0.5;
@@ -61,15 +61,15 @@ fn canonicalize_relationship_type(value: &str) -> String {
let key: String = trimmed let key: String = trimmed
.chars() .chars()
.filter(|c| c.is_ascii_alphanumeric()) .filter(char::is_ascii_alphanumeric)
.flat_map(|c| c.to_lowercase()) .flat_map(char::to_lowercase)
.collect(); .collect();
for option in RELATIONSHIP_TYPE_OPTIONS { for option in RELATIONSHIP_TYPE_OPTIONS {
let option_key: String = option let option_key: String = option
.chars() .chars()
.filter(|c| c.is_ascii_alphanumeric()) .filter(char::is_ascii_alphanumeric)
.flat_map(|c| c.to_lowercase()) .flat_map(char::to_lowercase)
.collect(); .collect();
if option_key == key { if option_key == key {
return (*option).to_string(); return (*option).to_string();
@@ -141,7 +141,7 @@ pub async fn show_new_knowledge_entity_form(
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
let entity_types: Vec<String> = KnowledgeEntityType::variants() let entity_types: Vec<String> = KnowledgeEntityType::variants()
.iter() .iter()
.map(|&s| s.to_owned()) .map(ToString::to_string)
.collect(); .collect();
let existing_entities = User::get_knowledge_entities(&user.id, &state.db).await?; let existing_entities = User::get_knowledge_entities(&user.id, &state.db).await?;
@@ -278,7 +278,7 @@ pub async fn suggest_knowledge_relationships(
if !query_parts.is_empty() { if !query_parts.is_empty() {
let query = query_parts.join(" "); let query = query_parts.join(" ");
let rerank_lease = match state.reranker_pool.as_ref() { let rerank_lease = match state.reranker_pool.as_ref() {
Some(pool) => Some(pool.checkout().await), Some(pool) => pool.checkout().await,
None => None, None => None,
}; };
@@ -406,9 +406,10 @@ fn build_relationship_table_data(
.map(|relationship| { .map(|relationship| {
let relationship_type_label = let relationship_type_label =
canonicalize_relationship_type(&relationship.metadata.relationship_type); canonicalize_relationship_type(&relationship.metadata.relationship_type);
*frequency let count = frequency
.entry(relationship_type_label.clone()) .entry(relationship_type_label.clone())
.or_insert(0) += 1; .or_insert(0);
*count = count.saturating_add(1);
RelationshipTableRow { RelationshipTableRow {
relationship, relationship,
relationship_type_label, relationship_type_label,
@@ -417,9 +418,7 @@ fn build_relationship_table_data(
.collect(); .collect();
let default_relationship_type = frequency let default_relationship_type = frequency
.into_iter() .into_iter()
.max_by_key(|(_, count)| *count) .max_by_key(|(_, count)| *count).map_or_else(|| DEFAULT_RELATIONSHIP_TYPE.to_string(), |(label, _)| label);
.map(|(label, _)| label)
.unwrap_or_else(|| DEFAULT_RELATIONSHIP_TYPE.to_string());
RelationshipTableData { RelationshipTableData {
entities, entities,
@@ -800,8 +799,10 @@ pub async fn get_knowledge_graph_json(
for rel in &relationships { for rel in &relationships {
if entity_ids.contains(&rel.in_) && entity_ids.contains(&rel.out) { if entity_ids.contains(&rel.in_) && entity_ids.contains(&rel.out) {
// undirected counting for degree // undirected counting for degree
*degree_count.entry(rel.in_.clone()).or_insert(0) += 1; let count = degree_count.entry(rel.in_.clone()).or_insert(0);
*degree_count.entry(rel.out.clone()).or_insert(0) += 1; *count = count.saturating_add(1);
let count = degree_count.entry(rel.out.clone()).or_insert(0);
*count = count.saturating_add(1);
links.push(GraphLink { links.push(GraphLink {
source: rel.out.clone(), source: rel.out.clone(),
target: rel.in_.clone(), target: rel.in_.clone(),
@@ -836,11 +837,11 @@ fn normalize_filter(input: Option<String>) -> Option<String> {
fn trim_matching_quotes(value: &str) -> &str { fn trim_matching_quotes(value: &str) -> &str {
let bytes = value.as_bytes(); let bytes = value.as_bytes();
if bytes.len() >= 2 { if let (Some(&first), Some(&last)) = (bytes.first(), bytes.last()) {
let first = bytes[0]; if bytes.len() >= 2
let last = bytes[bytes.len() - 1]; && ((first == b'"' && last == b'"') || (first == b'\'' && last == b'\''))
if (first == b'"' && last == b'"') || (first == b'\'' && last == b'\'') { {
return &value[1..value.len() - 1]; return &value[1..value.len().saturating_sub(1)];
} }
} }
value value
@@ -860,7 +861,7 @@ pub async fn show_edit_knowledge_entity_form(
// Get entity types // Get entity types
let entity_types: Vec<String> = KnowledgeEntityType::variants() let entity_types: Vec<String> = KnowledgeEntityType::variants()
.iter() .iter()
.map(|&s| s.to_owned()) .map(ToString::to_string)
.collect(); .collect();
// Get the entity and validate ownership // Get the entity and validate ownership
+103 -94
View File
@@ -11,6 +11,7 @@ use axum::{
use common::storage::types::{ use common::storage::types::{
serde_helpers::deserialize_flexible_id, serde_helpers::deserialize_flexible_id,
text_content::TextContent, text_content::TextContent,
user::User,
StoredObject, StoredObject,
}; };
use retrieval_pipeline::{RetrievalConfig, SearchResult, SearchTarget, StrategyOutput}; use retrieval_pipeline::{RetrievalConfig, SearchResult, SearchTarget, StrategyOutput};
@@ -46,13 +47,11 @@ fn source_id_suffix(source_id: &str) -> String {
fn truncate_label(value: &str, max_chars: usize) -> String { fn truncate_label(value: &str, max_chars: usize) -> String {
let mut end = None; let mut end = None;
let mut count = 0; for (count, (idx, _)) in value.char_indices().enumerate() {
for (idx, _) in value.char_indices() {
if count == max_chars { if count == max_chars {
end = Some(idx); end = Some(idx);
break; break;
} }
count += 1;
} }
match end { match end {
@@ -174,11 +173,6 @@ struct KnowledgeEntityForTemplate {
score: f32, score: f32,
} }
pub async fn search_result_handler(
State(state): State<HtmlState>,
Query(params): Query<SearchParams>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
#[derive(Serialize)] #[derive(Serialize)]
struct SearchResultForTemplate { struct SearchResultForTemplate {
result_type: String, result_type: String,
@@ -195,38 +189,119 @@ pub async fn search_result_handler(
query_param: String, query_param: String,
} }
pub async fn search_result_handler(
State(state): State<HtmlState>,
Query(params): Query<SearchParams>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
let (search_results_for_template, final_query_param_for_template) = if let Some(actual_query) = let (search_results_for_template, final_query_param_for_template) = if let Some(actual_query) =
params.query params.query
{ {
let trimmed_query = actual_query.trim(); perform_search(&state, &user, actual_query).await?
if trimmed_query.is_empty() {
(Vec::<SearchResultForTemplate>::new(), String::new())
} else { } else {
// Use retrieval pipeline Search strategy (Vec::<SearchResultForTemplate>::new(), String::new())
};
Ok(TemplateResponse::new_template(
"search/base.html",
AnswerData {
search_result: search_results_for_template,
query_param: final_query_param_for_template,
},
))
}
async fn perform_search(
state: &HtmlState,
user: &User,
query: String,
) -> Result<(Vec<SearchResultForTemplate>, String), HtmlError> {
const TOTAL_LIMIT: usize = 10;
let trimmed_query = query.trim();
if trimmed_query.is_empty() {
return Ok((Vec::new(), String::new()));
}
let config = RetrievalConfig::for_search(SearchTarget::Both); let config = RetrievalConfig::for_search(SearchTarget::Both);
// Checkout a reranker lease if pool is available
let reranker_lease = match &state.reranker_pool { let reranker_lease = match &state.reranker_pool {
Some(pool) => Some(pool.checkout().await), Some(pool) => pool.checkout().await,
None => None, None => None,
}; };
let result = retrieval_pipeline::pipeline::run_pipeline( let params = retrieval_pipeline::pipeline::StrategyParams {
&state.db, db_client: &state.db,
&state.openai_client, openai_client: &state.openai_client,
Some(&state.embedding_provider), embedding_provider: Some(&state.embedding_provider),
trimmed_query, input_text: trimmed_query,
&user.id, user_id: &user.id,
config, config,
reranker_lease, reranker: reranker_lease,
) };
.await?; let result = retrieval_pipeline::pipeline::execute(params).await?;
let search_result = match result { let search_result = match result {
StrategyOutput::Search(sr) => sr, StrategyOutput::Search(sr) => sr,
_ => SearchResult::new(vec![], vec![]), _ => SearchResult::new(vec![], vec![]),
}; };
let source_label_map = resolve_source_labels(state, user, &search_result).await?;
let mut combined_results: Vec<SearchResultForTemplate> =
Vec::with_capacity(search_result.chunks.len().saturating_add(search_result.entities.len()));
for chunk_result in search_result.chunks {
let source_label = source_label_map
.get(&chunk_result.chunk.source_id)
.cloned()
.unwrap_or_else(|| fallback_source_label(&chunk_result.chunk.source_id));
combined_results.push(SearchResultForTemplate {
result_type: "text_chunk".to_string(),
score: chunk_result.score,
text_chunk: Some(TextChunkForTemplate {
id: chunk_result.chunk.id,
source_id: chunk_result.chunk.source_id,
source_label,
chunk: chunk_result.chunk.chunk,
score: chunk_result.score,
}),
knowledge_entity: None,
});
}
for entity_result in search_result.entities {
let source_label = source_label_map
.get(&entity_result.entity.source_id)
.cloned()
.unwrap_or_else(|| fallback_source_label(&entity_result.entity.source_id));
combined_results.push(SearchResultForTemplate {
result_type: "knowledge_entity".to_string(),
score: entity_result.score,
text_chunk: None,
knowledge_entity: Some(KnowledgeEntityForTemplate {
id: entity_result.entity.id,
name: entity_result.entity.name,
description: entity_result.entity.description,
entity_type: format!("{:?}", entity_result.entity.entity_type),
source_id: entity_result.entity.source_id,
source_label,
score: entity_result.score,
}),
});
}
combined_results.sort_by(|a, b| b.score.total_cmp(&a.score));
combined_results.truncate(TOTAL_LIMIT);
Ok((combined_results, trimmed_query.to_string()))
}
async fn resolve_source_labels(
state: &HtmlState,
user: &User,
search_result: &SearchResult,
) -> Result<HashMap<String, String>, HtmlError> {
let mut source_ids = HashSet::new(); let mut source_ids = HashSet::new();
for chunk_result in &search_result.chunks { for chunk_result in &search_result.chunks {
source_ids.insert(chunk_result.chunk.source_id.clone()); source_ids.insert(chunk_result.chunk.source_id.clone());
@@ -235,9 +310,10 @@ pub async fn search_result_handler(
source_ids.insert(entity_result.entity.source_id.clone()); source_ids.insert(entity_result.entity.source_id.clone());
} }
let source_label_map = if source_ids.is_empty() { if source_ids.is_empty() {
HashMap::new() return Ok(HashMap::new());
} else { }
let record_ids: Vec<RecordId> = source_ids let record_ids: Vec<RecordId> = source_ids
.iter() .iter()
.filter_map(|id| { .filter_map(|id| {
@@ -276,72 +352,5 @@ pub async fn search_result_handler(
); );
} }
labels Ok(labels)
};
let mut combined_results: Vec<SearchResultForTemplate> =
Vec::with_capacity(search_result.chunks.len() + search_result.entities.len());
// Add chunk results
for chunk_result in search_result.chunks {
let source_label = source_label_map
.get(&chunk_result.chunk.source_id)
.cloned()
.unwrap_or_else(|| fallback_source_label(&chunk_result.chunk.source_id));
combined_results.push(SearchResultForTemplate {
result_type: "text_chunk".to_string(),
score: chunk_result.score,
text_chunk: Some(TextChunkForTemplate {
id: chunk_result.chunk.id,
source_id: chunk_result.chunk.source_id,
source_label,
chunk: chunk_result.chunk.chunk,
score: chunk_result.score,
}),
knowledge_entity: None,
});
}
// Add entity results
for entity_result in search_result.entities {
let source_label = source_label_map
.get(&entity_result.entity.source_id)
.cloned()
.unwrap_or_else(|| fallback_source_label(&entity_result.entity.source_id));
combined_results.push(SearchResultForTemplate {
result_type: "knowledge_entity".to_string(),
score: entity_result.score,
text_chunk: None,
knowledge_entity: Some(KnowledgeEntityForTemplate {
id: entity_result.entity.id,
name: entity_result.entity.name,
description: entity_result.entity.description,
entity_type: format!("{:?}", entity_result.entity.entity_type),
source_id: entity_result.entity.source_id,
source_label,
score: entity_result.score,
}),
});
}
// Sort by score descending
combined_results.sort_by(|a, b| b.score.total_cmp(&a.score));
// Limit results
const TOTAL_LIMIT: usize = 10;
combined_results.truncate(TOTAL_LIMIT);
(combined_results, trimmed_query.to_string())
}
} else {
(Vec::<SearchResultForTemplate>::new(), String::new())
};
Ok(TemplateResponse::new_template(
"search/base.html",
AnswerData {
search_result: search_results_for_template,
query_param: final_query_param_for_template,
},
))
} }
+5 -2
View File
@@ -1,7 +1,10 @@
mod handlers; mod handlers;
use axum::{extract::FromRef, routing::get, Router}; use axum::{extract::FromRef, routing::get, Router};
pub use handlers::{search_result_handler, SearchParams}; #[allow(clippy::module_name_repetitions)]
pub use handlers::{
search_result_handler as result_handler, SearchParams as SearchQueryParams,
};
use crate::html_state::HtmlState; use crate::html_state::HtmlState;
@@ -10,5 +13,5 @@ where
S: Clone + Send + Sync + 'static, S: Clone + Send + Sync + 'static,
HtmlState: FromRef<S>, HtmlState: FromRef<S>,
{ {
Router::new().route("/search", get(search_result_handler)) Router::new().route("/search", get(result_handler))
} }
+8 -8
View File
@@ -31,8 +31,8 @@ impl Pagination {
} else { } else {
0 0
}; };
let start_index = if page_len == 0 { 0 } else { offset + 1 }; let start_index = if page_len == 0 { 0 } else { offset.saturating_add(1) };
let end_index = if page_len == 0 { 0 } else { offset + page_len }; let end_index = if page_len == 0 { 0 } else { offset.saturating_add(page_len) };
Self { Self {
current_page, current_page,
@@ -42,12 +42,12 @@ impl Pagination {
has_previous, has_previous,
has_next, has_next,
previous_page: if has_previous { previous_page: if has_previous {
Some(current_page - 1) Some(current_page.saturating_sub(1))
} else { } else {
None None
}, },
next_page: if has_next { next_page: if has_next {
Some(current_page + 1) Some(current_page.saturating_add(1))
} else { } else {
None None
}, },
@@ -68,7 +68,7 @@ pub fn paginate_items<T>(
let total_pages = if total_items == 0 { let total_pages = if total_items == 0 {
0 0
} else { } else {
((total_items - 1) / per_page) + 1 total_items.saturating_sub(1).checked_div(per_page).unwrap_or(0).saturating_add(1)
}; };
let mut current_page = requested_page.unwrap_or(1); let mut current_page = requested_page.unwrap_or(1);
@@ -84,7 +84,7 @@ pub fn paginate_items<T>(
let offset = if total_pages == 0 { let offset = if total_pages == 0 {
0 0
} else { } else {
per_page.saturating_mul(current_page - 1) per_page.saturating_mul(current_page.saturating_sub(1))
}; };
let page_items: Vec<T> = items.into_iter().skip(offset).take(per_page).collect(); let page_items: Vec<T> = items.into_iter().skip(offset).take(per_page).collect();
@@ -136,8 +136,8 @@ mod tests {
assert_eq!(page, vec![5]); assert_eq!(page, vec![5]);
assert_eq!(meta.current_page, 3); assert_eq!(meta.current_page, 3);
assert_eq!(meta.total_pages, 3); assert_eq!(meta.total_pages, 3);
assert_eq!(meta.has_next, false); assert!(!meta.has_next, "expected no next page");
assert_eq!(meta.has_previous, true); assert!(meta.has_previous, "expected previous page");
assert_eq!(meta.start_index, 5); assert_eq!(meta.start_index, 5);
assert_eq!(meta.end_index, 5); assert_eq!(meta.end_index, 5);
} }
+1 -1
View File
@@ -180,7 +180,7 @@ impl PipelineServices for DefaultPipelineServices {
); );
let rerank_lease = match &self.reranker_pool { let rerank_lease = match &self.reranker_pool {
Some(pool) => Some(pool.checkout().await), Some(pool) => pool.checkout().await,
None => None, None => None,
}; };
@@ -4,7 +4,7 @@ use common::{
error::AppError, error::AppError,
storage::{ storage::{
db::SurrealDbClient, db::SurrealDbClient,
indexes::rebuild_indexes, indexes::rebuild,
types::{ types::{
ingestion_payload::IngestionPayload, knowledge_entity::KnowledgeEntity, ingestion_payload::IngestionPayload, knowledge_entity::KnowledgeEntity,
knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk, knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk,
@@ -191,7 +191,7 @@ pub async fn persist(
ctx.db.store_item(text_content).await?; ctx.db.store_item(text_content).await?;
debug!("stored item"); debug!("stored item");
rebuild_indexes(ctx.db).await?; rebuild(ctx.db).await?;
debug!( debug!(
task_id = %ctx.task_id, task_id = %ctx.task_id,
@@ -301,8 +301,8 @@ async fn store_chunk_batch(
for embedded in batch { for embedded in batch {
TextChunk::store_with_embedding( TextChunk::store_with_embedding(
embedded.chunk.to_owned(), embedded.chunk.clone(),
embedded.embedding.to_owned(), embedded.embedding.clone(),
db, db,
) )
.await?; .await?;
+62 -72
View File
@@ -1,5 +1,6 @@
use std::sync::Arc; use std::sync::Arc;
use anyhow::{self, Context};
use crate::pipeline::context::{EmbeddedKnowledgeEntity, EmbeddedTextChunk}; use crate::pipeline::context::{EmbeddedKnowledgeEntity, EmbeddedTextChunk};
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{Duration as ChronoDuration, Utc}; use chrono::{Duration as ChronoDuration, Utc};
@@ -265,16 +266,12 @@ impl PipelineServices for ValidationServices {
} }
} }
async fn setup_db() -> SurrealDbClient { async fn setup_db() -> anyhow::Result<SurrealDbClient> {
let namespace = "pipeline_test"; let namespace = "pipeline_test";
let database = Uuid::new_v4().to_string(); let database = Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, &database) let db = SurrealDbClient::memory(namespace, &database).await?;
.await db.apply_migrations().await?;
.expect("Failed to create in-memory SurrealDB"); Ok(db)
db.apply_migrations()
.await
.expect("Failed to apply migrations");
db
} }
fn pipeline_config() -> IngestionConfig { fn pipeline_config() -> IngestionConfig {
@@ -295,26 +292,28 @@ async fn reserve_task(
worker_id: &str, worker_id: &str,
payload: IngestionPayload, payload: IngestionPayload,
user_id: &str, user_id: &str,
) -> IngestionTask { ) -> anyhow::Result<IngestionTask> {
let task = IngestionTask::create_and_add_to_db(payload, user_id.into(), db) let task = IngestionTask::create_and_add_to_db(payload, user_id.into(), db).await?;
.await
.expect("task created");
let lease = task.lease_duration(); let lease = task.lease_duration();
IngestionTask::claim_next_ready(db, worker_id, Utc::now(), lease) let claimed = IngestionTask::claim_next_ready(db, worker_id, Utc::now(), lease)
.await .await?
.expect("claim succeeds") .context("task claimed")?;
.expect("task claimed") Ok(claimed)
} }
#[tokio::test] #[tokio::test]
async fn ingestion_pipeline_happy_path_persists_entities() { async fn ingestion_pipeline_happy_path_persists_entities() -> anyhow::Result<()>
let db = setup_db().await; {
let db = setup_db().await?;
let worker_id = "worker-happy"; let worker_id = "worker-happy";
let user_id = "user-123"; let user_id = "user-123";
let services = Arc::new(MockServices::new(user_id)); let services = Arc::new(MockServices::new(user_id));
let pipeline = let services_clone: Arc<dyn PipelineServices> = Arc::<MockServices>::clone(&services);
IngestionPipeline::with_services(Arc::new(db.clone()), pipeline_config(), services.clone()) let pipeline = IngestionPipeline::with_services(
.expect("pipeline"); Arc::new(db.clone()),
pipeline_config(),
services_clone,
)?;
let task = reserve_task( let task = reserve_task(
&db, &db,
@@ -327,30 +326,22 @@ async fn ingestion_pipeline_happy_path_persists_entities() {
}, },
user_id, user_id,
) )
.await; .await?;
pipeline pipeline.process_task(task.clone()).await?;
.process_task(task.clone())
.await
.expect("pipeline succeeds");
let stored_task: IngestionTask = db let stored_task: IngestionTask = db
.get_item(&task.id) .get_item(&task.id)
.await .await?
.expect("retrieve task") .context("task present")?;
.expect("task present");
assert_eq!(stored_task.state, TaskState::Succeeded); assert_eq!(stored_task.state, TaskState::Succeeded);
let stored_entities: Vec<KnowledgeEntity> = db let stored_entities: Vec<KnowledgeEntity> = db
.get_all_stored_items::<KnowledgeEntity>() .get_all_stored_items::<KnowledgeEntity>()
.await .await?;
.expect("entities stored");
assert!(!stored_entities.is_empty(), "entities should be stored"); assert!(!stored_entities.is_empty(), "entities should be stored");
let stored_chunks: Vec<TextChunk> = db let stored_chunks: Vec<TextChunk> = db.get_all_stored_items::<TextChunk>().await?;
.get_all_stored_items::<TextChunk>()
.await
.expect("chunks stored");
assert!( assert!(
!stored_chunks.is_empty(), !stored_chunks.is_empty(),
"chunks should be stored for ingestion text" "chunks should be stored for ingestion text"
@@ -362,22 +353,29 @@ async fn ingestion_pipeline_happy_path_persists_entities() {
"expected at least one chunk embedding call" "expected at least one chunk embedding call"
); );
assert_eq!( assert_eq!(
&call_log[0..4], call_log.get(0..4),
["prepare", "retrieve", "enrich", "convert"] Some(&["prepare", "retrieve", "enrich", "convert"][..])
); );
assert!(call_log[4..].iter().all(|entry| *entry == "chunk")); assert!(
call_log.get(4..).is_some_and(|tail| tail.iter().all(|entry| *entry == "chunk"))
);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn ingestion_pipeline_chunk_only_skips_analysis() { async fn ingestion_pipeline_chunk_only_skips_analysis() -> anyhow::Result<()> {
let db = setup_db().await; let db = setup_db().await?;
let worker_id = "worker-chunk-only"; let worker_id = "worker-chunk-only";
let user_id = "user-999"; let user_id = "user-999";
let services = Arc::new(MockServices::new(user_id)); let services = Arc::new(MockServices::new(user_id));
let services_clone: Arc<dyn PipelineServices> = Arc::<MockServices>::clone(&services);
let mut config = pipeline_config(); let mut config = pipeline_config();
config.chunk_only = true; config.chunk_only = true;
let pipeline = IngestionPipeline::with_services(Arc::new(db.clone()), config, services.clone()) let pipeline = IngestionPipeline::with_services(
.expect("pipeline"); Arc::new(db.clone()),
config,
services_clone,
)?;
let task = reserve_task( let task = reserve_task(
&db, &db,
@@ -390,17 +388,13 @@ async fn ingestion_pipeline_chunk_only_skips_analysis() {
}, },
user_id, user_id,
) )
.await; .await?;
pipeline pipeline.process_task(task.clone()).await?;
.process_task(task.clone())
.await
.expect("pipeline succeeds");
let stored_entities: Vec<KnowledgeEntity> = db let stored_entities: Vec<KnowledgeEntity> = db
.get_all_stored_items::<KnowledgeEntity>() .get_all_stored_items::<KnowledgeEntity>()
.await .await?;
.expect("entities stored");
assert!( assert!(
stored_entities.is_empty(), stored_entities.is_empty(),
"chunk-only ingestion should not persist entities" "chunk-only ingestion should not persist entities"
@@ -408,8 +402,7 @@ async fn ingestion_pipeline_chunk_only_skips_analysis() {
let relationship_count: Option<i64> = db let relationship_count: Option<i64> = db
.client .client
.query("SELECT count() as count FROM relates_to;") .query("SELECT count() as count FROM relates_to;")
.await .await?
.expect("query relationships")
.take::<Option<i64>>(0) .take::<Option<i64>>(0)
.unwrap_or_default(); .unwrap_or_default();
assert_eq!( assert_eq!(
@@ -417,10 +410,7 @@ async fn ingestion_pipeline_chunk_only_skips_analysis() {
0, 0,
"chunk-only ingestion should not persist relationships" "chunk-only ingestion should not persist relationships"
); );
let stored_chunks: Vec<TextChunk> = db let stored_chunks: Vec<TextChunk> = db.get_all_stored_items::<TextChunk>().await?;
.get_all_stored_items::<TextChunk>()
.await
.expect("chunks stored");
assert!( assert!(
!stored_chunks.is_empty(), !stored_chunks.is_empty(),
"chunk-only ingestion should still persist chunks" "chunk-only ingestion should still persist chunks"
@@ -428,19 +418,19 @@ async fn ingestion_pipeline_chunk_only_skips_analysis() {
let call_log = services.calls.lock().await.clone(); let call_log = services.calls.lock().await.clone();
assert_eq!(call_log, vec!["prepare", "chunk"]); assert_eq!(call_log, vec!["prepare", "chunk"]);
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn ingestion_pipeline_failure_marks_retry() { async fn ingestion_pipeline_failure_marks_retry() -> anyhow::Result<()> {
let db = setup_db().await; let db = setup_db().await?;
let worker_id = "worker-fail"; let worker_id = "worker-fail";
let user_id = "user-456"; let user_id = "user-456";
let services = Arc::new(FailingServices { let services = Arc::new(FailingServices {
inner: MockServices::new(user_id), inner: MockServices::new(user_id),
}); });
let pipeline = let pipeline =
IngestionPipeline::with_services(Arc::new(db.clone()), pipeline_config(), services) IngestionPipeline::with_services(Arc::new(db.clone()), pipeline_config(), services)?;
.expect("pipeline");
let task = reserve_task( let task = reserve_task(
&db, &db,
@@ -453,7 +443,7 @@ async fn ingestion_pipeline_failure_marks_retry() {
}, },
user_id, user_id,
) )
.await; .await?;
let result = pipeline.process_task(task.clone()).await; let result = pipeline.process_task(task.clone()).await;
assert!( assert!(
@@ -463,38 +453,38 @@ async fn ingestion_pipeline_failure_marks_retry() {
let stored_task: IngestionTask = db let stored_task: IngestionTask = db
.get_item(&task.id) .get_item(&task.id)
.await .await?
.expect("retrieve task") .context("task present")?;
.expect("task present");
assert_eq!(stored_task.state, TaskState::Failed); assert_eq!(stored_task.state, TaskState::Failed);
assert!( assert!(
stored_task.scheduled_at > Utc::now() - ChronoDuration::seconds(5), stored_task.scheduled_at > Utc::now() - ChronoDuration::seconds(5),
"failed task should schedule retry in the future" "failed task should schedule retry in the future"
); );
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn ingestion_pipeline_validation_failure_dead_letters_task() { async fn ingestion_pipeline_validation_failure_dead_letters_task(
let db = setup_db().await; ) -> anyhow::Result<()> {
let db = setup_db().await?;
let worker_id = "worker-validation"; let worker_id = "worker-validation";
let user_id = "user-789"; let user_id = "user-789";
let services = Arc::new(ValidationServices); let services = Arc::new(ValidationServices);
let pipeline = let pipeline =
IngestionPipeline::with_services(Arc::new(db.clone()), pipeline_config(), services) IngestionPipeline::with_services(Arc::new(db.clone()), pipeline_config(), services)?;
.expect("pipeline");
let task = reserve_task( let task = reserve_task(
&db, &db,
worker_id, worker_id,
IngestionPayload::Text { IngestionPayload::Text {
text: "irrelevant".into(), text: "irrelevant".into(),
context: "".into(), context: String::new(),
category: "notes".into(), category: "notes".into(),
user_id: user_id.into(), user_id: user_id.into(),
}, },
user_id, user_id,
) )
.await; .await?;
let result = pipeline.process_task(task.clone()).await; let result = pipeline.process_task(task.clone()).await;
assert!( assert!(
@@ -504,8 +494,8 @@ async fn ingestion_pipeline_validation_failure_dead_letters_task() {
let stored_task: IngestionTask = db let stored_task: IngestionTask = db
.get_item(&task.id) .get_item(&task.id)
.await .await?
.expect("retrieve task") .context("task present")?;
.expect("task present");
assert_eq!(stored_task.state, TaskState::DeadLetter); assert_eq!(stored_task.state, TaskState::DeadLetter);
Ok(())
} }
@@ -155,21 +155,20 @@ mod tests {
}; };
#[tokio::test] #[tokio::test]
async fn extracts_text_using_memory_storage_backend() { async fn extracts_text_using_memory_storage_backend() -> anyhow::Result<()> {
let mut config = AppConfig::default(); let config = AppConfig {
config.storage = StorageKind::Memory; storage: StorageKind::Memory,
..Default::default()
};
let storage = StorageManager::new(&config) let storage = StorageManager::new(&config).await?;
.await
.expect("create storage manager");
let location = "user/test/file.txt"; let location = "user/test/file.txt";
let contents = b"hello from memory storage"; let contents = b"hello from memory storage";
storage storage
.put(location, Bytes::from(contents.as_slice().to_vec())) .put(location, Bytes::from(contents.as_slice().to_vec()))
.await .await?;
.expect("write object");
let now = Utc::now(); let now = Utc::now();
let file_info = FileInfo { let file_info = FileInfo {
@@ -185,16 +184,14 @@ mod tests {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database).await?;
.await
.expect("create surreal memory");
let openai_client = Client::with_config(OpenAIConfig::default()); let openai_client = Client::with_config(OpenAIConfig::default());
let text = extract_text_from_file(&file_info, &db, &openai_client, &config, &storage) let text = extract_text_from_file(&file_info, &db, &openai_client, &config, &storage)
.await .await?;
.expect("extract text");
assert_eq!(text, String::from_utf8_lossy(contents)); assert_eq!(text, String::from_utf8_lossy(contents));
Ok(())
} }
} }
@@ -715,6 +715,7 @@ const fn prompt_for_attempt(attempt: usize, base_prompt: &str) -> &str {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use anyhow::{self};
#[test] #[test]
fn test_looks_good_enough_short_text() { fn test_looks_good_enough_short_text() {
@@ -737,15 +738,16 @@ mod tests {
} }
#[test] #[test]
fn test_debug_dump_directory_env_var() { fn test_debug_dump_directory_env_var() -> anyhow::Result<()> {
std::env::remove_var(DEBUG_IMAGE_ENV_VAR); std::env::remove_var(DEBUG_IMAGE_ENV_VAR);
assert!(debug_dump_directory().is_none()); assert!(debug_dump_directory().is_none());
std::env::set_var(DEBUG_IMAGE_ENV_VAR, "/tmp/minne_pdf_debug"); std::env::set_var(DEBUG_IMAGE_ENV_VAR, "/tmp/minne_pdf_debug");
let dir = debug_dump_directory().expect("expected debug directory"); let dir = debug_dump_directory().ok_or_else(|| anyhow::anyhow!("expected debug directory"))?;
assert_eq!(dir, PathBuf::from("/tmp/minne_pdf_debug")); assert_eq!(dir, PathBuf::from("/tmp/minne_pdf_debug"));
std::env::remove_var(DEBUG_IMAGE_ENV_VAR); std::env::remove_var(DEBUG_IMAGE_ENV_VAR);
Ok(())
} }
#[test] #[test]
@@ -142,29 +142,34 @@ fn ensure_ingestion_url_allowed(url: &url::Url) -> Result<String, AppError> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use anyhow::{self};
#[test] #[test]
fn rejects_unsupported_scheme() { fn rejects_unsupported_scheme() -> anyhow::Result<()> {
let url = url::Url::parse("ftp://example.com").expect("url"); let url = url::Url::parse("ftp://example.com")?;
assert!(ensure_ingestion_url_allowed(&url).is_err()); assert!(ensure_ingestion_url_allowed(&url).is_err());
Ok(())
} }
#[test] #[test]
fn rejects_localhost() { fn rejects_localhost() -> anyhow::Result<()> {
let url = url::Url::parse("http://localhost/resource").expect("url"); let url = url::Url::parse("http://localhost/resource")?;
assert!(ensure_ingestion_url_allowed(&url).is_err()); assert!(ensure_ingestion_url_allowed(&url).is_err());
Ok(())
} }
#[test] #[test]
fn rejects_private_ipv4() { fn rejects_private_ipv4() -> anyhow::Result<()> {
let url = url::Url::parse("http://192.168.1.10/index.html").expect("url"); let url = url::Url::parse("http://192.168.1.10/index.html")?;
assert!(ensure_ingestion_url_allowed(&url).is_err()); assert!(ensure_ingestion_url_allowed(&url).is_err());
Ok(())
} }
#[test] #[test]
fn allows_public_domain_and_sanitizes() { fn allows_public_domain_and_sanitizes() -> anyhow::Result<()> {
let url = url::Url::parse("https://sub.example.com/path").expect("url"); let url = url::Url::parse("https://sub.example.com/path")?;
let sanitized = ensure_ingestion_url_allowed(&url).expect("allowed"); let sanitized = ensure_ingestion_url_allowed(&url)?;
assert_eq!(sanitized, "sub_example_com"); assert_eq!(sanitized, "sub_example_com");
Ok(())
} }
} }
+97 -118
View File
@@ -3,7 +3,7 @@ use axum::{extract::FromRef, Router};
use common::{ use common::{
storage::{ storage::{
db::SurrealDbClient, db::SurrealDbClient,
indexes::ensure_runtime_indexes, indexes::ensure_runtime,
store::StorageManager, store::StorageManager,
types::{ types::{
knowledge_entity::KnowledgeEntity, system_settings::SystemSettings, knowledge_entity::KnowledgeEntity, system_settings::SystemSettings,
@@ -12,7 +12,10 @@ use common::{
}, },
utils::{config::get_config, embedding::EmbeddingProvider}, utils::{config::get_config, embedding::EmbeddingProvider},
}; };
use html_router::{html_routes, html_state::HtmlState}; use html_router::{
html_routes,
html_state::{HtmlState, StateResources},
};
use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop}; use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop};
use retrieval_pipeline::reranking::RerankerPool; use retrieval_pipeline::reranking::RerankerPool;
use std::sync::Arc; use std::sync::Arc;
@@ -21,19 +24,77 @@ use tracing_subscriber::{fmt, prelude::*, EnvFilter};
use tokio::task::LocalSet; use tokio::task::LocalSet;
fn spawn_server_thread(
listener: tokio::net::TcpListener,
app: Router,
) -> std::thread::JoinHandle<()> {
std::thread::spawn(move || {
let rt = match tokio::runtime::Runtime::new() {
Ok(rt) => rt,
Err(e) => {
error!("Failed to create server runtime: {e}");
return;
}
};
rt.block_on(async {
if let Err(e) = axum::serve(listener, app).await {
error!("Server error: {}", e);
}
});
})
}
async fn run_worker(
config: common::utils::config::AppConfig,
reranker_pool: Option<Arc<RerankerPool>>,
storage: StorageManager,
) -> anyhow::Result<()> {
let worker_db = Arc::new(
SurrealDbClient::new(
&config.surrealdb_address,
&config.surrealdb_username,
&config.surrealdb_password,
&config.surrealdb_namespace,
&config.surrealdb_database,
)
.await?,
);
let openai_client = Arc::new(async_openai::Client::with_config(
async_openai::config::OpenAIConfig::new()
.with_api_key(&config.openai_api_key)
.with_api_base(&config.openai_base_url),
));
let embedding_provider = Arc::new(
EmbeddingProvider::from_config(&config, Some(Arc::clone(&openai_client))).await?,
);
let ingestion_pipeline = Arc::new(
IngestionPipeline::new(
Arc::clone(&worker_db),
openai_client,
config,
reranker_pool,
storage,
embedding_provider,
)?,
);
info!("Starting worker process");
run_worker_loop(worker_db, ingestion_pipeline).await
}
#[tokio::main] #[tokio::main]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
// Set up tracing
tracing_subscriber::registry() tracing_subscriber::registry()
.with(fmt::layer().with_writer(std::io::stderr)) .with(fmt::layer().with_writer(std::io::stderr))
.with(EnvFilter::from_default_env()) .with(EnvFilter::from_default_env())
.try_init() .try_init()
.ok(); .ok();
// Get config
let config = get_config()?; let config = get_config()?;
// Set up router states
let db = Arc::new( let db = Arc::new(
SurrealDbClient::new( SurrealDbClient::new(
&config.surrealdb_address, &config.surrealdb_address,
@@ -45,7 +106,6 @@ async fn main() -> anyhow::Result<()> {
.await?, .await?,
); );
// Ensure db is initialized
db.apply_migrations().await?; db.apply_migrations().await?;
let session_store = Arc::new(db.create_session_store().await?); let session_store = Arc::new(db.create_session_store().await?);
@@ -55,27 +115,23 @@ async fn main() -> anyhow::Result<()> {
.with_api_base(&config.openai_base_url), .with_api_base(&config.openai_base_url),
)); ));
// Create embedding provider based on config before syncing settings.
let embedding_provider = let embedding_provider =
Arc::new(EmbeddingProvider::from_config(&config, Some(openai_client.clone())).await?); Arc::new(EmbeddingProvider::from_config(&config, Some(Arc::clone(&openai_client))).await?);
info!( info!(
embedding_backend = ?config.embedding_backend, embedding_backend = ?config.embedding_backend,
embedding_dimension = embedding_provider.dimension(), embedding_dimension = embedding_provider.dimension(),
"Embedding provider initialized" "Embedding provider initialized"
); );
// Sync SystemSettings with provider's dimensions/model/backend
let (settings, dimensions_changed) = let (settings, dimensions_changed) =
SystemSettings::sync_from_embedding_provider(&db, &embedding_provider).await?; SystemSettings::sync_from_embedding_provider(&db, &embedding_provider).await?;
// If dimensions changed, re-embed existing data to keep queries working.
if dimensions_changed { if dimensions_changed {
warn!( warn!(
new_dimensions = settings.embedding_dimensions, new_dimensions = settings.embedding_dimensions,
"Embedding configuration changed; re-embedding existing data" "Embedding configuration changed; re-embedding existing data"
); );
// Re-embed text chunks
info!("Re-embedding TextChunks"); info!("Re-embedding TextChunks");
if let Err(e) = if let Err(e) =
TextChunk::update_all_embeddings_with_provider(&db, &embedding_provider).await TextChunk::update_all_embeddings_with_provider(&db, &embedding_provider).await
@@ -86,7 +142,6 @@ async fn main() -> anyhow::Result<()> {
); );
} }
// Re-embed knowledge entities
info!("Re-embedding KnowledgeEntities"); info!("Re-embedding KnowledgeEntities");
if let Err(e) = if let Err(e) =
KnowledgeEntity::update_all_embeddings_with_provider(&db, &embedding_provider).await KnowledgeEntity::update_all_embeddings_with_provider(&db, &embedding_provider).await
@@ -100,29 +155,25 @@ async fn main() -> anyhow::Result<()> {
info!("Re-embedding complete."); info!("Re-embedding complete.");
} }
// Now ensure runtime indexes with the correct (synced) dimensions ensure_runtime(&db, settings.embedding_dimensions as usize).await?;
ensure_runtime_indexes(&db, settings.embedding_dimensions as usize).await?;
let reranker_pool = RerankerPool::maybe_from_config(&config)?; let reranker_pool = RerankerPool::maybe_from_config(&config)?;
// Create global storage manager
let storage = StorageManager::new(&config).await?; let storage = StorageManager::new(&config).await?;
let html_state = HtmlState::new_with_resources( let html_state = HtmlState::new_with_resources(StateResources {
db, db,
openai_client, openai_client,
session_store, session_store,
storage.clone(), storage: storage.clone(),
config.clone(), config: config.clone(),
reranker_pool.clone(), reranker_pool: reranker_pool.clone(),
embedding_provider.clone(), embedding_provider: Arc::clone(&embedding_provider),
None, template_engine: None,
) });
.await;
let api_state = ApiState::new(&config, storage.clone()).await?; let api_state = ApiState::new(&config, storage.clone()).await?;
// Create Axum router
let app = Router::new() let app = Router::new()
.nest("/api/v1", api_routes_v1(&api_state)) .nest("/api/v1", api_routes_v1(&api_state))
.merge(html_routes(&html_state)) .merge(html_routes(&html_state))
@@ -135,72 +186,16 @@ async fn main() -> anyhow::Result<()> {
let serve_address = format!("0.0.0.0:{}", config.http_port); let serve_address = format!("0.0.0.0:{}", config.http_port);
let listener = tokio::net::TcpListener::bind(serve_address).await?; let listener = tokio::net::TcpListener::bind(serve_address).await?;
// Start the server in a separate OS thread with its own runtime let server_handle = spawn_server_thread(listener, app);
let server_handle = std::thread::spawn(move || {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
if let Err(e) = axum::serve(listener, app).await {
error!("Server error: {}", e);
}
});
});
// Create a LocalSet for the worker
let local = LocalSet::new(); let local = LocalSet::new();
// Use a clone of the config for the worker
let worker_config = config.clone();
// Run the worker in the local set
local.spawn_local(async move { local.spawn_local(async move {
// Create worker db connection if let Err(e) = run_worker(config, reranker_pool, storage).await {
let worker_db = Arc::new( error!("Worker error: {}", e);
SurrealDbClient::new(
&worker_config.surrealdb_address,
&worker_config.surrealdb_username,
&worker_config.surrealdb_password,
&worker_config.surrealdb_namespace,
&worker_config.surrealdb_database,
)
.await
.unwrap(),
);
// Initialize worker components
let openai_client = Arc::new(async_openai::Client::with_config(
async_openai::config::OpenAIConfig::new()
.with_api_key(&config.openai_api_key)
.with_api_base(&config.openai_base_url),
));
// Create embedding provider based on config
let embedding_provider = Arc::new(
EmbeddingProvider::from_config(&config, Some(openai_client.clone()))
.await
.expect("failed to create embedding provider"),
);
let ingestion_pipeline = Arc::new(
IngestionPipeline::new(
worker_db.clone(),
openai_client.clone(),
config.clone(),
reranker_pool.clone(),
storage.clone(),
embedding_provider,
)
.unwrap(),
);
info!("Starting worker process");
if let Err(e) = run_worker_loop(worker_db, ingestion_pipeline).await {
error!("Worker process error: {}", e);
} }
}); });
// Run the local set on the main thread
local.await; local.await;
// Wait for the server thread to finish (this likely won't be reached)
if let Err(e) = server_handle.join() { if let Err(e) = server_handle.join() {
error!("Server thread panicked: {:?}", e); error!("Server thread panicked: {:?}", e);
} }
@@ -253,52 +248,39 @@ mod tests {
let namespace = "test_ns"; let namespace = "test_ns";
let database = format!("test_db_{}", Uuid::new_v4()); let database = format!("test_db_{}", Uuid::new_v4());
let data_dir = std::env::temp_dir().join(format!("minne_smoke_{}", Uuid::new_v4())); let data_dir = std::env::temp_dir().join(format!("minne_smoke_{}", Uuid::new_v4()));
tokio::fs::create_dir_all(&data_dir).await
tokio::fs::create_dir_all(&data_dir)
.await
.expect("failed to create temp data directory"); .expect("failed to create temp data directory");
let config = smoke_test_config(namespace, &database, &data_dir); let config = smoke_test_config(namespace, &database, &data_dir);
let db = Arc::new( let db = Arc::new(SurrealDbClient::memory(namespace, &database).await?);
SurrealDbClient::memory(namespace, &database) db.apply_migrations().await?;
.await
.expect("failed to start in-memory surrealdb"),
);
db.apply_migrations()
.await
.expect("failed to apply migrations");
let session_store = Arc::new(db.create_session_store().await.expect("session store")); let session_store = Arc::new(db.create_session_store().await?);
let openai_client = Arc::new(async_openai::Client::with_config( let openai_client = Arc::new(async_openai::Client::with_config(
async_openai::config::OpenAIConfig::new() async_openai::config::OpenAIConfig::new()
.with_api_key(&config.openai_api_key) .with_api_key(&config.openai_api_key)
.with_api_base(&config.openai_base_url), .with_api_base(&config.openai_base_url),
)); ));
let storage = StorageManager::new(&config) let storage = StorageManager::new(&config).await?;
.await
.expect("failed to build storage manager");
// Use hashed embeddings for tests to avoid external dependencies
let embedding_provider = Arc::new( let embedding_provider = Arc::new(
common::utils::embedding::EmbeddingProvider::new_hashed(384) common::utils::embedding::EmbeddingProvider::new_hashed(384)?,
.expect("failed to create hashed embedding provider"),
); );
let html_state = HtmlState::new_with_resources( let html_state = HtmlState::new_with_resources(StateResources {
db.clone(), db: Arc::clone(&db),
openai_client, openai_client,
session_store, session_store,
storage.clone(), storage: storage.clone(),
config.clone(), config: config.clone(),
None, reranker_pool: None,
embedding_provider, embedding_provider,
None, template_engine: None,
) });
.await;
let api_state = ApiState { let api_state = ApiState {
db: db.clone(), db: Arc::clone(&db),
config: config.clone(), config: config.clone(),
storage, storage,
}; };
@@ -376,25 +358,22 @@ mod tests {
.oneshot( .oneshot(
Request::builder() Request::builder()
.uri("/api/v1/live") .uri("/api/v1/live")
.body(Body::empty()) .body(Body::empty())?,
.expect("request"),
) )
.await .await?;
.expect("router response");
assert_eq!(response.status(), StatusCode::OK); assert_eq!(response.status(), StatusCode::OK);
let ready_response = app let ready_response = app
.oneshot( .oneshot(
Request::builder() Request::builder()
.uri("/api/v1/ready") .uri("/api/v1/ready")
.body(Body::empty()) .body(Body::empty())?,
.expect("request"),
) )
.await .await?;
.expect("ready response");
assert_eq!(ready_response.status(), StatusCode::OK); assert_eq!(ready_response.status(), StatusCode::OK);
tokio::fs::remove_dir_all(&data_dir).await.ok(); tokio::fs::remove_dir_all(&data_dir).await.ok();
Ok(())
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
+10 -8
View File
@@ -6,7 +6,10 @@ use common::{
storage::{db::SurrealDbClient, store::StorageManager, types::system_settings::SystemSettings}, storage::{db::SurrealDbClient, store::StorageManager, types::system_settings::SystemSettings},
utils::{config::get_config, embedding::EmbeddingProvider}, utils::{config::get_config, embedding::EmbeddingProvider},
}; };
use html_router::{html_routes, html_state::HtmlState}; use html_router::{
html_routes,
html_state::{HtmlState, StateResources},
};
use retrieval_pipeline::reranking::RerankerPool; use retrieval_pipeline::reranking::RerankerPool;
use tracing::info; use tracing::info;
use tracing_subscriber::{fmt, prelude::*, EnvFilter}; use tracing_subscriber::{fmt, prelude::*, EnvFilter};
@@ -52,7 +55,7 @@ async fn main() -> anyhow::Result<()> {
// Create embedding provider based on config // Create embedding provider based on config
let embedding_provider = let embedding_provider =
Arc::new(EmbeddingProvider::from_config(&config, Some(openai_client.clone())).await?); Arc::new(EmbeddingProvider::from_config(&config, Some(Arc::clone(&openai_client))).await?);
info!( info!(
embedding_backend = ?config.embedding_backend, embedding_backend = ?config.embedding_backend,
embedding_dimension = embedding_provider.dimension(), embedding_dimension = embedding_provider.dimension(),
@@ -63,17 +66,16 @@ async fn main() -> anyhow::Result<()> {
let (_settings, _dimensions_changed) = let (_settings, _dimensions_changed) =
SystemSettings::sync_from_embedding_provider(&db, &embedding_provider).await?; SystemSettings::sync_from_embedding_provider(&db, &embedding_provider).await?;
let html_state = HtmlState::new_with_resources( let html_state = HtmlState::new_with_resources(StateResources {
db, db,
openai_client, openai_client,
session_store, session_store,
storage.clone(), storage: storage.clone(),
config.clone(), config: config.clone(),
reranker_pool, reranker_pool,
embedding_provider, embedding_provider,
None, template_engine: None,
) });
.await;
let api_state = ApiState::new(&config, storage).await?; let api_state = ApiState::new(&config, storage).await?;
+3 -3
View File
@@ -42,7 +42,7 @@ async fn main() -> anyhow::Result<()> {
// Create embedding provider based on config // Create embedding provider based on config
let embedding_provider = let embedding_provider =
Arc::new(EmbeddingProvider::from_config(&config, Some(openai_client.clone())).await?); Arc::new(EmbeddingProvider::from_config(&config, Some(Arc::clone(&openai_client))).await?);
info!( info!(
embedding_backend = ?config.embedding_backend, embedding_backend = ?config.embedding_backend,
"Embedding provider initialized for worker" "Embedding provider initialized for worker"
@@ -52,8 +52,8 @@ async fn main() -> anyhow::Result<()> {
let storage = StorageManager::new(&config).await?; let storage = StorageManager::new(&config).await?;
let ingestion_pipeline = Arc::new(IngestionPipeline::new( let ingestion_pipeline = Arc::new(IngestionPipeline::new(
db.clone(), Arc::clone(&db),
openai_client.clone(), Arc::clone(&openai_client),
config, config,
reranker_pool, reranker_pool,
storage, storage,
+4 -6
View File
@@ -118,18 +118,16 @@ pub fn create_chat_request(
} }
pub fn process_llm_response( pub fn process_llm_response(
response: CreateChatCompletionResponse, response: &CreateChatCompletionResponse,
) -> Result<LLMResponseFormat, AppError> { ) -> Result<LLMResponseFormat, Box<AppError>> {
response response
.choices .choices
.first() .first()
.and_then(|choice| choice.message.content.as_ref()) .and_then(|choice| choice.message.content.as_ref())
.ok_or(AppError::LLMParsing( .ok_or_else(|| Box::new(AppError::LLMParsing("No content found in LLM response".into())))
"No content found in LLM response".into(),
))
.and_then(|content| { .and_then(|content| {
serde_json::from_str::<LLMResponseFormat>(content).map_err(|e| { serde_json::from_str::<LLMResponseFormat>(content).map_err(|e| {
AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {e}")) Box::new(AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {e}")))
}) })
}) })
} }
+16 -26
View File
@@ -20,7 +20,6 @@ use common::storage::{
/// * `entity_id` - ID of the entity to find neighbors for /// * `entity_id` - ID of the entity to find neighbors for
/// * `user_id` - User ID for access control /// * `user_id` - User ID for access control
/// * `limit` - Maximum number of neighbors to return /// * `limit` - Maximum number of neighbors to return
pub async fn find_entities_by_relationship_by_id( pub async fn find_entities_by_relationship_by_id(
db: &SurrealDbClient, db: &SurrealDbClient,
entity_id: &str, entity_id: &str,
@@ -113,25 +112,23 @@ pub async fn find_entities_by_relationship_by_id(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use anyhow::{self, Context};
use super::*; use super::*;
use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}; use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
use common::storage::types::knowledge_relationship::KnowledgeRelationship; use common::storage::types::knowledge_relationship::KnowledgeRelationship;
use uuid::Uuid; use uuid::Uuid;
#[tokio::test] #[tokio::test]
async fn test_find_entities_by_relationship_by_id() { async fn test_find_entities_by_relationship_by_id() -> anyhow::Result<()> {
// Setup in-memory database for testing
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database)
.await .await
.expect("Failed to start in-memory surrealdb"); .with_context(|| "Failed to start in-memory surrealdb".to_string())?;
// Create some test entities
let entity_type = KnowledgeEntityType::Document; let entity_type = KnowledgeEntityType::Document;
let user_id = "user123".to_string(); let user_id = "user123".to_string();
// Create the central entity we'll query relationships for
let central_entity = KnowledgeEntity::new( let central_entity = KnowledgeEntity::new(
"central_source".to_string(), "central_source".to_string(),
"Central Entity".to_string(), "Central Entity".to_string(),
@@ -141,7 +138,6 @@ mod tests {
user_id.clone(), user_id.clone(),
); );
// Create related entities
let related_entity1 = KnowledgeEntity::new( let related_entity1 = KnowledgeEntity::new(
"related_source1".to_string(), "related_source1".to_string(),
"Related Entity 1".to_string(), "Related Entity 1".to_string(),
@@ -160,7 +156,6 @@ mod tests {
user_id.clone(), user_id.clone(),
); );
// Create an unrelated entity
let unrelated_entity = KnowledgeEntity::new( let unrelated_entity = KnowledgeEntity::new(
"unrelated_source".to_string(), "unrelated_source".to_string(),
"Unrelated Entity".to_string(), "Unrelated Entity".to_string(),
@@ -170,32 +165,29 @@ mod tests {
user_id.clone(), user_id.clone(),
); );
// Store all entities
let central_entity = db let central_entity = db
.store_item(central_entity.clone()) .store_item(central_entity.clone())
.await .await
.expect("Failed to store central entity") .with_context(|| "Failed to store central entity".to_string())?
.unwrap(); .ok_or_else(|| anyhow::anyhow!("Central entity not returned after store"))?;
let related_entity1 = db let related_entity1 = db
.store_item(related_entity1.clone()) .store_item(related_entity1.clone())
.await .await
.expect("Failed to store related entity 1") .with_context(|| "Failed to store related entity 1".to_string())?
.unwrap(); .ok_or_else(|| anyhow::anyhow!("Related entity 1 not returned after store"))?;
let related_entity2 = db let related_entity2 = db
.store_item(related_entity2.clone()) .store_item(related_entity2.clone())
.await .await
.expect("Failed to store related entity 2") .with_context(|| "Failed to store related entity 2".to_string())?
.unwrap(); .ok_or_else(|| anyhow::anyhow!("Related entity 2 not returned after store"))?;
let _unrelated_entity = db let _unrelated_entity = db
.store_item(unrelated_entity.clone()) .store_item(unrelated_entity.clone())
.await .await
.expect("Failed to store unrelated entity") .with_context(|| "Failed to store unrelated entity".to_string())?
.unwrap(); .ok_or_else(|| anyhow::anyhow!("Unrelated entity not returned after store"))?;
// Create relationships
let source_id = "relationship_source".to_string(); let source_id = "relationship_source".to_string();
// Create relationship 1: central -> related1
let relationship1 = KnowledgeRelationship::new( let relationship1 = KnowledgeRelationship::new(
central_entity.id.clone(), central_entity.id.clone(),
related_entity1.id.clone(), related_entity1.id.clone(),
@@ -204,7 +196,6 @@ mod tests {
"references".to_string(), "references".to_string(),
); );
// Create relationship 2: central -> related2
let relationship2 = KnowledgeRelationship::new( let relationship2 = KnowledgeRelationship::new(
central_entity.id.clone(), central_entity.id.clone(),
related_entity2.id.clone(), related_entity2.id.clone(),
@@ -213,26 +204,25 @@ mod tests {
"contains".to_string(), "contains".to_string(),
); );
// Store relationships
relationship1 relationship1
.store_relationship(&db) .store_relationship(&db)
.await .await
.expect("Failed to store relationship 1"); .with_context(|| "Failed to store relationship 1".to_string())?;
relationship2 relationship2
.store_relationship(&db) .store_relationship(&db)
.await .await
.expect("Failed to store relationship 2"); .with_context(|| "Failed to store relationship 2".to_string())?;
// Test finding entities related to the central entity
let related_entities = let related_entities =
find_entities_by_relationship_by_id(&db, &central_entity.id, &user_id, usize::MAX) find_entities_by_relationship_by_id(&db, &central_entity.id, &user_id, usize::MAX)
.await .await
.expect("Failed to find entities by relationship"); .with_context(|| "Failed to find entities by relationship".to_string())?;
// Check that we found relationships
assert!( assert!(
related_entities.len() >= 2, related_entities.len() >= 2,
"Should find related entities in both directions" "Should find related entities in both directions"
); );
Ok(())
} }
} }
+80 -100
View File
@@ -42,10 +42,14 @@ impl SearchResult {
} }
pub use pipeline::{ pub use pipeline::{
retrieved_entities_to_json, PipelineDiagnostics, PipelineStageTimings, RetrievalConfig, retrieved_entities_to_json, Diagnostics, StageTimings, RetrievalConfig,
RetrievalStrategy, RetrievalTuning, SearchTarget, RetrievalStrategy, RetrievalTuning, RetrievalTuningFlags, SearchTarget,
}; };
// Backward-compatible type aliases for external consumers
pub type PipelineDiagnostics = Diagnostics;
pub type PipelineStageTimings = StageTimings;
// Captures a supporting chunk plus its fused retrieval score for downstream prompts. // Captures a supporting chunk plus its fused retrieval score for downstream prompts.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct RetrievedChunk { pub struct RetrievedChunk {
@@ -61,7 +65,7 @@ pub struct RetrievedEntity {
pub chunks: Vec<RetrievedChunk>, pub chunks: Vec<RetrievedChunk>,
} }
/// Primary orchestrator for the process of retrieving KnowledgeEntitities related to a input_text /// Primary orchestrator for the process of retrieving `KnowledgeEntity` values related to an `input_text`
#[instrument(skip_all, fields(user_id))] #[instrument(skip_all, fields(user_id))]
pub async fn retrieve_entities( pub async fn retrieve_entities(
db_client: &SurrealDbClient, db_client: &SurrealDbClient,
@@ -72,7 +76,7 @@ pub async fn retrieve_entities(
config: RetrievalConfig, config: RetrievalConfig,
reranker: Option<RerankerLease>, reranker: Option<RerankerLease>,
) -> Result<StrategyOutput, AppError> { ) -> Result<StrategyOutput, AppError> {
pipeline::run_pipeline( let params = pipeline::StrategyParams {
db_client, db_client,
openai_client, openai_client,
embedding_provider, embedding_provider,
@@ -80,17 +84,16 @@ pub async fn retrieve_entities(
user_id, user_id,
config, config,
reranker, reranker,
) };
.await pipeline::execute(params).await
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use anyhow::{self};
use async_openai::Client; use async_openai::Client;
use common::storage::indexes::ensure_runtime_indexes; use common::storage::indexes::ensure_runtime;
use common::storage::types::text_chunk::TextChunk;
use pipeline::{RetrievalConfig, RetrievalStrategy};
use uuid::Uuid; use uuid::Uuid;
fn test_embedding() -> Vec<f32> { fn test_embedding() -> Vec<f32> {
@@ -105,27 +108,21 @@ mod tests {
vec![0.2, 0.8, 0.0] vec![0.2, 0.8, 0.0]
} }
async fn setup_test_db() -> SurrealDbClient { async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
let namespace = "test_ns"; let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database) let db = SurrealDbClient::memory(namespace, database).await?;
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations() db.apply_migrations().await?;
.await
.expect("Failed to apply migrations");
ensure_runtime_indexes(&db, 3) ensure_runtime(&db, 3).await?;
.await
.expect("failed to build runtime indexes");
db Ok(db)
} }
#[tokio::test] #[tokio::test]
async fn test_default_strategy_retrieves_chunks() { async fn test_default_strategy_retrieves_chunks() -> anyhow::Result<()> {
let db = setup_test_db().await; let db = setup_test_db().await?;
let user_id = "test_user"; let user_id = "test_user";
let chunk = TextChunk::new( let chunk = TextChunk::new(
"source_1".into(), "source_1".into(),
@@ -133,39 +130,38 @@ mod tests {
user_id.into(), user_id.into(),
); );
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db) TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?;
.await
.expect("Failed to store chunk");
let openai_client = Client::new(); let openai_client = Client::new();
let results = pipeline::run_pipeline_with_embedding( let params = pipeline::StrategyParams {
&db, db_client: &db,
&openai_client, openai_client: &openai_client,
None, embedding_provider: None,
test_embedding(), input_text: "Rust concurrency async tasks",
"Rust concurrency async tasks",
user_id, user_id,
RetrievalConfig::default(), config: RetrievalConfig::default(),
None, reranker: None,
) };
.await let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
.expect("Default strategy retrieval failed"); .await?;
let chunks = match results { let chunks = match results {
StrategyOutput::Chunks(items) => items, StrategyOutput::Chunks(items) => items,
other => panic!("expected chunk results, got {:?}", other), other => anyhow::bail!("expected chunk results, got {other:?}"),
}; };
assert!(!chunks.is_empty(), "Expected at least one retrieval result"); assert!(!chunks.is_empty(), "Expected at least one retrieval result");
assert!( assert!(
chunks[0].chunk.chunk.contains("Tokio"), chunks.first().is_some_and(|c| c.chunk.chunk.contains("Tokio")),
"Expected chunk about Tokio" "Expected chunk about Tokio"
); );
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_default_strategy_returns_chunks_from_multiple_sources() { async fn test_default_strategy_returns_chunks_from_multiple_sources(
let db = setup_test_db().await; ) -> anyhow::Result<()> {
let db = setup_test_db().await?;
let user_id = "multi_source_user"; let user_id = "multi_source_user";
let primary_chunk = TextChunk::new( let primary_chunk = TextChunk::new(
@@ -179,30 +175,25 @@ mod tests {
user_id.into(), user_id.into(),
); );
TextChunk::store_with_embedding(primary_chunk, chunk_embedding_primary(), &db) TextChunk::store_with_embedding(primary_chunk, chunk_embedding_primary(), &db).await?;
.await TextChunk::store_with_embedding(secondary_chunk, chunk_embedding_secondary(), &db).await?;
.expect("Failed to store primary chunk");
TextChunk::store_with_embedding(secondary_chunk, chunk_embedding_secondary(), &db)
.await
.expect("Failed to store secondary chunk");
let openai_client = Client::new(); let openai_client = Client::new();
let results = pipeline::run_pipeline_with_embedding( let params = pipeline::StrategyParams {
&db, db_client: &db,
&openai_client, openai_client: &openai_client,
None, embedding_provider: None,
test_embedding(), input_text: "Rust concurrency async tasks",
"Rust concurrency async tasks",
user_id, user_id,
RetrievalConfig::default(), config: RetrievalConfig::default(),
None, reranker: None,
) };
.await let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
.expect("Default strategy retrieval failed"); .await?;
let chunks = match results { let chunks = match results {
StrategyOutput::Chunks(items) => items, StrategyOutput::Chunks(items) => items,
other => panic!("expected chunk results, got {:?}", other), other => anyhow::bail!("expected chunk results, got {other:?}"),
}; };
assert!(chunks.len() >= 2, "Expected chunks from multiple sources"); assert!(chunks.len() >= 2, "Expected chunks from multiple sources");
@@ -216,11 +207,12 @@ mod tests {
.any(|c| c.chunk.source_id == "secondary_source"), .any(|c| c.chunk.source_id == "secondary_source"),
"Should include secondary source chunk" "Should include secondary source chunk"
); );
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_revised_strategy_returns_chunks() { async fn test_revised_strategy_returns_chunks() -> anyhow::Result<()> {
let db = setup_test_db().await; let db = setup_test_db().await?;
let user_id = "chunk_user"; let user_id = "chunk_user";
let chunk_one = TextChunk::new( let chunk_one = TextChunk::new(
"src_alpha".into(), "src_alpha".into(),
@@ -233,31 +225,26 @@ mod tests {
user_id.into(), user_id.into(),
); );
TextChunk::store_with_embedding(chunk_one.clone(), chunk_embedding_primary(), &db) TextChunk::store_with_embedding(chunk_one.clone(), chunk_embedding_primary(), &db).await?;
.await TextChunk::store_with_embedding(chunk_two.clone(), chunk_embedding_secondary(), &db).await?;
.expect("Failed to store chunk one");
TextChunk::store_with_embedding(chunk_two.clone(), chunk_embedding_secondary(), &db)
.await
.expect("Failed to store chunk two");
let config = RetrievalConfig::with_strategy(RetrievalStrategy::Default); let config = RetrievalConfig::with_strategy(RetrievalStrategy::Default);
let openai_client = Client::new(); let openai_client = Client::new();
let results = pipeline::run_pipeline_with_embedding( let params = pipeline::StrategyParams {
&db, db_client: &db,
&openai_client, openai_client: &openai_client,
None, embedding_provider: None,
test_embedding(), input_text: "tokio runtime worker behavior",
"tokio runtime worker behavior",
user_id, user_id,
config, config,
None, reranker: None,
) };
.await let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
.expect("Revised retrieval failed"); .await?;
let chunks = match results { let chunks = match results {
StrategyOutput::Chunks(items) => items, StrategyOutput::Chunks(items) => items,
other => panic!("expected chunk output, got {:?}", other), other => anyhow::bail!("expected chunk results, got {other:?}"),
}; };
assert!( assert!(
@@ -270,11 +257,12 @@ mod tests {
.any(|entry| entry.chunk.chunk.contains("Tokio")), .any(|entry| entry.chunk.chunk.contains("Tokio")),
"Chunk results should contain relevant snippets" "Chunk results should contain relevant snippets"
); );
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_search_strategy_returns_search_result() { async fn test_search_strategy_returns_search_result() -> anyhow::Result<()> {
let db = setup_test_db().await; let db = setup_test_db().await?;
let user_id = "search_user"; let user_id = "search_user";
let chunk = TextChunk::new( let chunk = TextChunk::new(
"search_src".into(), "search_src".into(),
@@ -282,33 +270,24 @@ mod tests {
user_id.into(), user_id.into(),
); );
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db) TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?;
.await
.expect("Failed to store chunk");
let config = RetrievalConfig::for_search(pipeline::SearchTarget::Both); let config = RetrievalConfig::for_search(pipeline::SearchTarget::Both);
let openai_client = Client::new(); let openai_client = Client::new();
let results = pipeline::run_pipeline_with_embedding( let params = pipeline::StrategyParams {
&db, db_client: &db,
&openai_client, openai_client: &openai_client,
None, embedding_provider: None,
test_embedding(), input_text: "async rust programming",
"async rust programming",
user_id, user_id,
config, config,
None, reranker: None,
) };
.await let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
.expect("Search strategy retrieval failed"); .await?;
assert!( let StrategyOutput::Search(search_result) = results else {
matches!(results, StrategyOutput::Search(_)), anyhow::bail!("expected Search output");
"expected Search output, got {:?}",
results
);
let search_result = match results {
StrategyOutput::Search(sr) => sr,
_ => unreachable!(),
}; };
// Should return chunks (entities may be empty if none stored) // Should return chunks (entities may be empty if none stored)
@@ -323,5 +302,6 @@ mod tests {
.any(|c| c.chunk.chunk.contains("Tokio")), .any(|c| c.chunk.chunk.contains("Tokio")),
"Search results should contain relevant chunks" "Search results should contain relevant chunks"
); );
Ok(())
} }
} }
+91 -44
View File
@@ -1,12 +1,13 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::fmt; use std::fmt;
use crate::scoring::FusionWeights; use crate::scoring::FusionWeights;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum RetrievalStrategy { pub enum RetrievalStrategy {
/// Primary hybrid chunk retrieval for search/chat (formerly Revised) /// Primary hybrid chunk retrieval for search/chat (formerly Revised)
#[default]
Default, Default,
/// Entity retrieval for suggesting relationships when creating manual entities /// Entity retrieval for suggesting relationships when creating manual entities
RelationshipSuggestion, RelationshipSuggestion,
@@ -29,12 +30,6 @@ pub enum SearchTarget {
Both, Both,
} }
impl Default for RetrievalStrategy {
fn default() -> Self {
Self::Default
}
}
impl std::str::FromStr for RetrievalStrategy { impl std::str::FromStr for RetrievalStrategy {
type Err = String; type Err = String;
@@ -70,6 +65,91 @@ impl fmt::Display for RetrievalStrategy {
} }
} }
/// Two-variant flag that serializes as a bool for backward compatibility.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum BoolFlag {
#[default]
Disabled,
Enabled,
}
impl BoolFlag {
pub const fn as_bool(self) -> bool {
matches!(self, BoolFlag::Enabled)
}
}
impl From<bool> for BoolFlag {
fn from(value: bool) -> Self {
if value {
BoolFlag::Enabled
} else {
BoolFlag::Disabled
}
}
}
impl Serialize for BoolFlag {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_bool(self.as_bool())
}
}
impl<'de> Deserialize<'de> for BoolFlag {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
bool::deserialize(deserializer).map(|b| {
if b {
BoolFlag::Enabled
} else {
BoolFlag::Disabled
}
})
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct RetrievalTuningFlags {
pub rerank_scores_only: BoolFlag,
pub normalize_vector_scores: BoolFlag,
pub normalize_fts_scores: BoolFlag,
pub chunk_rrf_use_vector: BoolFlag,
pub chunk_rrf_use_fts: BoolFlag,
}
impl RetrievalTuningFlags {
pub const fn rerank_scores_only(&self) -> bool {
self.rerank_scores_only.as_bool()
}
pub const fn normalize_vector_scores(&self) -> bool {
self.normalize_vector_scores.as_bool()
}
pub const fn normalize_fts_scores(&self) -> bool {
self.normalize_fts_scores.as_bool()
}
pub const fn chunk_rrf_use_vector(&self) -> bool {
self.chunk_rrf_use_vector.as_bool()
}
pub const fn chunk_rrf_use_fts(&self) -> bool {
self.chunk_rrf_use_fts.as_bool()
}
}
impl Default for RetrievalTuningFlags {
fn default() -> Self {
Self {
rerank_scores_only: BoolFlag::Disabled,
normalize_vector_scores: BoolFlag::Disabled,
normalize_fts_scores: BoolFlag::Enabled,
chunk_rrf_use_vector: BoolFlag::Enabled,
chunk_rrf_use_fts: BoolFlag::Enabled,
}
}
}
/// Tunable parameters that govern each retrieval stage. /// Tunable parameters that govern each retrieval stage.
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetrievalTuning { pub struct RetrievalTuning {
@@ -89,15 +169,11 @@ pub struct RetrievalTuning {
pub graph_seed_min_score: f32, pub graph_seed_min_score: f32,
pub graph_vector_inheritance: f32, pub graph_vector_inheritance: f32,
pub rerank_blend_weight: f32, pub rerank_blend_weight: f32,
pub rerank_scores_only: bool, pub flags: RetrievalTuningFlags,
pub rerank_keep_top: usize, pub rerank_keep_top: usize,
pub chunk_result_cap: usize, pub chunk_result_cap: usize,
/// Optional fusion weights for hybrid search. If None, uses default weights. /// Optional fusion weights for hybrid search. If None, uses default weights.
pub fusion_weights: Option<FusionWeights>, pub fusion_weights: Option<FusionWeights>,
/// Normalize vector similarity scores before fusion (default: true)
pub normalize_vector_scores: bool,
/// Normalize FTS (BM25) scores before fusion (default: true)
pub normalize_fts_scores: bool,
/// Reciprocal rank fusion k value for chunk merging in Revised strategy. /// Reciprocal rank fusion k value for chunk merging in Revised strategy.
#[serde(default = "default_chunk_rrf_k")] #[serde(default = "default_chunk_rrf_k")]
pub chunk_rrf_k: f32, pub chunk_rrf_k: f32,
@@ -107,12 +183,6 @@ pub struct RetrievalTuning {
/// Weight applied to chunk FTS ranks in RRF. /// Weight applied to chunk FTS ranks in RRF.
#[serde(default = "default_chunk_rrf_fts_weight")] #[serde(default = "default_chunk_rrf_fts_weight")]
pub chunk_rrf_fts_weight: f32, pub chunk_rrf_fts_weight: f32,
/// Whether to include vector rankings in RRF.
#[serde(default = "default_chunk_rrf_use_vector")]
pub chunk_rrf_use_vector: bool,
/// Whether to include chunk FTS rankings in RRF.
#[serde(default = "default_chunk_rrf_use_fts")]
pub chunk_rrf_use_fts: bool,
} }
impl Default for RetrievalTuning { impl Default for RetrievalTuning {
@@ -134,26 +204,19 @@ impl Default for RetrievalTuning {
graph_seed_min_score: 0.4, graph_seed_min_score: 0.4,
graph_vector_inheritance: 0.6, graph_vector_inheritance: 0.6,
rerank_blend_weight: 0.65, rerank_blend_weight: 0.65,
rerank_scores_only: false, flags: RetrievalTuningFlags::default(),
rerank_keep_top: 8, rerank_keep_top: 8,
chunk_result_cap: 5, chunk_result_cap: 5,
fusion_weights: None, fusion_weights: None,
// Vector scores (cosine similarity) are already in [0,1] range
// Normalization only helps when there's significant variation
normalize_vector_scores: false,
// FTS scores (BM25) are unbounded, normalization helps more
normalize_fts_scores: true,
chunk_rrf_k: default_chunk_rrf_k(), chunk_rrf_k: default_chunk_rrf_k(),
chunk_rrf_vector_weight: default_chunk_rrf_vector_weight(), chunk_rrf_vector_weight: default_chunk_rrf_vector_weight(),
chunk_rrf_fts_weight: default_chunk_rrf_fts_weight(), chunk_rrf_fts_weight: default_chunk_rrf_fts_weight(),
chunk_rrf_use_vector: default_chunk_rrf_use_vector(),
chunk_rrf_use_fts: default_chunk_rrf_use_fts(),
} }
} }
} }
/// Wrapper containing tuning plus future flags for per-request overrides. /// Wrapper containing tuning plus future flags for per-request overrides.
#[derive(Debug, Clone)] #[derive(Debug, Clone, Default)]
pub struct RetrievalConfig { pub struct RetrievalConfig {
pub strategy: RetrievalStrategy, pub strategy: RetrievalStrategy,
pub tuning: RetrievalTuning, pub tuning: RetrievalTuning,
@@ -211,16 +274,6 @@ impl RetrievalConfig {
} }
} }
impl Default for RetrievalConfig {
fn default() -> Self {
Self {
strategy: RetrievalStrategy::default(),
tuning: RetrievalTuning::default(),
search_target: SearchTarget::default(),
}
}
}
const fn default_chunk_rrf_k() -> f32 { const fn default_chunk_rrf_k() -> f32 {
60.0 60.0
} }
@@ -233,10 +286,4 @@ const fn default_chunk_rrf_fts_weight() -> f32 {
1.0 1.0
} }
const fn default_chunk_rrf_use_vector() -> bool {
true
}
const fn default_chunk_rrf_use_fts() -> bool {
true
}
@@ -2,7 +2,7 @@ use serde::Serialize;
/// Captures instrumentation for each hybrid retrieval stage when diagnostics are enabled. /// Captures instrumentation for each hybrid retrieval stage when diagnostics are enabled.
#[derive(Debug, Clone, Default, Serialize)] #[derive(Debug, Clone, Default, Serialize)]
pub struct PipelineDiagnostics { pub struct Diagnostics {
pub collect_candidates: Option<CollectCandidatesStats>, pub collect_candidates: Option<CollectCandidatesStats>,
pub enrich_chunks_from_entities: Option<ChunkEnrichmentStats>, pub enrich_chunks_from_entities: Option<ChunkEnrichmentStats>,
pub assemble: Option<AssembleStats>, pub assemble: Option<AssembleStats>,
+66 -223
View File
@@ -3,10 +3,11 @@ mod diagnostics;
mod stages; mod stages;
mod strategies; mod strategies;
pub use config::{RetrievalConfig, RetrievalStrategy, RetrievalTuning, SearchTarget}; pub use config::{
RetrievalConfig, RetrievalStrategy, RetrievalTuning, RetrievalTuningFlags, SearchTarget,
};
pub use diagnostics::{ pub use diagnostics::{
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace, AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace, Diagnostics,
PipelineDiagnostics,
}; };
use crate::{reranking::RerankerLease, RetrievedEntity, StrategyOutput}; use crate::{reranking::RerankerLease, RetrievedEntity, StrategyOutput};
@@ -37,13 +38,13 @@ pub enum StageKind {
// Pipeline stage trait // Pipeline stage trait
#[async_trait] #[async_trait]
pub trait PipelineStage: Send + Sync { pub trait Stage: Send + Sync {
fn kind(&self) -> StageKind; fn kind(&self) -> StageKind;
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError>; async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError>;
} }
// Type alias for boxed stages // Type alias for boxed stages
pub type BoxedStage = Box<dyn PipelineStage>; pub type BoxedStage = Box<dyn Stage>;
// Strategy driver trait // Strategy driver trait
#[async_trait] #[async_trait]
@@ -51,16 +52,16 @@ pub trait StrategyDriver: Send + Sync {
type Output; type Output;
fn stages(&self) -> Vec<BoxedStage>; fn stages(&self) -> Vec<BoxedStage>;
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError>; fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>>;
} }
// Pipeline stage timings tracker // Pipeline stage timings tracker
#[derive(Debug, Default, Clone)] #[derive(Debug, Default, Clone)]
pub struct PipelineStageTimings { pub struct StageTimings {
timings: Vec<(StageKind, Duration)>, timings: Vec<(StageKind, Duration)>,
} }
impl PipelineStageTimings { impl StageTimings {
pub fn record(&mut self, kind: StageKind, duration: Duration) { pub fn record(&mut self, kind: StageKind, duration: Duration) {
self.timings.push((kind, duration)); self.timings.push((kind, duration));
} }
@@ -74,8 +75,7 @@ impl PipelineStageTimings {
self.timings self.timings
.iter() .iter()
.find(|(k, _)| *k == kind) .find(|(k, _)| *k == kind)
.map(|(_, d)| d.as_millis()) .map_or(0, |(_, d)| d.as_millis())
.unwrap_or(0)
} }
pub fn embed_ms(&self) -> u128 { pub fn embed_ms(&self) -> u128 {
@@ -103,228 +103,100 @@ impl PipelineStageTimings {
} }
} }
pub struct PipelineRunOutput<T> { pub struct RunOutput<T> {
pub results: T, pub results: T,
pub diagnostics: Option<PipelineDiagnostics>, pub diagnostics: Option<Diagnostics>,
pub stage_timings: PipelineStageTimings, pub stage_timings: StageTimings,
} }
pub async fn run_pipeline( pub async fn execute(params: StrategyParams<'_>) -> Result<StrategyOutput, AppError> {
db_client: &SurrealDbClient, let input_chars = params.input_text.chars().count();
openai_client: &Client<async_openai::config::OpenAIConfig>, let input_preview: String = params.input_text.chars().take(120).collect();
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<StrategyOutput, AppError> {
let input_chars = input_text.chars().count();
let input_preview: String = input_text.chars().take(120).collect();
let input_preview_clean = input_preview.replace('\n', " "); let input_preview_clean = input_preview.replace('\n', " ");
let preview_len = input_preview_clean.chars().count(); let preview_len = input_preview_clean.chars().count();
info!( info!(
%user_id, user_id = %params.user_id,
input_chars, input_chars,
preview_truncated = input_chars > preview_len, preview_truncated = input_chars > preview_len,
preview = %input_preview_clean, preview = %input_preview_clean,
strategy = %config.strategy, strategy = %params.config.strategy,
"Starting retrieval pipeline" "Starting retrieval pipeline"
); );
match config.strategy { let strategy = params.config.strategy;
let search_target = params.config.search_target;
match strategy {
RetrievalStrategy::Default => { RetrievalStrategy::Default => {
let driver = DefaultStrategyDriver::new(); let driver = DefaultStrategyDriver::new();
let run = execute_strategy( let run = execute_strategy(driver, params, None, false).await?;
driver,
db_client,
openai_client,
embedding_provider,
None,
input_text,
user_id,
config,
reranker,
false,
)
.await?;
Ok(StrategyOutput::Chunks(run.results)) Ok(StrategyOutput::Chunks(run.results))
} }
RetrievalStrategy::RelationshipSuggestion => { RetrievalStrategy::RelationshipSuggestion => {
let driver = RelationshipSuggestionDriver::new(); let driver = RelationshipSuggestionDriver::new();
let run = execute_strategy( let run = execute_strategy(driver, params, None, false).await?;
driver,
db_client,
openai_client,
embedding_provider,
None,
input_text,
user_id,
config,
reranker,
false,
)
.await?;
Ok(StrategyOutput::Entities(run.results)) Ok(StrategyOutput::Entities(run.results))
} }
RetrievalStrategy::Ingestion => { RetrievalStrategy::Ingestion => {
let driver = IngestionDriver::new(); let driver = IngestionDriver::new();
let run = execute_strategy( let run = execute_strategy(driver, params, None, false).await?;
driver,
db_client,
openai_client,
embedding_provider,
None,
input_text,
user_id,
config,
reranker,
false,
)
.await?;
Ok(StrategyOutput::Entities(run.results)) Ok(StrategyOutput::Entities(run.results))
} }
RetrievalStrategy::Search => { RetrievalStrategy::Search => {
let search_target = config.search_target;
let driver = SearchStrategyDriver::new(search_target); let driver = SearchStrategyDriver::new(search_target);
let run = execute_strategy( let run = execute_strategy(driver, params, None, false).await?;
driver,
db_client,
openai_client,
embedding_provider,
None,
input_text,
user_id,
config,
reranker,
false,
)
.await?;
Ok(StrategyOutput::Search(run.results)) Ok(StrategyOutput::Search(run.results))
} }
} }
} }
pub async fn run_pipeline_with_embedding( pub async fn run_pipeline_with_embedding(
db_client: &SurrealDbClient, params: StrategyParams<'_>,
openai_client: &Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
query_embedding: Vec<f32>, query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<StrategyOutput, AppError> { ) -> Result<StrategyOutput, AppError> {
match config.strategy { let strategy = params.config.strategy;
let search_target = params.config.search_target;
match strategy {
RetrievalStrategy::Default => { RetrievalStrategy::Default => {
let driver = DefaultStrategyDriver::new(); let driver = DefaultStrategyDriver::new();
let run = execute_strategy( let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
false,
)
.await?;
Ok(StrategyOutput::Chunks(run.results)) Ok(StrategyOutput::Chunks(run.results))
} }
RetrievalStrategy::RelationshipSuggestion => { RetrievalStrategy::RelationshipSuggestion => {
let driver = RelationshipSuggestionDriver::new(); let driver = RelationshipSuggestionDriver::new();
let run = execute_strategy( let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
false,
)
.await?;
Ok(StrategyOutput::Entities(run.results)) Ok(StrategyOutput::Entities(run.results))
} }
RetrievalStrategy::Ingestion => { RetrievalStrategy::Ingestion => {
let driver = IngestionDriver::new(); let driver = IngestionDriver::new();
let run = execute_strategy( let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
false,
)
.await?;
Ok(StrategyOutput::Entities(run.results)) Ok(StrategyOutput::Entities(run.results))
} }
RetrievalStrategy::Search => { RetrievalStrategy::Search => {
let search_target = config.search_target;
let driver = SearchStrategyDriver::new(search_target); let driver = SearchStrategyDriver::new(search_target);
let run = execute_strategy( let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
false,
)
.await?;
Ok(StrategyOutput::Search(run.results)) Ok(StrategyOutput::Search(run.results))
} }
} }
} }
// Note: The metrics/diagnostics variants would follow the same pattern,
// but for brevity I'm only updating the main ones used by callers.
// If metrics/diagnostics are needed for non-chat strategies, they should be updated too.
// For now, I'll update them to support at least Initial/Revised as before.
pub async fn run_pipeline_with_embedding_with_metrics( pub async fn run_pipeline_with_embedding_with_metrics(
db_client: &SurrealDbClient, params: StrategyParams<'_>,
openai_client: &Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
query_embedding: Vec<f32>, query_embedding: Vec<f32>,
input_text: &str, ) -> Result<RunOutput<StrategyOutput>, AppError> {
user_id: &str, let strategy = params.config.strategy;
config: RetrievalConfig,
reranker: Option<RerankerLease>, match strategy {
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
match config.strategy {
RetrievalStrategy::Default => { RetrievalStrategy::Default => {
let driver = DefaultStrategyDriver::new(); let driver = DefaultStrategyDriver::new();
let run = execute_strategy( let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
driver, Ok(RunOutput {
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
false,
)
.await?;
Ok(PipelineRunOutput {
results: StrategyOutput::Chunks(run.results), results: StrategyOutput::Chunks(run.results),
diagnostics: run.diagnostics, diagnostics: run.diagnostics,
stage_timings: run.stage_timings, stage_timings: run.stage_timings,
}) })
} }
// Fallback for others if needed, or error. For now assuming metrics mainly for chat.
_ => Err(AppError::InternalError( _ => Err(AppError::InternalError(
"Metrics not supported for this strategy".into(), "Metrics not supported for this strategy".into(),
)), )),
@@ -332,32 +204,16 @@ pub async fn run_pipeline_with_embedding_with_metrics(
} }
pub async fn run_pipeline_with_embedding_with_diagnostics( pub async fn run_pipeline_with_embedding_with_diagnostics(
db_client: &SurrealDbClient, params: StrategyParams<'_>,
openai_client: &Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
query_embedding: Vec<f32>, query_embedding: Vec<f32>,
input_text: &str, ) -> Result<RunOutput<StrategyOutput>, AppError> {
user_id: &str, let strategy = params.config.strategy;
config: RetrievalConfig,
reranker: Option<RerankerLease>, match strategy {
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
match config.strategy {
RetrievalStrategy::Default => { RetrievalStrategy::Default => {
let driver = DefaultStrategyDriver::new(); let driver = DefaultStrategyDriver::new();
let run = execute_strategy( let run = execute_strategy(driver, params, Some(query_embedding), true).await?;
driver, Ok(RunOutput {
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
true,
)
.await?;
Ok(PipelineRunOutput {
results: StrategyOutput::Chunks(run.results), results: StrategyOutput::Chunks(run.results),
diagnostics: run.diagnostics, diagnostics: run.diagnostics,
stage_timings: run.stage_timings, stage_timings: run.stage_timings,
@@ -391,38 +247,25 @@ pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::V
.collect::<Vec<_>>()) .collect::<Vec<_>>())
} }
pub struct StrategyParams<'a> {
pub db_client: &'a SurrealDbClient,
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
pub embedding_provider: Option<&'a common::utils::embedding::EmbeddingProvider>,
pub input_text: &'a str,
pub user_id: &'a str,
pub config: RetrievalConfig,
pub reranker: Option<RerankerLease>,
}
async fn execute_strategy<D: StrategyDriver>( async fn execute_strategy<D: StrategyDriver>(
driver: D, driver: D,
db_client: &SurrealDbClient, params: StrategyParams<'_>,
openai_client: &Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
query_embedding: Option<Vec<f32>>, query_embedding: Option<Vec<f32>>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
capture_diagnostics: bool, capture_diagnostics: bool,
) -> Result<PipelineRunOutput<D::Output>, AppError> { ) -> Result<RunOutput<D::Output>, AppError> {
let ctx = match query_embedding { let ctx = match query_embedding {
Some(embedding) => PipelineContext::with_embedding( Some(embedding) => PipelineContext::with_embedding(params, embedding),
db_client, None => PipelineContext::new(params),
openai_client,
embedding_provider,
embedding,
input_text.to_owned(),
user_id.to_owned(),
config,
reranker,
),
None => PipelineContext::new(
db_client,
openai_client,
embedding_provider,
input_text.to_owned(),
user_id.to_owned(),
config,
reranker,
),
}; };
run_with_driver(driver, ctx, capture_diagnostics).await run_with_driver(driver, ctx, capture_diagnostics).await
@@ -432,7 +275,7 @@ async fn run_with_driver<D: StrategyDriver>(
driver: D, driver: D,
mut ctx: PipelineContext<'_>, mut ctx: PipelineContext<'_>,
capture_diagnostics: bool, capture_diagnostics: bool,
) -> Result<PipelineRunOutput<D::Output>, AppError> { ) -> Result<RunOutput<D::Output>, AppError> {
if capture_diagnostics { if capture_diagnostics {
ctx.enable_diagnostics(); ctx.enable_diagnostics();
} }
@@ -445,9 +288,9 @@ async fn run_with_driver<D: StrategyDriver>(
let diagnostics = ctx.take_diagnostics(); let diagnostics = ctx.take_diagnostics();
let stage_timings = ctx.take_stage_timings(); let stage_timings = ctx.take_stage_timings();
let results = driver.finalize(&mut ctx)?; let results = driver.finalize(&mut ctx).map_err(|e| *e)?;
Ok(PipelineRunOutput { Ok(RunOutput {
results, results,
diagnostics, diagnostics,
stage_timings, stage_timings,
+58 -85
View File
@@ -27,9 +27,9 @@ use super::{
config::{RetrievalConfig, RetrievalTuning}, config::{RetrievalConfig, RetrievalTuning},
diagnostics::{ diagnostics::{
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace, AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
PipelineDiagnostics, Diagnostics,
}, },
PipelineStage, PipelineStageTimings, StageKind, StageTimings, Stage, StageKind, StrategyParams,
}; };
pub struct PipelineContext<'a> { pub struct PipelineContext<'a> {
@@ -45,76 +45,51 @@ pub struct PipelineContext<'a> {
pub chunk_values: Vec<Scored<TextChunk>>, pub chunk_values: Vec<Scored<TextChunk>>,
pub revised_chunk_values: Vec<Scored<TextChunk>>, pub revised_chunk_values: Vec<Scored<TextChunk>>,
pub reranker: Option<RerankerLease>, pub reranker: Option<RerankerLease>,
pub diagnostics: Option<PipelineDiagnostics>, pub diagnostics: Option<Diagnostics>,
pub entity_results: Vec<RetrievedEntity>, pub entity_results: Vec<RetrievedEntity>,
pub chunk_results: Vec<RetrievedChunk>, pub chunk_results: Vec<RetrievedChunk>,
stage_timings: PipelineStageTimings, stage_timings: StageTimings,
} }
impl<'a> PipelineContext<'a> { impl<'a> PipelineContext<'a> {
pub fn new( pub fn new(params: StrategyParams<'a>) -> Self {
db_client: &'a SurrealDbClient,
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&'a EmbeddingProvider>,
input_text: String,
user_id: String,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Self {
Self { Self {
db_client, db_client: params.db_client,
openai_client, openai_client: params.openai_client,
embedding_provider, embedding_provider: params.embedding_provider,
input_text, input_text: params.input_text.to_owned(),
user_id, user_id: params.user_id.to_owned(),
config, config: params.config,
query_embedding: None, query_embedding: None,
entity_candidates: HashMap::new(), entity_candidates: HashMap::new(),
filtered_entities: Vec::new(), filtered_entities: Vec::new(),
chunk_values: Vec::new(), chunk_values: Vec::new(),
revised_chunk_values: Vec::new(), revised_chunk_values: Vec::new(),
reranker, reranker: params.reranker,
diagnostics: None, diagnostics: None,
entity_results: Vec::new(), entity_results: Vec::new(),
chunk_results: Vec::new(), chunk_results: Vec::new(),
stage_timings: PipelineStageTimings::default(), stage_timings: StageTimings::default(),
} }
} }
pub fn with_embedding( pub fn with_embedding(params: StrategyParams<'a>, query_embedding: Vec<f32>) -> Self {
db_client: &'a SurrealDbClient, let mut ctx = Self::new(params);
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&'a EmbeddingProvider>,
query_embedding: Vec<f32>,
input_text: String,
user_id: String,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Self {
let mut ctx = Self::new(
db_client,
openai_client,
embedding_provider,
input_text,
user_id,
config,
reranker,
);
ctx.query_embedding = Some(query_embedding); ctx.query_embedding = Some(query_embedding);
ctx ctx
} }
fn ensure_embedding(&self) -> Result<&Vec<f32>, AppError> { fn ensure_embedding(&self) -> Result<&Vec<f32>, Box<AppError>> {
self.query_embedding.as_ref().ok_or_else(|| { self.query_embedding.as_ref().ok_or_else(|| {
AppError::InternalError( Box::new(AppError::InternalError(
"query embedding missing before candidate collection".to_string(), "query embedding missing before candidate collection".to_string(),
) ))
}) })
} }
pub fn enable_diagnostics(&mut self) { pub fn enable_diagnostics(&mut self) {
if self.diagnostics.is_none() { if self.diagnostics.is_none() {
self.diagnostics = Some(PipelineDiagnostics::default()); self.diagnostics = Some(Diagnostics::default());
} }
} }
@@ -140,11 +115,11 @@ impl<'a> PipelineContext<'a> {
} }
} }
pub fn take_diagnostics(&mut self) -> Option<PipelineDiagnostics> { pub fn take_diagnostics(&mut self) -> Option<Diagnostics> {
self.diagnostics.take() self.diagnostics.take()
} }
pub fn take_stage_timings(&mut self) -> PipelineStageTimings { pub fn take_stage_timings(&mut self) -> StageTimings {
std::mem::take(&mut self.stage_timings) std::mem::take(&mut self.stage_timings)
} }
@@ -165,7 +140,7 @@ impl<'a> PipelineContext<'a> {
pub struct EmbedStage; pub struct EmbedStage;
#[async_trait] #[async_trait]
impl PipelineStage for EmbedStage { impl Stage for EmbedStage {
fn kind(&self) -> StageKind { fn kind(&self) -> StageKind {
StageKind::Embed StageKind::Embed
} }
@@ -179,7 +154,7 @@ impl PipelineStage for EmbedStage {
pub struct CollectCandidatesStage; pub struct CollectCandidatesStage;
#[async_trait] #[async_trait]
impl PipelineStage for CollectCandidatesStage { impl Stage for CollectCandidatesStage {
fn kind(&self) -> StageKind { fn kind(&self) -> StageKind {
StageKind::CollectCandidates StageKind::CollectCandidates
} }
@@ -193,7 +168,7 @@ impl PipelineStage for CollectCandidatesStage {
pub struct GraphExpansionStage; pub struct GraphExpansionStage;
#[async_trait] #[async_trait]
impl PipelineStage for GraphExpansionStage { impl Stage for GraphExpansionStage {
fn kind(&self) -> StageKind { fn kind(&self) -> StageKind {
StageKind::GraphExpansion StageKind::GraphExpansion
} }
@@ -207,7 +182,7 @@ impl PipelineStage for GraphExpansionStage {
pub struct RerankStage; pub struct RerankStage;
#[async_trait] #[async_trait]
impl PipelineStage for RerankStage { impl Stage for RerankStage {
fn kind(&self) -> StageKind { fn kind(&self) -> StageKind {
StageKind::Rerank StageKind::Rerank
} }
@@ -221,7 +196,7 @@ impl PipelineStage for RerankStage {
pub struct AssembleEntitiesStage; pub struct AssembleEntitiesStage;
#[async_trait] #[async_trait]
impl PipelineStage for AssembleEntitiesStage { impl Stage for AssembleEntitiesStage {
fn kind(&self) -> StageKind { fn kind(&self) -> StageKind {
StageKind::Assemble StageKind::Assemble
} }
@@ -235,7 +210,7 @@ impl PipelineStage for AssembleEntitiesStage {
pub struct ChunkVectorStage; pub struct ChunkVectorStage;
#[async_trait] #[async_trait]
impl PipelineStage for ChunkVectorStage { impl Stage for ChunkVectorStage {
fn kind(&self) -> StageKind { fn kind(&self) -> StageKind {
StageKind::CollectCandidates StageKind::CollectCandidates
} }
@@ -249,7 +224,7 @@ impl PipelineStage for ChunkVectorStage {
pub struct ChunkRerankStage; pub struct ChunkRerankStage;
#[async_trait] #[async_trait]
impl PipelineStage for ChunkRerankStage { impl Stage for ChunkRerankStage {
fn kind(&self) -> StageKind { fn kind(&self) -> StageKind {
StageKind::Rerank StageKind::Rerank
} }
@@ -263,7 +238,7 @@ impl PipelineStage for ChunkRerankStage {
pub struct ChunkAssembleStage; pub struct ChunkAssembleStage;
#[async_trait] #[async_trait]
impl PipelineStage for ChunkAssembleStage { impl Stage for ChunkAssembleStage {
fn kind(&self) -> StageKind { fn kind(&self) -> StageKind {
StageKind::Assemble StageKind::Assemble
} }
@@ -283,8 +258,7 @@ pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
let embedding = if let Some(provider) = ctx.embedding_provider { let embedding = if let Some(provider) = ctx.embedding_provider {
provider.embed(&ctx.input_text).await.map_err(|e| { provider.embed(&ctx.input_text).await.map_err(|e| {
AppError::InternalError(format!( AppError::InternalError(format!(
"Failed to generate embedding with provider: {}", "Failed to generate embedding with provider: {e}",
e
)) ))
})? })?
} else { } else {
@@ -299,7 +273,7 @@ pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
#[instrument(level = "trace", skip_all)] #[instrument(level = "trace", skip_all)]
pub async fn collect_candidates(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { pub async fn collect_candidates(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
debug!("Collecting initial candidates via vector and FTS search"); debug!("Collecting initial candidates via vector and FTS search");
let embedding = ctx.ensure_embedding()?.clone(); let embedding = ctx.ensure_embedding().map_err(|e| *e)?.clone();
let tuning = &ctx.config.tuning; let tuning = &ctx.config.tuning;
let weights = FusionWeights::default(); let weights = FusionWeights::default();
@@ -487,11 +461,11 @@ pub async fn rerank(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
#[instrument(level = "trace", skip_all)] #[instrument(level = "trace", skip_all)]
pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
debug!("Collecting vector chunk candidates for revised strategy"); debug!("Collecting vector chunk candidates for revised strategy");
let embedding = ctx.ensure_embedding()?.clone(); let embedding = ctx.ensure_embedding().map_err(|e| *e)?.clone();
let tuning = &ctx.config.tuning; let tuning = &ctx.config.tuning;
let fts_take = tuning.chunk_fts_take; let fts_take = tuning.chunk_fts_take;
let (fts_query, fts_token_count) = normalize_fts_query(&ctx.input_text); let (fts_query, fts_token_count) = normalize_fts_query(&ctx.input_text);
let fts_enabled = tuning.chunk_rrf_use_fts && fts_take > 0 && !fts_query.is_empty(); let fts_enabled = tuning.flags.chunk_rrf_use_fts() && fts_take > 0 && !fts_query.is_empty();
let (vector_rows, fts_rows) = tokio::try_join!( let (vector_rows, fts_rows) = tokio::try_join!(
TextChunk::vector_search( TextChunk::vector_search(
@@ -532,8 +506,8 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(),
k: tuning.chunk_rrf_k, k: tuning.chunk_rrf_k,
vector_weight: tuning.chunk_rrf_vector_weight, vector_weight: tuning.chunk_rrf_vector_weight,
fts_weight, fts_weight,
use_vector: tuning.chunk_rrf_use_vector, use_vector: tuning.flags.chunk_rrf_use_vector(),
use_fts: tuning.chunk_rrf_use_fts && fts_candidates > 0, use_fts: tuning.flags.chunk_rrf_use_fts() && fts_candidates > 0,
}; };
let mut vector_chunks = reciprocal_rank_fusion(vector_scored, fts_scored, rrf_config); let mut vector_chunks = reciprocal_rank_fusion(vector_scored, fts_scored, rrf_config);
@@ -715,7 +689,7 @@ pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
let mut per_entity_count = 0; let mut per_entity_count = 0;
for candidate in candidates.iter() { for candidate in candidates.iter() {
if let Some(trace) = entity_trace.as_mut() { if let Some(trace) = entity_trace.as_mut() {
trace.inspected_candidates += 1; trace.inspected_candidates = trace.inspected_candidates.saturating_add(1);
} }
if per_entity_count >= tuning.max_chunks_per_entity { if per_entity_count >= tuning.max_chunks_per_entity {
break; break;
@@ -723,17 +697,17 @@ pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
let estimated_tokens = let estimated_tokens =
estimate_tokens(&candidate.item.chunk, tuning.avg_chars_per_token); estimate_tokens(&candidate.item.chunk, tuning.avg_chars_per_token);
if estimated_tokens > token_budget_remaining { if estimated_tokens > token_budget_remaining {
chunks_skipped_due_budget += 1; chunks_skipped_due_budget = chunks_skipped_due_budget.saturating_add(1);
if let Some(trace) = entity_trace.as_mut() { if let Some(trace) = entity_trace.as_mut() {
trace.skipped_due_budget += 1; trace.skipped_due_budget = trace.skipped_due_budget.saturating_add(1);
} }
continue; continue;
} }
token_budget_remaining = token_budget_remaining.saturating_sub(estimated_tokens); token_budget_remaining = token_budget_remaining.saturating_sub(estimated_tokens);
tokens_spent += estimated_tokens; tokens_spent = tokens_spent.saturating_add(estimated_tokens);
per_entity_count += 1; per_entity_count = per_entity_count.saturating_add(1);
chunks_selected += 1; chunks_selected = chunks_selected.saturating_add(1);
selected_chunks.push(RetrievedChunk { selected_chunks.push(RetrievedChunk {
chunk: candidate.item.clone(), chunk: candidate.item.clone(),
@@ -780,14 +754,14 @@ pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
const SCORE_SAMPLE_LIMIT: usize = 8; const SCORE_SAMPLE_LIMIT: usize = 8;
fn sample_scores<T, F>(items: &[Scored<T>], mut extractor: F) -> Vec<f32> fn sample_scores<T, F>(items: &[Scored<T>], extractor: F) -> Vec<f32>
where where
F: FnMut(&Scored<T>) -> f32, F: FnMut(&Scored<T>) -> f32,
{ {
items items
.iter() .iter()
.take(SCORE_SAMPLE_LIMIT) .take(SCORE_SAMPLE_LIMIT)
.map(|item| extractor(item)) .map(extractor)
.collect() .collect()
} }
@@ -912,7 +886,7 @@ fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec<RerankResult
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect(); let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
let normalized_scores = min_max_normalize(&raw_scores); let normalized_scores = min_max_normalize(&raw_scores);
let use_only = ctx.config.tuning.rerank_scores_only; let use_only = ctx.config.tuning.flags.rerank_scores_only();
let blend = if use_only { let blend = if use_only {
1.0 1.0
} else { } else {
@@ -942,11 +916,7 @@ fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec<RerankResult
} }
} }
for slot in remaining.into_iter() { reranked.extend(remaining.into_iter().flatten());
if let Some(candidate) = slot {
reranked.push(candidate);
}
}
ctx.filtered_entities = reranked; ctx.filtered_entities = reranked;
let keep_top = ctx.config.tuning.rerank_keep_top; let keep_top = ctx.config.tuning.rerank_keep_top;
@@ -970,7 +940,7 @@ fn apply_chunk_rerank_results(
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect(); let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
let normalized_scores = min_max_normalize(&raw_scores); let normalized_scores = min_max_normalize(&raw_scores);
let use_only = tuning.rerank_scores_only; let use_only = tuning.flags.rerank_scores_only();
let blend = if use_only { let blend = if use_only {
1.0 1.0
} else { } else {
@@ -1001,11 +971,7 @@ fn apply_chunk_rerank_results(
} }
} }
for slot in remaining.into_iter() { reranked.extend(remaining.into_iter().flatten());
if let Some(candidate) = slot {
reranked.push(candidate);
}
}
let keep_top = tuning.rerank_keep_top; let keep_top = tuning.rerank_keep_top;
if keep_top > 0 && reranked.len() > keep_top { if keep_top > 0 && reranked.len() > keep_top {
@@ -1017,7 +983,7 @@ fn apply_chunk_rerank_results(
fn estimate_tokens(text: &str, avg_chars_per_token: usize) -> usize { fn estimate_tokens(text: &str, avg_chars_per_token: usize) -> usize {
let chars = text.chars().count().max(1); let chars = text.chars().count().max(1);
(chars / avg_chars_per_token).max(1) chars.checked_div(avg_chars_per_token).map_or(1, |v| v.max(1))
} }
fn rank_chunks_by_combined_score( fn rank_chunks_by_combined_score(
@@ -1053,13 +1019,20 @@ fn lexical_overlap_score(terms: &[String], haystack: &str) -> f32 {
return 0.0; return 0.0;
} }
let lower = haystack.to_ascii_lowercase(); let lower = haystack.to_ascii_lowercase();
let mut matches = 0usize; let mut matches: u32 = 0;
for term in terms { for term in terms {
if lower.contains(term) { if lower.contains(term) {
matches += 1; matches = matches.saturating_add(1);
} }
} }
(matches as f32) / (terms.len() as f32) let total = u32::try_from(terms.len()).unwrap_or(u32::MAX);
if total == 0 {
return 0.0;
}
let num = matches.min(total);
let num_f32 = u16::try_from(num).map(f32::from).unwrap_or(f32::MAX);
let den_f32 = u16::try_from(total).map(f32::from).unwrap_or(f32::MAX);
num_f32 / den_f32
} }
#[derive(Clone)] #[derive(Clone)]
@@ -28,7 +28,7 @@ impl StrategyDriver for DefaultStrategyDriver {
] ]
} }
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> { fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
Ok(ctx.take_chunk_results()) Ok(ctx.take_chunk_results())
} }
} }
@@ -55,7 +55,7 @@ impl StrategyDriver for RelationshipSuggestionDriver {
] ]
} }
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> { fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
Ok(ctx.take_entity_results()) Ok(ctx.take_entity_results())
} }
} }
@@ -82,7 +82,7 @@ impl StrategyDriver for IngestionDriver {
] ]
} }
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> { fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
Ok(ctx.take_entity_results()) Ok(ctx.take_entity_results())
} }
} }
@@ -134,7 +134,7 @@ impl StrategyDriver for SearchStrategyDriver {
} }
} }
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> { fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
let chunks = match self.target { let chunks = match self.target {
SearchTarget::EntitiesOnly => Vec::new(), SearchTarget::EntitiesOnly => Vec::new(),
_ => ctx.take_chunk_results(), _ => ctx.take_chunk_results(),
+24 -26
View File
@@ -17,7 +17,7 @@ static NEXT_ENGINE: AtomicUsize = AtomicUsize::new(0);
fn pick_engine_index(pool_len: usize) -> usize { fn pick_engine_index(pool_len: usize) -> usize {
let n = NEXT_ENGINE.fetch_add(1, Ordering::Relaxed); let n = NEXT_ENGINE.fetch_add(1, Ordering::Relaxed);
n % pool_len n.checked_rem(pool_len).unwrap_or(0)
} }
pub struct RerankerPool { pub struct RerankerPool {
@@ -28,30 +28,30 @@ pub struct RerankerPool {
impl RerankerPool { impl RerankerPool {
/// Build the pool at startup. /// Build the pool at startup.
/// `pool_size` controls max parallel reranks. /// `pool_size` controls max parallel reranks.
pub fn new(pool_size: usize) -> Result<Arc<Self>, AppError> { pub fn new(pool_size: usize) -> Result<Arc<Self>, Box<AppError>> {
Self::new_with_options( let init_options =
pool_size, RerankInitOptions::new(fastembed::RerankerModel::JINARerankerV1TurboEn);
RerankInitOptions::new(fastembed::RerankerModel::JINARerankerV1TurboEn), Self::new_with_options(pool_size, &init_options)
)
} }
fn new_with_options( fn new_with_options(
pool_size: usize, pool_size: usize,
init_options: RerankInitOptions, init_options: &RerankInitOptions,
) -> Result<Arc<Self>, AppError> { ) -> Result<Arc<Self>, Box<AppError>> {
if pool_size == 0 { if pool_size == 0 {
return Err(AppError::Validation( return Err(Box::new(AppError::Validation(
"RERANKING_POOL_SIZE must be greater than zero".to_string(), "RERANKING_POOL_SIZE must be greater than zero".to_string(),
)); )));
} }
fs::create_dir_all(&init_options.cache_dir)?; fs::create_dir_all(&init_options.cache_dir)
.map_err(|e| Box::new(AppError::from(e)))?;
let mut engines = Vec::with_capacity(pool_size); let mut engines = Vec::with_capacity(pool_size);
for x in 0..pool_size { for x in 0..pool_size {
debug!("Creating reranking engine: {x}"); debug!("Creating reranking engine: {x}");
let model = TextRerank::try_new(init_options.clone()) let model = TextRerank::try_new(init_options.clone())
.map_err(|e| AppError::InternalError(e.to_string()))?; .map_err(|e| Box::new(AppError::InternalError(e.to_string())))?;
engines.push(Arc::new(Mutex::new(model))); engines.push(Arc::new(Mutex::new(model)));
} }
@@ -62,7 +62,7 @@ impl RerankerPool {
} }
/// Initialize a pool using application configuration. /// Initialize a pool using application configuration.
pub fn maybe_from_config(config: &AppConfig) -> Result<Option<Arc<Self>>, AppError> { pub fn maybe_from_config(config: &AppConfig) -> Result<Option<Arc<Self>>, Box<AppError>> {
if !config.reranking_enabled { if !config.reranking_enabled {
return Ok(None); return Ok(None);
} }
@@ -70,30 +70,28 @@ impl RerankerPool {
let pool_size = config.reranking_pool_size.unwrap_or_else(default_pool_size); let pool_size = config.reranking_pool_size.unwrap_or_else(default_pool_size);
let init_options = build_rerank_init_options(config)?; let init_options = build_rerank_init_options(config)?;
Self::new_with_options(pool_size, init_options).map(Some) Self::new_with_options(pool_size, &init_options).map(Some)
} }
/// Check out capacity + pick an engine. /// Check out capacity + pick an engine.
/// This returns a lease that can perform rerank(). /// This returns a lease that can perform `rerank()`.
pub async fn checkout(self: &Arc<Self>) -> RerankerLease { pub async fn checkout(self: &Arc<Self>) -> Option<RerankerLease> {
// Acquire a permit. This enforces backpressure. // Acquire a permit. This enforces backpressure.
let permit = self let permit = Arc::clone(&self.semaphore)
.semaphore
.clone()
.acquire_owned() .acquire_owned()
.await .await
.expect("semaphore closed"); .ok()?;
// Pick an engine. // Pick an engine.
// This is naive: just pick based on a simple modulo counter. // This is naive: just pick based on a simple modulo counter.
// We use an atomic counter to avoid always choosing index 0. // We use an atomic counter to avoid always choosing index 0.
let idx = pick_engine_index(self.engines.len()); let idx = pick_engine_index(self.engines.len());
let engine = self.engines[idx].clone(); let engine = self.engines.get(idx).map(Arc::clone)?;
RerankerLease { Some(RerankerLease {
_permit: permit, _permit: permit,
engine, engine,
} })
} }
} }
@@ -111,7 +109,7 @@ fn is_truthy(value: &str) -> bool {
) )
} }
fn build_rerank_init_options(config: &AppConfig) -> Result<RerankInitOptions, AppError> { fn build_rerank_init_options(config: &AppConfig) -> Result<RerankInitOptions, Box<AppError>> {
let mut options = RerankInitOptions::default(); let mut options = RerankInitOptions::default();
let cache_dir = config let cache_dir = config
@@ -125,7 +123,7 @@ fn build_rerank_init_options(config: &AppConfig) -> Result<RerankInitOptions, Ap
.join("fastembed") .join("fastembed")
.join("reranker") .join("reranker")
}); });
fs::create_dir_all(&cache_dir)?; fs::create_dir_all(&cache_dir).map_err(|e| Box::new(AppError::from(e)))?;
options.cache_dir = cache_dir; options.cache_dir = cache_dir;
let show_progress = config let show_progress = config
@@ -150,7 +148,7 @@ fn env_bool(key: &str) -> Option<bool> {
env::var(key).ok().map(|value| is_truthy(&value)) env::var(key).ok().map(|value| is_truthy(&value))
} }
/// Active lease on a single TextRerank instance. /// Active lease on a single `TextRerank` instance.
pub struct RerankerLease { pub struct RerankerLease {
// When this drops the semaphore permit is released. // When this drops the semaphore permit is released.
_permit: OwnedSemaphorePermit, _permit: OwnedSemaphorePermit,
+14 -5
View File
@@ -28,16 +28,19 @@ impl<T> Scored<T> {
} }
} }
#[must_use]
pub const fn with_vector_score(mut self, score: f32) -> Self { pub const fn with_vector_score(mut self, score: f32) -> Self {
self.scores.vector = Some(score); self.scores.vector = Some(score);
self self
} }
#[must_use]
pub const fn with_fts_score(mut self, score: f32) -> Self { pub const fn with_fts_score(mut self, score: f32) -> Self {
self.scores.fts = Some(score); self.scores.fts = Some(score);
self self
} }
#[must_use]
pub const fn with_graph_score(mut self, score: f32) -> Self { pub const fn with_graph_score(mut self, score: f32) -> Self {
self.scores.graph = Some(score); self.scores.graph = Some(score);
self self
@@ -168,7 +171,7 @@ pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 {
if scores.vector.is_some() && scores.fts.is_some() { if scores.vector.is_some() && scores.fts.is_some() {
// Multiplicative boost: multiply by (1 + bonus) to scale with the base score // Multiplicative boost: multiply by (1 + bonus) to scale with the base score
// This ensures high-scoring golden chunks get boosted more than low-scoring ones // This ensures high-scoring golden chunks get boosted more than low-scoring ones
fused = fused * (1.0 + weights.multi_bonus); fused *= 1.0 + weights.multi_bonus;
} else { } else {
// For other multi-signal combinations (e.g., vector + graph), use additive bonus // For other multi-signal combinations (e.g., vector + graph), use additive bonus
fused += weights.multi_bonus; fused += weights.multi_bonus;
@@ -178,8 +181,8 @@ pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 {
clamp_unit(fused) clamp_unit(fused)
} }
pub fn merge_scored_by_id<T>( pub fn merge_scored_by_id<T, S: std::hash::BuildHasher>(
target: &mut std::collections::HashMap<String, Scored<T>>, target: &mut std::collections::HashMap<String, Scored<T>, S>,
incoming: Vec<Scored<T>>, incoming: Vec<Scored<T>>,
) where ) where
T: StoredObject + Clone, T: StoredObject + Clone,
@@ -263,7 +266,10 @@ where
} }
} }
entry.item = candidate.item; entry.item = candidate.item;
entry.fused += vector_weight / (k + rank as f32 + 1.0); let rank_f32: f32 = u16::try_from(rank)
.map(f32::from)
.unwrap_or(f32::MAX);
entry.fused += vector_weight / (k + rank_f32 + 1.0);
} }
} }
@@ -290,7 +296,10 @@ where
} }
} }
entry.item = candidate.item; entry.item = candidate.item;
entry.fused += fts_weight / (k + rank as f32 + 1.0); let rank_f32: f32 = u16::try_from(rank)
.map(f32::from)
.unwrap_or(f32::MAX);
entry.fused += fts_weight / (k + rank_f32 + 1.0);
} }
} }