mirror of
https://github.com/perstarkse/minne.git
synced 2026-03-27 20:01:31 +01:00
feat: support for other providers of ai models
This commit is contained in:
@@ -11,6 +11,7 @@ use surrealdb::{
|
||||
Error, Notification, Surreal,
|
||||
};
|
||||
use surrealdb_migrations::MigrationRunner;
|
||||
use tracing::debug;
|
||||
|
||||
static MIGRATIONS_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/");
|
||||
|
||||
@@ -50,6 +51,7 @@ impl SurrealDbClient {
|
||||
pub async fn create_session_store(
|
||||
&self,
|
||||
) -> Result<SessionStore<SessionSurrealPool<Any>>, SessionError> {
|
||||
debug!("Creating session store");
|
||||
SessionStore::new(
|
||||
Some(self.client.clone().into()),
|
||||
SessionConfig::default()
|
||||
@@ -65,6 +67,7 @@ impl SurrealDbClient {
|
||||
/// the database and selecting the appropriate namespace and database, but before
|
||||
/// the application starts performing operations that rely on the schema.
|
||||
pub async fn apply_migrations(&self) -> Result<(), AppError> {
|
||||
debug!("Applying migrations");
|
||||
MigrationRunner::new(&self.client)
|
||||
.load_files(&MIGRATIONS_DIR)
|
||||
.up()
|
||||
@@ -76,6 +79,7 @@ impl SurrealDbClient {
|
||||
|
||||
/// Operation to rebuild indexes
|
||||
pub async fn rebuild_indexes(&self) -> Result<(), Error> {
|
||||
debug!("Rebuilding indexes");
|
||||
self.client
|
||||
.query("REBUILD INDEX IF EXISTS idx_embedding_chunks ON text_chunk")
|
||||
.await?;
|
||||
|
||||
@@ -349,6 +349,8 @@ mod tests {
|
||||
surrealdb_password: "test_pass".to_string(),
|
||||
surrealdb_namespace: "test_ns".to_string(),
|
||||
surrealdb_database: "test_db".to_string(),
|
||||
http_port: 3000,
|
||||
openai_base_url: "..".to_string(),
|
||||
};
|
||||
|
||||
// Test file creation
|
||||
@@ -406,6 +408,8 @@ mod tests {
|
||||
surrealdb_password: "test_pass".to_string(),
|
||||
surrealdb_namespace: "test_ns".to_string(),
|
||||
surrealdb_database: "test_db".to_string(),
|
||||
http_port: 3000,
|
||||
openai_base_url: "..".to_string(),
|
||||
};
|
||||
|
||||
// Store the original file
|
||||
@@ -459,6 +463,8 @@ mod tests {
|
||||
surrealdb_password: "test_pass".to_string(),
|
||||
surrealdb_namespace: "test_ns".to_string(),
|
||||
surrealdb_database: "test_db".to_string(),
|
||||
http_port: 3000,
|
||||
openai_base_url: "..".to_string(),
|
||||
};
|
||||
let file_info = FileInfo::new(field_data, &db, user_id, &config).await;
|
||||
|
||||
@@ -508,6 +514,8 @@ mod tests {
|
||||
surrealdb_password: "test_pass".to_string(),
|
||||
surrealdb_namespace: "test_ns".to_string(),
|
||||
surrealdb_database: "test_db".to_string(),
|
||||
http_port: 3000,
|
||||
openai_base_url: "..".to_string(),
|
||||
};
|
||||
|
||||
let field_data1 = create_test_file(content, file_name);
|
||||
@@ -844,6 +852,8 @@ mod tests {
|
||||
surrealdb_password: "test_pass".to_string(),
|
||||
surrealdb_namespace: "test_ns".to_string(),
|
||||
surrealdb_database: "test_db".to_string(),
|
||||
http_port: 3000,
|
||||
openai_base_url: "..".to_string(),
|
||||
};
|
||||
|
||||
// Test file creation
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
error::AppError, storage::db::SurrealDbClient, stored_object,
|
||||
utils::embedding::generate_embedding,
|
||||
};
|
||||
use async_openai::{config::OpenAIConfig, Client};
|
||||
use tokio_retry::{
|
||||
strategy::{jitter, ExponentialBackoff},
|
||||
Retry,
|
||||
};
|
||||
use tracing::{error, info};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
|
||||
@@ -94,7 +101,7 @@ impl KnowledgeEntity {
|
||||
"name: {}, description: {}, type: {:?}",
|
||||
name, description, entity_type
|
||||
);
|
||||
let embedding = generate_embedding(ai_client, &embedding_input).await?;
|
||||
let embedding = generate_embedding(ai_client, &embedding_input, db_client).await?;
|
||||
|
||||
db_client
|
||||
.client
|
||||
@@ -118,6 +125,104 @@ impl KnowledgeEntity {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Re-creates embeddings for all knowledge entities in the database.
|
||||
///
|
||||
/// This is a costly operation that should be run in the background. It follows the same
|
||||
/// pattern as the text chunk update:
|
||||
/// 1. Re-defines the vector index with the new dimensions.
|
||||
/// 2. Fetches all existing entities.
|
||||
/// 3. Sequentially regenerates the embedding for each and updates the record.
|
||||
pub async fn update_all_embeddings(
|
||||
db: &SurrealDbClient,
|
||||
openai_client: &Client<OpenAIConfig>,
|
||||
new_model: &str,
|
||||
new_dimensions: u32,
|
||||
) -> Result<(), AppError> {
|
||||
info!(
|
||||
"Starting re-embedding process for all knowledge entities. New dimensions: {}",
|
||||
new_dimensions
|
||||
);
|
||||
|
||||
// Fetch all entities first
|
||||
let all_entities: Vec<KnowledgeEntity> = db.select(Self::table_name()).await?;
|
||||
let total_entities = all_entities.len();
|
||||
if total_entities == 0 {
|
||||
info!("No knowledge entities to update. Skipping.");
|
||||
return Ok(());
|
||||
}
|
||||
info!("Found {} entities to process.", total_entities);
|
||||
|
||||
// Generate all new embeddings in memory
|
||||
let mut new_embeddings: HashMap<String, Vec<f32>> = HashMap::new();
|
||||
info!("Generating new embeddings for all entities...");
|
||||
for entity in all_entities.iter() {
|
||||
let embedding_input = format!(
|
||||
"name: {}, description: {}, type: {:?}",
|
||||
entity.name, entity.description, entity.entity_type
|
||||
);
|
||||
let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3);
|
||||
|
||||
let embedding = Retry::spawn(retry_strategy, || {
|
||||
crate::utils::embedding::generate_embedding_with_params(
|
||||
openai_client,
|
||||
&embedding_input,
|
||||
new_model,
|
||||
new_dimensions,
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
|
||||
// Check embedding lengths
|
||||
if embedding.len() != new_dimensions as usize {
|
||||
let err_msg = format!(
|
||||
"CRITICAL: Generated embedding for entity {} has incorrect dimension ({}). Expected {}. Aborting.",
|
||||
entity.id, embedding.len(), new_dimensions
|
||||
);
|
||||
error!("{}", err_msg);
|
||||
return Err(AppError::InternalError(err_msg));
|
||||
}
|
||||
new_embeddings.insert(entity.id.clone(), embedding);
|
||||
}
|
||||
info!("Successfully generated all new embeddings.");
|
||||
|
||||
// Perform DB updates in a single transaction
|
||||
info!("Applying schema and data changes in a transaction...");
|
||||
let mut transaction_query = String::from("BEGIN TRANSACTION;");
|
||||
|
||||
// Add all update statements
|
||||
for (id, embedding) in new_embeddings {
|
||||
// We must properly serialize the vector for the SurrealQL query string
|
||||
let embedding_str = format!(
|
||||
"[{}]",
|
||||
embedding
|
||||
.iter()
|
||||
.map(|f| f.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
transaction_query.push_str(&format!(
|
||||
"UPDATE type::thing('knowledge_entity', '{}') SET embedding = {}, updated_at = time::now();",
|
||||
id, embedding_str
|
||||
));
|
||||
}
|
||||
|
||||
// Re-create the index after updating the data that it will index
|
||||
transaction_query
|
||||
.push_str("REMOVE INDEX idx_embedding_entities ON TABLE knowledge_entity;");
|
||||
transaction_query.push_str(&format!(
|
||||
"DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {};",
|
||||
new_dimensions
|
||||
));
|
||||
|
||||
transaction_query.push_str("COMMIT TRANSACTION;");
|
||||
|
||||
// Execute the entire atomic operation
|
||||
db.query(transaction_query).await?;
|
||||
|
||||
info!("Re-embedding process for knowledge entities completed successfully.");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -11,6 +11,8 @@ pub struct SystemSettings {
|
||||
pub require_email_verification: bool,
|
||||
pub query_model: String,
|
||||
pub processing_model: String,
|
||||
pub embedding_model: String,
|
||||
pub embedding_dimensions: u32,
|
||||
pub query_system_prompt: String,
|
||||
pub ingestion_system_prompt: String,
|
||||
}
|
||||
@@ -44,25 +46,12 @@ impl SystemSettings {
|
||||
"Something went wrong updating the settings".into(),
|
||||
))
|
||||
}
|
||||
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
id: "current".to_string(),
|
||||
query_system_prompt: crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT
|
||||
.to_string(),
|
||||
ingestion_system_prompt:
|
||||
crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT
|
||||
.to_string(),
|
||||
query_model: "gpt-4o-mini".to_string(),
|
||||
processing_model: "gpt-4o-mini".to_string(),
|
||||
registrations_enabled: true,
|
||||
require_email_verification: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::storage::types::text_chunk::TextChunk;
|
||||
|
||||
use super::*;
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -157,7 +146,7 @@ mod tests {
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
// Create updated settings
|
||||
let mut updated_settings = SystemSettings::new();
|
||||
let mut updated_settings = SystemSettings::get_current(&db).await.unwrap();
|
||||
updated_settings.id = "current".to_string();
|
||||
updated_settings.registrations_enabled = false;
|
||||
updated_settings.require_email_verification = true;
|
||||
@@ -206,21 +195,60 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_new_method() {
|
||||
let settings = SystemSettings::new();
|
||||
async fn test_migration_after_changing_embedding_length() {
|
||||
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.expect("Failed to start DB");
|
||||
|
||||
assert!(settings.id.len() > 0);
|
||||
assert_eq!(settings.registrations_enabled, true);
|
||||
assert_eq!(settings.require_email_verification, false);
|
||||
assert_eq!(settings.query_model, "gpt-4o-mini");
|
||||
assert_eq!(settings.processing_model, "gpt-4o-mini");
|
||||
assert_eq!(
|
||||
settings.query_system_prompt,
|
||||
crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT
|
||||
);
|
||||
assert_eq!(
|
||||
settings.ingestion_system_prompt,
|
||||
crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT
|
||||
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Initial migration failed");
|
||||
|
||||
let initial_chunk = TextChunk::new(
|
||||
"source1".into(),
|
||||
"This chunk has the original dimension".into(),
|
||||
vec![0.1; 1536],
|
||||
"user1".into(),
|
||||
);
|
||||
|
||||
db.store_item(initial_chunk.clone())
|
||||
.await
|
||||
.expect("Failed to store initial chunk");
|
||||
|
||||
async fn simulate_reembedding(
|
||||
db: &SurrealDbClient,
|
||||
target_dimension: usize,
|
||||
initial_chunk: TextChunk,
|
||||
) {
|
||||
db.query("REMOVE INDEX idx_embedding_chunks ON TABLE text_chunk;")
|
||||
.await
|
||||
.unwrap();
|
||||
let define_index_query = format!(
|
||||
"DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {};",
|
||||
target_dimension
|
||||
);
|
||||
db.query(define_index_query)
|
||||
.await
|
||||
.expect("Re-defining index should succeed");
|
||||
|
||||
let new_embedding = vec![0.5; target_dimension];
|
||||
let sql = "UPDATE type::thing('text_chunk', $id) SET embedding = $embedding;";
|
||||
|
||||
let update_result = db
|
||||
.client
|
||||
.query(sql)
|
||||
.bind(("id", initial_chunk.id.clone()))
|
||||
.bind(("embedding", new_embedding))
|
||||
.await;
|
||||
|
||||
assert!(update_result.is_ok());
|
||||
}
|
||||
|
||||
simulate_reembedding(&db, 768, initial_chunk).await;
|
||||
|
||||
let migration_result = db.apply_migrations().await;
|
||||
|
||||
assert!(migration_result.is_ok(), "Migrations should not fail");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,13 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||
use async_openai::{config::OpenAIConfig, Client};
|
||||
use tokio_retry::{
|
||||
strategy::{jitter, ExponentialBackoff},
|
||||
Retry,
|
||||
};
|
||||
|
||||
use tracing::{error, info};
|
||||
use uuid::Uuid;
|
||||
|
||||
stored_object!(TextChunk, "text_chunk", {
|
||||
@@ -35,6 +44,99 @@ impl TextChunk {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Re-creates embeddings for all text chunks using a safe, atomic transaction.
|
||||
///
|
||||
/// This is a costly operation that should be run in the background. It performs these steps:
|
||||
/// 1. **Fetches All Chunks**: Loads all existing text_chunk records into memory.
|
||||
/// 2. **Generates All Embeddings**: Creates new embeddings for every chunk. If any fails or
|
||||
/// has the wrong dimension, the entire operation is aborted before any DB changes are made.
|
||||
/// 3. **Executes Atomic Transaction**: All data updates and the index recreation are
|
||||
/// performed in a single, all-or-nothing database transaction.
|
||||
pub async fn update_all_embeddings(
|
||||
db: &SurrealDbClient,
|
||||
openai_client: &Client<OpenAIConfig>,
|
||||
new_model: &str,
|
||||
new_dimensions: u32,
|
||||
) -> Result<(), AppError> {
|
||||
info!(
|
||||
"Starting re-embedding process for all text chunks. New dimensions: {}",
|
||||
new_dimensions
|
||||
);
|
||||
|
||||
// Fetch all chunks first
|
||||
let all_chunks: Vec<TextChunk> = db.select(Self::table_name()).await?;
|
||||
let total_chunks = all_chunks.len();
|
||||
if total_chunks == 0 {
|
||||
info!("No text chunks to update. Skipping.");
|
||||
return Ok(());
|
||||
}
|
||||
info!("Found {} chunks to process.", total_chunks);
|
||||
|
||||
// Generate all new embeddings in memory
|
||||
let mut new_embeddings: HashMap<String, Vec<f32>> = HashMap::new();
|
||||
info!("Generating new embeddings for all chunks...");
|
||||
for chunk in all_chunks.iter() {
|
||||
let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3);
|
||||
|
||||
let embedding = Retry::spawn(retry_strategy, || {
|
||||
crate::utils::embedding::generate_embedding_with_params(
|
||||
openai_client,
|
||||
&chunk.chunk,
|
||||
new_model,
|
||||
new_dimensions,
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
|
||||
// Safety check: ensure the generated embedding has the correct dimension.
|
||||
if embedding.len() != new_dimensions as usize {
|
||||
let err_msg = format!(
|
||||
"CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.",
|
||||
chunk.id, embedding.len(), new_dimensions
|
||||
);
|
||||
error!("{}", err_msg);
|
||||
return Err(AppError::InternalError(err_msg));
|
||||
}
|
||||
new_embeddings.insert(chunk.id.clone(), embedding);
|
||||
}
|
||||
info!("Successfully generated all new embeddings.");
|
||||
|
||||
// Perform DB updates in a single transaction
|
||||
info!("Applying schema and data changes in a transaction...");
|
||||
let mut transaction_query = String::from("BEGIN TRANSACTION;");
|
||||
|
||||
// Add all update statements
|
||||
for (id, embedding) in new_embeddings {
|
||||
let embedding_str = format!(
|
||||
"[{}]",
|
||||
embedding
|
||||
.iter()
|
||||
.map(|f| f.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
transaction_query.push_str(&format!(
|
||||
"UPDATE type::thing('text_chunk', '{}') SET embedding = {}, updated_at = time::now();",
|
||||
id, embedding_str
|
||||
));
|
||||
}
|
||||
|
||||
// Re-create the index inside the same transaction
|
||||
transaction_query.push_str("REMOVE INDEX idx_embedding_chunks ON TABLE text_chunk;");
|
||||
transaction_query.push_str(&format!(
|
||||
"DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {};",
|
||||
new_dimensions
|
||||
));
|
||||
|
||||
transaction_query.push_str("COMMIT TRANSACTION;");
|
||||
|
||||
// Execute the entire atomic operation
|
||||
db.query(transaction_query).await?;
|
||||
|
||||
info!("Re-embedding process for text chunks completed successfully.");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -12,12 +12,18 @@ pub struct AppConfig {
|
||||
#[serde(default = "default_data_dir")]
|
||||
pub data_dir: String,
|
||||
pub http_port: u16,
|
||||
#[serde(default = "default_base_url")]
|
||||
pub openai_base_url: String,
|
||||
}
|
||||
|
||||
fn default_data_dir() -> String {
|
||||
"./data".to_string()
|
||||
}
|
||||
|
||||
fn default_base_url() -> String {
|
||||
"https://api.openai.com/v1.".to_string()
|
||||
}
|
||||
|
||||
pub fn get_config() -> Result<AppConfig, ConfigError> {
|
||||
let config = Config::builder()
|
||||
.add_source(File::with_name("config").required(false))
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
use async_openai::types::CreateEmbeddingRequestArgs;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::error::AppError;
|
||||
use crate::{
|
||||
error::AppError,
|
||||
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
|
||||
};
|
||||
/// Generates an embedding vector for the given input text using OpenAI's embedding model.
|
||||
///
|
||||
/// This function takes a text input and converts it into a numerical vector representation (embedding)
|
||||
@@ -27,9 +31,13 @@ use crate::error::AppError;
|
||||
pub async fn generate_embedding(
|
||||
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
input: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Vec<f32>, AppError> {
|
||||
let model = SystemSettings::get_current(db).await?;
|
||||
|
||||
let request = CreateEmbeddingRequestArgs::default()
|
||||
.model("text-embedding-3-small")
|
||||
.model(model.embedding_model)
|
||||
.dimensions(model.embedding_dimensions)
|
||||
.input([input])
|
||||
.build()?;
|
||||
|
||||
@@ -46,3 +54,36 @@ pub async fn generate_embedding(
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
/// Generates an embedding vector using a specific model and dimension.
|
||||
///
|
||||
/// This is used for the re-embedding process where the model and dimensions
|
||||
/// are known ahead of time and shouldn't be repeatedly fetched from settings.
|
||||
pub async fn generate_embedding_with_params(
|
||||
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
input: &str,
|
||||
model: &str,
|
||||
dimensions: u32,
|
||||
) -> Result<Vec<f32>, AppError> {
|
||||
let request = CreateEmbeddingRequestArgs::default()
|
||||
.model(model)
|
||||
.input([input])
|
||||
.dimensions(dimensions as u32)
|
||||
.build()?;
|
||||
|
||||
let response = client.embeddings().create(request).await?;
|
||||
|
||||
let embedding = response
|
||||
.data
|
||||
.first()
|
||||
.ok_or_else(|| AppError::LLMParsing("No embedding data received from API".into()))?
|
||||
.embedding
|
||||
.clone();
|
||||
|
||||
debug!(
|
||||
"Embedding was created with {:?} dimensions",
|
||||
embedding.len()
|
||||
);
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user