From d3443d4153bf8b5e0b0ee31c9efaabc7ed803490 Mon Sep 17 00:00:00 2001 From: Per Stark Date: Fri, 29 May 2026 14:35:07 +0200 Subject: [PATCH] chore: centralize embedding errors, retrieval strategy, and test DB helpers. Replace anyhow in embedding production code with EmbeddingError, move RetrievalStrategy into common config, and deduplicate Surreal test setup via common::test_utils. --- common/src/error.rs | 32 +++++ common/src/lib.rs | 3 + common/src/storage/types/knowledge_entity.rs | 15 +- .../types/knowledge_entity_embedding.rs | 50 ++----- .../storage/types/knowledge_relationship.rs | 30 ++-- common/src/storage/types/text_chunk.rs | 15 +- .../src/storage/types/text_chunk_embedding.rs | 43 +----- common/src/storage/types/text_content.rs | 29 +--- common/src/storage/types/user.rs | 15 +- common/src/test_utils.rs | 96 +++++++++++++ common/src/utils/config.rs | 135 +++++++++++++++++- common/src/utils/embedding.rs | 90 +++++++----- html-router/src/html_state.rs | 26 ++-- html-router/src/routes/search/handlers.rs | 30 +--- .../src/pipeline/enrichment_result.rs | 5 +- retrieval-pipeline/src/pipeline/config.rs | 50 +------ retrieval-pipeline/src/pipeline/stages/mod.rs | 6 +- 17 files changed, 366 insertions(+), 304 deletions(-) create mode 100644 common/src/test_utils.rs diff --git a/common/src/error.rs b/common/src/error.rs index 7fd4f22..c68559e 100644 --- a/common/src/error.rs +++ b/common/src/error.rs @@ -4,6 +4,36 @@ use tokio::task::JoinError; use crate::storage::types::file_info::FileError; +/// Errors from embedding provider operations. +#[allow(clippy::module_name_repetitions)] +#[derive(Error, Debug)] +pub enum EmbeddingError { + #[error("openai error: {0}")] + OpenAI(#[from] OpenAIError), + #[error("fastembed error: {0}")] + FastEmbed(String), + #[error("task join error: {0}")] + Join(#[from] JoinError), + #[error("fastembed model mutex poisoned: {0}")] + MutexPoisoned(String), + #[error("no embedding data received")] + NoData, + #[error("embedding configuration error: {0}")] + Config(String), + #[error("unknown fastembed model: {0}")] + UnknownModel(String), +} + +impl EmbeddingError { + pub(crate) fn fastembed(err: impl std::fmt::Display) -> Self { + Self::FastEmbed(err.to_string()) + } + + pub(crate) fn mutex_poisoned(err: impl std::fmt::Display) -> Self { + Self::MutexPoisoned(err.to_string()) + } +} + // Core internal errors #[allow(clippy::module_name_repetitions)] #[derive(Error, Debug)] @@ -12,6 +42,8 @@ pub enum AppError { Database(#[from] surrealdb::Error), #[error("openai error: {0}")] OpenAI(#[from] OpenAIError), + #[error("embedding error: {0}")] + Embedding(#[from] EmbeddingError), #[error("file error: {0}")] File(#[from] FileError), #[error("not found: {0}")] diff --git a/common/src/lib.rs b/common/src/lib.rs index ae9d7b2..2199d70 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -3,3 +3,6 @@ pub mod error; pub mod storage; pub mod utils; + +#[cfg(any(test, feature = "test-utils"))] +pub mod test_utils; diff --git a/common/src/storage/types/knowledge_entity.rs b/common/src/storage/types/knowledge_entity.rs index 23d49c5..1d4c830 100644 --- a/common/src/storage/types/knowledge_entity.rs +++ b/common/src/storage/types/knowledge_entity.rs @@ -499,8 +499,7 @@ impl KnowledgeEntity { let embedding = provider .embed(&embedding_input) - .await - .map_err(AppError::internal)?; + .await?; // Safety check: ensure the generated embedding has the correct dimension. if embedding.len() != new_dimensions { @@ -599,21 +598,11 @@ mod tests { #![allow(clippy::expect_used, clippy::must_use_candidate)] use super::*; use crate::storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding; - use crate::storage::types::system_settings::SystemSettings; + use crate::test_utils::configure_embedding_dimension; use anyhow::{self, Context}; use serde_json::json; use uuid::Uuid; - async fn configure_embedding_dimension( - db: &SurrealDbClient, - dimension: u32, - ) -> anyhow::Result<()> { - let mut settings = SystemSettings::get_current(db).await?; - settings.embedding_dimensions = dimension; - SystemSettings::update(db, settings).await?; - Ok(()) - } - #[tokio::test] async fn test_knowledge_entity_creation() -> anyhow::Result<()> { let source_id = "source123".to_string(); diff --git a/common/src/storage/types/knowledge_entity_embedding.rs b/common/src/storage/types/knowledge_entity_embedding.rs index 32d0c6b..60f61e5 100644 --- a/common/src/storage/types/knowledge_entity_embedding.rs +++ b/common/src/storage/types/knowledge_entity_embedding.rs @@ -158,45 +158,11 @@ impl KnowledgeEntityEmbedding { mod tests { #![allow(clippy::expect_used, clippy::must_use_candidate)] use super::*; - use crate::storage::db::SurrealDbClient; use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}; - use crate::storage::types::system_settings::SystemSettings; + use crate::test_utils::{prepare_knowledge_entity_test_db, setup_test_db}; use anyhow::{self, Context}; use chrono::Utc; use surrealdb::Value as SurrealValue; - use uuid::Uuid; - - async fn setup_test_db() -> anyhow::Result { - let namespace = "test_ns"; - let database = Uuid::new_v4().to_string(); - let db = SurrealDbClient::memory(namespace, &database) - .await - .with_context(|| "Failed to start in-memory surrealdb".to_string())?; - - db.apply_migrations() - .await - .with_context(|| "Failed to apply migrations".to_string())?; - - Ok(db) - } - - async fn setup_test_db_with_embedding_dimension( - dimension: u32, - ) -> anyhow::Result { - let db = setup_test_db().await?; - let mut settings = SystemSettings::get_current(&db).await?; - settings.embedding_dimensions = dimension; - SystemSettings::update(&db, settings).await?; - Ok(db) - } - - async fn prepare_test_db(dimension: u32) -> anyhow::Result { - let db = setup_test_db_with_embedding_dimension(dimension).await?; - KnowledgeEntityEmbedding::redefine_hnsw_index(&db, dimension as usize) - .await - .with_context(|| format!("set test index dimension to {dimension}"))?; - Ok(db) - } fn build_knowledge_entity_with_id( key: &str, @@ -236,7 +202,7 @@ mod tests { #[tokio::test] async fn test_create_and_get_by_entity_id() -> anyhow::Result<()> { - let db = prepare_test_db(3).await?; + let db = prepare_knowledge_entity_test_db(3).await?; let user_id = "user_ke"; let entity_key = "entity-1"; let source_id = "source-ke"; @@ -266,7 +232,7 @@ mod tests { #[tokio::test] async fn test_delete_by_entity_id() -> anyhow::Result<()> { - let db = prepare_test_db(3).await?; + let db = prepare_knowledge_entity_test_db(3).await?; let user_id = "user_ke"; let entity_key = "entity-delete"; let source_id = "source-del"; @@ -298,7 +264,7 @@ mod tests { #[tokio::test] async fn test_store_with_embedding_creates_entity_and_embedding() -> anyhow::Result<()> { - let db = prepare_test_db(3).await?; + let db = prepare_knowledge_entity_test_db(3).await?; let user_id = "user_store"; let source_id = "source_store"; let embedding = vec![0.2_f32, 0.3, 0.4]; @@ -331,7 +297,7 @@ mod tests { #[tokio::test] async fn test_store_with_embedding_rejects_wrong_dimension() -> anyhow::Result<()> { - let db = prepare_test_db(3).await?; + let db = prepare_knowledge_entity_test_db(3).await?; let entity = build_knowledge_entity_with_id("entity-dim", "source-dim", "user-dim"); let result = @@ -344,7 +310,7 @@ mod tests { #[tokio::test] async fn test_delete_by_source_id() -> anyhow::Result<()> { - let db = prepare_test_db(3).await?; + let db = prepare_knowledge_entity_test_db(3).await?; let user_id = "user_ke"; let source_id = "shared-ke"; let other_source = "other-ke"; @@ -437,7 +403,7 @@ mod tests { entity_id: KnowledgeEntity, } - let db = prepare_test_db(3).await?; + let db = prepare_knowledge_entity_test_db(3).await?; let user_id = "user_ke"; let entity_key = "entity-fetch"; let source_id = "source-fetch"; @@ -475,7 +441,7 @@ mod tests { #[tokio::test] async fn test_upsert_replaces_existing_embedding_row() -> anyhow::Result<()> { - let db = prepare_test_db(3).await?; + let db = prepare_knowledge_entity_test_db(3).await?; let user_id = "user-upsert"; let source_id = "source-upsert"; diff --git a/common/src/storage/types/knowledge_relationship.rs b/common/src/storage/types/knowledge_relationship.rs index 2f6cd07..29d0f71 100644 --- a/common/src/storage/types/knowledge_relationship.rs +++ b/common/src/storage/types/knowledge_relationship.rs @@ -151,19 +151,7 @@ mod tests { use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}; use anyhow::{self, Context}; - async fn setup_test_db() -> SurrealDbClient { - let namespace = "test_ns"; - let database = &Uuid::new_v4().to_string(); - let db = SurrealDbClient::memory(namespace, database) - .await - .expect("Failed to start in-memory surrealdb"); - - db.apply_migrations() - .await - .expect("Failed to apply migrations"); - - db - } + use crate::test_utils::setup_test_db; async fn get_relationship_by_id( relationship_id: &str, @@ -234,7 +222,7 @@ mod tests { #[tokio::test] async fn test_store_and_verify_by_source_id() -> anyhow::Result<()> { - let db = setup_test_db().await; + let db = setup_test_db().await?; let user_id = "user123"; let entity1_id = create_test_entity("Entity 1", user_id, &db).await?; @@ -282,7 +270,7 @@ mod tests { #[tokio::test] async fn test_store_relationship_rejects_foreign_entity() -> anyhow::Result<()> { - let db = setup_test_db().await; + let db = setup_test_db().await?; let owner_entity = create_test_entity("Owner entity", "owner-user", &db).await?; let other_entity = create_test_entity("Other entity", "other-user", &db).await?; @@ -303,7 +291,7 @@ mod tests { #[tokio::test] async fn test_store_relationship_resists_query_injection() -> anyhow::Result<()> { - let db = setup_test_db().await; + let db = setup_test_db().await?; let user_id = "user123"; let entity1_id = create_test_entity("Entity 1", user_id, &db).await?; @@ -342,7 +330,7 @@ mod tests { #[tokio::test] async fn test_store_and_delete_relationship() -> anyhow::Result<()> { - let db = setup_test_db().await; + let db = setup_test_db().await?; let user_id = "user123"; let entity1_id = create_test_entity("Entity 1", user_id, &db).await?; @@ -396,7 +384,7 @@ mod tests { #[tokio::test] async fn test_delete_relationship_by_id_unauthorized() -> anyhow::Result<()> { - let db = setup_test_db().await; + let db = setup_test_db().await?; let owner_user_id = "owner-user"; let entity1_id = create_test_entity("Entity 1", owner_user_id, &db).await?; @@ -459,7 +447,7 @@ mod tests { #[tokio::test] async fn test_store_relationship_exists() -> anyhow::Result<()> { - let db = setup_test_db().await; + let db = setup_test_db().await?; let user_id = "user123"; let entity1_id = create_test_entity("Entity 1", user_id, &db).await?; @@ -543,7 +531,7 @@ mod tests { #[tokio::test] async fn test_delete_relationships_by_source_id_scoped_to_user() -> anyhow::Result<()> { - let db = setup_test_db().await; + let db = setup_test_db().await?; let user_a = "user-a"; let user_b = "user-b"; @@ -584,7 +572,7 @@ mod tests { #[tokio::test] async fn test_delete_relationships_by_source_id_resists_query_injection() -> anyhow::Result<()> { - let db = setup_test_db().await; + let db = setup_test_db().await?; let user_id = "user123"; let entity1_id = create_test_entity("Entity 1", user_id, &db).await?; diff --git a/common/src/storage/types/text_chunk.rs b/common/src/storage/types/text_chunk.rs index b158d4e..1b5bf9c 100644 --- a/common/src/storage/types/text_chunk.rs +++ b/common/src/storage/types/text_chunk.rs @@ -384,8 +384,7 @@ impl TextChunk { let embedding = provider .embed(&chunk.chunk) - .await - .map_err(AppError::internal)?; + .await?; // Safety check: ensure the generated embedding has the correct dimension. if embedding.len() != new_dimensions { @@ -489,21 +488,11 @@ mod tests { use super::*; use crate::storage::indexes::{ensure_runtime, rebuild}; - use crate::storage::types::system_settings::SystemSettings; use crate::storage::types::text_chunk_embedding::TextChunkEmbedding; + use crate::test_utils::configure_embedding_dimension; use surrealdb::RecordId; use uuid::Uuid; - async fn configure_embedding_dimension( - db: &SurrealDbClient, - dimension: u32, - ) -> anyhow::Result<()> { - let mut settings = SystemSettings::get_current(db).await?; - settings.embedding_dimensions = dimension; - SystemSettings::update(db, settings).await?; - Ok(()) - } - async fn ensure_chunk_fts_index(db: &SurrealDbClient) -> anyhow::Result<()> { let snowball_sql = r#" DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english); diff --git a/common/src/storage/types/text_chunk_embedding.rs b/common/src/storage/types/text_chunk_embedding.rs index 69b45e0..71786a2 100644 --- a/common/src/storage/types/text_chunk_embedding.rs +++ b/common/src/storage/types/text_chunk_embedding.rs @@ -144,41 +144,8 @@ mod tests { use super::*; use crate::storage::db::SurrealDbClient; - use crate::storage::types::system_settings::SystemSettings; + use crate::test_utils::{prepare_text_chunk_test_db, setup_test_db}; use surrealdb::Value as SurrealValue; - use uuid::Uuid; - - async fn setup_test_db() -> anyhow::Result { - let namespace = "test_ns"; - let database = Uuid::new_v4().to_string(); - let db = SurrealDbClient::memory(namespace, &database) - .await - .with_context(|| "Failed to start in-memory surrealdb".to_string())?; - - db.apply_migrations() - .await - .with_context(|| "Failed to apply migrations".to_string())?; - - Ok(db) - } - - async fn setup_test_db_with_embedding_dimension( - dimension: u32, - ) -> anyhow::Result { - let db = setup_test_db().await?; - let mut settings = SystemSettings::get_current(&db).await?; - settings.embedding_dimensions = dimension; - SystemSettings::update(&db, settings).await?; - Ok(db) - } - - async fn prepare_test_db(dimension: u32) -> anyhow::Result { - let db = setup_test_db_with_embedding_dimension(dimension).await?; - TextChunkEmbedding::redefine_hnsw_index(&db, dimension as usize) - .await - .with_context(|| format!("set test index dimension to {dimension}"))?; - Ok(db) - } async fn create_text_chunk_with_id( db: &SurrealDbClient, @@ -245,7 +212,7 @@ mod tests { #[tokio::test] async fn test_create_and_get_by_chunk_id() -> anyhow::Result<()> { - let db = prepare_test_db(3).await?; + let db = prepare_text_chunk_test_db(3).await?; let user_id = "user_a"; let chunk_key = "chunk-123"; @@ -279,7 +246,7 @@ mod tests { #[tokio::test] async fn test_delete_by_chunk_id() -> anyhow::Result<()> { - let db = prepare_test_db(3).await?; + let db = prepare_text_chunk_test_db(3).await?; let user_id = "user_b"; let chunk_key = "chunk-delete"; @@ -316,7 +283,7 @@ mod tests { #[tokio::test] async fn test_delete_by_source_id() -> anyhow::Result<()> { - let db = prepare_test_db(1).await?; + let db = prepare_text_chunk_test_db(1).await?; let user_id = "user_c"; let source_id = "shared-source"; @@ -377,7 +344,7 @@ mod tests { #[tokio::test] async fn test_upsert_replaces_existing_embedding_row() -> anyhow::Result<()> { - let db = prepare_test_db(3).await?; + let db = prepare_text_chunk_test_db(3).await?; let user_id = "user-upsert"; let source_id = "source-upsert"; diff --git a/common/src/storage/types/text_content.rs b/common/src/storage/types/text_content.rs index 27326dd..ab0b34a 100644 --- a/common/src/storage/types/text_content.rs +++ b/common/src/storage/types/text_content.rs @@ -202,25 +202,7 @@ mod tests { use anyhow::{self, Context}; use super::*; - use crate::storage::indexes::{ensure_runtime, rebuild}; - - async fn setup_test_db() -> anyhow::Result { - let namespace = "test_ns"; - let database = Uuid::new_v4().to_string(); - let db = SurrealDbClient::memory(namespace, &database) - .await - .with_context(|| "Failed to start in-memory surrealdb".to_string())?; - db.apply_migrations() - .await - .with_context(|| "Failed to apply migrations".to_string())?; - ensure_runtime(&db, 1536) - .await - .with_context(|| "ensure runtime indexes".to_string())?; - rebuild(&db) - .await - .with_context(|| "rebuild indexes".to_string())?; - Ok(db) - } + use crate::test_utils::setup_test_db_with_runtime_indexes; #[tokio::test] async fn test_text_content_creation() -> anyhow::Result<()> { @@ -339,7 +321,7 @@ mod tests { #[tokio::test] async fn test_text_content_patch_not_found() -> anyhow::Result<()> { - let db = setup_test_db().await?; + let db = setup_test_db_with_runtime_indexes().await?; let err = TextContent::patch("missing-id", "ctx", "cat", "text", &db) .await @@ -412,7 +394,7 @@ mod tests { #[tokio::test] async fn test_search_returns_empty_when_no_content() -> anyhow::Result<()> { - let db = setup_test_db().await?; + let db = setup_test_db_with_runtime_indexes().await?; let results = TextContent::search(&db, "hello", "user", 5) .await @@ -424,7 +406,7 @@ mod tests { #[tokio::test] async fn test_search_finds_matching_text_and_filters_user() -> anyhow::Result<()> { - let db = setup_test_db().await?; + let db = setup_test_db_with_runtime_indexes().await?; let user_id = "search_user"; let matching = TextContent::new( @@ -450,9 +432,6 @@ mod tests { db.store_item(other_user) .await .with_context(|| "store other user".to_string())?; - rebuild(&db) - .await - .with_context(|| "rebuild indexes".to_string())?; let results = TextContent::search(&db, "rust", user_id, 5) .await diff --git a/common/src/storage/types/user.rs b/common/src/storage/types/user.rs index 475773b..a39940a 100644 --- a/common/src/storage/types/user.rs +++ b/common/src/storage/types/user.rs @@ -732,20 +732,7 @@ mod tests { use crate::storage::types::ingestion_task::{IngestionTask, TaskState, MAX_ATTEMPTS}; use std::collections::HashSet; - // Helper function to set up a test database with SystemSettings - async fn setup_test_db() -> anyhow::Result { - let namespace = "test_ns"; - let database = Uuid::new_v4().to_string(); - let db = SurrealDbClient::memory(namespace, &database) - .await - .with_context(|| "Failed to start in-memory surrealdb".to_string())?; - - db.apply_migrations() - .await - .with_context(|| "Failed to setup the migrations".to_string())?; - - Ok(db) - } + use crate::test_utils::setup_test_db; #[tokio::test] async fn test_user_creation() -> anyhow::Result<()> { diff --git a/common/src/test_utils.rs b/common/src/test_utils.rs new file mode 100644 index 0000000..6d83409 --- /dev/null +++ b/common/src/test_utils.rs @@ -0,0 +1,96 @@ +//! Shared helpers for in-memory SurrealDB tests. +#![cfg(any(test, feature = "test-utils"))] + +use anyhow::{Context, Result}; +use uuid::Uuid; + +use crate::storage::{ + db::SurrealDbClient, + indexes::{ensure_runtime, rebuild}, + types::{ + knowledge_entity_embedding::KnowledgeEntityEmbedding, + system_settings::SystemSettings, + text_chunk_embedding::TextChunkEmbedding, + }, +}; + +const TEST_NAMESPACE: &str = "test_ns"; + +/// Starts an in-memory database, applies migrations, and returns a client. +/// +/// # Errors +/// +/// Returns an error if the database cannot be started or migrations fail. +pub async fn setup_test_db() -> Result { + let database = Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(TEST_NAMESPACE, &database) + .await + .context("start in-memory surrealdb")?; + + db.apply_migrations() + .await + .context("apply migrations")?; + + Ok(db) +} + +/// Updates singleton [`SystemSettings`] embedding dimensions for tests. +/// +/// # Errors +/// +/// Returns an error if settings cannot be loaded or updated. +pub async fn configure_embedding_dimension(db: &SurrealDbClient, dimension: u32) -> Result<()> { + let mut settings = SystemSettings::get_current(db).await?; + settings.embedding_dimensions = dimension; + SystemSettings::update(db, settings).await?; + Ok(()) +} + +/// Starts a test database and sets the embedding dimension in system settings. +/// +/// # Errors +/// +/// Returns an error if setup or settings update fails. +pub async fn setup_test_db_with_embedding_dimension(dimension: u32) -> Result { + let db = setup_test_db().await?; + configure_embedding_dimension(&db, dimension).await?; + Ok(db) +} + +/// Prepares a database for text-chunk embedding tests at the given dimension. +/// +/// # Errors +/// +/// Returns an error if setup, settings update, or index redefinition fails. +pub async fn prepare_text_chunk_test_db(dimension: u32) -> Result { + let db = setup_test_db_with_embedding_dimension(dimension).await?; + TextChunkEmbedding::redefine_hnsw_index(&db, dimension as usize) + .await + .with_context(|| format!("set text chunk index dimension to {dimension}"))?; + Ok(db) +} + +/// Prepares a database for knowledge-entity embedding tests at the given dimension. +/// +/// # Errors +/// +/// Returns an error if setup, settings update, or index redefinition fails. +pub async fn prepare_knowledge_entity_test_db(dimension: u32) -> Result { + let db = setup_test_db_with_embedding_dimension(dimension).await?; + KnowledgeEntityEmbedding::redefine_hnsw_index(&db, dimension as usize) + .await + .with_context(|| format!("set knowledge entity index dimension to {dimension}"))?; + Ok(db) +} + +/// Starts a test database and ensures runtime FTS/HNSW indexes are ready. +/// +/// # Errors +/// +/// Returns an error if setup, index creation, or rebuild fails. +pub async fn setup_test_db_with_runtime_indexes() -> Result { + let db = setup_test_db().await?; + ensure_runtime(&db, 1536).await?; + rebuild(&db).await?; + Ok(db) +} diff --git a/common/src/utils/config.rs b/common/src/utils/config.rs index ef45320..9d4c727 100644 --- a/common/src/utils/config.rs +++ b/common/src/utils/config.rs @@ -1,7 +1,8 @@ use config::{Config, ConfigError, Environment, File}; -use serde::{Deserialize, Serialize}; -use std::{env, sync::Once, str::FromStr}; +use serde::{Deserialize, Deserializer, Serialize}; +use std::{env, fmt, str::FromStr, sync::Once}; use thiserror::Error; +use tracing::warn; /// Error returned when parsing an embedding backend name. #[derive(Debug, Error, PartialEq, Eq)] @@ -35,6 +36,83 @@ impl EmbeddingBackend { } } +/// Error returned when parsing a retrieval strategy name. +#[derive(Debug, Error, PartialEq, Eq)] +#[error("unknown retrieval strategy '{input}'")] +pub struct ParseRetrievalStrategyError { + /// The unrecognized input string. + pub input: String, +} + +/// Selects which retrieval pipeline strategy to run for chat and search. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub enum RetrievalStrategy { + /// Primary hybrid chunk retrieval for search/chat. + #[default] + Default, + /// Entity retrieval for suggesting relationships when creating manual entities. + RelationshipSuggestion, + /// Entity retrieval for context during content ingestion. + Ingestion, + /// Unified search returning both chunks and entities. + Search, +} + +impl RetrievalStrategy { + #[must_use] + pub fn as_str(self) -> &'static str { + match self { + Self::Default => "default", + Self::RelationshipSuggestion => "relationship_suggestion", + Self::Ingestion => "ingestion", + Self::Search => "search", + } + } +} + +impl FromStr for RetrievalStrategy { + type Err = ParseRetrievalStrategyError; + + fn from_str(value: &str) -> Result { + match value.to_ascii_lowercase().as_str() { + "default" => Ok(Self::Default), + "initial" | "revised" => { + warn!( + "retrieval strategy '{value}' is deprecated; use 'default' instead" + ); + Ok(Self::Default) + } + "relationship_suggestion" => Ok(Self::RelationshipSuggestion), + "ingestion" => Ok(Self::Ingestion), + "search" => Ok(Self::Search), + other => Err(ParseRetrievalStrategyError { + input: other.to_string(), + }), + } + } +} + +impl fmt::Display for RetrievalStrategy { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +fn deserialize_optional_retrieval_strategy<'de, D>( + deserializer: D, +) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + let value = Option::::deserialize(deserializer)?; + match value { + None => Ok(None), + Some(raw) if raw.trim().is_empty() => Ok(None), + Some(raw) => RetrievalStrategy::from_str(&raw).map(Some).map_err(serde::de::Error::custom), + } +} + impl FromStr for EmbeddingBackend { type Err = ParseEmbeddingBackendError; @@ -117,8 +195,8 @@ pub struct AppConfig { pub fastembed_show_download_progress: Option, #[serde(default)] pub fastembed_max_length: Option, - #[serde(default)] - pub retrieval_strategy: Option, + #[serde(default, deserialize_with = "deserialize_optional_retrieval_strategy")] + pub retrieval_strategy: Option, #[serde(default)] pub embedding_backend: EmbeddingBackend, #[serde(default = "default_ingest_max_body_bytes")] @@ -204,6 +282,14 @@ pub fn ensure_ort_path() { }); } +impl AppConfig { + /// Returns the configured retrieval strategy, or [`RetrievalStrategy::Default`] when unset. + #[must_use] + pub fn resolved_retrieval_strategy(&self) -> RetrievalStrategy { + self.retrieval_strategy.unwrap_or_default() + } +} + impl Default for AppConfig { fn default() -> Self { Self { @@ -249,3 +335,44 @@ pub fn get_config() -> Result { config.try_deserialize() } + +#[cfg(test)] +mod tests { + use super::{ParseRetrievalStrategyError, RetrievalStrategy}; + #[test] + fn retrieval_strategy_defaults_to_default() { + assert_eq!( + RetrievalStrategy::default(), + RetrievalStrategy::Default + ); + } + + #[test] + fn retrieval_strategy_serializes_snake_case() { + assert_eq!( + serde_json::to_string(&RetrievalStrategy::Search).expect("serialize"), + "\"search\"" + ); + } + + #[test] + fn retrieval_strategy_from_str_accepts_deprecated_aliases() { + assert_eq!( + "initial".parse::().expect("initial"), + RetrievalStrategy::Default + ); + assert!(matches!( + "unknown".parse::(), + Err(ParseRetrievalStrategyError { .. }) + )); + } + + #[test] + fn app_config_resolved_retrieval_strategy_uses_default_when_unset() { + let config = super::AppConfig::default(); + assert_eq!( + config.resolved_retrieval_strategy(), + RetrievalStrategy::Default + ); + } +} diff --git a/common/src/utils/embedding.rs b/common/src/utils/embedding.rs index a1847b0..54c41fa 100644 --- a/common/src/utils/embedding.rs +++ b/common/src/utils/embedding.rs @@ -5,13 +5,12 @@ use std::{ sync::{Arc, Mutex}, }; -use anyhow::{anyhow, Context, Result}; use async_openai::{types::CreateEmbeddingRequestArgs, Client}; use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions}; use tracing::debug; use crate::{ - error::AppError, + error::{AppError, EmbeddingError}, storage::{db::SurrealDbClient, types::system_settings::SystemSettings}, utils::config::AppConfig, }; @@ -57,16 +56,18 @@ enum EmbeddingInner { async fn run_fastembed( model: Arc>, texts: Vec, -) -> Result>> { - tokio::task::spawn_blocking(move || { +) -> Result>, EmbeddingError> { + match tokio::task::spawn_blocking(move || -> Result>, EmbeddingError> { let mut guard = model .lock() - .map_err(|e| anyhow!("fastembed model mutex poisoned: {e}"))?; - guard.embed(texts, None) + .map_err(EmbeddingError::mutex_poisoned)?; + guard.embed(texts, None).map_err(EmbeddingError::fastembed) }) .await - .context("joining fastembed embedding task")? - .context("generating fastembed embeddings") + { + Ok(result) => result, + Err(join_error) => Err(EmbeddingError::from(join_error)), + } } impl EmbeddingProvider { @@ -102,17 +103,14 @@ impl EmbeddingProvider { /// /// # Errors /// - /// Returns `Err` if the backend API call fails, FastEmbed initialisation fails, + /// Returns [`EmbeddingError`] if the backend API call fails, FastEmbed initialisation fails, /// or the backend returns no embedding data. - pub async fn embed(&self, text: &str) -> Result> { + pub async fn embed(&self, text: &str) -> Result, EmbeddingError> { match &self.inner { EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)), EmbeddingInner::FastEmbed { model, .. } => { let embeddings = run_fastembed(Arc::clone(model), vec![text.to_owned()]).await?; - embeddings - .into_iter() - .next() - .ok_or_else(|| anyhow!("fastembed returned no embedding for input")) + embeddings.into_iter().next().ok_or(EmbeddingError::NoData) } EmbeddingInner::OpenAI { client, @@ -130,7 +128,7 @@ impl EmbeddingProvider { let embedding = response .data .first() - .ok_or_else(|| anyhow!("No embedding data received from OpenAI API"))? + .ok_or(EmbeddingError::NoData)? .embedding .clone(); @@ -143,9 +141,9 @@ impl EmbeddingProvider { /// /// # Errors /// - /// Returns `Err` if the backend API call fails or returns no embedding data. + /// Returns [`EmbeddingError`] if the backend API call fails or returns no embedding data. /// Returns an empty `Vec` when `texts` is empty. - pub async fn embed_batch(&self, texts: Vec) -> Result>> { + pub async fn embed_batch(&self, texts: Vec) -> Result>, EmbeddingError> { match &self.inner { EmbeddingInner::Hashed { dimension } => Ok(texts .into_iter() @@ -185,11 +183,14 @@ impl EmbeddingProvider { } } + /// # Errors + /// + /// Currently infallible; reserved for future validation. pub fn new_openai( client: Arc>, model: String, dimensions: u32, - ) -> Result { + ) -> Result { Ok(Self { inner: EmbeddingInner::OpenAI { client, @@ -199,9 +200,12 @@ impl EmbeddingProvider { }) } - pub async fn new_fastembed(model_override: Option) -> Result { + /// # Errors + /// + /// Returns [`EmbeddingError`] if the model name is unknown or FastEmbed initialisation fails. + pub async fn new_fastembed(model_override: Option) -> Result { let model_name = if let Some(code) = model_override { - EmbeddingModel::from_str(&code).map_err(|err| anyhow!(err))? + EmbeddingModel::from_str(&code).map_err(EmbeddingError::UnknownModel)? } else { EmbeddingModel::default() }; @@ -210,15 +214,21 @@ impl EmbeddingProvider { let model_name_for_task = model_name.clone(); let model_name_code = model_name.to_string(); - let (model, dimension) = tokio::task::spawn_blocking(move || -> Result<_> { + let (model, dimension) = match tokio::task::spawn_blocking(move || -> Result<_, EmbeddingError> { let model = - TextEmbedding::try_new(options).context("initialising FastEmbed text model")?; - let info = EmbeddingModel::get_model_info(&model_name_for_task) - .ok_or_else(|| anyhow!("FastEmbed model metadata missing for {model_name_code}"))?; + TextEmbedding::try_new(options).map_err(EmbeddingError::fastembed)?; + let info = EmbeddingModel::get_model_info(&model_name_for_task).ok_or_else(|| { + EmbeddingError::Config(format!( + "fastembed model metadata missing for {model_name_code}" + )) + })?; Ok((model, info.dim)) }) .await - .context("joining FastEmbed initialisation task")??; + { + Ok(result) => result?, + Err(join_error) => return Err(EmbeddingError::from(join_error)), + }; Ok(EmbeddingProvider { inner: EmbeddingInner::FastEmbed { @@ -229,7 +239,10 @@ impl EmbeddingProvider { }) } - pub fn new_hashed(dimension: usize) -> Result { + /// # Errors + /// + /// Currently infallible; reserved for future validation. + pub fn new_hashed(dimension: usize) -> Result { Ok(EmbeddingProvider { inner: EmbeddingInner::Hashed { dimension: dimension.max(1), @@ -242,24 +255,32 @@ impl EmbeddingProvider { /// Model name and dimensions come from [`SystemSettings`]. The active backend is taken /// from `config.embedding_backend` at startup; [`SystemSettings::sync_from_embedding_provider`] /// persists the resolved backend to the database. + /// + /// # Errors + /// + /// Returns [`EmbeddingError`] if the selected backend cannot be initialised. pub async fn from_system_settings( settings: &SystemSettings, config: &AppConfig, openai_client: Option>>, - ) -> Result { + ) -> Result { let dimensions = settings.embedding_dimensions; match config.embedding_backend { EmbeddingBackend::OpenAI => { - let client = openai_client - .ok_or_else(|| anyhow!("OpenAI embedding backend requires an OpenAI client"))?; + let client = openai_client.ok_or_else(|| { + EmbeddingError::Config( + "openai embedding backend requires an openai client".into(), + ) + })?; Self::new_openai(client, settings.embedding_model.clone(), dimensions) } EmbeddingBackend::FastEmbed => { Self::new_fastembed(Some(settings.embedding_model.clone())).await } EmbeddingBackend::Hashed => { - let dimension = usize::try_from(dimensions) - .map_err(|_| anyhow!("embedding_dimensions exceeds usize::MAX"))?; + let dimension = usize::try_from(dimensions).map_err(|_| { + EmbeddingError::Config("embedding_dimensions exceeds usize::MAX".into()) + })?; Self::new_hashed(dimension) } } @@ -312,15 +333,12 @@ fn bucket(token: &str, dimension: usize) -> usize { /// /// # Errors /// -/// Returns [`AppError::InternalError`] if the provider's embed call fails. +/// Returns [`AppError::Embedding`] if the provider's embed call fails. pub async fn generate_embedding_with_provider( provider: &EmbeddingProvider, input: &str, ) -> Result, AppError> { - provider - .embed(input) - .await - .map_err(AppError::internal) + provider.embed(input).await.map_err(Into::into) } /// Generates an embedding vector for the given input text using `OpenAI`'s embedding model. diff --git a/html-router/src/html_state.rs b/html-router/src/html_state.rs index c4bcdee..cbf006e 100644 --- a/html-router/src/html_state.rs +++ b/html-router/src/html_state.rs @@ -2,8 +2,11 @@ use common::storage::types::conversation::SidebarConversation; use common::storage::{db::SurrealDbClient, store::StorageManager}; use common::utils::embedding::EmbeddingProvider; use common::utils::template_engine::{ProvidesTemplateEngine, TemplateEngine}; -use common::{create_template_engine, storage::db::ProvidesDb, utils::config::AppConfig}; -use retrieval_pipeline::{reranking::RerankerPool, RetrievalStrategy}; +use common::{ + create_template_engine, storage::db::ProvidesDb, + utils::config::{AppConfig, RetrievalStrategy}, +}; +use retrieval_pipeline::reranking::RerankerPool; use std::collections::HashMap; use std::sync::{ atomic::{AtomicUsize, Ordering}, @@ -16,6 +19,7 @@ use tracing::debug; use crate::{OpenAIClientType, SessionStoreType}; #[derive(Clone)] +/// Shared application state for HTML handlers and middleware. pub struct HtmlState { pub db: Arc, pub openai_client: Arc, @@ -31,7 +35,7 @@ pub struct HtmlState { #[derive(Clone)] struct ConversationArchiveCacheEntry { - conversations: Vec, + conversations: Arc<[SidebarConversation]>, expires_at: Instant, } @@ -72,23 +76,19 @@ impl HtmlState { } pub fn retrieval_strategy(&self) -> RetrievalStrategy { - self.config - .retrieval_strategy - .as_deref() - .and_then(|value| value.parse().ok()) - .unwrap_or(RetrievalStrategy::Default) + self.config.resolved_retrieval_strategy() } pub async fn get_cached_conversation_archive( &self, user_id: &str, - ) -> Option> { + ) -> Option> { let now = Instant::now(); let should_evict_expired = { let cache = self.conversation_archive_cache.read().await; if let Some(entry) = cache.get(user_id) { if entry.expires_at > now { - return Some(entry.conversations.clone()); + return Some(Arc::clone(&entry.conversations)); } true } else { @@ -107,7 +107,7 @@ impl HtmlState { pub async fn set_cached_conversation_archive( &self, user_id: &str, - conversations: Vec, + conversations: Arc<[SidebarConversation]>, ) { let now = Instant::now(); let mut cache = self.conversation_archive_cache.write().await; @@ -235,10 +235,10 @@ mod tests { cache.insert( user_id.to_string(), ConversationArchiveCacheEntry { - conversations: vec![SidebarConversation { + conversations: Arc::from([SidebarConversation { id: "conv-1".to_string(), title: "A stale chat".to_string(), - }], + }]), expires_at: Instant::now() - Duration::from_secs(1), }, ); diff --git a/html-router/src/routes/search/handlers.rs b/html-router/src/routes/search/handlers.rs index 248fd2b..4629b36 100644 --- a/html-router/src/routes/search/handlers.rs +++ b/html-router/src/routes/search/handlers.rs @@ -20,6 +20,7 @@ use crate::{ auth_middleware::RequireUser, response_middleware::{HtmlError, TemplateResponse}, }, + utils::truncate::{first_non_empty_line, truncate_with_ellipsis}, }; /// Serde deserialization decorator to map empty Strings to None, @@ -41,31 +42,6 @@ fn source_id_suffix(source_id: &str) -> String { source_id[start..].to_string() } -fn truncate_label(value: &str, max_chars: usize) -> String { - let mut end = None; - for (count, (idx, _)) in value.char_indices().enumerate() { - if count == max_chars { - end = Some(idx); - break; - } - } - - match end { - Some(idx) => format!("{}...", &value[..idx]), - None => value.to_string(), - } -} - -fn first_non_empty_line(text: &str, max_chars: usize) -> Option { - for line in text.lines() { - let trimmed = line.trim(); - if !trimmed.is_empty() { - return Some(truncate_label(trimmed, max_chars)); - } - } - None -} - #[derive(Deserialize)] struct UrlInfoLabel { #[serde(default)] @@ -121,7 +97,7 @@ fn build_source_label(row: &SourceLabelRow) -> String { if let Some(context) = row.context.as_ref() { let trimmed = context.trim(); if !trimmed.is_empty() { - return truncate_label(trimmed, MAX_LABEL_CHARS); + return truncate_with_ellipsis(trimmed, MAX_LABEL_CHARS); } } @@ -131,7 +107,7 @@ fn build_source_label(row: &SourceLabelRow) -> String { let category = row.category.trim(); if !category.is_empty() { - return truncate_label(category, MAX_LABEL_CHARS); + return truncate_with_ellipsis(category, MAX_LABEL_CHARS); } format!("Text snippet: {}", source_id_suffix(&row.id)) diff --git a/ingestion-pipeline/src/pipeline/enrichment_result.rs b/ingestion-pipeline/src/pipeline/enrichment_result.rs index a2003b2..47c596a 100644 --- a/ingestion-pipeline/src/pipeline/enrichment_result.rs +++ b/ingestion-pipeline/src/pipeline/enrichment_result.rs @@ -158,10 +158,7 @@ async fn create_single_entity( ); let embedding = if let Some(provider) = embedding_provider { - provider - .embed(&embedding_input) - .await - .map_err(|e| AppError::InternalError(format!("FastEmbed embedding for entity failed: {e}")))? + provider.embed(&embedding_input).await? } else { generate_embedding(openai_client, &embedding_input, db_client).await? }; diff --git a/retrieval-pipeline/src/pipeline/config.rs b/retrieval-pipeline/src/pipeline/config.rs index cee50eb..1c77d39 100644 --- a/retrieval-pipeline/src/pipeline/config.rs +++ b/retrieval-pipeline/src/pipeline/config.rs @@ -1,21 +1,8 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use std::fmt; use crate::scoring::FusionWeights; -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] -#[serde(rename_all = "snake_case")] -pub enum RetrievalStrategy { - /// Primary hybrid chunk retrieval for search/chat (formerly Revised) - #[default] - Default, - /// Entity retrieval for suggesting relationships when creating manual entities - RelationshipSuggestion, - /// Entity retrieval for context during content ingestion - Ingestion, - /// Unified search returning both chunks and entities - Search, -} +pub use common::utils::config::RetrievalStrategy; /// Configures which result types to include in Search strategy #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] @@ -30,41 +17,6 @@ pub enum SearchTarget { Both, } -impl std::str::FromStr for RetrievalStrategy { - type Err = String; - - fn from_str(value: &str) -> Result { - match value.to_ascii_lowercase().as_str() { - "default" => Ok(Self::Default), - // Backward compatibility: treat "initial" and "revised" as "default" - "initial" | "revised" => { - tracing::warn!( - "Retrieval strategy '{}' is deprecated. Use 'default' instead. \ - The 'initial' strategy has been removed in favor of the simpler hybrid chunk retrieval.", - value - ); - Ok(Self::Default) - } - "relationship_suggestion" => Ok(Self::RelationshipSuggestion), - "ingestion" => Ok(Self::Ingestion), - "search" => Ok(Self::Search), - other => Err(format!("unknown retrieval strategy '{other}'")), - } - } -} - -impl fmt::Display for RetrievalStrategy { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let label = match self { - RetrievalStrategy::Default => "default", - RetrievalStrategy::RelationshipSuggestion => "relationship_suggestion", - RetrievalStrategy::Ingestion => "ingestion", - RetrievalStrategy::Search => "search", - }; - f.write_str(label) - } -} - /// Two-variant flag that serializes as a bool for backward compatibility. #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum BoolFlag { diff --git a/retrieval-pipeline/src/pipeline/stages/mod.rs b/retrieval-pipeline/src/pipeline/stages/mod.rs index 60ab164..791d3ea 100644 --- a/retrieval-pipeline/src/pipeline/stages/mod.rs +++ b/retrieval-pipeline/src/pipeline/stages/mod.rs @@ -256,11 +256,7 @@ pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { } else { debug!("Generating query embedding for hybrid retrieval"); let embedding = if let Some(provider) = ctx.embedding_provider { - provider.embed(&ctx.input_text).await.map_err(|e| { - AppError::InternalError(format!( - "Failed to generate embedding with provider: {e}", - )) - })? + provider.embed(&ctx.input_text).await? } else { generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await? };