use std::{ collections::hash_map::DefaultHasher, hash::{Hash, Hasher}, str::FromStr, sync::{Arc, Mutex}, thread::available_parallelism, }; use serde::Serialize; use tracing::warn; use async_openai::{types::CreateEmbeddingRequestArgs, Client}; use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions}; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use crate::{ error::{AppError, EmbeddingError}, storage::{db::SurrealDbClient, types::system_settings::SystemSettings}, utils::config::AppConfig, }; #[allow(clippy::module_name_repetitions)] pub use crate::utils::config::{EmbeddingBackend, ParseEmbeddingBackendError}; /// Wrapper around the chosen embedding backend. #[allow(clippy::module_name_repetitions)] #[derive(Clone)] pub struct EmbeddingProvider { /// Concrete backend implementation. inner: EmbeddingInner, } /// Concrete embedding implementations. #[derive(Clone)] enum EmbeddingInner { /// Uses an `OpenAI`-compatible API. OpenAI { /// Client used to issue embedding requests. client: Arc>, /// Model identifier for the API. model: Arc, /// Expected output dimensions. dimensions: u32, }, /// Generates deterministic hashed embeddings without external calls. Hashed { /// Output vector length. dimension: usize, }, /// Uses `FastEmbed` running locally. FastEmbed { /// Pool of `FastEmbed` engines providing bounded-concurrency local embedding. pool: Arc, /// Model metadata used for info logging. model_name: EmbeddingModel, /// Output vector length. dimension: usize, }, } /// Batch size used when re-embedding stored data in bulk. Bounds peak memory and preserves /// progress logging while still amortising per-call lock/dispatch overhead. pub const RE_EMBED_BATCH_SIZE: usize = 128; /// Default FastEmbed model (`BGESmallENV15`) when config and DB do not specify a valid code. pub const DEFAULT_FASTEMBED_MODEL_CODE: &str = "Xenova/bge-small-en-v1.5"; /// A supported FastEmbed model for admin UI and documentation. #[derive(Clone, Debug, Serialize)] pub struct FastEmbedModelOption { /// HuggingFace-style `model_code` accepted by [`EmbeddingModel::from_str`]. pub model_code: String, /// Fixed output dimension for this model. pub dimension: u32, /// Short human-readable description from fastembed metadata. pub description: String, } /// Lists supported FastEmbed text embedding models (sorted by `model_code`). #[must_use] pub fn list_fastembed_embedding_models() -> Vec { let mut list: Vec = TextEmbedding::list_supported_models() .into_iter() .filter_map(|info| { let dimension = u32::try_from(info.dim).ok()?; Some(FastEmbedModelOption { model_code: info.model_code, dimension, description: info.description, }) }) .collect(); list.sort_by(|left, right| left.model_code.cmp(&right.model_code)); list } /// Returns true when `code` is a supported FastEmbed `model_code` (HuggingFace-style id). #[must_use] pub fn is_valid_fastembed_model_code(code: &str) -> bool { !code.trim().is_empty() && EmbeddingModel::from_str(code.trim()).is_ok() } /// Vector dimension for a supported FastEmbed `model_code`. /// /// # Errors /// /// Returns [`EmbeddingError::UnknownModel`] when the code is not recognized. pub fn fastembed_model_dimension(code: &str) -> Result { let model = EmbeddingModel::from_str(code.trim()) .map_err(|_| EmbeddingError::UnknownModel(unknown_fastembed_model_message(code)))?; let dim = EmbeddingModel::get_model_info(&model) .ok_or_else(|| { EmbeddingError::Config(format!("fastembed model metadata missing for {code}")) })? .dim; u32::try_from(dim).map_err(|_| { EmbeddingError::Config(format!("fastembed model dimension {dim} exceeds u32::MAX")) }) } /// Resolves the FastEmbed model code to load: config override, then DB, then default. /// /// When `config.fastembed_model` is set it must be valid. When only the DB value is used and it /// is not a FastEmbed code (e.g. legacy `text-embedding-3-small`), returns the default model. /// /// # Errors /// /// Returns [`EmbeddingError::UnknownModel`] if `config.fastembed_model` is set but invalid. pub fn resolve_fastembed_model_code( config: &AppConfig, settings_model: &str, ) -> Result { if let Some(code) = config.fastembed_model.as_deref() { let trimmed = code.trim(); if trimmed.is_empty() { return Err(EmbeddingError::Config( "fastembed_model must not be empty when set".into(), )); } EmbeddingModel::from_str(trimmed) .map_err(|_| EmbeddingError::UnknownModel(unknown_fastembed_model_message(trimmed)))?; return Ok(trimmed.to_owned()); } let trimmed = settings_model.trim(); if is_valid_fastembed_model_code(trimmed) { return Ok(trimmed.to_owned()); } if !trimmed.is_empty() { warn!( stored_model = trimmed, default_model = DEFAULT_FASTEMBED_MODEL_CODE, "system_settings.embedding_model is not a FastEmbed model code; using default" ); } Ok(DEFAULT_FASTEMBED_MODEL_CODE.to_owned()) } /// Persists a FastEmbed-compatible `embedding_model` and `embedding_dimensions` before startup /// when the active backend is FastEmbed and stored settings still carry OpenAI defaults. /// /// # Errors /// /// Returns [`AppError`] if settings cannot be loaded, resolved, or updated. pub async fn align_fastembed_system_settings( db: &SurrealDbClient, config: &AppConfig, ) -> Result { if config.embedding_backend != EmbeddingBackend::FastEmbed { return SystemSettings::get_current(db).await; } let mut settings = SystemSettings::get_current(db).await?; let resolved = resolve_fastembed_model_code(config, &settings.embedding_model)?; let dimension = fastembed_model_dimension(&resolved)?; if settings.embedding_model == resolved && settings.embedding_dimensions == dimension { return Ok(settings); } tracing::info!( old_model = %settings.embedding_model, new_model = %resolved, old_dimensions = settings.embedding_dimensions, new_dimensions = dimension, "Aligning system settings with FastEmbed model" ); settings.embedding_model = resolved; settings.embedding_dimensions = dimension; SystemSettings::update(db, settings).await } fn unknown_fastembed_model_message(code: &str) -> String { let mut codes: Vec = TextEmbedding::list_supported_models() .into_iter() .map(|info| info.model_code) .collect(); codes.sort(); let examples: Vec<&str> = codes.iter().take(6).map(String::as_str).collect(); format!( "unknown FastEmbed model '{code}' (expected a HuggingFace model_code such as {}). \ Set fastembed_model in config.yaml or update system_settings; \ see docs/configuration.md ({count} models supported)", examples.join(", "), count = codes.len() ) } /// Default FastEmbed pool size. /// /// Kept small on purpose: the ONNX runtime already uses intra-op threads per inference, so /// running many engines concurrently oversubscribes the CPU and each engine duplicates the /// model weights in memory. Mirrors the reranker pool default. #[must_use] pub fn default_embedding_pool_size() -> usize { available_parallelism() .map_or(2, |value| value.get().min(2)) .max(1) } /// Pool of `FastEmbed` engines enabling bounded-concurrency local embedding. /// /// A single [`TextEmbedding`] embeds one batch at a time (`&mut self`), so the pool keeps /// several instances and hands out a distinct idle engine per checkout. The semaphore bounds /// total in-flight embeds (backpressure); the free list guarantees each active lease holds a /// different engine — unlike a round-robin index, which can hand the same engine to two callers. struct FastEmbedPool { /// Idle engines; one is popped on checkout and returned on lease drop. engines: Mutex>>>, /// Sized to the engine count; gates concurrent checkouts. semaphore: Arc, } impl FastEmbedPool { fn new(engines: Vec>>) -> Self { let permits = engines.len().max(1); Self { engines: Mutex::new(engines), semaphore: Arc::new(Semaphore::new(permits)), } } /// Acquire a permit and borrow a distinct idle engine. The permit guarantees an engine is /// available, so the pop always succeeds for a correctly sized pool. async fn checkout(self: &Arc) -> Result { let permit = Arc::clone(&self.semaphore) .acquire_owned() .await .map_err(|_| EmbeddingError::Config("embedding pool is closed".into()))?; let engine = self .engines .lock() .map_err(EmbeddingError::mutex_poisoned)? .pop() .ok_or_else(|| EmbeddingError::Config("embedding pool unexpectedly empty".into()))?; Ok(FastEmbedLease { pool: Arc::clone(self), engine, _permit: permit, }) } } /// Active borrow of a single `FastEmbed` engine; returns it to the pool on drop. struct FastEmbedLease { pool: Arc, engine: Arc>, /// Released after the engine is returned, unblocking the next checkout. _permit: OwnedSemaphorePermit, } impl FastEmbedLease { async fn embed(&self, texts: &[String]) -> Result>, EmbeddingError> { let engine = Arc::clone(&self.engine); let texts = texts.to_vec(); tokio::task::spawn_blocking(move || -> Result>, EmbeddingError> { let mut guard = engine.lock().map_err(EmbeddingError::mutex_poisoned)?; guard.embed(texts, None).map_err(EmbeddingError::fastembed) }) .await .map_err(EmbeddingError::from)? } } impl Drop for FastEmbedLease { fn drop(&mut self) { if let Ok(mut free) = self.pool.engines.lock() { free.push(Arc::clone(&self.engine)); } } } async fn run_fastembed( pool: &Arc, texts: &[String], ) -> Result>, EmbeddingError> { let lease = pool.checkout().await?; lease.embed(texts).await } impl EmbeddingProvider { #[must_use] pub fn backend_label(&self) -> &'static str { match self.inner { EmbeddingInner::Hashed { .. } => "hashed", EmbeddingInner::FastEmbed { .. } => "fastembed", EmbeddingInner::OpenAI { .. } => "openai", } } #[must_use] pub fn dimension(&self) -> usize { match &self.inner { EmbeddingInner::Hashed { dimension } | EmbeddingInner::FastEmbed { dimension, .. } => { *dimension } EmbeddingInner::OpenAI { dimensions, .. } => *dimensions as usize, } } #[must_use] pub fn model_code(&self) -> Option { match &self.inner { EmbeddingInner::FastEmbed { model_name, .. } => Some(model_name.to_string()), EmbeddingInner::OpenAI { model, .. } => Some(model.as_ref().to_owned()), EmbeddingInner::Hashed { .. } => None, } } /// Generate an embedding vector for the given text. /// /// # Errors /// /// Returns [`EmbeddingError`] if the backend API call fails, FastEmbed initialisation fails, /// or the backend returns no embedding data. pub async fn embed(&self, text: &str) -> Result, EmbeddingError> { match &self.inner { EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)), EmbeddingInner::FastEmbed { pool, .. } => { let text = text.to_owned(); let embeddings = run_fastembed(pool, std::slice::from_ref(&text)).await?; embeddings.into_iter().next().ok_or(EmbeddingError::NoData) } EmbeddingInner::OpenAI { client, model, dimensions, } => { let request = CreateEmbeddingRequestArgs::default() .model(model.as_ref()) .input([text]) .dimensions(*dimensions) .build()?; let response = client.embeddings().create(request).await?; let embedding = response .data .first() .ok_or(EmbeddingError::NoData)? .embedding .clone(); Ok(embedding) } } } /// Generate embedding vectors for a batch of texts. /// /// # Errors /// /// Returns [`EmbeddingError`] if the backend API call fails or returns no embedding data. /// Returns an empty `Vec` when `texts` is empty. pub async fn embed_batch(&self, texts: &[String]) -> Result>, EmbeddingError> { match &self.inner { EmbeddingInner::Hashed { dimension } => Ok(texts .iter() .map(|text| hashed_embedding(text, *dimension)) .collect()), EmbeddingInner::FastEmbed { pool, .. } => { if texts.is_empty() { return Ok(Vec::new()); } run_fastembed(pool, texts).await } EmbeddingInner::OpenAI { client, model, dimensions, } => { if texts.is_empty() { return Ok(Vec::new()); } let request = CreateEmbeddingRequestArgs::default() .model(model.as_ref()) .input(texts.to_vec()) .dimensions(*dimensions) .build()?; let response = client.embeddings().create(request).await?; let embeddings: Vec> = response .data .into_iter() .map(|item| item.embedding) .collect(); Ok(embeddings) } } } /// # Errors /// /// Currently infallible; reserved for future validation. pub fn new_openai( client: Arc>, model: impl AsRef, dimensions: u32, ) -> Result { Ok(Self { inner: EmbeddingInner::OpenAI { client, model: Arc::from(model.as_ref()), dimensions, }, }) } /// Initialise a local FastEmbed provider backed by a pool of `pool_size` engines. /// /// `pool_size` is clamped to at least 1. Larger pools allow concurrent embeds at the cost of /// `pool_size`× model memory; see [`default_embedding_pool_size`] for guidance. /// /// # Errors /// /// Returns [`EmbeddingError`] if the model name is unknown or FastEmbed initialisation fails. pub async fn new_fastembed( model_override: Option, pool_size: usize, ) -> Result { let pool_size = pool_size.max(1); let model_name = if let Some(code) = model_override { EmbeddingModel::from_str(code.trim()) .map_err(|_| EmbeddingError::UnknownModel(unknown_fastembed_model_message(&code)))? } else { EmbeddingModel::default() }; let model_name_for_task = model_name.clone(); let model_name_code = model_name.to_string(); let (engines, dimension) = match tokio::task::spawn_blocking(move || -> Result<_, EmbeddingError> { let info = EmbeddingModel::get_model_info(&model_name_for_task).ok_or_else(|| { EmbeddingError::Config(format!( "fastembed model metadata missing for {model_name_code}" )) })?; let mut engines = Vec::with_capacity(pool_size); for index in 0..pool_size { let options = TextInitOptions::new(model_name_for_task.clone()) // Only the first engine reports download progress; the rest reuse the cache. .with_show_download_progress(index == 0); let model = TextEmbedding::try_new(options).map_err(EmbeddingError::fastembed)?; engines.push(Arc::new(Mutex::new(model))); } Ok((engines, info.dim)) }) .await { Ok(result) => result?, Err(join_error) => return Err(EmbeddingError::from(join_error)), }; Ok(EmbeddingProvider { inner: EmbeddingInner::FastEmbed { pool: Arc::new(FastEmbedPool::new(engines)), model_name, dimension, }, }) } /// # Errors /// /// Currently infallible; reserved for future validation. pub fn new_hashed(dimension: usize) -> Result { Ok(EmbeddingProvider { inner: EmbeddingInner::Hashed { dimension: dimension.max(1), }, }) } /// Creates an embedding provider from persisted settings and bootstrap config. /// /// OpenAI/hashed model settings come from [`SystemSettings`]. FastEmbed uses /// [`resolve_fastembed_model_code`] (config `fastembed_model` overrides DB). The active /// backend is taken from `config.embedding_backend`; [`SystemSettings::sync_from_embedding_provider`] /// persists the resolved backend to the database after startup. /// /// # Errors /// /// Returns [`EmbeddingError`] if the selected backend cannot be initialised. pub async fn from_system_settings( settings: &SystemSettings, config: &AppConfig, openai_client: Option>>, ) -> Result { let dimensions = settings.embedding_dimensions; match config.embedding_backend { EmbeddingBackend::OpenAI => { let client = openai_client.ok_or_else(|| { EmbeddingError::Config( "openai embedding backend requires an openai client".into(), ) })?; Self::new_openai(client, settings.embedding_model.as_str(), dimensions) } EmbeddingBackend::FastEmbed => { let pool_size = config .embedding_pool_size .unwrap_or_else(default_embedding_pool_size); let model_code = resolve_fastembed_model_code(config, &settings.embedding_model)?; Self::new_fastembed(Some(model_code), pool_size).await } EmbeddingBackend::Hashed => { let dimension = usize::try_from(dimensions).map_err(|_| { EmbeddingError::Config("embedding_dimensions exceeds usize::MAX".into()) })?; Self::new_hashed(dimension) } } } } // Helper functions for hashed embeddings /// Generates a hashed embedding vector without external dependencies. fn hashed_embedding(text: &str, dimension: usize) -> Vec { let dim = dimension.max(1); let mut vector = vec![0.0f32; dim]; if text.is_empty() { return vector; } for token in tokens(text) { let idx = bucket(&token, dim); if let Some(slot) = vector.get_mut(idx) { *slot += 1.0; } } let norm = vector.iter().map(|v| v * v).sum::().sqrt(); if norm > 0.0 { for value in &mut vector { *value /= norm; } } vector } /// Tokenizes the text into alphanumeric lowercase tokens. fn tokens(text: &str) -> impl Iterator + '_ { text.split(|c: char| !c.is_ascii_alphanumeric()) .filter(|token| !token.is_empty()) .map(str::to_ascii_lowercase) } /// Buckets a token into the hashed embedding vector. #[allow(clippy::arithmetic_side_effects)] fn bucket(token: &str, dimension: usize) -> usize { let safe_dimension = dimension.max(1); let mut hasher = DefaultHasher::new(); token.hash(&mut hasher); usize::try_from(hasher.finish()).unwrap_or_default() % safe_dimension } #[cfg(test)] mod tests { #![allow(clippy::expect_used)] use super::{ align_fastembed_system_settings, fastembed_model_dimension, list_fastembed_embedding_models, resolve_fastembed_model_code, EmbeddingError, DEFAULT_FASTEMBED_MODEL_CODE, }; use crate::storage::types::system_settings::SystemSettings; use crate::utils::config::{AppConfig, EmbeddingBackend, ParseEmbeddingBackendError}; use serde_json::json; #[test] fn embedding_backend_defaults_to_fastembed() { assert_eq!(EmbeddingBackend::default(), EmbeddingBackend::FastEmbed); } #[test] fn embedding_backend_as_str_matches_serde_names() { assert_eq!(EmbeddingBackend::OpenAI.as_str(), "openai"); assert_eq!(EmbeddingBackend::FastEmbed.as_str(), "fastembed"); assert_eq!(EmbeddingBackend::Hashed.as_str(), "hashed"); assert_eq!( serde_json::to_string(&EmbeddingBackend::FastEmbed).expect("serialize"), "\"fastembed\"" ); } #[test] fn embedding_backend_deserializes_lowercase_values() { let openai: EmbeddingBackend = serde_json::from_str("\"openai\"").expect("openai"); let fastembed: EmbeddingBackend = serde_json::from_str("\"fastembed\"").expect("fastembed"); let hashed: EmbeddingBackend = serde_json::from_str("\"hashed\"").expect("hashed"); assert_eq!(openai, EmbeddingBackend::OpenAI); assert_eq!(fastembed, EmbeddingBackend::FastEmbed); assert_eq!(hashed, EmbeddingBackend::Hashed); } #[test] fn embedding_backend_from_str_accepts_aliases() { assert_eq!( "fast-embed" .parse::() .expect("fast-embed"), EmbeddingBackend::FastEmbed ); assert_eq!( "FASTEMBED".parse::().expect("FASTEMBED"), EmbeddingBackend::FastEmbed ); assert!(matches!( "unknown-backend".parse::(), Err(ParseEmbeddingBackendError { .. }) )); } #[test] fn list_fastembed_embedding_models_includes_default() { let models = list_fastembed_embedding_models(); assert!( models .iter() .any(|m| m.model_code == DEFAULT_FASTEMBED_MODEL_CODE), "catalog should include the default FastEmbed model" ); } #[test] fn resolve_fastembed_model_prefers_config_over_db() { let config = AppConfig { fastembed_model: Some("Xenova/bge-base-en-v1.5".into()), ..AppConfig::default() }; let resolved = resolve_fastembed_model_code(&config, "text-embedding-3-small").expect("config model"); assert_eq!(resolved, "Xenova/bge-base-en-v1.5"); } #[test] fn resolve_fastembed_model_falls_back_from_openai_default() { let config = AppConfig::default(); let resolved = resolve_fastembed_model_code(&config, "text-embedding-3-small").expect("default model"); assert_eq!(resolved, DEFAULT_FASTEMBED_MODEL_CODE); } #[test] fn resolve_fastembed_model_rejects_invalid_config_override() { let config = AppConfig { fastembed_model: Some("not-a-real-model".into()), ..AppConfig::default() }; let err = resolve_fastembed_model_code(&config, "Xenova/bge-small-en-v1.5") .expect_err("invalid config model"); assert!(matches!(err, EmbeddingError::UnknownModel(_))); } #[test] fn fastembed_model_dimension_matches_model_metadata() { let dim = fastembed_model_dimension(DEFAULT_FASTEMBED_MODEL_CODE).expect("dim"); assert_eq!(dim, 384); } #[tokio::test] async fn align_fastembed_system_settings_replaces_openai_default() -> anyhow::Result<()> { use crate::storage::db::SurrealDbClient; use uuid::Uuid; let db = SurrealDbClient::memory("align_fe", &Uuid::new_v4().to_string()).await?; db.apply_migrations().await?; let config = AppConfig { embedding_backend: EmbeddingBackend::FastEmbed, ..AppConfig::default() }; let settings = align_fastembed_system_settings(&db, &config).await?; assert_eq!(settings.embedding_model, DEFAULT_FASTEMBED_MODEL_CODE); assert_eq!(settings.embedding_dimensions, 384); Ok(()) } #[test] fn system_settings_deserializes_embedding_backend_field() { let value = json!({ "id": "current", "registrations_enabled": true, "require_email_verification": false, "query_model": "gpt-4o-mini", "processing_model": "gpt-4o-mini", "embedding_model": "text-embedding-3-small", "embedding_dimensions": 1536, "embedding_backend": "hashed", "query_system_prompt": "query", "ingestion_system_prompt": "ingestion", "image_processing_model": "gpt-4o-mini", "image_processing_prompt": "image", "voice_processing_model": "whisper-1", }); let settings: SystemSettings = serde_json::from_value(value).expect("deserialize system settings"); assert_eq!(settings.embedding_backend, Some(EmbeddingBackend::Hashed)); } }