feat: pool fastembed, batch embeddings, and reconcile embedding config on startup

This commit is contained in:
Per Stark
2026-06-03 22:10:33 +02:00
parent 7b850769c9
commit 15c9f18f6e
24 changed files with 565 additions and 546 deletions
+49
View File
@@ -211,6 +211,27 @@ pub async fn rebuild(db: &SurrealDbClient) -> Result<(), AppError> {
rebuild_inner(db).await.map_err(AppError::internal)
}
/// Returns the dimension of the currently defined chunk-embedding HNSW index, if any.
///
/// Stored embeddings always share this index's dimension because re-embedding rewrites the
/// vectors and the index together, so it acts as a persisted marker of the embedding space
/// actually present in the database. Returns `Ok(None)` when the index has not been created yet
/// (for example on a fresh database with no ingested data).
///
/// # Errors
///
/// Returns `AppError::InternalError` if the index metadata cannot be read.
pub async fn embedding_index_dimension(db: &SurrealDbClient) -> Result<Option<usize>, AppError> {
let spec = HnswIndexSpec {
index_name: "idx_embedding_text_chunk_embedding",
table: "text_chunk_embedding",
options: HNSW_INDEX_OPTIONS,
};
existing_hnsw_dimension(db, &spec)
.await
.map_err(AppError::internal)
}
async fn ensure_runtime_inner(db: &SurrealDbClient, embedding_dimension: usize) -> Result<()> {
create_fts_analyzer(db).await?;
@@ -906,6 +927,34 @@ mod tests {
Ok(())
}
#[tokio::test]
async fn embedding_index_dimension_reflects_runtime_state() -> anyhow::Result<()> {
let namespace = "indexes_marker";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.context("in-memory db")?;
db.apply_migrations()
.await
.context("migrations should succeed")?;
// Before any index exists, there is no stored embedding dimension to detect.
assert_eq!(embedding_index_dimension(&db).await?, None);
ensure_runtime(&db, 1536)
.await
.context("initial index creation")?;
assert_eq!(embedding_index_dimension(&db).await?, Some(1536));
// After a dimension change the marker tracks the new index dimension.
ensure_runtime(&db, 256)
.await
.context("overwritten index creation")?;
assert_eq!(embedding_index_dimension(&db).await?, Some(256));
Ok(())
}
#[tokio::test]
async fn ensure_hnsw_index_overwrites_dimension() -> anyhow::Result<()> {
let namespace = "indexes_dim";
+47 -141
View File
@@ -6,12 +6,7 @@ use crate::{
error::AppError, storage::db::SurrealDbClient, storage::indexes::hnsw_index_overwrite_sql,
storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding,
storage::types::system_settings::SystemSettings, stored_object,
utils::embedding::{generate_embedding_with_params, generate_embedding_with_provider, EmbeddingProvider},
};
use async_openai::{config::OpenAIConfig, Client};
use tokio_retry::{
strategy::{jitter, ExponentialBackoff},
Retry,
utils::embedding::{EmbeddingProvider, RE_EMBED_BATCH_SIZE},
};
use tracing::{error, info};
use uuid::Uuid;
@@ -321,8 +316,7 @@ impl KnowledgeEntity {
) -> Result<(), AppError> {
let embedding_input =
format!("name: {name}, description: {description}, type: {entity_type:?}",);
let embedding =
generate_embedding_with_provider(embedding_provider, &embedding_input).await?;
let embedding = embedding_provider.embed(&embedding_input).await?;
let entity: KnowledgeEntity = db_client
.get_item(id)
@@ -368,120 +362,17 @@ impl KnowledgeEntity {
Ok(())
}
/// Re-creates embeddings for all knowledge entities in the database.
/// Re-creates embeddings for all knowledge entities using an `EmbeddingProvider`.
///
/// This is a costly operation that should be run in the background. It follows the same
/// pattern as the text chunk update:
/// 1. Re-defines the vector index with the new dimensions.
/// 2. Fetches all existing entities.
/// 3. Sequentially regenerates the embedding for each and updates the record.
#[allow(clippy::too_many_lines)]
pub async fn update_all_embeddings(
db: &SurrealDbClient,
openai_client: &Client<OpenAIConfig>,
new_model: &str,
new_dimensions: u32,
) -> Result<(), AppError> {
info!(
"Starting re-embedding process for all knowledge entities. New dimensions: {}",
new_dimensions
);
// Fetch all entities first
let all_entities: Vec<KnowledgeEntity> = db.select(Self::table_name()).await?;
let total_entities = all_entities.len();
if total_entities == 0 {
info!("No knowledge entities to update. Just updating the idx");
KnowledgeEntityEmbedding::redefine_hnsw_index(db, new_dimensions as usize).await?;
return Ok(());
}
info!("Found {total_entities} entities to process.");
// Generate all new embeddings in memory
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new();
info!("Generating new embeddings for all entities...");
for entity in &all_entities {
let embedding_input = format!(
"name: {}, description: {}, type: {:?}",
entity.name, entity.description, entity.entity_type
);
let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3);
let embedding = Retry::spawn(retry_strategy, || {
generate_embedding_with_params(
openai_client,
&embedding_input,
new_model,
new_dimensions,
)
})
.await?;
// Check embedding lengths
if embedding.len() != new_dimensions as usize {
let err_msg = format!(
"CRITICAL: Generated embedding for entity {} has incorrect dimension ({}). Expected {}. Aborting.",
entity.id, embedding.len(), new_dimensions
);
error!("{err_msg}");
return Err(AppError::internal(err_msg));
}
new_embeddings.insert(
entity.id.clone(),
(embedding, entity.user_id.clone(), entity.source_id.clone()),
);
}
info!("Successfully generated all new embeddings.");
// Perform DB updates in a single transaction
info!("Applying embedding updates in a transaction...");
let mut transaction_query = String::from("BEGIN TRANSACTION;");
// Add all update statements to the embedding table
for (id, (embedding, user_id, source_id)) in new_embeddings {
let embedding = serde_json::to_string(&embedding)
.map_err(|e| AppError::internal(format!("embedding serialization failed: {e}")))?;
write!(
transaction_query,
"UPSERT type::thing('knowledge_entity_embedding', '{id}') SET \
entity_id = type::thing('knowledge_entity', '{id}'), \
embedding = {embedding}, \
user_id = '{user_id}', \
source_id = '{source_id}', \
created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \
updated_at = time::now();",
)
.map_err(AppError::internal)?;
}
write!(
transaction_query,
"{}",
hnsw_index_overwrite_sql(
"idx_embedding_knowledge_entity_embedding",
KnowledgeEntityEmbedding::table_name(),
new_dimensions as usize,
)
)
.map_err(AppError::internal)?;
transaction_query.push_str("COMMIT TRANSACTION;");
// Execute the entire atomic operation
db.query(transaction_query).await?;
info!("Re-embedding process for knowledge entities completed successfully.");
Ok(())
}
/// Re-creates embeddings for all knowledge entities using an `EmbeddingProvider`.
///
/// This variant uses the application's configured embedding provider (FastEmbed, OpenAI, etc.)
/// instead of directly calling OpenAI. Used during startup when embedding configuration changes.
#[allow(clippy::too_many_lines)]
pub async fn update_all_embeddings_with_provider(
db: &SurrealDbClient,
provider: &crate::utils::embedding::EmbeddingProvider,
provider: &EmbeddingProvider,
) -> Result<(), AppError> {
let new_dimensions = provider.dimension();
info!(
@@ -500,38 +391,53 @@ impl KnowledgeEntity {
}
info!(entities = total_entities, "Found entities to process");
// Generate all new embeddings in memory
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new();
// Generate all new embeddings in memory, batching to amortise lock/dispatch overhead
// while keeping memory bounded and preserving progress logging.
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> =
HashMap::with_capacity(total_entities);
info!("Generating new embeddings for all entities...");
for (i, entity) in all_entities.iter().enumerate() {
if i > 0 && i % 100 == 0 {
info!(
progress = i,
total = total_entities,
"Re-embedding progress"
let mut processed = 0usize;
for batch in all_entities.chunks(RE_EMBED_BATCH_SIZE) {
let inputs: Vec<String> = batch
.iter()
.map(|entity| {
format!(
"name: {}, description: {}, type: {:?}",
entity.name, entity.description, entity.entity_type
)
})
.collect();
let embeddings = provider.embed_batch(inputs).await?;
if embeddings.len() != batch.len() {
return Err(AppError::internal(format!(
"embedding batch returned {} vectors for {} entities",
embeddings.len(),
batch.len()
)));
}
for (entity, embedding) in batch.iter().zip(embeddings) {
// Safety check: ensure the generated embedding has the correct dimension.
if embedding.len() != new_dimensions {
let err_msg = format!(
"CRITICAL: Generated embedding for entity {} has incorrect dimension ({}). Expected {}. Aborting.",
entity.id, embedding.len(), new_dimensions
);
error!("{err_msg}");
return Err(AppError::internal(err_msg));
}
new_embeddings.insert(
entity.id.clone(),
(embedding, entity.user_id.clone(), entity.source_id.clone()),
);
}
let embedding_input = format!(
"name: {}, description: {}, type: {:?}",
entity.name, entity.description, entity.entity_type
);
let embedding = provider.embed(&embedding_input).await?;
// Safety check: ensure the generated embedding has the correct dimension.
if embedding.len() != new_dimensions {
let err_msg = format!(
"CRITICAL: Generated embedding for entity {} has incorrect dimension ({}). Expected {}. Aborting.",
entity.id, embedding.len(), new_dimensions
);
error!("{err_msg}");
return Err(AppError::internal(err_msg));
}
new_embeddings.insert(
entity.id.clone(),
(embedding, entity.user_id.clone(), entity.source_id.clone()),
processed = processed.saturating_add(batch.len());
info!(
progress = processed,
total = total_entities,
"Re-embedding progress"
);
}
info!("Successfully generated all new embeddings.");
+6 -4
View File
@@ -235,7 +235,6 @@ mod tests {
use crate::storage::indexes::ensure_runtime;
use crate::storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk};
use anyhow::{self, Context};
use async_openai::Client;
use super::*;
use uuid::Uuid;
@@ -713,6 +712,8 @@ mod tests {
#[tokio::test]
async fn test_should_change_embedding_length_on_indexes_when_switching_length(
) -> anyhow::Result<()> {
use crate::utils::embedding::EmbeddingProvider;
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
.await
.with_context(|| "Failed to start DB".to_string())?;
@@ -758,12 +759,13 @@ mod tests {
"Settings should reflect the new embedding dimension"
);
let openai_client = Client::new();
let provider = EmbeddingProvider::new_hashed(new_dimension as usize)
.map_err(|e| anyhow::anyhow!("{e}"))?;
TextChunk::update_all_embeddings(&db, &openai_client, &new_model, new_dimension)
TextChunk::update_all_embeddings(&db, &provider)
.await
.with_context(|| "TextChunk re-embedding should succeed on fresh DB".to_string())?;
KnowledgeEntity::update_all_embeddings(&db, &openai_client, &new_model, new_dimension)
KnowledgeEntity::update_all_embeddings(&db, &provider)
.await
.with_context(|| {
"KnowledgeEntity re-embedding should succeed on fresh DB".to_string()
+33 -126
View File
@@ -5,12 +5,8 @@ use std::fmt::Write;
use crate::storage::indexes::hnsw_index_overwrite_sql;
use crate::storage::types::system_settings::SystemSettings;
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
use crate::utils::embedding::RE_EMBED_BATCH_SIZE;
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
use async_openai::{config::OpenAIConfig, Client};
use tokio_retry::{
strategy::{jitter, ExponentialBackoff},
Retry,
};
use tracing::{error, info, warn};
use uuid::Uuid;
@@ -238,7 +234,7 @@ impl TextChunk {
.collect())
}
/// Re-creates embeddings for all text chunks using a safe, atomic transaction.
/// Re-creates embeddings for all text chunks using an `EmbeddingProvider`.
///
/// This is a costly operation that should be run in the background. It performs these steps:
/// 1. **Fetches All Chunks**: Loads all existing text_chunk records into memory.
@@ -246,109 +242,8 @@ impl TextChunk {
/// has the wrong dimension, the entire operation is aborted before any DB changes are made.
/// 3. **Executes Atomic Transaction**: All data updates and the index recreation are
/// performed in a single, all-or-nothing database transaction.
#[allow(clippy::too_many_lines)]
pub async fn update_all_embeddings(
db: &SurrealDbClient,
openai_client: &Client<OpenAIConfig>,
new_model: &str,
new_dimensions: u32,
) -> Result<(), AppError> {
info!(
"Starting re-embedding process for all text chunks. New dimensions: {new_dimensions}"
);
// Fetch all chunks first
let all_chunks: Vec<TextChunk> = db.select(Self::table_name()).await?;
let total_chunks = all_chunks.len();
if total_chunks == 0 {
info!("No text chunks to update. Just updating the idx");
TextChunkEmbedding::redefine_hnsw_index(db, new_dimensions as usize).await?;
return Ok(());
}
info!("Found {total_chunks} chunks to process.");
// Generate all new embeddings in memory
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new();
info!("Generating new embeddings for all chunks...");
for chunk in &all_chunks {
let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3);
let embedding = Retry::spawn(retry_strategy, || {
crate::utils::embedding::generate_embedding_with_params(
openai_client,
&chunk.chunk,
new_model,
new_dimensions,
)
})
.await?;
// Safety check: ensure the generated embedding has the correct dimension.
if embedding.len() != new_dimensions as usize {
let err_msg = format!(
"CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.",
chunk.id, embedding.len(), new_dimensions
);
error!("{err_msg}");
return Err(AppError::internal(err_msg));
}
new_embeddings.insert(
chunk.id.clone(),
(embedding, chunk.user_id.clone(), chunk.source_id.clone()),
);
}
info!("Successfully generated all new embeddings.");
// Perform DB updates in a single transaction against the embedding table
info!("Applying embedding updates in a transaction...");
let mut transaction_query = String::from("BEGIN TRANSACTION;");
for (id, (embedding, user_id, source_id)) in new_embeddings {
let embedding = serde_json::to_string(&embedding)
.map_err(|e| AppError::internal(format!("embedding serialization failed: {e}")))?;
let id = surql_json_string(&id)?;
let user_id = surql_json_string(&user_id)?;
let source_id = surql_json_string(&source_id)?;
write!(
&mut transaction_query,
"UPSERT type::thing('{emb_table}', {id}) SET \
chunk_id = type::thing('{chunk_table}', {id}), \
source_id = {source_id}, \
embedding = {embedding}, \
user_id = {user_id}, \
created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \
updated_at = time::now();",
emb_table = TextChunkEmbedding::table_name(),
chunk_table = Self::table_name(),
)
.map_err(AppError::internal)?;
}
write!(
&mut transaction_query,
"{}",
hnsw_index_overwrite_sql(
"idx_embedding_text_chunk_embedding",
TextChunkEmbedding::table_name(),
new_dimensions as usize,
)
)
.map_err(AppError::internal)?;
transaction_query.push_str("COMMIT TRANSACTION;");
db.query(transaction_query).await?;
info!("Re-embedding process for text chunks completed successfully.");
Ok(())
}
/// Re-creates embeddings for all text chunks using an `EmbeddingProvider`.
///
/// This variant uses the application's configured embedding provider (FastEmbed, OpenAI, etc.)
/// instead of directly calling OpenAI. Used during startup when embedding configuration changes.
pub async fn update_all_embeddings_with_provider(
db: &SurrealDbClient,
provider: &crate::utils::embedding::EmbeddingProvider,
) -> Result<(), AppError> {
@@ -369,30 +264,42 @@ impl TextChunk {
}
info!(chunks = total_chunks, "Found chunks to process");
// Generate all new embeddings in memory
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new();
// Generate all new embeddings in memory, batching to amortise lock/dispatch overhead
// while keeping memory bounded and preserving progress logging.
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> =
HashMap::with_capacity(total_chunks);
info!("Generating new embeddings for all chunks...");
for (i, chunk) in all_chunks.iter().enumerate() {
if i > 0 && i % 100 == 0 {
info!(progress = i, total = total_chunks, "Re-embedding progress");
let mut processed = 0usize;
for batch in all_chunks.chunks(RE_EMBED_BATCH_SIZE) {
let inputs: Vec<String> = batch.iter().map(|chunk| chunk.chunk.clone()).collect();
let embeddings = provider.embed_batch(inputs).await?;
if embeddings.len() != batch.len() {
return Err(AppError::internal(format!(
"embedding batch returned {} vectors for {} chunks",
embeddings.len(),
batch.len()
)));
}
let embedding = provider.embed(&chunk.chunk).await?;
// Safety check: ensure the generated embedding has the correct dimension.
if embedding.len() != new_dimensions {
let err_msg = format!(
"CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.",
chunk.id, embedding.len(), new_dimensions
for (chunk, embedding) in batch.iter().zip(embeddings) {
// Safety check: ensure the generated embedding has the correct dimension.
if embedding.len() != new_dimensions {
let err_msg = format!(
"CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.",
chunk.id, embedding.len(), new_dimensions
);
error!("{err_msg}");
return Err(AppError::internal(err_msg));
}
new_embeddings.insert(
chunk.id.clone(),
(embedding, chunk.user_id.clone(), chunk.source_id.clone()),
);
error!("{err_msg}");
return Err(AppError::internal(err_msg));
}
new_embeddings.insert(
chunk.id.clone(),
(embedding, chunk.user_id.clone(), chunk.source_id.clone()),
);
processed = processed.saturating_add(batch.len());
info!(progress = processed, total = total_chunks, "Re-embedding progress");
}
info!("Successfully generated all new embeddings.");
+3
View File
@@ -119,6 +119,8 @@ pub struct AppConfig {
pub fastembed_max_length: Option<usize>,
#[serde(default)]
pub embedding_backend: EmbeddingBackend,
#[serde(default)]
pub embedding_pool_size: Option<usize>,
#[serde(default = "default_ingest_max_body_bytes")]
pub ingest_max_body_bytes: usize,
#[serde(default = "default_ingest_max_files")]
@@ -225,6 +227,7 @@ impl Default for AppConfig {
fastembed_show_download_progress: None,
fastembed_max_length: None,
embedding_backend: EmbeddingBackend::default(),
embedding_pool_size: None,
ingest_max_body_bytes: default_ingest_max_body_bytes(),
ingest_max_files: default_ingest_max_files(),
ingest_max_content_bytes: default_ingest_max_content_bytes(),
+126 -126
View File
@@ -3,15 +3,16 @@ use std::{
hash::{Hash, Hasher},
str::FromStr,
sync::{Arc, Mutex},
thread::available_parallelism,
};
use async_openai::{types::CreateEmbeddingRequestArgs, Client};
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
use tracing::debug;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use crate::{
error::{AppError, EmbeddingError},
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
error::EmbeddingError,
storage::types::system_settings::SystemSettings,
utils::config::AppConfig,
};
@@ -45,8 +46,8 @@ enum EmbeddingInner {
},
/// Uses `FastEmbed` running locally.
FastEmbed {
/// Shared `FastEmbed` model (blocking; used only inside `spawn_blocking`).
model: Arc<Mutex<TextEmbedding>>,
/// Pool of `FastEmbed` engines providing bounded-concurrency local embedding.
pool: Arc<FastEmbedPool>,
/// Model metadata used for info logging.
model_name: EmbeddingModel,
/// Output vector length.
@@ -54,19 +55,99 @@ enum EmbeddingInner {
},
}
/// Batch size used when re-embedding stored data in bulk. Bounds peak memory and preserves
/// progress logging while still amortising per-call lock/dispatch overhead.
pub const RE_EMBED_BATCH_SIZE: usize = 128;
/// Default FastEmbed pool size.
///
/// Kept small on purpose: the ONNX runtime already uses intra-op threads per inference, so
/// running many engines concurrently oversubscribes the CPU and each engine duplicates the
/// model weights in memory. Mirrors the reranker pool default.
#[must_use]
pub fn default_embedding_pool_size() -> usize {
available_parallelism()
.map_or(2, |value| value.get().min(2))
.max(1)
}
/// Pool of `FastEmbed` engines enabling bounded-concurrency local embedding.
///
/// A single [`TextEmbedding`] embeds one batch at a time (`&mut self`), so the pool keeps
/// several instances and hands out a distinct idle engine per checkout. The semaphore bounds
/// total in-flight embeds (backpressure); the free list guarantees each active lease holds a
/// different engine — unlike a round-robin index, which can hand the same engine to two callers.
struct FastEmbedPool {
/// Idle engines; one is popped on checkout and returned on lease drop.
engines: Mutex<Vec<Arc<Mutex<TextEmbedding>>>>,
/// Sized to the engine count; gates concurrent checkouts.
semaphore: Arc<Semaphore>,
}
impl FastEmbedPool {
fn new(engines: Vec<Arc<Mutex<TextEmbedding>>>) -> Self {
let permits = engines.len().max(1);
Self {
engines: Mutex::new(engines),
semaphore: Arc::new(Semaphore::new(permits)),
}
}
/// Acquire a permit and borrow a distinct idle engine. The permit guarantees an engine is
/// available, so the pop always succeeds for a correctly sized pool.
async fn checkout(self: &Arc<Self>) -> Result<FastEmbedLease, EmbeddingError> {
let permit = Arc::clone(&self.semaphore)
.acquire_owned()
.await
.map_err(|_| EmbeddingError::Config("embedding pool is closed".into()))?;
let engine = self
.engines
.lock()
.map_err(EmbeddingError::mutex_poisoned)?
.pop()
.ok_or_else(|| EmbeddingError::Config("embedding pool unexpectedly empty".into()))?;
Ok(FastEmbedLease {
pool: Arc::clone(self),
engine,
_permit: permit,
})
}
}
/// Active borrow of a single `FastEmbed` engine; returns it to the pool on drop.
struct FastEmbedLease {
pool: Arc<FastEmbedPool>,
engine: Arc<Mutex<TextEmbedding>>,
/// Released after the engine is returned, unblocking the next checkout.
_permit: OwnedSemaphorePermit,
}
impl FastEmbedLease {
async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, EmbeddingError> {
let engine = Arc::clone(&self.engine);
tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>, EmbeddingError> {
let mut guard = engine.lock().map_err(EmbeddingError::mutex_poisoned)?;
guard.embed(texts, None).map_err(EmbeddingError::fastembed)
})
.await
.map_err(EmbeddingError::from)?
}
}
impl Drop for FastEmbedLease {
fn drop(&mut self) {
if let Ok(mut free) = self.pool.engines.lock() {
free.push(Arc::clone(&self.engine));
}
}
}
async fn run_fastembed(
model: Arc<Mutex<TextEmbedding>>,
pool: &Arc<FastEmbedPool>,
texts: Vec<String>,
) -> Result<Vec<Vec<f32>>, EmbeddingError> {
match tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>, EmbeddingError> {
let mut guard = model.lock().map_err(EmbeddingError::mutex_poisoned)?;
guard.embed(texts, None).map_err(EmbeddingError::fastembed)
})
.await
{
Ok(result) => result,
Err(join_error) => Err(EmbeddingError::from(join_error)),
}
let lease = pool.checkout().await?;
lease.embed(texts).await
}
impl EmbeddingProvider {
@@ -107,8 +188,8 @@ impl EmbeddingProvider {
pub async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
match &self.inner {
EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)),
EmbeddingInner::FastEmbed { model, .. } => {
let embeddings = run_fastembed(Arc::clone(model), vec![text.to_owned()]).await?;
EmbeddingInner::FastEmbed { pool, .. } => {
let embeddings = run_fastembed(pool, vec![text.to_owned()]).await?;
embeddings.into_iter().next().ok_or(EmbeddingError::NoData)
}
EmbeddingInner::OpenAI {
@@ -148,11 +229,11 @@ impl EmbeddingProvider {
.into_iter()
.map(|text| hashed_embedding(&text, *dimension))
.collect()),
EmbeddingInner::FastEmbed { model, .. } => {
EmbeddingInner::FastEmbed { pool, .. } => {
if texts.is_empty() {
return Ok(Vec::new());
}
run_fastembed(Arc::clone(model), texts).await
run_fastembed(pool, texts).await
}
EmbeddingInner::OpenAI {
client,
@@ -199,30 +280,46 @@ impl EmbeddingProvider {
})
}
/// Initialise a local FastEmbed provider backed by a pool of `pool_size` engines.
///
/// `pool_size` is clamped to at least 1. Larger pools allow concurrent embeds at the cost of
/// `pool_size`× model memory; see [`default_embedding_pool_size`] for guidance.
///
/// # Errors
///
/// Returns [`EmbeddingError`] if the model name is unknown or FastEmbed initialisation fails.
pub async fn new_fastembed(model_override: Option<String>) -> Result<Self, EmbeddingError> {
pub async fn new_fastembed(
model_override: Option<String>,
pool_size: usize,
) -> Result<Self, EmbeddingError> {
let pool_size = pool_size.max(1);
let model_name = if let Some(code) = model_override {
EmbeddingModel::from_str(&code).map_err(EmbeddingError::UnknownModel)?
} else {
EmbeddingModel::default()
};
let options = TextInitOptions::new(model_name.clone()).with_show_download_progress(true);
let model_name_for_task = model_name.clone();
let model_name_code = model_name.to_string();
let (model, dimension) =
let (engines, dimension) =
match tokio::task::spawn_blocking(move || -> Result<_, EmbeddingError> {
let model = TextEmbedding::try_new(options).map_err(EmbeddingError::fastembed)?;
let info =
EmbeddingModel::get_model_info(&model_name_for_task).ok_or_else(|| {
EmbeddingError::Config(format!(
"fastembed model metadata missing for {model_name_code}"
))
})?;
Ok((model, info.dim))
let mut engines = Vec::with_capacity(pool_size);
for index in 0..pool_size {
let options = TextInitOptions::new(model_name_for_task.clone())
// Only the first engine reports download progress; the rest reuse the cache.
.with_show_download_progress(index == 0);
let model =
TextEmbedding::try_new(options).map_err(EmbeddingError::fastembed)?;
engines.push(Arc::new(Mutex::new(model)));
}
Ok((engines, info.dim))
})
.await
{
@@ -232,7 +329,7 @@ impl EmbeddingProvider {
Ok(EmbeddingProvider {
inner: EmbeddingInner::FastEmbed {
model: Arc::new(Mutex::new(model)),
pool: Arc::new(FastEmbedPool::new(engines)),
model_name,
dimension,
},
@@ -275,7 +372,10 @@ impl EmbeddingProvider {
Self::new_openai(client, settings.embedding_model.clone(), dimensions)
}
EmbeddingBackend::FastEmbed => {
Self::new_fastembed(Some(settings.embedding_model.clone())).await
let pool_size = config
.embedding_pool_size
.unwrap_or_else(default_embedding_pool_size);
Self::new_fastembed(Some(settings.embedding_model.clone()), pool_size).await
}
EmbeddingBackend::Hashed => {
let dimension = usize::try_from(dimensions).map_err(|_| {
@@ -329,106 +429,6 @@ fn bucket(token: &str, dimension: usize) -> usize {
usize::try_from(hasher.finish()).unwrap_or_default() % safe_dimension
}
/// Generate an embedding using the given provider.
///
/// # Errors
///
/// Returns [`AppError::Embedding`] if the provider's embed call fails.
pub async fn generate_embedding_with_provider(
provider: &EmbeddingProvider,
input: &str,
) -> Result<Vec<f32>, AppError> {
provider.embed(input).await.map_err(Into::into)
}
/// Generates an embedding vector for the given input text using `OpenAI`'s embedding model.
///
/// This function takes a text input and converts it into a numerical vector representation (embedding)
/// using `OpenAI`'s text-embedding-3-small model. These embeddings can be used for semantic similarity
/// comparisons, vector search, and other natural language processing tasks.
///
/// # Arguments
///
/// * `client`: The `OpenAI` client instance used to make API requests.
/// * `input`: The text string to generate embeddings for.
///
/// # Returns
///
/// Returns a `Result` containing either:
/// * `Ok(Vec<f32>)`: A vector of 32-bit floating point numbers representing the text embedding
/// * `Err(ProcessingError)`: An error if the embedding generation fails
///
/// # Errors
///
/// This function can return a `AppError` in the following cases:
/// * If the `OpenAI` API request fails
/// * If the request building fails
/// * If no embedding data is received in the response
#[allow(clippy::module_name_repetitions)]
pub async fn generate_embedding(
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
input: &str,
db: &SurrealDbClient,
) -> Result<Vec<f32>, AppError> {
let model = SystemSettings::get_current(db).await?;
let request = CreateEmbeddingRequestArgs::default()
.model(model.embedding_model)
.dimensions(model.embedding_dimensions)
.input([input])
.build()?;
// Send the request to OpenAI
let response = client.embeddings().create(request).await?;
// Extract the embedding vector
let embedding: Vec<f32> = response
.data
.first()
.ok_or_else(|| AppError::LLMParsing("no embedding data received".into()))?
.embedding
.clone();
Ok(embedding)
}
/// Generates an embedding vector using a specific model and dimension.
///
/// This is used for the re-embedding process where the model and dimensions
/// are known ahead of time and shouldn't be repeatedly fetched from settings.
///
/// # Errors
///
/// Returns `AppError` if the OpenAI API request fails or returns no embedding data.
pub async fn generate_embedding_with_params(
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
input: &str,
model: &str,
dimensions: u32,
) -> Result<Vec<f32>, AppError> {
let request = CreateEmbeddingRequestArgs::default()
.model(model)
.input([input])
.dimensions(dimensions)
.build()?;
let response = client.embeddings().create(request).await?;
let embedding = response
.data
.first()
.ok_or_else(|| AppError::LLMParsing("no embedding data received from API".into()))?
.embedding
.clone();
debug!(
"Embedding was created with {:?} dimensions",
embedding.len()
);
Ok(embedding)
}
#[cfg(test)]
mod tests {
#![allow(clippy::expect_used)]