mirror of
https://github.com/perstarkse/minne.git
synced 2026-06-12 17:24:26 +02:00
feat: pool fastembed, batch embeddings, and reconcile embedding config on startup
This commit is contained in:
@@ -211,6 +211,27 @@ pub async fn rebuild(db: &SurrealDbClient) -> Result<(), AppError> {
|
||||
rebuild_inner(db).await.map_err(AppError::internal)
|
||||
}
|
||||
|
||||
/// Returns the dimension of the currently defined chunk-embedding HNSW index, if any.
|
||||
///
|
||||
/// Stored embeddings always share this index's dimension because re-embedding rewrites the
|
||||
/// vectors and the index together, so it acts as a persisted marker of the embedding space
|
||||
/// actually present in the database. Returns `Ok(None)` when the index has not been created yet
|
||||
/// (for example on a fresh database with no ingested data).
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `AppError::InternalError` if the index metadata cannot be read.
|
||||
pub async fn embedding_index_dimension(db: &SurrealDbClient) -> Result<Option<usize>, AppError> {
|
||||
let spec = HnswIndexSpec {
|
||||
index_name: "idx_embedding_text_chunk_embedding",
|
||||
table: "text_chunk_embedding",
|
||||
options: HNSW_INDEX_OPTIONS,
|
||||
};
|
||||
existing_hnsw_dimension(db, &spec)
|
||||
.await
|
||||
.map_err(AppError::internal)
|
||||
}
|
||||
|
||||
async fn ensure_runtime_inner(db: &SurrealDbClient, embedding_dimension: usize) -> Result<()> {
|
||||
create_fts_analyzer(db).await?;
|
||||
|
||||
@@ -906,6 +927,34 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn embedding_index_dimension_reflects_runtime_state() -> anyhow::Result<()> {
|
||||
let namespace = "indexes_marker";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.context("in-memory db")?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.context("migrations should succeed")?;
|
||||
|
||||
// Before any index exists, there is no stored embedding dimension to detect.
|
||||
assert_eq!(embedding_index_dimension(&db).await?, None);
|
||||
|
||||
ensure_runtime(&db, 1536)
|
||||
.await
|
||||
.context("initial index creation")?;
|
||||
assert_eq!(embedding_index_dimension(&db).await?, Some(1536));
|
||||
|
||||
// After a dimension change the marker tracks the new index dimension.
|
||||
ensure_runtime(&db, 256)
|
||||
.await
|
||||
.context("overwritten index creation")?;
|
||||
assert_eq!(embedding_index_dimension(&db).await?, Some(256));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ensure_hnsw_index_overwrites_dimension() -> anyhow::Result<()> {
|
||||
let namespace = "indexes_dim";
|
||||
|
||||
@@ -6,12 +6,7 @@ use crate::{
|
||||
error::AppError, storage::db::SurrealDbClient, storage::indexes::hnsw_index_overwrite_sql,
|
||||
storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding,
|
||||
storage::types::system_settings::SystemSettings, stored_object,
|
||||
utils::embedding::{generate_embedding_with_params, generate_embedding_with_provider, EmbeddingProvider},
|
||||
};
|
||||
use async_openai::{config::OpenAIConfig, Client};
|
||||
use tokio_retry::{
|
||||
strategy::{jitter, ExponentialBackoff},
|
||||
Retry,
|
||||
utils::embedding::{EmbeddingProvider, RE_EMBED_BATCH_SIZE},
|
||||
};
|
||||
use tracing::{error, info};
|
||||
use uuid::Uuid;
|
||||
@@ -321,8 +316,7 @@ impl KnowledgeEntity {
|
||||
) -> Result<(), AppError> {
|
||||
let embedding_input =
|
||||
format!("name: {name}, description: {description}, type: {entity_type:?}",);
|
||||
let embedding =
|
||||
generate_embedding_with_provider(embedding_provider, &embedding_input).await?;
|
||||
let embedding = embedding_provider.embed(&embedding_input).await?;
|
||||
|
||||
let entity: KnowledgeEntity = db_client
|
||||
.get_item(id)
|
||||
@@ -368,120 +362,17 @@ impl KnowledgeEntity {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Re-creates embeddings for all knowledge entities in the database.
|
||||
/// Re-creates embeddings for all knowledge entities using an `EmbeddingProvider`.
|
||||
///
|
||||
/// This is a costly operation that should be run in the background. It follows the same
|
||||
/// pattern as the text chunk update:
|
||||
/// 1. Re-defines the vector index with the new dimensions.
|
||||
/// 2. Fetches all existing entities.
|
||||
/// 3. Sequentially regenerates the embedding for each and updates the record.
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn update_all_embeddings(
|
||||
db: &SurrealDbClient,
|
||||
openai_client: &Client<OpenAIConfig>,
|
||||
new_model: &str,
|
||||
new_dimensions: u32,
|
||||
) -> Result<(), AppError> {
|
||||
info!(
|
||||
"Starting re-embedding process for all knowledge entities. New dimensions: {}",
|
||||
new_dimensions
|
||||
);
|
||||
|
||||
// Fetch all entities first
|
||||
let all_entities: Vec<KnowledgeEntity> = db.select(Self::table_name()).await?;
|
||||
let total_entities = all_entities.len();
|
||||
if total_entities == 0 {
|
||||
info!("No knowledge entities to update. Just updating the idx");
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(db, new_dimensions as usize).await?;
|
||||
return Ok(());
|
||||
}
|
||||
info!("Found {total_entities} entities to process.");
|
||||
|
||||
// Generate all new embeddings in memory
|
||||
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new();
|
||||
info!("Generating new embeddings for all entities...");
|
||||
for entity in &all_entities {
|
||||
let embedding_input = format!(
|
||||
"name: {}, description: {}, type: {:?}",
|
||||
entity.name, entity.description, entity.entity_type
|
||||
);
|
||||
let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3);
|
||||
|
||||
let embedding = Retry::spawn(retry_strategy, || {
|
||||
generate_embedding_with_params(
|
||||
openai_client,
|
||||
&embedding_input,
|
||||
new_model,
|
||||
new_dimensions,
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
|
||||
// Check embedding lengths
|
||||
if embedding.len() != new_dimensions as usize {
|
||||
let err_msg = format!(
|
||||
"CRITICAL: Generated embedding for entity {} has incorrect dimension ({}). Expected {}. Aborting.",
|
||||
entity.id, embedding.len(), new_dimensions
|
||||
);
|
||||
error!("{err_msg}");
|
||||
return Err(AppError::internal(err_msg));
|
||||
}
|
||||
new_embeddings.insert(
|
||||
entity.id.clone(),
|
||||
(embedding, entity.user_id.clone(), entity.source_id.clone()),
|
||||
);
|
||||
}
|
||||
info!("Successfully generated all new embeddings.");
|
||||
|
||||
// Perform DB updates in a single transaction
|
||||
info!("Applying embedding updates in a transaction...");
|
||||
let mut transaction_query = String::from("BEGIN TRANSACTION;");
|
||||
|
||||
// Add all update statements to the embedding table
|
||||
for (id, (embedding, user_id, source_id)) in new_embeddings {
|
||||
let embedding = serde_json::to_string(&embedding)
|
||||
.map_err(|e| AppError::internal(format!("embedding serialization failed: {e}")))?;
|
||||
write!(
|
||||
transaction_query,
|
||||
"UPSERT type::thing('knowledge_entity_embedding', '{id}') SET \
|
||||
entity_id = type::thing('knowledge_entity', '{id}'), \
|
||||
embedding = {embedding}, \
|
||||
user_id = '{user_id}', \
|
||||
source_id = '{source_id}', \
|
||||
created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \
|
||||
updated_at = time::now();",
|
||||
)
|
||||
.map_err(AppError::internal)?;
|
||||
}
|
||||
|
||||
write!(
|
||||
transaction_query,
|
||||
"{}",
|
||||
hnsw_index_overwrite_sql(
|
||||
"idx_embedding_knowledge_entity_embedding",
|
||||
KnowledgeEntityEmbedding::table_name(),
|
||||
new_dimensions as usize,
|
||||
)
|
||||
)
|
||||
.map_err(AppError::internal)?;
|
||||
|
||||
transaction_query.push_str("COMMIT TRANSACTION;");
|
||||
|
||||
// Execute the entire atomic operation
|
||||
db.query(transaction_query).await?;
|
||||
|
||||
info!("Re-embedding process for knowledge entities completed successfully.");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Re-creates embeddings for all knowledge entities using an `EmbeddingProvider`.
|
||||
///
|
||||
/// This variant uses the application's configured embedding provider (FastEmbed, OpenAI, etc.)
|
||||
/// instead of directly calling OpenAI. Used during startup when embedding configuration changes.
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn update_all_embeddings_with_provider(
|
||||
db: &SurrealDbClient,
|
||||
provider: &crate::utils::embedding::EmbeddingProvider,
|
||||
provider: &EmbeddingProvider,
|
||||
) -> Result<(), AppError> {
|
||||
let new_dimensions = provider.dimension();
|
||||
info!(
|
||||
@@ -500,38 +391,53 @@ impl KnowledgeEntity {
|
||||
}
|
||||
info!(entities = total_entities, "Found entities to process");
|
||||
|
||||
// Generate all new embeddings in memory
|
||||
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new();
|
||||
// Generate all new embeddings in memory, batching to amortise lock/dispatch overhead
|
||||
// while keeping memory bounded and preserving progress logging.
|
||||
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> =
|
||||
HashMap::with_capacity(total_entities);
|
||||
info!("Generating new embeddings for all entities...");
|
||||
|
||||
for (i, entity) in all_entities.iter().enumerate() {
|
||||
if i > 0 && i % 100 == 0 {
|
||||
info!(
|
||||
progress = i,
|
||||
total = total_entities,
|
||||
"Re-embedding progress"
|
||||
let mut processed = 0usize;
|
||||
for batch in all_entities.chunks(RE_EMBED_BATCH_SIZE) {
|
||||
let inputs: Vec<String> = batch
|
||||
.iter()
|
||||
.map(|entity| {
|
||||
format!(
|
||||
"name: {}, description: {}, type: {:?}",
|
||||
entity.name, entity.description, entity.entity_type
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
let embeddings = provider.embed_batch(inputs).await?;
|
||||
if embeddings.len() != batch.len() {
|
||||
return Err(AppError::internal(format!(
|
||||
"embedding batch returned {} vectors for {} entities",
|
||||
embeddings.len(),
|
||||
batch.len()
|
||||
)));
|
||||
}
|
||||
|
||||
for (entity, embedding) in batch.iter().zip(embeddings) {
|
||||
// Safety check: ensure the generated embedding has the correct dimension.
|
||||
if embedding.len() != new_dimensions {
|
||||
let err_msg = format!(
|
||||
"CRITICAL: Generated embedding for entity {} has incorrect dimension ({}). Expected {}. Aborting.",
|
||||
entity.id, embedding.len(), new_dimensions
|
||||
);
|
||||
error!("{err_msg}");
|
||||
return Err(AppError::internal(err_msg));
|
||||
}
|
||||
new_embeddings.insert(
|
||||
entity.id.clone(),
|
||||
(embedding, entity.user_id.clone(), entity.source_id.clone()),
|
||||
);
|
||||
}
|
||||
|
||||
let embedding_input = format!(
|
||||
"name: {}, description: {}, type: {:?}",
|
||||
entity.name, entity.description, entity.entity_type
|
||||
);
|
||||
|
||||
let embedding = provider.embed(&embedding_input).await?;
|
||||
|
||||
// Safety check: ensure the generated embedding has the correct dimension.
|
||||
if embedding.len() != new_dimensions {
|
||||
let err_msg = format!(
|
||||
"CRITICAL: Generated embedding for entity {} has incorrect dimension ({}). Expected {}. Aborting.",
|
||||
entity.id, embedding.len(), new_dimensions
|
||||
);
|
||||
error!("{err_msg}");
|
||||
return Err(AppError::internal(err_msg));
|
||||
}
|
||||
new_embeddings.insert(
|
||||
entity.id.clone(),
|
||||
(embedding, entity.user_id.clone(), entity.source_id.clone()),
|
||||
processed = processed.saturating_add(batch.len());
|
||||
info!(
|
||||
progress = processed,
|
||||
total = total_entities,
|
||||
"Re-embedding progress"
|
||||
);
|
||||
}
|
||||
info!("Successfully generated all new embeddings.");
|
||||
|
||||
@@ -235,7 +235,6 @@ mod tests {
|
||||
use crate::storage::indexes::ensure_runtime;
|
||||
use crate::storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk};
|
||||
use anyhow::{self, Context};
|
||||
use async_openai::Client;
|
||||
|
||||
use super::*;
|
||||
use uuid::Uuid;
|
||||
@@ -713,6 +712,8 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_should_change_embedding_length_on_indexes_when_switching_length(
|
||||
) -> anyhow::Result<()> {
|
||||
use crate::utils::embedding::EmbeddingProvider;
|
||||
|
||||
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.with_context(|| "Failed to start DB".to_string())?;
|
||||
@@ -758,12 +759,13 @@ mod tests {
|
||||
"Settings should reflect the new embedding dimension"
|
||||
);
|
||||
|
||||
let openai_client = Client::new();
|
||||
let provider = EmbeddingProvider::new_hashed(new_dimension as usize)
|
||||
.map_err(|e| anyhow::anyhow!("{e}"))?;
|
||||
|
||||
TextChunk::update_all_embeddings(&db, &openai_client, &new_model, new_dimension)
|
||||
TextChunk::update_all_embeddings(&db, &provider)
|
||||
.await
|
||||
.with_context(|| "TextChunk re-embedding should succeed on fresh DB".to_string())?;
|
||||
KnowledgeEntity::update_all_embeddings(&db, &openai_client, &new_model, new_dimension)
|
||||
KnowledgeEntity::update_all_embeddings(&db, &provider)
|
||||
.await
|
||||
.with_context(|| {
|
||||
"KnowledgeEntity re-embedding should succeed on fresh DB".to_string()
|
||||
|
||||
@@ -5,12 +5,8 @@ use std::fmt::Write;
|
||||
use crate::storage::indexes::hnsw_index_overwrite_sql;
|
||||
use crate::storage::types::system_settings::SystemSettings;
|
||||
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
|
||||
use crate::utils::embedding::RE_EMBED_BATCH_SIZE;
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||
use async_openai::{config::OpenAIConfig, Client};
|
||||
use tokio_retry::{
|
||||
strategy::{jitter, ExponentialBackoff},
|
||||
Retry,
|
||||
};
|
||||
|
||||
use tracing::{error, info, warn};
|
||||
use uuid::Uuid;
|
||||
@@ -238,7 +234,7 @@ impl TextChunk {
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Re-creates embeddings for all text chunks using a safe, atomic transaction.
|
||||
/// Re-creates embeddings for all text chunks using an `EmbeddingProvider`.
|
||||
///
|
||||
/// This is a costly operation that should be run in the background. It performs these steps:
|
||||
/// 1. **Fetches All Chunks**: Loads all existing text_chunk records into memory.
|
||||
@@ -246,109 +242,8 @@ impl TextChunk {
|
||||
/// has the wrong dimension, the entire operation is aborted before any DB changes are made.
|
||||
/// 3. **Executes Atomic Transaction**: All data updates and the index recreation are
|
||||
/// performed in a single, all-or-nothing database transaction.
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn update_all_embeddings(
|
||||
db: &SurrealDbClient,
|
||||
openai_client: &Client<OpenAIConfig>,
|
||||
new_model: &str,
|
||||
new_dimensions: u32,
|
||||
) -> Result<(), AppError> {
|
||||
info!(
|
||||
"Starting re-embedding process for all text chunks. New dimensions: {new_dimensions}"
|
||||
);
|
||||
|
||||
// Fetch all chunks first
|
||||
let all_chunks: Vec<TextChunk> = db.select(Self::table_name()).await?;
|
||||
let total_chunks = all_chunks.len();
|
||||
if total_chunks == 0 {
|
||||
info!("No text chunks to update. Just updating the idx");
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(db, new_dimensions as usize).await?;
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
info!("Found {total_chunks} chunks to process.");
|
||||
|
||||
// Generate all new embeddings in memory
|
||||
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new();
|
||||
info!("Generating new embeddings for all chunks...");
|
||||
for chunk in &all_chunks {
|
||||
let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3);
|
||||
|
||||
let embedding = Retry::spawn(retry_strategy, || {
|
||||
crate::utils::embedding::generate_embedding_with_params(
|
||||
openai_client,
|
||||
&chunk.chunk,
|
||||
new_model,
|
||||
new_dimensions,
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
|
||||
// Safety check: ensure the generated embedding has the correct dimension.
|
||||
if embedding.len() != new_dimensions as usize {
|
||||
let err_msg = format!(
|
||||
"CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.",
|
||||
chunk.id, embedding.len(), new_dimensions
|
||||
);
|
||||
error!("{err_msg}");
|
||||
return Err(AppError::internal(err_msg));
|
||||
}
|
||||
new_embeddings.insert(
|
||||
chunk.id.clone(),
|
||||
(embedding, chunk.user_id.clone(), chunk.source_id.clone()),
|
||||
);
|
||||
}
|
||||
info!("Successfully generated all new embeddings.");
|
||||
|
||||
// Perform DB updates in a single transaction against the embedding table
|
||||
info!("Applying embedding updates in a transaction...");
|
||||
let mut transaction_query = String::from("BEGIN TRANSACTION;");
|
||||
|
||||
for (id, (embedding, user_id, source_id)) in new_embeddings {
|
||||
let embedding = serde_json::to_string(&embedding)
|
||||
.map_err(|e| AppError::internal(format!("embedding serialization failed: {e}")))?;
|
||||
let id = surql_json_string(&id)?;
|
||||
let user_id = surql_json_string(&user_id)?;
|
||||
let source_id = surql_json_string(&source_id)?;
|
||||
write!(
|
||||
&mut transaction_query,
|
||||
"UPSERT type::thing('{emb_table}', {id}) SET \
|
||||
chunk_id = type::thing('{chunk_table}', {id}), \
|
||||
source_id = {source_id}, \
|
||||
embedding = {embedding}, \
|
||||
user_id = {user_id}, \
|
||||
created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \
|
||||
updated_at = time::now();",
|
||||
emb_table = TextChunkEmbedding::table_name(),
|
||||
chunk_table = Self::table_name(),
|
||||
)
|
||||
.map_err(AppError::internal)?;
|
||||
}
|
||||
|
||||
write!(
|
||||
&mut transaction_query,
|
||||
"{}",
|
||||
hnsw_index_overwrite_sql(
|
||||
"idx_embedding_text_chunk_embedding",
|
||||
TextChunkEmbedding::table_name(),
|
||||
new_dimensions as usize,
|
||||
)
|
||||
)
|
||||
.map_err(AppError::internal)?;
|
||||
|
||||
transaction_query.push_str("COMMIT TRANSACTION;");
|
||||
|
||||
db.query(transaction_query).await?;
|
||||
|
||||
info!("Re-embedding process for text chunks completed successfully.");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Re-creates embeddings for all text chunks using an `EmbeddingProvider`.
|
||||
///
|
||||
/// This variant uses the application's configured embedding provider (FastEmbed, OpenAI, etc.)
|
||||
/// instead of directly calling OpenAI. Used during startup when embedding configuration changes.
|
||||
pub async fn update_all_embeddings_with_provider(
|
||||
db: &SurrealDbClient,
|
||||
provider: &crate::utils::embedding::EmbeddingProvider,
|
||||
) -> Result<(), AppError> {
|
||||
@@ -369,30 +264,42 @@ impl TextChunk {
|
||||
}
|
||||
info!(chunks = total_chunks, "Found chunks to process");
|
||||
|
||||
// Generate all new embeddings in memory
|
||||
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new();
|
||||
// Generate all new embeddings in memory, batching to amortise lock/dispatch overhead
|
||||
// while keeping memory bounded and preserving progress logging.
|
||||
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> =
|
||||
HashMap::with_capacity(total_chunks);
|
||||
info!("Generating new embeddings for all chunks...");
|
||||
|
||||
for (i, chunk) in all_chunks.iter().enumerate() {
|
||||
if i > 0 && i % 100 == 0 {
|
||||
info!(progress = i, total = total_chunks, "Re-embedding progress");
|
||||
let mut processed = 0usize;
|
||||
for batch in all_chunks.chunks(RE_EMBED_BATCH_SIZE) {
|
||||
let inputs: Vec<String> = batch.iter().map(|chunk| chunk.chunk.clone()).collect();
|
||||
let embeddings = provider.embed_batch(inputs).await?;
|
||||
if embeddings.len() != batch.len() {
|
||||
return Err(AppError::internal(format!(
|
||||
"embedding batch returned {} vectors for {} chunks",
|
||||
embeddings.len(),
|
||||
batch.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let embedding = provider.embed(&chunk.chunk).await?;
|
||||
|
||||
// Safety check: ensure the generated embedding has the correct dimension.
|
||||
if embedding.len() != new_dimensions {
|
||||
let err_msg = format!(
|
||||
"CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.",
|
||||
chunk.id, embedding.len(), new_dimensions
|
||||
for (chunk, embedding) in batch.iter().zip(embeddings) {
|
||||
// Safety check: ensure the generated embedding has the correct dimension.
|
||||
if embedding.len() != new_dimensions {
|
||||
let err_msg = format!(
|
||||
"CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.",
|
||||
chunk.id, embedding.len(), new_dimensions
|
||||
);
|
||||
error!("{err_msg}");
|
||||
return Err(AppError::internal(err_msg));
|
||||
}
|
||||
new_embeddings.insert(
|
||||
chunk.id.clone(),
|
||||
(embedding, chunk.user_id.clone(), chunk.source_id.clone()),
|
||||
);
|
||||
error!("{err_msg}");
|
||||
return Err(AppError::internal(err_msg));
|
||||
}
|
||||
new_embeddings.insert(
|
||||
chunk.id.clone(),
|
||||
(embedding, chunk.user_id.clone(), chunk.source_id.clone()),
|
||||
);
|
||||
|
||||
processed = processed.saturating_add(batch.len());
|
||||
info!(progress = processed, total = total_chunks, "Re-embedding progress");
|
||||
}
|
||||
info!("Successfully generated all new embeddings.");
|
||||
|
||||
|
||||
@@ -119,6 +119,8 @@ pub struct AppConfig {
|
||||
pub fastembed_max_length: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub embedding_backend: EmbeddingBackend,
|
||||
#[serde(default)]
|
||||
pub embedding_pool_size: Option<usize>,
|
||||
#[serde(default = "default_ingest_max_body_bytes")]
|
||||
pub ingest_max_body_bytes: usize,
|
||||
#[serde(default = "default_ingest_max_files")]
|
||||
@@ -225,6 +227,7 @@ impl Default for AppConfig {
|
||||
fastembed_show_download_progress: None,
|
||||
fastembed_max_length: None,
|
||||
embedding_backend: EmbeddingBackend::default(),
|
||||
embedding_pool_size: None,
|
||||
ingest_max_body_bytes: default_ingest_max_body_bytes(),
|
||||
ingest_max_files: default_ingest_max_files(),
|
||||
ingest_max_content_bytes: default_ingest_max_content_bytes(),
|
||||
|
||||
+126
-126
@@ -3,15 +3,16 @@ use std::{
|
||||
hash::{Hash, Hasher},
|
||||
str::FromStr,
|
||||
sync::{Arc, Mutex},
|
||||
thread::available_parallelism,
|
||||
};
|
||||
|
||||
use async_openai::{types::CreateEmbeddingRequestArgs, Client};
|
||||
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
|
||||
use tracing::debug;
|
||||
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
|
||||
|
||||
use crate::{
|
||||
error::{AppError, EmbeddingError},
|
||||
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
|
||||
error::EmbeddingError,
|
||||
storage::types::system_settings::SystemSettings,
|
||||
utils::config::AppConfig,
|
||||
};
|
||||
|
||||
@@ -45,8 +46,8 @@ enum EmbeddingInner {
|
||||
},
|
||||
/// Uses `FastEmbed` running locally.
|
||||
FastEmbed {
|
||||
/// Shared `FastEmbed` model (blocking; used only inside `spawn_blocking`).
|
||||
model: Arc<Mutex<TextEmbedding>>,
|
||||
/// Pool of `FastEmbed` engines providing bounded-concurrency local embedding.
|
||||
pool: Arc<FastEmbedPool>,
|
||||
/// Model metadata used for info logging.
|
||||
model_name: EmbeddingModel,
|
||||
/// Output vector length.
|
||||
@@ -54,19 +55,99 @@ enum EmbeddingInner {
|
||||
},
|
||||
}
|
||||
|
||||
/// Batch size used when re-embedding stored data in bulk. Bounds peak memory and preserves
|
||||
/// progress logging while still amortising per-call lock/dispatch overhead.
|
||||
pub const RE_EMBED_BATCH_SIZE: usize = 128;
|
||||
|
||||
/// Default FastEmbed pool size.
|
||||
///
|
||||
/// Kept small on purpose: the ONNX runtime already uses intra-op threads per inference, so
|
||||
/// running many engines concurrently oversubscribes the CPU and each engine duplicates the
|
||||
/// model weights in memory. Mirrors the reranker pool default.
|
||||
#[must_use]
|
||||
pub fn default_embedding_pool_size() -> usize {
|
||||
available_parallelism()
|
||||
.map_or(2, |value| value.get().min(2))
|
||||
.max(1)
|
||||
}
|
||||
|
||||
/// Pool of `FastEmbed` engines enabling bounded-concurrency local embedding.
|
||||
///
|
||||
/// A single [`TextEmbedding`] embeds one batch at a time (`&mut self`), so the pool keeps
|
||||
/// several instances and hands out a distinct idle engine per checkout. The semaphore bounds
|
||||
/// total in-flight embeds (backpressure); the free list guarantees each active lease holds a
|
||||
/// different engine — unlike a round-robin index, which can hand the same engine to two callers.
|
||||
struct FastEmbedPool {
|
||||
/// Idle engines; one is popped on checkout and returned on lease drop.
|
||||
engines: Mutex<Vec<Arc<Mutex<TextEmbedding>>>>,
|
||||
/// Sized to the engine count; gates concurrent checkouts.
|
||||
semaphore: Arc<Semaphore>,
|
||||
}
|
||||
|
||||
impl FastEmbedPool {
|
||||
fn new(engines: Vec<Arc<Mutex<TextEmbedding>>>) -> Self {
|
||||
let permits = engines.len().max(1);
|
||||
Self {
|
||||
engines: Mutex::new(engines),
|
||||
semaphore: Arc::new(Semaphore::new(permits)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Acquire a permit and borrow a distinct idle engine. The permit guarantees an engine is
|
||||
/// available, so the pop always succeeds for a correctly sized pool.
|
||||
async fn checkout(self: &Arc<Self>) -> Result<FastEmbedLease, EmbeddingError> {
|
||||
let permit = Arc::clone(&self.semaphore)
|
||||
.acquire_owned()
|
||||
.await
|
||||
.map_err(|_| EmbeddingError::Config("embedding pool is closed".into()))?;
|
||||
let engine = self
|
||||
.engines
|
||||
.lock()
|
||||
.map_err(EmbeddingError::mutex_poisoned)?
|
||||
.pop()
|
||||
.ok_or_else(|| EmbeddingError::Config("embedding pool unexpectedly empty".into()))?;
|
||||
Ok(FastEmbedLease {
|
||||
pool: Arc::clone(self),
|
||||
engine,
|
||||
_permit: permit,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Active borrow of a single `FastEmbed` engine; returns it to the pool on drop.
|
||||
struct FastEmbedLease {
|
||||
pool: Arc<FastEmbedPool>,
|
||||
engine: Arc<Mutex<TextEmbedding>>,
|
||||
/// Released after the engine is returned, unblocking the next checkout.
|
||||
_permit: OwnedSemaphorePermit,
|
||||
}
|
||||
|
||||
impl FastEmbedLease {
|
||||
async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||
let engine = Arc::clone(&self.engine);
|
||||
tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||
let mut guard = engine.lock().map_err(EmbeddingError::mutex_poisoned)?;
|
||||
guard.embed(texts, None).map_err(EmbeddingError::fastembed)
|
||||
})
|
||||
.await
|
||||
.map_err(EmbeddingError::from)?
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for FastEmbedLease {
|
||||
fn drop(&mut self) {
|
||||
if let Ok(mut free) = self.pool.engines.lock() {
|
||||
free.push(Arc::clone(&self.engine));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_fastembed(
|
||||
model: Arc<Mutex<TextEmbedding>>,
|
||||
pool: &Arc<FastEmbedPool>,
|
||||
texts: Vec<String>,
|
||||
) -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||
match tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||
let mut guard = model.lock().map_err(EmbeddingError::mutex_poisoned)?;
|
||||
guard.embed(texts, None).map_err(EmbeddingError::fastembed)
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(join_error) => Err(EmbeddingError::from(join_error)),
|
||||
}
|
||||
let lease = pool.checkout().await?;
|
||||
lease.embed(texts).await
|
||||
}
|
||||
|
||||
impl EmbeddingProvider {
|
||||
@@ -107,8 +188,8 @@ impl EmbeddingProvider {
|
||||
pub async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
|
||||
match &self.inner {
|
||||
EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)),
|
||||
EmbeddingInner::FastEmbed { model, .. } => {
|
||||
let embeddings = run_fastembed(Arc::clone(model), vec![text.to_owned()]).await?;
|
||||
EmbeddingInner::FastEmbed { pool, .. } => {
|
||||
let embeddings = run_fastembed(pool, vec![text.to_owned()]).await?;
|
||||
embeddings.into_iter().next().ok_or(EmbeddingError::NoData)
|
||||
}
|
||||
EmbeddingInner::OpenAI {
|
||||
@@ -148,11 +229,11 @@ impl EmbeddingProvider {
|
||||
.into_iter()
|
||||
.map(|text| hashed_embedding(&text, *dimension))
|
||||
.collect()),
|
||||
EmbeddingInner::FastEmbed { model, .. } => {
|
||||
EmbeddingInner::FastEmbed { pool, .. } => {
|
||||
if texts.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
run_fastembed(Arc::clone(model), texts).await
|
||||
run_fastembed(pool, texts).await
|
||||
}
|
||||
EmbeddingInner::OpenAI {
|
||||
client,
|
||||
@@ -199,30 +280,46 @@ impl EmbeddingProvider {
|
||||
})
|
||||
}
|
||||
|
||||
/// Initialise a local FastEmbed provider backed by a pool of `pool_size` engines.
|
||||
///
|
||||
/// `pool_size` is clamped to at least 1. Larger pools allow concurrent embeds at the cost of
|
||||
/// `pool_size`× model memory; see [`default_embedding_pool_size`] for guidance.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`EmbeddingError`] if the model name is unknown or FastEmbed initialisation fails.
|
||||
pub async fn new_fastembed(model_override: Option<String>) -> Result<Self, EmbeddingError> {
|
||||
pub async fn new_fastembed(
|
||||
model_override: Option<String>,
|
||||
pool_size: usize,
|
||||
) -> Result<Self, EmbeddingError> {
|
||||
let pool_size = pool_size.max(1);
|
||||
let model_name = if let Some(code) = model_override {
|
||||
EmbeddingModel::from_str(&code).map_err(EmbeddingError::UnknownModel)?
|
||||
} else {
|
||||
EmbeddingModel::default()
|
||||
};
|
||||
|
||||
let options = TextInitOptions::new(model_name.clone()).with_show_download_progress(true);
|
||||
let model_name_for_task = model_name.clone();
|
||||
let model_name_code = model_name.to_string();
|
||||
|
||||
let (model, dimension) =
|
||||
let (engines, dimension) =
|
||||
match tokio::task::spawn_blocking(move || -> Result<_, EmbeddingError> {
|
||||
let model = TextEmbedding::try_new(options).map_err(EmbeddingError::fastembed)?;
|
||||
let info =
|
||||
EmbeddingModel::get_model_info(&model_name_for_task).ok_or_else(|| {
|
||||
EmbeddingError::Config(format!(
|
||||
"fastembed model metadata missing for {model_name_code}"
|
||||
))
|
||||
})?;
|
||||
Ok((model, info.dim))
|
||||
let mut engines = Vec::with_capacity(pool_size);
|
||||
for index in 0..pool_size {
|
||||
let options = TextInitOptions::new(model_name_for_task.clone())
|
||||
// Only the first engine reports download progress; the rest reuse the cache.
|
||||
.with_show_download_progress(index == 0);
|
||||
let model =
|
||||
TextEmbedding::try_new(options).map_err(EmbeddingError::fastembed)?;
|
||||
engines.push(Arc::new(Mutex::new(model)));
|
||||
}
|
||||
Ok((engines, info.dim))
|
||||
})
|
||||
.await
|
||||
{
|
||||
@@ -232,7 +329,7 @@ impl EmbeddingProvider {
|
||||
|
||||
Ok(EmbeddingProvider {
|
||||
inner: EmbeddingInner::FastEmbed {
|
||||
model: Arc::new(Mutex::new(model)),
|
||||
pool: Arc::new(FastEmbedPool::new(engines)),
|
||||
model_name,
|
||||
dimension,
|
||||
},
|
||||
@@ -275,7 +372,10 @@ impl EmbeddingProvider {
|
||||
Self::new_openai(client, settings.embedding_model.clone(), dimensions)
|
||||
}
|
||||
EmbeddingBackend::FastEmbed => {
|
||||
Self::new_fastembed(Some(settings.embedding_model.clone())).await
|
||||
let pool_size = config
|
||||
.embedding_pool_size
|
||||
.unwrap_or_else(default_embedding_pool_size);
|
||||
Self::new_fastembed(Some(settings.embedding_model.clone()), pool_size).await
|
||||
}
|
||||
EmbeddingBackend::Hashed => {
|
||||
let dimension = usize::try_from(dimensions).map_err(|_| {
|
||||
@@ -329,106 +429,6 @@ fn bucket(token: &str, dimension: usize) -> usize {
|
||||
usize::try_from(hasher.finish()).unwrap_or_default() % safe_dimension
|
||||
}
|
||||
|
||||
/// Generate an embedding using the given provider.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`AppError::Embedding`] if the provider's embed call fails.
|
||||
pub async fn generate_embedding_with_provider(
|
||||
provider: &EmbeddingProvider,
|
||||
input: &str,
|
||||
) -> Result<Vec<f32>, AppError> {
|
||||
provider.embed(input).await.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Generates an embedding vector for the given input text using `OpenAI`'s embedding model.
|
||||
///
|
||||
/// This function takes a text input and converts it into a numerical vector representation (embedding)
|
||||
/// using `OpenAI`'s text-embedding-3-small model. These embeddings can be used for semantic similarity
|
||||
/// comparisons, vector search, and other natural language processing tasks.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `client`: The `OpenAI` client instance used to make API requests.
|
||||
/// * `input`: The text string to generate embeddings for.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns a `Result` containing either:
|
||||
/// * `Ok(Vec<f32>)`: A vector of 32-bit floating point numbers representing the text embedding
|
||||
/// * `Err(ProcessingError)`: An error if the embedding generation fails
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// This function can return a `AppError` in the following cases:
|
||||
/// * If the `OpenAI` API request fails
|
||||
/// * If the request building fails
|
||||
/// * If no embedding data is received in the response
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
pub async fn generate_embedding(
|
||||
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
input: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Vec<f32>, AppError> {
|
||||
let model = SystemSettings::get_current(db).await?;
|
||||
|
||||
let request = CreateEmbeddingRequestArgs::default()
|
||||
.model(model.embedding_model)
|
||||
.dimensions(model.embedding_dimensions)
|
||||
.input([input])
|
||||
.build()?;
|
||||
|
||||
// Send the request to OpenAI
|
||||
let response = client.embeddings().create(request).await?;
|
||||
|
||||
// Extract the embedding vector
|
||||
let embedding: Vec<f32> = response
|
||||
.data
|
||||
.first()
|
||||
.ok_or_else(|| AppError::LLMParsing("no embedding data received".into()))?
|
||||
.embedding
|
||||
.clone();
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
/// Generates an embedding vector using a specific model and dimension.
|
||||
///
|
||||
/// This is used for the re-embedding process where the model and dimensions
|
||||
/// are known ahead of time and shouldn't be repeatedly fetched from settings.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `AppError` if the OpenAI API request fails or returns no embedding data.
|
||||
pub async fn generate_embedding_with_params(
|
||||
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
input: &str,
|
||||
model: &str,
|
||||
dimensions: u32,
|
||||
) -> Result<Vec<f32>, AppError> {
|
||||
let request = CreateEmbeddingRequestArgs::default()
|
||||
.model(model)
|
||||
.input([input])
|
||||
.dimensions(dimensions)
|
||||
.build()?;
|
||||
|
||||
let response = client.embeddings().create(request).await?;
|
||||
|
||||
let embedding = response
|
||||
.data
|
||||
.first()
|
||||
.ok_or_else(|| AppError::LLMParsing("no embedding data received from API".into()))?
|
||||
.embedding
|
||||
.clone();
|
||||
|
||||
debug!(
|
||||
"Embedding was created with {:?} dimensions",
|
||||
embedding.len()
|
||||
);
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::expect_used)]
|
||||
|
||||
@@ -11,7 +11,7 @@ use crate::{
|
||||
},
|
||||
openai,
|
||||
};
|
||||
use common::utils::embedding::EmbeddingProvider;
|
||||
use common::utils::embedding::{default_embedding_pool_size, EmbeddingProvider};
|
||||
|
||||
use super::super::{
|
||||
context::{EvalStage, EvaluationContext},
|
||||
@@ -43,9 +43,12 @@ pub(crate) async fn prepare_db(
|
||||
// Create embedding provider directly from config (eval only supports FastEmbed and Hashed)
|
||||
let embedding_provider = match config.embedding_backend {
|
||||
crate::args::EmbeddingBackend::FastEmbed => {
|
||||
EmbeddingProvider::new_fastembed(config.embedding_model.clone())
|
||||
.await
|
||||
.context("creating FastEmbed provider")?
|
||||
EmbeddingProvider::new_fastembed(
|
||||
config.embedding_model.clone(),
|
||||
default_embedding_pool_size(),
|
||||
)
|
||||
.await
|
||||
.context("creating FastEmbed provider")?
|
||||
}
|
||||
crate::args::EmbeddingBackend::Hashed => {
|
||||
EmbeddingProvider::new_hashed(1536).context("creating Hashed provider")?
|
||||
|
||||
@@ -136,12 +136,10 @@ pub(crate) async fn run_queries(
|
||||
let embedding_provider_for_queries = ctx.embedding_provider()?.clone();
|
||||
let rerank_pool_for_queries = rerank_pool.clone();
|
||||
let db = ctx.db()?.clone();
|
||||
let openai_client = ctx.openai_client()?;
|
||||
|
||||
let raw_results = stream::iter(cases_iter)
|
||||
.map(move |(idx, case)| {
|
||||
let db = db.clone();
|
||||
let openai_client = Arc::clone(&openai_client);
|
||||
let user_id = user_id.clone();
|
||||
let retrieval_config = Arc::clone(&retrieval_config);
|
||||
let embedding_provider = embedding_provider_for_queries.clone();
|
||||
@@ -180,8 +178,7 @@ pub(crate) async fn run_queries(
|
||||
|
||||
let params = pipeline::RetrievalParams {
|
||||
db_client: &db,
|
||||
openai_client: &openai_client,
|
||||
embedding_provider: Some(&embedding_provider),
|
||||
embedding_provider: &embedding_provider,
|
||||
input_text: &question,
|
||||
user_id: &user_id,
|
||||
config: (*retrieval_config).clone(),
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_openai::types::ListModelResponse;
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
@@ -11,17 +9,15 @@ use common::{
|
||||
error::AppError,
|
||||
storage::types::{
|
||||
analytics::Analytics,
|
||||
knowledge_entity::KnowledgeEntity,
|
||||
system_prompts::{
|
||||
DEFAULT_IMAGE_PROCESSING_PROMPT, DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT,
|
||||
DEFAULT_QUERY_SYSTEM_PROMPT,
|
||||
},
|
||||
system_settings::{SystemSettings, SystemSettingsPatch},
|
||||
text_chunk::TextChunk,
|
||||
},
|
||||
utils::embedding::EmbeddingBackend,
|
||||
};
|
||||
use tracing::{error, info};
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
html_state::HtmlState,
|
||||
@@ -209,41 +205,15 @@ pub async fn update_model_settings(
|
||||
.await?;
|
||||
|
||||
if reembedding_needed {
|
||||
info!("Embedding dimensions changed. Spawning background re-embedding task...");
|
||||
|
||||
let db_for_task = Arc::clone(&state.db);
|
||||
let openai_for_task = Arc::clone(&state.openai_client);
|
||||
let new_model_for_task = new_settings.embedding_model.clone();
|
||||
let new_dims_for_task = new_settings.embedding_dimensions;
|
||||
|
||||
tokio::spawn(async move {
|
||||
// First, update all text chunks
|
||||
if let Err(e) = TextChunk::update_all_embeddings(
|
||||
&db_for_task,
|
||||
&openai_for_task,
|
||||
&new_model_for_task,
|
||||
new_dims_for_task,
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("Background re-embedding task failed for TextChunks: {}", e);
|
||||
}
|
||||
|
||||
// Second, update all knowledge entities
|
||||
if let Err(e) = KnowledgeEntity::update_all_embeddings(
|
||||
&db_for_task,
|
||||
&openai_for_task,
|
||||
&new_model_for_task,
|
||||
new_dims_for_task,
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!(
|
||||
"Background re-embedding task failed for KnowledgeEntities: {}",
|
||||
e
|
||||
);
|
||||
}
|
||||
});
|
||||
// Re-embedding is owned by startup (the worker/combined binary), not the admin request.
|
||||
// Doing it inline here would leave the live, startup-built embedding provider embedding
|
||||
// queries at the old dimension while stored vectors move to the new one — broken retrieval
|
||||
// until restart. Persisting the new settings is enough: on the next restart the maintainer
|
||||
// detects the index/dimension mismatch and re-embeds before rebuilding indexes.
|
||||
info!(
|
||||
new_dimensions = new_settings.embedding_dimensions,
|
||||
"Embedding dimensions changed; restart the worker/server to re-embed and apply"
|
||||
);
|
||||
}
|
||||
|
||||
let available_models = state
|
||||
|
||||
@@ -359,8 +359,7 @@ async fn prepare_chat_request(
|
||||
|
||||
let retrieval_result = match retrieval_pipeline::retrieve(
|
||||
&state.db,
|
||||
&state.openai_client,
|
||||
Some(&*state.embedding_provider),
|
||||
&state.embedding_provider,
|
||||
&user_message.content,
|
||||
&user.id,
|
||||
config,
|
||||
|
||||
@@ -24,7 +24,7 @@ use common::{
|
||||
user::User,
|
||||
},
|
||||
},
|
||||
utils::embedding::{generate_embedding_with_provider, EmbeddingProvider},
|
||||
utils::embedding::EmbeddingProvider,
|
||||
};
|
||||
use retrieval_pipeline::{
|
||||
normalize_fts_terms, reciprocal_rank_fusion, RetrievalTuning, RrfConfig, Scored,
|
||||
@@ -187,8 +187,11 @@ pub async fn create_knowledge_entity(
|
||||
|
||||
let embedding_input =
|
||||
format!("name: {name}, description: {description}, type: {entity_type:?}");
|
||||
let embedding =
|
||||
generate_embedding_with_provider(&state.embedding_provider, &embedding_input).await?;
|
||||
let embedding = state
|
||||
.embedding_provider
|
||||
.embed(&embedding_input)
|
||||
.await
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
let source_id = format!("manual::{}", Uuid::new_v4());
|
||||
let new_entity = KnowledgeEntity::new(
|
||||
@@ -373,8 +376,7 @@ async fn suggest_related_entities(
|
||||
"name: {}, description: {}, type: {:?}",
|
||||
draft.name, draft.description, draft.entity_type
|
||||
);
|
||||
let embedding =
|
||||
generate_embedding_with_provider(embedding_provider, &embedding_input).await?;
|
||||
let embedding = embedding_provider.embed(&embedding_input).await?;
|
||||
|
||||
let take = MAX_RELATIONSHIP_SUGGESTIONS * 2;
|
||||
let tuning = RetrievalTuning::default();
|
||||
|
||||
@@ -171,8 +171,7 @@ async fn perform_search(
|
||||
|
||||
let result = retrieve(
|
||||
&state.db,
|
||||
&state.openai_client,
|
||||
Some(&state.embedding_provider),
|
||||
&state.embedding_provider,
|
||||
trimmed_query,
|
||||
&user.id,
|
||||
config,
|
||||
|
||||
@@ -132,7 +132,7 @@ fn extract_authenticated_main(html: &str) -> &str {
|
||||
let start = html
|
||||
.find(AUTHENTICATED_MAIN_OPEN)
|
||||
.expect("authenticated page main column")
|
||||
+ AUTHENTICATED_MAIN_OPEN.len();
|
||||
.saturating_add(AUTHENTICATED_MAIN_OPEN.len());
|
||||
let rest = &html[start..];
|
||||
let end = rest
|
||||
.find("</main>")
|
||||
@@ -184,8 +184,12 @@ async fn create_scratchpad_and_get_id(app: &Router, cookie: &str, title: &str) -
|
||||
|
||||
let list = get_html(app, "/scratchpad", Some(cookie)).await;
|
||||
let marker = "/scratchpad/";
|
||||
let start = list.find(marker).expect("scratchpad link present") + marker.len();
|
||||
list[start..start + list[start..].find('/').expect("id terminator")].to_string()
|
||||
let start = list
|
||||
.find(marker)
|
||||
.expect("scratchpad link present")
|
||||
.saturating_add(marker.len());
|
||||
let end = start.saturating_add(list[start..].find('/').expect("id terminator"));
|
||||
list[start..end].to_string()
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
|
||||
@@ -7,13 +7,12 @@ use serde::{Deserialize, Serialize};
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{
|
||||
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
||||
knowledge_relationship::KnowledgeRelationship,
|
||||
},
|
||||
},
|
||||
utils::{embedding::generate_embedding, embedding::EmbeddingProvider},
|
||||
utils::embedding::EmbeddingProvider,
|
||||
};
|
||||
|
||||
use crate::pipeline::context::EmbeddedKnowledgeEntity;
|
||||
@@ -46,10 +45,8 @@ impl LLMEnrichmentResult {
|
||||
&self,
|
||||
source_id: &str,
|
||||
user_id: &str,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
db_client: &SurrealDbClient,
|
||||
entity_concurrency: usize,
|
||||
embedding_provider: Option<&EmbeddingProvider>,
|
||||
embedding_provider: &EmbeddingProvider,
|
||||
) -> Result<(Vec<EmbeddedKnowledgeEntity>, Vec<KnowledgeRelationship>), AppError> {
|
||||
let mapper = Arc::new(self.create_mapper());
|
||||
|
||||
@@ -58,8 +55,6 @@ impl LLMEnrichmentResult {
|
||||
source_id,
|
||||
user_id,
|
||||
Arc::clone(&mapper),
|
||||
openai_client,
|
||||
db_client,
|
||||
entity_concurrency,
|
||||
embedding_provider,
|
||||
)
|
||||
@@ -80,23 +75,18 @@ impl LLMEnrichmentResult {
|
||||
mapper
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn process_entities(
|
||||
&self,
|
||||
source_id: &str,
|
||||
user_id: &str,
|
||||
mapper: Arc<GraphMapper>,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
db_client: &SurrealDbClient,
|
||||
entity_concurrency: usize,
|
||||
embedding_provider: Option<&EmbeddingProvider>,
|
||||
embedding_provider: &EmbeddingProvider,
|
||||
) -> Result<Vec<EmbeddedKnowledgeEntity>, AppError> {
|
||||
stream::iter(self.knowledge_entities.clone().into_iter().map(|entity| {
|
||||
let mapper = Arc::clone(&mapper);
|
||||
let openai_client = openai_client.clone();
|
||||
let source_id = source_id.to_string();
|
||||
let user_id = user_id.to_string();
|
||||
let db_client = db_client.clone();
|
||||
|
||||
async move {
|
||||
create_single_entity(
|
||||
@@ -104,8 +94,6 @@ impl LLMEnrichmentResult {
|
||||
&source_id,
|
||||
&user_id,
|
||||
mapper,
|
||||
&openai_client,
|
||||
&db_client,
|
||||
embedding_provider,
|
||||
)
|
||||
.await
|
||||
@@ -145,9 +133,7 @@ async fn create_single_entity(
|
||||
source_id: &str,
|
||||
user_id: &str,
|
||||
mapper: Arc<GraphMapper>,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
db_client: &SurrealDbClient,
|
||||
embedding_provider: Option<&EmbeddingProvider>,
|
||||
embedding_provider: &EmbeddingProvider,
|
||||
) -> Result<EmbeddedKnowledgeEntity, AppError> {
|
||||
let assigned_id = mapper.get_id(&llm_entity.key)?.to_string();
|
||||
|
||||
@@ -156,11 +142,7 @@ async fn create_single_entity(
|
||||
llm_entity.name, llm_entity.description, llm_entity.entity_type
|
||||
);
|
||||
|
||||
let embedding = if let Some(provider) = embedding_provider {
|
||||
provider.embed(&embedding_input).await?
|
||||
} else {
|
||||
generate_embedding(openai_client, &embedding_input, db_client).await?
|
||||
};
|
||||
let embedding = embedding_provider.embed(&embedding_input).await?;
|
||||
|
||||
let now = Utc::now();
|
||||
let entity = KnowledgeEntity {
|
||||
|
||||
@@ -187,8 +187,7 @@ impl PipelineServices for DefaultPipelineServices {
|
||||
let config = retrieval_pipeline::RetrievalConfig::with_entities();
|
||||
match retrieval_pipeline::retrieve(
|
||||
&self.db,
|
||||
&self.openai_client,
|
||||
Some(&*self.embedding_provider),
|
||||
&self.embedding_provider,
|
||||
&input_text,
|
||||
&content.user_id,
|
||||
config,
|
||||
@@ -237,10 +236,8 @@ impl PipelineServices for DefaultPipelineServices {
|
||||
.to_database_entities(
|
||||
content.id(),
|
||||
&content.user_id,
|
||||
&self.openai_client,
|
||||
&self.db,
|
||||
entity_concurrency,
|
||||
Some(&*self.embedding_provider),
|
||||
&self.embedding_provider,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -258,15 +255,30 @@ impl PipelineServices for DefaultPipelineServices {
|
||||
overlap_tokens,
|
||||
)?;
|
||||
|
||||
if chunk_candidates.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Embed all chunks of this document in one batch: a single lock acquisition and one
|
||||
// blocking hop, letting the backend batch the inference internally.
|
||||
let embeddings = self
|
||||
.embedding_provider
|
||||
.embed_batch(chunk_candidates.clone())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
AppError::InternalError(format!("FastEmbed embedding for chunks failed: {e}"))
|
||||
})?;
|
||||
|
||||
if embeddings.len() != chunk_candidates.len() {
|
||||
return Err(AppError::InternalError(format!(
|
||||
"embedding batch returned {} vectors for {} chunks",
|
||||
embeddings.len(),
|
||||
chunk_candidates.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let mut chunks = Vec::with_capacity(chunk_candidates.len());
|
||||
for chunk_text in chunk_candidates {
|
||||
let embedding = self
|
||||
.embedding_provider
|
||||
.embed(&chunk_text)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
AppError::InternalError(format!("FastEmbed embedding for chunk failed: {e}"))
|
||||
})?;
|
||||
for (chunk_text, embedding) in chunk_candidates.into_iter().zip(embeddings) {
|
||||
let chunk_struct = TextChunk::new(
|
||||
content.id().to_string(),
|
||||
chunk_text,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
mod startup;
|
||||
pub mod wiring;
|
||||
|
||||
pub use startup::prepare_embedding_runtime;
|
||||
pub use startup::{prepare_embedding_runtime, EmbeddingRuntimeRole};
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
|
||||
+210
-21
@@ -2,7 +2,7 @@ use anyhow::Context;
|
||||
use common::{
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
indexes::ensure_runtime,
|
||||
indexes::{embedding_index_dimension, ensure_runtime},
|
||||
types::{
|
||||
knowledge_entity::KnowledgeEntity, system_settings::SystemSettings,
|
||||
text_chunk::TextChunk,
|
||||
@@ -10,37 +10,129 @@ use common::{
|
||||
},
|
||||
utils::embedding::EmbeddingProvider,
|
||||
};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use super::SharedServices;
|
||||
|
||||
/// Syncs embedding settings, re-embeds stored vectors when dimensions change, and
|
||||
/// ensures runtime indexes match the active embedding dimension.
|
||||
pub async fn prepare_embedding_runtime(services: &SharedServices) -> anyhow::Result<SystemSettings> {
|
||||
let (settings, dimensions_changed) =
|
||||
/// How a process participates in embedding-runtime maintenance.
|
||||
///
|
||||
/// Embedding configuration changes (model/dimension) take effect on restart: the active
|
||||
/// [`EmbeddingProvider`] is built once at startup, so the stored vectors must be reconciled to it
|
||||
/// before indexes are rebuilt. Only a single maintainer should perform that (potentially long,
|
||||
/// destructive) re-embed; query-only servers stay read-only to avoid racing it.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
// Each binary (main/worker/server) constructs only one variant, so the other looks dead within
|
||||
// that single compilation unit even though both are used across the binary set.
|
||||
#[allow(dead_code)]
|
||||
pub enum EmbeddingRuntimeRole {
|
||||
/// Combined binary or worker: re-embeds stored data when it no longer matches the provider.
|
||||
Maintainer,
|
||||
/// Server-only: never mutates stored embeddings; aligns indexes to the data that exists.
|
||||
ReadOnly,
|
||||
}
|
||||
|
||||
/// Re-embed lock TTL. Generously sized so a slow re-embed of a large corpus never expires
|
||||
/// out from under the maintainer that holds it; an abandoned lock (crashed maintainer) self-heals.
|
||||
const REEMBED_LOCK_TTL: &str = "30m";
|
||||
|
||||
/// Reconciles embeddings with the active provider and ensures runtime indexes are ready.
|
||||
///
|
||||
/// Detection is based on the stored chunk-embedding HNSW index dimension (a persisted marker of
|
||||
/// the embedding space actually in the database). When it differs from the active provider's
|
||||
/// dimension, a [`EmbeddingRuntimeRole::Maintainer`] re-embeds before indexes are (re)built;
|
||||
/// a [`EmbeddingRuntimeRole::ReadOnly`] server leaves indexes aligned to the existing data and
|
||||
/// serves in a degraded state until a maintainer reconciles.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if syncing settings, inspecting/building indexes, or re-embedding fails.
|
||||
pub async fn prepare_embedding_runtime(
|
||||
services: &SharedServices,
|
||||
role: EmbeddingRuntimeRole,
|
||||
) -> anyhow::Result<SystemSettings> {
|
||||
// Keep SystemSettings in sync with the active provider so the admin UI reflects the real
|
||||
// backend/model/dimension. This does not, by itself, decide whether a re-embed is needed.
|
||||
let (settings, _changed) =
|
||||
SystemSettings::sync_from_embedding_provider(&services.db, &services.embedding_provider)
|
||||
.await
|
||||
.context("sync system settings from embedding provider")?;
|
||||
|
||||
if dimensions_changed {
|
||||
re_embed_all(
|
||||
&services.db,
|
||||
&services.embedding_provider,
|
||||
settings.embedding_dimensions,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
let target_dim = services.embedding_provider.dimension();
|
||||
let stored_dim = embedding_index_dimension(&services.db)
|
||||
.await
|
||||
.context("inspect stored embedding index dimension")?;
|
||||
let mismatch = matches!(stored_dim, Some(dim) if dim != target_dim);
|
||||
|
||||
ensure_runtime(
|
||||
&services.db,
|
||||
settings.embedding_dimensions as usize,
|
||||
)
|
||||
.await
|
||||
.context("ensure runtime indexes")?;
|
||||
let index_dim = if mismatch {
|
||||
match role {
|
||||
EmbeddingRuntimeRole::Maintainer => {
|
||||
reconcile_embeddings(&services.db, &services.embedding_provider, target_dim).await?;
|
||||
target_dim
|
||||
}
|
||||
EmbeddingRuntimeRole::ReadOnly => {
|
||||
warn!(
|
||||
stored_dimension = stored_dim,
|
||||
target_dimension = target_dim,
|
||||
"Stored embeddings do not match the active embedding dimension. A maintainer \
|
||||
(worker) must re-embed; serving in a degraded state and keeping indexes \
|
||||
aligned to the existing data until then."
|
||||
);
|
||||
// Preserve the index that matches the vectors actually stored. Do not overwrite it
|
||||
// to the new dimension here — that would happen before the data is re-embedded and
|
||||
// would break retrieval entirely.
|
||||
stored_dim.unwrap_or(target_dim)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
target_dim
|
||||
};
|
||||
|
||||
ensure_runtime(&services.db, index_dim)
|
||||
.await
|
||||
.context("ensure runtime indexes")?;
|
||||
|
||||
Ok(settings)
|
||||
}
|
||||
|
||||
/// Acquires the re-embed lock (so only one maintainer reconciles), re-embeds, then releases it.
|
||||
async fn reconcile_embeddings(
|
||||
db: &SurrealDbClient,
|
||||
embedding_provider: &EmbeddingProvider,
|
||||
target_dim: usize,
|
||||
) -> anyhow::Result<()> {
|
||||
let owner = reembed_lock_owner();
|
||||
|
||||
if !try_acquire_reembed_lock(db, &owner).await? {
|
||||
info!("Another maintainer holds the re-embed lock; skipping re-embed on this instance");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let result = reconcile_under_lock(db, embedding_provider, target_dim).await;
|
||||
release_reembed_lock(db, &owner).await;
|
||||
result
|
||||
}
|
||||
|
||||
/// Re-embed body executed while holding the lock, with a re-check to avoid duplicate work.
|
||||
async fn reconcile_under_lock(
|
||||
db: &SurrealDbClient,
|
||||
embedding_provider: &EmbeddingProvider,
|
||||
target_dim: usize,
|
||||
) -> anyhow::Result<()> {
|
||||
// A peer may have finished re-embedding between detection and lock acquisition.
|
||||
let stored_dim = embedding_index_dimension(db)
|
||||
.await
|
||||
.context("re-check stored embedding dimension under lock")?;
|
||||
if !matches!(stored_dim, Some(dim) if dim != target_dim) {
|
||||
info!("Stored embeddings already match the active dimension; skipping re-embed");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let target_dim_u32 = u32::try_from(target_dim)
|
||||
.map_err(|_| anyhow::anyhow!("embedding dimension {target_dim} exceeds u32::MAX"))?;
|
||||
re_embed_all(db, embedding_provider, target_dim_u32).await
|
||||
}
|
||||
|
||||
async fn re_embed_all(
|
||||
db: &SurrealDbClient,
|
||||
embedding_provider: &EmbeddingProvider,
|
||||
@@ -52,15 +144,112 @@ async fn re_embed_all(
|
||||
);
|
||||
|
||||
info!("Re-embedding TextChunks");
|
||||
TextChunk::update_all_embeddings_with_provider(db, embedding_provider)
|
||||
TextChunk::update_all_embeddings(db, embedding_provider)
|
||||
.await
|
||||
.context("re-embed text chunks after embedding dimension change")?;
|
||||
|
||||
info!("Re-embedding KnowledgeEntities");
|
||||
KnowledgeEntity::update_all_embeddings_with_provider(db, embedding_provider)
|
||||
KnowledgeEntity::update_all_embeddings(db, embedding_provider)
|
||||
.await
|
||||
.context("re-embed knowledge entities after embedding dimension change")?;
|
||||
|
||||
info!("Re-embedding complete");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// A process-unique token identifying this re-embed lock acquisition (for release).
|
||||
fn reembed_lock_owner() -> String {
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.map_or(0, |d| d.as_nanos());
|
||||
format!("reembed-{}-{nanos}", std::process::id())
|
||||
}
|
||||
|
||||
/// Best-effort atomic mutex over the (potentially long) re-embed using a singleton record.
|
||||
///
|
||||
/// `CREATE` of a fixed record id fails if it already exists, which serializes concurrent
|
||||
/// maintainers. An expired lock is reaped first so a crashed maintainer cannot block forever.
|
||||
async fn try_acquire_reembed_lock(db: &SurrealDbClient, owner: &str) -> anyhow::Result<bool> {
|
||||
db.client
|
||||
.query("DEFINE TABLE IF NOT EXISTS maintenance_lock SCHEMALESS;")
|
||||
.await
|
||||
.and_then(surrealdb::Response::check)
|
||||
.context("define maintenance_lock table")?;
|
||||
|
||||
db.client
|
||||
.query("DELETE maintenance_lock:reembed WHERE expires_at < time::now();")
|
||||
.await
|
||||
.and_then(surrealdb::Response::check)
|
||||
.context("reap expired re-embed lock")?;
|
||||
|
||||
// `CREATE` of a fixed record id succeeds for the first caller and errors with an
|
||||
// "already exists" record conflict for any concurrent caller, giving us an atomic mutex.
|
||||
let acquired = db
|
||||
.client
|
||||
.query(format!(
|
||||
"CREATE maintenance_lock:reembed SET owner = $owner, expires_at = time::now() + {REEMBED_LOCK_TTL};"
|
||||
))
|
||||
.bind(("owner", owner.to_string()))
|
||||
.await
|
||||
.and_then(surrealdb::Response::check)
|
||||
.is_ok();
|
||||
|
||||
Ok(acquired)
|
||||
}
|
||||
|
||||
async fn release_reembed_lock(db: &SurrealDbClient, owner: &str) {
|
||||
let released = db
|
||||
.client
|
||||
.query("DELETE maintenance_lock:reembed WHERE owner = $owner;")
|
||||
.bind(("owner", owner.to_string()))
|
||||
.await
|
||||
.and_then(surrealdb::Response::check);
|
||||
|
||||
if let Err(err) = released {
|
||||
warn!(error = %err, "Failed to release re-embed lock; it will expire automatically");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::expect_used)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use common::storage::db::SurrealDbClient;
|
||||
|
||||
async fn test_db() -> SurrealDbClient {
|
||||
SurrealDbClient::memory("reembed_lock_ns", &reembed_lock_owner())
|
||||
.await
|
||||
.expect("in-memory db")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reembed_lock_is_exclusive_and_reusable_after_release() {
|
||||
let db = test_db().await;
|
||||
|
||||
let first = reembed_lock_owner();
|
||||
assert!(
|
||||
try_acquire_reembed_lock(&db, &first)
|
||||
.await
|
||||
.expect("acquire first"),
|
||||
"the first acquirer should win the lock"
|
||||
);
|
||||
|
||||
// A second, concurrent maintainer must not be able to take a held lock.
|
||||
let second = format!("{first}-peer");
|
||||
assert!(
|
||||
!try_acquire_reembed_lock(&db, &second)
|
||||
.await
|
||||
.expect("contend for lock"),
|
||||
"a held lock must not be granted to another owner"
|
||||
);
|
||||
|
||||
// Releasing it (only the holder can) frees it for the next maintainer.
|
||||
release_reembed_lock(&db, &first).await;
|
||||
assert!(
|
||||
try_acquire_reembed_lock(&db, &second)
|
||||
.await
|
||||
.expect("re-acquire after release"),
|
||||
"the lock should be grantable again once released"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
+5
-2
@@ -6,6 +6,7 @@ use axum::extract::FromRef;
|
||||
use bootstrap::{
|
||||
init, prepare_embedding_runtime,
|
||||
wiring::{build_api_state, build_html_state, minne_routes},
|
||||
EmbeddingRuntimeRole,
|
||||
};
|
||||
use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop};
|
||||
use tracing::info;
|
||||
@@ -20,7 +21,8 @@ async fn main() -> anyhow::Result<()> {
|
||||
"Embedding provider initialized"
|
||||
);
|
||||
|
||||
prepare_embedding_runtime(&services).await?;
|
||||
// The combined binary runs the worker in-process, so it owns re-embedding.
|
||||
prepare_embedding_runtime(&services, EmbeddingRuntimeRole::Maintainer).await?;
|
||||
|
||||
let html_state = build_html_state(&services).await?;
|
||||
let api_state = build_api_state(&services);
|
||||
@@ -88,6 +90,7 @@ mod tests {
|
||||
prepare_embedding_runtime,
|
||||
tests::init_smoke_services,
|
||||
wiring::{build_api_state, build_html_state, minne_routes},
|
||||
EmbeddingRuntimeRole,
|
||||
};
|
||||
use common::storage::types::{system_settings::SystemSettings, user::User};
|
||||
use tower::ServiceExt;
|
||||
@@ -97,7 +100,7 @@ mod tests {
|
||||
.await
|
||||
.expect("failed to init services");
|
||||
|
||||
prepare_embedding_runtime(&services)
|
||||
prepare_embedding_runtime(&services, EmbeddingRuntimeRole::Maintainer)
|
||||
.await
|
||||
.expect("failed to prepare embedding runtime");
|
||||
|
||||
|
||||
+3
-1
@@ -4,13 +4,15 @@ use axum::extract::FromRef;
|
||||
use bootstrap::{
|
||||
init, prepare_embedding_runtime,
|
||||
wiring::{build_api_state, build_html_state, minne_routes},
|
||||
EmbeddingRuntimeRole,
|
||||
};
|
||||
use tracing::info;
|
||||
|
||||
#[tokio::main(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let services = init().await?;
|
||||
prepare_embedding_runtime(&services).await?;
|
||||
// The server never re-embeds; the worker owns that. It only ensures indexes are ready.
|
||||
prepare_embedding_runtime(&services, EmbeddingRuntimeRole::ReadOnly).await?;
|
||||
|
||||
let html_state = build_html_state(&services).await?;
|
||||
let api_state = build_api_state(&services);
|
||||
|
||||
+3
-2
@@ -2,14 +2,15 @@ mod bootstrap;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use bootstrap::{init, prepare_embedding_runtime};
|
||||
use bootstrap::{init, prepare_embedding_runtime, EmbeddingRuntimeRole};
|
||||
use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop};
|
||||
use tracing::info;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let services = init().await?;
|
||||
prepare_embedding_runtime(&services).await?;
|
||||
// The worker owns re-embedding: it reconciles stored vectors to the active provider.
|
||||
prepare_embedding_runtime(&services, EmbeddingRuntimeRole::Maintainer).await?;
|
||||
|
||||
info!(
|
||||
embedding_backend = ?services.config.embedding_backend,
|
||||
|
||||
@@ -59,8 +59,7 @@ pub struct RetrievedEntity {
|
||||
#[instrument(skip_all, fields(user_id))]
|
||||
pub async fn retrieve(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
|
||||
embedding_provider: &common::utils::embedding::EmbeddingProvider,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
@@ -68,7 +67,6 @@ pub async fn retrieve(
|
||||
) -> Result<RetrievalOutput, AppError> {
|
||||
let params = pipeline::RetrievalParams {
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
input_text,
|
||||
user_id,
|
||||
@@ -82,12 +80,16 @@ pub async fn retrieve(
|
||||
mod tests {
|
||||
use super::*;
|
||||
use anyhow::{self};
|
||||
use async_openai::Client;
|
||||
use common::storage::indexes::ensure_runtime;
|
||||
use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||
use common::storage::types::system_settings::SystemSettings;
|
||||
use common::utils::embedding::EmbeddingProvider;
|
||||
use uuid::Uuid;
|
||||
|
||||
fn test_embedding_provider() -> EmbeddingProvider {
|
||||
EmbeddingProvider::new_hashed(3).unwrap_or_else(|_| unreachable!())
|
||||
}
|
||||
|
||||
fn test_embedding() -> Vec<f32> {
|
||||
vec![0.9, 0.1, 0.0]
|
||||
}
|
||||
@@ -135,11 +137,10 @@ mod tests {
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?;
|
||||
|
||||
let openai_client = Client::new();
|
||||
let embedding_provider = test_embedding_provider();
|
||||
let params = pipeline::RetrievalParams {
|
||||
db_client: &db,
|
||||
openai_client: &openai_client,
|
||||
embedding_provider: None,
|
||||
embedding_provider: &embedding_provider,
|
||||
input_text: "Rust concurrency async tasks",
|
||||
user_id,
|
||||
config: RetrievalConfig::default(),
|
||||
@@ -181,11 +182,10 @@ mod tests {
|
||||
TextChunk::store_with_embedding(primary_chunk, chunk_embedding_primary(), &db).await?;
|
||||
TextChunk::store_with_embedding(secondary_chunk, chunk_embedding_secondary(), &db).await?;
|
||||
|
||||
let openai_client = Client::new();
|
||||
let embedding_provider = test_embedding_provider();
|
||||
let params = pipeline::RetrievalParams {
|
||||
db_client: &db,
|
||||
openai_client: &openai_client,
|
||||
embedding_provider: None,
|
||||
embedding_provider: &embedding_provider,
|
||||
input_text: "Rust concurrency async tasks",
|
||||
user_id,
|
||||
config: RetrievalConfig::default(),
|
||||
@@ -236,11 +236,10 @@ mod tests {
|
||||
);
|
||||
db.store_item(entity).await?;
|
||||
|
||||
let openai_client = Client::new();
|
||||
let embedding_provider = test_embedding_provider();
|
||||
let params = pipeline::RetrievalParams {
|
||||
db_client: &db,
|
||||
openai_client: &openai_client,
|
||||
embedding_provider: None,
|
||||
embedding_provider: &embedding_provider,
|
||||
input_text: "async rust programming",
|
||||
user_id,
|
||||
config: RetrievalConfig::with_entities(),
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use async_openai::Client;
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{db::SurrealDbClient, types::text_chunk::TextChunk},
|
||||
@@ -18,8 +17,7 @@ use super::{
|
||||
/// Mutable working state threaded through every retrieval stage.
|
||||
pub(crate) struct PipelineContext<'a> {
|
||||
pub db_client: &'a SurrealDbClient,
|
||||
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
pub embedding_provider: Option<&'a EmbeddingProvider>,
|
||||
pub embedding_provider: &'a EmbeddingProvider,
|
||||
pub input_text: String,
|
||||
pub user_id: String,
|
||||
pub config: RetrievalConfig,
|
||||
@@ -36,7 +34,6 @@ impl<'a> PipelineContext<'a> {
|
||||
pub fn new(params: RetrievalParams<'a>) -> Self {
|
||||
Self {
|
||||
db_client: params.db_client,
|
||||
openai_client: params.openai_client,
|
||||
embedding_provider: params.embedding_provider,
|
||||
input_text: params.input_text.to_owned(),
|
||||
user_id: params.user_id.to_owned(),
|
||||
|
||||
@@ -7,7 +7,6 @@ pub use config::{RetrievalConfig, RetrievalTuning};
|
||||
pub use diagnostics::Diagnostics;
|
||||
|
||||
use crate::{round_score, RetrievalOutput, RetrievedEntity};
|
||||
use async_openai::Client;
|
||||
use async_trait::async_trait;
|
||||
use common::{error::AppError, storage::db::SurrealDbClient};
|
||||
use std::time::{Duration, Instant};
|
||||
@@ -91,8 +90,7 @@ pub struct RunOutput<T> {
|
||||
/// Inputs required to run a retrieval.
|
||||
pub struct RetrievalParams<'a> {
|
||||
pub db_client: &'a SurrealDbClient,
|
||||
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
pub embedding_provider: Option<&'a common::utils::embedding::EmbeddingProvider>,
|
||||
pub embedding_provider: &'a common::utils::embedding::EmbeddingProvider,
|
||||
pub input_text: &'a str,
|
||||
pub user_id: &'a str,
|
||||
pub config: RetrievalConfig,
|
||||
|
||||
@@ -2,7 +2,6 @@ use async_trait::async_trait;
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
|
||||
utils::embedding::generate_embedding,
|
||||
};
|
||||
use fastembed::RerankResult;
|
||||
use std::collections::HashMap;
|
||||
@@ -97,11 +96,7 @@ pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
debug!("Reusing cached query embedding for hybrid retrieval");
|
||||
} else {
|
||||
debug!("Generating query embedding for hybrid retrieval");
|
||||
let embedding = if let Some(provider) = ctx.embedding_provider {
|
||||
provider.embed(&ctx.input_text).await?
|
||||
} else {
|
||||
generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await?
|
||||
};
|
||||
let embedding = ctx.embedding_provider.embed(&ctx.input_text).await?;
|
||||
ctx.query_embedding = Some(embedding);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user