chore: centralize embedding errors, retrieval strategy, and test DB helpers.

Replace anyhow in embedding production code with EmbeddingError, move
RetrievalStrategy into common config, and deduplicate Surreal test setup
via common::test_utils.
This commit is contained in:
Per Stark
2026-05-29 14:35:07 +02:00
parent e3bb2935d0
commit d3443d4153
17 changed files with 366 additions and 304 deletions
+54 -36
View File
@@ -5,13 +5,12 @@ use std::{
sync::{Arc, Mutex},
};
use anyhow::{anyhow, Context, Result};
use async_openai::{types::CreateEmbeddingRequestArgs, Client};
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
use tracing::debug;
use crate::{
error::AppError,
error::{AppError, EmbeddingError},
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
utils::config::AppConfig,
};
@@ -57,16 +56,18 @@ enum EmbeddingInner {
async fn run_fastembed(
model: Arc<Mutex<TextEmbedding>>,
texts: Vec<String>,
) -> Result<Vec<Vec<f32>>> {
tokio::task::spawn_blocking(move || {
) -> Result<Vec<Vec<f32>>, EmbeddingError> {
match tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>, EmbeddingError> {
let mut guard = model
.lock()
.map_err(|e| anyhow!("fastembed model mutex poisoned: {e}"))?;
guard.embed(texts, None)
.map_err(EmbeddingError::mutex_poisoned)?;
guard.embed(texts, None).map_err(EmbeddingError::fastembed)
})
.await
.context("joining fastembed embedding task")?
.context("generating fastembed embeddings")
{
Ok(result) => result,
Err(join_error) => Err(EmbeddingError::from(join_error)),
}
}
impl EmbeddingProvider {
@@ -102,17 +103,14 @@ impl EmbeddingProvider {
///
/// # Errors
///
/// Returns `Err` if the backend API call fails, FastEmbed initialisation fails,
/// Returns [`EmbeddingError`] if the backend API call fails, FastEmbed initialisation fails,
/// or the backend returns no embedding data.
pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
pub async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
match &self.inner {
EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)),
EmbeddingInner::FastEmbed { model, .. } => {
let embeddings = run_fastembed(Arc::clone(model), vec![text.to_owned()]).await?;
embeddings
.into_iter()
.next()
.ok_or_else(|| anyhow!("fastembed returned no embedding for input"))
embeddings.into_iter().next().ok_or(EmbeddingError::NoData)
}
EmbeddingInner::OpenAI {
client,
@@ -130,7 +128,7 @@ impl EmbeddingProvider {
let embedding = response
.data
.first()
.ok_or_else(|| anyhow!("No embedding data received from OpenAI API"))?
.ok_or(EmbeddingError::NoData)?
.embedding
.clone();
@@ -143,9 +141,9 @@ impl EmbeddingProvider {
///
/// # Errors
///
/// Returns `Err` if the backend API call fails or returns no embedding data.
/// Returns [`EmbeddingError`] if the backend API call fails or returns no embedding data.
/// Returns an empty `Vec` when `texts` is empty.
pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, EmbeddingError> {
match &self.inner {
EmbeddingInner::Hashed { dimension } => Ok(texts
.into_iter()
@@ -185,11 +183,14 @@ impl EmbeddingProvider {
}
}
/// # Errors
///
/// Currently infallible; reserved for future validation.
pub fn new_openai(
client: Arc<Client<async_openai::config::OpenAIConfig>>,
model: String,
dimensions: u32,
) -> Result<Self> {
) -> Result<Self, EmbeddingError> {
Ok(Self {
inner: EmbeddingInner::OpenAI {
client,
@@ -199,9 +200,12 @@ impl EmbeddingProvider {
})
}
pub async fn new_fastembed(model_override: Option<String>) -> Result<Self> {
/// # Errors
///
/// Returns [`EmbeddingError`] if the model name is unknown or FastEmbed initialisation fails.
pub async fn new_fastembed(model_override: Option<String>) -> Result<Self, EmbeddingError> {
let model_name = if let Some(code) = model_override {
EmbeddingModel::from_str(&code).map_err(|err| anyhow!(err))?
EmbeddingModel::from_str(&code).map_err(EmbeddingError::UnknownModel)?
} else {
EmbeddingModel::default()
};
@@ -210,15 +214,21 @@ impl EmbeddingProvider {
let model_name_for_task = model_name.clone();
let model_name_code = model_name.to_string();
let (model, dimension) = tokio::task::spawn_blocking(move || -> Result<_> {
let (model, dimension) = match tokio::task::spawn_blocking(move || -> Result<_, EmbeddingError> {
let model =
TextEmbedding::try_new(options).context("initialising FastEmbed text model")?;
let info = EmbeddingModel::get_model_info(&model_name_for_task)
.ok_or_else(|| anyhow!("FastEmbed model metadata missing for {model_name_code}"))?;
TextEmbedding::try_new(options).map_err(EmbeddingError::fastembed)?;
let info = EmbeddingModel::get_model_info(&model_name_for_task).ok_or_else(|| {
EmbeddingError::Config(format!(
"fastembed model metadata missing for {model_name_code}"
))
})?;
Ok((model, info.dim))
})
.await
.context("joining FastEmbed initialisation task")??;
{
Ok(result) => result?,
Err(join_error) => return Err(EmbeddingError::from(join_error)),
};
Ok(EmbeddingProvider {
inner: EmbeddingInner::FastEmbed {
@@ -229,7 +239,10 @@ impl EmbeddingProvider {
})
}
pub fn new_hashed(dimension: usize) -> Result<Self> {
/// # Errors
///
/// Currently infallible; reserved for future validation.
pub fn new_hashed(dimension: usize) -> Result<Self, EmbeddingError> {
Ok(EmbeddingProvider {
inner: EmbeddingInner::Hashed {
dimension: dimension.max(1),
@@ -242,24 +255,32 @@ impl EmbeddingProvider {
/// Model name and dimensions come from [`SystemSettings`]. The active backend is taken
/// from `config.embedding_backend` at startup; [`SystemSettings::sync_from_embedding_provider`]
/// persists the resolved backend to the database.
///
/// # Errors
///
/// Returns [`EmbeddingError`] if the selected backend cannot be initialised.
pub async fn from_system_settings(
settings: &SystemSettings,
config: &AppConfig,
openai_client: Option<Arc<Client<async_openai::config::OpenAIConfig>>>,
) -> Result<Self> {
) -> Result<Self, EmbeddingError> {
let dimensions = settings.embedding_dimensions;
match config.embedding_backend {
EmbeddingBackend::OpenAI => {
let client = openai_client
.ok_or_else(|| anyhow!("OpenAI embedding backend requires an OpenAI client"))?;
let client = openai_client.ok_or_else(|| {
EmbeddingError::Config(
"openai embedding backend requires an openai client".into(),
)
})?;
Self::new_openai(client, settings.embedding_model.clone(), dimensions)
}
EmbeddingBackend::FastEmbed => {
Self::new_fastembed(Some(settings.embedding_model.clone())).await
}
EmbeddingBackend::Hashed => {
let dimension = usize::try_from(dimensions)
.map_err(|_| anyhow!("embedding_dimensions exceeds usize::MAX"))?;
let dimension = usize::try_from(dimensions).map_err(|_| {
EmbeddingError::Config("embedding_dimensions exceeds usize::MAX".into())
})?;
Self::new_hashed(dimension)
}
}
@@ -312,15 +333,12 @@ fn bucket(token: &str, dimension: usize) -> usize {
///
/// # Errors
///
/// Returns [`AppError::InternalError`] if the provider's embed call fails.
/// Returns [`AppError::Embedding`] if the provider's embed call fails.
pub async fn generate_embedding_with_provider(
provider: &EmbeddingProvider,
input: &str,
) -> Result<Vec<f32>, AppError> {
provider
.embed(input)
.await
.map_err(AppError::internal)
provider.embed(input).await.map_err(Into::into)
}
/// Generates an embedding vector for the given input text using `OpenAI`'s embedding model.