chore: harden common errors, fastembed blocking, and ingest ownership

Run FastEmbed inference on spawn_blocking, propagate Surreal take
failures,
add AppError::internal and typed ingest/embedding parse errors, and take
owned file lists in ingestion payload construction.
This commit is contained in:
Per Stark
2026-05-28 20:25:12 +02:00
parent 9d5e7cd794
commit 85336d77a3
13 changed files with 153 additions and 125 deletions
+2 -2
View File
@@ -73,8 +73,8 @@ pub async fn ingest_data(
input.content, input.content,
input.context, input.context,
input.category, input.category,
&file_infos, file_infos,
&user_id, user_id.clone(),
)?; )?;
let futures: Vec<_> = payloads let futures: Vec<_> = payloads
+8
View File
@@ -39,3 +39,11 @@ pub enum AppError {
#[error("internal service error: {0}")] #[error("internal service error: {0}")]
InternalError(String), InternalError(String),
} }
impl AppError {
/// Builds an [`AppError::InternalError`] from a displayable message.
#[must_use]
pub fn internal(msg: impl std::fmt::Display) -> Self {
Self::InternalError(msg.to_string())
}
}
+1 -1
View File
@@ -124,7 +124,7 @@ impl SurrealDbClient {
.load_files(&MIGRATIONS_DIR) .load_files(&MIGRATIONS_DIR)
.up() .up()
.await .await
.map_err(|e| AppError::InternalError(e.to_string()))?; .map_err(AppError::internal)?;
Ok(()) Ok(())
} }
+2 -2
View File
@@ -169,7 +169,7 @@ pub async fn ensure_runtime(
) -> Result<(), AppError> { ) -> Result<(), AppError> {
ensure_runtime_inner(db, embedding_dimension) ensure_runtime_inner(db, embedding_dimension)
.await .await
.map_err(|err| AppError::InternalError(err.to_string())) .map_err(AppError::internal)
} }
/// Rebuild known FTS and HNSW indexes, skipping any that are not yet defined. /// Rebuild known FTS and HNSW indexes, skipping any that are not yet defined.
@@ -180,7 +180,7 @@ pub async fn ensure_runtime(
pub async fn rebuild(db: &SurrealDbClient) -> Result<(), AppError> { pub async fn rebuild(db: &SurrealDbClient) -> Result<(), AppError> {
rebuild_inner(db) rebuild_inner(db)
.await .await
.map_err(|err| AppError::InternalError(err.to_string())) .map_err(AppError::internal)
} }
async fn ensure_runtime_inner(db: &SurrealDbClient, embedding_dimension: usize) -> Result<()> { async fn ensure_runtime_inner(db: &SurrealDbClient, embedding_dimension: usize) -> Result<()> {
+1 -1
View File
@@ -261,7 +261,7 @@ impl FileInfo {
// Remove the object's parent prefix in the object store // Remove the object's parent prefix in the object store
let (parent_prefix, _file_name) = store::split_object_path(&file_info.path) let (parent_prefix, _file_name) = store::split_object_path(&file_info.path)
.map_err(|e| AppError::InternalError(e.to_string()))?; .map_err(AppError::internal)?;
storage storage
.delete_prefix(&parent_prefix) .delete_prefix(&parent_prefix)
.await .await
+54 -64
View File
@@ -1,9 +1,4 @@
#![allow( #![allow(clippy::result_large_err)]
clippy::result_large_err,
clippy::needless_pass_by_value,
clippy::implicit_clone,
clippy::semicolon_if_nothing_returned
)]
use crate::{error::AppError, storage::types::file_info::FileInfo}; use crate::{error::AppError, storage::types::file_info::FileInfo};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tracing::info; use tracing::info;
@@ -49,48 +44,49 @@ impl IngestionPayload {
content: Option<String>, content: Option<String>,
context: String, context: String,
category: String, category: String,
files: &[FileInfo], files: Vec<FileInfo>,
user_id: &str, user_id: String,
) -> Result<Vec<IngestionPayload>, AppError> { ) -> Result<Vec<IngestionPayload>, AppError> {
// Initialize list let has_content = content
let mut object_list = Vec::new(); .as_ref()
.is_some_and(|c| c.len() > 2);
#[allow(clippy::arithmetic_side_effects)]
let capacity = files.len() + usize::from(has_content);
let mut object_list = Vec::with_capacity(capacity);
for file in files {
object_list.push(IngestionPayload::File {
file_info: file,
context: context.clone(),
category: category.clone(),
user_id: user_id.clone(),
});
}
// Create a IngestionPayload from content if it exists, checking for URL or text
if let Some(input_content) = content { if let Some(input_content) = content {
match Url::parse(&input_content) { match Url::parse(&input_content) {
Ok(url) => { Ok(url) => {
info!("Detected URL: {}", url); info!("Detected URL: {}", url);
object_list.push(IngestionPayload::Url { object_list.push(IngestionPayload::Url {
url: url.to_string(), url: url.to_string(),
context: context.clone(), context,
category: category.clone(), category,
user_id: user_id.into(), user_id,
}); });
} }
Err(_) => { Err(_) => {
if input_content.len() > 2 { if input_content.len() > 2 {
info!("Treating input as plain text"); info!("Treating input as plain text");
object_list.push(IngestionPayload::Text { object_list.push(IngestionPayload::Text {
text: input_content.to_string(), text: input_content,
context: context.clone(), context,
category: category.clone(), category,
user_id: user_id.into(), user_id,
}); });
} }
} }
} }
} }
for file in files {
object_list.push(IngestionPayload::File {
file_info: file.clone(),
context: context.clone(),
category: category.clone(),
user_id: user_id.into(),
})
}
// If no objects are constructed, we return Err
if object_list.is_empty() { if object_list.is_empty() {
return Err(AppError::NotFound( return Err(AppError::NotFound(
"No valid content or files provided".into(), "No valid content or files provided".into(),
@@ -138,14 +134,13 @@ mod tests {
let context = "Process this URL"; let context = "Process this URL";
let category = "websites"; let category = "websites";
let user_id = "user123"; let user_id = "user123";
let files = vec![];
let result = IngestionPayload::create_ingestion_payload( let result = IngestionPayload::create_ingestion_payload(
Some(url.to_string()), Some(url.to_string()),
context.to_string(), context.to_string(),
category.to_string(), category.to_string(),
&files, vec![],
user_id, user_id.to_string(),
) )
.with_context(|| "create_ingestion_payload".to_string())?; .with_context(|| "create_ingestion_payload".to_string())?;
@@ -174,14 +169,13 @@ mod tests {
let context = "Process this text"; let context = "Process this text";
let category = "notes"; let category = "notes";
let user_id = "user123"; let user_id = "user123";
let files = vec![];
let result = IngestionPayload::create_ingestion_payload( let result = IngestionPayload::create_ingestion_payload(
Some(text.to_string()), Some(text.to_string()),
context.to_string(), context.to_string(),
category.to_string(), category.to_string(),
&files, vec![],
user_id, user_id.to_string(),
) )
.with_context(|| "create_ingestion_payload".to_string())?; .with_context(|| "create_ingestion_payload".to_string())?;
@@ -215,16 +209,15 @@ mod tests {
}; };
let file_info: FileInfo = mock_file.into(); let file_info: FileInfo = mock_file.into();
let files = vec![file_info.clone()]; let file_id = file_info.id.clone();
let result = IngestionPayload::create_ingestion_payload( let result = IngestionPayload::create_ingestion_payload(
None, None,
context.to_string(), context.to_string(),
category.to_string(), category.to_string(),
&files, vec![file_info],
user_id, user_id.to_string(),
) )?;
.with_context(|| "create_ingestion_payload".to_string())?;
assert_eq!(result.len(), 1); assert_eq!(result.len(), 1);
match result.first().context("expected one result")? { match result.first().context("expected one result")? {
@@ -234,7 +227,7 @@ mod tests {
category: payload_category, category: payload_category,
user_id: payload_user_id, user_id: payload_user_id,
} => { } => {
assert_eq!(payload_file_info.id, file_info.id); assert_eq!(payload_file_info.id, file_id);
assert_eq!(payload_context, context); assert_eq!(payload_context, context);
assert_eq!(payload_category, category); assert_eq!(payload_category, category);
assert_eq!(payload_user_id, user_id); assert_eq!(payload_user_id, user_id);
@@ -257,39 +250,38 @@ mod tests {
}; };
let file_info: FileInfo = mock_file.into(); let file_info: FileInfo = mock_file.into();
let files = vec![file_info.clone()]; let file_id = file_info.id.clone();
let result = IngestionPayload::create_ingestion_payload( let result = IngestionPayload::create_ingestion_payload(
Some(url.to_string()), Some(url.to_string()),
context.to_string(), context.to_string(),
category.to_string(), category.to_string(),
&files, vec![file_info],
user_id, user_id.to_string(),
) )?;
.with_context(|| "create_ingestion_payload".to_string())?;
assert_eq!(result.len(), 2); assert_eq!(result.len(), 2);
// Check first item is URL // Check first item is File (files processed first to minimize clones)
match result.first().context("expected first item")? { match result.first().context("expected first item")? {
IngestionPayload::File {
file_info: payload_file_info,
..
} => {
assert_eq!(payload_file_info.id, file_id);
}
_ => anyhow::bail!("Expected first item to be File variant"),
}
// Check second item is URL
match result.get(1).context("expected second item")? {
IngestionPayload::Url { IngestionPayload::Url {
url: payload_url, .. url: payload_url, ..
} => { } => {
// URL parser may normalize the URL by adding a trailing slash // URL parser may normalize the URL by adding a trailing slash
assert!(payload_url == &url.to_string() || payload_url == &format!("{url}/")); assert!(payload_url == &url.to_string() || payload_url == &format!("{url}/"));
} }
_ => anyhow::bail!("Expected first item to be Url variant"), _ => anyhow::bail!("Expected second item to be Url variant"),
}
// Check second item is File
match result.get(1).context("expected second item")? {
IngestionPayload::File {
file_info: payload_file_info,
..
} => {
assert_eq!(payload_file_info.id, file_info.id);
}
_ => anyhow::bail!("Expected second item to be File variant"),
} }
Ok(()) Ok(())
} }
@@ -299,14 +291,13 @@ mod tests {
let context = "Process something"; let context = "Process something";
let category = "empty"; let category = "empty";
let user_id = "user123"; let user_id = "user123";
let files = vec![];
let result = IngestionPayload::create_ingestion_payload( let result = IngestionPayload::create_ingestion_payload(
None, None,
context.to_string(), context.to_string(),
category.to_string(), category.to_string(),
&files, vec![],
user_id, user_id.to_string(),
); );
assert!(result.is_err()); assert!(result.is_err());
@@ -325,14 +316,13 @@ mod tests {
let context = "Process this"; let context = "Process this";
let category = "notes"; let category = "notes";
let user_id = "user123"; let user_id = "user123";
let files = vec![];
let result = IngestionPayload::create_ingestion_payload( let result = IngestionPayload::create_ingestion_payload(
Some(text.to_string()), Some(text.to_string()),
context.to_string(), context.to_string(),
category.to_string(), category.to_string(),
&files, vec![],
user_id, user_id.to_string(),
); );
assert!(result.is_err()); assert!(result.is_err());
+15 -13
View File
@@ -250,7 +250,7 @@ impl KnowledgeEntity {
.bind(("embedding", query_embedding)) .bind(("embedding", query_embedding))
.bind(("user_id", user_id.to_string())) .bind(("user_id", user_id.to_string()))
.await .await
.map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?; .map_err(AppError::Database)?;
response = response.check().map_err(AppError::Database)?; response = response.check().map_err(AppError::Database)?;
@@ -326,7 +326,7 @@ impl KnowledgeEntity {
let rows: Vec<Row> = response.take(0).map_err(AppError::Database)?; let rows: Vec<Row> = response.take(0).map_err(AppError::Database)?;
rows.first() rows.first()
.map(|r| r.user_id.clone()) .map(|r| r.user_id.clone())
.ok_or_else(|| AppError::InternalError("user not found for entity".to_string())) .ok_or_else(|| AppError::internal("user not found for entity"))
} }
/// Re-creates embeddings for all knowledge entities in the database. /// Re-creates embeddings for all knowledge entities in the database.
@@ -385,7 +385,7 @@ impl KnowledgeEntity {
entity.id, embedding.len(), new_dimensions entity.id, embedding.len(), new_dimensions
); );
error!("{err_msg}"); error!("{err_msg}");
return Err(AppError::InternalError(err_msg)); return Err(AppError::internal(err_msg));
} }
new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone())); new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone()));
} }
@@ -397,8 +397,9 @@ impl KnowledgeEntity {
// Add all update statements to the embedding table // Add all update statements to the embedding table
for (id, (embedding, user_id)) in new_embeddings { for (id, (embedding, user_id)) in new_embeddings {
let embedding = serde_json::to_string(&embedding) let embedding = serde_json::to_string(&embedding).map_err(|e| {
.map_err(|e| AppError::InternalError(format!("embedding serialization failed: {e}")))?; AppError::internal(format!("embedding serialization failed: {e}"))
})?;
write!( write!(
transaction_query, transaction_query,
"UPSERT type::thing('knowledge_entity_embedding', '{id}') SET \ "UPSERT type::thing('knowledge_entity_embedding', '{id}') SET \
@@ -408,14 +409,14 @@ impl KnowledgeEntity {
created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \ created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \
updated_at = time::now();", updated_at = time::now();",
) )
.map_err(|e| AppError::InternalError(e.to_string()))?; .map_err(AppError::internal)?;
} }
write!( write!(
transaction_query, transaction_query,
"DEFINE INDEX OVERWRITE idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};", "DEFINE INDEX OVERWRITE idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};",
) )
.map_err(|e| AppError::InternalError(e.to_string()))?; .map_err(AppError::internal)?;
transaction_query.push_str("COMMIT TRANSACTION;"); transaction_query.push_str("COMMIT TRANSACTION;");
@@ -472,7 +473,7 @@ impl KnowledgeEntity {
let embedding = provider let embedding = provider
.embed(&embedding_input) .embed(&embedding_input)
.await .await
.map_err(|e| AppError::InternalError(format!("Embedding failed: {e}")))?; .map_err(AppError::internal)?;
// Safety check: ensure the generated embedding has the correct dimension. // Safety check: ensure the generated embedding has the correct dimension.
if embedding.len() != new_dimensions { if embedding.len() != new_dimensions {
@@ -481,7 +482,7 @@ impl KnowledgeEntity {
entity.id, embedding.len(), new_dimensions entity.id, embedding.len(), new_dimensions
); );
error!("{err_msg}"); error!("{err_msg}");
return Err(AppError::InternalError(err_msg)); return Err(AppError::internal(err_msg));
} }
new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone())); new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone()));
} }
@@ -517,8 +518,9 @@ impl KnowledgeEntity {
let mut transaction_query = String::from("BEGIN TRANSACTION;"); let mut transaction_query = String::from("BEGIN TRANSACTION;");
for (id, (embedding, user_id)) in new_embeddings { for (id, (embedding, user_id)) in new_embeddings {
let embedding = serde_json::to_string(&embedding) let embedding = serde_json::to_string(&embedding).map_err(|e| {
.map_err(|e| AppError::InternalError(format!("embedding serialization failed: {e}")))?; AppError::internal(format!("embedding serialization failed: {e}"))
})?;
write!( write!(
transaction_query, transaction_query,
"CREATE type::thing('knowledge_entity_embedding', '{id}') SET \ "CREATE type::thing('knowledge_entity_embedding', '{id}') SET \
@@ -528,14 +530,14 @@ impl KnowledgeEntity {
created_at = time::now(), \ created_at = time::now(), \
updated_at = time::now();", updated_at = time::now();",
) )
.map_err(|e| AppError::InternalError(e.to_string()))?; .map_err(AppError::internal)?;
} }
write!( write!(
transaction_query, transaction_query,
"DEFINE INDEX OVERWRITE idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};", "DEFINE INDEX OVERWRITE idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};",
) )
.map_err(|e| AppError::InternalError(e.to_string()))?; .map_err(AppError::internal)?;
transaction_query.push_str("COMMIT TRANSACTION;"); transaction_query.push_str("COMMIT TRANSACTION;");
@@ -94,7 +94,7 @@ impl KnowledgeRelationship {
.bind(("id", id.to_owned())) .bind(("id", id.to_owned()))
.bind(("user_id", user_id.to_owned())) .bind(("user_id", user_id.to_owned()))
.await?; .await?;
let authorized: Vec<KnowledgeRelationship> = authorized_result.take(0).unwrap_or_default(); let authorized: Vec<KnowledgeRelationship> = authorized_result.take(0)?;
if authorized.is_empty() { if authorized.is_empty() {
let mut exists_result = db_client let mut exists_result = db_client
+1 -1
View File
@@ -69,7 +69,7 @@ pub fn format_history(history: &[Message]) -> String {
if i > 0 { if i > 0 {
out.push('\n'); out.push('\n');
} }
write!(out, "{msg}").unwrap_or_default(); let _ = write!(out, "{msg}");
} }
out out
} }
+18 -14
View File
@@ -135,9 +135,11 @@ impl TextChunk {
.bind(("embedding", query_embedding)) .bind(("embedding", query_embedding))
.bind(("user_id", user_id.to_string())) .bind(("user_id", user_id.to_string()))
.await .await
.map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?; .map_err(AppError::Database)?;
let rows: Vec<Row> = response.take::<Vec<Row>>(0).unwrap_or_default(); response = response.check().map_err(AppError::Database)?;
let rows: Vec<Row> = response.take::<Vec<Row>>(0).map_err(AppError::Database)?;
Ok(rows Ok(rows
.into_iter() .into_iter()
@@ -198,7 +200,7 @@ impl TextChunk {
.bind(("user_id", user_id.to_owned())) .bind(("user_id", user_id.to_owned()))
.bind(("limit", limit)) .bind(("limit", limit))
.await .await
.map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?; .map_err(AppError::Database)?;
response = response.check().map_err(AppError::Database)?; response = response.check().map_err(AppError::Database)?;
@@ -277,7 +279,7 @@ impl TextChunk {
chunk.id, embedding.len(), new_dimensions chunk.id, embedding.len(), new_dimensions
); );
error!("{err_msg}"); error!("{err_msg}");
return Err(AppError::InternalError(err_msg)); return Err(AppError::internal(err_msg));
} }
new_embeddings.insert( new_embeddings.insert(
chunk.id.clone(), chunk.id.clone(),
@@ -291,8 +293,9 @@ impl TextChunk {
let mut transaction_query = String::from("BEGIN TRANSACTION;"); let mut transaction_query = String::from("BEGIN TRANSACTION;");
for (id, (embedding, user_id, source_id)) in new_embeddings { for (id, (embedding, user_id, source_id)) in new_embeddings {
let embedding = serde_json::to_string(&embedding) let embedding = serde_json::to_string(&embedding).map_err(|e| {
.map_err(|e| AppError::InternalError(format!("embedding serialization failed: {e}")))?; AppError::internal(format!("embedding serialization failed: {e}"))
})?;
write!( write!(
&mut transaction_query, &mut transaction_query,
"UPSERT type::thing('text_chunk_embedding', '{id}') SET \ "UPSERT type::thing('text_chunk_embedding', '{id}') SET \
@@ -303,14 +306,14 @@ impl TextChunk {
created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \ created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \
updated_at = time::now();", updated_at = time::now();",
) )
.map_err(|e| AppError::InternalError(e.to_string()))?; .map_err(AppError::internal)?;
} }
write!( write!(
&mut transaction_query, &mut transaction_query,
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};", "DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};",
) )
.map_err(|e| AppError::InternalError(e.to_string()))?; .map_err(AppError::internal)?;
transaction_query.push_str("COMMIT TRANSACTION;"); transaction_query.push_str("COMMIT TRANSACTION;");
@@ -357,7 +360,7 @@ impl TextChunk {
let embedding = provider let embedding = provider
.embed(&chunk.chunk) .embed(&chunk.chunk)
.await .await
.map_err(|e| AppError::InternalError(format!("Embedding failed: {e}")))?; .map_err(AppError::internal)?;
// Safety check: ensure the generated embedding has the correct dimension. // Safety check: ensure the generated embedding has the correct dimension.
if embedding.len() != new_dimensions { if embedding.len() != new_dimensions {
@@ -366,7 +369,7 @@ impl TextChunk {
chunk.id, embedding.len(), new_dimensions chunk.id, embedding.len(), new_dimensions
); );
error!("{err_msg}"); error!("{err_msg}");
return Err(AppError::InternalError(err_msg)); return Err(AppError::internal(err_msg));
} }
new_embeddings.insert( new_embeddings.insert(
chunk.id.clone(), chunk.id.clone(),
@@ -402,8 +405,9 @@ impl TextChunk {
let mut transaction_query = String::from("BEGIN TRANSACTION;"); let mut transaction_query = String::from("BEGIN TRANSACTION;");
for (id, (embedding, user_id, source_id)) in new_embeddings { for (id, (embedding, user_id, source_id)) in new_embeddings {
let embedding = serde_json::to_string(&embedding) let embedding = serde_json::to_string(&embedding).map_err(|e| {
.map_err(|e| AppError::InternalError(format!("embedding serialization failed: {e}")))?; AppError::internal(format!("embedding serialization failed: {e}"))
})?;
write!( write!(
&mut transaction_query, &mut transaction_query,
"CREATE type::thing('text_chunk_embedding', '{id}') SET \ "CREATE type::thing('text_chunk_embedding', '{id}') SET \
@@ -414,14 +418,14 @@ impl TextChunk {
created_at = time::now(), \ created_at = time::now(), \
updated_at = time::now();", updated_at = time::now();",
) )
.map_err(|e| AppError::InternalError(e.to_string()))?; .map_err(AppError::internal)?;
} }
write!( write!(
&mut transaction_query, &mut transaction_query,
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};", "DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};",
) )
.map_err(|e| AppError::InternalError(e.to_string()))?; .map_err(AppError::internal)?;
transaction_query.push_str("COMMIT TRANSACTION;"); transaction_query.push_str("COMMIT TRANSACTION;");
+39 -19
View File
@@ -2,13 +2,13 @@ use std::{
collections::hash_map::DefaultHasher, collections::hash_map::DefaultHasher,
hash::{Hash, Hasher}, hash::{Hash, Hasher},
str::FromStr, str::FromStr,
sync::Arc, sync::{Arc, Mutex},
}; };
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use async_openai::{types::CreateEmbeddingRequestArgs, Client}; use async_openai::{types::CreateEmbeddingRequestArgs, Client};
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions}; use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
use tokio::sync::Mutex; use thiserror::Error;
use tracing::debug; use tracing::debug;
use crate::{ use crate::{
@@ -16,6 +16,14 @@ use crate::{
storage::{db::SurrealDbClient, types::system_settings::SystemSettings}, storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
}; };
/// Error returned when parsing an embedding backend name.
#[derive(Debug, Error, PartialEq, Eq)]
#[error("unknown embedding backend '{input}': expected 'openai', 'hashed', or 'fastembed'")]
pub struct ParseEmbeddingBackendError {
/// The unrecognized input string.
pub input: String,
}
/// Supported embedding backends. /// Supported embedding backends.
#[allow(clippy::module_name_repetitions)] #[allow(clippy::module_name_repetitions)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
@@ -27,16 +35,16 @@ pub enum EmbeddingBackend {
} }
impl std::str::FromStr for EmbeddingBackend { impl std::str::FromStr for EmbeddingBackend {
type Err = anyhow::Error; type Err = ParseEmbeddingBackendError;
fn from_str(s: &str) -> Result<Self, Self::Err> { fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_ascii_lowercase().as_str() { match s.to_ascii_lowercase().as_str() {
"openai" => Ok(Self::OpenAI), "openai" => Ok(Self::OpenAI),
"hashed" => Ok(Self::Hashed), "hashed" => Ok(Self::Hashed),
"fastembed" | "fast-embed" | "fast" => Ok(Self::FastEmbed), "fastembed" | "fast-embed" | "fast" => Ok(Self::FastEmbed),
other => Err(anyhow!( other => Err(ParseEmbeddingBackendError {
"unknown embedding backend '{other}'. Expected 'openai', 'hashed', or 'fastembed'." input: other.to_string(),
)), }),
} }
} }
} }
@@ -68,7 +76,7 @@ enum EmbeddingInner {
}, },
/// Uses `FastEmbed` running locally. /// Uses `FastEmbed` running locally.
FastEmbed { FastEmbed {
/// Shared `FastEmbed` model. /// Shared `FastEmbed` model (blocking; used only inside `spawn_blocking`).
model: Arc<Mutex<TextEmbedding>>, model: Arc<Mutex<TextEmbedding>>,
/// Model metadata used for info logging. /// Model metadata used for info logging.
model_name: EmbeddingModel, model_name: EmbeddingModel,
@@ -77,6 +85,21 @@ enum EmbeddingInner {
}, },
} }
async fn run_fastembed(
model: Arc<Mutex<TextEmbedding>>,
texts: Vec<String>,
) -> Result<Vec<Vec<f32>>> {
tokio::task::spawn_blocking(move || {
let mut guard = model
.lock()
.map_err(|e| anyhow!("fastembed model mutex poisoned: {e}"))?;
guard.embed(texts, None)
})
.await
.context("joining fastembed embedding task")?
.context("generating fastembed embeddings")
}
impl EmbeddingProvider { impl EmbeddingProvider {
#[must_use] #[must_use]
pub fn backend_label(&self) -> &'static str { pub fn backend_label(&self) -> &'static str {
@@ -116,10 +139,7 @@ impl EmbeddingProvider {
match &self.inner { match &self.inner {
EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)), EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)),
EmbeddingInner::FastEmbed { model, .. } => { EmbeddingInner::FastEmbed { model, .. } => {
let mut guard = model.lock().await; let embeddings = run_fastembed(Arc::clone(model), vec![text.to_owned()]).await?;
let embeddings = guard
.embed(vec![text.to_owned()], None)
.context("generating fastembed vector")?;
embeddings embeddings
.into_iter() .into_iter()
.next() .next()
@@ -166,10 +186,7 @@ impl EmbeddingProvider {
if texts.is_empty() { if texts.is_empty() {
return Ok(Vec::new()); return Ok(Vec::new());
} }
let mut guard = model.lock().await; run_fastembed(Arc::clone(model), texts).await
guard
.embed(texts, None)
.context("generating fastembed batch embeddings")
} }
EmbeddingInner::OpenAI { EmbeddingInner::OpenAI {
client, client,
@@ -325,12 +342,15 @@ fn bucket(token: &str, dimension: usize) -> usize {
/// ///
/// # Errors /// # Errors
/// ///
/// Returns `AppError::InternalError` if the provider's embed call fails. /// Returns [`AppError::InternalError`] if the provider's embed call fails.
pub async fn generate_embedding_with_provider( pub async fn generate_embedding_with_provider(
provider: &EmbeddingProvider, provider: &EmbeddingProvider,
input: &str, input: &str,
) -> Result<Vec<f32>, AppError> { ) -> Result<Vec<f32>, AppError> {
provider.embed(input).await.map_err(|e| AppError::InternalError(e.to_string())) provider
.embed(input)
.await
.map_err(AppError::internal)
} }
/// Generates an embedding vector for the given input text using `OpenAI`'s embedding model. /// Generates an embedding vector for the given input text using `OpenAI`'s embedding model.
@@ -377,7 +397,7 @@ pub async fn generate_embedding(
let embedding: Vec<f32> = response let embedding: Vec<f32> = response
.data .data
.first() .first()
.ok_or_else(|| AppError::LLMParsing("No embedding data received".into()))? .ok_or_else(|| AppError::LLMParsing("no embedding data received".into()))?
.embedding .embedding
.clone(); .clone();
@@ -409,7 +429,7 @@ pub async fn generate_embedding_with_params(
let embedding = response let embedding = response
.data .data
.first() .first()
.ok_or_else(|| AppError::LLMParsing("No embedding data received from API".into()))? .ok_or_else(|| AppError::LLMParsing("no embedding data received from API".into()))?
.embedding .embedding
.clone(); .clone();
+9 -5
View File
@@ -1,11 +1,15 @@
use thiserror::Error;
use super::config::AppConfig; use super::config::AppConfig;
/// Errors raised when validating ingestion payloads against configured limits. /// Errors raised when validating ingestion payloads against configured limits.
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Error, Debug, Clone, PartialEq, Eq)]
pub enum IngestValidationError { pub enum IngestValidationError {
/// The payload exceeds a configured size limit (content, context, or category). /// The payload exceeds a configured size limit (content, context, or category).
#[error("payload too large: {0}")]
PayloadTooLarge(String), PayloadTooLarge(String),
/// The request violates a non-size constraint (e.g., too many files). /// The request violates a non-size constraint (e.g., too many files).
#[error("bad request: {0}")]
BadRequest(String), BadRequest(String),
} }
@@ -27,7 +31,7 @@ pub fn validate_ingest_input(
) -> Result<(), IngestValidationError> { ) -> Result<(), IngestValidationError> {
if file_count > config.ingest_max_files { if file_count > config.ingest_max_files {
return Err(IngestValidationError::BadRequest(format!( return Err(IngestValidationError::BadRequest(format!(
"Too many files. Maximum allowed is {}", "too many files: maximum allowed is {}",
config.ingest_max_files config.ingest_max_files
))); )));
} }
@@ -35,7 +39,7 @@ pub fn validate_ingest_input(
if let Some(content) = content { if let Some(content) = content {
if content.len() > config.ingest_max_content_bytes { if content.len() > config.ingest_max_content_bytes {
return Err(IngestValidationError::PayloadTooLarge(format!( return Err(IngestValidationError::PayloadTooLarge(format!(
"Content is too large. Maximum allowed is {} bytes", "content is too large: maximum allowed is {} bytes",
config.ingest_max_content_bytes config.ingest_max_content_bytes
))); )));
} }
@@ -43,14 +47,14 @@ pub fn validate_ingest_input(
if ctx.len() > config.ingest_max_context_bytes { if ctx.len() > config.ingest_max_context_bytes {
return Err(IngestValidationError::PayloadTooLarge(format!( return Err(IngestValidationError::PayloadTooLarge(format!(
"Context is too large. Maximum allowed is {} bytes", "context is too large: maximum allowed is {} bytes",
config.ingest_max_context_bytes config.ingest_max_context_bytes
))); )));
} }
if category.len() > config.ingest_max_category_bytes { if category.len() > config.ingest_max_category_bytes {
return Err(IngestValidationError::PayloadTooLarge(format!( return Err(IngestValidationError::PayloadTooLarge(format!(
"Category is too large. Maximum allowed is {} bytes", "category is too large: maximum allowed is {} bytes",
config.ingest_max_category_bytes config.ingest_max_category_bytes
))); )));
} }
+2 -2
View File
@@ -142,8 +142,8 @@ pub async fn process_ingest_form(
input.content, input.content,
input.context, input.context,
input.category, input.category,
&file_infos, file_infos,
user.id.as_str(), user.id.clone(),
)?; )?;
let futures: Vec<_> = payloads let futures: Vec<_> = payloads