mirror of
https://github.com/perstarkse/minne.git
synced 2026-05-29 19:00:51 +02:00
chore: centralize embedding errors, retrieval strategy, and test DB helpers.
Replace anyhow in embedding production code with EmbeddingError, move RetrievalStrategy into common config, and deduplicate Surreal test setup via common::test_utils.
This commit is contained in:
@@ -4,6 +4,36 @@ use tokio::task::JoinError;
|
||||
|
||||
use crate::storage::types::file_info::FileError;
|
||||
|
||||
/// Errors from embedding provider operations.
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
#[derive(Error, Debug)]
|
||||
pub enum EmbeddingError {
|
||||
#[error("openai error: {0}")]
|
||||
OpenAI(#[from] OpenAIError),
|
||||
#[error("fastembed error: {0}")]
|
||||
FastEmbed(String),
|
||||
#[error("task join error: {0}")]
|
||||
Join(#[from] JoinError),
|
||||
#[error("fastembed model mutex poisoned: {0}")]
|
||||
MutexPoisoned(String),
|
||||
#[error("no embedding data received")]
|
||||
NoData,
|
||||
#[error("embedding configuration error: {0}")]
|
||||
Config(String),
|
||||
#[error("unknown fastembed model: {0}")]
|
||||
UnknownModel(String),
|
||||
}
|
||||
|
||||
impl EmbeddingError {
|
||||
pub(crate) fn fastembed(err: impl std::fmt::Display) -> Self {
|
||||
Self::FastEmbed(err.to_string())
|
||||
}
|
||||
|
||||
pub(crate) fn mutex_poisoned(err: impl std::fmt::Display) -> Self {
|
||||
Self::MutexPoisoned(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
// Core internal errors
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
#[derive(Error, Debug)]
|
||||
@@ -12,6 +42,8 @@ pub enum AppError {
|
||||
Database(#[from] surrealdb::Error),
|
||||
#[error("openai error: {0}")]
|
||||
OpenAI(#[from] OpenAIError),
|
||||
#[error("embedding error: {0}")]
|
||||
Embedding(#[from] EmbeddingError),
|
||||
#[error("file error: {0}")]
|
||||
File(#[from] FileError),
|
||||
#[error("not found: {0}")]
|
||||
|
||||
@@ -3,3 +3,6 @@
|
||||
pub mod error;
|
||||
pub mod storage;
|
||||
pub mod utils;
|
||||
|
||||
#[cfg(any(test, feature = "test-utils"))]
|
||||
pub mod test_utils;
|
||||
|
||||
@@ -499,8 +499,7 @@ impl KnowledgeEntity {
|
||||
|
||||
let embedding = provider
|
||||
.embed(&embedding_input)
|
||||
.await
|
||||
.map_err(AppError::internal)?;
|
||||
.await?;
|
||||
|
||||
// Safety check: ensure the generated embedding has the correct dimension.
|
||||
if embedding.len() != new_dimensions {
|
||||
@@ -599,21 +598,11 @@ mod tests {
|
||||
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||
use super::*;
|
||||
use crate::storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding;
|
||||
use crate::storage::types::system_settings::SystemSettings;
|
||||
use crate::test_utils::configure_embedding_dimension;
|
||||
use anyhow::{self, Context};
|
||||
use serde_json::json;
|
||||
use uuid::Uuid;
|
||||
|
||||
async fn configure_embedding_dimension(
|
||||
db: &SurrealDbClient,
|
||||
dimension: u32,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut settings = SystemSettings::get_current(db).await?;
|
||||
settings.embedding_dimensions = dimension;
|
||||
SystemSettings::update(db, settings).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_knowledge_entity_creation() -> anyhow::Result<()> {
|
||||
let source_id = "source123".to_string();
|
||||
|
||||
@@ -158,45 +158,11 @@ impl KnowledgeEntityEmbedding {
|
||||
mod tests {
|
||||
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||
use super::*;
|
||||
use crate::storage::db::SurrealDbClient;
|
||||
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||
use crate::storage::types::system_settings::SystemSettings;
|
||||
use crate::test_utils::{prepare_knowledge_entity_test_db, setup_test_db};
|
||||
use anyhow::{self, Context};
|
||||
use chrono::Utc;
|
||||
use surrealdb::Value as SurrealValue;
|
||||
use uuid::Uuid;
|
||||
|
||||
async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
|
||||
let namespace = "test_ns";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
async fn setup_test_db_with_embedding_dimension(
|
||||
dimension: u32,
|
||||
) -> anyhow::Result<SurrealDbClient> {
|
||||
let db = setup_test_db().await?;
|
||||
let mut settings = SystemSettings::get_current(&db).await?;
|
||||
settings.embedding_dimensions = dimension;
|
||||
SystemSettings::update(&db, settings).await?;
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
async fn prepare_test_db(dimension: u32) -> anyhow::Result<SurrealDbClient> {
|
||||
let db = setup_test_db_with_embedding_dimension(dimension).await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, dimension as usize)
|
||||
.await
|
||||
.with_context(|| format!("set test index dimension to {dimension}"))?;
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
fn build_knowledge_entity_with_id(
|
||||
key: &str,
|
||||
@@ -236,7 +202,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_and_get_by_entity_id() -> anyhow::Result<()> {
|
||||
let db = prepare_test_db(3).await?;
|
||||
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||
let user_id = "user_ke";
|
||||
let entity_key = "entity-1";
|
||||
let source_id = "source-ke";
|
||||
@@ -266,7 +232,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_entity_id() -> anyhow::Result<()> {
|
||||
let db = prepare_test_db(3).await?;
|
||||
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||
let user_id = "user_ke";
|
||||
let entity_key = "entity-delete";
|
||||
let source_id = "source-del";
|
||||
@@ -298,7 +264,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_with_embedding_creates_entity_and_embedding() -> anyhow::Result<()> {
|
||||
let db = prepare_test_db(3).await?;
|
||||
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||
let user_id = "user_store";
|
||||
let source_id = "source_store";
|
||||
let embedding = vec![0.2_f32, 0.3, 0.4];
|
||||
@@ -331,7 +297,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_with_embedding_rejects_wrong_dimension() -> anyhow::Result<()> {
|
||||
let db = prepare_test_db(3).await?;
|
||||
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||
|
||||
let entity = build_knowledge_entity_with_id("entity-dim", "source-dim", "user-dim");
|
||||
let result =
|
||||
@@ -344,7 +310,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id() -> anyhow::Result<()> {
|
||||
let db = prepare_test_db(3).await?;
|
||||
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||
let user_id = "user_ke";
|
||||
let source_id = "shared-ke";
|
||||
let other_source = "other-ke";
|
||||
@@ -437,7 +403,7 @@ mod tests {
|
||||
entity_id: KnowledgeEntity,
|
||||
}
|
||||
|
||||
let db = prepare_test_db(3).await?;
|
||||
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||
let user_id = "user_ke";
|
||||
let entity_key = "entity-fetch";
|
||||
let source_id = "source-fetch";
|
||||
@@ -475,7 +441,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_upsert_replaces_existing_embedding_row() -> anyhow::Result<()> {
|
||||
let db = prepare_test_db(3).await?;
|
||||
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||
|
||||
let user_id = "user-upsert";
|
||||
let source_id = "source-upsert";
|
||||
|
||||
@@ -151,19 +151,7 @@ mod tests {
|
||||
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||
use anyhow::{self, Context};
|
||||
|
||||
async fn setup_test_db() -> SurrealDbClient {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
db
|
||||
}
|
||||
use crate::test_utils::setup_test_db;
|
||||
|
||||
async fn get_relationship_by_id(
|
||||
relationship_id: &str,
|
||||
@@ -234,7 +222,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_and_verify_by_source_id() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await;
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "user123";
|
||||
|
||||
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
||||
@@ -282,7 +270,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_relationship_rejects_foreign_entity() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let owner_entity = create_test_entity("Owner entity", "owner-user", &db).await?;
|
||||
let other_entity = create_test_entity("Other entity", "other-user", &db).await?;
|
||||
@@ -303,7 +291,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_relationship_resists_query_injection() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await;
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "user123";
|
||||
|
||||
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
||||
@@ -342,7 +330,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_and_delete_relationship() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await;
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "user123";
|
||||
|
||||
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
||||
@@ -396,7 +384,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_relationship_by_id_unauthorized() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await;
|
||||
let db = setup_test_db().await?;
|
||||
let owner_user_id = "owner-user";
|
||||
|
||||
let entity1_id = create_test_entity("Entity 1", owner_user_id, &db).await?;
|
||||
@@ -459,7 +447,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_relationship_exists() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await;
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "user123";
|
||||
|
||||
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
||||
@@ -543,7 +531,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_relationships_by_source_id_scoped_to_user() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let user_a = "user-a";
|
||||
let user_b = "user-b";
|
||||
@@ -584,7 +572,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_delete_relationships_by_source_id_resists_query_injection() -> anyhow::Result<()>
|
||||
{
|
||||
let db = setup_test_db().await;
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "user123";
|
||||
|
||||
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
||||
|
||||
@@ -384,8 +384,7 @@ impl TextChunk {
|
||||
|
||||
let embedding = provider
|
||||
.embed(&chunk.chunk)
|
||||
.await
|
||||
.map_err(AppError::internal)?;
|
||||
.await?;
|
||||
|
||||
// Safety check: ensure the generated embedding has the correct dimension.
|
||||
if embedding.len() != new_dimensions {
|
||||
@@ -489,21 +488,11 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::storage::indexes::{ensure_runtime, rebuild};
|
||||
use crate::storage::types::system_settings::SystemSettings;
|
||||
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
|
||||
use crate::test_utils::configure_embedding_dimension;
|
||||
use surrealdb::RecordId;
|
||||
use uuid::Uuid;
|
||||
|
||||
async fn configure_embedding_dimension(
|
||||
db: &SurrealDbClient,
|
||||
dimension: u32,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut settings = SystemSettings::get_current(db).await?;
|
||||
settings.embedding_dimensions = dimension;
|
||||
SystemSettings::update(db, settings).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn ensure_chunk_fts_index(db: &SurrealDbClient) -> anyhow::Result<()> {
|
||||
let snowball_sql = r#"
|
||||
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english);
|
||||
|
||||
@@ -144,41 +144,8 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::storage::db::SurrealDbClient;
|
||||
use crate::storage::types::system_settings::SystemSettings;
|
||||
use crate::test_utils::{prepare_text_chunk_test_db, setup_test_db};
|
||||
use surrealdb::Value as SurrealValue;
|
||||
use uuid::Uuid;
|
||||
|
||||
async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
|
||||
let namespace = "test_ns";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
async fn setup_test_db_with_embedding_dimension(
|
||||
dimension: u32,
|
||||
) -> anyhow::Result<SurrealDbClient> {
|
||||
let db = setup_test_db().await?;
|
||||
let mut settings = SystemSettings::get_current(&db).await?;
|
||||
settings.embedding_dimensions = dimension;
|
||||
SystemSettings::update(&db, settings).await?;
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
async fn prepare_test_db(dimension: u32) -> anyhow::Result<SurrealDbClient> {
|
||||
let db = setup_test_db_with_embedding_dimension(dimension).await?;
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, dimension as usize)
|
||||
.await
|
||||
.with_context(|| format!("set test index dimension to {dimension}"))?;
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
async fn create_text_chunk_with_id(
|
||||
db: &SurrealDbClient,
|
||||
@@ -245,7 +212,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_and_get_by_chunk_id() -> anyhow::Result<()> {
|
||||
let db = prepare_test_db(3).await?;
|
||||
let db = prepare_text_chunk_test_db(3).await?;
|
||||
|
||||
let user_id = "user_a";
|
||||
let chunk_key = "chunk-123";
|
||||
@@ -279,7 +246,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_chunk_id() -> anyhow::Result<()> {
|
||||
let db = prepare_test_db(3).await?;
|
||||
let db = prepare_text_chunk_test_db(3).await?;
|
||||
|
||||
let user_id = "user_b";
|
||||
let chunk_key = "chunk-delete";
|
||||
@@ -316,7 +283,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id() -> anyhow::Result<()> {
|
||||
let db = prepare_test_db(1).await?;
|
||||
let db = prepare_text_chunk_test_db(1).await?;
|
||||
|
||||
let user_id = "user_c";
|
||||
let source_id = "shared-source";
|
||||
@@ -377,7 +344,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_upsert_replaces_existing_embedding_row() -> anyhow::Result<()> {
|
||||
let db = prepare_test_db(3).await?;
|
||||
let db = prepare_text_chunk_test_db(3).await?;
|
||||
|
||||
let user_id = "user-upsert";
|
||||
let source_id = "source-upsert";
|
||||
|
||||
@@ -202,25 +202,7 @@ mod tests {
|
||||
use anyhow::{self, Context};
|
||||
|
||||
use super::*;
|
||||
use crate::storage::indexes::{ensure_runtime, rebuild};
|
||||
|
||||
async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
|
||||
let namespace = "test_ns";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
ensure_runtime(&db, 1536)
|
||||
.await
|
||||
.with_context(|| "ensure runtime indexes".to_string())?;
|
||||
rebuild(&db)
|
||||
.await
|
||||
.with_context(|| "rebuild indexes".to_string())?;
|
||||
Ok(db)
|
||||
}
|
||||
use crate::test_utils::setup_test_db_with_runtime_indexes;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_content_creation() -> anyhow::Result<()> {
|
||||
@@ -339,7 +321,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_content_patch_not_found() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let db = setup_test_db_with_runtime_indexes().await?;
|
||||
|
||||
let err = TextContent::patch("missing-id", "ctx", "cat", "text", &db)
|
||||
.await
|
||||
@@ -412,7 +394,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_search_returns_empty_when_no_content() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let db = setup_test_db_with_runtime_indexes().await?;
|
||||
|
||||
let results = TextContent::search(&db, "hello", "user", 5)
|
||||
.await
|
||||
@@ -424,7 +406,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_search_finds_matching_text_and_filters_user() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let db = setup_test_db_with_runtime_indexes().await?;
|
||||
let user_id = "search_user";
|
||||
|
||||
let matching = TextContent::new(
|
||||
@@ -450,9 +432,6 @@ mod tests {
|
||||
db.store_item(other_user)
|
||||
.await
|
||||
.with_context(|| "store other user".to_string())?;
|
||||
rebuild(&db)
|
||||
.await
|
||||
.with_context(|| "rebuild indexes".to_string())?;
|
||||
|
||||
let results = TextContent::search(&db, "rust", user_id, 5)
|
||||
.await
|
||||
|
||||
@@ -732,20 +732,7 @@ mod tests {
|
||||
use crate::storage::types::ingestion_task::{IngestionTask, TaskState, MAX_ATTEMPTS};
|
||||
use std::collections::HashSet;
|
||||
|
||||
// Helper function to set up a test database with SystemSettings
|
||||
async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
|
||||
let namespace = "test_ns";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to setup the migrations".to_string())?;
|
||||
|
||||
Ok(db)
|
||||
}
|
||||
use crate::test_utils::setup_test_db;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_user_creation() -> anyhow::Result<()> {
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
//! Shared helpers for in-memory SurrealDB tests.
|
||||
#![cfg(any(test, feature = "test-utils"))]
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::storage::{
|
||||
db::SurrealDbClient,
|
||||
indexes::{ensure_runtime, rebuild},
|
||||
types::{
|
||||
knowledge_entity_embedding::KnowledgeEntityEmbedding,
|
||||
system_settings::SystemSettings,
|
||||
text_chunk_embedding::TextChunkEmbedding,
|
||||
},
|
||||
};
|
||||
|
||||
const TEST_NAMESPACE: &str = "test_ns";
|
||||
|
||||
/// Starts an in-memory database, applies migrations, and returns a client.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the database cannot be started or migrations fail.
|
||||
pub async fn setup_test_db() -> Result<SurrealDbClient> {
|
||||
let database = Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(TEST_NAMESPACE, &database)
|
||||
.await
|
||||
.context("start in-memory surrealdb")?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.context("apply migrations")?;
|
||||
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
/// Updates singleton [`SystemSettings`] embedding dimensions for tests.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if settings cannot be loaded or updated.
|
||||
pub async fn configure_embedding_dimension(db: &SurrealDbClient, dimension: u32) -> Result<()> {
|
||||
let mut settings = SystemSettings::get_current(db).await?;
|
||||
settings.embedding_dimensions = dimension;
|
||||
SystemSettings::update(db, settings).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Starts a test database and sets the embedding dimension in system settings.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if setup or settings update fails.
|
||||
pub async fn setup_test_db_with_embedding_dimension(dimension: u32) -> Result<SurrealDbClient> {
|
||||
let db = setup_test_db().await?;
|
||||
configure_embedding_dimension(&db, dimension).await?;
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
/// Prepares a database for text-chunk embedding tests at the given dimension.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if setup, settings update, or index redefinition fails.
|
||||
pub async fn prepare_text_chunk_test_db(dimension: u32) -> Result<SurrealDbClient> {
|
||||
let db = setup_test_db_with_embedding_dimension(dimension).await?;
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, dimension as usize)
|
||||
.await
|
||||
.with_context(|| format!("set text chunk index dimension to {dimension}"))?;
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
/// Prepares a database for knowledge-entity embedding tests at the given dimension.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if setup, settings update, or index redefinition fails.
|
||||
pub async fn prepare_knowledge_entity_test_db(dimension: u32) -> Result<SurrealDbClient> {
|
||||
let db = setup_test_db_with_embedding_dimension(dimension).await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, dimension as usize)
|
||||
.await
|
||||
.with_context(|| format!("set knowledge entity index dimension to {dimension}"))?;
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
/// Starts a test database and ensures runtime FTS/HNSW indexes are ready.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if setup, index creation, or rebuild fails.
|
||||
pub async fn setup_test_db_with_runtime_indexes() -> Result<SurrealDbClient> {
|
||||
let db = setup_test_db().await?;
|
||||
ensure_runtime(&db, 1536).await?;
|
||||
rebuild(&db).await?;
|
||||
Ok(db)
|
||||
}
|
||||
+131
-4
@@ -1,7 +1,8 @@
|
||||
use config::{Config, ConfigError, Environment, File};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{env, sync::Once, str::FromStr};
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
use std::{env, fmt, str::FromStr, sync::Once};
|
||||
use thiserror::Error;
|
||||
use tracing::warn;
|
||||
|
||||
/// Error returned when parsing an embedding backend name.
|
||||
#[derive(Debug, Error, PartialEq, Eq)]
|
||||
@@ -35,6 +36,83 @@ impl EmbeddingBackend {
|
||||
}
|
||||
}
|
||||
|
||||
/// Error returned when parsing a retrieval strategy name.
|
||||
#[derive(Debug, Error, PartialEq, Eq)]
|
||||
#[error("unknown retrieval strategy '{input}'")]
|
||||
pub struct ParseRetrievalStrategyError {
|
||||
/// The unrecognized input string.
|
||||
pub input: String,
|
||||
}
|
||||
|
||||
/// Selects which retrieval pipeline strategy to run for chat and search.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RetrievalStrategy {
|
||||
/// Primary hybrid chunk retrieval for search/chat.
|
||||
#[default]
|
||||
Default,
|
||||
/// Entity retrieval for suggesting relationships when creating manual entities.
|
||||
RelationshipSuggestion,
|
||||
/// Entity retrieval for context during content ingestion.
|
||||
Ingestion,
|
||||
/// Unified search returning both chunks and entities.
|
||||
Search,
|
||||
}
|
||||
|
||||
impl RetrievalStrategy {
|
||||
#[must_use]
|
||||
pub fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
Self::Default => "default",
|
||||
Self::RelationshipSuggestion => "relationship_suggestion",
|
||||
Self::Ingestion => "ingestion",
|
||||
Self::Search => "search",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for RetrievalStrategy {
|
||||
type Err = ParseRetrievalStrategyError;
|
||||
|
||||
fn from_str(value: &str) -> Result<Self, Self::Err> {
|
||||
match value.to_ascii_lowercase().as_str() {
|
||||
"default" => Ok(Self::Default),
|
||||
"initial" | "revised" => {
|
||||
warn!(
|
||||
"retrieval strategy '{value}' is deprecated; use 'default' instead"
|
||||
);
|
||||
Ok(Self::Default)
|
||||
}
|
||||
"relationship_suggestion" => Ok(Self::RelationshipSuggestion),
|
||||
"ingestion" => Ok(Self::Ingestion),
|
||||
"search" => Ok(Self::Search),
|
||||
other => Err(ParseRetrievalStrategyError {
|
||||
input: other.to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for RetrievalStrategy {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(self.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
fn deserialize_optional_retrieval_strategy<'de, D>(
|
||||
deserializer: D,
|
||||
) -> Result<Option<RetrievalStrategy>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value = Option::<String>::deserialize(deserializer)?;
|
||||
match value {
|
||||
None => Ok(None),
|
||||
Some(raw) if raw.trim().is_empty() => Ok(None),
|
||||
Some(raw) => RetrievalStrategy::from_str(&raw).map(Some).map_err(serde::de::Error::custom),
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for EmbeddingBackend {
|
||||
type Err = ParseEmbeddingBackendError;
|
||||
|
||||
@@ -117,8 +195,8 @@ pub struct AppConfig {
|
||||
pub fastembed_show_download_progress: Option<bool>,
|
||||
#[serde(default)]
|
||||
pub fastembed_max_length: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub retrieval_strategy: Option<String>,
|
||||
#[serde(default, deserialize_with = "deserialize_optional_retrieval_strategy")]
|
||||
pub retrieval_strategy: Option<RetrievalStrategy>,
|
||||
#[serde(default)]
|
||||
pub embedding_backend: EmbeddingBackend,
|
||||
#[serde(default = "default_ingest_max_body_bytes")]
|
||||
@@ -204,6 +282,14 @@ pub fn ensure_ort_path() {
|
||||
});
|
||||
}
|
||||
|
||||
impl AppConfig {
|
||||
/// Returns the configured retrieval strategy, or [`RetrievalStrategy::Default`] when unset.
|
||||
#[must_use]
|
||||
pub fn resolved_retrieval_strategy(&self) -> RetrievalStrategy {
|
||||
self.retrieval_strategy.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AppConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@@ -249,3 +335,44 @@ pub fn get_config() -> Result<AppConfig, ConfigError> {
|
||||
|
||||
config.try_deserialize()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{ParseRetrievalStrategyError, RetrievalStrategy};
|
||||
#[test]
|
||||
fn retrieval_strategy_defaults_to_default() {
|
||||
assert_eq!(
|
||||
RetrievalStrategy::default(),
|
||||
RetrievalStrategy::Default
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn retrieval_strategy_serializes_snake_case() {
|
||||
assert_eq!(
|
||||
serde_json::to_string(&RetrievalStrategy::Search).expect("serialize"),
|
||||
"\"search\""
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn retrieval_strategy_from_str_accepts_deprecated_aliases() {
|
||||
assert_eq!(
|
||||
"initial".parse::<RetrievalStrategy>().expect("initial"),
|
||||
RetrievalStrategy::Default
|
||||
);
|
||||
assert!(matches!(
|
||||
"unknown".parse::<RetrievalStrategy>(),
|
||||
Err(ParseRetrievalStrategyError { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn app_config_resolved_retrieval_strategy_uses_default_when_unset() {
|
||||
let config = super::AppConfig::default();
|
||||
assert_eq!(
|
||||
config.resolved_retrieval_strategy(),
|
||||
RetrievalStrategy::Default
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,13 +5,12 @@ use std::{
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use async_openai::{types::CreateEmbeddingRequestArgs, Client};
|
||||
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
|
||||
use tracing::debug;
|
||||
|
||||
use crate::{
|
||||
error::AppError,
|
||||
error::{AppError, EmbeddingError},
|
||||
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
|
||||
utils::config::AppConfig,
|
||||
};
|
||||
@@ -57,16 +56,18 @@ enum EmbeddingInner {
|
||||
async fn run_fastembed(
|
||||
model: Arc<Mutex<TextEmbedding>>,
|
||||
texts: Vec<String>,
|
||||
) -> Result<Vec<Vec<f32>>> {
|
||||
tokio::task::spawn_blocking(move || {
|
||||
) -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||
match tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||
let mut guard = model
|
||||
.lock()
|
||||
.map_err(|e| anyhow!("fastembed model mutex poisoned: {e}"))?;
|
||||
guard.embed(texts, None)
|
||||
.map_err(EmbeddingError::mutex_poisoned)?;
|
||||
guard.embed(texts, None).map_err(EmbeddingError::fastembed)
|
||||
})
|
||||
.await
|
||||
.context("joining fastembed embedding task")?
|
||||
.context("generating fastembed embeddings")
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(join_error) => Err(EmbeddingError::from(join_error)),
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingProvider {
|
||||
@@ -102,17 +103,14 @@ impl EmbeddingProvider {
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the backend API call fails, FastEmbed initialisation fails,
|
||||
/// Returns [`EmbeddingError`] if the backend API call fails, FastEmbed initialisation fails,
|
||||
/// or the backend returns no embedding data.
|
||||
pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
|
||||
pub async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
|
||||
match &self.inner {
|
||||
EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)),
|
||||
EmbeddingInner::FastEmbed { model, .. } => {
|
||||
let embeddings = run_fastembed(Arc::clone(model), vec![text.to_owned()]).await?;
|
||||
embeddings
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("fastembed returned no embedding for input"))
|
||||
embeddings.into_iter().next().ok_or(EmbeddingError::NoData)
|
||||
}
|
||||
EmbeddingInner::OpenAI {
|
||||
client,
|
||||
@@ -130,7 +128,7 @@ impl EmbeddingProvider {
|
||||
let embedding = response
|
||||
.data
|
||||
.first()
|
||||
.ok_or_else(|| anyhow!("No embedding data received from OpenAI API"))?
|
||||
.ok_or(EmbeddingError::NoData)?
|
||||
.embedding
|
||||
.clone();
|
||||
|
||||
@@ -143,9 +141,9 @@ impl EmbeddingProvider {
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the backend API call fails or returns no embedding data.
|
||||
/// Returns [`EmbeddingError`] if the backend API call fails or returns no embedding data.
|
||||
/// Returns an empty `Vec` when `texts` is empty.
|
||||
pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
|
||||
pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||
match &self.inner {
|
||||
EmbeddingInner::Hashed { dimension } => Ok(texts
|
||||
.into_iter()
|
||||
@@ -185,11 +183,14 @@ impl EmbeddingProvider {
|
||||
}
|
||||
}
|
||||
|
||||
/// # Errors
|
||||
///
|
||||
/// Currently infallible; reserved for future validation.
|
||||
pub fn new_openai(
|
||||
client: Arc<Client<async_openai::config::OpenAIConfig>>,
|
||||
model: String,
|
||||
dimensions: u32,
|
||||
) -> Result<Self> {
|
||||
) -> Result<Self, EmbeddingError> {
|
||||
Ok(Self {
|
||||
inner: EmbeddingInner::OpenAI {
|
||||
client,
|
||||
@@ -199,9 +200,12 @@ impl EmbeddingProvider {
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn new_fastembed(model_override: Option<String>) -> Result<Self> {
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`EmbeddingError`] if the model name is unknown or FastEmbed initialisation fails.
|
||||
pub async fn new_fastembed(model_override: Option<String>) -> Result<Self, EmbeddingError> {
|
||||
let model_name = if let Some(code) = model_override {
|
||||
EmbeddingModel::from_str(&code).map_err(|err| anyhow!(err))?
|
||||
EmbeddingModel::from_str(&code).map_err(EmbeddingError::UnknownModel)?
|
||||
} else {
|
||||
EmbeddingModel::default()
|
||||
};
|
||||
@@ -210,15 +214,21 @@ impl EmbeddingProvider {
|
||||
let model_name_for_task = model_name.clone();
|
||||
let model_name_code = model_name.to_string();
|
||||
|
||||
let (model, dimension) = tokio::task::spawn_blocking(move || -> Result<_> {
|
||||
let (model, dimension) = match tokio::task::spawn_blocking(move || -> Result<_, EmbeddingError> {
|
||||
let model =
|
||||
TextEmbedding::try_new(options).context("initialising FastEmbed text model")?;
|
||||
let info = EmbeddingModel::get_model_info(&model_name_for_task)
|
||||
.ok_or_else(|| anyhow!("FastEmbed model metadata missing for {model_name_code}"))?;
|
||||
TextEmbedding::try_new(options).map_err(EmbeddingError::fastembed)?;
|
||||
let info = EmbeddingModel::get_model_info(&model_name_for_task).ok_or_else(|| {
|
||||
EmbeddingError::Config(format!(
|
||||
"fastembed model metadata missing for {model_name_code}"
|
||||
))
|
||||
})?;
|
||||
Ok((model, info.dim))
|
||||
})
|
||||
.await
|
||||
.context("joining FastEmbed initialisation task")??;
|
||||
{
|
||||
Ok(result) => result?,
|
||||
Err(join_error) => return Err(EmbeddingError::from(join_error)),
|
||||
};
|
||||
|
||||
Ok(EmbeddingProvider {
|
||||
inner: EmbeddingInner::FastEmbed {
|
||||
@@ -229,7 +239,10 @@ impl EmbeddingProvider {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new_hashed(dimension: usize) -> Result<Self> {
|
||||
/// # Errors
|
||||
///
|
||||
/// Currently infallible; reserved for future validation.
|
||||
pub fn new_hashed(dimension: usize) -> Result<Self, EmbeddingError> {
|
||||
Ok(EmbeddingProvider {
|
||||
inner: EmbeddingInner::Hashed {
|
||||
dimension: dimension.max(1),
|
||||
@@ -242,24 +255,32 @@ impl EmbeddingProvider {
|
||||
/// Model name and dimensions come from [`SystemSettings`]. The active backend is taken
|
||||
/// from `config.embedding_backend` at startup; [`SystemSettings::sync_from_embedding_provider`]
|
||||
/// persists the resolved backend to the database.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`EmbeddingError`] if the selected backend cannot be initialised.
|
||||
pub async fn from_system_settings(
|
||||
settings: &SystemSettings,
|
||||
config: &AppConfig,
|
||||
openai_client: Option<Arc<Client<async_openai::config::OpenAIConfig>>>,
|
||||
) -> Result<Self> {
|
||||
) -> Result<Self, EmbeddingError> {
|
||||
let dimensions = settings.embedding_dimensions;
|
||||
match config.embedding_backend {
|
||||
EmbeddingBackend::OpenAI => {
|
||||
let client = openai_client
|
||||
.ok_or_else(|| anyhow!("OpenAI embedding backend requires an OpenAI client"))?;
|
||||
let client = openai_client.ok_or_else(|| {
|
||||
EmbeddingError::Config(
|
||||
"openai embedding backend requires an openai client".into(),
|
||||
)
|
||||
})?;
|
||||
Self::new_openai(client, settings.embedding_model.clone(), dimensions)
|
||||
}
|
||||
EmbeddingBackend::FastEmbed => {
|
||||
Self::new_fastembed(Some(settings.embedding_model.clone())).await
|
||||
}
|
||||
EmbeddingBackend::Hashed => {
|
||||
let dimension = usize::try_from(dimensions)
|
||||
.map_err(|_| anyhow!("embedding_dimensions exceeds usize::MAX"))?;
|
||||
let dimension = usize::try_from(dimensions).map_err(|_| {
|
||||
EmbeddingError::Config("embedding_dimensions exceeds usize::MAX".into())
|
||||
})?;
|
||||
Self::new_hashed(dimension)
|
||||
}
|
||||
}
|
||||
@@ -312,15 +333,12 @@ fn bucket(token: &str, dimension: usize) -> usize {
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`AppError::InternalError`] if the provider's embed call fails.
|
||||
/// Returns [`AppError::Embedding`] if the provider's embed call fails.
|
||||
pub async fn generate_embedding_with_provider(
|
||||
provider: &EmbeddingProvider,
|
||||
input: &str,
|
||||
) -> Result<Vec<f32>, AppError> {
|
||||
provider
|
||||
.embed(input)
|
||||
.await
|
||||
.map_err(AppError::internal)
|
||||
provider.embed(input).await.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Generates an embedding vector for the given input text using `OpenAI`'s embedding model.
|
||||
|
||||
Reference in New Issue
Block a user