chore: harden common errors, fastembed blocking, and ingest ownership

Run FastEmbed inference on spawn_blocking, propagate Surreal take
failures,
add AppError::internal and typed ingest/embedding parse errors, and take
owned file lists in ingestion payload construction.
This commit is contained in:
Per Stark
2026-05-28 20:25:12 +02:00
parent 1e25705377
commit 1e0dba72c8
13 changed files with 153 additions and 125 deletions
+39 -19
View File
@@ -2,13 +2,13 @@ use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
str::FromStr,
sync::Arc,
sync::{Arc, Mutex},
};
use anyhow::{anyhow, Context, Result};
use async_openai::{types::CreateEmbeddingRequestArgs, Client};
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
use tokio::sync::Mutex;
use thiserror::Error;
use tracing::debug;
use crate::{
@@ -16,6 +16,14 @@ use crate::{
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
};
/// Error returned when parsing an embedding backend name.
#[derive(Debug, Error, PartialEq, Eq)]
#[error("unknown embedding backend '{input}': expected 'openai', 'hashed', or 'fastembed'")]
pub struct ParseEmbeddingBackendError {
/// The unrecognized input string.
pub input: String,
}
/// Supported embedding backends.
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
@@ -27,16 +35,16 @@ pub enum EmbeddingBackend {
}
impl std::str::FromStr for EmbeddingBackend {
type Err = anyhow::Error;
type Err = ParseEmbeddingBackendError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_ascii_lowercase().as_str() {
"openai" => Ok(Self::OpenAI),
"hashed" => Ok(Self::Hashed),
"fastembed" | "fast-embed" | "fast" => Ok(Self::FastEmbed),
other => Err(anyhow!(
"unknown embedding backend '{other}'. Expected 'openai', 'hashed', or 'fastembed'."
)),
other => Err(ParseEmbeddingBackendError {
input: other.to_string(),
}),
}
}
}
@@ -68,7 +76,7 @@ enum EmbeddingInner {
},
/// Uses `FastEmbed` running locally.
FastEmbed {
/// Shared `FastEmbed` model.
/// Shared `FastEmbed` model (blocking; used only inside `spawn_blocking`).
model: Arc<Mutex<TextEmbedding>>,
/// Model metadata used for info logging.
model_name: EmbeddingModel,
@@ -77,6 +85,21 @@ enum EmbeddingInner {
},
}
async fn run_fastembed(
model: Arc<Mutex<TextEmbedding>>,
texts: Vec<String>,
) -> Result<Vec<Vec<f32>>> {
tokio::task::spawn_blocking(move || {
let mut guard = model
.lock()
.map_err(|e| anyhow!("fastembed model mutex poisoned: {e}"))?;
guard.embed(texts, None)
})
.await
.context("joining fastembed embedding task")?
.context("generating fastembed embeddings")
}
impl EmbeddingProvider {
#[must_use]
pub fn backend_label(&self) -> &'static str {
@@ -116,10 +139,7 @@ impl EmbeddingProvider {
match &self.inner {
EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)),
EmbeddingInner::FastEmbed { model, .. } => {
let mut guard = model.lock().await;
let embeddings = guard
.embed(vec![text.to_owned()], None)
.context("generating fastembed vector")?;
let embeddings = run_fastembed(Arc::clone(model), vec![text.to_owned()]).await?;
embeddings
.into_iter()
.next()
@@ -166,10 +186,7 @@ impl EmbeddingProvider {
if texts.is_empty() {
return Ok(Vec::new());
}
let mut guard = model.lock().await;
guard
.embed(texts, None)
.context("generating fastembed batch embeddings")
run_fastembed(Arc::clone(model), texts).await
}
EmbeddingInner::OpenAI {
client,
@@ -325,12 +342,15 @@ fn bucket(token: &str, dimension: usize) -> usize {
///
/// # Errors
///
/// Returns `AppError::InternalError` if the provider's embed call fails.
/// Returns [`AppError::InternalError`] if the provider's embed call fails.
pub async fn generate_embedding_with_provider(
provider: &EmbeddingProvider,
input: &str,
) -> Result<Vec<f32>, AppError> {
provider.embed(input).await.map_err(|e| AppError::InternalError(e.to_string()))
provider
.embed(input)
.await
.map_err(AppError::internal)
}
/// Generates an embedding vector for the given input text using `OpenAI`'s embedding model.
@@ -377,7 +397,7 @@ pub async fn generate_embedding(
let embedding: Vec<f32> = response
.data
.first()
.ok_or_else(|| AppError::LLMParsing("No embedding data received".into()))?
.ok_or_else(|| AppError::LLMParsing("no embedding data received".into()))?
.embedding
.clone();
@@ -409,7 +429,7 @@ pub async fn generate_embedding_with_params(
let embedding = response
.data
.first()
.ok_or_else(|| AppError::LLMParsing("No embedding data received from API".into()))?
.ok_or_else(|| AppError::LLMParsing("no embedding data received from API".into()))?
.embedding
.clone();
+9 -5
View File
@@ -1,11 +1,15 @@
use thiserror::Error;
use super::config::AppConfig;
/// Errors raised when validating ingestion payloads against configured limits.
#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Error, Debug, Clone, PartialEq, Eq)]
pub enum IngestValidationError {
/// The payload exceeds a configured size limit (content, context, or category).
#[error("payload too large: {0}")]
PayloadTooLarge(String),
/// The request violates a non-size constraint (e.g., too many files).
#[error("bad request: {0}")]
BadRequest(String),
}
@@ -27,7 +31,7 @@ pub fn validate_ingest_input(
) -> Result<(), IngestValidationError> {
if file_count > config.ingest_max_files {
return Err(IngestValidationError::BadRequest(format!(
"Too many files. Maximum allowed is {}",
"too many files: maximum allowed is {}",
config.ingest_max_files
)));
}
@@ -35,7 +39,7 @@ pub fn validate_ingest_input(
if let Some(content) = content {
if content.len() > config.ingest_max_content_bytes {
return Err(IngestValidationError::PayloadTooLarge(format!(
"Content is too large. Maximum allowed is {} bytes",
"content is too large: maximum allowed is {} bytes",
config.ingest_max_content_bytes
)));
}
@@ -43,14 +47,14 @@ pub fn validate_ingest_input(
if ctx.len() > config.ingest_max_context_bytes {
return Err(IngestValidationError::PayloadTooLarge(format!(
"Context is too large. Maximum allowed is {} bytes",
"context is too large: maximum allowed is {} bytes",
config.ingest_max_context_bytes
)));
}
if category.len() > config.ingest_max_category_bytes {
return Err(IngestValidationError::PayloadTooLarge(format!(
"Category is too large. Maximum allowed is {} bytes",
"category is too large: maximum allowed is {} bytes",
config.ingest_max_category_bytes
)));
}