mirror of
https://github.com/perstarkse/minne.git
synced 2026-06-25 19:36:20 +02:00
chore: harden common errors, fastembed blocking, and ingest ownership
Run FastEmbed inference on spawn_blocking, propagate Surreal take failures, add AppError::internal and typed ingest/embedding parse errors, and take owned file lists in ingestion payload construction.
This commit is contained in:
@@ -2,13 +2,13 @@ use std::{
|
||||
collections::hash_map::DefaultHasher,
|
||||
hash::{Hash, Hasher},
|
||||
str::FromStr,
|
||||
sync::Arc,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use async_openai::{types::CreateEmbeddingRequestArgs, Client};
|
||||
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
|
||||
use tokio::sync::Mutex;
|
||||
use thiserror::Error;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::{
|
||||
@@ -16,6 +16,14 @@ use crate::{
|
||||
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
|
||||
};
|
||||
|
||||
/// Error returned when parsing an embedding backend name.
|
||||
#[derive(Debug, Error, PartialEq, Eq)]
|
||||
#[error("unknown embedding backend '{input}': expected 'openai', 'hashed', or 'fastembed'")]
|
||||
pub struct ParseEmbeddingBackendError {
|
||||
/// The unrecognized input string.
|
||||
pub input: String,
|
||||
}
|
||||
|
||||
/// Supported embedding backends.
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
@@ -27,16 +35,16 @@ pub enum EmbeddingBackend {
|
||||
}
|
||||
|
||||
impl std::str::FromStr for EmbeddingBackend {
|
||||
type Err = anyhow::Error;
|
||||
type Err = ParseEmbeddingBackendError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_ascii_lowercase().as_str() {
|
||||
"openai" => Ok(Self::OpenAI),
|
||||
"hashed" => Ok(Self::Hashed),
|
||||
"fastembed" | "fast-embed" | "fast" => Ok(Self::FastEmbed),
|
||||
other => Err(anyhow!(
|
||||
"unknown embedding backend '{other}'. Expected 'openai', 'hashed', or 'fastembed'."
|
||||
)),
|
||||
other => Err(ParseEmbeddingBackendError {
|
||||
input: other.to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -68,7 +76,7 @@ enum EmbeddingInner {
|
||||
},
|
||||
/// Uses `FastEmbed` running locally.
|
||||
FastEmbed {
|
||||
/// Shared `FastEmbed` model.
|
||||
/// Shared `FastEmbed` model (blocking; used only inside `spawn_blocking`).
|
||||
model: Arc<Mutex<TextEmbedding>>,
|
||||
/// Model metadata used for info logging.
|
||||
model_name: EmbeddingModel,
|
||||
@@ -77,6 +85,21 @@ enum EmbeddingInner {
|
||||
},
|
||||
}
|
||||
|
||||
async fn run_fastembed(
|
||||
model: Arc<Mutex<TextEmbedding>>,
|
||||
texts: Vec<String>,
|
||||
) -> Result<Vec<Vec<f32>>> {
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let mut guard = model
|
||||
.lock()
|
||||
.map_err(|e| anyhow!("fastembed model mutex poisoned: {e}"))?;
|
||||
guard.embed(texts, None)
|
||||
})
|
||||
.await
|
||||
.context("joining fastembed embedding task")?
|
||||
.context("generating fastembed embeddings")
|
||||
}
|
||||
|
||||
impl EmbeddingProvider {
|
||||
#[must_use]
|
||||
pub fn backend_label(&self) -> &'static str {
|
||||
@@ -116,10 +139,7 @@ impl EmbeddingProvider {
|
||||
match &self.inner {
|
||||
EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)),
|
||||
EmbeddingInner::FastEmbed { model, .. } => {
|
||||
let mut guard = model.lock().await;
|
||||
let embeddings = guard
|
||||
.embed(vec![text.to_owned()], None)
|
||||
.context("generating fastembed vector")?;
|
||||
let embeddings = run_fastembed(Arc::clone(model), vec![text.to_owned()]).await?;
|
||||
embeddings
|
||||
.into_iter()
|
||||
.next()
|
||||
@@ -166,10 +186,7 @@ impl EmbeddingProvider {
|
||||
if texts.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut guard = model.lock().await;
|
||||
guard
|
||||
.embed(texts, None)
|
||||
.context("generating fastembed batch embeddings")
|
||||
run_fastembed(Arc::clone(model), texts).await
|
||||
}
|
||||
EmbeddingInner::OpenAI {
|
||||
client,
|
||||
@@ -325,12 +342,15 @@ fn bucket(token: &str, dimension: usize) -> usize {
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `AppError::InternalError` if the provider's embed call fails.
|
||||
/// Returns [`AppError::InternalError`] if the provider's embed call fails.
|
||||
pub async fn generate_embedding_with_provider(
|
||||
provider: &EmbeddingProvider,
|
||||
input: &str,
|
||||
) -> Result<Vec<f32>, AppError> {
|
||||
provider.embed(input).await.map_err(|e| AppError::InternalError(e.to_string()))
|
||||
provider
|
||||
.embed(input)
|
||||
.await
|
||||
.map_err(AppError::internal)
|
||||
}
|
||||
|
||||
/// Generates an embedding vector for the given input text using `OpenAI`'s embedding model.
|
||||
@@ -377,7 +397,7 @@ pub async fn generate_embedding(
|
||||
let embedding: Vec<f32> = response
|
||||
.data
|
||||
.first()
|
||||
.ok_or_else(|| AppError::LLMParsing("No embedding data received".into()))?
|
||||
.ok_or_else(|| AppError::LLMParsing("no embedding data received".into()))?
|
||||
.embedding
|
||||
.clone();
|
||||
|
||||
@@ -409,7 +429,7 @@ pub async fn generate_embedding_with_params(
|
||||
let embedding = response
|
||||
.data
|
||||
.first()
|
||||
.ok_or_else(|| AppError::LLMParsing("No embedding data received from API".into()))?
|
||||
.ok_or_else(|| AppError::LLMParsing("no embedding data received from API".into()))?
|
||||
.embedding
|
||||
.clone();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user