diff --git a/common/src/storage/db.rs b/common/src/storage/db.rs index f91135a..6c97c92 100644 --- a/common/src/storage/db.rs +++ b/common/src/storage/db.rs @@ -41,8 +41,10 @@ impl SurrealDbClient { ) -> Result { let db = connect(address).await?; - // Sign in to database - db.signin(Root { username, password }).await?; + // Skip sign-in for in-memory engine (no auth support) + if !address.starts_with("mem://") { + db.signin(Root { username, password }).await?; + } // Set namespace db.use_ns(namespace).use_db(database).await?; diff --git a/main/src/bootstrap.rs b/main/src/bootstrap.rs new file mode 100644 index 0000000..4272ed6 --- /dev/null +++ b/main/src/bootstrap.rs @@ -0,0 +1,73 @@ +use std::sync::Arc; + +use async_openai::Client; +use common::{ + storage::{ + db::SurrealDbClient, + store::StorageManager, + }, + utils::{ + config::{get_config, AppConfig}, + embedding::EmbeddingProvider, + }, +}; +use retrieval_pipeline::reranking::RerankerPool; +use tracing_subscriber::{fmt, prelude::*, EnvFilter}; + +pub struct SharedServices { + pub db: Arc, + pub openai_client: Arc>, + pub embedding_provider: Arc, + pub storage: StorageManager, + pub reranker_pool: Option>, + pub config: AppConfig, +} + +pub async fn init() -> anyhow::Result { + tracing_subscriber::registry() + .with(fmt::layer().with_writer(std::io::stderr)) + .with(EnvFilter::from_default_env()) + .try_init() + .ok(); + + let config = get_config()?; + init_with_config(config).await +} + +pub(crate) async fn init_with_config(config: AppConfig) -> anyhow::Result { + let db = Arc::new( + SurrealDbClient::new( + &config.surrealdb_address, + &config.surrealdb_username, + &config.surrealdb_password, + &config.surrealdb_namespace, + &config.surrealdb_database, + ) + .await?, + ); + + db.apply_migrations().await?; + + let openai_client = Arc::new(Client::with_config( + async_openai::config::OpenAIConfig::new() + .with_api_key(&config.openai_api_key) + .with_api_base(&config.openai_base_url), + )); + + let embedding_provider = Arc::new( + EmbeddingProvider::from_config(&config, Some(Arc::clone(&openai_client))).await?, + ); + + let reranker_pool = RerankerPool::maybe_from_config(&config)?; + + let storage = StorageManager::new(&config).await?; + + Ok(SharedServices { + db, + openai_client, + embedding_provider, + storage, + reranker_pool, + config, + }) +} diff --git a/main/src/main.rs b/main/src/main.rs index 3d1018e..9083d10 100644 --- a/main/src/main.rs +++ b/main/src/main.rs @@ -1,27 +1,24 @@ +mod bootstrap; + +use std::sync::Arc; + use api_router::{api_routes_v1, api_state::ApiState}; use axum::{extract::FromRef, Router}; use common::{ storage::{ - db::SurrealDbClient, indexes::ensure_runtime, - store::StorageManager, types::{ knowledge_entity::KnowledgeEntity, system_settings::SystemSettings, text_chunk::TextChunk, }, }, - utils::{config::get_config, embedding::EmbeddingProvider}, }; use html_router::{ html_routes, html_state::{HtmlState, StateResources}, }; use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop}; -use retrieval_pipeline::reranking::RerankerPool; -use std::sync::Arc; use tracing::{error, info, warn}; -use tracing_subscriber::{fmt, prelude::*, EnvFilter}; - use tokio::task::LocalSet; fn spawn_server_thread( @@ -44,87 +41,21 @@ fn spawn_server_thread( }) } -async fn run_worker( - config: common::utils::config::AppConfig, - reranker_pool: Option>, - storage: StorageManager, -) -> anyhow::Result<()> { - let worker_db = Arc::new( - SurrealDbClient::new( - &config.surrealdb_address, - &config.surrealdb_username, - &config.surrealdb_password, - &config.surrealdb_namespace, - &config.surrealdb_database, - ) - .await?, - ); - - let openai_client = Arc::new(async_openai::Client::with_config( - async_openai::config::OpenAIConfig::new() - .with_api_key(&config.openai_api_key) - .with_api_base(&config.openai_base_url), - )); - - let embedding_provider = Arc::new( - EmbeddingProvider::from_config(&config, Some(Arc::clone(&openai_client))).await?, - ); - - let ingestion_pipeline = Arc::new( - IngestionPipeline::new( - Arc::clone(&worker_db), - openai_client, - config, - reranker_pool, - storage, - embedding_provider, - )?, - ); - - info!("Starting worker process"); - run_worker_loop(worker_db, ingestion_pipeline).await -} - #[tokio::main] async fn main() -> anyhow::Result<()> { - tracing_subscriber::registry() - .with(fmt::layer().with_writer(std::io::stderr)) - .with(EnvFilter::from_default_env()) - .try_init() - .ok(); + let services = bootstrap::init().await?; - let config = get_config()?; + let session_store = Arc::new(services.db.create_session_store().await?); - let db = Arc::new( - SurrealDbClient::new( - &config.surrealdb_address, - &config.surrealdb_username, - &config.surrealdb_password, - &config.surrealdb_namespace, - &config.surrealdb_database, - ) - .await?, - ); - - db.apply_migrations().await?; - - let session_store = Arc::new(db.create_session_store().await?); - let openai_client = Arc::new(async_openai::Client::with_config( - async_openai::config::OpenAIConfig::new() - .with_api_key(&config.openai_api_key) - .with_api_base(&config.openai_base_url), - )); - - let embedding_provider = - Arc::new(EmbeddingProvider::from_config(&config, Some(Arc::clone(&openai_client))).await?); info!( - embedding_backend = ?config.embedding_backend, - embedding_dimension = embedding_provider.dimension(), + embedding_backend = ?services.config.embedding_backend, + embedding_dimension = services.embedding_provider.dimension(), "Embedding provider initialized" ); let (settings, dimensions_changed) = - SystemSettings::sync_from_embedding_provider(&db, &embedding_provider).await?; + SystemSettings::sync_from_embedding_provider(&services.db, &services.embedding_provider) + .await?; if dimensions_changed { warn!( @@ -134,7 +65,8 @@ async fn main() -> anyhow::Result<()> { info!("Re-embedding TextChunks"); if let Err(e) = - TextChunk::update_all_embeddings_with_provider(&db, &embedding_provider).await + TextChunk::update_all_embeddings_with_provider(&services.db, &services.embedding_provider) + .await { error!( "Failed to re-embed TextChunks: {}. Search results may be stale.", @@ -144,7 +76,8 @@ async fn main() -> anyhow::Result<()> { info!("Re-embedding KnowledgeEntities"); if let Err(e) = - KnowledgeEntity::update_all_embeddings_with_provider(&db, &embedding_provider).await + KnowledgeEntity::update_all_embeddings_with_provider(&services.db, &services.embedding_provider) + .await { error!( "Failed to re-embed KnowledgeEntities: {}. Search results may be stale.", @@ -155,24 +88,20 @@ async fn main() -> anyhow::Result<()> { info!("Re-embedding complete."); } - ensure_runtime(&db, settings.embedding_dimensions as usize).await?; + ensure_runtime(&services.db, settings.embedding_dimensions as usize).await?; - let reranker_pool = RerankerPool::maybe_from_config(&config)?; - - let storage = StorageManager::new(&config).await?; - - let html_state = HtmlState::new_with_resources(StateResources { - db, - openai_client, + let html_state = HtmlState::new_with_resources(StateResources { + db: Arc::clone(&services.db), + openai_client: Arc::clone(&services.openai_client), session_store, - storage: storage.clone(), - config: config.clone(), - reranker_pool: reranker_pool.clone(), - embedding_provider: Arc::clone(&embedding_provider), + storage: services.storage.clone(), + config: services.config.clone(), + reranker_pool: services.reranker_pool.clone(), + embedding_provider: Arc::clone(&services.embedding_provider), template_engine: None, }); - let api_state = ApiState::new(&config, storage.clone()).await?; + let api_state = ApiState::new(&services.config, services.storage.clone()).await?; let app = Router::new() .nest("/api/v1", api_routes_v1(&api_state)) @@ -182,15 +111,28 @@ async fn main() -> anyhow::Result<()> { html_state, }); - info!("Starting server listening on 0.0.0.0:{}", config.http_port); - let serve_address = format!("0.0.0.0:{}", config.http_port); + info!( + "Starting server listening on 0.0.0.0:{}", + services.config.http_port + ); + let serve_address = format!("0.0.0.0:{}", services.config.http_port); let listener = tokio::net::TcpListener::bind(serve_address).await?; let server_handle = spawn_server_thread(listener, app); + let ingestion_pipeline = Arc::new(IngestionPipeline::new( + Arc::clone(&services.db), + Arc::clone(&services.openai_client), + services.config.clone(), + services.reranker_pool.clone(), + services.storage, + Arc::clone(&services.embedding_provider), + )?); + let local = LocalSet::new(); local.spawn_local(async move { - if let Err(e) = run_worker(config, reranker_pool, storage).await { + info!("Starting worker process"); + if let Err(e) = run_worker_loop(services.db, ingestion_pipeline).await { error!("Worker error: {}", e); } }); @@ -222,7 +164,7 @@ mod tests { store::StorageManager, types::{system_settings::SystemSettings, user::User}, }; - use common::utils::config::{AppConfig, PdfIngestMode, StorageKind}; + use common::utils::config::{AppConfig, EmbeddingBackend, PdfIngestMode, StorageKind}; use std::{path::Path, sync::Arc}; use tower::ServiceExt; use uuid::Uuid; @@ -240,6 +182,7 @@ mod tests { openai_base_url: "https://example.com".into(), storage: StorageKind::Local, pdf_ingest_mode: PdfIngestMode::LlmFirst, + embedding_backend: EmbeddingBackend::Hashed, ..Default::default() } } @@ -252,37 +195,25 @@ mod tests { .expect("failed to create temp data directory"); let config = smoke_test_config(namespace, &database, &data_dir); - let db = Arc::new(SurrealDbClient::memory(namespace, &database).await?); - db.apply_migrations().await?; + let services = crate::bootstrap::init_with_config(config.clone()).await?; - let session_store = Arc::new(db.create_session_store().await?); - let openai_client = Arc::new(async_openai::Client::with_config( - async_openai::config::OpenAIConfig::new() - .with_api_key(&config.openai_api_key) - .with_api_base(&config.openai_base_url), - )); + let session_store = Arc::new(services.db.create_session_store().await?); - let storage = StorageManager::new(&config).await?; - - let embedding_provider = Arc::new( - common::utils::embedding::EmbeddingProvider::new_hashed(384)?, - ); - - let html_state = HtmlState::new_with_resources(StateResources { - db: Arc::clone(&db), - openai_client, + let html_state = HtmlState::new_with_resources(StateResources { + db: Arc::clone(&services.db), + openai_client: Arc::clone(&services.openai_client), session_store, - storage: storage.clone(), - config: config.clone(), - reranker_pool: None, - embedding_provider, + storage: services.storage.clone(), + config: services.config.clone(), + reranker_pool: services.reranker_pool.clone(), + embedding_provider: Arc::clone(&services.embedding_provider), template_engine: None, }); let api_state = ApiState { - db: Arc::clone(&db), - config: config.clone(), - storage, + db: Arc::clone(&services.db), + config: services.config.clone(), + storage: services.storage, }; let app = Router::new() diff --git a/main/src/server.rs b/main/src/server.rs index 1a1e047..462fea9 100644 --- a/main/src/server.rs +++ b/main/src/server.rs @@ -1,85 +1,39 @@ +mod bootstrap; + use std::sync::Arc; use api_router::{api_routes_v1, api_state::ApiState}; use axum::{extract::FromRef, Router}; -use common::{ - storage::{db::SurrealDbClient, store::StorageManager, types::system_settings::SystemSettings}, - utils::{config::get_config, embedding::EmbeddingProvider}, -}; +use common::storage::types::system_settings::SystemSettings; use html_router::{ html_routes, html_state::{HtmlState, StateResources}, }; -use retrieval_pipeline::reranking::RerankerPool; use tracing::info; -use tracing_subscriber::{fmt, prelude::*, EnvFilter}; #[tokio::main(flavor = "multi_thread", worker_threads = 2)] async fn main() -> anyhow::Result<()> { - // Set up tracing - tracing_subscriber::registry() - .with(fmt::layer().with_writer(std::io::stderr)) - .with(EnvFilter::from_default_env()) - .try_init() - .ok(); + let services = bootstrap::init().await?; - // Get config - let config = get_config()?; + let session_store = Arc::new(services.db.create_session_store().await?); - // Set up router states - let db = Arc::new( - SurrealDbClient::new( - &config.surrealdb_address, - &config.surrealdb_username, - &config.surrealdb_password, - &config.surrealdb_namespace, - &config.surrealdb_database, - ) - .await?, - ); - - // Ensure db is initialized - db.apply_migrations().await?; - - let session_store = Arc::new(db.create_session_store().await?); - let openai_client = Arc::new(async_openai::Client::with_config( - async_openai::config::OpenAIConfig::new() - .with_api_key(&config.openai_api_key) - .with_api_base(&config.openai_base_url), - )); - - let reranker_pool = RerankerPool::maybe_from_config(&config)?; - - // Create global storage manager - let storage = StorageManager::new(&config).await?; - - // Create embedding provider based on config - let embedding_provider = - Arc::new(EmbeddingProvider::from_config(&config, Some(Arc::clone(&openai_client))).await?); - info!( - embedding_backend = ?config.embedding_backend, - embedding_dimension = embedding_provider.dimension(), - "Embedding provider initialized" - ); - - // Sync SystemSettings with provider's dimensions/backend for visibility let (_settings, _dimensions_changed) = - SystemSettings::sync_from_embedding_provider(&db, &embedding_provider).await?; + SystemSettings::sync_from_embedding_provider(&services.db, &services.embedding_provider) + .await?; let html_state = HtmlState::new_with_resources(StateResources { - db, - openai_client, + db: Arc::clone(&services.db), + openai_client: Arc::clone(&services.openai_client), session_store, - storage: storage.clone(), - config: config.clone(), - reranker_pool, - embedding_provider, + storage: services.storage.clone(), + config: services.config.clone(), + reranker_pool: services.reranker_pool.clone(), + embedding_provider: Arc::clone(&services.embedding_provider), template_engine: None, }); - let api_state = ApiState::new(&config, storage).await?; + let api_state = ApiState::new(&services.config, services.storage).await?; - // Create Axum router let app = Router::new() .nest("/api/v1", api_routes_v1(&api_state)) .merge(html_routes(&html_state)) @@ -88,8 +42,11 @@ async fn main() -> anyhow::Result<()> { html_state, }); - info!("Starting server listening on 0.0.0.0:{}", config.http_port); - let serve_address = format!("0.0.0.0:{}", config.http_port); + info!( + "Starting server listening on 0.0.0.0:{}", + services.config.http_port + ); + let serve_address = format!("0.0.0.0:{}", services.config.http_port); let listener = tokio::net::TcpListener::bind(serve_address).await?; axum::serve(listener, app).await?; diff --git a/main/src/worker.rs b/main/src/worker.rs index 1e076c6..4e34332 100644 --- a/main/src/worker.rs +++ b/main/src/worker.rs @@ -1,64 +1,27 @@ +mod bootstrap; + use std::sync::Arc; -use common::{ - storage::db::SurrealDbClient, - storage::store::StorageManager, - utils::{config::get_config, embedding::EmbeddingProvider}, -}; use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop}; -use retrieval_pipeline::reranking::RerankerPool; use tracing::info; -use tracing_subscriber::{fmt, prelude::*, EnvFilter}; #[tokio::main] async fn main() -> anyhow::Result<()> { - // Set up tracing - tracing_subscriber::registry() - .with(fmt::layer()) - .with(EnvFilter::from_default_env()) - .try_init() - .ok(); + let services = bootstrap::init().await?; - let config = get_config()?; - - let db = Arc::new( - SurrealDbClient::new( - &config.surrealdb_address, - &config.surrealdb_username, - &config.surrealdb_password, - &config.surrealdb_namespace, - &config.surrealdb_database, - ) - .await?, - ); - - let openai_client = Arc::new(async_openai::Client::with_config( - async_openai::config::OpenAIConfig::new() - .with_api_key(&config.openai_api_key) - .with_api_base(&config.openai_base_url), - )); - - let reranker_pool = RerankerPool::maybe_from_config(&config)?; - - // Create embedding provider based on config - let embedding_provider = - Arc::new(EmbeddingProvider::from_config(&config, Some(Arc::clone(&openai_client))).await?); info!( - embedding_backend = ?config.embedding_backend, + embedding_backend = ?services.config.embedding_backend, "Embedding provider initialized for worker" ); - // Create global storage manager - let storage = StorageManager::new(&config).await?; - let ingestion_pipeline = Arc::new(IngestionPipeline::new( - Arc::clone(&db), - Arc::clone(&openai_client), - config, - reranker_pool, - storage, - embedding_provider, + Arc::clone(&services.db), + Arc::clone(&services.openai_client), + services.config.clone(), + services.reranker_pool.clone(), + services.storage, + Arc::clone(&services.embedding_provider), )?); - run_worker_loop(db, ingestion_pipeline).await + run_worker_loop(services.db, ingestion_pipeline).await }