feat: refactoring complete?

This commit is contained in:
Per Stark
2024-11-21 21:23:49 +01:00
parent cd7901eefe
commit 5f64fbd3fb
27 changed files with 428 additions and 338 deletions

View File

@@ -1 +0,0 @@
pub mod ingress;

View File

@@ -2,6 +2,8 @@ use async_openai::error::OpenAIError;
use thiserror::Error;
use tokio::task::JoinError;
use crate::{ingress::types::ingress_input::IngressContentError, rabbitmq::RabbitMQError};
/// Error types for processing `TextContent`.
#[derive(Error, Debug)]
pub enum ProcessingError {
@@ -23,3 +25,15 @@ pub enum ProcessingError {
#[error("Task join error: {0}")]
JoinError(#[from] JoinError),
}
#[derive(Error, Debug)]
pub enum IngressConsumerError {
#[error("RabbitMQ error: {0}")]
RabbitMQ(#[from] RabbitMQError),
#[error("Processing error: {0}")]
Processing(#[from] ProcessingError),
#[error("Ingress content error: {0}")]
IngressContent(#[from] IngressContentError),
}

View File

@@ -1,9 +1,6 @@
use crate::{
analysis::ingress::{
prompt::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE},
types::llm_analysis_result::LLMGraphAnalysisResult,
},
error::ProcessingError,
ingress::analysis::prompt::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE},
retrieval::vector::find_items_by_vector_similarity,
storage::types::{knowledge_entity::KnowledgeEntity, StoredObject},
};
@@ -15,7 +12,9 @@ use async_openai::types::{
use serde_json::json;
use surrealdb::engine::remote::ws::Client;
use surrealdb::Surreal;
use tracing::{debug, instrument};
use tracing::debug;
use super::types::llm_analysis_result::LLMGraphAnalysisResult;
pub struct IngressAnalyzer<'a> {
db_client: &'a Surreal<Client>,
@@ -33,7 +32,6 @@ impl<'a> IngressAnalyzer<'a> {
}
}
#[instrument(skip(self))]
pub async fn analyze_content(
&self,
category: &str,
@@ -48,7 +46,6 @@ impl<'a> IngressAnalyzer<'a> {
self.perform_analysis(llm_request).await
}
#[instrument(skip(self))]
async fn find_similar_entities(
&self,
category: &str,
@@ -70,7 +67,6 @@ impl<'a> IngressAnalyzer<'a> {
.await
}
#[instrument(skip(self))]
fn prepare_llm_request(
&self,
category: &str,
@@ -108,7 +104,7 @@ impl<'a> IngressAnalyzer<'a> {
};
CreateChatCompletionRequestArgs::default()
.model("gpt-4-mini")
.model("gpt-4o-mini")
.temperature(0.2)
.max_tokens(2048u32)
.messages([
@@ -120,7 +116,6 @@ impl<'a> IngressAnalyzer<'a> {
.map_err(|e| ProcessingError::LLMParsingError(e.to_string()))
}
#[instrument(skip(self, request))]
async fn perform_analysis(
&self,
request: CreateChatCompletionRequest,

View File

@@ -5,14 +5,15 @@ use tokio::task;
use crate::{
error::ProcessingError,
models::graph_entities::GraphMapper,
storage::types::{
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
knowledge_relationship::KnowledgeRelationship,
},
utils::embedding::generate_embedding,
};
use futures::future::try_join_all; // For future parallelization
use futures::future::try_join_all;
use super::graph_mapper::GraphMapper; // For future parallelization
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct LLMKnowledgeEntity {

View File

@@ -1 +1,2 @@
pub mod graph_mapper;
pub mod llm_analysis_result;

View File

@@ -0,0 +1,119 @@
use text_splitter::TextSplitter;
use tracing::{debug, info};
use crate::{
error::ProcessingError,
retrieval::vector::find_items_by_vector_similarity,
storage::{
db::{store_item, SurrealDbClient},
types::{
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
text_chunk::TextChunk, text_content::TextContent,
},
},
utils::embedding::generate_embedding,
};
use super::analysis::{
ingress_analyser::IngressAnalyzer, types::llm_analysis_result::LLMGraphAnalysisResult,
};
pub struct ContentProcessor {
db_client: SurrealDbClient,
openai_client: async_openai::Client<async_openai::config::OpenAIConfig>,
}
impl ContentProcessor {
pub async fn new() -> Result<Self, ProcessingError> {
Ok(Self {
db_client: SurrealDbClient::new().await?,
openai_client: async_openai::Client::new(),
})
}
pub async fn process(&self, content: &TextContent) -> Result<(), ProcessingError> {
// Store original content
store_item(&self.db_client, content.clone()).await?;
// Process in parallel where possible
let (analysis, _similar_chunks) = tokio::try_join!(
self.perform_semantic_analysis(content),
self.find_similar_content(content),
)?;
// Convert and store entities
let (entities, relationships) = analysis
.to_database_entities(&content.id, &self.openai_client)
.await?;
// Store everything
tokio::try_join!(
self.store_graph_entities(entities, relationships),
self.store_vector_chunks(content),
)?;
self.db_client.rebuild_indexes().await?;
Ok(())
}
async fn perform_semantic_analysis(
&self,
content: &TextContent,
) -> Result<LLMGraphAnalysisResult, ProcessingError> {
let analyser = IngressAnalyzer::new(&self.db_client, &self.openai_client);
analyser
.analyze_content(&content.category, &content.instructions, &content.text)
.await
}
async fn find_similar_content(
&self,
content: &TextContent,
) -> Result<Vec<TextChunk>, ProcessingError> {
find_items_by_vector_similarity(
3,
content.text.clone(),
&self.db_client,
"text_chunk".to_string(),
&self.openai_client,
)
.await
}
async fn store_graph_entities(
&self,
entities: Vec<KnowledgeEntity>,
relationships: Vec<KnowledgeRelationship>,
) -> Result<(), ProcessingError> {
for entity in &entities {
debug!("Storing entity: {:?}", entity);
store_item(&self.db_client, entity.clone()).await?;
}
for relationship in &relationships {
debug!("Storing relationship: {:?}", relationship);
store_item(&self.db_client, relationship.clone()).await?;
}
info!(
"Stored {} entities and {} relationships",
entities.len(),
relationships.len()
);
Ok(())
}
async fn store_vector_chunks(&self, content: &TextContent) -> Result<(), ProcessingError> {
let splitter = TextSplitter::new(500..2000);
let chunks = splitter.chunks(&content.text);
// Could potentially process chunks in parallel with a bounded concurrent limit
for chunk in chunks {
let embedding = generate_embedding(&self.openai_client, chunk.to_string()).await?;
let text_chunk = TextChunk::new(content.id.to_string(), chunk.to_string(), embedding);
store_item(&self.db_client, text_chunk).await?;
}
Ok(())
}
}

3
src/ingress/mod.rs Normal file
View File

@@ -0,0 +1,3 @@
pub mod analysis;
pub mod content_processor;
pub mod types;

View File

@@ -1,10 +1,10 @@
use super::ingress_object::IngressObject;
use crate::storage::{db::SurrealDbClient, types::file_info::FileInfo};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tracing::info;
use url::Url;
use uuid::Uuid;
use crate::surrealdb::SurrealDbClient;
use super::{file_info::FileInfo, ingress_object::IngressObject };
/// Struct defining the expected body when ingressing content.
#[derive(Serialize, Deserialize, Debug)]
@@ -105,4 +105,3 @@ pub async fn create_ingress_objects(
Ok(object_list)
}

View File

@@ -1,8 +1,8 @@
use super::ingress_content::IngressContentError;
use crate::models::file_info::FileInfo;
use crate::storage::types::text_content::TextContent;
use crate::storage::types::{file_info::FileInfo, text_content::TextContent};
use serde::{Deserialize, Serialize};
use super::ingress_input::IngressContentError;
/// Knowledge object type, containing the content or reference to it, as well as metadata
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum IngressObject {

2
src/ingress/types/mod.rs Normal file
View File

@@ -0,0 +1,2 @@
pub mod ingress_input;
pub mod ingress_object;

View File

@@ -1,9 +1,7 @@
pub mod analysis;
pub mod error;
pub mod models;
pub mod ingress;
pub mod rabbitmq;
pub mod retrieval;
pub mod routes;
pub mod storage;
pub mod surrealdb;
pub mod utils;

View File

@@ -1,5 +0,0 @@
pub mod file_info;
pub mod graph_entities;
pub mod ingress_content;
pub mod ingress_object;
pub mod text_content;

View File

@@ -1,122 +0,0 @@
use crate::analysis::ingress::ingress_analyser::IngressAnalyzer;
use crate::retrieval::graph::find_entities_by_source_id;
use crate::retrieval::vector::find_items_by_vector_similarity;
use crate::storage::db::store_item;
use crate::storage::types::knowledge_entity::KnowledgeEntity;
use crate::storage::types::knowledge_relationship::KnowledgeRelationship;
use crate::storage::types::text_chunk::TextChunk;
use crate::storage::types::text_content::TextContent;
use crate::storage::types::StoredObject;
use crate::utils::embedding::generate_embedding;
use crate::{error::ProcessingError, surrealdb::SurrealDbClient};
use surrealdb::{engine::remote::ws::Client, Surreal};
use text_splitter::TextSplitter;
use tracing::{debug, info};
impl TextContent {
/// Processes the `TextContent` by sending it to an LLM, storing in a graph DB, and vector DB.
pub async fn process(&self) -> Result<(), ProcessingError> {
let db_client = SurrealDbClient::new().await?;
let openai_client = async_openai::Client::new();
// Store TextContent
store_item(&db_client, self.clone()).await?;
// Get related nodes
let closest_text_content: Vec<TextChunk> = find_items_by_vector_similarity(
3,
self.text.clone(),
&db_client,
"text_chunk".to_string(),
&openai_client,
)
.await?;
for node in closest_text_content {
let related_nodes: Vec<KnowledgeEntity> = find_entities_by_source_id(
node.source_id.to_owned(),
KnowledgeEntity::table_name().to_string(),
&db_client,
)
.await?;
for related_node in related_nodes {
info!("{:?}", related_node.name);
}
}
// Rebuild indexes
db_client.rebuild_indexes().await?;
// Step 1: Send to LLM for analysis
let analyser = IngressAnalyzer::new(&db_client, &openai_client);
let analysis = analyser
.analyze_content(&self.category, &self.instructions, &self.text)
.await?;
// Step 2: Convert LLM analysis to database entities
let (entities, relationships) = analysis
.to_database_entities(&self.id, &openai_client)
.await?;
// Step 3: Store in database
self.store_in_graph_db(entities, relationships, &db_client)
.await?;
// Step 4: Split text and store in Vector DB
self.store_in_vector_db(&db_client, &openai_client).await?;
Ok(())
}
async fn store_in_graph_db(
&self,
entities: Vec<KnowledgeEntity>,
relationships: Vec<KnowledgeRelationship>,
db_client: &Surreal<Client>,
) -> Result<(), ProcessingError> {
for entity in &entities {
debug!(
"{:?}, {:?}, {:?}",
&entity.id, &entity.name, &entity.description
);
store_item(db_client, entity.clone()).await?;
}
for relationship in &relationships {
debug!("{:?}", relationship);
store_item(db_client, relationship.clone()).await?;
}
info!(
"Inserted to database: {:?} entities, {:?} relationships",
entities.len(),
relationships.len()
);
Ok(())
}
/// Splits text and stores it in a vector database.
async fn store_in_vector_db(
&self,
db_client: &Surreal<Client>,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<(), ProcessingError> {
let max_characters = 500..2000;
let splitter = TextSplitter::new(max_characters);
let chunks = splitter.chunks(self.text.as_str());
for chunk in chunks {
info!("Chunk: {}", chunk);
let embedding = generate_embedding(openai_client, chunk.to_string()).await?;
let text_chunk = TextChunk::new(self.id.to_string(), chunk.to_string(), embedding);
store_item(db_client, text_chunk).await?;
}
Ok(())
}
}

View File

@@ -1,12 +1,16 @@
use lapin::{
message::Delivery, options::*, types::FieldTable, Channel, Consumer, Queue
};
use futures_lite::stream::StreamExt;
use lapin::{message::Delivery, options::*, types::FieldTable, Channel, Consumer, Queue};
use crate::models::{ingress_content::IngressContentError, ingress_object::IngressObject };
use crate::{
error::IngressConsumerError,
ingress::{
content_processor::ContentProcessor,
types::{ingress_input::IngressContentError, ingress_object::IngressObject},
},
};
use super::{RabbitMQCommon, RabbitMQCommonTrait, RabbitMQConfig, RabbitMQError};
use tracing::{info, error};
use tracing::{error, info};
/// Struct to consume messages from RabbitMQ.
pub struct RabbitMQConsumer {
@@ -26,18 +30,22 @@ impl RabbitMQConsumer {
/// * `Result<Self, RabbitMQError>` - The created client or an error.
pub async fn new(config: &RabbitMQConfig) -> Result<Self, RabbitMQError> {
let common = RabbitMQCommon::new(config).await?;
// Passively declare the exchange (it should already exist)
common.declare_exchange(config, true).await?;
// Declare queue and bind it to the channel
let queue = Self::declare_queue(&common.channel, config).await?;
Self::bind_queue(&common.channel, &config.exchange, &queue, config).await?;
// Initialize the consumer
let consumer = Self::initialize_consumer(&common.channel, config).await?;
Ok(Self { common, queue, consumer })
Ok(Self {
common,
queue,
consumer,
})
}
/// Sets up the consumer based on the channel and `RabbitMQConfig`.
@@ -48,7 +56,10 @@ impl RabbitMQConsumer {
///
/// # Returns
/// * `Result<Consumer, RabbitMQError>` - The initialized consumer or error
async fn initialize_consumer(channel: &Channel, config: &RabbitMQConfig) -> Result<Consumer, RabbitMQError> {
async fn initialize_consumer(
channel: &Channel,
config: &RabbitMQConfig,
) -> Result<Consumer, RabbitMQError> {
channel
.basic_consume(
&config.queue,
@@ -56,7 +67,8 @@ impl RabbitMQConsumer {
BasicConsumeOptions::default(),
FieldTable::default(),
)
.await.map_err(|e| RabbitMQError::InitializeConsumerError(e.to_string()))
.await
.map_err(|e| RabbitMQError::InitializeConsumerError(e.to_string()))
}
/// Declares the queue based on the channel and `RabbitMQConfig`.
/// # Arguments
@@ -65,7 +77,10 @@ impl RabbitMQConsumer {
///
/// # Returns
/// * `Result<Queue, RabbitMQError>` - The initialized queue or error
async fn declare_queue(channel: &Channel, config: &RabbitMQConfig) -> Result<Queue, RabbitMQError> {
async fn declare_queue(
channel: &Channel,
config: &RabbitMQConfig,
) -> Result<Queue, RabbitMQError> {
channel
.queue_declare(
&config.queue,
@@ -88,7 +103,12 @@ impl RabbitMQConsumer {
///
/// # Returns
/// * `Result<(), RabbitMQError>` - Ok or error
async fn bind_queue(channel: &Channel, exchange: &str, queue: &Queue, config: &RabbitMQConfig) -> Result<(), RabbitMQError> {
async fn bind_queue(
channel: &Channel,
exchange: &str,
queue: &Queue,
config: &RabbitMQConfig,
) -> Result<(), RabbitMQError> {
channel
.queue_bind(
queue.name().as_str(),
@@ -111,7 +131,11 @@ impl RabbitMQConsumer {
/// `Delivery` - A delivery reciept, required to ack or nack the delivery.
pub async fn consume(&self) -> Result<(IngressObject, Delivery), RabbitMQError> {
// Receive the next message
let delivery = self.consumer.clone().next().await
let delivery = self
.consumer
.clone()
.next()
.await
.ok_or_else(|| RabbitMQError::ConsumeError("No message received".to_string()))?
.map_err(|e| RabbitMQError::ConsumeError(e.to_string()))?;
@@ -131,7 +155,8 @@ impl RabbitMQConsumer {
/// # Returns
/// * `Result<(), RabbitMQError>` - Ok or error
pub async fn ack_delivery(&self, delivery: Delivery) -> Result<(), RabbitMQError> {
self.common.channel
self.common
.channel
.basic_ack(delivery.delivery_tag, BasicAckOptions::default())
.await
.map_err(|e| RabbitMQError::ConsumeError(e.to_string()))?;
@@ -139,33 +164,22 @@ impl RabbitMQConsumer {
Ok(())
}
/// Function to continually consume messages as they come in
/// WIP
pub async fn process_messages(&self) -> Result<(), RabbitMQError> {
pub async fn process_messages(&self) -> Result<(), IngressConsumerError> {
loop {
match self.consume().await {
Ok((ingress, delivery)) => {
info!("Received IngressObject: {:?}", ingress);
let text_content = ingress.to_text_content().await.unwrap();
text_content.process().await.unwrap();
// Get the TextContent
let text_content = ingress.to_text_content().await?;
// Initialize ContentProcessor which handles LLM analysis and storage
let content_processor = ContentProcessor::new().await?;
// Begin processing of TextContent
content_processor.process(&text_content).await?;
// Remove from queue
self.ack_delivery(delivery).await?;
// Process the IngressContent
// match self.handle_ingress_content(&ingress).await {
// Ok(_) => {
// info!("Successfully handled IngressContent");
// // Acknowledge the message
// if let Err(e) = self.ack_delivery(delivery).await {
// error!("Failed to acknowledge message: {:?}", e);
// }
// },
// Err(e) => {
// error!("Failed to handle IngressContent: {:?}", e);
// // For now, we'll acknowledge to remove it from the queue. Change to nack?
// if let Err(ack_err) = self.ack_delivery(delivery).await {
// error!("Failed to acknowledge message after handling error: {:?}", ack_err);
// }
// }
// }
}
Err(RabbitMQError::ConsumeError(e)) => {
error!("Error consuming message: {}", e);
@@ -182,7 +196,10 @@ impl RabbitMQConsumer {
Ok(())
}
pub async fn handle_ingress_content(&self, ingress: &IngressObject) -> Result<(), IngressContentError> {
pub async fn handle_ingress_content(
&self,
ingress: &IngressObject,
) -> Result<(), IngressContentError> {
info!("Processing IngressContent: {:?}", ingress);
unimplemented!()

View File

@@ -1,13 +1,16 @@
pub mod publisher;
pub mod consumer;
pub mod publisher;
use axum::async_trait;
use lapin::{
options::ExchangeDeclareOptions, types::FieldTable, Channel, Connection, ConnectionProperties, ExchangeKind
options::ExchangeDeclareOptions, types::FieldTable, Channel, Connection, ConnectionProperties,
ExchangeKind,
};
use thiserror::Error;
use tracing::debug;
use crate::error::ProcessingError;
/// Possible errors related to RabbitMQ operations.
#[derive(Error, Debug)]
pub enum RabbitMQError {
@@ -25,6 +28,8 @@ pub enum RabbitMQError {
InitializeConsumerError(String),
#[error("Queue error: {0}")]
QueueError(String),
#[error("Processing error: {0}")]
ProcessingError(#[from] ProcessingError),
}
/// Struct containing the information required to set up a client and connection.
@@ -42,18 +47,21 @@ pub struct RabbitMQCommon {
pub channel: Channel,
}
/// Defines the behavior for RabbitMQCommon client operations.
#[cfg_attr(test, mockall::automock)]
#[async_trait]
pub trait RabbitMQCommonTrait: Send + Sync {
async fn create_connection(config: &RabbitMQConfig) -> Result<Connection, RabbitMQError>;
async fn declare_exchange(&self, config: &RabbitMQConfig, passive: bool) -> Result<(), RabbitMQError>;
}
async fn declare_exchange(
&self,
config: &RabbitMQConfig,
passive: bool,
) -> Result<(), RabbitMQError>;
}
impl RabbitMQCommon {
/// Sets up a new RabbitMQ client or error
///
///
/// # Arguments
/// * `RabbitMQConfig` - Configuration object with required information
///
@@ -62,7 +70,10 @@ impl RabbitMQCommon {
pub async fn new(config: &RabbitMQConfig) -> Result<Self, RabbitMQError> {
let connection = Self::create_connection(config).await?;
let channel = connection.create_channel().await?;
Ok(Self { connection, channel })
Ok(Self {
connection,
channel,
})
}
}
@@ -77,7 +88,11 @@ impl RabbitMQCommonTrait for RabbitMQCommon {
}
/// Function to declare the exchange required
async fn declare_exchange(&self, config: &RabbitMQConfig, passive: bool) -> Result<(), RabbitMQError> {
async fn declare_exchange(
&self,
config: &RabbitMQConfig,
passive: bool,
) -> Result<(), RabbitMQError> {
debug!("Declaring exchange");
self.channel
.exchange_declare(

View File

@@ -1,11 +1,9 @@
use lapin::{
options::*, publisher_confirm::Confirmation, BasicProperties,
};
use lapin::{options::*, publisher_confirm::Confirmation, BasicProperties};
use crate::models::ingress_object::IngressObject;
use crate::ingress::types::ingress_object::IngressObject;
use super::{RabbitMQCommon, RabbitMQCommonTrait, RabbitMQConfig, RabbitMQError};
use tracing::{info, error};
use tracing::{error, info};
/// Struct to publish messages to RabbitMQ.
pub struct RabbitMQProducer {
@@ -27,7 +25,7 @@ impl RabbitMQProducer {
let common = RabbitMQCommon::new(config).await?;
common.declare_exchange(config, false).await?;
Ok(Self {
Ok(Self {
common,
exchange_name: config.exchange.clone(),
routing_key: config.routing_key.clone(),
@@ -39,19 +37,23 @@ impl RabbitMQProducer {
/// # Arguments
/// * `self` - Reference to self
/// * `ingress_object` - A initialized IngressObject
///
///
/// # Returns
/// * `Result<Confirmation, RabbitMQError>` - Confirmation of sent message or error
pub async fn publish(&self, ingress_object: &IngressObject) -> Result<Confirmation, RabbitMQError> {
pub async fn publish(
&self,
ingress_object: &IngressObject,
) -> Result<Confirmation, RabbitMQError> {
// Serialize IngressObject to JSON
let payload = serde_json::to_vec(ingress_object)
.map_err(|e| {
error!("Serialization Error: {}", e);
RabbitMQError::PublishError(format!("Serialization Error: {}", e))
})?;
let payload = serde_json::to_vec(ingress_object).map_err(|e| {
error!("Serialization Error: {}", e);
RabbitMQError::PublishError(format!("Serialization Error: {}", e))
})?;
// Publish the serialized payload to RabbitMQ
let confirmation = self.common.channel
let confirmation = self
.common
.channel
.basic_publish(
&self.exchange_name,
&self.routing_key,
@@ -69,9 +71,12 @@ impl RabbitMQProducer {
error!("Publish Confirmation Error: {}", e);
RabbitMQError::PublishError(format!("Publish Confirmation Error: {}", e))
})?;
info!("Published IngressObject to exchange '{}' with routing key '{}'", self.exchange_name, self.routing_key);
info!(
"Published IngressObject to exchange '{}' with routing key '{}'",
self.exchange_name, self.routing_key
);
Ok(confirmation)
}
}

View File

@@ -1,16 +1,14 @@
use std::sync::Arc;
use axum::{
extract::Path, response::IntoResponse, Extension, Json
use crate::storage::{
db::SurrealDbClient,
types::file_info::{FileError, FileInfo},
};
use axum_typed_multipart::{TypedMultipart, FieldData, TryFromMultipart};
use axum::{extract::Path, response::IntoResponse, Extension, Json};
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
use serde_json::json;
use std::sync::Arc;
use tempfile::NamedTempFile;
use tracing::info;
use uuid::Uuid;
use crate::{
models::file_info::{FileError, FileInfo},
surrealdb::SurrealDbClient,
};
#[derive(Debug, TryFromMultipart)]
pub struct FileUploadRequest {
@@ -19,7 +17,7 @@ pub struct FileUploadRequest {
}
/// Handler to upload a new file.
///
///
/// Route: POST /file
pub async fn upload_handler(
Extension(db_client): Extension<Arc<SurrealDbClient>>,
@@ -40,13 +38,12 @@ pub async fn upload_handler(
info!("File uploaded successfully: {:?}", file_info);
// Return the response with HTTP 200
Ok((axum::http::StatusCode::OK, Json(response)))
}
/// Handler to retrieve file information by UUID.
///
///
/// Route: GET /file/:uuid
pub async fn get_file_handler(
Extension(db_client): Extension<Arc<SurrealDbClient>>,
@@ -73,7 +70,7 @@ pub async fn get_file_handler(
}
/// Handler to update an existing file by UUID.
///
///
/// Route: PUT /file/:uuid
pub async fn update_file_handler(
Extension(db_client): Extension<Arc<SurrealDbClient>>,
@@ -101,7 +98,7 @@ pub async fn update_file_handler(
}
/// Handler to delete a file by UUID.
///
///
/// Route: DELETE /file/:uuid
pub async fn delete_file_handler(
Extension(db_client): Extension<Arc<SurrealDbClient>>,

View File

@@ -1,7 +1,11 @@
use std::sync::Arc;
use crate::{
ingress::types::ingress_input::{create_ingress_objects, IngressInput},
rabbitmq::publisher::RabbitMQProducer,
storage::db::SurrealDbClient,
};
use axum::{http::StatusCode, response::IntoResponse, Extension, Json};
use std::sync::Arc;
use tracing::{error, info};
use crate::{models::ingress_content::{create_ingress_objects, IngressInput}, rabbitmq::publisher::RabbitMQProducer, surrealdb::SurrealDbClient};
pub async fn ingress_handler(
Extension(producer): Extension<Arc<RabbitMQProducer>>,
@@ -23,7 +27,7 @@ pub async fn ingress_handler(
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to publish message",
)
.into_response();
.into_response();
}
}
}
@@ -31,11 +35,7 @@ pub async fn ingress_handler(
}
Err(e) => {
error!("Failed to process input: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to process input",
)
.into_response()
(StatusCode::INTERNAL_SERVER_ERROR, "Failed to process input").into_response()
}
}
}

View File

@@ -12,7 +12,7 @@ use zettle_db::{
ingress::ingress_handler,
queue_length::queue_length_handler,
},
surrealdb::SurrealDbClient,
storage::db::SurrealDbClient,
};
#[tokio::main(flavor = "multi_thread", worker_threads = 2)]

View File

@@ -1,8 +1,67 @@
use surrealdb::{engine::remote::ws::Client, Surreal};
use crate::error::ProcessingError;
use super::types::StoredObject;
use std::ops::Deref;
use surrealdb::{
engine::remote::ws::{Client, Ws},
opt::auth::Root,
Error, Surreal,
};
#[derive(Clone)]
pub struct SurrealDbClient {
pub client: Surreal<Client>,
}
impl SurrealDbClient {
/// # Initialize a new datbase client
///
/// # Arguments
///
/// # Returns
/// * `SurrealDbClient` initialized
pub async fn new() -> Result<Self, Error> {
let db = Surreal::new::<Ws>("127.0.0.1:8000").await?;
// Sign in to database
db.signin(Root {
username: "root_user",
password: "root_password",
})
.await?;
// Set namespace
db.use_ns("test").use_db("test").await?;
Ok(SurrealDbClient { client: db })
}
pub async fn rebuild_indexes(&self) -> Result<(), Error> {
self.client
.query("REBUILD INDEX IF EXISTS idx_embedding ON text_chunk")
.await?;
self.client
.query("REBUILD INDEX IF EXISTS embeddings ON knowledge_entity")
.await?;
Ok(())
}
pub async fn drop_table<T>(&self) -> Result<(), Error>
where
T: StoredObject + Send + Sync + 'static,
{
let _deleted: Vec<T> = self.client.delete(T::table_name()).await?;
Ok(())
}
}
impl Deref for SurrealDbClient {
type Target = Surreal<Client>;
fn deref(&self) -> &Self::Target {
&self.client
}
}
/// Operation to store a object in SurrealDB, requires the struct to implement StoredObject
///

View File

@@ -1,22 +1,26 @@
use axum::{http::StatusCode, response::{IntoResponse, Response}, Json};
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use axum_typed_multipart::FieldData;
use mime_guess::from_path;
use serde::{Deserialize, Serialize};
use serde_json::json;
use sha2::{Digest, Sha256};
use surrealdb::RecordId;
use std::{
io::{BufReader, Read},
path::{Path, PathBuf},
};
use surrealdb::RecordId;
use tempfile::NamedTempFile;
use thiserror::Error;
use tracing::{debug, info};
use uuid::Uuid;
use crate::surrealdb::SurrealDbClient;
use crate::storage::db::SurrealDbClient;
#[derive(Debug,Deserialize)]
#[derive(Debug, Deserialize)]
struct Record {
#[allow(dead_code)]
id: RecordId,
@@ -72,7 +76,6 @@ pub enum FileError {
#[error("Deserialization error: {0}")]
DeserializationError(String),
// Add more error variants as needed.
}
@@ -82,16 +85,30 @@ impl IntoResponse for FileError {
FileError::Io(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Internal server error"),
FileError::Utf8(_) => (StatusCode::BAD_REQUEST, "Invalid UTF-8 data"),
FileError::MimeDetection(_) => (StatusCode::BAD_REQUEST, "MIME type detection failed"),
FileError::UnsupportedMime(_) => (StatusCode::UNSUPPORTED_MEDIA_TYPE, "Unsupported MIME type"),
FileError::UnsupportedMime(_) => {
(StatusCode::UNSUPPORTED_MEDIA_TYPE, "Unsupported MIME type")
}
FileError::FileNotFound(_) => (StatusCode::NOT_FOUND, "File not found"),
FileError::DuplicateFile(_) => (StatusCode::CONFLICT, "Duplicate file detected"),
FileError::HashCollision => (StatusCode::INTERNAL_SERVER_ERROR, "Hash collision detected"),
FileError::HashCollision => {
(StatusCode::INTERNAL_SERVER_ERROR, "Hash collision detected")
}
FileError::InvalidUuid(_) => (StatusCode::BAD_REQUEST, "Invalid UUID format"),
FileError::MissingFileName => (StatusCode::BAD_REQUEST, "Missing file name in metadata"),
FileError::PersistError(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Failed to persist file"),
FileError::SerializationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Serialization error"),
FileError::DeserializationError(_) => (StatusCode::BAD_REQUEST, "Deserialization error"),
FileError::SurrealError(_) =>(StatusCode::INTERNAL_SERVER_ERROR, "Serialization error"),
FileError::MissingFileName => {
(StatusCode::BAD_REQUEST, "Missing file name in metadata")
}
FileError::PersistError(_) => {
(StatusCode::INTERNAL_SERVER_ERROR, "Failed to persist file")
}
FileError::SerializationError(_) => {
(StatusCode::INTERNAL_SERVER_ERROR, "Serialization error")
}
FileError::DeserializationError(_) => {
(StatusCode::BAD_REQUEST, "Deserialization error")
}
FileError::SurrealError(_) => {
(StatusCode::INTERNAL_SERVER_ERROR, "Serialization error")
}
};
let body = Json(json!({
@@ -163,7 +180,11 @@ impl FileInfo {
///
/// # Returns
/// * `Result<FileInfo, FileError>` - The updated `FileInfo` or an error.
pub async fn update(uuid: Uuid, new_field_data: FieldData<NamedTempFile>, db_client: &SurrealDbClient) -> Result<FileInfo, FileError> {
pub async fn update(
uuid: Uuid,
new_field_data: FieldData<NamedTempFile>,
db_client: &SurrealDbClient,
) -> Result<FileInfo, FileError> {
let new_file = new_field_data.contents;
let new_metadata = new_field_data.metadata;
@@ -184,7 +205,8 @@ impl FileInfo {
let sanitized_new_file_name = sanitize_file_name(&new_file_name);
// Persist the new file
let new_persisted_path = Self::persist_file(&uuid, new_file, &sanitized_new_file_name).await?;
let new_persisted_path =
Self::persist_file(&uuid, new_file, &sanitized_new_file_name).await?;
// Guess the new MIME type
let new_mime_type = Self::guess_mime_type(&new_persisted_path);
@@ -202,7 +224,7 @@ impl FileInfo {
};
// Save the new item
Self::create_record(&updated_file_info,db_client).await?;
Self::create_record(&updated_file_info, db_client).await?;
// Optionally, delete the old file from the filesystem if it's no longer referenced
// This requires reference counting or checking if other FileInfo entries point to the same SHA
@@ -221,26 +243,35 @@ impl FileInfo {
/// * `Result<(), FileError>` - Empty result or an error.
pub async fn delete(uuid: Uuid, db_client: &SurrealDbClient) -> Result<(), FileError> {
// Retrieve FileInfo to get SHA256 and path
let file_info = Self::get_by_uuid(uuid, db_client).await?;
let file_info = Self::get_by_uuid(uuid, db_client).await?;
// Delete the file from the filesystem
let file_path = Path::new(&file_info.path);
if file_path.exists() {
tokio::fs::remove_file(file_path).await.map_err(FileError::Io)?;
tokio::fs::remove_file(file_path)
.await
.map_err(FileError::Io)?;
info!("Deleted file at path: {}", file_info.path);
} else {
info!("File path does not exist, skipping deletion: {}", file_info.path);
info!(
"File path does not exist, skipping deletion: {}",
file_info.path
);
}
// Delete the FileInfo from database
Self::delete_record(&file_info.sha256, db_client).await?;
// Remove the UUID directory if empty
let uuid_dir = file_path.parent().ok_or(FileError::FileNotFound(uuid.to_string()))?;
let uuid_dir = file_path
.parent()
.ok_or(FileError::FileNotFound(uuid.to_string()))?;
if uuid_dir.exists() {
let mut entries = tokio::fs::read_dir(uuid_dir).await.map_err(FileError::Io)?;
if entries.next_entry().await?.is_none() {
tokio::fs::remove_dir(uuid_dir).await.map_err(FileError::Io)?;
tokio::fs::remove_dir(uuid_dir)
.await
.map_err(FileError::Io)?;
info!("Deleted empty UUID directory: {:?}", uuid_dir);
}
}
@@ -257,19 +288,26 @@ impl FileInfo {
///
/// # Returns
/// * `Result<PathBuf, FileError>` - The persisted file path or an error.
async fn persist_file(uuid: &Uuid, file: NamedTempFile, file_name: &str) -> Result<PathBuf, FileError> {
async fn persist_file(
uuid: &Uuid,
file: NamedTempFile,
file_name: &str,
) -> Result<PathBuf, FileError> {
let base_dir = Path::new("./data");
let uuid_dir = base_dir.join(uuid.to_string());
// Create the UUID directory if it doesn't exist
tokio::fs::create_dir_all(&uuid_dir).await.map_err(FileError::Io)?;
tokio::fs::create_dir_all(&uuid_dir)
.await
.map_err(FileError::Io)?;
// Define the final file path
let final_path = uuid_dir.join(file_name);
info!("Final path: {:?}", final_path);
// Persist the temporary file to the final path
file.persist(&final_path).map_err(|e| FileError::PersistError(e.to_string()))?;
file.persist(&final_path)
.map_err(|e| FileError::PersistError(e.to_string()))?;
info!("Persisted file to {:?}", final_path);
@@ -313,7 +351,6 @@ impl FileInfo {
.to_string()
}
/// Creates a new record in SurrealDB for the given `FileInfo`.
///
/// # Arguments
@@ -323,16 +360,19 @@ impl FileInfo {
/// # Returns
/// * `Result<(), FileError>` - Empty result or an error.
async fn create_record(file_info: &FileInfo, db_client: &SurrealDbClient) -> Result<(), FileError> {
async fn create_record(
file_info: &FileInfo,
db_client: &SurrealDbClient,
) -> Result<(), FileError> {
// Create the record
let _created: Option<Record> = db_client
.client
.create(("file", &file_info.uuid ))
.create(("file", &file_info.uuid))
.content(file_info.clone())
.await?;
debug!("{:?}",_created);
debug!("{:?}", _created);
info!("Created FileInfo record with SHA256: {}", file_info.sha256);
Ok(())
@@ -346,11 +386,17 @@ impl FileInfo {
///
/// # Returns
/// * `Result<FileInfo, FileError>` - The `FileInfo` or `Error` if not found.
pub async fn get_by_uuid(uuid: Uuid, db_client: &SurrealDbClient) -> Result<FileInfo, FileError> {
pub async fn get_by_uuid(
uuid: Uuid,
db_client: &SurrealDbClient,
) -> Result<FileInfo, FileError> {
let query = format!("SELECT * FROM file WHERE uuid = '{}'", uuid);
let response: Vec<FileInfo> = db_client.client.query(query).await?.take(0)?;
response.into_iter().next().ok_or(FileError::FileNotFound(uuid.to_string()))
response
.into_iter()
.next()
.ok_or(FileError::FileNotFound(uuid.to_string()))
}
/// Retrieves a `FileInfo` by SHA256.
@@ -367,7 +413,11 @@ impl FileInfo {
debug!("{:?}", response);
response.into_iter().next().ok_or(FileError::FileNotFound(sha256.to_string())) }
response
.into_iter()
.next()
.ok_or(FileError::FileNotFound(sha256.to_string()))
}
/// Deletes a `FileInfo` record by SHA256.
///
@@ -381,10 +431,7 @@ impl FileInfo {
let table = "file";
let primary_key = sha256;
let _created: Option<Record> = db_client
.client
.delete((table, primary_key))
.await?;
let _created: Option<Record> = db_client.client.delete((table, primary_key)).await?;
info!("Deleted FileInfo record with SHA256: {}", sha256);
@@ -395,7 +442,14 @@ impl FileInfo {
/// Sanitizes the file name to prevent security vulnerabilities like directory traversal.
/// Replaces any non-alphanumeric characters (excluding '.' and '_') with underscores.
fn sanitize_file_name(file_name: &str) -> String {
file_name.chars()
.map(|c| if c.is_ascii_alphanumeric() || c == '.' || c == '_' { c } else { '_' })
file_name
.chars()
.map(|c| {
if c.is_ascii_alphanumeric() || c == '.' || c == '_' {
c
} else {
'_'
}
})
.collect()
}

View File

@@ -1,5 +1,6 @@
use axum::async_trait;
use serde::{Deserialize, Serialize};
pub mod file_info;
pub mod knowledge_entity;
pub mod knowledge_relationship;
pub mod text_chunk;

View File

@@ -1,8 +1,9 @@
use uuid::Uuid;
use crate::models::file_info::FileInfo;
use crate::stored_object;
use super::file_info::FileInfo;
stored_object!(TextContent, "text_content", {
text: String,
file_info: Option<FileInfo>,

View File

@@ -1,63 +0,0 @@
use std::ops::Deref;
use surrealdb::{
engine::remote::ws::{Client, Ws},
opt::auth::Root,
Error, Surreal,
};
use crate::storage::types::StoredObject;
#[derive(Clone)]
pub struct SurrealDbClient {
pub client: Surreal<Client>,
}
impl SurrealDbClient {
/// # Initialize a new datbase client
///
/// # Arguments
///
/// # Returns
/// * `SurrealDbClient` initialized
pub async fn new() -> Result<Self, Error> {
let db = Surreal::new::<Ws>("127.0.0.1:8000").await?;
// Sign in to database
db.signin(Root {
username: "root_user",
password: "root_password",
})
.await?;
// Set namespace
db.use_ns("test").use_db("test").await?;
Ok(SurrealDbClient { client: db })
}
pub async fn rebuild_indexes(&self) -> Result<(), Error> {
self.client
.query("REBUILD INDEX IF EXISTS idx_embedding ON text_chunk")
.await?;
self.client
.query("REBUILD INDEX IF EXISTS embeddings ON knowledge_entity")
.await?;
Ok(())
}
pub async fn drop_table<T>(&self) -> Result<(), Error>
where
T: StoredObject + Send + Sync + 'static,
{
let _deleted: Vec<T> = self.client.delete(T::table_name()).await?;
Ok(())
}
}
impl Deref for SurrealDbClient {
type Target = Surreal<Client>;
fn deref(&self) -> &Self::Target {
&self.client
}
}