diff --git a/crates/api-router/src/routes/ingress.rs b/crates/api-router/src/routes/ingress.rs index 34a5bae..44fb72f 100644 --- a/crates/api-router/src/routes/ingress.rs +++ b/crates/api-router/src/routes/ingress.rs @@ -2,7 +2,7 @@ use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension}; use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart}; use common::{ error::{ApiError, AppError}, - ingress::ingress_input::{create_ingress_objects, IngressInput}, + ingress::ingress_object::IngressObject, storage::types::{file_info::FileInfo, job::Job, user::User}, }; use futures::{future::try_join_all, TryFutureExt}; @@ -38,13 +38,11 @@ pub async fn ingress_data( debug!("Got file infos"); - let ingress_objects = create_ingress_objects( - IngressInput { - content: input.content, - instructions: input.instructions, - category: input.category, - files: file_infos, - }, + let ingress_objects = IngressObject::create_ingress_objects( + input.content, + input.instructions, + input.category, + file_infos, user.id.as_str(), )?; debug!("Got ingress objects"); diff --git a/crates/common/src/ingress/analysis/ingress_analyser.rs b/crates/common/src/ingress/analysis/ingress_analyser.rs index 7647606..4a3cb37 100644 --- a/crates/common/src/ingress/analysis/ingress_analyser.rs +++ b/crates/common/src/ingress/analysis/ingress_analyser.rs @@ -2,7 +2,7 @@ use crate::{ error::AppError, ingress::analysis::prompt::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE}, retrieval::combined_knowledge_entity_retrieval, - storage::types::knowledge_entity::KnowledgeEntity, + storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity}, }; use async_openai::{ error::OpenAIError, @@ -13,20 +13,18 @@ use async_openai::{ }, }; use serde_json::json; -use surrealdb::engine::any::Any; -use surrealdb::Surreal; use tracing::debug; use super::types::llm_analysis_result::LLMGraphAnalysisResult; pub struct IngressAnalyzer<'a> { - db_client: &'a Surreal, + db_client: &'a SurrealDbClient, openai_client: &'a async_openai::Client, } impl<'a> IngressAnalyzer<'a> { pub fn new( - db_client: &'a Surreal, + db_client: &'a SurrealDbClient, openai_client: &'a async_openai::Client, ) -> Self { Self { diff --git a/crates/common/src/ingress/ingress_input.rs b/crates/common/src/ingress/ingress_input.rs deleted file mode 100644 index a122f92..0000000 --- a/crates/common/src/ingress/ingress_input.rs +++ /dev/null @@ -1,74 +0,0 @@ -use super::ingress_object::IngressObject; -use crate::{error::AppError, storage::types::file_info::FileInfo}; -use serde::{Deserialize, Serialize}; -use tracing::info; -use url::Url; - -/// Struct defining the expected body when ingressing content. -#[derive(Serialize, Deserialize, Debug)] -pub struct IngressInput { - pub content: Option, - pub instructions: String, - pub category: String, - pub files: Vec, -} - -/// Function to create ingress objects from input. -/// -/// # Arguments -/// * `input` - IngressInput containing information needed to ingress content. -/// * `user_id` - User id of the ingressing user -/// -/// # Returns -/// * `Vec` - An array containing the ingressed objects, one file/contenttype per object. -pub fn create_ingress_objects( - input: IngressInput, - user_id: &str, -) -> Result, AppError> { - // Initialize list - let mut object_list = Vec::new(); - - // Create a IngressObject from input.content if it exists, checking for URL or text - if let Some(input_content) = input.content { - match Url::parse(&input_content) { - Ok(url) => { - info!("Detected URL: {}", url); - object_list.push(IngressObject::Url { - url: url.to_string(), - instructions: input.instructions.clone(), - category: input.category.clone(), - user_id: user_id.into(), - }); - } - Err(_) => { - if input_content.len() > 2 { - info!("Treating input as plain text"); - object_list.push(IngressObject::Text { - text: input_content.to_string(), - instructions: input.instructions.clone(), - category: input.category.clone(), - user_id: user_id.into(), - }); - } - } - } - } - - for file in input.files { - object_list.push(IngressObject::File { - file_info: file, - instructions: input.instructions.clone(), - category: input.category.clone(), - user_id: user_id.into(), - }) - } - - // If no objects are constructed, we return Err - if object_list.is_empty() { - return Err(AppError::NotFound( - "No valid content or files provided".into(), - )); - } - - Ok(object_list) -} diff --git a/crates/common/src/ingress/ingress_object.rs b/crates/common/src/ingress/ingress_object.rs index 96cd140..7ca9847 100644 --- a/crates/common/src/ingress/ingress_object.rs +++ b/crates/common/src/ingress/ingress_object.rs @@ -13,6 +13,8 @@ use scraper::{Html, Selector}; use serde::{Deserialize, Serialize}; use std::fmt::Write; use tiktoken_rs::{o200k_base, CoreBPE}; +use tracing::info; +use url::Url; #[derive(Debug, Serialize, Deserialize, Clone)] pub enum IngressObject { @@ -37,6 +39,72 @@ pub enum IngressObject { } impl IngressObject { + /// Creates ingress objects from the provided content, instructions, and files. + /// + /// # Arguments + /// * `content` - Optional textual content to be ingressed + /// * `instructions` - Instructions for processing the ingress content + /// * `category` - Category to classify the ingressed content + /// * `files` - Vector of `FileInfo` objects containing information about uploaded files + /// * `user_id` - Identifier of the user performing the ingress operation + /// + /// # Returns + /// * `Result, AppError>` - On success, returns a vector of ingress objects + /// (one per file/content type). On failure, returns an `AppError`. + pub fn create_ingress_objects( + content: Option, + instructions: String, + category: String, + files: Vec, + user_id: &str, + ) -> Result, AppError> { + // Initialize list + let mut object_list = Vec::new(); + + // Create a IngressObject from content if it exists, checking for URL or text + if let Some(input_content) = content { + match Url::parse(&input_content) { + Ok(url) => { + info!("Detected URL: {}", url); + object_list.push(IngressObject::Url { + url: url.to_string(), + instructions: instructions.clone(), + category: category.clone(), + user_id: user_id.into(), + }); + } + Err(_) => { + if input_content.len() > 2 { + info!("Treating input as plain text"); + object_list.push(IngressObject::Text { + text: input_content.to_string(), + instructions: instructions.clone(), + category: category.clone(), + user_id: user_id.into(), + }); + } + } + } + } + + for file in files { + object_list.push(IngressObject::File { + file_info: file, + instructions: instructions.clone(), + category: category.clone(), + user_id: user_id.into(), + }) + } + + // If no objects are constructed, we return Err + if object_list.is_empty() { + return Err(AppError::NotFound( + "No valid content or files provided".into(), + )); + } + + Ok(object_list) + } /// Creates a new `TextContent` instance from a `IngressObject`. /// /// # Arguments diff --git a/crates/common/src/ingress/mod.rs b/crates/common/src/ingress/mod.rs index 34fee9c..646645f 100644 --- a/crates/common/src/ingress/mod.rs +++ b/crates/common/src/ingress/mod.rs @@ -1,4 +1,3 @@ pub mod analysis; pub mod content_processor; -pub mod ingress_input; pub mod ingress_object; diff --git a/crates/common/src/retrieval/graph.rs b/crates/common/src/retrieval/graph.rs index 4a61bc3..99388b4 100644 --- a/crates/common/src/retrieval/graph.rs +++ b/crates/common/src/retrieval/graph.rs @@ -1,7 +1,7 @@ -use surrealdb::{engine::any::Any, Error, Surreal}; +use surrealdb::Error; use tracing::debug; -use crate::storage::types::{knowledge_entity::KnowledgeEntity, StoredObject}; +use crate::storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity}; /// Retrieves database entries that match a specific source identifier. /// @@ -33,15 +33,14 @@ use crate::storage::types::{knowledge_entity::KnowledgeEntity, StoredObject}; pub async fn find_entities_by_source_ids( source_id: Vec, table_name: String, - db_client: &Surreal, + db: &SurrealDbClient, ) -> Result, Error> where T: for<'de> serde::Deserialize<'de>, { let query = "SELECT * FROM type::table($table) WHERE source_id IN $source_ids"; - db_client - .query(query) + db.query(query) .bind(("table", table_name)) .bind(("source_ids", source_id)) .await? @@ -50,7 +49,7 @@ where /// Find entities by their relationship to the id pub async fn find_entities_by_relationship_by_id( - db_client: &Surreal, + db: &SurrealDbClient, entity_id: String, ) -> Result, Error> { let query = format!( @@ -60,15 +59,5 @@ pub async fn find_entities_by_relationship_by_id( debug!("{}", query); - db_client.query(query).await?.take(0) -} - -/// Get a specific KnowledgeEntity by its id -pub async fn get_entity_by_id( - db_client: &Surreal, - entity_id: &str, -) -> Result, Error> { - db_client - .select((KnowledgeEntity::table_name(), entity_id)) - .await + db.query(query).await?.take(0) } diff --git a/crates/common/src/retrieval/mod.rs b/crates/common/src/retrieval/mod.rs index f14e447..863ab49 100644 --- a/crates/common/src/retrieval/mod.rs +++ b/crates/common/src/retrieval/mod.rs @@ -9,11 +9,13 @@ use crate::{ graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids}, vector::find_items_by_vector_similarity, }, - storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk}, + storage::{ + db::SurrealDbClient, + types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk}, + }, }; use futures::future::{try_join, try_join_all}; use std::collections::HashMap; -use surrealdb::{engine::any::Any, Surreal}; /// Performs a comprehensive knowledge entity retrieval using multiple search strategies /// to find the most relevant entities for a given query. @@ -37,7 +39,7 @@ use surrealdb::{engine::any::Any, Surreal}; /// * `Result, AppError>` - A deduplicated vector of relevant /// knowledge entities, or an error if the retrieval process fails pub async fn combined_knowledge_entity_retrieval( - db_client: &Surreal, + db_client: &SurrealDbClient, openai_client: &async_openai::Client, query: &str, user_id: &str, diff --git a/crates/html-router/src/routes/ingress_form.rs b/crates/html-router/src/routes/ingress_form.rs index 43d0204..5a1736b 100644 --- a/crates/html-router/src/routes/ingress_form.rs +++ b/crates/html-router/src/routes/ingress_form.rs @@ -12,7 +12,7 @@ use tracing::info; use common::{ error::{AppError, HtmlError, IntoHtmlError}, - ingress::ingress_input::{create_ingress_objects, IngressInput}, + ingress::ingress_object::IngressObject, storage::types::{file_info::FileInfo, job::Job, user::User}, }; @@ -112,13 +112,11 @@ pub async fn process_ingress_form( })) .await?; - let ingress_objects = create_ingress_objects( - IngressInput { - content: input.content, - instructions: input.instructions, - category: input.category, - files: file_infos, - }, + let ingress_objects = IngressObject::create_ingress_objects( + input.content, + input.instructions, + input.category, + file_infos, user.id.as_str(), ) .map_err(|e| HtmlError::new(e, state.templates.clone()))?; diff --git a/crates/main/src/worker.rs b/crates/main/src/worker.rs index f115953..04cc7b5 100644 --- a/crates/main/src/worker.rs +++ b/crates/main/src/worker.rs @@ -24,7 +24,7 @@ async fn main() -> Result<(), Box> { let config = get_config()?; - let surreal_db_client = Arc::new( + let db = Arc::new( SurrealDbClient::new( &config.surrealdb_address, &config.surrealdb_username, @@ -37,12 +37,11 @@ async fn main() -> Result<(), Box> { let openai_client = Arc::new(async_openai::Client::new()); - let content_processor = - ContentProcessor::new(surreal_db_client.clone(), openai_client.clone()).await?; + let content_processor = ContentProcessor::new(db.clone(), openai_client.clone()).await?; loop { // First, check for any unfinished jobs - let unfinished_jobs = Job::get_unfinished_jobs(&surreal_db_client).await?; + let unfinished_jobs = Job::get_unfinished_jobs(&db).await?; if !unfinished_jobs.is_empty() { info!("Found {} unfinished jobs", unfinished_jobs.len()); @@ -54,7 +53,7 @@ async fn main() -> Result<(), Box> { // If no unfinished jobs, start listening for new ones info!("Listening for new jobs..."); - let mut job_stream = Job::listen_for_jobs(&surreal_db_client).await?; + let mut job_stream = Job::listen_for_jobs(&db).await?; while let Some(notification) = job_stream.next().await { match notification { @@ -80,9 +79,8 @@ async fn main() -> Result<(), Box> { } JobStatus::InProgress { attempts, .. } => { // Only process if this is a retry after an error, not our own update - if let Ok(Some(current_job)) = surreal_db_client - .get_item::(¬ification.data.id) - .await + if let Ok(Some(current_job)) = + db.get_item::(¬ification.data.id).await { match current_job.status { JobStatus::Error(_)