mirror of
https://github.com/perstarkse/minne.git
synced 2026-06-23 10:26:46 +02:00
feat: pool fastembed, batch embeddings, and reconcile embedding config on startup
This commit is contained in:
+126
-126
@@ -3,15 +3,16 @@ use std::{
|
||||
hash::{Hash, Hasher},
|
||||
str::FromStr,
|
||||
sync::{Arc, Mutex},
|
||||
thread::available_parallelism,
|
||||
};
|
||||
|
||||
use async_openai::{types::CreateEmbeddingRequestArgs, Client};
|
||||
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
|
||||
use tracing::debug;
|
||||
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
|
||||
|
||||
use crate::{
|
||||
error::{AppError, EmbeddingError},
|
||||
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
|
||||
error::EmbeddingError,
|
||||
storage::types::system_settings::SystemSettings,
|
||||
utils::config::AppConfig,
|
||||
};
|
||||
|
||||
@@ -45,8 +46,8 @@ enum EmbeddingInner {
|
||||
},
|
||||
/// Uses `FastEmbed` running locally.
|
||||
FastEmbed {
|
||||
/// Shared `FastEmbed` model (blocking; used only inside `spawn_blocking`).
|
||||
model: Arc<Mutex<TextEmbedding>>,
|
||||
/// Pool of `FastEmbed` engines providing bounded-concurrency local embedding.
|
||||
pool: Arc<FastEmbedPool>,
|
||||
/// Model metadata used for info logging.
|
||||
model_name: EmbeddingModel,
|
||||
/// Output vector length.
|
||||
@@ -54,19 +55,99 @@ enum EmbeddingInner {
|
||||
},
|
||||
}
|
||||
|
||||
/// Batch size used when re-embedding stored data in bulk. Bounds peak memory and preserves
|
||||
/// progress logging while still amortising per-call lock/dispatch overhead.
|
||||
pub const RE_EMBED_BATCH_SIZE: usize = 128;
|
||||
|
||||
/// Default FastEmbed pool size.
|
||||
///
|
||||
/// Kept small on purpose: the ONNX runtime already uses intra-op threads per inference, so
|
||||
/// running many engines concurrently oversubscribes the CPU and each engine duplicates the
|
||||
/// model weights in memory. Mirrors the reranker pool default.
|
||||
#[must_use]
|
||||
pub fn default_embedding_pool_size() -> usize {
|
||||
available_parallelism()
|
||||
.map_or(2, |value| value.get().min(2))
|
||||
.max(1)
|
||||
}
|
||||
|
||||
/// Pool of `FastEmbed` engines enabling bounded-concurrency local embedding.
|
||||
///
|
||||
/// A single [`TextEmbedding`] embeds one batch at a time (`&mut self`), so the pool keeps
|
||||
/// several instances and hands out a distinct idle engine per checkout. The semaphore bounds
|
||||
/// total in-flight embeds (backpressure); the free list guarantees each active lease holds a
|
||||
/// different engine — unlike a round-robin index, which can hand the same engine to two callers.
|
||||
struct FastEmbedPool {
|
||||
/// Idle engines; one is popped on checkout and returned on lease drop.
|
||||
engines: Mutex<Vec<Arc<Mutex<TextEmbedding>>>>,
|
||||
/// Sized to the engine count; gates concurrent checkouts.
|
||||
semaphore: Arc<Semaphore>,
|
||||
}
|
||||
|
||||
impl FastEmbedPool {
|
||||
fn new(engines: Vec<Arc<Mutex<TextEmbedding>>>) -> Self {
|
||||
let permits = engines.len().max(1);
|
||||
Self {
|
||||
engines: Mutex::new(engines),
|
||||
semaphore: Arc::new(Semaphore::new(permits)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Acquire a permit and borrow a distinct idle engine. The permit guarantees an engine is
|
||||
/// available, so the pop always succeeds for a correctly sized pool.
|
||||
async fn checkout(self: &Arc<Self>) -> Result<FastEmbedLease, EmbeddingError> {
|
||||
let permit = Arc::clone(&self.semaphore)
|
||||
.acquire_owned()
|
||||
.await
|
||||
.map_err(|_| EmbeddingError::Config("embedding pool is closed".into()))?;
|
||||
let engine = self
|
||||
.engines
|
||||
.lock()
|
||||
.map_err(EmbeddingError::mutex_poisoned)?
|
||||
.pop()
|
||||
.ok_or_else(|| EmbeddingError::Config("embedding pool unexpectedly empty".into()))?;
|
||||
Ok(FastEmbedLease {
|
||||
pool: Arc::clone(self),
|
||||
engine,
|
||||
_permit: permit,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Active borrow of a single `FastEmbed` engine; returns it to the pool on drop.
|
||||
struct FastEmbedLease {
|
||||
pool: Arc<FastEmbedPool>,
|
||||
engine: Arc<Mutex<TextEmbedding>>,
|
||||
/// Released after the engine is returned, unblocking the next checkout.
|
||||
_permit: OwnedSemaphorePermit,
|
||||
}
|
||||
|
||||
impl FastEmbedLease {
|
||||
async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||
let engine = Arc::clone(&self.engine);
|
||||
tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||
let mut guard = engine.lock().map_err(EmbeddingError::mutex_poisoned)?;
|
||||
guard.embed(texts, None).map_err(EmbeddingError::fastembed)
|
||||
})
|
||||
.await
|
||||
.map_err(EmbeddingError::from)?
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for FastEmbedLease {
|
||||
fn drop(&mut self) {
|
||||
if let Ok(mut free) = self.pool.engines.lock() {
|
||||
free.push(Arc::clone(&self.engine));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_fastembed(
|
||||
model: Arc<Mutex<TextEmbedding>>,
|
||||
pool: &Arc<FastEmbedPool>,
|
||||
texts: Vec<String>,
|
||||
) -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||
match tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||
let mut guard = model.lock().map_err(EmbeddingError::mutex_poisoned)?;
|
||||
guard.embed(texts, None).map_err(EmbeddingError::fastembed)
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(join_error) => Err(EmbeddingError::from(join_error)),
|
||||
}
|
||||
let lease = pool.checkout().await?;
|
||||
lease.embed(texts).await
|
||||
}
|
||||
|
||||
impl EmbeddingProvider {
|
||||
@@ -107,8 +188,8 @@ impl EmbeddingProvider {
|
||||
pub async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
|
||||
match &self.inner {
|
||||
EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)),
|
||||
EmbeddingInner::FastEmbed { model, .. } => {
|
||||
let embeddings = run_fastembed(Arc::clone(model), vec![text.to_owned()]).await?;
|
||||
EmbeddingInner::FastEmbed { pool, .. } => {
|
||||
let embeddings = run_fastembed(pool, vec![text.to_owned()]).await?;
|
||||
embeddings.into_iter().next().ok_or(EmbeddingError::NoData)
|
||||
}
|
||||
EmbeddingInner::OpenAI {
|
||||
@@ -148,11 +229,11 @@ impl EmbeddingProvider {
|
||||
.into_iter()
|
||||
.map(|text| hashed_embedding(&text, *dimension))
|
||||
.collect()),
|
||||
EmbeddingInner::FastEmbed { model, .. } => {
|
||||
EmbeddingInner::FastEmbed { pool, .. } => {
|
||||
if texts.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
run_fastembed(Arc::clone(model), texts).await
|
||||
run_fastembed(pool, texts).await
|
||||
}
|
||||
EmbeddingInner::OpenAI {
|
||||
client,
|
||||
@@ -199,30 +280,46 @@ impl EmbeddingProvider {
|
||||
})
|
||||
}
|
||||
|
||||
/// Initialise a local FastEmbed provider backed by a pool of `pool_size` engines.
|
||||
///
|
||||
/// `pool_size` is clamped to at least 1. Larger pools allow concurrent embeds at the cost of
|
||||
/// `pool_size`× model memory; see [`default_embedding_pool_size`] for guidance.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`EmbeddingError`] if the model name is unknown or FastEmbed initialisation fails.
|
||||
pub async fn new_fastembed(model_override: Option<String>) -> Result<Self, EmbeddingError> {
|
||||
pub async fn new_fastembed(
|
||||
model_override: Option<String>,
|
||||
pool_size: usize,
|
||||
) -> Result<Self, EmbeddingError> {
|
||||
let pool_size = pool_size.max(1);
|
||||
let model_name = if let Some(code) = model_override {
|
||||
EmbeddingModel::from_str(&code).map_err(EmbeddingError::UnknownModel)?
|
||||
} else {
|
||||
EmbeddingModel::default()
|
||||
};
|
||||
|
||||
let options = TextInitOptions::new(model_name.clone()).with_show_download_progress(true);
|
||||
let model_name_for_task = model_name.clone();
|
||||
let model_name_code = model_name.to_string();
|
||||
|
||||
let (model, dimension) =
|
||||
let (engines, dimension) =
|
||||
match tokio::task::spawn_blocking(move || -> Result<_, EmbeddingError> {
|
||||
let model = TextEmbedding::try_new(options).map_err(EmbeddingError::fastembed)?;
|
||||
let info =
|
||||
EmbeddingModel::get_model_info(&model_name_for_task).ok_or_else(|| {
|
||||
EmbeddingError::Config(format!(
|
||||
"fastembed model metadata missing for {model_name_code}"
|
||||
))
|
||||
})?;
|
||||
Ok((model, info.dim))
|
||||
let mut engines = Vec::with_capacity(pool_size);
|
||||
for index in 0..pool_size {
|
||||
let options = TextInitOptions::new(model_name_for_task.clone())
|
||||
// Only the first engine reports download progress; the rest reuse the cache.
|
||||
.with_show_download_progress(index == 0);
|
||||
let model =
|
||||
TextEmbedding::try_new(options).map_err(EmbeddingError::fastembed)?;
|
||||
engines.push(Arc::new(Mutex::new(model)));
|
||||
}
|
||||
Ok((engines, info.dim))
|
||||
})
|
||||
.await
|
||||
{
|
||||
@@ -232,7 +329,7 @@ impl EmbeddingProvider {
|
||||
|
||||
Ok(EmbeddingProvider {
|
||||
inner: EmbeddingInner::FastEmbed {
|
||||
model: Arc::new(Mutex::new(model)),
|
||||
pool: Arc::new(FastEmbedPool::new(engines)),
|
||||
model_name,
|
||||
dimension,
|
||||
},
|
||||
@@ -275,7 +372,10 @@ impl EmbeddingProvider {
|
||||
Self::new_openai(client, settings.embedding_model.clone(), dimensions)
|
||||
}
|
||||
EmbeddingBackend::FastEmbed => {
|
||||
Self::new_fastembed(Some(settings.embedding_model.clone())).await
|
||||
let pool_size = config
|
||||
.embedding_pool_size
|
||||
.unwrap_or_else(default_embedding_pool_size);
|
||||
Self::new_fastembed(Some(settings.embedding_model.clone()), pool_size).await
|
||||
}
|
||||
EmbeddingBackend::Hashed => {
|
||||
let dimension = usize::try_from(dimensions).map_err(|_| {
|
||||
@@ -329,106 +429,6 @@ fn bucket(token: &str, dimension: usize) -> usize {
|
||||
usize::try_from(hasher.finish()).unwrap_or_default() % safe_dimension
|
||||
}
|
||||
|
||||
/// Generate an embedding using the given provider.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`AppError::Embedding`] if the provider's embed call fails.
|
||||
pub async fn generate_embedding_with_provider(
|
||||
provider: &EmbeddingProvider,
|
||||
input: &str,
|
||||
) -> Result<Vec<f32>, AppError> {
|
||||
provider.embed(input).await.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Generates an embedding vector for the given input text using `OpenAI`'s embedding model.
|
||||
///
|
||||
/// This function takes a text input and converts it into a numerical vector representation (embedding)
|
||||
/// using `OpenAI`'s text-embedding-3-small model. These embeddings can be used for semantic similarity
|
||||
/// comparisons, vector search, and other natural language processing tasks.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `client`: The `OpenAI` client instance used to make API requests.
|
||||
/// * `input`: The text string to generate embeddings for.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns a `Result` containing either:
|
||||
/// * `Ok(Vec<f32>)`: A vector of 32-bit floating point numbers representing the text embedding
|
||||
/// * `Err(ProcessingError)`: An error if the embedding generation fails
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// This function can return a `AppError` in the following cases:
|
||||
/// * If the `OpenAI` API request fails
|
||||
/// * If the request building fails
|
||||
/// * If no embedding data is received in the response
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
pub async fn generate_embedding(
|
||||
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
input: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Vec<f32>, AppError> {
|
||||
let model = SystemSettings::get_current(db).await?;
|
||||
|
||||
let request = CreateEmbeddingRequestArgs::default()
|
||||
.model(model.embedding_model)
|
||||
.dimensions(model.embedding_dimensions)
|
||||
.input([input])
|
||||
.build()?;
|
||||
|
||||
// Send the request to OpenAI
|
||||
let response = client.embeddings().create(request).await?;
|
||||
|
||||
// Extract the embedding vector
|
||||
let embedding: Vec<f32> = response
|
||||
.data
|
||||
.first()
|
||||
.ok_or_else(|| AppError::LLMParsing("no embedding data received".into()))?
|
||||
.embedding
|
||||
.clone();
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
/// Generates an embedding vector using a specific model and dimension.
|
||||
///
|
||||
/// This is used for the re-embedding process where the model and dimensions
|
||||
/// are known ahead of time and shouldn't be repeatedly fetched from settings.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `AppError` if the OpenAI API request fails or returns no embedding data.
|
||||
pub async fn generate_embedding_with_params(
|
||||
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
input: &str,
|
||||
model: &str,
|
||||
dimensions: u32,
|
||||
) -> Result<Vec<f32>, AppError> {
|
||||
let request = CreateEmbeddingRequestArgs::default()
|
||||
.model(model)
|
||||
.input([input])
|
||||
.dimensions(dimensions)
|
||||
.build()?;
|
||||
|
||||
let response = client.embeddings().create(request).await?;
|
||||
|
||||
let embedding = response
|
||||
.data
|
||||
.first()
|
||||
.ok_or_else(|| AppError::LLMParsing("no embedding data received from API".into()))?
|
||||
.embedding
|
||||
.clone();
|
||||
|
||||
debug!(
|
||||
"Embedding was created with {:?} dimensions",
|
||||
embedding.len()
|
||||
);
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::expect_used)]
|
||||
|
||||
Reference in New Issue
Block a user