mirror of
https://github.com/perstarkse/minne.git
synced 2026-05-30 03:10:45 +02:00
chore: centralize embedding errors, retrieval strategy, and test DB helpers.
Replace anyhow in embedding production code with EmbeddingError, move RetrievalStrategy into common config, and deduplicate Surreal test setup via common::test_utils.
This commit is contained in:
@@ -4,6 +4,36 @@ use tokio::task::JoinError;
|
|||||||
|
|
||||||
use crate::storage::types::file_info::FileError;
|
use crate::storage::types::file_info::FileError;
|
||||||
|
|
||||||
|
/// Errors from embedding provider operations.
|
||||||
|
#[allow(clippy::module_name_repetitions)]
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum EmbeddingError {
|
||||||
|
#[error("openai error: {0}")]
|
||||||
|
OpenAI(#[from] OpenAIError),
|
||||||
|
#[error("fastembed error: {0}")]
|
||||||
|
FastEmbed(String),
|
||||||
|
#[error("task join error: {0}")]
|
||||||
|
Join(#[from] JoinError),
|
||||||
|
#[error("fastembed model mutex poisoned: {0}")]
|
||||||
|
MutexPoisoned(String),
|
||||||
|
#[error("no embedding data received")]
|
||||||
|
NoData,
|
||||||
|
#[error("embedding configuration error: {0}")]
|
||||||
|
Config(String),
|
||||||
|
#[error("unknown fastembed model: {0}")]
|
||||||
|
UnknownModel(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EmbeddingError {
|
||||||
|
pub(crate) fn fastembed(err: impl std::fmt::Display) -> Self {
|
||||||
|
Self::FastEmbed(err.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn mutex_poisoned(err: impl std::fmt::Display) -> Self {
|
||||||
|
Self::MutexPoisoned(err.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Core internal errors
|
// Core internal errors
|
||||||
#[allow(clippy::module_name_repetitions)]
|
#[allow(clippy::module_name_repetitions)]
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
@@ -12,6 +42,8 @@ pub enum AppError {
|
|||||||
Database(#[from] surrealdb::Error),
|
Database(#[from] surrealdb::Error),
|
||||||
#[error("openai error: {0}")]
|
#[error("openai error: {0}")]
|
||||||
OpenAI(#[from] OpenAIError),
|
OpenAI(#[from] OpenAIError),
|
||||||
|
#[error("embedding error: {0}")]
|
||||||
|
Embedding(#[from] EmbeddingError),
|
||||||
#[error("file error: {0}")]
|
#[error("file error: {0}")]
|
||||||
File(#[from] FileError),
|
File(#[from] FileError),
|
||||||
#[error("not found: {0}")]
|
#[error("not found: {0}")]
|
||||||
|
|||||||
@@ -3,3 +3,6 @@
|
|||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod storage;
|
pub mod storage;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
|
|
||||||
|
#[cfg(any(test, feature = "test-utils"))]
|
||||||
|
pub mod test_utils;
|
||||||
|
|||||||
@@ -499,8 +499,7 @@ impl KnowledgeEntity {
|
|||||||
|
|
||||||
let embedding = provider
|
let embedding = provider
|
||||||
.embed(&embedding_input)
|
.embed(&embedding_input)
|
||||||
.await
|
.await?;
|
||||||
.map_err(AppError::internal)?;
|
|
||||||
|
|
||||||
// Safety check: ensure the generated embedding has the correct dimension.
|
// Safety check: ensure the generated embedding has the correct dimension.
|
||||||
if embedding.len() != new_dimensions {
|
if embedding.len() != new_dimensions {
|
||||||
@@ -599,21 +598,11 @@ mod tests {
|
|||||||
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding;
|
use crate::storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding;
|
||||||
use crate::storage::types::system_settings::SystemSettings;
|
use crate::test_utils::configure_embedding_dimension;
|
||||||
use anyhow::{self, Context};
|
use anyhow::{self, Context};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
async fn configure_embedding_dimension(
|
|
||||||
db: &SurrealDbClient,
|
|
||||||
dimension: u32,
|
|
||||||
) -> anyhow::Result<()> {
|
|
||||||
let mut settings = SystemSettings::get_current(db).await?;
|
|
||||||
settings.embedding_dimensions = dimension;
|
|
||||||
SystemSettings::update(db, settings).await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_knowledge_entity_creation() -> anyhow::Result<()> {
|
async fn test_knowledge_entity_creation() -> anyhow::Result<()> {
|
||||||
let source_id = "source123".to_string();
|
let source_id = "source123".to_string();
|
||||||
|
|||||||
@@ -158,45 +158,11 @@ impl KnowledgeEntityEmbedding {
|
|||||||
mod tests {
|
mod tests {
|
||||||
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::storage::db::SurrealDbClient;
|
|
||||||
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||||
use crate::storage::types::system_settings::SystemSettings;
|
use crate::test_utils::{prepare_knowledge_entity_test_db, setup_test_db};
|
||||||
use anyhow::{self, Context};
|
use anyhow::{self, Context};
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use surrealdb::Value as SurrealValue;
|
use surrealdb::Value as SurrealValue;
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
|
|
||||||
let namespace = "test_ns";
|
|
||||||
let database = Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, &database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
Ok(db)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn setup_test_db_with_embedding_dimension(
|
|
||||||
dimension: u32,
|
|
||||||
) -> anyhow::Result<SurrealDbClient> {
|
|
||||||
let db = setup_test_db().await?;
|
|
||||||
let mut settings = SystemSettings::get_current(&db).await?;
|
|
||||||
settings.embedding_dimensions = dimension;
|
|
||||||
SystemSettings::update(&db, settings).await?;
|
|
||||||
Ok(db)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn prepare_test_db(dimension: u32) -> anyhow::Result<SurrealDbClient> {
|
|
||||||
let db = setup_test_db_with_embedding_dimension(dimension).await?;
|
|
||||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, dimension as usize)
|
|
||||||
.await
|
|
||||||
.with_context(|| format!("set test index dimension to {dimension}"))?;
|
|
||||||
Ok(db)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build_knowledge_entity_with_id(
|
fn build_knowledge_entity_with_id(
|
||||||
key: &str,
|
key: &str,
|
||||||
@@ -236,7 +202,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_create_and_get_by_entity_id() -> anyhow::Result<()> {
|
async fn test_create_and_get_by_entity_id() -> anyhow::Result<()> {
|
||||||
let db = prepare_test_db(3).await?;
|
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||||
let user_id = "user_ke";
|
let user_id = "user_ke";
|
||||||
let entity_key = "entity-1";
|
let entity_key = "entity-1";
|
||||||
let source_id = "source-ke";
|
let source_id = "source-ke";
|
||||||
@@ -266,7 +232,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_delete_by_entity_id() -> anyhow::Result<()> {
|
async fn test_delete_by_entity_id() -> anyhow::Result<()> {
|
||||||
let db = prepare_test_db(3).await?;
|
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||||
let user_id = "user_ke";
|
let user_id = "user_ke";
|
||||||
let entity_key = "entity-delete";
|
let entity_key = "entity-delete";
|
||||||
let source_id = "source-del";
|
let source_id = "source-del";
|
||||||
@@ -298,7 +264,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_store_with_embedding_creates_entity_and_embedding() -> anyhow::Result<()> {
|
async fn test_store_with_embedding_creates_entity_and_embedding() -> anyhow::Result<()> {
|
||||||
let db = prepare_test_db(3).await?;
|
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||||
let user_id = "user_store";
|
let user_id = "user_store";
|
||||||
let source_id = "source_store";
|
let source_id = "source_store";
|
||||||
let embedding = vec![0.2_f32, 0.3, 0.4];
|
let embedding = vec![0.2_f32, 0.3, 0.4];
|
||||||
@@ -331,7 +297,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_store_with_embedding_rejects_wrong_dimension() -> anyhow::Result<()> {
|
async fn test_store_with_embedding_rejects_wrong_dimension() -> anyhow::Result<()> {
|
||||||
let db = prepare_test_db(3).await?;
|
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||||
|
|
||||||
let entity = build_knowledge_entity_with_id("entity-dim", "source-dim", "user-dim");
|
let entity = build_knowledge_entity_with_id("entity-dim", "source-dim", "user-dim");
|
||||||
let result =
|
let result =
|
||||||
@@ -344,7 +310,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_delete_by_source_id() -> anyhow::Result<()> {
|
async fn test_delete_by_source_id() -> anyhow::Result<()> {
|
||||||
let db = prepare_test_db(3).await?;
|
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||||
let user_id = "user_ke";
|
let user_id = "user_ke";
|
||||||
let source_id = "shared-ke";
|
let source_id = "shared-ke";
|
||||||
let other_source = "other-ke";
|
let other_source = "other-ke";
|
||||||
@@ -437,7 +403,7 @@ mod tests {
|
|||||||
entity_id: KnowledgeEntity,
|
entity_id: KnowledgeEntity,
|
||||||
}
|
}
|
||||||
|
|
||||||
let db = prepare_test_db(3).await?;
|
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||||
let user_id = "user_ke";
|
let user_id = "user_ke";
|
||||||
let entity_key = "entity-fetch";
|
let entity_key = "entity-fetch";
|
||||||
let source_id = "source-fetch";
|
let source_id = "source-fetch";
|
||||||
@@ -475,7 +441,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_upsert_replaces_existing_embedding_row() -> anyhow::Result<()> {
|
async fn test_upsert_replaces_existing_embedding_row() -> anyhow::Result<()> {
|
||||||
let db = prepare_test_db(3).await?;
|
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||||
|
|
||||||
let user_id = "user-upsert";
|
let user_id = "user-upsert";
|
||||||
let source_id = "source-upsert";
|
let source_id = "source-upsert";
|
||||||
|
|||||||
@@ -151,19 +151,7 @@ mod tests {
|
|||||||
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||||
use anyhow::{self, Context};
|
use anyhow::{self, Context};
|
||||||
|
|
||||||
async fn setup_test_db() -> SurrealDbClient {
|
use crate::test_utils::setup_test_db;
|
||||||
let namespace = "test_ns";
|
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.expect("Failed to start in-memory surrealdb");
|
|
||||||
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.expect("Failed to apply migrations");
|
|
||||||
|
|
||||||
db
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn get_relationship_by_id(
|
async fn get_relationship_by_id(
|
||||||
relationship_id: &str,
|
relationship_id: &str,
|
||||||
@@ -234,7 +222,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_store_and_verify_by_source_id() -> anyhow::Result<()> {
|
async fn test_store_and_verify_by_source_id() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await;
|
let db = setup_test_db().await?;
|
||||||
let user_id = "user123";
|
let user_id = "user123";
|
||||||
|
|
||||||
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
||||||
@@ -282,7 +270,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_store_relationship_rejects_foreign_entity() -> anyhow::Result<()> {
|
async fn test_store_relationship_rejects_foreign_entity() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await;
|
let db = setup_test_db().await?;
|
||||||
|
|
||||||
let owner_entity = create_test_entity("Owner entity", "owner-user", &db).await?;
|
let owner_entity = create_test_entity("Owner entity", "owner-user", &db).await?;
|
||||||
let other_entity = create_test_entity("Other entity", "other-user", &db).await?;
|
let other_entity = create_test_entity("Other entity", "other-user", &db).await?;
|
||||||
@@ -303,7 +291,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_store_relationship_resists_query_injection() -> anyhow::Result<()> {
|
async fn test_store_relationship_resists_query_injection() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await;
|
let db = setup_test_db().await?;
|
||||||
let user_id = "user123";
|
let user_id = "user123";
|
||||||
|
|
||||||
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
||||||
@@ -342,7 +330,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_store_and_delete_relationship() -> anyhow::Result<()> {
|
async fn test_store_and_delete_relationship() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await;
|
let db = setup_test_db().await?;
|
||||||
let user_id = "user123";
|
let user_id = "user123";
|
||||||
|
|
||||||
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
||||||
@@ -396,7 +384,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_delete_relationship_by_id_unauthorized() -> anyhow::Result<()> {
|
async fn test_delete_relationship_by_id_unauthorized() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await;
|
let db = setup_test_db().await?;
|
||||||
let owner_user_id = "owner-user";
|
let owner_user_id = "owner-user";
|
||||||
|
|
||||||
let entity1_id = create_test_entity("Entity 1", owner_user_id, &db).await?;
|
let entity1_id = create_test_entity("Entity 1", owner_user_id, &db).await?;
|
||||||
@@ -459,7 +447,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_store_relationship_exists() -> anyhow::Result<()> {
|
async fn test_store_relationship_exists() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await;
|
let db = setup_test_db().await?;
|
||||||
let user_id = "user123";
|
let user_id = "user123";
|
||||||
|
|
||||||
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
||||||
@@ -543,7 +531,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_delete_relationships_by_source_id_scoped_to_user() -> anyhow::Result<()> {
|
async fn test_delete_relationships_by_source_id_scoped_to_user() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await;
|
let db = setup_test_db().await?;
|
||||||
|
|
||||||
let user_a = "user-a";
|
let user_a = "user-a";
|
||||||
let user_b = "user-b";
|
let user_b = "user-b";
|
||||||
@@ -584,7 +572,7 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_delete_relationships_by_source_id_resists_query_injection() -> anyhow::Result<()>
|
async fn test_delete_relationships_by_source_id_resists_query_injection() -> anyhow::Result<()>
|
||||||
{
|
{
|
||||||
let db = setup_test_db().await;
|
let db = setup_test_db().await?;
|
||||||
let user_id = "user123";
|
let user_id = "user123";
|
||||||
|
|
||||||
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
||||||
|
|||||||
@@ -384,8 +384,7 @@ impl TextChunk {
|
|||||||
|
|
||||||
let embedding = provider
|
let embedding = provider
|
||||||
.embed(&chunk.chunk)
|
.embed(&chunk.chunk)
|
||||||
.await
|
.await?;
|
||||||
.map_err(AppError::internal)?;
|
|
||||||
|
|
||||||
// Safety check: ensure the generated embedding has the correct dimension.
|
// Safety check: ensure the generated embedding has the correct dimension.
|
||||||
if embedding.len() != new_dimensions {
|
if embedding.len() != new_dimensions {
|
||||||
@@ -489,21 +488,11 @@ mod tests {
|
|||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::storage::indexes::{ensure_runtime, rebuild};
|
use crate::storage::indexes::{ensure_runtime, rebuild};
|
||||||
use crate::storage::types::system_settings::SystemSettings;
|
|
||||||
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
|
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
|
||||||
|
use crate::test_utils::configure_embedding_dimension;
|
||||||
use surrealdb::RecordId;
|
use surrealdb::RecordId;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
async fn configure_embedding_dimension(
|
|
||||||
db: &SurrealDbClient,
|
|
||||||
dimension: u32,
|
|
||||||
) -> anyhow::Result<()> {
|
|
||||||
let mut settings = SystemSettings::get_current(db).await?;
|
|
||||||
settings.embedding_dimensions = dimension;
|
|
||||||
SystemSettings::update(db, settings).await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn ensure_chunk_fts_index(db: &SurrealDbClient) -> anyhow::Result<()> {
|
async fn ensure_chunk_fts_index(db: &SurrealDbClient) -> anyhow::Result<()> {
|
||||||
let snowball_sql = r#"
|
let snowball_sql = r#"
|
||||||
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english);
|
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english);
|
||||||
|
|||||||
@@ -144,41 +144,8 @@ mod tests {
|
|||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::storage::db::SurrealDbClient;
|
use crate::storage::db::SurrealDbClient;
|
||||||
use crate::storage::types::system_settings::SystemSettings;
|
use crate::test_utils::{prepare_text_chunk_test_db, setup_test_db};
|
||||||
use surrealdb::Value as SurrealValue;
|
use surrealdb::Value as SurrealValue;
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
|
|
||||||
let namespace = "test_ns";
|
|
||||||
let database = Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, &database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
Ok(db)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn setup_test_db_with_embedding_dimension(
|
|
||||||
dimension: u32,
|
|
||||||
) -> anyhow::Result<SurrealDbClient> {
|
|
||||||
let db = setup_test_db().await?;
|
|
||||||
let mut settings = SystemSettings::get_current(&db).await?;
|
|
||||||
settings.embedding_dimensions = dimension;
|
|
||||||
SystemSettings::update(&db, settings).await?;
|
|
||||||
Ok(db)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn prepare_test_db(dimension: u32) -> anyhow::Result<SurrealDbClient> {
|
|
||||||
let db = setup_test_db_with_embedding_dimension(dimension).await?;
|
|
||||||
TextChunkEmbedding::redefine_hnsw_index(&db, dimension as usize)
|
|
||||||
.await
|
|
||||||
.with_context(|| format!("set test index dimension to {dimension}"))?;
|
|
||||||
Ok(db)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn create_text_chunk_with_id(
|
async fn create_text_chunk_with_id(
|
||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
@@ -245,7 +212,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_create_and_get_by_chunk_id() -> anyhow::Result<()> {
|
async fn test_create_and_get_by_chunk_id() -> anyhow::Result<()> {
|
||||||
let db = prepare_test_db(3).await?;
|
let db = prepare_text_chunk_test_db(3).await?;
|
||||||
|
|
||||||
let user_id = "user_a";
|
let user_id = "user_a";
|
||||||
let chunk_key = "chunk-123";
|
let chunk_key = "chunk-123";
|
||||||
@@ -279,7 +246,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_delete_by_chunk_id() -> anyhow::Result<()> {
|
async fn test_delete_by_chunk_id() -> anyhow::Result<()> {
|
||||||
let db = prepare_test_db(3).await?;
|
let db = prepare_text_chunk_test_db(3).await?;
|
||||||
|
|
||||||
let user_id = "user_b";
|
let user_id = "user_b";
|
||||||
let chunk_key = "chunk-delete";
|
let chunk_key = "chunk-delete";
|
||||||
@@ -316,7 +283,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_delete_by_source_id() -> anyhow::Result<()> {
|
async fn test_delete_by_source_id() -> anyhow::Result<()> {
|
||||||
let db = prepare_test_db(1).await?;
|
let db = prepare_text_chunk_test_db(1).await?;
|
||||||
|
|
||||||
let user_id = "user_c";
|
let user_id = "user_c";
|
||||||
let source_id = "shared-source";
|
let source_id = "shared-source";
|
||||||
@@ -377,7 +344,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_upsert_replaces_existing_embedding_row() -> anyhow::Result<()> {
|
async fn test_upsert_replaces_existing_embedding_row() -> anyhow::Result<()> {
|
||||||
let db = prepare_test_db(3).await?;
|
let db = prepare_text_chunk_test_db(3).await?;
|
||||||
|
|
||||||
let user_id = "user-upsert";
|
let user_id = "user-upsert";
|
||||||
let source_id = "source-upsert";
|
let source_id = "source-upsert";
|
||||||
|
|||||||
@@ -202,25 +202,7 @@ mod tests {
|
|||||||
use anyhow::{self, Context};
|
use anyhow::{self, Context};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::storage::indexes::{ensure_runtime, rebuild};
|
use crate::test_utils::setup_test_db_with_runtime_indexes;
|
||||||
|
|
||||||
async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
|
|
||||||
let namespace = "test_ns";
|
|
||||||
let database = Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, &database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
ensure_runtime(&db, 1536)
|
|
||||||
.await
|
|
||||||
.with_context(|| "ensure runtime indexes".to_string())?;
|
|
||||||
rebuild(&db)
|
|
||||||
.await
|
|
||||||
.with_context(|| "rebuild indexes".to_string())?;
|
|
||||||
Ok(db)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_text_content_creation() -> anyhow::Result<()> {
|
async fn test_text_content_creation() -> anyhow::Result<()> {
|
||||||
@@ -339,7 +321,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_text_content_patch_not_found() -> anyhow::Result<()> {
|
async fn test_text_content_patch_not_found() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await?;
|
let db = setup_test_db_with_runtime_indexes().await?;
|
||||||
|
|
||||||
let err = TextContent::patch("missing-id", "ctx", "cat", "text", &db)
|
let err = TextContent::patch("missing-id", "ctx", "cat", "text", &db)
|
||||||
.await
|
.await
|
||||||
@@ -412,7 +394,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_search_returns_empty_when_no_content() -> anyhow::Result<()> {
|
async fn test_search_returns_empty_when_no_content() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await?;
|
let db = setup_test_db_with_runtime_indexes().await?;
|
||||||
|
|
||||||
let results = TextContent::search(&db, "hello", "user", 5)
|
let results = TextContent::search(&db, "hello", "user", 5)
|
||||||
.await
|
.await
|
||||||
@@ -424,7 +406,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_search_finds_matching_text_and_filters_user() -> anyhow::Result<()> {
|
async fn test_search_finds_matching_text_and_filters_user() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await?;
|
let db = setup_test_db_with_runtime_indexes().await?;
|
||||||
let user_id = "search_user";
|
let user_id = "search_user";
|
||||||
|
|
||||||
let matching = TextContent::new(
|
let matching = TextContent::new(
|
||||||
@@ -450,9 +432,6 @@ mod tests {
|
|||||||
db.store_item(other_user)
|
db.store_item(other_user)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "store other user".to_string())?;
|
.with_context(|| "store other user".to_string())?;
|
||||||
rebuild(&db)
|
|
||||||
.await
|
|
||||||
.with_context(|| "rebuild indexes".to_string())?;
|
|
||||||
|
|
||||||
let results = TextContent::search(&db, "rust", user_id, 5)
|
let results = TextContent::search(&db, "rust", user_id, 5)
|
||||||
.await
|
.await
|
||||||
|
|||||||
@@ -732,20 +732,7 @@ mod tests {
|
|||||||
use crate::storage::types::ingestion_task::{IngestionTask, TaskState, MAX_ATTEMPTS};
|
use crate::storage::types::ingestion_task::{IngestionTask, TaskState, MAX_ATTEMPTS};
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
|
|
||||||
// Helper function to set up a test database with SystemSettings
|
use crate::test_utils::setup_test_db;
|
||||||
async fn setup_test_db() -> anyhow::Result<SurrealDbClient> {
|
|
||||||
let namespace = "test_ns";
|
|
||||||
let database = Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, &database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to setup the migrations".to_string())?;
|
|
||||||
|
|
||||||
Ok(db)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_user_creation() -> anyhow::Result<()> {
|
async fn test_user_creation() -> anyhow::Result<()> {
|
||||||
|
|||||||
@@ -0,0 +1,96 @@
|
|||||||
|
//! Shared helpers for in-memory SurrealDB tests.
|
||||||
|
#![cfg(any(test, feature = "test-utils"))]
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::storage::{
|
||||||
|
db::SurrealDbClient,
|
||||||
|
indexes::{ensure_runtime, rebuild},
|
||||||
|
types::{
|
||||||
|
knowledge_entity_embedding::KnowledgeEntityEmbedding,
|
||||||
|
system_settings::SystemSettings,
|
||||||
|
text_chunk_embedding::TextChunkEmbedding,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const TEST_NAMESPACE: &str = "test_ns";
|
||||||
|
|
||||||
|
/// Starts an in-memory database, applies migrations, and returns a client.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the database cannot be started or migrations fail.
|
||||||
|
pub async fn setup_test_db() -> Result<SurrealDbClient> {
|
||||||
|
let database = Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(TEST_NAMESPACE, &database)
|
||||||
|
.await
|
||||||
|
.context("start in-memory surrealdb")?;
|
||||||
|
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.context("apply migrations")?;
|
||||||
|
|
||||||
|
Ok(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Updates singleton [`SystemSettings`] embedding dimensions for tests.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if settings cannot be loaded or updated.
|
||||||
|
pub async fn configure_embedding_dimension(db: &SurrealDbClient, dimension: u32) -> Result<()> {
|
||||||
|
let mut settings = SystemSettings::get_current(db).await?;
|
||||||
|
settings.embedding_dimensions = dimension;
|
||||||
|
SystemSettings::update(db, settings).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Starts a test database and sets the embedding dimension in system settings.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if setup or settings update fails.
|
||||||
|
pub async fn setup_test_db_with_embedding_dimension(dimension: u32) -> Result<SurrealDbClient> {
|
||||||
|
let db = setup_test_db().await?;
|
||||||
|
configure_embedding_dimension(&db, dimension).await?;
|
||||||
|
Ok(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Prepares a database for text-chunk embedding tests at the given dimension.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if setup, settings update, or index redefinition fails.
|
||||||
|
pub async fn prepare_text_chunk_test_db(dimension: u32) -> Result<SurrealDbClient> {
|
||||||
|
let db = setup_test_db_with_embedding_dimension(dimension).await?;
|
||||||
|
TextChunkEmbedding::redefine_hnsw_index(&db, dimension as usize)
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("set text chunk index dimension to {dimension}"))?;
|
||||||
|
Ok(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Prepares a database for knowledge-entity embedding tests at the given dimension.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if setup, settings update, or index redefinition fails.
|
||||||
|
pub async fn prepare_knowledge_entity_test_db(dimension: u32) -> Result<SurrealDbClient> {
|
||||||
|
let db = setup_test_db_with_embedding_dimension(dimension).await?;
|
||||||
|
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, dimension as usize)
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("set knowledge entity index dimension to {dimension}"))?;
|
||||||
|
Ok(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Starts a test database and ensures runtime FTS/HNSW indexes are ready.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if setup, index creation, or rebuild fails.
|
||||||
|
pub async fn setup_test_db_with_runtime_indexes() -> Result<SurrealDbClient> {
|
||||||
|
let db = setup_test_db().await?;
|
||||||
|
ensure_runtime(&db, 1536).await?;
|
||||||
|
rebuild(&db).await?;
|
||||||
|
Ok(db)
|
||||||
|
}
|
||||||
+131
-4
@@ -1,7 +1,8 @@
|
|||||||
use config::{Config, ConfigError, Environment, File};
|
use config::{Config, ConfigError, Environment, File};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Deserializer, Serialize};
|
||||||
use std::{env, sync::Once, str::FromStr};
|
use std::{env, fmt, str::FromStr, sync::Once};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
use tracing::warn;
|
||||||
|
|
||||||
/// Error returned when parsing an embedding backend name.
|
/// Error returned when parsing an embedding backend name.
|
||||||
#[derive(Debug, Error, PartialEq, Eq)]
|
#[derive(Debug, Error, PartialEq, Eq)]
|
||||||
@@ -35,6 +36,83 @@ impl EmbeddingBackend {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Error returned when parsing a retrieval strategy name.
|
||||||
|
#[derive(Debug, Error, PartialEq, Eq)]
|
||||||
|
#[error("unknown retrieval strategy '{input}'")]
|
||||||
|
pub struct ParseRetrievalStrategyError {
|
||||||
|
/// The unrecognized input string.
|
||||||
|
pub input: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Selects which retrieval pipeline strategy to run for chat and search.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum RetrievalStrategy {
|
||||||
|
/// Primary hybrid chunk retrieval for search/chat.
|
||||||
|
#[default]
|
||||||
|
Default,
|
||||||
|
/// Entity retrieval for suggesting relationships when creating manual entities.
|
||||||
|
RelationshipSuggestion,
|
||||||
|
/// Entity retrieval for context during content ingestion.
|
||||||
|
Ingestion,
|
||||||
|
/// Unified search returning both chunks and entities.
|
||||||
|
Search,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RetrievalStrategy {
|
||||||
|
#[must_use]
|
||||||
|
pub fn as_str(self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::Default => "default",
|
||||||
|
Self::RelationshipSuggestion => "relationship_suggestion",
|
||||||
|
Self::Ingestion => "ingestion",
|
||||||
|
Self::Search => "search",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromStr for RetrievalStrategy {
|
||||||
|
type Err = ParseRetrievalStrategyError;
|
||||||
|
|
||||||
|
fn from_str(value: &str) -> Result<Self, Self::Err> {
|
||||||
|
match value.to_ascii_lowercase().as_str() {
|
||||||
|
"default" => Ok(Self::Default),
|
||||||
|
"initial" | "revised" => {
|
||||||
|
warn!(
|
||||||
|
"retrieval strategy '{value}' is deprecated; use 'default' instead"
|
||||||
|
);
|
||||||
|
Ok(Self::Default)
|
||||||
|
}
|
||||||
|
"relationship_suggestion" => Ok(Self::RelationshipSuggestion),
|
||||||
|
"ingestion" => Ok(Self::Ingestion),
|
||||||
|
"search" => Ok(Self::Search),
|
||||||
|
other => Err(ParseRetrievalStrategyError {
|
||||||
|
input: other.to_string(),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for RetrievalStrategy {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
f.write_str(self.as_str())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn deserialize_optional_retrieval_strategy<'de, D>(
|
||||||
|
deserializer: D,
|
||||||
|
) -> Result<Option<RetrievalStrategy>, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let value = Option::<String>::deserialize(deserializer)?;
|
||||||
|
match value {
|
||||||
|
None => Ok(None),
|
||||||
|
Some(raw) if raw.trim().is_empty() => Ok(None),
|
||||||
|
Some(raw) => RetrievalStrategy::from_str(&raw).map(Some).map_err(serde::de::Error::custom),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl FromStr for EmbeddingBackend {
|
impl FromStr for EmbeddingBackend {
|
||||||
type Err = ParseEmbeddingBackendError;
|
type Err = ParseEmbeddingBackendError;
|
||||||
|
|
||||||
@@ -117,8 +195,8 @@ pub struct AppConfig {
|
|||||||
pub fastembed_show_download_progress: Option<bool>,
|
pub fastembed_show_download_progress: Option<bool>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub fastembed_max_length: Option<usize>,
|
pub fastembed_max_length: Option<usize>,
|
||||||
#[serde(default)]
|
#[serde(default, deserialize_with = "deserialize_optional_retrieval_strategy")]
|
||||||
pub retrieval_strategy: Option<String>,
|
pub retrieval_strategy: Option<RetrievalStrategy>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub embedding_backend: EmbeddingBackend,
|
pub embedding_backend: EmbeddingBackend,
|
||||||
#[serde(default = "default_ingest_max_body_bytes")]
|
#[serde(default = "default_ingest_max_body_bytes")]
|
||||||
@@ -204,6 +282,14 @@ pub fn ensure_ort_path() {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl AppConfig {
|
||||||
|
/// Returns the configured retrieval strategy, or [`RetrievalStrategy::Default`] when unset.
|
||||||
|
#[must_use]
|
||||||
|
pub fn resolved_retrieval_strategy(&self) -> RetrievalStrategy {
|
||||||
|
self.retrieval_strategy.unwrap_or_default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Default for AppConfig {
|
impl Default for AppConfig {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -249,3 +335,44 @@ pub fn get_config() -> Result<AppConfig, ConfigError> {
|
|||||||
|
|
||||||
config.try_deserialize()
|
config.try_deserialize()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::{ParseRetrievalStrategyError, RetrievalStrategy};
|
||||||
|
#[test]
|
||||||
|
fn retrieval_strategy_defaults_to_default() {
|
||||||
|
assert_eq!(
|
||||||
|
RetrievalStrategy::default(),
|
||||||
|
RetrievalStrategy::Default
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn retrieval_strategy_serializes_snake_case() {
|
||||||
|
assert_eq!(
|
||||||
|
serde_json::to_string(&RetrievalStrategy::Search).expect("serialize"),
|
||||||
|
"\"search\""
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn retrieval_strategy_from_str_accepts_deprecated_aliases() {
|
||||||
|
assert_eq!(
|
||||||
|
"initial".parse::<RetrievalStrategy>().expect("initial"),
|
||||||
|
RetrievalStrategy::Default
|
||||||
|
);
|
||||||
|
assert!(matches!(
|
||||||
|
"unknown".parse::<RetrievalStrategy>(),
|
||||||
|
Err(ParseRetrievalStrategyError { .. })
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn app_config_resolved_retrieval_strategy_uses_default_when_unset() {
|
||||||
|
let config = super::AppConfig::default();
|
||||||
|
assert_eq!(
|
||||||
|
config.resolved_retrieval_strategy(),
|
||||||
|
RetrievalStrategy::Default
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,13 +5,12 @@ use std::{
|
|||||||
sync::{Arc, Mutex},
|
sync::{Arc, Mutex},
|
||||||
};
|
};
|
||||||
|
|
||||||
use anyhow::{anyhow, Context, Result};
|
|
||||||
use async_openai::{types::CreateEmbeddingRequestArgs, Client};
|
use async_openai::{types::CreateEmbeddingRequestArgs, Client};
|
||||||
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
|
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
error::AppError,
|
error::{AppError, EmbeddingError},
|
||||||
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
|
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
|
||||||
utils::config::AppConfig,
|
utils::config::AppConfig,
|
||||||
};
|
};
|
||||||
@@ -57,16 +56,18 @@ enum EmbeddingInner {
|
|||||||
async fn run_fastembed(
|
async fn run_fastembed(
|
||||||
model: Arc<Mutex<TextEmbedding>>,
|
model: Arc<Mutex<TextEmbedding>>,
|
||||||
texts: Vec<String>,
|
texts: Vec<String>,
|
||||||
) -> Result<Vec<Vec<f32>>> {
|
) -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||||
tokio::task::spawn_blocking(move || {
|
match tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||||
let mut guard = model
|
let mut guard = model
|
||||||
.lock()
|
.lock()
|
||||||
.map_err(|e| anyhow!("fastembed model mutex poisoned: {e}"))?;
|
.map_err(EmbeddingError::mutex_poisoned)?;
|
||||||
guard.embed(texts, None)
|
guard.embed(texts, None).map_err(EmbeddingError::fastembed)
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.context("joining fastembed embedding task")?
|
{
|
||||||
.context("generating fastembed embeddings")
|
Ok(result) => result,
|
||||||
|
Err(join_error) => Err(EmbeddingError::from(join_error)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EmbeddingProvider {
|
impl EmbeddingProvider {
|
||||||
@@ -102,17 +103,14 @@ impl EmbeddingProvider {
|
|||||||
///
|
///
|
||||||
/// # Errors
|
/// # Errors
|
||||||
///
|
///
|
||||||
/// Returns `Err` if the backend API call fails, FastEmbed initialisation fails,
|
/// Returns [`EmbeddingError`] if the backend API call fails, FastEmbed initialisation fails,
|
||||||
/// or the backend returns no embedding data.
|
/// or the backend returns no embedding data.
|
||||||
pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
|
pub async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
|
||||||
match &self.inner {
|
match &self.inner {
|
||||||
EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)),
|
EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)),
|
||||||
EmbeddingInner::FastEmbed { model, .. } => {
|
EmbeddingInner::FastEmbed { model, .. } => {
|
||||||
let embeddings = run_fastembed(Arc::clone(model), vec![text.to_owned()]).await?;
|
let embeddings = run_fastembed(Arc::clone(model), vec![text.to_owned()]).await?;
|
||||||
embeddings
|
embeddings.into_iter().next().ok_or(EmbeddingError::NoData)
|
||||||
.into_iter()
|
|
||||||
.next()
|
|
||||||
.ok_or_else(|| anyhow!("fastembed returned no embedding for input"))
|
|
||||||
}
|
}
|
||||||
EmbeddingInner::OpenAI {
|
EmbeddingInner::OpenAI {
|
||||||
client,
|
client,
|
||||||
@@ -130,7 +128,7 @@ impl EmbeddingProvider {
|
|||||||
let embedding = response
|
let embedding = response
|
||||||
.data
|
.data
|
||||||
.first()
|
.first()
|
||||||
.ok_or_else(|| anyhow!("No embedding data received from OpenAI API"))?
|
.ok_or(EmbeddingError::NoData)?
|
||||||
.embedding
|
.embedding
|
||||||
.clone();
|
.clone();
|
||||||
|
|
||||||
@@ -143,9 +141,9 @@ impl EmbeddingProvider {
|
|||||||
///
|
///
|
||||||
/// # Errors
|
/// # Errors
|
||||||
///
|
///
|
||||||
/// Returns `Err` if the backend API call fails or returns no embedding data.
|
/// Returns [`EmbeddingError`] if the backend API call fails or returns no embedding data.
|
||||||
/// Returns an empty `Vec` when `texts` is empty.
|
/// Returns an empty `Vec` when `texts` is empty.
|
||||||
pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
|
pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||||
match &self.inner {
|
match &self.inner {
|
||||||
EmbeddingInner::Hashed { dimension } => Ok(texts
|
EmbeddingInner::Hashed { dimension } => Ok(texts
|
||||||
.into_iter()
|
.into_iter()
|
||||||
@@ -185,11 +183,14 @@ impl EmbeddingProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Currently infallible; reserved for future validation.
|
||||||
pub fn new_openai(
|
pub fn new_openai(
|
||||||
client: Arc<Client<async_openai::config::OpenAIConfig>>,
|
client: Arc<Client<async_openai::config::OpenAIConfig>>,
|
||||||
model: String,
|
model: String,
|
||||||
dimensions: u32,
|
dimensions: u32,
|
||||||
) -> Result<Self> {
|
) -> Result<Self, EmbeddingError> {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
inner: EmbeddingInner::OpenAI {
|
inner: EmbeddingInner::OpenAI {
|
||||||
client,
|
client,
|
||||||
@@ -199,9 +200,12 @@ impl EmbeddingProvider {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn new_fastembed(model_override: Option<String>) -> Result<Self> {
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns [`EmbeddingError`] if the model name is unknown or FastEmbed initialisation fails.
|
||||||
|
pub async fn new_fastembed(model_override: Option<String>) -> Result<Self, EmbeddingError> {
|
||||||
let model_name = if let Some(code) = model_override {
|
let model_name = if let Some(code) = model_override {
|
||||||
EmbeddingModel::from_str(&code).map_err(|err| anyhow!(err))?
|
EmbeddingModel::from_str(&code).map_err(EmbeddingError::UnknownModel)?
|
||||||
} else {
|
} else {
|
||||||
EmbeddingModel::default()
|
EmbeddingModel::default()
|
||||||
};
|
};
|
||||||
@@ -210,15 +214,21 @@ impl EmbeddingProvider {
|
|||||||
let model_name_for_task = model_name.clone();
|
let model_name_for_task = model_name.clone();
|
||||||
let model_name_code = model_name.to_string();
|
let model_name_code = model_name.to_string();
|
||||||
|
|
||||||
let (model, dimension) = tokio::task::spawn_blocking(move || -> Result<_> {
|
let (model, dimension) = match tokio::task::spawn_blocking(move || -> Result<_, EmbeddingError> {
|
||||||
let model =
|
let model =
|
||||||
TextEmbedding::try_new(options).context("initialising FastEmbed text model")?;
|
TextEmbedding::try_new(options).map_err(EmbeddingError::fastembed)?;
|
||||||
let info = EmbeddingModel::get_model_info(&model_name_for_task)
|
let info = EmbeddingModel::get_model_info(&model_name_for_task).ok_or_else(|| {
|
||||||
.ok_or_else(|| anyhow!("FastEmbed model metadata missing for {model_name_code}"))?;
|
EmbeddingError::Config(format!(
|
||||||
|
"fastembed model metadata missing for {model_name_code}"
|
||||||
|
))
|
||||||
|
})?;
|
||||||
Ok((model, info.dim))
|
Ok((model, info.dim))
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.context("joining FastEmbed initialisation task")??;
|
{
|
||||||
|
Ok(result) => result?,
|
||||||
|
Err(join_error) => return Err(EmbeddingError::from(join_error)),
|
||||||
|
};
|
||||||
|
|
||||||
Ok(EmbeddingProvider {
|
Ok(EmbeddingProvider {
|
||||||
inner: EmbeddingInner::FastEmbed {
|
inner: EmbeddingInner::FastEmbed {
|
||||||
@@ -229,7 +239,10 @@ impl EmbeddingProvider {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_hashed(dimension: usize) -> Result<Self> {
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Currently infallible; reserved for future validation.
|
||||||
|
pub fn new_hashed(dimension: usize) -> Result<Self, EmbeddingError> {
|
||||||
Ok(EmbeddingProvider {
|
Ok(EmbeddingProvider {
|
||||||
inner: EmbeddingInner::Hashed {
|
inner: EmbeddingInner::Hashed {
|
||||||
dimension: dimension.max(1),
|
dimension: dimension.max(1),
|
||||||
@@ -242,24 +255,32 @@ impl EmbeddingProvider {
|
|||||||
/// Model name and dimensions come from [`SystemSettings`]. The active backend is taken
|
/// Model name and dimensions come from [`SystemSettings`]. The active backend is taken
|
||||||
/// from `config.embedding_backend` at startup; [`SystemSettings::sync_from_embedding_provider`]
|
/// from `config.embedding_backend` at startup; [`SystemSettings::sync_from_embedding_provider`]
|
||||||
/// persists the resolved backend to the database.
|
/// persists the resolved backend to the database.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns [`EmbeddingError`] if the selected backend cannot be initialised.
|
||||||
pub async fn from_system_settings(
|
pub async fn from_system_settings(
|
||||||
settings: &SystemSettings,
|
settings: &SystemSettings,
|
||||||
config: &AppConfig,
|
config: &AppConfig,
|
||||||
openai_client: Option<Arc<Client<async_openai::config::OpenAIConfig>>>,
|
openai_client: Option<Arc<Client<async_openai::config::OpenAIConfig>>>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self, EmbeddingError> {
|
||||||
let dimensions = settings.embedding_dimensions;
|
let dimensions = settings.embedding_dimensions;
|
||||||
match config.embedding_backend {
|
match config.embedding_backend {
|
||||||
EmbeddingBackend::OpenAI => {
|
EmbeddingBackend::OpenAI => {
|
||||||
let client = openai_client
|
let client = openai_client.ok_or_else(|| {
|
||||||
.ok_or_else(|| anyhow!("OpenAI embedding backend requires an OpenAI client"))?;
|
EmbeddingError::Config(
|
||||||
|
"openai embedding backend requires an openai client".into(),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
Self::new_openai(client, settings.embedding_model.clone(), dimensions)
|
Self::new_openai(client, settings.embedding_model.clone(), dimensions)
|
||||||
}
|
}
|
||||||
EmbeddingBackend::FastEmbed => {
|
EmbeddingBackend::FastEmbed => {
|
||||||
Self::new_fastembed(Some(settings.embedding_model.clone())).await
|
Self::new_fastembed(Some(settings.embedding_model.clone())).await
|
||||||
}
|
}
|
||||||
EmbeddingBackend::Hashed => {
|
EmbeddingBackend::Hashed => {
|
||||||
let dimension = usize::try_from(dimensions)
|
let dimension = usize::try_from(dimensions).map_err(|_| {
|
||||||
.map_err(|_| anyhow!("embedding_dimensions exceeds usize::MAX"))?;
|
EmbeddingError::Config("embedding_dimensions exceeds usize::MAX".into())
|
||||||
|
})?;
|
||||||
Self::new_hashed(dimension)
|
Self::new_hashed(dimension)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -312,15 +333,12 @@ fn bucket(token: &str, dimension: usize) -> usize {
|
|||||||
///
|
///
|
||||||
/// # Errors
|
/// # Errors
|
||||||
///
|
///
|
||||||
/// Returns [`AppError::InternalError`] if the provider's embed call fails.
|
/// Returns [`AppError::Embedding`] if the provider's embed call fails.
|
||||||
pub async fn generate_embedding_with_provider(
|
pub async fn generate_embedding_with_provider(
|
||||||
provider: &EmbeddingProvider,
|
provider: &EmbeddingProvider,
|
||||||
input: &str,
|
input: &str,
|
||||||
) -> Result<Vec<f32>, AppError> {
|
) -> Result<Vec<f32>, AppError> {
|
||||||
provider
|
provider.embed(input).await.map_err(Into::into)
|
||||||
.embed(input)
|
|
||||||
.await
|
|
||||||
.map_err(AppError::internal)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates an embedding vector for the given input text using `OpenAI`'s embedding model.
|
/// Generates an embedding vector for the given input text using `OpenAI`'s embedding model.
|
||||||
|
|||||||
@@ -2,8 +2,11 @@ use common::storage::types::conversation::SidebarConversation;
|
|||||||
use common::storage::{db::SurrealDbClient, store::StorageManager};
|
use common::storage::{db::SurrealDbClient, store::StorageManager};
|
||||||
use common::utils::embedding::EmbeddingProvider;
|
use common::utils::embedding::EmbeddingProvider;
|
||||||
use common::utils::template_engine::{ProvidesTemplateEngine, TemplateEngine};
|
use common::utils::template_engine::{ProvidesTemplateEngine, TemplateEngine};
|
||||||
use common::{create_template_engine, storage::db::ProvidesDb, utils::config::AppConfig};
|
use common::{
|
||||||
use retrieval_pipeline::{reranking::RerankerPool, RetrievalStrategy};
|
create_template_engine, storage::db::ProvidesDb,
|
||||||
|
utils::config::{AppConfig, RetrievalStrategy},
|
||||||
|
};
|
||||||
|
use retrieval_pipeline::reranking::RerankerPool;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{
|
use std::sync::{
|
||||||
atomic::{AtomicUsize, Ordering},
|
atomic::{AtomicUsize, Ordering},
|
||||||
@@ -16,6 +19,7 @@ use tracing::debug;
|
|||||||
use crate::{OpenAIClientType, SessionStoreType};
|
use crate::{OpenAIClientType, SessionStoreType};
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
|
/// Shared application state for HTML handlers and middleware.
|
||||||
pub struct HtmlState {
|
pub struct HtmlState {
|
||||||
pub db: Arc<SurrealDbClient>,
|
pub db: Arc<SurrealDbClient>,
|
||||||
pub openai_client: Arc<OpenAIClientType>,
|
pub openai_client: Arc<OpenAIClientType>,
|
||||||
@@ -31,7 +35,7 @@ pub struct HtmlState {
|
|||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct ConversationArchiveCacheEntry {
|
struct ConversationArchiveCacheEntry {
|
||||||
conversations: Vec<SidebarConversation>,
|
conversations: Arc<[SidebarConversation]>,
|
||||||
expires_at: Instant,
|
expires_at: Instant,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -72,23 +76,19 @@ impl HtmlState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn retrieval_strategy(&self) -> RetrievalStrategy {
|
pub fn retrieval_strategy(&self) -> RetrievalStrategy {
|
||||||
self.config
|
self.config.resolved_retrieval_strategy()
|
||||||
.retrieval_strategy
|
|
||||||
.as_deref()
|
|
||||||
.and_then(|value| value.parse().ok())
|
|
||||||
.unwrap_or(RetrievalStrategy::Default)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_cached_conversation_archive(
|
pub async fn get_cached_conversation_archive(
|
||||||
&self,
|
&self,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
) -> Option<Vec<SidebarConversation>> {
|
) -> Option<Arc<[SidebarConversation]>> {
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
let should_evict_expired = {
|
let should_evict_expired = {
|
||||||
let cache = self.conversation_archive_cache.read().await;
|
let cache = self.conversation_archive_cache.read().await;
|
||||||
if let Some(entry) = cache.get(user_id) {
|
if let Some(entry) = cache.get(user_id) {
|
||||||
if entry.expires_at > now {
|
if entry.expires_at > now {
|
||||||
return Some(entry.conversations.clone());
|
return Some(Arc::clone(&entry.conversations));
|
||||||
}
|
}
|
||||||
true
|
true
|
||||||
} else {
|
} else {
|
||||||
@@ -107,7 +107,7 @@ impl HtmlState {
|
|||||||
pub async fn set_cached_conversation_archive(
|
pub async fn set_cached_conversation_archive(
|
||||||
&self,
|
&self,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
conversations: Vec<SidebarConversation>,
|
conversations: Arc<[SidebarConversation]>,
|
||||||
) {
|
) {
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
let mut cache = self.conversation_archive_cache.write().await;
|
let mut cache = self.conversation_archive_cache.write().await;
|
||||||
@@ -235,10 +235,10 @@ mod tests {
|
|||||||
cache.insert(
|
cache.insert(
|
||||||
user_id.to_string(),
|
user_id.to_string(),
|
||||||
ConversationArchiveCacheEntry {
|
ConversationArchiveCacheEntry {
|
||||||
conversations: vec![SidebarConversation {
|
conversations: Arc::from([SidebarConversation {
|
||||||
id: "conv-1".to_string(),
|
id: "conv-1".to_string(),
|
||||||
title: "A stale chat".to_string(),
|
title: "A stale chat".to_string(),
|
||||||
}],
|
}]),
|
||||||
expires_at: Instant::now() - Duration::from_secs(1),
|
expires_at: Instant::now() - Duration::from_secs(1),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ use crate::{
|
|||||||
auth_middleware::RequireUser,
|
auth_middleware::RequireUser,
|
||||||
response_middleware::{HtmlError, TemplateResponse},
|
response_middleware::{HtmlError, TemplateResponse},
|
||||||
},
|
},
|
||||||
|
utils::truncate::{first_non_empty_line, truncate_with_ellipsis},
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Serde deserialization decorator to map empty Strings to None,
|
/// Serde deserialization decorator to map empty Strings to None,
|
||||||
@@ -41,31 +42,6 @@ fn source_id_suffix(source_id: &str) -> String {
|
|||||||
source_id[start..].to_string()
|
source_id[start..].to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn truncate_label(value: &str, max_chars: usize) -> String {
|
|
||||||
let mut end = None;
|
|
||||||
for (count, (idx, _)) in value.char_indices().enumerate() {
|
|
||||||
if count == max_chars {
|
|
||||||
end = Some(idx);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
match end {
|
|
||||||
Some(idx) => format!("{}...", &value[..idx]),
|
|
||||||
None => value.to_string(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn first_non_empty_line(text: &str, max_chars: usize) -> Option<String> {
|
|
||||||
for line in text.lines() {
|
|
||||||
let trimmed = line.trim();
|
|
||||||
if !trimmed.is_empty() {
|
|
||||||
return Some(truncate_label(trimmed, max_chars));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct UrlInfoLabel {
|
struct UrlInfoLabel {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
@@ -121,7 +97,7 @@ fn build_source_label(row: &SourceLabelRow) -> String {
|
|||||||
if let Some(context) = row.context.as_ref() {
|
if let Some(context) = row.context.as_ref() {
|
||||||
let trimmed = context.trim();
|
let trimmed = context.trim();
|
||||||
if !trimmed.is_empty() {
|
if !trimmed.is_empty() {
|
||||||
return truncate_label(trimmed, MAX_LABEL_CHARS);
|
return truncate_with_ellipsis(trimmed, MAX_LABEL_CHARS);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -131,7 +107,7 @@ fn build_source_label(row: &SourceLabelRow) -> String {
|
|||||||
|
|
||||||
let category = row.category.trim();
|
let category = row.category.trim();
|
||||||
if !category.is_empty() {
|
if !category.is_empty() {
|
||||||
return truncate_label(category, MAX_LABEL_CHARS);
|
return truncate_with_ellipsis(category, MAX_LABEL_CHARS);
|
||||||
}
|
}
|
||||||
|
|
||||||
format!("Text snippet: {}", source_id_suffix(&row.id))
|
format!("Text snippet: {}", source_id_suffix(&row.id))
|
||||||
|
|||||||
@@ -158,10 +158,7 @@ async fn create_single_entity(
|
|||||||
);
|
);
|
||||||
|
|
||||||
let embedding = if let Some(provider) = embedding_provider {
|
let embedding = if let Some(provider) = embedding_provider {
|
||||||
provider
|
provider.embed(&embedding_input).await?
|
||||||
.embed(&embedding_input)
|
|
||||||
.await
|
|
||||||
.map_err(|e| AppError::InternalError(format!("FastEmbed embedding for entity failed: {e}")))?
|
|
||||||
} else {
|
} else {
|
||||||
generate_embedding(openai_client, &embedding_input, db_client).await?
|
generate_embedding(openai_client, &embedding_input, db_client).await?
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,21 +1,8 @@
|
|||||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||||
use std::fmt;
|
|
||||||
|
|
||||||
use crate::scoring::FusionWeights;
|
use crate::scoring::FusionWeights;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
pub use common::utils::config::RetrievalStrategy;
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum RetrievalStrategy {
|
|
||||||
/// Primary hybrid chunk retrieval for search/chat (formerly Revised)
|
|
||||||
#[default]
|
|
||||||
Default,
|
|
||||||
/// Entity retrieval for suggesting relationships when creating manual entities
|
|
||||||
RelationshipSuggestion,
|
|
||||||
/// Entity retrieval for context during content ingestion
|
|
||||||
Ingestion,
|
|
||||||
/// Unified search returning both chunks and entities
|
|
||||||
Search,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Configures which result types to include in Search strategy
|
/// Configures which result types to include in Search strategy
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||||
@@ -30,41 +17,6 @@ pub enum SearchTarget {
|
|||||||
Both,
|
Both,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::str::FromStr for RetrievalStrategy {
|
|
||||||
type Err = String;
|
|
||||||
|
|
||||||
fn from_str(value: &str) -> Result<Self, Self::Err> {
|
|
||||||
match value.to_ascii_lowercase().as_str() {
|
|
||||||
"default" => Ok(Self::Default),
|
|
||||||
// Backward compatibility: treat "initial" and "revised" as "default"
|
|
||||||
"initial" | "revised" => {
|
|
||||||
tracing::warn!(
|
|
||||||
"Retrieval strategy '{}' is deprecated. Use 'default' instead. \
|
|
||||||
The 'initial' strategy has been removed in favor of the simpler hybrid chunk retrieval.",
|
|
||||||
value
|
|
||||||
);
|
|
||||||
Ok(Self::Default)
|
|
||||||
}
|
|
||||||
"relationship_suggestion" => Ok(Self::RelationshipSuggestion),
|
|
||||||
"ingestion" => Ok(Self::Ingestion),
|
|
||||||
"search" => Ok(Self::Search),
|
|
||||||
other => Err(format!("unknown retrieval strategy '{other}'")),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl fmt::Display for RetrievalStrategy {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
let label = match self {
|
|
||||||
RetrievalStrategy::Default => "default",
|
|
||||||
RetrievalStrategy::RelationshipSuggestion => "relationship_suggestion",
|
|
||||||
RetrievalStrategy::Ingestion => "ingestion",
|
|
||||||
RetrievalStrategy::Search => "search",
|
|
||||||
};
|
|
||||||
f.write_str(label)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Two-variant flag that serializes as a bool for backward compatibility.
|
/// Two-variant flag that serializes as a bool for backward compatibility.
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||||
pub enum BoolFlag {
|
pub enum BoolFlag {
|
||||||
|
|||||||
@@ -256,11 +256,7 @@ pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
|||||||
} else {
|
} else {
|
||||||
debug!("Generating query embedding for hybrid retrieval");
|
debug!("Generating query embedding for hybrid retrieval");
|
||||||
let embedding = if let Some(provider) = ctx.embedding_provider {
|
let embedding = if let Some(provider) = ctx.embedding_provider {
|
||||||
provider.embed(&ctx.input_text).await.map_err(|e| {
|
provider.embed(&ctx.input_text).await?
|
||||||
AppError::InternalError(format!(
|
|
||||||
"Failed to generate embedding with provider: {e}",
|
|
||||||
))
|
|
||||||
})?
|
|
||||||
} else {
|
} else {
|
||||||
generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await?
|
generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await?
|
||||||
};
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user