feat: support for other providers of ai models

This commit is contained in:
Per Stark
2025-06-06 23:16:41 +02:00
parent 811aaec554
commit a363c6cc05
22 changed files with 519 additions and 66 deletions

View File

@@ -11,6 +11,7 @@ use surrealdb::{
Error, Notification, Surreal,
};
use surrealdb_migrations::MigrationRunner;
use tracing::debug;
static MIGRATIONS_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/");
@@ -50,6 +51,7 @@ impl SurrealDbClient {
pub async fn create_session_store(
&self,
) -> Result<SessionStore<SessionSurrealPool<Any>>, SessionError> {
debug!("Creating session store");
SessionStore::new(
Some(self.client.clone().into()),
SessionConfig::default()
@@ -65,6 +67,7 @@ impl SurrealDbClient {
/// the database and selecting the appropriate namespace and database, but before
/// the application starts performing operations that rely on the schema.
pub async fn apply_migrations(&self) -> Result<(), AppError> {
debug!("Applying migrations");
MigrationRunner::new(&self.client)
.load_files(&MIGRATIONS_DIR)
.up()
@@ -76,6 +79,7 @@ impl SurrealDbClient {
/// Operation to rebuild indexes
pub async fn rebuild_indexes(&self) -> Result<(), Error> {
debug!("Rebuilding indexes");
self.client
.query("REBUILD INDEX IF EXISTS idx_embedding_chunks ON text_chunk")
.await?;

View File

@@ -349,6 +349,8 @@ mod tests {
surrealdb_password: "test_pass".to_string(),
surrealdb_namespace: "test_ns".to_string(),
surrealdb_database: "test_db".to_string(),
http_port: 3000,
openai_base_url: "..".to_string(),
};
// Test file creation
@@ -406,6 +408,8 @@ mod tests {
surrealdb_password: "test_pass".to_string(),
surrealdb_namespace: "test_ns".to_string(),
surrealdb_database: "test_db".to_string(),
http_port: 3000,
openai_base_url: "..".to_string(),
};
// Store the original file
@@ -459,6 +463,8 @@ mod tests {
surrealdb_password: "test_pass".to_string(),
surrealdb_namespace: "test_ns".to_string(),
surrealdb_database: "test_db".to_string(),
http_port: 3000,
openai_base_url: "..".to_string(),
};
let file_info = FileInfo::new(field_data, &db, user_id, &config).await;
@@ -508,6 +514,8 @@ mod tests {
surrealdb_password: "test_pass".to_string(),
surrealdb_namespace: "test_ns".to_string(),
surrealdb_database: "test_db".to_string(),
http_port: 3000,
openai_base_url: "..".to_string(),
};
let field_data1 = create_test_file(content, file_name);
@@ -844,6 +852,8 @@ mod tests {
surrealdb_password: "test_pass".to_string(),
surrealdb_namespace: "test_ns".to_string(),
surrealdb_database: "test_db".to_string(),
http_port: 3000,
openai_base_url: "..".to_string(),
};
// Test file creation

View File

@@ -1,8 +1,15 @@
use std::collections::HashMap;
use crate::{
error::AppError, storage::db::SurrealDbClient, stored_object,
utils::embedding::generate_embedding,
};
use async_openai::{config::OpenAIConfig, Client};
use tokio_retry::{
strategy::{jitter, ExponentialBackoff},
Retry,
};
use tracing::{error, info};
use uuid::Uuid;
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
@@ -94,7 +101,7 @@ impl KnowledgeEntity {
"name: {}, description: {}, type: {:?}",
name, description, entity_type
);
let embedding = generate_embedding(ai_client, &embedding_input).await?;
let embedding = generate_embedding(ai_client, &embedding_input, db_client).await?;
db_client
.client
@@ -118,6 +125,104 @@ impl KnowledgeEntity {
Ok(())
}
/// Re-creates embeddings for all knowledge entities in the database.
///
/// This is a costly operation that should be run in the background. It follows the same
/// pattern as the text chunk update:
/// 1. Re-defines the vector index with the new dimensions.
/// 2. Fetches all existing entities.
/// 3. Sequentially regenerates the embedding for each and updates the record.
pub async fn update_all_embeddings(
db: &SurrealDbClient,
openai_client: &Client<OpenAIConfig>,
new_model: &str,
new_dimensions: u32,
) -> Result<(), AppError> {
info!(
"Starting re-embedding process for all knowledge entities. New dimensions: {}",
new_dimensions
);
// Fetch all entities first
let all_entities: Vec<KnowledgeEntity> = db.select(Self::table_name()).await?;
let total_entities = all_entities.len();
if total_entities == 0 {
info!("No knowledge entities to update. Skipping.");
return Ok(());
}
info!("Found {} entities to process.", total_entities);
// Generate all new embeddings in memory
let mut new_embeddings: HashMap<String, Vec<f32>> = HashMap::new();
info!("Generating new embeddings for all entities...");
for entity in all_entities.iter() {
let embedding_input = format!(
"name: {}, description: {}, type: {:?}",
entity.name, entity.description, entity.entity_type
);
let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3);
let embedding = Retry::spawn(retry_strategy, || {
crate::utils::embedding::generate_embedding_with_params(
openai_client,
&embedding_input,
new_model,
new_dimensions,
)
})
.await?;
// Check embedding lengths
if embedding.len() != new_dimensions as usize {
let err_msg = format!(
"CRITICAL: Generated embedding for entity {} has incorrect dimension ({}). Expected {}. Aborting.",
entity.id, embedding.len(), new_dimensions
);
error!("{}", err_msg);
return Err(AppError::InternalError(err_msg));
}
new_embeddings.insert(entity.id.clone(), embedding);
}
info!("Successfully generated all new embeddings.");
// Perform DB updates in a single transaction
info!("Applying schema and data changes in a transaction...");
let mut transaction_query = String::from("BEGIN TRANSACTION;");
// Add all update statements
for (id, embedding) in new_embeddings {
// We must properly serialize the vector for the SurrealQL query string
let embedding_str = format!(
"[{}]",
embedding
.iter()
.map(|f| f.to_string())
.collect::<Vec<_>>()
.join(",")
);
transaction_query.push_str(&format!(
"UPDATE type::thing('knowledge_entity', '{}') SET embedding = {}, updated_at = time::now();",
id, embedding_str
));
}
// Re-create the index after updating the data that it will index
transaction_query
.push_str("REMOVE INDEX idx_embedding_entities ON TABLE knowledge_entity;");
transaction_query.push_str(&format!(
"DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {};",
new_dimensions
));
transaction_query.push_str("COMMIT TRANSACTION;");
// Execute the entire atomic operation
db.query(transaction_query).await?;
info!("Re-embedding process for knowledge entities completed successfully.");
Ok(())
}
}
#[cfg(test)]

View File

@@ -11,6 +11,8 @@ pub struct SystemSettings {
pub require_email_verification: bool,
pub query_model: String,
pub processing_model: String,
pub embedding_model: String,
pub embedding_dimensions: u32,
pub query_system_prompt: String,
pub ingestion_system_prompt: String,
}
@@ -44,25 +46,12 @@ impl SystemSettings {
"Something went wrong updating the settings".into(),
))
}
pub fn new() -> Self {
Self {
id: "current".to_string(),
query_system_prompt: crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT
.to_string(),
ingestion_system_prompt:
crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT
.to_string(),
query_model: "gpt-4o-mini".to_string(),
processing_model: "gpt-4o-mini".to_string(),
registrations_enabled: true,
require_email_verification: false,
}
}
}
#[cfg(test)]
mod tests {
use crate::storage::types::text_chunk::TextChunk;
use super::*;
use uuid::Uuid;
@@ -157,7 +146,7 @@ mod tests {
.expect("Failed to apply migrations");
// Create updated settings
let mut updated_settings = SystemSettings::new();
let mut updated_settings = SystemSettings::get_current(&db).await.unwrap();
updated_settings.id = "current".to_string();
updated_settings.registrations_enabled = false;
updated_settings.require_email_verification = true;
@@ -206,21 +195,60 @@ mod tests {
}
#[tokio::test]
async fn test_new_method() {
let settings = SystemSettings::new();
async fn test_migration_after_changing_embedding_length() {
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
.await
.expect("Failed to start DB");
assert!(settings.id.len() > 0);
assert_eq!(settings.registrations_enabled, true);
assert_eq!(settings.require_email_verification, false);
assert_eq!(settings.query_model, "gpt-4o-mini");
assert_eq!(settings.processing_model, "gpt-4o-mini");
assert_eq!(
settings.query_system_prompt,
crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT
);
assert_eq!(
settings.ingestion_system_prompt,
crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
db.apply_migrations()
.await
.expect("Initial migration failed");
let initial_chunk = TextChunk::new(
"source1".into(),
"This chunk has the original dimension".into(),
vec![0.1; 1536],
"user1".into(),
);
db.store_item(initial_chunk.clone())
.await
.expect("Failed to store initial chunk");
async fn simulate_reembedding(
db: &SurrealDbClient,
target_dimension: usize,
initial_chunk: TextChunk,
) {
db.query("REMOVE INDEX idx_embedding_chunks ON TABLE text_chunk;")
.await
.unwrap();
let define_index_query = format!(
"DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {};",
target_dimension
);
db.query(define_index_query)
.await
.expect("Re-defining index should succeed");
let new_embedding = vec![0.5; target_dimension];
let sql = "UPDATE type::thing('text_chunk', $id) SET embedding = $embedding;";
let update_result = db
.client
.query(sql)
.bind(("id", initial_chunk.id.clone()))
.bind(("embedding", new_embedding))
.await;
assert!(update_result.is_ok());
}
simulate_reembedding(&db, 768, initial_chunk).await;
let migration_result = db.apply_migrations().await;
assert!(migration_result.is_ok(), "Migrations should not fail");
}
}

View File

@@ -1,4 +1,13 @@
use std::collections::HashMap;
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
use async_openai::{config::OpenAIConfig, Client};
use tokio_retry::{
strategy::{jitter, ExponentialBackoff},
Retry,
};
use tracing::{error, info};
use uuid::Uuid;
stored_object!(TextChunk, "text_chunk", {
@@ -35,6 +44,99 @@ impl TextChunk {
Ok(())
}
/// Re-creates embeddings for all text chunks using a safe, atomic transaction.
///
/// This is a costly operation that should be run in the background. It performs these steps:
/// 1. **Fetches All Chunks**: Loads all existing text_chunk records into memory.
/// 2. **Generates All Embeddings**: Creates new embeddings for every chunk. If any fails or
/// has the wrong dimension, the entire operation is aborted before any DB changes are made.
/// 3. **Executes Atomic Transaction**: All data updates and the index recreation are
/// performed in a single, all-or-nothing database transaction.
pub async fn update_all_embeddings(
db: &SurrealDbClient,
openai_client: &Client<OpenAIConfig>,
new_model: &str,
new_dimensions: u32,
) -> Result<(), AppError> {
info!(
"Starting re-embedding process for all text chunks. New dimensions: {}",
new_dimensions
);
// Fetch all chunks first
let all_chunks: Vec<TextChunk> = db.select(Self::table_name()).await?;
let total_chunks = all_chunks.len();
if total_chunks == 0 {
info!("No text chunks to update. Skipping.");
return Ok(());
}
info!("Found {} chunks to process.", total_chunks);
// Generate all new embeddings in memory
let mut new_embeddings: HashMap<String, Vec<f32>> = HashMap::new();
info!("Generating new embeddings for all chunks...");
for chunk in all_chunks.iter() {
let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3);
let embedding = Retry::spawn(retry_strategy, || {
crate::utils::embedding::generate_embedding_with_params(
openai_client,
&chunk.chunk,
new_model,
new_dimensions,
)
})
.await?;
// Safety check: ensure the generated embedding has the correct dimension.
if embedding.len() != new_dimensions as usize {
let err_msg = format!(
"CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.",
chunk.id, embedding.len(), new_dimensions
);
error!("{}", err_msg);
return Err(AppError::InternalError(err_msg));
}
new_embeddings.insert(chunk.id.clone(), embedding);
}
info!("Successfully generated all new embeddings.");
// Perform DB updates in a single transaction
info!("Applying schema and data changes in a transaction...");
let mut transaction_query = String::from("BEGIN TRANSACTION;");
// Add all update statements
for (id, embedding) in new_embeddings {
let embedding_str = format!(
"[{}]",
embedding
.iter()
.map(|f| f.to_string())
.collect::<Vec<_>>()
.join(",")
);
transaction_query.push_str(&format!(
"UPDATE type::thing('text_chunk', '{}') SET embedding = {}, updated_at = time::now();",
id, embedding_str
));
}
// Re-create the index inside the same transaction
transaction_query.push_str("REMOVE INDEX idx_embedding_chunks ON TABLE text_chunk;");
transaction_query.push_str(&format!(
"DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {};",
new_dimensions
));
transaction_query.push_str("COMMIT TRANSACTION;");
// Execute the entire atomic operation
db.query(transaction_query).await?;
info!("Re-embedding process for text chunks completed successfully.");
Ok(())
}
}
#[cfg(test)]

View File

@@ -12,12 +12,18 @@ pub struct AppConfig {
#[serde(default = "default_data_dir")]
pub data_dir: String,
pub http_port: u16,
#[serde(default = "default_base_url")]
pub openai_base_url: String,
}
fn default_data_dir() -> String {
"./data".to_string()
}
fn default_base_url() -> String {
"https://api.openai.com/v1.".to_string()
}
pub fn get_config() -> Result<AppConfig, ConfigError> {
let config = Config::builder()
.add_source(File::with_name("config").required(false))

View File

@@ -1,6 +1,10 @@
use async_openai::types::CreateEmbeddingRequestArgs;
use tracing::debug;
use crate::error::AppError;
use crate::{
error::AppError,
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
};
/// Generates an embedding vector for the given input text using OpenAI's embedding model.
///
/// This function takes a text input and converts it into a numerical vector representation (embedding)
@@ -27,9 +31,13 @@ use crate::error::AppError;
pub async fn generate_embedding(
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
input: &str,
db: &SurrealDbClient,
) -> Result<Vec<f32>, AppError> {
let model = SystemSettings::get_current(db).await?;
let request = CreateEmbeddingRequestArgs::default()
.model("text-embedding-3-small")
.model(model.embedding_model)
.dimensions(model.embedding_dimensions)
.input([input])
.build()?;
@@ -46,3 +54,36 @@ pub async fn generate_embedding(
Ok(embedding)
}
/// Generates an embedding vector using a specific model and dimension.
///
/// This is used for the re-embedding process where the model and dimensions
/// are known ahead of time and shouldn't be repeatedly fetched from settings.
pub async fn generate_embedding_with_params(
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
input: &str,
model: &str,
dimensions: u32,
) -> Result<Vec<f32>, AppError> {
let request = CreateEmbeddingRequestArgs::default()
.model(model)
.input([input])
.dimensions(dimensions as u32)
.build()?;
let response = client.embeddings().create(request).await?;
let embedding = response
.data
.first()
.ok_or_else(|| AppError::LLMParsing("No embedding data received from API".into()))?
.embedding
.clone();
debug!(
"Embedding was created with {:?} dimensions",
embedding.len()
);
Ok(embedding)
}