feat: refactoring complete?

This commit is contained in:
Per Stark
2024-11-21 21:23:49 +01:00
parent 94f328e542
commit 1e789e1153
27 changed files with 428 additions and 338 deletions

View File

@@ -0,0 +1,142 @@
use crate::{
error::ProcessingError,
ingress::analysis::prompt::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE},
retrieval::vector::find_items_by_vector_similarity,
storage::types::{knowledge_entity::KnowledgeEntity, StoredObject},
};
use async_openai::types::{
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, ResponseFormat,
ResponseFormatJsonSchema,
};
use serde_json::json;
use surrealdb::engine::remote::ws::Client;
use surrealdb::Surreal;
use tracing::debug;
use super::types::llm_analysis_result::LLMGraphAnalysisResult;
pub struct IngressAnalyzer<'a> {
db_client: &'a Surreal<Client>,
openai_client: &'a async_openai::Client<async_openai::config::OpenAIConfig>,
}
impl<'a> IngressAnalyzer<'a> {
pub fn new(
db_client: &'a Surreal<Client>,
openai_client: &'a async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Self {
Self {
db_client,
openai_client,
}
}
pub async fn analyze_content(
&self,
category: &str,
instructions: &str,
text: &str,
) -> Result<LLMGraphAnalysisResult, ProcessingError> {
let similar_entities = self
.find_similar_entities(category, instructions, text)
.await?;
let llm_request =
self.prepare_llm_request(category, instructions, text, &similar_entities)?;
self.perform_analysis(llm_request).await
}
async fn find_similar_entities(
&self,
category: &str,
instructions: &str,
text: &str,
) -> Result<Vec<KnowledgeEntity>, ProcessingError> {
let input_text = format!(
"content: {}, category: {}, user_instructions: {}",
text, category, instructions
);
find_items_by_vector_similarity(
10,
input_text,
self.db_client,
KnowledgeEntity::table_name().to_string(),
self.openai_client,
)
.await
}
fn prepare_llm_request(
&self,
category: &str,
instructions: &str,
text: &str,
similar_entities: &[KnowledgeEntity],
) -> Result<CreateChatCompletionRequest, ProcessingError> {
let entities_json = json!(similar_entities
.iter()
.map(|entity| {
json!({
"KnowledgeEntity": {
"id": entity.id,
"name": entity.name,
"description": entity.description
}
})
})
.collect::<Vec<_>>());
let user_message = format!(
"Category:\n{}\nInstructions:\n{}\nContent:\n{}\nExisting KnowledgeEntities in database:\n{}",
category, instructions, text, entities_json
);
debug!("Prepared LLM request message: {}", user_message);
let response_format = ResponseFormat::JsonSchema {
json_schema: ResponseFormatJsonSchema {
description: Some("Structured analysis of the submitted content".into()),
name: "content_analysis".into(),
schema: Some(get_ingress_analysis_schema()),
strict: Some(true),
},
};
CreateChatCompletionRequestArgs::default()
.model("gpt-4o-mini")
.temperature(0.2)
.max_tokens(2048u32)
.messages([
ChatCompletionRequestSystemMessage::from(INGRESS_ANALYSIS_SYSTEM_MESSAGE).into(),
ChatCompletionRequestUserMessage::from(user_message).into(),
])
.response_format(response_format)
.build()
.map_err(|e| ProcessingError::LLMParsingError(e.to_string()))
}
async fn perform_analysis(
&self,
request: CreateChatCompletionRequest,
) -> Result<LLMGraphAnalysisResult, ProcessingError> {
let response = self.openai_client.chat().create(request).await?;
debug!("Received LLM response: {:?}", response);
response
.choices
.first()
.and_then(|choice| choice.message.content.as_ref())
.ok_or(ProcessingError::LLMParsingError(
"No content found in LLM response".into(),
))
.and_then(|content| {
serde_json::from_str(content).map_err(|e| {
ProcessingError::LLMParsingError(format!(
"Failed to parse LLM response into analysis: {}",
e
))
})
})
}
}

View File

@@ -0,0 +1,3 @@
pub mod ingress_analyser;
pub mod prompt;
pub mod types;

View File

