feat: support for other providers of ai models

This commit is contained in:
Per Stark
2025-06-06 23:16:41 +02:00
parent 811aaec554
commit a363c6cc05
22 changed files with 519 additions and 66 deletions

12
Cargo.lock generated
View File

@@ -1242,6 +1242,7 @@ dependencies = [
"tempfile",
"thiserror 1.0.69",
"tokio",
"tokio-retry",
"tracing",
"url",
"uuid",
@@ -5793,6 +5794,17 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-retry"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f57eb36ecbe0fc510036adff84824dd3c24bb781e21bfa67b69d556aa85214f"
dependencies = [
"pin-project",
"rand 0.8.5",
"tokio",
]
[[package]]
name = "tokio-rustls"
version = "0.26.2"

View File

@@ -52,6 +52,7 @@ tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
url = { version = "2.5.2", features = ["serde"] }
uuid = { version = "1.10.0", features = ["v4", "serde"] }
tokio-retry = "0.3.0"
[profile.dist]
inherits = "release"

View File

@@ -38,6 +38,7 @@ sha2 = { workspace = true }
url = { workspace = true }
uuid = { workspace = true }
surrealdb-migrations = { workspace = true }
tokio-retry = { workspace = true }
[features]

View File

@@ -13,6 +13,8 @@ CREATE system_settings:current CONTENT {
require_email_verification: false,
query_model: "gpt-4o-mini",
processing_model: "gpt-4o-mini",
embedding_model: "text-embedding-3-small",
embedding_dimensions: 1536,
query_system_prompt: "You are a knowledgeable assistant with access to a specialized knowledge base. You will be provided with relevant knowledge entities from the database as context. Each knowledge entity contains a name, description, and type, representing different concepts, ideas, and information.\nYour task is to:\n1. Carefully analyze the provided knowledge entities in the context\n2. Answer user questions based on this information\n3. Provide clear, concise, and accurate responses\n4. When referencing information, briefly mention which knowledge entity it came from\n5. If the provided context doesn't contain enough information to answer the question confidently, clearly state this\n6. If only partial information is available, explain what you can answer and what information is missing\n7. Avoid making assumptions or providing information not supported by the context\n8. Output the references to the documents. Use the UUIDs and make sure they are correct!\nRemember:\n- Be direct and honest about the limitations of your knowledge\n- Cite the relevant knowledge entities when providing information, but only provide the UUIDs in the reference array\n- If you need to combine information from multiple entities, explain how they connect\n- Don't speculate beyond what's provided in the context\nExample response formats:\n\"Based on [Entity Name], [answer...]\"\n\"I found relevant information in multiple entries: [explanation...]\"\n\"I apologize, but the provided context doesn't contain information about [topic]\"",
ingestion_system_prompt: "You are an AI assistant. You will receive a text content, along with user context and a category. Your task is to provide a structured JSON object representing the content in a graph format suitable for a graph database. You will also be presented with some existing knowledge_entities from the database, do not replicate these! Your task is to create meaningful knowledge entities from the submitted content. Try and infer as much as possible from the users context and category when creating these. If the user submits a large content, create more general entities. If the user submits a narrow and precise content, try and create precise knowledge entities.\nThe JSON should have the following structure:\n{\n\"knowledge_entities\": [\n{\n\"key\": \"unique-key-1\",\n\"name\": \"Entity Name\",\n\"description\": \"A detailed description of the entity.\",\n\"entity_type\": \"TypeOfEntity\"\n},\n// More entities...\n],\n\"relationships\": [\n{\n\"type\": \"RelationshipType\",\n\"source\": \"unique-key-1 or UUID from existing database\",\n\"target\": \"unique-key-1 or UUID from existing database\"\n},\n// More relationships...\n]\n}\nGuidelines:\n1. Do NOT generate any IDs or UUIDs. Use a unique `key` for each knowledge entity.\n2. Each KnowledgeEntity should have a unique `key`, a meaningful `name`, and a descriptive `description`.\n3. Define the type of each KnowledgeEntity using the following categories: Idea, Project, Document, Page, TextSnippet.\n4. Establish relationships between entities using types like RelatedTo, RelevantTo, SimilarTo.\n5. Use the `source` key to indicate the originating entity and the `target` key to indicate the related entity\"\n6. You will be presented with a few existing KnowledgeEntities that are similar to the current ones. They will have an existing UUID. When creating relationships to these entities, use their UUID.\n7. Only create relationships between existing KnowledgeEntities.\n8. Entities that exist already in the database should NOT be created again. If there is only a minor overlap, skip creating a new entity.\n9. A new relationship MUST include a newly created KnowledgeEntity."
};

View File

@@ -0,0 +1 @@
{"schemas":"--- original\n+++ modified\n@@ -98,7 +98,7 @@\n DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;\n-DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at; # For get_latest_knowledge_entities\n+DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;\n\n # Defines the schema for the 'message' table.\n\n@@ -157,6 +157,8 @@\n DEFINE FIELD IF NOT EXISTS require_email_verification ON system_settings TYPE bool;\n DEFINE FIELD IF NOT EXISTS query_model ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS processing_model ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS embedding_model ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS embedding_dimensions ON system_settings TYPE int;\n DEFINE FIELD IF NOT EXISTS query_system_prompt ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;\n\n","events":null}

View File

@@ -27,4 +27,4 @@ DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS emb
DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;
DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;
DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;
DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at; # For get_latest_knowledge_entities
DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;

View File

@@ -7,5 +7,7 @@ DEFINE FIELD IF NOT EXISTS registrations_enabled ON system_settings TYPE bool;
DEFINE FIELD IF NOT EXISTS require_email_verification ON system_settings TYPE bool;
DEFINE FIELD IF NOT EXISTS query_model ON system_settings TYPE string;
DEFINE FIELD IF NOT EXISTS processing_model ON system_settings TYPE string;
DEFINE FIELD IF NOT EXISTS embedding_model ON system_settings TYPE string;
DEFINE FIELD IF NOT EXISTS embedding_dimensions ON system_settings TYPE int;
DEFINE FIELD IF NOT EXISTS query_system_prompt ON system_settings TYPE string;
DEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;

View File

@@ -11,6 +11,7 @@ use surrealdb::{
Error, Notification, Surreal,
};
use surrealdb_migrations::MigrationRunner;
use tracing::debug;
static MIGRATIONS_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/");
@@ -50,6 +51,7 @@ impl SurrealDbClient {
pub async fn create_session_store(
&self,
) -> Result<SessionStore<SessionSurrealPool<Any>>, SessionError> {
debug!("Creating session store");
SessionStore::new(
Some(self.client.clone().into()),
SessionConfig::default()
@@ -65,6 +67,7 @@ impl SurrealDbClient {
/// the database and selecting the appropriate namespace and database, but before
/// the application starts performing operations that rely on the schema.
pub async fn apply_migrations(&self) -> Result<(), AppError> {
debug!("Applying migrations");
MigrationRunner::new(&self.client)
.load_files(&MIGRATIONS_DIR)
.up()
@@ -76,6 +79,7 @@ impl SurrealDbClient {
/// Operation to rebuild indexes
pub async fn rebuild_indexes(&self) -> Result<(), Error> {
debug!("Rebuilding indexes");
self.client
.query("REBUILD INDEX IF EXISTS idx_embedding_chunks ON text_chunk")
.await?;

View File

@@ -349,6 +349,8 @@ mod tests {
surrealdb_password: "test_pass".to_string(),
surrealdb_namespace: "test_ns".to_string(),
surrealdb_database: "test_db".to_string(),
http_port: 3000,
openai_base_url: "..".to_string(),
};
// Test file creation
@@ -406,6 +408,8 @@ mod tests {
surrealdb_password: "test_pass".to_string(),
surrealdb_namespace: "test_ns".to_string(),
surrealdb_database: "test_db".to_string(),
http_port: 3000,
openai_base_url: "..".to_string(),
};
// Store the original file
@@ -459,6 +463,8 @@ mod tests {
surrealdb_password: "test_pass".to_string(),
surrealdb_namespace: "test_ns".to_string(),
surrealdb_database: "test_db".to_string(),
http_port: 3000,
openai_base_url: "..".to_string(),
};
let file_info = FileInfo::new(field_data, &db, user_id, &config).await;
@@ -508,6 +514,8 @@ mod tests {
surrealdb_password: "test_pass".to_string(),
surrealdb_namespace: "test_ns".to_string(),
surrealdb_database: "test_db".to_string(),
http_port: 3000,
openai_base_url: "..".to_string(),
};
let field_data1 = create_test_file(content, file_name);
@@ -844,6 +852,8 @@ mod tests {
surrealdb_password: "test_pass".to_string(),
surrealdb_namespace: "test_ns".to_string(),
surrealdb_database: "test_db".to_string(),
http_port: 3000,
openai_base_url: "..".to_string(),
};
// Test file creation

View File

@@ -1,8 +1,15 @@
use std::collections::HashMap;
use crate::{
error::AppError, storage::db::SurrealDbClient, stored_object,
utils::embedding::generate_embedding,
};
use async_openai::{config::OpenAIConfig, Client};
use tokio_retry::{
strategy::{jitter, ExponentialBackoff},
Retry,
};
use tracing::{error, info};
use uuid::Uuid;
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
@@ -94,7 +101,7 @@ impl KnowledgeEntity {
"name: {}, description: {}, type: {:?}",
name, description, entity_type
);
let embedding = generate_embedding(ai_client, &embedding_input).await?;
let embedding = generate_embedding(ai_client, &embedding_input, db_client).await?;
db_client
.client
@@ -118,6 +125,104 @@ impl KnowledgeEntity {
Ok(())
}
/// Re-creates embeddings for all knowledge entities in the database.
///
/// This is a costly operation that should be run in the background. It follows the same
/// pattern as the text chunk update:
/// 1. Re-defines the vector index with the new dimensions.
/// 2. Fetches all existing entities.
/// 3. Sequentially regenerates the embedding for each and updates the record.
pub async fn update_all_embeddings(
db: &SurrealDbClient,
openai_client: &Client<OpenAIConfig>,
new_model: &str,
new_dimensions: u32,
) -> Result<(), AppError> {
info!(
"Starting re-embedding process for all knowledge entities. New dimensions: {}",
new_dimensions
);
// Fetch all entities first
let all_entities: Vec<KnowledgeEntity> = db.select(Self::table_name()).await?;
let total_entities = all_entities.len();
if total_entities == 0 {
info!("No knowledge entities to update. Skipping.");
return Ok(());
}
info!("Found {} entities to process.", total_entities);
// Generate all new embeddings in memory
let mut new_embeddings: HashMap<String, Vec<f32>> = HashMap::new();
info!("Generating new embeddings for all entities...");
for entity in all_entities.iter() {
let embedding_input = format!(
"name: {}, description: {}, type: {:?}",
entity.name, entity.description, entity.entity_type
);
let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3);
let embedding = Retry::spawn(retry_strategy, || {
crate::utils::embedding::generate_embedding_with_params(
openai_client,
&embedding_input,
new_model,
new_dimensions,
)
})
.await?;
// Check embedding lengths
if embedding.len() != new_dimensions as usize {
let err_msg = format!(
"CRITICAL: Generated embedding for entity {} has incorrect dimension ({}). Expected {}. Aborting.",
entity.id, embedding.len(), new_dimensions
);
error!("{}", err_msg);
return Err(AppError::InternalError(err_msg));
}
new_embeddings.insert(entity.id.clone(), embedding);
}
info!("Successfully generated all new embeddings.");
// Perform DB updates in a single transaction
info!("Applying schema and data changes in a transaction...");
let mut transaction_query = String::from("BEGIN TRANSACTION;");
// Add all update statements
for (id, embedding) in new_embeddings {
// We must properly serialize the vector for the SurrealQL query string
let embedding_str = format!(
"[{}]",
embedding
.iter()
.map(|f| f.to_string())
.collect::<Vec<_>>()
.join(",")
);
transaction_query.push_str(&format!(
"UPDATE type::thing('knowledge_entity', '{}') SET embedding = {}, updated_at = time::now();",
id, embedding_str
));
}
// Re-create the index after updating the data that it will index
transaction_query
.push_str("REMOVE INDEX idx_embedding_entities ON TABLE knowledge_entity;");
transaction_query.push_str(&format!(
"DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {};",
new_dimensions
));
transaction_query.push_str("COMMIT TRANSACTION;");
// Execute the entire atomic operation
db.query(transaction_query).await?;
info!("Re-embedding process for knowledge entities completed successfully.");
Ok(())
}
}
#[cfg(test)]

View File

@@ -11,6 +11,8 @@ pub struct SystemSettings {
pub require_email_verification: bool,
pub query_model: String,
pub processing_model: String,
pub embedding_model: String,
pub embedding_dimensions: u32,
pub query_system_prompt: String,
pub ingestion_system_prompt: String,
}
@@ -44,25 +46,12 @@ impl SystemSettings {
"Something went wrong updating the settings".into(),
))
}
pub fn new() -> Self {
Self {
id: "current".to_string(),
query_system_prompt: crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT
.to_string(),
ingestion_system_prompt:
crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT
.to_string(),
query_model: "gpt-4o-mini".to_string(),
processing_model: "gpt-4o-mini".to_string(),
registrations_enabled: true,
require_email_verification: false,
}
}
}
#[cfg(test)]
mod tests {
use crate::storage::types::text_chunk::TextChunk;
use super::*;
use uuid::Uuid;
@@ -157,7 +146,7 @@ mod tests {
.expect("Failed to apply migrations");
// Create updated settings
let mut updated_settings = SystemSettings::new();
let mut updated_settings = SystemSettings::get_current(&db).await.unwrap();
updated_settings.id = "current".to_string();
updated_settings.registrations_enabled = false;
updated_settings.require_email_verification = true;
@@ -206,21 +195,60 @@ mod tests {
}
#[tokio::test]
async fn test_new_method() {
let settings = SystemSettings::new();
async fn test_migration_after_changing_embedding_length() {
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
.await
.expect("Failed to start DB");
assert!(settings.id.len() > 0);
assert_eq!(settings.registrations_enabled, true);
assert_eq!(settings.require_email_verification, false);
assert_eq!(settings.query_model, "gpt-4o-mini");
assert_eq!(settings.processing_model, "gpt-4o-mini");
assert_eq!(
settings.query_system_prompt,
crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT
);
assert_eq!(
settings.ingestion_system_prompt,
crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
db.apply_migrations()
.await
.expect("Initial migration failed");
let initial_chunk = TextChunk::new(
"source1".into(),
"This chunk has the original dimension".into(),
vec![0.1; 1536],
"user1".into(),
);
db.store_item(initial_chunk.clone())
.await
.expect("Failed to store initial chunk");
async fn simulate_reembedding(
db: &SurrealDbClient,
target_dimension: usize,
initial_chunk: TextChunk,
) {
db.query("REMOVE INDEX idx_embedding_chunks ON TABLE text_chunk;")
.await
.unwrap();
let define_index_query = format!(
"DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {};",
target_dimension
);
db.query(define_index_query)
.await
.expect("Re-defining index should succeed");
let new_embedding = vec![0.5; target_dimension];
let sql = "UPDATE type::thing('text_chunk', $id) SET embedding = $embedding;";
let update_result = db
.client
.query(sql)
.bind(("id", initial_chunk.id.clone()))
.bind(("embedding", new_embedding))
.await;
assert!(update_result.is_ok());
}
simulate_reembedding(&db, 768, initial_chunk).await;
let migration_result = db.apply_migrations().await;
assert!(migration_result.is_ok(), "Migrations should not fail");
}
}

View File

@@ -1,4 +1,13 @@
use std::collections::HashMap;
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
use async_openai::{config::OpenAIConfig, Client};
use tokio_retry::{
strategy::{jitter, ExponentialBackoff},
Retry,
};
use tracing::{error, info};
use uuid::Uuid;
stored_object!(TextChunk, "text_chunk", {
@@ -35,6 +44,99 @@ impl TextChunk {
Ok(())
}
/// Re-creates embeddings for all text chunks using a safe, atomic transaction.
///
/// This is a costly operation that should be run in the background. It performs these steps:
/// 1. **Fetches All Chunks**: Loads all existing text_chunk records into memory.
/// 2. **Generates All Embeddings**: Creates new embeddings for every chunk. If any fails or
/// has the wrong dimension, the entire operation is aborted before any DB changes are made.
/// 3. **Executes Atomic Transaction**: All data updates and the index recreation are
/// performed in a single, all-or-nothing database transaction.
pub async fn update_all_embeddings(
db: &SurrealDbClient,
openai_client: &Client<OpenAIConfig>,
new_model: &str,
new_dimensions: u32,
) -> Result<(), AppError> {
info!(
"Starting re-embedding process for all text chunks. New dimensions: {}",
new_dimensions
);
// Fetch all chunks first
let all_chunks: Vec<TextChunk> = db.select(Self::table_name()).await?;
let total_chunks = all_chunks.len();
if total_chunks == 0 {
info!("No text chunks to update. Skipping.");
return Ok(());
}
info!("Found {} chunks to process.", total_chunks);
// Generate all new embeddings in memory
let mut new_embeddings: HashMap<String, Vec<f32>> = HashMap::new();
info!("Generating new embeddings for all chunks...");
for chunk in all_chunks.iter() {
let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3);
let embedding = Retry::spawn(retry_strategy, || {
crate::utils::embedding::generate_embedding_with_params(
openai_client,
&chunk.chunk,
new_model,
new_dimensions,
)
})
.await?;
// Safety check: ensure the generated embedding has the correct dimension.
if embedding.len() != new_dimensions as usize {
let err_msg = format!(
"CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.",
chunk.id, embedding.len(), new_dimensions
);
error!("{}", err_msg);
return Err(AppError::InternalError(err_msg));
}
new_embeddings.insert(chunk.id.clone(), embedding);
}
info!("Successfully generated all new embeddings.");
// Perform DB updates in a single transaction
info!("Applying schema and data changes in a transaction...");
let mut transaction_query = String::from("BEGIN TRANSACTION;");
// Add all update statements
for (id, embedding) in new_embeddings {
let embedding_str = format!(
"[{}]",
embedding
.iter()
.map(|f| f.to_string())
.collect::<Vec<_>>()
.join(",")
);
transaction_query.push_str(&format!(
"UPDATE type::thing('text_chunk', '{}') SET embedding = {}, updated_at = time::now();",
id, embedding_str
));
}
// Re-create the index inside the same transaction
transaction_query.push_str("REMOVE INDEX idx_embedding_chunks ON TABLE text_chunk;");
transaction_query.push_str(&format!(
"DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {};",
new_dimensions
));
transaction_query.push_str("COMMIT TRANSACTION;");
// Execute the entire atomic operation
db.query(transaction_query).await?;
info!("Re-embedding process for text chunks completed successfully.");
Ok(())
}
}
#[cfg(test)]

View File

@@ -12,12 +12,18 @@ pub struct AppConfig {
#[serde(default = "default_data_dir")]
pub data_dir: String,
pub http_port: u16,
#[serde(default = "default_base_url")]
pub openai_base_url: String,
}
fn default_data_dir() -> String {
"./data".to_string()
}
fn default_base_url() -> String {
"https://api.openai.com/v1.".to_string()
}
pub fn get_config() -> Result<AppConfig, ConfigError> {
let config = Config::builder()
.add_source(File::with_name("config").required(false))

View File

@@ -1,6 +1,10 @@
use async_openai::types::CreateEmbeddingRequestArgs;
use tracing::debug;
use crate::error::AppError;
use crate::{
error::AppError,
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
};
/// Generates an embedding vector for the given input text using OpenAI's embedding model.
///
/// This function takes a text input and converts it into a numerical vector representation (embedding)
@@ -27,9 +31,13 @@ use crate::error::AppError;
pub async fn generate_embedding(
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
input: &str,
db: &SurrealDbClient,
) -> Result<Vec<f32>, AppError> {
let model = SystemSettings::get_current(db).await?;
let request = CreateEmbeddingRequestArgs::default()
.model("text-embedding-3-small")
.model(model.embedding_model)
.dimensions(model.embedding_dimensions)
.input([input])
.build()?;
@@ -46,3 +54,36 @@ pub async fn generate_embedding(
Ok(embedding)
}
/// Generates an embedding vector using a specific model and dimension.
///
/// This is used for the re-embedding process where the model and dimensions
/// are known ahead of time and shouldn't be repeatedly fetched from settings.
pub async fn generate_embedding_with_params(
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
input: &str,
model: &str,
dimensions: u32,
) -> Result<Vec<f32>, AppError> {
let request = CreateEmbeddingRequestArgs::default()
.model(model)
.input([input])
.dimensions(dimensions as u32)
.build()?;
let response = client.embeddings().create(request).await?;
let embedding = response
.data
.first()
.ok_or_else(|| AppError::LLMParsing("No embedding data received from API".into()))?
.embedding
.clone();
debug!(
"Embedding was created with {:?} dimensions",
embedding.len()
);
Ok(embedding)
}

View File

@@ -1,6 +1,4 @@
use surrealdb::{engine::any::Any, Surreal};
use common::{error::AppError, utils::embedding::generate_embedding};
use common::{error::AppError, storage::db::SurrealDbClient, utils::embedding::generate_embedding};
/// Compares vectors and retrieves a number of items from the specified table.
///
@@ -26,7 +24,7 @@ use common::{error::AppError, utils::embedding::generate_embedding};
pub async fn find_items_by_vector_similarity<T>(
take: u8,
input_text: &str,
db_client: &Surreal<Any>,
db_client: &SurrealDbClient,
table: &str,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
user_id: &str,
@@ -35,7 +33,7 @@ where
T: for<'de> serde::Deserialize<'de>,
{
// Generate embeddings
let input_embedding = generate_embedding(openai_client, input_text).await?;
let input_embedding = generate_embedding(openai_client, input_text, db_client).await?;
// Construct the query
let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE user_id = '{}' AND embedding <|{},40|> {:?} ORDER BY distance", table, user_id, take, input_embedding);

File diff suppressed because one or more lines are too long

View File

@@ -1,13 +1,20 @@
use async_openai::types::ListModelResponse;
use axum::{extract::State, response::IntoResponse, Form};
use serde::{Deserialize, Serialize};
use common::storage::types::{
analytics::Analytics,
conversation::Conversation,
system_prompts::{DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT, DEFAULT_QUERY_SYSTEM_PROMPT},
system_settings::SystemSettings,
user::User,
use common::{
error::AppError,
storage::types::{
analytics::Analytics,
conversation::Conversation,
knowledge_entity::KnowledgeEntity,
system_prompts::{DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT, DEFAULT_QUERY_SYSTEM_PROMPT},
system_settings::SystemSettings,
text_chunk::TextChunk,
user::User,
},
};
use tracing::{error, info};
use crate::{
html_state::HtmlState,
@@ -25,6 +32,7 @@ pub struct AdminPanelData {
users: i64,
default_query_prompt: String,
conversation_archive: Vec<Conversation>,
available_models: ListModelResponse,
}
pub async fn show_admin_panel(
@@ -35,6 +43,12 @@ pub async fn show_admin_panel(
let analytics = Analytics::get_current(&state.db).await?;
let users_count = Analytics::get_users_amount(&state.db).await?;
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
let available_models = state
.openai_client
.models()
.list()
.await
.map_err(|_e| AppError::InternalError("OpenAI error".to_string()))?;
Ok(TemplateResponse::new_template(
"admin/base.html",
@@ -42,6 +56,7 @@ pub async fn show_admin_panel(
user,
settings,
analytics,
available_models,
users: users_count,
default_query_prompt: DEFAULT_QUERY_SYSTEM_PROMPT.to_string(),
conversation_archive,
@@ -103,11 +118,14 @@ pub async fn toggle_registration_status(
pub struct ModelSettingsInput {
query_model: String,
processing_model: String,
embedding_model: String,
embedding_dimensions: Option<u32>,
}
#[derive(Serialize)]
pub struct ModelSettingsData {
settings: SystemSettings,
available_models: ListModelResponse,
}
pub async fn update_model_settings(
@@ -122,19 +140,77 @@ pub async fn update_model_settings(
let current_settings = SystemSettings::get_current(&state.db).await?;
// --- Determine if re-embedding is required ---
let reembedding_needed = input
.embedding_dimensions
.is_some_and(|new_dims| new_dims != current_settings.embedding_dimensions);
// --- Create the new settings object immutably ---
let new_settings = SystemSettings {
query_model: input.query_model,
processing_model: input.processing_model,
..current_settings
embedding_model: input.embedding_model,
// Use new dimensions if provided, otherwise retain the current ones.
embedding_dimensions: input
.embedding_dimensions
.unwrap_or(current_settings.embedding_dimensions),
// Copy all other fields from the current settings.
..current_settings.clone()
};
SystemSettings::update(&state.db, new_settings.clone()).await?;
if reembedding_needed {
info!("Embedding dimensions changed. Spawning background re-embedding task...");
let db_for_task = state.db.clone();
let openai_for_task = state.openai_client.clone();
let new_model_for_task = new_settings.embedding_model.clone();
let new_dims_for_task = new_settings.embedding_dimensions;
tokio::spawn(async move {
// First, update all text chunks
if let Err(e) = TextChunk::update_all_embeddings(
&db_for_task,
&openai_for_task,
&new_model_for_task,
new_dims_for_task,
)
.await
{
error!("Background re-embedding task failed for TextChunks: {}", e);
}
// Second, update all knowledge entities
if let Err(e) = KnowledgeEntity::update_all_embeddings(
&db_for_task,
&openai_for_task,
&new_model_for_task,
new_dims_for_task,
)
.await
{
error!(
"Background re-embedding task failed for KnowledgeEntities: {}",
e
);
}
});
}
let available_models = state
.openai_client
.models()
.list()
.await
.map_err(|_e| AppError::InternalError("Failed to get models".to_string()))?;
Ok(TemplateResponse::new_partial(
"admin/base.html",
"model_settings_form",
ModelSettingsData {
settings: new_settings,
available_models,
},
))
}

View File

@@ -138,7 +138,7 @@ pub async fn delete_job(
let active_jobs = User::get_unfinished_ingestion_tasks(&user.id, &state.db).await?;
Ok(TemplateResponse::new_partial(
"index/signed_in/active_jobs.html",
"dashboard/active_jobs.html",
"active_jobs_section",
ActiveJobsData {
user: user.clone(),

View File

@@ -50,36 +50,94 @@
<legend class="fieldset-legend">AI Models</legend>
{% block model_settings_form %}
<form hx-patch="/update-model-settings" hx-swap="outerHTML">
<!-- Query Model -->
<div class="form-control mb-4">
<label class="label">
<span class="label-text">Query Model</span>
</label>
<select name="query_model" class="select select-bordered w-full">
<option value="gpt-4o-mini" {% if settings.query_model=="gpt-4o-mini" %}selected{% endif %}>GPT-4o Mini
</option>
<option value="gpt-4.1" {% if settings.query_model=="gpt-4.1" %}selected{% endif %}>GPT-4.1</option>
<option value="gpt-4.1-mini" {% if settings.query_model=="gpt-4.1-mini" %}selected{% endif %}>GPT-4.1-mini
{% for model in available_models.data %}
<option value="{{model.id}}" {% if settings.query_model==model.id %} selected {% endif %}>{{model.id}}
</option>
{% endfor %}
</select>
<p class="text-xs text-gray-500 mt-1">Model used for answering user queries</p>
</div>
<div class="form-control my-4">
<!-- Processing Model -->
<div class="form-control mb-4">
<label class="label">
<span class="label-text">Processing Model</span>
</label>
<select name="processing_model" class="select select-bordered w-full">
<option value="gpt-4o-mini" {% if settings.query_model=="gpt-4o-mini" %}selected{% endif %}>GPT-4o Mini
</option>
<option value="gpt-4.1" {% if settings.query_model=="gpt-4.1" %}selected{% endif %}>GPT-4.1</option>
<option value="gpt-4.1-mini" {% if settings.query_model=="gpt-4.1-mini" %}selected{% endif %}>GPT-4.1-mini
{% for model in available_models.data %}
<option value="{{model.id}}" {% if settings.processing_model==model.id %} selected {% endif %}>{{model.id}}
</option>
{% endfor %}
</select>
<p class="text-xs text-gray-500 mt-1">Model used for content processing and ingestion</p>
</div>
<button type="submit" class="btn btn-primary btn-sm">Save Model Settings</button>
<!-- Embedding Model -->
<div class="form-control mb-4">
<label class="label">
<span class="label-text">Embedding Model</span>
</label>
<select name="embedding_model" class="select select-bordered w-full">
{% for model in available_models.data %}
<option value="{{model.id}}" {% if settings.embedding_model==model.id %} selected {% endif %}>{{model.id}}
</option>
{% endfor %}
</select>
<p class="text-xs text-gray-500 mt-1">
Current used:
<span class="font-mono">{{settings.embedding_model}} ({{settings.embedding_dimensions}} dims)</span>
</p>
</div>
<!-- Embedding Dimensions (Always Visible) -->
<div class="form-control mb-4">
<label class="label" for="embedding_dimensions">
<span class="label-text">Embedding Dimensions</span>
</label>
<input type="number" id="embedding_dimensions" name="embedding_dimensions" class="input input-bordered w-full"
value="{{ settings.embedding_dimensions }}" required />
</div>
<!-- Conditional Alert -->
<div id="embedding-change-alert" role="alert" class="alert alert-warning mt-2 hidden">
<svg xmlns="http://www.w3.org/2000/svg" class="stroke-current shrink-0 h-6 w-6" fill="none"
viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2"
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z" />
</svg>
<span><strong>Warning:</strong> Changing dimensions will require re-creating all embeddings. Make sure you
look up what dimensions the model uses or use a model that allows specifying embedding dimensions</span>
</div>
<button type="submit" class="btn btn-primary btn-sm mt-4">Save Model Settings</button>
</form>
<script>
// Use a self-executing function to avoid polluting the global scope
// and to ensure it runs correctly after an HTMX swap.
(() => {
const dimensionInput = document.getElementById('embedding_dimensions');
const alertElement = document.getElementById('embedding-change-alert');
// The initial value is read directly from the template each time this script runs.
const initialDimensions = '{{ settings.embedding_dimensions }}';
if (dimensionInput && alertElement) {
// Use the 'input' event for immediate feedback as the user types.
dimensionInput.addEventListener('input', (event) => {
// Show alert if the current value is not the initial value. Hide it otherwise.
if (event.target.value !== initialDimensions) {
alertElement.classList.remove('hidden');
} else {
alertElement.classList.add('hidden');
}
});
}
})();
</script>
{% endblock %}
</fieldset>

View File

@@ -95,7 +95,7 @@ impl IngestionPipeline {
// Convert analysis to application objects
let (entities, relationships) = analysis
.to_database_entities(&content.id, &content.user_id, &self.openai_client)
.to_database_entities(&content.id, &content.user_id, &self.openai_client, &self.db)
.await?;
// Store everything
@@ -155,7 +155,7 @@ impl IngestionPipeline {
// Could potentially process chunks in parallel with a bounded concurrent limit
for chunk in chunks {
let embedding = generate_embedding(&self.openai_client, chunk).await?;
let embedding = generate_embedding(&self.openai_client, chunk, &self.db).await?;
let text_chunk = TextChunk::new(
content.id.to_string(),
chunk.to_string(),

View File

@@ -38,7 +38,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let session_store = Arc::new(db.create_session_store().await?);
let openai_client = Arc::new(async_openai::Client::with_config(
async_openai::config::OpenAIConfig::new().with_api_key(&config.openai_api_key),
async_openai::config::OpenAIConfig::new()
.with_api_key(&config.openai_api_key)
.with_api_base(&config.openai_base_url),
));
let html_state =
@@ -94,7 +96,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
);
// Initialize worker components
let openai_client = Arc::new(async_openai::Client::new());
let openai_client = Arc::new(async_openai::Client::with_config(
async_openai::config::OpenAIConfig::new()
.with_api_key(&config.openai_api_key)
.with_api_base(&config.openai_base_url),
));
let ingestion_pipeline = Arc::new(
IngestionPipeline::new(worker_db.clone(), openai_client.clone(), config.clone())
.await

View File

@@ -1,5 +1,5 @@
[] ollama and changing of openai_base_url
[] allow changing of port the server listens to
[x] ollama and changing of openai_base_url
[x] allow changing of port the server listens to
[] archive ingressed webpage, pdf would be easy
[] embed surrealdb for the main binary
[] three js graph explorer