feat: refactoring complete?

This commit is contained in:
Per Stark
2024-11-21 21:23:49 +01:00
parent 94f328e542
commit 1e789e1153
27 changed files with 428 additions and 338 deletions

View File

@@ -1 +0,0 @@
pub mod ingress;

View File

@@ -2,6 +2,8 @@ use async_openai::error::OpenAIError;
use thiserror::Error; use thiserror::Error;
use tokio::task::JoinError; use tokio::task::JoinError;
use crate::{ingress::types::ingress_input::IngressContentError, rabbitmq::RabbitMQError};
/// Error types for processing `TextContent`. /// Error types for processing `TextContent`.
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum ProcessingError { pub enum ProcessingError {
@@ -23,3 +25,15 @@ pub enum ProcessingError {
#[error("Task join error: {0}")] #[error("Task join error: {0}")]
JoinError(#[from] JoinError), JoinError(#[from] JoinError),
} }
#[derive(Error, Debug)]
pub enum IngressConsumerError {
#[error("RabbitMQ error: {0}")]
RabbitMQ(#[from] RabbitMQError),
#[error("Processing error: {0}")]
Processing(#[from] ProcessingError),
#[error("Ingress content error: {0}")]
IngressContent(#[from] IngressContentError),
}

View File

@@ -1,9 +1,6 @@
use crate::{ use crate::{
analysis::ingress::{
prompt::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE},
types::llm_analysis_result::LLMGraphAnalysisResult,
},
error::ProcessingError, error::ProcessingError,
ingress::analysis::prompt::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE},
retrieval::vector::find_items_by_vector_similarity, retrieval::vector::find_items_by_vector_similarity,
storage::types::{knowledge_entity::KnowledgeEntity, StoredObject}, storage::types::{knowledge_entity::KnowledgeEntity, StoredObject},
}; };
@@ -15,7 +12,9 @@ use async_openai::types::{
use serde_json::json; use serde_json::json;
use surrealdb::engine::remote::ws::Client; use surrealdb::engine::remote::ws::Client;
use surrealdb::Surreal; use surrealdb::Surreal;
use tracing::{debug, instrument}; use tracing::debug;
use super::types::llm_analysis_result::LLMGraphAnalysisResult;
pub struct IngressAnalyzer<'a> { pub struct IngressAnalyzer<'a> {
db_client: &'a Surreal<Client>, db_client: &'a Surreal<Client>,
@@ -33,7 +32,6 @@ impl<'a> IngressAnalyzer<'a> {
} }
} }
#[instrument(skip(self))]
pub async fn analyze_content( pub async fn analyze_content(
&self, &self,
category: &str, category: &str,
@@ -48,7 +46,6 @@ impl<'a> IngressAnalyzer<'a> {
self.perform_analysis(llm_request).await self.perform_analysis(llm_request).await
} }
#[instrument(skip(self))]
async fn find_similar_entities( async fn find_similar_entities(
&self, &self,
category: &str, category: &str,
@@ -70,7 +67,6 @@ impl<'a> IngressAnalyzer<'a> {
.await .await
} }
#[instrument(skip(self))]
fn prepare_llm_request( fn prepare_llm_request(
&self, &self,
category: &str, category: &str,
@@ -108,7 +104,7 @@ impl<'a> IngressAnalyzer<'a> {
}; };
CreateChatCompletionRequestArgs::default() CreateChatCompletionRequestArgs::default()
.model("gpt-4-mini") .model("gpt-4o-mini")
.temperature(0.2) .temperature(0.2)
.max_tokens(2048u32) .max_tokens(2048u32)
.messages([ .messages([
@@ -120,7 +116,6 @@ impl<'a> IngressAnalyzer<'a> {
.map_err(|e| ProcessingError::LLMParsingError(e.to_string())) .map_err(|e| ProcessingError::LLMParsingError(e.to_string()))
} }
#[instrument(skip(self, request))]
async fn perform_analysis( async fn perform_analysis(
&self, &self,
request: CreateChatCompletionRequest, request: CreateChatCompletionRequest,

View File

@@ -5,14 +5,15 @@ use tokio::task;
use crate::{ use crate::{
error::ProcessingError, error::ProcessingError,
models::graph_entities::GraphMapper,
storage::types::{ storage::types::{
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}, knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
knowledge_relationship::KnowledgeRelationship, knowledge_relationship::KnowledgeRelationship,
}, },
utils::embedding::generate_embedding, utils::embedding::generate_embedding,
}; };
use futures::future::try_join_all; // For future parallelization use futures::future::try_join_all;
use super::graph_mapper::GraphMapper; // For future parallelization
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
pub struct LLMKnowledgeEntity { pub struct LLMKnowledgeEntity {

View File

@@ -1 +1,2 @@
pub mod graph_mapper;
pub mod llm_analysis_result; pub mod llm_analysis_result;

View File

@@ -0,0 +1,119 @@
use text_splitter::TextSplitter;
use tracing::{debug, info};
use crate::{
error::ProcessingError,
retrieval::vector::find_items_by_vector_similarity,
storage::{
db::{store_item, SurrealDbClient},
types::{
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
text_chunk::TextChunk, text_content::TextContent,
},
},
utils::embedding::generate_embedding,
};
use super::analysis::{
ingress_analyser::IngressAnalyzer, types::llm_analysis_result::LLMGraphAnalysisResult,
};
pub struct ContentProcessor {
db_client: SurrealDbClient,
openai_client: async_openai::Client<async_openai::config::OpenAIConfig>,
}
impl ContentProcessor {
pub async fn new() -> Result<Self, ProcessingError> {
Ok(Self {
db_client: SurrealDbClient::new().await?,
openai_client: async_openai::Client::new(),
})
}
pub async fn process(&self, content: &TextContent) -> Result<(), ProcessingError> {
// Store original content
store_item(&self.db_client, content.clone()).await?;
// Process in parallel where possible
let (analysis, _similar_chunks) = tokio::try_join!(
self.perform_semantic_analysis(content),
self.find_similar_content(content),
)?;
// Convert and store entities
let (entities, relationships) = analysis
.to_database_entities(&content.id, &self.openai_client)
.await?;
// Store everything
tokio::try_join!(
self.store_graph_entities(entities, relationships),
self.store_vector_chunks(content),
)?;
self.db_client.rebuild_indexes().await?;
Ok(())
}
async fn perform_semantic_analysis(
&self,
content: &TextContent,
) -> Result<LLMGraphAnalysisResult, ProcessingError> {
let analyser = IngressAnalyzer::new(&self.db_client, &self.openai_client);
analyser
.analyze_content(&content.category, &content.instructions, &content.text)
.await
}
async fn find_similar_content(
&self,
content: &TextContent,
) -> Result<Vec<TextChunk>, ProcessingError> {
find_items_by_vector_similarity(
3,
content.text.clone(),
&self.db_client,
"text_chunk".to_string(),
&self.openai_client,
)
.await
}
async fn store_graph_entities(
&self,
entities: Vec<KnowledgeEntity>,
relationships: Vec<KnowledgeRelationship>,
) -> Result<(), ProcessingError> {
for entity in &entities {
debug!("Storing entity: {:?}", entity);
store_item(&self.db_client, entity.clone()).await?;
}
for relationship in &relationships {
debug!("Storing relationship: {:?}", relationship);
store_item(&self.db_client, relationship.clone()).await?;
}
info!(
"Stored {} entities and {} relationships",
entities.len(),
relationships.len()
);
Ok(())
}
async fn store_vector_chunks(&self, content: &TextContent) -> Result<(), ProcessingError> {
let splitter = TextSplitter::new(500..2000);
let chunks = splitter.chunks(&content.text);
// Could potentially process chunks in parallel with a bounded concurrent limit
for chunk in chunks {
let embedding = generate_embedding(&self.openai_client, chunk.to_string()).await?;
let text_chunk = TextChunk::new(content.id.to_string(), chunk.to_string(), embedding);
store_item(&self.db_client, text_chunk).await?;
}
Ok(())
}
}

3
src/ingress/mod.rs Normal file
View File

@@ -0,0 +1,3 @@
pub mod analysis;
pub mod content_processor;
pub mod types;

View File

@@ -1,10 +1,10 @@
use super::ingress_object::IngressObject;
use crate::storage::{db::SurrealDbClient, types::file_info::FileInfo};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use thiserror::Error; use thiserror::Error;
use tracing::info; use tracing::info;
use url::Url; use url::Url;
use uuid::Uuid; use uuid::Uuid;
use crate::surrealdb::SurrealDbClient;
use super::{file_info::FileInfo, ingress_object::IngressObject };
/// Struct defining the expected body when ingressing content. /// Struct defining the expected body when ingressing content.
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
@@ -105,4 +105,3 @@ pub async fn create_ingress_objects(
Ok(object_list) Ok(object_list)
} }

View File

@@ -1,8 +1,8 @@
use super::ingress_content::IngressContentError; use crate::storage::types::{file_info::FileInfo, text_content::TextContent};
use crate::models::file_info::FileInfo;
use crate::storage::types::text_content::TextContent;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::ingress_input::IngressContentError;
/// Knowledge object type, containing the content or reference to it, as well as metadata /// Knowledge object type, containing the content or reference to it, as well as metadata
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
pub enum IngressObject { pub enum IngressObject {

2
src/ingress/types/mod.rs Normal file
View File

@@ -0,0 +1,2 @@
pub mod ingress_input;
pub mod ingress_object;

View File

@@ -1,9 +1,7 @@
pub mod analysis;
pub mod error; pub mod error;
pub mod models; pub mod ingress;
pub mod rabbitmq; pub mod rabbitmq;
pub mod retrieval; pub mod retrieval;
pub mod routes; pub mod routes;
pub mod storage; pub mod storage;
pub mod surrealdb;
pub mod utils; pub mod utils;

View File

@@ -1,5 +0,0 @@
pub mod file_info;
pub mod graph_entities;
pub mod ingress_content;
pub mod ingress_object;
pub mod text_content;

View File

@@ -1,122 +0,0 @@
use crate::analysis::ingress::ingress_analyser::IngressAnalyzer;
use crate::retrieval::graph::find_entities_by_source_id;
use crate::retrieval::vector::find_items_by_vector_similarity;
use crate::storage::db::store_item;
use crate::storage::types::knowledge_entity::KnowledgeEntity;
use crate::storage::types::knowledge_relationship::KnowledgeRelationship;
use crate::storage::types::text_chunk::TextChunk;
use crate::storage::types::text_content::TextContent;
use crate::storage::types::StoredObject;
use crate::utils::embedding::generate_embedding;
use crate::{error::ProcessingError, surrealdb::SurrealDbClient};
use surrealdb::{engine::remote::ws::Client, Surreal};
use text_splitter::TextSplitter;
use tracing::{debug, info};
impl TextContent {
/// Processes the `TextContent` by sending it to an LLM, storing in a graph DB, and vector DB.
pub async fn process(&self) -> Result<(), ProcessingError> {
let db_client = SurrealDbClient::new().await?;
let openai_client = async_openai::Client::new();
// Store TextContent
store_item(&db_client, self.clone()).await?;
// Get related nodes
let closest_text_content: Vec<TextChunk> = find_items_by_vector_similarity(
3,
self.text.clone(),
&db_client,
"text_chunk".to_string(),
&openai_client,
)
.await?;
for node in closest_text_content {
let related_nodes: Vec<KnowledgeEntity> = find_entities_by_source_id(
node.source_id.to_owned(),
KnowledgeEntity::table_name().to_string(),
&db_client,
)
.await?;
for related_node in related_nodes {
info!("{:?}", related_node.name);
}
}
// Rebuild indexes
db_client.rebuild_indexes().await?;
// Step 1: Send to LLM for analysis
let analyser = IngressAnalyzer::new(&db_client, &openai_client);
let analysis = analyser
.analyze_content(&self.category, &self.instructions, &self.text)
.await?;
// Step 2: Convert LLM analysis to database entities
let (entities, relationships) = analysis
.to_database_entities(&self.id, &openai_client)
.await?;
// Step 3: Store in database
self.store_in_graph_db(entities, relationships, &db_client)
.await?;
// Step 4: Split text and store in Vector DB
self.store_in_vector_db(&db_client, &openai_client).await?;
Ok(())
}
async fn store_in_graph_db(
&self,
entities: Vec<KnowledgeEntity>,
relationships: Vec<KnowledgeRelationship>,
db_client: &Surreal<Client>,
) -> Result<(), ProcessingError> {
for entity in &entities {
debug!(
"{:?}, {:?}, {:?}",
&entity.id, &entity.name, &entity.description
);
store_item(db_client, entity.clone()).await?;
}
for relationship in &relationships {
debug!("{:?}", relationship);
store_item(db_client, relationship.clone()).await?;
}
info!(
"Inserted to database: {:?} entities, {:?} relationships",
entities.len(),
relationships.len()
);
Ok(())
}
/// Splits text and stores it in a vector database.
async fn store_in_vector_db(
&self,
db_client: &Surreal<Client>,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<(), ProcessingError> {
let max_characters = 500..2000;
let splitter = TextSplitter::new(max_characters);
let chunks = splitter.chunks(self.text.as_str());
for chunk in chunks {
info!("Chunk: {}", chunk);
let embedding = generate_embedding(openai_client, chunk.to_string()).await?;
let text_chunk = TextChunk::new(self.id.to_string(), chunk.to_string(), embedding);
store_item(db_client, text_chunk).await?;
}
Ok(())
}
}

View File

@@ -1,12 +1,16 @@
use lapin::{
message::Delivery, options::*, types::FieldTable, Channel, Consumer, Queue
};
use futures_lite::stream::StreamExt; use futures_lite::stream::StreamExt;
use lapin::{message::Delivery, options::*, types::FieldTable, Channel, Consumer, Queue};
use crate::models::{ingress_content::IngressContentError, ingress_object::IngressObject }; use crate::{
error::IngressConsumerError,
ingress::{
content_processor::ContentProcessor,
types::{ingress_input::IngressContentError, ingress_object::IngressObject},
},
};
use super::{RabbitMQCommon, RabbitMQCommonTrait, RabbitMQConfig, RabbitMQError}; use super::{RabbitMQCommon, RabbitMQCommonTrait, RabbitMQConfig, RabbitMQError};
use tracing::{info, error}; use tracing::{error, info};
/// Struct to consume messages from RabbitMQ. /// Struct to consume messages from RabbitMQ.
pub struct RabbitMQConsumer { pub struct RabbitMQConsumer {
@@ -26,18 +30,22 @@ impl RabbitMQConsumer {
/// * `Result<Self, RabbitMQError>` - The created client or an error. /// * `Result<Self, RabbitMQError>` - The created client or an error.
pub async fn new(config: &RabbitMQConfig) -> Result<Self, RabbitMQError> { pub async fn new(config: &RabbitMQConfig) -> Result<Self, RabbitMQError> {
let common = RabbitMQCommon::new(config).await?; let common = RabbitMQCommon::new(config).await?;
// Passively declare the exchange (it should already exist) // Passively declare the exchange (it should already exist)
common.declare_exchange(config, true).await?; common.declare_exchange(config, true).await?;
// Declare queue and bind it to the channel // Declare queue and bind it to the channel
let queue = Self::declare_queue(&common.channel, config).await?; let queue = Self::declare_queue(&common.channel, config).await?;
Self::bind_queue(&common.channel, &config.exchange, &queue, config).await?; Self::bind_queue(&common.channel, &config.exchange, &queue, config).await?;
// Initialize the consumer // Initialize the consumer
let consumer = Self::initialize_consumer(&common.channel, config).await?; let consumer = Self::initialize_consumer(&common.channel, config).await?;
Ok(Self { common, queue, consumer }) Ok(Self {
common,
queue,
consumer,
})
} }
/// Sets up the consumer based on the channel and `RabbitMQConfig`. /// Sets up the consumer based on the channel and `RabbitMQConfig`.
@@ -48,7 +56,10 @@ impl RabbitMQConsumer {
/// ///
/// # Returns /// # Returns
/// * `Result<Consumer, RabbitMQError>` - The initialized consumer or error /// * `Result<Consumer, RabbitMQError>` - The initialized consumer or error
async fn initialize_consumer(channel: &Channel, config: &RabbitMQConfig) -> Result<Consumer, RabbitMQError> { async fn initialize_consumer(
channel: &Channel,
config: &RabbitMQConfig,
) -> Result<Consumer, RabbitMQError> {
channel channel
.basic_consume( .basic_consume(
&config.queue, &config.queue,
@@ -56,7 +67,8 @@ impl RabbitMQConsumer {
BasicConsumeOptions::default(), BasicConsumeOptions::default(),
FieldTable::default(), FieldTable::default(),
) )
.await.map_err(|e| RabbitMQError::InitializeConsumerError(e.to_string())) .await
.map_err(|e| RabbitMQError::InitializeConsumerError(e.to_string()))
} }
/// Declares the queue based on the channel and `RabbitMQConfig`. /// Declares the queue based on the channel and `RabbitMQConfig`.
/// # Arguments /// # Arguments
@@ -65,7 +77,10 @@ impl RabbitMQConsumer {
/// ///
/// # Returns /// # Returns
/// * `Result<Queue, RabbitMQError>` - The initialized queue or error /// * `Result<Queue, RabbitMQError>` - The initialized queue or error
async fn declare_queue(channel: &Channel, config: &RabbitMQConfig) -> Result<Queue, RabbitMQError> { async fn declare_queue(
channel: &Channel,
config: &RabbitMQConfig,
) -> Result<Queue, RabbitMQError> {
channel channel
.queue_declare( .queue_declare(
&config.queue, &config.queue,
@@ -88,7 +103,12 @@ impl RabbitMQConsumer {
/// ///
/// # Returns /// # Returns
/// * `Result<(), RabbitMQError>` - Ok or error /// * `Result<(), RabbitMQError>` - Ok or error
async fn bind_queue(channel: &Channel, exchange: &str, queue: &Queue, config: &RabbitMQConfig) -> Result<(), RabbitMQError> { async fn bind_queue(
channel: &Channel,
exchange: &str,
queue: &Queue,
config: &RabbitMQConfig,
) -> Result<(), RabbitMQError> {
channel channel
.queue_bind( .queue_bind(
queue.name().as_str(), queue.name().as_str(),
@@ -111,7 +131,11 @@ impl RabbitMQConsumer {
/// `Delivery` - A delivery reciept, required to ack or nack the delivery. /// `Delivery` - A delivery reciept, required to ack or nack the delivery.
pub async fn consume(&self) -> Result<(IngressObject, Delivery), RabbitMQError> { pub async fn consume(&self) -> Result<(IngressObject, Delivery), RabbitMQError> {
// Receive the next message // Receive the next message
let delivery = self.consumer.clone().next().await let delivery = self
.consumer
.clone()
.next()
.await
.ok_or_else(|| RabbitMQError::ConsumeError("No message received".to_string()))? .ok_or_else(|| RabbitMQError::ConsumeError("No message received".to_string()))?
.map_err(|e| RabbitMQError::ConsumeError(e.to_string()))?; .map_err(|e| RabbitMQError::ConsumeError(e.to_string()))?;
@@ -131,7 +155,8 @@ impl RabbitMQConsumer {
/// # Returns /// # Returns
/// * `Result<(), RabbitMQError>` - Ok or error /// * `Result<(), RabbitMQError>` - Ok or error
pub async fn ack_delivery(&self, delivery: Delivery) -> Result<(), RabbitMQError> { pub async fn ack_delivery(&self, delivery: Delivery) -> Result<(), RabbitMQError> {
self.common.channel self.common
.channel
.basic_ack(delivery.delivery_tag, BasicAckOptions::default()) .basic_ack(delivery.delivery_tag, BasicAckOptions::default())
.await .await
.map_err(|e| RabbitMQError::ConsumeError(e.to_string()))?; .map_err(|e| RabbitMQError::ConsumeError(e.to_string()))?;
@@ -139,33 +164,22 @@ impl RabbitMQConsumer {
Ok(()) Ok(())
} }
/// Function to continually consume messages as they come in /// Function to continually consume messages as they come in
/// WIP pub async fn process_messages(&self) -> Result<(), IngressConsumerError> {
pub async fn process_messages(&self) -> Result<(), RabbitMQError> {
loop { loop {
match self.consume().await { match self.consume().await {
Ok((ingress, delivery)) => { Ok((ingress, delivery)) => {
info!("Received IngressObject: {:?}", ingress); info!("Received IngressObject: {:?}", ingress);
let text_content = ingress.to_text_content().await.unwrap(); // Get the TextContent
text_content.process().await.unwrap(); let text_content = ingress.to_text_content().await?;
// Initialize ContentProcessor which handles LLM analysis and storage
let content_processor = ContentProcessor::new().await?;
// Begin processing of TextContent
content_processor.process(&text_content).await?;
// Remove from queue
self.ack_delivery(delivery).await?; self.ack_delivery(delivery).await?;
// Process the IngressContent
// match self.handle_ingress_content(&ingress).await {
// Ok(_) => {
// info!("Successfully handled IngressContent");
// // Acknowledge the message
// if let Err(e) = self.ack_delivery(delivery).await {
// error!("Failed to acknowledge message: {:?}", e);
// }
// },
// Err(e) => {
// error!("Failed to handle IngressContent: {:?}", e);
// // For now, we'll acknowledge to remove it from the queue. Change to nack?
// if let Err(ack_err) = self.ack_delivery(delivery).await {
// error!("Failed to acknowledge message after handling error: {:?}", ack_err);
// }
// }
// }
} }
Err(RabbitMQError::ConsumeError(e)) => { Err(RabbitMQError::ConsumeError(e)) => {
error!("Error consuming message: {}", e); error!("Error consuming message: {}", e);
@@ -182,7 +196,10 @@ impl RabbitMQConsumer {
Ok(()) Ok(())
} }
pub async fn handle_ingress_content(&self, ingress: &IngressObject) -> Result<(), IngressContentError> { pub async fn handle_ingress_content(
&self,
ingress: &IngressObject,
) -> Result<(), IngressContentError> {
info!("Processing IngressContent: {:?}", ingress); info!("Processing IngressContent: {:?}", ingress);
unimplemented!() unimplemented!()

View File

@@ -1,13 +1,16 @@
pub mod publisher;
pub mod consumer; pub mod consumer;
pub mod publisher;
use axum::async_trait; use axum::async_trait;
use lapin::{ use lapin::{
options::ExchangeDeclareOptions, types::FieldTable, Channel, Connection, ConnectionProperties, ExchangeKind options::ExchangeDeclareOptions, types::FieldTable, Channel, Connection, ConnectionProperties,
ExchangeKind,
}; };
use thiserror::Error; use thiserror::Error;
use tracing::debug; use tracing::debug;
use crate::error::ProcessingError;
/// Possible errors related to RabbitMQ operations. /// Possible errors related to RabbitMQ operations.
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum RabbitMQError { pub enum RabbitMQError {
@@ -25,6 +28,8 @@ pub enum RabbitMQError {
InitializeConsumerError(String), InitializeConsumerError(String),
#[error("Queue error: {0}")] #[error("Queue error: {0}")]
QueueError(String), QueueError(String),
#[error("Processing error: {0}")]
ProcessingError(#[from] ProcessingError),
} }
/// Struct containing the information required to set up a client and connection. /// Struct containing the information required to set up a client and connection.
@@ -42,18 +47,21 @@ pub struct RabbitMQCommon {
pub channel: Channel, pub channel: Channel,
} }
/// Defines the behavior for RabbitMQCommon client operations. /// Defines the behavior for RabbitMQCommon client operations.
#[cfg_attr(test, mockall::automock)] #[cfg_attr(test, mockall::automock)]
#[async_trait] #[async_trait]
pub trait RabbitMQCommonTrait: Send + Sync { pub trait RabbitMQCommonTrait: Send + Sync {
async fn create_connection(config: &RabbitMQConfig) -> Result<Connection, RabbitMQError>; async fn create_connection(config: &RabbitMQConfig) -> Result<Connection, RabbitMQError>;
async fn declare_exchange(&self, config: &RabbitMQConfig, passive: bool) -> Result<(), RabbitMQError>; async fn declare_exchange(
} &self,
config: &RabbitMQConfig,
passive: bool,
) -> Result<(), RabbitMQError>;
}
impl RabbitMQCommon { impl RabbitMQCommon {
/// Sets up a new RabbitMQ client or error /// Sets up a new RabbitMQ client or error
/// ///
/// # Arguments /// # Arguments
/// * `RabbitMQConfig` - Configuration object with required information /// * `RabbitMQConfig` - Configuration object with required information
/// ///
@@ -62,7 +70,10 @@ impl RabbitMQCommon {
pub async fn new(config: &RabbitMQConfig) -> Result<Self, RabbitMQError> { pub async fn new(config: &RabbitMQConfig) -> Result<Self, RabbitMQError> {
let connection = Self::create_connection(config).await?; let connection = Self::create_connection(config).await?;
let channel = connection.create_channel().await?; let channel = connection.create_channel().await?;
Ok(Self { connection, channel }) Ok(Self {
connection,
channel,
})
} }
} }
@@ -77,7 +88,11 @@ impl RabbitMQCommonTrait for RabbitMQCommon {
} }
/// Function to declare the exchange required /// Function to declare the exchange required
async fn declare_exchange(&self, config: &RabbitMQConfig, passive: bool) -> Result<(), RabbitMQError> { async fn declare_exchange(
&self,
config: &RabbitMQConfig,
passive: bool,
) -> Result<(), RabbitMQError> {
debug!("Declaring exchange"); debug!("Declaring exchange");
self.channel self.channel
.exchange_declare( .exchange_declare(

View File

@@ -1,11 +1,9 @@
use lapin::{ use lapin::{options::*, publisher_confirm::Confirmation, BasicProperties};
options::*, publisher_confirm::Confirmation, BasicProperties,
};
use crate::models::ingress_object::IngressObject; use crate::ingress::types::ingress_object::IngressObject;
use super::{RabbitMQCommon, RabbitMQCommonTrait, RabbitMQConfig, RabbitMQError}; use super::{RabbitMQCommon, RabbitMQCommonTrait, RabbitMQConfig, RabbitMQError};
use tracing::{info, error}; use tracing::{error, info};
/// Struct to publish messages to RabbitMQ. /// Struct to publish messages to RabbitMQ.
pub struct RabbitMQProducer { pub struct RabbitMQProducer {
@@ -27,7 +25,7 @@ impl RabbitMQProducer {
let common = RabbitMQCommon::new(config).await?; let common = RabbitMQCommon::new(config).await?;
common.declare_exchange(config, false).await?; common.declare_exchange(config, false).await?;
Ok(Self { Ok(Self {
common, common,
exchange_name: config.exchange.clone(), exchange_name: config.exchange.clone(),
routing_key: config.routing_key.clone(), routing_key: config.routing_key.clone(),
@@ -39,19 +37,23 @@ impl RabbitMQProducer {
/// # Arguments /// # Arguments
/// * `self` - Reference to self /// * `self` - Reference to self
/// * `ingress_object` - A initialized IngressObject /// * `ingress_object` - A initialized IngressObject
/// ///
/// # Returns /// # Returns
/// * `Result<Confirmation, RabbitMQError>` - Confirmation of sent message or error /// * `Result<Confirmation, RabbitMQError>` - Confirmation of sent message or error
pub async fn publish(&self, ingress_object: &IngressObject) -> Result<Confirmation, RabbitMQError> { pub async fn publish(
&self,
ingress_object: &IngressObject,
) -> Result<Confirmation, RabbitMQError> {
// Serialize IngressObject to JSON // Serialize IngressObject to JSON
let payload = serde_json::to_vec(ingress_object) let payload = serde_json::to_vec(ingress_object).map_err(|e| {
.map_err(|e| { error!("Serialization Error: {}", e);
error!("Serialization Error: {}", e); RabbitMQError::PublishError(format!("Serialization Error: {}", e))
RabbitMQError::PublishError(format!("Serialization Error: {}", e)) })?;
})?;
// Publish the serialized payload to RabbitMQ // Publish the serialized payload to RabbitMQ
let confirmation = self.common.channel let confirmation = self
.common
.channel
.basic_publish( .basic_publish(
&self.exchange_name, &self.exchange_name,
&self.routing_key, &self.routing_key,
@@ -69,9 +71,12 @@ impl RabbitMQProducer {
error!("Publish Confirmation Error: {}", e); error!("Publish Confirmation Error: {}", e);
RabbitMQError::PublishError(format!("Publish Confirmation Error: {}", e)) RabbitMQError::PublishError(format!("Publish Confirmation Error: {}", e))
})?; })?;
info!("Published IngressObject to exchange '{}' with routing key '{}'", self.exchange_name, self.routing_key); info!(
"Published IngressObject to exchange '{}' with routing key '{}'",
self.exchange_name, self.routing_key
);
Ok(confirmation) Ok(confirmation)
} }
} }

View File

@@ -1,16 +1,14 @@
use std::sync::Arc; use crate::storage::{
use axum::{ db::SurrealDbClient,
extract::Path, response::IntoResponse, Extension, Json types::file_info::{FileError, FileInfo},
}; };
use axum_typed_multipart::{TypedMultipart, FieldData, TryFromMultipart}; use axum::{extract::Path, response::IntoResponse, Extension, Json};
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
use serde_json::json; use serde_json::json;
use std::sync::Arc;
use tempfile::NamedTempFile; use tempfile::NamedTempFile;
use tracing::info; use tracing::info;
use uuid::Uuid; use uuid::Uuid;
use crate::{
models::file_info::{FileError, FileInfo},
surrealdb::SurrealDbClient,
};
#[derive(Debug, TryFromMultipart)] #[derive(Debug, TryFromMultipart)]
pub struct FileUploadRequest { pub struct FileUploadRequest {
@@ -19,7 +17,7 @@ pub struct FileUploadRequest {
} }
/// Handler to upload a new file. /// Handler to upload a new file.
/// ///
/// Route: POST /file /// Route: POST /file
pub async fn upload_handler( pub async fn upload_handler(
Extension(db_client): Extension<Arc<SurrealDbClient>>, Extension(db_client): Extension<Arc<SurrealDbClient>>,
@@ -40,13 +38,12 @@ pub async fn upload_handler(
info!("File uploaded successfully: {:?}", file_info); info!("File uploaded successfully: {:?}", file_info);
// Return the response with HTTP 200 // Return the response with HTTP 200
Ok((axum::http::StatusCode::OK, Json(response))) Ok((axum::http::StatusCode::OK, Json(response)))
} }
/// Handler to retrieve file information by UUID. /// Handler to retrieve file information by UUID.
/// ///
/// Route: GET /file/:uuid /// Route: GET /file/:uuid
pub async fn get_file_handler( pub async fn get_file_handler(
Extension(db_client): Extension<Arc<SurrealDbClient>>, Extension(db_client): Extension<Arc<SurrealDbClient>>,
@@ -73,7 +70,7 @@ pub async fn get_file_handler(
} }
/// Handler to update an existing file by UUID. /// Handler to update an existing file by UUID.
/// ///
/// Route: PUT /file/:uuid /// Route: PUT /file/:uuid
pub async fn update_file_handler( pub async fn update_file_handler(
Extension(db_client): Extension<Arc<SurrealDbClient>>, Extension(db_client): Extension<Arc<SurrealDbClient>>,
@@ -101,7 +98,7 @@ pub async fn update_file_handler(
} }
/// Handler to delete a file by UUID. /// Handler to delete a file by UUID.
/// ///
/// Route: DELETE /file/:uuid /// Route: DELETE /file/:uuid
pub async fn delete_file_handler( pub async fn delete_file_handler(
Extension(db_client): Extension<Arc<SurrealDbClient>>, Extension(db_client): Extension<Arc<SurrealDbClient>>,

View File

@@ -1,7 +1,11 @@
use std::sync::Arc; use crate::{
ingress::types::ingress_input::{create_ingress_objects, IngressInput},
rabbitmq::publisher::RabbitMQProducer,
storage::db::SurrealDbClient,
};
use axum::{http::StatusCode, response::IntoResponse, Extension, Json}; use axum::{http::StatusCode, response::IntoResponse, Extension, Json};
use std::sync::Arc;
use tracing::{error, info}; use tracing::{error, info};
use crate::{models::ingress_content::{create_ingress_objects, IngressInput}, rabbitmq::publisher::RabbitMQProducer, surrealdb::SurrealDbClient};
pub async fn ingress_handler( pub async fn ingress_handler(
Extension(producer): Extension<Arc<RabbitMQProducer>>, Extension(producer): Extension<Arc<RabbitMQProducer>>,
@@ -23,7 +27,7 @@ pub async fn ingress_handler(
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
"Failed to publish message", "Failed to publish message",
) )
.into_response(); .into_response();
} }
} }
} }
@@ -31,11 +35,7 @@ pub async fn ingress_handler(
} }
Err(e) => { Err(e) => {
error!("Failed to process input: {:?}", e); error!("Failed to process input: {:?}", e);
( (StatusCode::INTERNAL_SERVER_ERROR, "Failed to process input").into_response()
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to process input",
)
.into_response()
} }
} }
} }

View File

@@ -12,7 +12,7 @@ use zettle_db::{
ingress::ingress_handler, ingress::ingress_handler,
queue_length::queue_length_handler, queue_length::queue_length_handler,
}, },
surrealdb::SurrealDbClient, storage::db::SurrealDbClient,
}; };
#[tokio::main(flavor = "multi_thread", worker_threads = 2)] #[tokio::main(flavor = "multi_thread", worker_threads = 2)]

View File

@@ -1,8 +1,67 @@
use surrealdb::{engine::remote::ws::Client, Surreal};
use crate::error::ProcessingError; use crate::error::ProcessingError;
use super::types::StoredObject; use super::types::StoredObject;
use std::ops::Deref;
use surrealdb::{
engine::remote::ws::{Client, Ws},
opt::auth::Root,
Error, Surreal,
};
#[derive(Clone)]
pub struct SurrealDbClient {
pub client: Surreal<Client>,
}
impl SurrealDbClient {
/// # Initialize a new datbase client
///
/// # Arguments
///
/// # Returns
/// * `SurrealDbClient` initialized
pub async fn new() -> Result<Self, Error> {
let db = Surreal::new::<Ws>("127.0.0.1:8000").await?;
// Sign in to database
db.signin(Root {
username: "root_user",
password: "root_password",
})
.await?;
// Set namespace
db.use_ns("test").use_db("test").await?;
Ok(SurrealDbClient { client: db })
}
pub async fn rebuild_indexes(&self) -> Result<(), Error> {
self.client
.query("REBUILD INDEX IF EXISTS idx_embedding ON text_chunk")
.await?;
self.client
.query("REBUILD INDEX IF EXISTS embeddings ON knowledge_entity")
.await?;
Ok(())
}
pub async fn drop_table<T>(&self) -> Result<(), Error>
where
T: StoredObject + Send + Sync + 'static,
{
let _deleted: Vec<T> = self.client.delete(T::table_name()).await?;
Ok(())
}
}
impl Deref for SurrealDbClient {
type Target = Surreal<Client>;
fn deref(&self) -> &Self::Target {
&self.client
}
}
/// Operation to store a object in SurrealDB, requires the struct to implement StoredObject /// Operation to store a object in SurrealDB, requires the struct to implement StoredObject
/// ///

View File

@@ -1,22 +1,26 @@
use axum::{http::StatusCode, response::{IntoResponse, Response}, Json}; use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use axum_typed_multipart::FieldData; use axum_typed_multipart::FieldData;
use mime_guess::from_path; use mime_guess::from_path;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json; use serde_json::json;
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
use surrealdb::RecordId;
use std::{ use std::{
io::{BufReader, Read}, io::{BufReader, Read},
path::{Path, PathBuf}, path::{Path, PathBuf},
}; };
use surrealdb::RecordId;
use tempfile::NamedTempFile; use tempfile::NamedTempFile;
use thiserror::Error; use thiserror::Error;
use tracing::{debug, info}; use tracing::{debug, info};
use uuid::Uuid; use uuid::Uuid;
use crate::surrealdb::SurrealDbClient; use crate::storage::db::SurrealDbClient;
#[derive(Debug,Deserialize)] #[derive(Debug, Deserialize)]
struct Record { struct Record {
#[allow(dead_code)] #[allow(dead_code)]
id: RecordId, id: RecordId,
@@ -72,7 +76,6 @@ pub enum FileError {
#[error("Deserialization error: {0}")] #[error("Deserialization error: {0}")]
DeserializationError(String), DeserializationError(String),
// Add more error variants as needed. // Add more error variants as needed.
} }
@@ -82,16 +85,30 @@ impl IntoResponse for FileError {
FileError::Io(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Internal server error"), FileError::Io(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Internal server error"),
FileError::Utf8(_) => (StatusCode::BAD_REQUEST, "Invalid UTF-8 data"), FileError::Utf8(_) => (StatusCode::BAD_REQUEST, "Invalid UTF-8 data"),
FileError::MimeDetection(_) => (StatusCode::BAD_REQUEST, "MIME type detection failed"), FileError::MimeDetection(_) => (StatusCode::BAD_REQUEST, "MIME type detection failed"),
FileError::UnsupportedMime(_) => (StatusCode::UNSUPPORTED_MEDIA_TYPE, "Unsupported MIME type"), FileError::UnsupportedMime(_) => {
(StatusCode::UNSUPPORTED_MEDIA_TYPE, "Unsupported MIME type")
}
FileError::FileNotFound(_) => (StatusCode::NOT_FOUND, "File not found"), FileError::FileNotFound(_) => (StatusCode::NOT_FOUND, "File not found"),
FileError::DuplicateFile(_) => (StatusCode::CONFLICT, "Duplicate file detected"), FileError::DuplicateFile(_) => (StatusCode::CONFLICT, "Duplicate file detected"),
FileError::HashCollision => (StatusCode::INTERNAL_SERVER_ERROR, "Hash collision detected"), FileError::HashCollision => {
(StatusCode::INTERNAL_SERVER_ERROR, "Hash collision detected")
}
FileError::InvalidUuid(_) => (StatusCode::BAD_REQUEST, "Invalid UUID format"), FileError::InvalidUuid(_) => (StatusCode::BAD_REQUEST, "Invalid UUID format"),
FileError::MissingFileName => (StatusCode::BAD_REQUEST, "Missing file name in metadata"), FileError::MissingFileName => {
FileError::PersistError(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Failed to persist file"), (StatusCode::BAD_REQUEST, "Missing file name in metadata")
FileError::SerializationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Serialization error"), }
FileError::DeserializationError(_) => (StatusCode::BAD_REQUEST, "Deserialization error"), FileError::PersistError(_) => {
FileError::SurrealError(_) =>(StatusCode::INTERNAL_SERVER_ERROR, "Serialization error"), (StatusCode::INTERNAL_SERVER_ERROR, "Failed to persist file")
}
FileError::SerializationError(_) => {
(StatusCode::INTERNAL_SERVER_ERROR, "Serialization error")
}
FileError::DeserializationError(_) => {
(StatusCode::BAD_REQUEST, "Deserialization error")
}
FileError::SurrealError(_) => {
(StatusCode::INTERNAL_SERVER_ERROR, "Serialization error")
}
}; };
let body = Json(json!({ let body = Json(json!({
@@ -163,7 +180,11 @@ impl FileInfo {
/// ///
/// # Returns /// # Returns
/// * `Result<FileInfo, FileError>` - The updated `FileInfo` or an error. /// * `Result<FileInfo, FileError>` - The updated `FileInfo` or an error.
pub async fn update(uuid: Uuid, new_field_data: FieldData<NamedTempFile>, db_client: &SurrealDbClient) -> Result<FileInfo, FileError> { pub async fn update(
uuid: Uuid,
new_field_data: FieldData<NamedTempFile>,
db_client: &SurrealDbClient,
) -> Result<FileInfo, FileError> {
let new_file = new_field_data.contents; let new_file = new_field_data.contents;
let new_metadata = new_field_data.metadata; let new_metadata = new_field_data.metadata;
@@ -184,7 +205,8 @@ impl FileInfo {
let sanitized_new_file_name = sanitize_file_name(&new_file_name); let sanitized_new_file_name = sanitize_file_name(&new_file_name);
// Persist the new file // Persist the new file
let new_persisted_path = Self::persist_file(&uuid, new_file, &sanitized_new_file_name).await?; let new_persisted_path =
Self::persist_file(&uuid, new_file, &sanitized_new_file_name).await?;
// Guess the new MIME type // Guess the new MIME type
let new_mime_type = Self::guess_mime_type(&new_persisted_path); let new_mime_type = Self::guess_mime_type(&new_persisted_path);
@@ -202,7 +224,7 @@ impl FileInfo {
}; };
// Save the new item // Save the new item
Self::create_record(&updated_file_info,db_client).await?; Self::create_record(&updated_file_info, db_client).await?;
// Optionally, delete the old file from the filesystem if it's no longer referenced // Optionally, delete the old file from the filesystem if it's no longer referenced
// This requires reference counting or checking if other FileInfo entries point to the same SHA // This requires reference counting or checking if other FileInfo entries point to the same SHA
@@ -221,26 +243,35 @@ impl FileInfo {
/// * `Result<(), FileError>` - Empty result or an error. /// * `Result<(), FileError>` - Empty result or an error.
pub async fn delete(uuid: Uuid, db_client: &SurrealDbClient) -> Result<(), FileError> { pub async fn delete(uuid: Uuid, db_client: &SurrealDbClient) -> Result<(), FileError> {
// Retrieve FileInfo to get SHA256 and path // Retrieve FileInfo to get SHA256 and path
let file_info = Self::get_by_uuid(uuid, db_client).await?; let file_info = Self::get_by_uuid(uuid, db_client).await?;
// Delete the file from the filesystem // Delete the file from the filesystem
let file_path = Path::new(&file_info.path); let file_path = Path::new(&file_info.path);
if file_path.exists() { if file_path.exists() {
tokio::fs::remove_file(file_path).await.map_err(FileError::Io)?; tokio::fs::remove_file(file_path)
.await
.map_err(FileError::Io)?;
info!("Deleted file at path: {}", file_info.path); info!("Deleted file at path: {}", file_info.path);
} else { } else {
info!("File path does not exist, skipping deletion: {}", file_info.path); info!(
"File path does not exist, skipping deletion: {}",
file_info.path
);
} }
// Delete the FileInfo from database // Delete the FileInfo from database
Self::delete_record(&file_info.sha256, db_client).await?; Self::delete_record(&file_info.sha256, db_client).await?;
// Remove the UUID directory if empty // Remove the UUID directory if empty
let uuid_dir = file_path.parent().ok_or(FileError::FileNotFound(uuid.to_string()))?; let uuid_dir = file_path
.parent()
.ok_or(FileError::FileNotFound(uuid.to_string()))?;
if uuid_dir.exists() { if uuid_dir.exists() {
let mut entries = tokio::fs::read_dir(uuid_dir).await.map_err(FileError::Io)?; let mut entries = tokio::fs::read_dir(uuid_dir).await.map_err(FileError::Io)?;
if entries.next_entry().await?.is_none() { if entries.next_entry().await?.is_none() {
tokio::fs::remove_dir(uuid_dir).await.map_err(FileError::Io)?; tokio::fs::remove_dir(uuid_dir)
.await
.map_err(FileError::Io)?;
info!("Deleted empty UUID directory: {:?}", uuid_dir); info!("Deleted empty UUID directory: {:?}", uuid_dir);
} }
} }
@@ -257,19 +288,26 @@ impl FileInfo {
/// ///
/// # Returns /// # Returns
/// * `Result<PathBuf, FileError>` - The persisted file path or an error. /// * `Result<PathBuf, FileError>` - The persisted file path or an error.
async fn persist_file(uuid: &Uuid, file: NamedTempFile, file_name: &str) -> Result<PathBuf, FileError> { async fn persist_file(
uuid: &Uuid,
file: NamedTempFile,
file_name: &str,
) -> Result<PathBuf, FileError> {
let base_dir = Path::new("./data"); let base_dir = Path::new("./data");
let uuid_dir = base_dir.join(uuid.to_string()); let uuid_dir = base_dir.join(uuid.to_string());
// Create the UUID directory if it doesn't exist // Create the UUID directory if it doesn't exist
tokio::fs::create_dir_all(&uuid_dir).await.map_err(FileError::Io)?; tokio::fs::create_dir_all(&uuid_dir)
.await
.map_err(FileError::Io)?;
// Define the final file path // Define the final file path
let final_path = uuid_dir.join(file_name); let final_path = uuid_dir.join(file_name);
info!("Final path: {:?}", final_path); info!("Final path: {:?}", final_path);
// Persist the temporary file to the final path // Persist the temporary file to the final path
file.persist(&final_path).map_err(|e| FileError::PersistError(e.to_string()))?; file.persist(&final_path)
.map_err(|e| FileError::PersistError(e.to_string()))?;
info!("Persisted file to {:?}", final_path); info!("Persisted file to {:?}", final_path);
@@ -313,7 +351,6 @@ impl FileInfo {
.to_string() .to_string()
} }
/// Creates a new record in SurrealDB for the given `FileInfo`. /// Creates a new record in SurrealDB for the given `FileInfo`.
/// ///
/// # Arguments /// # Arguments
@@ -323,16 +360,19 @@ impl FileInfo {
/// # Returns /// # Returns
/// * `Result<(), FileError>` - Empty result or an error. /// * `Result<(), FileError>` - Empty result or an error.
async fn create_record(file_info: &FileInfo, db_client: &SurrealDbClient) -> Result<(), FileError> { async fn create_record(
file_info: &FileInfo,
db_client: &SurrealDbClient,
) -> Result<(), FileError> {
// Create the record // Create the record
let _created: Option<Record> = db_client let _created: Option<Record> = db_client
.client .client
.create(("file", &file_info.uuid )) .create(("file", &file_info.uuid))
.content(file_info.clone()) .content(file_info.clone())
.await?; .await?;
debug!("{:?}",_created); debug!("{:?}", _created);
info!("Created FileInfo record with SHA256: {}", file_info.sha256); info!("Created FileInfo record with SHA256: {}", file_info.sha256);
Ok(()) Ok(())
@@ -346,11 +386,17 @@ impl FileInfo {
/// ///
/// # Returns /// # Returns
/// * `Result<FileInfo, FileError>` - The `FileInfo` or `Error` if not found. /// * `Result<FileInfo, FileError>` - The `FileInfo` or `Error` if not found.
pub async fn get_by_uuid(uuid: Uuid, db_client: &SurrealDbClient) -> Result<FileInfo, FileError> { pub async fn get_by_uuid(
uuid: Uuid,
db_client: &SurrealDbClient,
) -> Result<FileInfo, FileError> {
let query = format!("SELECT * FROM file WHERE uuid = '{}'", uuid); let query = format!("SELECT * FROM file WHERE uuid = '{}'", uuid);
let response: Vec<FileInfo> = db_client.client.query(query).await?.take(0)?; let response: Vec<FileInfo> = db_client.client.query(query).await?.take(0)?;
response.into_iter().next().ok_or(FileError::FileNotFound(uuid.to_string())) response
.into_iter()
.next()
.ok_or(FileError::FileNotFound(uuid.to_string()))
} }
/// Retrieves a `FileInfo` by SHA256. /// Retrieves a `FileInfo` by SHA256.
@@ -367,7 +413,11 @@ impl FileInfo {
debug!("{:?}", response); debug!("{:?}", response);
response.into_iter().next().ok_or(FileError::FileNotFound(sha256.to_string())) } response
.into_iter()
.next()
.ok_or(FileError::FileNotFound(sha256.to_string()))
}
/// Deletes a `FileInfo` record by SHA256. /// Deletes a `FileInfo` record by SHA256.
/// ///
@@ -381,10 +431,7 @@ impl FileInfo {
let table = "file"; let table = "file";
let primary_key = sha256; let primary_key = sha256;
let _created: Option<Record> = db_client let _created: Option<Record> = db_client.client.delete((table, primary_key)).await?;
.client
.delete((table, primary_key))
.await?;
info!("Deleted FileInfo record with SHA256: {}", sha256); info!("Deleted FileInfo record with SHA256: {}", sha256);
@@ -395,7 +442,14 @@ impl FileInfo {
/// Sanitizes the file name to prevent security vulnerabilities like directory traversal. /// Sanitizes the file name to prevent security vulnerabilities like directory traversal.
/// Replaces any non-alphanumeric characters (excluding '.' and '_') with underscores. /// Replaces any non-alphanumeric characters (excluding '.' and '_') with underscores.
fn sanitize_file_name(file_name: &str) -> String { fn sanitize_file_name(file_name: &str) -> String {
file_name.chars() file_name
.map(|c| if c.is_ascii_alphanumeric() || c == '.' || c == '_' { c } else { '_' }) .chars()
.map(|c| {
if c.is_ascii_alphanumeric() || c == '.' || c == '_' {
c
} else {
'_'
}
})
.collect() .collect()
} }

View File

@@ -1,5 +1,6 @@
use axum::async_trait; use axum::async_trait;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub mod file_info;
pub mod knowledge_entity; pub mod knowledge_entity;
pub mod knowledge_relationship; pub mod knowledge_relationship;
pub mod text_chunk; pub mod text_chunk;

View File

@@ -1,8 +1,9 @@
use uuid::Uuid; use uuid::Uuid;
use crate::models::file_info::FileInfo;
use crate::stored_object; use crate::stored_object;
use super::file_info::FileInfo;
stored_object!(TextContent, "text_content", { stored_object!(TextContent, "text_content", {
text: String, text: String,
file_info: Option<FileInfo>, file_info: Option<FileInfo>,

View File

@@ -1,63 +0,0 @@
use std::ops::Deref;
use surrealdb::{
engine::remote::ws::{Client, Ws},
opt::auth::Root,
Error, Surreal,
};
use crate::storage::types::StoredObject;
#[derive(Clone)]
pub struct SurrealDbClient {
pub client: Surreal<Client>,
}
impl SurrealDbClient {
/// # Initialize a new datbase client
///
/// # Arguments
///
/// # Returns
/// * `SurrealDbClient` initialized
pub async fn new() -> Result<Self, Error> {
let db = Surreal::new::<Ws>("127.0.0.1:8000").await?;
// Sign in to database
db.signin(Root {
username: "root_user",
password: "root_password",
})
.await?;
// Set namespace
db.use_ns("test").use_db("test").await?;
Ok(SurrealDbClient { client: db })
}
pub async fn rebuild_indexes(&self) -> Result<(), Error> {
self.client
.query("REBUILD INDEX IF EXISTS idx_embedding ON text_chunk")
.await?;
self.client
.query("REBUILD INDEX IF EXISTS embeddings ON knowledge_entity")
.await?;
Ok(())
}
pub async fn drop_table<T>(&self) -> Result<(), Error>
where
T: StoredObject + Send + Sync + 'static,
{
let _deleted: Vec<T> = self.client.delete(T::table_name()).await?;
Ok(())
}
}
impl Deref for SurrealDbClient {
type Target = Surreal<Client>;
fn deref(&self) -> &Self::Target {
&self.client
}
}