@@ -0,0 +1,81 @@
use serde_json::{json, Value};
pub static INGRESS_ANALYSIS_SYSTEM_MESSAGE: &str = r#"
You are an expert document analyzer. You will receive a document's text content, along with user instructions and a category. Your task is to provide a structured JSON object representing the content in a graph format suitable for a graph database. You will also be presented with some existing knowledge_entities from the database, do not replicate these!
The JSON should have the following structure:
{
"knowledge_entities": [
{
"key": "unique-key-1",
"name": "Entity Name",
"description": "A detailed description of the entity.",
"entity_type": "TypeOfEntity"
},
// More entities...
],
"relationships": [
{
"type": "RelationshipType",
"source": "unique-key-1 or UUID from existing database",
"target": "unique-key-1 or UUID from existing database"
},
// More relationships...
]
}
Guidelines:
1. Do NOT generate any IDs or UUIDs. Use a unique `key` for each knowledge entity.
2. Each KnowledgeEntity should have a unique `key`, a meaningful `name`, and a descriptive `description`.
3. Define the type of each KnowledgeEntity using the following categories: Idea, Project, Document, Page, TextSnippet.
4. Establish relationships between entities using types like RelatedTo, RelevantTo, SimilarTo.
5. Use the `source` key to indicate the originating entity and the `target` key to indicate the related entity"
6. You will be presented with a few existing KnowledgeEntities that are similar to the current ones. They will have an existing UUID. When creating relationships to these entities, use their UUID.
7. Only create relationships between existing KnowledgeEntities.
8. Entities that exist already in the database should NOT be created again. If there is only a minor overlap, skip creating a new entity.
9. A new relationship MUST include a newly created KnowledgeEntity.
"#;
pub fn get_ingress_analysis_schema() -> Value {
json!({
"type": "object",
"properties": {
"knowledge_entities": {
"type": "array",
"items": {
"type": "object",
"properties": {
"key": { "type": "string" },
"name": { "type": "string" },
"description": { "type": "string" },
"entity_type": {
"type": "string",
"enum": ["idea", "project", "document", "page", "textsnippet"]
}
},
"required": ["key", "name", "description", "entity_type"],
"additionalProperties": false
}
},
"relationships": {
"type": "array",
"items": {
"type": "object",
"properties": {
"type": {
"type": "string",
"enum": ["RelatedTo", "RelevantTo", "SimilarTo"]
},
"source": { "type": "string" },
"target": { "type": "string" }
},
"required": ["type", "source", "target"],
"additionalProperties": false
}
}
},
"required": ["knowledge_entities", "relationships"],
"additionalProperties": false
})
}

View File

@@ -0,0 +1,42 @@
use std::collections::HashMap;
use uuid::Uuid;
/// Intermediate struct to hold mapping between LLM keys and generated IDs.
#[derive(Clone)]
pub struct GraphMapper {
pub key_to_id: HashMap<String, Uuid>,
}
impl Default for GraphMapper {
fn default() -> Self {
GraphMapper::new()
}
}
impl GraphMapper {
pub fn new() -> Self {
GraphMapper {
key_to_id: HashMap::new(),
}
}
/// Get ID, tries to parse UUID
pub fn get_or_parse_id(&mut self, key: &str) -> Uuid {
if let Ok(parsed_uuid) = Uuid::parse_str(key) {
parsed_uuid
} else {
self.key_to_id.get(key).unwrap().clone()
}
}
/// Assigns a new UUID for a given key.
pub fn assign_id(&mut self, key: &str) -> Uuid {
let id = Uuid::new_v4();
self.key_to_id.insert(key.to_string(), id);
id
}
/// Retrieves the UUID for a given key.
pub fn get_id(&self, key: &str) -> Option<&Uuid> {
self.key_to_id.get(key)
}
}

View File

