From 7d79f468aa0aef8d8fbd51c729e1bff7fd36c613 Mon Sep 17 00:00:00 2001 From: Per Stark Date: Wed, 5 Mar 2025 16:14:18 +0100 Subject: [PATCH] refactored queue into Job --- crates/api-router/src/api_state.rs | 8 +- crates/api-router/src/middleware_api_auth.rs | 2 +- crates/api-router/src/routes/ingress.rs | 13 +- .../common/src/ingress/content_processor.rs | 68 +++++-- crates/common/src/ingress/jobqueue.rs | 182 ------------------ crates/common/src/ingress/mod.rs | 2 - crates/common/src/ingress/queue_task.rs | 13 -- crates/common/src/storage/db.rs | 140 +++++++------- .../common/src/storage/types/conversation.rs | 9 +- crates/common/src/storage/types/file_info.rs | 11 +- crates/common/src/storage/types/job.rs | 69 ++++++- crates/common/src/storage/types/user.rs | 64 +++++- crates/html-router/src/html_state.rs | 7 +- crates/html-router/src/lib.rs | 2 +- .../html-router/src/middleware_analytics.rs | 4 +- crates/html-router/src/routes/account.rs | 10 +- crates/html-router/src/routes/admin_panel.rs | 10 +- .../routes/chat/message_response_stream.rs | 27 ++- crates/html-router/src/routes/chat/mod.rs | 51 ++--- .../html-router/src/routes/chat/references.rs | 9 +- crates/html-router/src/routes/content/mod.rs | 10 +- crates/html-router/src/routes/index.rs | 50 ++--- crates/html-router/src/routes/ingress_form.rs | 12 +- .../html-router/src/routes/knowledge/mod.rs | 43 ++--- crates/html-router/src/routes/signin.rs | 2 +- crates/html-router/src/routes/signup.rs | 9 +- crates/main/src/server.rs | 3 +- crates/main/src/worker.rs | 58 ++---- todo.md | 3 +- 29 files changed, 401 insertions(+), 490 deletions(-) delete mode 100644 crates/common/src/ingress/jobqueue.rs delete mode 100644 crates/common/src/ingress/queue_task.rs diff --git a/crates/api-router/src/api_state.rs b/crates/api-router/src/api_state.rs index 098843e..416b4ce 100644 --- a/crates/api-router/src/api_state.rs +++ b/crates/api-router/src/api_state.rs @@ -1,11 +1,10 @@ use std::sync::Arc; -use common::{ingress::jobqueue::JobQueue, storage::db::SurrealDbClient, utils::config::AppConfig}; +use common::{storage::db::SurrealDbClient, utils::config::AppConfig}; #[derive(Clone)] pub struct ApiState { - pub surreal_db_client: Arc, - pub job_queue: Arc, + pub db: Arc, } impl ApiState { @@ -24,8 +23,7 @@ impl ApiState { surreal_db_client.ensure_initialized().await?; let app_state = ApiState { - surreal_db_client: surreal_db_client.clone(), - job_queue: Arc::new(JobQueue::new(surreal_db_client)), + db: surreal_db_client.clone(), }; Ok(app_state) diff --git a/crates/api-router/src/middleware_api_auth.rs b/crates/api-router/src/middleware_api_auth.rs index f7cac7e..628e666 100644 --- a/crates/api-router/src/middleware_api_auth.rs +++ b/crates/api-router/src/middleware_api_auth.rs @@ -17,7 +17,7 @@ pub async fn api_auth( "You have to be authenticated".to_string(), ))?; - let user = User::find_by_api_key(&api_key, &state.surreal_db_client).await?; + let user = User::find_by_api_key(&api_key, &state.db).await?; let user = user.ok_or(ApiError::Unauthorized( "You have to be authenticated".to_string(), ))?; diff --git a/crates/api-router/src/routes/ingress.rs b/crates/api-router/src/routes/ingress.rs index f671ab0..34a5bae 100644 --- a/crates/api-router/src/routes/ingress.rs +++ b/crates/api-router/src/routes/ingress.rs @@ -3,7 +3,7 @@ use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart}; use common::{ error::{ApiError, AppError}, ingress::ingress_input::{create_ingress_objects, IngressInput}, - storage::types::{file_info::FileInfo, user::User}, + storage::types::{file_info::FileInfo, job::Job, user::User}, }; use futures::{future::try_join_all, TryFutureExt}; use tempfile::NamedTempFile; @@ -28,9 +28,12 @@ pub async fn ingress_data( ) -> Result { info!("Received input: {:?}", input); - let file_infos = try_join_all(input.files.into_iter().map(|file| { - FileInfo::new(file, &state.surreal_db_client, &user.id).map_err(AppError::from) - })) + let file_infos = try_join_all( + input + .files + .into_iter() + .map(|file| FileInfo::new(file, &state.db, &user.id).map_err(AppError::from)), + ) .await?; debug!("Got file infos"); @@ -48,7 +51,7 @@ pub async fn ingress_data( let futures: Vec<_> = ingress_objects .into_iter() - .map(|object| state.job_queue.enqueue(object.clone(), user.id.clone())) + .map(|object| Job::create_and_add_to_db(object.clone(), user.id.clone(), &state.db)) .collect(); try_join_all(futures).await.map_err(AppError::from)?; diff --git a/crates/common/src/ingress/content_processor.rs b/crates/common/src/ingress/content_processor.rs index 266162f..fe8ff8e 100644 --- a/crates/common/src/ingress/content_processor.rs +++ b/crates/common/src/ingress/content_processor.rs @@ -1,15 +1,19 @@ use std::{sync::Arc, time::Instant}; +use chrono::Utc; use text_splitter::TextSplitter; use tracing::{debug, info}; use crate::{ error::AppError, storage::{ - db::{store_item, SurrealDbClient}, + db::SurrealDbClient, types::{ - knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship, - text_chunk::TextChunk, text_content::TextContent, + job::{Job, JobStatus, MAX_ATTEMPTS}, + knowledge_entity::KnowledgeEntity, + knowledge_relationship::KnowledgeRelationship, + text_chunk::TextChunk, + text_content::TextContent, }, }, utils::embedding::generate_embedding, @@ -20,19 +24,53 @@ use super::analysis::{ }; pub struct ContentProcessor { - db_client: Arc, + db: Arc, openai_client: Arc>, } impl ContentProcessor { pub async fn new( - surreal_db_client: Arc, + db: Arc, openai_client: Arc>, ) -> Result { - Ok(Self { - db_client: surreal_db_client, - openai_client, - }) + Ok(Self { db, openai_client }) + } + pub async fn process_job(&self, job: Job) -> Result<(), AppError> { + let current_attempts = match job.status { + JobStatus::InProgress { attempts, .. } => attempts + 1, + _ => 1, + }; + + // Update status to InProgress with attempt count + Job::update_status( + &job.id, + JobStatus::InProgress { + attempts: current_attempts, + last_attempt: Utc::now(), + }, + &self.db, + ) + .await?; + + let text_content = job.content.to_text_content(&self.openai_client).await?; + + match self.process(&text_content).await { + Ok(_) => { + Job::update_status(&job.id, JobStatus::Completed, &self.db).await?; + Ok(()) + } + Err(e) => { + if current_attempts >= MAX_ATTEMPTS { + Job::update_status( + &job.id, + JobStatus::Error(format!("Max attempts reached: {}", e)), + &self.db, + ) + .await?; + } + Err(AppError::Processing(e.to_string())) + } + } } pub async fn process(&self, content: &TextContent) -> Result<(), AppError> { @@ -59,9 +97,9 @@ impl ContentProcessor { )?; // Store original content - store_item(&self.db_client, content.to_owned()).await?; + self.db.store_item(content.to_owned()).await?; - self.db_client.rebuild_indexes().await?; + self.db.rebuild_indexes().await?; Ok(()) } @@ -69,7 +107,7 @@ impl ContentProcessor { &self, content: &TextContent, ) -> Result { - let analyser = IngressAnalyzer::new(&self.db_client, &self.openai_client); + let analyser = IngressAnalyzer::new(&self.db, &self.openai_client); analyser .analyze_content( &content.category, @@ -87,12 +125,12 @@ impl ContentProcessor { ) -> Result<(), AppError> { for entity in &entities { debug!("Storing entity: {:?}", entity); - store_item(&self.db_client, entity.clone()).await?; + self.db.store_item(entity.clone()).await?; } for relationship in &relationships { debug!("Storing relationship: {:?}", relationship); - relationship.store_relationship(&self.db_client).await?; + relationship.store_relationship(&self.db).await?; } info!( @@ -116,7 +154,7 @@ impl ContentProcessor { embedding, content.user_id.to_string(), ); - store_item(&self.db_client, text_chunk).await?; + self.db.store_item(text_chunk).await?; } Ok(()) diff --git a/crates/common/src/ingress/jobqueue.rs b/crates/common/src/ingress/jobqueue.rs deleted file mode 100644 index 22a57e1..0000000 --- a/crates/common/src/ingress/jobqueue.rs +++ /dev/null @@ -1,182 +0,0 @@ -use chrono::Utc; -use futures::Stream; -use std::sync::Arc; -use surrealdb::{opt::PatchOp, Error, Notification}; -use tracing::{debug, error, info}; - -use crate::{ - error::AppError, - storage::{ - db::{delete_item, get_item, store_item, SurrealDbClient}, - types::{ - job::{Job, JobStatus}, - StoredObject, - }, - }, -}; - -use super::{content_processor::ContentProcessor, ingress_object::IngressObject}; - -pub struct JobQueue { - pub db: Arc, -} - -pub const MAX_ATTEMPTS: u32 = 3; - -impl JobQueue { - pub fn new(db: Arc) -> Self { - Self { db } - } - - /// Creates a new job and stores it in the database - pub async fn enqueue(&self, content: IngressObject, user_id: String) -> Result<(), AppError> { - let job = Job::new(content, user_id).await; - - info!("{:?}", job); - - store_item(&self.db, job).await?; - - Ok(()) - } - - /// Gets all jobs for a specific user - pub async fn get_user_jobs(&self, user_id: &str) -> Result, AppError> { - let jobs: Vec = self - .db - .query("SELECT * FROM job WHERE user_id = $user_id ORDER BY created_at DESC") - .bind(("user_id", user_id.to_owned())) - .await? - .take(0)?; - - debug!("{:?}", jobs); - - Ok(jobs) - } - - /// Gets all active jobs for a specific user - pub async fn get_unfinished_user_jobs(&self, user_id: &str) -> Result, AppError> { - let jobs: Vec = self - .db - .query( - "SELECT * FROM type::table($table) - WHERE user_id = $user_id - AND ( - status = 'Created' - OR ( - status.InProgress != NONE - AND status.InProgress.attempts < $max_attempts - ) - ) - ORDER BY created_at DESC", - ) - .bind(("table", Job::table_name())) - .bind(("user_id", user_id.to_owned())) - .bind(("max_attempts", MAX_ATTEMPTS)) - .await? - .take(0)?; - debug!("{:?}", jobs); - Ok(jobs) - } - - pub async fn delete_job(&self, id: &str, user_id: &str) -> Result<(), AppError> { - get_item::(&self.db.client, id) - .await? - .filter(|job| job.user_id == user_id) - .ok_or_else(|| { - error!("Unauthorized attempt to delete job {id} by user {user_id}"); - AppError::Auth("Not authorized to delete this job".into()) - })?; - - info!("Deleting job {id} for user {user_id}"); - delete_item::(&self.db.client, id) - .await - .map_err(AppError::Database)?; - - Ok(()) - } - - pub async fn update_status(&self, id: &str, status: JobStatus) -> Result<(), AppError> { - let _job: Option = self - .db - .update((Job::table_name(), id)) - .patch(PatchOp::replace("/status", status)) - .patch(PatchOp::replace( - "/updated_at", - surrealdb::sql::Datetime::default(), - )) - .await?; - - Ok(()) - } - - /// Listen for new jobs - pub async fn listen_for_jobs( - &self, - ) -> Result, Error>>, Error> { - self.db.select("job").live().await - } - - /// Get unfinished jobs, ie newly created and in progress up two times - pub async fn get_unfinished_jobs(&self) -> Result, AppError> { - let jobs: Vec = self - .db - .query( - "SELECT * FROM type::table($table) - WHERE - status = 'Created' - OR ( - status.InProgress != NONE - AND status.InProgress.attempts < $max_attempts - ) - ORDER BY created_at ASC", - ) - .bind(("table", Job::table_name())) - .bind(("max_attempts", MAX_ATTEMPTS)) - .await? - .take(0)?; - - Ok(jobs) - } - - // Method to process a single job - pub async fn process_job( - &self, - job: Job, - processor: &ContentProcessor, - openai_client: Arc>, - ) -> Result<(), AppError> { - let current_attempts = match job.status { - JobStatus::InProgress { attempts, .. } => attempts + 1, - _ => 1, - }; - - // Update status to InProgress with attempt count - self.update_status( - &job.id, - JobStatus::InProgress { - attempts: current_attempts, - last_attempt: Utc::now(), - }, - ) - .await?; - - let text_content = job.content.to_text_content(&openai_client).await?; - - match processor.process(&text_content).await { - Ok(_) => { - self.update_status(&job.id, JobStatus::Completed).await?; - Ok(()) - } - Err(e) => { - if current_attempts >= MAX_ATTEMPTS { - self.update_status( - &job.id, - JobStatus::Error(format!("Max attempts reached: {}", e)), - ) - .await?; - } - Err(AppError::Processing(e.to_string())) - } - } - } -} diff --git a/crates/common/src/ingress/mod.rs b/crates/common/src/ingress/mod.rs index fea78f0..34fee9c 100644 --- a/crates/common/src/ingress/mod.rs +++ b/crates/common/src/ingress/mod.rs @@ -2,5 +2,3 @@ pub mod analysis; pub mod content_processor; pub mod ingress_input; pub mod ingress_object; -pub mod jobqueue; -pub mod queue_task; diff --git a/crates/common/src/ingress/queue_task.rs b/crates/common/src/ingress/queue_task.rs deleted file mode 100644 index 79ee882..0000000 --- a/crates/common/src/ingress/queue_task.rs +++ /dev/null @@ -1,13 +0,0 @@ -use crate::ingress::ingress_object::IngressObject; -use serde::Serialize; - -#[derive(Serialize)] -pub struct QueueTask { - pub delivery_tag: u64, - pub content: IngressObject, -} - -#[derive(Serialize)] -pub struct QueueTaskResponse { - pub tasks: Vec, -} diff --git a/crates/common/src/storage/db.rs b/crates/common/src/storage/db.rs index 515abf2..3883483 100644 --- a/crates/common/src/storage/db.rs +++ b/crates/common/src/storage/db.rs @@ -1,13 +1,14 @@ use crate::error::AppError; -use super::types::{analytics::Analytics, system_settings::SystemSettings, StoredObject}; +use super::types::{analytics::Analytics, job::Job, system_settings::SystemSettings, StoredObject}; use axum_session::{SessionConfig, SessionError, SessionStore}; use axum_session_surreal::SessionSurrealPool; +use futures::Stream; use std::ops::Deref; use surrealdb::{ engine::any::{connect, Any}, opt::auth::Root, - Error, Surreal, + Error, Notification, Surreal, }; #[derive(Clone)] @@ -53,8 +54,8 @@ impl SurrealDbClient { } pub async fn ensure_initialized(&self) -> Result<(), AppError> { - Self::build_indexes(&self).await?; - Self::setup_auth(&self).await?; + Self::build_indexes(self).await?; + Self::setup_auth(self).await?; Analytics::ensure_initialized(self).await?; SystemSettings::ensure_initialized(self).await?; @@ -107,6 +108,75 @@ impl SurrealDbClient { { self.client.delete(T::table_name()).await } + + /// Operation to store a object in SurrealDB, requires the struct to implement StoredObject + /// + /// # Arguments + /// * `item` - The item to be stored + /// + /// # Returns + /// * `Result` - Item or Error + pub async fn store_item(&self, item: T) -> Result, Error> + where + T: StoredObject + Send + Sync + 'static, + { + self.client + .create((T::table_name(), item.get_id())) + .content(item) + .await + } + + /// Operation to retrieve all objects from a certain table, requires the struct to implement StoredObject + /// + /// # Returns + /// * `Result` - Vec or Error + pub async fn get_all_stored_items(&self) -> Result, Error> + where + T: for<'de> StoredObject, + { + self.client.select(T::table_name()).await + } + + /// Operation to retrieve a single object by its ID, requires the struct to implement StoredObject + /// + /// # Arguments + /// * `id` - The ID of the item to retrieve + /// + /// # Returns + /// * `Result, Error>` - The found item or Error + pub async fn get_item(&self, id: &str) -> Result, Error> + where + T: for<'de> StoredObject, + { + self.client.select((T::table_name(), id)).await + } + + /// Operation to delete a single object by its ID, requires the struct to implement StoredObject + /// + /// # Arguments + /// * `id` - The ID of the item to delete + /// + /// # Returns + /// * `Result, Error>` - The deleted item or Error + pub async fn delete_item(&self, id: &str) -> Result, Error> + where + T: for<'de> StoredObject, + { + self.client.delete((T::table_name(), id)).await + } + + /// Operation to listen to a table for updates, requires the struct to implement StoredObject + /// + /// # Returns + /// * `Result, Error>` - The deleted item or Error + pub async fn listen( + &self, + ) -> Result, Error>>, Error> + where + T: for<'de> StoredObject, + { + self.client.select(T::table_name()).live().await + } } impl Deref for SurrealDbClient { @@ -116,65 +186,3 @@ impl Deref for SurrealDbClient { &self.client } } - -/// Operation to store a object in SurrealDB, requires the struct to implement StoredObject -/// -/// # Arguments -/// * `db_client` - A initialized database client -/// * `item` - The item to be stored -/// -/// # Returns -/// * `Result` - Item or Error -pub async fn store_item(db_client: &Surreal, item: T) -> Result, Error> -where - T: StoredObject + Send + Sync + 'static, -{ - db_client - .create((T::table_name(), item.get_id())) - .content(item) - .await -} - -/// Operation to retrieve all objects from a certain table, requires the struct to implement StoredObject -/// -/// # Arguments -/// * `db_client` - A initialized database client -/// -/// # Returns -/// * `Result` - Vec or Error -pub async fn get_all_stored_items(db_client: &Surreal) -> Result, Error> -where - T: for<'de> StoredObject, -{ - db_client.select(T::table_name()).await -} - -/// Operation to retrieve a single object by its ID, requires the struct to implement StoredObject -/// -/// # Arguments -/// * `db_client` - An initialized database client -/// * `id` - The ID of the item to retrieve -/// -/// # Returns -/// * `Result, Error>` - The found item or Error -pub async fn get_item(db_client: &Surreal, id: &str) -> Result, Error> -where - T: for<'de> StoredObject, -{ - db_client.select((T::table_name(), id)).await -} - -/// Operation to delete a single object by its ID, requires the struct to implement StoredObject -/// -/// # Arguments -/// * `db_client` - An initialized database client -/// * `id` - The ID of the item to delete -/// -/// # Returns -/// * `Result, Error>` - The deleted item or Error -pub async fn delete_item(db_client: &Surreal, id: &str) -> Result, Error> -where - T: for<'de> StoredObject, -{ - db_client.delete((T::table_name(), id)).await -} diff --git a/crates/common/src/storage/types/conversation.rs b/crates/common/src/storage/types/conversation.rs index 7a85494..2a6675f 100644 --- a/crates/common/src/storage/types/conversation.rs +++ b/crates/common/src/storage/types/conversation.rs @@ -1,10 +1,6 @@ use uuid::Uuid; -use crate::{ - error::AppError, - storage::db::{get_item, SurrealDbClient}, - stored_object, -}; +use crate::{error::AppError, storage::db::SurrealDbClient, stored_object}; use super::message::Message; @@ -30,7 +26,8 @@ impl Conversation { user_id: &str, db: &SurrealDbClient, ) -> Result<(Self, Vec), AppError> { - let conversation: Conversation = get_item(&db, conversation_id) + let conversation: Conversation = db + .get_item(conversation_id) .await? .ok_or_else(|| AppError::NotFound("Conversation not found".to_string()))?; diff --git a/crates/common/src/storage/types/file_info.rs b/crates/common/src/storage/types/file_info.rs index 688b0dd..8ff3342 100644 --- a/crates/common/src/storage/types/file_info.rs +++ b/crates/common/src/storage/types/file_info.rs @@ -11,10 +11,7 @@ use tokio::fs::remove_dir_all; use tracing::info; use uuid::Uuid; -use crate::{ - storage::db::{delete_item, get_item, store_item, SurrealDbClient}, - stored_object, -}; +use crate::{storage::db::SurrealDbClient, stored_object}; #[derive(Error, Debug)] pub enum FileError { @@ -89,7 +86,7 @@ impl FileInfo { }; // Store in database - store_item(&db_client.client, file_info.clone()).await?; + db_client.store_item(file_info.clone()).await?; Ok(file_info) } @@ -210,7 +207,7 @@ impl FileInfo { /// `Result<(), FileError>` pub async fn delete_by_id(id: &str, db_client: &SurrealDbClient) -> Result<(), FileError> { // Get the FileInfo from the database - let file_info = match get_item::(db_client, id).await? { + let file_info = match db_client.get_item::(id).await? { Some(info) => info, None => { return Err(FileError::FileNotFound(format!( @@ -241,7 +238,7 @@ impl FileInfo { } // Delete the FileInfo from the database - delete_item::(db_client, id).await?; + db_client.delete_item::(id).await?; Ok(()) } diff --git a/crates/common/src/storage/types/job.rs b/crates/common/src/storage/types/job.rs index 510cf20..9239a00 100644 --- a/crates/common/src/storage/types/job.rs +++ b/crates/common/src/storage/types/job.rs @@ -1,6 +1,11 @@ +use futures::Stream; +use surrealdb::{opt::PatchOp, Notification}; use uuid::Uuid; -use crate::{ingress::ingress_object::IngressObject, stored_object}; +use crate::{ + error::AppError, ingress::ingress_object::IngressObject, storage::db::SurrealDbClient, + stored_object, +}; #[derive(Debug, Clone, Serialize, Deserialize)] pub enum JobStatus { @@ -20,6 +25,8 @@ stored_object!(Job, "job", { user_id: String }); +pub const MAX_ATTEMPTS: u32 = 3; + impl Job { pub async fn new(content: IngressObject, user_id: String) -> Self { let now = Utc::now(); @@ -33,4 +40,64 @@ impl Job { user_id, } } + + /// Creates a new job and stores it in the database + pub async fn create_and_add_to_db( + content: IngressObject, + user_id: String, + db: &SurrealDbClient, + ) -> Result<(), AppError> { + let job = Self::new(content, user_id).await; + + db.store_item(job).await?; + + Ok(()) + } + + // Update job status + pub async fn update_status( + id: &str, + status: JobStatus, + db: &SurrealDbClient, + ) -> Result<(), AppError> { + let _job: Option = db + .update((Self::table_name(), id)) + .patch(PatchOp::replace("/status", status)) + .patch(PatchOp::replace( + "/updated_at", + surrealdb::sql::Datetime::default(), + )) + .await?; + + Ok(()) + } + + /// Listen for new jobs + pub async fn listen_for_jobs( + db: &SurrealDbClient, + ) -> Result, surrealdb::Error>>, surrealdb::Error> + { + db.listen::().await + } + + /// Get all unfinished jobs, ie newly created and in progress up two times + pub async fn get_unfinished_jobs(db: &SurrealDbClient) -> Result, AppError> { + let jobs: Vec = db + .query( + "SELECT * FROM type::table($table) + WHERE + status = 'Created' + OR ( + status.InProgress != NONE + AND status.InProgress.attempts < $max_attempts + ) + ORDER BY created_at ASC", + ) + .bind(("table", Self::table_name())) + .bind(("max_attempts", MAX_ATTEMPTS)) + .await? + .take(0)?; + + Ok(jobs) + } } diff --git a/crates/common/src/storage/types/user.rs b/crates/common/src/storage/types/user.rs index f928452..2bffd6c 100644 --- a/crates/common/src/storage/types/user.rs +++ b/crates/common/src/storage/types/user.rs @@ -1,14 +1,10 @@ -use crate::{ - error::AppError, - storage::db::{get_item, SurrealDbClient}, - stored_object, -}; +use crate::{error::AppError, storage::db::SurrealDbClient, stored_object}; use axum_session_auth::Authentication; use surrealdb::{engine::any::Any, Surreal}; use uuid::Uuid; use super::{ - conversation::Conversation, knowledge_entity::KnowledgeEntity, + conversation::Conversation, job::Job, knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship, system_settings::SystemSettings, text_content::TextContent, }; @@ -32,7 +28,10 @@ stored_object!(User, "user", { impl Authentication> for User { async fn load_user(userid: String, db: Option<&Surreal>) -> Result { let db = db.unwrap(); - Ok(get_item::(db, userid.as_str()).await?.unwrap()) + Ok(db + .select((Self::table_name(), userid.as_str())) + .await? + .unwrap()) } fn is_authenticated(&self) -> bool { @@ -308,7 +307,8 @@ impl User { user_id: &str, db: &SurrealDbClient, ) -> Result { - let entity: KnowledgeEntity = get_item(db, &id) + let entity: KnowledgeEntity = db + .get_item(id) .await? .ok_or_else(|| AppError::NotFound("Entity not found".into()))?; @@ -324,7 +324,8 @@ impl User { user_id: &str, db: &SurrealDbClient, ) -> Result { - let text_content: TextContent = get_item(db, &id) + let text_content: TextContent = db + .get_item(id) .await? .ok_or_else(|| AppError::NotFound("Content not found".into()))?; @@ -349,4 +350,49 @@ impl User { Ok(conversations) } + + /// Gets all active jobs for the specified user + pub async fn get_unfinished_jobs( + user_id: &str, + db: &SurrealDbClient, + ) -> Result, AppError> { + let jobs: Vec = db + .query( + "SELECT * FROM type::table($table) + WHERE user_id = $user_id + AND ( + status = 'Created' + OR ( + status.InProgress != NONE + AND status.InProgress.attempts < $max_attempts + ) + ) + ORDER BY created_at DESC", + ) + .bind(("table", Job::table_name())) + .bind(("user_id", user_id.to_owned())) + .bind(("max_attempts", 3)) + .await? + .take(0)?; + + Ok(jobs) + } + + /// Validate and delete job + pub async fn validate_and_delete_job( + id: &str, + user_id: &str, + db: &SurrealDbClient, + ) -> Result<(), AppError> { + db.get_item::(id) + .await? + .filter(|job| job.user_id == user_id) + .ok_or_else(|| AppError::Auth("Not authorized to delete this job".into()))?; + + db.delete_item::(id) + .await + .map_err(AppError::Database)?; + + Ok(()) + } } diff --git a/crates/html-router/src/html_state.rs b/crates/html-router/src/html_state.rs index 25d4d7d..1ce1f51 100644 --- a/crates/html-router/src/html_state.rs +++ b/crates/html-router/src/html_state.rs @@ -1,6 +1,5 @@ use axum_session::SessionStore; use axum_session_surreal::SessionSurrealPool; -use common::ingress::jobqueue::JobQueue; use common::storage::db::SurrealDbClient; use common::utils::config::AppConfig; use common::utils::mailer::Mailer; @@ -12,11 +11,10 @@ use surrealdb::engine::any::Any; #[derive(Clone)] pub struct HtmlState { - pub surreal_db_client: Arc, + pub db: Arc, pub openai_client: Arc>, pub templates: Arc, pub mailer: Arc, - pub job_queue: Arc, pub session_store: Arc>>, } @@ -51,7 +49,7 @@ impl HtmlState { let session_store = Arc::new(surreal_db_client.create_session_store().await?); let app_state = HtmlState { - surreal_db_client: surreal_db_client.clone(), + db: surreal_db_client.clone(), templates: Arc::new(reloader), openai_client: openai_client.clone(), mailer: Arc::new(Mailer::new( @@ -59,7 +57,6 @@ impl HtmlState { &config.smtp_relayer, &config.smtp_password, )?), - job_queue: Arc::new(JobQueue::new(surreal_db_client)), session_store, }; diff --git a/crates/html-router/src/lib.rs b/crates/html-router/src/lib.rs index 4154d7a..28c47f4 100644 --- a/crates/html-router/src/lib.rs +++ b/crates/html-router/src/lib.rs @@ -102,7 +102,7 @@ where .layer(from_fn_with_state(app_state.clone(), analytics_middleware)) .layer( AuthSessionLayer::, Surreal>::new(Some( - app_state.surreal_db_client.client.clone(), + app_state.db.client.clone(), )) .with_config(AuthConfig::::default()), ) diff --git a/crates/html-router/src/middleware_analytics.rs b/crates/html-router/src/middleware_analytics.rs index a76bd08..4337404 100644 --- a/crates/html-router/src/middleware_analytics.rs +++ b/crates/html-router/src/middleware_analytics.rs @@ -22,11 +22,11 @@ pub async fn analytics_middleware( // Only count if it's a main page request (not assets or other resources) if !path.starts_with("/assets") && !path.starts_with("/_next") && !path.contains('.') { if !session.get::("counted_visitor").unwrap_or(false) { - let _ = Analytics::increment_visitors(&state.surreal_db_client).await; + let _ = Analytics::increment_visitors(&state.db).await; session.set("counted_visitor", true); } - let _ = Analytics::increment_page_loads(&state.surreal_db_client).await; + let _ = Analytics::increment_page_loads(&state.db).await; } next.run(request).await diff --git a/crates/html-router/src/routes/account.rs b/crates/html-router/src/routes/account.rs index 1c5e0ae..65d73b8 100644 --- a/crates/html-router/src/routes/account.rs +++ b/crates/html-router/src/routes/account.rs @@ -12,7 +12,7 @@ use surrealdb::{engine::any::Any, Surreal}; use common::{ error::{AppError, HtmlError}, - storage::{db::delete_item, types::user::User}, + storage::types::user::User, }; use crate::{html_state::HtmlState, page_data}; @@ -56,7 +56,7 @@ pub async fn set_api_key( }; // Generate and set the API key - let api_key = User::set_api_key(&user.id, &state.surreal_db_client) + let api_key = User::set_api_key(&user.id, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; @@ -92,7 +92,9 @@ pub async fn delete_account( None => return Ok(Redirect::to("/").into_response()), }; - delete_item::(&state.surreal_db_client, &user.id) + state + .db + .delete_item::(&user.id) .await .map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?; @@ -118,7 +120,7 @@ pub async fn update_timezone( None => return Ok(Redirect::to("/").into_response()), }; - User::update_timezone(&user.id, &form.timezone, &state.surreal_db_client) + User::update_timezone(&user.id, &form.timezone, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; diff --git a/crates/html-router/src/routes/admin_panel.rs b/crates/html-router/src/routes/admin_panel.rs index 67d996b..f6b42f8 100644 --- a/crates/html-router/src/routes/admin_panel.rs +++ b/crates/html-router/src/routes/admin_panel.rs @@ -33,15 +33,15 @@ pub async fn show_admin_panel( _ => return Ok(Redirect::to("/").into_response()), }; - let settings = SystemSettings::get_current(&state.surreal_db_client) + let settings = SystemSettings::get_current(&state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; - let analytics = Analytics::get_current(&state.surreal_db_client) + let analytics = Analytics::get_current(&state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; - let users_count = Analytics::get_users_amount(&state.surreal_db_client) + let users_count = Analytics::get_users_amount(&state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; @@ -92,7 +92,7 @@ pub async fn toggle_registration_status( _ => return Ok(Redirect::to("/").into_response()), }; - let current_settings = SystemSettings::get_current(&state.surreal_db_client) + let current_settings = SystemSettings::get_current(&state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; @@ -101,7 +101,7 @@ pub async fn toggle_registration_status( ..current_settings.clone() }; - SystemSettings::update(&state.surreal_db_client, new_settings.clone()) + SystemSettings::update(&state.db, new_settings.clone()) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; diff --git a/crates/html-router/src/routes/chat/message_response_stream.rs b/crates/html-router/src/routes/chat/message_response_stream.rs index 1e89881..b6ab5b0 100644 --- a/crates/html-router/src/routes/chat/message_response_stream.rs +++ b/crates/html-router/src/routes/chat/message_response_stream.rs @@ -29,7 +29,7 @@ use common::{ }, }, storage::{ - db::{get_item, store_item, SurrealDbClient}, + db::SurrealDbClient, types::{ message::{Message, MessageRole}, user::User, @@ -64,7 +64,7 @@ async fn get_message_and_user( }; // Retrieve message - let message = match get_item::(db, message_id).await { + let message = match db.get_item::(message_id).await { Ok(Some(message)) => message, Ok(None) => { return Err(Sse::new(create_error_stream( @@ -93,20 +93,15 @@ pub async fn get_response_stream( Query(params): Query, ) -> Sse> + Send>>> { // 1. Authentication and initial data validation - let (user_message, user) = match get_message_and_user( - &state.surreal_db_client, - auth.current_user, - ¶ms.message_id, - ) - .await - { - Ok((user_message, user)) => (user_message, user), - Err(error_stream) => return error_stream, - }; + let (user_message, user) = + match get_message_and_user(&state.db, auth.current_user, ¶ms.message_id).await { + Ok((user_message, user)) => (user_message, user), + Err(error_stream) => return error_stream, + }; // 2. Retrieve knowledge entities let entities = match combined_knowledge_entity_retrieval( - &state.surreal_db_client, + &state.db, &state.openai_client, &user_message.content, &user.id, @@ -143,7 +138,7 @@ pub async fn get_response_stream( let (tx_final, mut rx_final) = channel::(1); // 6. Set up the collection task for DB storage - let db_client = state.surreal_db_client.clone(); + let db_client = state.db.clone(); tokio::spawn(async move { drop(tx); // Close sender when no longer needed @@ -170,7 +165,7 @@ pub async fn get_response_stream( let _ = tx_final.send(ai_message.clone()).await; - match store_item(&db_client, ai_message).await { + match db_client.store_item(ai_message).await { Ok(_) => info!("Successfully stored AI message with references"), Err(e) => error!("Failed to store AI message: {:?}", e), } @@ -185,7 +180,7 @@ pub async fn get_response_stream( None, ); - let _ = store_item(&db_client, ai_message).await; + let _ = db_client.store_item(ai_message).await; } }); diff --git a/crates/html-router/src/routes/chat/mod.rs b/crates/html-router/src/routes/chat/mod.rs index 906fd3b..f8c56aa 100644 --- a/crates/html-router/src/routes/chat/mod.rs +++ b/crates/html-router/src/routes/chat/mod.rs @@ -14,19 +14,15 @@ use tracing::info; use common::{ error::{AppError, HtmlError}, - storage::{ - db::{get_item, store_item}, - types::{ - conversation::Conversation, - message::{Message, MessageRole}, - user::User, - }, + storage::types::{ + conversation::Conversation, + message::{Message, MessageRole}, + user::User, }, }; use crate::{html_state::HtmlState, page_data, routes::render_template}; -// Update your ChatStartParams struct to properly deserialize the references #[derive(Debug, Deserialize)] pub struct ChatStartParams { user_query: String, @@ -80,9 +76,9 @@ pub async fn show_initialized_chat( ); let (conversation_result, ai_message_result, user_message_result) = futures::join!( - store_item(&state.surreal_db_client, conversation.clone()), - store_item(&state.surreal_db_client, ai_message.clone()), - store_item(&state.surreal_db_client, user_message.clone()) + state.db.store_item(conversation.clone()), + state.db.store_item(ai_message.clone()), + state.db.store_item(user_message.clone()) ); // Check each result individually @@ -90,7 +86,7 @@ pub async fn show_initialized_chat( user_message_result.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?; ai_message_result.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?; - let conversation_archive = User::get_user_conversations(&user.id, &state.surreal_db_client) + let conversation_archive = User::get_user_conversations(&user.id, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; @@ -126,7 +122,7 @@ pub async fn show_chat_base( None => return Ok(Redirect::to("/").into_response()), }; - let conversation_archive = User::get_user_conversations(&user.id, &state.surreal_db_client) + let conversation_archive = User::get_user_conversations(&user.id, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; @@ -161,17 +157,14 @@ pub async fn show_existing_chat( None => return Ok(Redirect::to("/").into_response()), }; - let conversation_archive = User::get_user_conversations(&user.id, &state.surreal_db_client) + let conversation_archive = User::get_user_conversations(&user.id, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; - let (conversation, messages) = Conversation::get_complete_conversation( - conversation_id.as_str(), - &user.id, - &state.surreal_db_client, - ) - .await - .map_err(|e| HtmlError::new(e, state.templates.clone()))?; + let (conversation, messages) = + Conversation::get_complete_conversation(conversation_id.as_str(), &user.id, &state.db) + .await + .map_err(|e| HtmlError::new(e, state.templates.clone()))?; let output = render_template( ChatData::template_name(), @@ -198,7 +191,9 @@ pub async fn new_user_message( None => return Ok(Redirect::to("/").into_response()), }; - let conversation: Conversation = get_item(&state.surreal_db_client, &conversation_id) + let conversation: Conversation = state + .db + .get_item(&conversation_id) .await .map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))? .ok_or_else(|| { @@ -217,7 +212,9 @@ pub async fn new_user_message( let user_message = Message::new(conversation_id, MessageRole::User, form.content, None); - store_item(&state.surreal_db_client, user_message.clone()) + state + .db + .store_item(user_message.clone()) .await .map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?; @@ -258,10 +255,14 @@ pub async fn new_chat_user_message( None, ); - store_item(&state.surreal_db_client, conversation.clone()) + state + .db + .store_item(conversation.clone()) .await .map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?; - store_item(&state.surreal_db_client, user_message.clone()) + state + .db + .store_item(user_message.clone()) .await .map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?; diff --git a/crates/html-router/src/routes/chat/references.rs b/crates/html-router/src/routes/chat/references.rs index 15b4ded..6f5695c 100644 --- a/crates/html-router/src/routes/chat/references.rs +++ b/crates/html-router/src/routes/chat/references.rs @@ -10,10 +10,7 @@ use tracing::info; use common::{ error::{AppError, HtmlError}, - storage::{ - db::get_item, - types::{knowledge_entity::KnowledgeEntity, user::User}, - }, + storage::types::{knowledge_entity::KnowledgeEntity, user::User}, }; use crate::{html_state::HtmlState, routes::render_template}; @@ -30,7 +27,9 @@ pub async fn show_reference_tooltip( None => return Ok(Redirect::to("/").into_response()), }; - let entity: KnowledgeEntity = get_item(&state.surreal_db_client, &reference_id) + let entity: KnowledgeEntity = state + .db + .get_item(&reference_id) .await .map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))? .ok_or_else(|| { diff --git a/crates/html-router/src/routes/content/mod.rs b/crates/html-router/src/routes/content/mod.rs index 831c721..85eebce 100644 --- a/crates/html-router/src/routes/content/mod.rs +++ b/crates/html-router/src/routes/content/mod.rs @@ -30,7 +30,7 @@ pub async fn show_content_page( None => return Ok(Redirect::to("/signin").into_response()), }; - let text_contents = User::get_text_contents(&user.id, &state.surreal_db_client) + let text_contents = User::get_text_contents(&user.id, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; @@ -63,7 +63,7 @@ pub async fn show_text_content_edit_form( None => return Ok(Redirect::to("/signin").into_response()), }; - let text_content = User::get_and_validate_text_content(&id, &user.id, &state.surreal_db_client) + let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; @@ -87,11 +87,13 @@ pub async fn patch_text_content( None => return Ok(Redirect::to("/signin").into_response()), }; - let text_content = User::get_and_validate_text_content(&id, &user.id, &state.surreal_db_client) + let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; - let text_contents = User::get_text_contents(&user.id, &state.surreal_db_client) + // ADD FUNCTION TO PATCH CONTENT + + let text_contents = User::get_text_contents(&user.id, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; diff --git a/crates/html-router/src/routes/index.rs b/crates/html-router/src/routes/index.rs index eebc070..a0f0f55 100644 --- a/crates/html-router/src/routes/index.rs +++ b/crates/html-router/src/routes/index.rs @@ -11,13 +11,10 @@ use tracing::info; use common::{ error::{AppError, HtmlError}, - storage::{ - db::{delete_item, get_item}, - types::{ - file_info::FileInfo, job::Job, knowledge_entity::KnowledgeEntity, - knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk, - text_content::TextContent, user::User, - }, + storage::types::{ + file_info::FileInfo, job::Job, knowledge_entity::KnowledgeEntity, + knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk, + text_content::TextContent, user::User, }, }; @@ -42,9 +39,7 @@ pub async fn index_handler( let gdpr_accepted = auth.current_user.is_some() | session.get("gdpr_accepted").unwrap_or(false); let active_jobs = match auth.current_user.is_some() { - true => state - .job_queue - .get_unfinished_user_jobs(&auth.current_user.clone().unwrap().id) + true => User::get_unfinished_jobs(&auth.current_user.clone().unwrap().id, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?, false => vec![], @@ -53,7 +48,7 @@ pub async fn index_handler( let latest_text_contents = match auth.current_user.clone().is_some() { true => User::get_latest_text_contents( auth.current_user.clone().unwrap().id.as_str(), - &state.surreal_db_client, + &state.db, ) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?, @@ -63,7 +58,7 @@ pub async fn index_handler( // let latest_knowledge_entities = match auth.current_user.is_some() { // true => User::get_latest_knowledge_entities( // auth.current_user.clone().unwrap().id.as_str(), - // &state.surreal_db_client, + // &state.db, // ) // .await // .map_err(|e| HtmlError::new(e, state.templates.clone()))?, @@ -107,18 +102,15 @@ pub async fn delete_text_content( let deletion_tasks = join!( async { if let Some(file_info) = text_content.file_info { - FileInfo::delete_by_id(&file_info.id, &state.surreal_db_client).await + FileInfo::delete_by_id(&file_info.id, &state.db).await } else { Ok(()) } }, - delete_item::(&state.surreal_db_client, &text_content.id), - TextChunk::delete_by_source_id(&text_content.id, &state.surreal_db_client), - KnowledgeEntity::delete_by_source_id(&text_content.id, &state.surreal_db_client), - KnowledgeRelationship::delete_relationships_by_source_id( - &text_content.id, - &state.surreal_db_client - ) + state.db.delete_item::(&text_content.id), + TextChunk::delete_by_source_id(&text_content.id, &state.db), + KnowledgeEntity::delete_by_source_id(&text_content.id, &state.db), + KnowledgeRelationship::delete_relationships_by_source_id(&text_content.id, &state.db) ); // Handle potential errors from concurrent operations @@ -133,7 +125,7 @@ pub async fn delete_text_content( } // Render updated content - let latest_text_contents = User::get_latest_text_contents(&user.id, &state.surreal_db_client) + let latest_text_contents = User::get_latest_text_contents(&user.id, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; @@ -156,7 +148,9 @@ async fn get_and_validate_text_content( id: &str, user: &User, ) -> Result { - let text_content = get_item::(&state.surreal_db_client, id) + let text_content = state + .db + .get_item::(id) .await .map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))? .ok_or_else(|| { @@ -192,15 +186,11 @@ pub async fn delete_job( None => return Ok(Redirect::to("/signin").into_response()), }; - state - .job_queue - .delete_job(&id, &user.id) + User::validate_and_delete_job(&id, &user.id, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; - let active_jobs = state - .job_queue - .get_unfinished_user_jobs(&user.id) + let active_jobs = User::get_unfinished_jobs(&user.id, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; @@ -226,9 +216,7 @@ pub async fn show_active_jobs( None => return Ok(Redirect::to("/signin").into_response()), }; - let active_jobs = state - .job_queue - .get_unfinished_user_jobs(&user.id) + let active_jobs = User::get_unfinished_jobs(&user.id, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; diff --git a/crates/html-router/src/routes/ingress_form.rs b/crates/html-router/src/routes/ingress_form.rs index 1bdd60f..43d0204 100644 --- a/crates/html-router/src/routes/ingress_form.rs +++ b/crates/html-router/src/routes/ingress_form.rs @@ -13,7 +13,7 @@ use tracing::info; use common::{ error::{AppError, HtmlError, IntoHtmlError}, ingress::ingress_input::{create_ingress_objects, IngressInput}, - storage::types::{file_info::FileInfo, user::User}, + storage::types::{file_info::FileInfo, job::Job, user::User}, }; use crate::{ @@ -37,7 +37,7 @@ pub async fn show_ingress_form( return Ok(Redirect::to("/").into_response()); } - let user_categories = User::get_user_categories(&auth.id, &state.surreal_db_client) + let user_categories = User::get_user_categories(&auth.id, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; @@ -107,7 +107,7 @@ pub async fn process_ingress_form( info!("{:?}", input); let file_infos = try_join_all(input.files.into_iter().map(|file| { - FileInfo::new(file, &state.surreal_db_client, &user.id) + FileInfo::new(file, &state.db, &user.id) .map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone())) })) .await?; @@ -125,7 +125,7 @@ pub async fn process_ingress_form( let futures: Vec<_> = ingress_objects .into_iter() - .map(|object| state.job_queue.enqueue(object.clone(), user.id.clone())) + .map(|object| Job::create_and_add_to_db(object.clone(), user.id.clone(), &state.db)) .collect(); try_join_all(futures) @@ -134,9 +134,7 @@ pub async fn process_ingress_form( .map_err(|e| HtmlError::new(e, state.templates.clone()))?; // Update the active jobs page with the newly created job - let active_jobs = state - .job_queue - .get_unfinished_user_jobs(&user.id) + let active_jobs = User::get_unfinished_jobs(&user.id, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; diff --git a/crates/html-router/src/routes/knowledge/mod.rs b/crates/html-router/src/routes/knowledge/mod.rs index b82b872..40f4293 100644 --- a/crates/html-router/src/routes/knowledge/mod.rs +++ b/crates/html-router/src/routes/knowledge/mod.rs @@ -15,13 +15,10 @@ use tracing::info; use common::{ error::{AppError, HtmlError}, - storage::{ - db::delete_item, - types::{ - knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}, - knowledge_relationship::KnowledgeRelationship, - user::User, - }, + storage::types::{ + knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}, + knowledge_relationship::KnowledgeRelationship, + user::User, }, }; @@ -44,13 +41,13 @@ pub async fn show_knowledge_page( None => return Ok(Redirect::to("/signin").into_response()), }; - let entities = User::get_knowledge_entities(&user.id, &state.surreal_db_client) + let entities = User::get_knowledge_entities(&user.id, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; info!("Got entities ok"); - let relationships = User::get_knowledge_relationships(&user.id, &state.surreal_db_client) + let relationships = User::get_knowledge_relationships(&user.id, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; @@ -169,7 +166,7 @@ pub async fn show_edit_knowledge_entity_form( .collect(); // Get the entity and validate ownership - let entity = User::get_and_validate_knowledge_entity(&id, &user.id, &state.surreal_db_client) + let entity = User::get_and_validate_knowledge_entity(&id, &user.id, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; @@ -212,7 +209,7 @@ pub async fn patch_knowledge_entity( }; // Get the existing entity and validate that the user is allowed - User::get_and_validate_knowledge_entity(&form.id, &user.id, &state.surreal_db_client) + User::get_and_validate_knowledge_entity(&form.id, &user.id, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; @@ -224,14 +221,14 @@ pub async fn patch_knowledge_entity( &form.name, &form.description, &entity_type, - &state.surreal_db_client, + &state.db, &state.openai_client, ) .await .map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?; // Get updated list of entities - let entities = User::get_knowledge_entities(&user.id, &state.surreal_db_client) + let entities = User::get_knowledge_entities(&user.id, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; @@ -257,17 +254,19 @@ pub async fn delete_knowledge_entity( }; // Get the existing entity and validate that the user is allowed - User::get_and_validate_knowledge_entity(&id, &user.id, &state.surreal_db_client) + User::get_and_validate_knowledge_entity(&id, &user.id, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; // Delete the entity - delete_item::(&state.surreal_db_client, &id) + state + .db + .delete_item::(&id) .await .map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?; // Get updated list of entities - let entities = User::get_knowledge_entities(&user.id, &state.surreal_db_client) + let entities = User::get_knowledge_entities(&user.id, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; @@ -300,15 +299,15 @@ pub async fn delete_knowledge_relationship( // GOTTA ADD AUTH VALIDATION - KnowledgeRelationship::delete_relationship_by_id(&id, &state.surreal_db_client) + KnowledgeRelationship::delete_relationship_by_id(&id, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; - let entities = User::get_knowledge_entities(&user.id, &state.surreal_db_client) + let entities = User::get_knowledge_entities(&user.id, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; - let relationships = User::get_knowledge_relationships(&user.id, &state.surreal_db_client) + let relationships = User::get_knowledge_relationships(&user.id, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; @@ -353,15 +352,15 @@ pub async fn save_knowledge_relationship( ); relationship - .store_relationship(&state.surreal_db_client) + .store_relationship(&state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; - let entities = User::get_knowledge_entities(&user.id, &state.surreal_db_client) + let entities = User::get_knowledge_entities(&user.id, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; - let relationships = User::get_knowledge_relationships(&user.id, &state.surreal_db_client) + let relationships = User::get_knowledge_relationships(&user.id, &state.db) .await .map_err(|e| HtmlError::new(e, state.templates.clone()))?; diff --git a/crates/html-router/src/routes/signin.rs b/crates/html-router/src/routes/signin.rs index ca6b3ed..24c8cdf 100644 --- a/crates/html-router/src/routes/signin.rs +++ b/crates/html-router/src/routes/signin.rs @@ -54,7 +54,7 @@ pub async fn authenticate_user( auth: AuthSession, Surreal>, Form(form): Form, ) -> Result { - let user = match User::authenticate(form.email, form.password, &state.surreal_db_client).await { + let user = match User::authenticate(form.email, form.password, &state.db).await { Ok(user) => user, Err(_) => { return Ok(Html("

