chore: centralize embedding errors, retrieval strategy, and test DB helpers.

Replace anyhow in embedding production code with EmbeddingError, move
RetrievalStrategy into common config, and deduplicate Surreal test setup
via common::test_utils.
This commit is contained in:
Per Stark
2026-05-29 14:35:07 +02:00
parent e3bb2935d0
commit d3443d4153
17 changed files with 366 additions and 304 deletions
+32
View File
@@ -4,6 +4,36 @@ use tokio::task::JoinError;
use crate::storage::types::file_info::FileError;
/// Errors from embedding provider operations.
#[allow(clippy::module_name_repetitions)]
#[derive(Error, Debug)]
pub enum EmbeddingError {
#[error("openai error: {0}")]
OpenAI(#[from] OpenAIError),
#[error("fastembed error: {0}")]
FastEmbed(String),
#[error("task join error: {0}")]
Join(#[from] JoinError),
#[error("fastembed model mutex poisoned: {0}")]
MutexPoisoned(String),
#[error("no embedding data received")]
NoData,
#[error("embedding configuration error: {0}")]
Config(String),
#[error("unknown fastembed model: {0}")]
UnknownModel(String),
}
impl EmbeddingError {
pub(crate) fn fastembed(err: impl std::fmt::Display) -> Self {
Self::FastEmbed(err.to_string())
}
pub(crate) fn mutex_poisoned(err: impl std::fmt::Display) -> Self {
Self::MutexPoisoned(err.to_string())
}
}
// Core internal errors
#[allow(clippy::module_name_repetitions)]
#[derive(Error, Debug)]
@@ -12,6 +42,8 @@ pub enum AppError {
Database(#[from] surrealdb::Error),
#[error("openai error: {0}")]
OpenAI(#[from] OpenAIError),
#[error("embedding error: {0}")]
Embedding(#[from] EmbeddingError),
#[error("file error: {0}")]
File(#[from] FileError),
#[error("not found: {0}")]
+3
View File
@@ -3,3 +3,6 @@
pub mod error;
pub mod storage;
pub mod utils;
#[cfg(any(test, feature = "test-utils"))]
pub mod test_utils;
+2 -13
View File
@@ -499,8 +499,7 @@ impl KnowledgeEntity {
let embedding = provider
.embed(&embedding_input)
.await
.map_err(AppError::internal)?;
.await?;
// Safety check: ensure the generated embedding has the correct dimension.
if embedding.len() != new_dimensions {
@@ -599,21 +598,11 @@ mod tests {
#![allow(clippy::expect_used, clippy::must_use_candidate)]
use super::*;
use crate::storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding;
use crate::storage::types::system_settings::SystemSettings;
use crate::test_utils::configure_embedding_dimension;
use anyhow::{self, Context};
use serde_json::json;
use uuid::Uuid;
async fn configure_embedding_dimension(
db: &SurrealDbClient,
dimension: u32,
) -> anyhow::Result<()> {
let mut settings = SystemSettings::get_current(db).await?;
settings.embedding_dimensions = dimension;
SystemSettings::update(db, settings).await?;
Ok(())
}
#[tokio::test]
async fn test_knowledge_entity_creation() -> anyhow::Result<()> {
let source_id = "source123".to_string();
@@ -158,45 +158,11 @@ impl KnowledgeEntityEmbedding {
mod tests {
#![allow(clippy::expect_used, clippy::must_use_candidate)]
use super::*;
use crate::storage::db::SurrealDbClient;
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
use crate::storage::types::system_settings::SystemSettings;
use crate::test_utils::{prepare_knowledge_entity_test_db, setup_test_db};
use anyhow::{self, Context};
use chrono::Utc;
use surrealdb::Value as SurrealValue;
use uuid::Uuid;
async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
let namespace = "test_ns";
let database = Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, &database)
.await
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations()
.await
.with_context(|| "Failed to apply migrations".to_string())?;
Ok(db)
}
async fn setup_test_db_with_embedding_dimension(
dimension: u32,
) -> anyhow::Result<SurrealDbClient> {
let db = setup_test_db().await?;
let mut settings = SystemSettings::get_current(&db).await?;
settings.embedding_dimensions = dimension;
SystemSettings::update(&db, settings).await?;
Ok(db)
}
async fn prepare_test_db(dimension: u32) -> anyhow::Result<SurrealDbClient> {
let db = setup_test_db_with_embedding_dimension(dimension).await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, dimension as usize)
.await
.with_context(|| format!("set test index dimension to {dimension}"))?;
Ok(db)
}
fn build_knowledge_entity_with_id(
key: &str,
@@ -236,7 +202,7 @@ mod tests {
#[tokio::test]
async fn test_create_and_get_by_entity_id() -> anyhow::Result<()> {
let db = prepare_test_db(3).await?;
let db = prepare_knowledge_entity_test_db(3).await?;
let user_id = "user_ke";
let entity_key = "entity-1";
let source_id = "source-ke";
@@ -266,7 +232,7 @@ mod tests {
#[tokio::test]
async fn test_delete_by_entity_id() -> anyhow::Result<()> {
let db = prepare_test_db(3).await?;
let db = prepare_knowledge_entity_test_db(3).await?;
let user_id = "user_ke";
let entity_key = "entity-delete";
let source_id = "source-del";
@@ -298,7 +264,7 @@ mod tests {
#[tokio::test]
async fn test_store_with_embedding_creates_entity_and_embedding() -> anyhow::Result<()> {
let db = prepare_test_db(3).await?;
let db = prepare_knowledge_entity_test_db(3).await?;
let user_id = "user_store";
let source_id = "source_store";
let embedding = vec![0.2_f32, 0.3, 0.4];
@@ -331,7 +297,7 @@ mod tests {
#[tokio::test]
async fn test_store_with_embedding_rejects_wrong_dimension() -> anyhow::Result<()> {
let db = prepare_test_db(3).await?;
let db = prepare_knowledge_entity_test_db(3).await?;
let entity = build_knowledge_entity_with_id("entity-dim", "source-dim", "user-dim");
let result =
@@ -344,7 +310,7 @@ mod tests {
#[tokio::test]
async fn test_delete_by_source_id() -> anyhow::Result<()> {
let db = prepare_test_db(3).await?;
let db = prepare_knowledge_entity_test_db(3).await?;
let user_id = "user_ke";
let source_id = "shared-ke";
let other_source = "other-ke";
@@ -437,7 +403,7 @@ mod tests {
entity_id: KnowledgeEntity,
}
let db = prepare_test_db(3).await?;
let db = prepare_knowledge_entity_test_db(3).await?;
let user_id = "user_ke";
let entity_key = "entity-fetch";
let source_id = "source-fetch";
@@ -475,7 +441,7 @@ mod tests {
#[tokio::test]
async fn test_upsert_replaces_existing_embedding_row() -> anyhow::Result<()> {
let db = prepare_test_db(3).await?;
let db = prepare_knowledge_entity_test_db(3).await?;
let user_id = "user-upsert";
let source_id = "source-upsert";
@@ -151,19 +151,7 @@ mod tests {
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
use anyhow::{self, Context};
async fn setup_test_db() -> SurrealDbClient {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
db
}
use crate::test_utils::setup_test_db;
async fn get_relationship_by_id(
relationship_id: &str,
@@ -234,7 +222,7 @@ mod tests {
#[tokio::test]
async fn test_store_and_verify_by_source_id() -> anyhow::Result<()> {
let db = setup_test_db().await;
let db = setup_test_db().await?;
let user_id = "user123";
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
@@ -282,7 +270,7 @@ mod tests {
#[tokio::test]
async fn test_store_relationship_rejects_foreign_entity() -> anyhow::Result<()> {
let db = setup_test_db().await;
let db = setup_test_db().await?;
let owner_entity = create_test_entity("Owner entity", "owner-user", &db).await?;
let other_entity = create_test_entity("Other entity", "other-user", &db).await?;
@@ -303,7 +291,7 @@ mod tests {
#[tokio::test]
async fn test_store_relationship_resists_query_injection() -> anyhow::Result<()> {
let db = setup_test_db().await;
let db = setup_test_db().await?;
let user_id = "user123";
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
@@ -342,7 +330,7 @@ mod tests {
#[tokio::test]
async fn test_store_and_delete_relationship() -> anyhow::Result<()> {
let db = setup_test_db().await;
let db = setup_test_db().await?;
let user_id = "user123";
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
@@ -396,7 +384,7 @@ mod tests {
#[tokio::test]
async fn test_delete_relationship_by_id_unauthorized() -> anyhow::Result<()> {
let db = setup_test_db().await;
let db = setup_test_db().await?;
let owner_user_id = "owner-user";
let entity1_id = create_test_entity("Entity 1", owner_user_id, &db).await?;
@@ -459,7 +447,7 @@ mod tests {
#[tokio::test]
async fn test_store_relationship_exists() -> anyhow::Result<()> {
let db = setup_test_db().await;
let db = setup_test_db().await?;
let user_id = "user123";
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
@@ -543,7 +531,7 @@ mod tests {
#[tokio::test]
async fn test_delete_relationships_by_source_id_scoped_to_user() -> anyhow::Result<()> {
let db = setup_test_db().await;
let db = setup_test_db().await?;
let user_a = "user-a";
let user_b = "user-b";
@@ -584,7 +572,7 @@ mod tests {
#[tokio::test]
async fn test_delete_relationships_by_source_id_resists_query_injection() -> anyhow::Result<()>
{
let db = setup_test_db().await;
let db = setup_test_db().await?;
let user_id = "user123";
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
+2 -13
View File
@@ -384,8 +384,7 @@ impl TextChunk {
let embedding = provider
.embed(&chunk.chunk)
.await
.map_err(AppError::internal)?;
.await?;
// Safety check: ensure the generated embedding has the correct dimension.
if embedding.len() != new_dimensions {
@@ -489,21 +488,11 @@ mod tests {
use super::*;
use crate::storage::indexes::{ensure_runtime, rebuild};
use crate::storage::types::system_settings::SystemSettings;
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
use crate::test_utils::configure_embedding_dimension;
use surrealdb::RecordId;
use uuid::Uuid;
async fn configure_embedding_dimension(
db: &SurrealDbClient,
dimension: u32,
) -> anyhow::Result<()> {
let mut settings = SystemSettings::get_current(db).await?;
settings.embedding_dimensions = dimension;
SystemSettings::update(db, settings).await?;
Ok(())
}
async fn ensure_chunk_fts_index(db: &SurrealDbClient) -> anyhow::Result<()> {
let snowball_sql = r#"
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english);
@@ -144,41 +144,8 @@ mod tests {
use super::*;
use crate::storage::db::SurrealDbClient;
use crate::storage::types::system_settings::SystemSettings;
use crate::test_utils::{prepare_text_chunk_test_db, setup_test_db};
use surrealdb::Value as SurrealValue;
use uuid::Uuid;
async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
let namespace = "test_ns";
let database = Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, &database)
.await
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations()
.await
.with_context(|| "Failed to apply migrations".to_string())?;
Ok(db)
}
async fn setup_test_db_with_embedding_dimension(
dimension: u32,
) -> anyhow::Result<SurrealDbClient> {
let db = setup_test_db().await?;
let mut settings = SystemSettings::get_current(&db).await?;
settings.embedding_dimensions = dimension;
SystemSettings::update(&db, settings).await?;
Ok(db)
}
async fn prepare_test_db(dimension: u32) -> anyhow::Result<SurrealDbClient> {
let db = setup_test_db_with_embedding_dimension(dimension).await?;
TextChunkEmbedding::redefine_hnsw_index(&db, dimension as usize)
.await
.with_context(|| format!("set test index dimension to {dimension}"))?;
Ok(db)
}
async fn create_text_chunk_with_id(
db: &SurrealDbClient,
@@ -245,7 +212,7 @@ mod tests {
#[tokio::test]
async fn test_create_and_get_by_chunk_id() -> anyhow::Result<()> {
let db = prepare_test_db(3).await?;
let db = prepare_text_chunk_test_db(3).await?;
let user_id = "user_a";
let chunk_key = "chunk-123";
@@ -279,7 +246,7 @@ mod tests {
#[tokio::test]
async fn test_delete_by_chunk_id() -> anyhow::Result<()> {
let db = prepare_test_db(3).await?;
let db = prepare_text_chunk_test_db(3).await?;
let user_id = "user_b";
let chunk_key = "chunk-delete";
@@ -316,7 +283,7 @@ mod tests {
#[tokio::test]
async fn test_delete_by_source_id() -> anyhow::Result<()> {
let db = prepare_test_db(1).await?;
let db = prepare_text_chunk_test_db(1).await?;
let user_id = "user_c";
let source_id = "shared-source";
@@ -377,7 +344,7 @@ mod tests {
#[tokio::test]
async fn test_upsert_replaces_existing_embedding_row() -> anyhow::Result<()> {
let db = prepare_test_db(3).await?;
let db = prepare_text_chunk_test_db(3).await?;
let user_id = "user-upsert";
let source_id = "source-upsert";
+4 -25
View File
@@ -202,25 +202,7 @@ mod tests {
use anyhow::{self, Context};
use super::*;
use crate::storage::indexes::{ensure_runtime, rebuild};
async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
let namespace = "test_ns";
let database = Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, &database)
.await
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations()
.await
.with_context(|| "Failed to apply migrations".to_string())?;
ensure_runtime(&db, 1536)
.await
.with_context(|| "ensure runtime indexes".to_string())?;
rebuild(&db)
.await
.with_context(|| "rebuild indexes".to_string())?;
Ok(db)
}
use crate::test_utils::setup_test_db_with_runtime_indexes;
#[tokio::test]
async fn test_text_content_creation() -> anyhow::Result<()> {
@@ -339,7 +321,7 @@ mod tests {
#[tokio::test]
async fn test_text_content_patch_not_found() -> anyhow::Result<()> {
let db = setup_test_db().await?;
let db = setup_test_db_with_runtime_indexes().await?;
let err = TextContent::patch("missing-id", "ctx", "cat", "text", &db)
.await
@@ -412,7 +394,7 @@ mod tests {
#[tokio::test]
async fn test_search_returns_empty_when_no_content() -> anyhow::Result<()> {
let db = setup_test_db().await?;
let db = setup_test_db_with_runtime_indexes().await?;
let results = TextContent::search(&db, "hello", "user", 5)
.await
@@ -424,7 +406,7 @@ mod tests {
#[tokio::test]
async fn test_search_finds_matching_text_and_filters_user() -> anyhow::Result<()> {
let db = setup_test_db().await?;
let db = setup_test_db_with_runtime_indexes().await?;
let user_id = "search_user";
let matching = TextContent::new(
@@ -450,9 +432,6 @@ mod tests {
db.store_item(other_user)
.await
.with_context(|| "store other user".to_string())?;
rebuild(&db)
.await
.with_context(|| "rebuild indexes".to_string())?;
let results = TextContent::search(&db, "rust", user_id, 5)
.await
+1 -14
View File
@@ -732,20 +732,7 @@ mod tests {
use crate::storage::types::ingestion_task::{IngestionTask, TaskState, MAX_ATTEMPTS};
use std::collections::HashSet;
// Helper function to set up a test database with SystemSettings
async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
let namespace = "test_ns";
let database = Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, &database)
.await
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
db.apply_migrations()
.await
.with_context(|| "Failed to setup the migrations".to_string())?;
Ok(db)
}
use crate::test_utils::setup_test_db;
#[tokio::test]
async fn test_user_creation() -> anyhow::Result<()> {
+96
View File
@@ -0,0 +1,96 @@
//! Shared helpers for in-memory SurrealDB tests.
#![cfg(any(test, feature = "test-utils"))]
use anyhow::{Context, Result};
use uuid::Uuid;
use crate::storage::{
db::SurrealDbClient,
indexes::{ensure_runtime, rebuild},
types::{
knowledge_entity_embedding::KnowledgeEntityEmbedding,
system_settings::SystemSettings,
text_chunk_embedding::TextChunkEmbedding,
},
};
const TEST_NAMESPACE: &str = "test_ns";
/// Starts an in-memory database, applies migrations, and returns a client.
///
/// # Errors
///
/// Returns an error if the database cannot be started or migrations fail.
pub async fn setup_test_db() -> Result<SurrealDbClient> {
let database = Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(TEST_NAMESPACE, &database)
.await
.context("start in-memory surrealdb")?;
db.apply_migrations()
.await
.context("apply migrations")?;
Ok(db)
}
/// Updates singleton [`SystemSettings`] embedding dimensions for tests.
///
/// # Errors
///
/// Returns an error if settings cannot be loaded or updated.
pub async fn configure_embedding_dimension(db: &SurrealDbClient, dimension: u32) -> Result<()> {
let mut settings = SystemSettings::get_current(db).await?;
settings.embedding_dimensions = dimension;
SystemSettings::update(db, settings).await?;
Ok(())
}
/// Starts a test database and sets the embedding dimension in system settings.
///
/// # Errors
///
/// Returns an error if setup or settings update fails.
pub async fn setup_test_db_with_embedding_dimension(dimension: u32) -> Result<SurrealDbClient> {
let db = setup_test_db().await?;
configure_embedding_dimension(&db, dimension).await?;
Ok(db)
}
/// Prepares a database for text-chunk embedding tests at the given dimension.
///
/// # Errors
///
/// Returns an error if setup, settings update, or index redefinition fails.
pub async fn prepare_text_chunk_test_db(dimension: u32) -> Result<SurrealDbClient> {
let db = setup_test_db_with_embedding_dimension(dimension).await?;
TextChunkEmbedding::redefine_hnsw_index(&db, dimension as usize)
.await
.with_context(|| format!("set text chunk index dimension to {dimension}"))?;
Ok(db)
}
/// Prepares a database for knowledge-entity embedding tests at the given dimension.
///
/// # Errors
///
/// Returns an error if setup, settings update, or index redefinition fails.
pub async fn prepare_knowledge_entity_test_db(dimension: u32) -> Result<SurrealDbClient> {
let db = setup_test_db_with_embedding_dimension(dimension).await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, dimension as usize)
.await
.with_context(|| format!("set knowledge entity index dimension to {dimension}"))?;
Ok(db)
}
/// Starts a test database and ensures runtime FTS/HNSW indexes are ready.
///
/// # Errors
///
/// Returns an error if setup, index creation, or rebuild fails.
pub async fn setup_test_db_with_runtime_indexes() -> Result<SurrealDbClient> {
let db = setup_test_db().await?;
ensure_runtime(&db, 1536).await?;
rebuild(&db).await?;
Ok(db)
}
+131 -4
View File
@@ -1,7 +1,8 @@
use config::{Config, ConfigError, Environment, File};
use serde::{Deserialize, Serialize};
use std::{env, sync::Once, str::FromStr};
use serde::{Deserialize, Deserializer, Serialize};
use std::{env, fmt, str::FromStr, sync::Once};
use thiserror::Error;
use tracing::warn;
/// Error returned when parsing an embedding backend name.
#[derive(Debug, Error, PartialEq, Eq)]
@@ -35,6 +36,83 @@ impl EmbeddingBackend {
}
}
/// Error returned when parsing a retrieval strategy name.
#[derive(Debug, Error, PartialEq, Eq)]
#[error("unknown retrieval strategy '{input}'")]
pub struct ParseRetrievalStrategyError {
/// The unrecognized input string.
pub input: String,
}
/// Selects which retrieval pipeline strategy to run for chat and search.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum RetrievalStrategy {
/// Primary hybrid chunk retrieval for search/chat.
#[default]
Default,
/// Entity retrieval for suggesting relationships when creating manual entities.
RelationshipSuggestion,
/// Entity retrieval for context during content ingestion.
Ingestion,
/// Unified search returning both chunks and entities.
Search,
}
impl RetrievalStrategy {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::Default => "default",
Self::RelationshipSuggestion => "relationship_suggestion",
Self::Ingestion => "ingestion",
Self::Search => "search",
}
}
}
impl FromStr for RetrievalStrategy {
type Err = ParseRetrievalStrategyError;
fn from_str(value: &str) -> Result<Self, Self::Err> {
match value.to_ascii_lowercase().as_str() {
"default" => Ok(Self::Default),
"initial" | "revised" => {
warn!(
"retrieval strategy '{value}' is deprecated; use 'default' instead"
);
Ok(Self::Default)
}
"relationship_suggestion" => Ok(Self::RelationshipSuggestion),
"ingestion" => Ok(Self::Ingestion),
"search" => Ok(Self::Search),
other => Err(ParseRetrievalStrategyError {
input: other.to_string(),
}),
}
}
}
impl fmt::Display for RetrievalStrategy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
fn deserialize_optional_retrieval_strategy<'de, D>(
deserializer: D,
) -> Result<Option<RetrievalStrategy>, D::Error>
where
D: Deserializer<'de>,
{
let value = Option::<String>::deserialize(deserializer)?;
match value {
None => Ok(None),
Some(raw) if raw.trim().is_empty() => Ok(None),
Some(raw) => RetrievalStrategy::from_str(&raw).map(Some).map_err(serde::de::Error::custom),
}
}
impl FromStr for EmbeddingBackend {
type Err = ParseEmbeddingBackendError;
@@ -117,8 +195,8 @@ pub struct AppConfig {
pub fastembed_show_download_progress: Option<bool>,
#[serde(default)]
pub fastembed_max_length: Option<usize>,
#[serde(default)]
pub retrieval_strategy: Option<String>,
#[serde(default, deserialize_with = "deserialize_optional_retrieval_strategy")]
pub retrieval_strategy: Option<RetrievalStrategy>,
#[serde(default)]
pub embedding_backend: EmbeddingBackend,
#[serde(default = "default_ingest_max_body_bytes")]
@@ -204,6 +282,14 @@ pub fn ensure_ort_path() {
});
}
impl AppConfig {
/// Returns the configured retrieval strategy, or [`RetrievalStrategy::Default`] when unset.
#[must_use]
pub fn resolved_retrieval_strategy(&self) -> RetrievalStrategy {
self.retrieval_strategy.unwrap_or_default()
}
}
impl Default for AppConfig {
fn default() -> Self {
Self {
@@ -249,3 +335,44 @@ pub fn get_config() -> Result<AppConfig, ConfigError> {
config.try_deserialize()
}
#[cfg(test)]
mod tests {
use super::{ParseRetrievalStrategyError, RetrievalStrategy};
#[test]
fn retrieval_strategy_defaults_to_default() {
assert_eq!(
RetrievalStrategy::default(),
RetrievalStrategy::Default
);
}
#[test]
fn retrieval_strategy_serializes_snake_case() {
assert_eq!(
serde_json::to_string(&RetrievalStrategy::Search).expect("serialize"),
"\"search\""
);
}
#[test]
fn retrieval_strategy_from_str_accepts_deprecated_aliases() {
assert_eq!(
"initial".parse::<RetrievalStrategy>().expect("initial"),
RetrievalStrategy::Default
);
assert!(matches!(
"unknown".parse::<RetrievalStrategy>(),
Err(ParseRetrievalStrategyError { .. })
));
}
#[test]
fn app_config_resolved_retrieval_strategy_uses_default_when_unset() {
let config = super::AppConfig::default();
assert_eq!(
config.resolved_retrieval_strategy(),
RetrievalStrategy::Default
);
}
}
+54 -36
View File
@@ -5,13 +5,12 @@ use std::{
sync::{Arc, Mutex},
};
use anyhow::{anyhow, Context, Result};
use async_openai::{types::CreateEmbeddingRequestArgs, Client};
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
use tracing::debug;
use crate::{
error::AppError,
error::{AppError, EmbeddingError},
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
utils::config::AppConfig,
};
@@ -57,16 +56,18 @@ enum EmbeddingInner {
async fn run_fastembed(
model: Arc<Mutex<TextEmbedding>>,
texts: Vec<String>,
) -> Result<Vec<Vec<f32>>> {
tokio::task::spawn_blocking(move || {
) -> Result<Vec<Vec<f32>>, EmbeddingError> {
match tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>, EmbeddingError> {
let mut guard = model
.lock()
.map_err(|e| anyhow!("fastembed model mutex poisoned: {e}"))?;
guard.embed(texts, None)
.map_err(EmbeddingError::mutex_poisoned)?;
guard.embed(texts, None).map_err(EmbeddingError::fastembed)
})
.await
.context("joining fastembed embedding task")?
.context("generating fastembed embeddings")
{
Ok(result) => result,
Err(join_error) => Err(EmbeddingError::from(join_error)),
}
}
impl EmbeddingProvider {
@@ -102,17 +103,14 @@ impl EmbeddingProvider {
///
/// # Errors
///
/// Returns `Err` if the backend API call fails, FastEmbed initialisation fails,
/// Returns [`EmbeddingError`] if the backend API call fails, FastEmbed initialisation fails,
/// or the backend returns no embedding data.
pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
pub async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
match &self.inner {
EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)),
EmbeddingInner::FastEmbed { model, .. } => {
let embeddings = run_fastembed(Arc::clone(model), vec![text.to_owned()]).await?;
embeddings
.into_iter()
.next()
.ok_or_else(|| anyhow!("fastembed returned no embedding for input"))
embeddings.into_iter().next().ok_or(EmbeddingError::NoData)
}
EmbeddingInner::OpenAI {
client,
@@ -130,7 +128,7 @@ impl EmbeddingProvider {
let embedding = response
.data
.first()
.ok_or_else(|| anyhow!("No embedding data received from OpenAI API"))?
.ok_or(EmbeddingError::NoData)?
.embedding
.clone();
@@ -143,9 +141,9 @@ impl EmbeddingProvider {
///
/// # Errors
///
/// Returns `Err` if the backend API call fails or returns no embedding data.
/// Returns [`EmbeddingError`] if the backend API call fails or returns no embedding data.
/// Returns an empty `Vec` when `texts` is empty.
pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, EmbeddingError> {
match &self.inner {
EmbeddingInner::Hashed { dimension } => Ok(texts
.into_iter()
@@ -185,11 +183,14 @@ impl EmbeddingProvider {
}
}
/// # Errors
///
/// Currently infallible; reserved for future validation.
pub fn new_openai(
client: Arc<Client<async_openai::config::OpenAIConfig>>,
model: String,
dimensions: u32,
) -> Result<Self> {
) -> Result<Self, EmbeddingError> {
Ok(Self {
inner: EmbeddingInner::OpenAI {
client,
@@ -199,9 +200,12 @@ impl EmbeddingProvider {
})
}
pub async fn new_fastembed(model_override: Option<String>) -> Result<Self> {
/// # Errors
///
/// Returns [`EmbeddingError`] if the model name is unknown or FastEmbed initialisation fails.
pub async fn new_fastembed(model_override: Option<String>) -> Result<Self, EmbeddingError> {
let model_name = if let Some(code) = model_override {
EmbeddingModel::from_str(&code).map_err(|err| anyhow!(err))?
EmbeddingModel::from_str(&code).map_err(EmbeddingError::UnknownModel)?
} else {
EmbeddingModel::default()
};
@@ -210,15 +214,21 @@ impl EmbeddingProvider {
let model_name_for_task = model_name.clone();
let model_name_code = model_name.to_string();
let (model, dimension) = tokio::task::spawn_blocking(move || -> Result<_> {
let (model, dimension) = match tokio::task::spawn_blocking(move || -> Result<_, EmbeddingError> {
let model =
TextEmbedding::try_new(options).context("initialising FastEmbed text model")?;
let info = EmbeddingModel::get_model_info(&model_name_for_task)
.ok_or_else(|| anyhow!("FastEmbed model metadata missing for {model_name_code}"))?;
TextEmbedding::try_new(options).map_err(EmbeddingError::fastembed)?;
let info = EmbeddingModel::get_model_info(&model_name_for_task).ok_or_else(|| {
EmbeddingError::Config(format!(
"fastembed model metadata missing for {model_name_code}"
))
})?;
Ok((model, info.dim))
})
.await
.context("joining FastEmbed initialisation task")??;
{
Ok(result) => result?,
Err(join_error) => return Err(EmbeddingError::from(join_error)),
};
Ok(EmbeddingProvider {
inner: EmbeddingInner::FastEmbed {
@@ -229,7 +239,10 @@ impl EmbeddingProvider {
})
}
pub fn new_hashed(dimension: usize) -> Result<Self> {
/// # Errors
///
/// Currently infallible; reserved for future validation.
pub fn new_hashed(dimension: usize) -> Result<Self, EmbeddingError> {
Ok(EmbeddingProvider {
inner: EmbeddingInner::Hashed {
dimension: dimension.max(1),
@@ -242,24 +255,32 @@ impl EmbeddingProvider {
/// Model name and dimensions come from [`SystemSettings`]. The active backend is taken
/// from `config.embedding_backend` at startup; [`SystemSettings::sync_from_embedding_provider`]
/// persists the resolved backend to the database.
///
/// # Errors
///
/// Returns [`EmbeddingError`] if the selected backend cannot be initialised.
pub async fn from_system_settings(
settings: &SystemSettings,
config: &AppConfig,
openai_client: Option<Arc<Client<async_openai::config::OpenAIConfig>>>,
) -> Result<Self> {
) -> Result<Self, EmbeddingError> {
let dimensions = settings.embedding_dimensions;
match config.embedding_backend {
EmbeddingBackend::OpenAI => {
let client = openai_client
.ok_or_else(|| anyhow!("OpenAI embedding backend requires an OpenAI client"))?;
let client = openai_client.ok_or_else(|| {
EmbeddingError::Config(
"openai embedding backend requires an openai client".into(),
)
})?;
Self::new_openai(client, settings.embedding_model.clone(), dimensions)
}
EmbeddingBackend::FastEmbed => {
Self::new_fastembed(Some(settings.embedding_model.clone())).await
}
EmbeddingBackend::Hashed => {
let dimension = usize::try_from(dimensions)
.map_err(|_| anyhow!("embedding_dimensions exceeds usize::MAX"))?;
let dimension = usize::try_from(dimensions).map_err(|_| {
EmbeddingError::Config("embedding_dimensions exceeds usize::MAX".into())
})?;
Self::new_hashed(dimension)
}
}
@@ -312,15 +333,12 @@ fn bucket(token: &str, dimension: usize) -> usize {
///
/// # Errors
///
/// Returns [`AppError::InternalError`] if the provider's embed call fails.
/// Returns [`AppError::Embedding`] if the provider's embed call fails.
pub async fn generate_embedding_with_provider(
provider: &EmbeddingProvider,
input: &str,
) -> Result<Vec<f32>, AppError> {
provider
.embed(input)
.await
.map_err(AppError::internal)
provider.embed(input).await.map_err(Into::into)
}
/// Generates an embedding vector for the given input text using `OpenAI`'s embedding model.