From ba0db2649c1980991173b73cc1706b51c53e3eaa Mon Sep 17 00:00:00 2001 From: Per Stark Date: Thu, 20 Mar 2025 21:31:20 +0100 Subject: [PATCH] feat: bin for combined server and worker --- Cargo.lock | 1 + .../src/answer_retrieval.rs | 1 - crates/ingestion-pipeline/Cargo.toml | 1 + crates/ingestion-pipeline/src/lib.rs | 99 +++++++++++++++++ crates/main/Cargo.toml | 4 + crates/main/src/main.rs | 102 +++++++++++++++++ crates/main/src/worker.rs | 104 +----------------- 7 files changed, 212 insertions(+), 100 deletions(-) create mode 100644 crates/main/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index f2c421f..feab823 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2443,6 +2443,7 @@ dependencies = [ "scraper", "serde", "serde_json", + "surrealdb", "text-splitter", "tiktoken-rs", "tokio", diff --git a/crates/composite-retrieval/src/answer_retrieval.rs b/crates/composite-retrieval/src/answer_retrieval.rs index b5aca58..d560ae3 100644 --- a/crates/composite-retrieval/src/answer_retrieval.rs +++ b/crates/composite-retrieval/src/answer_retrieval.rs @@ -68,7 +68,6 @@ pub async fn get_answer_with_references( let entities = retrieve_entities(surreal_db_client, openai_client, query, user_id).await?; let entities_json = format_entities_json(&entities); - debug!("{:?}", entities_json); let user_message = create_user_message(&entities_json, query); let request = create_chat_request(user_message)?; diff --git a/crates/ingestion-pipeline/Cargo.toml b/crates/ingestion-pipeline/Cargo.toml index 32d6850..ddb2d0b 100644 --- a/crates/ingestion-pipeline/Cargo.toml +++ b/crates/ingestion-pipeline/Cargo.toml @@ -12,6 +12,7 @@ tracing = { workspace = true } serde_json = { workspace = true } futures = { workspace = true } async-openai = { workspace = true } +surrealdb = { workspace = true } tiktoken-rs = "0.6.0" reqwest = {version = "0.12.12", features = ["charset", "json"]} diff --git a/crates/ingestion-pipeline/src/lib.rs b/crates/ingestion-pipeline/src/lib.rs index 60cf9ea..0c472b7 100644 --- a/crates/ingestion-pipeline/src/lib.rs +++ b/crates/ingestion-pipeline/src/lib.rs @@ -2,3 +2,102 @@ pub mod enricher; pub mod pipeline; pub mod types; pub mod utils; + +use common::storage::{ + db::SurrealDbClient, + types::ingestion_task::{IngestionTask, IngestionTaskStatus}, +}; +use futures::StreamExt; +use pipeline::IngestionPipeline; +use std::sync::Arc; +use surrealdb::Action; +use tracing::{error, info}; + +pub async fn run_worker_loop( + db: Arc, + ingestion_pipeline: Arc, +) -> Result<(), Box> { + loop { + // First, check for any unfinished tasks + let unfinished_tasks = IngestionTask::get_unfinished_tasks(&db).await?; + if !unfinished_tasks.is_empty() { + info!("Found {} unfinished jobs", unfinished_tasks.len()); + for task in unfinished_tasks { + ingestion_pipeline.process_task(task).await?; + } + } + + // If no unfinished jobs, start listening for new ones + info!("Listening for new jobs..."); + let mut job_stream = IngestionTask::listen_for_tasks(&db).await?; + while let Some(notification) = job_stream.next().await { + match notification { + Ok(notification) => { + info!("Received notification: {:?}", notification); + match notification.action { + Action::Create => { + if let Err(e) = ingestion_pipeline.process_task(notification.data).await + { + error!("Error processing task: {}", e); + } + } + Action::Update => { + match notification.data.status { + IngestionTaskStatus::Completed + | IngestionTaskStatus::Error(_) + | IngestionTaskStatus::Cancelled => { + info!( + "Skipping already completed/error/cancelled task: {}", + notification.data.id + ); + continue; + } + IngestionTaskStatus::InProgress { attempts, .. } => { + // Only process if this is a retry after an error, not our own update + if let Ok(Some(current_task)) = + db.get_item::(¬ification.data.id).await + { + match current_task.status { + IngestionTaskStatus::Error(_) + if attempts + < common::storage::types::ingestion_task::MAX_ATTEMPTS => + { + // This is a retry after an error + if let Err(e) = + ingestion_pipeline.process_task(current_task).await + { + error!("Error processing task retry: {}", e); + } + } + _ => { + info!( + "Skipping in-progress update for task: {}", + notification.data.id + ); + continue; + } + } + } + } + IngestionTaskStatus::Created => { + // Shouldn't happen with Update action, but process if it does + if let Err(e) = + ingestion_pipeline.process_task(notification.data).await + { + error!("Error processing task: {}", e); + } + } + } + } + _ => {} // Ignore other actions + } + } + Err(e) => error!("Error in job notification: {}", e), + } + } + + // If we reach here, the stream has ended (connection lost?) + error!("Database stream ended unexpectedly, reconnecting..."); + tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; + } +} diff --git a/crates/main/Cargo.toml b/crates/main/Cargo.toml index 5542f39..275cba6 100644 --- a/crates/main/Cargo.toml +++ b/crates/main/Cargo.toml @@ -29,3 +29,7 @@ path = "src/server.rs" [[bin]] name = "worker" path = "src/worker.rs" + +[[bin]] +name = "main" +path = "src/main.rs" diff --git a/crates/main/src/main.rs b/crates/main/src/main.rs new file mode 100644 index 0000000..9175cd2 --- /dev/null +++ b/crates/main/src/main.rs @@ -0,0 +1,102 @@ +use api_router::{api_routes_v1, api_state::ApiState}; +use axum::{extract::FromRef, Router}; +use common::{storage::db::SurrealDbClient, utils::config::get_config}; +use html_router::{html_routes, html_state::HtmlState}; +use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop}; +use std::sync::Arc; +use tracing::{error, info}; +use tracing_subscriber::{fmt, prelude::*, EnvFilter}; + +use tokio::task::LocalSet; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Set up tracing + tracing_subscriber::registry() + .with(fmt::layer()) + .with(EnvFilter::from_default_env()) + .try_init() + .ok(); + + // Get config + let config = get_config()?; + + // Set up server components + let html_state = HtmlState::new(&config).await?; + let api_state = ApiState { + db: html_state.db.clone(), + }; + + // Create Axum router + let app = Router::new() + .nest("/api/v1", api_routes_v1(&api_state)) + .nest("/", html_routes(&html_state)) + .with_state(AppState { + api_state, + html_state, + }); + + info!("Starting server listening on 0.0.0.0:3000"); + let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?; + + // Start the server in a separate OS thread with its own runtime + let server_handle = std::thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + if let Err(e) = axum::serve(listener, app).await { + error!("Server error: {}", e); + } + }); + }); + + // Create a LocalSet for the worker + let local = LocalSet::new(); + + // Use a clone of the config for the worker + let worker_config = config.clone(); + + // Run the worker in the local set + local.spawn_local(async move { + // Create worker db connection + let worker_db = Arc::new( + SurrealDbClient::new( + &worker_config.surrealdb_address, + &worker_config.surrealdb_username, + &worker_config.surrealdb_password, + &worker_config.surrealdb_namespace, + &worker_config.surrealdb_database, + ) + .await + .unwrap(), + ); + + // Initialize worker components + let openai_client = Arc::new(async_openai::Client::new()); + let ingestion_pipeline = Arc::new( + IngestionPipeline::new(worker_db.clone(), openai_client.clone()) + .await + .unwrap(), + ); + + info!("Starting worker process"); + if let Err(e) = run_worker_loop(worker_db, ingestion_pipeline).await { + error!("Worker process error: {}", e); + } + }); + + // Run the local set on the main thread + local.await; + + // Wait for the server thread to finish (this likely won't be reached) + if let Err(e) = server_handle.join() { + error!("Server thread panicked: {:?}", e); + } + + Ok(()) +} + +#[derive(Clone, FromRef)] +struct AppState { + api_state: ApiState, + html_state: HtmlState, +} diff --git a/crates/main/src/worker.rs b/crates/main/src/worker.rs index 40800a8..78d327f 100644 --- a/crates/main/src/worker.rs +++ b/crates/main/src/worker.rs @@ -1,16 +1,7 @@ use std::sync::Arc; -use common::{ - storage::{ - db::SurrealDbClient, - types::ingestion_task::{IngestionTask, IngestionTaskStatus}, - }, - utils::config::get_config, -}; -use futures::StreamExt; -use ingestion_pipeline::pipeline::IngestionPipeline; -use surrealdb::Action; -use tracing::{error, info}; +use common::{storage::db::SurrealDbClient, utils::config::get_config}; +use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop}; use tracing_subscriber::{fmt, prelude::*, EnvFilter}; #[tokio::main] @@ -37,93 +28,8 @@ async fn main() -> Result<(), Box> { let openai_client = Arc::new(async_openai::Client::new()); - let ingestion_pipeline = IngestionPipeline::new(db.clone(), openai_client.clone()).await?; + let ingestion_pipeline = + Arc::new(IngestionPipeline::new(db.clone(), openai_client.clone()).await?); - loop { - // First, check for any unfinished tasks - let unfinished_tasks = IngestionTask::get_unfinished_tasks(&db).await?; - - if !unfinished_tasks.is_empty() { - info!("Found {} unfinished jobs", unfinished_tasks.len()); - - for task in unfinished_tasks { - ingestion_pipeline.process_task(task).await?; - } - } - - // If no unfinished jobs, start listening for new ones - info!("Listening for new jobs..."); - let mut job_stream = IngestionTask::listen_for_tasks(&db).await?; - - while let Some(notification) = job_stream.next().await { - match notification { - Ok(notification) => { - info!("Received notification: {:?}", notification); - - match notification.action { - Action::Create => { - if let Err(e) = ingestion_pipeline.process_task(notification.data).await - { - error!("Error processing task: {}", e); - } - } - Action::Update => { - match notification.data.status { - IngestionTaskStatus::Completed - | IngestionTaskStatus::Error(_) - | IngestionTaskStatus::Cancelled => { - info!( - "Skipping already completed/error/cancelled task: {}", - notification.data.id - ); - continue; - } - IngestionTaskStatus::InProgress { attempts, .. } => { - // Only process if this is a retry after an error, not our own update - if let Ok(Some(current_task)) = - db.get_item::(¬ification.data.id).await - { - match current_task.status { - IngestionTaskStatus::Error(_) - if attempts - < common::storage::types::ingestion_task::MAX_ATTEMPTS => - { - // This is a retry after an error - if let Err(e) = - ingestion_pipeline.process_task(current_task).await - { - error!("Error processing task retry: {}", e); - } - } - _ => { - info!( - "Skipping in-progress update for task: {}", - notification.data.id - ); - continue; - } - } - } - } - IngestionTaskStatus::Created => { - // Shouldn't happen with Update action, but process if it does - if let Err(e) = - ingestion_pipeline.process_task(notification.data).await - { - error!("Error processing task: {}", e); - } - } - } - } - _ => {} // Ignore other actions - } - } - Err(e) => error!("Error in job notification: {}", e), - } - } - - // If we reach here, the stream has ended (connection lost?) - error!("Database stream ended unexpectedly, reconnecting..."); - tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; - } + run_worker_loop(db, ingestion_pipeline).await }