From 15c9f18f6ed3c0594b75a1a1acb92a67eccf5dc5 Mon Sep 17 00:00:00 2001 From: Per Stark Date: Wed, 3 Jun 2026 22:10:33 +0200 Subject: [PATCH] feat: pool fastembed, batch embeddings, and reconcile embedding config on startup --- common/src/storage/indexes.rs | 49 ++++ common/src/storage/types/knowledge_entity.rs | 188 ++++--------- common/src/storage/types/system_settings.rs | 10 +- common/src/storage/types/text_chunk.rs | 159 +++-------- common/src/utils/config.rs | 3 + common/src/utils/embedding.rs | 252 +++++++++--------- evaluations/src/pipeline/stages/prepare_db.rs | 11 +- .../src/pipeline/stages/run_queries.rs | 5 +- html-router/src/routes/admin/handlers.rs | 50 +--- .../routes/chat/message_response_stream.rs | 3 +- html-router/src/routes/knowledge/handlers.rs | 12 +- html-router/src/routes/search/handlers.rs | 3 +- html-router/tests/router_integration.rs | 10 +- .../src/pipeline/enrichment_result.rs | 28 +- ingestion-pipeline/src/pipeline/services.rs | 38 ++- main/src/bootstrap/mod.rs | 2 +- main/src/bootstrap/startup.rs | 231 ++++++++++++++-- main/src/main.rs | 7 +- main/src/server.rs | 4 +- main/src/worker.rs | 5 +- retrieval-pipeline/src/lib.rs | 25 +- retrieval-pipeline/src/pipeline/context.rs | 5 +- retrieval-pipeline/src/pipeline/mod.rs | 4 +- retrieval-pipeline/src/pipeline/stages.rs | 7 +- 24 files changed, 565 insertions(+), 546 deletions(-) diff --git a/common/src/storage/indexes.rs b/common/src/storage/indexes.rs index 9ed2ba8..35de433 100644 --- a/common/src/storage/indexes.rs +++ b/common/src/storage/indexes.rs @@ -211,6 +211,27 @@ pub async fn rebuild(db: &SurrealDbClient) -> Result<(), AppError> { rebuild_inner(db).await.map_err(AppError::internal) } +/// Returns the dimension of the currently defined chunk-embedding HNSW index, if any. +/// +/// Stored embeddings always share this index's dimension because re-embedding rewrites the +/// vectors and the index together, so it acts as a persisted marker of the embedding space +/// actually present in the database. Returns `Ok(None)` when the index has not been created yet +/// (for example on a fresh database with no ingested data). +/// +/// # Errors +/// +/// Returns `AppError::InternalError` if the index metadata cannot be read. +pub async fn embedding_index_dimension(db: &SurrealDbClient) -> Result, AppError> { + let spec = HnswIndexSpec { + index_name: "idx_embedding_text_chunk_embedding", + table: "text_chunk_embedding", + options: HNSW_INDEX_OPTIONS, + }; + existing_hnsw_dimension(db, &spec) + .await + .map_err(AppError::internal) +} + async fn ensure_runtime_inner(db: &SurrealDbClient, embedding_dimension: usize) -> Result<()> { create_fts_analyzer(db).await?; @@ -906,6 +927,34 @@ mod tests { Ok(()) } + #[tokio::test] + async fn embedding_index_dimension_reflects_runtime_state() -> anyhow::Result<()> { + let namespace = "indexes_marker"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .context("in-memory db")?; + + db.apply_migrations() + .await + .context("migrations should succeed")?; + + // Before any index exists, there is no stored embedding dimension to detect. + assert_eq!(embedding_index_dimension(&db).await?, None); + + ensure_runtime(&db, 1536) + .await + .context("initial index creation")?; + assert_eq!(embedding_index_dimension(&db).await?, Some(1536)); + + // After a dimension change the marker tracks the new index dimension. + ensure_runtime(&db, 256) + .await + .context("overwritten index creation")?; + assert_eq!(embedding_index_dimension(&db).await?, Some(256)); + Ok(()) + } + #[tokio::test] async fn ensure_hnsw_index_overwrites_dimension() -> anyhow::Result<()> { let namespace = "indexes_dim"; diff --git a/common/src/storage/types/knowledge_entity.rs b/common/src/storage/types/knowledge_entity.rs index 4509b5a..965de4f 100644 --- a/common/src/storage/types/knowledge_entity.rs +++ b/common/src/storage/types/knowledge_entity.rs @@ -6,12 +6,7 @@ use crate::{ error::AppError, storage::db::SurrealDbClient, storage::indexes::hnsw_index_overwrite_sql, storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding, storage::types::system_settings::SystemSettings, stored_object, - utils::embedding::{generate_embedding_with_params, generate_embedding_with_provider, EmbeddingProvider}, -}; -use async_openai::{config::OpenAIConfig, Client}; -use tokio_retry::{ - strategy::{jitter, ExponentialBackoff}, - Retry, + utils::embedding::{EmbeddingProvider, RE_EMBED_BATCH_SIZE}, }; use tracing::{error, info}; use uuid::Uuid; @@ -321,8 +316,7 @@ impl KnowledgeEntity { ) -> Result<(), AppError> { let embedding_input = format!("name: {name}, description: {description}, type: {entity_type:?}",); - let embedding = - generate_embedding_with_provider(embedding_provider, &embedding_input).await?; + let embedding = embedding_provider.embed(&embedding_input).await?; let entity: KnowledgeEntity = db_client .get_item(id) @@ -368,120 +362,17 @@ impl KnowledgeEntity { Ok(()) } - /// Re-creates embeddings for all knowledge entities in the database. + /// Re-creates embeddings for all knowledge entities using an `EmbeddingProvider`. /// /// This is a costly operation that should be run in the background. It follows the same /// pattern as the text chunk update: /// 1. Re-defines the vector index with the new dimensions. /// 2. Fetches all existing entities. /// 3. Sequentially regenerates the embedding for each and updates the record. + #[allow(clippy::too_many_lines)] pub async fn update_all_embeddings( db: &SurrealDbClient, - openai_client: &Client, - new_model: &str, - new_dimensions: u32, - ) -> Result<(), AppError> { - info!( - "Starting re-embedding process for all knowledge entities. New dimensions: {}", - new_dimensions - ); - - // Fetch all entities first - let all_entities: Vec = db.select(Self::table_name()).await?; - let total_entities = all_entities.len(); - if total_entities == 0 { - info!("No knowledge entities to update. Just updating the idx"); - - KnowledgeEntityEmbedding::redefine_hnsw_index(db, new_dimensions as usize).await?; - return Ok(()); - } - info!("Found {total_entities} entities to process."); - - // Generate all new embeddings in memory - let mut new_embeddings: HashMap, String, String)> = HashMap::new(); - info!("Generating new embeddings for all entities..."); - for entity in &all_entities { - let embedding_input = format!( - "name: {}, description: {}, type: {:?}", - entity.name, entity.description, entity.entity_type - ); - let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3); - - let embedding = Retry::spawn(retry_strategy, || { - generate_embedding_with_params( - openai_client, - &embedding_input, - new_model, - new_dimensions, - ) - }) - .await?; - - // Check embedding lengths - if embedding.len() != new_dimensions as usize { - let err_msg = format!( - "CRITICAL: Generated embedding for entity {} has incorrect dimension ({}). Expected {}. Aborting.", - entity.id, embedding.len(), new_dimensions - ); - error!("{err_msg}"); - return Err(AppError::internal(err_msg)); - } - new_embeddings.insert( - entity.id.clone(), - (embedding, entity.user_id.clone(), entity.source_id.clone()), - ); - } - info!("Successfully generated all new embeddings."); - - // Perform DB updates in a single transaction - info!("Applying embedding updates in a transaction..."); - let mut transaction_query = String::from("BEGIN TRANSACTION;"); - - // Add all update statements to the embedding table - for (id, (embedding, user_id, source_id)) in new_embeddings { - let embedding = serde_json::to_string(&embedding) - .map_err(|e| AppError::internal(format!("embedding serialization failed: {e}")))?; - write!( - transaction_query, - "UPSERT type::thing('knowledge_entity_embedding', '{id}') SET \ - entity_id = type::thing('knowledge_entity', '{id}'), \ - embedding = {embedding}, \ - user_id = '{user_id}', \ - source_id = '{source_id}', \ - created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \ - updated_at = time::now();", - ) - .map_err(AppError::internal)?; - } - - write!( - transaction_query, - "{}", - hnsw_index_overwrite_sql( - "idx_embedding_knowledge_entity_embedding", - KnowledgeEntityEmbedding::table_name(), - new_dimensions as usize, - ) - ) - .map_err(AppError::internal)?; - - transaction_query.push_str("COMMIT TRANSACTION;"); - - // Execute the entire atomic operation - db.query(transaction_query).await?; - - info!("Re-embedding process for knowledge entities completed successfully."); - Ok(()) - } - - /// Re-creates embeddings for all knowledge entities using an `EmbeddingProvider`. - /// - /// This variant uses the application's configured embedding provider (FastEmbed, OpenAI, etc.) - /// instead of directly calling OpenAI. Used during startup when embedding configuration changes. - #[allow(clippy::too_many_lines)] - pub async fn update_all_embeddings_with_provider( - db: &SurrealDbClient, - provider: &crate::utils::embedding::EmbeddingProvider, + provider: &EmbeddingProvider, ) -> Result<(), AppError> { let new_dimensions = provider.dimension(); info!( @@ -500,38 +391,53 @@ impl KnowledgeEntity { } info!(entities = total_entities, "Found entities to process"); - // Generate all new embeddings in memory - let mut new_embeddings: HashMap, String, String)> = HashMap::new(); + // Generate all new embeddings in memory, batching to amortise lock/dispatch overhead + // while keeping memory bounded and preserving progress logging. + let mut new_embeddings: HashMap, String, String)> = + HashMap::with_capacity(total_entities); info!("Generating new embeddings for all entities..."); - for (i, entity) in all_entities.iter().enumerate() { - if i > 0 && i % 100 == 0 { - info!( - progress = i, - total = total_entities, - "Re-embedding progress" + let mut processed = 0usize; + for batch in all_entities.chunks(RE_EMBED_BATCH_SIZE) { + let inputs: Vec = batch + .iter() + .map(|entity| { + format!( + "name: {}, description: {}, type: {:?}", + entity.name, entity.description, entity.entity_type + ) + }) + .collect(); + let embeddings = provider.embed_batch(inputs).await?; + if embeddings.len() != batch.len() { + return Err(AppError::internal(format!( + "embedding batch returned {} vectors for {} entities", + embeddings.len(), + batch.len() + ))); + } + + for (entity, embedding) in batch.iter().zip(embeddings) { + // Safety check: ensure the generated embedding has the correct dimension. + if embedding.len() != new_dimensions { + let err_msg = format!( + "CRITICAL: Generated embedding for entity {} has incorrect dimension ({}). Expected {}. Aborting.", + entity.id, embedding.len(), new_dimensions + ); + error!("{err_msg}"); + return Err(AppError::internal(err_msg)); + } + new_embeddings.insert( + entity.id.clone(), + (embedding, entity.user_id.clone(), entity.source_id.clone()), ); } - let embedding_input = format!( - "name: {}, description: {}, type: {:?}", - entity.name, entity.description, entity.entity_type - ); - - let embedding = provider.embed(&embedding_input).await?; - - // Safety check: ensure the generated embedding has the correct dimension. - if embedding.len() != new_dimensions { - let err_msg = format!( - "CRITICAL: Generated embedding for entity {} has incorrect dimension ({}). Expected {}. Aborting.", - entity.id, embedding.len(), new_dimensions - ); - error!("{err_msg}"); - return Err(AppError::internal(err_msg)); - } - new_embeddings.insert( - entity.id.clone(), - (embedding, entity.user_id.clone(), entity.source_id.clone()), + processed = processed.saturating_add(batch.len()); + info!( + progress = processed, + total = total_entities, + "Re-embedding progress" ); } info!("Successfully generated all new embeddings."); diff --git a/common/src/storage/types/system_settings.rs b/common/src/storage/types/system_settings.rs index 17347d6..7bedfb1 100644 --- a/common/src/storage/types/system_settings.rs +++ b/common/src/storage/types/system_settings.rs @@ -235,7 +235,6 @@ mod tests { use crate::storage::indexes::ensure_runtime; use crate::storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk}; use anyhow::{self, Context}; - use async_openai::Client; use super::*; use uuid::Uuid; @@ -713,6 +712,8 @@ mod tests { #[tokio::test] async fn test_should_change_embedding_length_on_indexes_when_switching_length( ) -> anyhow::Result<()> { + use crate::utils::embedding::EmbeddingProvider; + let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string()) .await .with_context(|| "Failed to start DB".to_string())?; @@ -758,12 +759,13 @@ mod tests { "Settings should reflect the new embedding dimension" ); - let openai_client = Client::new(); + let provider = EmbeddingProvider::new_hashed(new_dimension as usize) + .map_err(|e| anyhow::anyhow!("{e}"))?; - TextChunk::update_all_embeddings(&db, &openai_client, &new_model, new_dimension) + TextChunk::update_all_embeddings(&db, &provider) .await .with_context(|| "TextChunk re-embedding should succeed on fresh DB".to_string())?; - KnowledgeEntity::update_all_embeddings(&db, &openai_client, &new_model, new_dimension) + KnowledgeEntity::update_all_embeddings(&db, &provider) .await .with_context(|| { "KnowledgeEntity re-embedding should succeed on fresh DB".to_string() diff --git a/common/src/storage/types/text_chunk.rs b/common/src/storage/types/text_chunk.rs index a0c573e..6d97efc 100644 --- a/common/src/storage/types/text_chunk.rs +++ b/common/src/storage/types/text_chunk.rs @@ -5,12 +5,8 @@ use std::fmt::Write; use crate::storage::indexes::hnsw_index_overwrite_sql; use crate::storage::types::system_settings::SystemSettings; use crate::storage::types::text_chunk_embedding::TextChunkEmbedding; +use crate::utils::embedding::RE_EMBED_BATCH_SIZE; use crate::{error::AppError, storage::db::SurrealDbClient, stored_object}; -use async_openai::{config::OpenAIConfig, Client}; -use tokio_retry::{ - strategy::{jitter, ExponentialBackoff}, - Retry, -}; use tracing::{error, info, warn}; use uuid::Uuid; @@ -238,7 +234,7 @@ impl TextChunk { .collect()) } - /// Re-creates embeddings for all text chunks using a safe, atomic transaction. + /// Re-creates embeddings for all text chunks using an `EmbeddingProvider`. /// /// This is a costly operation that should be run in the background. It performs these steps: /// 1. **Fetches All Chunks**: Loads all existing text_chunk records into memory. @@ -246,109 +242,8 @@ impl TextChunk { /// has the wrong dimension, the entire operation is aborted before any DB changes are made. /// 3. **Executes Atomic Transaction**: All data updates and the index recreation are /// performed in a single, all-or-nothing database transaction. + #[allow(clippy::too_many_lines)] pub async fn update_all_embeddings( - db: &SurrealDbClient, - openai_client: &Client, - new_model: &str, - new_dimensions: u32, - ) -> Result<(), AppError> { - info!( - "Starting re-embedding process for all text chunks. New dimensions: {new_dimensions}" - ); - - // Fetch all chunks first - let all_chunks: Vec = db.select(Self::table_name()).await?; - let total_chunks = all_chunks.len(); - if total_chunks == 0 { - info!("No text chunks to update. Just updating the idx"); - - TextChunkEmbedding::redefine_hnsw_index(db, new_dimensions as usize).await?; - - return Ok(()); - } - info!("Found {total_chunks} chunks to process."); - - // Generate all new embeddings in memory - let mut new_embeddings: HashMap, String, String)> = HashMap::new(); - info!("Generating new embeddings for all chunks..."); - for chunk in &all_chunks { - let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3); - - let embedding = Retry::spawn(retry_strategy, || { - crate::utils::embedding::generate_embedding_with_params( - openai_client, - &chunk.chunk, - new_model, - new_dimensions, - ) - }) - .await?; - - // Safety check: ensure the generated embedding has the correct dimension. - if embedding.len() != new_dimensions as usize { - let err_msg = format!( - "CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.", - chunk.id, embedding.len(), new_dimensions - ); - error!("{err_msg}"); - return Err(AppError::internal(err_msg)); - } - new_embeddings.insert( - chunk.id.clone(), - (embedding, chunk.user_id.clone(), chunk.source_id.clone()), - ); - } - info!("Successfully generated all new embeddings."); - - // Perform DB updates in a single transaction against the embedding table - info!("Applying embedding updates in a transaction..."); - let mut transaction_query = String::from("BEGIN TRANSACTION;"); - - for (id, (embedding, user_id, source_id)) in new_embeddings { - let embedding = serde_json::to_string(&embedding) - .map_err(|e| AppError::internal(format!("embedding serialization failed: {e}")))?; - let id = surql_json_string(&id)?; - let user_id = surql_json_string(&user_id)?; - let source_id = surql_json_string(&source_id)?; - write!( - &mut transaction_query, - "UPSERT type::thing('{emb_table}', {id}) SET \ - chunk_id = type::thing('{chunk_table}', {id}), \ - source_id = {source_id}, \ - embedding = {embedding}, \ - user_id = {user_id}, \ - created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \ - updated_at = time::now();", - emb_table = TextChunkEmbedding::table_name(), - chunk_table = Self::table_name(), - ) - .map_err(AppError::internal)?; - } - - write!( - &mut transaction_query, - "{}", - hnsw_index_overwrite_sql( - "idx_embedding_text_chunk_embedding", - TextChunkEmbedding::table_name(), - new_dimensions as usize, - ) - ) - .map_err(AppError::internal)?; - - transaction_query.push_str("COMMIT TRANSACTION;"); - - db.query(transaction_query).await?; - - info!("Re-embedding process for text chunks completed successfully."); - Ok(()) - } - - /// Re-creates embeddings for all text chunks using an `EmbeddingProvider`. - /// - /// This variant uses the application's configured embedding provider (FastEmbed, OpenAI, etc.) - /// instead of directly calling OpenAI. Used during startup when embedding configuration changes. - pub async fn update_all_embeddings_with_provider( db: &SurrealDbClient, provider: &crate::utils::embedding::EmbeddingProvider, ) -> Result<(), AppError> { @@ -369,30 +264,42 @@ impl TextChunk { } info!(chunks = total_chunks, "Found chunks to process"); - // Generate all new embeddings in memory - let mut new_embeddings: HashMap, String, String)> = HashMap::new(); + // Generate all new embeddings in memory, batching to amortise lock/dispatch overhead + // while keeping memory bounded and preserving progress logging. + let mut new_embeddings: HashMap, String, String)> = + HashMap::with_capacity(total_chunks); info!("Generating new embeddings for all chunks..."); - for (i, chunk) in all_chunks.iter().enumerate() { - if i > 0 && i % 100 == 0 { - info!(progress = i, total = total_chunks, "Re-embedding progress"); + let mut processed = 0usize; + for batch in all_chunks.chunks(RE_EMBED_BATCH_SIZE) { + let inputs: Vec = batch.iter().map(|chunk| chunk.chunk.clone()).collect(); + let embeddings = provider.embed_batch(inputs).await?; + if embeddings.len() != batch.len() { + return Err(AppError::internal(format!( + "embedding batch returned {} vectors for {} chunks", + embeddings.len(), + batch.len() + ))); } - let embedding = provider.embed(&chunk.chunk).await?; - - // Safety check: ensure the generated embedding has the correct dimension. - if embedding.len() != new_dimensions { - let err_msg = format!( - "CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.", - chunk.id, embedding.len(), new_dimensions + for (chunk, embedding) in batch.iter().zip(embeddings) { + // Safety check: ensure the generated embedding has the correct dimension. + if embedding.len() != new_dimensions { + let err_msg = format!( + "CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.", + chunk.id, embedding.len(), new_dimensions + ); + error!("{err_msg}"); + return Err(AppError::internal(err_msg)); + } + new_embeddings.insert( + chunk.id.clone(), + (embedding, chunk.user_id.clone(), chunk.source_id.clone()), ); - error!("{err_msg}"); - return Err(AppError::internal(err_msg)); } - new_embeddings.insert( - chunk.id.clone(), - (embedding, chunk.user_id.clone(), chunk.source_id.clone()), - ); + + processed = processed.saturating_add(batch.len()); + info!(progress = processed, total = total_chunks, "Re-embedding progress"); } info!("Successfully generated all new embeddings."); diff --git a/common/src/utils/config.rs b/common/src/utils/config.rs index ed306f4..0e07eae 100644 --- a/common/src/utils/config.rs +++ b/common/src/utils/config.rs @@ -119,6 +119,8 @@ pub struct AppConfig { pub fastembed_max_length: Option, #[serde(default)] pub embedding_backend: EmbeddingBackend, + #[serde(default)] + pub embedding_pool_size: Option, #[serde(default = "default_ingest_max_body_bytes")] pub ingest_max_body_bytes: usize, #[serde(default = "default_ingest_max_files")] @@ -225,6 +227,7 @@ impl Default for AppConfig { fastembed_show_download_progress: None, fastembed_max_length: None, embedding_backend: EmbeddingBackend::default(), + embedding_pool_size: None, ingest_max_body_bytes: default_ingest_max_body_bytes(), ingest_max_files: default_ingest_max_files(), ingest_max_content_bytes: default_ingest_max_content_bytes(), diff --git a/common/src/utils/embedding.rs b/common/src/utils/embedding.rs index 5e9809e..f02f2c5 100644 --- a/common/src/utils/embedding.rs +++ b/common/src/utils/embedding.rs @@ -3,15 +3,16 @@ use std::{ hash::{Hash, Hasher}, str::FromStr, sync::{Arc, Mutex}, + thread::available_parallelism, }; use async_openai::{types::CreateEmbeddingRequestArgs, Client}; use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions}; -use tracing::debug; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use crate::{ - error::{AppError, EmbeddingError}, - storage::{db::SurrealDbClient, types::system_settings::SystemSettings}, + error::EmbeddingError, + storage::types::system_settings::SystemSettings, utils::config::AppConfig, }; @@ -45,8 +46,8 @@ enum EmbeddingInner { }, /// Uses `FastEmbed` running locally. FastEmbed { - /// Shared `FastEmbed` model (blocking; used only inside `spawn_blocking`). - model: Arc>, + /// Pool of `FastEmbed` engines providing bounded-concurrency local embedding. + pool: Arc, /// Model metadata used for info logging. model_name: EmbeddingModel, /// Output vector length. @@ -54,19 +55,99 @@ enum EmbeddingInner { }, } +/// Batch size used when re-embedding stored data in bulk. Bounds peak memory and preserves +/// progress logging while still amortising per-call lock/dispatch overhead. +pub const RE_EMBED_BATCH_SIZE: usize = 128; + +/// Default FastEmbed pool size. +/// +/// Kept small on purpose: the ONNX runtime already uses intra-op threads per inference, so +/// running many engines concurrently oversubscribes the CPU and each engine duplicates the +/// model weights in memory. Mirrors the reranker pool default. +#[must_use] +pub fn default_embedding_pool_size() -> usize { + available_parallelism() + .map_or(2, |value| value.get().min(2)) + .max(1) +} + +/// Pool of `FastEmbed` engines enabling bounded-concurrency local embedding. +/// +/// A single [`TextEmbedding`] embeds one batch at a time (`&mut self`), so the pool keeps +/// several instances and hands out a distinct idle engine per checkout. The semaphore bounds +/// total in-flight embeds (backpressure); the free list guarantees each active lease holds a +/// different engine — unlike a round-robin index, which can hand the same engine to two callers. +struct FastEmbedPool { + /// Idle engines; one is popped on checkout and returned on lease drop. + engines: Mutex>>>, + /// Sized to the engine count; gates concurrent checkouts. + semaphore: Arc, +} + +impl FastEmbedPool { + fn new(engines: Vec>>) -> Self { + let permits = engines.len().max(1); + Self { + engines: Mutex::new(engines), + semaphore: Arc::new(Semaphore::new(permits)), + } + } + + /// Acquire a permit and borrow a distinct idle engine. The permit guarantees an engine is + /// available, so the pop always succeeds for a correctly sized pool. + async fn checkout(self: &Arc) -> Result { + let permit = Arc::clone(&self.semaphore) + .acquire_owned() + .await + .map_err(|_| EmbeddingError::Config("embedding pool is closed".into()))?; + let engine = self + .engines + .lock() + .map_err(EmbeddingError::mutex_poisoned)? + .pop() + .ok_or_else(|| EmbeddingError::Config("embedding pool unexpectedly empty".into()))?; + Ok(FastEmbedLease { + pool: Arc::clone(self), + engine, + _permit: permit, + }) + } +} + +/// Active borrow of a single `FastEmbed` engine; returns it to the pool on drop. +struct FastEmbedLease { + pool: Arc, + engine: Arc>, + /// Released after the engine is returned, unblocking the next checkout. + _permit: OwnedSemaphorePermit, +} + +impl FastEmbedLease { + async fn embed(&self, texts: Vec) -> Result>, EmbeddingError> { + let engine = Arc::clone(&self.engine); + tokio::task::spawn_blocking(move || -> Result>, EmbeddingError> { + let mut guard = engine.lock().map_err(EmbeddingError::mutex_poisoned)?; + guard.embed(texts, None).map_err(EmbeddingError::fastembed) + }) + .await + .map_err(EmbeddingError::from)? + } +} + +impl Drop for FastEmbedLease { + fn drop(&mut self) { + if let Ok(mut free) = self.pool.engines.lock() { + free.push(Arc::clone(&self.engine)); + } + } +} + async fn run_fastembed( - model: Arc>, + pool: &Arc, texts: Vec, ) -> Result>, EmbeddingError> { - match tokio::task::spawn_blocking(move || -> Result>, EmbeddingError> { - let mut guard = model.lock().map_err(EmbeddingError::mutex_poisoned)?; - guard.embed(texts, None).map_err(EmbeddingError::fastembed) - }) - .await - { - Ok(result) => result, - Err(join_error) => Err(EmbeddingError::from(join_error)), - } + let lease = pool.checkout().await?; + lease.embed(texts).await } impl EmbeddingProvider { @@ -107,8 +188,8 @@ impl EmbeddingProvider { pub async fn embed(&self, text: &str) -> Result, EmbeddingError> { match &self.inner { EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)), - EmbeddingInner::FastEmbed { model, .. } => { - let embeddings = run_fastembed(Arc::clone(model), vec![text.to_owned()]).await?; + EmbeddingInner::FastEmbed { pool, .. } => { + let embeddings = run_fastembed(pool, vec![text.to_owned()]).await?; embeddings.into_iter().next().ok_or(EmbeddingError::NoData) } EmbeddingInner::OpenAI { @@ -148,11 +229,11 @@ impl EmbeddingProvider { .into_iter() .map(|text| hashed_embedding(&text, *dimension)) .collect()), - EmbeddingInner::FastEmbed { model, .. } => { + EmbeddingInner::FastEmbed { pool, .. } => { if texts.is_empty() { return Ok(Vec::new()); } - run_fastembed(Arc::clone(model), texts).await + run_fastembed(pool, texts).await } EmbeddingInner::OpenAI { client, @@ -199,30 +280,46 @@ impl EmbeddingProvider { }) } + /// Initialise a local FastEmbed provider backed by a pool of `pool_size` engines. + /// + /// `pool_size` is clamped to at least 1. Larger pools allow concurrent embeds at the cost of + /// `pool_size`× model memory; see [`default_embedding_pool_size`] for guidance. + /// /// # Errors /// /// Returns [`EmbeddingError`] if the model name is unknown or FastEmbed initialisation fails. - pub async fn new_fastembed(model_override: Option) -> Result { + pub async fn new_fastembed( + model_override: Option, + pool_size: usize, + ) -> Result { + let pool_size = pool_size.max(1); let model_name = if let Some(code) = model_override { EmbeddingModel::from_str(&code).map_err(EmbeddingError::UnknownModel)? } else { EmbeddingModel::default() }; - let options = TextInitOptions::new(model_name.clone()).with_show_download_progress(true); let model_name_for_task = model_name.clone(); let model_name_code = model_name.to_string(); - let (model, dimension) = + let (engines, dimension) = match tokio::task::spawn_blocking(move || -> Result<_, EmbeddingError> { - let model = TextEmbedding::try_new(options).map_err(EmbeddingError::fastembed)?; let info = EmbeddingModel::get_model_info(&model_name_for_task).ok_or_else(|| { EmbeddingError::Config(format!( "fastembed model metadata missing for {model_name_code}" )) })?; - Ok((model, info.dim)) + let mut engines = Vec::with_capacity(pool_size); + for index in 0..pool_size { + let options = TextInitOptions::new(model_name_for_task.clone()) + // Only the first engine reports download progress; the rest reuse the cache. + .with_show_download_progress(index == 0); + let model = + TextEmbedding::try_new(options).map_err(EmbeddingError::fastembed)?; + engines.push(Arc::new(Mutex::new(model))); + } + Ok((engines, info.dim)) }) .await { @@ -232,7 +329,7 @@ impl EmbeddingProvider { Ok(EmbeddingProvider { inner: EmbeddingInner::FastEmbed { - model: Arc::new(Mutex::new(model)), + pool: Arc::new(FastEmbedPool::new(engines)), model_name, dimension, }, @@ -275,7 +372,10 @@ impl EmbeddingProvider { Self::new_openai(client, settings.embedding_model.clone(), dimensions) } EmbeddingBackend::FastEmbed => { - Self::new_fastembed(Some(settings.embedding_model.clone())).await + let pool_size = config + .embedding_pool_size + .unwrap_or_else(default_embedding_pool_size); + Self::new_fastembed(Some(settings.embedding_model.clone()), pool_size).await } EmbeddingBackend::Hashed => { let dimension = usize::try_from(dimensions).map_err(|_| { @@ -329,106 +429,6 @@ fn bucket(token: &str, dimension: usize) -> usize { usize::try_from(hasher.finish()).unwrap_or_default() % safe_dimension } -/// Generate an embedding using the given provider. -/// -/// # Errors -/// -/// Returns [`AppError::Embedding`] if the provider's embed call fails. -pub async fn generate_embedding_with_provider( - provider: &EmbeddingProvider, - input: &str, -) -> Result, AppError> { - provider.embed(input).await.map_err(Into::into) -} - -/// Generates an embedding vector for the given input text using `OpenAI`'s embedding model. -/// -/// This function takes a text input and converts it into a numerical vector representation (embedding) -/// using `OpenAI`'s text-embedding-3-small model. These embeddings can be used for semantic similarity -/// comparisons, vector search, and other natural language processing tasks. -/// -/// # Arguments -/// -/// * `client`: The `OpenAI` client instance used to make API requests. -/// * `input`: The text string to generate embeddings for. -/// -/// # Returns -/// -/// Returns a `Result` containing either: -/// * `Ok(Vec)`: A vector of 32-bit floating point numbers representing the text embedding -/// * `Err(ProcessingError)`: An error if the embedding generation fails -/// -/// # Errors -/// -/// This function can return a `AppError` in the following cases: -/// * If the `OpenAI` API request fails -/// * If the request building fails -/// * If no embedding data is received in the response -#[allow(clippy::module_name_repetitions)] -pub async fn generate_embedding( - client: &async_openai::Client, - input: &str, - db: &SurrealDbClient, -) -> Result, AppError> { - let model = SystemSettings::get_current(db).await?; - - let request = CreateEmbeddingRequestArgs::default() - .model(model.embedding_model) - .dimensions(model.embedding_dimensions) - .input([input]) - .build()?; - - // Send the request to OpenAI - let response = client.embeddings().create(request).await?; - - // Extract the embedding vector - let embedding: Vec = response - .data - .first() - .ok_or_else(|| AppError::LLMParsing("no embedding data received".into()))? - .embedding - .clone(); - - Ok(embedding) -} - -/// Generates an embedding vector using a specific model and dimension. -/// -/// This is used for the re-embedding process where the model and dimensions -/// are known ahead of time and shouldn't be repeatedly fetched from settings. -/// -/// # Errors -/// -/// Returns `AppError` if the OpenAI API request fails or returns no embedding data. -pub async fn generate_embedding_with_params( - client: &async_openai::Client, - input: &str, - model: &str, - dimensions: u32, -) -> Result, AppError> { - let request = CreateEmbeddingRequestArgs::default() - .model(model) - .input([input]) - .dimensions(dimensions) - .build()?; - - let response = client.embeddings().create(request).await?; - - let embedding = response - .data - .first() - .ok_or_else(|| AppError::LLMParsing("no embedding data received from API".into()))? - .embedding - .clone(); - - debug!( - "Embedding was created with {:?} dimensions", - embedding.len() - ); - - Ok(embedding) -} - #[cfg(test)] mod tests { #![allow(clippy::expect_used)] diff --git a/evaluations/src/pipeline/stages/prepare_db.rs b/evaluations/src/pipeline/stages/prepare_db.rs index 89990b3..b18e3be 100644 --- a/evaluations/src/pipeline/stages/prepare_db.rs +++ b/evaluations/src/pipeline/stages/prepare_db.rs @@ -11,7 +11,7 @@ use crate::{ }, openai, }; -use common::utils::embedding::EmbeddingProvider; +use common::utils::embedding::{default_embedding_pool_size, EmbeddingProvider}; use super::super::{ context::{EvalStage, EvaluationContext}, @@ -43,9 +43,12 @@ pub(crate) async fn prepare_db( // Create embedding provider directly from config (eval only supports FastEmbed and Hashed) let embedding_provider = match config.embedding_backend { crate::args::EmbeddingBackend::FastEmbed => { - EmbeddingProvider::new_fastembed(config.embedding_model.clone()) - .await - .context("creating FastEmbed provider")? + EmbeddingProvider::new_fastembed( + config.embedding_model.clone(), + default_embedding_pool_size(), + ) + .await + .context("creating FastEmbed provider")? } crate::args::EmbeddingBackend::Hashed => { EmbeddingProvider::new_hashed(1536).context("creating Hashed provider")? diff --git a/evaluations/src/pipeline/stages/run_queries.rs b/evaluations/src/pipeline/stages/run_queries.rs index cfc32c5..c8683f5 100644 --- a/evaluations/src/pipeline/stages/run_queries.rs +++ b/evaluations/src/pipeline/stages/run_queries.rs @@ -136,12 +136,10 @@ pub(crate) async fn run_queries( let embedding_provider_for_queries = ctx.embedding_provider()?.clone(); let rerank_pool_for_queries = rerank_pool.clone(); let db = ctx.db()?.clone(); - let openai_client = ctx.openai_client()?; let raw_results = stream::iter(cases_iter) .map(move |(idx, case)| { let db = db.clone(); - let openai_client = Arc::clone(&openai_client); let user_id = user_id.clone(); let retrieval_config = Arc::clone(&retrieval_config); let embedding_provider = embedding_provider_for_queries.clone(); @@ -180,8 +178,7 @@ pub(crate) async fn run_queries( let params = pipeline::RetrievalParams { db_client: &db, - openai_client: &openai_client, - embedding_provider: Some(&embedding_provider), + embedding_provider: &embedding_provider, input_text: &question, user_id: &user_id, config: (*retrieval_config).clone(), diff --git a/html-router/src/routes/admin/handlers.rs b/html-router/src/routes/admin/handlers.rs index b0657ad..2fda03f 100644 --- a/html-router/src/routes/admin/handlers.rs +++ b/html-router/src/routes/admin/handlers.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use async_openai::types::ListModelResponse; use axum::{ extract::{Query, State}, @@ -11,17 +9,15 @@ use common::{ error::AppError, storage::types::{ analytics::Analytics, - knowledge_entity::KnowledgeEntity, system_prompts::{ DEFAULT_IMAGE_PROCESSING_PROMPT, DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT, DEFAULT_QUERY_SYSTEM_PROMPT, }, system_settings::{SystemSettings, SystemSettingsPatch}, - text_chunk::TextChunk, }, utils::embedding::EmbeddingBackend, }; -use tracing::{error, info}; +use tracing::info; use crate::{ html_state::HtmlState, @@ -209,41 +205,15 @@ pub async fn update_model_settings( .await?; if reembedding_needed { - info!("Embedding dimensions changed. Spawning background re-embedding task..."); - - let db_for_task = Arc::clone(&state.db); - let openai_for_task = Arc::clone(&state.openai_client); - let new_model_for_task = new_settings.embedding_model.clone(); - let new_dims_for_task = new_settings.embedding_dimensions; - - tokio::spawn(async move { - // First, update all text chunks - if let Err(e) = TextChunk::update_all_embeddings( - &db_for_task, - &openai_for_task, - &new_model_for_task, - new_dims_for_task, - ) - .await - { - error!("Background re-embedding task failed for TextChunks: {}", e); - } - - // Second, update all knowledge entities - if let Err(e) = KnowledgeEntity::update_all_embeddings( - &db_for_task, - &openai_for_task, - &new_model_for_task, - new_dims_for_task, - ) - .await - { - error!( - "Background re-embedding task failed for KnowledgeEntities: {}", - e - ); - } - }); + // Re-embedding is owned by startup (the worker/combined binary), not the admin request. + // Doing it inline here would leave the live, startup-built embedding provider embedding + // queries at the old dimension while stored vectors move to the new one — broken retrieval + // until restart. Persisting the new settings is enough: on the next restart the maintainer + // detects the index/dimension mismatch and re-embeds before rebuilding indexes. + info!( + new_dimensions = new_settings.embedding_dimensions, + "Embedding dimensions changed; restart the worker/server to re-embed and apply" + ); } let available_models = state diff --git a/html-router/src/routes/chat/message_response_stream.rs b/html-router/src/routes/chat/message_response_stream.rs index 44cd413..b00ca76 100644 --- a/html-router/src/routes/chat/message_response_stream.rs +++ b/html-router/src/routes/chat/message_response_stream.rs @@ -359,8 +359,7 @@ async fn prepare_chat_request( let retrieval_result = match retrieval_pipeline::retrieve( &state.db, - &state.openai_client, - Some(&*state.embedding_provider), + &state.embedding_provider, &user_message.content, &user.id, config, diff --git a/html-router/src/routes/knowledge/handlers.rs b/html-router/src/routes/knowledge/handlers.rs index 0a0578a..2af580e 100644 --- a/html-router/src/routes/knowledge/handlers.rs +++ b/html-router/src/routes/knowledge/handlers.rs @@ -24,7 +24,7 @@ use common::{ user::User, }, }, - utils::embedding::{generate_embedding_with_provider, EmbeddingProvider}, + utils::embedding::EmbeddingProvider, }; use retrieval_pipeline::{ normalize_fts_terms, reciprocal_rank_fusion, RetrievalTuning, RrfConfig, Scored, @@ -187,8 +187,11 @@ pub async fn create_knowledge_entity( let embedding_input = format!("name: {name}, description: {description}, type: {entity_type:?}"); - let embedding = - generate_embedding_with_provider(&state.embedding_provider, &embedding_input).await?; + let embedding = state + .embedding_provider + .embed(&embedding_input) + .await + .map_err(AppError::from)?; let source_id = format!("manual::{}", Uuid::new_v4()); let new_entity = KnowledgeEntity::new( @@ -373,8 +376,7 @@ async fn suggest_related_entities( "name: {}, description: {}, type: {:?}", draft.name, draft.description, draft.entity_type ); - let embedding = - generate_embedding_with_provider(embedding_provider, &embedding_input).await?; + let embedding = embedding_provider.embed(&embedding_input).await?; let take = MAX_RELATIONSHIP_SUGGESTIONS * 2; let tuning = RetrievalTuning::default(); diff --git a/html-router/src/routes/search/handlers.rs b/html-router/src/routes/search/handlers.rs index 4fe6093..4174c18 100644 --- a/html-router/src/routes/search/handlers.rs +++ b/html-router/src/routes/search/handlers.rs @@ -171,8 +171,7 @@ async fn perform_search( let result = retrieve( &state.db, - &state.openai_client, - Some(&state.embedding_provider), + &state.embedding_provider, trimmed_query, &user.id, config, diff --git a/html-router/tests/router_integration.rs b/html-router/tests/router_integration.rs index 25df4a6..5349785 100644 --- a/html-router/tests/router_integration.rs +++ b/html-router/tests/router_integration.rs @@ -132,7 +132,7 @@ fn extract_authenticated_main(html: &str) -> &str { let start = html .find(AUTHENTICATED_MAIN_OPEN) .expect("authenticated page main column") - + AUTHENTICATED_MAIN_OPEN.len(); + .saturating_add(AUTHENTICATED_MAIN_OPEN.len()); let rest = &html[start..]; let end = rest .find("") @@ -184,8 +184,12 @@ async fn create_scratchpad_and_get_id(app: &Router, cookie: &str, title: &str) - let list = get_html(app, "/scratchpad", Some(cookie)).await; let marker = "/scratchpad/"; - let start = list.find(marker).expect("scratchpad link present") + marker.len(); - list[start..start + list[start..].find('/').expect("id terminator")].to_string() + let start = list + .find(marker) + .expect("scratchpad link present") + .saturating_add(marker.len()); + let end = start.saturating_add(list[start..].find('/').expect("id terminator")); + list[start..end].to_string() } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] diff --git a/ingestion-pipeline/src/pipeline/enrichment_result.rs b/ingestion-pipeline/src/pipeline/enrichment_result.rs index 0db064e..b9ee892 100644 --- a/ingestion-pipeline/src/pipeline/enrichment_result.rs +++ b/ingestion-pipeline/src/pipeline/enrichment_result.rs @@ -7,13 +7,12 @@ use serde::{Deserialize, Serialize}; use common::{ error::AppError, storage::{ - db::SurrealDbClient, types::{ knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}, knowledge_relationship::KnowledgeRelationship, }, }, - utils::{embedding::generate_embedding, embedding::EmbeddingProvider}, + utils::embedding::EmbeddingProvider, }; use crate::pipeline::context::EmbeddedKnowledgeEntity; @@ -46,10 +45,8 @@ impl LLMEnrichmentResult { &self, source_id: &str, user_id: &str, - openai_client: &async_openai::Client, - db_client: &SurrealDbClient, entity_concurrency: usize, - embedding_provider: Option<&EmbeddingProvider>, + embedding_provider: &EmbeddingProvider, ) -> Result<(Vec, Vec), AppError> { let mapper = Arc::new(self.create_mapper()); @@ -58,8 +55,6 @@ impl LLMEnrichmentResult { source_id, user_id, Arc::clone(&mapper), - openai_client, - db_client, entity_concurrency, embedding_provider, ) @@ -80,23 +75,18 @@ impl LLMEnrichmentResult { mapper } - #[allow(clippy::too_many_arguments)] async fn process_entities( &self, source_id: &str, user_id: &str, mapper: Arc, - openai_client: &async_openai::Client, - db_client: &SurrealDbClient, entity_concurrency: usize, - embedding_provider: Option<&EmbeddingProvider>, + embedding_provider: &EmbeddingProvider, ) -> Result, AppError> { stream::iter(self.knowledge_entities.clone().into_iter().map(|entity| { let mapper = Arc::clone(&mapper); - let openai_client = openai_client.clone(); let source_id = source_id.to_string(); let user_id = user_id.to_string(); - let db_client = db_client.clone(); async move { create_single_entity( @@ -104,8 +94,6 @@ impl LLMEnrichmentResult { &source_id, &user_id, mapper, - &openai_client, - &db_client, embedding_provider, ) .await @@ -145,9 +133,7 @@ async fn create_single_entity( source_id: &str, user_id: &str, mapper: Arc, - openai_client: &async_openai::Client, - db_client: &SurrealDbClient, - embedding_provider: Option<&EmbeddingProvider>, + embedding_provider: &EmbeddingProvider, ) -> Result { let assigned_id = mapper.get_id(&llm_entity.key)?.to_string(); @@ -156,11 +142,7 @@ async fn create_single_entity( llm_entity.name, llm_entity.description, llm_entity.entity_type ); - let embedding = if let Some(provider) = embedding_provider { - provider.embed(&embedding_input).await? - } else { - generate_embedding(openai_client, &embedding_input, db_client).await? - }; + let embedding = embedding_provider.embed(&embedding_input).await?; let now = Utc::now(); let entity = KnowledgeEntity { diff --git a/ingestion-pipeline/src/pipeline/services.rs b/ingestion-pipeline/src/pipeline/services.rs index be752bc..365e6da 100644 --- a/ingestion-pipeline/src/pipeline/services.rs +++ b/ingestion-pipeline/src/pipeline/services.rs @@ -187,8 +187,7 @@ impl PipelineServices for DefaultPipelineServices { let config = retrieval_pipeline::RetrievalConfig::with_entities(); match retrieval_pipeline::retrieve( &self.db, - &self.openai_client, - Some(&*self.embedding_provider), + &self.embedding_provider, &input_text, &content.user_id, config, @@ -237,10 +236,8 @@ impl PipelineServices for DefaultPipelineServices { .to_database_entities( content.id(), &content.user_id, - &self.openai_client, - &self.db, entity_concurrency, - Some(&*self.embedding_provider), + &self.embedding_provider, ) .await } @@ -258,15 +255,30 @@ impl PipelineServices for DefaultPipelineServices { overlap_tokens, )?; + if chunk_candidates.is_empty() { + return Ok(Vec::new()); + } + + // Embed all chunks of this document in one batch: a single lock acquisition and one + // blocking hop, letting the backend batch the inference internally. + let embeddings = self + .embedding_provider + .embed_batch(chunk_candidates.clone()) + .await + .map_err(|e| { + AppError::InternalError(format!("FastEmbed embedding for chunks failed: {e}")) + })?; + + if embeddings.len() != chunk_candidates.len() { + return Err(AppError::InternalError(format!( + "embedding batch returned {} vectors for {} chunks", + embeddings.len(), + chunk_candidates.len() + ))); + } + let mut chunks = Vec::with_capacity(chunk_candidates.len()); - for chunk_text in chunk_candidates { - let embedding = self - .embedding_provider - .embed(&chunk_text) - .await - .map_err(|e| { - AppError::InternalError(format!("FastEmbed embedding for chunk failed: {e}")) - })?; + for (chunk_text, embedding) in chunk_candidates.into_iter().zip(embeddings) { let chunk_struct = TextChunk::new( content.id().to_string(), chunk_text, diff --git a/main/src/bootstrap/mod.rs b/main/src/bootstrap/mod.rs index 1925091..affe0c0 100644 --- a/main/src/bootstrap/mod.rs +++ b/main/src/bootstrap/mod.rs @@ -1,7 +1,7 @@ mod startup; pub mod wiring; -pub use startup::prepare_embedding_runtime; +pub use startup::{prepare_embedding_runtime, EmbeddingRuntimeRole}; use std::sync::Arc; diff --git a/main/src/bootstrap/startup.rs b/main/src/bootstrap/startup.rs index 056199a..4d4d82b 100644 --- a/main/src/bootstrap/startup.rs +++ b/main/src/bootstrap/startup.rs @@ -2,7 +2,7 @@ use anyhow::Context; use common::{ storage::{ db::SurrealDbClient, - indexes::ensure_runtime, + indexes::{embedding_index_dimension, ensure_runtime}, types::{ knowledge_entity::KnowledgeEntity, system_settings::SystemSettings, text_chunk::TextChunk, @@ -10,37 +10,129 @@ use common::{ }, utils::embedding::EmbeddingProvider, }; +use std::time::{SystemTime, UNIX_EPOCH}; use tracing::{info, warn}; use super::SharedServices; -/// Syncs embedding settings, re-embeds stored vectors when dimensions change, and -/// ensures runtime indexes match the active embedding dimension. -pub async fn prepare_embedding_runtime(services: &SharedServices) -> anyhow::Result { - let (settings, dimensions_changed) = +/// How a process participates in embedding-runtime maintenance. +/// +/// Embedding configuration changes (model/dimension) take effect on restart: the active +/// [`EmbeddingProvider`] is built once at startup, so the stored vectors must be reconciled to it +/// before indexes are rebuilt. Only a single maintainer should perform that (potentially long, +/// destructive) re-embed; query-only servers stay read-only to avoid racing it. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +// Each binary (main/worker/server) constructs only one variant, so the other looks dead within +// that single compilation unit even though both are used across the binary set. +#[allow(dead_code)] +pub enum EmbeddingRuntimeRole { + /// Combined binary or worker: re-embeds stored data when it no longer matches the provider. + Maintainer, + /// Server-only: never mutates stored embeddings; aligns indexes to the data that exists. + ReadOnly, +} + +/// Re-embed lock TTL. Generously sized so a slow re-embed of a large corpus never expires +/// out from under the maintainer that holds it; an abandoned lock (crashed maintainer) self-heals. +const REEMBED_LOCK_TTL: &str = "30m"; + +/// Reconciles embeddings with the active provider and ensures runtime indexes are ready. +/// +/// Detection is based on the stored chunk-embedding HNSW index dimension (a persisted marker of +/// the embedding space actually in the database). When it differs from the active provider's +/// dimension, a [`EmbeddingRuntimeRole::Maintainer`] re-embeds before indexes are (re)built; +/// a [`EmbeddingRuntimeRole::ReadOnly`] server leaves indexes aligned to the existing data and +/// serves in a degraded state until a maintainer reconciles. +/// +/// # Errors +/// +/// Returns an error if syncing settings, inspecting/building indexes, or re-embedding fails. +pub async fn prepare_embedding_runtime( + services: &SharedServices, + role: EmbeddingRuntimeRole, +) -> anyhow::Result { + // Keep SystemSettings in sync with the active provider so the admin UI reflects the real + // backend/model/dimension. This does not, by itself, decide whether a re-embed is needed. + let (settings, _changed) = SystemSettings::sync_from_embedding_provider(&services.db, &services.embedding_provider) .await .context("sync system settings from embedding provider")?; - if dimensions_changed { - re_embed_all( - &services.db, - &services.embedding_provider, - settings.embedding_dimensions, - ) - .await?; - } + let target_dim = services.embedding_provider.dimension(); + let stored_dim = embedding_index_dimension(&services.db) + .await + .context("inspect stored embedding index dimension")?; + let mismatch = matches!(stored_dim, Some(dim) if dim != target_dim); - ensure_runtime( - &services.db, - settings.embedding_dimensions as usize, - ) - .await - .context("ensure runtime indexes")?; + let index_dim = if mismatch { + match role { + EmbeddingRuntimeRole::Maintainer => { + reconcile_embeddings(&services.db, &services.embedding_provider, target_dim).await?; + target_dim + } + EmbeddingRuntimeRole::ReadOnly => { + warn!( + stored_dimension = stored_dim, + target_dimension = target_dim, + "Stored embeddings do not match the active embedding dimension. A maintainer \ + (worker) must re-embed; serving in a degraded state and keeping indexes \ + aligned to the existing data until then." + ); + // Preserve the index that matches the vectors actually stored. Do not overwrite it + // to the new dimension here — that would happen before the data is re-embedded and + // would break retrieval entirely. + stored_dim.unwrap_or(target_dim) + } + } + } else { + target_dim + }; + + ensure_runtime(&services.db, index_dim) + .await + .context("ensure runtime indexes")?; Ok(settings) } +/// Acquires the re-embed lock (so only one maintainer reconciles), re-embeds, then releases it. +async fn reconcile_embeddings( + db: &SurrealDbClient, + embedding_provider: &EmbeddingProvider, + target_dim: usize, +) -> anyhow::Result<()> { + let owner = reembed_lock_owner(); + + if !try_acquire_reembed_lock(db, &owner).await? { + info!("Another maintainer holds the re-embed lock; skipping re-embed on this instance"); + return Ok(()); + } + + let result = reconcile_under_lock(db, embedding_provider, target_dim).await; + release_reembed_lock(db, &owner).await; + result +} + +/// Re-embed body executed while holding the lock, with a re-check to avoid duplicate work. +async fn reconcile_under_lock( + db: &SurrealDbClient, + embedding_provider: &EmbeddingProvider, + target_dim: usize, +) -> anyhow::Result<()> { + // A peer may have finished re-embedding between detection and lock acquisition. + let stored_dim = embedding_index_dimension(db) + .await + .context("re-check stored embedding dimension under lock")?; + if !matches!(stored_dim, Some(dim) if dim != target_dim) { + info!("Stored embeddings already match the active dimension; skipping re-embed"); + return Ok(()); + } + + let target_dim_u32 = u32::try_from(target_dim) + .map_err(|_| anyhow::anyhow!("embedding dimension {target_dim} exceeds u32::MAX"))?; + re_embed_all(db, embedding_provider, target_dim_u32).await +} + async fn re_embed_all( db: &SurrealDbClient, embedding_provider: &EmbeddingProvider, @@ -52,15 +144,112 @@ async fn re_embed_all( ); info!("Re-embedding TextChunks"); - TextChunk::update_all_embeddings_with_provider(db, embedding_provider) + TextChunk::update_all_embeddings(db, embedding_provider) .await .context("re-embed text chunks after embedding dimension change")?; info!("Re-embedding KnowledgeEntities"); - KnowledgeEntity::update_all_embeddings_with_provider(db, embedding_provider) + KnowledgeEntity::update_all_embeddings(db, embedding_provider) .await .context("re-embed knowledge entities after embedding dimension change")?; info!("Re-embedding complete"); Ok(()) } + +/// A process-unique token identifying this re-embed lock acquisition (for release). +fn reembed_lock_owner() -> String { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_or(0, |d| d.as_nanos()); + format!("reembed-{}-{nanos}", std::process::id()) +} + +/// Best-effort atomic mutex over the (potentially long) re-embed using a singleton record. +/// +/// `CREATE` of a fixed record id fails if it already exists, which serializes concurrent +/// maintainers. An expired lock is reaped first so a crashed maintainer cannot block forever. +async fn try_acquire_reembed_lock(db: &SurrealDbClient, owner: &str) -> anyhow::Result { + db.client + .query("DEFINE TABLE IF NOT EXISTS maintenance_lock SCHEMALESS;") + .await + .and_then(surrealdb::Response::check) + .context("define maintenance_lock table")?; + + db.client + .query("DELETE maintenance_lock:reembed WHERE expires_at < time::now();") + .await + .and_then(surrealdb::Response::check) + .context("reap expired re-embed lock")?; + + // `CREATE` of a fixed record id succeeds for the first caller and errors with an + // "already exists" record conflict for any concurrent caller, giving us an atomic mutex. + let acquired = db + .client + .query(format!( + "CREATE maintenance_lock:reembed SET owner = $owner, expires_at = time::now() + {REEMBED_LOCK_TTL};" + )) + .bind(("owner", owner.to_string())) + .await + .and_then(surrealdb::Response::check) + .is_ok(); + + Ok(acquired) +} + +async fn release_reembed_lock(db: &SurrealDbClient, owner: &str) { + let released = db + .client + .query("DELETE maintenance_lock:reembed WHERE owner = $owner;") + .bind(("owner", owner.to_string())) + .await + .and_then(surrealdb::Response::check); + + if let Err(err) = released { + warn!(error = %err, "Failed to release re-embed lock; it will expire automatically"); + } +} + +#[cfg(test)] +#[allow(clippy::expect_used)] +mod tests { + use super::*; + use common::storage::db::SurrealDbClient; + + async fn test_db() -> SurrealDbClient { + SurrealDbClient::memory("reembed_lock_ns", &reembed_lock_owner()) + .await + .expect("in-memory db") + } + + #[tokio::test] + async fn reembed_lock_is_exclusive_and_reusable_after_release() { + let db = test_db().await; + + let first = reembed_lock_owner(); + assert!( + try_acquire_reembed_lock(&db, &first) + .await + .expect("acquire first"), + "the first acquirer should win the lock" + ); + + // A second, concurrent maintainer must not be able to take a held lock. + let second = format!("{first}-peer"); + assert!( + !try_acquire_reembed_lock(&db, &second) + .await + .expect("contend for lock"), + "a held lock must not be granted to another owner" + ); + + // Releasing it (only the holder can) frees it for the next maintainer. + release_reembed_lock(&db, &first).await; + assert!( + try_acquire_reembed_lock(&db, &second) + .await + .expect("re-acquire after release"), + "the lock should be grantable again once released" + ); + } +} diff --git a/main/src/main.rs b/main/src/main.rs index 34f21db..1c44955 100644 --- a/main/src/main.rs +++ b/main/src/main.rs @@ -6,6 +6,7 @@ use axum::extract::FromRef; use bootstrap::{ init, prepare_embedding_runtime, wiring::{build_api_state, build_html_state, minne_routes}, + EmbeddingRuntimeRole, }; use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop}; use tracing::info; @@ -20,7 +21,8 @@ async fn main() -> anyhow::Result<()> { "Embedding provider initialized" ); - prepare_embedding_runtime(&services).await?; + // The combined binary runs the worker in-process, so it owns re-embedding. + prepare_embedding_runtime(&services, EmbeddingRuntimeRole::Maintainer).await?; let html_state = build_html_state(&services).await?; let api_state = build_api_state(&services); @@ -88,6 +90,7 @@ mod tests { prepare_embedding_runtime, tests::init_smoke_services, wiring::{build_api_state, build_html_state, minne_routes}, + EmbeddingRuntimeRole, }; use common::storage::types::{system_settings::SystemSettings, user::User}; use tower::ServiceExt; @@ -97,7 +100,7 @@ mod tests { .await .expect("failed to init services"); - prepare_embedding_runtime(&services) + prepare_embedding_runtime(&services, EmbeddingRuntimeRole::Maintainer) .await .expect("failed to prepare embedding runtime"); diff --git a/main/src/server.rs b/main/src/server.rs index 1b79391..6192b86 100644 --- a/main/src/server.rs +++ b/main/src/server.rs @@ -4,13 +4,15 @@ use axum::extract::FromRef; use bootstrap::{ init, prepare_embedding_runtime, wiring::{build_api_state, build_html_state, minne_routes}, + EmbeddingRuntimeRole, }; use tracing::info; #[tokio::main(flavor = "multi_thread", worker_threads = 2)] async fn main() -> anyhow::Result<()> { let services = init().await?; - prepare_embedding_runtime(&services).await?; + // The server never re-embeds; the worker owns that. It only ensures indexes are ready. + prepare_embedding_runtime(&services, EmbeddingRuntimeRole::ReadOnly).await?; let html_state = build_html_state(&services).await?; let api_state = build_api_state(&services); diff --git a/main/src/worker.rs b/main/src/worker.rs index 52eb8a5..4afc499 100644 --- a/main/src/worker.rs +++ b/main/src/worker.rs @@ -2,14 +2,15 @@ mod bootstrap; use std::sync::Arc; -use bootstrap::{init, prepare_embedding_runtime}; +use bootstrap::{init, prepare_embedding_runtime, EmbeddingRuntimeRole}; use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop}; use tracing::info; #[tokio::main] async fn main() -> anyhow::Result<()> { let services = init().await?; - prepare_embedding_runtime(&services).await?; + // The worker owns re-embedding: it reconciles stored vectors to the active provider. + prepare_embedding_runtime(&services, EmbeddingRuntimeRole::Maintainer).await?; info!( embedding_backend = ?services.config.embedding_backend, diff --git a/retrieval-pipeline/src/lib.rs b/retrieval-pipeline/src/lib.rs index 3bc9068..009b7a1 100644 --- a/retrieval-pipeline/src/lib.rs +++ b/retrieval-pipeline/src/lib.rs @@ -59,8 +59,7 @@ pub struct RetrievedEntity { #[instrument(skip_all, fields(user_id))] pub async fn retrieve( db_client: &SurrealDbClient, - openai_client: &async_openai::Client, - embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>, + embedding_provider: &common::utils::embedding::EmbeddingProvider, input_text: &str, user_id: &str, config: RetrievalConfig, @@ -68,7 +67,6 @@ pub async fn retrieve( ) -> Result { let params = pipeline::RetrievalParams { db_client, - openai_client, embedding_provider, input_text, user_id, @@ -82,12 +80,16 @@ pub async fn retrieve( mod tests { use super::*; use anyhow::{self}; - use async_openai::Client; use common::storage::indexes::ensure_runtime; use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}; use common::storage::types::system_settings::SystemSettings; + use common::utils::embedding::EmbeddingProvider; use uuid::Uuid; + fn test_embedding_provider() -> EmbeddingProvider { + EmbeddingProvider::new_hashed(3).unwrap_or_else(|_| unreachable!()) + } + fn test_embedding() -> Vec { vec![0.9, 0.1, 0.0] } @@ -135,11 +137,10 @@ mod tests { TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?; - let openai_client = Client::new(); + let embedding_provider = test_embedding_provider(); let params = pipeline::RetrievalParams { db_client: &db, - openai_client: &openai_client, - embedding_provider: None, + embedding_provider: &embedding_provider, input_text: "Rust concurrency async tasks", user_id, config: RetrievalConfig::default(), @@ -181,11 +182,10 @@ mod tests { TextChunk::store_with_embedding(primary_chunk, chunk_embedding_primary(), &db).await?; TextChunk::store_with_embedding(secondary_chunk, chunk_embedding_secondary(), &db).await?; - let openai_client = Client::new(); + let embedding_provider = test_embedding_provider(); let params = pipeline::RetrievalParams { db_client: &db, - openai_client: &openai_client, - embedding_provider: None, + embedding_provider: &embedding_provider, input_text: "Rust concurrency async tasks", user_id, config: RetrievalConfig::default(), @@ -236,11 +236,10 @@ mod tests { ); db.store_item(entity).await?; - let openai_client = Client::new(); + let embedding_provider = test_embedding_provider(); let params = pipeline::RetrievalParams { db_client: &db, - openai_client: &openai_client, - embedding_provider: None, + embedding_provider: &embedding_provider, input_text: "async rust programming", user_id, config: RetrievalConfig::with_entities(), diff --git a/retrieval-pipeline/src/pipeline/context.rs b/retrieval-pipeline/src/pipeline/context.rs index 3111d5e..94bee5d 100644 --- a/retrieval-pipeline/src/pipeline/context.rs +++ b/retrieval-pipeline/src/pipeline/context.rs @@ -1,4 +1,3 @@ -use async_openai::Client; use common::{ error::AppError, storage::{db::SurrealDbClient, types::text_chunk::TextChunk}, @@ -18,8 +17,7 @@ use super::{ /// Mutable working state threaded through every retrieval stage. pub(crate) struct PipelineContext<'a> { pub db_client: &'a SurrealDbClient, - pub openai_client: &'a Client, - pub embedding_provider: Option<&'a EmbeddingProvider>, + pub embedding_provider: &'a EmbeddingProvider, pub input_text: String, pub user_id: String, pub config: RetrievalConfig, @@ -36,7 +34,6 @@ impl<'a> PipelineContext<'a> { pub fn new(params: RetrievalParams<'a>) -> Self { Self { db_client: params.db_client, - openai_client: params.openai_client, embedding_provider: params.embedding_provider, input_text: params.input_text.to_owned(), user_id: params.user_id.to_owned(), diff --git a/retrieval-pipeline/src/pipeline/mod.rs b/retrieval-pipeline/src/pipeline/mod.rs index 1761099..b833ae1 100644 --- a/retrieval-pipeline/src/pipeline/mod.rs +++ b/retrieval-pipeline/src/pipeline/mod.rs @@ -7,7 +7,6 @@ pub use config::{RetrievalConfig, RetrievalTuning}; pub use diagnostics::Diagnostics; use crate::{round_score, RetrievalOutput, RetrievedEntity}; -use async_openai::Client; use async_trait::async_trait; use common::{error::AppError, storage::db::SurrealDbClient}; use std::time::{Duration, Instant}; @@ -91,8 +90,7 @@ pub struct RunOutput { /// Inputs required to run a retrieval. pub struct RetrievalParams<'a> { pub db_client: &'a SurrealDbClient, - pub openai_client: &'a Client, - pub embedding_provider: Option<&'a common::utils::embedding::EmbeddingProvider>, + pub embedding_provider: &'a common::utils::embedding::EmbeddingProvider, pub input_text: &'a str, pub user_id: &'a str, pub config: RetrievalConfig, diff --git a/retrieval-pipeline/src/pipeline/stages.rs b/retrieval-pipeline/src/pipeline/stages.rs index a9b107c..857eb36 100644 --- a/retrieval-pipeline/src/pipeline/stages.rs +++ b/retrieval-pipeline/src/pipeline/stages.rs @@ -2,7 +2,6 @@ use async_trait::async_trait; use common::{ error::AppError, storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk}, - utils::embedding::generate_embedding, }; use fastembed::RerankResult; use std::collections::HashMap; @@ -97,11 +96,7 @@ pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { debug!("Reusing cached query embedding for hybrid retrieval"); } else { debug!("Generating query embedding for hybrid retrieval"); - let embedding = if let Some(provider) = ctx.embedding_provider { - provider.embed(&ctx.input_text).await? - } else { - generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await? - }; + let embedding = ctx.embedding_provider.embed(&ctx.input_text).await?; ctx.query_embedding = Some(embedding); }