mirror of
https://github.com/perstarkse/minne.git
synced 2026-06-22 06:29:33 +02:00
chore: centralize embedding errors, retrieval strategy, and test DB helpers.
Replace anyhow in embedding production code with EmbeddingError, move RetrievalStrategy into common config, and deduplicate Surreal test setup via common::test_utils.
This commit is contained in:
@@ -5,13 +5,12 @@ use std::{
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use async_openai::{types::CreateEmbeddingRequestArgs, Client};
|
||||
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
|
||||
use tracing::debug;
|
||||
|
||||
use crate::{
|
||||
error::AppError,
|
||||
error::{AppError, EmbeddingError},
|
||||
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
|
||||
utils::config::AppConfig,
|
||||
};
|
||||
@@ -57,16 +56,18 @@ enum EmbeddingInner {
|
||||
async fn run_fastembed(
|
||||
model: Arc<Mutex<TextEmbedding>>,
|
||||
texts: Vec<String>,
|
||||
) -> Result<Vec<Vec<f32>>> {
|
||||
tokio::task::spawn_blocking(move || {
|
||||
) -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||
match tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||
let mut guard = model
|
||||
.lock()
|
||||
.map_err(|e| anyhow!("fastembed model mutex poisoned: {e}"))?;
|
||||
guard.embed(texts, None)
|
||||
.map_err(EmbeddingError::mutex_poisoned)?;
|
||||
guard.embed(texts, None).map_err(EmbeddingError::fastembed)
|
||||
})
|
||||
.await
|
||||
.context("joining fastembed embedding task")?
|
||||
.context("generating fastembed embeddings")
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(join_error) => Err(EmbeddingError::from(join_error)),
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingProvider {
|
||||
@@ -102,17 +103,14 @@ impl EmbeddingProvider {
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the backend API call fails, FastEmbed initialisation fails,
|
||||
/// Returns [`EmbeddingError`] if the backend API call fails, FastEmbed initialisation fails,
|
||||
/// or the backend returns no embedding data.
|
||||
pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
|
||||
pub async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
|
||||
match &self.inner {
|
||||
EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)),
|
||||
EmbeddingInner::FastEmbed { model, .. } => {
|
||||
let embeddings = run_fastembed(Arc::clone(model), vec![text.to_owned()]).await?;
|
||||
embeddings
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("fastembed returned no embedding for input"))
|
||||
embeddings.into_iter().next().ok_or(EmbeddingError::NoData)
|
||||
}
|
||||
EmbeddingInner::OpenAI {
|
||||
client,
|
||||
@@ -130,7 +128,7 @@ impl EmbeddingProvider {
|
||||
let embedding = response
|
||||
.data
|
||||
.first()
|
||||
.ok_or_else(|| anyhow!("No embedding data received from OpenAI API"))?
|
||||
.ok_or(EmbeddingError::NoData)?
|
||||
.embedding
|
||||
.clone();
|
||||
|
||||
@@ -143,9 +141,9 @@ impl EmbeddingProvider {
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the backend API call fails or returns no embedding data.
|
||||
/// Returns [`EmbeddingError`] if the backend API call fails or returns no embedding data.
|
||||
/// Returns an empty `Vec` when `texts` is empty.
|
||||
pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
|
||||
pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||
match &self.inner {
|
||||
EmbeddingInner::Hashed { dimension } => Ok(texts
|
||||
.into_iter()
|
||||
@@ -185,11 +183,14 @@ impl EmbeddingProvider {
|
||||
}
|
||||
}
|
||||
|
||||
/// # Errors
|
||||
///
|
||||
/// Currently infallible; reserved for future validation.
|
||||
pub fn new_openai(
|
||||
client: Arc<Client<async_openai::config::OpenAIConfig>>,
|
||||
model: String,
|
||||
dimensions: u32,
|
||||
) -> Result<Self> {
|
||||
) -> Result<Self, EmbeddingError> {
|
||||
Ok(Self {
|
||||
inner: EmbeddingInner::OpenAI {
|
||||
client,
|
||||
@@ -199,9 +200,12 @@ impl EmbeddingProvider {
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn new_fastembed(model_override: Option<String>) -> Result<Self> {
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`EmbeddingError`] if the model name is unknown or FastEmbed initialisation fails.
|
||||
pub async fn new_fastembed(model_override: Option<String>) -> Result<Self, EmbeddingError> {
|
||||
let model_name = if let Some(code) = model_override {
|
||||
EmbeddingModel::from_str(&code).map_err(|err| anyhow!(err))?
|
||||
EmbeddingModel::from_str(&code).map_err(EmbeddingError::UnknownModel)?
|
||||
} else {
|
||||
EmbeddingModel::default()
|
||||
};
|
||||
@@ -210,15 +214,21 @@ impl EmbeddingProvider {
|
||||
let model_name_for_task = model_name.clone();
|
||||
let model_name_code = model_name.to_string();
|
||||
|
||||
let (model, dimension) = tokio::task::spawn_blocking(move || -> Result<_> {
|
||||
let (model, dimension) = match tokio::task::spawn_blocking(move || -> Result<_, EmbeddingError> {
|
||||
let model =
|
||||
TextEmbedding::try_new(options).context("initialising FastEmbed text model")?;
|
||||
let info = EmbeddingModel::get_model_info(&model_name_for_task)
|
||||
.ok_or_else(|| anyhow!("FastEmbed model metadata missing for {model_name_code}"))?;
|
||||
TextEmbedding::try_new(options).map_err(EmbeddingError::fastembed)?;
|
||||
let info = EmbeddingModel::get_model_info(&model_name_for_task).ok_or_else(|| {
|
||||
EmbeddingError::Config(format!(
|
||||
"fastembed model metadata missing for {model_name_code}"
|
||||
))
|
||||
})?;
|
||||
Ok((model, info.dim))
|
||||
})
|
||||
.await
|
||||
.context("joining FastEmbed initialisation task")??;
|
||||
{
|
||||
Ok(result) => result?,
|
||||
Err(join_error) => return Err(EmbeddingError::from(join_error)),
|
||||
};
|
||||
|
||||
Ok(EmbeddingProvider {
|
||||
inner: EmbeddingInner::FastEmbed {
|
||||
@@ -229,7 +239,10 @@ impl EmbeddingProvider {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new_hashed(dimension: usize) -> Result<Self> {
|
||||
/// # Errors
|
||||
///
|
||||
/// Currently infallible; reserved for future validation.
|
||||
pub fn new_hashed(dimension: usize) -> Result<Self, EmbeddingError> {
|
||||
Ok(EmbeddingProvider {
|
||||
inner: EmbeddingInner::Hashed {
|
||||
dimension: dimension.max(1),
|
||||
@@ -242,24 +255,32 @@ impl EmbeddingProvider {
|
||||
/// Model name and dimensions come from [`SystemSettings`]. The active backend is taken
|
||||
/// from `config.embedding_backend` at startup; [`SystemSettings::sync_from_embedding_provider`]
|
||||
/// persists the resolved backend to the database.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`EmbeddingError`] if the selected backend cannot be initialised.
|
||||
pub async fn from_system_settings(
|
||||
settings: &SystemSettings,
|
||||
config: &AppConfig,
|
||||
openai_client: Option<Arc<Client<async_openai::config::OpenAIConfig>>>,
|
||||
) -> Result<Self> {
|
||||
) -> Result<Self, EmbeddingError> {
|
||||
let dimensions = settings.embedding_dimensions;
|
||||
match config.embedding_backend {
|
||||
EmbeddingBackend::OpenAI => {
|
||||
let client = openai_client
|
||||
.ok_or_else(|| anyhow!("OpenAI embedding backend requires an OpenAI client"))?;
|
||||
let client = openai_client.ok_or_else(|| {
|
||||
EmbeddingError::Config(
|
||||
"openai embedding backend requires an openai client".into(),
|
||||
)
|
||||
})?;
|
||||
Self::new_openai(client, settings.embedding_model.clone(), dimensions)
|
||||
}
|
||||
EmbeddingBackend::FastEmbed => {
|
||||
Self::new_fastembed(Some(settings.embedding_model.clone())).await
|
||||
}
|
||||
EmbeddingBackend::Hashed => {
|
||||
let dimension = usize::try_from(dimensions)
|
||||
.map_err(|_| anyhow!("embedding_dimensions exceeds usize::MAX"))?;
|
||||
let dimension = usize::try_from(dimensions).map_err(|_| {
|
||||
EmbeddingError::Config("embedding_dimensions exceeds usize::MAX".into())
|
||||
})?;
|
||||
Self::new_hashed(dimension)
|
||||
}
|
||||
}
|
||||
@@ -312,15 +333,12 @@ fn bucket(token: &str, dimension: usize) -> usize {
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`AppError::InternalError`] if the provider's embed call fails.
|
||||
/// Returns [`AppError::Embedding`] if the provider's embed call fails.
|
||||
pub async fn generate_embedding_with_provider(
|
||||
provider: &EmbeddingProvider,
|
||||
input: &str,
|
||||
) -> Result<Vec<f32>, AppError> {
|
||||
provider
|
||||
.embed(input)
|
||||
.await
|
||||
.map_err(AppError::internal)
|
||||
provider.embed(input).await.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Generates an embedding vector for the given input text using `OpenAI`'s embedding model.
|
||||
|
||||
Reference in New Issue
Block a user