feat: support for other providers of ai models

This commit is contained in:
Per Stark
2025-06-06 23:16:41 +02:00
parent 811aaec554
commit a363c6cc05
22 changed files with 519 additions and 66 deletions

View File

@@ -38,6 +38,7 @@ sha2 = { workspace = true }
url = { workspace = true }
uuid = { workspace = true }
surrealdb-migrations = { workspace = true }
tokio-retry = { workspace = true }
[features]

View File

@@ -13,6 +13,8 @@ CREATE system_settings:current CONTENT {
require_email_verification: false,
query_model: "gpt-4o-mini",
processing_model: "gpt-4o-mini",
embedding_model: "text-embedding-3-small",
embedding_dimensions: 1536,
query_system_prompt: "You are a knowledgeable assistant with access to a specialized knowledge base. You will be provided with relevant knowledge entities from the database as context. Each knowledge entity contains a name, description, and type, representing different concepts, ideas, and information.\nYour task is to:\n1. Carefully analyze the provided knowledge entities in the context\n2. Answer user questions based on this information\n3. Provide clear, concise, and accurate responses\n4. When referencing information, briefly mention which knowledge entity it came from\n5. If the provided context doesn't contain enough information to answer the question confidently, clearly state this\n6. If only partial information is available, explain what you can answer and what information is missing\n7. Avoid making assumptions or providing information not supported by the context\n8. Output the references to the documents. Use the UUIDs and make sure they are correct!\nRemember:\n- Be direct and honest about the limitations of your knowledge\n- Cite the relevant knowledge entities when providing information, but only provide the UUIDs in the reference array\n- If you need to combine information from multiple entities, explain how they connect\n- Don't speculate beyond what's provided in the context\nExample response formats:\n\"Based on [Entity Name], [answer...]\"\n\"I found relevant information in multiple entries: [explanation...]\"\n\"I apologize, but the provided context doesn't contain information about [topic]\"",
ingestion_system_prompt: "You are an AI assistant. You will receive a text content, along with user context and a category. Your task is to provide a structured JSON object representing the content in a graph format suitable for a graph database. You will also be presented with some existing knowledge_entities from the database, do not replicate these! Your task is to create meaningful knowledge entities from the submitted content. Try and infer as much as possible from the users context and category when creating these. If the user submits a large content, create more general entities. If the user submits a narrow and precise content, try and create precise knowledge entities.\nThe JSON should have the following structure:\n{\n\"knowledge_entities\": [\n{\n\"key\": \"unique-key-1\",\n\"name\": \"Entity Name\",\n\"description\": \"A detailed description of the entity.\",\n\"entity_type\": \"TypeOfEntity\"\n},\n// More entities...\n],\n\"relationships\": [\n{\n\"type\": \"RelationshipType\",\n\"source\": \"unique-key-1 or UUID from existing database\",\n\"target\": \"unique-key-1 or UUID from existing database\"\n},\n// More relationships...\n]\n}\nGuidelines:\n1. Do NOT generate any IDs or UUIDs. Use a unique `key` for each knowledge entity.\n2. Each KnowledgeEntity should have a unique `key`, a meaningful `name`, and a descriptive `description`.\n3. Define the type of each KnowledgeEntity using the following categories: Idea, Project, Document, Page, TextSnippet.\n4. Establish relationships between entities using types like RelatedTo, RelevantTo, SimilarTo.\n5. Use the `source` key to indicate the originating entity and the `target` key to indicate the related entity\"\n6. You will be presented with a few existing KnowledgeEntities that are similar to the current ones. They will have an existing UUID. When creating relationships to these entities, use their UUID.\n7. Only create relationships between existing KnowledgeEntities.\n8. Entities that exist already in the database should NOT be created again. If there is only a minor overlap, skip creating a new entity.\n9. A new relationship MUST include a newly created KnowledgeEntity."
};

View File

