mirror of
https://github.com/perstarkse/minne.git
synced 2026-06-30 10:01:40 +02:00
chore: centralize embedding errors, retrieval strategy, and test DB helpers.
Replace anyhow in embedding production code with EmbeddingError, move RetrievalStrategy into common config, and deduplicate Surreal test setup via common::test_utils.
This commit is contained in:
+131
-4
@@ -1,7 +1,8 @@
|
||||
use config::{Config, ConfigError, Environment, File};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{env, sync::Once, str::FromStr};
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
use std::{env, fmt, str::FromStr, sync::Once};
|
||||
use thiserror::Error;
|
||||
use tracing::warn;
|
||||
|
||||
/// Error returned when parsing an embedding backend name.
|
||||
#[derive(Debug, Error, PartialEq, Eq)]
|
||||
@@ -35,6 +36,83 @@ impl EmbeddingBackend {
|
||||
}
|
||||
}
|
||||
|
||||
/// Error returned when parsing a retrieval strategy name.
|
||||
#[derive(Debug, Error, PartialEq, Eq)]
|
||||
#[error("unknown retrieval strategy '{input}'")]
|
||||
pub struct ParseRetrievalStrategyError {
|
||||
/// The unrecognized input string.
|
||||
pub input: String,
|
||||
}
|
||||
|
||||
/// Selects which retrieval pipeline strategy to run for chat and search.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RetrievalStrategy {
|
||||
/// Primary hybrid chunk retrieval for search/chat.
|
||||
#[default]
|
||||
Default,
|
||||
/// Entity retrieval for suggesting relationships when creating manual entities.
|
||||
RelationshipSuggestion,
|
||||
/// Entity retrieval for context during content ingestion.
|
||||
Ingestion,
|
||||
/// Unified search returning both chunks and entities.
|
||||
Search,
|
||||
}
|
||||
|
||||
impl RetrievalStrategy {
|
||||
#[must_use]
|
||||
pub fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
Self::Default => "default",
|
||||
Self::RelationshipSuggestion => "relationship_suggestion",
|
||||
Self::Ingestion => "ingestion",
|
||||
Self::Search => "search",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for RetrievalStrategy {
|
||||
type Err = ParseRetrievalStrategyError;
|
||||
|
||||
fn from_str(value: &str) -> Result<Self, Self::Err> {
|
||||
match value.to_ascii_lowercase().as_str() {
|
||||
"default" => Ok(Self::Default),
|
||||
"initial" | "revised" => {
|
||||
warn!(
|
||||
"retrieval strategy '{value}' is deprecated; use 'default' instead"
|
||||
);
|
||||
Ok(Self::Default)
|
||||
}
|
||||
"relationship_suggestion" => Ok(Self::RelationshipSuggestion),
|
||||
"ingestion" => Ok(Self::Ingestion),
|
||||
"search" => Ok(Self::Search),
|
||||
other => Err(ParseRetrievalStrategyError {
|
||||
input: other.to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for RetrievalStrategy {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(self.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
fn deserialize_optional_retrieval_strategy<'de, D>(
|
||||
deserializer: D,
|
||||
) -> Result<Option<RetrievalStrategy>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value = Option::<String>::deserialize(deserializer)?;
|
||||
match value {
|
||||
None => Ok(None),
|
||||
Some(raw) if raw.trim().is_empty() => Ok(None),
|
||||
Some(raw) => RetrievalStrategy::from_str(&raw).map(Some).map_err(serde::de::Error::custom),
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for EmbeddingBackend {
|
||||
type Err = ParseEmbeddingBackendError;
|
||||
|
||||
@@ -117,8 +195,8 @@ pub struct AppConfig {
|
||||
pub fastembed_show_download_progress: Option<bool>,
|
||||
#[serde(default)]
|
||||
pub fastembed_max_length: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub retrieval_strategy: Option<String>,
|
||||
#[serde(default, deserialize_with = "deserialize_optional_retrieval_strategy")]
|
||||
pub retrieval_strategy: Option<RetrievalStrategy>,
|
||||
#[serde(default)]
|
||||
pub embedding_backend: EmbeddingBackend,
|
||||
#[serde(default = "default_ingest_max_body_bytes")]
|
||||
@@ -204,6 +282,14 @@ pub fn ensure_ort_path() {
|
||||
});
|
||||
}
|
||||
|
||||
impl AppConfig {
|
||||
/// Returns the configured retrieval strategy, or [`RetrievalStrategy::Default`] when unset.
|
||||
#[must_use]
|
||||
pub fn resolved_retrieval_strategy(&self) -> RetrievalStrategy {
|
||||
self.retrieval_strategy.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AppConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@@ -249,3 +335,44 @@ pub fn get_config() -> Result<AppConfig, ConfigError> {
|
||||
|
||||
config.try_deserialize()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{ParseRetrievalStrategyError, RetrievalStrategy};
|
||||
#[test]
|
||||
fn retrieval_strategy_defaults_to_default() {
|
||||
assert_eq!(
|
||||
RetrievalStrategy::default(),
|
||||
RetrievalStrategy::Default
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn retrieval_strategy_serializes_snake_case() {
|
||||
assert_eq!(
|
||||
serde_json::to_string(&RetrievalStrategy::Search).expect("serialize"),
|
||||
"\"search\""
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn retrieval_strategy_from_str_accepts_deprecated_aliases() {
|
||||
assert_eq!(
|
||||
"initial".parse::<RetrievalStrategy>().expect("initial"),
|
||||
RetrievalStrategy::Default
|
||||
);
|
||||
assert!(matches!(
|
||||
"unknown".parse::<RetrievalStrategy>(),
|
||||
Err(ParseRetrievalStrategyError { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn app_config_resolved_retrieval_strategy_uses_default_when_unset() {
|
||||
let config = super::AppConfig::default();
|
||||
assert_eq!(
|
||||
config.resolved_retrieval_strategy(),
|
||||
RetrievalStrategy::Default
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,13 +5,12 @@ use std::{
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use async_openai::{types::CreateEmbeddingRequestArgs, Client};
|
||||
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
|
||||
use tracing::debug;
|
||||
|
||||
use crate::{
|
||||
error::AppError,
|
||||
error::{AppError, EmbeddingError},
|
||||
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
|
||||
utils::config::AppConfig,
|
||||
};
|
||||
@@ -57,16 +56,18 @@ enum EmbeddingInner {
|
||||
async fn run_fastembed(
|
||||
model: Arc<Mutex<TextEmbedding>>,
|
||||
texts: Vec<String>,
|
||||
) -> Result<Vec<Vec<f32>>> {
|
||||
tokio::task::spawn_blocking(move || {
|
||||
) -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||
match tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||
let mut guard = model
|
||||
.lock()
|
||||
.map_err(|e| anyhow!("fastembed model mutex poisoned: {e}"))?;
|
||||
guard.embed(texts, None)
|
||||
.map_err(EmbeddingError::mutex_poisoned)?;
|
||||
guard.embed(texts, None).map_err(EmbeddingError::fastembed)
|
||||
})
|
||||
.await
|
||||
.context("joining fastembed embedding task")?
|
||||
.context("generating fastembed embeddings")
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(join_error) => Err(EmbeddingError::from(join_error)),
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingProvider {
|
||||
@@ -102,17 +103,14 @@ impl EmbeddingProvider {
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the backend API call fails, FastEmbed initialisation fails,
|
||||
/// Returns [`EmbeddingError`] if the backend API call fails, FastEmbed initialisation fails,
|
||||
/// or the backend returns no embedding data.
|
||||
pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
|
||||
pub async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
|
||||
match &self.inner {
|
||||
EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)),
|
||||
EmbeddingInner::FastEmbed { model, .. } => {
|
||||
let embeddings = run_fastembed(Arc::clone(model), vec![text.to_owned()]).await?;
|
||||
embeddings
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("fastembed returned no embedding for input"))
|
||||
embeddings.into_iter().next().ok_or(EmbeddingError::NoData)
|
||||
}
|
||||
EmbeddingInner::OpenAI {
|
||||
client,
|
||||
@@ -130,7 +128,7 @@ impl EmbeddingProvider {
|
||||
let embedding = response
|
||||
.data
|
||||
.first()
|
||||
.ok_or_else(|| anyhow!("No embedding data received from OpenAI API"))?
|
||||
.ok_or(EmbeddingError::NoData)?
|
||||
.embedding
|
||||
.clone();
|
||||
|
||||
@@ -143,9 +141,9 @@ impl EmbeddingProvider {
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the backend API call fails or returns no embedding data.
|
||||
/// Returns [`EmbeddingError`] if the backend API call fails or returns no embedding data.
|
||||
/// Returns an empty `Vec` when `texts` is empty.
|
||||
pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
|
||||
pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||
match &self.inner {
|
||||
EmbeddingInner::Hashed { dimension } => Ok(texts
|
||||
.into_iter()
|
||||
@@ -185,11 +183,14 @@ impl EmbeddingProvider {
|
||||
}
|
||||
}
|
||||
|
||||
/// # Errors
|
||||
///
|
||||
/// Currently infallible; reserved for future validation.
|
||||
pub fn new_openai(
|
||||
client: Arc<Client<async_openai::config::OpenAIConfig>>,
|
||||
model: String,
|
||||
dimensions: u32,
|
||||
) -> Result<Self> {
|
||||
) -> Result<Self, EmbeddingError> {
|
||||
Ok(Self {
|
||||
inner: EmbeddingInner::OpenAI {
|
||||
client,
|
||||
@@ -199,9 +200,12 @@ impl EmbeddingProvider {
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn new_fastembed(model_override: Option<String>) -> Result<Self> {
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`EmbeddingError`] if the model name is unknown or FastEmbed initialisation fails.
|
||||
pub async fn new_fastembed(model_override: Option<String>) -> Result<Self, EmbeddingError> {
|
||||
let model_name = if let Some(code) = model_override {
|
||||
EmbeddingModel::from_str(&code).map_err(|err| anyhow!(err))?
|
||||
EmbeddingModel::from_str(&code).map_err(EmbeddingError::UnknownModel)?
|
||||
} else {
|
||||
EmbeddingModel::default()
|
||||
};
|
||||
@@ -210,15 +214,21 @@ impl EmbeddingProvider {
|
||||
let model_name_for_task = model_name.clone();
|
||||
let model_name_code = model_name.to_string();
|
||||
|
||||
let (model, dimension) = tokio::task::spawn_blocking(move || -> Result<_> {
|
||||
let (model, dimension) = match tokio::task::spawn_blocking(move || -> Result<_, EmbeddingError> {
|
||||
let model =
|
||||
TextEmbedding::try_new(options).context("initialising FastEmbed text model")?;
|
||||
let info = EmbeddingModel::get_model_info(&model_name_for_task)
|
||||
.ok_or_else(|| anyhow!("FastEmbed model metadata missing for {model_name_code}"))?;
|
||||
TextEmbedding::try_new(options).map_err(EmbeddingError::fastembed)?;
|
||||
let info = EmbeddingModel::get_model_info(&model_name_for_task).ok_or_else(|| {
|
||||
EmbeddingError::Config(format!(
|
||||
"fastembed model metadata missing for {model_name_code}"
|
||||
))
|
||||
})?;
|
||||
Ok((model, info.dim))
|
||||
})
|
||||
.await
|
||||
.context("joining FastEmbed initialisation task")??;
|
||||
{
|
||||
Ok(result) => result?,
|
||||
Err(join_error) => return Err(EmbeddingError::from(join_error)),
|
||||
};
|
||||
|
||||
Ok(EmbeddingProvider {
|
||||
inner: EmbeddingInner::FastEmbed {
|
||||
@@ -229,7 +239,10 @@ impl EmbeddingProvider {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new_hashed(dimension: usize) -> Result<Self> {
|
||||
/// # Errors
|
||||
///
|
||||
/// Currently infallible; reserved for future validation.
|
||||
pub fn new_hashed(dimension: usize) -> Result<Self, EmbeddingError> {
|
||||
Ok(EmbeddingProvider {
|
||||
inner: EmbeddingInner::Hashed {
|
||||
dimension: dimension.max(1),
|
||||
@@ -242,24 +255,32 @@ impl EmbeddingProvider {
|
||||
/// Model name and dimensions come from [`SystemSettings`]. The active backend is taken
|
||||
/// from `config.embedding_backend` at startup; [`SystemSettings::sync_from_embedding_provider`]
|
||||
/// persists the resolved backend to the database.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`EmbeddingError`] if the selected backend cannot be initialised.
|
||||
pub async fn from_system_settings(
|
||||
settings: &SystemSettings,
|
||||
config: &AppConfig,
|
||||
openai_client: Option<Arc<Client<async_openai::config::OpenAIConfig>>>,
|
||||
) -> Result<Self> {
|
||||
) -> Result<Self, EmbeddingError> {
|
||||
let dimensions = settings.embedding_dimensions;
|
||||
match config.embedding_backend {
|
||||
EmbeddingBackend::OpenAI => {
|
||||
let client = openai_client
|
||||
.ok_or_else(|| anyhow!("OpenAI embedding backend requires an OpenAI client"))?;
|
||||
let client = openai_client.ok_or_else(|| {
|
||||
EmbeddingError::Config(
|
||||
"openai embedding backend requires an openai client".into(),
|
||||
)
|
||||
})?;
|
||||
Self::new_openai(client, settings.embedding_model.clone(), dimensions)
|
||||
}
|
||||
EmbeddingBackend::FastEmbed => {
|
||||
Self::new_fastembed(Some(settings.embedding_model.clone())).await
|
||||
}
|
||||
EmbeddingBackend::Hashed => {
|
||||
let dimension = usize::try_from(dimensions)
|
||||
.map_err(|_| anyhow!("embedding_dimensions exceeds usize::MAX"))?;
|
||||
let dimension = usize::try_from(dimensions).map_err(|_| {
|
||||
EmbeddingError::Config("embedding_dimensions exceeds usize::MAX".into())
|
||||
})?;
|
||||
Self::new_hashed(dimension)
|
||||
}
|
||||
}
|
||||
@@ -312,15 +333,12 @@ fn bucket(token: &str, dimension: usize) -> usize {
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`AppError::InternalError`] if the provider's embed call fails.
|
||||
/// Returns [`AppError::Embedding`] if the provider's embed call fails.
|
||||
pub async fn generate_embedding_with_provider(
|
||||
provider: &EmbeddingProvider,
|
||||
input: &str,
|
||||
) -> Result<Vec<f32>, AppError> {
|
||||
provider
|
||||
.embed(input)
|
||||
.await
|
||||
.map_err(AppError::internal)
|
||||
provider.embed(input).await.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Generates an embedding vector for the given input text using `OpenAI`'s embedding model.
|
||||
|
||||
Reference in New Issue
Block a user