@@ -0,0 +1,175 @@
use std::sync::{Arc, Mutex};
use serde::{Deserialize, Serialize};
use tokio::task;
use crate::{
error::ProcessingError,
storage::types::{
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
knowledge_relationship::KnowledgeRelationship,
},
utils::embedding::generate_embedding,
};
use futures::future::try_join_all;
use super::graph_mapper::GraphMapper; // For future parallelization
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct LLMKnowledgeEntity {
pub key: String, // Temporary identifier
pub name: String,
pub description: String,
pub entity_type: String, // Should match KnowledgeEntityType variants
}
/// Represents a single relationship from the LLM.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct LLMRelationship {
#[serde(rename = "type")]
pub type_: String, // e.g., RelatedTo, RelevantTo
pub source: String, // Key of the source entity
pub target: String, // Key of the target entity
}
/// Represents the entire graph analysis result from the LLM.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct LLMGraphAnalysisResult {
pub knowledge_entities: Vec<LLMKnowledgeEntity>,
pub relationships: Vec<LLMRelationship>,
}
/// Converts the LLM graph analysis result into database entities and relationships.
/// Processes embeddings sequentially for simplicity.
///
/// # Arguments
///
/// * `source_id` - A UUID representing the source identifier.
/// * `openai_client` - OpenAI client for LLM calls.
///
/// # Returns
///
/// * `Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), ProcessingError>` - A tuple containing vectors of `KnowledgeEntity` and `KnowledgeRelationship`.
impl LLMGraphAnalysisResult {
// Split the main function into smaller, focused functions
pub async fn to_database_entities(
&self,
source_id: &str,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), ProcessingError> {
// Create mapper and pre-assign IDs
let mapper = Arc::new(Mutex::new(self.create_mapper()?));
// Process entities (prepared for future parallelization)
let entities = self
.process_entities(source_id, Arc::clone(&mapper), openai_client)
.await?;
// Process relationships
let relationships = self.process_relationships(Arc::clone(&mapper))?;
Ok((entities, relationships))
}
fn create_mapper(&self) -> Result<GraphMapper, ProcessingError> {
let mut mapper = GraphMapper::new();
// Pre-assign all IDs
for entity in &self.knowledge_entities {
mapper.assign_id(&entity.key);
}
Ok(mapper)
}
async fn process_entities(
&self,
source_id: &str,
mapper: Arc<Mutex<GraphMapper>>,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<Vec<KnowledgeEntity>, ProcessingError> {
let futures: Vec<_> = self
.knowledge_entities
.iter()
.map(|entity| {
let mapper = Arc::clone(&mapper);
let openai_client = openai_client.clone();
let source_id = source_id.to_string();
let entity = entity.clone();
task::spawn(async move {
create_single_entity(&entity, &source_id, mapper, &openai_client).await
})
})
.collect();
let results = try_join_all(futures)
.await?
.into_iter()
.collect::<Result<Vec<_>, _>>()?;
Ok(results)
}
fn process_relationships(
&self,
mapper: Arc<Mutex<GraphMapper>>,
) -> Result<Vec<KnowledgeRelationship>, ProcessingError> {
let mut mapper_guard = mapper
.lock()
.map_err(|_| ProcessingError::GraphProcessingError("Failed to lock mapper".into()))?;
self.relationships
.iter()
.map(|rel| {
let source_db_id = mapper_guard.get_or_parse_id(&rel.source);
let target_db_id = mapper_guard.get_or_parse_id(&rel.target);
Ok(KnowledgeRelationship::new(
source_db_id.to_string(),
target_db_id.to_string(),
rel.type_.clone(),
None,
))
})
.collect()
}
}
async fn create_single_entity(
llm_entity: &LLMKnowledgeEntity,
source_id: &str,
mapper: Arc<Mutex<GraphMapper>>,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<KnowledgeEntity, ProcessingError> {
let assigned_id = {
let mapper = mapper
.lock()
.map_err(|_| ProcessingError::GraphProcessingError("Failed to lock mapper".into()))?;
mapper
.get_id(&llm_entity.key)
.ok_or_else(|| {
ProcessingError::GraphProcessingError(format!(
"ID not found for key: {}",
llm_entity.key
))
})?
.to_string()
};
let embedding_input = format!(
"name: {}, description: {}, type: {}",
llm_entity.name, llm_entity.description, llm_entity.entity_type
);
let embedding = generate_embedding(openai_client, embedding_input).await?;
Ok(KnowledgeEntity {
id: assigned_id,
name: llm_entity.name.to_string(),
description: llm_entity.description.to_string(),
entity_type: KnowledgeEntityType::from(llm_entity.entity_type.to_string()),
source_id: source_id.to_string(),
metadata: None,
embedding,
})
}

View File

@@ -0,0 +1,2 @@
pub mod graph_mapper;
pub mod llm_analysis_result;

View File

@@ -0,0 +1,119 @@
use text_splitter::TextSplitter;
use tracing::{debug, info};
use crate::{
error::ProcessingError,
retrieval::vector::find_items_by_vector_similarity,
storage::{
db::{store_item, SurrealDbClient},
types::{
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
text_chunk::TextChunk, text_content::TextContent,
},
},
utils::embedding::generate_embedding,
};
use super::analysis::{
ingress_analyser::IngressAnalyzer, types::llm_analysis_result::LLMGraphAnalysisResult,
};
pub struct ContentProcessor {
db_client: SurrealDbClient,
openai_client: async_openai::Client<async_openai::config::OpenAIConfig>,
}
impl ContentProcessor {
pub async fn new() -> Result<Self, ProcessingError> {
Ok(Self {
db_client: SurrealDbClient::new().await?,
openai_client: async_openai::Client::new(),
})
}
pub async fn process(&self, content: &TextContent) -> Result<(), ProcessingError> {
// Store original content
store_item(&self.db_client, content.clone()).await?;
// Process in parallel where possible
let (analysis, _similar_chunks) = tokio::try_join!(
self.perform_semantic_analysis(content),
self.find_similar_content(content),
)?;
// Convert and store entities
let (entities, relationships) = analysis
.to_database_entities(&content.id, &self.openai_client)
.await?;
// Store everything
tokio::try_join!(
self.store_graph_entities(entities, relationships),
self.store_vector_chunks(content),
)?;
self.db_client.rebuild_indexes().await?;
Ok(())
}
async fn perform_semantic_analysis(
&self,
content: &TextContent,
) -> Result<LLMGraphAnalysisResult, ProcessingError> {
let analyser = IngressAnalyzer::new(&self.db_client, &self.openai_client);
analyser
.analyze_content(&content.category, &content.instructions, &content.text)
.await
}
async fn find_similar_content(
&self,
content: &TextContent,
) -> Result<Vec<TextChunk>, ProcessingError> {
find_items_by_vector_similarity(
3,
content.text.clone(),
&self.db_client,
"text_chunk".to_string(),
&self.openai_client,
)
.await
}
async fn store_graph_entities(
&self,
entities: Vec<KnowledgeEntity>,
relationships: Vec<KnowledgeRelationship>,
) -> Result<(), ProcessingError> {
for entity in &entities {
debug!("Storing entity: {:?}", entity);
store_item(&self.db_client, entity.clone()).await?;
}
for relationship in &relationships {
debug!("Storing relationship: {:?}", relationship);
store_item(&self.db_client, relationship.clone()).await?;
}
info!(
"Stored {} entities and {} relationships",
entities.len(),
relationships.len()
);
Ok(())
}
async fn store_vector_chunks(&self, content: &TextContent) -> Result<(), ProcessingError> {
let splitter = TextSplitter::new(500..2000);
let chunks = splitter.chunks(&content.text);
// Could potentially process chunks in parallel with a bounded concurrent limit
for chunk in chunks {
let embedding = generate_embedding(&self.openai_client, chunk.to_string()).await?;
let text_chunk = TextChunk::new(content.id.to_string(), chunk.to_string(), embedding);
store_item(&self.db_client, text_chunk).await?;
}
Ok(())
}
}

3
src/ingress/mod.rs Normal file
View File

@@ -0,0 +1,3 @@
pub mod analysis;
pub mod content_processor;
pub mod types;

View File

@@ -0,0 +1,107 @@
use super::ingress_object::IngressObject;
use crate::storage::{db::SurrealDbClient, types::file_info::FileInfo};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tracing::info;
use url::Url;
use uuid::Uuid;
/// Struct defining the expected body when ingressing content.
#[derive(Serialize, Deserialize, Debug)]
pub struct IngressInput {
pub content: Option<String>,
pub instructions: String,
pub category: String,
pub files: Option<Vec<String>>,
}
/// Error types for processing ingress content.
#[derive(Error, Debug)]
pub enum IngressContentError {
#[error("IO error occurred: {0}")]
Io(#[from] std::io::Error),
#[error("UTF-8 conversion error: {0}")]
Utf8(#[from] std::string::FromUtf8Error),
#[error("MIME type detection failed for input: {0}")]
MimeDetection(String),
#[error("Unsupported MIME type: {0}")]
UnsupportedMime(String),
#[error("URL parse error: {0}")]
UrlParse(#[from] url::ParseError),
#[error("UUID parse error: {0}")]
UuidParse(#[from] uuid::Error),
#[error("Redis error: {0}")]
RedisError(String),
}
/// Function to create ingress objects from input.
///
/// # Arguments
/// * `input` - IngressInput containing information needed to ingress content.
/// * `redis_client` - Initialized redis client needed to retrieve file information
///
/// # Returns
/// * `Vec<IngressObject>` - An array containing the ingressed objects, one file/contenttype per object.
pub async fn create_ingress_objects(
input: IngressInput,
db_client: &SurrealDbClient,
) -> Result<Vec<IngressObject>, IngressContentError> {
// Initialize list
let mut object_list = Vec::new();
// Create a IngressObject from input.content if it exists, checking for URL or text
if let Some(input_content) = input.content {
match Url::parse(&input_content) {
Ok(url) => {
info!("Detected URL: {}", url);
object_list.push(IngressObject::Url {
url: url.to_string(),
instructions: input.instructions.clone(),
category: input.category.clone(),
});
}
Err(_) => {
info!("Treating input as plain text");
object_list.push(IngressObject::Text {
text: input_content.to_string(),
instructions: input.instructions.clone(),
category: input.category.clone(),
});
}
}
}
// Look up FileInfo objects using the db and the submitted uuids in input.files
if let Some(file_uuids) = input.files {
for uuid_str in file_uuids {
let uuid = Uuid::parse_str(&uuid_str)?;
match FileInfo::get_by_uuid(uuid, db_client).await {
Ok(file_info) => {
object_list.push(IngressObject::File {
file_info,
instructions: input.instructions.clone(),
category: input.category.clone(),
});
}
_ => {
info!("No file with UUID: {}", uuid);
}
}
}
}
// If no objects are constructed, we return Err
if object_list.is_empty() {
return Err(IngressContentError::MimeDetection(
"No valid content or files provided".into(),
));
}
Ok(object_list)
}

View File

@@ -0,0 +1,119 @@
use crate::storage::types::{file_info::FileInfo, text_content::TextContent};
use serde::{Deserialize, Serialize};
use super::ingress_input::IngressContentError;
/// Knowledge object type, containing the content or reference to it, as well as metadata
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum IngressObject {
Url {
url: String,
instructions: String,
category: String,
},
Text {
text: String,
instructions: String,
category: String,
},
File {
file_info: FileInfo,
instructions: String,
category: String,
},
}
impl IngressObject {
/// Creates a new `TextContent` instance from a `IngressObject`.
///
/// # Arguments
/// `&self` - A reference to the `IngressObject`.
///
/// # Returns
/// `TextContent` - An object containing a text representation of the object, could be a scraped URL, parsed PDF, etc.
pub async fn to_text_content(&self) -> Result<TextContent, IngressContentError> {
match self {
IngressObject::Url {
url,
instructions,
category,
} => {
let text = Self::fetch_text_from_url(url).await?;
Ok(TextContent::new(
text,
instructions.into(),
category.into(),
None,
))
}
IngressObject::Text {
text,
instructions,
category,
} => Ok(TextContent::new(
text.into(),
instructions.into(),
category.into(),
None,
)),
IngressObject::File {
file_info,
instructions,
category,
} => {
let text = Self::extract_text_from_file(file_info).await?;
Ok(TextContent::new(
text,
instructions.into(),
category.into(),
Some(file_info.to_owned()),
))
}
}
}
/// Fetches and extracts text from a URL.
async fn fetch_text_from_url(_url: &str) -> Result<String, IngressContentError> {
unimplemented!()
}
/// Extracts text from a file based on its MIME type.
async fn extract_text_from_file(file_info: &FileInfo) -> Result<String, IngressContentError> {
match file_info.mime_type.as_str() {
"text/plain" => {
// Read the file and return its content
let content = tokio::fs::read_to_string(&file_info.path).await?;
Ok(content)
}
"text/markdown" => {
// Read the file and return its content
let content = tokio::fs::read_to_string(&file_info.path).await?;
Ok(content)
}
"application/pdf" => {
// TODO: Implement PDF text extraction using a crate like `pdf-extract` or `lopdf`
Err(IngressContentError::UnsupportedMime(
file_info.mime_type.clone(),
))
}
"image/png" | "image/jpeg" => {
// TODO: Implement OCR on image using a crate like `tesseract`
Err(IngressContentError::UnsupportedMime(
file_info.mime_type.clone(),
))
}
"application/octet-stream" => {
let content = tokio::fs::read_to_string(&file_info.path).await?;
Ok(content)
}
"text/x-rust" => {
let content = tokio::fs::read_to_string(&file_info.path).await?;
Ok(content)
}
// Handle other MIME types as needed
_ => Err(IngressContentError::UnsupportedMime(
file_info.mime_type.clone(),
)),
}
}
}

2
src/ingress/types/mod.rs Normal file
View File

@@ -0,0 +1,2 @@
pub mod ingress_input;
pub mod ingress_object;