@@ -0,0 +1 @@
{"schemas":"--- original\n+++ modified\n@@ -98,7 +98,7 @@\n DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;\n-DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at; # For get_latest_knowledge_entities\n+DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;\n\n # Defines the schema for the 'message' table.\n\n@@ -157,6 +157,8 @@\n DEFINE FIELD IF NOT EXISTS require_email_verification ON system_settings TYPE bool;\n DEFINE FIELD IF NOT EXISTS query_model ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS processing_model ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS embedding_model ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS embedding_dimensions ON system_settings TYPE int;\n DEFINE FIELD IF NOT EXISTS query_system_prompt ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;\n\n","events":null}

View File

@@ -27,4 +27,4 @@ DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS emb
DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;
DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;
DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;
DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at; # For get_latest_knowledge_entities
DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;

View File

@@ -7,5 +7,7 @@ DEFINE FIELD IF NOT EXISTS registrations_enabled ON system_settings TYPE bool;
DEFINE FIELD IF NOT EXISTS require_email_verification ON system_settings TYPE bool;
DEFINE FIELD IF NOT EXISTS query_model ON system_settings TYPE string;
DEFINE FIELD IF NOT EXISTS processing_model ON system_settings TYPE string;
DEFINE FIELD IF NOT EXISTS embedding_model ON system_settings TYPE string;
DEFINE FIELD IF NOT EXISTS embedding_dimensions ON system_settings TYPE int;
DEFINE FIELD IF NOT EXISTS query_system_prompt ON system_settings TYPE string;
DEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;

View File

