From c53ec8c0a1b11aa2cc47f0c621d254aa27c16013 Mon Sep 17 00:00:00 2001 From: Per Stark Date: Sat, 6 Jun 2026 23:05:53 +0200 Subject: [PATCH] fix: arc-share retrieved chunks, centralize entity embeddings, and trim hot-path clones. --- api-router/src/middleware_api_auth.rs | 7 +- api-router/tests/api_router_integration.rs | 14 +--- common/src/storage/types/knowledge_entity.rs | 59 +++++++++++-- .../storage/types/knowledge_relationship.rs | 15 ++-- common/src/storage/types/text_chunk.rs | 6 +- common/src/storage/types/text_content.rs | 3 +- common/src/utils/embedding.rs | 39 +++++---- evaluations/src/pipeline/stages/prepare_db.rs | 14 ++-- evaluations/src/types.rs | 8 +- html-router/src/html_state.rs | 4 +- .../src/middlewares/response_middleware.rs | 4 +- html-router/src/router_factory.rs | 4 +- html-router/src/routes/account/handlers.rs | 4 +- html-router/src/routes/admin/handlers.rs | 49 +++++------ html-router/src/routes/auth/signup.rs | 5 +- html-router/src/routes/chat/chat_handlers.rs | 4 +- .../routes/chat/message_response_stream.rs | 7 +- html-router/src/routes/chat/mod.rs | 10 +-- html-router/src/routes/chat/references.rs | 4 +- html-router/src/routes/index/handlers.rs | 5 +- html-router/src/routes/ingestion/handlers.rs | 7 +- html-router/src/routes/knowledge/handlers.rs | 32 +++---- html-router/src/routes/scratchpad/handlers.rs | 2 +- html-router/src/routes/search/handlers.rs | 25 +++--- html-router/src/routes/search/mod.rs | 4 +- html-router/src/utils/pagination.rs | 18 +++- html-router/tests/router_integration.rs | 21 ++--- .../src/pipeline/enrichment_result.rs | 67 ++++++++------- .../src/pipeline/persistence.rs | 8 +- ingestion-pipeline/src/pipeline/stages.rs | 2 +- ingestion-pipeline/src/pipeline/tests.rs | 2 +- .../src/utils/file_text_extraction.rs | 5 +- main/src/bootstrap/mod.rs | 5 +- main/src/bootstrap/startup.rs | 28 ++++--- main/src/main.rs | 6 +- main/src/worker.rs | 5 +- retrieval-pipeline/src/lib.rs | 8 +- retrieval-pipeline/src/pipeline/context.rs | 4 +- retrieval-pipeline/src/pipeline/stages.rs | 84 +++++++++++++------ retrieval-pipeline/src/reranking.rs | 17 ++-- retrieval-pipeline/src/scoring.rs | 42 ++++++++-- 41 files changed, 368 insertions(+), 289 deletions(-) diff --git a/api-router/src/middleware_api_auth.rs b/api-router/src/middleware_api_auth.rs index 69369d3..92ad698 100644 --- a/api-router/src/middleware_api_auth.rs +++ b/api-router/src/middleware_api_auth.rs @@ -92,9 +92,10 @@ mod tests { #[test] fn extract_api_key_rejects_invalid_header_values() { let mut request = request_with_headers(&[]); - request - .headers_mut() - .insert("X-API-Key", HeaderValue::from_bytes(&[0xFF]).expect("invalid header")); + request.headers_mut().insert( + "X-API-Key", + HeaderValue::from_bytes(&[0xFF]).expect("invalid header"), + ); assert_eq!(extract_api_key(&request), None); } } diff --git a/api-router/tests/api_router_integration.rs b/api-router/tests/api_router_integration.rs index 11c7022..9227221 100644 --- a/api-router/tests/api_router_integration.rs +++ b/api-router/tests/api_router_integration.rs @@ -9,11 +9,7 @@ use axum::{ Router, }; use common::{ - storage::{ - db::SurrealDbClient, - store::StorageManager, - types::user::User, - }, + storage::{db::SurrealDbClient, store::StorageManager, types::user::User}, utils::config::{AppConfig, StorageKind}, }; use tower::ServiceExt; @@ -34,9 +30,7 @@ async fn build_test_app() -> (Router, Arc) { storage: StorageKind::Memory, ..Default::default() }; - let storage = StorageManager::new(&config) - .await - .expect("storage manager"); + let storage = StorageManager::new(&config).await.expect("storage manager"); let state = ApiState { db: Arc::clone(&db), @@ -147,9 +141,7 @@ async fn authenticated_user_can_list_categories() { .await .expect("test user"); - let api_key = User::set_api_key(&user.id, &db) - .await - .expect("api key"); + let api_key = User::set_api_key(&user.id, &db).await.expect("api key"); let response = app .clone() diff --git a/common/src/storage/types/knowledge_entity.rs b/common/src/storage/types/knowledge_entity.rs index ed0694e..19fab47 100644 --- a/common/src/storage/types/knowledge_entity.rs +++ b/common/src/storage/types/knowledge_entity.rs @@ -3,9 +3,12 @@ use std::collections::HashMap; use std::fmt::Write; use crate::{ - error::AppError, storage::db::SurrealDbClient, storage::indexes::hnsw_index_overwrite_sql, + error::AppError, + storage::db::SurrealDbClient, + storage::indexes::hnsw_index_overwrite_sql, storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding, - storage::types::system_settings::SystemSettings, stored_object, + storage::types::system_settings::SystemSettings, + stored_object, utils::embedding::{EmbeddingProvider, RE_EMBED_BATCH_SIZE}, }; use tracing::{error, info}; @@ -25,6 +28,17 @@ impl KnowledgeEntityType { pub fn variants() -> &'static [&'static str] { &["Idea", "Project", "Document", "Page", "TextSnippet"] } + + #[must_use] + pub const fn as_str(self) -> &'static str { + match self { + Self::Idea => "Idea", + Self::Project => "Project", + Self::Document => "Document", + Self::Page => "Page", + Self::TextSnippet => "TextSnippet", + } + } } impl From for KnowledgeEntityType { @@ -80,6 +94,27 @@ impl KnowledgeEntity { } } + /// Canonical text fed to the embedding provider for a knowledge entity. + #[must_use] + pub fn embedding_input_text( + name: &str, + description: &str, + entity_type: KnowledgeEntityType, + ) -> String { + let mut out = String::with_capacity( + name.len() + .saturating_add(description.len()) + .saturating_add(entity_type.as_str().len()) + .saturating_add(32), + ); + let _ = write!( + out, + "name: {name}, description: {description}, type: {}", + entity_type.as_str() + ); + out + } + /// Full-text search over knowledge entities using the BM25 FTS index. pub async fn fts_search( take: usize, @@ -314,8 +349,7 @@ impl KnowledgeEntity { db_client: &SurrealDbClient, embedding_provider: &EmbeddingProvider, ) -> Result<(), AppError> { - let embedding_input = - format!("name: {name}, description: {description}, type: {entity_type:?}",); + let embedding_input = Self::embedding_input_text(name, description, *entity_type); let embedding = embedding_provider.embed(&embedding_input).await?; let entity: KnowledgeEntity = db_client @@ -402,9 +436,10 @@ impl KnowledgeEntity { let inputs: Vec = batch .iter() .map(|entity| { - format!( - "name: {}, description: {}, type: {:?}", - entity.name, entity.description, entity.entity_type + Self::embedding_input_text( + &entity.name, + &entity.description, + entity.entity_type, ) }) .collect(); @@ -523,6 +558,16 @@ mod tests { use anyhow::{self, Context}; use uuid::Uuid; + #[test] + fn embedding_input_text_uses_canonical_type_label() { + let text = KnowledgeEntity::embedding_input_text( + "Alpha", + "Beta", + KnowledgeEntityType::TextSnippet, + ); + assert_eq!(text, "name: Alpha, description: Beta, type: TextSnippet"); + } + async fn ensure_entity_fts_indexes(db: &SurrealDbClient) -> anyhow::Result<()> { let snowball_sql = r#" DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english); diff --git a/common/src/storage/types/knowledge_relationship.rs b/common/src/storage/types/knowledge_relationship.rs index 6c3b4fa..b23265e 100644 --- a/common/src/storage/types/knowledge_relationship.rs +++ b/common/src/storage/types/knowledge_relationship.rs @@ -122,8 +122,7 @@ impl KnowledgeRelationship { .bind(("user_id", user_id.to_owned())) .await .map_err(AppError::from)?; - let deleted: Vec = - delete_result.take(0).map_err(AppError::from)?; + let deleted: Vec = delete_result.take(0).map_err(AppError::from)?; if !deleted.is_empty() { return Ok(()); @@ -567,8 +566,8 @@ mod tests { shared_source.to_string(), "references".to_string(), ); - let rel_a_id = rel_a.id.clone(); - let rel_b_id = rel_b.id.clone(); + let owner_relationship_id = rel_a.id.clone(); + let other_relationship_id = rel_b.id.clone(); rel_a.store_relationship(&db).await?; rel_b.store_relationship(&db).await?; @@ -576,8 +575,12 @@ mod tests { KnowledgeRelationship::delete_relationships_by_source_id(shared_source, user_a, &db) .await?; - assert!(get_relationship_by_id(&rel_a_id, &db).await.is_none()); - assert!(get_relationship_by_id(&rel_b_id, &db).await.is_some()); + assert!(get_relationship_by_id(&owner_relationship_id, &db) + .await + .is_none()); + assert!(get_relationship_by_id(&other_relationship_id, &db) + .await + .is_some()); Ok(()) } diff --git a/common/src/storage/types/text_chunk.rs b/common/src/storage/types/text_chunk.rs index 73c8306..5adbe0c 100644 --- a/common/src/storage/types/text_chunk.rs +++ b/common/src/storage/types/text_chunk.rs @@ -299,7 +299,11 @@ impl TextChunk { } processed = processed.saturating_add(batch.len()); - info!(progress = processed, total = total_chunks, "Re-embedding progress"); + info!( + progress = processed, + total = total_chunks, + "Re-embedding progress" + ); } info!("Successfully generated all new embeddings."); diff --git a/common/src/storage/types/text_content.rs b/common/src/storage/types/text_content.rs index c18c109..f1f1af0 100644 --- a/common/src/storage/types/text_content.rs +++ b/common/src/storage/types/text_content.rs @@ -140,8 +140,7 @@ impl TextContent { .await .map_err(AppError::from)?; - let existing: Option = - response.take(0).map_err(AppError::from)?; + let existing: Option = response.take(0).map_err(AppError::from)?; Ok(existing.is_some()) } diff --git a/common/src/utils/embedding.rs b/common/src/utils/embedding.rs index ae38373..a052e2c 100644 --- a/common/src/utils/embedding.rs +++ b/common/src/utils/embedding.rs @@ -38,7 +38,7 @@ enum EmbeddingInner { /// Client used to issue embedding requests. client: Arc>, /// Model identifier for the API. - model: String, + model: Arc, /// Expected output dimensions. dimensions: u32, }, @@ -272,8 +272,9 @@ struct FastEmbedLease { } impl FastEmbedLease { - async fn embed(&self, texts: Vec) -> Result>, EmbeddingError> { + async fn embed(&self, texts: &[String]) -> Result>, EmbeddingError> { let engine = Arc::clone(&self.engine); + let texts = texts.to_vec(); tokio::task::spawn_blocking(move || -> Result>, EmbeddingError> { let mut guard = engine.lock().map_err(EmbeddingError::mutex_poisoned)?; guard.embed(texts, None).map_err(EmbeddingError::fastembed) @@ -293,7 +294,7 @@ impl Drop for FastEmbedLease { async fn run_fastembed( pool: &Arc, - texts: Vec, + texts: &[String], ) -> Result>, EmbeddingError> { let lease = pool.checkout().await?; lease.embed(texts).await @@ -323,7 +324,7 @@ impl EmbeddingProvider { pub fn model_code(&self) -> Option { match &self.inner { EmbeddingInner::FastEmbed { model_name, .. } => Some(model_name.to_string()), - EmbeddingInner::OpenAI { model, .. } => Some(model.clone()), + EmbeddingInner::OpenAI { model, .. } => Some(model.as_ref().to_owned()), EmbeddingInner::Hashed { .. } => None, } } @@ -338,7 +339,8 @@ impl EmbeddingProvider { match &self.inner { EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)), EmbeddingInner::FastEmbed { pool, .. } => { - let embeddings = run_fastembed(pool, vec![text.to_owned()]).await?; + let text = text.to_owned(); + let embeddings = run_fastembed(pool, std::slice::from_ref(&text)).await?; embeddings.into_iter().next().ok_or(EmbeddingError::NoData) } EmbeddingInner::OpenAI { @@ -347,7 +349,7 @@ impl EmbeddingProvider { dimensions, } => { let request = CreateEmbeddingRequestArgs::default() - .model(model.clone()) + .model(model.as_ref()) .input([text]) .dimensions(*dimensions) .build()?; @@ -382,7 +384,7 @@ impl EmbeddingProvider { if texts.is_empty() { return Ok(Vec::new()); } - run_fastembed(pool, texts.to_vec()).await + run_fastembed(pool, texts).await } EmbeddingInner::OpenAI { client, @@ -394,7 +396,7 @@ impl EmbeddingProvider { } let request = CreateEmbeddingRequestArgs::default() - .model(model.clone()) + .model(model.as_ref()) .input(texts.to_vec()) .dimensions(*dimensions) .build()?; @@ -417,13 +419,13 @@ impl EmbeddingProvider { /// Currently infallible; reserved for future validation. pub fn new_openai( client: Arc>, - model: String, + model: impl AsRef, dimensions: u32, ) -> Result { Ok(Self { inner: EmbeddingInner::OpenAI { client, - model, + model: Arc::from(model.as_ref()), dimensions, }, }) @@ -520,7 +522,7 @@ impl EmbeddingProvider { "openai embedding backend requires an openai client".into(), ) })?; - Self::new_openai(client, settings.embedding_model.clone(), dimensions) + Self::new_openai(client, settings.embedding_model.as_str(), dimensions) } EmbeddingBackend::FastEmbed => { let pool_size = config @@ -586,11 +588,12 @@ mod tests { #![allow(clippy::expect_used)] use super::{ - align_fastembed_system_settings, fastembed_model_dimension, list_fastembed_embedding_models, - resolve_fastembed_model_code, DEFAULT_FASTEMBED_MODEL_CODE, EmbeddingError, + align_fastembed_system_settings, fastembed_model_dimension, + list_fastembed_embedding_models, resolve_fastembed_model_code, EmbeddingError, + DEFAULT_FASTEMBED_MODEL_CODE, }; - use crate::utils::config::{AppConfig, EmbeddingBackend, ParseEmbeddingBackendError}; use crate::storage::types::system_settings::SystemSettings; + use crate::utils::config::{AppConfig, EmbeddingBackend, ParseEmbeddingBackendError}; use serde_json::json; #[test] @@ -656,16 +659,16 @@ mod tests { fastembed_model: Some("Xenova/bge-base-en-v1.5".into()), ..AppConfig::default() }; - let resolved = resolve_fastembed_model_code(&config, "text-embedding-3-small") - .expect("config model"); + let resolved = + resolve_fastembed_model_code(&config, "text-embedding-3-small").expect("config model"); assert_eq!(resolved, "Xenova/bge-base-en-v1.5"); } #[test] fn resolve_fastembed_model_falls_back_from_openai_default() { let config = AppConfig::default(); - let resolved = resolve_fastembed_model_code(&config, "text-embedding-3-small") - .expect("default model"); + let resolved = + resolve_fastembed_model_code(&config, "text-embedding-3-small").expect("default model"); assert_eq!(resolved, DEFAULT_FASTEMBED_MODEL_CODE); } diff --git a/evaluations/src/pipeline/stages/prepare_db.rs b/evaluations/src/pipeline/stages/prepare_db.rs index b18e3be..01eff64 100644 --- a/evaluations/src/pipeline/stages/prepare_db.rs +++ b/evaluations/src/pipeline/stages/prepare_db.rs @@ -42,14 +42,12 @@ pub(crate) async fn prepare_db( // Create embedding provider directly from config (eval only supports FastEmbed and Hashed) let embedding_provider = match config.embedding_backend { - crate::args::EmbeddingBackend::FastEmbed => { - EmbeddingProvider::new_fastembed( - config.embedding_model.clone(), - default_embedding_pool_size(), - ) - .await - .context("creating FastEmbed provider")? - } + crate::args::EmbeddingBackend::FastEmbed => EmbeddingProvider::new_fastembed( + config.embedding_model.clone(), + default_embedding_pool_size(), + ) + .await + .context("creating FastEmbed provider")?, crate::args::EmbeddingBackend::Hashed => { EmbeddingProvider::new_hashed(1536).context("creating Hashed provider")? } diff --git a/evaluations/src/types.rs b/evaluations/src/types.rs index b54d369..0f4950f 100644 --- a/evaluations/src/types.rs +++ b/evaluations/src/types.rs @@ -196,7 +196,7 @@ pub struct EvaluationCandidate { } impl EvaluationCandidate { - fn from_entity(entity: RetrievedEntity) -> Self { + fn from_entity(entity: &RetrievedEntity) -> Self { let entity_category = Some(format!("{:?}", entity.entity.entity_type)); Self { entity_id: entity.entity.id().to_string(), @@ -223,9 +223,9 @@ impl EvaluationCandidate { } } -fn candidates_from_entities(entities: Vec) -> Vec { +fn candidates_from_entities(entities: &[RetrievedEntity]) -> Vec { entities - .into_iter() + .iter() .map(EvaluationCandidate::from_entity) .collect() } @@ -241,7 +241,7 @@ pub fn adapt_retrieval_output(output: RetrievalOutput) -> Vec candidates_from_chunks(chunks), RetrievalOutput::WithEntities { chunks, entities } => { - let mut candidates = candidates_from_entities(entities); + let mut candidates = candidates_from_entities(&entities); candidates.extend(candidates_from_chunks(chunks)); candidates.sort_by(|a, b| b.score.total_cmp(&a.score)); candidates diff --git a/html-router/src/html_state.rs b/html-router/src/html_state.rs index ce8ffb2..946851d 100644 --- a/html-router/src/html_state.rs +++ b/html-router/src/html_state.rs @@ -142,7 +142,9 @@ impl HtmlState { return; } - let overflow = cache.len().saturating_sub(CONVERSATION_ARCHIVE_CACHE_MAX_USERS); + let overflow = cache + .len() + .saturating_sub(CONVERSATION_ARCHIVE_CACHE_MAX_USERS); let mut by_expiry: Vec<(String, Instant)> = cache .iter() .map(|(user_id, entry)| (user_id.clone(), entry.expires_at)) diff --git a/html-router/src/middlewares/response_middleware.rs b/html-router/src/middlewares/response_middleware.rs index 5cf3d0a..4be6d96 100644 --- a/html-router/src/middlewares/response_middleware.rs +++ b/html-router/src/middlewares/response_middleware.rs @@ -183,9 +183,7 @@ fn forward_headers(from: &axum::http::HeaderMap, to: &mut axum::http::HeaderMap) } } -fn context_to_map( - value: &Value, -) -> Result, minijinja::value::ValueKind> { +fn context_to_map(value: &Value) -> Result, minijinja::value::ValueKind> { match value.kind() { minijinja::value::ValueKind::Map => { let mut map = HashMap::new(); diff --git a/html-router/src/router_factory.rs b/html-router/src/router_factory.rs index 2c85166..9c796fe 100644 --- a/html-router/src/router_factory.rs +++ b/html-router/src/router_factory.rs @@ -8,8 +8,8 @@ use surrealdb::{engine::any::Any, Surreal}; use crate::{ html_state::HtmlState, middlewares::{ - analytics_middleware::analytics_middleware, auth_middleware::require_auth, - compression, response_middleware::with_template_response, + analytics_middleware::analytics_middleware, auth_middleware::require_auth, compression, + response_middleware::with_template_response, }, }; diff --git a/html-router/src/routes/account/handlers.rs b/html-router/src/routes/account/handlers.rs index 902be8a..50a0ffa 100644 --- a/html-router/src/routes/account/handlers.rs +++ b/html-router/src/routes/account/handlers.rs @@ -164,9 +164,7 @@ pub async fn update_theme( )) } -pub async fn show_change_password( - RequireUser(_user): RequireUser, -) -> TemplateResult { +pub async fn show_change_password(RequireUser(_user): RequireUser) -> TemplateResult { Ok(TemplateResponse::new_template( "auth/change_password_form.html", (), diff --git a/html-router/src/routes/admin/handlers.rs b/html-router/src/routes/admin/handlers.rs index b46b51c..b80a970 100644 --- a/html-router/src/routes/admin/handlers.rs +++ b/html-router/src/routes/admin/handlers.rs @@ -18,8 +18,8 @@ use common::{ utils::{ config::AppConfig, embedding::{ - fastembed_model_dimension, is_valid_fastembed_model_code, list_fastembed_embedding_models, - EmbeddingBackend, FastEmbedModelOption, + fastembed_model_dimension, is_valid_fastembed_model_code, + list_fastembed_embedding_models, EmbeddingBackend, FastEmbedModelOption, }, }, }; @@ -52,7 +52,6 @@ pub enum AdminSection { Models, } - #[derive(Deserialize)] pub struct AdminPanelQuery { section: Option, @@ -101,8 +100,9 @@ pub async fn show_admin_panel( (None, None, false) }; - let effective_backend = - effective_embedding_backend(&settings, &state.config).as_str().to_string(); + let effective_backend = effective_embedding_backend(&settings, &state.config) + .as_str() + .to_string(); Ok(TemplateResponse::new_template( "admin/base.html", @@ -187,7 +187,9 @@ struct EmbeddingSettingsPlan { } fn effective_embedding_backend(settings: &SystemSettings, config: &AppConfig) -> EmbeddingBackend { - settings.embedding_backend.unwrap_or(config.embedding_backend) + settings + .embedding_backend + .unwrap_or(config.embedding_backend) } fn is_fastembed_admin_context(settings: &SystemSettings, config: &AppConfig) -> bool { @@ -241,11 +243,10 @@ fn plan_embedding_settings_update( ))); } - let embedding_dimensions = fastembed_model_dimension(&embedding_model) - .map_err(AppError::from)?; + let embedding_dimensions = + fastembed_model_dimension(&embedding_model).map_err(AppError::from)?; let reembedding_needed = embedding_dimensions != current.embedding_dimensions; - let restart_needed = - embedding_model != current.embedding_model || reembedding_needed; + let restart_needed = embedding_model != current.embedding_model || reembedding_needed; Ok(EmbeddingSettingsPlan { embedding_model, @@ -274,8 +275,7 @@ pub async fn update_model_settings( Form(input): Form, ) -> TemplateResult { let current_settings = SystemSettings::get_current(&state.db).await?; - let embedding_plan = - plan_embedding_settings_update(¤t_settings, &input, &state.config)?; + let embedding_plan = plan_embedding_settings_update(¤t_settings, &input, &state.config)?; let new_settings = SystemSettingsPatch { query_model: Some(input.query_model), @@ -309,10 +309,11 @@ pub async fn update_model_settings( .await .map_err(|_e| AppError::InternalError("Failed to get models".to_string()))?; - let effective_backend = - effective_embedding_backend(&new_settings, &state.config).as_str().to_string(); - let show_fastembed_models = - is_fastembed_admin_context(&new_settings, &state.config).then(list_fastembed_embedding_models); + let effective_backend = effective_embedding_backend(&new_settings, &state.config) + .as_str() + .to_string(); + let show_fastembed_models = is_fastembed_admin_context(&new_settings, &state.config) + .then(list_fastembed_embedding_models); Ok(TemplateResponse::new_partial( "admin/sections/models.html", @@ -368,8 +369,8 @@ mod tests { embedding_model: Some("Xenova/bge-base-en-v1.5".into()), embedding_dimensions: None, }; - let plan = plan_embedding_settings_update(¤t, &input, &AppConfig::default()) - .expect("plan"); + let plan = + plan_embedding_settings_update(¤t, &input, &AppConfig::default()).expect("plan"); assert_eq!(plan.embedding_model, "Xenova/bge-base-en-v1.5"); assert_eq!(plan.embedding_dimensions, 768); assert!(plan.reembedding_needed); @@ -407,9 +408,7 @@ pub struct SystemPromptEditData { default_query_prompt: String, } -pub async fn show_edit_system_prompt( - State(state): State, -) -> TemplateResult { +pub async fn show_edit_system_prompt(State(state): State) -> TemplateResult { let settings = SystemSettings::get_current(&state.db).await?; Ok(TemplateResponse::new_template( @@ -457,9 +456,7 @@ pub struct IngestionPromptEditData { default_ingestion_prompt: String, } -pub async fn show_edit_ingestion_prompt( - State(state): State, -) -> TemplateResult { +pub async fn show_edit_ingestion_prompt(State(state): State) -> TemplateResult { let settings = SystemSettings::get_current(&state.db).await?; Ok(TemplateResponse::new_template( @@ -502,9 +499,7 @@ pub struct ImagePromptEditData { default_image_prompt: String, } -pub async fn show_edit_image_prompt( - State(state): State, -) -> TemplateResult { +pub async fn show_edit_image_prompt(State(state): State) -> TemplateResult { let settings = SystemSettings::get_current(&state.db).await?; Ok(TemplateResponse::new_template( diff --git a/html-router/src/routes/auth/signup.rs b/html-router/src/routes/auth/signup.rs index d195803..dfa36be 100644 --- a/html-router/src/routes/auth/signup.rs +++ b/html-router/src/routes/auth/signup.rs @@ -2,7 +2,10 @@ use axum::{extract::State, Form}; use axum_htmx::HxBoosted; use serde::{Deserialize, Serialize}; -use common::{error::AppError, storage::types::user::{Theme, User}}; +use common::{ + error::AppError, + storage::types::user::{Theme, User}, +}; use crate::{ html_state::HtmlState, diff --git a/html-router/src/routes/chat/chat_handlers.rs b/html-router/src/routes/chat/chat_handlers.rs index 7e7e0ef..0ad838d 100644 --- a/html-router/src/routes/chat/chat_handlers.rs +++ b/html-router/src/routes/chat/chat_handlers.rs @@ -18,8 +18,8 @@ use crate::{ middlewares::{ auth_middleware::RequireUser, response_middleware::{ - template_as_response, template_with_headers, TemplateResponse, TemplateResult, - ResponseResult, + template_as_response, template_with_headers, ResponseResult, TemplateResponse, + TemplateResult, }, }, }; diff --git a/html-router/src/routes/chat/message_response_stream.rs b/html-router/src/routes/chat/message_response_stream.rs index b00ca76..11bdab8 100644 --- a/html-router/src/routes/chat/message_response_stream.rs +++ b/html-router/src/routes/chat/message_response_stream.rs @@ -22,8 +22,8 @@ use retrieval_pipeline::answer_retrieval::{ }; use serde::{Deserialize, Serialize}; use serde_json::from_str; -use tokio::sync::Mutex; use tokio::sync::mpsc::channel; +use tokio::sync::Mutex; use tracing::{debug, error, info}; use common::storage::{ @@ -36,10 +36,7 @@ use common::storage::{ }, }; -use crate::{ - html_state::HtmlState, - middlewares::auth_middleware::RequireUser, -}; +use crate::{html_state::HtmlState, middlewares::auth_middleware::RequireUser}; use super::reference_validation::{collect_reference_ids_from_retrieval, validate_references}; diff --git a/html-router/src/routes/chat/mod.rs b/html-router/src/routes/chat/mod.rs index e4162e8..20912ef 100644 --- a/html-router/src/routes/chat/mod.rs +++ b/html-router/src/routes/chat/mod.rs @@ -3,15 +3,11 @@ mod message_response_stream; mod reference_validation; mod references; -use axum::{ - extract::FromRef, - routing::get, - Router, -}; +use axum::{extract::FromRef, routing::get, Router}; pub use chat_handlers::{ delete_conversation, new_chat_user_message, new_user_message, patch_conversation_title, - reload_sidebar, show_conversation_editing_title, - show_chat_base as show_base, show_existing_chat as show_existing, + reload_sidebar, show_chat_base as show_base, show_conversation_editing_title, + show_existing_chat as show_existing, }; use message_response_stream::get_response_stream; use references::show_reference_tooltip; diff --git a/html-router/src/routes/chat/references.rs b/html-router/src/routes/chat/references.rs index f09c294..93cc8e1 100644 --- a/html-router/src/routes/chat/references.rs +++ b/html-router/src/routes/chat/references.rs @@ -1,8 +1,6 @@ #![allow(clippy::missing_docs_in_private_items)] -use axum::{ - extract::{Path, State}, -}; +use axum::extract::{Path, State}; use chrono::{DateTime, Utc}; use chrono_tz::Tz; use serde::Serialize; diff --git a/html-router/src/routes/index/handlers.rs b/html-router/src/routes/index/handlers.rs index 0bca199..2b76a12 100644 --- a/html-router/src/routes/index/handlers.rs +++ b/html-router/src/routes/index/handlers.rs @@ -13,7 +13,7 @@ use crate::{ middlewares::{ auth_middleware::RequireUser, response_middleware::{ - template_as_response, TemplateResponse, TemplateResult, ResponseResult, + template_as_response, ResponseResult, TemplateResponse, TemplateResult, }, }, utils::text_content_preview::truncate_text_contents, @@ -84,7 +84,8 @@ pub async fn delete_text_content( // Delete the text content and any related data TextChunk::delete_by_source_id(&text_content.id, &state.db).await?; KnowledgeEntity::delete_by_source_id(&text_content.id, &state.db).await?; - KnowledgeRelationship::delete_relationships_by_source_id(&text_content.id, &user.id, &state.db).await?; + KnowledgeRelationship::delete_relationships_by_source_id(&text_content.id, &user.id, &state.db) + .await?; state .db .delete_item::(&text_content.id) diff --git a/html-router/src/routes/ingestion/handlers.rs b/html-router/src/routes/ingestion/handlers.rs index 49dd303..cefc10a 100644 --- a/html-router/src/routes/ingestion/handlers.rs +++ b/html-router/src/routes/ingestion/handlers.rs @@ -63,9 +63,7 @@ pub async fn show_ingest_form( )) } -pub async fn hide_ingest_form( - RequireUser(_user): RequireUser, -) -> TemplateResult { +pub async fn hide_ingest_form(RequireUser(_user): RequireUser) -> TemplateResult { Ok(TemplateResponse::new_template( "ingestion/add_content_button.html", (), @@ -148,8 +146,7 @@ pub async fn process_ingest_form( user.id.clone(), )?; - let tasks = - IngestionTask::create_all_and_add_to_db(payloads, &user.id, &state.db).await?; + let tasks = IngestionTask::create_all_and_add_to_db(payloads, &user.id, &state.db).await?; Ok(TemplateResponse::new_template( "dashboard/current_task.html", diff --git a/html-router/src/routes/knowledge/handlers.rs b/html-router/src/routes/knowledge/handlers.rs index fe0f3c9..db5a297 100644 --- a/html-router/src/routes/knowledge/handlers.rs +++ b/html-router/src/routes/knowledge/handlers.rs @@ -37,7 +37,7 @@ use crate::{ middlewares::{ auth_middleware::RequireUser, response_middleware::{ - template_with_headers, TemplateResponse, TemplateResult, ResponseResult, + template_with_headers, ResponseResult, TemplateResponse, TemplateResult, }, }, utils::pagination::{paginate_items, paginate_slice, Pagination}, @@ -185,8 +185,7 @@ pub async fn create_knowledge_entity( let description = form.description.trim().to_string(); let entity_type = KnowledgeEntityType::from(form.entity_type.trim().to_string()); - let embedding_input = - format!("name: {name}, description: {description}, type: {entity_type:?}"); + let embedding_input = KnowledgeEntity::embedding_input_text(&name, &description, entity_type); let embedding = state .embedding_provider .embed(&embedding_input) @@ -290,10 +289,12 @@ pub async fn suggest_knowledge_relationships( if !query_parts.is_empty() { let name = form.name.as_deref().unwrap_or("").trim(); let description = form.description.as_deref().unwrap_or("").trim(); - let entity_type = form.entity_type.as_deref().map_or( - KnowledgeEntityType::Document, - |value| KnowledgeEntityType::from(value.to_string()), - ); + let entity_type = form + .entity_type + .as_deref() + .map_or(KnowledgeEntityType::Document, |value| { + KnowledgeEntityType::from(value.to_string()) + }); let suggested = suggest_related_entities( &state.db, @@ -374,10 +375,8 @@ async fn suggest_related_entities( draft: DraftEntityQuery<'_>, entity_lookup: &HashMap, ) -> Result, AppError> { - let embedding_input = format!( - "name: {}, description: {}, type: {:?}", - draft.name, draft.description, draft.entity_type - ); + let embedding_input = + KnowledgeEntity::embedding_input_text(draft.name, draft.description, draft.entity_type); let embedding = embedding_provider.embed(&embedding_input).await?; let take = MAX_RELATIONSHIP_SUGGESTIONS * 2; @@ -484,11 +483,7 @@ fn build_relationship_options( fn build_relationship_rows( relationships: Vec, -) -> ( - Vec, - Vec, - String, -) { +) -> (Vec, Vec, String) { let relationship_type_options = collect_relationship_type_options(&relationships); let mut frequency: HashMap = HashMap::new(); let relationships = relationships @@ -509,10 +504,7 @@ fn build_relationship_rows( let default_relationship_type = frequency .into_iter() .max_by_key(|(_, count)| *count) - .map_or_else( - || DEFAULT_RELATIONSHIP_TYPE.to_string(), - |(label, _)| label, - ); + .map_or_else(|| DEFAULT_RELATIONSHIP_TYPE.to_string(), |(label, _)| label); ( relationships, diff --git a/html-router/src/routes/scratchpad/handlers.rs b/html-router/src/routes/scratchpad/handlers.rs index 383dce9..12ad051 100644 --- a/html-router/src/routes/scratchpad/handlers.rs +++ b/html-router/src/routes/scratchpad/handlers.rs @@ -12,7 +12,7 @@ use crate::html_state::HtmlState; use crate::middlewares::{ auth_middleware::RequireUser, response_middleware::{ - template_with_headers, TemplateResponse, TemplateResult, ResponseResult, + template_with_headers, ResponseResult, TemplateResponse, TemplateResult, }, }; use common::storage::types::{ diff --git a/html-router/src/routes/search/handlers.rs b/html-router/src/routes/search/handlers.rs index 4174c18..9ce386e 100644 --- a/html-router/src/routes/search/handlers.rs +++ b/html-router/src/routes/search/handlers.rs @@ -1,11 +1,11 @@ use std::collections::HashSet; -use axum::{ - extract::{Query, State}, -}; +use axum::extract::{Query, State}; use axum_htmx::{HxBoosted, HxRequest}; use common::storage::types::{text_content::TextContent, user::User}; -use retrieval_pipeline::{retrieve, RetrievalConfig, RetrievalOutput, RetrievedChunk, RetrievedEntity}; +use retrieval_pipeline::{ + retrieve, RetrievalConfig, RetrievalOutput, RetrievedChunk, RetrievedEntity, +}; use serde::{de, Deserialize, Deserializer, Serialize}; use std::{fmt, str::FromStr}; @@ -48,9 +48,7 @@ impl<'de> Deserialize<'de> for SearchView { Some("chunks") => SearchView::Chunks, Some("entities") => SearchView::Entities, Some(other) => { - return Err(de::Error::custom(format!( - "invalid search view: {other}" - ))); + return Err(de::Error::custom(format!("invalid search view: {other}"))); } }) } @@ -121,13 +119,12 @@ pub async fn search_result_handler( HxBoosted(is_boosted): HxBoosted, ) -> TemplateResult { let view = params.view; - let (search_results_for_template, final_query_param_for_template) = if let Some(actual_query) = - params.query - { - perform_search(&state, &user, actual_query, view).await? - } else { - (Vec::::new(), String::new()) - }; + let (search_results_for_template, final_query_param_for_template) = + if let Some(actual_query) = params.query { + perform_search(&state, &user, actual_query, view).await? + } else { + (Vec::::new(), String::new()) + }; let data = AnswerData { search_result: search_results_for_template, diff --git a/html-router/src/routes/search/mod.rs b/html-router/src/routes/search/mod.rs index cbd93b0..a1a62d1 100644 --- a/html-router/src/routes/search/mod.rs +++ b/html-router/src/routes/search/mod.rs @@ -2,9 +2,7 @@ mod handlers; use axum::{extract::FromRef, routing::get, Router}; #[allow(clippy::module_name_repetitions)] -pub use handlers::{ - search_result_handler as result_handler, SearchParams as SearchQueryParams, -}; +pub use handlers::{search_result_handler as result_handler, SearchParams as SearchQueryParams}; use crate::html_state::HtmlState; diff --git a/html-router/src/utils/pagination.rs b/html-router/src/utils/pagination.rs index b20ffd2..3b0ea89 100644 --- a/html-router/src/utils/pagination.rs +++ b/html-router/src/utils/pagination.rs @@ -31,8 +31,16 @@ impl Pagination { } else { 0 }; - let start_index = if page_len == 0 { 0 } else { offset.saturating_add(1) }; - let end_index = if page_len == 0 { 0 } else { offset.saturating_add(page_len) }; + let start_index = if page_len == 0 { + 0 + } else { + offset.saturating_add(1) + }; + let end_index = if page_len == 0 { + 0 + } else { + offset.saturating_add(page_len) + }; Self { current_page, @@ -109,7 +117,11 @@ pub fn paginate_items( let total_pages = if total_items == 0 { 0 } else { - total_items.saturating_sub(1).checked_div(per_page).unwrap_or(0).saturating_add(1) + total_items + .saturating_sub(1) + .checked_div(per_page) + .unwrap_or(0) + .saturating_add(1) }; let mut current_page = requested_page.unwrap_or(1); diff --git a/html-router/tests/router_integration.rs b/html-router/tests/router_integration.rs index 5349785..0f8347f 100644 --- a/html-router/tests/router_integration.rs +++ b/html-router/tests/router_integration.rs @@ -9,11 +9,7 @@ use axum::{ Router, }; use common::{ - storage::{ - db::SurrealDbClient, - store::StorageManager, - types::user::User, - }, + storage::{db::SurrealDbClient, store::StorageManager, types::user::User}, utils::{ config::{AppConfig, StorageKind}, embedding::EmbeddingProvider, @@ -37,24 +33,17 @@ async fn build_test_app() -> (Router, Arc) { .await .expect("migrations should apply"); - let session_store = Arc::new( - db.create_session_store() - .await - .expect("session store"), - ); + let session_store = Arc::new(db.create_session_store().await.expect("session store")); let config = AppConfig { storage: StorageKind::Memory, ..Default::default() }; - let storage = StorageManager::new(&config) - .await - .expect("storage manager"); + let storage = StorageManager::new(&config).await.expect("storage manager"); - let embedding_provider = Arc::new( - EmbeddingProvider::new_hashed(8).expect("embedding provider"), - ); + let embedding_provider = + Arc::new(EmbeddingProvider::new_hashed(8).expect("embedding provider")); let state = HtmlState::new_with_resources(StateResources { db: Arc::clone(&db), diff --git a/ingestion-pipeline/src/pipeline/enrichment_result.rs b/ingestion-pipeline/src/pipeline/enrichment_result.rs index b9ee892..aa2ceae 100644 --- a/ingestion-pipeline/src/pipeline/enrichment_result.rs +++ b/ingestion-pipeline/src/pipeline/enrichment_result.rs @@ -6,11 +6,9 @@ use serde::{Deserialize, Serialize}; use common::{ error::AppError, - storage::{ - types::{ - knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}, - knowledge_relationship::KnowledgeRelationship, - }, + storage::types::{ + knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}, + knowledge_relationship::KnowledgeRelationship, }, utils::embedding::EmbeddingProvider, }; @@ -83,25 +81,32 @@ impl LLMEnrichmentResult { entity_concurrency: usize, embedding_provider: &EmbeddingProvider, ) -> Result, AppError> { - stream::iter(self.knowledge_entities.clone().into_iter().map(|entity| { - let mapper = Arc::clone(&mapper); - let source_id = source_id.to_string(); - let user_id = user_id.to_string(); + let tasks: Vec<_> = self + .knowledge_entities + .iter() + .map(|entity| { + let llm_entity = entity.clone(); + let mapper = Arc::clone(&mapper); + let source_id = source_id.to_string(); + let user_id = user_id.to_string(); - async move { - create_single_entity( - &entity, - &source_id, - &user_id, - mapper, - embedding_provider, - ) - .await - } - })) - .buffer_unordered(entity_concurrency.max(1)) - .try_collect() - .await + async move { + create_single_entity( + llm_entity, + &source_id, + &user_id, + mapper, + embedding_provider, + ) + .await + } + }) + .collect(); + + stream::iter(tasks) + .buffer_unordered(entity_concurrency.max(1)) + .try_collect() + .await } fn process_relationships( @@ -129,7 +134,7 @@ impl LLMEnrichmentResult { } async fn create_single_entity( - llm_entity: &LLMKnowledgeEntity, + llm_entity: LLMKnowledgeEntity, source_id: &str, user_id: &str, mapper: Arc, @@ -137,9 +142,11 @@ async fn create_single_entity( ) -> Result { let assigned_id = mapper.get_id(&llm_entity.key)?.to_string(); - let embedding_input = format!( - "name: {}, description: {}, type: {}", - llm_entity.name, llm_entity.description, llm_entity.entity_type + let entity_type = KnowledgeEntityType::from(llm_entity.entity_type); + let embedding_input = KnowledgeEntity::embedding_input_text( + &llm_entity.name, + &llm_entity.description, + entity_type, ); let embedding = embedding_provider.embed(&embedding_input).await?; @@ -149,9 +156,9 @@ async fn create_single_entity( id: assigned_id, created_at: now, updated_at: now, - name: llm_entity.name.clone(), - description: llm_entity.description.clone(), - entity_type: KnowledgeEntityType::from(llm_entity.entity_type.clone()), + name: llm_entity.name, + description: llm_entity.description, + entity_type, source_id: source_id.to_string(), metadata: None, user_id: user_id.into(), diff --git a/ingestion-pipeline/src/pipeline/persistence.rs b/ingestion-pipeline/src/pipeline/persistence.rs index 2c37f35..3182552 100644 --- a/ingestion-pipeline/src/pipeline/persistence.rs +++ b/ingestion-pipeline/src/pipeline/persistence.rs @@ -48,20 +48,20 @@ const STORE_RELATIONSHIPS: &str = r" pub(super) async fn store_vector_chunks( db: &SurrealDbClient, task_id: &str, - chunks: &[EmbeddedTextChunk], + chunks: Vec, ) -> Result { + let chunk_count = chunks.len(); for embedded in chunks { - TextChunk::store_with_embedding(embedded.chunk.clone(), embedded.embedding.clone(), db) - .await?; debug!( task_id = %task_id, chunk_id = %embedded.chunk.id, chunk_len = embedded.chunk.chunk.chars().count(), "chunk persisted" ); + TextChunk::store_with_embedding(embedded.chunk, embedded.embedding, db).await?; } - Ok(chunks.len()) + Ok(chunk_count) } /// Persists knowledge entities and their relationships. diff --git a/ingestion-pipeline/src/pipeline/stages.rs b/ingestion-pipeline/src/pipeline/stages.rs index 9ed4b9b..9484141 100644 --- a/ingestion-pipeline/src/pipeline/stages.rs +++ b/ingestion-pipeline/src/pipeline/stages.rs @@ -155,7 +155,7 @@ pub async fn persist( let entity_count = entities.len(); let relationship_count = relationships.len(); - let chunk_count = store_vector_chunks(ctx.db, ctx.task_id.as_str(), &chunks).await?; + let chunk_count = store_vector_chunks(ctx.db, ctx.task_id.as_str(), chunks).await?; store_graph_entities(ctx.db, &ctx.pipeline_config.tuning, entities, relationships).await?; ctx.db.store_item(text_content).await?; rebuild(ctx.db).await?; diff --git a/ingestion-pipeline/src/pipeline/tests.rs b/ingestion-pipeline/src/pipeline/tests.rs index c702027..0a6ea8f 100644 --- a/ingestion-pipeline/src/pipeline/tests.rs +++ b/ingestion-pipeline/src/pipeline/tests.rs @@ -92,7 +92,7 @@ impl MockServices { entity: retrieved_entity, score: 0.8, chunks: std::sync::Arc::new(vec![RetrievedChunk { - chunk: retrieved_chunk, + chunk: std::sync::Arc::new(retrieved_chunk), score: 0.7, }]), }], diff --git a/ingestion-pipeline/src/utils/file_text_extraction.rs b/ingestion-pipeline/src/utils/file_text_extraction.rs index 90df7c7..dee232b 100644 --- a/ingestion-pipeline/src/utils/file_text_extraction.rs +++ b/ingestion-pipeline/src/utils/file_text_extraction.rs @@ -74,10 +74,7 @@ pub async fn extract_text_from_file( config: &AppConfig, storage: &StorageManager, ) -> Result { - let file_bytes = storage - .get(&file_info.path) - .await - .map_err(AppError::from)?; + let file_bytes = storage.get(&file_info.path).await.map_err(AppError::from)?; let local_path = resolve_existing_local_path(storage, &file_info.path).await; match file_info.mime_type.as_str() { diff --git a/main/src/bootstrap/mod.rs b/main/src/bootstrap/mod.rs index 67724dc..f6db73c 100644 --- a/main/src/bootstrap/mod.rs +++ b/main/src/bootstrap/mod.rs @@ -8,10 +8,7 @@ use std::sync::Arc; use anyhow::Context; use async_openai::Client; use common::{ - storage::{ - db::SurrealDbClient, - store::StorageManager, - }, + storage::{db::SurrealDbClient, store::StorageManager}, utils::{ config::{get_config, AppConfig}, embedding::{align_fastembed_system_settings, EmbeddingProvider}, diff --git a/main/src/bootstrap/startup.rs b/main/src/bootstrap/startup.rs index a2f423e..9418e1f 100644 --- a/main/src/bootstrap/startup.rs +++ b/main/src/bootstrap/startup.rs @@ -67,7 +67,8 @@ pub async fn prepare_embedding_runtime( let index_dim = if mismatch { match role { EmbeddingRuntimeRole::Maintainer => { - reconcile_embeddings(&services.db, &services.embedding_provider, target_dim).await?; + reconcile_embeddings(&services.db, &services.embedding_provider, target_dim) + .await?; target_dim } EmbeddingRuntimeRole::ReadOnly => { @@ -238,9 +239,7 @@ mod tests { stored_dim: usize, target_dim: usize, ) -> (super::SharedServices, std::path::PathBuf) { - let (mut services, data_dir) = init_smoke_services() - .await - .expect("smoke services"); + let (mut services, data_dir) = init_smoke_services().await.expect("smoke services"); ensure_runtime(&services.db, stored_dim) .await @@ -254,9 +253,8 @@ mod tests { .await .expect("update settings"); - services.embedding_provider = Arc::new( - EmbeddingProvider::new_hashed(target_dim).expect("hashed provider for test"), - ); + services.embedding_provider = + Arc::new(EmbeddingProvider::new_hashed(target_dim).expect("hashed provider for test")); (services, data_dir) } @@ -270,7 +268,9 @@ mod tests { .expect("maintainer startup"); assert_eq!( - embedding_index_dimension(&services.db).await.expect("index dim"), + embedding_index_dimension(&services.db) + .await + .expect("index dim"), Some(5), "maintainer should rebuild the index to the provider dimension" ); @@ -287,7 +287,9 @@ mod tests { .expect("read-only startup"); assert_eq!( - embedding_index_dimension(&services.db).await.expect("index dim"), + embedding_index_dimension(&services.db) + .await + .expect("index dim"), Some(3), "read-only server must not overwrite the index before a maintainer re-embeds" ); @@ -297,9 +299,7 @@ mod tests { #[tokio::test] async fn maintainer_reembeds_chunks_when_index_dimension_differs() { - let (mut services, data_dir) = init_smoke_services() - .await - .expect("smoke services"); + let (mut services, data_dir) = init_smoke_services().await.expect("smoke services"); let mut settings = SystemSettings::get_current(&services.db) .await @@ -339,7 +339,9 @@ mod tests { .expect("maintainer startup with data"); assert_eq!( - embedding_index_dimension(&services.db).await.expect("index dim"), + embedding_index_dimension(&services.db) + .await + .expect("index dim"), Some(5) ); diff --git a/main/src/main.rs b/main/src/main.rs index 1c44955..700effd 100644 --- a/main/src/main.rs +++ b/main/src/main.rs @@ -95,7 +95,11 @@ mod tests { use common::storage::types::{system_settings::SystemSettings, user::User}; use tower::ServiceExt; - async fn build_test_app() -> (Router, Arc, std::path::PathBuf) { + async fn build_test_app() -> ( + Router, + Arc, + std::path::PathBuf, + ) { let (services, data_dir) = init_smoke_services() .await .expect("failed to init services"); diff --git a/main/src/worker.rs b/main/src/worker.rs index 4afc499..79c52eb 100644 --- a/main/src/worker.rs +++ b/main/src/worker.rs @@ -68,9 +68,8 @@ mod tests { let db = Arc::clone(&services.db); let pipeline = Arc::new(pipeline); - let worker = tokio::spawn(async move { - ingestion_pipeline::run_worker_loop(db, pipeline).await - }); + let worker = + tokio::spawn(async move { ingestion_pipeline::run_worker_loop(db, pipeline).await }); tokio::time::sleep(Duration::from_millis(250)).await; assert!( diff --git a/retrieval-pipeline/src/lib.rs b/retrieval-pipeline/src/lib.rs index d27eb32..975e737 100644 --- a/retrieval-pipeline/src/lib.rs +++ b/retrieval-pipeline/src/lib.rs @@ -7,6 +7,8 @@ pub mod scoring; use std::sync::Arc; +pub use scoring::RetrievalCandidate; + use common::{ error::AppError, storage::{ @@ -45,7 +47,7 @@ pub(crate) fn round_score(value: f32) -> f64 { // Captures a supporting chunk plus its fused retrieval score for downstream prompts. #[derive(Debug, Clone)] pub struct RetrievedChunk { - pub chunk: TextChunk, + pub chunk: Arc, pub score: f32, } @@ -159,7 +161,9 @@ mod tests { assert!(!chunks.is_empty(), "Expected at least one retrieval result"); assert!( - chunks.first().is_some_and(|c| c.chunk.chunk.contains("Tokio")), + chunks + .first() + .is_some_and(|c| c.chunk.chunk.contains("Tokio")), "Expected chunk about Tokio" ); Ok(()) diff --git a/retrieval-pipeline/src/pipeline/context.rs b/retrieval-pipeline/src/pipeline/context.rs index 94bee5d..831036e 100644 --- a/retrieval-pipeline/src/pipeline/context.rs +++ b/retrieval-pipeline/src/pipeline/context.rs @@ -11,7 +11,7 @@ use crate::{reranking::RerankerLease, RetrievedChunk, RetrievedEntity}; use super::{ config::RetrievalConfig, diagnostics::{AssembleStats, Diagnostics, SearchStats}, - StageKind, StageTimings, RetrievalParams, + RetrievalParams, StageKind, StageTimings, }; /// Mutable working state threaded through every retrieval stage. @@ -22,7 +22,7 @@ pub(crate) struct PipelineContext<'a> { pub user_id: String, pub config: RetrievalConfig, pub query_embedding: Option>, - pub chunk_values: Vec>, + pub chunk_values: Vec>>, pub reranker: Option, pub diagnostics: Option, pub entity_results: Vec, diff --git a/retrieval-pipeline/src/pipeline/stages.rs b/retrieval-pipeline/src/pipeline/stages.rs index e752284..91e3690 100644 --- a/retrieval-pipeline/src/pipeline/stages.rs +++ b/retrieval-pipeline/src/pipeline/stages.rs @@ -131,14 +131,14 @@ pub async fn search_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError let vector_candidates = vector_rows.len(); let fts_candidates = fts_rows.len(); - let vector_scored: Vec> = vector_rows + let vector_scored: Vec>> = vector_rows .into_iter() - .map(|row| Scored::new(row.chunk).with_vector_score(row.score)) + .map(|row| Scored::new(Arc::new(row.chunk)).with_vector_score(row.score)) .collect(); - let fts_scored: Vec> = fts_rows + let fts_scored: Vec>> = fts_rows .into_iter() - .map(|row| Scored::new(row.chunk).with_fts_score(row.score)) + .map(|row| Scored::new(Arc::new(row.chunk)).with_fts_score(row.score)) .collect(); let mut fts_weight = tuning.chunk_rrf_fts_weight; @@ -222,40 +222,63 @@ pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError /// and the contributing chunks are attached. #[instrument(level = "trace", skip_all)] pub async fn resolve_entities(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { - if ctx.chunk_values.is_empty() { + let chunk_values = std::mem::take(&mut ctx.chunk_values); + if chunk_values.is_empty() { return Ok(()); } let max_chunks = ctx.config.tuning.max_chunks_per_entity.max(1); + struct IndexedChunk { + idx: usize, + score: f32, + } + let mut source_order: Vec = Vec::new(); - let mut chunks_by_source: HashMap> = HashMap::new(); + let mut chunks_by_source: HashMap> = HashMap::new(); let mut best_score: HashMap = HashMap::new(); - for scored in &ctx.chunk_values { - let source_id = &scored.item.source_id; - let is_new_source = !chunks_by_source.contains_key(source_id); - if is_new_source { - source_order.push(source_id.clone()); + for (idx, scored) in chunk_values.iter().enumerate() { + if let Some(attached) = chunks_by_source.get_mut(&scored.item.source_id) { + if attached.len() < max_chunks { + attached.push(IndexedChunk { + idx, + score: scored.fused, + }); + } + } else { + let source_id = scored.item.source_id.clone(); best_score.insert(source_id.clone(), scored.fused); - } - - let attached = chunks_by_source - .entry(source_id.clone()) - .or_default(); - if attached.len() < max_chunks { - attached.push(RetrievedChunk { - chunk: scored.item.clone(), - score: scored.fused, - }); + source_order.push(source_id.clone()); + chunks_by_source.insert( + source_id, + vec![IndexedChunk { + idx, + score: scored.fused, + }], + ); } } let chunks_by_source: HashMap>> = chunks_by_source .into_iter() - .map(|(source, chunks)| (source, Arc::new(chunks))) + .map(|(source, indices)| { + let chunks = indices + .into_iter() + .filter_map(|indexed| { + let scored = chunk_values.get(indexed.idx)?; + Some(RetrievedChunk { + chunk: Arc::clone(&scored.item), + score: indexed.score, + }) + }) + .collect(); + (source, Arc::new(chunks)) + }) .collect(); + ctx.chunk_values = chunk_values; + let entities = KnowledgeEntity::find_by_source_ids(ctx.db_client, &source_order, &ctx.user_id).await?; @@ -336,10 +359,17 @@ fn sample_scores(items: &[Scored], extractor: F) -> Vec where F: FnMut(&Scored) -> f32, { - items.iter().take(SCORE_SAMPLE_LIMIT).map(extractor).collect() + items + .iter() + .take(SCORE_SAMPLE_LIMIT) + .map(extractor) + .collect() } -fn build_chunk_rerank_documents(chunks: &[Scored], max_chunks: usize) -> Vec { +fn build_chunk_rerank_documents( + chunks: &[Scored>], + max_chunks: usize, +) -> Vec { let take = chunks.len().min(max_chunks); let mut documents = Vec::with_capacity(take); let mut buffer = String::with_capacity(512); @@ -363,7 +393,7 @@ fn build_chunk_rerank_documents(chunks: &[Scored], max_chunks: usize) } fn apply_chunk_rerank_results( - chunks: &mut Vec>, + chunks: &mut Vec>>, tuning: &RetrievalTuning, results: Vec, ) { @@ -371,7 +401,7 @@ fn apply_chunk_rerank_results( return; } - let mut remaining: Vec>> = + let mut remaining: Vec>>> = std::mem::take(chunks).into_iter().map(Some).collect(); let raw_scores: Vec = results.iter().map(|r| r.score).collect(); @@ -384,7 +414,7 @@ fn apply_chunk_rerank_results( clamp_unit(tuning.rerank_blend_weight) }; - let mut reranked: Vec> = Vec::with_capacity(remaining.len()); + let mut reranked: Vec>> = Vec::with_capacity(remaining.len()); for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) { if let Some(slot) = remaining.get_mut(result.index) { if let Some(mut candidate) = slot.take() { diff --git a/retrieval-pipeline/src/reranking.rs b/retrieval-pipeline/src/reranking.rs index 823e03e..f11fb24 100644 --- a/retrieval-pipeline/src/reranking.rs +++ b/retrieval-pipeline/src/reranking.rs @@ -29,8 +29,7 @@ impl RerankerPool { /// Build the pool at startup. /// `pool_size` controls max parallel reranks. pub fn new(pool_size: usize) -> Result, Box> { - let init_options = - RerankInitOptions::new(fastembed::RerankerModel::JINARerankerV1TurboEn); + let init_options = RerankInitOptions::new(fastembed::RerankerModel::JINARerankerV1TurboEn); Self::new_with_options(pool_size, &init_options) } @@ -44,8 +43,7 @@ impl RerankerPool { ))); } - fs::create_dir_all(&init_options.cache_dir) - .map_err(|e| Box::new(AppError::from(e)))?; + fs::create_dir_all(&init_options.cache_dir).map_err(|e| Box::new(AppError::from(e)))?; let mut engines = Vec::with_capacity(pool_size); for x in 0..pool_size { @@ -77,10 +75,7 @@ impl RerankerPool { /// This returns a lease that can perform `rerank()`. pub async fn checkout(self: &Arc) -> Option { // Acquire a permit. This enforces backpressure. - let permit = Arc::clone(&self.semaphore) - .acquire_owned() - .await - .ok()?; + let permit = Arc::clone(&self.semaphore).acquire_owned().await.ok()?; // Pick an engine. // This is naive: just pick based on a simple modulo counter. @@ -165,9 +160,9 @@ impl RerankerLease { let engine = Arc::clone(&self.engine); tokio::task::spawn_blocking(move || { - let mut guard = engine.lock().map_err(|_| { - AppError::InternalError("reranker engine mutex poisoned".into()) - })?; + let mut guard = engine + .lock() + .map_err(|_| AppError::InternalError("reranker engine mutex poisoned".into()))?; guard .rerank(query, documents, false, None) .map_err(|e| AppError::InternalError(e.to_string())) diff --git a/retrieval-pipeline/src/scoring.rs b/retrieval-pipeline/src/scoring.rs index 849d5ed..0327c93 100644 --- a/retrieval-pipeline/src/scoring.rs +++ b/retrieval-pipeline/src/scoring.rs @@ -1,9 +1,35 @@ use std::{ cmp::Ordering, collections::{hash_map::Entry, HashMap}, + sync::Arc, }; -use common::storage::types::StoredObject; +use common::storage::types::{ + knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject, +}; + +/// Identifier access for retrieval fusion and sorting. +pub trait RetrievalCandidate { + fn candidate_id(&self) -> &str; +} + +impl RetrievalCandidate for TextChunk { + fn candidate_id(&self) -> &str { + self.id() + } +} + +impl RetrievalCandidate for Arc { + fn candidate_id(&self) -> &str { + self.as_ref().id() + } +} + +impl RetrievalCandidate for KnowledgeEntity { + fn candidate_id(&self) -> &str { + self.id() + } +} /// Holds optional subscores gathered from the vector and full-text retrieval signals. #[derive(Debug, Clone, Copy, Default)] @@ -102,13 +128,13 @@ pub fn min_max_normalize(scores: &[f32]) -> Vec { pub fn sort_by_fused_desc(items: &mut [Scored]) where - T: StoredObject, + T: RetrievalCandidate, { items.sort_by(|a, b| { b.fused .partial_cmp(&a.fused) .unwrap_or(Ordering::Equal) - .then_with(|| a.item.id().cmp(b.item.id())) + .then_with(|| a.item.candidate_id().cmp(b.item.candidate_id())) }); } @@ -122,7 +148,7 @@ pub fn reciprocal_rank_fusion( config: RrfConfig, ) -> Vec> where - T: StoredObject, + T: RetrievalCandidate, { let mut merged: HashMap> = HashMap::new(); let k = if config.k <= 0.0 { 60.0 } else { config.k }; @@ -144,11 +170,11 @@ where b_score .partial_cmp(&a_score) .unwrap_or(Ordering::Equal) - .then_with(|| a.item.id().cmp(b.item.id())) + .then_with(|| a.item.candidate_id().cmp(b.item.candidate_id())) }); for (rank, candidate) in vector_ranked.into_iter().enumerate() { - let id = candidate.item.id().to_owned(); + let id = candidate.item.candidate_id().to_owned(); let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from); let contribution = vector_weight / (k + rank_f32 + 1.0); @@ -183,11 +209,11 @@ where b_score .partial_cmp(&a_score) .unwrap_or(Ordering::Equal) - .then_with(|| a.item.id().cmp(b.item.id())) + .then_with(|| a.item.candidate_id().cmp(b.item.candidate_id())) }); for (rank, candidate) in fts_ranked.into_iter().enumerate() { - let id = candidate.item.id().to_owned(); + let id = candidate.item.candidate_id().to_owned(); let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from); let contribution = fts_weight / (k + rank_f32 + 1.0);