mirror of
https://github.com/perstarkse/minne.git
synced 2026-03-22 09:29:51 +01:00
feat: refactoring complete?
This commit is contained in:
@@ -1 +0,0 @@
|
||||
pub mod ingress;
|
||||
14
src/error.rs
14
src/error.rs
@@ -2,6 +2,8 @@ use async_openai::error::OpenAIError;
|
||||
use thiserror::Error;
|
||||
use tokio::task::JoinError;
|
||||
|
||||
use crate::{ingress::types::ingress_input::IngressContentError, rabbitmq::RabbitMQError};
|
||||
|
||||
/// Error types for processing `TextContent`.
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ProcessingError {
|
||||
@@ -23,3 +25,15 @@ pub enum ProcessingError {
|
||||
#[error("Task join error: {0}")]
|
||||
JoinError(#[from] JoinError),
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum IngressConsumerError {
|
||||
#[error("RabbitMQ error: {0}")]
|
||||
RabbitMQ(#[from] RabbitMQError),
|
||||
|
||||
#[error("Processing error: {0}")]
|
||||
Processing(#[from] ProcessingError),
|
||||
|
||||
#[error("Ingress content error: {0}")]
|
||||
IngressContent(#[from] IngressContentError),
|
||||
}
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
use crate::{
|
||||
analysis::ingress::{
|
||||
prompt::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE},
|
||||
types::llm_analysis_result::LLMGraphAnalysisResult,
|
||||
},
|
||||
error::ProcessingError,
|
||||
ingress::analysis::prompt::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE},
|
||||
retrieval::vector::find_items_by_vector_similarity,
|
||||
storage::types::{knowledge_entity::KnowledgeEntity, StoredObject},
|
||||
};
|
||||
@@ -15,7 +12,9 @@ use async_openai::types::{
|
||||
use serde_json::json;
|
||||
use surrealdb::engine::remote::ws::Client;
|
||||
use surrealdb::Surreal;
|
||||
use tracing::{debug, instrument};
|
||||
use tracing::debug;
|
||||
|
||||
use super::types::llm_analysis_result::LLMGraphAnalysisResult;
|
||||
|
||||
pub struct IngressAnalyzer<'a> {
|
||||
db_client: &'a Surreal<Client>,
|
||||
@@ -33,7 +32,6 @@ impl<'a> IngressAnalyzer<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
pub async fn analyze_content(
|
||||
&self,
|
||||
category: &str,
|
||||
@@ -48,7 +46,6 @@ impl<'a> IngressAnalyzer<'a> {
|
||||
self.perform_analysis(llm_request).await
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn find_similar_entities(
|
||||
&self,
|
||||
category: &str,
|
||||
@@ -70,7 +67,6 @@ impl<'a> IngressAnalyzer<'a> {
|
||||
.await
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn prepare_llm_request(
|
||||
&self,
|
||||
category: &str,
|
||||
@@ -108,7 +104,7 @@ impl<'a> IngressAnalyzer<'a> {
|
||||
};
|
||||
|
||||
CreateChatCompletionRequestArgs::default()
|
||||
.model("gpt-4-mini")
|
||||
.model("gpt-4o-mini")
|
||||
.temperature(0.2)
|
||||
.max_tokens(2048u32)
|
||||
.messages([
|
||||
@@ -120,7 +116,6 @@ impl<'a> IngressAnalyzer<'a> {
|
||||
.map_err(|e| ProcessingError::LLMParsingError(e.to_string()))
|
||||
}
|
||||
|
||||
#[instrument(skip(self, request))]
|
||||
async fn perform_analysis(
|
||||
&self,
|
||||
request: CreateChatCompletionRequest,
|
||||
@@ -5,14 +5,15 @@ use tokio::task;
|
||||
|
||||
use crate::{
|
||||
error::ProcessingError,
|
||||
models::graph_entities::GraphMapper,
|
||||
storage::types::{
|
||||
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
||||
knowledge_relationship::KnowledgeRelationship,
|
||||
},
|
||||
utils::embedding::generate_embedding,
|
||||
};
|
||||
use futures::future::try_join_all; // For future parallelization
|
||||
use futures::future::try_join_all;
|
||||
|
||||
use super::graph_mapper::GraphMapper; // For future parallelization
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct LLMKnowledgeEntity {
|
||||
@@ -1 +1,2 @@
|
||||
pub mod graph_mapper;
|
||||
pub mod llm_analysis_result;
|
||||
119
src/ingress/content_processor.rs
Normal file
119
src/ingress/content_processor.rs
Normal file
@@ -0,0 +1,119 @@
|
||||
use text_splitter::TextSplitter;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::{
|
||||
error::ProcessingError,
|
||||
retrieval::vector::find_items_by_vector_similarity,
|
||||
storage::{
|
||||
db::{store_item, SurrealDbClient},
|
||||
types::{
|
||||
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
|
||||
text_chunk::TextChunk, text_content::TextContent,
|
||||
},
|
||||
},
|
||||
utils::embedding::generate_embedding,
|
||||
};
|
||||
|
||||
use super::analysis::{
|
||||
ingress_analyser::IngressAnalyzer, types::llm_analysis_result::LLMGraphAnalysisResult,
|
||||
};
|
||||
|
||||
pub struct ContentProcessor {
|
||||
db_client: SurrealDbClient,
|
||||
openai_client: async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
}
|
||||
|
||||
impl ContentProcessor {
|
||||
pub async fn new() -> Result<Self, ProcessingError> {
|
||||
Ok(Self {
|
||||
db_client: SurrealDbClient::new().await?,
|
||||
openai_client: async_openai::Client::new(),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn process(&self, content: &TextContent) -> Result<(), ProcessingError> {
|
||||
// Store original content
|
||||
store_item(&self.db_client, content.clone()).await?;
|
||||
|
||||
// Process in parallel where possible
|
||||
let (analysis, _similar_chunks) = tokio::try_join!(
|
||||
self.perform_semantic_analysis(content),
|
||||
self.find_similar_content(content),
|
||||
)?;
|
||||
|
||||
// Convert and store entities
|
||||
let (entities, relationships) = analysis
|
||||
.to_database_entities(&content.id, &self.openai_client)
|
||||
.await?;
|
||||
|
||||
// Store everything
|
||||
tokio::try_join!(
|
||||
self.store_graph_entities(entities, relationships),
|
||||
self.store_vector_chunks(content),
|
||||
)?;
|
||||
|
||||
self.db_client.rebuild_indexes().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn perform_semantic_analysis(
|
||||
&self,
|
||||
content: &TextContent,
|
||||
) -> Result<LLMGraphAnalysisResult, ProcessingError> {
|
||||
let analyser = IngressAnalyzer::new(&self.db_client, &self.openai_client);
|
||||
analyser
|
||||
.analyze_content(&content.category, &content.instructions, &content.text)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn find_similar_content(
|
||||
&self,
|
||||
content: &TextContent,
|
||||
) -> Result<Vec<TextChunk>, ProcessingError> {
|
||||
find_items_by_vector_similarity(
|
||||
3,
|
||||
content.text.clone(),
|
||||
&self.db_client,
|
||||
"text_chunk".to_string(),
|
||||
&self.openai_client,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn store_graph_entities(
|
||||
&self,
|
||||
entities: Vec<KnowledgeEntity>,
|
||||
relationships: Vec<KnowledgeRelationship>,
|
||||
) -> Result<(), ProcessingError> {
|
||||
for entity in &entities {
|
||||
debug!("Storing entity: {:?}", entity);
|
||||
store_item(&self.db_client, entity.clone()).await?;
|
||||
}
|
||||
|
||||
for relationship in &relationships {
|
||||
debug!("Storing relationship: {:?}", relationship);
|
||||
store_item(&self.db_client, relationship.clone()).await?;
|
||||
}
|
||||
|
||||
info!(
|
||||
"Stored {} entities and {} relationships",
|
||||
entities.len(),
|
||||
relationships.len()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn store_vector_chunks(&self, content: &TextContent) -> Result<(), ProcessingError> {
|
||||
let splitter = TextSplitter::new(500..2000);
|
||||
let chunks = splitter.chunks(&content.text);
|
||||
|
||||
// Could potentially process chunks in parallel with a bounded concurrent limit
|
||||
for chunk in chunks {
|
||||
let embedding = generate_embedding(&self.openai_client, chunk.to_string()).await?;
|
||||
let text_chunk = TextChunk::new(content.id.to_string(), chunk.to_string(), embedding);
|
||||
store_item(&self.db_client, text_chunk).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
3
src/ingress/mod.rs
Normal file
3
src/ingress/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod analysis;
|
||||
pub mod content_processor;
|
||||
pub mod types;
|
||||
@@ -1,10 +1,10 @@
|
||||
use super::ingress_object::IngressObject;
|
||||
use crate::storage::{db::SurrealDbClient, types::file_info::FileInfo};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use tracing::info;
|
||||
use url::Url;
|
||||
use uuid::Uuid;
|
||||
use crate::surrealdb::SurrealDbClient;
|
||||
use super::{file_info::FileInfo, ingress_object::IngressObject };
|
||||
|
||||
/// Struct defining the expected body when ingressing content.
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
@@ -105,4 +105,3 @@ pub async fn create_ingress_objects(
|
||||
|
||||
Ok(object_list)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use super::ingress_content::IngressContentError;
|
||||
use crate::models::file_info::FileInfo;
|
||||
use crate::storage::types::text_content::TextContent;
|
||||
use crate::storage::types::{file_info::FileInfo, text_content::TextContent};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::ingress_input::IngressContentError;
|
||||
|
||||
/// Knowledge object type, containing the content or reference to it, as well as metadata
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub enum IngressObject {
|
||||
2
src/ingress/types/mod.rs
Normal file
2
src/ingress/types/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod ingress_input;
|
||||
pub mod ingress_object;
|
||||
@@ -1,9 +1,7 @@
|
||||
pub mod analysis;
|
||||
pub mod error;
|
||||
pub mod models;
|
||||
pub mod ingress;
|
||||
pub mod rabbitmq;
|
||||
pub mod retrieval;
|
||||
pub mod routes;
|
||||
pub mod storage;
|
||||
pub mod surrealdb;
|
||||
pub mod utils;
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
pub mod file_info;
|
||||
pub mod graph_entities;
|
||||
pub mod ingress_content;
|
||||
pub mod ingress_object;
|
||||
pub mod text_content;
|
||||
@@ -1,122 +0,0 @@
|
||||
use crate::analysis::ingress::ingress_analyser::IngressAnalyzer;
|
||||
use crate::retrieval::graph::find_entities_by_source_id;
|
||||
use crate::retrieval::vector::find_items_by_vector_similarity;
|
||||
use crate::storage::db::store_item;
|
||||
use crate::storage::types::knowledge_entity::KnowledgeEntity;
|
||||
use crate::storage::types::knowledge_relationship::KnowledgeRelationship;
|
||||
use crate::storage::types::text_chunk::TextChunk;
|
||||
use crate::storage::types::text_content::TextContent;
|
||||
use crate::storage::types::StoredObject;
|
||||
use crate::utils::embedding::generate_embedding;
|
||||
use crate::{error::ProcessingError, surrealdb::SurrealDbClient};
|
||||
use surrealdb::{engine::remote::ws::Client, Surreal};
|
||||
use text_splitter::TextSplitter;
|
||||
use tracing::{debug, info};
|
||||
|
||||
impl TextContent {
|
||||
/// Processes the `TextContent` by sending it to an LLM, storing in a graph DB, and vector DB.
|
||||
pub async fn process(&self) -> Result<(), ProcessingError> {
|
||||
let db_client = SurrealDbClient::new().await?;
|
||||
let openai_client = async_openai::Client::new();
|
||||
|
||||
// Store TextContent
|
||||
store_item(&db_client, self.clone()).await?;
|
||||
|
||||
// Get related nodes
|
||||
let closest_text_content: Vec<TextChunk> = find_items_by_vector_similarity(
|
||||
3,
|
||||
self.text.clone(),
|
||||
&db_client,
|
||||
"text_chunk".to_string(),
|
||||
&openai_client,
|
||||
)
|
||||
.await?;
|
||||
|
||||
for node in closest_text_content {
|
||||
let related_nodes: Vec<KnowledgeEntity> = find_entities_by_source_id(
|
||||
node.source_id.to_owned(),
|
||||
KnowledgeEntity::table_name().to_string(),
|
||||
&db_client,
|
||||
)
|
||||
.await?;
|
||||
for related_node in related_nodes {
|
||||
info!("{:?}", related_node.name);
|
||||
}
|
||||
}
|
||||
|
||||
// Rebuild indexes
|
||||
db_client.rebuild_indexes().await?;
|
||||
|
||||
// Step 1: Send to LLM for analysis
|
||||
let analyser = IngressAnalyzer::new(&db_client, &openai_client);
|
||||
let analysis = analyser
|
||||
.analyze_content(&self.category, &self.instructions, &self.text)
|
||||
.await?;
|
||||
|
||||
// Step 2: Convert LLM analysis to database entities
|
||||
let (entities, relationships) = analysis
|
||||
.to_database_entities(&self.id, &openai_client)
|
||||
.await?;
|
||||
|
||||
// Step 3: Store in database
|
||||
self.store_in_graph_db(entities, relationships, &db_client)
|
||||
.await?;
|
||||
|
||||
// Step 4: Split text and store in Vector DB
|
||||
self.store_in_vector_db(&db_client, &openai_client).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn store_in_graph_db(
|
||||
&self,
|
||||
entities: Vec<KnowledgeEntity>,
|
||||
relationships: Vec<KnowledgeRelationship>,
|
||||
db_client: &Surreal<Client>,
|
||||
) -> Result<(), ProcessingError> {
|
||||
for entity in &entities {
|
||||
debug!(
|
||||
"{:?}, {:?}, {:?}",
|
||||
&entity.id, &entity.name, &entity.description
|
||||
);
|
||||
|
||||
store_item(db_client, entity.clone()).await?;
|
||||
}
|
||||
|
||||
for relationship in &relationships {
|
||||
debug!("{:?}", relationship);
|
||||
|
||||
store_item(db_client, relationship.clone()).await?;
|
||||
}
|
||||
|
||||
info!(
|
||||
"Inserted to database: {:?} entities, {:?} relationships",
|
||||
entities.len(),
|
||||
relationships.len()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Splits text and stores it in a vector database.
|
||||
async fn store_in_vector_db(
|
||||
&self,
|
||||
db_client: &Surreal<Client>,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
) -> Result<(), ProcessingError> {
|
||||
let max_characters = 500..2000;
|
||||
let splitter = TextSplitter::new(max_characters);
|
||||
|
||||
let chunks = splitter.chunks(self.text.as_str());
|
||||
|
||||
for chunk in chunks {
|
||||
info!("Chunk: {}", chunk);
|
||||
let embedding = generate_embedding(openai_client, chunk.to_string()).await?;
|
||||
let text_chunk = TextChunk::new(self.id.to_string(), chunk.to_string(), embedding);
|
||||
|
||||
store_item(db_client, text_chunk).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,16 @@
|
||||
use lapin::{
|
||||
message::Delivery, options::*, types::FieldTable, Channel, Consumer, Queue
|
||||
};
|
||||
use futures_lite::stream::StreamExt;
|
||||
use lapin::{message::Delivery, options::*, types::FieldTable, Channel, Consumer, Queue};
|
||||
|
||||
use crate::models::{ingress_content::IngressContentError, ingress_object::IngressObject };
|
||||
use crate::{
|
||||
error::IngressConsumerError,
|
||||
ingress::{
|
||||
content_processor::ContentProcessor,
|
||||
types::{ingress_input::IngressContentError, ingress_object::IngressObject},
|
||||
},
|
||||
};
|
||||
|
||||
use super::{RabbitMQCommon, RabbitMQCommonTrait, RabbitMQConfig, RabbitMQError};
|
||||
use tracing::{info, error};
|
||||
use tracing::{error, info};
|
||||
|
||||
/// Struct to consume messages from RabbitMQ.
|
||||
pub struct RabbitMQConsumer {
|
||||
@@ -26,18 +30,22 @@ impl RabbitMQConsumer {
|
||||
/// * `Result<Self, RabbitMQError>` - The created client or an error.
|
||||
pub async fn new(config: &RabbitMQConfig) -> Result<Self, RabbitMQError> {
|
||||
let common = RabbitMQCommon::new(config).await?;
|
||||
|
||||
|
||||
// Passively declare the exchange (it should already exist)
|
||||
common.declare_exchange(config, true).await?;
|
||||
|
||||
|
||||
// Declare queue and bind it to the channel
|
||||
let queue = Self::declare_queue(&common.channel, config).await?;
|
||||
Self::bind_queue(&common.channel, &config.exchange, &queue, config).await?;
|
||||
|
||||
|
||||
// Initialize the consumer
|
||||
let consumer = Self::initialize_consumer(&common.channel, config).await?;
|
||||
|
||||
Ok(Self { common, queue, consumer })
|
||||
Ok(Self {
|
||||
common,
|
||||
queue,
|
||||
consumer,
|
||||
})
|
||||
}
|
||||
|
||||
/// Sets up the consumer based on the channel and `RabbitMQConfig`.
|
||||
@@ -48,7 +56,10 @@ impl RabbitMQConsumer {
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<Consumer, RabbitMQError>` - The initialized consumer or error
|
||||
async fn initialize_consumer(channel: &Channel, config: &RabbitMQConfig) -> Result<Consumer, RabbitMQError> {
|
||||
async fn initialize_consumer(
|
||||
channel: &Channel,
|
||||
config: &RabbitMQConfig,
|
||||
) -> Result<Consumer, RabbitMQError> {
|
||||
channel
|
||||
.basic_consume(
|
||||
&config.queue,
|
||||
@@ -56,7 +67,8 @@ impl RabbitMQConsumer {
|
||||
BasicConsumeOptions::default(),
|
||||
FieldTable::default(),
|
||||
)
|
||||
.await.map_err(|e| RabbitMQError::InitializeConsumerError(e.to_string()))
|
||||
.await
|
||||
.map_err(|e| RabbitMQError::InitializeConsumerError(e.to_string()))
|
||||
}
|
||||
/// Declares the queue based on the channel and `RabbitMQConfig`.
|
||||
/// # Arguments
|
||||
@@ -65,7 +77,10 @@ impl RabbitMQConsumer {
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<Queue, RabbitMQError>` - The initialized queue or error
|
||||
async fn declare_queue(channel: &Channel, config: &RabbitMQConfig) -> Result<Queue, RabbitMQError> {
|
||||
async fn declare_queue(
|
||||
channel: &Channel,
|
||||
config: &RabbitMQConfig,
|
||||
) -> Result<Queue, RabbitMQError> {
|
||||
channel
|
||||
.queue_declare(
|
||||
&config.queue,
|
||||
@@ -88,7 +103,12 @@ impl RabbitMQConsumer {
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<(), RabbitMQError>` - Ok or error
|
||||
async fn bind_queue(channel: &Channel, exchange: &str, queue: &Queue, config: &RabbitMQConfig) -> Result<(), RabbitMQError> {
|
||||
async fn bind_queue(
|
||||
channel: &Channel,
|
||||
exchange: &str,
|
||||
queue: &Queue,
|
||||
config: &RabbitMQConfig,
|
||||
) -> Result<(), RabbitMQError> {
|
||||
channel
|
||||
.queue_bind(
|
||||
queue.name().as_str(),
|
||||
@@ -111,7 +131,11 @@ impl RabbitMQConsumer {
|
||||
/// `Delivery` - A delivery reciept, required to ack or nack the delivery.
|
||||
pub async fn consume(&self) -> Result<(IngressObject, Delivery), RabbitMQError> {
|
||||
// Receive the next message
|
||||
let delivery = self.consumer.clone().next().await
|
||||
let delivery = self
|
||||
.consumer
|
||||
.clone()
|
||||
.next()
|
||||
.await
|
||||
.ok_or_else(|| RabbitMQError::ConsumeError("No message received".to_string()))?
|
||||
.map_err(|e| RabbitMQError::ConsumeError(e.to_string()))?;
|
||||
|
||||
@@ -131,7 +155,8 @@ impl RabbitMQConsumer {
|
||||
/// # Returns
|
||||
/// * `Result<(), RabbitMQError>` - Ok or error
|
||||
pub async fn ack_delivery(&self, delivery: Delivery) -> Result<(), RabbitMQError> {
|
||||
self.common.channel
|
||||
self.common
|
||||
.channel
|
||||
.basic_ack(delivery.delivery_tag, BasicAckOptions::default())
|
||||
.await
|
||||
.map_err(|e| RabbitMQError::ConsumeError(e.to_string()))?;
|
||||
@@ -139,33 +164,22 @@ impl RabbitMQConsumer {
|
||||
Ok(())
|
||||
}
|
||||
/// Function to continually consume messages as they come in
|
||||
/// WIP
|
||||
pub async fn process_messages(&self) -> Result<(), RabbitMQError> {
|
||||
pub async fn process_messages(&self) -> Result<(), IngressConsumerError> {
|
||||
loop {
|
||||
match self.consume().await {
|
||||
Ok((ingress, delivery)) => {
|
||||
info!("Received IngressObject: {:?}", ingress);
|
||||
let text_content = ingress.to_text_content().await.unwrap();
|
||||
text_content.process().await.unwrap();
|
||||
|
||||
// Get the TextContent
|
||||
let text_content = ingress.to_text_content().await?;
|
||||
|
||||
// Initialize ContentProcessor which handles LLM analysis and storage
|
||||
let content_processor = ContentProcessor::new().await?;
|
||||
|
||||
// Begin processing of TextContent
|
||||
content_processor.process(&text_content).await?;
|
||||
|
||||
// Remove from queue
|
||||
self.ack_delivery(delivery).await?;
|
||||
// Process the IngressContent
|
||||
// match self.handle_ingress_content(&ingress).await {
|
||||
// Ok(_) => {
|
||||
// info!("Successfully handled IngressContent");
|
||||
// // Acknowledge the message
|
||||
// if let Err(e) = self.ack_delivery(delivery).await {
|
||||
// error!("Failed to acknowledge message: {:?}", e);
|
||||
// }
|
||||
// },
|
||||
// Err(e) => {
|
||||
// error!("Failed to handle IngressContent: {:?}", e);
|
||||
// // For now, we'll acknowledge to remove it from the queue. Change to nack?
|
||||
// if let Err(ack_err) = self.ack_delivery(delivery).await {
|
||||
// error!("Failed to acknowledge message after handling error: {:?}", ack_err);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
}
|
||||
Err(RabbitMQError::ConsumeError(e)) => {
|
||||
error!("Error consuming message: {}", e);
|
||||
@@ -182,7 +196,10 @@ impl RabbitMQConsumer {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn handle_ingress_content(&self, ingress: &IngressObject) -> Result<(), IngressContentError> {
|
||||
pub async fn handle_ingress_content(
|
||||
&self,
|
||||
ingress: &IngressObject,
|
||||
) -> Result<(), IngressContentError> {
|
||||
info!("Processing IngressContent: {:?}", ingress);
|
||||
|
||||
unimplemented!()
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
pub mod publisher;
|
||||
pub mod consumer;
|
||||
pub mod publisher;
|
||||
|
||||
use axum::async_trait;
|
||||
use lapin::{
|
||||
options::ExchangeDeclareOptions, types::FieldTable, Channel, Connection, ConnectionProperties, ExchangeKind
|
||||
options::ExchangeDeclareOptions, types::FieldTable, Channel, Connection, ConnectionProperties,
|
||||
ExchangeKind,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::error::ProcessingError;
|
||||
|
||||
/// Possible errors related to RabbitMQ operations.
|
||||
#[derive(Error, Debug)]
|
||||
pub enum RabbitMQError {
|
||||
@@ -25,6 +28,8 @@ pub enum RabbitMQError {
|
||||
InitializeConsumerError(String),
|
||||
#[error("Queue error: {0}")]
|
||||
QueueError(String),
|
||||
#[error("Processing error: {0}")]
|
||||
ProcessingError(#[from] ProcessingError),
|
||||
}
|
||||
|
||||
/// Struct containing the information required to set up a client and connection.
|
||||
@@ -42,18 +47,21 @@ pub struct RabbitMQCommon {
|
||||
pub channel: Channel,
|
||||
}
|
||||
|
||||
|
||||
/// Defines the behavior for RabbitMQCommon client operations.
|
||||
#[cfg_attr(test, mockall::automock)]
|
||||
#[async_trait]
|
||||
pub trait RabbitMQCommonTrait: Send + Sync {
|
||||
async fn create_connection(config: &RabbitMQConfig) -> Result<Connection, RabbitMQError>;
|
||||
async fn declare_exchange(&self, config: &RabbitMQConfig, passive: bool) -> Result<(), RabbitMQError>;
|
||||
}
|
||||
async fn declare_exchange(
|
||||
&self,
|
||||
config: &RabbitMQConfig,
|
||||
passive: bool,
|
||||
) -> Result<(), RabbitMQError>;
|
||||
}
|
||||
|
||||
impl RabbitMQCommon {
|
||||
/// Sets up a new RabbitMQ client or error
|
||||
///
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `RabbitMQConfig` - Configuration object with required information
|
||||
///
|
||||
@@ -62,7 +70,10 @@ impl RabbitMQCommon {
|
||||
pub async fn new(config: &RabbitMQConfig) -> Result<Self, RabbitMQError> {
|
||||
let connection = Self::create_connection(config).await?;
|
||||
let channel = connection.create_channel().await?;
|
||||
Ok(Self { connection, channel })
|
||||
Ok(Self {
|
||||
connection,
|
||||
channel,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,7 +88,11 @@ impl RabbitMQCommonTrait for RabbitMQCommon {
|
||||
}
|
||||
|
||||
/// Function to declare the exchange required
|
||||
async fn declare_exchange(&self, config: &RabbitMQConfig, passive: bool) -> Result<(), RabbitMQError> {
|
||||
async fn declare_exchange(
|
||||
&self,
|
||||
config: &RabbitMQConfig,
|
||||
passive: bool,
|
||||
) -> Result<(), RabbitMQError> {
|
||||
debug!("Declaring exchange");
|
||||
self.channel
|
||||
.exchange_declare(
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
use lapin::{
|
||||
options::*, publisher_confirm::Confirmation, BasicProperties,
|
||||
};
|
||||
use lapin::{options::*, publisher_confirm::Confirmation, BasicProperties};
|
||||
|
||||
use crate::models::ingress_object::IngressObject;
|
||||
use crate::ingress::types::ingress_object::IngressObject;
|
||||
|
||||
use super::{RabbitMQCommon, RabbitMQCommonTrait, RabbitMQConfig, RabbitMQError};
|
||||
use tracing::{info, error};
|
||||
use tracing::{error, info};
|
||||
|
||||
/// Struct to publish messages to RabbitMQ.
|
||||
pub struct RabbitMQProducer {
|
||||
@@ -27,7 +25,7 @@ impl RabbitMQProducer {
|
||||
let common = RabbitMQCommon::new(config).await?;
|
||||
common.declare_exchange(config, false).await?;
|
||||
|
||||
Ok(Self {
|
||||
Ok(Self {
|
||||
common,
|
||||
exchange_name: config.exchange.clone(),
|
||||
routing_key: config.routing_key.clone(),
|
||||
@@ -39,19 +37,23 @@ impl RabbitMQProducer {
|
||||
/// # Arguments
|
||||
/// * `self` - Reference to self
|
||||
/// * `ingress_object` - A initialized IngressObject
|
||||
///
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<Confirmation, RabbitMQError>` - Confirmation of sent message or error
|
||||
pub async fn publish(&self, ingress_object: &IngressObject) -> Result<Confirmation, RabbitMQError> {
|
||||
pub async fn publish(
|
||||
&self,
|
||||
ingress_object: &IngressObject,
|
||||
) -> Result<Confirmation, RabbitMQError> {
|
||||
// Serialize IngressObject to JSON
|
||||
let payload = serde_json::to_vec(ingress_object)
|
||||
.map_err(|e| {
|
||||
error!("Serialization Error: {}", e);
|
||||
RabbitMQError::PublishError(format!("Serialization Error: {}", e))
|
||||
})?;
|
||||
|
||||
let payload = serde_json::to_vec(ingress_object).map_err(|e| {
|
||||
error!("Serialization Error: {}", e);
|
||||
RabbitMQError::PublishError(format!("Serialization Error: {}", e))
|
||||
})?;
|
||||
|
||||
// Publish the serialized payload to RabbitMQ
|
||||
let confirmation = self.common.channel
|
||||
let confirmation = self
|
||||
.common
|
||||
.channel
|
||||
.basic_publish(
|
||||
&self.exchange_name,
|
||||
&self.routing_key,
|
||||
@@ -69,9 +71,12 @@ impl RabbitMQProducer {
|
||||
error!("Publish Confirmation Error: {}", e);
|
||||
RabbitMQError::PublishError(format!("Publish Confirmation Error: {}", e))
|
||||
})?;
|
||||
|
||||
info!("Published IngressObject to exchange '{}' with routing key '{}'", self.exchange_name, self.routing_key);
|
||||
|
||||
|
||||
info!(
|
||||
"Published IngressObject to exchange '{}' with routing key '{}'",
|
||||
self.exchange_name, self.routing_key
|
||||
);
|
||||
|
||||
Ok(confirmation)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
use std::sync::Arc;
|
||||
use axum::{
|
||||
extract::Path, response::IntoResponse, Extension, Json
|
||||
use crate::storage::{
|
||||
db::SurrealDbClient,
|
||||
types::file_info::{FileError, FileInfo},
|
||||
};
|
||||
use axum_typed_multipart::{TypedMultipart, FieldData, TryFromMultipart};
|
||||
use axum::{extract::Path, response::IntoResponse, Extension, Json};
|
||||
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
use tempfile::NamedTempFile;
|
||||
use tracing::info;
|
||||
use uuid::Uuid;
|
||||
use crate::{
|
||||
models::file_info::{FileError, FileInfo},
|
||||
surrealdb::SurrealDbClient,
|
||||
};
|
||||
|
||||
#[derive(Debug, TryFromMultipart)]
|
||||
pub struct FileUploadRequest {
|
||||
@@ -19,7 +17,7 @@ pub struct FileUploadRequest {
|
||||
}
|
||||
|
||||
/// Handler to upload a new file.
|
||||
///
|
||||
///
|
||||
/// Route: POST /file
|
||||
pub async fn upload_handler(
|
||||
Extension(db_client): Extension<Arc<SurrealDbClient>>,
|
||||
@@ -40,13 +38,12 @@ pub async fn upload_handler(
|
||||
|
||||
info!("File uploaded successfully: {:?}", file_info);
|
||||
|
||||
|
||||
// Return the response with HTTP 200
|
||||
Ok((axum::http::StatusCode::OK, Json(response)))
|
||||
}
|
||||
|
||||
/// Handler to retrieve file information by UUID.
|
||||
///
|
||||
///
|
||||
/// Route: GET /file/:uuid
|
||||
pub async fn get_file_handler(
|
||||
Extension(db_client): Extension<Arc<SurrealDbClient>>,
|
||||
@@ -73,7 +70,7 @@ pub async fn get_file_handler(
|
||||
}
|
||||
|
||||
/// Handler to update an existing file by UUID.
|
||||
///
|
||||
///
|
||||
/// Route: PUT /file/:uuid
|
||||
pub async fn update_file_handler(
|
||||
Extension(db_client): Extension<Arc<SurrealDbClient>>,
|
||||
@@ -101,7 +98,7 @@ pub async fn update_file_handler(
|
||||
}
|
||||
|
||||
/// Handler to delete a file by UUID.
|
||||
///
|
||||
///
|
||||
/// Route: DELETE /file/:uuid
|
||||
pub async fn delete_file_handler(
|
||||
Extension(db_client): Extension<Arc<SurrealDbClient>>,
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
use std::sync::Arc;
|
||||
use crate::{
|
||||
ingress::types::ingress_input::{create_ingress_objects, IngressInput},
|
||||
rabbitmq::publisher::RabbitMQProducer,
|
||||
storage::db::SurrealDbClient,
|
||||
};
|
||||
use axum::{http::StatusCode, response::IntoResponse, Extension, Json};
|
||||
use std::sync::Arc;
|
||||
use tracing::{error, info};
|
||||
use crate::{models::ingress_content::{create_ingress_objects, IngressInput}, rabbitmq::publisher::RabbitMQProducer, surrealdb::SurrealDbClient};
|
||||
|
||||
pub async fn ingress_handler(
|
||||
Extension(producer): Extension<Arc<RabbitMQProducer>>,
|
||||
@@ -23,7 +27,7 @@ pub async fn ingress_handler(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Failed to publish message",
|
||||
)
|
||||
.into_response();
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -31,11 +35,7 @@ pub async fn ingress_handler(
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to process input: {:?}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Failed to process input",
|
||||
)
|
||||
.into_response()
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Failed to process input").into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ use zettle_db::{
|
||||
ingress::ingress_handler,
|
||||
queue_length::queue_length_handler,
|
||||
},
|
||||
surrealdb::SurrealDbClient,
|
||||
storage::db::SurrealDbClient,
|
||||
};
|
||||
|
||||
#[tokio::main(flavor = "multi_thread", worker_threads = 2)]
|
||||
|
||||
@@ -1,8 +1,67 @@
|
||||
use surrealdb::{engine::remote::ws::Client, Surreal};
|
||||
|
||||
use crate::error::ProcessingError;
|
||||
|
||||
use super::types::StoredObject;
|
||||
use std::ops::Deref;
|
||||
use surrealdb::{
|
||||
engine::remote::ws::{Client, Ws},
|
||||
opt::auth::Root,
|
||||
Error, Surreal,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SurrealDbClient {
|
||||
pub client: Surreal<Client>,
|
||||
}
|
||||
|
||||
impl SurrealDbClient {
|
||||
/// # Initialize a new datbase client
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// # Returns
|
||||
/// * `SurrealDbClient` initialized
|
||||
pub async fn new() -> Result<Self, Error> {
|
||||
let db = Surreal::new::<Ws>("127.0.0.1:8000").await?;
|
||||
|
||||
// Sign in to database
|
||||
db.signin(Root {
|
||||
username: "root_user",
|
||||
password: "root_password",
|
||||
})
|
||||
.await?;
|
||||
|
||||
// Set namespace
|
||||
db.use_ns("test").use_db("test").await?;
|
||||
|
||||
Ok(SurrealDbClient { client: db })
|
||||
}
|
||||
|
||||
pub async fn rebuild_indexes(&self) -> Result<(), Error> {
|
||||
self.client
|
||||
.query("REBUILD INDEX IF EXISTS idx_embedding ON text_chunk")
|
||||
.await?;
|
||||
self.client
|
||||
.query("REBUILD INDEX IF EXISTS embeddings ON knowledge_entity")
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn drop_table<T>(&self) -> Result<(), Error>
|
||||
where
|
||||
T: StoredObject + Send + Sync + 'static,
|
||||
{
|
||||
let _deleted: Vec<T> = self.client.delete(T::table_name()).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for SurrealDbClient {
|
||||
type Target = Surreal<Client>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.client
|
||||
}
|
||||
}
|
||||
|
||||
/// Operation to store a object in SurrealDB, requires the struct to implement StoredObject
|
||||
///
|
||||
|
||||
@@ -1,22 +1,26 @@
|
||||
use axum::{http::StatusCode, response::{IntoResponse, Response}, Json};
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use axum_typed_multipart::FieldData;
|
||||
use mime_guess::from_path;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use sha2::{Digest, Sha256};
|
||||
use surrealdb::RecordId;
|
||||
use std::{
|
||||
io::{BufReader, Read},
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
use surrealdb::RecordId;
|
||||
use tempfile::NamedTempFile;
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, info};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::surrealdb::SurrealDbClient;
|
||||
use crate::storage::db::SurrealDbClient;
|
||||
|
||||
#[derive(Debug,Deserialize)]
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Record {
|
||||
#[allow(dead_code)]
|
||||
id: RecordId,
|
||||
@@ -72,7 +76,6 @@ pub enum FileError {
|
||||
|
||||
#[error("Deserialization error: {0}")]
|
||||
DeserializationError(String),
|
||||
|
||||
// Add more error variants as needed.
|
||||
}
|
||||
|
||||
@@ -82,16 +85,30 @@ impl IntoResponse for FileError {
|
||||
FileError::Io(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Internal server error"),
|
||||
FileError::Utf8(_) => (StatusCode::BAD_REQUEST, "Invalid UTF-8 data"),
|
||||
FileError::MimeDetection(_) => (StatusCode::BAD_REQUEST, "MIME type detection failed"),
|
||||
FileError::UnsupportedMime(_) => (StatusCode::UNSUPPORTED_MEDIA_TYPE, "Unsupported MIME type"),
|
||||
FileError::UnsupportedMime(_) => {
|
||||
(StatusCode::UNSUPPORTED_MEDIA_TYPE, "Unsupported MIME type")
|
||||
}
|
||||
FileError::FileNotFound(_) => (StatusCode::NOT_FOUND, "File not found"),
|
||||
FileError::DuplicateFile(_) => (StatusCode::CONFLICT, "Duplicate file detected"),
|
||||
FileError::HashCollision => (StatusCode::INTERNAL_SERVER_ERROR, "Hash collision detected"),
|
||||
FileError::HashCollision => {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Hash collision detected")
|
||||
}
|
||||
FileError::InvalidUuid(_) => (StatusCode::BAD_REQUEST, "Invalid UUID format"),
|
||||
FileError::MissingFileName => (StatusCode::BAD_REQUEST, "Missing file name in metadata"),
|
||||
FileError::PersistError(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Failed to persist file"),
|
||||
FileError::SerializationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Serialization error"),
|
||||
FileError::DeserializationError(_) => (StatusCode::BAD_REQUEST, "Deserialization error"),
|
||||
FileError::SurrealError(_) =>(StatusCode::INTERNAL_SERVER_ERROR, "Serialization error"),
|
||||
FileError::MissingFileName => {
|
||||
(StatusCode::BAD_REQUEST, "Missing file name in metadata")
|
||||
}
|
||||
FileError::PersistError(_) => {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Failed to persist file")
|
||||
}
|
||||
FileError::SerializationError(_) => {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Serialization error")
|
||||
}
|
||||
FileError::DeserializationError(_) => {
|
||||
(StatusCode::BAD_REQUEST, "Deserialization error")
|
||||
}
|
||||
FileError::SurrealError(_) => {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Serialization error")
|
||||
}
|
||||
};
|
||||
|
||||
let body = Json(json!({
|
||||
@@ -163,7 +180,11 @@ impl FileInfo {
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<FileInfo, FileError>` - The updated `FileInfo` or an error.
|
||||
pub async fn update(uuid: Uuid, new_field_data: FieldData<NamedTempFile>, db_client: &SurrealDbClient) -> Result<FileInfo, FileError> {
|
||||
pub async fn update(
|
||||
uuid: Uuid,
|
||||
new_field_data: FieldData<NamedTempFile>,
|
||||
db_client: &SurrealDbClient,
|
||||
) -> Result<FileInfo, FileError> {
|
||||
let new_file = new_field_data.contents;
|
||||
let new_metadata = new_field_data.metadata;
|
||||
|
||||
@@ -184,7 +205,8 @@ impl FileInfo {
|
||||
let sanitized_new_file_name = sanitize_file_name(&new_file_name);
|
||||
|
||||
// Persist the new file
|
||||
let new_persisted_path = Self::persist_file(&uuid, new_file, &sanitized_new_file_name).await?;
|
||||
let new_persisted_path =
|
||||
Self::persist_file(&uuid, new_file, &sanitized_new_file_name).await?;
|
||||
|
||||
// Guess the new MIME type
|
||||
let new_mime_type = Self::guess_mime_type(&new_persisted_path);
|
||||
@@ -202,7 +224,7 @@ impl FileInfo {
|
||||
};
|
||||
|
||||
// Save the new item
|
||||
Self::create_record(&updated_file_info,db_client).await?;
|
||||
Self::create_record(&updated_file_info, db_client).await?;
|
||||
|
||||
// Optionally, delete the old file from the filesystem if it's no longer referenced
|
||||
// This requires reference counting or checking if other FileInfo entries point to the same SHA
|
||||
@@ -221,26 +243,35 @@ impl FileInfo {
|
||||
/// * `Result<(), FileError>` - Empty result or an error.
|
||||
pub async fn delete(uuid: Uuid, db_client: &SurrealDbClient) -> Result<(), FileError> {
|
||||
// Retrieve FileInfo to get SHA256 and path
|
||||
let file_info = Self::get_by_uuid(uuid, db_client).await?;
|
||||
let file_info = Self::get_by_uuid(uuid, db_client).await?;
|
||||
|
||||
// Delete the file from the filesystem
|
||||
let file_path = Path::new(&file_info.path);
|
||||
if file_path.exists() {
|
||||
tokio::fs::remove_file(file_path).await.map_err(FileError::Io)?;
|
||||
tokio::fs::remove_file(file_path)
|
||||
.await
|
||||
.map_err(FileError::Io)?;
|
||||
info!("Deleted file at path: {}", file_info.path);
|
||||
} else {
|
||||
info!("File path does not exist, skipping deletion: {}", file_info.path);
|
||||
info!(
|
||||
"File path does not exist, skipping deletion: {}",
|
||||
file_info.path
|
||||
);
|
||||
}
|
||||
|
||||
// Delete the FileInfo from database
|
||||
Self::delete_record(&file_info.sha256, db_client).await?;
|
||||
|
||||
// Remove the UUID directory if empty
|
||||
let uuid_dir = file_path.parent().ok_or(FileError::FileNotFound(uuid.to_string()))?;
|
||||
let uuid_dir = file_path
|
||||
.parent()
|
||||
.ok_or(FileError::FileNotFound(uuid.to_string()))?;
|
||||
if uuid_dir.exists() {
|
||||
let mut entries = tokio::fs::read_dir(uuid_dir).await.map_err(FileError::Io)?;
|
||||
if entries.next_entry().await?.is_none() {
|
||||
tokio::fs::remove_dir(uuid_dir).await.map_err(FileError::Io)?;
|
||||
tokio::fs::remove_dir(uuid_dir)
|
||||
.await
|
||||
.map_err(FileError::Io)?;
|
||||
info!("Deleted empty UUID directory: {:?}", uuid_dir);
|
||||
}
|
||||
}
|
||||
@@ -257,19 +288,26 @@ impl FileInfo {
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<PathBuf, FileError>` - The persisted file path or an error.
|
||||
async fn persist_file(uuid: &Uuid, file: NamedTempFile, file_name: &str) -> Result<PathBuf, FileError> {
|
||||
async fn persist_file(
|
||||
uuid: &Uuid,
|
||||
file: NamedTempFile,
|
||||
file_name: &str,
|
||||
) -> Result<PathBuf, FileError> {
|
||||
let base_dir = Path::new("./data");
|
||||
let uuid_dir = base_dir.join(uuid.to_string());
|
||||
|
||||
// Create the UUID directory if it doesn't exist
|
||||
tokio::fs::create_dir_all(&uuid_dir).await.map_err(FileError::Io)?;
|
||||
tokio::fs::create_dir_all(&uuid_dir)
|
||||
.await
|
||||
.map_err(FileError::Io)?;
|
||||
|
||||
// Define the final file path
|
||||
let final_path = uuid_dir.join(file_name);
|
||||
info!("Final path: {:?}", final_path);
|
||||
|
||||
// Persist the temporary file to the final path
|
||||
file.persist(&final_path).map_err(|e| FileError::PersistError(e.to_string()))?;
|
||||
file.persist(&final_path)
|
||||
.map_err(|e| FileError::PersistError(e.to_string()))?;
|
||||
|
||||
info!("Persisted file to {:?}", final_path);
|
||||
|
||||
@@ -313,7 +351,6 @@ impl FileInfo {
|
||||
.to_string()
|
||||
}
|
||||
|
||||
|
||||
/// Creates a new record in SurrealDB for the given `FileInfo`.
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -323,16 +360,19 @@ impl FileInfo {
|
||||
/// # Returns
|
||||
/// * `Result<(), FileError>` - Empty result or an error.
|
||||
|
||||
async fn create_record(file_info: &FileInfo, db_client: &SurrealDbClient) -> Result<(), FileError> {
|
||||
async fn create_record(
|
||||
file_info: &FileInfo,
|
||||
db_client: &SurrealDbClient,
|
||||
) -> Result<(), FileError> {
|
||||
// Create the record
|
||||
let _created: Option<Record> = db_client
|
||||
.client
|
||||
.create(("file", &file_info.uuid ))
|
||||
.create(("file", &file_info.uuid))
|
||||
.content(file_info.clone())
|
||||
.await?;
|
||||
|
||||
debug!("{:?}",_created);
|
||||
|
||||
debug!("{:?}", _created);
|
||||
|
||||
info!("Created FileInfo record with SHA256: {}", file_info.sha256);
|
||||
|
||||
Ok(())
|
||||
@@ -346,11 +386,17 @@ impl FileInfo {
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<FileInfo, FileError>` - The `FileInfo` or `Error` if not found.
|
||||
pub async fn get_by_uuid(uuid: Uuid, db_client: &SurrealDbClient) -> Result<FileInfo, FileError> {
|
||||
pub async fn get_by_uuid(
|
||||
uuid: Uuid,
|
||||
db_client: &SurrealDbClient,
|
||||
) -> Result<FileInfo, FileError> {
|
||||
let query = format!("SELECT * FROM file WHERE uuid = '{}'", uuid);
|
||||
let response: Vec<FileInfo> = db_client.client.query(query).await?.take(0)?;
|
||||
|
||||
response.into_iter().next().ok_or(FileError::FileNotFound(uuid.to_string()))
|
||||
response
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or(FileError::FileNotFound(uuid.to_string()))
|
||||
}
|
||||
|
||||
/// Retrieves a `FileInfo` by SHA256.
|
||||
@@ -367,7 +413,11 @@ impl FileInfo {
|
||||
|
||||
debug!("{:?}", response);
|
||||
|
||||
response.into_iter().next().ok_or(FileError::FileNotFound(sha256.to_string())) }
|
||||
response
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or(FileError::FileNotFound(sha256.to_string()))
|
||||
}
|
||||
|
||||
/// Deletes a `FileInfo` record by SHA256.
|
||||
///
|
||||
@@ -381,10 +431,7 @@ impl FileInfo {
|
||||
let table = "file";
|
||||
let primary_key = sha256;
|
||||
|
||||
let _created: Option<Record> = db_client
|
||||
.client
|
||||
.delete((table, primary_key))
|
||||
.await?;
|
||||
let _created: Option<Record> = db_client.client.delete((table, primary_key)).await?;
|
||||
|
||||
info!("Deleted FileInfo record with SHA256: {}", sha256);
|
||||
|
||||
@@ -395,7 +442,14 @@ impl FileInfo {
|
||||
/// Sanitizes the file name to prevent security vulnerabilities like directory traversal.
|
||||
/// Replaces any non-alphanumeric characters (excluding '.' and '_') with underscores.
|
||||
fn sanitize_file_name(file_name: &str) -> String {
|
||||
file_name.chars()
|
||||
.map(|c| if c.is_ascii_alphanumeric() || c == '.' || c == '_' { c } else { '_' })
|
||||
file_name
|
||||
.chars()
|
||||
.map(|c| {
|
||||
if c.is_ascii_alphanumeric() || c == '.' || c == '_' {
|
||||
c
|
||||
} else {
|
||||
'_'
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
use axum::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
pub mod file_info;
|
||||
pub mod knowledge_entity;
|
||||
pub mod knowledge_relationship;
|
||||
pub mod text_chunk;
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::models::file_info::FileInfo;
|
||||
use crate::stored_object;
|
||||
|
||||
use super::file_info::FileInfo;
|
||||
|
||||
stored_object!(TextContent, "text_content", {
|
||||
text: String,
|
||||
file_info: Option<FileInfo>,
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
use std::ops::Deref;
|
||||
use surrealdb::{
|
||||
engine::remote::ws::{Client, Ws},
|
||||
opt::auth::Root,
|
||||
Error, Surreal,
|
||||
};
|
||||
|
||||
use crate::storage::types::StoredObject;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SurrealDbClient {
|
||||
pub client: Surreal<Client>,
|
||||
}
|
||||
|
||||
impl SurrealDbClient {
|
||||
/// # Initialize a new datbase client
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// # Returns
|
||||
/// * `SurrealDbClient` initialized
|
||||
pub async fn new() -> Result<Self, Error> {
|
||||
let db = Surreal::new::<Ws>("127.0.0.1:8000").await?;
|
||||
|
||||
// Sign in to database
|
||||
db.signin(Root {
|
||||
username: "root_user",
|
||||
password: "root_password",
|
||||
})
|
||||
.await?;
|
||||
|
||||
// Set namespace
|
||||
db.use_ns("test").use_db("test").await?;
|
||||
|
||||
Ok(SurrealDbClient { client: db })
|
||||
}
|
||||
|
||||
pub async fn rebuild_indexes(&self) -> Result<(), Error> {
|
||||
self.client
|
||||
.query("REBUILD INDEX IF EXISTS idx_embedding ON text_chunk")
|
||||
.await?;
|
||||
self.client
|
||||
.query("REBUILD INDEX IF EXISTS embeddings ON knowledge_entity")
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn drop_table<T>(&self) -> Result<(), Error>
|
||||
where
|
||||
T: StoredObject + Send + Sync + 'static,
|
||||
{
|
||||
let _deleted: Vec<T> = self.client.delete(T::table_name()).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for SurrealDbClient {
|
||||
type Target = Surreal<Client>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.client
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user