refactoring: working macro and generics

This commit is contained in:
Per Stark
2024-11-20 22:44:30 +01:00
parent 7222223c31
commit 41134cfa49
11 changed files with 198 additions and 167 deletions

21
src/error.rs Normal file
View File

@@ -0,0 +1,21 @@
use async_openai::error::OpenAIError;
use thiserror::Error;
/// Error types for processing `TextContent`.
#[derive(Error, Debug)]
pub enum ProcessingError {
#[error("SurrealDb error: {0}")]
SurrealDbError(#[from] surrealdb::Error),
#[error("LLM processing error: {0}")]
OpenAIerror(#[from] OpenAIError),
#[error("Embedding processing error: {0}")]
EmbeddingError(String),
#[error("Graph processing error: {0}")]
GraphProcessingError(String),
#[error("LLM parsing error: {0}")]
LLMParsingError(String),
}

View File

@@ -1,3 +1,4 @@
pub mod error;
pub mod models;
pub mod rabbitmq;
pub mod routes;

View File

@@ -1,9 +1,9 @@
use super::ingress_content::IngressContentError;
use crate::models::file_info::FileInfo;
use crate::storage::types::text_content::TextContent;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use super::{ingress_content::IngressContentError, text_content::TextContent};
/// Knowledge object type, containing the content or reference to it, as well as metadata
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum IngressObject {
@@ -34,7 +34,11 @@ impl IngressObject {
/// `TextContent` - An object containing a text representation of the object, could be a scraped URL, parsed PDF, etc.
pub async fn to_text_content(&self) -> Result<TextContent, IngressContentError> {
match self {
IngressObject::Url { url, instructions, category } => {
IngressObject::Url {
url,
instructions,
category,
} => {
let text = Self::fetch_text_from_url(url).await?;
let id = Uuid::new_v4();
Ok(TextContent {
@@ -44,8 +48,12 @@ impl IngressObject {
category: category.clone(),
file_info: None,
})
},
IngressObject::Text { text, instructions, category } => {
}
IngressObject::Text {
text,
instructions,
category,
} => {
let id = Uuid::new_v4();
Ok(TextContent {
id: id.to_string(),
@@ -54,8 +62,12 @@ impl IngressObject {
category: category.clone(),
file_info: None,
})
},
IngressObject::File { file_info, instructions, category } => {
}
IngressObject::File {
file_info,
instructions,
category,
} => {
let id = Uuid::new_v4();
let text = Self::extract_text_from_file(file_info).await?;
Ok(TextContent {
@@ -65,7 +77,7 @@ impl IngressObject {
category: category.clone(),
file_info: Some(file_info.clone()),
})
},
}
}
}
@@ -89,11 +101,15 @@ impl IngressObject {
}
"application/pdf" => {
// TODO: Implement PDF text extraction using a crate like `pdf-extract` or `lopdf`
Err(IngressContentError::UnsupportedMime(file_info.mime_type.clone()))
Err(IngressContentError::UnsupportedMime(
file_info.mime_type.clone(),
))
}
"image/png" | "image/jpeg" => {
// TODO: Implement OCR on image using a crate like `tesseract`
Err(IngressContentError::UnsupportedMime(file_info.mime_type.clone()))
Err(IngressContentError::UnsupportedMime(
file_info.mime_type.clone(),
))
}
"application/octet-stream" => {
let content = tokio::fs::read_to_string(&file_info.path).await?;
@@ -104,8 +120,9 @@ impl IngressObject {
Ok(content)
}
// Handle other MIME types as needed
_ => Err(IngressContentError::UnsupportedMime(file_info.mime_type.clone())),
_ => Err(IngressContentError::UnsupportedMime(
file_info.mime_type.clone(),
)),
}
}
}

View File

@@ -1,62 +1,38 @@
use crate::storage;
use crate::storage::db::store_item;
use crate::storage::types::text_chunk::TextChunk;
use crate::storage::types::text_content::TextContent;
use crate::{
models::file_info::FileInfo,
surrealdb::{SurrealDbClient, SurrealError},
error::ProcessingError,
surrealdb::SurrealDbClient,
utils::llm::{create_json_ld, generate_embedding},
};
use async_openai::error::OpenAIError;
use serde::{Deserialize, Serialize};
use surrealdb::{engine::remote::ws::Client, sql::Thing, Surreal};
use surrealdb::{engine::remote::ws::Client, Surreal};
use text_splitter::TextSplitter;
use thiserror::Error;
use tracing::{debug, info};
use uuid::Uuid;
use super::graph_entities::{thing_to_string, KnowledgeEntity, KnowledgeRelationship};
use super::graph_entities::{KnowledgeEntity, KnowledgeRelationship};
#[derive(Serialize, Deserialize, Debug)]
struct TextChunk {
#[serde(deserialize_with = "thing_to_string")]
id: String,
source_id: String,
chunk: String,
embedding: Vec<f32>,
}
// #[derive(Serialize, Deserialize, Debug)]
// struct TextChunk {
// #[serde(deserialize_with = "thing_to_string")]
// id: String,
// source_id: String,
// chunk: String,
// embedding: Vec<f32>,
// }
/// Represents a single piece of text content extracted from various sources.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct TextContent {
#[serde(deserialize_with = "thing_to_string")]
pub id: String,
pub text: String,
pub file_info: Option<FileInfo>,
pub instructions: String,
pub category: String,
}
/// Error types for processing `TextContent`.
#[derive(Error, Debug)]
pub enum ProcessingError {
#[error("LLM processing error: {0}")]
LLMError(String),
#[error("SurrealDB error: {0}")]
SurrealError(#[from] SurrealError),
#[error("SurrealDb error: {0}")]
SurrealDbError(#[from] surrealdb::Error),
#[error("Graph DB storage error: {0}")]
GraphDBError(String),
#[error("Vector DB storage error: {0}")]
VectorDBError(String),
#[error("Unknown processing error")]
Unknown,
#[error("LLM processing error: {0}")]
OpenAIerror(#[from] OpenAIError),
}
// #[derive(Debug, Serialize, Deserialize, Clone)]
// pub struct TextContent {
// #[serde(deserialize_with = "thing_to_string")]
// pub id: String,
// pub text: String,
// pub file_info: Option<FileInfo>,
// pub instructions: String,
// pub category: String,
// }
async fn vector_comparison<T>(
take: u8,
@@ -66,9 +42,9 @@ async fn vector_comparison<T>(
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<Vec<T>, ProcessingError>
where
T: for<'de> serde::Deserialize<'de>, // Add this trait bound for deserialization
T: for<'de> serde::Deserialize<'de>,
{
let input_embedding = generate_embedding(&openai_client, input_text).await?;
let input_embedding = generate_embedding(openai_client, input_text).await?;
// Construct the query
let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE embedding <|{},40|> {:?} ORDER BY distance",table, take, input_embedding);
@@ -98,7 +74,9 @@ impl TextContent {
let db_client = SurrealDbClient::new().await?;
let openai_client = async_openai::Client::new();
self.store_text_content(&db_client).await?;
let create_operation = storage::db::store_item(&db_client, self.clone()).await?;
info!("{:?}", create_operation);
// self.store_text_content(&db_client).await?;
let closest_text_content: Vec<TextChunk> = vector_comparison(
3,
@@ -116,7 +94,7 @@ impl TextContent {
}
}
panic!("STOPPING");
// panic!("STOPPING");
// let deleted: Vec<TextChunk> = db_client.delete("text_chunk").await?;
// info! {"{:?} KnowledgeEntities deleted", deleted.len()};
@@ -230,35 +208,13 @@ impl TextContent {
for chunk in chunks {
info!("Chunk: {}", chunk);
let embedding = generate_embedding(&openai_client, chunk.to_string()).await?;
let text_chunk = TextChunk {
id: Uuid::new_v4().to_string(),
source_id: self.id.clone(),
chunk: chunk.to_string(),
embedding,
};
let text_chunk = TextChunk::new(self.id.to_string(), chunk.to_string(), embedding);
info!("{:?}", text_chunk);
let _created: Option<TextChunk> = db_client
.create(("text_chunk", text_chunk.id.clone()))
.content(text_chunk)
.await?;
debug!("{:?}", _created);
store_item(db_client, text_chunk).await?;
}
Ok(())
}
/// Stores text content in database
async fn store_text_content(&self, db_client: &Surreal<Client>) -> Result<(), ProcessingError> {
let _created: Option<TextContent> = db_client
.create(("text_content", self.id.clone()))
.content(self.clone())
.await?;
debug!("{:?}", _created);
Ok(())
}
}

27
src/storage/db.rs Normal file
View File

@@ -0,0 +1,27 @@
use surrealdb::{engine::remote::ws::Client, Surreal};
use crate::error::ProcessingError;
use super::types::StoredObject;
/// Operation to store a object in SurrealDB, requires the struct to implement StoredObject
///
/// # Arguments
/// * `db_client` - A initialized database client
/// * `item` - The item to be stored
///
/// # Returns
/// * `Result` - Item or Error
pub async fn store_item<T>(
db_client: &Surreal<Client>,
item: T,
) -> Result<Option<T>, ProcessingError>
where
T: StoredObject + Send + Sync + 'static,
{
db_client
.create((T::table_name(), item.get_id()))
.content(item)
.await
.map_err(ProcessingError::from)
}

View File

@@ -1 +1,2 @@
pub mod db;
pub mod types;

View File

@@ -1,11 +1,21 @@
use axum::async_trait;
use serde::{Deserialize, Serialize};
pub mod text_chunk;
pub mod text_content;
#[async_trait]
pub trait StoredObject: Serialize + for<'de> Deserialize<'de> {
fn table_name() -> &'static str;
fn get_id(&self) -> &str;
}
#[macro_export]
macro_rules! stored_entity {
macro_rules! stored_object {
($name:ident, $table:expr, {$($field:ident: $ty:ty),*}) => {
use axum::async_trait;
use serde::{Deserialize, Deserializer, Serialize};
use surrealdb::sql::Thing;
use $crate::storage::types::StoredObject;
fn thing_to_string<'de, D>(deserializer: D) -> Result<String, D::Error>
where
@@ -15,13 +25,8 @@ macro_rules! stored_entity {
Ok(thing.id.to_raw())
}
#[async_trait]
pub trait StoredEntity: Serialize + for<'de> Deserialize<'de> {
fn table_name() -> &'static str;
fn get_id(&self) -> &str;
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct $name {
#[serde(deserialize_with = "thing_to_string")]
pub id: String,
@@ -29,7 +34,7 @@ macro_rules! stored_entity {
}
#[async_trait]
impl StoredEntity for $name {
impl StoredObject for $name {
fn table_name() -> &'static str {
$table
}

View File

@@ -0,0 +1,19 @@
use crate::stored_object;
use uuid::Uuid;
stored_object!(TextChunk, "text_chunk", {
source_id: String,
chunk: String,
embedding: Vec<f32>
});
impl TextChunk {
pub fn new(source_id: String, chunk: String, embedding: Vec<f32>) -> Self {
Self {
id: Uuid::new_v4().to_string(),
source_id,
chunk,
embedding,
}
}
}

View File

@@ -1,9 +1,9 @@
use uuid::Uuid;
use crate::models::file_info::FileInfo;
use crate::stored_entity;
use crate::stored_object;
stored_entity!(TextContent, "text_content", {
stored_object!(TextContent, "text_content", {
text: String,
file_info: Option<FileInfo>,
instructions: String,
@@ -23,13 +23,3 @@ impl TextContent {
// Other methods...
}
fn test() {
let content = TextContent::new(
"hiho".to_string(),
"instructions".to_string(),
"cat".to_string(),
);
content.get_id();
}

View File

@@ -2,22 +2,14 @@ use std::ops::Deref;
use surrealdb::{
engine::remote::ws::{Client, Ws},
opt::auth::Root,
Surreal,
Error, Surreal,
};
use thiserror::Error;
#[derive(Clone)]
pub struct SurrealDbClient {
pub client: Surreal<Client>,
}
#[derive(Error, Debug)]
pub enum SurrealError {
#[error("SurrealDb error: {0}")]
SurrealDbError(#[from] surrealdb::Error),
// Add more error variants as needed.
}
impl SurrealDbClient {
/// # Initialize a new datbase client
///
@@ -25,7 +17,7 @@ impl SurrealDbClient {
///
/// # Returns
/// * `SurrealDbClient` initialized
pub async fn new() -> Result<Self, SurrealError> {
pub async fn new() -> Result<Self, Error> {
let db = Surreal::new::<Ws>("127.0.0.1:8000").await?;
// Sign in to database

View File

@@ -1,10 +1,12 @@
use crate::models::graph_entities::{
GraphMapper, KnowledgeEntity, KnowledgeEntityType, KnowledgeRelationship,
use crate::{
error::ProcessingError,
models::graph_entities::{
GraphMapper, KnowledgeEntity, KnowledgeEntityType, KnowledgeRelationship,
},
};
use crate::models::text_content::ProcessingError;
use async_openai::types::{
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequestArgs, CreateEmbeddingRequestArgs
CreateChatCompletionRequestArgs, CreateEmbeddingRequestArgs,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
@@ -45,21 +47,16 @@ pub async fn generate_embedding(
let request = CreateEmbeddingRequestArgs::default()
.model("text-embedding-3-small")
.input(&[input])
.build()
.map_err(|e| ProcessingError::LLMError(e.to_string()))?;
.build()?;
// Send the request to OpenAI
let response = client
.embeddings()
.create(request)
.await
.map_err(|e| ProcessingError::LLMError(e.to_string()))?;
let response = client.embeddings().create(request).await?;
// Extract the embedding vector
let embedding: Vec<f32> = response
.data
.first()
.ok_or_else(|| ProcessingError::LLMError("No embedding data received".into()))?
.ok_or_else(|| ProcessingError::EmbeddingError("No embedding data received".into()))?
.embedding
.clone();
@@ -90,7 +87,6 @@ impl LLMGraphAnalysisResult {
mapper.assign_id(&llm_entity.key);
}
let mut entities = vec![];
// Step 2: Process each knowledge entity sequentially
@@ -99,7 +95,10 @@ impl LLMGraphAnalysisResult {
let assigned_id = mapper
.get_id(&llm_entity.key)
.ok_or_else(|| {
ProcessingError::LLMError(format!("ID not found for key: {}", llm_entity.key))
ProcessingError::GraphProcessingError(format!(
"ID not found for key: {}",
llm_entity.key
))
})?
.clone();
@@ -158,37 +157,46 @@ pub async fn create_json_ld(
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<LLMGraphAnalysisResult, ProcessingError> {
// Format the input for more cohesive comparison
let input_text = format!("content: {:?}, category: {:?}, user_instructions: {:?}", text, category, instructions);
let input_text = format!(
"content: {:?}, category: {:?}, user_instructions: {:?}",
text, category, instructions
);
// Generate embedding of the input
let input_embedding = generate_embedding(&openai_client, input_text).await?;
let number_of_entities_to_get = 10;
// Construct the query
let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM knowledge_entity WHERE embedding <|{},40|> {:?} ORDER BY distance",number_of_entities_to_get, input_embedding);
// Perform query and deserialize to struct
// Perform query and deserialize to struct
let closest_entities: Vec<KnowledgeEntity> = db_client.query(closest_query).await?.take(0)?;
#[allow(dead_code)]
#[derive(Debug)]
struct KnowledgeEntityToLLM {
id: String,
name: String,
description: String
description: String,
}
info!("Number of KnowledgeEntities sent as context: {}", closest_entities.len());
info!(
"Number of KnowledgeEntities sent as context: {}",
closest_entities.len()
);
// Only keep most relevant information
let closest_entities_to_llm: Vec<KnowledgeEntityToLLM> = closest_entities.clone().into_iter().map(|entity| KnowledgeEntityToLLM {
id: entity.id,
name: entity.name,
description: entity.description
}).collect();
let closest_entities_to_llm: Vec<KnowledgeEntityToLLM> = closest_entities
.clone()
.into_iter()
.map(|entity| KnowledgeEntityToLLM {
id: entity.id,
name: entity.name,
description: entity.description,
})
.collect();
debug!("{:?}", closest_entities_to_llm);
let schema = json!({
"type": "object",
@@ -293,32 +301,26 @@ pub async fn create_json_ld(
ChatCompletionRequestUserMessage::from(user_message).into(),
])
.response_format(response_format)
.build()
.map_err(|e| ProcessingError::LLMError(e.to_string()))?;
.build()?;
// Send the request to OpenAI
let response = openai_client
.chat()
.create(request)
.await
.map_err(|e| ProcessingError::LLMError(format!("OpenAI API request failed: {}", e)))?;
let response = openai_client.chat().create(request).await?;
debug!("{:?}", response);
// Extract and parse the response
for choice in response.choices {
if let Some(content) = choice.message.content {
let analysis: LLMGraphAnalysisResult = serde_json::from_str(&content).map_err(|e| {
ProcessingError::LLMError(format!(
response
.choices
.first()
.and_then(|choice| choice.message.content.as_ref())
.ok_or(ProcessingError::LLMParsingError(
"No content found in LLM response".into(),
))
.and_then(|content| {
serde_json::from_str(content).map_err(|e| {
ProcessingError::LLMParsingError(format!(
"Failed to parse LLM response into analysis: {}",
e
))
})?;
return Ok(analysis);
}
}
Err(ProcessingError::LLMError(
"No content found in LLM response".into(),
))
})
})
}