From 2939e4c2a480eaa3725d2fe8ce7f05837efdb374 Mon Sep 17 00:00:00 2001 From: Per Stark Date: Sat, 29 Nov 2025 20:07:48 +0100 Subject: [PATCH] fix: removed stale embeddings handler --- eval/src/args.rs | 2 - eval/src/db_helpers.rs | 2 +- eval/src/embedding.rs | 171 -------------------------------------- eval/src/ingest/config.rs | 51 +----------- eval/src/main.rs | 1 - main/src/main.rs | 2 +- 6 files changed, 3 insertions(+), 226 deletions(-) delete mode 100644 eval/src/embedding.rs diff --git a/eval/src/args.rs b/eval/src/args.rs index fde67ba..88e9a0d 100644 --- a/eval/src/args.rs +++ b/eval/src/args.rs @@ -439,7 +439,6 @@ impl Config { pub struct ParsedArgs { pub config: Config, - pub show_help: bool, } pub fn parse() -> Result { @@ -447,7 +446,6 @@ pub fn parse() -> Result { config.finalize()?; Ok(ParsedArgs { config, - show_help: false, // Clap handles help automatically }) } diff --git a/eval/src/db_helpers.rs b/eval/src/db_helpers.rs index b12b74c..9be1fc4 100644 --- a/eval/src/db_helpers.rs +++ b/eval/src/db_helpers.rs @@ -1,6 +1,5 @@ use anyhow::{Context, Result}; use common::storage::{db::SurrealDbClient, indexes::ensure_runtime_indexes}; -use serde::Deserialize; use tracing::info; // Remove and recreate HNSW indexes for changing embedding lengths, used at beginning if embedding length differs from default system settings. @@ -50,6 +49,7 @@ pub async fn reset_namespace(db: &SurrealDbClient, namespace: &str, database: &s #[cfg(test)] mod tests { use super::*; + use serde::Deserialize; use uuid::Uuid; #[derive(Debug, Deserialize)] diff --git a/eval/src/embedding.rs b/eval/src/embedding.rs deleted file mode 100644 index c17f1fc..0000000 --- a/eval/src/embedding.rs +++ /dev/null @@ -1,171 +0,0 @@ -use std::{ - collections::hash_map::DefaultHasher, - hash::{Hash, Hasher}, - str::FromStr, - sync::Arc, -}; - -use anyhow::{anyhow, Context, Result}; -use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions}; -use tokio::sync::Mutex; - -use crate::args::{Config, EmbeddingBackend}; - -#[derive(Clone)] -pub struct EmbeddingProvider { - inner: EmbeddingInner, -} - -#[derive(Clone)] -enum EmbeddingInner { - Hashed { - dimension: usize, - }, - FastEmbed { - model: Arc>, - model_name: EmbeddingModel, - dimension: usize, - }, -} - -impl EmbeddingProvider { - pub fn backend_label(&self) -> &'static str { - match self.inner { - EmbeddingInner::Hashed { .. } => "hashed", - EmbeddingInner::FastEmbed { .. } => "fastembed", - } - } - - pub fn dimension(&self) -> usize { - match &self.inner { - EmbeddingInner::Hashed { dimension } => *dimension, - EmbeddingInner::FastEmbed { dimension, .. } => *dimension, - } - } - - pub fn model_code(&self) -> Option { - match &self.inner { - EmbeddingInner::FastEmbed { model_name, .. } => Some(model_name.to_string()), - EmbeddingInner::Hashed { .. } => None, - } - } - - pub async fn embed(&self, text: &str) -> Result> { - match &self.inner { - EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)), - EmbeddingInner::FastEmbed { model, .. } => { - let mut guard = model.lock().await; - let embeddings = guard - .embed(vec![text.to_owned()], None) - .context("generating fastembed vector")?; - embeddings - .into_iter() - .next() - .ok_or_else(|| anyhow!("fastembed returned no embedding for input")) - } - } - } - - pub async fn embed_batch(&self, texts: Vec) -> Result>> { - match &self.inner { - EmbeddingInner::Hashed { dimension } => Ok(texts - .into_iter() - .map(|text| hashed_embedding(&text, *dimension)) - .collect()), - EmbeddingInner::FastEmbed { model, .. } => { - if texts.is_empty() { - return Ok(Vec::new()); - } - let mut guard = model.lock().await; - guard - .embed(texts, None) - .context("generating fastembed batch embeddings") - } - } - } -} - -pub async fn build_provider( - config: &Config, - default_dimension: usize, -) -> Result { - match config.embedding_backend { - EmbeddingBackend::Hashed => Ok(EmbeddingProvider { - inner: EmbeddingInner::Hashed { - dimension: default_dimension.max(1), - }, - }), - EmbeddingBackend::FastEmbed => { - let model_name = if let Some(code) = config.embedding_model.as_deref() { - EmbeddingModel::from_str(code).map_err(|err| anyhow!(err))? - } else { - EmbeddingModel::default() - }; - - let options = - TextInitOptions::new(model_name.clone()).with_show_download_progress(true); - let model_name_for_task = model_name.clone(); - let model_name_code = model_name.to_string(); - - let (model, dimension) = tokio::task::spawn_blocking(move || -> Result<_> { - let model = - TextEmbedding::try_new(options).context("initialising FastEmbed text model")?; - let info = - EmbeddingModel::get_model_info(&model_name_for_task).ok_or_else(|| { - anyhow!("FastEmbed model metadata missing for {model_name_code}") - })?; - Ok((model, info.dim)) - }) - .await - .context("joining FastEmbed initialisation task")??; - - Ok(EmbeddingProvider { - inner: EmbeddingInner::FastEmbed { - model: Arc::new(Mutex::new(model)), - model_name, - dimension, - }, - }) - } - } -} - -fn hashed_embedding(text: &str, dimension: usize) -> Vec { - let dim = dimension.max(1); - let mut vector = vec![0.0f32; dim]; - if text.is_empty() { - return vector; - } - - let mut token_count = 0f32; - for token in tokens(text) { - token_count += 1.0; - let idx = bucket(&token, dim); - vector[idx] += 1.0; - } - - if token_count == 0.0 { - return vector; - } - - let norm = vector.iter().map(|v| v * v).sum::().sqrt(); - if norm > 0.0 { - for value in &mut vector { - *value /= norm; - } - } - - vector -} - -fn tokens(text: &str) -> impl Iterator + '_ { - text.split(|c: char| !c.is_ascii_alphanumeric()) - .filter(|token| !token.is_empty()) - .map(|token| token.to_ascii_lowercase()) -} - -fn bucket(token: &str, dimension: usize) -> usize { - let mut hasher = DefaultHasher::new(); - token.hash(&mut hasher); - (hasher.finish() as usize) % dimension -} diff --git a/eval/src/ingest/config.rs b/eval/src/ingest/config.rs index c238cb8..3837a0d 100644 --- a/eval/src/ingest/config.rs +++ b/eval/src/ingest/config.rs @@ -1,9 +1,6 @@ use std::path::PathBuf; -use anyhow::Result; -use async_trait::async_trait; - -use crate::{args::Config, embedding::EmbeddingProvider}; +use crate::args::Config; #[derive(Debug, Clone)] pub struct CorpusCacheConfig { @@ -32,52 +29,6 @@ impl CorpusCacheConfig { } } -#[async_trait] -pub trait CorpusEmbeddingProvider: Send + Sync { - fn backend_label(&self) -> &str; - fn model_code(&self) -> Option; - fn dimension(&self) -> usize; - async fn embed_batch(&self, texts: Vec) -> Result>>; -} - -#[async_trait] -impl CorpusEmbeddingProvider for EmbeddingProvider { - fn backend_label(&self) -> &str { - EmbeddingProvider::backend_label(self) - } - - fn model_code(&self) -> Option { - EmbeddingProvider::model_code(self) - } - - fn dimension(&self) -> usize { - EmbeddingProvider::dimension(self) - } - - async fn embed_batch(&self, texts: Vec) -> Result>> { - EmbeddingProvider::embed_batch(self, texts).await - } -} - -#[async_trait] -impl CorpusEmbeddingProvider for common::utils::embedding::EmbeddingProvider { - fn backend_label(&self) -> &str { - common::utils::embedding::EmbeddingProvider::backend_label(self) - } - - fn model_code(&self) -> Option { - common::utils::embedding::EmbeddingProvider::model_code(self) - } - - fn dimension(&self) -> usize { - common::utils::embedding::EmbeddingProvider::dimension(self) - } - - async fn embed_batch(&self, texts: Vec) -> Result>> { - common::utils::embedding::EmbeddingProvider::embed_batch(self, texts).await - } -} - impl From<&Config> for CorpusCacheConfig { fn from(config: &Config) -> Self { CorpusCacheConfig::new( diff --git a/eval/src/main.rs b/eval/src/main.rs index f312acf..97d3773 100644 --- a/eval/src/main.rs +++ b/eval/src/main.rs @@ -2,7 +2,6 @@ mod args; mod cache; mod datasets; mod db_helpers; -mod embedding; mod eval; mod ingest; mod inspection; diff --git a/main/src/main.rs b/main/src/main.rs index 442dc00..7db36ca 100644 --- a/main/src/main.rs +++ b/main/src/main.rs @@ -112,7 +112,7 @@ async fn main() -> Result<(), Box> { .await .unwrap(), ); - let settings = SystemSettings::get_current(&worker_db) + let _settings = SystemSettings::get_current(&worker_db) .await .expect("failed to load system settings");