Incorrect email or password

").into_response()); diff --git a/crates/html-router/src/routes/signup.rs b/crates/html-router/src/routes/signup.rs index cf851c1..67993b4 100644 --- a/crates/html-router/src/routes/signup.rs +++ b/crates/html-router/src/routes/signup.rs @@ -44,14 +44,7 @@ pub async fn process_signup_and_show_verification( auth: AuthSession, Surreal>, Form(form): Form, ) -> Result { - let user = match User::create_new( - form.email, - form.password, - &state.surreal_db_client, - form.timezone, - ) - .await - { + let user = match User::create_new(form.email, form.password, &state.db, form.timezone).await { Ok(user) => user, Err(e) => { tracing::error!("{:?}", e); diff --git a/crates/main/src/server.rs b/crates/main/src/server.rs index 7baa6db..2ebadbc 100644 --- a/crates/main/src/server.rs +++ b/crates/main/src/server.rs @@ -20,8 +20,7 @@ async fn main() -> Result<(), Box> { // Set up router states let html_state = HtmlState::new(&config).await?; let api_state = ApiState { - surreal_db_client: html_state.surreal_db_client.clone(), - job_queue: html_state.job_queue.clone(), + db: html_state.db.clone(), }; // Create Axum router diff --git a/crates/main/src/worker.rs b/crates/main/src/worker.rs index 88225a3..f115953 100644 --- a/crates/main/src/worker.rs +++ b/crates/main/src/worker.rs @@ -1,12 +1,9 @@ use std::sync::Arc; use common::{ - ingress::{ - content_processor::ContentProcessor, - jobqueue::{JobQueue, MAX_ATTEMPTS}, - }, + ingress::content_processor::ContentProcessor, storage::{ - db::{get_item, SurrealDbClient}, + db::SurrealDbClient, types::job::{Job, JobStatus}, }, utils::config::get_config, @@ -40,27 +37,24 @@ async fn main() -> Result<(), Box> { let openai_client = Arc::new(async_openai::Client::new()); - let job_queue = JobQueue::new(surreal_db_client.clone()); - - let content_processor = ContentProcessor::new(surreal_db_client, openai_client.clone()).await?; + let content_processor = + ContentProcessor::new(surreal_db_client.clone(), openai_client.clone()).await?; loop { // First, check for any unfinished jobs - let unfinished_jobs = job_queue.get_unfinished_jobs().await?; + let unfinished_jobs = Job::get_unfinished_jobs(&surreal_db_client).await?; if !unfinished_jobs.is_empty() { info!("Found {} unfinished jobs", unfinished_jobs.len()); for job in unfinished_jobs { - job_queue - .process_job(job, &content_processor, openai_client.clone()) - .await?; + content_processor.process_job(job).await?; } } // If no unfinished jobs, start listening for new ones info!("Listening for new jobs..."); - let mut job_stream = job_queue.listen_for_jobs().await?; + let mut job_stream = Job::listen_for_jobs(&surreal_db_client).await?; while let Some(notification) = job_stream.next().await { match notification { @@ -69,14 +63,7 @@ async fn main() -> Result<(), Box> { match notification.action { Action::Create => { - if let Err(e) = job_queue - .process_job( - notification.data, - &content_processor, - openai_client.clone(), - ) - .await - { + if let Err(e) = content_processor.process_job(notification.data).await { error!("Error processing job: {}", e); } } @@ -93,20 +80,18 @@ async fn main() -> Result<(), Box> { } JobStatus::InProgress { attempts, .. } => { // Only process if this is a retry after an error, not our own update - if let Ok(Some(current_job)) = - get_item::(&job_queue.db.client, ¬ification.data.id) - .await + if let Ok(Some(current_job)) = surreal_db_client + .get_item::(¬ification.data.id) + .await { match current_job.status { - JobStatus::Error(_) if attempts < MAX_ATTEMPTS => { + JobStatus::Error(_) + if attempts + < common::storage::types::job::MAX_ATTEMPTS => + { // This is a retry after an error - if let Err(e) = job_queue - .process_job( - current_job, - &content_processor, - openai_client.clone(), - ) - .await + if let Err(e) = + content_processor.process_job(current_job).await { error!("Error processing job retry: {}", e); } @@ -123,13 +108,8 @@ async fn main() -> Result<(), Box> { } JobStatus::Created => { // Shouldn't happen with Update action, but process if it does - if let Err(e) = job_queue - .process_job( - notification.data, - &content_processor, - openai_client.clone(), - ) - .await + if let Err(e) = + content_processor.process_job(notification.data).await { error!("Error processing job: {}", e); } diff --git a/todo.md b/todo.md index 5912488..071dde9 100644 --- a/todo.md +++ b/todo.md @@ -1,8 +1,9 @@ +\[\] fix patch_text_conent \[\] archive ingressed webpage \[x\] chat styling overhaul \[\] configs primarily get envs \[\] filtering on categories -\[\] link to ingressed urls or archives +\[x\] link to ingressed urls or archives \[\] three js graph explorer \[\] three js vector explorer \[x\] add user_id to ingress objects