chore: harden common errors, fastembed blocking, and ingest ownership

Run FastEmbed inference on spawn_blocking, propagate Surreal take
failures,
add AppError::internal and typed ingest/embedding parse errors, and take
owned file lists in ingestion payload construction.
This commit is contained in:
Per Stark
2026-05-28 20:25:12 +02:00
parent 9d5e7cd794
commit 85336d77a3
13 changed files with 153 additions and 125 deletions
+2 -2
View File
@@ -73,8 +73,8 @@ pub async fn ingest_data(
input.content,
input.context,
input.category,
&file_infos,
&user_id,
file_infos,
user_id.clone(),
)?;
let futures: Vec<_> = payloads
+8
View File
@@ -39,3 +39,11 @@ pub enum AppError {
#[error("internal service error: {0}")]
InternalError(String),
}
impl AppError {
/// Builds an [`AppError::InternalError`] from a displayable message.
#[must_use]
pub fn internal(msg: impl std::fmt::Display) -> Self {
Self::InternalError(msg.to_string())
}
}
+1 -1
View File
@@ -124,7 +124,7 @@ impl SurrealDbClient {
.load_files(&MIGRATIONS_DIR)
.up()
.await
.map_err(|e| AppError::InternalError(e.to_string()))?;
.map_err(AppError::internal)?;
Ok(())
}
+2 -2
View File
@@ -169,7 +169,7 @@ pub async fn ensure_runtime(
) -> Result<(), AppError> {
ensure_runtime_inner(db, embedding_dimension)
.await
.map_err(|err| AppError::InternalError(err.to_string()))
.map_err(AppError::internal)
}
/// Rebuild known FTS and HNSW indexes, skipping any that are not yet defined.
@@ -180,7 +180,7 @@ pub async fn ensure_runtime(
pub async fn rebuild(db: &SurrealDbClient) -> Result<(), AppError> {
rebuild_inner(db)
.await
.map_err(|err| AppError::InternalError(err.to_string()))
.map_err(AppError::internal)
}
async fn ensure_runtime_inner(db: &SurrealDbClient, embedding_dimension: usize) -> Result<()> {
+1 -1
View File
@@ -261,7 +261,7 @@ impl FileInfo {
// Remove the object's parent prefix in the object store
let (parent_prefix, _file_name) = store::split_object_path(&file_info.path)
.map_err(|e| AppError::InternalError(e.to_string()))?;
.map_err(AppError::internal)?;
storage
.delete_prefix(&parent_prefix)
.await
+54 -64
View File
@@ -1,9 +1,4 @@
#![allow(
clippy::result_large_err,
clippy::needless_pass_by_value,
clippy::implicit_clone,
clippy::semicolon_if_nothing_returned
)]
#![allow(clippy::result_large_err)]
use crate::{error::AppError, storage::types::file_info::FileInfo};
use serde::{Deserialize, Serialize};
use tracing::info;
@@ -49,48 +44,49 @@ impl IngestionPayload {
content: Option<String>,
context: String,
category: String,
files: &[FileInfo],
user_id: &str,
files: Vec<FileInfo>,
user_id: String,
) -> Result<Vec<IngestionPayload>, AppError> {
// Initialize list
let mut object_list = Vec::new();
let has_content = content
.as_ref()
.is_some_and(|c| c.len() > 2);
#[allow(clippy::arithmetic_side_effects)]
let capacity = files.len() + usize::from(has_content);
let mut object_list = Vec::with_capacity(capacity);
for file in files {
object_list.push(IngestionPayload::File {
file_info: file,
context: context.clone(),
category: category.clone(),
user_id: user_id.clone(),
});
}
// Create a IngestionPayload from content if it exists, checking for URL or text
if let Some(input_content) = content {
match Url::parse(&input_content) {
Ok(url) => {
info!("Detected URL: {}", url);
object_list.push(IngestionPayload::Url {
url: url.to_string(),
context: context.clone(),
category: category.clone(),
user_id: user_id.into(),
context,
category,
user_id,
});
}
Err(_) => {
if input_content.len() > 2 {
info!("Treating input as plain text");
object_list.push(IngestionPayload::Text {
text: input_content.to_string(),
context: context.clone(),
category: category.clone(),
user_id: user_id.into(),
text: input_content,
context,
category,
user_id,
});
}
}
}
}
for file in files {
object_list.push(IngestionPayload::File {
file_info: file.clone(),
context: context.clone(),
category: category.clone(),
user_id: user_id.into(),
})
}
// If no objects are constructed, we return Err
if object_list.is_empty() {
return Err(AppError::NotFound(
"No valid content or files provided".into(),
@@ -138,14 +134,13 @@ mod tests {
let context = "Process this URL";
let category = "websites";
let user_id = "user123";
let files = vec![];
let result = IngestionPayload::create_ingestion_payload(
Some(url.to_string()),
context.to_string(),
category.to_string(),
&files,
user_id,
vec![],
user_id.to_string(),
)
.with_context(|| "create_ingestion_payload".to_string())?;
@@ -174,14 +169,13 @@ mod tests {
let context = "Process this text";
let category = "notes";
let user_id = "user123";
let files = vec![];
let result = IngestionPayload::create_ingestion_payload(
Some(text.to_string()),
context.to_string(),
category.to_string(),
&files,
user_id,
vec![],
user_id.to_string(),
)
.with_context(|| "create_ingestion_payload".to_string())?;
@@ -215,16 +209,15 @@ mod tests {
};
let file_info: FileInfo = mock_file.into();
let files = vec![file_info.clone()];
let file_id = file_info.id.clone();
let result = IngestionPayload::create_ingestion_payload(
None,
context.to_string(),
category.to_string(),
&files,
user_id,
)
.with_context(|| "create_ingestion_payload".to_string())?;
vec![file_info],
user_id.to_string(),
)?;
assert_eq!(result.len(), 1);
match result.first().context("expected one result")? {
@@ -234,7 +227,7 @@ mod tests {
category: payload_category,
user_id: payload_user_id,
} => {
assert_eq!(payload_file_info.id, file_info.id);
assert_eq!(payload_file_info.id, file_id);
assert_eq!(payload_context, context);
assert_eq!(payload_category, category);
assert_eq!(payload_user_id, user_id);
@@ -257,39 +250,38 @@ mod tests {
};
let file_info: FileInfo = mock_file.into();
let files = vec![file_info.clone()];
let file_id = file_info.id.clone();
let result = IngestionPayload::create_ingestion_payload(
Some(url.to_string()),
context.to_string(),
category.to_string(),
&files,
user_id,
)
.with_context(|| "create_ingestion_payload".to_string())?;
vec![file_info],
user_id.to_string(),
)?;
assert_eq!(result.len(), 2);
// Check first item is URL
// Check first item is File (files processed first to minimize clones)
match result.first().context("expected first item")? {
IngestionPayload::File {
file_info: payload_file_info,
..
} => {
assert_eq!(payload_file_info.id, file_id);
}
_ => anyhow::bail!("Expected first item to be File variant"),
}
// Check second item is URL
match result.get(1).context("expected second item")? {
IngestionPayload::Url {
url: payload_url, ..
} => {
// URL parser may normalize the URL by adding a trailing slash
assert!(payload_url == &url.to_string() || payload_url == &format!("{url}/"));
}
_ => anyhow::bail!("Expected first item to be Url variant"),
}
// Check second item is File
match result.get(1).context("expected second item")? {
IngestionPayload::File {
file_info: payload_file_info,
..
} => {
assert_eq!(payload_file_info.id, file_info.id);
}
_ => anyhow::bail!("Expected second item to be File variant"),
_ => anyhow::bail!("Expected second item to be Url variant"),
}
Ok(())
}
@@ -299,14 +291,13 @@ mod tests {
let context = "Process something";
let category = "empty";
let user_id = "user123";
let files = vec![];
let result = IngestionPayload::create_ingestion_payload(
None,
context.to_string(),
category.to_string(),
&files,
user_id,
vec![],
user_id.to_string(),
);
assert!(result.is_err());
@@ -325,14 +316,13 @@ mod tests {
let context = "Process this";
let category = "notes";
let user_id = "user123";
let files = vec![];
let result = IngestionPayload::create_ingestion_payload(
Some(text.to_string()),
context.to_string(),
category.to_string(),
&files,
user_id,
vec![],
user_id.to_string(),
);
assert!(result.is_err());
+15 -13
View File
@@ -250,7 +250,7 @@ impl KnowledgeEntity {
.bind(("embedding", query_embedding))
.bind(("user_id", user_id.to_string()))
.await
.map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?;
.map_err(AppError::Database)?;
response = response.check().map_err(AppError::Database)?;
@@ -326,7 +326,7 @@ impl KnowledgeEntity {
let rows: Vec<Row> = response.take(0).map_err(AppError::Database)?;
rows.first()
.map(|r| r.user_id.clone())
.ok_or_else(|| AppError::InternalError("user not found for entity".to_string()))
.ok_or_else(|| AppError::internal("user not found for entity"))
}
/// Re-creates embeddings for all knowledge entities in the database.
@@ -385,7 +385,7 @@ impl KnowledgeEntity {
entity.id, embedding.len(), new_dimensions
);
error!("{err_msg}");
return Err(AppError::InternalError(err_msg));
return Err(AppError::internal(err_msg));
}
new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone()));
}
@@ -397,8 +397,9 @@ impl KnowledgeEntity {
// Add all update statements to the embedding table
for (id, (embedding, user_id)) in new_embeddings {
let embedding = serde_json::to_string(&embedding)
.map_err(|e| AppError::InternalError(format!("embedding serialization failed: {e}")))?;
let embedding = serde_json::to_string(&embedding).map_err(|e| {
AppError::internal(format!("embedding serialization failed: {e}"))
})?;
write!(
transaction_query,
"UPSERT type::thing('knowledge_entity_embedding', '{id}') SET \
@@ -408,14 +409,14 @@ impl KnowledgeEntity {
created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \
updated_at = time::now();",
)
.map_err(|e| AppError::InternalError(e.to_string()))?;
.map_err(AppError::internal)?;
}
write!(
transaction_query,
"DEFINE INDEX OVERWRITE idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};",
)
.map_err(|e| AppError::InternalError(e.to_string()))?;
.map_err(AppError::internal)?;
transaction_query.push_str("COMMIT TRANSACTION;");
@@ -472,7 +473,7 @@ impl KnowledgeEntity {
let embedding = provider
.embed(&embedding_input)
.await
.map_err(|e| AppError::InternalError(format!("Embedding failed: {e}")))?;
.map_err(AppError::internal)?;
// Safety check: ensure the generated embedding has the correct dimension.
if embedding.len() != new_dimensions {
@@ -481,7 +482,7 @@ impl KnowledgeEntity {
entity.id, embedding.len(), new_dimensions
);
error!("{err_msg}");
return Err(AppError::InternalError(err_msg));
return Err(AppError::internal(err_msg));
}
new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone()));
}
@@ -517,8 +518,9 @@ impl KnowledgeEntity {
let mut transaction_query = String::from("BEGIN TRANSACTION;");
for (id, (embedding, user_id)) in new_embeddings {
let embedding = serde_json::to_string(&embedding)
.map_err(|e| AppError::InternalError(format!("embedding serialization failed: {e}")))?;
let embedding = serde_json::to_string(&embedding).map_err(|e| {
AppError::internal(format!("embedding serialization failed: {e}"))
})?;
write!(
transaction_query,
"CREATE type::thing('knowledge_entity_embedding', '{id}') SET \
@@ -528,14 +530,14 @@ impl KnowledgeEntity {
created_at = time::now(), \
updated_at = time::now();",
)
.map_err(|e| AppError::InternalError(e.to_string()))?;
.map_err(AppError::internal)?;
}
write!(
transaction_query,
"DEFINE INDEX OVERWRITE idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};",
)
.map_err(|e| AppError::InternalError(e.to_string()))?;
.map_err(AppError::internal)?;
transaction_query.push_str("COMMIT TRANSACTION;");
@@ -94,7 +94,7 @@ impl KnowledgeRelationship {
.bind(("id", id.to_owned()))
.bind(("user_id", user_id.to_owned()))
.await?;
let authorized: Vec<KnowledgeRelationship> = authorized_result.take(0).unwrap_or_default();
let authorized: Vec<KnowledgeRelationship> = authorized_result.take(0)?;
if authorized.is_empty() {
let mut exists_result = db_client
+1 -1
View File
@@ -69,7 +69,7 @@ pub fn format_history(history: &[Message]) -> String {
if i > 0 {
out.push('\n');
}
write!(out, "{msg}").unwrap_or_default();
let _ = write!(out, "{msg}");
}
out
}
+18 -14
View File
@@ -135,9 +135,11 @@ impl TextChunk {
.bind(("embedding", query_embedding))
.bind(("user_id", user_id.to_string()))
.await
.map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?;
.map_err(AppError::Database)?;
let rows: Vec<Row> = response.take::<Vec<Row>>(0).unwrap_or_default();
response = response.check().map_err(AppError::Database)?;
let rows: Vec<Row> = response.take::<Vec<Row>>(0).map_err(AppError::Database)?;
Ok(rows
.into_iter()
@@ -198,7 +200,7 @@ impl TextChunk {
.bind(("user_id", user_id.to_owned()))
.bind(("limit", limit))
.await
.map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?;
.map_err(AppError::Database)?;
response = response.check().map_err(AppError::Database)?;
@@ -277,7 +279,7 @@ impl TextChunk {
chunk.id, embedding.len(), new_dimensions
);
error!("{err_msg}");
return Err(AppError::InternalError(err_msg));
return Err(AppError::internal(err_msg));
}
new_embeddings.insert(
chunk.id.clone(),
@@ -291,8 +293,9 @@ impl TextChunk {
let mut transaction_query = String::from("BEGIN TRANSACTION;");
for (id, (embedding, user_id, source_id)) in new_embeddings {
let embedding = serde_json::to_string(&embedding)
.map_err(|e| AppError::InternalError(format!("embedding serialization failed: {e}")))?;
let embedding = serde_json::to_string(&embedding).map_err(|e| {
AppError::internal(format!("embedding serialization failed: {e}"))
})?;
write!(
&mut transaction_query,
"UPSERT type::thing('text_chunk_embedding', '{id}') SET \
@@ -303,14 +306,14 @@ impl TextChunk {
created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \
updated_at = time::now();",
)
.map_err(|e| AppError::InternalError(e.to_string()))?;
.map_err(AppError::internal)?;
}
write!(
&mut transaction_query,
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};",
)
.map_err(|e| AppError::InternalError(e.to_string()))?;
.map_err(AppError::internal)?;
transaction_query.push_str("COMMIT TRANSACTION;");
@@ -357,7 +360,7 @@ impl TextChunk {
let embedding = provider
.embed(&chunk.chunk)
.await
.map_err(|e| AppError::InternalError(format!("Embedding failed: {e}")))?;
.map_err(AppError::internal)?;
// Safety check: ensure the generated embedding has the correct dimension.
if embedding.len() != new_dimensions {
@@ -366,7 +369,7 @@ impl TextChunk {
chunk.id, embedding.len(), new_dimensions
);
error!("{err_msg}");
return Err(AppError::InternalError(err_msg));
return Err(AppError::internal(err_msg));
}
new_embeddings.insert(
chunk.id.clone(),
@@ -402,8 +405,9 @@ impl TextChunk {
let mut transaction_query = String::from("BEGIN TRANSACTION;");
for (id, (embedding, user_id, source_id)) in new_embeddings {
let embedding = serde_json::to_string(&embedding)
.map_err(|e| AppError::InternalError(format!("embedding serialization failed: {e}")))?;
let embedding = serde_json::to_string(&embedding).map_err(|e| {
AppError::internal(format!("embedding serialization failed: {e}"))
})?;
write!(
&mut transaction_query,
"CREATE type::thing('text_chunk_embedding', '{id}') SET \
@@ -414,14 +418,14 @@ impl TextChunk {
created_at = time::now(), \
updated_at = time::now();",
)
.map_err(|e| AppError::InternalError(e.to_string()))?;
.map_err(AppError::internal)?;
}
write!(
&mut transaction_query,
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};",
)
.map_err(|e| AppError::InternalError(e.to_string()))?;
.map_err(AppError::internal)?;
transaction_query.push_str("COMMIT TRANSACTION;");
+39 -19
View File
@@ -2,13 +2,13 @@ use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
str::FromStr,
sync::Arc,
sync::{Arc, Mutex},
};
use anyhow::{anyhow, Context, Result};
use async_openai::{types::CreateEmbeddingRequestArgs, Client};
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
use tokio::sync::Mutex;
use thiserror::Error;
use tracing::debug;
use crate::{
@@ -16,6 +16,14 @@ use crate::{
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
};
/// Error returned when parsing an embedding backend name.
#[derive(Debug, Error, PartialEq, Eq)]
#[error("unknown embedding backend '{input}': expected 'openai', 'hashed', or 'fastembed'")]
pub struct ParseEmbeddingBackendError {
/// The unrecognized input string.
pub input: String,
}
/// Supported embedding backends.
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
@@ -27,16 +35,16 @@ pub enum EmbeddingBackend {
}
impl std::str::FromStr for EmbeddingBackend {
type Err = anyhow::Error;
type Err = ParseEmbeddingBackendError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_ascii_lowercase().as_str() {
"openai" => Ok(Self::OpenAI),
"hashed" => Ok(Self::Hashed),
"fastembed" | "fast-embed" | "fast" => Ok(Self::FastEmbed),
other => Err(anyhow!(
"unknown embedding backend '{other}'. Expected 'openai', 'hashed', or 'fastembed'."
)),
other => Err(ParseEmbeddingBackendError {
input: other.to_string(),
}),
}
}
}
@@ -68,7 +76,7 @@ enum EmbeddingInner {
},
/// Uses `FastEmbed` running locally.
FastEmbed {
/// Shared `FastEmbed` model.
/// Shared `FastEmbed` model (blocking; used only inside `spawn_blocking`).
model: Arc<Mutex<TextEmbedding>>,
/// Model metadata used for info logging.
model_name: EmbeddingModel,
@@ -77,6 +85,21 @@ enum EmbeddingInner {
},
}
async fn run_fastembed(
model: Arc<Mutex<TextEmbedding>>,
texts: Vec<String>,
) -> Result<Vec<Vec<f32>>> {
tokio::task::spawn_blocking(move || {
let mut guard = model
.lock()
.map_err(|e| anyhow!("fastembed model mutex poisoned: {e}"))?;
guard.embed(texts, None)
})
.await
.context("joining fastembed embedding task")?
.context("generating fastembed embeddings")
}
impl EmbeddingProvider {
#[must_use]
pub fn backend_label(&self) -> &'static str {
@@ -116,10 +139,7 @@ impl EmbeddingProvider {
match &self.inner {
EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)),
EmbeddingInner::FastEmbed { model, .. } => {
let mut guard = model.lock().await;
let embeddings = guard
.embed(vec![text.to_owned()], None)
.context("generating fastembed vector")?;
let embeddings = run_fastembed(Arc::clone(model), vec![text.to_owned()]).await?;
embeddings
.into_iter()
.next()
@@ -166,10 +186,7 @@ impl EmbeddingProvider {
if texts.is_empty() {
return Ok(Vec::new());
}
let mut guard = model.lock().await;
guard
.embed(texts, None)
.context("generating fastembed batch embeddings")
run_fastembed(Arc::clone(model), texts).await
}
EmbeddingInner::OpenAI {
client,
@@ -325,12 +342,15 @@ fn bucket(token: &str, dimension: usize) -> usize {
///
/// # Errors
///
/// Returns `AppError::InternalError` if the provider's embed call fails.
/// Returns [`AppError::InternalError`] if the provider's embed call fails.
pub async fn generate_embedding_with_provider(
provider: &EmbeddingProvider,
input: &str,
) -> Result<Vec<f32>, AppError> {
provider.embed(input).await.map_err(|e| AppError::InternalError(e.to_string()))
provider
.embed(input)
.await
.map_err(AppError::internal)
}
/// Generates an embedding vector for the given input text using `OpenAI`'s embedding model.
@@ -377,7 +397,7 @@ pub async fn generate_embedding(
let embedding: Vec<f32> = response
.data
.first()
.ok_or_else(|| AppError::LLMParsing("No embedding data received".into()))?
.ok_or_else(|| AppError::LLMParsing("no embedding data received".into()))?
.embedding
.clone();
@@ -409,7 +429,7 @@ pub async fn generate_embedding_with_params(
let embedding = response
.data
.first()
.ok_or_else(|| AppError::LLMParsing("No embedding data received from API".into()))?
.ok_or_else(|| AppError::LLMParsing("no embedding data received from API".into()))?
.embedding
.clone();
+9 -5
View File
@@ -1,11 +1,15 @@
use thiserror::Error;
use super::config::AppConfig;
/// Errors raised when validating ingestion payloads against configured limits.
#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Error, Debug, Clone, PartialEq, Eq)]
pub enum IngestValidationError {
/// The payload exceeds a configured size limit (content, context, or category).
#[error("payload too large: {0}")]
PayloadTooLarge(String),
/// The request violates a non-size constraint (e.g., too many files).
#[error("bad request: {0}")]
BadRequest(String),
}
@@ -27,7 +31,7 @@ pub fn validate_ingest_input(
) -> Result<(), IngestValidationError> {
if file_count > config.ingest_max_files {
return Err(IngestValidationError::BadRequest(format!(
"Too many files. Maximum allowed is {}",
"too many files: maximum allowed is {}",
config.ingest_max_files
)));
}
@@ -35,7 +39,7 @@ pub fn validate_ingest_input(
if let Some(content) = content {
if content.len() > config.ingest_max_content_bytes {
return Err(IngestValidationError::PayloadTooLarge(format!(
"Content is too large. Maximum allowed is {} bytes",
"content is too large: maximum allowed is {} bytes",
config.ingest_max_content_bytes
)));
}
@@ -43,14 +47,14 @@ pub fn validate_ingest_input(
if ctx.len() > config.ingest_max_context_bytes {
return Err(IngestValidationError::PayloadTooLarge(format!(
"Context is too large. Maximum allowed is {} bytes",
"context is too large: maximum allowed is {} bytes",
config.ingest_max_context_bytes
)));
}
if category.len() > config.ingest_max_category_bytes {
return Err(IngestValidationError::PayloadTooLarge(format!(
"Category is too large. Maximum allowed is {} bytes",
"category is too large: maximum allowed is {} bytes",
config.ingest_max_category_bytes
)));
}
+2 -2
View File
@@ -142,8 +142,8 @@ pub async fn process_ingest_form(
input.content,
input.context,
input.category,
&file_infos,
user.id.as_str(),
file_infos,
user.id.clone(),
)?;
let futures: Vec<_> = payloads