mirror of
https://github.com/perstarkse/minne.git
synced 2026-04-25 02:08:30 +02:00
feat: refactoring complete?
This commit is contained in:
@@ -1 +0,0 @@
|
|||||||
pub mod ingress;
|
|
||||||
14
src/error.rs
14
src/error.rs
@@ -2,6 +2,8 @@ use async_openai::error::OpenAIError;
|
|||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::task::JoinError;
|
use tokio::task::JoinError;
|
||||||
|
|
||||||
|
use crate::{ingress::types::ingress_input::IngressContentError, rabbitmq::RabbitMQError};
|
||||||
|
|
||||||
/// Error types for processing `TextContent`.
|
/// Error types for processing `TextContent`.
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum ProcessingError {
|
pub enum ProcessingError {
|
||||||
@@ -23,3 +25,15 @@ pub enum ProcessingError {
|
|||||||
#[error("Task join error: {0}")]
|
#[error("Task join error: {0}")]
|
||||||
JoinError(#[from] JoinError),
|
JoinError(#[from] JoinError),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum IngressConsumerError {
|
||||||
|
#[error("RabbitMQ error: {0}")]
|
||||||
|
RabbitMQ(#[from] RabbitMQError),
|
||||||
|
|
||||||
|
#[error("Processing error: {0}")]
|
||||||
|
Processing(#[from] ProcessingError),
|
||||||
|
|
||||||
|
#[error("Ingress content error: {0}")]
|
||||||
|
IngressContent(#[from] IngressContentError),
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,9 +1,6 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
analysis::ingress::{
|
|
||||||
prompt::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE},
|
|
||||||
types::llm_analysis_result::LLMGraphAnalysisResult,
|
|
||||||
},
|
|
||||||
error::ProcessingError,
|
error::ProcessingError,
|
||||||
|
ingress::analysis::prompt::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE},
|
||||||
retrieval::vector::find_items_by_vector_similarity,
|
retrieval::vector::find_items_by_vector_similarity,
|
||||||
storage::types::{knowledge_entity::KnowledgeEntity, StoredObject},
|
storage::types::{knowledge_entity::KnowledgeEntity, StoredObject},
|
||||||
};
|
};
|
||||||
@@ -15,7 +12,9 @@ use async_openai::types::{
|
|||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use surrealdb::engine::remote::ws::Client;
|
use surrealdb::engine::remote::ws::Client;
|
||||||
use surrealdb::Surreal;
|
use surrealdb::Surreal;
|
||||||
use tracing::{debug, instrument};
|
use tracing::debug;
|
||||||
|
|
||||||
|
use super::types::llm_analysis_result::LLMGraphAnalysisResult;
|
||||||
|
|
||||||
pub struct IngressAnalyzer<'a> {
|
pub struct IngressAnalyzer<'a> {
|
||||||
db_client: &'a Surreal<Client>,
|
db_client: &'a Surreal<Client>,
|
||||||
@@ -33,7 +32,6 @@ impl<'a> IngressAnalyzer<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(self))]
|
|
||||||
pub async fn analyze_content(
|
pub async fn analyze_content(
|
||||||
&self,
|
&self,
|
||||||
category: &str,
|
category: &str,
|
||||||
@@ -48,7 +46,6 @@ impl<'a> IngressAnalyzer<'a> {
|
|||||||
self.perform_analysis(llm_request).await
|
self.perform_analysis(llm_request).await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(self))]
|
|
||||||
async fn find_similar_entities(
|
async fn find_similar_entities(
|
||||||
&self,
|
&self,
|
||||||
category: &str,
|
category: &str,
|
||||||
@@ -70,7 +67,6 @@ impl<'a> IngressAnalyzer<'a> {
|
|||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(self))]
|
|
||||||
fn prepare_llm_request(
|
fn prepare_llm_request(
|
||||||
&self,
|
&self,
|
||||||
category: &str,
|
category: &str,
|
||||||
@@ -108,7 +104,7 @@ impl<'a> IngressAnalyzer<'a> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
CreateChatCompletionRequestArgs::default()
|
CreateChatCompletionRequestArgs::default()
|
||||||
.model("gpt-4-mini")
|
.model("gpt-4o-mini")
|
||||||
.temperature(0.2)
|
.temperature(0.2)
|
||||||
.max_tokens(2048u32)
|
.max_tokens(2048u32)
|
||||||
.messages([
|
.messages([
|
||||||
@@ -120,7 +116,6 @@ impl<'a> IngressAnalyzer<'a> {
|
|||||||
.map_err(|e| ProcessingError::LLMParsingError(e.to_string()))
|
.map_err(|e| ProcessingError::LLMParsingError(e.to_string()))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(self, request))]
|
|
||||||
async fn perform_analysis(
|
async fn perform_analysis(
|
||||||
&self,
|
&self,
|
||||||
request: CreateChatCompletionRequest,
|
request: CreateChatCompletionRequest,
|
||||||
@@ -5,14 +5,15 @@ use tokio::task;
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
error::ProcessingError,
|
error::ProcessingError,
|
||||||
models::graph_entities::GraphMapper,
|
|
||||||
storage::types::{
|
storage::types::{
|
||||||
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
||||||
knowledge_relationship::KnowledgeRelationship,
|
knowledge_relationship::KnowledgeRelationship,
|
||||||
},
|
},
|
||||||
utils::embedding::generate_embedding,
|
utils::embedding::generate_embedding,
|
||||||
};
|
};
|
||||||
use futures::future::try_join_all; // For future parallelization
|
use futures::future::try_join_all;
|
||||||
|
|
||||||
|
use super::graph_mapper::GraphMapper; // For future parallelization
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
pub struct LLMKnowledgeEntity {
|
pub struct LLMKnowledgeEntity {
|
||||||
@@ -1 +1,2 @@
|
|||||||
|
pub mod graph_mapper;
|
||||||
pub mod llm_analysis_result;
|
pub mod llm_analysis_result;
|
||||||
119
src/ingress/content_processor.rs
Normal file
119
src/ingress/content_processor.rs
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
use text_splitter::TextSplitter;
|
||||||
|
use tracing::{debug, info};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
error::ProcessingError,
|
||||||
|
retrieval::vector::find_items_by_vector_similarity,
|
||||||
|
storage::{
|
||||||
|
db::{store_item, SurrealDbClient},
|
||||||
|
types::{
|
||||||
|
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
|
||||||
|
text_chunk::TextChunk, text_content::TextContent,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
utils::embedding::generate_embedding,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::analysis::{
|
||||||
|
ingress_analyser::IngressAnalyzer, types::llm_analysis_result::LLMGraphAnalysisResult,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub struct ContentProcessor {
|
||||||
|
db_client: SurrealDbClient,
|
||||||
|
openai_client: async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ContentProcessor {
|
||||||
|
pub async fn new() -> Result<Self, ProcessingError> {
|
||||||
|
Ok(Self {
|
||||||
|
db_client: SurrealDbClient::new().await?,
|
||||||
|
openai_client: async_openai::Client::new(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn process(&self, content: &TextContent) -> Result<(), ProcessingError> {
|
||||||
|
// Store original content
|
||||||
|
store_item(&self.db_client, content.clone()).await?;
|
||||||
|
|
||||||
|
// Process in parallel where possible
|
||||||
|
let (analysis, _similar_chunks) = tokio::try_join!(
|
||||||
|
self.perform_semantic_analysis(content),
|
||||||
|
self.find_similar_content(content),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// Convert and store entities
|
||||||
|
let (entities, relationships) = analysis
|
||||||
|
.to_database_entities(&content.id, &self.openai_client)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Store everything
|
||||||
|
tokio::try_join!(
|
||||||
|
self.store_graph_entities(entities, relationships),
|
||||||
|
self.store_vector_chunks(content),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
self.db_client.rebuild_indexes().await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn perform_semantic_analysis(
|
||||||
|
&self,
|
||||||
|
content: &TextContent,
|
||||||
|
) -> Result<LLMGraphAnalysisResult, ProcessingError> {
|
||||||
|
let analyser = IngressAnalyzer::new(&self.db_client, &self.openai_client);
|
||||||
|
analyser
|
||||||
|
.analyze_content(&content.category, &content.instructions, &content.text)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn find_similar_content(
|
||||||
|
&self,
|
||||||
|
content: &TextContent,
|
||||||
|
) -> Result<Vec<TextChunk>, ProcessingError> {
|
||||||
|
find_items_by_vector_similarity(
|
||||||
|
3,
|
||||||
|
content.text.clone(),
|
||||||
|
&self.db_client,
|
||||||
|
"text_chunk".to_string(),
|
||||||
|
&self.openai_client,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn store_graph_entities(
|
||||||
|
&self,
|
||||||
|
entities: Vec<KnowledgeEntity>,
|
||||||
|
relationships: Vec<KnowledgeRelationship>,
|
||||||
|
) -> Result<(), ProcessingError> {
|
||||||
|
for entity in &entities {
|
||||||
|
debug!("Storing entity: {:?}", entity);
|
||||||
|
store_item(&self.db_client, entity.clone()).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
for relationship in &relationships {
|
||||||
|
debug!("Storing relationship: {:?}", relationship);
|
||||||
|
store_item(&self.db_client, relationship.clone()).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Stored {} entities and {} relationships",
|
||||||
|
entities.len(),
|
||||||
|
relationships.len()
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn store_vector_chunks(&self, content: &TextContent) -> Result<(), ProcessingError> {
|
||||||
|
let splitter = TextSplitter::new(500..2000);
|
||||||
|
let chunks = splitter.chunks(&content.text);
|
||||||
|
|
||||||
|
// Could potentially process chunks in parallel with a bounded concurrent limit
|
||||||
|
for chunk in chunks {
|
||||||
|
let embedding = generate_embedding(&self.openai_client, chunk.to_string()).await?;
|
||||||
|
let text_chunk = TextChunk::new(content.id.to_string(), chunk.to_string(), embedding);
|
||||||
|
store_item(&self.db_client, text_chunk).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
3
src/ingress/mod.rs
Normal file
3
src/ingress/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
pub mod analysis;
|
||||||
|
pub mod content_processor;
|
||||||
|
pub mod types;
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
|
use super::ingress_object::IngressObject;
|
||||||
|
use crate::storage::{db::SurrealDbClient, types::file_info::FileInfo};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
use url::Url;
|
use url::Url;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
use crate::surrealdb::SurrealDbClient;
|
|
||||||
use super::{file_info::FileInfo, ingress_object::IngressObject };
|
|
||||||
|
|
||||||
/// Struct defining the expected body when ingressing content.
|
/// Struct defining the expected body when ingressing content.
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
@@ -105,4 +105,3 @@ pub async fn create_ingress_objects(
|
|||||||
|
|
||||||
Ok(object_list)
|
Ok(object_list)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
use super::ingress_content::IngressContentError;
|
use crate::storage::types::{file_info::FileInfo, text_content::TextContent};
|
||||||
use crate::models::file_info::FileInfo;
|
|
||||||
use crate::storage::types::text_content::TextContent;
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use super::ingress_input::IngressContentError;
|
||||||
|
|
||||||
/// Knowledge object type, containing the content or reference to it, as well as metadata
|
/// Knowledge object type, containing the content or reference to it, as well as metadata
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
pub enum IngressObject {
|
pub enum IngressObject {
|
||||||
2
src/ingress/types/mod.rs
Normal file
2
src/ingress/types/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
pub mod ingress_input;
|
||||||
|
pub mod ingress_object;
|
||||||
@@ -1,9 +1,7 @@
|
|||||||
pub mod analysis;
|
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod models;
|
pub mod ingress;
|
||||||
pub mod rabbitmq;
|
pub mod rabbitmq;
|
||||||
pub mod retrieval;
|
pub mod retrieval;
|
||||||
pub mod routes;
|
pub mod routes;
|
||||||
pub mod storage;
|
pub mod storage;
|
||||||
pub mod surrealdb;
|
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
|
|||||||
@@ -1,5 +0,0 @@
|
|||||||
pub mod file_info;
|
|
||||||
pub mod graph_entities;
|
|
||||||
pub mod ingress_content;
|
|
||||||
pub mod ingress_object;
|
|
||||||
pub mod text_content;
|
|
||||||
@@ -1,122 +0,0 @@
|
|||||||
use crate::analysis::ingress::ingress_analyser::IngressAnalyzer;
|
|
||||||
use crate::retrieval::graph::find_entities_by_source_id;
|
|
||||||
use crate::retrieval::vector::find_items_by_vector_similarity;
|
|
||||||
use crate::storage::db::store_item;
|
|
||||||
use crate::storage::types::knowledge_entity::KnowledgeEntity;
|
|
||||||
use crate::storage::types::knowledge_relationship::KnowledgeRelationship;
|
|
||||||
use crate::storage::types::text_chunk::TextChunk;
|
|
||||||
use crate::storage::types::text_content::TextContent;
|
|
||||||
use crate::storage::types::StoredObject;
|
|
||||||
use crate::utils::embedding::generate_embedding;
|
|
||||||
use crate::{error::ProcessingError, surrealdb::SurrealDbClient};
|
|
||||||
use surrealdb::{engine::remote::ws::Client, Surreal};
|
|
||||||
use text_splitter::TextSplitter;
|
|
||||||
use tracing::{debug, info};
|
|
||||||
|
|
||||||
impl TextContent {
|
|
||||||
/// Processes the `TextContent` by sending it to an LLM, storing in a graph DB, and vector DB.
|
|
||||||
pub async fn process(&self) -> Result<(), ProcessingError> {
|
|
||||||
let db_client = SurrealDbClient::new().await?;
|
|
||||||
let openai_client = async_openai::Client::new();
|
|
||||||
|
|
||||||
// Store TextContent
|
|
||||||
store_item(&db_client, self.clone()).await?;
|
|
||||||
|
|
||||||
// Get related nodes
|
|
||||||
let closest_text_content: Vec<TextChunk> = find_items_by_vector_similarity(
|
|
||||||
3,
|
|
||||||
self.text.clone(),
|
|
||||||
&db_client,
|
|
||||||
"text_chunk".to_string(),
|
|
||||||
&openai_client,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
for node in closest_text_content {
|
|
||||||
let related_nodes: Vec<KnowledgeEntity> = find_entities_by_source_id(
|
|
||||||
node.source_id.to_owned(),
|
|
||||||
KnowledgeEntity::table_name().to_string(),
|
|
||||||
&db_client,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
for related_node in related_nodes {
|
|
||||||
info!("{:?}", related_node.name);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rebuild indexes
|
|
||||||
db_client.rebuild_indexes().await?;
|
|
||||||
|
|
||||||
// Step 1: Send to LLM for analysis
|
|
||||||
let analyser = IngressAnalyzer::new(&db_client, &openai_client);
|
|
||||||
let analysis = analyser
|
|
||||||
.analyze_content(&self.category, &self.instructions, &self.text)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Step 2: Convert LLM analysis to database entities
|
|
||||||
let (entities, relationships) = analysis
|
|
||||||
.to_database_entities(&self.id, &openai_client)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Step 3: Store in database
|
|
||||||
self.store_in_graph_db(entities, relationships, &db_client)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Step 4: Split text and store in Vector DB
|
|
||||||
self.store_in_vector_db(&db_client, &openai_client).await?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn store_in_graph_db(
|
|
||||||
&self,
|
|
||||||
entities: Vec<KnowledgeEntity>,
|
|
||||||
relationships: Vec<KnowledgeRelationship>,
|
|
||||||
db_client: &Surreal<Client>,
|
|
||||||
) -> Result<(), ProcessingError> {
|
|
||||||
for entity in &entities {
|
|
||||||
debug!(
|
|
||||||
"{:?}, {:?}, {:?}",
|
|
||||||
&entity.id, &entity.name, &entity.description
|
|
||||||
);
|
|
||||||
|
|
||||||
store_item(db_client, entity.clone()).await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
for relationship in &relationships {
|
|
||||||
debug!("{:?}", relationship);
|
|
||||||
|
|
||||||
store_item(db_client, relationship.clone()).await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
info!(
|
|
||||||
"Inserted to database: {:?} entities, {:?} relationships",
|
|
||||||
entities.len(),
|
|
||||||
relationships.len()
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Splits text and stores it in a vector database.
|
|
||||||
async fn store_in_vector_db(
|
|
||||||
&self,
|
|
||||||
db_client: &Surreal<Client>,
|
|
||||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
|
||||||
) -> Result<(), ProcessingError> {
|
|
||||||
let max_characters = 500..2000;
|
|
||||||
let splitter = TextSplitter::new(max_characters);
|
|
||||||
|
|
||||||
let chunks = splitter.chunks(self.text.as_str());
|
|
||||||
|
|
||||||
for chunk in chunks {
|
|
||||||
info!("Chunk: {}", chunk);
|
|
||||||
let embedding = generate_embedding(openai_client, chunk.to_string()).await?;
|
|
||||||
let text_chunk = TextChunk::new(self.id.to_string(), chunk.to_string(), embedding);
|
|
||||||
|
|
||||||
store_item(db_client, text_chunk).await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,12 +1,16 @@
|
|||||||
use lapin::{
|
|
||||||
message::Delivery, options::*, types::FieldTable, Channel, Consumer, Queue
|
|
||||||
};
|
|
||||||
use futures_lite::stream::StreamExt;
|
use futures_lite::stream::StreamExt;
|
||||||
|
use lapin::{message::Delivery, options::*, types::FieldTable, Channel, Consumer, Queue};
|
||||||
|
|
||||||
use crate::models::{ingress_content::IngressContentError, ingress_object::IngressObject };
|
use crate::{
|
||||||
|
error::IngressConsumerError,
|
||||||
|
ingress::{
|
||||||
|
content_processor::ContentProcessor,
|
||||||
|
types::{ingress_input::IngressContentError, ingress_object::IngressObject},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
use super::{RabbitMQCommon, RabbitMQCommonTrait, RabbitMQConfig, RabbitMQError};
|
use super::{RabbitMQCommon, RabbitMQCommonTrait, RabbitMQConfig, RabbitMQError};
|
||||||
use tracing::{info, error};
|
use tracing::{error, info};
|
||||||
|
|
||||||
/// Struct to consume messages from RabbitMQ.
|
/// Struct to consume messages from RabbitMQ.
|
||||||
pub struct RabbitMQConsumer {
|
pub struct RabbitMQConsumer {
|
||||||
@@ -26,18 +30,22 @@ impl RabbitMQConsumer {
|
|||||||
/// * `Result<Self, RabbitMQError>` - The created client or an error.
|
/// * `Result<Self, RabbitMQError>` - The created client or an error.
|
||||||
pub async fn new(config: &RabbitMQConfig) -> Result<Self, RabbitMQError> {
|
pub async fn new(config: &RabbitMQConfig) -> Result<Self, RabbitMQError> {
|
||||||
let common = RabbitMQCommon::new(config).await?;
|
let common = RabbitMQCommon::new(config).await?;
|
||||||
|
|
||||||
// Passively declare the exchange (it should already exist)
|
// Passively declare the exchange (it should already exist)
|
||||||
common.declare_exchange(config, true).await?;
|
common.declare_exchange(config, true).await?;
|
||||||
|
|
||||||
// Declare queue and bind it to the channel
|
// Declare queue and bind it to the channel
|
||||||
let queue = Self::declare_queue(&common.channel, config).await?;
|
let queue = Self::declare_queue(&common.channel, config).await?;
|
||||||
Self::bind_queue(&common.channel, &config.exchange, &queue, config).await?;
|
Self::bind_queue(&common.channel, &config.exchange, &queue, config).await?;
|
||||||
|
|
||||||
// Initialize the consumer
|
// Initialize the consumer
|
||||||
let consumer = Self::initialize_consumer(&common.channel, config).await?;
|
let consumer = Self::initialize_consumer(&common.channel, config).await?;
|
||||||
|
|
||||||
Ok(Self { common, queue, consumer })
|
Ok(Self {
|
||||||
|
common,
|
||||||
|
queue,
|
||||||
|
consumer,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sets up the consumer based on the channel and `RabbitMQConfig`.
|
/// Sets up the consumer based on the channel and `RabbitMQConfig`.
|
||||||
@@ -48,7 +56,10 @@ impl RabbitMQConsumer {
|
|||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// * `Result<Consumer, RabbitMQError>` - The initialized consumer or error
|
/// * `Result<Consumer, RabbitMQError>` - The initialized consumer or error
|
||||||
async fn initialize_consumer(channel: &Channel, config: &RabbitMQConfig) -> Result<Consumer, RabbitMQError> {
|
async fn initialize_consumer(
|
||||||
|
channel: &Channel,
|
||||||
|
config: &RabbitMQConfig,
|
||||||
|
) -> Result<Consumer, RabbitMQError> {
|
||||||
channel
|
channel
|
||||||
.basic_consume(
|
.basic_consume(
|
||||||
&config.queue,
|
&config.queue,
|
||||||
@@ -56,7 +67,8 @@ impl RabbitMQConsumer {
|
|||||||
BasicConsumeOptions::default(),
|
BasicConsumeOptions::default(),
|
||||||
FieldTable::default(),
|
FieldTable::default(),
|
||||||
)
|
)
|
||||||
.await.map_err(|e| RabbitMQError::InitializeConsumerError(e.to_string()))
|
.await
|
||||||
|
.map_err(|e| RabbitMQError::InitializeConsumerError(e.to_string()))
|
||||||
}
|
}
|
||||||
/// Declares the queue based on the channel and `RabbitMQConfig`.
|
/// Declares the queue based on the channel and `RabbitMQConfig`.
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
@@ -65,7 +77,10 @@ impl RabbitMQConsumer {
|
|||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// * `Result<Queue, RabbitMQError>` - The initialized queue or error
|
/// * `Result<Queue, RabbitMQError>` - The initialized queue or error
|
||||||
async fn declare_queue(channel: &Channel, config: &RabbitMQConfig) -> Result<Queue, RabbitMQError> {
|
async fn declare_queue(
|
||||||
|
channel: &Channel,
|
||||||
|
config: &RabbitMQConfig,
|
||||||
|
) -> Result<Queue, RabbitMQError> {
|
||||||
channel
|
channel
|
||||||
.queue_declare(
|
.queue_declare(
|
||||||
&config.queue,
|
&config.queue,
|
||||||
@@ -88,7 +103,12 @@ impl RabbitMQConsumer {
|
|||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// * `Result<(), RabbitMQError>` - Ok or error
|
/// * `Result<(), RabbitMQError>` - Ok or error
|
||||||
async fn bind_queue(channel: &Channel, exchange: &str, queue: &Queue, config: &RabbitMQConfig) -> Result<(), RabbitMQError> {
|
async fn bind_queue(
|
||||||
|
channel: &Channel,
|
||||||
|
exchange: &str,
|
||||||
|
queue: &Queue,
|
||||||
|
config: &RabbitMQConfig,
|
||||||
|
) -> Result<(), RabbitMQError> {
|
||||||
channel
|
channel
|
||||||
.queue_bind(
|
.queue_bind(
|
||||||
queue.name().as_str(),
|
queue.name().as_str(),
|
||||||
@@ -111,7 +131,11 @@ impl RabbitMQConsumer {
|
|||||||
/// `Delivery` - A delivery reciept, required to ack or nack the delivery.
|
/// `Delivery` - A delivery reciept, required to ack or nack the delivery.
|
||||||
pub async fn consume(&self) -> Result<(IngressObject, Delivery), RabbitMQError> {
|
pub async fn consume(&self) -> Result<(IngressObject, Delivery), RabbitMQError> {
|
||||||
// Receive the next message
|
// Receive the next message
|
||||||
let delivery = self.consumer.clone().next().await
|
let delivery = self
|
||||||
|
.consumer
|
||||||
|
.clone()
|
||||||
|
.next()
|
||||||
|
.await
|
||||||
.ok_or_else(|| RabbitMQError::ConsumeError("No message received".to_string()))?
|
.ok_or_else(|| RabbitMQError::ConsumeError("No message received".to_string()))?
|
||||||
.map_err(|e| RabbitMQError::ConsumeError(e.to_string()))?;
|
.map_err(|e| RabbitMQError::ConsumeError(e.to_string()))?;
|
||||||
|
|
||||||
@@ -131,7 +155,8 @@ impl RabbitMQConsumer {
|
|||||||
/// # Returns
|
/// # Returns
|
||||||
/// * `Result<(), RabbitMQError>` - Ok or error
|
/// * `Result<(), RabbitMQError>` - Ok or error
|
||||||
pub async fn ack_delivery(&self, delivery: Delivery) -> Result<(), RabbitMQError> {
|
pub async fn ack_delivery(&self, delivery: Delivery) -> Result<(), RabbitMQError> {
|
||||||
self.common.channel
|
self.common
|
||||||
|
.channel
|
||||||
.basic_ack(delivery.delivery_tag, BasicAckOptions::default())
|
.basic_ack(delivery.delivery_tag, BasicAckOptions::default())
|
||||||
.await
|
.await
|
||||||
.map_err(|e| RabbitMQError::ConsumeError(e.to_string()))?;
|
.map_err(|e| RabbitMQError::ConsumeError(e.to_string()))?;
|
||||||
@@ -139,33 +164,22 @@ impl RabbitMQConsumer {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
/// Function to continually consume messages as they come in
|
/// Function to continually consume messages as they come in
|
||||||
/// WIP
|
pub async fn process_messages(&self) -> Result<(), IngressConsumerError> {
|
||||||
pub async fn process_messages(&self) -> Result<(), RabbitMQError> {
|
|
||||||
loop {
|
loop {
|
||||||
match self.consume().await {
|
match self.consume().await {
|
||||||
Ok((ingress, delivery)) => {
|
Ok((ingress, delivery)) => {
|
||||||
info!("Received IngressObject: {:?}", ingress);
|
info!("Received IngressObject: {:?}", ingress);
|
||||||
let text_content = ingress.to_text_content().await.unwrap();
|
// Get the TextContent
|
||||||
text_content.process().await.unwrap();
|
let text_content = ingress.to_text_content().await?;
|
||||||
|
|
||||||
|
// Initialize ContentProcessor which handles LLM analysis and storage
|
||||||
|
let content_processor = ContentProcessor::new().await?;
|
||||||
|
|
||||||
|
// Begin processing of TextContent
|
||||||
|
content_processor.process(&text_content).await?;
|
||||||
|
|
||||||
|
// Remove from queue
|
||||||
self.ack_delivery(delivery).await?;
|
self.ack_delivery(delivery).await?;
|
||||||
// Process the IngressContent
|
|
||||||
// match self.handle_ingress_content(&ingress).await {
|
|
||||||
// Ok(_) => {
|
|
||||||
// info!("Successfully handled IngressContent");
|
|
||||||
// // Acknowledge the message
|
|
||||||
// if let Err(e) = self.ack_delivery(delivery).await {
|
|
||||||
// error!("Failed to acknowledge message: {:?}", e);
|
|
||||||
// }
|
|
||||||
// },
|
|
||||||
// Err(e) => {
|
|
||||||
// error!("Failed to handle IngressContent: {:?}", e);
|
|
||||||
// // For now, we'll acknowledge to remove it from the queue. Change to nack?
|
|
||||||
// if let Err(ack_err) = self.ack_delivery(delivery).await {
|
|
||||||
// error!("Failed to acknowledge message after handling error: {:?}", ack_err);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
Err(RabbitMQError::ConsumeError(e)) => {
|
Err(RabbitMQError::ConsumeError(e)) => {
|
||||||
error!("Error consuming message: {}", e);
|
error!("Error consuming message: {}", e);
|
||||||
@@ -182,7 +196,10 @@ impl RabbitMQConsumer {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn handle_ingress_content(&self, ingress: &IngressObject) -> Result<(), IngressContentError> {
|
pub async fn handle_ingress_content(
|
||||||
|
&self,
|
||||||
|
ingress: &IngressObject,
|
||||||
|
) -> Result<(), IngressContentError> {
|
||||||
info!("Processing IngressContent: {:?}", ingress);
|
info!("Processing IngressContent: {:?}", ingress);
|
||||||
|
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
pub mod publisher;
|
|
||||||
pub mod consumer;
|
pub mod consumer;
|
||||||
|
pub mod publisher;
|
||||||
|
|
||||||
use axum::async_trait;
|
use axum::async_trait;
|
||||||
use lapin::{
|
use lapin::{
|
||||||
options::ExchangeDeclareOptions, types::FieldTable, Channel, Connection, ConnectionProperties, ExchangeKind
|
options::ExchangeDeclareOptions, types::FieldTable, Channel, Connection, ConnectionProperties,
|
||||||
|
ExchangeKind,
|
||||||
};
|
};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
|
|
||||||
|
use crate::error::ProcessingError;
|
||||||
|
|
||||||
/// Possible errors related to RabbitMQ operations.
|
/// Possible errors related to RabbitMQ operations.
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum RabbitMQError {
|
pub enum RabbitMQError {
|
||||||
@@ -25,6 +28,8 @@ pub enum RabbitMQError {
|
|||||||
InitializeConsumerError(String),
|
InitializeConsumerError(String),
|
||||||
#[error("Queue error: {0}")]
|
#[error("Queue error: {0}")]
|
||||||
QueueError(String),
|
QueueError(String),
|
||||||
|
#[error("Processing error: {0}")]
|
||||||
|
ProcessingError(#[from] ProcessingError),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Struct containing the information required to set up a client and connection.
|
/// Struct containing the information required to set up a client and connection.
|
||||||
@@ -42,18 +47,21 @@ pub struct RabbitMQCommon {
|
|||||||
pub channel: Channel,
|
pub channel: Channel,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Defines the behavior for RabbitMQCommon client operations.
|
/// Defines the behavior for RabbitMQCommon client operations.
|
||||||
#[cfg_attr(test, mockall::automock)]
|
#[cfg_attr(test, mockall::automock)]
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait RabbitMQCommonTrait: Send + Sync {
|
pub trait RabbitMQCommonTrait: Send + Sync {
|
||||||
async fn create_connection(config: &RabbitMQConfig) -> Result<Connection, RabbitMQError>;
|
async fn create_connection(config: &RabbitMQConfig) -> Result<Connection, RabbitMQError>;
|
||||||
async fn declare_exchange(&self, config: &RabbitMQConfig, passive: bool) -> Result<(), RabbitMQError>;
|
async fn declare_exchange(
|
||||||
}
|
&self,
|
||||||
|
config: &RabbitMQConfig,
|
||||||
|
passive: bool,
|
||||||
|
) -> Result<(), RabbitMQError>;
|
||||||
|
}
|
||||||
|
|
||||||
impl RabbitMQCommon {
|
impl RabbitMQCommon {
|
||||||
/// Sets up a new RabbitMQ client or error
|
/// Sets up a new RabbitMQ client or error
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * `RabbitMQConfig` - Configuration object with required information
|
/// * `RabbitMQConfig` - Configuration object with required information
|
||||||
///
|
///
|
||||||
@@ -62,7 +70,10 @@ impl RabbitMQCommon {
|
|||||||
pub async fn new(config: &RabbitMQConfig) -> Result<Self, RabbitMQError> {
|
pub async fn new(config: &RabbitMQConfig) -> Result<Self, RabbitMQError> {
|
||||||
let connection = Self::create_connection(config).await?;
|
let connection = Self::create_connection(config).await?;
|
||||||
let channel = connection.create_channel().await?;
|
let channel = connection.create_channel().await?;
|
||||||
Ok(Self { connection, channel })
|
Ok(Self {
|
||||||
|
connection,
|
||||||
|
channel,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,7 +88,11 @@ impl RabbitMQCommonTrait for RabbitMQCommon {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Function to declare the exchange required
|
/// Function to declare the exchange required
|
||||||
async fn declare_exchange(&self, config: &RabbitMQConfig, passive: bool) -> Result<(), RabbitMQError> {
|
async fn declare_exchange(
|
||||||
|
&self,
|
||||||
|
config: &RabbitMQConfig,
|
||||||
|
passive: bool,
|
||||||
|
) -> Result<(), RabbitMQError> {
|
||||||
debug!("Declaring exchange");
|
debug!("Declaring exchange");
|
||||||
self.channel
|
self.channel
|
||||||
.exchange_declare(
|
.exchange_declare(
|
||||||
|
|||||||
@@ -1,11 +1,9 @@
|
|||||||
use lapin::{
|
use lapin::{options::*, publisher_confirm::Confirmation, BasicProperties};
|
||||||
options::*, publisher_confirm::Confirmation, BasicProperties,
|
|
||||||
};
|
|
||||||
|
|
||||||
use crate::models::ingress_object::IngressObject;
|
use crate::ingress::types::ingress_object::IngressObject;
|
||||||
|
|
||||||
use super::{RabbitMQCommon, RabbitMQCommonTrait, RabbitMQConfig, RabbitMQError};
|
use super::{RabbitMQCommon, RabbitMQCommonTrait, RabbitMQConfig, RabbitMQError};
|
||||||
use tracing::{info, error};
|
use tracing::{error, info};
|
||||||
|
|
||||||
/// Struct to publish messages to RabbitMQ.
|
/// Struct to publish messages to RabbitMQ.
|
||||||
pub struct RabbitMQProducer {
|
pub struct RabbitMQProducer {
|
||||||
@@ -27,7 +25,7 @@ impl RabbitMQProducer {
|
|||||||
let common = RabbitMQCommon::new(config).await?;
|
let common = RabbitMQCommon::new(config).await?;
|
||||||
common.declare_exchange(config, false).await?;
|
common.declare_exchange(config, false).await?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
common,
|
common,
|
||||||
exchange_name: config.exchange.clone(),
|
exchange_name: config.exchange.clone(),
|
||||||
routing_key: config.routing_key.clone(),
|
routing_key: config.routing_key.clone(),
|
||||||
@@ -39,19 +37,23 @@ impl RabbitMQProducer {
|
|||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * `self` - Reference to self
|
/// * `self` - Reference to self
|
||||||
/// * `ingress_object` - A initialized IngressObject
|
/// * `ingress_object` - A initialized IngressObject
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// * `Result<Confirmation, RabbitMQError>` - Confirmation of sent message or error
|
/// * `Result<Confirmation, RabbitMQError>` - Confirmation of sent message or error
|
||||||
pub async fn publish(&self, ingress_object: &IngressObject) -> Result<Confirmation, RabbitMQError> {
|
pub async fn publish(
|
||||||
|
&self,
|
||||||
|
ingress_object: &IngressObject,
|
||||||
|
) -> Result<Confirmation, RabbitMQError> {
|
||||||
// Serialize IngressObject to JSON
|
// Serialize IngressObject to JSON
|
||||||
let payload = serde_json::to_vec(ingress_object)
|
let payload = serde_json::to_vec(ingress_object).map_err(|e| {
|
||||||
.map_err(|e| {
|
error!("Serialization Error: {}", e);
|
||||||
error!("Serialization Error: {}", e);
|
RabbitMQError::PublishError(format!("Serialization Error: {}", e))
|
||||||
RabbitMQError::PublishError(format!("Serialization Error: {}", e))
|
})?;
|
||||||
})?;
|
|
||||||
|
|
||||||
// Publish the serialized payload to RabbitMQ
|
// Publish the serialized payload to RabbitMQ
|
||||||
let confirmation = self.common.channel
|
let confirmation = self
|
||||||
|
.common
|
||||||
|
.channel
|
||||||
.basic_publish(
|
.basic_publish(
|
||||||
&self.exchange_name,
|
&self.exchange_name,
|
||||||
&self.routing_key,
|
&self.routing_key,
|
||||||
@@ -69,9 +71,12 @@ impl RabbitMQProducer {
|
|||||||
error!("Publish Confirmation Error: {}", e);
|
error!("Publish Confirmation Error: {}", e);
|
||||||
RabbitMQError::PublishError(format!("Publish Confirmation Error: {}", e))
|
RabbitMQError::PublishError(format!("Publish Confirmation Error: {}", e))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
info!("Published IngressObject to exchange '{}' with routing key '{}'", self.exchange_name, self.routing_key);
|
info!(
|
||||||
|
"Published IngressObject to exchange '{}' with routing key '{}'",
|
||||||
|
self.exchange_name, self.routing_key
|
||||||
|
);
|
||||||
|
|
||||||
Ok(confirmation)
|
Ok(confirmation)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,16 +1,14 @@
|
|||||||
use std::sync::Arc;
|
use crate::storage::{
|
||||||
use axum::{
|
db::SurrealDbClient,
|
||||||
extract::Path, response::IntoResponse, Extension, Json
|
types::file_info::{FileError, FileInfo},
|
||||||
};
|
};
|
||||||
use axum_typed_multipart::{TypedMultipart, FieldData, TryFromMultipart};
|
use axum::{extract::Path, response::IntoResponse, Extension, Json};
|
||||||
|
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
use std::sync::Arc;
|
||||||
use tempfile::NamedTempFile;
|
use tempfile::NamedTempFile;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
use crate::{
|
|
||||||
models::file_info::{FileError, FileInfo},
|
|
||||||
surrealdb::SurrealDbClient,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[derive(Debug, TryFromMultipart)]
|
#[derive(Debug, TryFromMultipart)]
|
||||||
pub struct FileUploadRequest {
|
pub struct FileUploadRequest {
|
||||||
@@ -19,7 +17,7 @@ pub struct FileUploadRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Handler to upload a new file.
|
/// Handler to upload a new file.
|
||||||
///
|
///
|
||||||
/// Route: POST /file
|
/// Route: POST /file
|
||||||
pub async fn upload_handler(
|
pub async fn upload_handler(
|
||||||
Extension(db_client): Extension<Arc<SurrealDbClient>>,
|
Extension(db_client): Extension<Arc<SurrealDbClient>>,
|
||||||
@@ -40,13 +38,12 @@ pub async fn upload_handler(
|
|||||||
|
|
||||||
info!("File uploaded successfully: {:?}", file_info);
|
info!("File uploaded successfully: {:?}", file_info);
|
||||||
|
|
||||||
|
|
||||||
// Return the response with HTTP 200
|
// Return the response with HTTP 200
|
||||||
Ok((axum::http::StatusCode::OK, Json(response)))
|
Ok((axum::http::StatusCode::OK, Json(response)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Handler to retrieve file information by UUID.
|
/// Handler to retrieve file information by UUID.
|
||||||
///
|
///
|
||||||
/// Route: GET /file/:uuid
|
/// Route: GET /file/:uuid
|
||||||
pub async fn get_file_handler(
|
pub async fn get_file_handler(
|
||||||
Extension(db_client): Extension<Arc<SurrealDbClient>>,
|
Extension(db_client): Extension<Arc<SurrealDbClient>>,
|
||||||
@@ -73,7 +70,7 @@ pub async fn get_file_handler(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Handler to update an existing file by UUID.
|
/// Handler to update an existing file by UUID.
|
||||||
///
|
///
|
||||||
/// Route: PUT /file/:uuid
|
/// Route: PUT /file/:uuid
|
||||||
pub async fn update_file_handler(
|
pub async fn update_file_handler(
|
||||||
Extension(db_client): Extension<Arc<SurrealDbClient>>,
|
Extension(db_client): Extension<Arc<SurrealDbClient>>,
|
||||||
@@ -101,7 +98,7 @@ pub async fn update_file_handler(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Handler to delete a file by UUID.
|
/// Handler to delete a file by UUID.
|
||||||
///
|
///
|
||||||
/// Route: DELETE /file/:uuid
|
/// Route: DELETE /file/:uuid
|
||||||
pub async fn delete_file_handler(
|
pub async fn delete_file_handler(
|
||||||
Extension(db_client): Extension<Arc<SurrealDbClient>>,
|
Extension(db_client): Extension<Arc<SurrealDbClient>>,
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
use std::sync::Arc;
|
use crate::{
|
||||||
|
ingress::types::ingress_input::{create_ingress_objects, IngressInput},
|
||||||
|
rabbitmq::publisher::RabbitMQProducer,
|
||||||
|
storage::db::SurrealDbClient,
|
||||||
|
};
|
||||||
use axum::{http::StatusCode, response::IntoResponse, Extension, Json};
|
use axum::{http::StatusCode, response::IntoResponse, Extension, Json};
|
||||||
|
use std::sync::Arc;
|
||||||
use tracing::{error, info};
|
use tracing::{error, info};
|
||||||
use crate::{models::ingress_content::{create_ingress_objects, IngressInput}, rabbitmq::publisher::RabbitMQProducer, surrealdb::SurrealDbClient};
|
|
||||||
|
|
||||||
pub async fn ingress_handler(
|
pub async fn ingress_handler(
|
||||||
Extension(producer): Extension<Arc<RabbitMQProducer>>,
|
Extension(producer): Extension<Arc<RabbitMQProducer>>,
|
||||||
@@ -23,7 +27,7 @@ pub async fn ingress_handler(
|
|||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
"Failed to publish message",
|
"Failed to publish message",
|
||||||
)
|
)
|
||||||
.into_response();
|
.into_response();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -31,11 +35,7 @@ pub async fn ingress_handler(
|
|||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to process input: {:?}", e);
|
error!("Failed to process input: {:?}", e);
|
||||||
(
|
(StatusCode::INTERNAL_SERVER_ERROR, "Failed to process input").into_response()
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
"Failed to process input",
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ use zettle_db::{
|
|||||||
ingress::ingress_handler,
|
ingress::ingress_handler,
|
||||||
queue_length::queue_length_handler,
|
queue_length::queue_length_handler,
|
||||||
},
|
},
|
||||||
surrealdb::SurrealDbClient,
|
storage::db::SurrealDbClient,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[tokio::main(flavor = "multi_thread", worker_threads = 2)]
|
#[tokio::main(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
|||||||
@@ -1,8 +1,67 @@
|
|||||||
use surrealdb::{engine::remote::ws::Client, Surreal};
|
|
||||||
|
|
||||||
use crate::error::ProcessingError;
|
use crate::error::ProcessingError;
|
||||||
|
|
||||||
use super::types::StoredObject;
|
use super::types::StoredObject;
|
||||||
|
use std::ops::Deref;
|
||||||
|
use surrealdb::{
|
||||||
|
engine::remote::ws::{Client, Ws},
|
||||||
|
opt::auth::Root,
|
||||||
|
Error, Surreal,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct SurrealDbClient {
|
||||||
|
pub client: Surreal<Client>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SurrealDbClient {
|
||||||
|
/// # Initialize a new datbase client
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// * `SurrealDbClient` initialized
|
||||||
|
pub async fn new() -> Result<Self, Error> {
|
||||||
|
let db = Surreal::new::<Ws>("127.0.0.1:8000").await?;
|
||||||
|
|
||||||
|
// Sign in to database
|
||||||
|
db.signin(Root {
|
||||||
|
username: "root_user",
|
||||||
|
password: "root_password",
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Set namespace
|
||||||
|
db.use_ns("test").use_db("test").await?;
|
||||||
|
|
||||||
|
Ok(SurrealDbClient { client: db })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn rebuild_indexes(&self) -> Result<(), Error> {
|
||||||
|
self.client
|
||||||
|
.query("REBUILD INDEX IF EXISTS idx_embedding ON text_chunk")
|
||||||
|
.await?;
|
||||||
|
self.client
|
||||||
|
.query("REBUILD INDEX IF EXISTS embeddings ON knowledge_entity")
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn drop_table<T>(&self) -> Result<(), Error>
|
||||||
|
where
|
||||||
|
T: StoredObject + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
let _deleted: Vec<T> = self.client.delete(T::table_name()).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Deref for SurrealDbClient {
|
||||||
|
type Target = Surreal<Client>;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.client
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Operation to store a object in SurrealDB, requires the struct to implement StoredObject
|
/// Operation to store a object in SurrealDB, requires the struct to implement StoredObject
|
||||||
///
|
///
|
||||||
|
|||||||
@@ -1,22 +1,26 @@
|
|||||||
use axum::{http::StatusCode, response::{IntoResponse, Response}, Json};
|
use axum::{
|
||||||
|
http::StatusCode,
|
||||||
|
response::{IntoResponse, Response},
|
||||||
|
Json,
|
||||||
|
};
|
||||||
use axum_typed_multipart::FieldData;
|
use axum_typed_multipart::FieldData;
|
||||||
use mime_guess::from_path;
|
use mime_guess::from_path;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use sha2::{Digest, Sha256};
|
use sha2::{Digest, Sha256};
|
||||||
use surrealdb::RecordId;
|
|
||||||
use std::{
|
use std::{
|
||||||
io::{BufReader, Read},
|
io::{BufReader, Read},
|
||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
};
|
};
|
||||||
|
use surrealdb::RecordId;
|
||||||
use tempfile::NamedTempFile;
|
use tempfile::NamedTempFile;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tracing::{debug, info};
|
use tracing::{debug, info};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::surrealdb::SurrealDbClient;
|
use crate::storage::db::SurrealDbClient;
|
||||||
|
|
||||||
#[derive(Debug,Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct Record {
|
struct Record {
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
id: RecordId,
|
id: RecordId,
|
||||||
@@ -72,7 +76,6 @@ pub enum FileError {
|
|||||||
|
|
||||||
#[error("Deserialization error: {0}")]
|
#[error("Deserialization error: {0}")]
|
||||||
DeserializationError(String),
|
DeserializationError(String),
|
||||||
|
|
||||||
// Add more error variants as needed.
|
// Add more error variants as needed.
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,16 +85,30 @@ impl IntoResponse for FileError {
|
|||||||
FileError::Io(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Internal server error"),
|
FileError::Io(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Internal server error"),
|
||||||
FileError::Utf8(_) => (StatusCode::BAD_REQUEST, "Invalid UTF-8 data"),
|
FileError::Utf8(_) => (StatusCode::BAD_REQUEST, "Invalid UTF-8 data"),
|
||||||
FileError::MimeDetection(_) => (StatusCode::BAD_REQUEST, "MIME type detection failed"),
|
FileError::MimeDetection(_) => (StatusCode::BAD_REQUEST, "MIME type detection failed"),
|
||||||
FileError::UnsupportedMime(_) => (StatusCode::UNSUPPORTED_MEDIA_TYPE, "Unsupported MIME type"),
|
FileError::UnsupportedMime(_) => {
|
||||||
|
(StatusCode::UNSUPPORTED_MEDIA_TYPE, "Unsupported MIME type")
|
||||||
|
}
|
||||||
FileError::FileNotFound(_) => (StatusCode::NOT_FOUND, "File not found"),
|
FileError::FileNotFound(_) => (StatusCode::NOT_FOUND, "File not found"),
|
||||||
FileError::DuplicateFile(_) => (StatusCode::CONFLICT, "Duplicate file detected"),
|
FileError::DuplicateFile(_) => (StatusCode::CONFLICT, "Duplicate file detected"),
|
||||||
FileError::HashCollision => (StatusCode::INTERNAL_SERVER_ERROR, "Hash collision detected"),
|
FileError::HashCollision => {
|
||||||
|
(StatusCode::INTERNAL_SERVER_ERROR, "Hash collision detected")
|
||||||
|
}
|
||||||
FileError::InvalidUuid(_) => (StatusCode::BAD_REQUEST, "Invalid UUID format"),
|
FileError::InvalidUuid(_) => (StatusCode::BAD_REQUEST, "Invalid UUID format"),
|
||||||
FileError::MissingFileName => (StatusCode::BAD_REQUEST, "Missing file name in metadata"),
|
FileError::MissingFileName => {
|
||||||
FileError::PersistError(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Failed to persist file"),
|
(StatusCode::BAD_REQUEST, "Missing file name in metadata")
|
||||||
FileError::SerializationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Serialization error"),
|
}
|
||||||
FileError::DeserializationError(_) => (StatusCode::BAD_REQUEST, "Deserialization error"),
|
FileError::PersistError(_) => {
|
||||||
FileError::SurrealError(_) =>(StatusCode::INTERNAL_SERVER_ERROR, "Serialization error"),
|
(StatusCode::INTERNAL_SERVER_ERROR, "Failed to persist file")
|
||||||
|
}
|
||||||
|
FileError::SerializationError(_) => {
|
||||||
|
(StatusCode::INTERNAL_SERVER_ERROR, "Serialization error")
|
||||||
|
}
|
||||||
|
FileError::DeserializationError(_) => {
|
||||||
|
(StatusCode::BAD_REQUEST, "Deserialization error")
|
||||||
|
}
|
||||||
|
FileError::SurrealError(_) => {
|
||||||
|
(StatusCode::INTERNAL_SERVER_ERROR, "Serialization error")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let body = Json(json!({
|
let body = Json(json!({
|
||||||
@@ -163,7 +180,11 @@ impl FileInfo {
|
|||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// * `Result<FileInfo, FileError>` - The updated `FileInfo` or an error.
|
/// * `Result<FileInfo, FileError>` - The updated `FileInfo` or an error.
|
||||||
pub async fn update(uuid: Uuid, new_field_data: FieldData<NamedTempFile>, db_client: &SurrealDbClient) -> Result<FileInfo, FileError> {
|
pub async fn update(
|
||||||
|
uuid: Uuid,
|
||||||
|
new_field_data: FieldData<NamedTempFile>,
|
||||||
|
db_client: &SurrealDbClient,
|
||||||
|
) -> Result<FileInfo, FileError> {
|
||||||
let new_file = new_field_data.contents;
|
let new_file = new_field_data.contents;
|
||||||
let new_metadata = new_field_data.metadata;
|
let new_metadata = new_field_data.metadata;
|
||||||
|
|
||||||
@@ -184,7 +205,8 @@ impl FileInfo {
|
|||||||
let sanitized_new_file_name = sanitize_file_name(&new_file_name);
|
let sanitized_new_file_name = sanitize_file_name(&new_file_name);
|
||||||
|
|
||||||
// Persist the new file
|
// Persist the new file
|
||||||
let new_persisted_path = Self::persist_file(&uuid, new_file, &sanitized_new_file_name).await?;
|
let new_persisted_path =
|
||||||
|
Self::persist_file(&uuid, new_file, &sanitized_new_file_name).await?;
|
||||||
|
|
||||||
// Guess the new MIME type
|
// Guess the new MIME type
|
||||||
let new_mime_type = Self::guess_mime_type(&new_persisted_path);
|
let new_mime_type = Self::guess_mime_type(&new_persisted_path);
|
||||||
@@ -202,7 +224,7 @@ impl FileInfo {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Save the new item
|
// Save the new item
|
||||||
Self::create_record(&updated_file_info,db_client).await?;
|
Self::create_record(&updated_file_info, db_client).await?;
|
||||||
|
|
||||||
// Optionally, delete the old file from the filesystem if it's no longer referenced
|
// Optionally, delete the old file from the filesystem if it's no longer referenced
|
||||||
// This requires reference counting or checking if other FileInfo entries point to the same SHA
|
// This requires reference counting or checking if other FileInfo entries point to the same SHA
|
||||||
@@ -221,26 +243,35 @@ impl FileInfo {
|
|||||||
/// * `Result<(), FileError>` - Empty result or an error.
|
/// * `Result<(), FileError>` - Empty result or an error.
|
||||||
pub async fn delete(uuid: Uuid, db_client: &SurrealDbClient) -> Result<(), FileError> {
|
pub async fn delete(uuid: Uuid, db_client: &SurrealDbClient) -> Result<(), FileError> {
|
||||||
// Retrieve FileInfo to get SHA256 and path
|
// Retrieve FileInfo to get SHA256 and path
|
||||||
let file_info = Self::get_by_uuid(uuid, db_client).await?;
|
let file_info = Self::get_by_uuid(uuid, db_client).await?;
|
||||||
|
|
||||||
// Delete the file from the filesystem
|
// Delete the file from the filesystem
|
||||||
let file_path = Path::new(&file_info.path);
|
let file_path = Path::new(&file_info.path);
|
||||||
if file_path.exists() {
|
if file_path.exists() {
|
||||||
tokio::fs::remove_file(file_path).await.map_err(FileError::Io)?;
|
tokio::fs::remove_file(file_path)
|
||||||
|
.await
|
||||||
|
.map_err(FileError::Io)?;
|
||||||
info!("Deleted file at path: {}", file_info.path);
|
info!("Deleted file at path: {}", file_info.path);
|
||||||
} else {
|
} else {
|
||||||
info!("File path does not exist, skipping deletion: {}", file_info.path);
|
info!(
|
||||||
|
"File path does not exist, skipping deletion: {}",
|
||||||
|
file_info.path
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete the FileInfo from database
|
// Delete the FileInfo from database
|
||||||
Self::delete_record(&file_info.sha256, db_client).await?;
|
Self::delete_record(&file_info.sha256, db_client).await?;
|
||||||
|
|
||||||
// Remove the UUID directory if empty
|
// Remove the UUID directory if empty
|
||||||
let uuid_dir = file_path.parent().ok_or(FileError::FileNotFound(uuid.to_string()))?;
|
let uuid_dir = file_path
|
||||||
|
.parent()
|
||||||
|
.ok_or(FileError::FileNotFound(uuid.to_string()))?;
|
||||||
if uuid_dir.exists() {
|
if uuid_dir.exists() {
|
||||||
let mut entries = tokio::fs::read_dir(uuid_dir).await.map_err(FileError::Io)?;
|
let mut entries = tokio::fs::read_dir(uuid_dir).await.map_err(FileError::Io)?;
|
||||||
if entries.next_entry().await?.is_none() {
|
if entries.next_entry().await?.is_none() {
|
||||||
tokio::fs::remove_dir(uuid_dir).await.map_err(FileError::Io)?;
|
tokio::fs::remove_dir(uuid_dir)
|
||||||
|
.await
|
||||||
|
.map_err(FileError::Io)?;
|
||||||
info!("Deleted empty UUID directory: {:?}", uuid_dir);
|
info!("Deleted empty UUID directory: {:?}", uuid_dir);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -257,19 +288,26 @@ impl FileInfo {
|
|||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// * `Result<PathBuf, FileError>` - The persisted file path or an error.
|
/// * `Result<PathBuf, FileError>` - The persisted file path or an error.
|
||||||
async fn persist_file(uuid: &Uuid, file: NamedTempFile, file_name: &str) -> Result<PathBuf, FileError> {
|
async fn persist_file(
|
||||||
|
uuid: &Uuid,
|
||||||
|
file: NamedTempFile,
|
||||||
|
file_name: &str,
|
||||||
|
) -> Result<PathBuf, FileError> {
|
||||||
let base_dir = Path::new("./data");
|
let base_dir = Path::new("./data");
|
||||||
let uuid_dir = base_dir.join(uuid.to_string());
|
let uuid_dir = base_dir.join(uuid.to_string());
|
||||||
|
|
||||||
// Create the UUID directory if it doesn't exist
|
// Create the UUID directory if it doesn't exist
|
||||||
tokio::fs::create_dir_all(&uuid_dir).await.map_err(FileError::Io)?;
|
tokio::fs::create_dir_all(&uuid_dir)
|
||||||
|
.await
|
||||||
|
.map_err(FileError::Io)?;
|
||||||
|
|
||||||
// Define the final file path
|
// Define the final file path
|
||||||
let final_path = uuid_dir.join(file_name);
|
let final_path = uuid_dir.join(file_name);
|
||||||
info!("Final path: {:?}", final_path);
|
info!("Final path: {:?}", final_path);
|
||||||
|
|
||||||
// Persist the temporary file to the final path
|
// Persist the temporary file to the final path
|
||||||
file.persist(&final_path).map_err(|e| FileError::PersistError(e.to_string()))?;
|
file.persist(&final_path)
|
||||||
|
.map_err(|e| FileError::PersistError(e.to_string()))?;
|
||||||
|
|
||||||
info!("Persisted file to {:?}", final_path);
|
info!("Persisted file to {:?}", final_path);
|
||||||
|
|
||||||
@@ -313,7 +351,6 @@ impl FileInfo {
|
|||||||
.to_string()
|
.to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Creates a new record in SurrealDB for the given `FileInfo`.
|
/// Creates a new record in SurrealDB for the given `FileInfo`.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
@@ -323,16 +360,19 @@ impl FileInfo {
|
|||||||
/// # Returns
|
/// # Returns
|
||||||
/// * `Result<(), FileError>` - Empty result or an error.
|
/// * `Result<(), FileError>` - Empty result or an error.
|
||||||
|
|
||||||
async fn create_record(file_info: &FileInfo, db_client: &SurrealDbClient) -> Result<(), FileError> {
|
async fn create_record(
|
||||||
|
file_info: &FileInfo,
|
||||||
|
db_client: &SurrealDbClient,
|
||||||
|
) -> Result<(), FileError> {
|
||||||
// Create the record
|
// Create the record
|
||||||
let _created: Option<Record> = db_client
|
let _created: Option<Record> = db_client
|
||||||
.client
|
.client
|
||||||
.create(("file", &file_info.uuid ))
|
.create(("file", &file_info.uuid))
|
||||||
.content(file_info.clone())
|
.content(file_info.clone())
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
debug!("{:?}",_created);
|
debug!("{:?}", _created);
|
||||||
|
|
||||||
info!("Created FileInfo record with SHA256: {}", file_info.sha256);
|
info!("Created FileInfo record with SHA256: {}", file_info.sha256);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -346,11 +386,17 @@ impl FileInfo {
|
|||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// * `Result<FileInfo, FileError>` - The `FileInfo` or `Error` if not found.
|
/// * `Result<FileInfo, FileError>` - The `FileInfo` or `Error` if not found.
|
||||||
pub async fn get_by_uuid(uuid: Uuid, db_client: &SurrealDbClient) -> Result<FileInfo, FileError> {
|
pub async fn get_by_uuid(
|
||||||
|
uuid: Uuid,
|
||||||
|
db_client: &SurrealDbClient,
|
||||||
|
) -> Result<FileInfo, FileError> {
|
||||||
let query = format!("SELECT * FROM file WHERE uuid = '{}'", uuid);
|
let query = format!("SELECT * FROM file WHERE uuid = '{}'", uuid);
|
||||||
let response: Vec<FileInfo> = db_client.client.query(query).await?.take(0)?;
|
let response: Vec<FileInfo> = db_client.client.query(query).await?.take(0)?;
|
||||||
|
|
||||||
response.into_iter().next().ok_or(FileError::FileNotFound(uuid.to_string()))
|
response
|
||||||
|
.into_iter()
|
||||||
|
.next()
|
||||||
|
.ok_or(FileError::FileNotFound(uuid.to_string()))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Retrieves a `FileInfo` by SHA256.
|
/// Retrieves a `FileInfo` by SHA256.
|
||||||
@@ -367,7 +413,11 @@ impl FileInfo {
|
|||||||
|
|
||||||
debug!("{:?}", response);
|
debug!("{:?}", response);
|
||||||
|
|
||||||
response.into_iter().next().ok_or(FileError::FileNotFound(sha256.to_string())) }
|
response
|
||||||
|
.into_iter()
|
||||||
|
.next()
|
||||||
|
.ok_or(FileError::FileNotFound(sha256.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
/// Deletes a `FileInfo` record by SHA256.
|
/// Deletes a `FileInfo` record by SHA256.
|
||||||
///
|
///
|
||||||
@@ -381,10 +431,7 @@ impl FileInfo {
|
|||||||
let table = "file";
|
let table = "file";
|
||||||
let primary_key = sha256;
|
let primary_key = sha256;
|
||||||
|
|
||||||
let _created: Option<Record> = db_client
|
let _created: Option<Record> = db_client.client.delete((table, primary_key)).await?;
|
||||||
.client
|
|
||||||
.delete((table, primary_key))
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
info!("Deleted FileInfo record with SHA256: {}", sha256);
|
info!("Deleted FileInfo record with SHA256: {}", sha256);
|
||||||
|
|
||||||
@@ -395,7 +442,14 @@ impl FileInfo {
|
|||||||
/// Sanitizes the file name to prevent security vulnerabilities like directory traversal.
|
/// Sanitizes the file name to prevent security vulnerabilities like directory traversal.
|
||||||
/// Replaces any non-alphanumeric characters (excluding '.' and '_') with underscores.
|
/// Replaces any non-alphanumeric characters (excluding '.' and '_') with underscores.
|
||||||
fn sanitize_file_name(file_name: &str) -> String {
|
fn sanitize_file_name(file_name: &str) -> String {
|
||||||
file_name.chars()
|
file_name
|
||||||
.map(|c| if c.is_ascii_alphanumeric() || c == '.' || c == '_' { c } else { '_' })
|
.chars()
|
||||||
|
.map(|c| {
|
||||||
|
if c.is_ascii_alphanumeric() || c == '.' || c == '_' {
|
||||||
|
c
|
||||||
|
} else {
|
||||||
|
'_'
|
||||||
|
}
|
||||||
|
})
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
use axum::async_trait;
|
use axum::async_trait;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
pub mod file_info;
|
||||||
pub mod knowledge_entity;
|
pub mod knowledge_entity;
|
||||||
pub mod knowledge_relationship;
|
pub mod knowledge_relationship;
|
||||||
pub mod text_chunk;
|
pub mod text_chunk;
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::models::file_info::FileInfo;
|
|
||||||
use crate::stored_object;
|
use crate::stored_object;
|
||||||
|
|
||||||
|
use super::file_info::FileInfo;
|
||||||
|
|
||||||
stored_object!(TextContent, "text_content", {
|
stored_object!(TextContent, "text_content", {
|
||||||
text: String,
|
text: String,
|
||||||
file_info: Option<FileInfo>,
|
file_info: Option<FileInfo>,
|
||||||
|
|||||||
@@ -1,63 +0,0 @@
|
|||||||
use std::ops::Deref;
|
|
||||||
use surrealdb::{
|
|
||||||
engine::remote::ws::{Client, Ws},
|
|
||||||
opt::auth::Root,
|
|
||||||
Error, Surreal,
|
|
||||||
};
|
|
||||||
|
|
||||||
use crate::storage::types::StoredObject;
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct SurrealDbClient {
|
|
||||||
pub client: Surreal<Client>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SurrealDbClient {
|
|
||||||
/// # Initialize a new datbase client
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// * `SurrealDbClient` initialized
|
|
||||||
pub async fn new() -> Result<Self, Error> {
|
|
||||||
let db = Surreal::new::<Ws>("127.0.0.1:8000").await?;
|
|
||||||
|
|
||||||
// Sign in to database
|
|
||||||
db.signin(Root {
|
|
||||||
username: "root_user",
|
|
||||||
password: "root_password",
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Set namespace
|
|
||||||
db.use_ns("test").use_db("test").await?;
|
|
||||||
|
|
||||||
Ok(SurrealDbClient { client: db })
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn rebuild_indexes(&self) -> Result<(), Error> {
|
|
||||||
self.client
|
|
||||||
.query("REBUILD INDEX IF EXISTS idx_embedding ON text_chunk")
|
|
||||||
.await?;
|
|
||||||
self.client
|
|
||||||
.query("REBUILD INDEX IF EXISTS embeddings ON knowledge_entity")
|
|
||||||
.await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn drop_table<T>(&self) -> Result<(), Error>
|
|
||||||
where
|
|
||||||
T: StoredObject + Send + Sync + 'static,
|
|
||||||
{
|
|
||||||
let _deleted: Vec<T> = self.client.delete(T::table_name()).await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Deref for SurrealDbClient {
|
|
||||||
type Target = Surreal<Client>;
|
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
|
||||||
&self.client
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user