diff --git a/src/analysis/mod.rs b/src/analysis/mod.rs deleted file mode 100644 index d8d3864..0000000 --- a/src/analysis/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod ingress; diff --git a/src/error.rs b/src/error.rs index b25cb88..eca465b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,6 +2,8 @@ use async_openai::error::OpenAIError; use thiserror::Error; use tokio::task::JoinError; +use crate::{ingress::types::ingress_input::IngressContentError, rabbitmq::RabbitMQError}; + /// Error types for processing `TextContent`. #[derive(Error, Debug)] pub enum ProcessingError { @@ -23,3 +25,15 @@ pub enum ProcessingError { #[error("Task join error: {0}")] JoinError(#[from] JoinError), } + +#[derive(Error, Debug)] +pub enum IngressConsumerError { + #[error("RabbitMQ error: {0}")] + RabbitMQ(#[from] RabbitMQError), + + #[error("Processing error: {0}")] + Processing(#[from] ProcessingError), + + #[error("Ingress content error: {0}")] + IngressContent(#[from] IngressContentError), +} diff --git a/src/analysis/ingress/ingress_analyser.rs b/src/ingress/analysis/ingress_analyser.rs similarity index 92% rename from src/analysis/ingress/ingress_analyser.rs rename to src/ingress/analysis/ingress_analyser.rs index cbf65d1..aae0bde 100644 --- a/src/analysis/ingress/ingress_analyser.rs +++ b/src/ingress/analysis/ingress_analyser.rs @@ -1,9 +1,6 @@ use crate::{ - analysis::ingress::{ - prompt::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE}, - types::llm_analysis_result::LLMGraphAnalysisResult, - }, error::ProcessingError, + ingress::analysis::prompt::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE}, retrieval::vector::find_items_by_vector_similarity, storage::types::{knowledge_entity::KnowledgeEntity, StoredObject}, }; @@ -15,7 +12,9 @@ use async_openai::types::{ use serde_json::json; use surrealdb::engine::remote::ws::Client; use surrealdb::Surreal; -use tracing::{debug, instrument}; +use tracing::debug; + +use super::types::llm_analysis_result::LLMGraphAnalysisResult; pub struct IngressAnalyzer<'a> { db_client: &'a Surreal, @@ -33,7 +32,6 @@ impl<'a> IngressAnalyzer<'a> { } } - #[instrument(skip(self))] pub async fn analyze_content( &self, category: &str, @@ -48,7 +46,6 @@ impl<'a> IngressAnalyzer<'a> { self.perform_analysis(llm_request).await } - #[instrument(skip(self))] async fn find_similar_entities( &self, category: &str, @@ -70,7 +67,6 @@ impl<'a> IngressAnalyzer<'a> { .await } - #[instrument(skip(self))] fn prepare_llm_request( &self, category: &str, @@ -108,7 +104,7 @@ impl<'a> IngressAnalyzer<'a> { }; CreateChatCompletionRequestArgs::default() - .model("gpt-4-mini") + .model("gpt-4o-mini") .temperature(0.2) .max_tokens(2048u32) .messages([ @@ -120,7 +116,6 @@ impl<'a> IngressAnalyzer<'a> { .map_err(|e| ProcessingError::LLMParsingError(e.to_string())) } - #[instrument(skip(self, request))] async fn perform_analysis( &self, request: CreateChatCompletionRequest, diff --git a/src/analysis/ingress/mod.rs b/src/ingress/analysis/mod.rs similarity index 100% rename from src/analysis/ingress/mod.rs rename to src/ingress/analysis/mod.rs diff --git a/src/analysis/ingress/prompt.rs b/src/ingress/analysis/prompt.rs similarity index 100% rename from src/analysis/ingress/prompt.rs rename to src/ingress/analysis/prompt.rs diff --git a/src/models/graph_entities.rs b/src/ingress/analysis/types/graph_mapper.rs similarity index 100% rename from src/models/graph_entities.rs rename to src/ingress/analysis/types/graph_mapper.rs diff --git a/src/analysis/ingress/types/llm_analysis_result.rs b/src/ingress/analysis/types/llm_analysis_result.rs similarity index 98% rename from src/analysis/ingress/types/llm_analysis_result.rs rename to src/ingress/analysis/types/llm_analysis_result.rs index b97aaab..fb1bc80 100644 --- a/src/analysis/ingress/types/llm_analysis_result.rs +++ b/src/ingress/analysis/types/llm_analysis_result.rs @@ -5,14 +5,15 @@ use tokio::task; use crate::{ error::ProcessingError, - models::graph_entities::GraphMapper, storage::types::{ knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}, knowledge_relationship::KnowledgeRelationship, }, utils::embedding::generate_embedding, }; -use futures::future::try_join_all; // For future parallelization +use futures::future::try_join_all; + +use super::graph_mapper::GraphMapper; // For future parallelization #[derive(Debug, Serialize, Deserialize, Clone)] pub struct LLMKnowledgeEntity { diff --git a/src/analysis/ingress/types/mod.rs b/src/ingress/analysis/types/mod.rs similarity index 56% rename from src/analysis/ingress/types/mod.rs rename to src/ingress/analysis/types/mod.rs index 3f2223e..f8def07 100644 --- a/src/analysis/ingress/types/mod.rs +++ b/src/ingress/analysis/types/mod.rs @@ -1 +1,2 @@ +pub mod graph_mapper; pub mod llm_analysis_result; diff --git a/src/ingress/content_processor.rs b/src/ingress/content_processor.rs new file mode 100644 index 0000000..57ff233 --- /dev/null +++ b/src/ingress/content_processor.rs @@ -0,0 +1,119 @@ +use text_splitter::TextSplitter; +use tracing::{debug, info}; + +use crate::{ + error::ProcessingError, + retrieval::vector::find_items_by_vector_similarity, + storage::{ + db::{store_item, SurrealDbClient}, + types::{ + knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship, + text_chunk::TextChunk, text_content::TextContent, + }, + }, + utils::embedding::generate_embedding, +}; + +use super::analysis::{ + ingress_analyser::IngressAnalyzer, types::llm_analysis_result::LLMGraphAnalysisResult, +}; + +pub struct ContentProcessor { + db_client: SurrealDbClient, + openai_client: async_openai::Client, +} + +impl ContentProcessor { + pub async fn new() -> Result { + Ok(Self { + db_client: SurrealDbClient::new().await?, + openai_client: async_openai::Client::new(), + }) + } + + pub async fn process(&self, content: &TextContent) -> Result<(), ProcessingError> { + // Store original content + store_item(&self.db_client, content.clone()).await?; + + // Process in parallel where possible + let (analysis, _similar_chunks) = tokio::try_join!( + self.perform_semantic_analysis(content), + self.find_similar_content(content), + )?; + + // Convert and store entities + let (entities, relationships) = analysis + .to_database_entities(&content.id, &self.openai_client) + .await?; + + // Store everything + tokio::try_join!( + self.store_graph_entities(entities, relationships), + self.store_vector_chunks(content), + )?; + + self.db_client.rebuild_indexes().await?; + Ok(()) + } + + async fn perform_semantic_analysis( + &self, + content: &TextContent, + ) -> Result { + let analyser = IngressAnalyzer::new(&self.db_client, &self.openai_client); + analyser + .analyze_content(&content.category, &content.instructions, &content.text) + .await + } + + async fn find_similar_content( + &self, + content: &TextContent, + ) -> Result, ProcessingError> { + find_items_by_vector_similarity( + 3, + content.text.clone(), + &self.db_client, + "text_chunk".to_string(), + &self.openai_client, + ) + .await + } + + async fn store_graph_entities( + &self, + entities: Vec, + relationships: Vec, + ) -> Result<(), ProcessingError> { + for entity in &entities { + debug!("Storing entity: {:?}", entity); + store_item(&self.db_client, entity.clone()).await?; + } + + for relationship in &relationships { + debug!("Storing relationship: {:?}", relationship); + store_item(&self.db_client, relationship.clone()).await?; + } + + info!( + "Stored {} entities and {} relationships", + entities.len(), + relationships.len() + ); + Ok(()) + } + + async fn store_vector_chunks(&self, content: &TextContent) -> Result<(), ProcessingError> { + let splitter = TextSplitter::new(500..2000); + let chunks = splitter.chunks(&content.text); + + // Could potentially process chunks in parallel with a bounded concurrent limit + for chunk in chunks { + let embedding = generate_embedding(&self.openai_client, chunk.to_string()).await?; + let text_chunk = TextChunk::new(content.id.to_string(), chunk.to_string(), embedding); + store_item(&self.db_client, text_chunk).await?; + } + + Ok(()) + } +} diff --git a/src/ingress/mod.rs b/src/ingress/mod.rs new file mode 100644 index 0000000..3959ef7 --- /dev/null +++ b/src/ingress/mod.rs @@ -0,0 +1,3 @@ +pub mod analysis; +pub mod content_processor; +pub mod types; diff --git a/src/models/ingress_content.rs b/src/ingress/types/ingress_input.rs similarity index 96% rename from src/models/ingress_content.rs rename to src/ingress/types/ingress_input.rs index ed37371..30fd248 100644 --- a/src/models/ingress_content.rs +++ b/src/ingress/types/ingress_input.rs @@ -1,10 +1,10 @@ +use super::ingress_object::IngressObject; +use crate::storage::{db::SurrealDbClient, types::file_info::FileInfo}; use serde::{Deserialize, Serialize}; use thiserror::Error; use tracing::info; use url::Url; use uuid::Uuid; -use crate::surrealdb::SurrealDbClient; -use super::{file_info::FileInfo, ingress_object::IngressObject }; /// Struct defining the expected body when ingressing content. #[derive(Serialize, Deserialize, Debug)] @@ -105,4 +105,3 @@ pub async fn create_ingress_objects( Ok(object_list) } - diff --git a/src/models/ingress_object.rs b/src/ingress/types/ingress_object.rs similarity index 96% rename from src/models/ingress_object.rs rename to src/ingress/types/ingress_object.rs index 2b1297a..bd45279 100644 --- a/src/models/ingress_object.rs +++ b/src/ingress/types/ingress_object.rs @@ -1,8 +1,8 @@ -use super::ingress_content::IngressContentError; -use crate::models::file_info::FileInfo; -use crate::storage::types::text_content::TextContent; +use crate::storage::types::{file_info::FileInfo, text_content::TextContent}; use serde::{Deserialize, Serialize}; +use super::ingress_input::IngressContentError; + /// Knowledge object type, containing the content or reference to it, as well as metadata #[derive(Debug, Serialize, Deserialize, Clone)] pub enum IngressObject { diff --git a/src/ingress/types/mod.rs b/src/ingress/types/mod.rs new file mode 100644 index 0000000..8cfb70e --- /dev/null +++ b/src/ingress/types/mod.rs @@ -0,0 +1,2 @@ +pub mod ingress_input; +pub mod ingress_object; diff --git a/src/lib.rs b/src/lib.rs index 380e0b4..1be0aca 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,7 @@ -pub mod analysis; pub mod error; -pub mod models; +pub mod ingress; pub mod rabbitmq; pub mod retrieval; pub mod routes; pub mod storage; -pub mod surrealdb; pub mod utils; diff --git a/src/models/mod.rs b/src/models/mod.rs deleted file mode 100644 index 9392443..0000000 --- a/src/models/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub mod file_info; -pub mod graph_entities; -pub mod ingress_content; -pub mod ingress_object; -pub mod text_content; diff --git a/src/models/text_content.rs b/src/models/text_content.rs deleted file mode 100644 index e3cf7b5..0000000 --- a/src/models/text_content.rs +++ /dev/null @@ -1,122 +0,0 @@ -use crate::analysis::ingress::ingress_analyser::IngressAnalyzer; -use crate::retrieval::graph::find_entities_by_source_id; -use crate::retrieval::vector::find_items_by_vector_similarity; -use crate::storage::db::store_item; -use crate::storage::types::knowledge_entity::KnowledgeEntity; -use crate::storage::types::knowledge_relationship::KnowledgeRelationship; -use crate::storage::types::text_chunk::TextChunk; -use crate::storage::types::text_content::TextContent; -use crate::storage::types::StoredObject; -use crate::utils::embedding::generate_embedding; -use crate::{error::ProcessingError, surrealdb::SurrealDbClient}; -use surrealdb::{engine::remote::ws::Client, Surreal}; -use text_splitter::TextSplitter; -use tracing::{debug, info}; - -impl TextContent { - /// Processes the `TextContent` by sending it to an LLM, storing in a graph DB, and vector DB. - pub async fn process(&self) -> Result<(), ProcessingError> { - let db_client = SurrealDbClient::new().await?; - let openai_client = async_openai::Client::new(); - - // Store TextContent - store_item(&db_client, self.clone()).await?; - - // Get related nodes - let closest_text_content: Vec = find_items_by_vector_similarity( - 3, - self.text.clone(), - &db_client, - "text_chunk".to_string(), - &openai_client, - ) - .await?; - - for node in closest_text_content { - let related_nodes: Vec = find_entities_by_source_id( - node.source_id.to_owned(), - KnowledgeEntity::table_name().to_string(), - &db_client, - ) - .await?; - for related_node in related_nodes { - info!("{:?}", related_node.name); - } - } - - // Rebuild indexes - db_client.rebuild_indexes().await?; - - // Step 1: Send to LLM for analysis - let analyser = IngressAnalyzer::new(&db_client, &openai_client); - let analysis = analyser - .analyze_content(&self.category, &self.instructions, &self.text) - .await?; - - // Step 2: Convert LLM analysis to database entities - let (entities, relationships) = analysis - .to_database_entities(&self.id, &openai_client) - .await?; - - // Step 3: Store in database - self.store_in_graph_db(entities, relationships, &db_client) - .await?; - - // Step 4: Split text and store in Vector DB - self.store_in_vector_db(&db_client, &openai_client).await?; - - Ok(()) - } - - async fn store_in_graph_db( - &self, - entities: Vec, - relationships: Vec, - db_client: &Surreal, - ) -> Result<(), ProcessingError> { - for entity in &entities { - debug!( - "{:?}, {:?}, {:?}", - &entity.id, &entity.name, &entity.description - ); - - store_item(db_client, entity.clone()).await?; - } - - for relationship in &relationships { - debug!("{:?}", relationship); - - store_item(db_client, relationship.clone()).await?; - } - - info!( - "Inserted to database: {:?} entities, {:?} relationships", - entities.len(), - relationships.len() - ); - - Ok(()) - } - - /// Splits text and stores it in a vector database. - async fn store_in_vector_db( - &self, - db_client: &Surreal, - openai_client: &async_openai::Client, - ) -> Result<(), ProcessingError> { - let max_characters = 500..2000; - let splitter = TextSplitter::new(max_characters); - - let chunks = splitter.chunks(self.text.as_str()); - - for chunk in chunks { - info!("Chunk: {}", chunk); - let embedding = generate_embedding(openai_client, chunk.to_string()).await?; - let text_chunk = TextChunk::new(self.id.to_string(), chunk.to_string(), embedding); - - store_item(db_client, text_chunk).await?; - } - - Ok(()) - } -} diff --git a/src/rabbitmq/consumer.rs b/src/rabbitmq/consumer.rs index 38fdd47..5f2827c 100644 --- a/src/rabbitmq/consumer.rs +++ b/src/rabbitmq/consumer.rs @@ -1,12 +1,16 @@ -use lapin::{ - message::Delivery, options::*, types::FieldTable, Channel, Consumer, Queue -}; use futures_lite::stream::StreamExt; +use lapin::{message::Delivery, options::*, types::FieldTable, Channel, Consumer, Queue}; -use crate::models::{ingress_content::IngressContentError, ingress_object::IngressObject }; +use crate::{ + error::IngressConsumerError, + ingress::{ + content_processor::ContentProcessor, + types::{ingress_input::IngressContentError, ingress_object::IngressObject}, + }, +}; use super::{RabbitMQCommon, RabbitMQCommonTrait, RabbitMQConfig, RabbitMQError}; -use tracing::{info, error}; +use tracing::{error, info}; /// Struct to consume messages from RabbitMQ. pub struct RabbitMQConsumer { @@ -26,18 +30,22 @@ impl RabbitMQConsumer { /// * `Result` - The created client or an error. pub async fn new(config: &RabbitMQConfig) -> Result { let common = RabbitMQCommon::new(config).await?; - + // Passively declare the exchange (it should already exist) common.declare_exchange(config, true).await?; - + // Declare queue and bind it to the channel let queue = Self::declare_queue(&common.channel, config).await?; Self::bind_queue(&common.channel, &config.exchange, &queue, config).await?; - + // Initialize the consumer let consumer = Self::initialize_consumer(&common.channel, config).await?; - Ok(Self { common, queue, consumer }) + Ok(Self { + common, + queue, + consumer, + }) } /// Sets up the consumer based on the channel and `RabbitMQConfig`. @@ -48,7 +56,10 @@ impl RabbitMQConsumer { /// /// # Returns /// * `Result` - The initialized consumer or error - async fn initialize_consumer(channel: &Channel, config: &RabbitMQConfig) -> Result { + async fn initialize_consumer( + channel: &Channel, + config: &RabbitMQConfig, + ) -> Result { channel .basic_consume( &config.queue, @@ -56,7 +67,8 @@ impl RabbitMQConsumer { BasicConsumeOptions::default(), FieldTable::default(), ) - .await.map_err(|e| RabbitMQError::InitializeConsumerError(e.to_string())) + .await + .map_err(|e| RabbitMQError::InitializeConsumerError(e.to_string())) } /// Declares the queue based on the channel and `RabbitMQConfig`. /// # Arguments @@ -65,7 +77,10 @@ impl RabbitMQConsumer { /// /// # Returns /// * `Result` - The initialized queue or error - async fn declare_queue(channel: &Channel, config: &RabbitMQConfig) -> Result { + async fn declare_queue( + channel: &Channel, + config: &RabbitMQConfig, + ) -> Result { channel .queue_declare( &config.queue, @@ -88,7 +103,12 @@ impl RabbitMQConsumer { /// /// # Returns /// * `Result<(), RabbitMQError>` - Ok or error - async fn bind_queue(channel: &Channel, exchange: &str, queue: &Queue, config: &RabbitMQConfig) -> Result<(), RabbitMQError> { + async fn bind_queue( + channel: &Channel, + exchange: &str, + queue: &Queue, + config: &RabbitMQConfig, + ) -> Result<(), RabbitMQError> { channel .queue_bind( queue.name().as_str(), @@ -111,7 +131,11 @@ impl RabbitMQConsumer { /// `Delivery` - A delivery reciept, required to ack or nack the delivery. pub async fn consume(&self) -> Result<(IngressObject, Delivery), RabbitMQError> { // Receive the next message - let delivery = self.consumer.clone().next().await + let delivery = self + .consumer + .clone() + .next() + .await .ok_or_else(|| RabbitMQError::ConsumeError("No message received".to_string()))? .map_err(|e| RabbitMQError::ConsumeError(e.to_string()))?; @@ -131,7 +155,8 @@ impl RabbitMQConsumer { /// # Returns /// * `Result<(), RabbitMQError>` - Ok or error pub async fn ack_delivery(&self, delivery: Delivery) -> Result<(), RabbitMQError> { - self.common.channel + self.common + .channel .basic_ack(delivery.delivery_tag, BasicAckOptions::default()) .await .map_err(|e| RabbitMQError::ConsumeError(e.to_string()))?; @@ -139,33 +164,22 @@ impl RabbitMQConsumer { Ok(()) } /// Function to continually consume messages as they come in - /// WIP - pub async fn process_messages(&self) -> Result<(), RabbitMQError> { + pub async fn process_messages(&self) -> Result<(), IngressConsumerError> { loop { match self.consume().await { Ok((ingress, delivery)) => { info!("Received IngressObject: {:?}", ingress); - let text_content = ingress.to_text_content().await.unwrap(); - text_content.process().await.unwrap(); - + // Get the TextContent + let text_content = ingress.to_text_content().await?; + + // Initialize ContentProcessor which handles LLM analysis and storage + let content_processor = ContentProcessor::new().await?; + + // Begin processing of TextContent + content_processor.process(&text_content).await?; + + // Remove from queue self.ack_delivery(delivery).await?; - // Process the IngressContent - // match self.handle_ingress_content(&ingress).await { - // Ok(_) => { - // info!("Successfully handled IngressContent"); - // // Acknowledge the message - // if let Err(e) = self.ack_delivery(delivery).await { - // error!("Failed to acknowledge message: {:?}", e); - // } - // }, - // Err(e) => { - // error!("Failed to handle IngressContent: {:?}", e); - // // For now, we'll acknowledge to remove it from the queue. Change to nack? - // if let Err(ack_err) = self.ack_delivery(delivery).await { - // error!("Failed to acknowledge message after handling error: {:?}", ack_err); - // } - // } - // } } Err(RabbitMQError::ConsumeError(e)) => { error!("Error consuming message: {}", e); @@ -182,7 +196,10 @@ impl RabbitMQConsumer { Ok(()) } - pub async fn handle_ingress_content(&self, ingress: &IngressObject) -> Result<(), IngressContentError> { + pub async fn handle_ingress_content( + &self, + ingress: &IngressObject, + ) -> Result<(), IngressContentError> { info!("Processing IngressContent: {:?}", ingress); unimplemented!() diff --git a/src/rabbitmq/mod.rs b/src/rabbitmq/mod.rs index f4c2e78..38f0e89 100644 --- a/src/rabbitmq/mod.rs +++ b/src/rabbitmq/mod.rs @@ -1,13 +1,16 @@ -pub mod publisher; pub mod consumer; +pub mod publisher; use axum::async_trait; use lapin::{ - options::ExchangeDeclareOptions, types::FieldTable, Channel, Connection, ConnectionProperties, ExchangeKind + options::ExchangeDeclareOptions, types::FieldTable, Channel, Connection, ConnectionProperties, + ExchangeKind, }; use thiserror::Error; use tracing::debug; +use crate::error::ProcessingError; + /// Possible errors related to RabbitMQ operations. #[derive(Error, Debug)] pub enum RabbitMQError { @@ -25,6 +28,8 @@ pub enum RabbitMQError { InitializeConsumerError(String), #[error("Queue error: {0}")] QueueError(String), + #[error("Processing error: {0}")] + ProcessingError(#[from] ProcessingError), } /// Struct containing the information required to set up a client and connection. @@ -42,18 +47,21 @@ pub struct RabbitMQCommon { pub channel: Channel, } - /// Defines the behavior for RabbitMQCommon client operations. #[cfg_attr(test, mockall::automock)] #[async_trait] pub trait RabbitMQCommonTrait: Send + Sync { async fn create_connection(config: &RabbitMQConfig) -> Result; - async fn declare_exchange(&self, config: &RabbitMQConfig, passive: bool) -> Result<(), RabbitMQError>; -} + async fn declare_exchange( + &self, + config: &RabbitMQConfig, + passive: bool, + ) -> Result<(), RabbitMQError>; +} impl RabbitMQCommon { /// Sets up a new RabbitMQ client or error - /// + /// /// # Arguments /// * `RabbitMQConfig` - Configuration object with required information /// @@ -62,7 +70,10 @@ impl RabbitMQCommon { pub async fn new(config: &RabbitMQConfig) -> Result { let connection = Self::create_connection(config).await?; let channel = connection.create_channel().await?; - Ok(Self { connection, channel }) + Ok(Self { + connection, + channel, + }) } } @@ -77,7 +88,11 @@ impl RabbitMQCommonTrait for RabbitMQCommon { } /// Function to declare the exchange required - async fn declare_exchange(&self, config: &RabbitMQConfig, passive: bool) -> Result<(), RabbitMQError> { + async fn declare_exchange( + &self, + config: &RabbitMQConfig, + passive: bool, + ) -> Result<(), RabbitMQError> { debug!("Declaring exchange"); self.channel .exchange_declare( diff --git a/src/rabbitmq/publisher.rs b/src/rabbitmq/publisher.rs index 2d9f125..c86ad5b 100644 --- a/src/rabbitmq/publisher.rs +++ b/src/rabbitmq/publisher.rs @@ -1,11 +1,9 @@ -use lapin::{ - options::*, publisher_confirm::Confirmation, BasicProperties, -}; +use lapin::{options::*, publisher_confirm::Confirmation, BasicProperties}; -use crate::models::ingress_object::IngressObject; +use crate::ingress::types::ingress_object::IngressObject; use super::{RabbitMQCommon, RabbitMQCommonTrait, RabbitMQConfig, RabbitMQError}; -use tracing::{info, error}; +use tracing::{error, info}; /// Struct to publish messages to RabbitMQ. pub struct RabbitMQProducer { @@ -27,7 +25,7 @@ impl RabbitMQProducer { let common = RabbitMQCommon::new(config).await?; common.declare_exchange(config, false).await?; - Ok(Self { + Ok(Self { common, exchange_name: config.exchange.clone(), routing_key: config.routing_key.clone(), @@ -39,19 +37,23 @@ impl RabbitMQProducer { /// # Arguments /// * `self` - Reference to self /// * `ingress_object` - A initialized IngressObject - /// + /// /// # Returns /// * `Result` - Confirmation of sent message or error - pub async fn publish(&self, ingress_object: &IngressObject) -> Result { + pub async fn publish( + &self, + ingress_object: &IngressObject, + ) -> Result { // Serialize IngressObject to JSON - let payload = serde_json::to_vec(ingress_object) - .map_err(|e| { - error!("Serialization Error: {}", e); - RabbitMQError::PublishError(format!("Serialization Error: {}", e)) - })?; - + let payload = serde_json::to_vec(ingress_object).map_err(|e| { + error!("Serialization Error: {}", e); + RabbitMQError::PublishError(format!("Serialization Error: {}", e)) + })?; + // Publish the serialized payload to RabbitMQ - let confirmation = self.common.channel + let confirmation = self + .common + .channel .basic_publish( &self.exchange_name, &self.routing_key, @@ -69,9 +71,12 @@ impl RabbitMQProducer { error!("Publish Confirmation Error: {}", e); RabbitMQError::PublishError(format!("Publish Confirmation Error: {}", e)) })?; - - info!("Published IngressObject to exchange '{}' with routing key '{}'", self.exchange_name, self.routing_key); - + + info!( + "Published IngressObject to exchange '{}' with routing key '{}'", + self.exchange_name, self.routing_key + ); + Ok(confirmation) } } diff --git a/src/routes/file.rs b/src/routes/file.rs index 103916c..7b21c7e 100644 --- a/src/routes/file.rs +++ b/src/routes/file.rs @@ -1,16 +1,14 @@ -use std::sync::Arc; -use axum::{ - extract::Path, response::IntoResponse, Extension, Json +use crate::storage::{ + db::SurrealDbClient, + types::file_info::{FileError, FileInfo}, }; -use axum_typed_multipart::{TypedMultipart, FieldData, TryFromMultipart}; +use axum::{extract::Path, response::IntoResponse, Extension, Json}; +use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart}; use serde_json::json; +use std::sync::Arc; use tempfile::NamedTempFile; use tracing::info; use uuid::Uuid; -use crate::{ - models::file_info::{FileError, FileInfo}, - surrealdb::SurrealDbClient, -}; #[derive(Debug, TryFromMultipart)] pub struct FileUploadRequest { @@ -19,7 +17,7 @@ pub struct FileUploadRequest { } /// Handler to upload a new file. -/// +/// /// Route: POST /file pub async fn upload_handler( Extension(db_client): Extension>, @@ -40,13 +38,12 @@ pub async fn upload_handler( info!("File uploaded successfully: {:?}", file_info); - // Return the response with HTTP 200 Ok((axum::http::StatusCode::OK, Json(response))) } /// Handler to retrieve file information by UUID. -/// +/// /// Route: GET /file/:uuid pub async fn get_file_handler( Extension(db_client): Extension>, @@ -73,7 +70,7 @@ pub async fn get_file_handler( } /// Handler to update an existing file by UUID. -/// +/// /// Route: PUT /file/:uuid pub async fn update_file_handler( Extension(db_client): Extension>, @@ -101,7 +98,7 @@ pub async fn update_file_handler( } /// Handler to delete a file by UUID. -/// +/// /// Route: DELETE /file/:uuid pub async fn delete_file_handler( Extension(db_client): Extension>, diff --git a/src/routes/ingress.rs b/src/routes/ingress.rs index 2bb9d6b..de187f2 100644 --- a/src/routes/ingress.rs +++ b/src/routes/ingress.rs @@ -1,7 +1,11 @@ -use std::sync::Arc; +use crate::{ + ingress::types::ingress_input::{create_ingress_objects, IngressInput}, + rabbitmq::publisher::RabbitMQProducer, + storage::db::SurrealDbClient, +}; use axum::{http::StatusCode, response::IntoResponse, Extension, Json}; +use std::sync::Arc; use tracing::{error, info}; -use crate::{models::ingress_content::{create_ingress_objects, IngressInput}, rabbitmq::publisher::RabbitMQProducer, surrealdb::SurrealDbClient}; pub async fn ingress_handler( Extension(producer): Extension>, @@ -23,7 +27,7 @@ pub async fn ingress_handler( StatusCode::INTERNAL_SERVER_ERROR, "Failed to publish message", ) - .into_response(); + .into_response(); } } } @@ -31,11 +35,7 @@ pub async fn ingress_handler( } Err(e) => { error!("Failed to process input: {:?}", e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to process input", - ) - .into_response() + (StatusCode::INTERNAL_SERVER_ERROR, "Failed to process input").into_response() } } } diff --git a/src/server.rs b/src/server.rs index 48f7616..bbe1a74 100644 --- a/src/server.rs +++ b/src/server.rs @@ -12,7 +12,7 @@ use zettle_db::{ ingress::ingress_handler, queue_length::queue_length_handler, }, - surrealdb::SurrealDbClient, + storage::db::SurrealDbClient, }; #[tokio::main(flavor = "multi_thread", worker_threads = 2)] diff --git a/src/storage/db.rs b/src/storage/db.rs index 90602c5..1757a5b 100644 --- a/src/storage/db.rs +++ b/src/storage/db.rs @@ -1,8 +1,67 @@ -use surrealdb::{engine::remote::ws::Client, Surreal}; - use crate::error::ProcessingError; use super::types::StoredObject; +use std::ops::Deref; +use surrealdb::{ + engine::remote::ws::{Client, Ws}, + opt::auth::Root, + Error, Surreal, +}; + +#[derive(Clone)] +pub struct SurrealDbClient { + pub client: Surreal, +} + +impl SurrealDbClient { + /// # Initialize a new datbase client + /// + /// # Arguments + /// + /// # Returns + /// * `SurrealDbClient` initialized + pub async fn new() -> Result { + let db = Surreal::new::("127.0.0.1:8000").await?; + + // Sign in to database + db.signin(Root { + username: "root_user", + password: "root_password", + }) + .await?; + + // Set namespace + db.use_ns("test").use_db("test").await?; + + Ok(SurrealDbClient { client: db }) + } + + pub async fn rebuild_indexes(&self) -> Result<(), Error> { + self.client + .query("REBUILD INDEX IF EXISTS idx_embedding ON text_chunk") + .await?; + self.client + .query("REBUILD INDEX IF EXISTS embeddings ON knowledge_entity") + .await?; + Ok(()) + } + + pub async fn drop_table(&self) -> Result<(), Error> + where + T: StoredObject + Send + Sync + 'static, + { + let _deleted: Vec = self.client.delete(T::table_name()).await?; + Ok(()) + } +} + +impl Deref for SurrealDbClient { + type Target = Surreal; + + fn deref(&self) -> &Self::Target { + &self.client + } +} /// Operation to store a object in SurrealDB, requires the struct to implement StoredObject /// diff --git a/src/models/file_info.rs b/src/storage/types/file_info.rs similarity index 78% rename from src/models/file_info.rs rename to src/storage/types/file_info.rs index 75789d2..014a3ba 100644 --- a/src/models/file_info.rs +++ b/src/storage/types/file_info.rs @@ -1,22 +1,26 @@ -use axum::{http::StatusCode, response::{IntoResponse, Response}, Json}; +use axum::{ + http::StatusCode, + response::{IntoResponse, Response}, + Json, +}; use axum_typed_multipart::FieldData; use mime_guess::from_path; use serde::{Deserialize, Serialize}; use serde_json::json; use sha2::{Digest, Sha256}; -use surrealdb::RecordId; use std::{ io::{BufReader, Read}, path::{Path, PathBuf}, }; +use surrealdb::RecordId; use tempfile::NamedTempFile; use thiserror::Error; use tracing::{debug, info}; use uuid::Uuid; -use crate::surrealdb::SurrealDbClient; +use crate::storage::db::SurrealDbClient; -#[derive(Debug,Deserialize)] +#[derive(Debug, Deserialize)] struct Record { #[allow(dead_code)] id: RecordId, @@ -72,7 +76,6 @@ pub enum FileError { #[error("Deserialization error: {0}")] DeserializationError(String), - // Add more error variants as needed. } @@ -82,16 +85,30 @@ impl IntoResponse for FileError { FileError::Io(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Internal server error"), FileError::Utf8(_) => (StatusCode::BAD_REQUEST, "Invalid UTF-8 data"), FileError::MimeDetection(_) => (StatusCode::BAD_REQUEST, "MIME type detection failed"), - FileError::UnsupportedMime(_) => (StatusCode::UNSUPPORTED_MEDIA_TYPE, "Unsupported MIME type"), + FileError::UnsupportedMime(_) => { + (StatusCode::UNSUPPORTED_MEDIA_TYPE, "Unsupported MIME type") + } FileError::FileNotFound(_) => (StatusCode::NOT_FOUND, "File not found"), FileError::DuplicateFile(_) => (StatusCode::CONFLICT, "Duplicate file detected"), - FileError::HashCollision => (StatusCode::INTERNAL_SERVER_ERROR, "Hash collision detected"), + FileError::HashCollision => { + (StatusCode::INTERNAL_SERVER_ERROR, "Hash collision detected") + } FileError::InvalidUuid(_) => (StatusCode::BAD_REQUEST, "Invalid UUID format"), - FileError::MissingFileName => (StatusCode::BAD_REQUEST, "Missing file name in metadata"), - FileError::PersistError(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Failed to persist file"), - FileError::SerializationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Serialization error"), - FileError::DeserializationError(_) => (StatusCode::BAD_REQUEST, "Deserialization error"), - FileError::SurrealError(_) =>(StatusCode::INTERNAL_SERVER_ERROR, "Serialization error"), + FileError::MissingFileName => { + (StatusCode::BAD_REQUEST, "Missing file name in metadata") + } + FileError::PersistError(_) => { + (StatusCode::INTERNAL_SERVER_ERROR, "Failed to persist file") + } + FileError::SerializationError(_) => { + (StatusCode::INTERNAL_SERVER_ERROR, "Serialization error") + } + FileError::DeserializationError(_) => { + (StatusCode::BAD_REQUEST, "Deserialization error") + } + FileError::SurrealError(_) => { + (StatusCode::INTERNAL_SERVER_ERROR, "Serialization error") + } }; let body = Json(json!({ @@ -163,7 +180,11 @@ impl FileInfo { /// /// # Returns /// * `Result` - The updated `FileInfo` or an error. - pub async fn update(uuid: Uuid, new_field_data: FieldData, db_client: &SurrealDbClient) -> Result { + pub async fn update( + uuid: Uuid, + new_field_data: FieldData, + db_client: &SurrealDbClient, + ) -> Result { let new_file = new_field_data.contents; let new_metadata = new_field_data.metadata; @@ -184,7 +205,8 @@ impl FileInfo { let sanitized_new_file_name = sanitize_file_name(&new_file_name); // Persist the new file - let new_persisted_path = Self::persist_file(&uuid, new_file, &sanitized_new_file_name).await?; + let new_persisted_path = + Self::persist_file(&uuid, new_file, &sanitized_new_file_name).await?; // Guess the new MIME type let new_mime_type = Self::guess_mime_type(&new_persisted_path); @@ -202,7 +224,7 @@ impl FileInfo { }; // Save the new item - Self::create_record(&updated_file_info,db_client).await?; + Self::create_record(&updated_file_info, db_client).await?; // Optionally, delete the old file from the filesystem if it's no longer referenced // This requires reference counting or checking if other FileInfo entries point to the same SHA @@ -221,26 +243,35 @@ impl FileInfo { /// * `Result<(), FileError>` - Empty result or an error. pub async fn delete(uuid: Uuid, db_client: &SurrealDbClient) -> Result<(), FileError> { // Retrieve FileInfo to get SHA256 and path - let file_info = Self::get_by_uuid(uuid, db_client).await?; + let file_info = Self::get_by_uuid(uuid, db_client).await?; // Delete the file from the filesystem let file_path = Path::new(&file_info.path); if file_path.exists() { - tokio::fs::remove_file(file_path).await.map_err(FileError::Io)?; + tokio::fs::remove_file(file_path) + .await + .map_err(FileError::Io)?; info!("Deleted file at path: {}", file_info.path); } else { - info!("File path does not exist, skipping deletion: {}", file_info.path); + info!( + "File path does not exist, skipping deletion: {}", + file_info.path + ); } // Delete the FileInfo from database Self::delete_record(&file_info.sha256, db_client).await?; // Remove the UUID directory if empty - let uuid_dir = file_path.parent().ok_or(FileError::FileNotFound(uuid.to_string()))?; + let uuid_dir = file_path + .parent() + .ok_or(FileError::FileNotFound(uuid.to_string()))?; if uuid_dir.exists() { let mut entries = tokio::fs::read_dir(uuid_dir).await.map_err(FileError::Io)?; if entries.next_entry().await?.is_none() { - tokio::fs::remove_dir(uuid_dir).await.map_err(FileError::Io)?; + tokio::fs::remove_dir(uuid_dir) + .await + .map_err(FileError::Io)?; info!("Deleted empty UUID directory: {:?}", uuid_dir); } } @@ -257,19 +288,26 @@ impl FileInfo { /// /// # Returns /// * `Result` - The persisted file path or an error. - async fn persist_file(uuid: &Uuid, file: NamedTempFile, file_name: &str) -> Result { + async fn persist_file( + uuid: &Uuid, + file: NamedTempFile, + file_name: &str, + ) -> Result { let base_dir = Path::new("./data"); let uuid_dir = base_dir.join(uuid.to_string()); // Create the UUID directory if it doesn't exist - tokio::fs::create_dir_all(&uuid_dir).await.map_err(FileError::Io)?; + tokio::fs::create_dir_all(&uuid_dir) + .await + .map_err(FileError::Io)?; // Define the final file path let final_path = uuid_dir.join(file_name); info!("Final path: {:?}", final_path); // Persist the temporary file to the final path - file.persist(&final_path).map_err(|e| FileError::PersistError(e.to_string()))?; + file.persist(&final_path) + .map_err(|e| FileError::PersistError(e.to_string()))?; info!("Persisted file to {:?}", final_path); @@ -313,7 +351,6 @@ impl FileInfo { .to_string() } - /// Creates a new record in SurrealDB for the given `FileInfo`. /// /// # Arguments @@ -323,16 +360,19 @@ impl FileInfo { /// # Returns /// * `Result<(), FileError>` - Empty result or an error. - async fn create_record(file_info: &FileInfo, db_client: &SurrealDbClient) -> Result<(), FileError> { + async fn create_record( + file_info: &FileInfo, + db_client: &SurrealDbClient, + ) -> Result<(), FileError> { // Create the record let _created: Option = db_client .client - .create(("file", &file_info.uuid )) + .create(("file", &file_info.uuid)) .content(file_info.clone()) .await?; - debug!("{:?}",_created); - + debug!("{:?}", _created); + info!("Created FileInfo record with SHA256: {}", file_info.sha256); Ok(()) @@ -346,11 +386,17 @@ impl FileInfo { /// /// # Returns /// * `Result` - The `FileInfo` or `Error` if not found. - pub async fn get_by_uuid(uuid: Uuid, db_client: &SurrealDbClient) -> Result { + pub async fn get_by_uuid( + uuid: Uuid, + db_client: &SurrealDbClient, + ) -> Result { let query = format!("SELECT * FROM file WHERE uuid = '{}'", uuid); let response: Vec = db_client.client.query(query).await?.take(0)?; - response.into_iter().next().ok_or(FileError::FileNotFound(uuid.to_string())) + response + .into_iter() + .next() + .ok_or(FileError::FileNotFound(uuid.to_string())) } /// Retrieves a `FileInfo` by SHA256. @@ -367,7 +413,11 @@ impl FileInfo { debug!("{:?}", response); - response.into_iter().next().ok_or(FileError::FileNotFound(sha256.to_string())) } + response + .into_iter() + .next() + .ok_or(FileError::FileNotFound(sha256.to_string())) + } /// Deletes a `FileInfo` record by SHA256. /// @@ -381,10 +431,7 @@ impl FileInfo { let table = "file"; let primary_key = sha256; - let _created: Option = db_client - .client - .delete((table, primary_key)) - .await?; + let _created: Option = db_client.client.delete((table, primary_key)).await?; info!("Deleted FileInfo record with SHA256: {}", sha256); @@ -395,7 +442,14 @@ impl FileInfo { /// Sanitizes the file name to prevent security vulnerabilities like directory traversal. /// Replaces any non-alphanumeric characters (excluding '.' and '_') with underscores. fn sanitize_file_name(file_name: &str) -> String { - file_name.chars() - .map(|c| if c.is_ascii_alphanumeric() || c == '.' || c == '_' { c } else { '_' }) + file_name + .chars() + .map(|c| { + if c.is_ascii_alphanumeric() || c == '.' || c == '_' { + c + } else { + '_' + } + }) .collect() } diff --git a/src/storage/types/mod.rs b/src/storage/types/mod.rs index 4c40e90..e443785 100644 --- a/src/storage/types/mod.rs +++ b/src/storage/types/mod.rs @@ -1,5 +1,6 @@ use axum::async_trait; use serde::{Deserialize, Serialize}; +pub mod file_info; pub mod knowledge_entity; pub mod knowledge_relationship; pub mod text_chunk; diff --git a/src/storage/types/text_content.rs b/src/storage/types/text_content.rs index 314de66..b1a86f4 100644 --- a/src/storage/types/text_content.rs +++ b/src/storage/types/text_content.rs @@ -1,8 +1,9 @@ use uuid::Uuid; -use crate::models::file_info::FileInfo; use crate::stored_object; +use super::file_info::FileInfo; + stored_object!(TextContent, "text_content", { text: String, file_info: Option, diff --git a/src/surrealdb/mod.rs b/src/surrealdb/mod.rs deleted file mode 100644 index 3fd60bc..0000000 --- a/src/surrealdb/mod.rs +++ /dev/null @@ -1,63 +0,0 @@ -use std::ops::Deref; -use surrealdb::{ - engine::remote::ws::{Client, Ws}, - opt::auth::Root, - Error, Surreal, -}; - -use crate::storage::types::StoredObject; - -#[derive(Clone)] -pub struct SurrealDbClient { - pub client: Surreal, -} - -impl SurrealDbClient { - /// # Initialize a new datbase client - /// - /// # Arguments - /// - /// # Returns - /// * `SurrealDbClient` initialized - pub async fn new() -> Result { - let db = Surreal::new::("127.0.0.1:8000").await?; - - // Sign in to database - db.signin(Root { - username: "root_user", - password: "root_password", - }) - .await?; - - // Set namespace - db.use_ns("test").use_db("test").await?; - - Ok(SurrealDbClient { client: db }) - } - - pub async fn rebuild_indexes(&self) -> Result<(), Error> { - self.client - .query("REBUILD INDEX IF EXISTS idx_embedding ON text_chunk") - .await?; - self.client - .query("REBUILD INDEX IF EXISTS embeddings ON knowledge_entity") - .await?; - Ok(()) - } - - pub async fn drop_table(&self) -> Result<(), Error> - where - T: StoredObject + Send + Sync + 'static, - { - let _deleted: Vec = self.client.delete(T::table_name()).await?; - Ok(()) - } -} - -impl Deref for SurrealDbClient { - type Target = Surreal; - - fn deref(&self) -> &Self::Target { - &self.client - } -}