use std::{ collections::hash_map::DefaultHasher, hash::{Hash, Hasher}, str::FromStr, sync::Arc, }; use anyhow::{anyhow, Context, Result}; use async_openai::{types::CreateEmbeddingRequestArgs, Client}; use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions}; use tokio::sync::Mutex; use tracing::debug; use crate::{ error::AppError, storage::{db::SurrealDbClient, types::system_settings::SystemSettings}, }; /// Supported embedding backends. #[allow(clippy::module_name_repetitions)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum EmbeddingBackend { #[default] OpenAI, FastEmbed, Hashed, } impl std::str::FromStr for EmbeddingBackend { type Err = anyhow::Error; fn from_str(s: &str) -> Result { match s.to_ascii_lowercase().as_str() { "openai" => Ok(Self::OpenAI), "hashed" => Ok(Self::Hashed), "fastembed" | "fast-embed" | "fast" => Ok(Self::FastEmbed), other => Err(anyhow!( "unknown embedding backend '{other}'. Expected 'openai', 'hashed', or 'fastembed'." )), } } } /// Wrapper around the chosen embedding backend. #[allow(clippy::module_name_repetitions)] #[derive(Clone)] pub struct EmbeddingProvider { /// Concrete backend implementation. inner: EmbeddingInner, } /// Concrete embedding implementations. #[derive(Clone)] enum EmbeddingInner { /// Uses an `OpenAI`-compatible API. OpenAI { /// Client used to issue embedding requests. client: Arc>, /// Model identifier for the API. model: String, /// Expected output dimensions. dimensions: u32, }, /// Generates deterministic hashed embeddings without external calls. Hashed { /// Output vector length. dimension: usize, }, /// Uses `FastEmbed` running locally. FastEmbed { /// Shared `FastEmbed` model. model: Arc>, /// Model metadata used for info logging. model_name: EmbeddingModel, /// Output vector length. dimension: usize, }, } impl EmbeddingProvider { pub fn backend_label(&self) -> &'static str { match self.inner { EmbeddingInner::Hashed { .. } => "hashed", EmbeddingInner::FastEmbed { .. } => "fastembed", EmbeddingInner::OpenAI { .. } => "openai", } } pub fn dimension(&self) -> usize { match &self.inner { EmbeddingInner::Hashed { dimension } | EmbeddingInner::FastEmbed { dimension, .. } => { *dimension } EmbeddingInner::OpenAI { dimensions, .. } => *dimensions as usize, } } pub fn model_code(&self) -> Option { match &self.inner { EmbeddingInner::FastEmbed { model_name, .. } => Some(model_name.to_string()), EmbeddingInner::OpenAI { model, .. } => Some(model.clone()), EmbeddingInner::Hashed { .. } => None, } } pub async fn embed(&self, text: &str) -> Result> { match &self.inner { EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)), EmbeddingInner::FastEmbed { model, .. } => { let mut guard = model.lock().await; let embeddings = guard .embed(vec![text.to_owned()], None) .context("generating fastembed vector")?; embeddings .into_iter() .next() .ok_or_else(|| anyhow!("fastembed returned no embedding for input")) } EmbeddingInner::OpenAI { client, model, dimensions, } => { let request = CreateEmbeddingRequestArgs::default() .model(model.clone()) .input([text]) .dimensions(*dimensions) .build()?; let response = client.embeddings().create(request).await?; let embedding = response .data .first() .ok_or_else(|| anyhow!("No embedding data received from OpenAI API"))? .embedding .clone(); Ok(embedding) } } } pub async fn embed_batch(&self, texts: Vec) -> Result>> { match &self.inner { EmbeddingInner::Hashed { dimension } => Ok(texts .into_iter() .map(|text| hashed_embedding(&text, *dimension)) .collect()), EmbeddingInner::FastEmbed { model, .. } => { if texts.is_empty() { return Ok(Vec::new()); } let mut guard = model.lock().await; guard .embed(texts, None) .context("generating fastembed batch embeddings") } EmbeddingInner::OpenAI { client, model, dimensions, } => { if texts.is_empty() { return Ok(Vec::new()); } let request = CreateEmbeddingRequestArgs::default() .model(model.clone()) .input(texts) .dimensions(*dimensions) .build()?; let response = client.embeddings().create(request).await?; let embeddings: Vec> = response .data .into_iter() .map(|item| item.embedding) .collect(); Ok(embeddings) } } } pub fn new_openai( client: Arc>, model: String, dimensions: u32, ) -> Result { Ok(Self { inner: EmbeddingInner::OpenAI { client, model, dimensions, }, }) } pub async fn new_fastembed(model_override: Option) -> Result { let model_name = if let Some(code) = model_override { EmbeddingModel::from_str(&code).map_err(|err| anyhow!(err))? } else { EmbeddingModel::default() }; let options = TextInitOptions::new(model_name.clone()).with_show_download_progress(true); let model_name_for_task = model_name.clone(); let model_name_code = model_name.to_string(); let (model, dimension) = tokio::task::spawn_blocking(move || -> Result<_> { let model = TextEmbedding::try_new(options).context("initialising FastEmbed text model")?; let info = EmbeddingModel::get_model_info(&model_name_for_task) .ok_or_else(|| anyhow!("FastEmbed model metadata missing for {model_name_code}"))?; Ok((model, info.dim)) }) .await .context("joining FastEmbed initialisation task")??; Ok(EmbeddingProvider { inner: EmbeddingInner::FastEmbed { model: Arc::new(Mutex::new(model)), model_name, dimension, }, }) } pub fn new_hashed(dimension: usize) -> Result { Ok(EmbeddingProvider { inner: EmbeddingInner::Hashed { dimension: dimension.max(1), }, }) } /// Creates an embedding provider based on application configuration. /// /// Dispatches to the appropriate constructor based on `config.embedding_backend`: /// - `OpenAI`: Requires a valid OpenAI client /// - `FastEmbed`: Uses local embedding model /// - `Hashed`: Uses deterministic hashed embeddings (for testing) pub async fn from_config( config: &crate::utils::config::AppConfig, openai_client: Option>>, ) -> Result { use crate::utils::config::EmbeddingBackend; match config.embedding_backend { EmbeddingBackend::OpenAI => { let client = openai_client .ok_or_else(|| anyhow!("OpenAI embedding backend requires an OpenAI client"))?; // Use defaults that match SystemSettings initial values Self::new_openai(client, "text-embedding-3-small".to_string(), 1536) } EmbeddingBackend::FastEmbed => { // Use nomic-embed-text-v1.5 as the default FastEmbed model Self::new_fastembed(Some("nomic-ai/nomic-embed-text-v1.5".to_string())).await } EmbeddingBackend::Hashed => Self::new_hashed(384), } } } // Helper functions for hashed embeddings /// Generates a hashed embedding vector without external dependencies. fn hashed_embedding(text: &str, dimension: usize) -> Vec { let dim = dimension.max(1); let mut vector = vec![0.0f32; dim]; if text.is_empty() { return vector; } for token in tokens(text) { let idx = bucket(&token, dim); if let Some(slot) = vector.get_mut(idx) { *slot += 1.0; } } let norm = vector.iter().map(|v| v * v).sum::().sqrt(); if norm > 0.0 { for value in &mut vector { *value /= norm; } } vector } /// Tokenizes the text into alphanumeric lowercase tokens. fn tokens(text: &str) -> impl Iterator + '_ { text.split(|c: char| !c.is_ascii_alphanumeric()) .filter(|token| !token.is_empty()) .map(str::to_ascii_lowercase) } /// Buckets a token into the hashed embedding vector. #[allow(clippy::arithmetic_side_effects)] fn bucket(token: &str, dimension: usize) -> usize { let safe_dimension = dimension.max(1); let mut hasher = DefaultHasher::new(); token.hash(&mut hasher); usize::try_from(hasher.finish()).unwrap_or_default() % safe_dimension } // Backward compatibility function pub async fn generate_embedding_with_provider( provider: &EmbeddingProvider, input: &str, ) -> Result, AppError> { provider.embed(input).await.map_err(|e| AppError::InternalError(e.to_string())) } /// Generates an embedding vector for the given input text using `OpenAI`'s embedding model. /// /// This function takes a text input and converts it into a numerical vector representation (embedding) /// using `OpenAI`'s text-embedding-3-small model. These embeddings can be used for semantic similarity /// comparisons, vector search, and other natural language processing tasks. /// /// # Arguments /// /// * `client`: The `OpenAI` client instance used to make API requests. /// * `input`: The text string to generate embeddings for. /// /// # Returns /// /// Returns a `Result` containing either: /// * `Ok(Vec)`: A vector of 32-bit floating point numbers representing the text embedding /// * `Err(ProcessingError)`: An error if the embedding generation fails /// /// # Errors /// /// This function can return a `AppError` in the following cases: /// * If the `OpenAI` API request fails /// * If the request building fails /// * If no embedding data is received in the response #[allow(clippy::module_name_repetitions)] pub async fn generate_embedding( client: &async_openai::Client, input: &str, db: &SurrealDbClient, ) -> Result, AppError> { let model = SystemSettings::get_current(db).await?; let request = CreateEmbeddingRequestArgs::default() .model(model.embedding_model) .dimensions(model.embedding_dimensions) .input([input]) .build()?; // Send the request to OpenAI let response = client.embeddings().create(request).await?; // Extract the embedding vector let embedding: Vec = response .data .first() .ok_or_else(|| AppError::LLMParsing("No embedding data received".into()))? .embedding .clone(); Ok(embedding) } /// Generates an embedding vector using a specific model and dimension. /// /// This is used for the re-embedding process where the model and dimensions /// are known ahead of time and shouldn't be repeatedly fetched from settings. pub async fn generate_embedding_with_params( client: &async_openai::Client, input: &str, model: &str, dimensions: u32, ) -> Result, AppError> { let request = CreateEmbeddingRequestArgs::default() .model(model) .input([input]) .dimensions(dimensions) .build()?; let response = client.embeddings().create(request).await?; let embedding = response .data .first() .ok_or_else(|| AppError::LLMParsing("No embedding data received from API".into()))? .embedding .clone(); debug!( "Embedding was created with {:?} dimensions", embedding.len() ); Ok(embedding) }