@@ -11,6 +11,7 @@ use surrealdb::{
Error, Notification, Surreal,
};
use surrealdb_migrations::MigrationRunner;
use tracing::debug;
static MIGRATIONS_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/");
@@ -50,6 +51,7 @@ impl SurrealDbClient {
pub async fn create_session_store(
&self,
) -> Result<SessionStore<SessionSurrealPool<Any>>, SessionError> {
debug!("Creating session store");
SessionStore::new(
Some(self.client.clone().into()),
SessionConfig::default()
@@ -65,6 +67,7 @@ impl SurrealDbClient {
/// the database and selecting the appropriate namespace and database, but before
/// the application starts performing operations that rely on the schema.
pub async fn apply_migrations(&self) -> Result<(), AppError> {
debug!("Applying migrations");
MigrationRunner::new(&self.client)
.load_files(&MIGRATIONS_DIR)
.up()
@@ -76,6 +79,7 @@ impl SurrealDbClient {
/// Operation to rebuild indexes
pub async fn rebuild_indexes(&self) -> Result<(), Error> {
debug!("Rebuilding indexes");
self.client
.query("REBUILD INDEX IF EXISTS idx_embedding_chunks ON text_chunk")
.await?;

View File

@@ -349,6 +349,8 @@ mod tests {
surrealdb_password: "test_pass".to_string(),
surrealdb_namespace: "test_ns".to_string(),
surrealdb_database: "test_db".to_string(),
http_port: 3000,
openai_base_url: "..".to_string(),
};
// Test file creation
@@ -406,6 +408,8 @@ mod tests {
surrealdb_password: "test_pass".to_string(),
surrealdb_namespace: "test_ns".to_string(),
surrealdb_database: "test_db".to_string(),
http_port: 3000,
openai_base_url: "..".to_string(),
};
// Store the original file
@@ -459,6 +463,8 @@ mod tests {
surrealdb_password: "test_pass".to_string(),
surrealdb_namespace: "test_ns".to_string(),
surrealdb_database: "test_db".to_string(),
http_port: 3000,
openai_base_url: "..".to_string(),
};
let file_info = FileInfo::new(field_data, &db, user_id, &config).await;
@@ -508,6 +514,8 @@ mod tests {
surrealdb_password: "test_pass".to_string(),
surrealdb_namespace: "test_ns".to_string(),
surrealdb_database: "test_db".to_string(),
http_port: 3000,
openai_base_url: "..".to_string(),
};
let field_data1 = create_test_file(content, file_name);
@@ -844,6 +852,8 @@ mod tests {
surrealdb_password: "test_pass".to_string(),
surrealdb_namespace: "test_ns".to_string(),
surrealdb_database: "test_db".to_string(),
http_port: 3000,
openai_base_url: "..".to_string(),
};
// Test file creation

View File

@@ -1,8 +1,15 @@
use std::collections::HashMap;
use crate::{
error::AppError, storage::db::SurrealDbClient, stored_object,
utils::embedding::generate_embedding,
};
use async_openai::{config::OpenAIConfig, Client};
use tokio_retry::{
strategy::{jitter, ExponentialBackoff},
Retry,
};
use tracing::{error, info};
use uuid::Uuid;
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
@@ -94,7 +101,7 @@ impl KnowledgeEntity {
"name: {}, description: {}, type: {:?}",
name, description, entity_type
);
let embedding = generate_embedding(ai_client, &embedding_input).await?;
let embedding = generate_embedding(ai_client, &embedding_input, db_client).await?;
db_client
.client
@@ -118,6 +125,104 @@ impl KnowledgeEntity {
Ok(())
}
/// Re-creates embeddings for all knowledge entities in the database.
///
/// This is a costly operation that should be run in the background. It follows the same
/// pattern as the text chunk update:
/// 1. Re-defines the vector index with the new dimensions.
/// 2. Fetches all existing entities.
/// 3. Sequentially regenerates the embedding for each and updates the record.
pub async fn update_all_embeddings(
db: &SurrealDbClient,
openai_client: &Client<OpenAIConfig>,
new_model: &str,
new_dimensions: u32,
) -> Result<(), AppError> {
info!(
"Starting re-embedding process for all knowledge entities. New dimensions: {}",
new_dimensions
);
// Fetch all entities first
let all_entities: Vec<KnowledgeEntity> = db.select(Self::table_name()).await?;
let total_entities = all_entities.len();
if total_entities == 0 {
info!("No knowledge entities to update. Skipping.");
return Ok(());
}
info!("Found {} entities to process.", total_entities);
// Generate all new embeddings in memory
let mut new_embeddings: HashMap<String, Vec<f32>> = HashMap::new();
info!("Generating new embeddings for all entities...");
for entity in all_entities.iter() {
let embedding_input = format!(
"name: {}, description: {}, type: {:?}",
entity.name, entity.description, entity.entity_type
);
let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3);
let embedding = Retry::spawn(retry_strategy, || {
crate::utils::embedding::generate_embedding_with_params(
openai_client,
&embedding_input,
new_model,
new_dimensions,
)
})
.await?;
// Check embedding lengths
if embedding.len() != new_dimensions as usize {
let err_msg = format!(
"CRITICAL: Generated embedding for entity {} has incorrect dimension ({}). Expected {}. Aborting.",
entity.id, embedding.len(), new_dimensions
);
error!("{}", err_msg);
return Err(AppError::InternalError(err_msg));
}
new_embeddings.insert(entity.id.clone(), embedding);
}
info!("Successfully generated all new embeddings.");
// Perform DB updates in a single transaction
info!("Applying schema and data changes in a transaction...");
let mut transaction_query = String::from("BEGIN TRANSACTION;");
// Add all update statements
for (id, embedding) in new_embeddings {
// We must properly serialize the vector for the SurrealQL query string
let embedding_str = format!(
"[{}]",
embedding
.iter()
.map(|f| f.to_string())
.collect::<Vec<_>>()
.join(",")
);
transaction_query.push_str(&format!(
"UPDATE type::thing('knowledge_entity', '{}') SET embedding = {}, updated_at = time::now();",
id, embedding_str
));
}
// Re-create the index after updating the data that it will index
transaction_query
.push_str("REMOVE INDEX idx_embedding_entities ON TABLE knowledge_entity;");
transaction_query.push_str(&format!(
"DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {};",
new_dimensions
));
transaction_query.push_str("COMMIT TRANSACTION;");
// Execute the entire atomic operation
db.query(transaction_query).await?;
info!("Re-embedding process for knowledge entities completed successfully.");
Ok(())
}
}
#[cfg(test)]

View File

@@ -11,6 +11,8 @@ pub struct SystemSettings {
pub require_email_verification: bool,
pub query_model: String,
pub processing_model: String,
pub embedding_model: String,
pub embedding_dimensions: u32,
pub query_system_prompt: String,
pub ingestion_system_prompt: String,
}
@@ -44,25 +46,12 @@ impl SystemSettings {
"Something went wrong updating the settings".into(),
))
}
pub fn new() -> Self {
Self {
id: "current".to_string(),
query_system_prompt: crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT
.to_string(),
ingestion_system_prompt:
crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT
.to_string(),
query_model: "gpt-4o-mini".to_string(),
processing_model: "gpt-4o-mini".to_string(),
registrations_enabled: true,
require_email_verification: false,
}
}
}
#[cfg(test)]
mod tests {
use crate::storage::types::text_chunk::TextChunk;
use super::*;
use uuid::Uuid;
@@ -157,7 +146,7 @@ mod tests {
.expect("Failed to apply migrations");
// Create updated settings
let mut updated_settings = SystemSettings::new();
let mut updated_settings = SystemSettings::get_current(&db).await.unwrap();
updated_settings.id = "current".to_string();
updated_settings.registrations_enabled = false;
updated_settings.require_email_verification = true;
@@ -206,21 +195,60 @@ mod tests {
}
#[tokio::test]
async fn test_new_method() {
let settings = SystemSettings::new();
async fn test_migration_after_changing_embedding_length() {
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
.await
.expect("Failed to start DB");
assert!(settings.id.len() > 0);
assert_eq!(settings.registrations_enabled, true);
assert_eq!(settings.require_email_verification, false);
assert_eq!(settings.query_model, "gpt-4o-mini");
assert_eq!(settings.processing_model, "gpt-4o-mini");
assert_eq!(
settings.query_system_prompt,
crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT
);
assert_eq!(
settings.ingestion_system_prompt,
crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
db.apply_migrations()
.await
.expect("Initial migration failed");
let initial_chunk = TextChunk::new(
"source1".into(),
"This chunk has the original dimension".into(),
vec![0.1; 1536],
"user1".into(),
);
db.store_item(initial_chunk.clone())
.await
.expect("Failed to store initial chunk");
async fn simulate_reembedding(
db: &SurrealDbClient,
target_dimension: usize,
initial_chunk: TextChunk,
) {
db.query("REMOVE INDEX idx_embedding_chunks ON TABLE text_chunk;")
.await
.unwrap();
let define_index_query = format!(
"DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {};",
target_dimension
);
db.query(define_index_query)
.await
.expect("Re-defining index should succeed");
let new_embedding = vec![0.5; target_dimension];
let sql = "UPDATE type::thing('text_chunk', $id) SET embedding = $embedding;";
let update_result = db
.client
.query(sql)
.bind(("id", initial_chunk.id.clone()))
.bind(("embedding", new_embedding))
.await;
assert!(update_result.is_ok());
}
simulate_reembedding(&db, 768, initial_chunk).await;
let migration_result = db.apply_migrations().await;
assert!(migration_result.is_ok(), "Migrations should not fail");
}
}

View File

@@ -1,4 +1,13 @@
use std::collections::HashMap;
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
use async_openai::{config::OpenAIConfig, Client};
use tokio_retry::{
strategy::{jitter, ExponentialBackoff},
Retry,
};
use tracing::{error, info};
use uuid::Uuid;
stored_object!(TextChunk, "text_chunk", {
@@ -35,6 +44,99 @@ impl TextChunk {
Ok(())
}
/// Re-creates embeddings for all text chunks using a safe, atomic transaction.
///
/// This is a costly operation that should be run in the background. It performs these steps:
/// 1. **Fetches All Chunks**: Loads all existing text_chunk records into memory.
/// 2. **Generates All Embeddings**: Creates new embeddings for every chunk. If any fails or
/// has the wrong dimension, the entire operation is aborted before any DB changes are made.
/// 3. **Executes Atomic Transaction**: All data updates and the index recreation are
/// performed in a single, all-or-nothing database transaction.
pub async fn update_all_embeddings(
db: &SurrealDbClient,
openai_client: &Client<OpenAIConfig>,
new_model: &str,
new_dimensions: u32,
) -> Result<(), AppError> {
info!(
"Starting re-embedding process for all text chunks. New dimensions: {}",
new_dimensions
);
// Fetch all chunks first
let all_chunks: Vec<TextChunk> = db.select(Self::table_name()).await?;
let total_chunks = all_chunks.len();
if total_chunks == 0 {
info!("No text chunks to update. Skipping.");
return Ok(());
}
info!("Found {} chunks to process.", total_chunks);
// Generate all new embeddings in memory
let mut new_embeddings: HashMap<String, Vec<f32>> = HashMap::new();
info!("Generating new embeddings for all chunks...");
for chunk in all_chunks.iter() {
let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3);
let embedding = Retry::spawn(retry_strategy, || {
crate::utils::embedding::generate_embedding_with_params(
openai_client,
&chunk.chunk,
new_model,
new_dimensions,
)
})
.await?;
// Safety check: ensure the generated embedding has the correct dimension.
if embedding.len() != new_dimensions as usize {
let err_msg = format!(
"CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.",
chunk.id, embedding.len(), new_dimensions
);
error!("{}", err_msg);
return Err(AppError::InternalError(err_msg));
}
new_embeddings.insert(chunk.id.clone(), embedding);
}
info!("Successfully generated all new embeddings.");
// Perform DB updates in a single transaction
info!("Applying schema and data changes in a transaction...");
let mut transaction_query = String::from("BEGIN TRANSACTION;");
// Add all update statements
for (id, embedding) in new_embeddings {
let embedding_str = format!(
"[{}]",
embedding
.iter()
.map(|f| f.to_string())
.collect::<Vec<_>>()
.join(",")
);
transaction_query.push_str(&format!(
"UPDATE type::thing('text_chunk', '{}') SET embedding = {}, updated_at = time::now();",
id, embedding_str
));
}
// Re-create the index inside the same transaction
transaction_query.push_str("REMOVE INDEX idx_embedding_chunks ON TABLE text_chunk;");
transaction_query.push_str(&format!(
"DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {};",
new_dimensions
));
transaction_query.push_str("COMMIT TRANSACTION;");
// Execute the entire atomic operation
db.query(transaction_query).await?;
info!("Re-embedding process for text chunks completed successfully.");
Ok(())
}
}
#[cfg(test)]

View File

@@ -12,12 +12,18 @@ pub struct AppConfig {
#[serde(default = "default_data_dir")]
pub data_dir: String,
pub http_port: u16,
#[serde(default = "default_base_url")]
pub openai_base_url: String,
}
fn default_data_dir() -> String {
"./data".to_string()
}
fn default_base_url() -> String {
"https://api.openai.com/v1.".to_string()
}
pub fn get_config() -> Result<AppConfig, ConfigError> {
let config = Config::builder()
.add_source(File::with_name("config").required(false))

View File

@@ -1,6 +1,10 @@
use async_openai::types::CreateEmbeddingRequestArgs;
use tracing::debug;
use crate::error::AppError;
use crate::{
error::AppError,
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
};
/// Generates an embedding vector for the given input text using OpenAI's embedding model.
///
/// This function takes a text input and converts it into a numerical vector representation (embedding)
@@ -27,9 +31,13 @@ use crate::error::AppError;
pub async fn generate_embedding(
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
input: &str,
db: &SurrealDbClient,
) -> Result<Vec<f32>, AppError> {
let model = SystemSettings::get_current(db).await?;
let request = CreateEmbeddingRequestArgs::default()
.model("text-embedding-3-small")
.model(model.embedding_model)
.dimensions(model.embedding_dimensions)
.input([input])
.build()?;
@@ -46,3 +54,36 @@ pub async fn generate_embedding(
Ok(embedding)
}
/// Generates an embedding vector using a specific model and dimension.
///
/// This is used for the re-embedding process where the model and dimensions
/// are known ahead of time and shouldn't be repeatedly fetched from settings.
pub async fn generate_embedding_with_params(
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
input: &str,
model: &str,
dimensions: u32,
) -> Result<Vec<f32>, AppError> {
let request = CreateEmbeddingRequestArgs::default()
.model(model)
.input([input])
.dimensions(dimensions as u32)
.build()?;
let response = client.embeddings().create(request).await?;
let embedding = response
.data
.first()
.ok_or_else(|| AppError::LLMParsing("No embedding data received from API".into()))?
.embedding
.clone();
debug!(
"Embedding was created with {:?} dimensions",
embedding.len()
);
Ok(embedding)
}