Files
minne/src/utils/llm.rs
2024-11-12 21:14:29 +01:00

308 lines
11 KiB
Rust

use crate::models::graph_entities::{
GraphMapper, KnowledgeEntity, KnowledgeEntityType, KnowledgeRelationship,
};
use crate::models::text_content::ProcessingError;
use async_openai::types::{
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequestArgs, CreateEmbeddingRequestArgs, Embedding,
};
use futures::future::try_join_all;
use serde::{Deserialize, Serialize};
use serde_json::json;
use surrealdb::engine::remote::ws::Client;
use surrealdb::Surreal;
use tokio::try_join;
use tracing::{debug, info};
use uuid::Uuid;
/// Represents a single knowledge entity from the LLM.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct LLMKnowledgeEntity {
pub key: String, // Temporary identifier
pub name: String,
pub description: String,
pub entity_type: String, // Should match KnowledgeEntityType variants
}
/// Represents a single relationship from the LLM.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct LLMRelationship {
#[serde(rename = "type")]
pub type_: String, // e.g., RelatedTo, RelevantTo
pub source: String, // Key of the source entity
pub target: String, // Key of the target entity
}
/// Represents the entire graph analysis result from the LLM.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct LLMGraphAnalysisResult {
pub knowledge_entities: Vec<LLMKnowledgeEntity>,
pub relationships: Vec<LLMRelationship>,
}
async fn generate_embedding(
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
input: String,
) -> Result<Vec<f32>, ProcessingError> {
let request = CreateEmbeddingRequestArgs::default()
.model("text-embedding-3-small")
.input(&[input])
.build()
.map_err(|e| ProcessingError::LLMError(e.to_string()))?;
// Send the request to OpenAI
let response = client
.embeddings()
.create(request)
.await
.map_err(|e| ProcessingError::LLMError(e.to_string()))?;
// Extract the embedding vector
let embedding: Vec<f32> = response
.data
.first()
.ok_or_else(|| ProcessingError::LLMError("No embedding data received".into()))?
.embedding
.clone();
Ok(embedding)
}
impl LLMGraphAnalysisResult {
/// Converts the LLM graph analysis result into database entities and relationships.
/// Processes embeddings sequentially for simplicity.
///
/// # Arguments
///
/// * `source_id` - A UUID representing the source identifier.
///
/// # Returns
///
/// * `Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), ProcessingError>` - A tuple containing vectors of `KnowledgeEntity` and `KnowledgeRelationship`.
pub async fn to_database_entities(
&self,
source_id: &Uuid,
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), ProcessingError> {
let mut mapper = GraphMapper::new();
// Step 1: Assign unique IDs to all knowledge entities upfront
for llm_entity in &self.knowledge_entities {
mapper.assign_id(&llm_entity.key);
}
let openai_client = async_openai::Client::new();
let mut entities = vec![];
// Step 2: Process each knowledge entity sequentially
for llm_entity in &self.knowledge_entities {
// Retrieve the assigned ID for the current entity
let assigned_id = mapper
.get_id(&llm_entity.key)
.ok_or_else(|| {
ProcessingError::LLMError(format!("ID not found for key: {}", llm_entity.key))
})?
.clone();
// Prepare the embedding input
let embedding_input = format!(
"name: {}, description: {}, type: {}",
llm_entity.name, llm_entity.description, llm_entity.entity_type
);
// Generate embedding
let embedding = generate_embedding(&openai_client, embedding_input).await?;
// Construct the KnowledgeEntity with embedding
let knowledge_entity = KnowledgeEntity {
id: assigned_id.to_string(),
name: llm_entity.name.clone(),
description: llm_entity.description.clone(),
entity_type: KnowledgeEntityType::from(llm_entity.entity_type.clone()),
source_id: source_id.to_string(),
metadata: None,
embedding: Some(embedding),
};
entities.push(knowledge_entity);
}
// Step 3: Process relationships using the pre-assigned IDs
let relationships: Vec<KnowledgeRelationship> = self
.relationships
.iter()
.filter_map(|llm_rel| {
let source_db_id = mapper.get_id(&llm_rel.source)?;
let target_db_id = mapper.get_id(&llm_rel.target)?;
Some(KnowledgeRelationship {
id: Uuid::new_v4().to_string(),
out: source_db_id.to_string(),
in_: target_db_id.to_string(),
relationship_type: llm_rel.type_.clone(),
metadata: None,
})
})
.collect();
Ok((entities, relationships))
}
}
/// Sends text to an LLM for analysis.
pub async fn create_json_ld(
category: &str,
instructions: &str,
text: &str,
db_client: &Surreal<Client>,
) -> Result<LLMGraphAnalysisResult, ProcessingError> {
// Get the nodes from the database
let entities: Vec<KnowledgeEntity> = db_client
.query("SELECT * FROM knowledge_entity")
.await?
.take(0)?;
for entity in entities {
info!("{:?}", entity.name);
}
let deleted: Vec<KnowledgeEntity> = db_client.delete("knowledge_entity").await?;
info! {"{:?} KnowledgeEntities deleted", deleted.len()};
let relationships: Vec<KnowledgeRelationship> =
db_client.select("knowledge_relationship").await?;
info!("{:?}", relationships);
let relationships_deleted: Vec<KnowledgeRelationship> =
db_client.delete("knowledge_relationship").await?;
info!("{:?} Relationships deleted", relationships_deleted.len());
let client = async_openai::Client::new();
let schema = json!({
"type": "object",
"properties": {
"knowledge_entities": {
"type": "array",
"items": {
"type": "object",
"properties": {
"key": { "type": "string" },
"name": { "type": "string" },
"description": { "type": "string" },
"entity_type": {
"type": "string",
"enum": ["idea", "project", "document", "page", "textsnippet"]
}
},
"required": ["key", "name", "description", "entity_type"],
"additionalProperties": false
}
},
"relationships": {
"type": "array",
"items": {
"type": "object",
"properties": {
"type": {
"type": "string",
"enum": ["RelatedTo", "RelevantTo", "SimilarTo"]
},
"source": { "type": "string" },
"target": { "type": "string" }
},
"required": ["type", "source", "target"],
"additionalProperties": false
}
}
},
"required": ["knowledge_entities", "relationships"],
"additionalProperties": false
});
let response_format = async_openai::types::ResponseFormat::JsonSchema {
json_schema: async_openai::types::ResponseFormatJsonSchema {
description: Some("Structured analysis of the submitted content".into()),
name: "content_analysis".into(),
schema: Some(schema),
strict: Some(true),
},
};
// Construct the system and user messages
let system_message = r#"
You are an expert document analyzer. You will receive a document's text content, along with user instructions and a category. Your task is to provide a structured JSON object representing the content in a graph format suitable for a graph database.
The JSON should have the following structure:
{
"knowledge_entities": [
{
"key": "unique-key-1",
"name": "Entity Name",
"description": "A detailed description of the entity.",
"entity_type": "TypeOfEntity"
},
// More entities...
],
"relationships": [
{
"type": "RelationshipType",
"source": "unique-key-1",
"target": "unique-key-2"
},
// More relationships...
]
}
Guidelines:
1. Do NOT generate any IDs or UUIDs. Use a unique `key` for each knowledge entity.
2. Each KnowledgeEntity should have a unique `key`, a meaningful `name`, and a descriptive `description`.
3. Define the type of each KnowledgeEntity using the following categories: Idea, Project, Document, Page, TextSnippet.
4. Establish relationships between entities using types like RelatedTo, RelevantTo, SimilarTo.
5. Use the `source` key to indicate the originating entity and the `target` key to indicate the related entity"
6. Only create relationships between existing KnowledgeEntities.
"#;
let user_message = format!(
"Category: {}\nInstructions: {}\nContent:\n{}",
category, instructions, text
);
// Build the chat completion request
let request = CreateChatCompletionRequestArgs::default()
.model("gpt-4o-mini")
.max_tokens(2048u32)
.messages([
ChatCompletionRequestSystemMessage::from(system_message).into(),
ChatCompletionRequestUserMessage::from(user_message).into(),
])
.response_format(response_format)
.build()
.map_err(|e| ProcessingError::LLMError(e.to_string()))?;
// Send the request to OpenAI
let response = client
.chat()
.create(request)
.await
.map_err(|e| ProcessingError::LLMError(format!("OpenAI API request failed: {}", e)))?;
debug!("{:?}", response);
// Extract and parse the response
for choice in response.choices {
if let Some(content) = choice.message.content {
let analysis: LLMGraphAnalysisResult = serde_json::from_str(&content).map_err(|e| {
ProcessingError::LLMError(format!(
"Failed to parse LLM response into analysis: {}",
e
))
})?;
return Ok(analysis);
}
}
Err(ProcessingError::LLMError(
"No content found in LLM response".into(),
))
}