mirror of
https://github.com/perstarkse/minne.git
synced 2026-05-28 10:29:30 +02:00
clippy: adhere to pedantic clippy, uniform test error handling
This commit is contained in:
+2
-2
@@ -106,11 +106,11 @@ missing_errors_doc = "allow"
|
||||
missing_panics_doc = "warn"
|
||||
module_name_repetitions = "warn"
|
||||
wildcard_dependencies = "warn"
|
||||
missing_docs_in_private_items = "warn"
|
||||
missing_docs_in_private_items = "allow"
|
||||
|
||||
# Allow noisy lints that don't add value for this project
|
||||
needless_raw_string_hashes = "allow"
|
||||
multiple_bound_locations = "allow"
|
||||
cargo_common_metadata = "allow"
|
||||
multiple-crate-versions = "allow"
|
||||
module_name_repetition = "allow"
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ impl ApiState {
|
||||
surreal_db_client.apply_migrations().await?;
|
||||
|
||||
let app_state = Self {
|
||||
db: surreal_db_client.clone(),
|
||||
db: Arc::clone(&surreal_db_client),
|
||||
config: config.clone(),
|
||||
storage,
|
||||
};
|
||||
|
||||
+23
-22
@@ -8,7 +8,7 @@ use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug, Serialize, Clone)]
|
||||
pub enum ApiError {
|
||||
pub enum ApiErr {
|
||||
#[error("Internal server error")]
|
||||
InternalError(String),
|
||||
|
||||
@@ -25,7 +25,7 @@ pub enum ApiError {
|
||||
PayloadTooLarge(String),
|
||||
}
|
||||
|
||||
impl From<AppError> for ApiError {
|
||||
impl From<AppError> for ApiErr {
|
||||
fn from(err: AppError) -> Self {
|
||||
match err {
|
||||
AppError::Database(_) | AppError::OpenAI(_) => {
|
||||
@@ -39,7 +39,7 @@ impl From<AppError> for ApiError {
|
||||
}
|
||||
}
|
||||
}
|
||||
impl IntoResponse for ApiError {
|
||||
impl IntoResponse for ApiErr {
|
||||
fn into_response(self) -> Response {
|
||||
let (status, error_response) = match self {
|
||||
Self::InternalError(message) => (
|
||||
@@ -94,6 +94,7 @@ mod tests {
|
||||
use super::*;
|
||||
use common::error::AppError;
|
||||
use std::fmt::Debug;
|
||||
use std::io;
|
||||
|
||||
// Helper to check status code
|
||||
fn assert_status_code<T: IntoResponse + Debug>(response: T, expected_status: StatusCode) {
|
||||
@@ -105,42 +106,42 @@ mod tests {
|
||||
fn test_app_error_to_api_error_conversion() {
|
||||
// Test NotFound error conversion
|
||||
let not_found = AppError::NotFound("resource not found".to_string());
|
||||
let api_error = ApiError::from(not_found);
|
||||
assert!(matches!(api_error, ApiError::NotFound(msg) if msg == "resource not found"));
|
||||
let api_error = ApiErr::from(not_found);
|
||||
assert!(matches!(api_error, ApiErr::NotFound(msg) if msg == "resource not found"));
|
||||
|
||||
// Test Validation error conversion
|
||||
let validation = AppError::Validation("invalid input".to_string());
|
||||
let api_error = ApiError::from(validation);
|
||||
assert!(matches!(api_error, ApiError::ValidationError(msg) if msg == "invalid input"));
|
||||
let api_error = ApiErr::from(validation);
|
||||
assert!(matches!(api_error, ApiErr::ValidationError(msg) if msg == "invalid input"));
|
||||
|
||||
// Test Auth error conversion
|
||||
let auth = AppError::Auth("unauthorized".to_string());
|
||||
let api_error = ApiError::from(auth);
|
||||
assert!(matches!(api_error, ApiError::Unauthorized(msg) if msg == "unauthorized"));
|
||||
let api_error = ApiErr::from(auth);
|
||||
assert!(matches!(api_error, ApiErr::Unauthorized(msg) if msg == "unauthorized"));
|
||||
|
||||
// Test for internal errors - create a mock error that doesn't require surrealdb
|
||||
let internal_error =
|
||||
AppError::Io(std::io::Error::new(std::io::ErrorKind::Other, "io error"));
|
||||
let api_error = ApiError::from(internal_error);
|
||||
assert!(matches!(api_error, ApiError::InternalError(_)));
|
||||
AppError::Io(io::Error::other("io error"));
|
||||
let api_error = ApiErr::from(internal_error);
|
||||
assert!(matches!(api_error, ApiErr::InternalError(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_api_error_response_status_codes() {
|
||||
// Test internal error status
|
||||
let error = ApiError::InternalError("server error".to_string());
|
||||
let error = ApiErr::InternalError("server error".to_string());
|
||||
assert_status_code(error, StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
// Test not found status
|
||||
let error = ApiError::NotFound("not found".to_string());
|
||||
let error = ApiErr::NotFound("not found".to_string());
|
||||
assert_status_code(error, StatusCode::NOT_FOUND);
|
||||
|
||||
// Test validation error status
|
||||
let error = ApiError::ValidationError("invalid input".to_string());
|
||||
let error = ApiErr::ValidationError("invalid input".to_string());
|
||||
assert_status_code(error, StatusCode::BAD_REQUEST);
|
||||
|
||||
// Test unauthorized status
|
||||
let error = ApiError::Unauthorized("not allowed".to_string());
|
||||
let error = ApiErr::Unauthorized("not allowed".to_string());
|
||||
assert_status_code(error, StatusCode::UNAUTHORIZED);
|
||||
|
||||
// Test payload too large status
|
||||
@@ -153,15 +154,15 @@ mod tests {
|
||||
fn test_error_messages() {
|
||||
// For validation errors
|
||||
let message = "invalid data format";
|
||||
let error = ApiError::ValidationError(message.to_string());
|
||||
let error = ApiErr::ValidationError(message.to_string());
|
||||
|
||||
// Check that the error itself contains the message
|
||||
assert_eq!(error.to_string(), format!("Validation error: {}", message));
|
||||
assert_eq!(error.to_string(), format!("Validation error: {message}"));
|
||||
|
||||
// For not found errors
|
||||
let message = "user not found";
|
||||
let error = ApiError::NotFound(message.to_string());
|
||||
assert_eq!(error.to_string(), format!("Not found: {}", message));
|
||||
let error = ApiErr::NotFound(message.to_string());
|
||||
assert_eq!(error.to_string(), format!("Not found: {message}"));
|
||||
}
|
||||
|
||||
// Alternative approach for internal error test
|
||||
@@ -170,8 +171,8 @@ mod tests {
|
||||
// Create a sensitive error message
|
||||
let sensitive_info = "db password incorrect";
|
||||
|
||||
// Create ApiError with sensitive info
|
||||
let api_error = ApiError::InternalError(sensitive_info.to_string());
|
||||
// Create ApiErr with sensitive info
|
||||
let api_error = ApiErr::InternalError(sensitive_info.to_string());
|
||||
|
||||
// Check the error message is correctly set
|
||||
assert_eq!(api_error.to_string(), "Internal server error");
|
||||
|
||||
@@ -6,19 +6,19 @@ use axum::{
|
||||
|
||||
use common::storage::types::user::User;
|
||||
|
||||
use crate::{api_state::ApiState, error::ApiError};
|
||||
use crate::{api_state::ApiState, error::ApiErr};
|
||||
|
||||
pub async fn api_auth(
|
||||
State(state): State<ApiState>,
|
||||
mut request: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, ApiError> {
|
||||
) -> Result<Response, ApiErr> {
|
||||
let api_key = extract_api_key(&request)
|
||||
.ok_or_else(|| ApiError::Unauthorized("You have to be authenticated".to_string()))?;
|
||||
.ok_or_else(|| ApiErr::Unauthorized("You have to be authenticated".to_string()))?;
|
||||
|
||||
let user = User::find_by_api_key(&api_key, &state.db).await?;
|
||||
let user =
|
||||
user.ok_or_else(|| ApiError::Unauthorized("You have to be authenticated".to_string()))?;
|
||||
user.ok_or_else(|| ApiErr::Unauthorized("You have to be authenticated".to_string()))?;
|
||||
|
||||
request.extensions_mut().insert(user);
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
use axum::{extract::State, response::IntoResponse, Extension, Json};
|
||||
use common::storage::types::user::User;
|
||||
|
||||
use crate::{api_state::ApiState, error::ApiError};
|
||||
use crate::{api_state::ApiState, error::ApiErr};
|
||||
|
||||
pub async fn get_categories(
|
||||
pub async fn list(
|
||||
State(state): State<ApiState>,
|
||||
Extension(user): Extension<User>,
|
||||
) -> Result<impl IntoResponse, ApiError> {
|
||||
) -> Result<impl IntoResponse, ApiErr> {
|
||||
let categories = User::get_user_categories(&user.id, &state.db).await?;
|
||||
|
||||
Ok(Json(categories))
|
||||
|
||||
@@ -13,7 +13,7 @@ use serde_json::json;
|
||||
use tempfile::NamedTempFile;
|
||||
use tracing::info;
|
||||
|
||||
use crate::{api_state::ApiState, error::ApiError};
|
||||
use crate::{api_state::ApiState, error::ApiErr};
|
||||
|
||||
#[derive(Debug, TryFromMultipart)]
|
||||
pub struct IngestParams {
|
||||
|
||||
+35
-28
@@ -202,6 +202,7 @@ impl SurrealDbClient {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
use crate::stored_object;
|
||||
|
||||
use super::*;
|
||||
@@ -212,19 +213,17 @@ mod tests {
|
||||
});
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_initialization_and_crud() {
|
||||
async fn test_initialization_and_crud() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string(); // ensures isolation per test run
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Call your initialization
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to initialize schema");
|
||||
.with_context(|| "Failed to initialize schema".to_string())?;
|
||||
|
||||
// Test basic CRUD
|
||||
let dummy = Dummy {
|
||||
id: "abc".to_string(),
|
||||
name: "first".to_string(),
|
||||
@@ -232,50 +231,50 @@ mod tests {
|
||||
updated_at: Utc::now(),
|
||||
};
|
||||
|
||||
// Store
|
||||
let stored = db.store_item(dummy.clone()).await.expect("Failed to store");
|
||||
let stored = db
|
||||
.store_item(dummy.clone())
|
||||
.await
|
||||
.with_context(|| "Failed to store".to_string())?;
|
||||
assert!(stored.is_some());
|
||||
|
||||
// Read
|
||||
let fetched = db
|
||||
.get_item::<Dummy>(&dummy.id)
|
||||
.await
|
||||
.expect("Failed to fetch");
|
||||
.with_context(|| "Failed to fetch".to_string())?;
|
||||
assert_eq!(fetched, Some(dummy.clone()));
|
||||
|
||||
// Read all
|
||||
let all = db
|
||||
.get_all_stored_items::<Dummy>()
|
||||
.await
|
||||
.expect("Failed to fetch all");
|
||||
.with_context(|| "Failed to fetch all".to_string())?;
|
||||
assert!(all.contains(&dummy));
|
||||
|
||||
// Delete
|
||||
let deleted = db
|
||||
.delete_item::<Dummy>(&dummy.id)
|
||||
.await
|
||||
.expect("Failed to delete");
|
||||
.with_context(|| "Failed to delete".to_string())?;
|
||||
assert_eq!(deleted, Some(dummy));
|
||||
|
||||
// After delete, should not be present
|
||||
let fetch_post = db
|
||||
.get_item::<Dummy>("abc")
|
||||
.await
|
||||
.expect("Failed fetch post delete");
|
||||
.with_context(|| "Failed fetch post delete".to_string())?;
|
||||
assert!(fetch_post.is_none());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn upsert_item_overwrites_existing_records() {
|
||||
async fn upsert_item_overwrites_existing_records() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to initialize schema");
|
||||
.with_context(|| "Failed to initialize schema".to_string())?;
|
||||
|
||||
let mut dummy = Dummy {
|
||||
id: "abc".to_string(),
|
||||
@@ -286,17 +285,21 @@ mod tests {
|
||||
|
||||
db.store_item(dummy.clone())
|
||||
.await
|
||||
.expect("Failed to store initial record");
|
||||
.with_context(|| "Failed to store initial record".to_string())?;
|
||||
|
||||
dummy.name = "updated".to_string();
|
||||
let upserted = db
|
||||
.upsert_item(dummy.clone())
|
||||
.await
|
||||
.expect("Failed to upsert record");
|
||||
.with_context(|| "Failed to upsert record".to_string())?;
|
||||
assert!(upserted.is_some());
|
||||
|
||||
let fetched: Option<Dummy> = db.get_item(&dummy.id).await.expect("fetch after upsert");
|
||||
assert_eq!(fetched.unwrap().name, "updated");
|
||||
let fetched: Option<Dummy> = db
|
||||
.get_item(&dummy.id)
|
||||
.await
|
||||
.with_context(|| "fetch after upsert".to_string())?;
|
||||
let fetched = fetched.ok_or_else(|| anyhow::anyhow!("Expected record to exist after upsert"))?;
|
||||
assert_eq!(fetched.name, "updated");
|
||||
|
||||
let new_record = Dummy {
|
||||
id: "def".to_string(),
|
||||
@@ -306,25 +309,29 @@ mod tests {
|
||||
};
|
||||
db.upsert_item(new_record.clone())
|
||||
.await
|
||||
.expect("Failed to upsert new record");
|
||||
.with_context(|| "Failed to upsert new record".to_string())?;
|
||||
|
||||
let fetched_new: Option<Dummy> = db
|
||||
.get_item(&new_record.id)
|
||||
.await
|
||||
.expect("fetch inserted via upsert");
|
||||
.with_context(|| "fetch inserted via upsert".to_string())?;
|
||||
assert_eq!(fetched_new, Some(new_record));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_applying_migrations() {
|
||||
async fn test_applying_migrations() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to build indexes");
|
||||
.with_context(|| "Failed to build indexes".to_string())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -159,23 +159,23 @@ impl FtsIndexSpec {
|
||||
|
||||
/// Build runtime Surreal indexes (FTS + HNSW) using concurrent creation with readiness polling.
|
||||
/// Idempotent: safe to call multiple times and will overwrite HNSW definitions when the dimension changes.
|
||||
pub async fn ensure_runtime_indexes(
|
||||
pub async fn ensure_runtime(
|
||||
db: &SurrealDbClient,
|
||||
embedding_dimension: usize,
|
||||
) -> Result<(), AppError> {
|
||||
ensure_runtime_indexes_inner(db, embedding_dimension)
|
||||
ensure_runtime_inner(db, embedding_dimension)
|
||||
.await
|
||||
.map_err(|err| AppError::InternalError(err.to_string()))
|
||||
}
|
||||
|
||||
/// Rebuild known FTS and HNSW indexes, skipping any that are not yet defined.
|
||||
pub async fn rebuild_indexes(db: &SurrealDbClient) -> Result<(), AppError> {
|
||||
rebuild_indexes_inner(db)
|
||||
pub async fn rebuild(db: &SurrealDbClient) -> Result<(), AppError> {
|
||||
rebuild_inner(db)
|
||||
.await
|
||||
.map_err(|err| AppError::InternalError(err.to_string()))
|
||||
}
|
||||
|
||||
async fn ensure_runtime_indexes_inner(
|
||||
async fn ensure_runtime_inner(
|
||||
db: &SurrealDbClient,
|
||||
embedding_dimension: usize,
|
||||
) -> Result<()> {
|
||||
@@ -262,9 +262,8 @@ async fn get_index_status(db: &SurrealDbClient, index_name: &str, table: &str) -
|
||||
.context("checking index status")?;
|
||||
let info: Option<Value> = info_res.take(0).context("failed to take info result")?;
|
||||
|
||||
let info = match info {
|
||||
Some(i) => i,
|
||||
None => return Ok("unknown".to_string()),
|
||||
let Some(info) = info else {
|
||||
return Ok("unknown".to_string());
|
||||
};
|
||||
|
||||
let building = info.get("building");
|
||||
@@ -277,7 +276,7 @@ async fn get_index_status(db: &SurrealDbClient, index_name: &str, table: &str) -
|
||||
Ok(status)
|
||||
}
|
||||
|
||||
async fn rebuild_indexes_inner(db: &SurrealDbClient) -> Result<()> {
|
||||
async fn rebuild_inner(db: &SurrealDbClient) -> Result<()> {
|
||||
debug!("Rebuilding indexes with concurrent definitions");
|
||||
create_fts_analyzer(db).await?;
|
||||
|
||||
@@ -385,10 +384,9 @@ async fn create_fts_analyzer(db: &SurrealDbClient) -> Result<()> {
|
||||
// is unavailable in the running Surreal build. Use IF NOT EXISTS to avoid clobbering
|
||||
// an existing analyzer definition.
|
||||
let snowball_query = format!(
|
||||
"DEFINE ANALYZER IF NOT EXISTS {analyzer}
|
||||
"DEFINE ANALYZER IF NOT EXISTS {FTS_ANALYZER_NAME}
|
||||
TOKENIZERS class
|
||||
FILTERS lowercase, ascii, snowball(english);",
|
||||
analyzer = FTS_ANALYZER_NAME
|
||||
FILTERS lowercase, ascii, snowball(english);"
|
||||
);
|
||||
|
||||
match db.client.query(snowball_query).await {
|
||||
@@ -410,10 +408,9 @@ async fn create_fts_analyzer(db: &SurrealDbClient) -> Result<()> {
|
||||
}
|
||||
|
||||
let fallback_query = format!(
|
||||
"DEFINE ANALYZER IF NOT EXISTS {analyzer}
|
||||
"DEFINE ANALYZER IF NOT EXISTS {FTS_ANALYZER_NAME}
|
||||
TOKENIZERS class
|
||||
FILTERS lowercase, ascii;",
|
||||
analyzer = FTS_ANALYZER_NAME
|
||||
FILTERS lowercase, ascii;"
|
||||
);
|
||||
|
||||
let res = db
|
||||
@@ -446,6 +443,7 @@ async fn create_index_with_polling(
|
||||
table: &str,
|
||||
progress_table: Option<&str>,
|
||||
) -> Result<()> {
|
||||
const MAX_ATTEMPTS: usize = 3;
|
||||
let expected_total = match progress_table {
|
||||
Some(table) => Some(count_table_rows(db, table).await.with_context(|| {
|
||||
format!("counting rows in {table} for index {index_name} progress")
|
||||
@@ -453,10 +451,9 @@ async fn create_index_with_polling(
|
||||
None => None,
|
||||
};
|
||||
|
||||
let mut attempts = 0;
|
||||
const MAX_ATTEMPTS: usize = 3;
|
||||
let mut attempts: usize = 0;
|
||||
loop {
|
||||
attempts += 1;
|
||||
attempts = attempts.saturating_add(1);
|
||||
let res = db
|
||||
.client
|
||||
.query(definition.clone())
|
||||
@@ -527,8 +524,8 @@ async fn poll_index_build_status(
|
||||
break;
|
||||
};
|
||||
|
||||
match snapshot.progress_pct {
|
||||
Some(pct) => debug!(
|
||||
if let Some(pct) = snapshot.progress_pct {
|
||||
debug!(
|
||||
index = %index_name,
|
||||
table = %table,
|
||||
status = snapshot.status,
|
||||
@@ -539,8 +536,9 @@ async fn poll_index_build_status(
|
||||
total = snapshot.total_rows,
|
||||
progress_pct = format_args!("{pct:.1}"),
|
||||
"Index build status"
|
||||
),
|
||||
None => debug!(
|
||||
);
|
||||
} else {
|
||||
debug!(
|
||||
index = %index_name,
|
||||
table = %table,
|
||||
status = snapshot.status,
|
||||
@@ -549,7 +547,7 @@ async fn poll_index_build_status(
|
||||
updated = snapshot.updated,
|
||||
processed = snapshot.processed,
|
||||
"Index build status"
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
if snapshot.is_ready() {
|
||||
@@ -611,17 +609,17 @@ fn parse_index_build_info(
|
||||
|
||||
let initial = building
|
||||
.and_then(|b| b.get("initial"))
|
||||
.and_then(|v| v.as_u64())
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.unwrap_or(0);
|
||||
|
||||
let pending = building
|
||||
.and_then(|b| b.get("pending"))
|
||||
.and_then(|v| v.as_u64())
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.unwrap_or(0);
|
||||
|
||||
let updated = building
|
||||
.and_then(|b| b.get("updated"))
|
||||
.and_then(|v| v.as_u64())
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.unwrap_or(0);
|
||||
|
||||
// `initial` is the number of rows seen when the build started; `updated` accounts for later writes.
|
||||
@@ -631,7 +629,7 @@ fn parse_index_build_info(
|
||||
if total == 0 {
|
||||
0.0
|
||||
} else {
|
||||
((processed as f64 / total as f64).min(1.0)) * 100.0
|
||||
((f64::from(u32::try_from(processed).unwrap_or(u32::MAX)) / f64::from(u32::try_from(total).unwrap_or(1))).min(1.0)) * 100.0
|
||||
}
|
||||
});
|
||||
|
||||
@@ -673,7 +671,7 @@ async fn table_index_definitions(
|
||||
.client
|
||||
.query(info_query)
|
||||
.await
|
||||
.with_context(|| format!("fetching table info for {}", table))?;
|
||||
.with_context(|| format!("fetching table info for {table}"))?;
|
||||
|
||||
let info: surrealdb::Value = response
|
||||
.take(0)
|
||||
@@ -700,12 +698,15 @@ async fn index_exists(db: &SurrealDbClient, table: &str, index_name: &str) -> Re
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use anyhow::{self, Context};
|
||||
use crate::storage::db::SurrealDbClient;
|
||||
use serde_json::json;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_index_build_info_reports_progress() {
|
||||
fn parse_index_build_info_reports_progress() -> anyhow::Result<()> {
|
||||
let info = json!({
|
||||
"building": {
|
||||
"initial": 56894,
|
||||
@@ -715,7 +716,8 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let snapshot = parse_index_build_info(Some(info), Some(61081)).expect("snapshot");
|
||||
let snapshot = parse_index_build_info(Some(info), Some(61081))
|
||||
.context("snapshot")?;
|
||||
assert_eq!(
|
||||
snapshot,
|
||||
IndexBuildSnapshot {
|
||||
@@ -729,16 +731,19 @@ mod tests {
|
||||
}
|
||||
);
|
||||
assert!(!snapshot.is_ready());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_index_build_info_defaults_to_ready_when_no_building_block() {
|
||||
fn parse_index_build_info_defaults_to_ready_when_no_building_block() -> anyhow::Result<()> {
|
||||
// Surreal returns `{}` when the index exists but isn't building.
|
||||
let info = json!({});
|
||||
let snapshot = parse_index_build_info(Some(info), Some(10)).expect("snapshot");
|
||||
let snapshot = parse_index_build_info(Some(info), Some(10))
|
||||
.context("snapshot")?;
|
||||
assert!(snapshot.is_ready());
|
||||
assert_eq!(snapshot.processed, 0);
|
||||
assert_eq!(snapshot.progress_pct, Some(0.0));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -748,48 +753,40 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ensure_runtime_indexes_is_idempotent() {
|
||||
async fn ensure_runtime_is_idempotent() -> anyhow::Result<()> {
|
||||
let namespace = "indexes_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("in-memory db");
|
||||
.context("in-memory db")?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("migrations should succeed");
|
||||
.context("migrations should succeed")?;
|
||||
|
||||
// First run creates everything
|
||||
ensure_runtime_indexes(&db, 1536)
|
||||
.await
|
||||
.expect("initial index creation");
|
||||
|
||||
// Second run should be a no-op and still succeed
|
||||
ensure_runtime_indexes(&db, 1536)
|
||||
.await
|
||||
.expect("second index creation");
|
||||
ensure_runtime(&db, 1536).await
|
||||
.context("first call should succeed")?;
|
||||
ensure_runtime(&db, 1536).await
|
||||
.context("second index creation")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ensure_hnsw_index_overwrites_dimension() {
|
||||
async fn ensure_hnsw_index_overwrites_dimension() -> anyhow::Result<()> {
|
||||
let namespace = "indexes_dim";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("in-memory db");
|
||||
.context("in-memory db")?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("migrations should succeed");
|
||||
.context("migrations should succeed")?;
|
||||
|
||||
// Create initial index with default dimension
|
||||
ensure_runtime_indexes(&db, 1536)
|
||||
.await
|
||||
.expect("initial index creation");
|
||||
|
||||
// Change dimension and ensure overwrite path is exercised
|
||||
ensure_runtime_indexes(&db, 128)
|
||||
.await
|
||||
.expect("overwritten index creation");
|
||||
ensure_runtime(&db, 1536).await
|
||||
.context("initial index creation")?;
|
||||
ensure_runtime(&db, 128).await
|
||||
.context("overwritten index creation")?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
+142
-108
@@ -13,13 +13,13 @@ use object_store::{path::Path as ObjPath, ObjectStore};
|
||||
|
||||
use crate::utils::config::{AppConfig, StorageKind};
|
||||
|
||||
pub type DynStore = Arc<dyn ObjectStore>;
|
||||
pub type DynStorage = Arc<dyn ObjectStore>;
|
||||
|
||||
/// Storage manager with persistent state and proper lifecycle management.
|
||||
#[derive(Clone)]
|
||||
pub struct StorageManager {
|
||||
// Store from objectstore wrapped as dyn
|
||||
store: DynStore,
|
||||
store: DynStorage,
|
||||
// Simple enum to track which kind
|
||||
backend_kind: StorageKind,
|
||||
// Where on disk
|
||||
@@ -46,7 +46,7 @@ impl StorageManager {
|
||||
///
|
||||
/// This method is useful for testing scenarios where you want to inject
|
||||
/// a specific storage backend.
|
||||
pub fn with_backend(store: DynStore, backend_kind: StorageKind) -> Self {
|
||||
pub fn with_backend(store: DynStorage, backend_kind: StorageKind) -> Self {
|
||||
Self {
|
||||
store,
|
||||
backend_kind,
|
||||
@@ -216,7 +216,7 @@ impl StorageManager {
|
||||
/// storage backends with proper error handling and validation.
|
||||
async fn create_storage_backend(
|
||||
cfg: &AppConfig,
|
||||
) -> object_store::Result<(DynStore, Option<PathBuf>)> {
|
||||
) -> object_store::Result<(DynStorage, Option<PathBuf>)> {
|
||||
match cfg.storage {
|
||||
StorageKind::Local => {
|
||||
let base = resolve_base_dir(cfg);
|
||||
@@ -261,9 +261,7 @@ async fn create_storage_backend(
|
||||
builder = builder.with_endpoint(endpoint);
|
||||
}
|
||||
|
||||
if let Some(region) = &cfg.s3_region {
|
||||
builder = builder.with_region(region);
|
||||
}
|
||||
builder = builder.with_region(&cfg.s3_region);
|
||||
|
||||
let store = builder.build()?;
|
||||
Ok((Arc::new(store), None))
|
||||
@@ -342,7 +340,7 @@ pub mod testing {
|
||||
surrealdb_password: "test".into(),
|
||||
surrealdb_namespace: "test".into(),
|
||||
surrealdb_database: "test".into(),
|
||||
data_dir: base.into(),
|
||||
data_dir: base,
|
||||
http_port: 0,
|
||||
openai_base_url: "..".into(),
|
||||
storage: StorageKind::Local,
|
||||
@@ -382,7 +380,7 @@ pub mod testing {
|
||||
#[derive(Clone)]
|
||||
pub struct TestStorageManager {
|
||||
storage: StorageManager,
|
||||
_temp_dir: Option<(String, std::path::PathBuf)>, // For local storage cleanup
|
||||
temp_dir: Option<(String, std::path::PathBuf)>, // For local storage cleanup
|
||||
}
|
||||
|
||||
impl TestStorageManager {
|
||||
@@ -396,7 +394,7 @@ pub mod testing {
|
||||
|
||||
Ok(Self {
|
||||
storage,
|
||||
_temp_dir: None,
|
||||
temp_dir: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -413,7 +411,7 @@ pub mod testing {
|
||||
|
||||
Ok(Self {
|
||||
storage,
|
||||
_temp_dir: resolved,
|
||||
temp_dir: resolved,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -437,7 +435,7 @@ pub mod testing {
|
||||
|
||||
Ok(Self {
|
||||
storage,
|
||||
_temp_dir: None,
|
||||
temp_dir: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -454,7 +452,7 @@ pub mod testing {
|
||||
|
||||
Ok(Self {
|
||||
storage,
|
||||
_temp_dir: temp_dir,
|
||||
temp_dir,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -508,7 +506,7 @@ pub mod testing {
|
||||
impl Drop for TestStorageManager {
|
||||
fn drop(&mut self) {
|
||||
// Clean up temporary directories for local storage
|
||||
if let Some((_, path)) = &self._temp_dir {
|
||||
if let Some((_, path)) = &self.temp_dir {
|
||||
if path.exists() {
|
||||
let _ = std::fs::remove_dir_all(path);
|
||||
}
|
||||
@@ -584,6 +582,7 @@ pub fn split_object_path(path: &str) -> AnyResult<(String, String)> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use anyhow::Context;
|
||||
use crate::utils::config::{PdfIngestMode::LlmFirst, StorageKind};
|
||||
use bytes::Bytes;
|
||||
use uuid::Uuid;
|
||||
@@ -623,11 +622,11 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_memory_basic_operations() {
|
||||
async fn test_storage_manager_memory_basic_operations() -> anyhow::Result<()> {
|
||||
let cfg = test_config_memory();
|
||||
let storage = StorageManager::new(&cfg)
|
||||
.await
|
||||
.expect("create storage manager");
|
||||
.with_context(|| "create storage manager".to_string())?;
|
||||
assert!(storage.local_base_path().is_none());
|
||||
|
||||
let location = "test/data/file.txt";
|
||||
@@ -637,31 +636,33 @@ mod tests {
|
||||
storage
|
||||
.put(location, Bytes::from(data.to_vec()))
|
||||
.await
|
||||
.expect("put");
|
||||
let retrieved = storage.get(location).await.expect("get");
|
||||
.with_context(|| "put".to_string())?;
|
||||
let retrieved = storage.get(location).await.with_context(|| "get".to_string())?;
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
// Test exists
|
||||
assert!(storage.exists(location).await.expect("exists check"));
|
||||
assert!(storage.exists(location).await.with_context(|| "exists check".to_string())?);
|
||||
|
||||
// Test delete
|
||||
storage.delete_prefix("test/data/").await.expect("delete");
|
||||
storage.delete_prefix("test/data/").await.with_context(|| "delete".to_string())?;
|
||||
assert!(!storage
|
||||
.exists(location)
|
||||
.await
|
||||
.expect("exists check after delete"));
|
||||
.with_context(|| "exists check after delete".to_string())?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_local_basic_operations() {
|
||||
async fn test_storage_manager_local_basic_operations() -> anyhow::Result<()> {
|
||||
let base = format!("/tmp/minne_storage_test_{}", Uuid::new_v4());
|
||||
let cfg = test_config(&base);
|
||||
let storage = StorageManager::new(&cfg)
|
||||
.await
|
||||
.expect("create storage manager");
|
||||
.with_context(|| "create storage manager".to_string())?;
|
||||
let resolved_base = storage
|
||||
.local_base_path()
|
||||
.expect("resolved base dir")
|
||||
.with_context(|| "resolved base dir".to_string())?
|
||||
.to_path_buf();
|
||||
assert_eq!(resolved_base, PathBuf::from(&base));
|
||||
|
||||
@@ -672,42 +673,44 @@ mod tests {
|
||||
storage
|
||||
.put(location, Bytes::from(data.to_vec()))
|
||||
.await
|
||||
.expect("put");
|
||||
let retrieved = storage.get(location).await.expect("get");
|
||||
.with_context(|| "put".to_string())?;
|
||||
let retrieved = storage.get(location).await.with_context(|| "get".to_string())?;
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
let object_dir = resolved_base.join("test/data");
|
||||
tokio::fs::metadata(&object_dir)
|
||||
.await
|
||||
.expect("object directory exists after write");
|
||||
.with_context(|| "object directory exists after write".to_string())?;
|
||||
|
||||
// Test exists
|
||||
assert!(storage.exists(location).await.expect("exists check"));
|
||||
assert!(storage.exists(location).await.with_context(|| "exists check".to_string())?);
|
||||
|
||||
// Test delete
|
||||
storage.delete_prefix("test/data/").await.expect("delete");
|
||||
storage.delete_prefix("test/data/").await.with_context(|| "delete".to_string())?;
|
||||
assert!(!storage
|
||||
.exists(location)
|
||||
.await
|
||||
.expect("exists check after delete"));
|
||||
.with_context(|| "exists check after delete".to_string())?);
|
||||
assert!(
|
||||
tokio::fs::metadata(&object_dir).await.is_err(),
|
||||
"object directory should be removed"
|
||||
);
|
||||
tokio::fs::metadata(&resolved_base)
|
||||
.await
|
||||
.expect("base directory remains intact");
|
||||
.with_context(|| "base directory remains intact".to_string())?;
|
||||
|
||||
// Clean up
|
||||
let _ = tokio::fs::remove_dir_all(&base).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_memory_persistence() {
|
||||
async fn test_storage_manager_memory_persistence() -> anyhow::Result<()> {
|
||||
let cfg = test_config_memory();
|
||||
let storage = StorageManager::new(&cfg)
|
||||
.await
|
||||
.expect("create storage manager");
|
||||
.with_context(|| "create storage manager".to_string())?;
|
||||
|
||||
let location = "persistence/test.txt";
|
||||
let data1 = b"first data";
|
||||
@@ -717,32 +720,34 @@ mod tests {
|
||||
storage
|
||||
.put(location, Bytes::from(data1.to_vec()))
|
||||
.await
|
||||
.expect("put first");
|
||||
.with_context(|| "put first".to_string())?;
|
||||
|
||||
// Retrieve and verify first data
|
||||
let retrieved1 = storage.get(location).await.expect("get first");
|
||||
let retrieved1 = storage.get(location).await.with_context(|| "get first".to_string())?;
|
||||
assert_eq!(retrieved1.as_ref(), data1);
|
||||
|
||||
// Overwrite with second data
|
||||
storage
|
||||
.put(location, Bytes::from(data2.to_vec()))
|
||||
.await
|
||||
.expect("put second");
|
||||
.with_context(|| "put second".to_string())?;
|
||||
|
||||
// Retrieve and verify second data
|
||||
let retrieved2 = storage.get(location).await.expect("get second");
|
||||
let retrieved2 = storage.get(location).await.with_context(|| "get second".to_string())?;
|
||||
assert_eq!(retrieved2.as_ref(), data2);
|
||||
|
||||
// Data persists across multiple operations using the same StorageManager
|
||||
assert_ne!(retrieved1.as_ref(), retrieved2.as_ref());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_list_operations() {
|
||||
async fn test_storage_manager_list_operations() -> anyhow::Result<()> {
|
||||
let cfg = test_config_memory();
|
||||
let storage = StorageManager::new(&cfg)
|
||||
.await
|
||||
.expect("create storage manager");
|
||||
.with_context(|| "create storage manager".to_string())?;
|
||||
|
||||
// Create multiple files
|
||||
let files = vec![
|
||||
@@ -755,15 +760,15 @@ mod tests {
|
||||
storage
|
||||
.put(location, Bytes::from(data.to_vec()))
|
||||
.await
|
||||
.expect("put");
|
||||
.with_context(|| "put".to_string())?;
|
||||
}
|
||||
|
||||
// Test listing without prefix
|
||||
let all_files = storage.list(None).await.expect("list all");
|
||||
let all_files = storage.list(None).await.with_context(|| "list all".to_string())?;
|
||||
assert_eq!(all_files.len(), 3);
|
||||
|
||||
// Test listing with prefix
|
||||
let dir1_files = storage.list(Some("dir1/")).await.expect("list dir1");
|
||||
let dir1_files = storage.list(Some("dir1/")).await.with_context(|| "list dir1".to_string())?;
|
||||
assert_eq!(dir1_files.len(), 2);
|
||||
assert!(dir1_files
|
||||
.iter()
|
||||
@@ -776,16 +781,18 @@ mod tests {
|
||||
let empty_files = storage
|
||||
.list(Some("nonexistent/"))
|
||||
.await
|
||||
.expect("list nonexistent");
|
||||
.with_context(|| "list nonexistent".to_string())?;
|
||||
assert_eq!(empty_files.len(), 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_stream_operations() {
|
||||
async fn test_storage_manager_stream_operations() -> anyhow::Result<()> {
|
||||
let cfg = test_config_memory();
|
||||
let storage = StorageManager::new(&cfg)
|
||||
.await
|
||||
.expect("create storage manager");
|
||||
.with_context(|| "create storage manager".to_string())?;
|
||||
|
||||
let location = "stream/test.bin";
|
||||
let content = vec![42u8; 1024 * 64]; // 64KB of data
|
||||
@@ -794,22 +801,24 @@ mod tests {
|
||||
storage
|
||||
.put(location, Bytes::from(content.clone()))
|
||||
.await
|
||||
.expect("put large data");
|
||||
.with_context(|| "put large data".to_string())?;
|
||||
|
||||
// Get as stream
|
||||
let mut stream = storage.get_stream(location).await.expect("get stream");
|
||||
let mut stream = storage.get_stream(location).await.with_context(|| "get stream".to_string())?;
|
||||
let mut collected = Vec::new();
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk.expect("stream chunk");
|
||||
let chunk = chunk.with_context(|| "stream chunk".to_string())?;
|
||||
collected.extend_from_slice(&chunk);
|
||||
}
|
||||
|
||||
assert_eq!(collected, content);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_with_custom_backend() {
|
||||
async fn test_storage_manager_with_custom_backend() -> anyhow::Result<()> {
|
||||
use object_store::memory::InMemory;
|
||||
|
||||
// Create custom memory backend
|
||||
@@ -823,20 +832,22 @@ mod tests {
|
||||
storage
|
||||
.put(location, Bytes::from(data.to_vec()))
|
||||
.await
|
||||
.expect("put");
|
||||
let retrieved = storage.get(location).await.expect("get");
|
||||
.with_context(|| "put".to_string())?;
|
||||
let retrieved = storage.get(location).await.with_context(|| "get".to_string())?;
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
assert!(storage.exists(location).await.expect("exists"));
|
||||
assert!(storage.exists(location).await.with_context(|| "exists".to_string())?);
|
||||
assert_eq!(*storage.backend_kind(), StorageKind::Memory);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_error_handling() {
|
||||
async fn test_storage_manager_error_handling() -> anyhow::Result<()> {
|
||||
let cfg = test_config_memory();
|
||||
let storage = StorageManager::new(&cfg)
|
||||
.await
|
||||
.expect("create storage manager");
|
||||
.with_context(|| "create storage manager".to_string())?;
|
||||
|
||||
// Test getting non-existent file
|
||||
let result = storage.get("nonexistent.txt").await;
|
||||
@@ -846,124 +857,136 @@ mod tests {
|
||||
let exists = storage
|
||||
.exists("nonexistent.txt")
|
||||
.await
|
||||
.expect("exists check");
|
||||
.with_context(|| "exists check".to_string())?;
|
||||
assert!(!exists);
|
||||
|
||||
// Test listing with invalid location (should not panic)
|
||||
let _result = storage.get("").await;
|
||||
// This may or may not error depending on the backend implementation
|
||||
// The important thing is that it doesn't panic
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// TestStorageManager tests
|
||||
#[tokio::test]
|
||||
async fn test_test_storage_manager_memory() {
|
||||
async fn test_test_storage_manager_memory() -> anyhow::Result<()> {
|
||||
let test_storage = testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.expect("create test storage");
|
||||
.with_context(|| "create test storage".to_string())?;
|
||||
|
||||
let location = "test/storage/file.txt";
|
||||
let data = b"test data with TestStorageManager";
|
||||
|
||||
// Test put and get
|
||||
test_storage.put(location, data).await.expect("put");
|
||||
let retrieved = test_storage.get(location).await.expect("get");
|
||||
test_storage.put(location, data).await.with_context(|| "put".to_string())?;
|
||||
let retrieved = test_storage.get(location).await.with_context(|| "get".to_string())?;
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
// Test existence check
|
||||
assert!(test_storage.exists(location).await.expect("exists"));
|
||||
assert!(test_storage.exists(location).await.with_context(|| "exists".to_string())?);
|
||||
|
||||
// Test list
|
||||
let files = test_storage
|
||||
.list(Some("test/storage/"))
|
||||
.await
|
||||
.expect("list");
|
||||
.with_context(|| "list".to_string())?;
|
||||
assert_eq!(files.len(), 1);
|
||||
|
||||
// Test delete
|
||||
test_storage
|
||||
.delete_prefix("test/storage/")
|
||||
.await
|
||||
.expect("delete");
|
||||
.with_context(|| "delete".to_string())?;
|
||||
assert!(!test_storage
|
||||
.exists(location)
|
||||
.await
|
||||
.expect("exists after delete"));
|
||||
.with_context(|| "exists after delete".to_string())?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_test_storage_manager_local() {
|
||||
async fn test_test_storage_manager_local() -> anyhow::Result<()> {
|
||||
let test_storage = testing::TestStorageManager::new_local()
|
||||
.await
|
||||
.expect("create test storage");
|
||||
.with_context(|| "create test storage".to_string())?;
|
||||
|
||||
let location = "test/local/file.txt";
|
||||
let data = b"test data with local TestStorageManager";
|
||||
|
||||
// Test put and get
|
||||
test_storage.put(location, data).await.expect("put");
|
||||
let retrieved = test_storage.get(location).await.expect("get");
|
||||
test_storage.put(location, data).await
|
||||
.with_context(|| "put".to_string())?;
|
||||
let retrieved = test_storage.get(location).await
|
||||
.with_context(|| "get".to_string())?;
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
// Test existence check
|
||||
assert!(test_storage.exists(location).await.expect("exists"));
|
||||
assert!(test_storage.exists(location).await
|
||||
.with_context(|| "exists".to_string())?);
|
||||
|
||||
// The storage should be automatically cleaned up when test_storage is dropped
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_test_storage_manager_isolation() {
|
||||
async fn test_test_storage_manager_isolation() -> anyhow::Result<()> {
|
||||
let storage1 = testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.expect("create test storage 1");
|
||||
.with_context(|| "create test storage 1".to_string())?;
|
||||
let storage2 = testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.expect("create test storage 2");
|
||||
.with_context(|| "create test storage 2".to_string())?;
|
||||
|
||||
let location = "isolation/test.txt";
|
||||
let data1 = b"storage 1 data";
|
||||
let data2 = b"storage 2 data";
|
||||
|
||||
// Put different data in each storage
|
||||
storage1.put(location, data1).await.expect("put storage 1");
|
||||
storage2.put(location, data2).await.expect("put storage 2");
|
||||
storage1.put(location, data1).await
|
||||
.with_context(|| "put storage 1".to_string())?;
|
||||
storage2.put(location, data2).await
|
||||
.with_context(|| "put storage 2".to_string())?;
|
||||
|
||||
// Verify isolation
|
||||
let retrieved1 = storage1.get(location).await.expect("get storage 1");
|
||||
let retrieved2 = storage2.get(location).await.expect("get storage 2");
|
||||
let retrieved1 = storage1.get(location).await
|
||||
.with_context(|| "get storage 1".to_string())?;
|
||||
let retrieved2 = storage2.get(location).await
|
||||
.with_context(|| "get storage 2".to_string())?;
|
||||
|
||||
assert_eq!(retrieved1.as_ref(), data1);
|
||||
assert_eq!(retrieved2.as_ref(), data2);
|
||||
assert_ne!(retrieved1.as_ref(), retrieved2.as_ref());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_test_storage_manager_config() {
|
||||
async fn test_test_storage_manager_config() -> anyhow::Result<()> {
|
||||
let cfg = testing::test_config_memory();
|
||||
let test_storage = testing::TestStorageManager::with_config(&cfg)
|
||||
.await
|
||||
.expect("create test storage with config");
|
||||
.with_context(|| "create test storage with config".to_string())?;
|
||||
|
||||
let location = "config/test.txt";
|
||||
let data = b"test data with custom config";
|
||||
|
||||
test_storage.put(location, data).await.expect("put");
|
||||
let retrieved = test_storage.get(location).await.expect("get");
|
||||
test_storage.put(location, data).await
|
||||
.with_context(|| "put".to_string())?;
|
||||
let retrieved = test_storage.get(location).await
|
||||
.with_context(|| "get".to_string())?;
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
// Verify it's using memory backend
|
||||
assert_eq!(*test_storage.storage().backend_kind(), StorageKind::Memory);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// S3 Tests - Require a reachable MinIO endpoint and test bucket.
|
||||
// `TestStorageManager::new_s3()` probes connectivity and these tests auto-skip when unavailable.
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_s3_basic_operations() {
|
||||
async fn test_storage_manager_s3_basic_operations() -> anyhow::Result<()> {
|
||||
// Skip if S3 connection fails (e.g. no MinIO)
|
||||
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
|
||||
eprintln!("Skipping S3 test (setup failed)");
|
||||
return;
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let prefix = format!("test-basic-{}", Uuid::new_v4());
|
||||
@@ -973,31 +996,33 @@ mod tests {
|
||||
// Test put
|
||||
if let Err(e) = storage.put(&location, data).await {
|
||||
eprintln!("Skipping S3 test (put failed - bucket missing?): {e}");
|
||||
return;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Test get
|
||||
let retrieved = storage.get(&location).await.expect("get");
|
||||
let retrieved = storage.get(&location).await.with_context(|| "get".to_string())?;
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
// Test exists
|
||||
assert!(storage.exists(&location).await.expect("exists"));
|
||||
assert!(storage.exists(&location).await.with_context(|| "exists".to_string())?);
|
||||
|
||||
// Test delete
|
||||
storage
|
||||
.delete_prefix(&format!("{prefix}/"))
|
||||
.await
|
||||
.expect("delete");
|
||||
.with_context(|| "delete".to_string())?;
|
||||
assert!(!storage
|
||||
.exists(&location)
|
||||
.await
|
||||
.expect("exists after delete"));
|
||||
.with_context(|| "exists after delete".to_string())?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_s3_list_operations() {
|
||||
async fn test_storage_manager_s3_list_operations() -> anyhow::Result<()> {
|
||||
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
|
||||
return;
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let prefix = format!("test-list-{}", Uuid::new_v4());
|
||||
@@ -1009,23 +1034,25 @@ mod tests {
|
||||
|
||||
for (loc, data) in &files {
|
||||
if storage.put(loc, *data).await.is_err() {
|
||||
return; // Abort if put fails
|
||||
return Ok(()); // Abort if put fails
|
||||
}
|
||||
}
|
||||
|
||||
// List with prefix
|
||||
let list_prefix = format!("{prefix}/");
|
||||
let items = storage.list(Some(&list_prefix)).await.expect("list");
|
||||
let items = storage.list(Some(&list_prefix)).await.with_context(|| "list".to_string())?;
|
||||
assert_eq!(items.len(), 3);
|
||||
|
||||
// Cleanup
|
||||
storage.delete_prefix(&list_prefix).await.expect("cleanup");
|
||||
storage.delete_prefix(&list_prefix).await.with_context(|| "cleanup".to_string())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_s3_stream_operations() {
|
||||
async fn test_storage_manager_s3_stream_operations() -> anyhow::Result<()> {
|
||||
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
|
||||
return;
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let prefix = format!("test-stream-{}", Uuid::new_v4());
|
||||
@@ -1033,38 +1060,45 @@ mod tests {
|
||||
let content = vec![42u8; 1024 * 10]; // 10KB
|
||||
|
||||
if storage.put(&location, &content).await.is_err() {
|
||||
return;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut stream = storage.get_stream(&location).await.expect("get stream");
|
||||
let mut stream = storage.get_stream(&location).await.with_context(|| "get stream".to_string())?;
|
||||
let mut collected = Vec::new();
|
||||
while let Some(chunk) = stream.next().await {
|
||||
collected.extend_from_slice(&chunk.expect("chunk"));
|
||||
collected.extend_from_slice(&chunk.with_context(|| "chunk".to_string())?);
|
||||
}
|
||||
assert_eq!(collected, content);
|
||||
|
||||
storage
|
||||
.delete_prefix(&format!("{prefix}/"))
|
||||
.await
|
||||
.expect("cleanup");
|
||||
.with_context(|| "cleanup".to_string())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_s3_backend_kind() {
|
||||
async fn test_storage_manager_s3_backend_kind() -> anyhow::Result<()> {
|
||||
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
|
||||
return;
|
||||
return Ok(());
|
||||
};
|
||||
assert_eq!(*storage.storage().backend_kind(), StorageKind::S3);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_s3_error_handling() {
|
||||
async fn test_storage_manager_s3_error_handling() -> anyhow::Result<()> {
|
||||
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
|
||||
return;
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let location = format!("nonexistent-{}/file.txt", Uuid::new_v4());
|
||||
assert!(storage.get(&location).await.is_err());
|
||||
assert!(!storage.exists(&location).await.expect("exists check"));
|
||||
// exists may fail if S3 is unavailable; treat error as false
|
||||
assert!(!storage.exists(&location).await.unwrap_or(false));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,6 +90,7 @@ impl Analytics {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::stored_object;
|
||||
use anyhow::{self};
|
||||
use uuid::Uuid;
|
||||
|
||||
stored_object!(TestUser, "user", {
|
||||
@@ -99,18 +100,14 @@ mod tests {
|
||||
});
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analytics_initialization() {
|
||||
async fn test_analytics_initialization() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
// Test initialization of analytics
|
||||
let analytics = Analytics::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to initialize analytics");
|
||||
let analytics = Analytics::ensure_initialized(&db).await?;
|
||||
|
||||
// Verify initial state after initialization
|
||||
assert_eq!(analytics.id, "current");
|
||||
@@ -118,159 +115,134 @@ mod tests {
|
||||
assert_eq!(analytics.visitors, 0);
|
||||
|
||||
// Test idempotency - ensure calling it again doesn't change anything
|
||||
let analytics_again = Analytics::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to get analytics after initialization");
|
||||
let analytics_again = Analytics::ensure_initialized(&db).await?;
|
||||
|
||||
assert_eq!(analytics.id, analytics_again.id);
|
||||
assert_eq!(analytics.page_loads, analytics_again.page_loads);
|
||||
assert_eq!(analytics.visitors, analytics_again.visitors);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_current_analytics() {
|
||||
async fn test_get_current_analytics() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
// Initialize analytics
|
||||
Analytics::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to initialize analytics");
|
||||
Analytics::ensure_initialized(&db).await?;
|
||||
|
||||
// Test get_current method
|
||||
let analytics = Analytics::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to get current analytics");
|
||||
let analytics = Analytics::get_current(&db).await?;
|
||||
|
||||
assert_eq!(analytics.id, "current");
|
||||
assert_eq!(analytics.page_loads, 0);
|
||||
assert_eq!(analytics.visitors, 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_increment_visitors() {
|
||||
async fn test_increment_visitors() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
// Initialize analytics
|
||||
Analytics::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to initialize analytics");
|
||||
Analytics::ensure_initialized(&db).await?;
|
||||
|
||||
// Test increment_visitors method
|
||||
let analytics = Analytics::increment_visitors(&db)
|
||||
.await
|
||||
.expect("Failed to increment visitors");
|
||||
let analytics = Analytics::increment_visitors(&db).await?;
|
||||
|
||||
assert_eq!(analytics.visitors, 1);
|
||||
assert_eq!(analytics.page_loads, 0);
|
||||
|
||||
// Increment again and check
|
||||
let analytics = Analytics::increment_visitors(&db)
|
||||
.await
|
||||
.expect("Failed to increment visitors again");
|
||||
let analytics = Analytics::increment_visitors(&db).await?;
|
||||
|
||||
assert_eq!(analytics.visitors, 2);
|
||||
assert_eq!(analytics.page_loads, 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_increment_page_loads() {
|
||||
async fn test_increment_page_loads() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
// Initialize analytics
|
||||
Analytics::ensure_initialized(&db)
|
||||
.await
|
||||
.expect("Failed to initialize analytics");
|
||||
Analytics::ensure_initialized(&db).await?;
|
||||
|
||||
// Test increment_page_loads method
|
||||
let analytics = Analytics::increment_page_loads(&db)
|
||||
.await
|
||||
.expect("Failed to increment page loads");
|
||||
let analytics = Analytics::increment_page_loads(&db).await?;
|
||||
|
||||
assert_eq!(analytics.visitors, 0);
|
||||
assert_eq!(analytics.page_loads, 1);
|
||||
|
||||
// Increment again and check
|
||||
let analytics = Analytics::increment_page_loads(&db)
|
||||
.await
|
||||
.expect("Failed to increment page loads again");
|
||||
let analytics = Analytics::increment_page_loads(&db).await?;
|
||||
|
||||
assert_eq!(analytics.visitors, 0);
|
||||
assert_eq!(analytics.page_loads, 2);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_users_amount() {
|
||||
async fn test_get_users_amount() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
// Test with no users
|
||||
let count = Analytics::get_users_amount(&db)
|
||||
.await
|
||||
.expect("Failed to get users amount");
|
||||
let count = Analytics::get_users_amount(&db).await?;
|
||||
assert_eq!(count, 0);
|
||||
|
||||
// Create a few test users
|
||||
for i in 0..3 {
|
||||
let user = TestUser {
|
||||
id: format!("user{}", i),
|
||||
email: format!("user{}@example.com", i),
|
||||
id: format!("user{i}"),
|
||||
email: format!("user{i}@example.com"),
|
||||
password: "password".to_string(),
|
||||
user_id: format!("uid{}", i),
|
||||
user_id: format!("uid{i}"),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
};
|
||||
|
||||
db.store_item(user)
|
||||
.await
|
||||
.expect("Failed to create test user");
|
||||
db.store_item(user).await?;
|
||||
}
|
||||
|
||||
// Test users amount after adding users
|
||||
let count = Analytics::get_users_amount(&db)
|
||||
.await
|
||||
.expect("Failed to get users amount after adding users");
|
||||
let count = Analytics::get_users_amount(&db).await?;
|
||||
assert_eq!(count, 3);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_current_nonexistent() {
|
||||
async fn test_get_current_nonexistent() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
// Don't initialize analytics and try to get it
|
||||
let result = Analytics::get_current(&db).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
if let Err(err) = result {
|
||||
match err {
|
||||
AppError::NotFound(_) => {
|
||||
// Expected error
|
||||
}
|
||||
_ => panic!("Expected NotFound error, got: {:?}", err),
|
||||
}
|
||||
match result {
|
||||
Ok(_) => anyhow::bail!("Expected NotFound error, got success"),
|
||||
Err(AppError::NotFound(_)) => {}
|
||||
Err(err) => anyhow::bail!("Expected NotFound error, got: {err:?}"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,76 +144,71 @@ impl Conversation {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
use crate::storage::types::message::MessageRole;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_conversation() {
|
||||
// Setup in-memory database for testing
|
||||
async fn test_create_conversation() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Create a new conversation
|
||||
let user_id = "test_user";
|
||||
let title = "Test Conversation";
|
||||
let conversation = Conversation::new(user_id.to_string(), title.to_string());
|
||||
|
||||
// Verify conversation properties
|
||||
assert_eq!(conversation.user_id, user_id);
|
||||
assert_eq!(conversation.title, title);
|
||||
assert!(!conversation.id.is_empty());
|
||||
|
||||
// Store the conversation
|
||||
let result = db.store_item(conversation.clone()).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify it can be retrieved
|
||||
let retrieved: Option<Conversation> = db
|
||||
.get_item(&conversation.id)
|
||||
.await
|
||||
.expect("Failed to retrieve conversation");
|
||||
assert!(retrieved.is_some());
|
||||
.with_context(|| "Failed to retrieve conversation".to_string())?;
|
||||
|
||||
let retrieved = retrieved.unwrap();
|
||||
let retrieved = retrieved.ok_or_else(|| anyhow::anyhow!("Expected conversation to exist"))?;
|
||||
assert_eq!(retrieved.id, conversation.id);
|
||||
assert_eq!(retrieved.user_id, user_id);
|
||||
assert_eq!(retrieved.title, title);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_complete_conversation_not_found() {
|
||||
// Setup in-memory database for testing
|
||||
async fn test_get_complete_conversation_not_found() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Try to get a conversation that doesn't exist
|
||||
let result =
|
||||
Conversation::get_complete_conversation("nonexistent_id", "test_user", &db).await;
|
||||
assert!(result.is_err());
|
||||
|
||||
match result {
|
||||
Err(AppError::NotFound(_)) => { /* expected error */ }
|
||||
_ => panic!("Expected NotFound error"),
|
||||
Err(AppError::NotFound(_)) => {}
|
||||
_ => anyhow::bail!("Expected NotFound error"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_complete_conversation_unauthorized() {
|
||||
// Setup in-memory database for testing
|
||||
async fn test_get_complete_conversation_unauthorized() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Create and store a conversation for user_id_1
|
||||
let user_id_1 = "user_1";
|
||||
let conversation =
|
||||
Conversation::new(user_id_1.to_string(), "Private Conversation".to_string());
|
||||
@@ -221,27 +216,28 @@ mod tests {
|
||||
|
||||
db.store_item(conversation)
|
||||
.await
|
||||
.expect("Failed to store conversation");
|
||||
.with_context(|| "Failed to store conversation".to_string())?;
|
||||
|
||||
// Try to access with a different user
|
||||
let user_id_2 = "user_2";
|
||||
let result =
|
||||
Conversation::get_complete_conversation(&conversation_id, user_id_2, &db).await;
|
||||
assert!(result.is_err());
|
||||
|
||||
match result {
|
||||
Err(AppError::Auth(_)) => { /* expected error */ }
|
||||
_ => panic!("Expected Auth error"),
|
||||
Err(AppError::Auth(_)) => {}
|
||||
_ => anyhow::bail!("Expected Auth error"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_patch_title_success() {
|
||||
async fn test_patch_title_success() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
let user_id = "user_1";
|
||||
let original_title = "Original Title";
|
||||
@@ -250,49 +246,50 @@ mod tests {
|
||||
|
||||
db.store_item(conversation)
|
||||
.await
|
||||
.expect("Failed to store conversation");
|
||||
.with_context(|| "Failed to store conversation".to_string())?;
|
||||
|
||||
let new_title = "Updated Title";
|
||||
|
||||
// Patch title successfully
|
||||
let result = Conversation::patch_title(&conversation_id, user_id, new_title, &db).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Retrieve from DB to verify
|
||||
let updated_conversation = db
|
||||
.get_item::<Conversation>(&conversation_id)
|
||||
.await
|
||||
.expect("Failed to get conversation")
|
||||
.expect("Conversation missing");
|
||||
.with_context(|| "Failed to get conversation".to_string())?
|
||||
.ok_or_else(|| anyhow::anyhow!("Conversation missing"))?;
|
||||
assert_eq!(updated_conversation.title, new_title);
|
||||
assert_eq!(updated_conversation.user_id, user_id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_patch_title_not_found() {
|
||||
async fn test_patch_title_not_found() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Try to patch non-existing conversation
|
||||
let result = Conversation::patch_title("nonexistent", "user_x", "New Title", &db).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(AppError::NotFound(_)) => {}
|
||||
_ => panic!("Expected NotFound error"),
|
||||
_ => anyhow::bail!("Expected NotFound error"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_patch_title_unauthorized() {
|
||||
async fn test_patch_title_unauthorized() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
let owner_id = "owner";
|
||||
let other_user_id = "intruder";
|
||||
@@ -301,17 +298,18 @@ mod tests {
|
||||
|
||||
db.store_item(conversation)
|
||||
.await
|
||||
.expect("Failed to store conversation");
|
||||
.with_context(|| "Failed to store conversation".to_string())?;
|
||||
|
||||
// Attempt patch with unauthorized user
|
||||
let result =
|
||||
Conversation::patch_title(&conversation_id, other_user_id, "Hacked Title", &db).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(AppError::Auth(_)) => {}
|
||||
_ => panic!("Expected Auth error"),
|
||||
_ => anyhow::bail!("Expected Auth error"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -405,24 +403,21 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_complete_conversation_with_messages() {
|
||||
// Setup in-memory database for testing
|
||||
async fn test_get_complete_conversation_with_messages() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Create and store a conversation for user_id_1
|
||||
let user_id_1 = "user_1";
|
||||
let conversation = Conversation::new(user_id_1.to_string(), "Conversation".to_string());
|
||||
let conversation_id = conversation.id.clone();
|
||||
|
||||
db.store_item(conversation)
|
||||
.await
|
||||
.expect("Failed to store conversation");
|
||||
.with_context(|| "Failed to store conversation".to_string())?;
|
||||
|
||||
// Create messages
|
||||
let message1 = Message::new(
|
||||
conversation_id.clone(),
|
||||
MessageRole::User,
|
||||
@@ -442,46 +437,44 @@ mod tests {
|
||||
None,
|
||||
);
|
||||
|
||||
// Store messages
|
||||
db.store_item(message1)
|
||||
.await
|
||||
.expect("Failed to store message1");
|
||||
.with_context(|| "Failed to store message1".to_string())?;
|
||||
db.store_item(message2)
|
||||
.await
|
||||
.expect("Failed to store message2");
|
||||
.with_context(|| "Failed to store message2".to_string())?;
|
||||
db.store_item(message3)
|
||||
.await
|
||||
.expect("Failed to store message3");
|
||||
.with_context(|| "Failed to store message3".to_string())?;
|
||||
|
||||
// Retrieve the complete conversation
|
||||
let result =
|
||||
Conversation::get_complete_conversation(&conversation_id, user_id_1, &db).await;
|
||||
assert!(result.is_ok(), "Failed to retrieve complete conversation");
|
||||
|
||||
let (retrieved_conversation, messages) = result.unwrap();
|
||||
let (retrieved_conversation, retrieved_messages) = result
|
||||
.with_context(|| "Failed to retrieve complete conversation".to_string())?;
|
||||
|
||||
// Verify conversation data
|
||||
assert_eq!(retrieved_conversation.id, conversation_id);
|
||||
assert_eq!(retrieved_conversation.user_id, user_id_1);
|
||||
assert_eq!(retrieved_conversation.title, "Conversation");
|
||||
|
||||
// Verify messages
|
||||
assert_eq!(messages.len(), 3);
|
||||
assert_eq!(retrieved_messages.len(), 3);
|
||||
|
||||
// Verify messages are sorted by updated_at
|
||||
let message_contents: Vec<&str> = messages.iter().map(|m| m.content.as_str()).collect();
|
||||
let message_contents: Vec<&str> =
|
||||
retrieved_messages.iter().map(|m| m.content.as_str()).collect();
|
||||
assert!(message_contents.contains(&"Hello, AI!"));
|
||||
assert!(message_contents.contains(&"Hello, human! How can I help you today?"));
|
||||
assert!(message_contents.contains(&"Tell me about Rust programming."));
|
||||
|
||||
// Make sure we can't access with different user
|
||||
let user_id_2 = "user_2";
|
||||
let unauthorized_result =
|
||||
Conversation::get_complete_conversation(&conversation_id, user_id_2, &db).await;
|
||||
assert!(unauthorized_result.is_err());
|
||||
match unauthorized_result {
|
||||
Err(AppError::Auth(_)) => { /* expected error */ }
|
||||
_ => panic!("Expected Auth error"),
|
||||
Err(AppError::Auth(_)) => {}
|
||||
_ => anyhow::bail!("Expected Auth error"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -320,6 +320,8 @@ impl FileInfo {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
|
||||
use super::*;
|
||||
use crate::storage::store::testing::TestStorageManager;
|
||||
use axum::http::HeaderMap;
|
||||
@@ -328,11 +330,11 @@ mod tests {
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
/// Creates a test temporary file with the given content
|
||||
fn create_test_file(content: &[u8], file_name: &str) -> FieldData<NamedTempFile> {
|
||||
let mut temp_file = NamedTempFile::new().expect("Failed to create temp file");
|
||||
fn create_test_file(content: &[u8], file_name: &str) -> anyhow::Result<FieldData<NamedTempFile>> {
|
||||
let mut temp_file = NamedTempFile::new().with_context(|| "Failed to create temp file".to_string())?;
|
||||
temp_file
|
||||
.write_all(content)
|
||||
.expect("Failed to write to temp file");
|
||||
.with_context(|| "Failed to write to temp file".to_string())?;
|
||||
|
||||
let metadata = FieldMetadata {
|
||||
name: Some("file".to_string()),
|
||||
@@ -341,31 +343,29 @@ mod tests {
|
||||
headers: HeaderMap::default(),
|
||||
};
|
||||
|
||||
let field_data = FieldData {
|
||||
Ok(FieldData {
|
||||
metadata,
|
||||
contents: temp_file,
|
||||
};
|
||||
|
||||
field_data
|
||||
})
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fileinfo_create_read_delete_with_storage_manager() {
|
||||
async fn test_fileinfo_create_read_delete_with_storage_manager() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.unwrap();
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let content = b"This is a test file for StorageManager operations";
|
||||
let file_name = "storage_manager_test.txt";
|
||||
let field_data = create_test_file(content, file_name);
|
||||
let field_data = create_test_file(content, file_name)?;
|
||||
|
||||
// Create test storage manager (memory backend)
|
||||
let test_storage = store::testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.expect("Failed to create test storage manager");
|
||||
.with_context(|| "Failed to create test storage manager".to_string())?;
|
||||
|
||||
// Create a FileInfo instance with storage manager
|
||||
let user_id = "test_user";
|
||||
@@ -374,20 +374,20 @@ mod tests {
|
||||
let file_info =
|
||||
FileInfo::new_with_storage(field_data, &db, user_id, test_storage.storage())
|
||||
.await
|
||||
.expect("Failed to create file with StorageManager");
|
||||
.with_context(|| "Failed to create file with StorageManager".to_string())?;
|
||||
assert_eq!(file_info.file_name, file_name);
|
||||
|
||||
// Verify the file exists via StorageManager and has correct content
|
||||
let bytes = file_info
|
||||
.get_content_with_storage(test_storage.storage())
|
||||
.await
|
||||
.expect("Failed to read file content via StorageManager");
|
||||
.with_context(|| "Failed to read file content via StorageManager".to_string())?;
|
||||
assert_eq!(bytes.as_ref(), content);
|
||||
|
||||
// Test file reading
|
||||
let retrieved = FileInfo::get_by_id(&file_info.id, &db)
|
||||
.await
|
||||
.expect("Failed to retrieve file info");
|
||||
.with_context(|| "Failed to retrieve file info".to_string())?;
|
||||
assert_eq!(retrieved.id, file_info.id);
|
||||
assert_eq!(retrieved.sha256, file_info.sha256);
|
||||
assert_eq!(retrieved.file_name, file_name);
|
||||
@@ -395,65 +395,65 @@ mod tests {
|
||||
// Test file deletion with StorageManager
|
||||
FileInfo::delete_by_id_with_storage(&file_info.id, &db, test_storage.storage())
|
||||
.await
|
||||
.expect("Failed to delete file with StorageManager");
|
||||
.with_context(|| "Failed to delete file with StorageManager".to_string())?;
|
||||
|
||||
let deleted_result = file_info
|
||||
.get_content_with_storage(test_storage.storage())
|
||||
.await;
|
||||
assert!(deleted_result.is_err(), "File should be deleted");
|
||||
|
||||
// No cleanup needed - TestStorageManager handles it automatically
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fileinfo_preserves_original_filename_and_sanitizes_path() {
|
||||
async fn test_fileinfo_preserves_original_filename_and_sanitizes_path() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.unwrap();
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let content = b"filename sanitization";
|
||||
let original_name = "Complex name (1).txt";
|
||||
let expected_sanitized = "Complex_name__1_.txt";
|
||||
let field_data = create_test_file(content, original_name);
|
||||
let field_data = create_test_file(content, original_name)?;
|
||||
|
||||
let test_storage = store::testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.expect("Failed to create test storage manager");
|
||||
.with_context(|| "Failed to create test storage manager".to_string())?;
|
||||
|
||||
let file_info =
|
||||
FileInfo::new_with_storage(field_data, &db, "sanitized_user", test_storage.storage())
|
||||
.await
|
||||
.expect("Failed to create file via storage manager");
|
||||
.with_context(|| "Failed to create file via storage manager".to_string())?;
|
||||
|
||||
assert_eq!(file_info.file_name, original_name);
|
||||
|
||||
let stored_name = Path::new(&file_info.path)
|
||||
.file_name()
|
||||
.and_then(|name| name.to_str())
|
||||
.expect("stored name");
|
||||
.with_context(|| "stored name".to_string())?;
|
||||
assert_eq!(stored_name, expected_sanitized);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fileinfo_duplicate_detection_with_storage_manager() {
|
||||
async fn test_fileinfo_duplicate_detection_with_storage_manager() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.unwrap();
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let content = b"This is a test file for StorageManager duplicate detection";
|
||||
let file_name = "storage_manager_duplicate.txt";
|
||||
let field_data = create_test_file(content, file_name);
|
||||
let field_data = create_test_file(content, file_name)?;
|
||||
|
||||
// Create test storage manager
|
||||
let test_storage = store::testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.expect("Failed to create test storage manager");
|
||||
.with_context(|| "Failed to create test storage manager".to_string())?;
|
||||
|
||||
// Create a FileInfo instance with storage manager
|
||||
let user_id = "test_user";
|
||||
@@ -462,17 +462,17 @@ mod tests {
|
||||
let original_file_info =
|
||||
FileInfo::new_with_storage(field_data, &db, user_id, test_storage.storage())
|
||||
.await
|
||||
.expect("Failed to create original file with StorageManager");
|
||||
.with_context(|| "Failed to create original file with StorageManager".to_string())?;
|
||||
|
||||
// Create another file with the same content but different name
|
||||
let duplicate_name = "storage_manager_duplicate_2.txt";
|
||||
let field_data2 = create_test_file(content, duplicate_name);
|
||||
let field_data2 = create_test_file(content, duplicate_name)?;
|
||||
|
||||
// The system should detect it's the same file and return the original FileInfo
|
||||
let duplicate_file_info =
|
||||
FileInfo::new_with_storage(field_data2, &db, user_id, test_storage.storage())
|
||||
.await
|
||||
.expect("Failed to process duplicate file with StorageManager");
|
||||
.with_context(|| "Failed to process duplicate file with StorageManager".to_string())?;
|
||||
|
||||
// Verify duplicate detection worked
|
||||
assert_eq!(duplicate_file_info.id, original_file_info.id);
|
||||
@@ -484,46 +484,44 @@ mod tests {
|
||||
let original_content = original_file_info
|
||||
.get_content_with_storage(test_storage.storage())
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "get original content".to_string())?;
|
||||
let duplicate_content = duplicate_file_info
|
||||
.get_content_with_storage(test_storage.storage())
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "get duplicate content".to_string())?;
|
||||
assert_eq!(original_content.as_ref(), content);
|
||||
assert_eq!(duplicate_content.as_ref(), content);
|
||||
|
||||
// Clean up
|
||||
FileInfo::delete_by_id_with_storage(&original_file_info.id, &db, test_storage.storage())
|
||||
.await
|
||||
.expect("Failed to delete original file with StorageManager");
|
||||
.with_context(|| "Failed to delete original file with StorageManager".to_string())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_file_creation() {
|
||||
async fn test_file_creation() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
let content = b"This is a test file content";
|
||||
let file_name = "test_file.txt";
|
||||
let field_data = create_test_file(content, file_name);
|
||||
let field_data = create_test_file(content, file_name)?;
|
||||
|
||||
// Create a FileInfo instance with StorageManager
|
||||
let user_id = "test_user";
|
||||
let test_storage = TestStorageManager::new_memory()
|
||||
.await
|
||||
.expect("create test storage manager");
|
||||
.with_context(|| "create test storage manager".to_string())?;
|
||||
let file_info =
|
||||
FileInfo::new_with_storage(field_data, &db, user_id, test_storage.storage()).await;
|
||||
|
||||
// Verify the FileInfo was created successfully
|
||||
assert!(file_info.is_ok());
|
||||
let file_info = file_info.unwrap();
|
||||
FileInfo::new_with_storage(field_data, &db, user_id, test_storage.storage())
|
||||
.await?;
|
||||
|
||||
// Check essential properties
|
||||
assert!(!file_info.id.is_empty());
|
||||
@@ -533,32 +531,32 @@ mod tests {
|
||||
// path should be logical: "user_id/uuid/file_name"
|
||||
let parts: Vec<&str> = file_info.path.split('/').collect();
|
||||
assert_eq!(parts.len(), 3);
|
||||
assert_eq!(parts[0], user_id);
|
||||
assert_eq!(parts[2], file_name);
|
||||
assert_eq!(parts.first(), Some(&user_id));
|
||||
assert_eq!(parts.get(2), Some(&file_name));
|
||||
assert!(file_info.mime_type.contains("text/plain"));
|
||||
|
||||
// Verify it's in the database
|
||||
let stored: Option<FileInfo> = db
|
||||
.get_item(&file_info.id)
|
||||
let stored = db
|
||||
.get_item::<FileInfo>(&file_info.id)
|
||||
.await
|
||||
.expect("Failed to retrieve file info");
|
||||
assert!(stored.is_some());
|
||||
let stored = stored.unwrap();
|
||||
.with_context(|| "Failed to retrieve file info".to_string())?
|
||||
.with_context(|| "expected stored file".to_string())?;
|
||||
assert_eq!(stored.id, file_info.id);
|
||||
assert_eq!(stored.file_name, file_info.file_name);
|
||||
assert_eq!(stored.sha256, file_info.sha256);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_file_duplicate_detection() {
|
||||
async fn test_file_duplicate_detection() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
// First, store a file with known content
|
||||
let content = b"This is a test file for duplicate detection";
|
||||
@@ -567,23 +565,23 @@ mod tests {
|
||||
|
||||
let test_storage = TestStorageManager::new_memory()
|
||||
.await
|
||||
.expect("create test storage manager");
|
||||
.with_context(|| "create test storage manager".to_string())?;
|
||||
|
||||
let field_data1 = create_test_file(content, file_name);
|
||||
let field_data1 = create_test_file(content, file_name)?;
|
||||
let original_file_info =
|
||||
FileInfo::new_with_storage(field_data1, &db, user_id, test_storage.storage())
|
||||
.await
|
||||
.expect("Failed to create original file");
|
||||
.with_context(|| "Failed to create original file".to_string())?;
|
||||
|
||||
// Now try to store another file with the same content but different name
|
||||
let duplicate_name = "duplicate.txt";
|
||||
let field_data2 = create_test_file(content, duplicate_name);
|
||||
let field_data2 = create_test_file(content, duplicate_name)?;
|
||||
|
||||
// The system should detect it's the same file and return the original FileInfo
|
||||
let duplicate_file_info =
|
||||
FileInfo::new_with_storage(field_data2, &db, user_id, test_storage.storage())
|
||||
.await
|
||||
.expect("Failed to process duplicate file");
|
||||
.with_context(|| "Failed to process duplicate file".to_string())?;
|
||||
|
||||
// The returned FileInfo should match the original
|
||||
assert_eq!(duplicate_file_info.id, original_file_info.id);
|
||||
@@ -592,10 +590,11 @@ mod tests {
|
||||
// But it should retain the original file name, not the duplicate's name
|
||||
assert_eq!(duplicate_file_info.file_name, file_name);
|
||||
assert_ne!(duplicate_file_info.file_name, duplicate_name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_guess_mime_type() {
|
||||
async fn test_guess_mime_type() -> anyhow::Result<()> {
|
||||
// Test common file extensions
|
||||
assert_eq!(
|
||||
FileInfo::guess_mime_type(Path::new("test.txt")),
|
||||
@@ -619,10 +618,11 @@ mod tests {
|
||||
FileInfo::guess_mime_type(Path::new("unknown.929yz")),
|
||||
"application/octet-stream".to_string()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sanitize_file_name() {
|
||||
async fn test_sanitize_file_name() -> anyhow::Result<()> {
|
||||
// Safe characters should remain unchanged
|
||||
assert_eq!(
|
||||
FileInfo::sanitize_file_name("normal_file.txt"),
|
||||
@@ -647,26 +647,26 @@ mod tests {
|
||||
FileInfo::sanitize_file_name("../dangerous.txt"),
|
||||
"___dangerous.txt"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_by_sha_not_found() {
|
||||
async fn test_get_by_sha_not_found() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Try to find a file with a SHA that doesn't exist
|
||||
let result = FileInfo::get_by_sha("nonexistent_sha_hash", &db).await;
|
||||
assert!(result.is_err());
|
||||
|
||||
match result {
|
||||
Err(FileError::FileNotFound(_)) => {
|
||||
// Expected error
|
||||
}
|
||||
_ => panic!("Expected FileNotFound error"),
|
||||
Err(FileError::FileNotFound(_)) => {}
|
||||
_ => anyhow::bail!("Expected FileNotFound error"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -705,7 +705,7 @@ mod tests {
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Create a FileInfo instance directly
|
||||
let now = Utc::now();
|
||||
@@ -725,40 +725,39 @@ mod tests {
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify it can be retrieved
|
||||
let retrieved: Option<FileInfo> = db
|
||||
.get_item(&file_info.id)
|
||||
let retrieved = db
|
||||
.get_item::<FileInfo>(&file_info.id)
|
||||
.await
|
||||
.expect("Failed to retrieve file info");
|
||||
assert!(retrieved.is_some());
|
||||
|
||||
let retrieved = retrieved.unwrap();
|
||||
.with_context(|| "Failed to retrieve file info".to_string())?
|
||||
.with_context(|| "expected file".to_string())?;
|
||||
assert_eq!(retrieved.id, file_info.id);
|
||||
assert_eq!(retrieved.sha256, file_info.sha256);
|
||||
assert_eq!(retrieved.file_name, file_info.file_name);
|
||||
assert_eq!(retrieved.path, file_info.path);
|
||||
assert_eq!(retrieved.mime_type, file_info.mime_type);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_id() {
|
||||
async fn test_delete_by_id() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
// Create and persist a test file via FileInfo::new_with_storage
|
||||
let user_id = "user123";
|
||||
let test_storage = TestStorageManager::new_memory()
|
||||
.await
|
||||
.expect("create test storage manager");
|
||||
let temp = create_test_file(b"test content", "test_file.txt");
|
||||
.with_context(|| "create test storage manager".to_string())?;
|
||||
let temp = create_test_file(b"test content", "test_file.txt")?;
|
||||
let file_info = FileInfo::new_with_storage(temp, &db, user_id, test_storage.storage())
|
||||
.await
|
||||
.expect("create file");
|
||||
.with_context(|| "create file".to_string())?;
|
||||
|
||||
// Delete the file using StorageManager
|
||||
let delete_result =
|
||||
@@ -767,15 +766,14 @@ mod tests {
|
||||
// Delete should be successful
|
||||
assert!(
|
||||
delete_result.is_ok(),
|
||||
"Failed to delete file: {:?}",
|
||||
delete_result
|
||||
"Failed to delete file: {delete_result:?}"
|
||||
);
|
||||
|
||||
// Verify the file is removed from the database
|
||||
let retrieved: Option<FileInfo> = db
|
||||
.get_item(&file_info.id)
|
||||
.await
|
||||
.expect("Failed to query database");
|
||||
.with_context(|| "Failed to query database".to_string())?;
|
||||
assert!(
|
||||
retrieved.is_none(),
|
||||
"FileInfo should be deleted from the database"
|
||||
@@ -783,32 +781,37 @@ mod tests {
|
||||
|
||||
// Verify content no longer retrievable from storage
|
||||
assert!(test_storage.storage().get(&file_info.path).await.is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_id_not_found() {
|
||||
async fn test_delete_by_id_not_found() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Try to delete a file that doesn't exist
|
||||
let test_storage = TestStorageManager::new_memory().await.unwrap();
|
||||
let test_storage = TestStorageManager::new_memory()
|
||||
.await
|
||||
.with_context(|| "create test storage manager".to_string())?;
|
||||
let result =
|
||||
FileInfo::delete_by_id_with_storage("nonexistent_id", &db, test_storage.storage())
|
||||
.await;
|
||||
|
||||
// Should succeed even if the file record does not exist
|
||||
assert!(result.is_ok());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_by_id() {
|
||||
async fn test_get_by_id() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Create a FileInfo instance directly
|
||||
let now = Utc::now();
|
||||
@@ -827,28 +830,27 @@ mod tests {
|
||||
// Store it in the database
|
||||
db.store_item(original_file_info.clone())
|
||||
.await
|
||||
.expect("Failed to store item for get_by_id test");
|
||||
.with_context(|| "Failed to store item for get_by_id test".to_string())?;
|
||||
|
||||
// Retrieve it using get_by_id
|
||||
let result = FileInfo::get_by_id(&file_id, &db).await;
|
||||
|
||||
// Assert success and content match
|
||||
assert!(result.is_ok());
|
||||
let retrieved_info = result.unwrap();
|
||||
let retrieved_info = FileInfo::get_by_id(&file_id, &db)
|
||||
.await
|
||||
.with_context(|| "get_by_id".to_string())?;
|
||||
assert_eq!(retrieved_info.id, original_file_info.id);
|
||||
assert_eq!(retrieved_info.sha256, original_file_info.sha256);
|
||||
assert_eq!(retrieved_info.file_name, original_file_info.file_name);
|
||||
assert_eq!(retrieved_info.path, original_file_info.path);
|
||||
assert_eq!(retrieved_info.mime_type, original_file_info.mime_type);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_by_id_not_found() {
|
||||
async fn test_get_by_id_not_found() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Try to retrieve a non-existent ID
|
||||
let non_existent_id = "non-existent-file-id";
|
||||
@@ -862,33 +864,34 @@ mod tests {
|
||||
Err(FileError::FileNotFound(id)) => {
|
||||
assert_eq!(id, non_existent_id);
|
||||
}
|
||||
Err(e) => panic!("Expected FileNotFound error, but got {:?}", e),
|
||||
Ok(_) => panic!("Expected an error, but got Ok"),
|
||||
Err(e) => anyhow::bail!("Expected FileNotFound error, but got {e:?}"),
|
||||
Ok(_) => anyhow::bail!("Expected an error, but got Ok"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// StorageManager-based tests
|
||||
#[tokio::test]
|
||||
async fn test_file_info_new_with_storage_memory() {
|
||||
async fn test_file_info_new_with_storage_memory() -> anyhow::Result<()> {
|
||||
// Setup
|
||||
let db = SurrealDbClient::memory("test_ns", "test_file_storage_memory")
|
||||
.await
|
||||
.unwrap();
|
||||
db.apply_migrations().await.unwrap();
|
||||
.with_context(|| "Failed to start DB".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let content = b"This is a test file for StorageManager";
|
||||
let field_data = create_test_file(content, "test_storage.txt");
|
||||
let field_data = create_test_file(content, "test_storage.txt")?;
|
||||
let user_id = "test_user";
|
||||
|
||||
// Create test storage manager
|
||||
let storage = store::testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "create test storage".to_string())?;
|
||||
|
||||
// Test file creation with StorageManager
|
||||
let file_info = FileInfo::new_with_storage(field_data, &db, user_id, storage.storage())
|
||||
.await
|
||||
.expect("Failed to create file with StorageManager");
|
||||
.with_context(|| "Failed to create file with StorageManager".to_string())?;
|
||||
|
||||
// Verify the file was created correctly
|
||||
assert_eq!(file_info.user_id, user_id);
|
||||
@@ -900,40 +903,41 @@ mod tests {
|
||||
let retrieved_content = file_info
|
||||
.get_content_with_storage(storage.storage())
|
||||
.await
|
||||
.expect("Failed to get file content with StorageManager");
|
||||
.with_context(|| "Failed to get file content with StorageManager".to_string())?;
|
||||
assert_eq!(retrieved_content.as_ref(), content);
|
||||
|
||||
// Test file deletion with StorageManager
|
||||
FileInfo::delete_by_id_with_storage(&file_info.id, &db, storage.storage())
|
||||
.await
|
||||
.expect("Failed to delete file with StorageManager");
|
||||
.with_context(|| "Failed to delete file with StorageManager".to_string())?;
|
||||
|
||||
// Verify file is deleted
|
||||
let deleted_content_result = file_info.get_content_with_storage(storage.storage()).await;
|
||||
assert!(deleted_content_result.is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_file_info_new_with_storage_local() {
|
||||
async fn test_file_info_new_with_storage_local() -> anyhow::Result<()> {
|
||||
// Setup
|
||||
let db = SurrealDbClient::memory("test_ns", "test_file_storage_local")
|
||||
.await
|
||||
.unwrap();
|
||||
db.apply_migrations().await.unwrap();
|
||||
.with_context(|| "Failed to start DB".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let content = b"This is a test file for StorageManager with local storage";
|
||||
let field_data = create_test_file(content, "test_local.txt");
|
||||
let field_data = create_test_file(content, "test_local.txt")?;
|
||||
let user_id = "test_user";
|
||||
|
||||
// Create test storage manager with local backend
|
||||
let storage = store::testing::TestStorageManager::new_local()
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "create test storage".to_string())?;
|
||||
|
||||
// Test file creation with StorageManager
|
||||
let file_info = FileInfo::new_with_storage(field_data, &db, user_id, storage.storage())
|
||||
.await
|
||||
.expect("Failed to create file with StorageManager");
|
||||
.with_context(|| "Failed to create file with StorageManager".to_string())?;
|
||||
|
||||
// Verify the file was created correctly
|
||||
assert_eq!(file_info.user_id, user_id);
|
||||
@@ -945,50 +949,51 @@ mod tests {
|
||||
let retrieved_content = file_info
|
||||
.get_content_with_storage(storage.storage())
|
||||
.await
|
||||
.expect("Failed to get file content with StorageManager");
|
||||
.with_context(|| "Failed to get file content with StorageManager".to_string())?;
|
||||
assert_eq!(retrieved_content.as_ref(), content);
|
||||
|
||||
// Test file deletion with StorageManager
|
||||
FileInfo::delete_by_id_with_storage(&file_info.id, &db, storage.storage())
|
||||
.await
|
||||
.expect("Failed to delete file with StorageManager");
|
||||
.with_context(|| "Failed to delete file with StorageManager".to_string())?;
|
||||
|
||||
// Verify file is deleted
|
||||
let deleted_content_result = file_info.get_content_with_storage(storage.storage()).await;
|
||||
assert!(deleted_content_result.is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_file_info_storage_manager_persistence() {
|
||||
async fn test_file_info_storage_manager_persistence() -> anyhow::Result<()> {
|
||||
// Setup
|
||||
let db = SurrealDbClient::memory("test_ns", "test_file_persistence")
|
||||
.await
|
||||
.unwrap();
|
||||
db.apply_migrations().await.unwrap();
|
||||
.with_context(|| "Failed to start DB".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let content = b"Test content for persistence";
|
||||
let field_data = create_test_file(content, "persistence_test.txt");
|
||||
let field_data = create_test_file(content, "persistence_test.txt")?;
|
||||
let user_id = "test_user";
|
||||
|
||||
// Create test storage manager
|
||||
let storage = store::testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "create test storage".to_string())?;
|
||||
|
||||
// Create file
|
||||
let file_info = FileInfo::new_with_storage(field_data, &db, user_id, storage.storage())
|
||||
.await
|
||||
.expect("Failed to create file");
|
||||
.with_context(|| "Failed to create file".to_string())?;
|
||||
|
||||
// Test that data persists across multiple operations with the same StorageManager
|
||||
let retrieved_content_1 = file_info
|
||||
.get_content_with_storage(storage.storage())
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "get content 1".to_string())?;
|
||||
let retrieved_content_2 = file_info
|
||||
.get_content_with_storage(storage.storage())
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "get content 2".to_string())?;
|
||||
|
||||
assert_eq!(retrieved_content_1.as_ref(), content);
|
||||
assert_eq!(retrieved_content_2.as_ref(), content);
|
||||
@@ -996,68 +1001,70 @@ mod tests {
|
||||
// Test that different StorageManager instances don't share data (memory storage isolation)
|
||||
let storage2 = store::testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "create second storage".to_string())?;
|
||||
let isolated_content_result = file_info.get_content_with_storage(storage2.storage()).await;
|
||||
assert!(
|
||||
isolated_content_result.is_err(),
|
||||
"Different StorageManager should not have access to same data"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_file_info_storage_manager_equivalence() {
|
||||
async fn test_file_info_storage_manager_equivalence() -> anyhow::Result<()> {
|
||||
// Setup
|
||||
let db = SurrealDbClient::memory("test_ns", "test_file_equivalence")
|
||||
.await
|
||||
.unwrap();
|
||||
db.apply_migrations().await.unwrap();
|
||||
.with_context(|| "Failed to start DB".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let content = b"Test content for equivalence testing";
|
||||
let field_data1 = create_test_file(content, "equivalence_test_1.txt");
|
||||
let field_data2 = create_test_file(content, "equivalence_test_2.txt");
|
||||
let field_data1 = create_test_file(content, "equivalence_test_1.txt")?;
|
||||
let field_data2 = create_test_file(content, "equivalence_test_2.txt")?;
|
||||
let user_id = "test_user";
|
||||
|
||||
// Create single storage manager and reuse it
|
||||
let storage_manager = store::testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "create storage".to_string())?;
|
||||
let storage = storage_manager.storage();
|
||||
|
||||
// Create multiple files with the same storage manager
|
||||
let file_info_1 = FileInfo::new_with_storage(field_data1, &db, user_id, &storage)
|
||||
let file_info_1 = FileInfo::new_with_storage(field_data1, &db, user_id, storage)
|
||||
.await
|
||||
.expect("Failed to create file 1");
|
||||
.with_context(|| "Failed to create file 1".to_string())?;
|
||||
|
||||
let file_info_2 = FileInfo::new_with_storage(field_data2, &db, user_id, &storage)
|
||||
let file_info_2 = FileInfo::new_with_storage(field_data2, &db, user_id, storage)
|
||||
.await
|
||||
.expect("Failed to create file 2");
|
||||
.with_context(|| "Failed to create file 2".to_string())?;
|
||||
|
||||
// Test that both files can be retrieved with the same storage backend
|
||||
let content_1 = file_info_1
|
||||
.get_content_with_storage(&storage)
|
||||
.get_content_with_storage(storage)
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "get file 1 content".to_string())?;
|
||||
let content_2 = file_info_2
|
||||
.get_content_with_storage(&storage)
|
||||
.get_content_with_storage(storage)
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "get file 2 content".to_string())?;
|
||||
|
||||
assert_eq!(content_1.as_ref(), content);
|
||||
assert_eq!(content_2.as_ref(), content);
|
||||
|
||||
// Test that files can be deleted with the same storage manager
|
||||
FileInfo::delete_by_id_with_storage(&file_info_1.id, &db, &storage)
|
||||
FileInfo::delete_by_id_with_storage(&file_info_1.id, &db, storage)
|
||||
.await
|
||||
.unwrap();
|
||||
FileInfo::delete_by_id_with_storage(&file_info_2.id, &db, &storage)
|
||||
.with_context(|| "delete file 1".to_string())?;
|
||||
FileInfo::delete_by_id_with_storage(&file_info_2.id, &db, storage)
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "delete file 2".to_string())?;
|
||||
|
||||
// Verify files are deleted
|
||||
let deleted_content_1 = file_info_1.get_content_with_storage(&storage).await;
|
||||
let deleted_content_2 = file_info_2.get_content_with_storage(&storage).await;
|
||||
let deleted_content_1 = file_info_1.get_content_with_storage(storage).await;
|
||||
let deleted_content_2 = file_info_2.get_content_with_storage(storage).await;
|
||||
|
||||
assert!(deleted_content_1.is_err());
|
||||
assert!(deleted_content_2.is_err());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,6 +103,7 @@ impl IngestionPayload {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
use chrono::Utc;
|
||||
|
||||
use super::*;
|
||||
@@ -131,7 +132,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_ingestion_payload_with_url() {
|
||||
fn test_create_ingestion_payload_with_url() -> anyhow::Result<()> {
|
||||
let url = "https://example.com";
|
||||
let context = "Process this URL";
|
||||
let category = "websites";
|
||||
@@ -145,10 +146,10 @@ mod tests {
|
||||
files,
|
||||
user_id,
|
||||
)
|
||||
.unwrap();
|
||||
.with_context(|| "create_ingestion_payload".to_string())?;
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
match &result[0] {
|
||||
match result.first().context("expected one result")? {
|
||||
IngestionPayload::Url {
|
||||
url: payload_url,
|
||||
context: payload_context,
|
||||
@@ -156,17 +157,18 @@ mod tests {
|
||||
user_id: payload_user_id,
|
||||
} => {
|
||||
// URL parser may normalize the URL by adding a trailing slash
|
||||
assert!(payload_url == &url.to_string() || payload_url == &format!("{}/", url));
|
||||
assert!(payload_url == &url.to_string() || payload_url == &format!("{url}/"));
|
||||
assert_eq!(payload_context, &context);
|
||||
assert_eq!(payload_category, &category);
|
||||
assert_eq!(payload_user_id, &user_id);
|
||||
}
|
||||
_ => panic!("Expected Url variant"),
|
||||
_ => anyhow::bail!("Expected Url variant"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_ingestion_payload_with_text() {
|
||||
fn test_create_ingestion_payload_with_text() -> anyhow::Result<()> {
|
||||
let text = "This is some text content";
|
||||
let context = "Process this text";
|
||||
let category = "notes";
|
||||
@@ -180,10 +182,10 @@ mod tests {
|
||||
files,
|
||||
user_id,
|
||||
)
|
||||
.unwrap();
|
||||
.with_context(|| "create_ingestion_payload".to_string())?;
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
match &result[0] {
|
||||
match result.first().context("expected one result")? {
|
||||
IngestionPayload::Text {
|
||||
text: payload_text,
|
||||
context: payload_context,
|
||||
@@ -195,12 +197,13 @@ mod tests {
|
||||
assert_eq!(payload_category, category);
|
||||
assert_eq!(payload_user_id, user_id);
|
||||
}
|
||||
_ => panic!("Expected Text variant"),
|
||||
_ => anyhow::bail!("Expected Text variant"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_ingestion_payload_with_file() {
|
||||
fn test_create_ingestion_payload_with_file() -> anyhow::Result<()> {
|
||||
let context = "Process this file";
|
||||
let category = "documents";
|
||||
let user_id = "user123";
|
||||
@@ -220,10 +223,10 @@ mod tests {
|
||||
files,
|
||||
user_id,
|
||||
)
|
||||
.unwrap();
|
||||
.with_context(|| "create_ingestion_payload".to_string())?;
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
match &result[0] {
|
||||
match result.first().context("expected one result")? {
|
||||
IngestionPayload::File {
|
||||
file_info: payload_file_info,
|
||||
context: payload_context,
|
||||
@@ -235,12 +238,13 @@ mod tests {
|
||||
assert_eq!(payload_category, category);
|
||||
assert_eq!(payload_user_id, user_id);
|
||||
}
|
||||
_ => panic!("Expected File variant"),
|
||||
_ => anyhow::bail!("Expected File variant"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_ingestion_payload_with_url_and_file() {
|
||||
fn test_create_ingestion_payload_with_url_and_file() -> anyhow::Result<()> {
|
||||
let url = "https://example.com";
|
||||
let context = "Process this data";
|
||||
let category = "mixed";
|
||||
@@ -261,35 +265,36 @@ mod tests {
|
||||
files,
|
||||
user_id,
|
||||
)
|
||||
.unwrap();
|
||||
.with_context(|| "create_ingestion_payload".to_string())?;
|
||||
|
||||
assert_eq!(result.len(), 2);
|
||||
|
||||
// Check first item is URL
|
||||
match &result[0] {
|
||||
match result.first().context("expected first item")? {
|
||||
IngestionPayload::Url {
|
||||
url: payload_url, ..
|
||||
} => {
|
||||
// URL parser may normalize the URL by adding a trailing slash
|
||||
assert!(payload_url == &url.to_string() || payload_url == &format!("{}/", url));
|
||||
assert!(payload_url == &url.to_string() || payload_url == &format!("{url}/"));
|
||||
}
|
||||
_ => panic!("Expected first item to be Url variant"),
|
||||
_ => anyhow::bail!("Expected first item to be Url variant"),
|
||||
}
|
||||
|
||||
// Check second item is File
|
||||
match &result[1] {
|
||||
match result.get(1).context("expected second item")? {
|
||||
IngestionPayload::File {
|
||||
file_info: payload_file_info,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(payload_file_info.id, file_info.id);
|
||||
}
|
||||
_ => panic!("Expected second item to be File variant"),
|
||||
_ => anyhow::bail!("Expected second item to be File variant"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_ingestion_payload_empty_input() {
|
||||
fn test_create_ingestion_payload_empty_input() -> anyhow::Result<()> {
|
||||
let context = "Process something";
|
||||
let category = "empty";
|
||||
let user_id = "user123";
|
||||
@@ -308,12 +313,13 @@ mod tests {
|
||||
Err(AppError::NotFound(msg)) => {
|
||||
assert_eq!(msg, "No valid content or files provided");
|
||||
}
|
||||
_ => panic!("Expected NotFound error"),
|
||||
_ => anyhow::bail!("Expected NotFound error"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_ingestion_payload_with_empty_text() {
|
||||
fn test_create_ingestion_payload_with_empty_text() -> anyhow::Result<()> {
|
||||
let text = ""; // Empty text
|
||||
let context = "Process this";
|
||||
let category = "notes";
|
||||
@@ -333,7 +339,8 @@ mod tests {
|
||||
Err(AppError::NotFound(msg)) => {
|
||||
assert_eq!(msg, "No valid content or files provided");
|
||||
}
|
||||
_ => panic!("Expected NotFound error"),
|
||||
_ => anyhow::bail!("Expected NotFound error"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -529,6 +529,8 @@ impl IngestionTask {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
|
||||
use super::*;
|
||||
use crate::storage::types::ingestion_payload::IngestionPayload;
|
||||
|
||||
@@ -541,16 +543,16 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
async fn memory_db() -> SurrealDbClient {
|
||||
async fn memory_db() -> anyhow::Result<SurrealDbClient> {
|
||||
let namespace = "test_ns";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.expect("in-memory surrealdb")
|
||||
.with_context(|| "in-memory surrealdb".to_string())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_new_task_defaults() {
|
||||
async fn test_new_task_defaults() -> anyhow::Result<()> {
|
||||
let user_id = "user123";
|
||||
let payload = create_payload(user_id);
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
@@ -562,73 +564,76 @@ mod tests {
|
||||
assert_eq!(task.max_attempts, MAX_ATTEMPTS);
|
||||
assert!(task.locked_at.is_none());
|
||||
assert!(task.worker_id.is_none());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_and_store_task() {
|
||||
let db = memory_db().await;
|
||||
async fn test_create_and_store_task() -> anyhow::Result<()> {
|
||||
let db = memory_db().await?;
|
||||
let user_id = "user123";
|
||||
let payload = create_payload(user_id);
|
||||
|
||||
let created =
|
||||
IngestionTask::create_and_add_to_db(payload.clone(), user_id.to_string(), &db)
|
||||
.await
|
||||
.expect("store");
|
||||
.with_context(|| "store".to_string())?;
|
||||
|
||||
let stored: Option<IngestionTask> = db
|
||||
.get_item::<IngestionTask>(&created.id)
|
||||
.await
|
||||
.expect("fetch");
|
||||
.with_context(|| "fetch".to_string())?;
|
||||
|
||||
let stored = stored.expect("task exists");
|
||||
let stored = stored.with_context(|| "task exists".to_string())?;
|
||||
assert_eq!(stored.id, created.id);
|
||||
assert_eq!(stored.state, TaskState::Pending);
|
||||
assert_eq!(stored.attempts, 0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_claim_and_transition() {
|
||||
let db = memory_db().await;
|
||||
async fn test_claim_and_transition() -> anyhow::Result<()> {
|
||||
let db = memory_db().await?;
|
||||
let user_id = "user123";
|
||||
let payload = create_payload(user_id);
|
||||
let task = IngestionTask::new(payload, user_id.to_string());
|
||||
db.store_item(task.clone()).await.expect("store");
|
||||
db.store_item(task.clone()).await.with_context(|| "store".to_string())?;
|
||||
|
||||
let worker_id = "worker-1";
|
||||
let now = chrono::Utc::now();
|
||||
let claimed = IngestionTask::claim_next_ready(&db, worker_id, now, Duration::from_secs(60))
|
||||
.await
|
||||
.expect("claim");
|
||||
.with_context(|| "claim".to_string())?
|
||||
.with_context(|| "task claimed".to_string())?;
|
||||
|
||||
let claimed = claimed.expect("task claimed");
|
||||
assert_eq!(claimed.state, TaskState::Reserved);
|
||||
assert_eq!(claimed.worker_id.as_deref(), Some(worker_id));
|
||||
|
||||
let processing = claimed.mark_processing(&db).await.expect("processing");
|
||||
let processing = claimed.mark_processing(&db).await.with_context(|| "processing".to_string())?;
|
||||
assert_eq!(processing.state, TaskState::Processing);
|
||||
|
||||
let succeeded = processing.mark_succeeded(&db).await.expect("succeeded");
|
||||
let succeeded = processing.mark_succeeded(&db).await.with_context(|| "succeeded".to_string())?;
|
||||
assert_eq!(succeeded.state, TaskState::Succeeded);
|
||||
assert!(succeeded.worker_id.is_none());
|
||||
assert!(succeeded.locked_at.is_none());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fail_and_dead_letter() {
|
||||
let db = memory_db().await;
|
||||
async fn test_fail_and_dead_letter() -> anyhow::Result<()> {
|
||||
let db = memory_db().await?;
|
||||
let user_id = "user123";
|
||||
let payload = create_payload(user_id);
|
||||
let task = IngestionTask::new(payload, user_id.to_string());
|
||||
db.store_item(task.clone()).await.expect("store");
|
||||
db.store_item(task.clone()).await.with_context(|| "store".to_string())?;
|
||||
|
||||
let worker_id = "worker-dead";
|
||||
let now = chrono::Utc::now();
|
||||
let claimed = IngestionTask::claim_next_ready(&db, worker_id, now, Duration::from_secs(60))
|
||||
.await
|
||||
.expect("claim")
|
||||
.expect("claimed");
|
||||
.with_context(|| "claim".to_string())?
|
||||
.with_context(|| "claimed".to_string())?;
|
||||
|
||||
let processing = claimed.mark_processing(&db).await.expect("processing");
|
||||
let processing = claimed.mark_processing(&db).await.with_context(|| "processing".to_string())?;
|
||||
|
||||
let error_info = TaskErrorInfo {
|
||||
code: Some("pipeline_error".into()),
|
||||
@@ -638,7 +643,7 @@ mod tests {
|
||||
let failed = processing
|
||||
.mark_failed(error_info.clone(), Duration::from_secs(30), &db)
|
||||
.await
|
||||
.expect("failed update");
|
||||
.with_context(|| "failed update".to_string())?;
|
||||
assert_eq!(failed.state, TaskState::Failed);
|
||||
assert_eq!(failed.error_message.as_deref(), Some("failed"));
|
||||
assert!(failed.worker_id.is_none());
|
||||
@@ -648,19 +653,20 @@ mod tests {
|
||||
let dead = failed
|
||||
.mark_dead_letter(error_info.clone(), &db)
|
||||
.await
|
||||
.expect("dead letter");
|
||||
.with_context(|| "dead letter".to_string())?;
|
||||
assert_eq!(dead.state, TaskState::DeadLetter);
|
||||
assert_eq!(dead.error_message.as_deref(), Some("failed"));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mark_processing_requires_reservation() {
|
||||
let db = memory_db().await;
|
||||
async fn test_mark_processing_requires_reservation() -> anyhow::Result<()> {
|
||||
let db = memory_db().await?;
|
||||
let user_id = "user123";
|
||||
let payload = create_payload(user_id);
|
||||
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
db.store_item(task.clone()).await.expect("store");
|
||||
db.store_item(task.clone()).await.with_context(|| "store".to_string())?;
|
||||
|
||||
let err = task
|
||||
.mark_processing(&db)
|
||||
@@ -674,18 +680,19 @@ mod tests {
|
||||
"unexpected message: {message}"
|
||||
);
|
||||
}
|
||||
other => panic!("expected validation error, got {other:?}"),
|
||||
other => anyhow::bail!("expected validation error, got {other:?}"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mark_failed_requires_processing() {
|
||||
let db = memory_db().await;
|
||||
async fn test_mark_failed_requires_processing() -> anyhow::Result<()> {
|
||||
let db = memory_db().await?;
|
||||
let user_id = "user123";
|
||||
let payload = create_payload(user_id);
|
||||
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
db.store_item(task.clone()).await.expect("store");
|
||||
db.store_item(task.clone()).await.with_context(|| "store".to_string())?;
|
||||
|
||||
let err = task
|
||||
.mark_failed(
|
||||
@@ -706,18 +713,19 @@ mod tests {
|
||||
"unexpected message: {message}"
|
||||
);
|
||||
}
|
||||
other => panic!("expected validation error, got {other:?}"),
|
||||
other => anyhow::bail!("expected validation error, got {other:?}"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_release_requires_reservation() {
|
||||
let db = memory_db().await;
|
||||
async fn test_release_requires_reservation() -> anyhow::Result<()> {
|
||||
let db = memory_db().await?;
|
||||
let user_id = "user123";
|
||||
let payload = create_payload(user_id);
|
||||
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
db.store_item(task.clone()).await.expect("store");
|
||||
db.store_item(task.clone()).await.with_context(|| "store".to_string())?;
|
||||
|
||||
let err = task
|
||||
.release(&db)
|
||||
@@ -731,7 +739,8 @@ mod tests {
|
||||
"unexpected message: {message}"
|
||||
);
|
||||
}
|
||||
other => panic!("expected validation error, got {other:?}"),
|
||||
other => anyhow::bail!("expected validation error, got {other:?}"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
clippy::format_push_string,
|
||||
clippy::uninlined_format_args,
|
||||
clippy::explicit_iter_loop,
|
||||
clippy::items_after_statements,
|
||||
clippy::get_first,
|
||||
clippy::redundant_closure_for_method_calls
|
||||
)]
|
||||
@@ -317,6 +316,11 @@ impl KnowledgeEntity {
|
||||
}
|
||||
|
||||
async fn get_user_id_by_id(id: &str, db_client: &SurrealDbClient) -> Result<String, AppError> {
|
||||
#[derive(Deserialize)]
|
||||
struct Row {
|
||||
user_id: String,
|
||||
}
|
||||
|
||||
let mut response = db_client
|
||||
.client
|
||||
.query("SELECT user_id FROM type::thing($table, $id) LIMIT 1")
|
||||
@@ -324,10 +328,6 @@ impl KnowledgeEntity {
|
||||
.bind(("id", id.to_string()))
|
||||
.await
|
||||
.map_err(AppError::Database)?;
|
||||
#[derive(Deserialize)]
|
||||
struct Row {
|
||||
user_id: String,
|
||||
}
|
||||
let rows: Vec<Row> = response.take(0).map_err(AppError::Database)?;
|
||||
rows.get(0)
|
||||
.map(|r| r.user_id.clone())
|
||||
@@ -497,7 +497,6 @@ impl KnowledgeEntity {
|
||||
new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone()));
|
||||
}
|
||||
info!("Successfully generated all new embeddings.");
|
||||
info!("Successfully generated all new embeddings.");
|
||||
|
||||
// Clear existing embeddings and index first to prevent SurrealDB panics and dimension conflicts.
|
||||
info!("Removing old index and clearing embeddings...");
|
||||
@@ -572,14 +571,14 @@ impl KnowledgeEntity {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
use super::*;
|
||||
use crate::storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding;
|
||||
use serde_json::json;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_knowledge_entity_creation() {
|
||||
// Create basic test entity
|
||||
async fn test_knowledge_entity_creation() -> anyhow::Result<()> {
|
||||
let source_id = "source123".to_string();
|
||||
let name = "Test Entity".to_string();
|
||||
let description = "Test Description".to_string();
|
||||
@@ -596,7 +595,6 @@ mod tests {
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Verify all fields are set correctly
|
||||
assert_eq!(entity.source_id, source_id);
|
||||
assert_eq!(entity.name, name);
|
||||
assert_eq!(entity.description, description);
|
||||
@@ -604,11 +602,12 @@ mod tests {
|
||||
assert_eq!(entity.metadata, metadata);
|
||||
assert_eq!(entity.user_id, user_id);
|
||||
assert!(!entity.id.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_knowledge_entity_type_from_string() {
|
||||
// Test conversion from String to KnowledgeEntityType
|
||||
async fn test_knowledge_entity_type_from_string() -> anyhow::Result<()> {
|
||||
assert_eq!(
|
||||
KnowledgeEntityType::from("idea".to_string()),
|
||||
KnowledgeEntityType::Idea
|
||||
@@ -639,15 +638,16 @@ mod tests {
|
||||
KnowledgeEntityType::TextSnippet
|
||||
);
|
||||
|
||||
// Test default case
|
||||
assert_eq!(
|
||||
KnowledgeEntityType::from("unknown".to_string()),
|
||||
KnowledgeEntityType::Document
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_knowledge_entity_variants() {
|
||||
async fn test_knowledge_entity_variants() -> anyhow::Result<()> {
|
||||
let variants = KnowledgeEntityType::variants();
|
||||
assert_eq!(variants.len(), 5);
|
||||
assert!(variants.contains(&"Idea"));
|
||||
@@ -655,28 +655,28 @@ mod tests {
|
||||
assert!(variants.contains(&"Document"));
|
||||
assert!(variants.contains(&"Page"));
|
||||
assert!(variants.contains(&"TextSnippet"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id() {
|
||||
// Setup in-memory database for testing
|
||||
async fn test_delete_by_source_id() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
// Create two entities with the same source_id
|
||||
let source_id = "source123".to_string();
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 5)
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
|
||||
let entity1 = KnowledgeEntity::new(
|
||||
source_id.clone(),
|
||||
@@ -696,7 +696,6 @@ mod tests {
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Create an entity with a different source_id
|
||||
let different_source_id = "different_source".to_string();
|
||||
let different_entity = KnowledgeEntity::new(
|
||||
different_source_id.clone(),
|
||||
@@ -708,23 +707,20 @@ mod tests {
|
||||
);
|
||||
|
||||
let emb = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
// Store the entities
|
||||
KnowledgeEntity::store_with_embedding(entity1.clone(), emb.clone(), &db)
|
||||
.await
|
||||
.expect("Failed to store entity 1");
|
||||
.with_context(|| "Failed to store entity 1".to_string())?;
|
||||
KnowledgeEntity::store_with_embedding(entity2.clone(), emb.clone(), &db)
|
||||
.await
|
||||
.expect("Failed to store entity 2");
|
||||
.with_context(|| "Failed to store entity 2".to_string())?;
|
||||
KnowledgeEntity::store_with_embedding(different_entity.clone(), emb.clone(), &db)
|
||||
.await
|
||||
.expect("Failed to store different entity");
|
||||
.with_context(|| "Failed to store different entity".to_string())?;
|
||||
|
||||
// Delete by source_id
|
||||
KnowledgeEntity::delete_by_source_id(&source_id, &db)
|
||||
.await
|
||||
.expect("Failed to delete entities by source_id");
|
||||
.with_context(|| "Failed to delete entities by source_id".to_string())?;
|
||||
|
||||
// Verify all entities with the specified source_id are deleted
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
KnowledgeEntity::table_name(),
|
||||
@@ -734,16 +730,11 @@ mod tests {
|
||||
.client
|
||||
.query(query)
|
||||
.await
|
||||
.expect("Query failed")
|
||||
.with_context(|| "Query failed".to_string())?
|
||||
.take(0)
|
||||
.expect("Failed to get query results");
|
||||
assert_eq!(
|
||||
remaining.len(),
|
||||
0,
|
||||
"All entities with the source_id should be deleted"
|
||||
);
|
||||
.with_context(|| "Failed to get query results".to_string())?;
|
||||
assert!(remaining.is_empty(), "All entities with the source_id should be deleted");
|
||||
|
||||
// Verify the entity with a different source_id still exists
|
||||
let different_query = format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
KnowledgeEntity::table_name(),
|
||||
@@ -753,15 +744,20 @@ mod tests {
|
||||
.client
|
||||
.query(different_query)
|
||||
.await
|
||||
.expect("Query failed")
|
||||
.with_context(|| "Query failed".to_string())?
|
||||
.take(0)
|
||||
.expect("Failed to get query results");
|
||||
.with_context(|| "Failed to get query results".to_string())?;
|
||||
assert_eq!(
|
||||
different_remaining.len(),
|
||||
1,
|
||||
"Entity with different source_id should still exist"
|
||||
);
|
||||
assert_eq!(different_remaining[0].id, different_entity.id);
|
||||
assert_eq!(
|
||||
different_remaining.first().context("Expected entity to exist")?.id,
|
||||
different_entity.id
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -833,35 +829,37 @@ mod tests {
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
|
||||
let results = KnowledgeEntity::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user")
|
||||
.await
|
||||
.expect("vector search");
|
||||
.with_context(|| "vector search".to_string())?;
|
||||
assert!(results.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_single_result() {
|
||||
async fn test_vector_search_single_result() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
|
||||
let user_id = "user".to_string();
|
||||
let source_id = "src".to_string();
|
||||
@@ -876,9 +874,12 @@ mod tests {
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
.await
|
||||
.expect("store entity with embedding");
|
||||
.with_context(|| "store entity with embedding".to_string())?;
|
||||
|
||||
let stored_entity: Option<KnowledgeEntity> = db.get_item(&entity.id).await.unwrap();
|
||||
let stored_entity: Option<KnowledgeEntity> = db
|
||||
.get_item(&entity.id)
|
||||
.await
|
||||
.with_context(|| "Failed to get entity".to_string())?;
|
||||
assert!(stored_entity.is_some());
|
||||
|
||||
let stored_embeddings: Vec<KnowledgeEntityEmbedding> = db
|
||||
@@ -888,42 +889,44 @@ mod tests {
|
||||
KnowledgeEntityEmbedding::table_name()
|
||||
))
|
||||
.await
|
||||
.expect("query embeddings")
|
||||
.with_context(|| "query embeddings".to_string())?
|
||||
.take(0)
|
||||
.expect("take embeddings");
|
||||
.with_context(|| "take embeddings".to_string())?;
|
||||
assert_eq!(stored_embeddings.len(), 1);
|
||||
|
||||
let rid = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
let fetched_emb = KnowledgeEntityEmbedding::get_by_entity_id(&rid, &db)
|
||||
.await
|
||||
.expect("fetch embedding");
|
||||
.with_context(|| "fetch embedding".to_string())?;
|
||||
assert!(fetched_emb.is_some());
|
||||
|
||||
let results = KnowledgeEntity::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
|
||||
.await
|
||||
.expect("vector search");
|
||||
.with_context(|| "vector search".to_string())?;
|
||||
|
||||
assert_eq!(results.len(), 1);
|
||||
let res = &results[0];
|
||||
let res = results.first().context("Expected at least one result")?;
|
||||
assert_eq!(res.entity.id, entity.id);
|
||||
assert_eq!(res.entity.source_id, source_id);
|
||||
assert_eq!(res.entity.name, "hello");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_orders_by_similarity() {
|
||||
async fn test_vector_search_orders_by_similarity() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
|
||||
let user_id = "user".to_string();
|
||||
let e1 = KnowledgeEntity::new(
|
||||
@@ -945,13 +948,19 @@ mod tests {
|
||||
|
||||
KnowledgeEntity::store_with_embedding(e1.clone(), vec![1.0, 0.0, 0.0], &db)
|
||||
.await
|
||||
.expect("store e1");
|
||||
.with_context(|| "store e1".to_string())?;
|
||||
KnowledgeEntity::store_with_embedding(e2.clone(), vec![0.0, 1.0, 0.0], &db)
|
||||
.await
|
||||
.expect("store e2");
|
||||
.with_context(|| "store e2".to_string())?;
|
||||
|
||||
let stored_e1: Option<KnowledgeEntity> = db.get_item(&e1.id).await.unwrap();
|
||||
let stored_e2: Option<KnowledgeEntity> = db.get_item(&e2.id).await.unwrap();
|
||||
let stored_e1: Option<KnowledgeEntity> = db
|
||||
.get_item(&e1.id)
|
||||
.await
|
||||
.with_context(|| "Failed to get entity".to_string())?;
|
||||
let stored_e2: Option<KnowledgeEntity> = db
|
||||
.get_item(&e2.id)
|
||||
.await
|
||||
.with_context(|| "Failed to get entity".to_string())?;
|
||||
assert!(stored_e1.is_some() && stored_e2.is_some());
|
||||
|
||||
let stored_embeddings: Vec<KnowledgeEntityEmbedding> = db
|
||||
@@ -961,45 +970,53 @@ mod tests {
|
||||
KnowledgeEntityEmbedding::table_name()
|
||||
))
|
||||
.await
|
||||
.expect("query embeddings")
|
||||
.with_context(|| "query embeddings".to_string())?
|
||||
.take(0)
|
||||
.expect("take embeddings");
|
||||
.with_context(|| "take embeddings".to_string())?;
|
||||
assert_eq!(stored_embeddings.len(), 2);
|
||||
|
||||
let rid_e1 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e1.id);
|
||||
let rid_e2 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e2.id);
|
||||
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e1, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.with_context(|| "get embedding e1".to_string())?
|
||||
.is_some());
|
||||
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e2, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.with_context(|| "get embedding e2".to_string())?
|
||||
.is_some());
|
||||
|
||||
let results = KnowledgeEntity::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id)
|
||||
.await
|
||||
.expect("vector search");
|
||||
.with_context(|| "vector search".to_string())?;
|
||||
|
||||
assert_eq!(results.len(), 2);
|
||||
assert_eq!(results[0].entity.id, e2.id);
|
||||
assert_eq!(results[1].entity.id, e1.id);
|
||||
assert_eq!(
|
||||
results.first().context("Expected at least one result")?.entity.id,
|
||||
e2.id
|
||||
);
|
||||
assert_eq!(
|
||||
results.get(1).context("Expected at least two results")?.entity.id,
|
||||
e1.id
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_with_orphaned_embedding() {
|
||||
async fn test_vector_search_with_orphaned_embedding() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns_orphan";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
|
||||
let user_id = "user".to_string();
|
||||
let source_id = "src".to_string();
|
||||
@@ -1014,21 +1031,20 @@ mod tests {
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
.await
|
||||
.expect("store entity with embedding");
|
||||
.with_context(|| "store entity with embedding".to_string())?;
|
||||
|
||||
// Manually delete the entity to create an orphan
|
||||
let query = format!("DELETE type::thing('knowledge_entity', '{}')", entity.id);
|
||||
db.client.query(query).await.expect("delete entity");
|
||||
db.client
|
||||
.query(query)
|
||||
.await
|
||||
.with_context(|| "delete entity".to_string())?;
|
||||
|
||||
// Now search
|
||||
let results = KnowledgeEntity::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
|
||||
.await
|
||||
.expect("search should succeed even with orphans");
|
||||
.with_context(|| "search should succeed even with orphans".to_string())?;
|
||||
|
||||
assert!(
|
||||
results.is_empty(),
|
||||
"Should return empty result for orphan, got: {:?}",
|
||||
results
|
||||
);
|
||||
assert!(results.is_empty(), "Should return empty result for orphan, got: {:?}", results);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -110,11 +110,15 @@ impl KnowledgeEntityEmbedding {
|
||||
}
|
||||
|
||||
/// Delete embeddings by source_id (via joining to knowledge_entity table)
|
||||
#[allow(clippy::items_after_statements)]
|
||||
pub async fn delete_by_source_id(
|
||||
source_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
#[derive(Deserialize)]
|
||||
struct IdRow {
|
||||
id: RecordId,
|
||||
}
|
||||
|
||||
let query = "SELECT id FROM knowledge_entity WHERE source_id = $source_id";
|
||||
let mut res = db
|
||||
.client
|
||||
@@ -122,11 +126,6 @@ impl KnowledgeEntityEmbedding {
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await
|
||||
.map_err(AppError::Database)?;
|
||||
#[allow(clippy::missing_docs_in_private_items)]
|
||||
#[derive(Deserialize)]
|
||||
struct IdRow {
|
||||
id: RecordId,
|
||||
}
|
||||
let ids: Vec<IdRow> = res.take(0).map_err(AppError::Database)?;
|
||||
|
||||
for row in ids {
|
||||
@@ -138,6 +137,7 @@ impl KnowledgeEntityEmbedding {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
use super::*;
|
||||
use crate::storage::db::SurrealDbClient;
|
||||
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||
@@ -145,18 +145,18 @@ mod tests {
|
||||
use surrealdb::Value as SurrealValue;
|
||||
use uuid::Uuid;
|
||||
|
||||
async fn setup_test_db() -> SurrealDbClient {
|
||||
async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
|
||||
let namespace = "test_ns";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
db
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
fn build_knowledge_entity_with_id(
|
||||
@@ -178,11 +178,11 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_and_get_by_entity_id() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_create_and_get_by_entity_id() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("set test index dimension");
|
||||
.with_context(|| "set test index dimension".to_string())?;
|
||||
let user_id = "user_ke";
|
||||
let entity_key = "entity-1";
|
||||
let source_id = "source-ke";
|
||||
@@ -192,26 +192,28 @@ mod tests {
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), embedding_vec.clone(), &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
|
||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
|
||||
let fetched = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
.await
|
||||
.expect("Failed to get embedding by entity_id")
|
||||
.expect("Expected embedding to exist");
|
||||
.with_context(|| "Failed to get embedding by entity_id".to_string())?
|
||||
.ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?;
|
||||
|
||||
assert_eq!(fetched.user_id, user_id);
|
||||
assert_eq!(fetched.entity_id, entity_rid);
|
||||
assert_eq!(fetched.embedding, embedding_vec);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_entity_id() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_delete_by_entity_id() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("set test index dimension");
|
||||
.with_context(|| "set test index dimension".to_string())?;
|
||||
let user_id = "user_ke";
|
||||
let entity_key = "entity-delete";
|
||||
let source_id = "source-del";
|
||||
@@ -220,61 +222,67 @@ mod tests {
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.5_f32, 0.6, 0.7], &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
|
||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
|
||||
let existing = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
.await
|
||||
.expect("Failed to get embedding before delete");
|
||||
.with_context(|| "Failed to get embedding before delete".to_string())?;
|
||||
assert!(existing.is_some());
|
||||
|
||||
KnowledgeEntityEmbedding::delete_by_entity_id(&entity_rid, &db)
|
||||
.await
|
||||
.expect("Failed to delete by entity_id");
|
||||
.with_context(|| "Failed to delete by entity_id".to_string())?;
|
||||
|
||||
let after = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
.await
|
||||
.expect("Failed to get embedding after delete");
|
||||
.with_context(|| "Failed to get embedding after delete".to_string())?;
|
||||
assert!(after.is_none());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_with_embedding_creates_entity_and_embedding() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_store_with_embedding_creates_entity_and_embedding() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "user_store";
|
||||
let source_id = "source_store";
|
||||
let embedding = vec![0.2_f32, 0.3, 0.4];
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, embedding.len())
|
||||
.await
|
||||
.expect("set test index dimension");
|
||||
.with_context(|| "set test index dimension".to_string())?;
|
||||
|
||||
let entity = build_knowledge_entity_with_id("entity-store", source_id, user_id);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), embedding.clone(), &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
|
||||
let stored_entity: Option<KnowledgeEntity> = db.get_item(&entity.id).await.unwrap();
|
||||
let stored_entity: Option<KnowledgeEntity> = db
|
||||
.get_item(&entity.id)
|
||||
.await
|
||||
.with_context(|| "Failed to get entity".to_string())?;
|
||||
assert!(stored_entity.is_some());
|
||||
|
||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
let stored_embedding = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
.await
|
||||
.expect("Failed to fetch embedding");
|
||||
assert!(stored_embedding.is_some());
|
||||
let stored_embedding = stored_embedding.unwrap();
|
||||
.with_context(|| "Failed to fetch embedding".to_string())?;
|
||||
let stored_embedding = stored_embedding.ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?;
|
||||
assert_eq!(stored_embedding.user_id, user_id);
|
||||
assert_eq!(stored_embedding.entity_id, entity_rid);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_delete_by_source_id() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("set test index dimension");
|
||||
.with_context(|| "set test index dimension".to_string())?;
|
||||
let user_id = "user_ke";
|
||||
let source_id = "shared-ke";
|
||||
let other_source = "other-ke";
|
||||
@@ -285,13 +293,13 @@ mod tests {
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity1.clone(), vec![1.0_f32, 1.1, 1.2], &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
KnowledgeEntity::store_with_embedding(entity2.clone(), vec![2.0_f32, 2.1, 2.2], &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
KnowledgeEntity::store_with_embedding(entity_other.clone(), vec![3.0_f32, 3.1, 3.2], &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
|
||||
let entity1_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity1.id);
|
||||
let entity2_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity2.id);
|
||||
@@ -299,59 +307,74 @@ mod tests {
|
||||
|
||||
KnowledgeEntityEmbedding::delete_by_source_id(source_id, &db)
|
||||
.await
|
||||
.expect("Failed to delete by source_id");
|
||||
.with_context(|| "Failed to delete by source_id".to_string())?;
|
||||
|
||||
assert!(
|
||||
KnowledgeEntityEmbedding::get_by_entity_id(&entity1_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.with_context(|| "get entity1 embedding after delete".to_string())?
|
||||
.is_none()
|
||||
);
|
||||
assert!(
|
||||
KnowledgeEntityEmbedding::get_by_entity_id(&entity2_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.with_context(|| "get entity2 embedding after delete".to_string())?
|
||||
.is_none()
|
||||
);
|
||||
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&other_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.with_context(|| "get other embedding after delete".to_string())?
|
||||
.is_some());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_redefine_hnsw_index_updates_dimension() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_redefine_hnsw_index_updates_dimension() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 16)
|
||||
.await
|
||||
.expect("failed to redefine index");
|
||||
.with_context(|| "failed to redefine index".to_string())?;
|
||||
|
||||
let mut info_res = db
|
||||
.client
|
||||
.query("INFO FOR TABLE knowledge_entity_embedding;")
|
||||
.await
|
||||
.expect("info query failed");
|
||||
let info: SurrealValue = info_res.take(0).expect("failed to take info result");
|
||||
let info_json: serde_json::Value =
|
||||
serde_json::to_value(info).expect("failed to convert info to json");
|
||||
let idx_sql = info_json["Object"]["indexes"]["Object"]
|
||||
["idx_embedding_knowledge_entity_embedding"]["Strand"]
|
||||
.as_str()
|
||||
.with_context(|| "info query failed".to_string())?;
|
||||
let info: SurrealValue = info_res
|
||||
.take(0)
|
||||
.with_context(|| "failed to take info result".to_string())?;
|
||||
let info_json: serde_json::Value = serde_json::to_value(info)
|
||||
.with_context(|| "failed to convert info to json".to_string())?;
|
||||
let idx_sql = info_json
|
||||
.get("Object")
|
||||
.and_then(|v| v.get("indexes"))
|
||||
.and_then(|v| v.get("Object"))
|
||||
.and_then(|v| v.get("idx_embedding_knowledge_entity_embedding"))
|
||||
.and_then(|v| v.get("Strand"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default();
|
||||
|
||||
assert!(
|
||||
idx_sql.contains("DIMENSION 16"),
|
||||
"expected index definition to contain new dimension, got: {idx_sql}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fetch_entity_via_record_id() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_fetch_entity_via_record_id() -> anyhow::Result<()> {
|
||||
#[derive(Deserialize)]
|
||||
struct Row {
|
||||
entity_id: KnowledgeEntity,
|
||||
}
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("set test index dimension");
|
||||
.with_context(|| "set test index dimension".to_string())?;
|
||||
let user_id = "user_ke";
|
||||
let entity_key = "entity-fetch";
|
||||
let source_id = "source-fetch";
|
||||
@@ -359,15 +382,10 @@ mod tests {
|
||||
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.7_f32, 0.8, 0.9], &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
|
||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Row {
|
||||
entity_id: KnowledgeEntity,
|
||||
}
|
||||
|
||||
let mut res = db
|
||||
.client
|
||||
.query(
|
||||
@@ -375,13 +393,17 @@ mod tests {
|
||||
)
|
||||
.bind(("id", entity_rid.clone()))
|
||||
.await
|
||||
.expect("failed to fetch embedding with FETCH");
|
||||
let rows: Vec<Row> = res.take(0).expect("failed to deserialize fetch rows");
|
||||
.with_context(|| "failed to fetch embedding with FETCH".to_string())?;
|
||||
let rows: Vec<Row> = res
|
||||
.take(0)
|
||||
.with_context(|| "failed to deserialize fetch rows".to_string())?;
|
||||
|
||||
assert_eq!(rows.len(), 1);
|
||||
let fetched_entity = &rows[0].entity_id;
|
||||
let fetched_entity = &rows.first().context("Expected at least one result")?.entity_id;
|
||||
assert_eq!(fetched_entity.id, entity_key);
|
||||
assert_eq!(fetched_entity.name, "Test entity");
|
||||
assert_eq!(fetched_entity.user_id, user_id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,6 +124,7 @@ impl KnowledgeRelationship {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
use super::*;
|
||||
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||
|
||||
@@ -155,10 +156,9 @@ mod tests {
|
||||
result.take(0).expect("failed to take relationship by id")
|
||||
}
|
||||
|
||||
// Helper function to create a test knowledge entity for the relationship tests
|
||||
async fn create_test_entity(name: &str, db_client: &SurrealDbClient) -> String {
|
||||
async fn create_test_entity(name: &str, db_client: &SurrealDbClient) -> anyhow::Result<String> {
|
||||
let source_id = "source123".to_string();
|
||||
let description = format!("Description for {}", name);
|
||||
let description = format!("Description for {name}");
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
@@ -174,12 +174,14 @@ mod tests {
|
||||
let stored: Option<KnowledgeEntity> = db_client
|
||||
.store_item(entity)
|
||||
.await
|
||||
.expect("Failed to store entity");
|
||||
stored.unwrap().id
|
||||
.with_context(|| "Failed to store entity".to_string())?;
|
||||
stored
|
||||
.ok_or_else(|| anyhow::anyhow!("Expected stored entity to return Some"))
|
||||
.map(|e| e.id)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_relationship_creation() {
|
||||
async fn test_relationship_creation() -> anyhow::Result<()> {
|
||||
let in_id = "entity1".to_string();
|
||||
let out_id = "entity2".to_string();
|
||||
let user_id = "user123".to_string();
|
||||
@@ -194,25 +196,23 @@ mod tests {
|
||||
relationship_type.clone(),
|
||||
);
|
||||
|
||||
// Verify fields are correctly set
|
||||
assert_eq!(relationship.in_, in_id);
|
||||
assert_eq!(relationship.out, out_id);
|
||||
assert_eq!(relationship.metadata.user_id, user_id);
|
||||
assert_eq!(relationship.metadata.source_id, source_id);
|
||||
assert_eq!(relationship.metadata.relationship_type, relationship_type);
|
||||
assert!(!relationship.id.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_and_verify_by_source_id() {
|
||||
// Setup in-memory database for testing
|
||||
async fn test_store_and_verify_by_source_id() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
// Create two entities to relate
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await?;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await?;
|
||||
|
||||
// Create relationship
|
||||
let user_id = "user123".to_string();
|
||||
let source_id = "source123".to_string();
|
||||
let relationship_type = "references".to_string();
|
||||
@@ -225,11 +225,10 @@ mod tests {
|
||||
relationship_type,
|
||||
);
|
||||
|
||||
// Store the relationship
|
||||
relationship
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship");
|
||||
.with_context(|| "Failed to store relationship".to_string())?;
|
||||
|
||||
let persisted = get_relationship_by_id(&relationship.id, &db)
|
||||
.await
|
||||
@@ -239,8 +238,6 @@ mod tests {
|
||||
assert_eq!(persisted.metadata.user_id, user_id);
|
||||
assert_eq!(persisted.metadata.source_id, source_id);
|
||||
|
||||
// Query to verify the relationship exists by checking for relationships with our source_id
|
||||
// This approach is more reliable than trying to look up by ID
|
||||
let mut check_result = db
|
||||
.query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id")
|
||||
.bind(("source_id", source_id.clone()))
|
||||
@@ -253,14 +250,16 @@ mod tests {
|
||||
1,
|
||||
"Expected one relationship for source_id"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_relationship_resists_query_injection() {
|
||||
async fn test_store_relationship_resists_query_injection() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await?;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await?;
|
||||
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
entity1_id,
|
||||
@@ -288,18 +287,17 @@ mod tests {
|
||||
rows[0].metadata.source_id,
|
||||
"source123'; DELETE FROM relates_to; --"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_and_delete_relationship() {
|
||||
// Setup in-memory database for testing
|
||||
async fn test_store_and_delete_relationship() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
// Create two entities to relate
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await?;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await?;
|
||||
|
||||
// Create relationship
|
||||
let user_id = "user123".to_string();
|
||||
let source_id = "source123".to_string();
|
||||
let relationship_type = "references".to_string();
|
||||
@@ -312,52 +310,44 @@ mod tests {
|
||||
relationship_type,
|
||||
);
|
||||
|
||||
// Store relationship
|
||||
relationship
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship");
|
||||
.with_context(|| "Failed to store relationship".to_string())?;
|
||||
|
||||
// Ensure relationship exists before deletion attempt
|
||||
let mut existing_before_delete = db
|
||||
.query(format!(
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{}' AND metadata.source_id = '{}'",
|
||||
user_id, source_id
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{user_id}' AND metadata.source_id = '{source_id}'"
|
||||
))
|
||||
.await
|
||||
.expect("Query failed");
|
||||
.with_context(|| "Query failed".to_string())?;
|
||||
let before_results: Vec<KnowledgeRelationship> =
|
||||
existing_before_delete.take(0).unwrap_or_default();
|
||||
assert!(
|
||||
!before_results.is_empty(),
|
||||
"Relationship should exist before deletion"
|
||||
);
|
||||
assert!(!before_results.is_empty(), "Relationship should exist before deletion");
|
||||
|
||||
// Delete relationship by ID
|
||||
KnowledgeRelationship::delete_relationship_by_id(&relationship.id, &user_id, &db)
|
||||
.await
|
||||
.expect("Failed to delete relationship by ID");
|
||||
.with_context(|| "Failed to delete relationship by ID".to_string())?;
|
||||
|
||||
// Query to verify relationship was deleted
|
||||
let mut result = db
|
||||
.query(format!(
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{}' AND metadata.source_id = '{}'",
|
||||
user_id, source_id
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{user_id}' AND metadata.source_id = '{source_id}'"
|
||||
))
|
||||
.await
|
||||
.expect("Query failed");
|
||||
.with_context(|| "Query failed".to_string())?;
|
||||
let results: Vec<KnowledgeRelationship> = result.take(0).unwrap_or_default();
|
||||
|
||||
// Verify relationship no longer exists
|
||||
assert!(results.is_empty(), "Relationship should be deleted");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_relationship_by_id_unauthorized() {
|
||||
async fn test_delete_relationship_by_id_unauthorized() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await?;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await?;
|
||||
|
||||
let owner_user_id = "owner-user".to_string();
|
||||
let source_id = "source123".to_string();
|
||||
@@ -373,20 +363,16 @@ mod tests {
|
||||
relationship
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship");
|
||||
.with_context(|| "Failed to store relationship".to_string())?;
|
||||
|
||||
let mut before_attempt = db
|
||||
.query(format!(
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{}'",
|
||||
owner_user_id
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{owner_user_id}'"
|
||||
))
|
||||
.await
|
||||
.expect("Query failed");
|
||||
.with_context(|| "Query failed".to_string())?;
|
||||
let before_results: Vec<KnowledgeRelationship> = before_attempt.take(0).unwrap_or_default();
|
||||
assert!(
|
||||
!before_results.is_empty(),
|
||||
"Relationship should exist before unauthorized delete attempt"
|
||||
);
|
||||
assert!(!before_results.is_empty(), "Relationship should exist before unauthorized delete attempt");
|
||||
|
||||
let result = KnowledgeRelationship::delete_relationship_by_id(
|
||||
&relationship.id,
|
||||
@@ -397,40 +383,34 @@ mod tests {
|
||||
|
||||
match result {
|
||||
Err(AppError::Auth(_)) => {}
|
||||
_ => panic!("Expected authorization error when deleting someone else's relationship"),
|
||||
_ => anyhow::bail!("Expected authorization error when deleting someone else's relationship"),
|
||||
}
|
||||
|
||||
let mut after_attempt = db
|
||||
.query(format!(
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{}'",
|
||||
owner_user_id
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{owner_user_id}'"
|
||||
))
|
||||
.await
|
||||
.expect("Query failed");
|
||||
.with_context(|| "Query failed".to_string())?;
|
||||
let results: Vec<KnowledgeRelationship> = after_attempt.take(0).unwrap_or_default();
|
||||
|
||||
assert!(
|
||||
!results.is_empty(),
|
||||
"Relationship should still exist after unauthorized delete attempt"
|
||||
);
|
||||
assert!(!results.is_empty(), "Relationship should still exist after unauthorized delete attempt");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_relationship_exists() {
|
||||
// Setup in-memory database for testing
|
||||
async fn test_store_relationship_exists() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
// Create entities to relate
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
||||
let entity3_id = create_test_entity("Entity 3", &db).await;
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await?;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await?;
|
||||
let entity3_id = create_test_entity("Entity 3", &db).await?;
|
||||
|
||||
// Create relationships with the same source_id
|
||||
let user_id = "user123".to_string();
|
||||
let source_id = "source123".to_string();
|
||||
let different_source_id = "different_source".to_string();
|
||||
|
||||
// Create two relationships with the same source_id
|
||||
let relationship1 = KnowledgeRelationship::new(
|
||||
entity1_id.clone(),
|
||||
entity2_id.clone(),
|
||||
@@ -447,7 +427,6 @@ mod tests {
|
||||
"contains".to_string(),
|
||||
);
|
||||
|
||||
// Create a relationship with a different source_id
|
||||
let different_relationship = KnowledgeRelationship::new(
|
||||
entity1_id.clone(),
|
||||
entity3_id.clone(),
|
||||
@@ -456,21 +435,19 @@ mod tests {
|
||||
"mentions".to_string(),
|
||||
);
|
||||
|
||||
// Store all relationships
|
||||
relationship1
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship 1");
|
||||
.with_context(|| "Failed to store relationship 1".to_string())?;
|
||||
relationship2
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship 2");
|
||||
.with_context(|| "Failed to store relationship 2".to_string())?;
|
||||
different_relationship
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store different relationship");
|
||||
.with_context(|| "Failed to store different relationship".to_string())?;
|
||||
|
||||
// Sanity-check setup: exactly two relationships use source_id and one uses different_source_id.
|
||||
let mut before_delete = db
|
||||
.query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id")
|
||||
.bind(("source_id", source_id.clone()))
|
||||
@@ -489,31 +466,30 @@ mod tests {
|
||||
before_delete_different.take(0).unwrap_or_default();
|
||||
assert_eq!(before_delete_different_rows.len(), 1);
|
||||
|
||||
// Delete relationships by source_id
|
||||
KnowledgeRelationship::delete_relationships_by_source_id(&source_id, &db)
|
||||
.await
|
||||
.expect("Failed to delete relationships by source_id");
|
||||
.with_context(|| "Failed to delete relationships by source_id".to_string())?;
|
||||
|
||||
// Query to verify the specific relationships with source_id were deleted.
|
||||
let result1 = get_relationship_by_id(&relationship1.id, &db).await;
|
||||
let result2 = get_relationship_by_id(&relationship2.id, &db).await;
|
||||
let different_result = get_relationship_by_id(&different_relationship.id, &db).await;
|
||||
|
||||
// Verify relationships with the source_id are deleted
|
||||
assert!(result1.is_none(), "Relationship 1 should be deleted");
|
||||
assert!(result2.is_none(), "Relationship 2 should be deleted");
|
||||
let remaining =
|
||||
different_result.expect("Relationship with different source_id should remain");
|
||||
assert_eq!(remaining.metadata.source_id, different_source_id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_relationships_by_source_id_resists_query_injection() {
|
||||
async fn test_delete_relationships_by_source_id_resists_query_injection() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
||||
let entity3_id = create_test_entity("Entity 3", &db).await;
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await?;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await?;
|
||||
let entity3_id = create_test_entity("Entity 3", &db).await?;
|
||||
|
||||
let safe_relationship = KnowledgeRelationship::new(
|
||||
entity1_id.clone(),
|
||||
@@ -552,5 +528,7 @@ mod tests {
|
||||
remaining_other.is_some(),
|
||||
"Other relationship should remain"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,12 +66,12 @@ pub fn format_history(history: &[Message]) -> String {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
use super::*;
|
||||
use crate::storage::db::SurrealDbClient;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_message_creation() {
|
||||
// Test basic message creation
|
||||
async fn test_message_creation() -> anyhow::Result<()> {
|
||||
let conversation_id = "test_conversation";
|
||||
let content = "This is a test message";
|
||||
let role = MessageRole::User;
|
||||
@@ -84,24 +84,23 @@ mod tests {
|
||||
references.clone(),
|
||||
);
|
||||
|
||||
// Verify message properties
|
||||
assert_eq!(message.conversation_id, conversation_id);
|
||||
assert_eq!(message.content, content);
|
||||
assert_eq!(message.role, role);
|
||||
assert_eq!(message.references, references);
|
||||
assert!(!message.id.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_message_persistence() {
|
||||
// Setup in-memory database for testing
|
||||
async fn test_message_persistence() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &uuid::Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Create and store a message
|
||||
let conversation_id = "test_conversation";
|
||||
let message = Message::new(
|
||||
conversation_id.to_string(),
|
||||
@@ -111,39 +110,37 @@ mod tests {
|
||||
);
|
||||
let message_id = message.id.clone();
|
||||
|
||||
// Store the message
|
||||
db.store_item(message.clone())
|
||||
.await
|
||||
.expect("Failed to store message");
|
||||
.with_context(|| "Failed to store message".to_string())?;
|
||||
|
||||
// Retrieve the message
|
||||
let retrieved: Option<Message> = db
|
||||
.get_item(&message_id)
|
||||
.await
|
||||
.expect("Failed to retrieve message");
|
||||
.with_context(|| "Failed to retrieve message".to_string())?;
|
||||
|
||||
assert!(retrieved.is_some());
|
||||
let retrieved = retrieved.unwrap();
|
||||
let retrieved = retrieved.ok_or_else(|| anyhow::anyhow!("Expected message to exist"))?;
|
||||
|
||||
// Verify retrieved properties match original
|
||||
assert_eq!(retrieved.id, message.id);
|
||||
assert_eq!(retrieved.conversation_id, message.conversation_id);
|
||||
assert_eq!(retrieved.role, message.role);
|
||||
assert_eq!(retrieved.content, message.content);
|
||||
assert_eq!(retrieved.references, message.references);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_message_role_display() {
|
||||
// Test the Display implementation for MessageRole
|
||||
async fn test_message_role_display() -> anyhow::Result<()> {
|
||||
assert_eq!(format!("{}", MessageRole::User), "User");
|
||||
assert_eq!(format!("{}", MessageRole::AI), "AI");
|
||||
assert_eq!(format!("{}", MessageRole::System), "System");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_message_display() {
|
||||
// Test the Display implementation for Message
|
||||
async fn test_message_display() -> anyhow::Result<()> {
|
||||
let message = Message {
|
||||
id: "test_id".to_string(),
|
||||
created_at: Utc::now(),
|
||||
@@ -154,12 +151,13 @@ mod tests {
|
||||
references: None,
|
||||
};
|
||||
|
||||
assert_eq!(format!("{}", message), "User: Hello world");
|
||||
assert_eq!(format!("{message}"), "User: Hello world");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_format_history() {
|
||||
// Create a vector of messages
|
||||
async fn test_format_history() -> anyhow::Result<()> {
|
||||
let messages = vec![
|
||||
Message {
|
||||
id: "1".to_string(),
|
||||
@@ -181,10 +179,10 @@ mod tests {
|
||||
},
|
||||
];
|
||||
|
||||
// Format the history
|
||||
let formatted = format_history(&messages);
|
||||
|
||||
// Verify the formatting
|
||||
assert_eq!(formatted, "User: Hello\nAI: Hi there!");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -216,20 +216,22 @@ impl Scratchpad {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_scratchpad() {
|
||||
async fn test_create_scratchpad() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
// Create a new scratchpad
|
||||
let user_id = "test_user";
|
||||
@@ -254,29 +256,28 @@ mod tests {
|
||||
let retrieved: Option<Scratchpad> = db
|
||||
.get_item(&scratchpad.id)
|
||||
.await
|
||||
.expect("Failed to retrieve scratchpad");
|
||||
assert!(retrieved.is_some());
|
||||
|
||||
let retrieved = retrieved.unwrap();
|
||||
.with_context(|| "Failed to retrieve scratchpad".to_string())?;
|
||||
let retrieved = retrieved.with_context(|| "expected scratchpad to exist".to_string())?;
|
||||
assert_eq!(retrieved.id, scratchpad.id);
|
||||
assert_eq!(retrieved.user_id, user_id);
|
||||
assert_eq!(retrieved.title, title);
|
||||
assert!(!retrieved.is_archived);
|
||||
assert!(retrieved.archived_at.is_none());
|
||||
assert!(retrieved.ingested_at.is_none());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_by_user() {
|
||||
async fn test_get_by_user() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
let user_id = "test_user";
|
||||
|
||||
@@ -288,19 +289,21 @@ mod tests {
|
||||
// Store them
|
||||
let scratchpad1_id = scratchpad1.id.clone();
|
||||
let scratchpad2_id = scratchpad2.id.clone();
|
||||
db.store_item(scratchpad1).await.unwrap();
|
||||
db.store_item(scratchpad2).await.unwrap();
|
||||
db.store_item(scratchpad3).await.unwrap();
|
||||
db.store_item(scratchpad1).await.with_context(|| "store scratchpad1".to_string())?;
|
||||
db.store_item(scratchpad2).await.with_context(|| "store scratchpad2".to_string())?;
|
||||
db.store_item(scratchpad3).await.with_context(|| "store scratchpad3".to_string())?;
|
||||
|
||||
// Archive one of the user's scratchpads
|
||||
Scratchpad::archive(&scratchpad2_id, user_id, &db, false)
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "archive".to_string())?;
|
||||
|
||||
// Get scratchpads for user_id
|
||||
let user_scratchpads = Scratchpad::get_by_user(user_id, &db).await.unwrap();
|
||||
let user_scratchpads = Scratchpad::get_by_user(user_id, &db)
|
||||
.await
|
||||
.with_context(|| "get_by_user".to_string())?;
|
||||
assert_eq!(user_scratchpads.len(), 1);
|
||||
assert_eq!(user_scratchpads[0].id, scratchpad1_id);
|
||||
assert_eq!(user_scratchpads.first().map(|s| &s.id), Some(&scratchpad1_id));
|
||||
|
||||
// Verify they belong to the user
|
||||
for scratchpad in &user_scratchpads {
|
||||
@@ -309,177 +312,183 @@ mod tests {
|
||||
|
||||
let archived = Scratchpad::get_archived_by_user(user_id, &db)
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "get_archived_by_user".to_string())?;
|
||||
assert_eq!(archived.len(), 1);
|
||||
assert_eq!(archived[0].id, scratchpad2_id);
|
||||
assert!(archived[0].is_archived);
|
||||
assert!(archived[0].ingested_at.is_none());
|
||||
assert_eq!(archived.first().map(|s| &s.id), Some(&scratchpad2_id));
|
||||
assert!(archived.first().is_some_and(|s| s.is_archived));
|
||||
assert!(archived.first().is_some_and(|s| s.ingested_at.is_none()));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_archive_and_restore() {
|
||||
async fn test_archive_and_restore() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
let user_id = "test_user";
|
||||
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
||||
let scratchpad_id = scratchpad.id.clone();
|
||||
db.store_item(scratchpad).await.unwrap();
|
||||
db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?;
|
||||
|
||||
let archived = Scratchpad::archive(&scratchpad_id, user_id, &db, true)
|
||||
.await
|
||||
.expect("Failed to archive");
|
||||
.with_context(|| "Failed to archive".to_string())?;
|
||||
assert!(archived.is_archived);
|
||||
assert!(archived.archived_at.is_some());
|
||||
assert!(archived.ingested_at.is_some());
|
||||
|
||||
let restored = Scratchpad::restore(&scratchpad_id, user_id, &db)
|
||||
.await
|
||||
.expect("Failed to restore");
|
||||
.with_context(|| "Failed to restore".to_string())?;
|
||||
assert!(!restored.is_archived);
|
||||
assert!(restored.archived_at.is_none());
|
||||
assert!(restored.ingested_at.is_none());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_content() {
|
||||
async fn test_update_content() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
let user_id = "test_user";
|
||||
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
||||
let scratchpad_id = scratchpad.id.clone();
|
||||
|
||||
db.store_item(scratchpad).await.unwrap();
|
||||
db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?;
|
||||
|
||||
let new_content = "Updated content";
|
||||
let updated = Scratchpad::update_content(&scratchpad_id, user_id, new_content, &db)
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "update_content".to_string())?;
|
||||
|
||||
assert_eq!(updated.content, new_content);
|
||||
assert!(!updated.is_dirty);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_content_unauthorized() {
|
||||
async fn test_update_content_unauthorized() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
let owner_id = "owner";
|
||||
let other_user = "other_user";
|
||||
let scratchpad = Scratchpad::new(owner_id.to_string(), "Test".to_string());
|
||||
let scratchpad_id = scratchpad.id.clone();
|
||||
|
||||
db.store_item(scratchpad).await.unwrap();
|
||||
db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?;
|
||||
|
||||
let result = Scratchpad::update_content(&scratchpad_id, other_user, "Hacked", &db).await;
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(AppError::Auth(_)) => {}
|
||||
_ => panic!("Expected Auth error"),
|
||||
_ => anyhow::bail!("Expected Auth error"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_scratchpad() {
|
||||
async fn test_delete_scratchpad() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
let user_id = "test_user";
|
||||
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
||||
let scratchpad_id = scratchpad.id.clone();
|
||||
|
||||
db.store_item(scratchpad).await.unwrap();
|
||||
db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?;
|
||||
|
||||
// Delete should succeed
|
||||
let result = Scratchpad::delete(&scratchpad_id, user_id, &db).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify it's gone
|
||||
let retrieved: Option<Scratchpad> = db.get_item(&scratchpad_id).await.unwrap();
|
||||
let retrieved: Option<Scratchpad> = db.get_item(&scratchpad_id).await.with_context(|| "get_item".to_string())?;
|
||||
assert!(retrieved.is_none());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_unauthorized() {
|
||||
async fn test_delete_unauthorized() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
let owner_id = "owner";
|
||||
let other_user = "other_user";
|
||||
let scratchpad = Scratchpad::new(owner_id.to_string(), "Test".to_string());
|
||||
let scratchpad_id = scratchpad.id.clone();
|
||||
|
||||
db.store_item(scratchpad).await.unwrap();
|
||||
db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?;
|
||||
|
||||
let result = Scratchpad::delete(&scratchpad_id, other_user, &db).await;
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(AppError::Auth(_)) => {}
|
||||
_ => panic!("Expected Auth error"),
|
||||
_ => anyhow::bail!("Expected Auth error"),
|
||||
}
|
||||
|
||||
// Verify it still exists
|
||||
let retrieved: Option<Scratchpad> = db.get_item(&scratchpad_id).await.unwrap();
|
||||
let retrieved: Option<Scratchpad> = db.get_item(&scratchpad_id).await.with_context(|| "get_item".to_string())?;
|
||||
assert!(retrieved.is_some());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timezone_aware_scratchpad_conversion() {
|
||||
async fn test_timezone_aware_scratchpad_conversion() -> anyhow::Result<()> {
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.expect("Failed to create test database");
|
||||
.with_context(|| "Failed to create test database".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
let user_id = "test_user_123";
|
||||
let scratchpad =
|
||||
Scratchpad::new(user_id.to_string(), "Test Timezone Scratchpad".to_string());
|
||||
let scratchpad_id = scratchpad.id.clone();
|
||||
|
||||
db.store_item(scratchpad).await.unwrap();
|
||||
db.store_item(scratchpad).await.with_context(|| "store scratchpad".to_string())?;
|
||||
|
||||
let retrieved = Scratchpad::get_by_id(&scratchpad_id, user_id, &db)
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "get_by_id".to_string())?;
|
||||
|
||||
// Test that datetime fields are preserved and can be used for timezone formatting
|
||||
assert!(retrieved.created_at.timestamp() > 0);
|
||||
@@ -493,10 +502,11 @@ mod tests {
|
||||
// Archive the scratchpad to test optional datetime handling
|
||||
let archived = Scratchpad::archive(&scratchpad_id, user_id, &db, false)
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "archive".to_string())?;
|
||||
|
||||
assert!(archived.archived_at.is_some());
|
||||
assert!(archived.archived_at.unwrap().timestamp() > 0);
|
||||
assert!(archived.archived_at.with_context(|| "expected archived_at".to_string())?.timestamp() > 0);
|
||||
assert!(archived.ingested_at.is_none());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,7 +64,14 @@ impl SystemSettings {
|
||||
let mut needs_update = false;
|
||||
|
||||
let backend_label = provider.backend_label().to_string();
|
||||
let provider_dimensions = provider.dimension() as u32;
|
||||
let provider_dimensions = u32::try_from(provider.dimension())
|
||||
.unwrap_or_else(|_| {
|
||||
tracing::warn!(
|
||||
"Provider dimension {} exceeds u32 max; falling back to 0",
|
||||
provider.dimension()
|
||||
);
|
||||
0u32
|
||||
});
|
||||
let provider_model = provider.model_code();
|
||||
|
||||
// Sync backend label
|
||||
@@ -107,7 +114,8 @@ impl SystemSettings {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::storage::indexes::ensure_runtime_indexes;
|
||||
use anyhow::{self, Context};
|
||||
use crate::storage::indexes::ensure_runtime;
|
||||
use crate::storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk};
|
||||
use async_openai::Client;
|
||||
|
||||
@@ -118,68 +126,102 @@ mod tests {
|
||||
db: &SurrealDbClient,
|
||||
table_name: &str,
|
||||
index_name: &str,
|
||||
) -> u32 {
|
||||
) -> anyhow::Result<u32> {
|
||||
let query = format!("INFO FOR TABLE {table_name};");
|
||||
let mut response = db
|
||||
.client
|
||||
.query(query)
|
||||
.await
|
||||
.expect("Failed to fetch table info");
|
||||
.with_context(|| "Failed to fetch table info".to_string())?;
|
||||
|
||||
let info: surrealdb::Value = response
|
||||
.take(0)
|
||||
.expect("Failed to extract table info response");
|
||||
.with_context(|| "Failed to extract table info response".to_string())?;
|
||||
|
||||
let info_json: serde_json::Value =
|
||||
serde_json::to_value(info).expect("Failed to convert info to json");
|
||||
serde_json::to_value(info).with_context(|| "Failed to convert info to json".to_string())?;
|
||||
|
||||
let indexes = info_json["Object"]["indexes"]["Object"]
|
||||
.as_object()
|
||||
.unwrap_or_else(|| panic!("Indexes collection missing in table info: {info_json:#?}"));
|
||||
let indexes = info_json
|
||||
.get("Object")
|
||||
.and_then(|v| v.get("indexes"))
|
||||
.and_then(|v| v.get("Object"))
|
||||
.and_then(|v| v.as_object())
|
||||
.with_context(|| format!("Indexes collection missing in table info: {info_json:#?}"))?;
|
||||
|
||||
let definition = indexes
|
||||
.get(index_name)
|
||||
.and_then(|definition| definition.get("Strand"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_else(|| panic!("Index definition not found in table info: {info_json:#?}"));
|
||||
.with_context(|| format!("Index definition not found in table info: {info_json:#?}"))?;
|
||||
|
||||
let dimension_part = definition
|
||||
.split("DIMENSION")
|
||||
.nth(1)
|
||||
.expect("Index definition missing DIMENSION clause");
|
||||
.with_context(|| "Index definition missing DIMENSION clause".to_string())?;
|
||||
|
||||
let dimension_token = dimension_part
|
||||
.split_whitespace()
|
||||
.next()
|
||||
.expect("Dimension value missing in definition")
|
||||
.with_context(|| "Dimension value missing in definition".to_string())?
|
||||
.trim_end_matches(';');
|
||||
|
||||
dimension_token
|
||||
.parse::<u32>()
|
||||
.expect("Dimension value is not a valid number")
|
||||
.with_context(|| "Dimension value is not a valid number".to_string())
|
||||
}
|
||||
|
||||
async fn simulate_reembedding(
|
||||
db: &SurrealDbClient,
|
||||
target_dimension: usize,
|
||||
initial_chunk: TextChunk,
|
||||
) -> anyhow::Result<()> {
|
||||
db.query(
|
||||
"REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding;",
|
||||
)
|
||||
.await
|
||||
.with_context(|| "remove index".to_string())?;
|
||||
let define_index_query = format!(
|
||||
"DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {target_dimension};"
|
||||
);
|
||||
db.query(define_index_query)
|
||||
.await
|
||||
.with_context(|| "Re-defining index should succeed".to_string())?;
|
||||
|
||||
let new_embedding = vec![0.5; target_dimension];
|
||||
let sql = "UPSERT type::thing('text_chunk_embedding', $id) SET chunk_id = type::thing('text_chunk', $id), embedding = $embedding, user_id = $user_id;";
|
||||
|
||||
db.client
|
||||
.query(sql)
|
||||
.bind(("id", initial_chunk.id.clone()))
|
||||
.bind(("user_id", initial_chunk.user_id.clone()))
|
||||
.bind(("embedding", new_embedding))
|
||||
.await
|
||||
.with_context(|| "upsert embedding".to_string())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_settings_initialization() {
|
||||
async fn test_settings_initialization() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Test initialization of system settings
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to get system settings");
|
||||
.with_context(|| "Failed to get system settings".to_string())?;
|
||||
|
||||
// Verify initial state after initialization
|
||||
assert_eq!(settings.id, "current");
|
||||
assert_eq!(settings.registrations_enabled, true);
|
||||
assert_eq!(settings.require_email_verification, false);
|
||||
assert!(settings.registrations_enabled);
|
||||
assert!(!settings.require_email_verification);
|
||||
assert_eq!(settings.query_model, "gpt-4o-mini");
|
||||
assert_eq!(settings.processing_model, "gpt-4o-mini");
|
||||
assert_eq!(settings.image_processing_model, "gpt-4o-mini");
|
||||
@@ -196,10 +238,10 @@ mod tests {
|
||||
// Test idempotency - ensure calling it again doesn't change anything
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let settings_again = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to get settings after initialization");
|
||||
.with_context(|| "Failed to get settings after initialization".to_string())?;
|
||||
|
||||
assert_eq!(settings.id, settings_again.id);
|
||||
assert_eq!(
|
||||
@@ -210,48 +252,52 @@ mod tests {
|
||||
settings.require_email_verification,
|
||||
settings_again.require_email_verification
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_current_settings() {
|
||||
async fn test_get_current_settings() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Initialize settings
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
// Test get_current method
|
||||
let settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to get current settings");
|
||||
.with_context(|| "Failed to get current settings".to_string())?;
|
||||
|
||||
assert_eq!(settings.id, "current");
|
||||
assert_eq!(settings.registrations_enabled, true);
|
||||
assert_eq!(settings.require_email_verification, false);
|
||||
assert!(settings.registrations_enabled);
|
||||
assert!(!settings.require_email_verification);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_settings() {
|
||||
async fn test_update_settings() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Initialize settings
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
// Create updated settings
|
||||
let mut updated_settings = SystemSettings::get_current(&db).await.unwrap();
|
||||
let mut updated_settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.with_context(|| "get_current".to_string())?;
|
||||
updated_settings.id = "current".to_string();
|
||||
updated_settings.registrations_enabled = false;
|
||||
updated_settings.require_email_verification = true;
|
||||
@@ -260,31 +306,32 @@ mod tests {
|
||||
// Test update method
|
||||
let result = SystemSettings::update(&db, updated_settings)
|
||||
.await
|
||||
.expect("Failed to update settings");
|
||||
.with_context(|| "Failed to update settings".to_string())?;
|
||||
|
||||
assert_eq!(result.id, "current");
|
||||
assert_eq!(result.registrations_enabled, false);
|
||||
assert_eq!(result.require_email_verification, true);
|
||||
assert!(!result.registrations_enabled);
|
||||
assert!(result.require_email_verification);
|
||||
assert_eq!(result.query_model, "gpt-4");
|
||||
|
||||
// Verify changes persisted by getting current settings
|
||||
let current = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to get current settings after update");
|
||||
.with_context(|| "Failed to get current settings after update".to_string())?;
|
||||
|
||||
assert_eq!(current.registrations_enabled, false);
|
||||
assert_eq!(current.require_email_verification, true);
|
||||
assert!(!current.registrations_enabled);
|
||||
assert!(current.require_email_verification);
|
||||
assert_eq!(current.query_model, "gpt-4");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_current_nonexistent() {
|
||||
async fn test_get_current_nonexistent() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Don't initialize settings and try to get them
|
||||
let result = SystemSettings::get_current(&db).await;
|
||||
@@ -294,21 +341,22 @@ mod tests {
|
||||
Err(AppError::NotFound(_)) => {
|
||||
// Expected error
|
||||
}
|
||||
Err(e) => panic!("Expected NotFound error, got: {:?}", e),
|
||||
Ok(_) => panic!("Expected error but got Ok"),
|
||||
Err(e) => anyhow::bail!("Expected NotFound error, got: {e:?}"),
|
||||
Ok(_) => anyhow::bail!("Expected error but got Ok"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_migration_after_changing_embedding_length() {
|
||||
async fn test_migration_after_changing_embedding_length() -> anyhow::Result<()> {
|
||||
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.expect("Failed to start DB");
|
||||
.with_context(|| "Failed to start DB".to_string())?;
|
||||
|
||||
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Initial migration failed");
|
||||
.with_context(|| "Initial migration failed".to_string())?;
|
||||
|
||||
let initial_chunk = TextChunk::new(
|
||||
"source1".into(),
|
||||
@@ -318,43 +366,11 @@ mod tests {
|
||||
|
||||
TextChunk::store_with_embedding(initial_chunk.clone(), vec![0.1; 1536], &db)
|
||||
.await
|
||||
.expect("Failed to store initial chunk with embedding");
|
||||
|
||||
async fn simulate_reembedding(
|
||||
db: &SurrealDbClient,
|
||||
target_dimension: usize,
|
||||
initial_chunk: TextChunk,
|
||||
) {
|
||||
db.query(
|
||||
"REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding;",
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let define_index_query = format!(
|
||||
"DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};",
|
||||
target_dimension
|
||||
);
|
||||
db.query(define_index_query)
|
||||
.await
|
||||
.expect("Re-defining index should succeed");
|
||||
|
||||
let new_embedding = vec![0.5; target_dimension];
|
||||
let sql = "UPSERT type::thing('text_chunk_embedding', $id) SET chunk_id = type::thing('text_chunk', $id), embedding = $embedding, user_id = $user_id;";
|
||||
|
||||
let update_result = db
|
||||
.client
|
||||
.query(sql)
|
||||
.bind(("id", initial_chunk.id.clone()))
|
||||
.bind(("user_id", initial_chunk.user_id.clone()))
|
||||
.bind(("embedding", new_embedding))
|
||||
.await;
|
||||
|
||||
assert!(update_result.is_ok());
|
||||
}
|
||||
.with_context(|| "Failed to store initial chunk with embedding".to_string())?;
|
||||
|
||||
// Re-embed with the existing configured dimension to ensure migrations remain idempotent.
|
||||
let target_dimension = 1536usize;
|
||||
simulate_reembedding(&db, target_dimension, initial_chunk).await;
|
||||
simulate_reembedding(&db, target_dimension, initial_chunk).await?;
|
||||
|
||||
let migration_result = db.apply_migrations().await;
|
||||
|
||||
@@ -363,34 +379,35 @@ mod tests {
|
||||
"Migrations should not fail: {:?}",
|
||||
migration_result.err()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_should_change_embedding_length_on_indexes_when_switching_length() {
|
||||
async fn test_should_change_embedding_length_on_indexes_when_switching_length() -> anyhow::Result<()> {
|
||||
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.expect("Failed to start DB");
|
||||
.with_context(|| "Failed to start DB".to_string())?;
|
||||
|
||||
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Initial migration failed");
|
||||
.with_context(|| "Initial migration failed".to_string())?;
|
||||
|
||||
let mut current_settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to load current settings");
|
||||
.with_context(|| "Failed to load current settings".to_string())?;
|
||||
|
||||
// Ensure runtime indexes exist with the current embedding dimension so INFO queries succeed.
|
||||
ensure_runtime_indexes(&db, current_settings.embedding_dimensions as usize)
|
||||
ensure_runtime(&db, current_settings.embedding_dimensions as usize)
|
||||
.await
|
||||
.expect("failed to build runtime indexes");
|
||||
.with_context(|| "failed to build runtime indexes".to_string())?;
|
||||
|
||||
let initial_chunk_dimension = get_hnsw_index_dimension(
|
||||
&db,
|
||||
"text_chunk_embedding",
|
||||
"idx_embedding_text_chunk_embedding",
|
||||
)
|
||||
.await;
|
||||
.await?;
|
||||
|
||||
assert_eq!(
|
||||
initial_chunk_dimension, current_settings.embedding_dimensions,
|
||||
@@ -405,7 +422,7 @@ mod tests {
|
||||
|
||||
let updated_settings = SystemSettings::update(&db, current_settings)
|
||||
.await
|
||||
.expect("Failed to update settings");
|
||||
.with_context(|| "Failed to update settings".to_string())?;
|
||||
|
||||
assert_eq!(
|
||||
updated_settings.embedding_dimensions, new_dimension,
|
||||
@@ -416,23 +433,23 @@ mod tests {
|
||||
|
||||
TextChunk::update_all_embeddings(&db, &openai_client, &new_model, new_dimension)
|
||||
.await
|
||||
.expect("TextChunk re-embedding should succeed on fresh DB");
|
||||
.with_context(|| "TextChunk re-embedding should succeed on fresh DB".to_string())?;
|
||||
KnowledgeEntity::update_all_embeddings(&db, &openai_client, &new_model, new_dimension)
|
||||
.await
|
||||
.expect("KnowledgeEntity re-embedding should succeed on fresh DB");
|
||||
.with_context(|| "KnowledgeEntity re-embedding should succeed on fresh DB".to_string())?;
|
||||
|
||||
let text_chunk_dimension = get_hnsw_index_dimension(
|
||||
&db,
|
||||
"text_chunk_embedding",
|
||||
"idx_embedding_text_chunk_embedding",
|
||||
)
|
||||
.await;
|
||||
.await?;
|
||||
let knowledge_dimension = get_hnsw_index_dimension(
|
||||
&db,
|
||||
"knowledge_entity_embedding",
|
||||
"idx_embedding_knowledge_entity_embedding",
|
||||
)
|
||||
.await;
|
||||
.await?;
|
||||
|
||||
assert_eq!(
|
||||
text_chunk_dimension, new_dimension,
|
||||
@@ -445,10 +462,11 @@ mod tests {
|
||||
|
||||
let persisted_settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to reload updated settings");
|
||||
.with_context(|| "Failed to reload updated settings".to_string())?;
|
||||
assert_eq!(
|
||||
persisted_settings.embedding_dimensions, new_dimension,
|
||||
"Settings should persist new embedding dimension"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#![allow(clippy::missing_docs_in_private_items, clippy::uninlined_format_args)]
|
||||
#![allow(clippy::missing_docs_in_private_items)]
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Write;
|
||||
|
||||
@@ -237,10 +237,7 @@ impl TextChunk {
|
||||
new_model: &str,
|
||||
new_dimensions: u32,
|
||||
) -> Result<(), AppError> {
|
||||
info!(
|
||||
"Starting re-embedding process for all text chunks. New dimensions: {}",
|
||||
new_dimensions
|
||||
);
|
||||
info!("Starting re-embedding process for all text chunks. New dimensions: {new_dimensions}");
|
||||
|
||||
// Fetch all chunks first
|
||||
let all_chunks: Vec<TextChunk> = db.select(Self::table_name()).await?;
|
||||
@@ -252,7 +249,7 @@ impl TextChunk {
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
info!("Found {} chunks to process.", total_chunks);
|
||||
info!("Found {total_chunks} chunks to process.");
|
||||
|
||||
// Generate all new embeddings in memory
|
||||
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new();
|
||||
@@ -276,7 +273,7 @@ impl TextChunk {
|
||||
"CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.",
|
||||
chunk.id, embedding.len(), new_dimensions
|
||||
);
|
||||
error!("{}", err_msg);
|
||||
error!("{err_msg}");
|
||||
return Err(AppError::InternalError(err_msg));
|
||||
}
|
||||
new_embeddings.insert(
|
||||
@@ -300,6 +297,7 @@ impl TextChunk {
|
||||
.join(",")
|
||||
);
|
||||
// Use the chunk id as the embedding record id to keep a 1:1 mapping
|
||||
let embedding = embedding_str;
|
||||
write!(
|
||||
&mut transaction_query,
|
||||
"UPSERT type::thing('text_chunk_embedding', '{id}') SET \
|
||||
@@ -309,18 +307,13 @@ impl TextChunk {
|
||||
user_id = '{user_id}', \
|
||||
created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \
|
||||
updated_at = time::now();",
|
||||
id = id,
|
||||
embedding = embedding_str,
|
||||
user_id = user_id,
|
||||
source_id = source_id
|
||||
)
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
}
|
||||
|
||||
write!(
|
||||
&mut transaction_query,
|
||||
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};",
|
||||
new_dimensions
|
||||
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};",
|
||||
)
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
|
||||
@@ -377,7 +370,7 @@ impl TextChunk {
|
||||
"CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.",
|
||||
chunk.id, embedding.len(), new_dimensions
|
||||
);
|
||||
error!("{}", err_msg);
|
||||
error!("{err_msg}");
|
||||
return Err(AppError::InternalError(err_msg));
|
||||
}
|
||||
new_embeddings.insert(
|
||||
@@ -422,6 +415,7 @@ impl TextChunk {
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
let embedding = embedding_str;
|
||||
write!(
|
||||
&mut transaction_query,
|
||||
"CREATE type::thing('text_chunk_embedding', '{id}') SET \
|
||||
@@ -431,18 +425,13 @@ impl TextChunk {
|
||||
user_id = '{user_id}', \
|
||||
created_at = time::now(), \
|
||||
updated_at = time::now();",
|
||||
id = id,
|
||||
embedding = embedding_str,
|
||||
user_id = user_id,
|
||||
source_id = source_id
|
||||
)
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
}
|
||||
|
||||
write!(
|
||||
&mut transaction_query,
|
||||
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};",
|
||||
new_dimensions
|
||||
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};",
|
||||
)
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
|
||||
@@ -462,20 +451,21 @@ impl TextChunk {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
|
||||
use super::*;
|
||||
use crate::storage::indexes::{ensure_runtime_indexes, rebuild_indexes};
|
||||
use crate::storage::indexes::{ensure_runtime, rebuild};
|
||||
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
|
||||
use surrealdb::RecordId;
|
||||
use uuid::Uuid;
|
||||
|
||||
async fn ensure_chunk_fts_index(db: &SurrealDbClient) {
|
||||
async fn ensure_chunk_fts_index(db: &SurrealDbClient) -> anyhow::Result<()> {
|
||||
let snowball_sql = r#"
|
||||
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english);
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
"#;
|
||||
|
||||
if let Err(err) = db.client.query(snowball_sql).await {
|
||||
// Fall back to ascii-only analyzer when snowball is unavailable in the build.
|
||||
let fallback_sql = r#"
|
||||
DEFINE ANALYZER OVERWRITE app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii;
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
@@ -484,12 +474,13 @@ mod tests {
|
||||
db.client
|
||||
.query(fallback_sql)
|
||||
.await
|
||||
.unwrap_or_else(|_| panic!("define chunk fts index fallback: {err}"));
|
||||
.with_context(|| format!("define chunk fts index fallback: {err}"))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_chunk_creation() {
|
||||
async fn test_text_chunk_creation() -> anyhow::Result<()> {
|
||||
let source_id = "source123".to_string();
|
||||
let chunk = "This is a text chunk for testing embeddings".to_string();
|
||||
let user_id = "user123".to_string();
|
||||
@@ -500,22 +491,23 @@ mod tests {
|
||||
assert_eq!(text_chunk.chunk, chunk);
|
||||
assert_eq!(text_chunk.user_id, user_id);
|
||||
assert!(!text_chunk.id.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id() {
|
||||
async fn test_delete_by_source_id() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let source_id = "source123".to_string();
|
||||
let user_id = "user123".to_string();
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
.with_context(|| "redefine index".to_string())?;
|
||||
|
||||
let chunk1 = TextChunk::new(
|
||||
source_id.clone(),
|
||||
@@ -535,61 +527,63 @@ mod tests {
|
||||
|
||||
TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
|
||||
.await
|
||||
.expect("store chunk1");
|
||||
.with_context(|| "store chunk1".to_string())?;
|
||||
TextChunk::store_with_embedding(chunk2.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
|
||||
.await
|
||||
.expect("store chunk2");
|
||||
.with_context(|| "store chunk2".to_string())?;
|
||||
TextChunk::store_with_embedding(
|
||||
different_chunk.clone(),
|
||||
vec![0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
.expect("store different chunk");
|
||||
.with_context(|| "store different chunk".to_string())?;
|
||||
|
||||
TextChunk::delete_by_source_id(&source_id, &db)
|
||||
.await
|
||||
.expect("Failed to delete chunks by source_id");
|
||||
.with_context(|| "Failed to delete chunks by source_id".to_string())?;
|
||||
|
||||
let remaining: Vec<TextChunk> = db
|
||||
.client
|
||||
.query(format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
"SELECT * FROM {} WHERE source_id = '{source_id}'",
|
||||
TextChunk::table_name(),
|
||||
source_id
|
||||
))
|
||||
.await
|
||||
.expect("Query failed")
|
||||
.with_context(|| "Query failed".to_string())?
|
||||
.take(0)
|
||||
.expect("Failed to get query results");
|
||||
.with_context(|| "Failed to get query results".to_string())?;
|
||||
assert_eq!(remaining.len(), 0);
|
||||
|
||||
let different_remaining: Vec<TextChunk> = db
|
||||
.client
|
||||
.query(format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
"SELECT * FROM {} WHERE source_id = 'different_source'",
|
||||
TextChunk::table_name(),
|
||||
"different_source"
|
||||
))
|
||||
.await
|
||||
.expect("Query failed")
|
||||
.with_context(|| "Query failed".to_string())?
|
||||
.take(0)
|
||||
.expect("Failed to get query results");
|
||||
.with_context(|| "Failed to get query results".to_string())?;
|
||||
assert_eq!(different_remaining.len(), 1);
|
||||
assert_eq!(different_remaining[0].id, different_chunk.id);
|
||||
assert_eq!(
|
||||
different_remaining.first().map(|r| &r.id),
|
||||
Some(&different_chunk.id)
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_nonexistent_source_id() {
|
||||
async fn test_delete_by_nonexistent_source_id() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
.with_context(|| "redefine index".to_string())?;
|
||||
|
||||
let real_source_id = "real_source".to_string();
|
||||
let chunk = TextChunk::new(
|
||||
@@ -600,24 +594,24 @@ mod tests {
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
|
||||
.await
|
||||
.expect("store chunk");
|
||||
.with_context(|| "store chunk".to_string())?;
|
||||
|
||||
TextChunk::delete_by_source_id("nonexistent_source", &db)
|
||||
.await
|
||||
.expect("Delete should succeed");
|
||||
.with_context(|| "Delete should succeed".to_string())?;
|
||||
|
||||
let remaining: Vec<TextChunk> = db
|
||||
.client
|
||||
.query(format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
"SELECT * FROM {} WHERE source_id = '{real_source_id}'",
|
||||
TextChunk::table_name(),
|
||||
real_source_id
|
||||
))
|
||||
.await
|
||||
.expect("Query failed")
|
||||
.with_context(|| "Query failed".to_string())?
|
||||
.take(0)
|
||||
.expect("Failed to get query results");
|
||||
.with_context(|| "Failed to get query results".to_string())?;
|
||||
assert_eq!(remaining.len(), 1);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -672,13 +666,13 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_with_embedding_creates_both_records() {
|
||||
async fn test_store_with_embedding_creates_both_records() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let source_id = "store-src".to_string();
|
||||
let user_id = "user_store".to_string();
|
||||
@@ -686,43 +680,43 @@ mod tests {
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
.with_context(|| "redefine index".to_string())?;
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
.await
|
||||
.expect("store with embedding");
|
||||
.with_context(|| "store with embedding".to_string())?;
|
||||
|
||||
let stored_chunk: Option<TextChunk> = db.get_item(&chunk.id).await.unwrap();
|
||||
assert!(stored_chunk.is_some());
|
||||
let stored_chunk = stored_chunk.unwrap();
|
||||
let stored_chunk: Option<TextChunk> = db.get_item(&chunk.id)
|
||||
.await
|
||||
.with_context(|| "get_item".to_string())?;
|
||||
let stored_chunk = stored_chunk.with_context(|| "expected stored chunk".to_string())?;
|
||||
assert_eq!(stored_chunk.source_id, source_id);
|
||||
assert_eq!(stored_chunk.user_id, user_id);
|
||||
|
||||
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
|
||||
let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db)
|
||||
.await
|
||||
.expect("get embedding");
|
||||
assert!(embedding.is_some());
|
||||
let embedding = embedding.unwrap();
|
||||
.with_context(|| "get embedding".to_string())?
|
||||
.with_context(|| "expected embedding".to_string())?;
|
||||
assert_eq!(embedding.chunk_id, rid);
|
||||
assert_eq!(embedding.user_id, user_id);
|
||||
assert_eq!(embedding.source_id, source_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_with_embedding_with_runtime_indexes() {
|
||||
async fn test_store_with_embedding_with_runtime_indexes() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns_runtime";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
|
||||
// Ensure runtime indexes are built with the expected dimension.
|
||||
let embedding_dimension = 3usize;
|
||||
ensure_runtime_indexes(&db, embedding_dimension)
|
||||
ensure_runtime(&db, embedding_dimension)
|
||||
.await
|
||||
.expect("ensure runtime indexes");
|
||||
.with_context(|| "ensure runtime indexes".to_string())?;
|
||||
|
||||
let chunk = TextChunk::new(
|
||||
"runtime_src".to_string(),
|
||||
@@ -732,55 +726,60 @@ mod tests {
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
.await
|
||||
.expect("store with embedding");
|
||||
.with_context(|| "store with embedding".to_string())?;
|
||||
|
||||
let stored_chunk: Option<TextChunk> = db.get_item(&chunk.id).await.unwrap();
|
||||
assert!(stored_chunk.is_some(), "chunk should be stored");
|
||||
let stored_chunk: Option<TextChunk> = db.get_item(&chunk.id)
|
||||
.await
|
||||
.with_context(|| "get_item".to_string())?;
|
||||
let stored_chunk = stored_chunk.with_context(|| "chunk should be stored".to_string())?;
|
||||
assert!(stored_chunk.id == chunk.id, "chunk should be stored");
|
||||
|
||||
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
|
||||
let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db)
|
||||
.await
|
||||
.expect("get embedding");
|
||||
assert!(embedding.is_some(), "embedding should exist");
|
||||
.with_context(|| "get embedding".to_string())?
|
||||
.with_context(|| "embedding should exist".to_string())?;
|
||||
assert_eq!(
|
||||
embedding.unwrap().embedding.len(),
|
||||
embedding.embedding.len(),
|
||||
embedding_dimension,
|
||||
"embedding dimension should match runtime index"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_returns_empty_when_no_embeddings() {
|
||||
async fn test_vector_search_returns_empty_when_no_embeddings() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
.with_context(|| "redefine index".to_string())?;
|
||||
|
||||
let results: Vec<TextChunkSearchResult> =
|
||||
TextChunk::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user")
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "vector_search".to_string())?;
|
||||
assert!(results.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_single_result() {
|
||||
async fn test_vector_search_single_result() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
.with_context(|| "redefine index".to_string())?;
|
||||
|
||||
let source_id = "src".to_string();
|
||||
let user_id = "user".to_string();
|
||||
@@ -792,32 +791,33 @@ mod tests {
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
.await
|
||||
.expect("store");
|
||||
.with_context(|| "store".to_string())?;
|
||||
|
||||
let results: Vec<TextChunkSearchResult> =
|
||||
TextChunk::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "vector_search".to_string())?;
|
||||
|
||||
assert_eq!(results.len(), 1);
|
||||
let res = &results[0];
|
||||
let res = results.first().context("expected first result")?;
|
||||
assert_eq!(res.chunk.id, chunk.id);
|
||||
assert_eq!(res.chunk.source_id, source_id);
|
||||
assert_eq!(res.chunk.chunk, "hello world");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_orders_by_similarity() {
|
||||
async fn test_vector_search_orders_by_similarity() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
.with_context(|| "redefine index".to_string())?;
|
||||
|
||||
let user_id = "user".to_string();
|
||||
let chunk1 = TextChunk::new("s1".to_string(), "chunk one".to_string(), user_id.clone());
|
||||
@@ -825,49 +825,59 @@ mod tests {
|
||||
|
||||
TextChunk::store_with_embedding(chunk1.clone(), vec![1.0, 0.0, 0.0], &db)
|
||||
.await
|
||||
.expect("store chunk1");
|
||||
.with_context(|| "store chunk1".to_string())?;
|
||||
TextChunk::store_with_embedding(chunk2.clone(), vec![0.0, 1.0, 0.0], &db)
|
||||
.await
|
||||
.expect("store chunk2");
|
||||
.with_context(|| "store chunk2".to_string())?;
|
||||
|
||||
let results: Vec<TextChunkSearchResult> =
|
||||
TextChunk::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id)
|
||||
.await
|
||||
.unwrap();
|
||||
.with_context(|| "vector_search".to_string())?;
|
||||
|
||||
assert_eq!(results.len(), 2);
|
||||
assert_eq!(results[0].chunk.id, chunk2.id);
|
||||
assert_eq!(results[1].chunk.id, chunk1.id);
|
||||
assert!(results[0].score >= results[1].score);
|
||||
assert_eq!(
|
||||
results.first().map(|r| &r.chunk.id),
|
||||
Some(&chunk2.id)
|
||||
);
|
||||
assert_eq!(
|
||||
results.get(1).map(|r| &r.chunk.id),
|
||||
Some(&chunk1.id)
|
||||
);
|
||||
let r0 = results.first().context("expected first result")?;
|
||||
let r1 = results.get(1).context("expected second result")?;
|
||||
assert!(r0.score >= r1.score);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fts_search_returns_empty_when_no_chunks() {
|
||||
async fn test_fts_search_returns_empty_when_no_chunks() -> anyhow::Result<()> {
|
||||
let namespace = "fts_chunk_ns_empty";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
ensure_chunk_fts_index(&db).await;
|
||||
rebuild_indexes(&db).await.expect("rebuild indexes");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
ensure_chunk_fts_index(&db).await?;
|
||||
rebuild(&db).await.with_context(|| "rebuild indexes".to_string())?;
|
||||
|
||||
let results = TextChunk::fts_search(5, "hello", &db, "user")
|
||||
.await
|
||||
.expect("fts search");
|
||||
.with_context(|| "fts search".to_string())?;
|
||||
|
||||
assert!(results.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fts_search_single_result() {
|
||||
async fn test_fts_search_single_result() -> anyhow::Result<()> {
|
||||
let namespace = "fts_chunk_ns_single";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
ensure_chunk_fts_index(&db).await;
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
ensure_chunk_fts_index(&db).await?;
|
||||
|
||||
let user_id = "fts_user";
|
||||
let chunk = TextChunk::new(
|
||||
@@ -875,27 +885,29 @@ mod tests {
|
||||
"rustaceans love rust".to_string(),
|
||||
user_id.to_string(),
|
||||
);
|
||||
db.store_item(chunk.clone()).await.expect("store chunk");
|
||||
rebuild_indexes(&db).await.expect("rebuild indexes");
|
||||
db.store_item(chunk.clone()).await.with_context(|| "store chunk".to_string())?;
|
||||
rebuild(&db).await.with_context(|| "rebuild indexes".to_string())?;
|
||||
|
||||
let results = TextChunk::fts_search(3, "rust", &db, user_id)
|
||||
.await
|
||||
.expect("fts search");
|
||||
.with_context(|| "fts search".to_string())?;
|
||||
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].chunk.id, chunk.id);
|
||||
assert!(results[0].score.is_finite(), "expected a finite FTS score");
|
||||
let r0 = results.first().context("expected first result")?;
|
||||
assert_eq!(r0.chunk.id, chunk.id);
|
||||
assert!(r0.score.is_finite(), "expected a finite FTS score");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fts_search_orders_by_score_and_filters_user() {
|
||||
async fn test_fts_search_orders_by_score_and_filters_user() -> anyhow::Result<()> {
|
||||
let namespace = "fts_chunk_ns_order";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
ensure_chunk_fts_index(&db).await;
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations().await.with_context(|| "migrations".to_string())?;
|
||||
ensure_chunk_fts_index(&db).await?;
|
||||
|
||||
let user_id = "fts_user_order";
|
||||
let high_score_chunk = TextChunk::new(
|
||||
@@ -916,18 +928,18 @@ mod tests {
|
||||
|
||||
db.store_item(high_score_chunk.clone())
|
||||
.await
|
||||
.expect("store high score chunk");
|
||||
.with_context(|| "store high score chunk".to_string())?;
|
||||
db.store_item(low_score_chunk.clone())
|
||||
.await
|
||||
.expect("store low score chunk");
|
||||
.with_context(|| "store low score chunk".to_string())?;
|
||||
db.store_item(other_user_chunk)
|
||||
.await
|
||||
.expect("store other user chunk");
|
||||
rebuild_indexes(&db).await.expect("rebuild indexes");
|
||||
.with_context(|| "store other user chunk".to_string())?;
|
||||
rebuild(&db).await.with_context(|| "rebuild indexes".to_string())?;
|
||||
|
||||
let results = TextChunk::fts_search(3, "apple", &db, user_id)
|
||||
.await
|
||||
.expect("fts search");
|
||||
.with_context(|| "fts search".to_string())?;
|
||||
|
||||
assert_eq!(results.len(), 2);
|
||||
let ids: Vec<_> = results.iter().map(|r| r.chunk.id.as_str()).collect();
|
||||
@@ -936,9 +948,12 @@ mod tests {
|
||||
&& ids.contains(&low_score_chunk.id.as_str()),
|
||||
"expected only the two chunks for the same user"
|
||||
);
|
||||
let r0 = results.first().context("expected first result")?;
|
||||
let r1 = results.get(1).context("expected second result")?;
|
||||
assert!(
|
||||
results[0].score >= results[1].score,
|
||||
r0.score >= r1.score,
|
||||
"expected results ordered by descending score"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -126,24 +126,26 @@ impl TextChunkEmbedding {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
|
||||
use super::*;
|
||||
use crate::storage::db::SurrealDbClient;
|
||||
use surrealdb::Value as SurrealValue;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Helper to create an in-memory DB and apply migrations
|
||||
async fn setup_test_db() -> SurrealDbClient {
|
||||
async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
|
||||
let namespace = "test_ns";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
db
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
/// Helper: create a text_chunk with a known key, return its RecordId
|
||||
@@ -152,7 +154,7 @@ mod tests {
|
||||
key: &str,
|
||||
source_id: &str,
|
||||
user_id: &str,
|
||||
) -> RecordId {
|
||||
) -> anyhow::Result<RecordId> {
|
||||
let chunk = TextChunk {
|
||||
id: key.to_owned(),
|
||||
created_at: Utc::now(),
|
||||
@@ -164,21 +166,42 @@ mod tests {
|
||||
|
||||
db.store_item(chunk)
|
||||
.await
|
||||
.expect("Failed to create text_chunk");
|
||||
.with_context(|| "Failed to create text_chunk".to_string())?;
|
||||
|
||||
RecordId::from_table_key(TextChunk::table_name(), key)
|
||||
Ok(RecordId::from_table_key(TextChunk::table_name(), key))
|
||||
}
|
||||
|
||||
async fn get_idx_sql(db: &SurrealDbClient) -> anyhow::Result<String> {
|
||||
let mut info_res = db
|
||||
.client
|
||||
.query("INFO FOR TABLE text_chunk_embedding;")
|
||||
.await
|
||||
.with_context(|| "info query failed".to_string())?;
|
||||
let info: SurrealValue = info_res.take(0).with_context(|| "failed to take info result".to_string())?;
|
||||
let info_json: serde_json::Value =
|
||||
serde_json::to_value(info).with_context(|| "failed to convert info to json".to_string())?;
|
||||
let idx_sql = info_json
|
||||
.get("Object")
|
||||
.and_then(|v| v.get("indexes"))
|
||||
.and_then(|v| v.get("Object"))
|
||||
.and_then(|v| v.get("idx_embedding_text_chunk_embedding"))
|
||||
.and_then(|v| v.get("Strand"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
Ok(idx_sql)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_and_get_by_chunk_id() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_create_and_get_by_chunk_id() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let user_id = "user_a";
|
||||
let chunk_key = "chunk-123";
|
||||
let source_id = "source-1";
|
||||
|
||||
// 1) Create a text_chunk with a known key
|
||||
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await;
|
||||
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await?;
|
||||
|
||||
// 2) Create and store an embedding for that chunk
|
||||
let embedding_vec = vec![0.1_f32, 0.2, 0.3];
|
||||
@@ -191,39 +214,37 @@ mod tests {
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, emb.embedding.len())
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
|
||||
let _: Option<TextChunkEmbedding> = db
|
||||
.client
|
||||
.create(TextChunkEmbedding::table_name())
|
||||
.content(emb)
|
||||
.await
|
||||
.expect("Failed to store embedding")
|
||||
.take()
|
||||
.expect("Failed to deserialize stored embedding");
|
||||
.with_context(|| "Failed to store embedding".to_string())?
|
||||
.with_context(|| "Failed to deserialize stored embedding".to_string())?;
|
||||
|
||||
// 3) Fetch it via get_by_chunk_id
|
||||
let fetched = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
||||
.await
|
||||
.expect("Failed to get embedding by chunk_id");
|
||||
|
||||
assert!(fetched.is_some(), "Expected an embedding to be found");
|
||||
let fetched = fetched.unwrap();
|
||||
.with_context(|| "Failed to get embedding by chunk_id".to_string())?
|
||||
.with_context(|| "Expected an embedding to be found".to_string())?;
|
||||
|
||||
assert_eq!(fetched.user_id, user_id);
|
||||
assert_eq!(fetched.chunk_id, chunk_rid);
|
||||
assert_eq!(fetched.embedding, embedding_vec);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_chunk_id() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_delete_by_chunk_id() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let user_id = "user_b";
|
||||
let chunk_key = "chunk-delete";
|
||||
let source_id = "source-del";
|
||||
|
||||
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await;
|
||||
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await?;
|
||||
|
||||
let emb = TextChunkEmbedding::new(
|
||||
chunk_key,
|
||||
@@ -234,50 +255,50 @@ mod tests {
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, emb.embedding.len())
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
|
||||
let _: Option<TextChunkEmbedding> = db
|
||||
.client
|
||||
.create(TextChunkEmbedding::table_name())
|
||||
.content(emb)
|
||||
.await
|
||||
.expect("Failed to store embedding")
|
||||
.take()
|
||||
.expect("Failed to deserialize stored embedding");
|
||||
.with_context(|| "Failed to store embedding".to_string())?
|
||||
.with_context(|| "Failed to deserialize stored embedding".to_string())?;
|
||||
|
||||
// Ensure it exists
|
||||
let existing = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
||||
.await
|
||||
.expect("Failed to get embedding before delete");
|
||||
.with_context(|| "Failed to get embedding before delete".to_string())?;
|
||||
assert!(existing.is_some(), "Embedding should exist before delete");
|
||||
|
||||
// Delete by chunk_id
|
||||
TextChunkEmbedding::delete_by_chunk_id(&chunk_rid, &db)
|
||||
.await
|
||||
.expect("Failed to delete by chunk_id");
|
||||
.with_context(|| "Failed to delete by chunk_id".to_string())?;
|
||||
|
||||
// Ensure it no longer exists
|
||||
let after = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
||||
.await
|
||||
.expect("Failed to get embedding after delete");
|
||||
.with_context(|| "Failed to get embedding after delete".to_string())?;
|
||||
assert!(after.is_none(), "Embedding should have been deleted");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_delete_by_source_id() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let user_id = "user_c";
|
||||
let source_id = "shared-source";
|
||||
let other_source = "other-source";
|
||||
|
||||
// Two chunks with the same source_id
|
||||
let chunk1_rid = create_text_chunk_with_id(&db, "chunk-s1", source_id, user_id).await;
|
||||
let chunk2_rid = create_text_chunk_with_id(&db, "chunk-s2", source_id, user_id).await;
|
||||
let chunk1_rid = create_text_chunk_with_id(&db, "chunk-s1", source_id, user_id).await?;
|
||||
let chunk2_rid = create_text_chunk_with_id(&db, "chunk-s2", source_id, user_id).await?;
|
||||
|
||||
// One chunk with a different source_id
|
||||
let chunk_other_rid =
|
||||
create_text_chunk_with_id(&db, "chunk-other", other_source, user_id).await;
|
||||
create_text_chunk_with_id(&db, "chunk-other", other_source, user_id).await?;
|
||||
|
||||
// Create embeddings for all three
|
||||
let emb1 = TextChunkEmbedding::new(
|
||||
@@ -302,7 +323,7 @@ mod tests {
|
||||
// Update length on index
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, emb1.embedding.len())
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
|
||||
for emb in [emb1, emb2, emb3] {
|
||||
let _: Option<TextChunkEmbedding> = db
|
||||
@@ -310,102 +331,82 @@ mod tests {
|
||||
.create(TextChunkEmbedding::table_name())
|
||||
.content(emb)
|
||||
.await
|
||||
.expect("Failed to store embedding")
|
||||
.take()
|
||||
.expect("Failed to deserialize stored embedding");
|
||||
.with_context(|| "Failed to store embedding".to_string())?
|
||||
.with_context(|| "Failed to deserialize stored embedding".to_string())?;
|
||||
}
|
||||
|
||||
// Sanity check: they all exist
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.with_context(|| "get chunk1".to_string())?
|
||||
.is_some());
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.with_context(|| "get chunk2".to_string())?
|
||||
.is_some());
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.with_context(|| "get chunk_other".to_string())?
|
||||
.is_some());
|
||||
|
||||
// Delete embeddings by source_id (shared-source)
|
||||
TextChunkEmbedding::delete_by_source_id(source_id, &db)
|
||||
.await
|
||||
.expect("Failed to delete by source_id");
|
||||
.with_context(|| "Failed to delete by source_id".to_string())?;
|
||||
|
||||
// Chunks from shared-source should have no embeddings
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.with_context(|| "check chunk1".to_string())?
|
||||
.is_none());
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.with_context(|| "check chunk2".to_string())?
|
||||
.is_none());
|
||||
|
||||
// The other chunk should still have its embedding
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.with_context(|| "check chunk_other".to_string())?
|
||||
.is_some());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_redefine_hnsw_index_updates_dimension() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_redefine_hnsw_index_updates_dimension() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
// Change the index dimension from default (1536) to a smaller test value.
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 8)
|
||||
.await
|
||||
.expect("failed to redefine index");
|
||||
.with_context(|| "failed to redefine index".to_string())?;
|
||||
|
||||
let mut info_res = db
|
||||
.client
|
||||
.query("INFO FOR TABLE text_chunk_embedding;")
|
||||
.await
|
||||
.expect("info query failed");
|
||||
let info: SurrealValue = info_res.take(0).expect("failed to take info result");
|
||||
let info_json: serde_json::Value =
|
||||
serde_json::to_value(info).expect("failed to convert info to json");
|
||||
let idx_sql = info_json["Object"]["indexes"]["Object"]
|
||||
["idx_embedding_text_chunk_embedding"]["Strand"]
|
||||
.as_str()
|
||||
.unwrap_or_default();
|
||||
let idx_sql = get_idx_sql(&db).await?;
|
||||
|
||||
assert!(
|
||||
idx_sql.contains("DIMENSION 8"),
|
||||
"expected index definition to contain new dimension, got: {idx_sql}"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_redefine_hnsw_index_is_idempotent() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_redefine_hnsw_index_is_idempotent() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 4)
|
||||
.await
|
||||
.expect("first redefine failed");
|
||||
.with_context(|| "first redefine failed".to_string())?;
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 4)
|
||||
.await
|
||||
.expect("second redefine failed");
|
||||
.with_context(|| "second redefine failed".to_string())?;
|
||||
|
||||
let mut info_res = db
|
||||
.client
|
||||
.query("INFO FOR TABLE text_chunk_embedding;")
|
||||
.await
|
||||
.expect("info query failed");
|
||||
let info: SurrealValue = info_res.take(0).expect("failed to take info result");
|
||||
let info_json: serde_json::Value =
|
||||
serde_json::to_value(info).expect("failed to convert info to json");
|
||||
let idx_sql = info_json["Object"]["indexes"]["Object"]
|
||||
["idx_embedding_text_chunk_embedding"]["Strand"]
|
||||
.as_str()
|
||||
.unwrap_or_default();
|
||||
let idx_sql = get_idx_sql(&db).await?;
|
||||
|
||||
assert!(
|
||||
idx_sql.contains("DIMENSION 4"),
|
||||
"expected index definition to retain dimension 4, got: {idx_sql}"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,10 +185,12 @@ impl TextContent {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_content_creation() {
|
||||
async fn test_text_content_creation() -> anyhow::Result<()> {
|
||||
// Test basic object creation
|
||||
let text = "Test content text".to_string();
|
||||
let context = "Test context".to_string();
|
||||
@@ -212,10 +214,11 @@ mod tests {
|
||||
assert!(text_content.file_info.is_none());
|
||||
assert!(text_content.url_info.is_none());
|
||||
assert!(!text_content.id.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_content_with_url() {
|
||||
async fn test_text_content_with_url() -> anyhow::Result<()> {
|
||||
// Test creating with URL
|
||||
let text = "Content with URL".to_string();
|
||||
let context = "URL context".to_string();
|
||||
@@ -232,26 +235,27 @@ mod tests {
|
||||
});
|
||||
|
||||
let text_content = TextContent::new(
|
||||
text.clone(),
|
||||
Some(context.clone()),
|
||||
category.clone(),
|
||||
text,
|
||||
Some(context),
|
||||
category,
|
||||
None,
|
||||
url_info.clone(),
|
||||
user_id.clone(),
|
||||
user_id,
|
||||
);
|
||||
|
||||
// Check URL field is set
|
||||
assert_eq!(text_content.url_info, url_info);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_content_patch() {
|
||||
async fn test_text_content_patch() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Create initial text content
|
||||
let initial_text = "Initial text".to_string();
|
||||
@@ -272,7 +276,7 @@ mod tests {
|
||||
let stored: Option<TextContent> = db
|
||||
.store_item(text_content.clone())
|
||||
.await
|
||||
.expect("Failed to store text content");
|
||||
.with_context(|| "Failed to store text content".to_string())?;
|
||||
assert!(stored.is_some());
|
||||
|
||||
// New values for patch
|
||||
@@ -283,31 +287,30 @@ mod tests {
|
||||
// Apply the patch
|
||||
TextContent::patch(&text_content.id, new_context, new_category, new_text, &db)
|
||||
.await
|
||||
.expect("Failed to patch text content");
|
||||
.with_context(|| "Failed to patch text content".to_string())?;
|
||||
|
||||
// Retrieve the updated content
|
||||
let updated: Option<TextContent> = db
|
||||
.get_item(&text_content.id)
|
||||
.await
|
||||
.expect("Failed to get updated text content");
|
||||
assert!(updated.is_some());
|
||||
|
||||
let updated_content = updated.unwrap();
|
||||
.with_context(|| "Failed to get updated text content".to_string())?;
|
||||
let updated_content = updated.with_context(|| "expected updated content".to_string())?;
|
||||
|
||||
// Verify the updates
|
||||
assert_eq!(updated_content.context, Some(new_context.to_string()));
|
||||
assert_eq!(updated_content.category, new_category);
|
||||
assert_eq!(updated_content.text, new_text);
|
||||
assert!(updated_content.updated_at > text_content.updated_at);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_has_other_with_file_detects_shared_usage() {
|
||||
async fn test_has_other_with_file_detects_shared_usage() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
let user_id = "user123".to_string();
|
||||
let file_info = FileInfo {
|
||||
@@ -340,24 +343,25 @@ mod tests {
|
||||
|
||||
db.store_item(content_a.clone())
|
||||
.await
|
||||
.expect("Failed to store first content");
|
||||
.with_context(|| "Failed to store first content".to_string())?;
|
||||
db.store_item(content_b.clone())
|
||||
.await
|
||||
.expect("Failed to store second content");
|
||||
.with_context(|| "Failed to store second content".to_string())?;
|
||||
|
||||
let has_other = TextContent::has_other_with_file(&file_info.id, &content_a.id, &db)
|
||||
.await
|
||||
.expect("Failed to check for shared file usage");
|
||||
.with_context(|| "Failed to check for shared file usage".to_string())?;
|
||||
assert!(has_other);
|
||||
|
||||
let _removed: Option<TextContent> = db
|
||||
.delete_item(&content_b.id)
|
||||
.await
|
||||
.expect("Failed to delete second content");
|
||||
.with_context(|| "Failed to delete second content".to_string())?;
|
||||
|
||||
let has_other_after = TextContent::has_other_with_file(&file_info.id, &content_a.id, &db)
|
||||
.await
|
||||
.expect("Failed to check shared usage after delete");
|
||||
.with_context(|| "Failed to check shared usage after delete".to_string())?;
|
||||
assert!(!has_other_after);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -723,30 +723,32 @@ impl User {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
|
||||
use super::*;
|
||||
use crate::storage::types::ingestion_payload::IngestionPayload;
|
||||
use crate::storage::types::ingestion_task::{IngestionTask, TaskState, MAX_ATTEMPTS};
|
||||
use std::collections::HashSet;
|
||||
|
||||
// Helper function to set up a test database with SystemSettings
|
||||
async fn setup_test_db() -> SurrealDbClient {
|
||||
async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
|
||||
let namespace = "test_ns";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to setup the migrations");
|
||||
.with_context(|| "Failed to setup the migrations".to_string())?;
|
||||
|
||||
db
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_user_creation() {
|
||||
async fn test_user_creation() -> anyhow::Result<()> {
|
||||
// Setup test database
|
||||
let db = setup_test_db().await;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
// Create a user
|
||||
let email = "test@example.com";
|
||||
@@ -761,7 +763,7 @@ mod tests {
|
||||
"system".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create user");
|
||||
.with_context(|| "Failed to create user".to_string())?;
|
||||
|
||||
// Verify user properties
|
||||
assert!(!user.id.is_empty());
|
||||
@@ -774,18 +776,17 @@ mod tests {
|
||||
let retrieved: Option<User> = db
|
||||
.get_item(&user.id)
|
||||
.await
|
||||
.expect("Failed to retrieve user");
|
||||
assert!(retrieved.is_some());
|
||||
|
||||
let retrieved = retrieved.unwrap();
|
||||
.with_context(|| "Failed to retrieve user".to_string())?;
|
||||
let retrieved = retrieved.with_context(|| "expected user to exist".to_string())?;
|
||||
assert_eq!(retrieved.id, user.id);
|
||||
assert_eq!(retrieved.email, email);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_user_authentication() {
|
||||
async fn test_user_authentication() -> anyhow::Result<()> {
|
||||
// Setup test database
|
||||
let db = setup_test_db().await;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
// Create a user
|
||||
let email = "auth_test@example.com";
|
||||
@@ -799,7 +800,7 @@ mod tests {
|
||||
"system".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create user");
|
||||
.with_context(|| "Failed to create user".to_string())?;
|
||||
|
||||
// Test successful authentication
|
||||
let auth_result = User::authenticate(email, password, &db).await;
|
||||
@@ -812,11 +813,12 @@ mod tests {
|
||||
// Test failed authentication with non-existent user
|
||||
let nonexistent = User::authenticate("nonexistent@example.com", password, &db).await;
|
||||
assert!(nonexistent.is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_unfinished_ingestion_tasks_filters_correctly() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_get_unfinished_ingestion_tasks_filters_correctly() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "unfinished_user";
|
||||
let other_user_id = "other_user";
|
||||
|
||||
@@ -830,14 +832,14 @@ mod tests {
|
||||
let created_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
db.store_item(created_task.clone())
|
||||
.await
|
||||
.expect("Failed to store created task");
|
||||
.with_context(|| "Failed to store created task".to_string())?;
|
||||
|
||||
let mut processing_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
processing_task.state = TaskState::Processing;
|
||||
processing_task.attempts = 1;
|
||||
db.store_item(processing_task.clone())
|
||||
.await
|
||||
.expect("Failed to store processing task");
|
||||
.with_context(|| "Failed to store processing task".to_string())?;
|
||||
|
||||
let mut failed_retry_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
failed_retry_task.state = TaskState::Failed;
|
||||
@@ -845,7 +847,7 @@ mod tests {
|
||||
failed_retry_task.scheduled_at = chrono::Utc::now() - chrono::Duration::minutes(5);
|
||||
db.store_item(failed_retry_task.clone())
|
||||
.await
|
||||
.expect("Failed to store retryable failed task");
|
||||
.with_context(|| "Failed to store retryable failed task".to_string())?;
|
||||
|
||||
let mut failed_blocked_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
failed_blocked_task.state = TaskState::Failed;
|
||||
@@ -853,13 +855,13 @@ mod tests {
|
||||
failed_blocked_task.error_message = Some("Too many failures".into());
|
||||
db.store_item(failed_blocked_task.clone())
|
||||
.await
|
||||
.expect("Failed to store blocked task");
|
||||
.with_context(|| "Failed to store blocked task".to_string())?;
|
||||
|
||||
let mut completed_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
completed_task.state = TaskState::Succeeded;
|
||||
db.store_item(completed_task.clone())
|
||||
.await
|
||||
.expect("Failed to store completed task");
|
||||
.with_context(|| "Failed to store completed task".to_string())?;
|
||||
|
||||
let other_payload = IngestionPayload::Text {
|
||||
text: "Other".to_string(),
|
||||
@@ -870,11 +872,11 @@ mod tests {
|
||||
let other_task = IngestionTask::new(other_payload, other_user_id.to_string());
|
||||
db.store_item(other_task)
|
||||
.await
|
||||
.expect("Failed to store other user task");
|
||||
.with_context(|| "Failed to store other user task".to_string())?;
|
||||
|
||||
let unfinished = User::get_unfinished_ingestion_tasks(user_id, &db)
|
||||
.await
|
||||
.expect("Failed to fetch unfinished tasks");
|
||||
.with_context(|| "Failed to fetch unfinished tasks".to_string())?;
|
||||
|
||||
let unfinished_ids: HashSet<String> =
|
||||
unfinished.iter().map(|task| task.id.clone()).collect();
|
||||
@@ -885,11 +887,12 @@ mod tests {
|
||||
assert!(!unfinished_ids.contains(&failed_blocked_task.id));
|
||||
assert!(!unfinished_ids.contains(&completed_task.id));
|
||||
assert_eq!(unfinished_ids.len(), 3);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_all_ingestion_tasks_returns_sorted() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_get_all_ingestion_tasks_returns_sorted() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "archive_user";
|
||||
let other_user_id = "other_user";
|
||||
|
||||
@@ -902,15 +905,15 @@ mod tests {
|
||||
|
||||
// Oldest task
|
||||
let mut first = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
first.created_at = first.created_at - chrono::Duration::minutes(1);
|
||||
first.created_at -= chrono::Duration::minutes(1);
|
||||
first.updated_at = first.created_at;
|
||||
first.state = TaskState::Succeeded;
|
||||
db.store_item(first.clone()).await.expect("store first");
|
||||
db.store_item(first.clone()).await.with_context(|| "store first".to_string())?;
|
||||
|
||||
// Latest task
|
||||
let mut second = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
second.state = TaskState::Processing;
|
||||
db.store_item(second.clone()).await.expect("store second");
|
||||
db.store_item(second.clone()).await.with_context(|| "store second".to_string())?;
|
||||
|
||||
let other_payload = IngestionPayload::Text {
|
||||
text: "Other".to_string(),
|
||||
@@ -919,21 +922,22 @@ mod tests {
|
||||
user_id: other_user_id.to_string(),
|
||||
};
|
||||
let other_task = IngestionTask::new(other_payload, other_user_id.to_string());
|
||||
db.store_item(other_task).await.expect("store other");
|
||||
db.store_item(other_task).await.with_context(|| "store other".to_string())?;
|
||||
|
||||
let tasks = User::get_all_ingestion_tasks(user_id, &db)
|
||||
.await
|
||||
.expect("fetch all tasks");
|
||||
.with_context(|| "fetch all tasks".to_string())?;
|
||||
|
||||
assert_eq!(tasks.len(), 2);
|
||||
assert_eq!(tasks[0].id, second.id); // newest first
|
||||
assert_eq!(tasks[1].id, first.id);
|
||||
assert_eq!(tasks.first().map(|t| &t.id), Some(&second.id)); // newest first
|
||||
assert_eq!(tasks.get(1).map(|t| &t.id), Some(&first.id));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_find_by_email() {
|
||||
async fn test_find_by_email() -> anyhow::Result<()> {
|
||||
// Setup test database
|
||||
let db = setup_test_db().await;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
// Create a user
|
||||
let email = "find_test@example.com";
|
||||
@@ -947,28 +951,28 @@ mod tests {
|
||||
"system".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create user");
|
||||
.with_context(|| "Failed to create user".to_string())?;
|
||||
|
||||
// Test finding user by email
|
||||
let found_user = User::find_by_email(email, &db)
|
||||
.await
|
||||
.expect("Error searching for user");
|
||||
assert!(found_user.is_some());
|
||||
let found_user = found_user.unwrap();
|
||||
.with_context(|| "Error searching for user".to_string())?
|
||||
.with_context(|| "expected user to exist".to_string())?;
|
||||
assert_eq!(found_user.id, created_user.id);
|
||||
assert_eq!(found_user.email, email);
|
||||
|
||||
// Test finding non-existent user
|
||||
let not_found = User::find_by_email("nonexistent@example.com", &db)
|
||||
.await
|
||||
.expect("Error searching for user");
|
||||
.with_context(|| "Error searching for user".to_string())?;
|
||||
assert!(not_found.is_none());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_api_key_management() {
|
||||
async fn test_api_key_management() -> anyhow::Result<()> {
|
||||
// Setup test database
|
||||
let db = setup_test_db().await;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
// Create a user
|
||||
let email = "apikey_test@example.com";
|
||||
@@ -982,7 +986,7 @@ mod tests {
|
||||
"system".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create user");
|
||||
.with_context(|| "Failed to create user".to_string())?;
|
||||
|
||||
// Initially, user should have no API key
|
||||
assert!(user.api_key.is_none());
|
||||
@@ -990,7 +994,7 @@ mod tests {
|
||||
// Generate API key
|
||||
let api_key = User::set_api_key(&user.id, &db)
|
||||
.await
|
||||
.expect("Failed to set API key");
|
||||
.with_context(|| "Failed to set API key".to_string())?;
|
||||
assert!(!api_key.is_empty());
|
||||
assert!(api_key.starts_with("sk_"));
|
||||
|
||||
@@ -998,38 +1002,36 @@ mod tests {
|
||||
let updated_user: Option<User> = db
|
||||
.get_item(&user.id)
|
||||
.await
|
||||
.expect("Failed to retrieve user");
|
||||
assert!(updated_user.is_some());
|
||||
let updated_user = updated_user.unwrap();
|
||||
.with_context(|| "Failed to retrieve user".to_string())?;
|
||||
let updated_user = updated_user.with_context(|| "expected updated user".to_string())?;
|
||||
assert_eq!(updated_user.api_key, Some(api_key.clone()));
|
||||
|
||||
// Test finding user by API key
|
||||
let found_user = User::find_by_api_key(&api_key, &db)
|
||||
.await
|
||||
.expect("Error searching by API key");
|
||||
assert!(found_user.is_some());
|
||||
let found_user = found_user.unwrap();
|
||||
.with_context(|| "Error searching by API key".to_string())?
|
||||
.with_context(|| "expected user found by api key".to_string())?;
|
||||
assert_eq!(found_user.id, user.id);
|
||||
|
||||
// Revoke API key
|
||||
User::revoke_api_key(&user.id, &db)
|
||||
.await
|
||||
.expect("Failed to revoke API key");
|
||||
.with_context(|| "Failed to revoke API key".to_string())?;
|
||||
|
||||
// Verify API key was revoked
|
||||
let revoked_user: Option<User> = db
|
||||
.get_item(&user.id)
|
||||
.await
|
||||
.expect("Failed to retrieve user");
|
||||
assert!(revoked_user.is_some());
|
||||
let revoked_user = revoked_user.unwrap();
|
||||
.with_context(|| "Failed to retrieve user".to_string())?;
|
||||
let revoked_user = revoked_user.with_context(|| "expected revoked user".to_string())?;
|
||||
assert!(revoked_user.api_key.is_none());
|
||||
|
||||
// Test searching by revoked API key
|
||||
let not_found = User::find_by_api_key(&api_key, &db)
|
||||
.await
|
||||
.expect("Error searching by API key");
|
||||
.with_context(|| "Error searching by API key".to_string())?;
|
||||
assert!(not_found.is_none());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -1069,9 +1071,9 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_password_update() {
|
||||
async fn test_password_update() -> anyhow::Result<()> {
|
||||
// Setup test database
|
||||
let db = setup_test_db().await;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
// Create a user
|
||||
let email = "pwd_test@example.com";
|
||||
@@ -1086,7 +1088,7 @@ mod tests {
|
||||
"system".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create user");
|
||||
.with_context(|| "Failed to create user".to_string())?;
|
||||
|
||||
// Authenticate with old password
|
||||
let auth_result = User::authenticate(email, old_password, &db).await;
|
||||
@@ -1095,7 +1097,7 @@ mod tests {
|
||||
// Update password
|
||||
User::patch_password(email, new_password, &db)
|
||||
.await
|
||||
.expect("Failed to update password");
|
||||
.with_context(|| "Failed to update password".to_string())?;
|
||||
|
||||
// Old password should no longer work
|
||||
let old_auth = User::authenticate(email, old_password, &db).await;
|
||||
@@ -1104,10 +1106,11 @@ mod tests {
|
||||
// New password should work
|
||||
let new_auth = User::authenticate(email, new_password, &db).await;
|
||||
assert!(new_auth.is_ok());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validate_timezone() {
|
||||
async fn test_validate_timezone() -> anyhow::Result<()> {
|
||||
// Valid timezones should be accepted as-is
|
||||
assert_eq!(validate_timezone("America/New_York"), "America/New_York");
|
||||
assert_eq!(validate_timezone("Europe/London"), "Europe/London");
|
||||
@@ -1117,12 +1120,13 @@ mod tests {
|
||||
// Invalid timezones should be replaced with UTC
|
||||
assert_eq!(validate_timezone("Invalid/Timezone"), "UTC");
|
||||
assert_eq!(validate_timezone("Not_Real"), "UTC");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timezone_update() {
|
||||
async fn test_timezone_update() -> anyhow::Result<()> {
|
||||
// Setup test database
|
||||
let db = setup_test_db().await;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
// Create user with default timezone
|
||||
let email = "timezone_test@example.com";
|
||||
@@ -1134,7 +1138,7 @@ mod tests {
|
||||
"system".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create user");
|
||||
.with_context(|| "Failed to create user".to_string())?;
|
||||
|
||||
assert_eq!(user.timezone, "UTC");
|
||||
|
||||
@@ -1142,58 +1146,61 @@ mod tests {
|
||||
let new_timezone = "Europe/Paris";
|
||||
User::update_timezone(&user.id, new_timezone, &db)
|
||||
.await
|
||||
.expect("Failed to update timezone");
|
||||
.with_context(|| "Failed to update timezone".to_string())?;
|
||||
|
||||
// Verify timezone was updated
|
||||
let updated_user: Option<User> = db
|
||||
.get_item(&user.id)
|
||||
.await
|
||||
.expect("Failed to retrieve user");
|
||||
assert!(updated_user.is_some());
|
||||
let updated_user = updated_user.unwrap();
|
||||
.with_context(|| "Failed to retrieve user".to_string())?;
|
||||
let updated_user = updated_user.with_context(|| "expected updated user".to_string())?;
|
||||
assert_eq!(updated_user.timezone, new_timezone);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_conversations_order() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_conversations_order() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "user_order_test";
|
||||
|
||||
// Create conversations with varying updated_at timestamps
|
||||
let mut conversations = Vec::new();
|
||||
for i in 0..5 {
|
||||
let mut conv = Conversation::new(user_id.to_string(), format!("Conv {}", i));
|
||||
let mut conv = Conversation::new(user_id.to_string(), format!("Conv {i}"));
|
||||
// Fake updated_at i minutes apart
|
||||
conv.created_at = chrono::Utc::now() - chrono::Duration::minutes(i);
|
||||
db.store_item(conv.clone())
|
||||
.await
|
||||
.expect("Failed to store conversation");
|
||||
.with_context(|| "Failed to store conversation".to_string())?;
|
||||
conversations.push(conv);
|
||||
}
|
||||
|
||||
// Retrieve via get_user_conversations - should be ordered by updated_at DESC
|
||||
let retrieved = User::get_user_conversations(user_id, &db)
|
||||
.await
|
||||
.expect("Failed to get conversations");
|
||||
.with_context(|| "Failed to get conversations".to_string())?;
|
||||
|
||||
assert_eq!(retrieved.len(), conversations.len());
|
||||
|
||||
for window in retrieved.windows(2) {
|
||||
// Assert each earlier conversation has updated_at >= later conversation
|
||||
for pair in retrieved.windows(2) {
|
||||
let a = pair.first().context("expected first in pair")?;
|
||||
let b = pair.get(1).context("expected second in pair")?;
|
||||
assert!(
|
||||
window[0].created_at >= window[1].created_at,
|
||||
a.created_at >= b.created_at,
|
||||
"Conversations not ordered descending by created_at"
|
||||
);
|
||||
}
|
||||
|
||||
// Check first conversation title matches the most recently updated
|
||||
let most_recent = conversations.iter().max_by_key(|c| c.created_at).unwrap();
|
||||
assert_eq!(retrieved[0].id, most_recent.id);
|
||||
let most_recent = conversations.iter().max_by_key(|c| c.created_at).context("expected most recent")?;
|
||||
let r0 = retrieved.first().context("expected first result")?;
|
||||
assert_eq!(r0.id, most_recent.id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_latest_text_contents_returns_last_five() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_get_latest_text_contents_returns_last_five() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "latest_text_user";
|
||||
|
||||
let mut inserted_ids = Vec::new();
|
||||
@@ -1201,8 +1208,8 @@ mod tests {
|
||||
|
||||
for i in 0..12 {
|
||||
let mut item = TextContent::new(
|
||||
format!("Text {}", i),
|
||||
Some(format!("Context {}", i)),
|
||||
format!("Text {i}"),
|
||||
Some(format!("Context {i}")),
|
||||
"Category".to_string(),
|
||||
None,
|
||||
None,
|
||||
@@ -1215,18 +1222,19 @@ mod tests {
|
||||
|
||||
db.store_item(item.clone())
|
||||
.await
|
||||
.expect("Failed to store text content");
|
||||
.with_context(|| "Failed to store text content".to_string())?;
|
||||
|
||||
inserted_ids.push(item.id.clone());
|
||||
}
|
||||
|
||||
let latest = User::get_latest_text_contents(user_id, &db)
|
||||
.await
|
||||
.expect("Failed to fetch latest text contents");
|
||||
.with_context(|| "Failed to fetch latest text contents".to_string())?;
|
||||
|
||||
assert_eq!(latest.len(), 5, "Expected exactly five items");
|
||||
|
||||
let mut expected_ids = inserted_ids[inserted_ids.len() - 5..].to_vec();
|
||||
let start = inserted_ids.len().saturating_sub(5);
|
||||
let mut expected_ids = inserted_ids.get(start..).unwrap_or_default().to_vec();
|
||||
expected_ids.reverse();
|
||||
|
||||
let returned_ids: Vec<String> = latest.iter().map(|item| item.id.clone()).collect();
|
||||
@@ -1235,25 +1243,29 @@ mod tests {
|
||||
"Latest items did not match expectation"
|
||||
);
|
||||
|
||||
for window in latest.windows(2) {
|
||||
for pair in latest.windows(2) {
|
||||
let a = pair.first().context("expected first in pair")?;
|
||||
let b = pair.get(1).context("expected second in pair")?;
|
||||
assert!(
|
||||
window[0].created_at >= window[1].created_at,
|
||||
a.created_at >= b.created_at,
|
||||
"Results are not ordered by created_at descending"
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validate_theme() {
|
||||
async fn test_validate_theme() -> anyhow::Result<()> {
|
||||
assert_eq!(validate_theme("light"), Theme::Light);
|
||||
assert_eq!(validate_theme("dark"), Theme::Dark);
|
||||
assert_eq!(validate_theme("system"), Theme::System);
|
||||
assert_eq!(validate_theme("invalid"), Theme::System);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_theme_update() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_theme_update() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let email = "theme_test@example.com";
|
||||
let user = User::create_new(
|
||||
email.to_string(),
|
||||
@@ -1263,30 +1275,31 @@ mod tests {
|
||||
"system".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create user");
|
||||
.with_context(|| "Failed to create user".to_string())?;
|
||||
|
||||
assert_eq!(user.theme, Theme::System);
|
||||
|
||||
User::update_theme(&user.id, "dark", &db)
|
||||
.await
|
||||
.expect("update theme");
|
||||
.with_context(|| "update theme".to_string())?;
|
||||
|
||||
let updated = db
|
||||
.get_item::<User>(&user.id)
|
||||
.await
|
||||
.expect("get user")
|
||||
.unwrap();
|
||||
.with_context(|| "get user".to_string())?
|
||||
.with_context(|| "expected user".to_string())?;
|
||||
assert_eq!(updated.theme, Theme::Dark);
|
||||
|
||||
// Invalid theme should default to system (but update_theme calls validate_theme)
|
||||
User::update_theme(&user.id, "invalid", &db)
|
||||
.await
|
||||
.expect("update theme invalid");
|
||||
.with_context(|| "update theme invalid".to_string())?;
|
||||
let updated2 = db
|
||||
.get_item::<User>(&user.id)
|
||||
.await
|
||||
.expect("get user")
|
||||
.unwrap();
|
||||
.with_context(|| "get user".to_string())?
|
||||
.with_context(|| "expected user".to_string())?;
|
||||
assert_eq!(updated2.theme, Theme::System);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,8 +28,8 @@ fn default_storage_kind() -> StorageKind {
|
||||
StorageKind::Local
|
||||
}
|
||||
|
||||
fn default_s3_region() -> Option<String> {
|
||||
Some("us-east-1".to_string())
|
||||
fn default_s3_region() -> String {
|
||||
"us-east-1".to_string()
|
||||
}
|
||||
|
||||
/// Selects the strategy used for PDF ingestion.
|
||||
@@ -69,7 +69,7 @@ pub struct AppConfig {
|
||||
#[serde(default)]
|
||||
pub s3_endpoint: Option<String>,
|
||||
#[serde(default = "default_s3_region")]
|
||||
pub s3_region: Option<String>,
|
||||
pub s3_region: String,
|
||||
#[serde(default = "default_pdf_ingest_mode")]
|
||||
pub pdf_ingest_mode: PdfIngestMode,
|
||||
#[serde(default = "default_reranking_enabled")]
|
||||
|
||||
@@ -14,7 +14,7 @@ use common::utils::config::get_config;
|
||||
use common::{
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
store::{DynStore, StorageManager},
|
||||
store::{DynStorage, StorageManager},
|
||||
types::{ingestion_payload::IngestionPayload, ingestion_task::IngestionTask, StoredObject},
|
||||
},
|
||||
utils::config::{AppConfig, StorageKind},
|
||||
@@ -432,7 +432,7 @@ async fn ingest_paragraph_batch(
|
||||
storage: StorageKind::Memory,
|
||||
..Default::default()
|
||||
};
|
||||
let backend: DynStore = Arc::new(InMemory::new());
|
||||
let backend: DynStorage = Arc::new(InMemory::new());
|
||||
let storage = StorageManager::with_backend(backend, StorageKind::Memory);
|
||||
|
||||
let pipeline_config = ingestion_config.clone();
|
||||
|
||||
@@ -861,7 +861,7 @@ mod tests {
|
||||
let question = CorpusQuestion {
|
||||
question_id: "q1".to_string(),
|
||||
paragraph_id: paragraph_one.paragraph_id.clone(),
|
||||
text_content_id: text_content_id,
|
||||
text_content_id,
|
||||
question_text: "What is this?".to_string(),
|
||||
answers: vec!["Hello".to_string()],
|
||||
is_impossible: false,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use anyhow::{Context, Result};
|
||||
use common::storage::{db::SurrealDbClient, indexes::ensure_runtime_indexes};
|
||||
use common::storage::{db::SurrealDbClient, indexes::ensure_runtime};
|
||||
use tracing::info;
|
||||
|
||||
// Helper functions for index management during namespace reseed
|
||||
@@ -11,7 +11,7 @@ pub async fn remove_all_indexes(db: &SurrealDbClient) -> Result<()> {
|
||||
|
||||
pub async fn recreate_indexes(db: &SurrealDbClient, dimension: usize) -> Result<()> {
|
||||
info!("Recreating ALL indexes after namespace reseed via shared runtime helper");
|
||||
ensure_runtime_indexes(db, dimension)
|
||||
ensure_runtime(db, dimension)
|
||||
.await
|
||||
.context("creating runtime indexes")
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ use common::{
|
||||
utils::embedding::EmbeddingProvider,
|
||||
};
|
||||
use retrieval_pipeline::{
|
||||
pipeline::{PipelineStageTimings, RetrievalConfig},
|
||||
pipeline::{StageTimings, RetrievalConfig},
|
||||
reranking::RerankerPool,
|
||||
};
|
||||
|
||||
@@ -56,7 +56,7 @@ pub(super) struct EvaluationContext<'a> {
|
||||
pub corpus_handle: Option<corpus::CorpusHandle>,
|
||||
pub cases: Vec<SeededCase>,
|
||||
pub filtered_questions: usize,
|
||||
pub stage_latency_samples: Vec<PipelineStageTimings>,
|
||||
pub stage_latency_samples: Vec<StageTimings>,
|
||||
pub latencies: Vec<u128>,
|
||||
pub diagnostics_output: Vec<CaseDiagnostics>,
|
||||
pub query_summaries: Vec<CaseSummary>,
|
||||
|
||||
@@ -10,7 +10,7 @@ use crate::eval::{
|
||||
CaseSummary, RetrievedSummary,
|
||||
};
|
||||
use retrieval_pipeline::{
|
||||
pipeline::{self, PipelineStageTimings, RetrievalConfig},
|
||||
pipeline::{self, StageTimings, RetrievalConfig},
|
||||
reranking::RerankerPool,
|
||||
};
|
||||
use tokio::sync::Semaphore;
|
||||
@@ -75,10 +75,10 @@ pub(crate) async fn run_queries(
|
||||
retrieval_config.tuning.chunk_rrf_fts_weight = value;
|
||||
}
|
||||
if let Some(value) = config.retrieval.chunk_rrf_use_vector {
|
||||
retrieval_config.tuning.chunk_rrf_use_vector = value;
|
||||
retrieval_config.tuning.flags.chunk_rrf_use_vector = value.into();
|
||||
}
|
||||
if let Some(value) = config.retrieval.chunk_rrf_use_fts {
|
||||
retrieval_config.tuning.chunk_rrf_use_fts = value;
|
||||
retrieval_config.tuning.flags.chunk_rrf_use_fts = value.into();
|
||||
}
|
||||
if let Some(value) = config.retrieval.chunk_avg_chars_per_token {
|
||||
retrieval_config.tuning.avg_chars_per_token = value;
|
||||
@@ -113,8 +113,8 @@ pub(crate) async fn run_queries(
|
||||
chunk_rrf_k = active_tuning.chunk_rrf_k,
|
||||
chunk_rrf_vector_weight = active_tuning.chunk_rrf_vector_weight,
|
||||
chunk_rrf_fts_weight = active_tuning.chunk_rrf_fts_weight,
|
||||
chunk_rrf_use_vector = active_tuning.chunk_rrf_use_vector,
|
||||
chunk_rrf_use_fts = active_tuning.chunk_rrf_use_fts,
|
||||
chunk_rrf_use_vector = active_tuning.flags.chunk_rrf_use_vector.as_bool(),
|
||||
chunk_rrf_use_fts = active_tuning.flags.chunk_rrf_use_fts.as_bool(),
|
||||
embedding_backend = ctx.embedding_provider().backend_label(),
|
||||
embedding_model = ctx
|
||||
.embedding_provider()
|
||||
@@ -181,35 +181,32 @@ pub(crate) async fn run_queries(
|
||||
embedding_provider.embed(&question).await.with_context(|| {
|
||||
format!("generating embedding for question {}", question_id)
|
||||
})?;
|
||||
let reranker = match &rerank_pool {
|
||||
Some(pool) => Some(pool.checkout().await),
|
||||
let reranker = match rerank_pool.as_ref() {
|
||||
Some(pool) => pool.checkout().await,
|
||||
None => None,
|
||||
};
|
||||
|
||||
let params = pipeline::StrategyParams {
|
||||
db_client: &db,
|
||||
openai_client: &openai_client,
|
||||
embedding_provider: Some(&embedding_provider),
|
||||
input_text: &question,
|
||||
user_id: &user_id,
|
||||
config: (*retrieval_config).clone(),
|
||||
reranker,
|
||||
};
|
||||
let (result_output, pipeline_diagnostics, stage_timings) = if diagnostics_enabled {
|
||||
let outcome = pipeline::run_pipeline_with_embedding_with_diagnostics(
|
||||
&db,
|
||||
&openai_client,
|
||||
Some(&embedding_provider),
|
||||
params,
|
||||
query_embedding,
|
||||
&question,
|
||||
&user_id,
|
||||
(*retrieval_config).clone(),
|
||||
reranker,
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("running pipeline for question {}", question_id))?;
|
||||
(outcome.results, outcome.diagnostics, outcome.stage_timings)
|
||||
} else {
|
||||
let outcome = pipeline::run_pipeline_with_embedding_with_metrics(
|
||||
&db,
|
||||
&openai_client,
|
||||
Some(&embedding_provider),
|
||||
params,
|
||||
query_embedding,
|
||||
&question,
|
||||
&user_id,
|
||||
(*retrieval_config).clone(),
|
||||
reranker,
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("running pipeline for question {}", question_id))?;
|
||||
@@ -327,7 +324,7 @@ pub(crate) async fn run_queries(
|
||||
usize,
|
||||
CaseSummary,
|
||||
Option<CaseDiagnostics>,
|
||||
PipelineStageTimings,
|
||||
StageTimings,
|
||||
),
|
||||
anyhow::Error,
|
||||
>((idx, summary, diagnostics, stage_timings))
|
||||
|
||||
@@ -205,8 +205,8 @@ pub(crate) async fn summarize(
|
||||
chunk_rrf_k: active_tuning.chunk_rrf_k,
|
||||
chunk_rrf_vector_weight: active_tuning.chunk_rrf_vector_weight,
|
||||
chunk_rrf_fts_weight: active_tuning.chunk_rrf_fts_weight,
|
||||
chunk_rrf_use_vector: active_tuning.chunk_rrf_use_vector,
|
||||
chunk_rrf_use_fts: active_tuning.chunk_rrf_use_fts,
|
||||
chunk_rrf_use_vector: active_tuning.flags.chunk_rrf_use_vector.as_bool(),
|
||||
chunk_rrf_use_fts: active_tuning.flags.chunk_rrf_use_fts.as_bool(),
|
||||
ingest_chunk_min_tokens: config.ingest.ingest_chunk_min_tokens,
|
||||
ingest_chunk_max_tokens: config.ingest.ingest_chunk_max_tokens,
|
||||
ingest_chunks_only: config.ingest.ingest_chunks_only,
|
||||
|
||||
+25
-27
@@ -1037,6 +1037,31 @@ fn write_manifest(path: &Path, manifest: &SliceManifest) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
use crate::args::Config;
|
||||
|
||||
impl<'a> From<&'a Config> for SliceConfig<'a> {
|
||||
fn from(config: &'a Config) -> Self {
|
||||
slice_config_with_limit(config, None)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn slice_config_with_limit<'a>(
|
||||
config: &'a Config,
|
||||
limit_override: Option<usize>,
|
||||
) -> SliceConfig<'a> {
|
||||
SliceConfig {
|
||||
cache_dir: config.cache_dir.as_path(),
|
||||
force_convert: config.force_convert,
|
||||
explicit_slice: config.slice.as_deref(),
|
||||
limit: limit_override.or(config.limit),
|
||||
corpus_limit: config.corpus_limit,
|
||||
slice_seed: config.slice_seed,
|
||||
llm_mode: config.llm_mode,
|
||||
negative_multiplier: config.negative_multiplier,
|
||||
require_verified_chunks: config.retrieval.require_verified_chunks,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -1214,30 +1239,3 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Config integration (merged from slice.rs)
|
||||
|
||||
use crate::args::Config;
|
||||
|
||||
impl<'a> From<&'a Config> for SliceConfig<'a> {
|
||||
fn from(config: &'a Config) -> Self {
|
||||
slice_config_with_limit(config, None)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn slice_config_with_limit<'a>(
|
||||
config: &'a Config,
|
||||
limit_override: Option<usize>,
|
||||
) -> SliceConfig<'a> {
|
||||
SliceConfig {
|
||||
cache_dir: config.cache_dir.as_path(),
|
||||
force_convert: config.force_convert,
|
||||
explicit_slice: config.slice.as_deref(),
|
||||
limit: limit_override.or(config.limit),
|
||||
corpus_limit: config.corpus_limit,
|
||||
slice_seed: config.slice_seed,
|
||||
llm_mode: config.llm_mode,
|
||||
negative_multiplier: config.negative_multiplier,
|
||||
require_verified_chunks: config.retrieval.require_verified_chunks,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,30 +39,33 @@ const CONVERSATION_ARCHIVE_CACHE_TTL: Duration = Duration::from_secs(30);
|
||||
const CONVERSATION_ARCHIVE_CACHE_MAX_USERS: usize = 1024;
|
||||
const CONVERSATION_ARCHIVE_CACHE_CLEANUP_WRITE_INTERVAL: usize = 64;
|
||||
|
||||
pub struct StateResources {
|
||||
pub db: Arc<SurrealDbClient>,
|
||||
pub openai_client: Arc<OpenAIClientType>,
|
||||
pub session_store: Arc<SessionStoreType>,
|
||||
pub storage: StorageManager,
|
||||
pub config: AppConfig,
|
||||
pub reranker_pool: Option<Arc<RerankerPool>>,
|
||||
pub embedding_provider: Arc<EmbeddingProvider>,
|
||||
pub template_engine: Option<Arc<TemplateEngine>>,
|
||||
}
|
||||
|
||||
impl HtmlState {
|
||||
pub async fn new_with_resources(
|
||||
db: Arc<SurrealDbClient>,
|
||||
openai_client: Arc<OpenAIClientType>,
|
||||
session_store: Arc<SessionStoreType>,
|
||||
storage: StorageManager,
|
||||
config: AppConfig,
|
||||
reranker_pool: Option<Arc<RerankerPool>>,
|
||||
embedding_provider: Arc<EmbeddingProvider>,
|
||||
template_engine: Option<Arc<TemplateEngine>>,
|
||||
) -> Self {
|
||||
let templates =
|
||||
template_engine.unwrap_or_else(|| Arc::new(create_template_engine!("templates")));
|
||||
pub fn new_with_resources(resources: StateResources) -> Self {
|
||||
let templates = resources
|
||||
.template_engine
|
||||
.unwrap_or_else(|| Arc::new(create_template_engine!("templates")));
|
||||
debug!("Template engine configured for html_router.");
|
||||
|
||||
Self {
|
||||
db,
|
||||
openai_client,
|
||||
session_store,
|
||||
db: resources.db,
|
||||
openai_client: resources.openai_client,
|
||||
templates,
|
||||
config,
|
||||
storage,
|
||||
reranker_pool,
|
||||
embedding_provider,
|
||||
session_store: resources.session_store,
|
||||
config: resources.config,
|
||||
storage: resources.storage,
|
||||
reranker_pool: resources.reranker_pool,
|
||||
embedding_provider: resources.embedding_provider,
|
||||
conversation_archive_cache: Arc::new(RwLock::new(HashMap::new())),
|
||||
conversation_archive_cache_writes: Arc::new(AtomicUsize::new(0)),
|
||||
}
|
||||
@@ -210,18 +213,16 @@ mod tests {
|
||||
EmbeddingProvider::new_hashed(8).expect("Failed to create embedding provider"),
|
||||
);
|
||||
|
||||
HtmlState::new_with_resources(
|
||||
HtmlState::new_with_resources(StateResources {
|
||||
db,
|
||||
Arc::new(async_openai::Client::new()),
|
||||
openai_client: Arc::new(async_openai::Client::new()),
|
||||
session_store,
|
||||
storage,
|
||||
config,
|
||||
None,
|
||||
reranker_pool: None,
|
||||
embedding_provider,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create HtmlState")
|
||||
template_engine: None,
|
||||
})
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -2,6 +2,6 @@ use tower_http::compression::CompressionLayer;
|
||||
|
||||
/// Provides a default compression layer that negotiates encoding based on the
|
||||
/// `Accept-Encoding` header of the incoming request.
|
||||
pub fn compression_layer() -> CompressionLayer {
|
||||
pub fn layer() -> CompressionLayer {
|
||||
CompressionLayer::new()
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ use axum::{
|
||||
use axum_htmx::{HxRequest, HX_TRIGGER};
|
||||
use common::{
|
||||
error::AppError,
|
||||
utils::template_engine::{ProvidesTemplateEngine, Value},
|
||||
utils::template_engine::{ProvidesTemplateEngine, TemplateEngine, Value},
|
||||
};
|
||||
use minijinja::context;
|
||||
use serde::Serialize;
|
||||
@@ -146,67 +146,21 @@ struct ContextWrapper<'a> {
|
||||
context: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
pub async fn with_template_response<S>(
|
||||
State(state): State<S>,
|
||||
HxRequest(is_htmx): HxRequest,
|
||||
req: Request,
|
||||
next: Next,
|
||||
) -> Response
|
||||
where
|
||||
S: ProvidesTemplateEngine + ProvidesHtmlState + Clone + Send + Sync + 'static,
|
||||
{
|
||||
let mut user_theme = Theme::System.as_str();
|
||||
let mut initial_theme = Theme::System.initial_theme();
|
||||
let mut is_authenticated = false;
|
||||
let mut current_user_id = None;
|
||||
let mut current_user = None;
|
||||
const HTMX_HEADERS_TO_FORWARD: &[&str] = &["HX-Push", "HX-Trigger", "HX-Redirect"];
|
||||
|
||||
{
|
||||
if let Some(auth) = req.extensions().get::<AuthSessionType>() {
|
||||
if let Some(user) = &auth.current_user {
|
||||
is_authenticated = true;
|
||||
current_user_id = Some(user.id.clone());
|
||||
user_theme = user.theme.as_str();
|
||||
initial_theme = user.theme.initial_theme();
|
||||
current_user = Some(TemplateUser::from(user));
|
||||
fn forward_headers(from: &axum::http::HeaderMap, to: &mut axum::http::HeaderMap) {
|
||||
for &header_name in HTMX_HEADERS_TO_FORWARD {
|
||||
if let Ok(name) = HeaderName::from_bytes(header_name.as_bytes()) {
|
||||
if let Some(value) = from.get(&name) {
|
||||
to.insert(name.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let response = next.run(req).await;
|
||||
|
||||
// Headers to forward from the original response
|
||||
const HTMX_HEADERS_TO_FORWARD: &[&str] = &["HX-Push", "HX-Trigger", "HX-Redirect"];
|
||||
|
||||
if let Some(template_response) = response.extensions().get::<TemplateResponse>().cloned() {
|
||||
let template_engine = state.template_engine();
|
||||
|
||||
let mut conversation_archive = Vec::new();
|
||||
|
||||
let should_load_conversation_archive =
|
||||
matches!(&template_response.template_kind, TemplateKind::Full(_));
|
||||
|
||||
if should_load_conversation_archive {
|
||||
if let Some(user_id) = current_user_id {
|
||||
let html_state = state.html_state();
|
||||
if let Some(cached_archive) =
|
||||
html_state.get_cached_conversation_archive(&user_id).await
|
||||
{
|
||||
conversation_archive = cached_archive;
|
||||
} else if let Ok(archive) =
|
||||
Conversation::get_user_sidebar_conversations(&user_id, &html_state.db).await
|
||||
{
|
||||
html_state
|
||||
.set_cached_conversation_archive(&user_id, archive.clone())
|
||||
.await;
|
||||
conversation_archive = archive;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn context_to_map(
|
||||
fn context_to_map(
|
||||
value: &Value,
|
||||
) -> Result<HashMap<String, Value>, minijinja::value::ValueKind> {
|
||||
) -> Result<HashMap<String, Value>, minijinja::value::ValueKind> {
|
||||
match value.kind() {
|
||||
minijinja::value::ValueKind::Map => {
|
||||
let mut map = HashMap::new();
|
||||
@@ -224,15 +178,57 @@ where
|
||||
}
|
||||
other => Err(other),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn with_template_response<S>(
|
||||
State(state): State<S>,
|
||||
HxRequest(is_htmx): HxRequest,
|
||||
req: Request,
|
||||
next: Next,
|
||||
) -> Response
|
||||
where
|
||||
S: ProvidesTemplateEngine + ProvidesHtmlState + Clone + Send + Sync + 'static,
|
||||
{
|
||||
let mut user_theme = Theme::System.as_str();
|
||||
let mut initial_theme = Theme::System.initial_theme();
|
||||
let mut is_authenticated = false;
|
||||
let mut current_user = None;
|
||||
|
||||
{
|
||||
if let Some(auth) = req.extensions().get::<AuthSessionType>() {
|
||||
if let Some(user) = &auth.current_user {
|
||||
is_authenticated = true;
|
||||
user_theme = user.theme.as_str();
|
||||
initial_theme = user.theme.initial_theme();
|
||||
current_user = Some(TemplateUser::from(user));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to forward relevant headers
|
||||
fn forward_headers(from: &axum::http::HeaderMap, to: &mut axum::http::HeaderMap) {
|
||||
for &header_name in HTMX_HEADERS_TO_FORWARD {
|
||||
if let Ok(name) = HeaderName::from_bytes(header_name.as_bytes()) {
|
||||
if let Some(value) = from.get(&name) {
|
||||
to.insert(name.clone(), value.clone());
|
||||
}
|
||||
let response = next.run(req).await;
|
||||
|
||||
if let Some(template_response) = response.extensions().get::<TemplateResponse>().cloned() {
|
||||
let template_engine = state.template_engine();
|
||||
|
||||
let mut conversation_archive = Vec::new();
|
||||
|
||||
let should_load_conversation_archive =
|
||||
matches!(&template_response.template_kind, TemplateKind::Full(_));
|
||||
|
||||
if should_load_conversation_archive {
|
||||
if let Some(user_id) = current_user.as_ref().map(|u| &u.id) {
|
||||
let html_state = state.html_state();
|
||||
if let Some(cached_archive) =
|
||||
html_state.get_cached_conversation_archive(user_id).await
|
||||
{
|
||||
conversation_archive = cached_archive;
|
||||
} else if let Ok(archive) =
|
||||
Conversation::get_user_sidebar_conversations(user_id, &html_state.db).await
|
||||
{
|
||||
html_state
|
||||
.set_cached_conversation_archive(user_id, archive.clone())
|
||||
.await;
|
||||
conversation_archive = archive;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -290,18 +286,17 @@ where
|
||||
}
|
||||
TemplateKind::Error(status) => {
|
||||
if is_htmx {
|
||||
// HTMX request: Send 204 + HX-Trigger for toast
|
||||
let title = template_response
|
||||
.context
|
||||
.get_attr("title")
|
||||
.ok()
|
||||
.and_then(|v| v.as_str().map(String::from))
|
||||
.and_then(|v| v.as_str().map(|s| s.to_string()))
|
||||
.unwrap_or_else(|| "Error".to_string());
|
||||
let description = template_response
|
||||
.context
|
||||
.get_attr("description")
|
||||
.ok()
|
||||
.and_then(|v| v.as_str().map(String::from))
|
||||
.and_then(|v| v.as_str().map(|s| s.to_string()))
|
||||
.unwrap_or_else(|| "An error occurred.".to_string());
|
||||
|
||||
let trigger_payload = json!({"toast": {"title": title, "description": description, "type": "error"}});
|
||||
@@ -312,14 +307,12 @@ where
|
||||
});
|
||||
(StatusCode::NO_CONTENT, [(HX_TRIGGER, trigger_value)], "").into_response()
|
||||
} else {
|
||||
// Non-HTMX request: Render the full errors/error.html page
|
||||
match template_engine
|
||||
.render("errors/error.html", &Value::from_serialize(&context))
|
||||
{
|
||||
Ok(html) => (*status, Html(html)).into_response(),
|
||||
Err(e) => {
|
||||
error!("Critical: Failed to render 'errors/error.html': {:?}", e);
|
||||
// Fallback HTML, but use the intended status code
|
||||
(*status, Html(fallback_error())).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ use crate::{
|
||||
html_state::HtmlState,
|
||||
middlewares::{
|
||||
analytics_middleware::analytics_middleware, auth_middleware::require_auth,
|
||||
compression::compression_layer, response_middleware::with_template_response,
|
||||
compression, response_middleware::with_template_response,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -71,6 +71,7 @@ where
|
||||
}
|
||||
|
||||
// Add a serving of assets
|
||||
#[must_use]
|
||||
pub fn with_public_assets(mut self, path: &str, directory: &str) -> Self {
|
||||
self.public_assets_config = Some(AssetsConfig {
|
||||
path: path.to_string(),
|
||||
@@ -80,24 +81,28 @@ where
|
||||
}
|
||||
|
||||
// Add a public router that will be merged at the root level
|
||||
#[must_use]
|
||||
pub fn add_public_routes(mut self, routes: Router<S>) -> Self {
|
||||
self.public_routers.push(routes);
|
||||
self
|
||||
}
|
||||
|
||||
// Add a protected router that will be merged at the root level
|
||||
#[must_use]
|
||||
pub fn add_protected_routes(mut self, routes: Router<S>) -> Self {
|
||||
self.protected_routers.push(routes);
|
||||
self
|
||||
}
|
||||
|
||||
// Nest a public router under a path prefix
|
||||
#[must_use]
|
||||
pub fn nest_public_routes(mut self, path: &str, routes: Router<S>) -> Self {
|
||||
self.nested_routes.push((path.to_string(), routes));
|
||||
self
|
||||
}
|
||||
|
||||
// Nest a protected router under a path prefix
|
||||
#[must_use]
|
||||
pub fn nest_protected_routes(mut self, path: &str, routes: Router<S>) -> Self {
|
||||
self.nested_protected_routes
|
||||
.push((path.to_string(), routes));
|
||||
@@ -105,6 +110,7 @@ where
|
||||
}
|
||||
|
||||
// Add custom middleware to be applied before the standard ones
|
||||
#[must_use]
|
||||
pub fn with_middleware<F>(mut self, middleware_fn: F) -> Self
|
||||
where
|
||||
F: FnOnce(Router<S>) -> Router<S> + Send + 'static,
|
||||
@@ -114,6 +120,7 @@ where
|
||||
}
|
||||
|
||||
/// Enables response compression when building the router.
|
||||
#[must_use]
|
||||
pub const fn with_compression(mut self) -> Self {
|
||||
self.compression_enabled = true;
|
||||
self
|
||||
@@ -191,7 +198,7 @@ where
|
||||
|
||||
// Apply Global Middleware (Compression)
|
||||
if self.compression_enabled {
|
||||
final_router = final_router.layer(compression_layer());
|
||||
final_router = final_router.layer(compression::layer());
|
||||
}
|
||||
|
||||
final_router
|
||||
|
||||
@@ -62,7 +62,7 @@ pub async fn set_api_key(
|
||||
let api_key = User::set_api_key(&user.id, &state.db).await?;
|
||||
|
||||
// Clear the cache so new requests have access to the user with api key
|
||||
auth.cache_clear_user(user.id.to_string());
|
||||
auth.cache_clear_user(user.id.clone());
|
||||
|
||||
// Render the API key section block
|
||||
Ok(TemplateResponse::new_partial(
|
||||
@@ -106,7 +106,7 @@ pub async fn update_timezone(
|
||||
User::update_timezone(&user.id, &form.timezone, &state.db).await?;
|
||||
|
||||
// Clear the cache
|
||||
auth.cache_clear_user(user.id.to_string());
|
||||
auth.cache_clear_user(user.id.clone());
|
||||
|
||||
let timezones = TZ_VARIANTS
|
||||
.iter()
|
||||
@@ -141,7 +141,7 @@ pub async fn update_theme(
|
||||
User::update_theme(&user.id, &form.theme, &state.db).await?;
|
||||
|
||||
// Clear the cache
|
||||
auth.cache_clear_user(user.id.to_string());
|
||||
auth.cache_clear_user(user.id.clone());
|
||||
|
||||
let theme_options = vec![
|
||||
Theme::Light.as_str().to_string(),
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_openai::types::ListModelResponse;
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
@@ -37,18 +39,14 @@ pub struct AdminPanelData {
|
||||
current_section: AdminSection,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum AdminSection {
|
||||
#[default]
|
||||
Overview,
|
||||
Models,
|
||||
}
|
||||
|
||||
impl Default for AdminSection {
|
||||
fn default() -> Self {
|
||||
Self::Overview
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct AdminPanelQuery {
|
||||
@@ -107,10 +105,7 @@ fn checkbox_to_bool<'de, D>(deserializer: D) -> Result<bool, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
match String::deserialize(deserializer) {
|
||||
Ok(string) => Ok(string == "on"),
|
||||
Err(_) => Ok(false),
|
||||
}
|
||||
String::deserialize(deserializer).map(|s| s == "on")
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -219,8 +214,8 @@ pub async fn update_model_settings(
|
||||
if reembedding_needed {
|
||||
info!("Embedding dimensions changed. Spawning background re-embedding task...");
|
||||
|
||||
let db_for_task = state.db.clone();
|
||||
let openai_for_task = state.openai_client.clone();
|
||||
let db_for_task = Arc::clone(&state.db);
|
||||
let openai_for_task = Arc::clone(&state.openai_client);
|
||||
let new_model_for_task = new_settings.embedding_model.clone();
|
||||
let new_dims_for_task = new_settings.embedding_dimensions;
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ use crate::{
|
||||
};
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
pub struct SignupParams {
|
||||
pub struct Params {
|
||||
pub email: String,
|
||||
pub password: String,
|
||||
pub timezone: String,
|
||||
@@ -39,7 +39,7 @@ pub async fn show_signup_form(
|
||||
pub async fn process_signup_and_show_verification(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSessionType,
|
||||
Form(form): Form<SignupParams>,
|
||||
Form(form): Form<Params>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let user = match User::create_new(
|
||||
form.email,
|
||||
|
||||
@@ -49,6 +49,8 @@ pub struct ChatPageData {
|
||||
conversation: Option<Conversation>,
|
||||
}
|
||||
|
||||
/// # Panics
|
||||
/// Panics if the HX-Push header value cannot be parsed.
|
||||
pub async fn show_initialized_chat(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
@@ -57,14 +59,14 @@ pub async fn show_initialized_chat(
|
||||
let conversation = Conversation::new(user.id.clone(), "Test".to_owned());
|
||||
|
||||
let user_message = Message::new(
|
||||
conversation.id.to_string(),
|
||||
conversation.id.clone(),
|
||||
MessageRole::User,
|
||||
form.user_query,
|
||||
None,
|
||||
);
|
||||
|
||||
let ai_message = Message::new(
|
||||
conversation.id.to_string(),
|
||||
conversation.id.clone(),
|
||||
MessageRole::AI,
|
||||
form.llm_response,
|
||||
Some(form.references),
|
||||
@@ -86,10 +88,9 @@ pub async fn show_initialized_chat(
|
||||
)
|
||||
.into_response();
|
||||
|
||||
response.headers_mut().insert(
|
||||
"HX-Push",
|
||||
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(),
|
||||
);
|
||||
if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) {
|
||||
response.headers_mut().insert("HX-Push", header_value);
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
@@ -130,12 +131,19 @@ pub async fn show_existing_chat(
|
||||
))
|
||||
}
|
||||
|
||||
/// # Panics
|
||||
/// Panics if the HX-Push header value cannot be parsed.
|
||||
pub async fn new_user_message(
|
||||
Path(conversation_id): Path<String>,
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Form(form): Form<NewMessageForm>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
#[derive(Serialize)]
|
||||
struct SSEResponseInitData {
|
||||
user_message: Message,
|
||||
}
|
||||
|
||||
let conversation: Conversation = state
|
||||
.db
|
||||
.get_item(&conversation_id)
|
||||
@@ -150,33 +158,34 @@ pub async fn new_user_message(
|
||||
|
||||
state.db.store_item(user_message.clone()).await?;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct SSEResponseInitData {
|
||||
user_message: Message,
|
||||
}
|
||||
|
||||
let mut response = TemplateResponse::new_template(
|
||||
"chat/streaming_response.html",
|
||||
SSEResponseInitData { user_message },
|
||||
)
|
||||
.into_response();
|
||||
|
||||
response.headers_mut().insert(
|
||||
"HX-Push",
|
||||
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(),
|
||||
);
|
||||
if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) {
|
||||
response.headers_mut().insert("HX-Push", header_value);
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// # Panics
|
||||
/// Panics if the HX-Push header value cannot be parsed.
|
||||
pub async fn new_chat_user_message(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
Form(form): Form<NewMessageForm>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let user = match auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/").into_response()),
|
||||
#[derive(Serialize)]
|
||||
struct SSEResponseInitData {
|
||||
user_message: Message,
|
||||
conversation: Conversation,
|
||||
}
|
||||
|
||||
let Some(user) = auth.current_user else {
|
||||
return Ok(Redirect::to("/").into_response());
|
||||
};
|
||||
|
||||
let conversation = Conversation::new(user.id.clone(), "New chat".to_string());
|
||||
@@ -191,11 +200,6 @@ pub async fn new_chat_user_message(
|
||||
state.db.store_item(user_message.clone()).await?;
|
||||
state.invalidate_conversation_archive_cache(&user.id).await;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct SSEResponseInitData {
|
||||
user_message: Message,
|
||||
conversation: Conversation,
|
||||
}
|
||||
let mut response = TemplateResponse::new_template(
|
||||
"chat/new_chat_first_response.html",
|
||||
SSEResponseInitData {
|
||||
@@ -205,10 +209,9 @@ pub async fn new_chat_user_message(
|
||||
)
|
||||
.into_response();
|
||||
|
||||
response.headers_mut().insert(
|
||||
"HX-Push",
|
||||
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(),
|
||||
);
|
||||
if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) {
|
||||
response.headers_mut().insert("HX-Push", header_value);
|
||||
}
|
||||
|
||||
Ok(response.into_response())
|
||||
}
|
||||
|
||||
@@ -53,26 +53,22 @@ fn sse_with_keep_alive(stream: EventStream) -> SseResponse {
|
||||
)
|
||||
}
|
||||
|
||||
// Error handling function
|
||||
fn create_error_stream(message: impl Into<String>) -> EventStream {
|
||||
let message = message.into();
|
||||
stream::once(async move { Ok(Event::default().event("error").data(message)) }).boxed()
|
||||
}
|
||||
|
||||
// Helper function to get message and user
|
||||
async fn get_message_and_user(
|
||||
db: &SurrealDbClient,
|
||||
current_user: Option<User>,
|
||||
message_id: &str,
|
||||
) -> Result<(Message, User, Conversation, Vec<Message>, Option<Message>), SseResponse> {
|
||||
// Check authentication
|
||||
let Some(user) = current_user else {
|
||||
return Err(sse_with_keep_alive(create_error_stream(
|
||||
"You must be signed in to use this feature",
|
||||
)));
|
||||
};
|
||||
|
||||
// Retrieve message
|
||||
let message = match db.get_item::<Message>(message_id).await {
|
||||
Ok(Some(message)) => message,
|
||||
Ok(None) => {
|
||||
@@ -88,7 +84,6 @@ async fn get_message_and_user(
|
||||
}
|
||||
};
|
||||
|
||||
// Get conversation history
|
||||
let (conversation, history) =
|
||||
match Conversation::get_complete_conversation(&message.conversation_id, &user.id, db).await
|
||||
{
|
||||
@@ -209,7 +204,6 @@ pub async fn get_response_stream(
|
||||
auth: AuthSessionType,
|
||||
Query(params): Query<QueryParams>,
|
||||
) -> SseResponse {
|
||||
// 1. Authentication and initial data validation
|
||||
let (user_message, user, _conversation, history, existing_ai_response) =
|
||||
match get_message_and_user(&state.db, auth.current_user, ¶ms.message_id).await {
|
||||
Ok((user_message, user, conversation, history, existing_ai_response)) => (
|
||||
@@ -226,9 +220,123 @@ pub async fn get_response_stream(
|
||||
return create_replayed_response_stream(&state, existing_ai_message);
|
||||
}
|
||||
|
||||
// 2. Retrieve knowledge entities
|
||||
let (request, allowed_reference_ids) = match prepare_chat_request(&state, &user_message, &user, &history).await {
|
||||
Ok(result) => result,
|
||||
Err(sse) => return sse,
|
||||
};
|
||||
|
||||
let openai_stream = match state.openai_client.chat().create_stream(request).await {
|
||||
Ok(stream) => stream,
|
||||
Err(_e) => {
|
||||
return sse_with_keep_alive(create_error_stream("Failed to create OpenAI stream"));
|
||||
}
|
||||
};
|
||||
|
||||
build_chat_event_stream(state, openai_stream, &user_message, user.id.clone(), allowed_reference_ids)
|
||||
}
|
||||
|
||||
fn build_chat_event_stream(
|
||||
state: HtmlState,
|
||||
openai_stream: impl Stream<Item = Result<async_openai::types::CreateChatCompletionStreamResponse, async_openai::error::OpenAIError>> + Send + 'static,
|
||||
user_message: &Message,
|
||||
user_id: String,
|
||||
allowed_reference_ids: Vec<String>,
|
||||
) -> SseResponse {
|
||||
let (tx, rx) = channel::<String>(1000);
|
||||
let (tx_final, mut rx_final) = channel::<Message>(1);
|
||||
|
||||
spawn_storage_task(Arc::clone(&state.db), rx, tx_final, user_message, user_id, allowed_reference_ids);
|
||||
|
||||
let json_state = Arc::new(Mutex::new(StreamParserState::new()));
|
||||
|
||||
let event_stream = openai_stream
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
|
||||
.map(move |result| {
|
||||
let tx_storage = tx.clone();
|
||||
let json_state = Arc::clone(&json_state);
|
||||
|
||||
stream! {
|
||||
match result {
|
||||
Ok(response) => {
|
||||
let content = response
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.delta.content.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
if !content.is_empty() {
|
||||
let _ = tx_storage.send(content.clone()).await;
|
||||
|
||||
let mut state = json_state.lock().await;
|
||||
let display_content = state.process_chunk(&content);
|
||||
drop(state);
|
||||
if !display_content.is_empty() {
|
||||
yield Ok(Event::default()
|
||||
.event("chat_message")
|
||||
.data(display_content));
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
yield Ok(Event::default()
|
||||
.event("error")
|
||||
.data(format!("Stream error: {e}")));
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.flatten()
|
||||
.chain(stream::once(async move {
|
||||
#[derive(Serialize)]
|
||||
struct LocalReferenceData {
|
||||
message: Message,
|
||||
}
|
||||
|
||||
if let Some(message) = rx_final.recv().await {
|
||||
if message
|
||||
.references
|
||||
.as_ref()
|
||||
.is_some_and(std::vec::Vec::is_empty)
|
||||
{
|
||||
return Ok(Event::default().event("empty"));
|
||||
}
|
||||
|
||||
match state.templates.render(
|
||||
"chat/reference_list.html",
|
||||
&Value::from_serialize(LocalReferenceData { message }),
|
||||
) {
|
||||
Ok(html) => Ok(Event::default().event("references").data(html)),
|
||||
Err(_) => Ok(Event::default()
|
||||
.event("error")
|
||||
.data("Failed to render references")),
|
||||
}
|
||||
} else {
|
||||
Ok(Event::default()
|
||||
.event("error")
|
||||
.data("Failed to retrieve references"))
|
||||
}
|
||||
}))
|
||||
.chain(once(async {
|
||||
Ok(Event::default()
|
||||
.event("close_stream")
|
||||
.data("Stream complete"))
|
||||
}))
|
||||
.boxed();
|
||||
|
||||
sse_with_keep_alive(event_stream)
|
||||
}
|
||||
|
||||
async fn prepare_chat_request(
|
||||
state: &HtmlState,
|
||||
user_message: &Message,
|
||||
user: &User,
|
||||
history: &[Message],
|
||||
) -> Result<
|
||||
(async_openai::types::CreateChatCompletionRequest, Vec<String>),
|
||||
Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>>,
|
||||
> {
|
||||
let rerank_lease = match state.reranker_pool.as_ref() {
|
||||
Some(pool) => Some(pool.checkout().await),
|
||||
Some(pool) => pool.checkout().await,
|
||||
None => None,
|
||||
};
|
||||
|
||||
@@ -248,59 +356,49 @@ pub async fn get_response_stream(
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(_e) => {
|
||||
return sse_with_keep_alive(create_error_stream("Failed to retrieve knowledge"));
|
||||
return Err(Sse::new(create_error_stream("Failed to retrieve knowledge")));
|
||||
}
|
||||
};
|
||||
|
||||
let allowed_reference_ids = collect_reference_ids_from_retrieval(&retrieval_result);
|
||||
|
||||
// 3. Create the OpenAI request with appropriate context format
|
||||
let context_json = match &retrieval_result {
|
||||
retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(chunks),
|
||||
let context_json = match retrieval_result {
|
||||
retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(&chunks),
|
||||
retrieval_pipeline::StrategyOutput::Entities(entities) => {
|
||||
retrieved_entities_to_json(entities)
|
||||
}
|
||||
retrieval_pipeline::StrategyOutput::Search(search_result) => {
|
||||
// For chat, use chunks from the search result
|
||||
chunks_to_chat_context(&search_result.chunks)
|
||||
}
|
||||
};
|
||||
let formatted_user_message =
|
||||
create_user_message_with_history(&context_json, &history, &user_message.content);
|
||||
create_user_message_with_history(&context_json, history, &user_message.content);
|
||||
let Ok(settings) = SystemSettings::get_current(&state.db).await else {
|
||||
return sse_with_keep_alive(create_error_stream("Failed to retrieve system settings"));
|
||||
return Err(Sse::new(create_error_stream("Failed to retrieve system settings")));
|
||||
};
|
||||
let Ok(request) = create_chat_request(formatted_user_message, &settings) else {
|
||||
return sse_with_keep_alive(create_error_stream("Failed to create chat request"));
|
||||
return Err(Sse::new(create_error_stream("Failed to create chat request")));
|
||||
};
|
||||
|
||||
// 4. Set up the OpenAI stream
|
||||
let openai_stream = match state.openai_client.chat().create_stream(request).await {
|
||||
Ok(stream) => stream,
|
||||
Err(_e) => {
|
||||
return sse_with_keep_alive(create_error_stream("Failed to create OpenAI stream"));
|
||||
}
|
||||
};
|
||||
Ok((request, allowed_reference_ids))
|
||||
}
|
||||
|
||||
// 5. Create channel for collecting complete response
|
||||
let (tx, mut rx) = channel::<String>(1000);
|
||||
let tx_clone = tx.clone();
|
||||
let (tx_final, mut rx_final) = channel::<Message>(1);
|
||||
fn spawn_storage_task(
|
||||
db_client: Arc<SurrealDbClient>,
|
||||
mut rx: tokio::sync::mpsc::Receiver<String>,
|
||||
tx_final: tokio::sync::mpsc::Sender<Message>,
|
||||
user_message: &Message,
|
||||
user_id: String,
|
||||
allowed_reference_ids: Vec<String>,
|
||||
) {
|
||||
let conversation_id = user_message.conversation_id.clone();
|
||||
|
||||
// 6. Set up the collection task for DB storage
|
||||
let db_client = Arc::clone(&state.db);
|
||||
let user_id = user.id.clone();
|
||||
let allowed_reference_ids = allowed_reference_ids.clone();
|
||||
tokio::spawn(async move {
|
||||
drop(tx); // Close sender when no longer needed
|
||||
|
||||
// Collect full response
|
||||
let mut full_json = String::new();
|
||||
while let Some(chunk) = rx.recv().await {
|
||||
full_json.push_str(&chunk);
|
||||
}
|
||||
|
||||
// Try to extract structured data
|
||||
if let Ok(response) = from_str::<LLMResponseFormat>(&full_json) {
|
||||
let raw_references = extract_reference_strings(&response);
|
||||
let answer = response.answer;
|
||||
@@ -347,7 +445,7 @@ pub async fn get_response_stream(
|
||||
);
|
||||
|
||||
let ai_message = Message::new(
|
||||
user_message.conversation_id,
|
||||
conversation_id,
|
||||
MessageRole::AI,
|
||||
answer,
|
||||
Some(initial_validation.valid_refs),
|
||||
@@ -362,104 +460,11 @@ pub async fn get_response_stream(
|
||||
} else {
|
||||
error!("Failed to parse LLM response as structured format");
|
||||
|
||||
// Fallback - store raw response
|
||||
let ai_message = Message::new(
|
||||
user_message.conversation_id,
|
||||
MessageRole::AI,
|
||||
full_json,
|
||||
None,
|
||||
);
|
||||
let ai_message = Message::new(conversation_id, MessageRole::AI, full_json, None);
|
||||
|
||||
let _ = db_client.store_item(ai_message).await;
|
||||
}
|
||||
});
|
||||
|
||||
// Create a shared state for tracking the JSON parsing
|
||||
let json_state = Arc::new(Mutex::new(StreamParserState::new()));
|
||||
|
||||
// 7. Create the response event stream
|
||||
let event_stream = openai_stream
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
|
||||
.map(move |result| {
|
||||
let tx_storage = tx_clone.clone();
|
||||
let json_state = Arc::clone(&json_state);
|
||||
|
||||
stream! {
|
||||
match result {
|
||||
Ok(response) => {
|
||||
let content = response
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.delta.content.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
if !content.is_empty() {
|
||||
// Always send raw content to storage
|
||||
let _ = tx_storage.send(content.clone()).await;
|
||||
|
||||
// Process through JSON parser
|
||||
let mut state = json_state.lock().await;
|
||||
let display_content = state.process_chunk(&content);
|
||||
drop(state);
|
||||
if !display_content.is_empty() {
|
||||
yield Ok(Event::default()
|
||||
.event("chat_message")
|
||||
.data(display_content));
|
||||
}
|
||||
// If display_content is empty, don't yield anything
|
||||
}
|
||||
// If content is empty, don't yield anything
|
||||
}
|
||||
Err(e) => {
|
||||
yield Ok(Event::default()
|
||||
.event("error")
|
||||
.data(format!("Stream error: {e}")));
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.flatten()
|
||||
.chain(stream::once(async move {
|
||||
if let Some(message) = rx_final.recv().await {
|
||||
// Don't send any event if references is empty
|
||||
if message
|
||||
.references
|
||||
.as_ref()
|
||||
.is_some_and(std::vec::Vec::is_empty)
|
||||
{
|
||||
return Ok(Event::default().event("empty")); // This event won't be sent
|
||||
}
|
||||
|
||||
// Render template with references
|
||||
match state.templates.render(
|
||||
"chat/reference_list.html",
|
||||
&Value::from_serialize(ReferenceData { message }),
|
||||
) {
|
||||
Ok(html) => {
|
||||
// Return the rendered HTML
|
||||
Ok(Event::default().event("references").data(html))
|
||||
}
|
||||
Err(_) => {
|
||||
// Handle template rendering error
|
||||
Ok(Event::default()
|
||||
.event("error")
|
||||
.data("Failed to render references"))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Handle case where no references were received
|
||||
Ok(Event::default()
|
||||
.event("error")
|
||||
.data("Failed to retrieve references"))
|
||||
}
|
||||
}))
|
||||
.chain(once(async {
|
||||
Ok(Event::default()
|
||||
.event("close_stream")
|
||||
.data("Stream complete"))
|
||||
}));
|
||||
|
||||
sse_with_keep_alive(event_stream.boxed())
|
||||
}
|
||||
|
||||
struct StreamParserState {
|
||||
@@ -478,23 +483,18 @@ impl StreamParserState {
|
||||
}
|
||||
|
||||
fn process_chunk(&mut self, chunk: &str) -> String {
|
||||
// Feed all characters into the parser
|
||||
for c in chunk.chars() {
|
||||
let _ = self.parser.add_char(c);
|
||||
}
|
||||
|
||||
// Get the current state of the JSON
|
||||
let json = self.parser.get_result();
|
||||
|
||||
// Check if we're in the answer field
|
||||
if let Some(obj) = json.as_object() {
|
||||
if let Some(answer) = obj.get("answer") {
|
||||
self.in_answer_field = true;
|
||||
|
||||
// Get current answer content
|
||||
let current_content = answer.as_str().unwrap_or_default().to_string();
|
||||
|
||||
// Calculate difference to send only new content
|
||||
if current_content.len() > self.last_answer_content.len() {
|
||||
let new_content = current_content[self.last_answer_content.len()..].to_string();
|
||||
self.last_answer_content = current_content;
|
||||
@@ -503,7 +503,6 @@ impl StreamParserState {
|
||||
}
|
||||
}
|
||||
|
||||
// No new content to return
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,8 +10,9 @@ use axum::{
|
||||
};
|
||||
pub use chat_handlers::{
|
||||
delete_conversation, new_chat_user_message, new_user_message, patch_conversation_title,
|
||||
reload_sidebar, show_chat_base, show_conversation_editing_title, show_existing_chat,
|
||||
show_initialized_chat,
|
||||
reload_sidebar, show_conversation_editing_title,
|
||||
show_chat_base as show_base, show_existing_chat as show_existing,
|
||||
show_initialized_chat as show_initialized,
|
||||
};
|
||||
use message_response_stream::get_response_stream;
|
||||
use references::show_reference_tooltip;
|
||||
@@ -24,10 +25,10 @@ where
|
||||
HtmlState: FromRef<S>,
|
||||
{
|
||||
Router::new()
|
||||
.route("/chat", get(show_chat_base).post(new_chat_user_message))
|
||||
.route("/chat", get(show_base).post(new_chat_user_message))
|
||||
.route(
|
||||
"/chat/{id}",
|
||||
get(show_existing_chat)
|
||||
get(show_existing)
|
||||
.post(new_user_message)
|
||||
.delete(delete_conversation),
|
||||
)
|
||||
@@ -36,7 +37,7 @@ where
|
||||
get(show_conversation_editing_title).patch(patch_conversation_title),
|
||||
)
|
||||
.route("/chat/sidebar", get(reload_sidebar))
|
||||
.route("/initialized-chat", post(show_initialized_chat))
|
||||
.route("/initialized-chat", post(show_initialized))
|
||||
.route("/chat/response-stream", get(get_response_stream))
|
||||
.route("/chat/reference/{id}", get(show_reference_tooltip))
|
||||
}
|
||||
|
||||
@@ -102,13 +102,13 @@ pub async fn show_text_content_edit_form(
|
||||
RequireUser(user): RequireUser,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct TextContentEditModal {
|
||||
pub text_content: TextContent,
|
||||
}
|
||||
|
||||
let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
"content/edit_text_content_modal.html",
|
||||
TextContentEditModal { text_content },
|
||||
@@ -214,13 +214,14 @@ pub async fn show_content_read_modal(
|
||||
RequireUser(user): RequireUser,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
// Get and validate the text content
|
||||
let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
|
||||
#[derive(Serialize)]
|
||||
pub struct TextContentReadModalData {
|
||||
pub text_content: TextContent,
|
||||
}
|
||||
|
||||
// Get and validate the text content
|
||||
let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
"content/read_content_modal.html",
|
||||
TextContentReadModalData { text_content },
|
||||
|
||||
@@ -226,7 +226,7 @@ fn summarize_task_content(task: &IngestionTask) -> (String, String) {
|
||||
("Text".to_string(), truncate_summary(text, 80))
|
||||
}
|
||||
common::storage::types::ingestion_payload::IngestionPayload::Url { url, .. } => {
|
||||
("URL".to_string(), url.to_string())
|
||||
("URL".to_string(), url.clone())
|
||||
}
|
||||
common::storage::types::ingestion_payload::IngestionPayload::File { file_info, .. } => {
|
||||
("File".to_string(), file_info.file_name.clone())
|
||||
@@ -248,18 +248,16 @@ pub async fn serve_file(
|
||||
RequireUser(user): RequireUser,
|
||||
Path(file_id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let file_info = match FileInfo::get_by_id(&file_id, &state.db).await {
|
||||
Ok(info) => info,
|
||||
_ => return Ok(TemplateResponse::not_found().into_response()),
|
||||
let Ok(file_info) = FileInfo::get_by_id(&file_id, &state.db).await else {
|
||||
return Ok(TemplateResponse::not_found().into_response());
|
||||
};
|
||||
|
||||
if file_info.user_id != user.id {
|
||||
return Ok(TemplateResponse::unauthorized().into_response());
|
||||
}
|
||||
|
||||
let stream = match state.storage.get_stream(&file_info.path).await {
|
||||
Ok(s) => s,
|
||||
Err(_) => return Ok(TemplateResponse::server_error().into_response()),
|
||||
let Ok(stream) = state.storage.get_stream(&file_info.path).await else {
|
||||
return Ok(TemplateResponse::server_error().into_response());
|
||||
};
|
||||
let body = Body::from_stream(stream);
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::{pin::Pin, time::Duration};
|
||||
use std::{pin::Pin, sync::Arc, time::Duration};
|
||||
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
@@ -51,13 +51,13 @@ pub async fn show_ingest_form(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let user_categories = User::get_user_categories(&user.id, &state.db).await?;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ShowIngestFormData {
|
||||
user_categories: Vec<String>,
|
||||
}
|
||||
|
||||
let user_categories = User::get_user_categories(&user.id, &state.db).await?;
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
"ingestion_modal.html",
|
||||
ShowIngestFormData { user_categories },
|
||||
@@ -180,7 +180,7 @@ pub async fn get_task_updates_stream(
|
||||
Query(params): Query<QueryParams>,
|
||||
) -> TaskSse {
|
||||
let task_id = params.task_id.clone();
|
||||
let db = state.db.clone();
|
||||
let db = Arc::clone(&state.db);
|
||||
|
||||
// 1. Check for authenticated user
|
||||
let Some(current_user) = auth.current_user else {
|
||||
@@ -198,7 +198,7 @@ pub async fn get_task_updates_stream(
|
||||
}
|
||||
|
||||
let sse_stream = async_stream::stream! {
|
||||
let mut consecutive_db_errors = 0;
|
||||
let mut consecutive_db_errors: u32 = 0;
|
||||
let max_consecutive_db_errors = 3;
|
||||
|
||||
loop {
|
||||
@@ -263,7 +263,7 @@ pub async fn get_task_updates_stream(
|
||||
}
|
||||
Err(db_err) => {
|
||||
error!("Database error while fetching task '{}': {:?}", task_id, db_err);
|
||||
consecutive_db_errors += 1;
|
||||
consecutive_db_errors = consecutive_db_errors.saturating_add(1);
|
||||
yield Ok(Event::default().event("error").data(format!("Temporary error fetching task update (attempt {consecutive_db_errors}).")));
|
||||
|
||||
if consecutive_db_errors >= max_consecutive_db_errors {
|
||||
|
||||
@@ -39,7 +39,7 @@ use url::form_urlencoded;
|
||||
|
||||
const KNOWLEDGE_ENTITIES_PER_PAGE: usize = 12;
|
||||
const RELATIONSHIP_TYPE_OPTIONS: &[&str] = &["RelatedTo", "RelevantTo", "SimilarTo", "References"];
|
||||
const DEFAULT_RELATIONSHIP_TYPE: &str = RELATIONSHIP_TYPE_OPTIONS[0];
|
||||
const DEFAULT_RELATIONSHIP_TYPE: &str = "RelatedTo";
|
||||
const MAX_RELATIONSHIP_SUGGESTIONS: usize = 10;
|
||||
const SUGGESTION_MIN_SCORE: f32 = 0.5;
|
||||
|
||||
@@ -61,15 +61,15 @@ fn canonicalize_relationship_type(value: &str) -> String {
|
||||
|
||||
let key: String = trimmed
|
||||
.chars()
|
||||
.filter(|c| c.is_ascii_alphanumeric())
|
||||
.flat_map(|c| c.to_lowercase())
|
||||
.filter(char::is_ascii_alphanumeric)
|
||||
.flat_map(char::to_lowercase)
|
||||
.collect();
|
||||
|
||||
for option in RELATIONSHIP_TYPE_OPTIONS {
|
||||
let option_key: String = option
|
||||
.chars()
|
||||
.filter(|c| c.is_ascii_alphanumeric())
|
||||
.flat_map(|c| c.to_lowercase())
|
||||
.filter(char::is_ascii_alphanumeric)
|
||||
.flat_map(char::to_lowercase)
|
||||
.collect();
|
||||
if option_key == key {
|
||||
return (*option).to_string();
|
||||
@@ -141,7 +141,7 @@ pub async fn show_new_knowledge_entity_form(
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let entity_types: Vec<String> = KnowledgeEntityType::variants()
|
||||
.iter()
|
||||
.map(|&s| s.to_owned())
|
||||
.map(ToString::to_string)
|
||||
.collect();
|
||||
|
||||
let existing_entities = User::get_knowledge_entities(&user.id, &state.db).await?;
|
||||
@@ -278,7 +278,7 @@ pub async fn suggest_knowledge_relationships(
|
||||
if !query_parts.is_empty() {
|
||||
let query = query_parts.join(" ");
|
||||
let rerank_lease = match state.reranker_pool.as_ref() {
|
||||
Some(pool) => Some(pool.checkout().await),
|
||||
Some(pool) => pool.checkout().await,
|
||||
None => None,
|
||||
};
|
||||
|
||||
@@ -406,9 +406,10 @@ fn build_relationship_table_data(
|
||||
.map(|relationship| {
|
||||
let relationship_type_label =
|
||||
canonicalize_relationship_type(&relationship.metadata.relationship_type);
|
||||
*frequency
|
||||
let count = frequency
|
||||
.entry(relationship_type_label.clone())
|
||||
.or_insert(0) += 1;
|
||||
.or_insert(0);
|
||||
*count = count.saturating_add(1);
|
||||
RelationshipTableRow {
|
||||
relationship,
|
||||
relationship_type_label,
|
||||
@@ -417,9 +418,7 @@ fn build_relationship_table_data(
|
||||
.collect();
|
||||
let default_relationship_type = frequency
|
||||
.into_iter()
|
||||
.max_by_key(|(_, count)| *count)
|
||||
.map(|(label, _)| label)
|
||||
.unwrap_or_else(|| DEFAULT_RELATIONSHIP_TYPE.to_string());
|
||||
.max_by_key(|(_, count)| *count).map_or_else(|| DEFAULT_RELATIONSHIP_TYPE.to_string(), |(label, _)| label);
|
||||
|
||||
RelationshipTableData {
|
||||
entities,
|
||||
@@ -800,8 +799,10 @@ pub async fn get_knowledge_graph_json(
|
||||
for rel in &relationships {
|
||||
if entity_ids.contains(&rel.in_) && entity_ids.contains(&rel.out) {
|
||||
// undirected counting for degree
|
||||
*degree_count.entry(rel.in_.clone()).or_insert(0) += 1;
|
||||
*degree_count.entry(rel.out.clone()).or_insert(0) += 1;
|
||||
let count = degree_count.entry(rel.in_.clone()).or_insert(0);
|
||||
*count = count.saturating_add(1);
|
||||
let count = degree_count.entry(rel.out.clone()).or_insert(0);
|
||||
*count = count.saturating_add(1);
|
||||
links.push(GraphLink {
|
||||
source: rel.out.clone(),
|
||||
target: rel.in_.clone(),
|
||||
@@ -836,11 +837,11 @@ fn normalize_filter(input: Option<String>) -> Option<String> {
|
||||
|
||||
fn trim_matching_quotes(value: &str) -> &str {
|
||||
let bytes = value.as_bytes();
|
||||
if bytes.len() >= 2 {
|
||||
let first = bytes[0];
|
||||
let last = bytes[bytes.len() - 1];
|
||||
if (first == b'"' && last == b'"') || (first == b'\'' && last == b'\'') {
|
||||
return &value[1..value.len() - 1];
|
||||
if let (Some(&first), Some(&last)) = (bytes.first(), bytes.last()) {
|
||||
if bytes.len() >= 2
|
||||
&& ((first == b'"' && last == b'"') || (first == b'\'' && last == b'\''))
|
||||
{
|
||||
return &value[1..value.len().saturating_sub(1)];
|
||||
}
|
||||
}
|
||||
value
|
||||
@@ -860,7 +861,7 @@ pub async fn show_edit_knowledge_entity_form(
|
||||
// Get entity types
|
||||
let entity_types: Vec<String> = KnowledgeEntityType::variants()
|
||||
.iter()
|
||||
.map(|&s| s.to_owned())
|
||||
.map(ToString::to_string)
|
||||
.collect();
|
||||
|
||||
// Get the entity and validate ownership
|
||||
|
||||
@@ -11,6 +11,7 @@ use axum::{
|
||||
use common::storage::types::{
|
||||
serde_helpers::deserialize_flexible_id,
|
||||
text_content::TextContent,
|
||||
user::User,
|
||||
StoredObject,
|
||||
};
|
||||
use retrieval_pipeline::{RetrievalConfig, SearchResult, SearchTarget, StrategyOutput};
|
||||
@@ -46,13 +47,11 @@ fn source_id_suffix(source_id: &str) -> String {
|
||||
|
||||
fn truncate_label(value: &str, max_chars: usize) -> String {
|
||||
let mut end = None;
|
||||
let mut count = 0;
|
||||
for (idx, _) in value.char_indices() {
|
||||
for (count, (idx, _)) in value.char_indices().enumerate() {
|
||||
if count == max_chars {
|
||||
end = Some(idx);
|
||||
break;
|
||||
}
|
||||
count += 1;
|
||||
}
|
||||
|
||||
match end {
|
||||
@@ -174,59 +173,135 @@ struct KnowledgeEntityForTemplate {
|
||||
score: f32,
|
||||
}
|
||||
|
||||
pub async fn search_result_handler(
|
||||
State(state): State<HtmlState>,
|
||||
Query(params): Query<SearchParams>,
|
||||
RequireUser(user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
#[derive(Serialize)]
|
||||
struct SearchResultForTemplate {
|
||||
#[derive(Serialize)]
|
||||
struct SearchResultForTemplate {
|
||||
result_type: String,
|
||||
score: f32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
text_chunk: Option<TextChunkForTemplate>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
knowledge_entity: Option<KnowledgeEntityForTemplate>,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct AnswerData {
|
||||
#[derive(Serialize)]
|
||||
pub struct AnswerData {
|
||||
search_result: Vec<SearchResultForTemplate>,
|
||||
query_param: String,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn search_result_handler(
|
||||
State(state): State<HtmlState>,
|
||||
Query(params): Query<SearchParams>,
|
||||
RequireUser(user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let (search_results_for_template, final_query_param_for_template) = if let Some(actual_query) =
|
||||
params.query
|
||||
{
|
||||
let trimmed_query = actual_query.trim();
|
||||
if trimmed_query.is_empty() {
|
||||
(Vec::<SearchResultForTemplate>::new(), String::new())
|
||||
perform_search(&state, &user, actual_query).await?
|
||||
} else {
|
||||
// Use retrieval pipeline Search strategy
|
||||
(Vec::<SearchResultForTemplate>::new(), String::new())
|
||||
};
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
"search/base.html",
|
||||
AnswerData {
|
||||
search_result: search_results_for_template,
|
||||
query_param: final_query_param_for_template,
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
async fn perform_search(
|
||||
state: &HtmlState,
|
||||
user: &User,
|
||||
query: String,
|
||||
) -> Result<(Vec<SearchResultForTemplate>, String), HtmlError> {
|
||||
const TOTAL_LIMIT: usize = 10;
|
||||
|
||||
let trimmed_query = query.trim();
|
||||
if trimmed_query.is_empty() {
|
||||
return Ok((Vec::new(), String::new()));
|
||||
}
|
||||
|
||||
let config = RetrievalConfig::for_search(SearchTarget::Both);
|
||||
|
||||
// Checkout a reranker lease if pool is available
|
||||
let reranker_lease = match &state.reranker_pool {
|
||||
Some(pool) => Some(pool.checkout().await),
|
||||
Some(pool) => pool.checkout().await,
|
||||
None => None,
|
||||
};
|
||||
|
||||
let result = retrieval_pipeline::pipeline::run_pipeline(
|
||||
&state.db,
|
||||
&state.openai_client,
|
||||
Some(&state.embedding_provider),
|
||||
trimmed_query,
|
||||
&user.id,
|
||||
let params = retrieval_pipeline::pipeline::StrategyParams {
|
||||
db_client: &state.db,
|
||||
openai_client: &state.openai_client,
|
||||
embedding_provider: Some(&state.embedding_provider),
|
||||
input_text: trimmed_query,
|
||||
user_id: &user.id,
|
||||
config,
|
||||
reranker_lease,
|
||||
)
|
||||
.await?;
|
||||
reranker: reranker_lease,
|
||||
};
|
||||
let result = retrieval_pipeline::pipeline::execute(params).await?;
|
||||
|
||||
let search_result = match result {
|
||||
StrategyOutput::Search(sr) => sr,
|
||||
_ => SearchResult::new(vec![], vec![]),
|
||||
};
|
||||
|
||||
let source_label_map = resolve_source_labels(state, user, &search_result).await?;
|
||||
|
||||
let mut combined_results: Vec<SearchResultForTemplate> =
|
||||
Vec::with_capacity(search_result.chunks.len().saturating_add(search_result.entities.len()));
|
||||
|
||||
for chunk_result in search_result.chunks {
|
||||
let source_label = source_label_map
|
||||
.get(&chunk_result.chunk.source_id)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| fallback_source_label(&chunk_result.chunk.source_id));
|
||||
combined_results.push(SearchResultForTemplate {
|
||||
result_type: "text_chunk".to_string(),
|
||||
score: chunk_result.score,
|
||||
text_chunk: Some(TextChunkForTemplate {
|
||||
id: chunk_result.chunk.id,
|
||||
source_id: chunk_result.chunk.source_id,
|
||||
source_label,
|
||||
chunk: chunk_result.chunk.chunk,
|
||||
score: chunk_result.score,
|
||||
}),
|
||||
knowledge_entity: None,
|
||||
});
|
||||
}
|
||||
|
||||
for entity_result in search_result.entities {
|
||||
let source_label = source_label_map
|
||||
.get(&entity_result.entity.source_id)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| fallback_source_label(&entity_result.entity.source_id));
|
||||
combined_results.push(SearchResultForTemplate {
|
||||
result_type: "knowledge_entity".to_string(),
|
||||
score: entity_result.score,
|
||||
text_chunk: None,
|
||||
knowledge_entity: Some(KnowledgeEntityForTemplate {
|
||||
id: entity_result.entity.id,
|
||||
name: entity_result.entity.name,
|
||||
description: entity_result.entity.description,
|
||||
entity_type: format!("{:?}", entity_result.entity.entity_type),
|
||||
source_id: entity_result.entity.source_id,
|
||||
source_label,
|
||||
score: entity_result.score,
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
combined_results.sort_by(|a, b| b.score.total_cmp(&a.score));
|
||||
combined_results.truncate(TOTAL_LIMIT);
|
||||
|
||||
Ok((combined_results, trimmed_query.to_string()))
|
||||
}
|
||||
|
||||
async fn resolve_source_labels(
|
||||
state: &HtmlState,
|
||||
user: &User,
|
||||
search_result: &SearchResult,
|
||||
) -> Result<HashMap<String, String>, HtmlError> {
|
||||
let mut source_ids = HashSet::new();
|
||||
for chunk_result in &search_result.chunks {
|
||||
source_ids.insert(chunk_result.chunk.source_id.clone());
|
||||
@@ -235,9 +310,10 @@ pub async fn search_result_handler(
|
||||
source_ids.insert(entity_result.entity.source_id.clone());
|
||||
}
|
||||
|
||||
let source_label_map = if source_ids.is_empty() {
|
||||
HashMap::new()
|
||||
} else {
|
||||
if source_ids.is_empty() {
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
|
||||
let record_ids: Vec<RecordId> = source_ids
|
||||
.iter()
|
||||
.filter_map(|id| {
|
||||
@@ -276,72 +352,5 @@ pub async fn search_result_handler(
|
||||
);
|
||||
}
|
||||
|
||||
labels
|
||||
};
|
||||
|
||||
let mut combined_results: Vec<SearchResultForTemplate> =
|
||||
Vec::with_capacity(search_result.chunks.len() + search_result.entities.len());
|
||||
|
||||
// Add chunk results
|
||||
for chunk_result in search_result.chunks {
|
||||
let source_label = source_label_map
|
||||
.get(&chunk_result.chunk.source_id)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| fallback_source_label(&chunk_result.chunk.source_id));
|
||||
combined_results.push(SearchResultForTemplate {
|
||||
result_type: "text_chunk".to_string(),
|
||||
score: chunk_result.score,
|
||||
text_chunk: Some(TextChunkForTemplate {
|
||||
id: chunk_result.chunk.id,
|
||||
source_id: chunk_result.chunk.source_id,
|
||||
source_label,
|
||||
chunk: chunk_result.chunk.chunk,
|
||||
score: chunk_result.score,
|
||||
}),
|
||||
knowledge_entity: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Add entity results
|
||||
for entity_result in search_result.entities {
|
||||
let source_label = source_label_map
|
||||
.get(&entity_result.entity.source_id)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| fallback_source_label(&entity_result.entity.source_id));
|
||||
combined_results.push(SearchResultForTemplate {
|
||||
result_type: "knowledge_entity".to_string(),
|
||||
score: entity_result.score,
|
||||
text_chunk: None,
|
||||
knowledge_entity: Some(KnowledgeEntityForTemplate {
|
||||
id: entity_result.entity.id,
|
||||
name: entity_result.entity.name,
|
||||
description: entity_result.entity.description,
|
||||
entity_type: format!("{:?}", entity_result.entity.entity_type),
|
||||
source_id: entity_result.entity.source_id,
|
||||
source_label,
|
||||
score: entity_result.score,
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
// Sort by score descending
|
||||
combined_results.sort_by(|a, b| b.score.total_cmp(&a.score));
|
||||
|
||||
// Limit results
|
||||
const TOTAL_LIMIT: usize = 10;
|
||||
combined_results.truncate(TOTAL_LIMIT);
|
||||
|
||||
(combined_results, trimmed_query.to_string())
|
||||
}
|
||||
} else {
|
||||
(Vec::<SearchResultForTemplate>::new(), String::new())
|
||||
};
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
"search/base.html",
|
||||
AnswerData {
|
||||
search_result: search_results_for_template,
|
||||
query_param: final_query_param_for_template,
|
||||
},
|
||||
))
|
||||
Ok(labels)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
mod handlers;
|
||||
|
||||
use axum::{extract::FromRef, routing::get, Router};
|
||||
pub use handlers::{search_result_handler, SearchParams};
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
pub use handlers::{
|
||||
search_result_handler as result_handler, SearchParams as SearchQueryParams,
|
||||
};
|
||||
|
||||
use crate::html_state::HtmlState;
|
||||
|
||||
@@ -10,5 +13,5 @@ where
|
||||
S: Clone + Send + Sync + 'static,
|
||||
HtmlState: FromRef<S>,
|
||||
{
|
||||
Router::new().route("/search", get(search_result_handler))
|
||||
Router::new().route("/search", get(result_handler))
|
||||
}
|
||||
|
||||
@@ -31,8 +31,8 @@ impl Pagination {
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let start_index = if page_len == 0 { 0 } else { offset + 1 };
|
||||
let end_index = if page_len == 0 { 0 } else { offset + page_len };
|
||||
let start_index = if page_len == 0 { 0 } else { offset.saturating_add(1) };
|
||||
let end_index = if page_len == 0 { 0 } else { offset.saturating_add(page_len) };
|
||||
|
||||
Self {
|
||||
current_page,
|
||||
@@ -42,12 +42,12 @@ impl Pagination {
|
||||
has_previous,
|
||||
has_next,
|
||||
previous_page: if has_previous {
|
||||
Some(current_page - 1)
|
||||
Some(current_page.saturating_sub(1))
|
||||
} else {
|
||||
None
|
||||
},
|
||||
next_page: if has_next {
|
||||
Some(current_page + 1)
|
||||
Some(current_page.saturating_add(1))
|
||||
} else {
|
||||
None
|
||||
},
|
||||
@@ -68,7 +68,7 @@ pub fn paginate_items<T>(
|
||||
let total_pages = if total_items == 0 {
|
||||
0
|
||||
} else {
|
||||
((total_items - 1) / per_page) + 1
|
||||
total_items.saturating_sub(1).checked_div(per_page).unwrap_or(0).saturating_add(1)
|
||||
};
|
||||
|
||||
let mut current_page = requested_page.unwrap_or(1);
|
||||
@@ -84,7 +84,7 @@ pub fn paginate_items<T>(
|
||||
let offset = if total_pages == 0 {
|
||||
0
|
||||
} else {
|
||||
per_page.saturating_mul(current_page - 1)
|
||||
per_page.saturating_mul(current_page.saturating_sub(1))
|
||||
};
|
||||
|
||||
let page_items: Vec<T> = items.into_iter().skip(offset).take(per_page).collect();
|
||||
@@ -136,8 +136,8 @@ mod tests {
|
||||
assert_eq!(page, vec![5]);
|
||||
assert_eq!(meta.current_page, 3);
|
||||
assert_eq!(meta.total_pages, 3);
|
||||
assert_eq!(meta.has_next, false);
|
||||
assert_eq!(meta.has_previous, true);
|
||||
assert!(!meta.has_next, "expected no next page");
|
||||
assert!(meta.has_previous, "expected previous page");
|
||||
assert_eq!(meta.start_index, 5);
|
||||
assert_eq!(meta.end_index, 5);
|
||||
}
|
||||
|
||||
@@ -180,7 +180,7 @@ impl PipelineServices for DefaultPipelineServices {
|
||||
);
|
||||
|
||||
let rerank_lease = match &self.reranker_pool {
|
||||
Some(pool) => Some(pool.checkout().await),
|
||||
Some(pool) => pool.checkout().await,
|
||||
None => None,
|
||||
};
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
indexes::rebuild_indexes,
|
||||
indexes::rebuild,
|
||||
types::{
|
||||
ingestion_payload::IngestionPayload, knowledge_entity::KnowledgeEntity,
|
||||
knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk,
|
||||
@@ -191,7 +191,7 @@ pub async fn persist(
|
||||
ctx.db.store_item(text_content).await?;
|
||||
|
||||
debug!("stored item");
|
||||
rebuild_indexes(ctx.db).await?;
|
||||
rebuild(ctx.db).await?;
|
||||
|
||||
debug!(
|
||||
task_id = %ctx.task_id,
|
||||
@@ -301,8 +301,8 @@ async fn store_chunk_batch(
|
||||
|
||||
for embedded in batch {
|
||||
TextChunk::store_with_embedding(
|
||||
embedded.chunk.to_owned(),
|
||||
embedded.embedding.to_owned(),
|
||||
embedded.chunk.clone(),
|
||||
embedded.embedding.clone(),
|
||||
db,
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{self, Context};
|
||||
use crate::pipeline::context::{EmbeddedKnowledgeEntity, EmbeddedTextChunk};
|
||||
use async_trait::async_trait;
|
||||
use chrono::{Duration as ChronoDuration, Utc};
|
||||
@@ -265,16 +266,12 @@ impl PipelineServices for ValidationServices {
|
||||
}
|
||||
}
|
||||
|
||||
async fn setup_db() -> SurrealDbClient {
|
||||
async fn setup_db() -> anyhow::Result<SurrealDbClient> {
|
||||
let namespace = "pipeline_test";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.expect("Failed to create in-memory SurrealDB");
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
db
|
||||
let db = SurrealDbClient::memory(namespace, &database).await?;
|
||||
db.apply_migrations().await?;
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
fn pipeline_config() -> IngestionConfig {
|
||||
@@ -295,26 +292,28 @@ async fn reserve_task(
|
||||
worker_id: &str,
|
||||
payload: IngestionPayload,
|
||||
user_id: &str,
|
||||
) -> IngestionTask {
|
||||
let task = IngestionTask::create_and_add_to_db(payload, user_id.into(), db)
|
||||
.await
|
||||
.expect("task created");
|
||||
) -> anyhow::Result<IngestionTask> {
|
||||
let task = IngestionTask::create_and_add_to_db(payload, user_id.into(), db).await?;
|
||||
let lease = task.lease_duration();
|
||||
IngestionTask::claim_next_ready(db, worker_id, Utc::now(), lease)
|
||||
.await
|
||||
.expect("claim succeeds")
|
||||
.expect("task claimed")
|
||||
let claimed = IngestionTask::claim_next_ready(db, worker_id, Utc::now(), lease)
|
||||
.await?
|
||||
.context("task claimed")?;
|
||||
Ok(claimed)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ingestion_pipeline_happy_path_persists_entities() {
|
||||
let db = setup_db().await;
|
||||
async fn ingestion_pipeline_happy_path_persists_entities() -> anyhow::Result<()>
|
||||
{
|
||||
let db = setup_db().await?;
|
||||
let worker_id = "worker-happy";
|
||||
let user_id = "user-123";
|
||||
let services = Arc::new(MockServices::new(user_id));
|
||||
let pipeline =
|
||||
IngestionPipeline::with_services(Arc::new(db.clone()), pipeline_config(), services.clone())
|
||||
.expect("pipeline");
|
||||
let services_clone: Arc<dyn PipelineServices> = Arc::<MockServices>::clone(&services);
|
||||
let pipeline = IngestionPipeline::with_services(
|
||||
Arc::new(db.clone()),
|
||||
pipeline_config(),
|
||||
services_clone,
|
||||
)?;
|
||||
|
||||
let task = reserve_task(
|
||||
&db,
|
||||
@@ -327,30 +326,22 @@ async fn ingestion_pipeline_happy_path_persists_entities() {
|
||||
},
|
||||
user_id,
|
||||
)
|
||||
.await;
|
||||
.await?;
|
||||
|
||||
pipeline
|
||||
.process_task(task.clone())
|
||||
.await
|
||||
.expect("pipeline succeeds");
|
||||
pipeline.process_task(task.clone()).await?;
|
||||
|
||||
let stored_task: IngestionTask = db
|
||||
.get_item(&task.id)
|
||||
.await
|
||||
.expect("retrieve task")
|
||||
.expect("task present");
|
||||
.await?
|
||||
.context("task present")?;
|
||||
assert_eq!(stored_task.state, TaskState::Succeeded);
|
||||
|
||||
let stored_entities: Vec<KnowledgeEntity> = db
|
||||
.get_all_stored_items::<KnowledgeEntity>()
|
||||
.await
|
||||
.expect("entities stored");
|
||||
.await?;
|
||||
assert!(!stored_entities.is_empty(), "entities should be stored");
|
||||
|
||||
let stored_chunks: Vec<TextChunk> = db
|
||||
.get_all_stored_items::<TextChunk>()
|
||||
.await
|
||||
.expect("chunks stored");
|
||||
let stored_chunks: Vec<TextChunk> = db.get_all_stored_items::<TextChunk>().await?;
|
||||
assert!(
|
||||
!stored_chunks.is_empty(),
|
||||
"chunks should be stored for ingestion text"
|
||||
@@ -362,22 +353,29 @@ async fn ingestion_pipeline_happy_path_persists_entities() {
|
||||
"expected at least one chunk embedding call"
|
||||
);
|
||||
assert_eq!(
|
||||
&call_log[0..4],
|
||||
["prepare", "retrieve", "enrich", "convert"]
|
||||
call_log.get(0..4),
|
||||
Some(&["prepare", "retrieve", "enrich", "convert"][..])
|
||||
);
|
||||
assert!(call_log[4..].iter().all(|entry| *entry == "chunk"));
|
||||
assert!(
|
||||
call_log.get(4..).is_some_and(|tail| tail.iter().all(|entry| *entry == "chunk"))
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ingestion_pipeline_chunk_only_skips_analysis() {
|
||||
let db = setup_db().await;
|
||||
async fn ingestion_pipeline_chunk_only_skips_analysis() -> anyhow::Result<()> {
|
||||
let db = setup_db().await?;
|
||||
let worker_id = "worker-chunk-only";
|
||||
let user_id = "user-999";
|
||||
let services = Arc::new(MockServices::new(user_id));
|
||||
let services_clone: Arc<dyn PipelineServices> = Arc::<MockServices>::clone(&services);
|
||||
let mut config = pipeline_config();
|
||||
config.chunk_only = true;
|
||||
let pipeline = IngestionPipeline::with_services(Arc::new(db.clone()), config, services.clone())
|
||||
.expect("pipeline");
|
||||
let pipeline = IngestionPipeline::with_services(
|
||||
Arc::new(db.clone()),
|
||||
config,
|
||||
services_clone,
|
||||
)?;
|
||||
|
||||
let task = reserve_task(
|
||||
&db,
|
||||
@@ -390,17 +388,13 @@ async fn ingestion_pipeline_chunk_only_skips_analysis() {
|
||||
},
|
||||
user_id,
|
||||
)
|
||||
.await;
|
||||
.await?;
|
||||
|
||||
pipeline
|
||||
.process_task(task.clone())
|
||||
.await
|
||||
.expect("pipeline succeeds");
|
||||
pipeline.process_task(task.clone()).await?;
|
||||
|
||||
let stored_entities: Vec<KnowledgeEntity> = db
|
||||
.get_all_stored_items::<KnowledgeEntity>()
|
||||
.await
|
||||
.expect("entities stored");
|
||||
.await?;
|
||||
assert!(
|
||||
stored_entities.is_empty(),
|
||||
"chunk-only ingestion should not persist entities"
|
||||
@@ -408,8 +402,7 @@ async fn ingestion_pipeline_chunk_only_skips_analysis() {
|
||||
let relationship_count: Option<i64> = db
|
||||
.client
|
||||
.query("SELECT count() as count FROM relates_to;")
|
||||
.await
|
||||
.expect("query relationships")
|
||||
.await?
|
||||
.take::<Option<i64>>(0)
|
||||
.unwrap_or_default();
|
||||
assert_eq!(
|
||||
@@ -417,10 +410,7 @@ async fn ingestion_pipeline_chunk_only_skips_analysis() {
|
||||
0,
|
||||
"chunk-only ingestion should not persist relationships"
|
||||
);
|
||||
let stored_chunks: Vec<TextChunk> = db
|
||||
.get_all_stored_items::<TextChunk>()
|
||||
.await
|
||||
.expect("chunks stored");
|
||||
let stored_chunks: Vec<TextChunk> = db.get_all_stored_items::<TextChunk>().await?;
|
||||
assert!(
|
||||
!stored_chunks.is_empty(),
|
||||
"chunk-only ingestion should still persist chunks"
|
||||
@@ -428,19 +418,19 @@ async fn ingestion_pipeline_chunk_only_skips_analysis() {
|
||||
|
||||
let call_log = services.calls.lock().await.clone();
|
||||
assert_eq!(call_log, vec!["prepare", "chunk"]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ingestion_pipeline_failure_marks_retry() {
|
||||
let db = setup_db().await;
|
||||
async fn ingestion_pipeline_failure_marks_retry() -> anyhow::Result<()> {
|
||||
let db = setup_db().await?;
|
||||
let worker_id = "worker-fail";
|
||||
let user_id = "user-456";
|
||||
let services = Arc::new(FailingServices {
|
||||
inner: MockServices::new(user_id),
|
||||
});
|
||||
let pipeline =
|
||||
IngestionPipeline::with_services(Arc::new(db.clone()), pipeline_config(), services)
|
||||
.expect("pipeline");
|
||||
IngestionPipeline::with_services(Arc::new(db.clone()), pipeline_config(), services)?;
|
||||
|
||||
let task = reserve_task(
|
||||
&db,
|
||||
@@ -453,7 +443,7 @@ async fn ingestion_pipeline_failure_marks_retry() {
|
||||
},
|
||||
user_id,
|
||||
)
|
||||
.await;
|
||||
.await?;
|
||||
|
||||
let result = pipeline.process_task(task.clone()).await;
|
||||
assert!(
|
||||
@@ -463,38 +453,38 @@ async fn ingestion_pipeline_failure_marks_retry() {
|
||||
|
||||
let stored_task: IngestionTask = db
|
||||
.get_item(&task.id)
|
||||
.await
|
||||
.expect("retrieve task")
|
||||
.expect("task present");
|
||||
.await?
|
||||
.context("task present")?;
|
||||
assert_eq!(stored_task.state, TaskState::Failed);
|
||||
assert!(
|
||||
stored_task.scheduled_at > Utc::now() - ChronoDuration::seconds(5),
|
||||
"failed task should schedule retry in the future"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ingestion_pipeline_validation_failure_dead_letters_task() {
|
||||
let db = setup_db().await;
|
||||
async fn ingestion_pipeline_validation_failure_dead_letters_task(
|
||||
) -> anyhow::Result<()> {
|
||||
let db = setup_db().await?;
|
||||
let worker_id = "worker-validation";
|
||||
let user_id = "user-789";
|
||||
let services = Arc::new(ValidationServices);
|
||||
let pipeline =
|
||||
IngestionPipeline::with_services(Arc::new(db.clone()), pipeline_config(), services)
|
||||
.expect("pipeline");
|
||||
IngestionPipeline::with_services(Arc::new(db.clone()), pipeline_config(), services)?;
|
||||
|
||||
let task = reserve_task(
|
||||
&db,
|
||||
worker_id,
|
||||
IngestionPayload::Text {
|
||||
text: "irrelevant".into(),
|
||||
context: "".into(),
|
||||
context: String::new(),
|
||||
category: "notes".into(),
|
||||
user_id: user_id.into(),
|
||||
},
|
||||
user_id,
|
||||
)
|
||||
.await;
|
||||
.await?;
|
||||
|
||||
let result = pipeline.process_task(task.clone()).await;
|
||||
assert!(
|
||||
@@ -504,8 +494,8 @@ async fn ingestion_pipeline_validation_failure_dead_letters_task() {
|
||||
|
||||
let stored_task: IngestionTask = db
|
||||
.get_item(&task.id)
|
||||
.await
|
||||
.expect("retrieve task")
|
||||
.expect("task present");
|
||||
.await?
|
||||
.context("task present")?;
|
||||
assert_eq!(stored_task.state, TaskState::DeadLetter);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -155,21 +155,20 @@ mod tests {
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn extracts_text_using_memory_storage_backend() {
|
||||
let mut config = AppConfig::default();
|
||||
config.storage = StorageKind::Memory;
|
||||
async fn extracts_text_using_memory_storage_backend() -> anyhow::Result<()> {
|
||||
let config = AppConfig {
|
||||
storage: StorageKind::Memory,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let storage = StorageManager::new(&config)
|
||||
.await
|
||||
.expect("create storage manager");
|
||||
let storage = StorageManager::new(&config).await?;
|
||||
|
||||
let location = "user/test/file.txt";
|
||||
let contents = b"hello from memory storage";
|
||||
|
||||
storage
|
||||
.put(location, Bytes::from(contents.as_slice().to_vec()))
|
||||
.await
|
||||
.expect("write object");
|
||||
.await?;
|
||||
|
||||
let now = Utc::now();
|
||||
let file_info = FileInfo {
|
||||
@@ -185,16 +184,14 @@ mod tests {
|
||||
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("create surreal memory");
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let openai_client = Client::with_config(OpenAIConfig::default());
|
||||
|
||||
let text = extract_text_from_file(&file_info, &db, &openai_client, &config, &storage)
|
||||
.await
|
||||
.expect("extract text");
|
||||
.await?;
|
||||
|
||||
assert_eq!(text, String::from_utf8_lossy(contents));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -715,6 +715,7 @@ const fn prompt_for_attempt(attempt: usize, base_prompt: &str) -> &str {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use anyhow::{self};
|
||||
|
||||
#[test]
|
||||
fn test_looks_good_enough_short_text() {
|
||||
@@ -737,15 +738,16 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_debug_dump_directory_env_var() {
|
||||
fn test_debug_dump_directory_env_var() -> anyhow::Result<()> {
|
||||
std::env::remove_var(DEBUG_IMAGE_ENV_VAR);
|
||||
assert!(debug_dump_directory().is_none());
|
||||
|
||||
std::env::set_var(DEBUG_IMAGE_ENV_VAR, "/tmp/minne_pdf_debug");
|
||||
let dir = debug_dump_directory().expect("expected debug directory");
|
||||
let dir = debug_dump_directory().ok_or_else(|| anyhow::anyhow!("expected debug directory"))?;
|
||||
assert_eq!(dir, PathBuf::from("/tmp/minne_pdf_debug"));
|
||||
|
||||
std::env::remove_var(DEBUG_IMAGE_ENV_VAR);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -142,29 +142,34 @@ fn ensure_ingestion_url_allowed(url: &url::Url) -> Result<String, AppError> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use anyhow::{self};
|
||||
|
||||
#[test]
|
||||
fn rejects_unsupported_scheme() {
|
||||
let url = url::Url::parse("ftp://example.com").expect("url");
|
||||
fn rejects_unsupported_scheme() -> anyhow::Result<()> {
|
||||
let url = url::Url::parse("ftp://example.com")?;
|
||||
assert!(ensure_ingestion_url_allowed(&url).is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_localhost() {
|
||||
let url = url::Url::parse("http://localhost/resource").expect("url");
|
||||
fn rejects_localhost() -> anyhow::Result<()> {
|
||||
let url = url::Url::parse("http://localhost/resource")?;
|
||||
assert!(ensure_ingestion_url_allowed(&url).is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_private_ipv4() {
|
||||
let url = url::Url::parse("http://192.168.1.10/index.html").expect("url");
|
||||
fn rejects_private_ipv4() -> anyhow::Result<()> {
|
||||
let url = url::Url::parse("http://192.168.1.10/index.html")?;
|
||||
assert!(ensure_ingestion_url_allowed(&url).is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allows_public_domain_and_sanitizes() {
|
||||
let url = url::Url::parse("https://sub.example.com/path").expect("url");
|
||||
let sanitized = ensure_ingestion_url_allowed(&url).expect("allowed");
|
||||
fn allows_public_domain_and_sanitizes() -> anyhow::Result<()> {
|
||||
let url = url::Url::parse("https://sub.example.com/path")?;
|
||||
let sanitized = ensure_ingestion_url_allowed(&url)?;
|
||||
assert_eq!(sanitized, "sub_example_com");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
+97
-118
@@ -3,7 +3,7 @@ use axum::{extract::FromRef, Router};
|
||||
use common::{
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
indexes::ensure_runtime_indexes,
|
||||
indexes::ensure_runtime,
|
||||
store::StorageManager,
|
||||
types::{
|
||||
knowledge_entity::KnowledgeEntity, system_settings::SystemSettings,
|
||||
@@ -12,7 +12,10 @@ use common::{
|
||||
},
|
||||
utils::{config::get_config, embedding::EmbeddingProvider},
|
||||
};
|
||||
use html_router::{html_routes, html_state::HtmlState};
|
||||
use html_router::{
|
||||
html_routes,
|
||||
html_state::{HtmlState, StateResources},
|
||||
};
|
||||
use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop};
|
||||
use retrieval_pipeline::reranking::RerankerPool;
|
||||
use std::sync::Arc;
|
||||
@@ -21,19 +24,77 @@ use tracing_subscriber::{fmt, prelude::*, EnvFilter};
|
||||
|
||||
use tokio::task::LocalSet;
|
||||
|
||||
fn spawn_server_thread(
|
||||
listener: tokio::net::TcpListener,
|
||||
app: Router,
|
||||
) -> std::thread::JoinHandle<()> {
|
||||
std::thread::spawn(move || {
|
||||
let rt = match tokio::runtime::Runtime::new() {
|
||||
Ok(rt) => rt,
|
||||
Err(e) => {
|
||||
error!("Failed to create server runtime: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
rt.block_on(async {
|
||||
if let Err(e) = axum::serve(listener, app).await {
|
||||
error!("Server error: {}", e);
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
async fn run_worker(
|
||||
config: common::utils::config::AppConfig,
|
||||
reranker_pool: Option<Arc<RerankerPool>>,
|
||||
storage: StorageManager,
|
||||
) -> anyhow::Result<()> {
|
||||
let worker_db = Arc::new(
|
||||
SurrealDbClient::new(
|
||||
&config.surrealdb_address,
|
||||
&config.surrealdb_username,
|
||||
&config.surrealdb_password,
|
||||
&config.surrealdb_namespace,
|
||||
&config.surrealdb_database,
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
|
||||
let openai_client = Arc::new(async_openai::Client::with_config(
|
||||
async_openai::config::OpenAIConfig::new()
|
||||
.with_api_key(&config.openai_api_key)
|
||||
.with_api_base(&config.openai_base_url),
|
||||
));
|
||||
|
||||
let embedding_provider = Arc::new(
|
||||
EmbeddingProvider::from_config(&config, Some(Arc::clone(&openai_client))).await?,
|
||||
);
|
||||
|
||||
let ingestion_pipeline = Arc::new(
|
||||
IngestionPipeline::new(
|
||||
Arc::clone(&worker_db),
|
||||
openai_client,
|
||||
config,
|
||||
reranker_pool,
|
||||
storage,
|
||||
embedding_provider,
|
||||
)?,
|
||||
);
|
||||
|
||||
info!("Starting worker process");
|
||||
run_worker_loop(worker_db, ingestion_pipeline).await
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
// Set up tracing
|
||||
tracing_subscriber::registry()
|
||||
.with(fmt::layer().with_writer(std::io::stderr))
|
||||
.with(EnvFilter::from_default_env())
|
||||
.try_init()
|
||||
.ok();
|
||||
|
||||
// Get config
|
||||
let config = get_config()?;
|
||||
|
||||
// Set up router states
|
||||
let db = Arc::new(
|
||||
SurrealDbClient::new(
|
||||
&config.surrealdb_address,
|
||||
@@ -45,7 +106,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
.await?,
|
||||
);
|
||||
|
||||
// Ensure db is initialized
|
||||
db.apply_migrations().await?;
|
||||
|
||||
let session_store = Arc::new(db.create_session_store().await?);
|
||||
@@ -55,27 +115,23 @@ async fn main() -> anyhow::Result<()> {
|
||||
.with_api_base(&config.openai_base_url),
|
||||
));
|
||||
|
||||
// Create embedding provider based on config before syncing settings.
|
||||
let embedding_provider =
|
||||
Arc::new(EmbeddingProvider::from_config(&config, Some(openai_client.clone())).await?);
|
||||
Arc::new(EmbeddingProvider::from_config(&config, Some(Arc::clone(&openai_client))).await?);
|
||||
info!(
|
||||
embedding_backend = ?config.embedding_backend,
|
||||
embedding_dimension = embedding_provider.dimension(),
|
||||
"Embedding provider initialized"
|
||||
);
|
||||
|
||||
// Sync SystemSettings with provider's dimensions/model/backend
|
||||
let (settings, dimensions_changed) =
|
||||
SystemSettings::sync_from_embedding_provider(&db, &embedding_provider).await?;
|
||||
|
||||
// If dimensions changed, re-embed existing data to keep queries working.
|
||||
if dimensions_changed {
|
||||
warn!(
|
||||
new_dimensions = settings.embedding_dimensions,
|
||||
"Embedding configuration changed; re-embedding existing data"
|
||||
);
|
||||
|
||||
// Re-embed text chunks
|
||||
info!("Re-embedding TextChunks");
|
||||
if let Err(e) =
|
||||
TextChunk::update_all_embeddings_with_provider(&db, &embedding_provider).await
|
||||
@@ -86,7 +142,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
);
|
||||
}
|
||||
|
||||
// Re-embed knowledge entities
|
||||
info!("Re-embedding KnowledgeEntities");
|
||||
if let Err(e) =
|
||||
KnowledgeEntity::update_all_embeddings_with_provider(&db, &embedding_provider).await
|
||||
@@ -100,29 +155,25 @@ async fn main() -> anyhow::Result<()> {
|
||||
info!("Re-embedding complete.");
|
||||
}
|
||||
|
||||
// Now ensure runtime indexes with the correct (synced) dimensions
|
||||
ensure_runtime_indexes(&db, settings.embedding_dimensions as usize).await?;
|
||||
ensure_runtime(&db, settings.embedding_dimensions as usize).await?;
|
||||
|
||||
let reranker_pool = RerankerPool::maybe_from_config(&config)?;
|
||||
|
||||
// Create global storage manager
|
||||
let storage = StorageManager::new(&config).await?;
|
||||
|
||||
let html_state = HtmlState::new_with_resources(
|
||||
let html_state = HtmlState::new_with_resources(StateResources {
|
||||
db,
|
||||
openai_client,
|
||||
session_store,
|
||||
storage.clone(),
|
||||
config.clone(),
|
||||
reranker_pool.clone(),
|
||||
embedding_provider.clone(),
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
storage: storage.clone(),
|
||||
config: config.clone(),
|
||||
reranker_pool: reranker_pool.clone(),
|
||||
embedding_provider: Arc::clone(&embedding_provider),
|
||||
template_engine: None,
|
||||
});
|
||||
|
||||
let api_state = ApiState::new(&config, storage.clone()).await?;
|
||||
|
||||
// Create Axum router
|
||||
let app = Router::new()
|
||||
.nest("/api/v1", api_routes_v1(&api_state))
|
||||
.merge(html_routes(&html_state))
|
||||
@@ -135,72 +186,16 @@ async fn main() -> anyhow::Result<()> {
|
||||
let serve_address = format!("0.0.0.0:{}", config.http_port);
|
||||
let listener = tokio::net::TcpListener::bind(serve_address).await?;
|
||||
|
||||
// Start the server in a separate OS thread with its own runtime
|
||||
let server_handle = std::thread::spawn(move || {
|
||||
let rt = tokio::runtime::Runtime::new().unwrap();
|
||||
rt.block_on(async {
|
||||
if let Err(e) = axum::serve(listener, app).await {
|
||||
error!("Server error: {}", e);
|
||||
}
|
||||
});
|
||||
});
|
||||
let server_handle = spawn_server_thread(listener, app);
|
||||
|
||||
// Create a LocalSet for the worker
|
||||
let local = LocalSet::new();
|
||||
|
||||
// Use a clone of the config for the worker
|
||||
let worker_config = config.clone();
|
||||
|
||||
// Run the worker in the local set
|
||||
local.spawn_local(async move {
|
||||
// Create worker db connection
|
||||
let worker_db = Arc::new(
|
||||
SurrealDbClient::new(
|
||||
&worker_config.surrealdb_address,
|
||||
&worker_config.surrealdb_username,
|
||||
&worker_config.surrealdb_password,
|
||||
&worker_config.surrealdb_namespace,
|
||||
&worker_config.surrealdb_database,
|
||||
)
|
||||
.await
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
// Initialize worker components
|
||||
let openai_client = Arc::new(async_openai::Client::with_config(
|
||||
async_openai::config::OpenAIConfig::new()
|
||||
.with_api_key(&config.openai_api_key)
|
||||
.with_api_base(&config.openai_base_url),
|
||||
));
|
||||
|
||||
// Create embedding provider based on config
|
||||
let embedding_provider = Arc::new(
|
||||
EmbeddingProvider::from_config(&config, Some(openai_client.clone()))
|
||||
.await
|
||||
.expect("failed to create embedding provider"),
|
||||
);
|
||||
let ingestion_pipeline = Arc::new(
|
||||
IngestionPipeline::new(
|
||||
worker_db.clone(),
|
||||
openai_client.clone(),
|
||||
config.clone(),
|
||||
reranker_pool.clone(),
|
||||
storage.clone(),
|
||||
embedding_provider,
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
info!("Starting worker process");
|
||||
if let Err(e) = run_worker_loop(worker_db, ingestion_pipeline).await {
|
||||
error!("Worker process error: {}", e);
|
||||
if let Err(e) = run_worker(config, reranker_pool, storage).await {
|
||||
error!("Worker error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
// Run the local set on the main thread
|
||||
local.await;
|
||||
|
||||
// Wait for the server thread to finish (this likely won't be reached)
|
||||
if let Err(e) = server_handle.join() {
|
||||
error!("Server thread panicked: {:?}", e);
|
||||
}
|
||||
@@ -253,52 +248,39 @@ mod tests {
|
||||
let namespace = "test_ns";
|
||||
let database = format!("test_db_{}", Uuid::new_v4());
|
||||
let data_dir = std::env::temp_dir().join(format!("minne_smoke_{}", Uuid::new_v4()));
|
||||
|
||||
tokio::fs::create_dir_all(&data_dir)
|
||||
.await
|
||||
tokio::fs::create_dir_all(&data_dir).await
|
||||
.expect("failed to create temp data directory");
|
||||
|
||||
let config = smoke_test_config(namespace, &database, &data_dir);
|
||||
let db = Arc::new(
|
||||
SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.expect("failed to start in-memory surrealdb"),
|
||||
);
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("failed to apply migrations");
|
||||
let db = Arc::new(SurrealDbClient::memory(namespace, &database).await?);
|
||||
db.apply_migrations().await?;
|
||||
|
||||
let session_store = Arc::new(db.create_session_store().await.expect("session store"));
|
||||
let session_store = Arc::new(db.create_session_store().await?);
|
||||
let openai_client = Arc::new(async_openai::Client::with_config(
|
||||
async_openai::config::OpenAIConfig::new()
|
||||
.with_api_key(&config.openai_api_key)
|
||||
.with_api_base(&config.openai_base_url),
|
||||
));
|
||||
|
||||
let storage = StorageManager::new(&config)
|
||||
.await
|
||||
.expect("failed to build storage manager");
|
||||
let storage = StorageManager::new(&config).await?;
|
||||
|
||||
// Use hashed embeddings for tests to avoid external dependencies
|
||||
let embedding_provider = Arc::new(
|
||||
common::utils::embedding::EmbeddingProvider::new_hashed(384)
|
||||
.expect("failed to create hashed embedding provider"),
|
||||
common::utils::embedding::EmbeddingProvider::new_hashed(384)?,
|
||||
);
|
||||
|
||||
let html_state = HtmlState::new_with_resources(
|
||||
db.clone(),
|
||||
let html_state = HtmlState::new_with_resources(StateResources {
|
||||
db: Arc::clone(&db),
|
||||
openai_client,
|
||||
session_store,
|
||||
storage.clone(),
|
||||
config.clone(),
|
||||
None,
|
||||
storage: storage.clone(),
|
||||
config: config.clone(),
|
||||
reranker_pool: None,
|
||||
embedding_provider,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
template_engine: None,
|
||||
});
|
||||
|
||||
let api_state = ApiState {
|
||||
db: db.clone(),
|
||||
db: Arc::clone(&db),
|
||||
config: config.clone(),
|
||||
storage,
|
||||
};
|
||||
@@ -376,25 +358,22 @@ mod tests {
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri("/api/v1/live")
|
||||
.body(Body::empty())
|
||||
.expect("request"),
|
||||
.body(Body::empty())?,
|
||||
)
|
||||
.await
|
||||
.expect("router response");
|
||||
.await?;
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let ready_response = app
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri("/api/v1/ready")
|
||||
.body(Body::empty())
|
||||
.expect("request"),
|
||||
.body(Body::empty())?,
|
||||
)
|
||||
.await
|
||||
.expect("ready response");
|
||||
.await?;
|
||||
assert_eq!(ready_response.status(), StatusCode::OK);
|
||||
|
||||
tokio::fs::remove_dir_all(&data_dir).await.ok();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
|
||||
+10
-8
@@ -6,7 +6,10 @@ use common::{
|
||||
storage::{db::SurrealDbClient, store::StorageManager, types::system_settings::SystemSettings},
|
||||
utils::{config::get_config, embedding::EmbeddingProvider},
|
||||
};
|
||||
use html_router::{html_routes, html_state::HtmlState};
|
||||
use html_router::{
|
||||
html_routes,
|
||||
html_state::{HtmlState, StateResources},
|
||||
};
|
||||
use retrieval_pipeline::reranking::RerankerPool;
|
||||
use tracing::info;
|
||||
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
|
||||
@@ -52,7 +55,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
// Create embedding provider based on config
|
||||
let embedding_provider =
|
||||
Arc::new(EmbeddingProvider::from_config(&config, Some(openai_client.clone())).await?);
|
||||
Arc::new(EmbeddingProvider::from_config(&config, Some(Arc::clone(&openai_client))).await?);
|
||||
info!(
|
||||
embedding_backend = ?config.embedding_backend,
|
||||
embedding_dimension = embedding_provider.dimension(),
|
||||
@@ -63,17 +66,16 @@ async fn main() -> anyhow::Result<()> {
|
||||
let (_settings, _dimensions_changed) =
|
||||
SystemSettings::sync_from_embedding_provider(&db, &embedding_provider).await?;
|
||||
|
||||
let html_state = HtmlState::new_with_resources(
|
||||
let html_state = HtmlState::new_with_resources(StateResources {
|
||||
db,
|
||||
openai_client,
|
||||
session_store,
|
||||
storage.clone(),
|
||||
config.clone(),
|
||||
storage: storage.clone(),
|
||||
config: config.clone(),
|
||||
reranker_pool,
|
||||
embedding_provider,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
template_engine: None,
|
||||
});
|
||||
|
||||
let api_state = ApiState::new(&config, storage).await?;
|
||||
|
||||
|
||||
+3
-3
@@ -42,7 +42,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
// Create embedding provider based on config
|
||||
let embedding_provider =
|
||||
Arc::new(EmbeddingProvider::from_config(&config, Some(openai_client.clone())).await?);
|
||||
Arc::new(EmbeddingProvider::from_config(&config, Some(Arc::clone(&openai_client))).await?);
|
||||
info!(
|
||||
embedding_backend = ?config.embedding_backend,
|
||||
"Embedding provider initialized for worker"
|
||||
@@ -52,8 +52,8 @@ async fn main() -> anyhow::Result<()> {
|
||||
let storage = StorageManager::new(&config).await?;
|
||||
|
||||
let ingestion_pipeline = Arc::new(IngestionPipeline::new(
|
||||
db.clone(),
|
||||
openai_client.clone(),
|
||||
Arc::clone(&db),
|
||||
Arc::clone(&openai_client),
|
||||
config,
|
||||
reranker_pool,
|
||||
storage,
|
||||
|
||||
@@ -118,18 +118,16 @@ pub fn create_chat_request(
|
||||
}
|
||||
|
||||
pub fn process_llm_response(
|
||||
response: CreateChatCompletionResponse,
|
||||
) -> Result<LLMResponseFormat, AppError> {
|
||||
response: &CreateChatCompletionResponse,
|
||||
) -> Result<LLMResponseFormat, Box<AppError>> {
|
||||
response
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.message.content.as_ref())
|
||||
.ok_or(AppError::LLMParsing(
|
||||
"No content found in LLM response".into(),
|
||||
))
|
||||
.ok_or_else(|| Box::new(AppError::LLMParsing("No content found in LLM response".into())))
|
||||
.and_then(|content| {
|
||||
serde_json::from_str::<LLMResponseFormat>(content).map_err(|e| {
|
||||
AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {e}"))
|
||||
Box::new(AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {e}")))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ use common::storage::{
|
||||
/// * `entity_id` - ID of the entity to find neighbors for
|
||||
/// * `user_id` - User ID for access control
|
||||
/// * `limit` - Maximum number of neighbors to return
|
||||
|
||||
pub async fn find_entities_by_relationship_by_id(
|
||||
db: &SurrealDbClient,
|
||||
entity_id: &str,
|
||||
@@ -113,25 +112,23 @@ pub async fn find_entities_by_relationship_by_id(
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{self, Context};
|
||||
use super::*;
|
||||
use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||
use common::storage::types::knowledge_relationship::KnowledgeRelationship;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_find_entities_by_relationship_by_id() {
|
||||
// Setup in-memory database for testing
|
||||
async fn test_find_entities_by_relationship_by_id() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Create some test entities
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
// Create the central entity we'll query relationships for
|
||||
let central_entity = KnowledgeEntity::new(
|
||||
"central_source".to_string(),
|
||||
"Central Entity".to_string(),
|
||||
@@ -141,7 +138,6 @@ mod tests {
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Create related entities
|
||||
let related_entity1 = KnowledgeEntity::new(
|
||||
"related_source1".to_string(),
|
||||
"Related Entity 1".to_string(),
|
||||
@@ -160,7 +156,6 @@ mod tests {
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Create an unrelated entity
|
||||
let unrelated_entity = KnowledgeEntity::new(
|
||||
"unrelated_source".to_string(),
|
||||
"Unrelated Entity".to_string(),
|
||||
@@ -170,32 +165,29 @@ mod tests {
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Store all entities
|
||||
let central_entity = db
|
||||
.store_item(central_entity.clone())
|
||||
.await
|
||||
.expect("Failed to store central entity")
|
||||
.unwrap();
|
||||
.with_context(|| "Failed to store central entity".to_string())?
|
||||
.ok_or_else(|| anyhow::anyhow!("Central entity not returned after store"))?;
|
||||
let related_entity1 = db
|
||||
.store_item(related_entity1.clone())
|
||||
.await
|
||||
.expect("Failed to store related entity 1")
|
||||
.unwrap();
|
||||
.with_context(|| "Failed to store related entity 1".to_string())?
|
||||
.ok_or_else(|| anyhow::anyhow!("Related entity 1 not returned after store"))?;
|
||||
let related_entity2 = db
|
||||
.store_item(related_entity2.clone())
|
||||
.await
|
||||
.expect("Failed to store related entity 2")
|
||||
.unwrap();
|
||||
.with_context(|| "Failed to store related entity 2".to_string())?
|
||||
.ok_or_else(|| anyhow::anyhow!("Related entity 2 not returned after store"))?;
|
||||
let _unrelated_entity = db
|
||||
.store_item(unrelated_entity.clone())
|
||||
.await
|
||||
.expect("Failed to store unrelated entity")
|
||||
.unwrap();
|
||||
.with_context(|| "Failed to store unrelated entity".to_string())?
|
||||
.ok_or_else(|| anyhow::anyhow!("Unrelated entity not returned after store"))?;
|
||||
|
||||
// Create relationships
|
||||
let source_id = "relationship_source".to_string();
|
||||
|
||||
// Create relationship 1: central -> related1
|
||||
let relationship1 = KnowledgeRelationship::new(
|
||||
central_entity.id.clone(),
|
||||
related_entity1.id.clone(),
|
||||
@@ -204,7 +196,6 @@ mod tests {
|
||||
"references".to_string(),
|
||||
);
|
||||
|
||||
// Create relationship 2: central -> related2
|
||||
let relationship2 = KnowledgeRelationship::new(
|
||||
central_entity.id.clone(),
|
||||
related_entity2.id.clone(),
|
||||
@@ -213,26 +204,25 @@ mod tests {
|
||||
"contains".to_string(),
|
||||
);
|
||||
|
||||
// Store relationships
|
||||
relationship1
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship 1");
|
||||
.with_context(|| "Failed to store relationship 1".to_string())?;
|
||||
relationship2
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship 2");
|
||||
.with_context(|| "Failed to store relationship 2".to_string())?;
|
||||
|
||||
// Test finding entities related to the central entity
|
||||
let related_entities =
|
||||
find_entities_by_relationship_by_id(&db, ¢ral_entity.id, &user_id, usize::MAX)
|
||||
.await
|
||||
.expect("Failed to find entities by relationship");
|
||||
.with_context(|| "Failed to find entities by relationship".to_string())?;
|
||||
|
||||
// Check that we found relationships
|
||||
assert!(
|
||||
related_entities.len() >= 2,
|
||||
"Should find related entities in both directions"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
+80
-100
@@ -42,10 +42,14 @@ impl SearchResult {
|
||||
}
|
||||
|
||||
pub use pipeline::{
|
||||
retrieved_entities_to_json, PipelineDiagnostics, PipelineStageTimings, RetrievalConfig,
|
||||
RetrievalStrategy, RetrievalTuning, SearchTarget,
|
||||
retrieved_entities_to_json, Diagnostics, StageTimings, RetrievalConfig,
|
||||
RetrievalStrategy, RetrievalTuning, RetrievalTuningFlags, SearchTarget,
|
||||
};
|
||||
|
||||
// Backward-compatible type aliases for external consumers
|
||||
pub type PipelineDiagnostics = Diagnostics;
|
||||
pub type PipelineStageTimings = StageTimings;
|
||||
|
||||
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetrievedChunk {
|
||||
@@ -61,7 +65,7 @@ pub struct RetrievedEntity {
|
||||
pub chunks: Vec<RetrievedChunk>,
|
||||
}
|
||||
|
||||
/// Primary orchestrator for the process of retrieving KnowledgeEntitities related to a input_text
|
||||
/// Primary orchestrator for the process of retrieving `KnowledgeEntity` values related to an `input_text`
|
||||
#[instrument(skip_all, fields(user_id))]
|
||||
pub async fn retrieve_entities(
|
||||
db_client: &SurrealDbClient,
|
||||
@@ -72,7 +76,7 @@ pub async fn retrieve_entities(
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<StrategyOutput, AppError> {
|
||||
pipeline::run_pipeline(
|
||||
let params = pipeline::StrategyParams {
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
@@ -80,17 +84,16 @@ pub async fn retrieve_entities(
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
)
|
||||
.await
|
||||
};
|
||||
pipeline::execute(params).await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use anyhow::{self};
|
||||
use async_openai::Client;
|
||||
use common::storage::indexes::ensure_runtime_indexes;
|
||||
use common::storage::types::text_chunk::TextChunk;
|
||||
use pipeline::{RetrievalConfig, RetrievalStrategy};
|
||||
use common::storage::indexes::ensure_runtime;
|
||||
use uuid::Uuid;
|
||||
|
||||
fn test_embedding() -> Vec<f32> {
|
||||
@@ -105,27 +108,21 @@ mod tests {
|
||||
vec![0.2, 0.8, 0.0]
|
||||
}
|
||||
|
||||
async fn setup_test_db() -> SurrealDbClient {
|
||||
async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
db.apply_migrations().await?;
|
||||
|
||||
ensure_runtime_indexes(&db, 3)
|
||||
.await
|
||||
.expect("failed to build runtime indexes");
|
||||
ensure_runtime(&db, 3).await?;
|
||||
|
||||
db
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_default_strategy_retrieves_chunks() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_default_strategy_retrieves_chunks() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "test_user";
|
||||
let chunk = TextChunk::new(
|
||||
"source_1".into(),
|
||||
@@ -133,39 +130,38 @@ mod tests {
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db)
|
||||
.await
|
||||
.expect("Failed to store chunk");
|
||||
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?;
|
||||
|
||||
let openai_client = Client::new();
|
||||
let results = pipeline::run_pipeline_with_embedding(
|
||||
&db,
|
||||
&openai_client,
|
||||
None,
|
||||
test_embedding(),
|
||||
"Rust concurrency async tasks",
|
||||
let params = pipeline::StrategyParams {
|
||||
db_client: &db,
|
||||
openai_client: &openai_client,
|
||||
embedding_provider: None,
|
||||
input_text: "Rust concurrency async tasks",
|
||||
user_id,
|
||||
RetrievalConfig::default(),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Default strategy retrieval failed");
|
||||
config: RetrievalConfig::default(),
|
||||
reranker: None,
|
||||
};
|
||||
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
|
||||
.await?;
|
||||
|
||||
let chunks = match results {
|
||||
StrategyOutput::Chunks(items) => items,
|
||||
other => panic!("expected chunk results, got {:?}", other),
|
||||
other => anyhow::bail!("expected chunk results, got {other:?}"),
|
||||
};
|
||||
|
||||
assert!(!chunks.is_empty(), "Expected at least one retrieval result");
|
||||
assert!(
|
||||
chunks[0].chunk.chunk.contains("Tokio"),
|
||||
chunks.first().is_some_and(|c| c.chunk.chunk.contains("Tokio")),
|
||||
"Expected chunk about Tokio"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_default_strategy_returns_chunks_from_multiple_sources() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_default_strategy_returns_chunks_from_multiple_sources(
|
||||
) -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "multi_source_user";
|
||||
|
||||
let primary_chunk = TextChunk::new(
|
||||
@@ -179,30 +175,25 @@ mod tests {
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(primary_chunk, chunk_embedding_primary(), &db)
|
||||
.await
|
||||
.expect("Failed to store primary chunk");
|
||||
TextChunk::store_with_embedding(secondary_chunk, chunk_embedding_secondary(), &db)
|
||||
.await
|
||||
.expect("Failed to store secondary chunk");
|
||||
TextChunk::store_with_embedding(primary_chunk, chunk_embedding_primary(), &db).await?;
|
||||
TextChunk::store_with_embedding(secondary_chunk, chunk_embedding_secondary(), &db).await?;
|
||||
|
||||
let openai_client = Client::new();
|
||||
let results = pipeline::run_pipeline_with_embedding(
|
||||
&db,
|
||||
&openai_client,
|
||||
None,
|
||||
test_embedding(),
|
||||
"Rust concurrency async tasks",
|
||||
let params = pipeline::StrategyParams {
|
||||
db_client: &db,
|
||||
openai_client: &openai_client,
|
||||
embedding_provider: None,
|
||||
input_text: "Rust concurrency async tasks",
|
||||
user_id,
|
||||
RetrievalConfig::default(),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Default strategy retrieval failed");
|
||||
config: RetrievalConfig::default(),
|
||||
reranker: None,
|
||||
};
|
||||
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
|
||||
.await?;
|
||||
|
||||
let chunks = match results {
|
||||
StrategyOutput::Chunks(items) => items,
|
||||
other => panic!("expected chunk results, got {:?}", other),
|
||||
other => anyhow::bail!("expected chunk results, got {other:?}"),
|
||||
};
|
||||
|
||||
assert!(chunks.len() >= 2, "Expected chunks from multiple sources");
|
||||
@@ -216,11 +207,12 @@ mod tests {
|
||||
.any(|c| c.chunk.source_id == "secondary_source"),
|
||||
"Should include secondary source chunk"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_revised_strategy_returns_chunks() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_revised_strategy_returns_chunks() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "chunk_user";
|
||||
let chunk_one = TextChunk::new(
|
||||
"src_alpha".into(),
|
||||
@@ -233,31 +225,26 @@ mod tests {
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk_one.clone(), chunk_embedding_primary(), &db)
|
||||
.await
|
||||
.expect("Failed to store chunk one");
|
||||
TextChunk::store_with_embedding(chunk_two.clone(), chunk_embedding_secondary(), &db)
|
||||
.await
|
||||
.expect("Failed to store chunk two");
|
||||
TextChunk::store_with_embedding(chunk_one.clone(), chunk_embedding_primary(), &db).await?;
|
||||
TextChunk::store_with_embedding(chunk_two.clone(), chunk_embedding_secondary(), &db).await?;
|
||||
|
||||
let config = RetrievalConfig::with_strategy(RetrievalStrategy::Default);
|
||||
let openai_client = Client::new();
|
||||
let results = pipeline::run_pipeline_with_embedding(
|
||||
&db,
|
||||
&openai_client,
|
||||
None,
|
||||
test_embedding(),
|
||||
"tokio runtime worker behavior",
|
||||
let params = pipeline::StrategyParams {
|
||||
db_client: &db,
|
||||
openai_client: &openai_client,
|
||||
embedding_provider: None,
|
||||
input_text: "tokio runtime worker behavior",
|
||||
user_id,
|
||||
config,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Revised retrieval failed");
|
||||
reranker: None,
|
||||
};
|
||||
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
|
||||
.await?;
|
||||
|
||||
let chunks = match results {
|
||||
StrategyOutput::Chunks(items) => items,
|
||||
other => panic!("expected chunk output, got {:?}", other),
|
||||
other => anyhow::bail!("expected chunk results, got {other:?}"),
|
||||
};
|
||||
|
||||
assert!(
|
||||
@@ -270,11 +257,12 @@ mod tests {
|
||||
.any(|entry| entry.chunk.chunk.contains("Tokio")),
|
||||
"Chunk results should contain relevant snippets"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_search_strategy_returns_search_result() {
|
||||
let db = setup_test_db().await;
|
||||
async fn test_search_strategy_returns_search_result() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "search_user";
|
||||
let chunk = TextChunk::new(
|
||||
"search_src".into(),
|
||||
@@ -282,33 +270,24 @@ mod tests {
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db)
|
||||
.await
|
||||
.expect("Failed to store chunk");
|
||||
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?;
|
||||
|
||||
let config = RetrievalConfig::for_search(pipeline::SearchTarget::Both);
|
||||
let openai_client = Client::new();
|
||||
let results = pipeline::run_pipeline_with_embedding(
|
||||
&db,
|
||||
&openai_client,
|
||||
None,
|
||||
test_embedding(),
|
||||
"async rust programming",
|
||||
let params = pipeline::StrategyParams {
|
||||
db_client: &db,
|
||||
openai_client: &openai_client,
|
||||
embedding_provider: None,
|
||||
input_text: "async rust programming",
|
||||
user_id,
|
||||
config,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Search strategy retrieval failed");
|
||||
reranker: None,
|
||||
};
|
||||
let results = pipeline::run_pipeline_with_embedding(params, test_embedding())
|
||||
.await?;
|
||||
|
||||
assert!(
|
||||
matches!(results, StrategyOutput::Search(_)),
|
||||
"expected Search output, got {:?}",
|
||||
results
|
||||
);
|
||||
let search_result = match results {
|
||||
StrategyOutput::Search(sr) => sr,
|
||||
_ => unreachable!(),
|
||||
let StrategyOutput::Search(search_result) = results else {
|
||||
anyhow::bail!("expected Search output");
|
||||
};
|
||||
|
||||
// Should return chunks (entities may be empty if none stored)
|
||||
@@ -323,5 +302,6 @@ mod tests {
|
||||
.any(|c| c.chunk.chunk.contains("Tokio")),
|
||||
"Search results should contain relevant chunks"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::fmt;
|
||||
|
||||
use crate::scoring::FusionWeights;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RetrievalStrategy {
|
||||
/// Primary hybrid chunk retrieval for search/chat (formerly Revised)
|
||||
#[default]
|
||||
Default,
|
||||
/// Entity retrieval for suggesting relationships when creating manual entities
|
||||
RelationshipSuggestion,
|
||||
@@ -29,12 +30,6 @@ pub enum SearchTarget {
|
||||
Both,
|
||||
}
|
||||
|
||||
impl Default for RetrievalStrategy {
|
||||
fn default() -> Self {
|
||||
Self::Default
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for RetrievalStrategy {
|
||||
type Err = String;
|
||||
|
||||
@@ -70,6 +65,91 @@ impl fmt::Display for RetrievalStrategy {
|
||||
}
|
||||
}
|
||||
|
||||
/// Two-variant flag that serializes as a bool for backward compatibility.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub enum BoolFlag {
|
||||
#[default]
|
||||
Disabled,
|
||||
Enabled,
|
||||
}
|
||||
|
||||
impl BoolFlag {
|
||||
pub const fn as_bool(self) -> bool {
|
||||
matches!(self, BoolFlag::Enabled)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<bool> for BoolFlag {
|
||||
fn from(value: bool) -> Self {
|
||||
if value {
|
||||
BoolFlag::Enabled
|
||||
} else {
|
||||
BoolFlag::Disabled
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for BoolFlag {
|
||||
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
serializer.serialize_bool(self.as_bool())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for BoolFlag {
|
||||
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
|
||||
bool::deserialize(deserializer).map(|b| {
|
||||
if b {
|
||||
BoolFlag::Enabled
|
||||
} else {
|
||||
BoolFlag::Disabled
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct RetrievalTuningFlags {
|
||||
pub rerank_scores_only: BoolFlag,
|
||||
pub normalize_vector_scores: BoolFlag,
|
||||
pub normalize_fts_scores: BoolFlag,
|
||||
pub chunk_rrf_use_vector: BoolFlag,
|
||||
pub chunk_rrf_use_fts: BoolFlag,
|
||||
}
|
||||
|
||||
impl RetrievalTuningFlags {
|
||||
pub const fn rerank_scores_only(&self) -> bool {
|
||||
self.rerank_scores_only.as_bool()
|
||||
}
|
||||
|
||||
pub const fn normalize_vector_scores(&self) -> bool {
|
||||
self.normalize_vector_scores.as_bool()
|
||||
}
|
||||
|
||||
pub const fn normalize_fts_scores(&self) -> bool {
|
||||
self.normalize_fts_scores.as_bool()
|
||||
}
|
||||
|
||||
pub const fn chunk_rrf_use_vector(&self) -> bool {
|
||||
self.chunk_rrf_use_vector.as_bool()
|
||||
}
|
||||
|
||||
pub const fn chunk_rrf_use_fts(&self) -> bool {
|
||||
self.chunk_rrf_use_fts.as_bool()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RetrievalTuningFlags {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
rerank_scores_only: BoolFlag::Disabled,
|
||||
normalize_vector_scores: BoolFlag::Disabled,
|
||||
normalize_fts_scores: BoolFlag::Enabled,
|
||||
chunk_rrf_use_vector: BoolFlag::Enabled,
|
||||
chunk_rrf_use_fts: BoolFlag::Enabled,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tunable parameters that govern each retrieval stage.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RetrievalTuning {
|
||||
@@ -89,15 +169,11 @@ pub struct RetrievalTuning {
|
||||
pub graph_seed_min_score: f32,
|
||||
pub graph_vector_inheritance: f32,
|
||||
pub rerank_blend_weight: f32,
|
||||
pub rerank_scores_only: bool,
|
||||
pub flags: RetrievalTuningFlags,
|
||||
pub rerank_keep_top: usize,
|
||||
pub chunk_result_cap: usize,
|
||||
/// Optional fusion weights for hybrid search. If None, uses default weights.
|
||||
pub fusion_weights: Option<FusionWeights>,
|
||||
/// Normalize vector similarity scores before fusion (default: true)
|
||||
pub normalize_vector_scores: bool,
|
||||
/// Normalize FTS (BM25) scores before fusion (default: true)
|
||||
pub normalize_fts_scores: bool,
|
||||
/// Reciprocal rank fusion k value for chunk merging in Revised strategy.
|
||||
#[serde(default = "default_chunk_rrf_k")]
|
||||
pub chunk_rrf_k: f32,
|
||||
@@ -107,12 +183,6 @@ pub struct RetrievalTuning {
|
||||
/// Weight applied to chunk FTS ranks in RRF.
|
||||
#[serde(default = "default_chunk_rrf_fts_weight")]
|
||||
pub chunk_rrf_fts_weight: f32,
|
||||
/// Whether to include vector rankings in RRF.
|
||||
#[serde(default = "default_chunk_rrf_use_vector")]
|
||||
pub chunk_rrf_use_vector: bool,
|
||||
/// Whether to include chunk FTS rankings in RRF.
|
||||
#[serde(default = "default_chunk_rrf_use_fts")]
|
||||
pub chunk_rrf_use_fts: bool,
|
||||
}
|
||||
|
||||
impl Default for RetrievalTuning {
|
||||
@@ -134,26 +204,19 @@ impl Default for RetrievalTuning {
|
||||
graph_seed_min_score: 0.4,
|
||||
graph_vector_inheritance: 0.6,
|
||||
rerank_blend_weight: 0.65,
|
||||
rerank_scores_only: false,
|
||||
flags: RetrievalTuningFlags::default(),
|
||||
rerank_keep_top: 8,
|
||||
chunk_result_cap: 5,
|
||||
fusion_weights: None,
|
||||
// Vector scores (cosine similarity) are already in [0,1] range
|
||||
// Normalization only helps when there's significant variation
|
||||
normalize_vector_scores: false,
|
||||
// FTS scores (BM25) are unbounded, normalization helps more
|
||||
normalize_fts_scores: true,
|
||||
chunk_rrf_k: default_chunk_rrf_k(),
|
||||
chunk_rrf_vector_weight: default_chunk_rrf_vector_weight(),
|
||||
chunk_rrf_fts_weight: default_chunk_rrf_fts_weight(),
|
||||
chunk_rrf_use_vector: default_chunk_rrf_use_vector(),
|
||||
chunk_rrf_use_fts: default_chunk_rrf_use_fts(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper containing tuning plus future flags for per-request overrides.
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct RetrievalConfig {
|
||||
pub strategy: RetrievalStrategy,
|
||||
pub tuning: RetrievalTuning,
|
||||
@@ -211,16 +274,6 @@ impl RetrievalConfig {
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RetrievalConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
strategy: RetrievalStrategy::default(),
|
||||
tuning: RetrievalTuning::default(),
|
||||
search_target: SearchTarget::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const fn default_chunk_rrf_k() -> f32 {
|
||||
60.0
|
||||
}
|
||||
@@ -233,10 +286,4 @@ const fn default_chunk_rrf_fts_weight() -> f32 {
|
||||
1.0
|
||||
}
|
||||
|
||||
const fn default_chunk_rrf_use_vector() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
const fn default_chunk_rrf_use_fts() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ use serde::Serialize;
|
||||
|
||||
/// Captures instrumentation for each hybrid retrieval stage when diagnostics are enabled.
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct PipelineDiagnostics {
|
||||
pub struct Diagnostics {
|
||||
pub collect_candidates: Option<CollectCandidatesStats>,
|
||||
pub enrich_chunks_from_entities: Option<ChunkEnrichmentStats>,
|
||||
pub assemble: Option<AssembleStats>,
|
||||
|
||||
@@ -3,10 +3,11 @@ mod diagnostics;
|
||||
mod stages;
|
||||
mod strategies;
|
||||
|
||||
pub use config::{RetrievalConfig, RetrievalStrategy, RetrievalTuning, SearchTarget};
|
||||
pub use config::{
|
||||
RetrievalConfig, RetrievalStrategy, RetrievalTuning, RetrievalTuningFlags, SearchTarget,
|
||||
};
|
||||
pub use diagnostics::{
|
||||
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
|
||||
PipelineDiagnostics,
|
||||
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace, Diagnostics,
|
||||
};
|
||||
|
||||
use crate::{reranking::RerankerLease, RetrievedEntity, StrategyOutput};
|
||||
@@ -37,13 +38,13 @@ pub enum StageKind {
|
||||
|
||||
// Pipeline stage trait
|
||||
#[async_trait]
|
||||
pub trait PipelineStage: Send + Sync {
|
||||
pub trait Stage: Send + Sync {
|
||||
fn kind(&self) -> StageKind;
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError>;
|
||||
}
|
||||
|
||||
// Type alias for boxed stages
|
||||
pub type BoxedStage = Box<dyn PipelineStage>;
|
||||
pub type BoxedStage = Box<dyn Stage>;
|
||||
|
||||
// Strategy driver trait
|
||||
#[async_trait]
|
||||
@@ -51,16 +52,16 @@ pub trait StrategyDriver: Send + Sync {
|
||||
type Output;
|
||||
|
||||
fn stages(&self) -> Vec<BoxedStage>;
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError>;
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>>;
|
||||
}
|
||||
|
||||
// Pipeline stage timings tracker
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct PipelineStageTimings {
|
||||
pub struct StageTimings {
|
||||
timings: Vec<(StageKind, Duration)>,
|
||||
}
|
||||
|
||||
impl PipelineStageTimings {
|
||||
impl StageTimings {
|
||||
pub fn record(&mut self, kind: StageKind, duration: Duration) {
|
||||
self.timings.push((kind, duration));
|
||||
}
|
||||
@@ -74,8 +75,7 @@ impl PipelineStageTimings {
|
||||
self.timings
|
||||
.iter()
|
||||
.find(|(k, _)| *k == kind)
|
||||
.map(|(_, d)| d.as_millis())
|
||||
.unwrap_or(0)
|
||||
.map_or(0, |(_, d)| d.as_millis())
|
||||
}
|
||||
|
||||
pub fn embed_ms(&self) -> u128 {
|
||||
@@ -103,228 +103,100 @@ impl PipelineStageTimings {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PipelineRunOutput<T> {
|
||||
pub struct RunOutput<T> {
|
||||
pub results: T,
|
||||
pub diagnostics: Option<PipelineDiagnostics>,
|
||||
pub stage_timings: PipelineStageTimings,
|
||||
pub diagnostics: Option<Diagnostics>,
|
||||
pub stage_timings: StageTimings,
|
||||
}
|
||||
|
||||
pub async fn run_pipeline(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<StrategyOutput, AppError> {
|
||||
let input_chars = input_text.chars().count();
|
||||
let input_preview: String = input_text.chars().take(120).collect();
|
||||
pub async fn execute(params: StrategyParams<'_>) -> Result<StrategyOutput, AppError> {
|
||||
let input_chars = params.input_text.chars().count();
|
||||
let input_preview: String = params.input_text.chars().take(120).collect();
|
||||
let input_preview_clean = input_preview.replace('\n', " ");
|
||||
let preview_len = input_preview_clean.chars().count();
|
||||
info!(
|
||||
%user_id,
|
||||
user_id = %params.user_id,
|
||||
input_chars,
|
||||
preview_truncated = input_chars > preview_len,
|
||||
preview = %input_preview_clean,
|
||||
strategy = %config.strategy,
|
||||
strategy = %params.config.strategy,
|
||||
"Starting retrieval pipeline"
|
||||
);
|
||||
|
||||
match config.strategy {
|
||||
let strategy = params.config.strategy;
|
||||
let search_target = params.config.search_target;
|
||||
|
||||
match strategy {
|
||||
RetrievalStrategy::Default => {
|
||||
let driver = DefaultStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
let run = execute_strategy(driver, params, None, false).await?;
|
||||
Ok(StrategyOutput::Chunks(run.results))
|
||||
}
|
||||
RetrievalStrategy::RelationshipSuggestion => {
|
||||
let driver = RelationshipSuggestionDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
let run = execute_strategy(driver, params, None, false).await?;
|
||||
Ok(StrategyOutput::Entities(run.results))
|
||||
}
|
||||
RetrievalStrategy::Ingestion => {
|
||||
let driver = IngestionDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
let run = execute_strategy(driver, params, None, false).await?;
|
||||
Ok(StrategyOutput::Entities(run.results))
|
||||
}
|
||||
RetrievalStrategy::Search => {
|
||||
let search_target = config.search_target;
|
||||
let driver = SearchStrategyDriver::new(search_target);
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
let run = execute_strategy(driver, params, None, false).await?;
|
||||
Ok(StrategyOutput::Search(run.results))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_pipeline_with_embedding(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
|
||||
params: StrategyParams<'_>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<StrategyOutput, AppError> {
|
||||
match config.strategy {
|
||||
let strategy = params.config.strategy;
|
||||
let search_target = params.config.search_target;
|
||||
|
||||
match strategy {
|
||||
RetrievalStrategy::Default => {
|
||||
let driver = DefaultStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
||||
Ok(StrategyOutput::Chunks(run.results))
|
||||
}
|
||||
RetrievalStrategy::RelationshipSuggestion => {
|
||||
let driver = RelationshipSuggestionDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
||||
Ok(StrategyOutput::Entities(run.results))
|
||||
}
|
||||
RetrievalStrategy::Ingestion => {
|
||||
let driver = IngestionDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
||||
Ok(StrategyOutput::Entities(run.results))
|
||||
}
|
||||
RetrievalStrategy::Search => {
|
||||
let search_target = config.search_target;
|
||||
let driver = SearchStrategyDriver::new(search_target);
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
||||
Ok(StrategyOutput::Search(run.results))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Note: The metrics/diagnostics variants would follow the same pattern,
|
||||
// but for brevity I'm only updating the main ones used by callers.
|
||||
// If metrics/diagnostics are needed for non-chat strategies, they should be updated too.
|
||||
// For now, I'll update them to support at least Initial/Revised as before.
|
||||
|
||||
pub async fn run_pipeline_with_embedding_with_metrics(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
|
||||
params: StrategyParams<'_>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
|
||||
match config.strategy {
|
||||
) -> Result<RunOutput<StrategyOutput>, AppError> {
|
||||
let strategy = params.config.strategy;
|
||||
|
||||
match strategy {
|
||||
RetrievalStrategy::Default => {
|
||||
let driver = DefaultStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(PipelineRunOutput {
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), false).await?;
|
||||
Ok(RunOutput {
|
||||
results: StrategyOutput::Chunks(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
})
|
||||
}
|
||||
// Fallback for others if needed, or error. For now assuming metrics mainly for chat.
|
||||
_ => Err(AppError::InternalError(
|
||||
"Metrics not supported for this strategy".into(),
|
||||
)),
|
||||
@@ -332,32 +204,16 @@ pub async fn run_pipeline_with_embedding_with_metrics(
|
||||
}
|
||||
|
||||
pub async fn run_pipeline_with_embedding_with_diagnostics(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
|
||||
params: StrategyParams<'_>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
|
||||
match config.strategy {
|
||||
) -> Result<RunOutput<StrategyOutput>, AppError> {
|
||||
let strategy = params.config.strategy;
|
||||
|
||||
match strategy {
|
||||
RetrievalStrategy::Default => {
|
||||
let driver = DefaultStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
true,
|
||||
)
|
||||
.await?;
|
||||
Ok(PipelineRunOutput {
|
||||
let run = execute_strategy(driver, params, Some(query_embedding), true).await?;
|
||||
Ok(RunOutput {
|
||||
results: StrategyOutput::Chunks(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
@@ -391,38 +247,25 @@ pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::V
|
||||
.collect::<Vec<_>>())
|
||||
}
|
||||
|
||||
pub struct StrategyParams<'a> {
|
||||
pub db_client: &'a SurrealDbClient,
|
||||
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
pub embedding_provider: Option<&'a common::utils::embedding::EmbeddingProvider>,
|
||||
pub input_text: &'a str,
|
||||
pub user_id: &'a str,
|
||||
pub config: RetrievalConfig,
|
||||
pub reranker: Option<RerankerLease>,
|
||||
}
|
||||
|
||||
async fn execute_strategy<D: StrategyDriver>(
|
||||
driver: D,
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
|
||||
params: StrategyParams<'_>,
|
||||
query_embedding: Option<Vec<f32>>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
capture_diagnostics: bool,
|
||||
) -> Result<PipelineRunOutput<D::Output>, AppError> {
|
||||
) -> Result<RunOutput<D::Output>, AppError> {
|
||||
let ctx = match query_embedding {
|
||||
Some(embedding) => PipelineContext::with_embedding(
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
embedding,
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
config,
|
||||
reranker,
|
||||
),
|
||||
None => PipelineContext::new(
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
config,
|
||||
reranker,
|
||||
),
|
||||
Some(embedding) => PipelineContext::with_embedding(params, embedding),
|
||||
None => PipelineContext::new(params),
|
||||
};
|
||||
|
||||
run_with_driver(driver, ctx, capture_diagnostics).await
|
||||
@@ -432,7 +275,7 @@ async fn run_with_driver<D: StrategyDriver>(
|
||||
driver: D,
|
||||
mut ctx: PipelineContext<'_>,
|
||||
capture_diagnostics: bool,
|
||||
) -> Result<PipelineRunOutput<D::Output>, AppError> {
|
||||
) -> Result<RunOutput<D::Output>, AppError> {
|
||||
if capture_diagnostics {
|
||||
ctx.enable_diagnostics();
|
||||
}
|
||||
@@ -445,9 +288,9 @@ async fn run_with_driver<D: StrategyDriver>(
|
||||
|
||||
let diagnostics = ctx.take_diagnostics();
|
||||
let stage_timings = ctx.take_stage_timings();
|
||||
let results = driver.finalize(&mut ctx)?;
|
||||
let results = driver.finalize(&mut ctx).map_err(|e| *e)?;
|
||||
|
||||
Ok(PipelineRunOutput {
|
||||
Ok(RunOutput {
|
||||
results,
|
||||
diagnostics,
|
||||
stage_timings,
|
||||
|
||||
@@ -27,9 +27,9 @@ use super::{
|
||||
config::{RetrievalConfig, RetrievalTuning},
|
||||
diagnostics::{
|
||||
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
|
||||
PipelineDiagnostics,
|
||||
Diagnostics,
|
||||
},
|
||||
PipelineStage, PipelineStageTimings, StageKind,
|
||||
StageTimings, Stage, StageKind, StrategyParams,
|
||||
};
|
||||
|
||||
pub struct PipelineContext<'a> {
|
||||
@@ -45,76 +45,51 @@ pub struct PipelineContext<'a> {
|
||||
pub chunk_values: Vec<Scored<TextChunk>>,
|
||||
pub revised_chunk_values: Vec<Scored<TextChunk>>,
|
||||
pub reranker: Option<RerankerLease>,
|
||||
pub diagnostics: Option<PipelineDiagnostics>,
|
||||
pub diagnostics: Option<Diagnostics>,
|
||||
pub entity_results: Vec<RetrievedEntity>,
|
||||
pub chunk_results: Vec<RetrievedChunk>,
|
||||
stage_timings: PipelineStageTimings,
|
||||
stage_timings: StageTimings,
|
||||
}
|
||||
|
||||
impl<'a> PipelineContext<'a> {
|
||||
pub fn new(
|
||||
db_client: &'a SurrealDbClient,
|
||||
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&'a EmbeddingProvider>,
|
||||
input_text: String,
|
||||
user_id: String,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Self {
|
||||
pub fn new(params: StrategyParams<'a>) -> Self {
|
||||
Self {
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
db_client: params.db_client,
|
||||
openai_client: params.openai_client,
|
||||
embedding_provider: params.embedding_provider,
|
||||
input_text: params.input_text.to_owned(),
|
||||
user_id: params.user_id.to_owned(),
|
||||
config: params.config,
|
||||
query_embedding: None,
|
||||
entity_candidates: HashMap::new(),
|
||||
filtered_entities: Vec::new(),
|
||||
chunk_values: Vec::new(),
|
||||
revised_chunk_values: Vec::new(),
|
||||
reranker,
|
||||
reranker: params.reranker,
|
||||
diagnostics: None,
|
||||
entity_results: Vec::new(),
|
||||
chunk_results: Vec::new(),
|
||||
stage_timings: PipelineStageTimings::default(),
|
||||
stage_timings: StageTimings::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_embedding(
|
||||
db_client: &'a SurrealDbClient,
|
||||
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&'a EmbeddingProvider>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: String,
|
||||
user_id: String,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Self {
|
||||
let mut ctx = Self::new(
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
);
|
||||
pub fn with_embedding(params: StrategyParams<'a>, query_embedding: Vec<f32>) -> Self {
|
||||
let mut ctx = Self::new(params);
|
||||
ctx.query_embedding = Some(query_embedding);
|
||||
ctx
|
||||
}
|
||||
|
||||
fn ensure_embedding(&self) -> Result<&Vec<f32>, AppError> {
|
||||
fn ensure_embedding(&self) -> Result<&Vec<f32>, Box<AppError>> {
|
||||
self.query_embedding.as_ref().ok_or_else(|| {
|
||||
AppError::InternalError(
|
||||
Box::new(AppError::InternalError(
|
||||
"query embedding missing before candidate collection".to_string(),
|
||||
)
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn enable_diagnostics(&mut self) {
|
||||
if self.diagnostics.is_none() {
|
||||
self.diagnostics = Some(PipelineDiagnostics::default());
|
||||
self.diagnostics = Some(Diagnostics::default());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,11 +115,11 @@ impl<'a> PipelineContext<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn take_diagnostics(&mut self) -> Option<PipelineDiagnostics> {
|
||||
pub fn take_diagnostics(&mut self) -> Option<Diagnostics> {
|
||||
self.diagnostics.take()
|
||||
}
|
||||
|
||||
pub fn take_stage_timings(&mut self) -> PipelineStageTimings {
|
||||
pub fn take_stage_timings(&mut self) -> StageTimings {
|
||||
std::mem::take(&mut self.stage_timings)
|
||||
}
|
||||
|
||||
@@ -165,7 +140,7 @@ impl<'a> PipelineContext<'a> {
|
||||
pub struct EmbedStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for EmbedStage {
|
||||
impl Stage for EmbedStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Embed
|
||||
}
|
||||
@@ -179,7 +154,7 @@ impl PipelineStage for EmbedStage {
|
||||
pub struct CollectCandidatesStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for CollectCandidatesStage {
|
||||
impl Stage for CollectCandidatesStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::CollectCandidates
|
||||
}
|
||||
@@ -193,7 +168,7 @@ impl PipelineStage for CollectCandidatesStage {
|
||||
pub struct GraphExpansionStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for GraphExpansionStage {
|
||||
impl Stage for GraphExpansionStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::GraphExpansion
|
||||
}
|
||||
@@ -207,7 +182,7 @@ impl PipelineStage for GraphExpansionStage {
|
||||
pub struct RerankStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for RerankStage {
|
||||
impl Stage for RerankStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Rerank
|
||||
}
|
||||
@@ -221,7 +196,7 @@ impl PipelineStage for RerankStage {
|
||||
pub struct AssembleEntitiesStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for AssembleEntitiesStage {
|
||||
impl Stage for AssembleEntitiesStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Assemble
|
||||
}
|
||||
@@ -235,7 +210,7 @@ impl PipelineStage for AssembleEntitiesStage {
|
||||
pub struct ChunkVectorStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for ChunkVectorStage {
|
||||
impl Stage for ChunkVectorStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::CollectCandidates
|
||||
}
|
||||
@@ -249,7 +224,7 @@ impl PipelineStage for ChunkVectorStage {
|
||||
pub struct ChunkRerankStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for ChunkRerankStage {
|
||||
impl Stage for ChunkRerankStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Rerank
|
||||
}
|
||||
@@ -263,7 +238,7 @@ impl PipelineStage for ChunkRerankStage {
|
||||
pub struct ChunkAssembleStage;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineStage for ChunkAssembleStage {
|
||||
impl Stage for ChunkAssembleStage {
|
||||
fn kind(&self) -> StageKind {
|
||||
StageKind::Assemble
|
||||
}
|
||||
@@ -283,8 +258,7 @@ pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
let embedding = if let Some(provider) = ctx.embedding_provider {
|
||||
provider.embed(&ctx.input_text).await.map_err(|e| {
|
||||
AppError::InternalError(format!(
|
||||
"Failed to generate embedding with provider: {}",
|
||||
e
|
||||
"Failed to generate embedding with provider: {e}",
|
||||
))
|
||||
})?
|
||||
} else {
|
||||
@@ -299,7 +273,7 @@ pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn collect_candidates(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
debug!("Collecting initial candidates via vector and FTS search");
|
||||
let embedding = ctx.ensure_embedding()?.clone();
|
||||
let embedding = ctx.ensure_embedding().map_err(|e| *e)?.clone();
|
||||
let tuning = &ctx.config.tuning;
|
||||
|
||||
let weights = FusionWeights::default();
|
||||
@@ -487,11 +461,11 @@ pub async fn rerank(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
debug!("Collecting vector chunk candidates for revised strategy");
|
||||
let embedding = ctx.ensure_embedding()?.clone();
|
||||
let embedding = ctx.ensure_embedding().map_err(|e| *e)?.clone();
|
||||
let tuning = &ctx.config.tuning;
|
||||
let fts_take = tuning.chunk_fts_take;
|
||||
let (fts_query, fts_token_count) = normalize_fts_query(&ctx.input_text);
|
||||
let fts_enabled = tuning.chunk_rrf_use_fts && fts_take > 0 && !fts_query.is_empty();
|
||||
let fts_enabled = tuning.flags.chunk_rrf_use_fts() && fts_take > 0 && !fts_query.is_empty();
|
||||
|
||||
let (vector_rows, fts_rows) = tokio::try_join!(
|
||||
TextChunk::vector_search(
|
||||
@@ -532,8 +506,8 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(),
|
||||
k: tuning.chunk_rrf_k,
|
||||
vector_weight: tuning.chunk_rrf_vector_weight,
|
||||
fts_weight,
|
||||
use_vector: tuning.chunk_rrf_use_vector,
|
||||
use_fts: tuning.chunk_rrf_use_fts && fts_candidates > 0,
|
||||
use_vector: tuning.flags.chunk_rrf_use_vector(),
|
||||
use_fts: tuning.flags.chunk_rrf_use_fts() && fts_candidates > 0,
|
||||
};
|
||||
|
||||
let mut vector_chunks = reciprocal_rank_fusion(vector_scored, fts_scored, rrf_config);
|
||||
@@ -715,7 +689,7 @@ pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
let mut per_entity_count = 0;
|
||||
for candidate in candidates.iter() {
|
||||
if let Some(trace) = entity_trace.as_mut() {
|
||||
trace.inspected_candidates += 1;
|
||||
trace.inspected_candidates = trace.inspected_candidates.saturating_add(1);
|
||||
}
|
||||
if per_entity_count >= tuning.max_chunks_per_entity {
|
||||
break;
|
||||
@@ -723,17 +697,17 @@ pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
let estimated_tokens =
|
||||
estimate_tokens(&candidate.item.chunk, tuning.avg_chars_per_token);
|
||||
if estimated_tokens > token_budget_remaining {
|
||||
chunks_skipped_due_budget += 1;
|
||||
chunks_skipped_due_budget = chunks_skipped_due_budget.saturating_add(1);
|
||||
if let Some(trace) = entity_trace.as_mut() {
|
||||
trace.skipped_due_budget += 1;
|
||||
trace.skipped_due_budget = trace.skipped_due_budget.saturating_add(1);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
token_budget_remaining = token_budget_remaining.saturating_sub(estimated_tokens);
|
||||
tokens_spent += estimated_tokens;
|
||||
per_entity_count += 1;
|
||||
chunks_selected += 1;
|
||||
tokens_spent = tokens_spent.saturating_add(estimated_tokens);
|
||||
per_entity_count = per_entity_count.saturating_add(1);
|
||||
chunks_selected = chunks_selected.saturating_add(1);
|
||||
|
||||
selected_chunks.push(RetrievedChunk {
|
||||
chunk: candidate.item.clone(),
|
||||
@@ -780,14 +754,14 @@ pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
|
||||
const SCORE_SAMPLE_LIMIT: usize = 8;
|
||||
|
||||
fn sample_scores<T, F>(items: &[Scored<T>], mut extractor: F) -> Vec<f32>
|
||||
fn sample_scores<T, F>(items: &[Scored<T>], extractor: F) -> Vec<f32>
|
||||
where
|
||||
F: FnMut(&Scored<T>) -> f32,
|
||||
{
|
||||
items
|
||||
.iter()
|
||||
.take(SCORE_SAMPLE_LIMIT)
|
||||
.map(|item| extractor(item))
|
||||
.map(extractor)
|
||||
.collect()
|
||||
}
|
||||
|
||||
@@ -912,7 +886,7 @@ fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec<RerankResult
|
||||
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
|
||||
let normalized_scores = min_max_normalize(&raw_scores);
|
||||
|
||||
let use_only = ctx.config.tuning.rerank_scores_only;
|
||||
let use_only = ctx.config.tuning.flags.rerank_scores_only();
|
||||
let blend = if use_only {
|
||||
1.0
|
||||
} else {
|
||||
@@ -942,11 +916,7 @@ fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec<RerankResult
|
||||
}
|
||||
}
|
||||
|
||||
for slot in remaining.into_iter() {
|
||||
if let Some(candidate) = slot {
|
||||
reranked.push(candidate);
|
||||
}
|
||||
}
|
||||
reranked.extend(remaining.into_iter().flatten());
|
||||
|
||||
ctx.filtered_entities = reranked;
|
||||
let keep_top = ctx.config.tuning.rerank_keep_top;
|
||||
@@ -970,7 +940,7 @@ fn apply_chunk_rerank_results(
|
||||
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
|
||||
let normalized_scores = min_max_normalize(&raw_scores);
|
||||
|
||||
let use_only = tuning.rerank_scores_only;
|
||||
let use_only = tuning.flags.rerank_scores_only();
|
||||
let blend = if use_only {
|
||||
1.0
|
||||
} else {
|
||||
@@ -1001,11 +971,7 @@ fn apply_chunk_rerank_results(
|
||||
}
|
||||
}
|
||||
|
||||
for slot in remaining.into_iter() {
|
||||
if let Some(candidate) = slot {
|
||||
reranked.push(candidate);
|
||||
}
|
||||
}
|
||||
reranked.extend(remaining.into_iter().flatten());
|
||||
|
||||
let keep_top = tuning.rerank_keep_top;
|
||||
if keep_top > 0 && reranked.len() > keep_top {
|
||||
@@ -1017,7 +983,7 @@ fn apply_chunk_rerank_results(
|
||||
|
||||
fn estimate_tokens(text: &str, avg_chars_per_token: usize) -> usize {
|
||||
let chars = text.chars().count().max(1);
|
||||
(chars / avg_chars_per_token).max(1)
|
||||
chars.checked_div(avg_chars_per_token).map_or(1, |v| v.max(1))
|
||||
}
|
||||
|
||||
fn rank_chunks_by_combined_score(
|
||||
@@ -1053,13 +1019,20 @@ fn lexical_overlap_score(terms: &[String], haystack: &str) -> f32 {
|
||||
return 0.0;
|
||||
}
|
||||
let lower = haystack.to_ascii_lowercase();
|
||||
let mut matches = 0usize;
|
||||
let mut matches: u32 = 0;
|
||||
for term in terms {
|
||||
if lower.contains(term) {
|
||||
matches += 1;
|
||||
matches = matches.saturating_add(1);
|
||||
}
|
||||
}
|
||||
(matches as f32) / (terms.len() as f32)
|
||||
let total = u32::try_from(terms.len()).unwrap_or(u32::MAX);
|
||||
if total == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
let num = matches.min(total);
|
||||
let num_f32 = u16::try_from(num).map(f32::from).unwrap_or(f32::MAX);
|
||||
let den_f32 = u16::try_from(total).map(f32::from).unwrap_or(f32::MAX);
|
||||
num_f32 / den_f32
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
||||
@@ -28,7 +28,7 @@ impl StrategyDriver for DefaultStrategyDriver {
|
||||
]
|
||||
}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
|
||||
Ok(ctx.take_chunk_results())
|
||||
}
|
||||
}
|
||||
@@ -55,7 +55,7 @@ impl StrategyDriver for RelationshipSuggestionDriver {
|
||||
]
|
||||
}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
|
||||
Ok(ctx.take_entity_results())
|
||||
}
|
||||
}
|
||||
@@ -82,7 +82,7 @@ impl StrategyDriver for IngestionDriver {
|
||||
]
|
||||
}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
|
||||
Ok(ctx.take_entity_results())
|
||||
}
|
||||
}
|
||||
@@ -134,7 +134,7 @@ impl StrategyDriver for SearchStrategyDriver {
|
||||
}
|
||||
}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, Box<AppError>> {
|
||||
let chunks = match self.target {
|
||||
SearchTarget::EntitiesOnly => Vec::new(),
|
||||
_ => ctx.take_chunk_results(),
|
||||
|
||||
@@ -17,7 +17,7 @@ static NEXT_ENGINE: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
fn pick_engine_index(pool_len: usize) -> usize {
|
||||
let n = NEXT_ENGINE.fetch_add(1, Ordering::Relaxed);
|
||||
n % pool_len
|
||||
n.checked_rem(pool_len).unwrap_or(0)
|
||||
}
|
||||
|
||||
pub struct RerankerPool {
|
||||
@@ -28,30 +28,30 @@ pub struct RerankerPool {
|
||||
impl RerankerPool {
|
||||
/// Build the pool at startup.
|
||||
/// `pool_size` controls max parallel reranks.
|
||||
pub fn new(pool_size: usize) -> Result<Arc<Self>, AppError> {
|
||||
Self::new_with_options(
|
||||
pool_size,
|
||||
RerankInitOptions::new(fastembed::RerankerModel::JINARerankerV1TurboEn),
|
||||
)
|
||||
pub fn new(pool_size: usize) -> Result<Arc<Self>, Box<AppError>> {
|
||||
let init_options =
|
||||
RerankInitOptions::new(fastembed::RerankerModel::JINARerankerV1TurboEn);
|
||||
Self::new_with_options(pool_size, &init_options)
|
||||
}
|
||||
|
||||
fn new_with_options(
|
||||
pool_size: usize,
|
||||
init_options: RerankInitOptions,
|
||||
) -> Result<Arc<Self>, AppError> {
|
||||
init_options: &RerankInitOptions,
|
||||
) -> Result<Arc<Self>, Box<AppError>> {
|
||||
if pool_size == 0 {
|
||||
return Err(AppError::Validation(
|
||||
return Err(Box::new(AppError::Validation(
|
||||
"RERANKING_POOL_SIZE must be greater than zero".to_string(),
|
||||
));
|
||||
)));
|
||||
}
|
||||
|
||||
fs::create_dir_all(&init_options.cache_dir)?;
|
||||
fs::create_dir_all(&init_options.cache_dir)
|
||||
.map_err(|e| Box::new(AppError::from(e)))?;
|
||||
|
||||
let mut engines = Vec::with_capacity(pool_size);
|
||||
for x in 0..pool_size {
|
||||
debug!("Creating reranking engine: {x}");
|
||||
let model = TextRerank::try_new(init_options.clone())
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
.map_err(|e| Box::new(AppError::InternalError(e.to_string())))?;
|
||||
engines.push(Arc::new(Mutex::new(model)));
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ impl RerankerPool {
|
||||
}
|
||||
|
||||
/// Initialize a pool using application configuration.
|
||||
pub fn maybe_from_config(config: &AppConfig) -> Result<Option<Arc<Self>>, AppError> {
|
||||
pub fn maybe_from_config(config: &AppConfig) -> Result<Option<Arc<Self>>, Box<AppError>> {
|
||||
if !config.reranking_enabled {
|
||||
return Ok(None);
|
||||
}
|
||||
@@ -70,30 +70,28 @@ impl RerankerPool {
|
||||
let pool_size = config.reranking_pool_size.unwrap_or_else(default_pool_size);
|
||||
|
||||
let init_options = build_rerank_init_options(config)?;
|
||||
Self::new_with_options(pool_size, init_options).map(Some)
|
||||
Self::new_with_options(pool_size, &init_options).map(Some)
|
||||
}
|
||||
|
||||
/// Check out capacity + pick an engine.
|
||||
/// This returns a lease that can perform rerank().
|
||||
pub async fn checkout(self: &Arc<Self>) -> RerankerLease {
|
||||
/// This returns a lease that can perform `rerank()`.
|
||||
pub async fn checkout(self: &Arc<Self>) -> Option<RerankerLease> {
|
||||
// Acquire a permit. This enforces backpressure.
|
||||
let permit = self
|
||||
.semaphore
|
||||
.clone()
|
||||
let permit = Arc::clone(&self.semaphore)
|
||||
.acquire_owned()
|
||||
.await
|
||||
.expect("semaphore closed");
|
||||
.ok()?;
|
||||
|
||||
// Pick an engine.
|
||||
// This is naive: just pick based on a simple modulo counter.
|
||||
// We use an atomic counter to avoid always choosing index 0.
|
||||
let idx = pick_engine_index(self.engines.len());
|
||||
let engine = self.engines[idx].clone();
|
||||
let engine = self.engines.get(idx).map(Arc::clone)?;
|
||||
|
||||
RerankerLease {
|
||||
Some(RerankerLease {
|
||||
_permit: permit,
|
||||
engine,
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,7 +109,7 @@ fn is_truthy(value: &str) -> bool {
|
||||
)
|
||||
}
|
||||
|
||||
fn build_rerank_init_options(config: &AppConfig) -> Result<RerankInitOptions, AppError> {
|
||||
fn build_rerank_init_options(config: &AppConfig) -> Result<RerankInitOptions, Box<AppError>> {
|
||||
let mut options = RerankInitOptions::default();
|
||||
|
||||
let cache_dir = config
|
||||
@@ -125,7 +123,7 @@ fn build_rerank_init_options(config: &AppConfig) -> Result<RerankInitOptions, Ap
|
||||
.join("fastembed")
|
||||
.join("reranker")
|
||||
});
|
||||
fs::create_dir_all(&cache_dir)?;
|
||||
fs::create_dir_all(&cache_dir).map_err(|e| Box::new(AppError::from(e)))?;
|
||||
options.cache_dir = cache_dir;
|
||||
|
||||
let show_progress = config
|
||||
@@ -150,7 +148,7 @@ fn env_bool(key: &str) -> Option<bool> {
|
||||
env::var(key).ok().map(|value| is_truthy(&value))
|
||||
}
|
||||
|
||||
/// Active lease on a single TextRerank instance.
|
||||
/// Active lease on a single `TextRerank` instance.
|
||||
pub struct RerankerLease {
|
||||
// When this drops the semaphore permit is released.
|
||||
_permit: OwnedSemaphorePermit,
|
||||
|
||||
@@ -28,16 +28,19 @@ impl<T> Scored<T> {
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn with_vector_score(mut self, score: f32) -> Self {
|
||||
self.scores.vector = Some(score);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn with_fts_score(mut self, score: f32) -> Self {
|
||||
self.scores.fts = Some(score);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn with_graph_score(mut self, score: f32) -> Self {
|
||||
self.scores.graph = Some(score);
|
||||
self
|
||||
@@ -168,7 +171,7 @@ pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 {
|
||||
if scores.vector.is_some() && scores.fts.is_some() {
|
||||
// Multiplicative boost: multiply by (1 + bonus) to scale with the base score
|
||||
// This ensures high-scoring golden chunks get boosted more than low-scoring ones
|
||||
fused = fused * (1.0 + weights.multi_bonus);
|
||||
fused *= 1.0 + weights.multi_bonus;
|
||||
} else {
|
||||
// For other multi-signal combinations (e.g., vector + graph), use additive bonus
|
||||
fused += weights.multi_bonus;
|
||||
@@ -178,8 +181,8 @@ pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 {
|
||||
clamp_unit(fused)
|
||||
}
|
||||
|
||||
pub fn merge_scored_by_id<T>(
|
||||
target: &mut std::collections::HashMap<String, Scored<T>>,
|
||||
pub fn merge_scored_by_id<T, S: std::hash::BuildHasher>(
|
||||
target: &mut std::collections::HashMap<String, Scored<T>, S>,
|
||||
incoming: Vec<Scored<T>>,
|
||||
) where
|
||||
T: StoredObject + Clone,
|
||||
@@ -263,7 +266,10 @@ where
|
||||
}
|
||||
}
|
||||
entry.item = candidate.item;
|
||||
entry.fused += vector_weight / (k + rank as f32 + 1.0);
|
||||
let rank_f32: f32 = u16::try_from(rank)
|
||||
.map(f32::from)
|
||||
.unwrap_or(f32::MAX);
|
||||
entry.fused += vector_weight / (k + rank_f32 + 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -290,7 +296,10 @@ where
|
||||
}
|
||||
}
|
||||
entry.item = candidate.item;
|
||||
entry.fused += fts_weight / (k + rank as f32 + 1.0);
|
||||
let rank_f32: f32 = u16::try_from(rank)
|
||||
.map(f32::from)
|
||||
.unwrap_or(f32::MAX);
|
||||
entry.fused += fts_weight / (k + rank_f32 + 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user