From 41134cfa49b844e8d41186fb9caf53b473fb4009 Mon Sep 17 00:00:00 2001 From: Per Stark Date: Wed, 20 Nov 2024 22:44:30 +0100 Subject: [PATCH] refactoring: working macro and generics --- src/error.rs | 21 ++++++ src/lib.rs | 1 + src/models/ingress_object.rs | 41 +++++++---- src/models/text_content.rs | 110 +++++++++--------------------- src/storage/db.rs | 27 ++++++++ src/storage/mod.rs | 1 + src/storage/types/mod.rs | 21 +++--- src/storage/types/text_chunk.rs | 19 ++++++ src/storage/types/text_content.rs | 14 +--- src/surrealdb/mod.rs | 12 +--- src/utils/llm.rs | 98 +++++++++++++------------- 11 files changed, 198 insertions(+), 167 deletions(-) create mode 100644 src/error.rs create mode 100644 src/storage/db.rs create mode 100644 src/storage/types/text_chunk.rs diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..7545043 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,21 @@ +use async_openai::error::OpenAIError; +use thiserror::Error; + +/// Error types for processing `TextContent`. +#[derive(Error, Debug)] +pub enum ProcessingError { + #[error("SurrealDb error: {0}")] + SurrealDbError(#[from] surrealdb::Error), + + #[error("LLM processing error: {0}")] + OpenAIerror(#[from] OpenAIError), + + #[error("Embedding processing error: {0}")] + EmbeddingError(String), + + #[error("Graph processing error: {0}")] + GraphProcessingError(String), + + #[error("LLM parsing error: {0}")] + LLMParsingError(String), +} diff --git a/src/lib.rs b/src/lib.rs index d906151..eceb2a4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +pub mod error; pub mod models; pub mod rabbitmq; pub mod routes; diff --git a/src/models/ingress_object.rs b/src/models/ingress_object.rs index 2689f5a..b37e5ad 100644 --- a/src/models/ingress_object.rs +++ b/src/models/ingress_object.rs @@ -1,9 +1,9 @@ +use super::ingress_content::IngressContentError; use crate::models::file_info::FileInfo; +use crate::storage::types::text_content::TextContent; use serde::{Deserialize, Serialize}; use uuid::Uuid; -use super::{ingress_content::IngressContentError, text_content::TextContent}; - /// Knowledge object type, containing the content or reference to it, as well as metadata #[derive(Debug, Serialize, Deserialize, Clone)] pub enum IngressObject { @@ -34,7 +34,11 @@ impl IngressObject { /// `TextContent` - An object containing a text representation of the object, could be a scraped URL, parsed PDF, etc. pub async fn to_text_content(&self) -> Result { match self { - IngressObject::Url { url, instructions, category } => { + IngressObject::Url { + url, + instructions, + category, + } => { let text = Self::fetch_text_from_url(url).await?; let id = Uuid::new_v4(); Ok(TextContent { @@ -44,8 +48,12 @@ impl IngressObject { category: category.clone(), file_info: None, }) - }, - IngressObject::Text { text, instructions, category } => { + } + IngressObject::Text { + text, + instructions, + category, + } => { let id = Uuid::new_v4(); Ok(TextContent { id: id.to_string(), @@ -54,8 +62,12 @@ impl IngressObject { category: category.clone(), file_info: None, }) - }, - IngressObject::File { file_info, instructions, category } => { + } + IngressObject::File { + file_info, + instructions, + category, + } => { let id = Uuid::new_v4(); let text = Self::extract_text_from_file(file_info).await?; Ok(TextContent { @@ -65,7 +77,7 @@ impl IngressObject { category: category.clone(), file_info: Some(file_info.clone()), }) - }, + } } } @@ -89,11 +101,15 @@ impl IngressObject { } "application/pdf" => { // TODO: Implement PDF text extraction using a crate like `pdf-extract` or `lopdf` - Err(IngressContentError::UnsupportedMime(file_info.mime_type.clone())) + Err(IngressContentError::UnsupportedMime( + file_info.mime_type.clone(), + )) } "image/png" | "image/jpeg" => { // TODO: Implement OCR on image using a crate like `tesseract` - Err(IngressContentError::UnsupportedMime(file_info.mime_type.clone())) + Err(IngressContentError::UnsupportedMime( + file_info.mime_type.clone(), + )) } "application/octet-stream" => { let content = tokio::fs::read_to_string(&file_info.path).await?; @@ -104,8 +120,9 @@ impl IngressObject { Ok(content) } // Handle other MIME types as needed - _ => Err(IngressContentError::UnsupportedMime(file_info.mime_type.clone())), + _ => Err(IngressContentError::UnsupportedMime( + file_info.mime_type.clone(), + )), } } } - diff --git a/src/models/text_content.rs b/src/models/text_content.rs index eb86833..5a42a1a 100644 --- a/src/models/text_content.rs +++ b/src/models/text_content.rs @@ -1,62 +1,38 @@ +use crate::storage; +use crate::storage::db::store_item; +use crate::storage::types::text_chunk::TextChunk; +use crate::storage::types::text_content::TextContent; use crate::{ - models::file_info::FileInfo, - surrealdb::{SurrealDbClient, SurrealError}, + error::ProcessingError, + surrealdb::SurrealDbClient, utils::llm::{create_json_ld, generate_embedding}, }; -use async_openai::error::OpenAIError; -use serde::{Deserialize, Serialize}; -use surrealdb::{engine::remote::ws::Client, sql::Thing, Surreal}; +use surrealdb::{engine::remote::ws::Client, Surreal}; use text_splitter::TextSplitter; -use thiserror::Error; use tracing::{debug, info}; use uuid::Uuid; -use super::graph_entities::{thing_to_string, KnowledgeEntity, KnowledgeRelationship}; +use super::graph_entities::{KnowledgeEntity, KnowledgeRelationship}; -#[derive(Serialize, Deserialize, Debug)] -struct TextChunk { - #[serde(deserialize_with = "thing_to_string")] - id: String, - source_id: String, - chunk: String, - embedding: Vec, -} +// #[derive(Serialize, Deserialize, Debug)] +// struct TextChunk { +// #[serde(deserialize_with = "thing_to_string")] +// id: String, +// source_id: String, +// chunk: String, +// embedding: Vec, +// } /// Represents a single piece of text content extracted from various sources. -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct TextContent { - #[serde(deserialize_with = "thing_to_string")] - pub id: String, - pub text: String, - pub file_info: Option, - pub instructions: String, - pub category: String, -} - -/// Error types for processing `TextContent`. -#[derive(Error, Debug)] -pub enum ProcessingError { - #[error("LLM processing error: {0}")] - LLMError(String), - - #[error("SurrealDB error: {0}")] - SurrealError(#[from] SurrealError), - - #[error("SurrealDb error: {0}")] - SurrealDbError(#[from] surrealdb::Error), - - #[error("Graph DB storage error: {0}")] - GraphDBError(String), - - #[error("Vector DB storage error: {0}")] - VectorDBError(String), - - #[error("Unknown processing error")] - Unknown, - - #[error("LLM processing error: {0}")] - OpenAIerror(#[from] OpenAIError), -} +// #[derive(Debug, Serialize, Deserialize, Clone)] +// pub struct TextContent { +// #[serde(deserialize_with = "thing_to_string")] +// pub id: String, +// pub text: String, +// pub file_info: Option, +// pub instructions: String, +// pub category: String, +// } async fn vector_comparison( take: u8, @@ -66,9 +42,9 @@ async fn vector_comparison( openai_client: &async_openai::Client, ) -> Result, ProcessingError> where - T: for<'de> serde::Deserialize<'de>, // Add this trait bound for deserialization + T: for<'de> serde::Deserialize<'de>, { - let input_embedding = generate_embedding(&openai_client, input_text).await?; + let input_embedding = generate_embedding(openai_client, input_text).await?; // Construct the query let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE embedding <|{},40|> {:?} ORDER BY distance",table, take, input_embedding); @@ -98,7 +74,9 @@ impl TextContent { let db_client = SurrealDbClient::new().await?; let openai_client = async_openai::Client::new(); - self.store_text_content(&db_client).await?; + let create_operation = storage::db::store_item(&db_client, self.clone()).await?; + info!("{:?}", create_operation); + // self.store_text_content(&db_client).await?; let closest_text_content: Vec = vector_comparison( 3, @@ -116,7 +94,7 @@ impl TextContent { } } - panic!("STOPPING"); + // panic!("STOPPING"); // let deleted: Vec = db_client.delete("text_chunk").await?; // info! {"{:?} KnowledgeEntities deleted", deleted.len()}; @@ -230,35 +208,13 @@ impl TextContent { for chunk in chunks { info!("Chunk: {}", chunk); let embedding = generate_embedding(&openai_client, chunk.to_string()).await?; - let text_chunk = TextChunk { - id: Uuid::new_v4().to_string(), - source_id: self.id.clone(), - chunk: chunk.to_string(), - embedding, - }; + let text_chunk = TextChunk::new(self.id.to_string(), chunk.to_string(), embedding); info!("{:?}", text_chunk); - let _created: Option = db_client - .create(("text_chunk", text_chunk.id.clone())) - .content(text_chunk) - .await?; - - debug!("{:?}", _created); + store_item(db_client, text_chunk).await?; } Ok(()) } - - /// Stores text content in database - async fn store_text_content(&self, db_client: &Surreal) -> Result<(), ProcessingError> { - let _created: Option = db_client - .create(("text_content", self.id.clone())) - .content(self.clone()) - .await?; - - debug!("{:?}", _created); - - Ok(()) - } } diff --git a/src/storage/db.rs b/src/storage/db.rs new file mode 100644 index 0000000..90602c5 --- /dev/null +++ b/src/storage/db.rs @@ -0,0 +1,27 @@ +use surrealdb::{engine::remote::ws::Client, Surreal}; + +use crate::error::ProcessingError; + +use super::types::StoredObject; + +/// Operation to store a object in SurrealDB, requires the struct to implement StoredObject +/// +/// # Arguments +/// * `db_client` - A initialized database client +/// * `item` - The item to be stored +/// +/// # Returns +/// * `Result` - Item or Error +pub async fn store_item( + db_client: &Surreal, + item: T, +) -> Result, ProcessingError> +where + T: StoredObject + Send + Sync + 'static, +{ + db_client + .create((T::table_name(), item.get_id())) + .content(item) + .await + .map_err(ProcessingError::from) +} diff --git a/src/storage/mod.rs b/src/storage/mod.rs index cd40856..48a6962 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -1 +1,2 @@ +pub mod db; pub mod types; diff --git a/src/storage/types/mod.rs b/src/storage/types/mod.rs index d4db975..c1d619d 100644 --- a/src/storage/types/mod.rs +++ b/src/storage/types/mod.rs @@ -1,11 +1,21 @@ +use axum::async_trait; +use serde::{Deserialize, Serialize}; +pub mod text_chunk; pub mod text_content; +#[async_trait] +pub trait StoredObject: Serialize + for<'de> Deserialize<'de> { + fn table_name() -> &'static str; + fn get_id(&self) -> &str; +} + #[macro_export] -macro_rules! stored_entity { +macro_rules! stored_object { ($name:ident, $table:expr, {$($field:ident: $ty:ty),*}) => { use axum::async_trait; use serde::{Deserialize, Deserializer, Serialize}; use surrealdb::sql::Thing; + use $crate::storage::types::StoredObject; fn thing_to_string<'de, D>(deserializer: D) -> Result where @@ -15,13 +25,8 @@ macro_rules! stored_entity { Ok(thing.id.to_raw()) } - #[async_trait] - pub trait StoredEntity: Serialize + for<'de> Deserialize<'de> { - fn table_name() -> &'static str; - fn get_id(&self) -> &str; - } - #[derive(Debug, Serialize, Deserialize)] + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct $name { #[serde(deserialize_with = "thing_to_string")] pub id: String, @@ -29,7 +34,7 @@ macro_rules! stored_entity { } #[async_trait] - impl StoredEntity for $name { + impl StoredObject for $name { fn table_name() -> &'static str { $table } diff --git a/src/storage/types/text_chunk.rs b/src/storage/types/text_chunk.rs new file mode 100644 index 0000000..691b15b --- /dev/null +++ b/src/storage/types/text_chunk.rs @@ -0,0 +1,19 @@ +use crate::stored_object; +use uuid::Uuid; + +stored_object!(TextChunk, "text_chunk", { + source_id: String, + chunk: String, + embedding: Vec +}); + +impl TextChunk { + pub fn new(source_id: String, chunk: String, embedding: Vec) -> Self { + Self { + id: Uuid::new_v4().to_string(), + source_id, + chunk, + embedding, + } + } +} diff --git a/src/storage/types/text_content.rs b/src/storage/types/text_content.rs index d842943..f3d4f77 100644 --- a/src/storage/types/text_content.rs +++ b/src/storage/types/text_content.rs @@ -1,9 +1,9 @@ use uuid::Uuid; use crate::models::file_info::FileInfo; -use crate::stored_entity; +use crate::stored_object; -stored_entity!(TextContent, "text_content", { +stored_object!(TextContent, "text_content", { text: String, file_info: Option, instructions: String, @@ -23,13 +23,3 @@ impl TextContent { // Other methods... } - -fn test() { - let content = TextContent::new( - "hiho".to_string(), - "instructions".to_string(), - "cat".to_string(), - ); - - content.get_id(); -} diff --git a/src/surrealdb/mod.rs b/src/surrealdb/mod.rs index 6c7eca1..21ef2ff 100644 --- a/src/surrealdb/mod.rs +++ b/src/surrealdb/mod.rs @@ -2,22 +2,14 @@ use std::ops::Deref; use surrealdb::{ engine::remote::ws::{Client, Ws}, opt::auth::Root, - Surreal, + Error, Surreal, }; -use thiserror::Error; #[derive(Clone)] pub struct SurrealDbClient { pub client: Surreal, } -#[derive(Error, Debug)] -pub enum SurrealError { - #[error("SurrealDb error: {0}")] - SurrealDbError(#[from] surrealdb::Error), - // Add more error variants as needed. -} - impl SurrealDbClient { /// # Initialize a new datbase client /// @@ -25,7 +17,7 @@ impl SurrealDbClient { /// /// # Returns /// * `SurrealDbClient` initialized - pub async fn new() -> Result { + pub async fn new() -> Result { let db = Surreal::new::("127.0.0.1:8000").await?; // Sign in to database diff --git a/src/utils/llm.rs b/src/utils/llm.rs index edf4157..aa4495d 100644 --- a/src/utils/llm.rs +++ b/src/utils/llm.rs @@ -1,10 +1,12 @@ -use crate::models::graph_entities::{ - GraphMapper, KnowledgeEntity, KnowledgeEntityType, KnowledgeRelationship, +use crate::{ + error::ProcessingError, + models::graph_entities::{ + GraphMapper, KnowledgeEntity, KnowledgeEntityType, KnowledgeRelationship, + }, }; -use crate::models::text_content::ProcessingError; use async_openai::types::{ ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, - CreateChatCompletionRequestArgs, CreateEmbeddingRequestArgs + CreateChatCompletionRequestArgs, CreateEmbeddingRequestArgs, }; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -45,21 +47,16 @@ pub async fn generate_embedding( let request = CreateEmbeddingRequestArgs::default() .model("text-embedding-3-small") .input(&[input]) - .build() - .map_err(|e| ProcessingError::LLMError(e.to_string()))?; + .build()?; // Send the request to OpenAI - let response = client - .embeddings() - .create(request) - .await - .map_err(|e| ProcessingError::LLMError(e.to_string()))?; + let response = client.embeddings().create(request).await?; // Extract the embedding vector let embedding: Vec = response .data .first() - .ok_or_else(|| ProcessingError::LLMError("No embedding data received".into()))? + .ok_or_else(|| ProcessingError::EmbeddingError("No embedding data received".into()))? .embedding .clone(); @@ -90,7 +87,6 @@ impl LLMGraphAnalysisResult { mapper.assign_id(&llm_entity.key); } - let mut entities = vec![]; // Step 2: Process each knowledge entity sequentially @@ -99,7 +95,10 @@ impl LLMGraphAnalysisResult { let assigned_id = mapper .get_id(&llm_entity.key) .ok_or_else(|| { - ProcessingError::LLMError(format!("ID not found for key: {}", llm_entity.key)) + ProcessingError::GraphProcessingError(format!( + "ID not found for key: {}", + llm_entity.key + )) })? .clone(); @@ -158,37 +157,46 @@ pub async fn create_json_ld( openai_client: &async_openai::Client, ) -> Result { // Format the input for more cohesive comparison - let input_text = format!("content: {:?}, category: {:?}, user_instructions: {:?}", text, category, instructions); - + let input_text = format!( + "content: {:?}, category: {:?}, user_instructions: {:?}", + text, category, instructions + ); + // Generate embedding of the input let input_embedding = generate_embedding(&openai_client, input_text).await?; let number_of_entities_to_get = 10; - + // Construct the query let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM knowledge_entity WHERE embedding <|{},40|> {:?} ORDER BY distance",number_of_entities_to_get, input_embedding); - // Perform query and deserialize to struct + // Perform query and deserialize to struct let closest_entities: Vec = db_client.query(closest_query).await?.take(0)?; #[allow(dead_code)] #[derive(Debug)] struct KnowledgeEntityToLLM { id: String, name: String, - description: String + description: String, } - info!("Number of KnowledgeEntities sent as context: {}", closest_entities.len()); - + info!( + "Number of KnowledgeEntities sent as context: {}", + closest_entities.len() + ); + // Only keep most relevant information - let closest_entities_to_llm: Vec = closest_entities.clone().into_iter().map(|entity| KnowledgeEntityToLLM { - id: entity.id, - name: entity.name, - description: entity.description - }).collect(); - + let closest_entities_to_llm: Vec = closest_entities + .clone() + .into_iter() + .map(|entity| KnowledgeEntityToLLM { + id: entity.id, + name: entity.name, + description: entity.description, + }) + .collect(); + debug!("{:?}", closest_entities_to_llm); - let schema = json!({ "type": "object", @@ -293,32 +301,26 @@ pub async fn create_json_ld( ChatCompletionRequestUserMessage::from(user_message).into(), ]) .response_format(response_format) - .build() - .map_err(|e| ProcessingError::LLMError(e.to_string()))?; + .build()?; // Send the request to OpenAI - let response = openai_client - .chat() - .create(request) - .await - .map_err(|e| ProcessingError::LLMError(format!("OpenAI API request failed: {}", e)))?; + let response = openai_client.chat().create(request).await?; debug!("{:?}", response); - // Extract and parse the response - for choice in response.choices { - if let Some(content) = choice.message.content { - let analysis: LLMGraphAnalysisResult = serde_json::from_str(&content).map_err(|e| { - ProcessingError::LLMError(format!( + response + .choices + .first() + .and_then(|choice| choice.message.content.as_ref()) + .ok_or(ProcessingError::LLMParsingError( + "No content found in LLM response".into(), + )) + .and_then(|content| { + serde_json::from_str(content).map_err(|e| { + ProcessingError::LLMParsingError(format!( "Failed to parse LLM response into analysis: {}", e )) - })?; - return Ok(analysis); - } - } - - Err(ProcessingError::LLMError( - "No content found in LLM response".into(), - )) + }) + }) }