mirror of
https://github.com/perstarkse/minne.git
synced 2026-05-29 19:00:51 +02:00
chore: harden common errors, fastembed blocking, and ingest ownership
Run FastEmbed inference on spawn_blocking, propagate Surreal take failures, add AppError::internal and typed ingest/embedding parse errors, and take owned file lists in ingestion payload construction.
This commit is contained in:
@@ -73,8 +73,8 @@ pub async fn ingest_data(
|
||||
input.content,
|
||||
input.context,
|
||||
input.category,
|
||||
&file_infos,
|
||||
&user_id,
|
||||
file_infos,
|
||||
user_id.clone(),
|
||||
)?;
|
||||
|
||||
let futures: Vec<_> = payloads
|
||||
|
||||
@@ -39,3 +39,11 @@ pub enum AppError {
|
||||
#[error("internal service error: {0}")]
|
||||
InternalError(String),
|
||||
}
|
||||
|
||||
impl AppError {
|
||||
/// Builds an [`AppError::InternalError`] from a displayable message.
|
||||
#[must_use]
|
||||
pub fn internal(msg: impl std::fmt::Display) -> Self {
|
||||
Self::InternalError(msg.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,7 +124,7 @@ impl SurrealDbClient {
|
||||
.load_files(&MIGRATIONS_DIR)
|
||||
.up()
|
||||
.await
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
.map_err(AppError::internal)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -169,7 +169,7 @@ pub async fn ensure_runtime(
|
||||
) -> Result<(), AppError> {
|
||||
ensure_runtime_inner(db, embedding_dimension)
|
||||
.await
|
||||
.map_err(|err| AppError::InternalError(err.to_string()))
|
||||
.map_err(AppError::internal)
|
||||
}
|
||||
|
||||
/// Rebuild known FTS and HNSW indexes, skipping any that are not yet defined.
|
||||
@@ -180,7 +180,7 @@ pub async fn ensure_runtime(
|
||||
pub async fn rebuild(db: &SurrealDbClient) -> Result<(), AppError> {
|
||||
rebuild_inner(db)
|
||||
.await
|
||||
.map_err(|err| AppError::InternalError(err.to_string()))
|
||||
.map_err(AppError::internal)
|
||||
}
|
||||
|
||||
async fn ensure_runtime_inner(db: &SurrealDbClient, embedding_dimension: usize) -> Result<()> {
|
||||
|
||||
@@ -261,7 +261,7 @@ impl FileInfo {
|
||||
|
||||
// Remove the object's parent prefix in the object store
|
||||
let (parent_prefix, _file_name) = store::split_object_path(&file_info.path)
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
.map_err(AppError::internal)?;
|
||||
storage
|
||||
.delete_prefix(&parent_prefix)
|
||||
.await
|
||||
|
||||
@@ -1,9 +1,4 @@
|
||||
#![allow(
|
||||
clippy::result_large_err,
|
||||
clippy::needless_pass_by_value,
|
||||
clippy::implicit_clone,
|
||||
clippy::semicolon_if_nothing_returned
|
||||
)]
|
||||
#![allow(clippy::result_large_err)]
|
||||
use crate::{error::AppError, storage::types::file_info::FileInfo};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::info;
|
||||
@@ -49,48 +44,49 @@ impl IngestionPayload {
|
||||
content: Option<String>,
|
||||
context: String,
|
||||
category: String,
|
||||
files: &[FileInfo],
|
||||
user_id: &str,
|
||||
files: Vec<FileInfo>,
|
||||
user_id: String,
|
||||
) -> Result<Vec<IngestionPayload>, AppError> {
|
||||
// Initialize list
|
||||
let mut object_list = Vec::new();
|
||||
let has_content = content
|
||||
.as_ref()
|
||||
.is_some_and(|c| c.len() > 2);
|
||||
#[allow(clippy::arithmetic_side_effects)]
|
||||
let capacity = files.len() + usize::from(has_content);
|
||||
let mut object_list = Vec::with_capacity(capacity);
|
||||
for file in files {
|
||||
object_list.push(IngestionPayload::File {
|
||||
file_info: file,
|
||||
context: context.clone(),
|
||||
category: category.clone(),
|
||||
user_id: user_id.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
// Create a IngestionPayload from content if it exists, checking for URL or text
|
||||
if let Some(input_content) = content {
|
||||
match Url::parse(&input_content) {
|
||||
Ok(url) => {
|
||||
info!("Detected URL: {}", url);
|
||||
object_list.push(IngestionPayload::Url {
|
||||
url: url.to_string(),
|
||||
context: context.clone(),
|
||||
category: category.clone(),
|
||||
user_id: user_id.into(),
|
||||
context,
|
||||
category,
|
||||
user_id,
|
||||
});
|
||||
}
|
||||
Err(_) => {
|
||||
if input_content.len() > 2 {
|
||||
info!("Treating input as plain text");
|
||||
object_list.push(IngestionPayload::Text {
|
||||
text: input_content.to_string(),
|
||||
context: context.clone(),
|
||||
category: category.clone(),
|
||||
user_id: user_id.into(),
|
||||
text: input_content,
|
||||
context,
|
||||
category,
|
||||
user_id,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for file in files {
|
||||
object_list.push(IngestionPayload::File {
|
||||
file_info: file.clone(),
|
||||
context: context.clone(),
|
||||
category: category.clone(),
|
||||
user_id: user_id.into(),
|
||||
})
|
||||
}
|
||||
|
||||
// If no objects are constructed, we return Err
|
||||
if object_list.is_empty() {
|
||||
return Err(AppError::NotFound(
|
||||
"No valid content or files provided".into(),
|
||||
@@ -138,14 +134,13 @@ mod tests {
|
||||
let context = "Process this URL";
|
||||
let category = "websites";
|
||||
let user_id = "user123";
|
||||
let files = vec![];
|
||||
|
||||
let result = IngestionPayload::create_ingestion_payload(
|
||||
Some(url.to_string()),
|
||||
context.to_string(),
|
||||
category.to_string(),
|
||||
&files,
|
||||
user_id,
|
||||
vec![],
|
||||
user_id.to_string(),
|
||||
)
|
||||
.with_context(|| "create_ingestion_payload".to_string())?;
|
||||
|
||||
@@ -174,14 +169,13 @@ mod tests {
|
||||
let context = "Process this text";
|
||||
let category = "notes";
|
||||
let user_id = "user123";
|
||||
let files = vec![];
|
||||
|
||||
let result = IngestionPayload::create_ingestion_payload(
|
||||
Some(text.to_string()),
|
||||
context.to_string(),
|
||||
category.to_string(),
|
||||
&files,
|
||||
user_id,
|
||||
vec![],
|
||||
user_id.to_string(),
|
||||
)
|
||||
.with_context(|| "create_ingestion_payload".to_string())?;
|
||||
|
||||
@@ -215,16 +209,15 @@ mod tests {
|
||||
};
|
||||
|
||||
let file_info: FileInfo = mock_file.into();
|
||||
let files = vec![file_info.clone()];
|
||||
let file_id = file_info.id.clone();
|
||||
|
||||
let result = IngestionPayload::create_ingestion_payload(
|
||||
None,
|
||||
context.to_string(),
|
||||
category.to_string(),
|
||||
&files,
|
||||
user_id,
|
||||
)
|
||||
.with_context(|| "create_ingestion_payload".to_string())?;
|
||||
vec![file_info],
|
||||
user_id.to_string(),
|
||||
)?;
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
match result.first().context("expected one result")? {
|
||||
@@ -234,7 +227,7 @@ mod tests {
|
||||
category: payload_category,
|
||||
user_id: payload_user_id,
|
||||
} => {
|
||||
assert_eq!(payload_file_info.id, file_info.id);
|
||||
assert_eq!(payload_file_info.id, file_id);
|
||||
assert_eq!(payload_context, context);
|
||||
assert_eq!(payload_category, category);
|
||||
assert_eq!(payload_user_id, user_id);
|
||||
@@ -257,39 +250,38 @@ mod tests {
|
||||
};
|
||||
|
||||
let file_info: FileInfo = mock_file.into();
|
||||
let files = vec![file_info.clone()];
|
||||
let file_id = file_info.id.clone();
|
||||
|
||||
let result = IngestionPayload::create_ingestion_payload(
|
||||
Some(url.to_string()),
|
||||
context.to_string(),
|
||||
category.to_string(),
|
||||
&files,
|
||||
user_id,
|
||||
)
|
||||
.with_context(|| "create_ingestion_payload".to_string())?;
|
||||
vec![file_info],
|
||||
user_id.to_string(),
|
||||
)?;
|
||||
|
||||
assert_eq!(result.len(), 2);
|
||||
|
||||
// Check first item is URL
|
||||
// Check first item is File (files processed first to minimize clones)
|
||||
match result.first().context("expected first item")? {
|
||||
IngestionPayload::File {
|
||||
file_info: payload_file_info,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(payload_file_info.id, file_id);
|
||||
}
|
||||
_ => anyhow::bail!("Expected first item to be File variant"),
|
||||
}
|
||||
|
||||
// Check second item is URL
|
||||
match result.get(1).context("expected second item")? {
|
||||
IngestionPayload::Url {
|
||||
url: payload_url, ..
|
||||
} => {
|
||||
// URL parser may normalize the URL by adding a trailing slash
|
||||
assert!(payload_url == &url.to_string() || payload_url == &format!("{url}/"));
|
||||
}
|
||||
_ => anyhow::bail!("Expected first item to be Url variant"),
|
||||
}
|
||||
|
||||
// Check second item is File
|
||||
match result.get(1).context("expected second item")? {
|
||||
IngestionPayload::File {
|
||||
file_info: payload_file_info,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(payload_file_info.id, file_info.id);
|
||||
}
|
||||
_ => anyhow::bail!("Expected second item to be File variant"),
|
||||
_ => anyhow::bail!("Expected second item to be Url variant"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -299,14 +291,13 @@ mod tests {
|
||||
let context = "Process something";
|
||||
let category = "empty";
|
||||
let user_id = "user123";
|
||||
let files = vec![];
|
||||
|
||||
let result = IngestionPayload::create_ingestion_payload(
|
||||
None,
|
||||
context.to_string(),
|
||||
category.to_string(),
|
||||
&files,
|
||||
user_id,
|
||||
vec![],
|
||||
user_id.to_string(),
|
||||
);
|
||||
|
||||
assert!(result.is_err());
|
||||
@@ -325,14 +316,13 @@ mod tests {
|
||||
let context = "Process this";
|
||||
let category = "notes";
|
||||
let user_id = "user123";
|
||||
let files = vec![];
|
||||
|
||||
let result = IngestionPayload::create_ingestion_payload(
|
||||
Some(text.to_string()),
|
||||
context.to_string(),
|
||||
category.to_string(),
|
||||
&files,
|
||||
user_id,
|
||||
vec![],
|
||||
user_id.to_string(),
|
||||
);
|
||||
|
||||
assert!(result.is_err());
|
||||
|
||||
@@ -250,7 +250,7 @@ impl KnowledgeEntity {
|
||||
.bind(("embedding", query_embedding))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.await
|
||||
.map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?;
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
response = response.check().map_err(AppError::Database)?;
|
||||
|
||||
@@ -326,7 +326,7 @@ impl KnowledgeEntity {
|
||||
let rows: Vec<Row> = response.take(0).map_err(AppError::Database)?;
|
||||
rows.first()
|
||||
.map(|r| r.user_id.clone())
|
||||
.ok_or_else(|| AppError::InternalError("user not found for entity".to_string()))
|
||||
.ok_or_else(|| AppError::internal("user not found for entity"))
|
||||
}
|
||||
|
||||
/// Re-creates embeddings for all knowledge entities in the database.
|
||||
@@ -385,7 +385,7 @@ impl KnowledgeEntity {
|
||||
entity.id, embedding.len(), new_dimensions
|
||||
);
|
||||
error!("{err_msg}");
|
||||
return Err(AppError::InternalError(err_msg));
|
||||
return Err(AppError::internal(err_msg));
|
||||
}
|
||||
new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone()));
|
||||
}
|
||||
@@ -397,8 +397,9 @@ impl KnowledgeEntity {
|
||||
|
||||
// Add all update statements to the embedding table
|
||||
for (id, (embedding, user_id)) in new_embeddings {
|
||||
let embedding = serde_json::to_string(&embedding)
|
||||
.map_err(|e| AppError::InternalError(format!("embedding serialization failed: {e}")))?;
|
||||
let embedding = serde_json::to_string(&embedding).map_err(|e| {
|
||||
AppError::internal(format!("embedding serialization failed: {e}"))
|
||||
})?;
|
||||
write!(
|
||||
transaction_query,
|
||||
"UPSERT type::thing('knowledge_entity_embedding', '{id}') SET \
|
||||
@@ -408,14 +409,14 @@ impl KnowledgeEntity {
|
||||
created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \
|
||||
updated_at = time::now();",
|
||||
)
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
.map_err(AppError::internal)?;
|
||||
}
|
||||
|
||||
write!(
|
||||
transaction_query,
|
||||
"DEFINE INDEX OVERWRITE idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};",
|
||||
)
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
.map_err(AppError::internal)?;
|
||||
|
||||
transaction_query.push_str("COMMIT TRANSACTION;");
|
||||
|
||||
@@ -472,7 +473,7 @@ impl KnowledgeEntity {
|
||||
let embedding = provider
|
||||
.embed(&embedding_input)
|
||||
.await
|
||||
.map_err(|e| AppError::InternalError(format!("Embedding failed: {e}")))?;
|
||||
.map_err(AppError::internal)?;
|
||||
|
||||
// Safety check: ensure the generated embedding has the correct dimension.
|
||||
if embedding.len() != new_dimensions {
|
||||
@@ -481,7 +482,7 @@ impl KnowledgeEntity {
|
||||
entity.id, embedding.len(), new_dimensions
|
||||
);
|
||||
error!("{err_msg}");
|
||||
return Err(AppError::InternalError(err_msg));
|
||||
return Err(AppError::internal(err_msg));
|
||||
}
|
||||
new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone()));
|
||||
}
|
||||
@@ -517,8 +518,9 @@ impl KnowledgeEntity {
|
||||
let mut transaction_query = String::from("BEGIN TRANSACTION;");
|
||||
|
||||
for (id, (embedding, user_id)) in new_embeddings {
|
||||
let embedding = serde_json::to_string(&embedding)
|
||||
.map_err(|e| AppError::InternalError(format!("embedding serialization failed: {e}")))?;
|
||||
let embedding = serde_json::to_string(&embedding).map_err(|e| {
|
||||
AppError::internal(format!("embedding serialization failed: {e}"))
|
||||
})?;
|
||||
write!(
|
||||
transaction_query,
|
||||
"CREATE type::thing('knowledge_entity_embedding', '{id}') SET \
|
||||
@@ -528,14 +530,14 @@ impl KnowledgeEntity {
|
||||
created_at = time::now(), \
|
||||
updated_at = time::now();",
|
||||
)
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
.map_err(AppError::internal)?;
|
||||
}
|
||||
|
||||
write!(
|
||||
transaction_query,
|
||||
"DEFINE INDEX OVERWRITE idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};",
|
||||
)
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
.map_err(AppError::internal)?;
|
||||
|
||||
transaction_query.push_str("COMMIT TRANSACTION;");
|
||||
|
||||
|
||||
@@ -94,7 +94,7 @@ impl KnowledgeRelationship {
|
||||
.bind(("id", id.to_owned()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?;
|
||||
let authorized: Vec<KnowledgeRelationship> = authorized_result.take(0).unwrap_or_default();
|
||||
let authorized: Vec<KnowledgeRelationship> = authorized_result.take(0)?;
|
||||
|
||||
if authorized.is_empty() {
|
||||
let mut exists_result = db_client
|
||||
|
||||
@@ -69,7 +69,7 @@ pub fn format_history(history: &[Message]) -> String {
|
||||
if i > 0 {
|
||||
out.push('\n');
|
||||
}
|
||||
write!(out, "{msg}").unwrap_or_default();
|
||||
let _ = write!(out, "{msg}");
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
@@ -135,9 +135,11 @@ impl TextChunk {
|
||||
.bind(("embedding", query_embedding))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.await
|
||||
.map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?;
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
let rows: Vec<Row> = response.take::<Vec<Row>>(0).unwrap_or_default();
|
||||
response = response.check().map_err(AppError::Database)?;
|
||||
|
||||
let rows: Vec<Row> = response.take::<Vec<Row>>(0).map_err(AppError::Database)?;
|
||||
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
@@ -198,7 +200,7 @@ impl TextChunk {
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.bind(("limit", limit))
|
||||
.await
|
||||
.map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?;
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
response = response.check().map_err(AppError::Database)?;
|
||||
|
||||
@@ -277,7 +279,7 @@ impl TextChunk {
|
||||
chunk.id, embedding.len(), new_dimensions
|
||||
);
|
||||
error!("{err_msg}");
|
||||
return Err(AppError::InternalError(err_msg));
|
||||
return Err(AppError::internal(err_msg));
|
||||
}
|
||||
new_embeddings.insert(
|
||||
chunk.id.clone(),
|
||||
@@ -291,8 +293,9 @@ impl TextChunk {
|
||||
let mut transaction_query = String::from("BEGIN TRANSACTION;");
|
||||
|
||||
for (id, (embedding, user_id, source_id)) in new_embeddings {
|
||||
let embedding = serde_json::to_string(&embedding)
|
||||
.map_err(|e| AppError::InternalError(format!("embedding serialization failed: {e}")))?;
|
||||
let embedding = serde_json::to_string(&embedding).map_err(|e| {
|
||||
AppError::internal(format!("embedding serialization failed: {e}"))
|
||||
})?;
|
||||
write!(
|
||||
&mut transaction_query,
|
||||
"UPSERT type::thing('text_chunk_embedding', '{id}') SET \
|
||||
@@ -303,14 +306,14 @@ impl TextChunk {
|
||||
created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \
|
||||
updated_at = time::now();",
|
||||
)
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
.map_err(AppError::internal)?;
|
||||
}
|
||||
|
||||
write!(
|
||||
&mut transaction_query,
|
||||
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};",
|
||||
)
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
.map_err(AppError::internal)?;
|
||||
|
||||
transaction_query.push_str("COMMIT TRANSACTION;");
|
||||
|
||||
@@ -357,7 +360,7 @@ impl TextChunk {
|
||||
let embedding = provider
|
||||
.embed(&chunk.chunk)
|
||||
.await
|
||||
.map_err(|e| AppError::InternalError(format!("Embedding failed: {e}")))?;
|
||||
.map_err(AppError::internal)?;
|
||||
|
||||
// Safety check: ensure the generated embedding has the correct dimension.
|
||||
if embedding.len() != new_dimensions {
|
||||
@@ -366,7 +369,7 @@ impl TextChunk {
|
||||
chunk.id, embedding.len(), new_dimensions
|
||||
);
|
||||
error!("{err_msg}");
|
||||
return Err(AppError::InternalError(err_msg));
|
||||
return Err(AppError::internal(err_msg));
|
||||
}
|
||||
new_embeddings.insert(
|
||||
chunk.id.clone(),
|
||||
@@ -402,8 +405,9 @@ impl TextChunk {
|
||||
let mut transaction_query = String::from("BEGIN TRANSACTION;");
|
||||
|
||||
for (id, (embedding, user_id, source_id)) in new_embeddings {
|
||||
let embedding = serde_json::to_string(&embedding)
|
||||
.map_err(|e| AppError::InternalError(format!("embedding serialization failed: {e}")))?;
|
||||
let embedding = serde_json::to_string(&embedding).map_err(|e| {
|
||||
AppError::internal(format!("embedding serialization failed: {e}"))
|
||||
})?;
|
||||
write!(
|
||||
&mut transaction_query,
|
||||
"CREATE type::thing('text_chunk_embedding', '{id}') SET \
|
||||
@@ -414,14 +418,14 @@ impl TextChunk {
|
||||
created_at = time::now(), \
|
||||
updated_at = time::now();",
|
||||
)
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
.map_err(AppError::internal)?;
|
||||
}
|
||||
|
||||
write!(
|
||||
&mut transaction_query,
|
||||
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {new_dimensions};",
|
||||
)
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
.map_err(AppError::internal)?;
|
||||
|
||||
transaction_query.push_str("COMMIT TRANSACTION;");
|
||||
|
||||
|
||||
@@ -2,13 +2,13 @@ use std::{
|
||||
collections::hash_map::DefaultHasher,
|
||||
hash::{Hash, Hasher},
|
||||
str::FromStr,
|
||||
sync::Arc,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use async_openai::{types::CreateEmbeddingRequestArgs, Client};
|
||||
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
|
||||
use tokio::sync::Mutex;
|
||||
use thiserror::Error;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::{
|
||||
@@ -16,6 +16,14 @@ use crate::{
|
||||
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
|
||||
};
|
||||
|
||||
/// Error returned when parsing an embedding backend name.
|
||||
#[derive(Debug, Error, PartialEq, Eq)]
|
||||
#[error("unknown embedding backend '{input}': expected 'openai', 'hashed', or 'fastembed'")]
|
||||
pub struct ParseEmbeddingBackendError {
|
||||
/// The unrecognized input string.
|
||||
pub input: String,
|
||||
}
|
||||
|
||||
/// Supported embedding backends.
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
@@ -27,16 +35,16 @@ pub enum EmbeddingBackend {
|
||||
}
|
||||
|
||||
impl std::str::FromStr for EmbeddingBackend {
|
||||
type Err = anyhow::Error;
|
||||
type Err = ParseEmbeddingBackendError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_ascii_lowercase().as_str() {
|
||||
"openai" => Ok(Self::OpenAI),
|
||||
"hashed" => Ok(Self::Hashed),
|
||||
"fastembed" | "fast-embed" | "fast" => Ok(Self::FastEmbed),
|
||||
other => Err(anyhow!(
|
||||
"unknown embedding backend '{other}'. Expected 'openai', 'hashed', or 'fastembed'."
|
||||
)),
|
||||
other => Err(ParseEmbeddingBackendError {
|
||||
input: other.to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -68,7 +76,7 @@ enum EmbeddingInner {
|
||||
},
|
||||
/// Uses `FastEmbed` running locally.
|
||||
FastEmbed {
|
||||
/// Shared `FastEmbed` model.
|
||||
/// Shared `FastEmbed` model (blocking; used only inside `spawn_blocking`).
|
||||
model: Arc<Mutex<TextEmbedding>>,
|
||||
/// Model metadata used for info logging.
|
||||
model_name: EmbeddingModel,
|
||||
@@ -77,6 +85,21 @@ enum EmbeddingInner {
|
||||
},
|
||||
}
|
||||
|
||||
async fn run_fastembed(
|
||||
model: Arc<Mutex<TextEmbedding>>,
|
||||
texts: Vec<String>,
|
||||
) -> Result<Vec<Vec<f32>>> {
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let mut guard = model
|
||||
.lock()
|
||||
.map_err(|e| anyhow!("fastembed model mutex poisoned: {e}"))?;
|
||||
guard.embed(texts, None)
|
||||
})
|
||||
.await
|
||||
.context("joining fastembed embedding task")?
|
||||
.context("generating fastembed embeddings")
|
||||
}
|
||||
|
||||
impl EmbeddingProvider {
|
||||
#[must_use]
|
||||
pub fn backend_label(&self) -> &'static str {
|
||||
@@ -116,10 +139,7 @@ impl EmbeddingProvider {
|
||||
match &self.inner {
|
||||
EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)),
|
||||
EmbeddingInner::FastEmbed { model, .. } => {
|
||||
let mut guard = model.lock().await;
|
||||
let embeddings = guard
|
||||
.embed(vec![text.to_owned()], None)
|
||||
.context("generating fastembed vector")?;
|
||||
let embeddings = run_fastembed(Arc::clone(model), vec![text.to_owned()]).await?;
|
||||
embeddings
|
||||
.into_iter()
|
||||
.next()
|
||||
@@ -166,10 +186,7 @@ impl EmbeddingProvider {
|
||||
if texts.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut guard = model.lock().await;
|
||||
guard
|
||||
.embed(texts, None)
|
||||
.context("generating fastembed batch embeddings")
|
||||
run_fastembed(Arc::clone(model), texts).await
|
||||
}
|
||||
EmbeddingInner::OpenAI {
|
||||
client,
|
||||
@@ -325,12 +342,15 @@ fn bucket(token: &str, dimension: usize) -> usize {
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `AppError::InternalError` if the provider's embed call fails.
|
||||
/// Returns [`AppError::InternalError`] if the provider's embed call fails.
|
||||
pub async fn generate_embedding_with_provider(
|
||||
provider: &EmbeddingProvider,
|
||||
input: &str,
|
||||
) -> Result<Vec<f32>, AppError> {
|
||||
provider.embed(input).await.map_err(|e| AppError::InternalError(e.to_string()))
|
||||
provider
|
||||
.embed(input)
|
||||
.await
|
||||
.map_err(AppError::internal)
|
||||
}
|
||||
|
||||
/// Generates an embedding vector for the given input text using `OpenAI`'s embedding model.
|
||||
@@ -377,7 +397,7 @@ pub async fn generate_embedding(
|
||||
let embedding: Vec<f32> = response
|
||||
.data
|
||||
.first()
|
||||
.ok_or_else(|| AppError::LLMParsing("No embedding data received".into()))?
|
||||
.ok_or_else(|| AppError::LLMParsing("no embedding data received".into()))?
|
||||
.embedding
|
||||
.clone();
|
||||
|
||||
@@ -409,7 +429,7 @@ pub async fn generate_embedding_with_params(
|
||||
let embedding = response
|
||||
.data
|
||||
.first()
|
||||
.ok_or_else(|| AppError::LLMParsing("No embedding data received from API".into()))?
|
||||
.ok_or_else(|| AppError::LLMParsing("no embedding data received from API".into()))?
|
||||
.embedding
|
||||
.clone();
|
||||
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
use thiserror::Error;
|
||||
|
||||
use super::config::AppConfig;
|
||||
|
||||
/// Errors raised when validating ingestion payloads against configured limits.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[derive(Error, Debug, Clone, PartialEq, Eq)]
|
||||
pub enum IngestValidationError {
|
||||
/// The payload exceeds a configured size limit (content, context, or category).
|
||||
#[error("payload too large: {0}")]
|
||||
PayloadTooLarge(String),
|
||||
/// The request violates a non-size constraint (e.g., too many files).
|
||||
#[error("bad request: {0}")]
|
||||
BadRequest(String),
|
||||
}
|
||||
|
||||
@@ -27,7 +31,7 @@ pub fn validate_ingest_input(
|
||||
) -> Result<(), IngestValidationError> {
|
||||
if file_count > config.ingest_max_files {
|
||||
return Err(IngestValidationError::BadRequest(format!(
|
||||
"Too many files. Maximum allowed is {}",
|
||||
"too many files: maximum allowed is {}",
|
||||
config.ingest_max_files
|
||||
)));
|
||||
}
|
||||
@@ -35,7 +39,7 @@ pub fn validate_ingest_input(
|
||||
if let Some(content) = content {
|
||||
if content.len() > config.ingest_max_content_bytes {
|
||||
return Err(IngestValidationError::PayloadTooLarge(format!(
|
||||
"Content is too large. Maximum allowed is {} bytes",
|
||||
"content is too large: maximum allowed is {} bytes",
|
||||
config.ingest_max_content_bytes
|
||||
)));
|
||||
}
|
||||
@@ -43,14 +47,14 @@ pub fn validate_ingest_input(
|
||||
|
||||
if ctx.len() > config.ingest_max_context_bytes {
|
||||
return Err(IngestValidationError::PayloadTooLarge(format!(
|
||||
"Context is too large. Maximum allowed is {} bytes",
|
||||
"context is too large: maximum allowed is {} bytes",
|
||||
config.ingest_max_context_bytes
|
||||
)));
|
||||
}
|
||||
|
||||
if category.len() > config.ingest_max_category_bytes {
|
||||
return Err(IngestValidationError::PayloadTooLarge(format!(
|
||||
"Category is too large. Maximum allowed is {} bytes",
|
||||
"category is too large: maximum allowed is {} bytes",
|
||||
config.ingest_max_category_bytes
|
||||
)));
|
||||
}
|
||||
|
||||
@@ -142,8 +142,8 @@ pub async fn process_ingest_form(
|
||||
input.content,
|
||||
input.context,
|
||||
input.category,
|
||||
&file_infos,
|
||||
user.id.as_str(),
|
||||
file_infos,
|
||||
user.id.clone(),
|
||||
)?;
|
||||
|
||||
let futures: Vec<_> = payloads
|
||||
|
||||
Reference in New Issue
Block a user