From dfa597758f4191b28c76f56ced82991122dedbe3 Mon Sep 17 00:00:00 2001 From: Per Stark Date: Tue, 12 Nov 2024 21:14:29 +0100 Subject: [PATCH] storing & retrieving embeddings --- flake.nix | 7 +- src/models/graph_entities.rs | 2 + src/models/text_content.rs | 22 ++++-- src/utils/llm.rs | 143 +++++++++++++++++++++++++++-------- 4 files changed, 129 insertions(+), 45 deletions(-) diff --git a/flake.nix b/flake.nix index 96e11ba..32f679c 100644 --- a/flake.nix +++ b/flake.nix @@ -41,20 +41,15 @@ ''; packages = [ - pkgs.neo4j ]; languages.rust.enable = true; processes = { - # start-neo4j.exec = "NEO4J_HOME=$(mktemp -d) neo4j console"; - surreal_db.exec = "docker run --rm --pull always -p 8000:8000 --user $(id -u) -v $(pwd)/database:/database surrealdb/surrealdb:latest start rocksdb:/database/database.db --user root_user --pass root_password"; + surreal_db.exec = "docker run --rm --pull always -p 8000:8000 --user $(id -u) -v $(pwd)/database:/database surrealdb/surrealdb:latest-dev start rocksdb:/database/database.db --user root_user --pass root_password"; }; services = { - redis = { - enable = true; - }; rabbitmq = { enable = true; # plugins = ["tracing"]; diff --git a/src/models/graph_entities.rs b/src/models/graph_entities.rs index 866b277..4d6bb1c 100644 --- a/src/models/graph_entities.rs +++ b/src/models/graph_entities.rs @@ -15,6 +15,7 @@ pub struct KnowledgeEntity { pub entity_type: KnowledgeEntityType, pub source_id: String, pub metadata: Option, + pub embedding: Option>, } fn thing_to_string<'de, D>(deserializer: D) -> Result @@ -61,6 +62,7 @@ pub struct KnowledgeRelationship { } /// Intermediate struct to hold mapping between LLM keys and generated IDs. +#[derive(Clone)] pub struct GraphMapper { pub key_to_id: HashMap, } diff --git a/src/models/text_content.rs b/src/models/text_content.rs index 5d02db8..978a70d 100644 --- a/src/models/text_content.rs +++ b/src/models/text_content.rs @@ -1,3 +1,4 @@ +use async_openai::{error::OpenAIError, types::{CreateEmbeddingRequest, CreateEmbeddingRequestArgs}}; use serde::{Deserialize, Serialize}; use surrealdb::{engine::remote::ws::Client, Surreal}; use tracing::{debug, info}; @@ -37,6 +38,9 @@ pub enum ProcessingError { #[error("Unknown processing error")] Unknown, + + #[error("LLM processing error: {0}")] + OpenAIerror(#[from] OpenAIError), } @@ -45,14 +49,17 @@ impl TextContent { pub async fn process(&self) -> Result<(), ProcessingError> { // Store TextContent let db_client = SurrealDbClient::new().await?; + + db_client.query("REMOVE INDEX embeddings ON knowledge_entity").await?; + // db_client.query("DEFINE INDEX embeddings ON knowledge_entity FIELDS embedding UNIQUE").await?; + // db_client.query("REBUILD INDEX IF EXISTS embeddings ON knowledge_entity").await?; // Step 1: Send to LLM for analysis let analysis = create_json_ld(&self.category, &self.instructions, &self.text, &db_client).await?; // info!("{:#?}", &analysis); - // Step 2: Convert LLM analysis to database entities - let (entities, relationships) = analysis.to_database_entities(&self.id); + let (entities, relationships) = analysis.to_database_entities(&self.id).await?; // Step 3: Store in database self.store_in_graph_db(entities, relationships, &db_client).await?; @@ -70,19 +77,22 @@ impl TextContent { relationships: Vec, db_client: &Surreal, ) -> Result<(), ProcessingError> { - for entity in entities { - info!("{:?}", entity); + for entity in entities { + // info!("{:?}", &entity); let _created: Option = db_client .create(("knowledge_entity", &entity.id.to_string())) - .content(entity) + .content(entity.clone()) .await?; debug!("{:?}",_created); + + + } for relationship in relationships { - info!("{:?}", relationship); + // info!("{:?}", relationship); let _created: Option = db_client .insert(("knowledge_relationship", &relationship.id.to_string())) diff --git a/src/utils/llm.rs b/src/utils/llm.rs index c98a51e..50eed10 100644 --- a/src/utils/llm.rs +++ b/src/utils/llm.rs @@ -1,11 +1,18 @@ -use crate::models::graph_entities::{GraphMapper, KnowledgeEntity, KnowledgeEntityType, KnowledgeRelationship}; +use crate::models::graph_entities::{ + GraphMapper, KnowledgeEntity, KnowledgeEntityType, KnowledgeRelationship, +}; use crate::models::text_content::ProcessingError; -use async_openai::types::{CreateChatCompletionRequestArgs, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage }; +use async_openai::types::{ + ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, + CreateChatCompletionRequestArgs, CreateEmbeddingRequestArgs, Embedding, +}; +use futures::future::try_join_all; use serde::{Deserialize, Serialize}; use serde_json::json; use surrealdb::engine::remote::ws::Client; use surrealdb::Surreal; -use tracing::{info,debug}; +use tokio::try_join; +use tracing::{debug, info}; use uuid::Uuid; /// Represents a single knowledge entity from the LLM. @@ -33,49 +40,112 @@ pub struct LLMGraphAnalysisResult { pub relationships: Vec, } +async fn generate_embedding( + client: &async_openai::Client, + input: String, +) -> Result, ProcessingError> { + let request = CreateEmbeddingRequestArgs::default() + .model("text-embedding-3-small") + .input(&[input]) + .build() + .map_err(|e| ProcessingError::LLMError(e.to_string()))?; + + // Send the request to OpenAI + let response = client + .embeddings() + .create(request) + .await + .map_err(|e| ProcessingError::LLMError(e.to_string()))?; + + // Extract the embedding vector + let embedding: Vec = response + .data + .first() + .ok_or_else(|| ProcessingError::LLMError("No embedding data received".into()))? + .embedding + .clone(); + + Ok(embedding) +} + impl LLMGraphAnalysisResult { - pub fn to_database_entities( + /// Converts the LLM graph analysis result into database entities and relationships. + /// Processes embeddings sequentially for simplicity. + /// + /// # Arguments + /// + /// * `source_id` - A UUID representing the source identifier. + /// + /// # Returns + /// + /// * `Result<(Vec, Vec), ProcessingError>` - A tuple containing vectors of `KnowledgeEntity` and `KnowledgeRelationship`. + pub async fn to_database_entities( &self, source_id: &Uuid, - ) -> (Vec, Vec) { + ) -> Result<(Vec, Vec), ProcessingError> { let mut mapper = GraphMapper::new(); - // First pass: Create all entities and map their keys to UUIDs - let entities: Vec = self - .knowledge_entities - .iter() - .map(|llm_entity| { - let id = mapper.assign_id(&llm_entity.key); - KnowledgeEntity { - id: id.to_string(), - name: llm_entity.name.clone(), - description: llm_entity.description.clone(), - entity_type: KnowledgeEntityType::from(llm_entity.entity_type.clone()), - source_id: source_id.to_string(), - metadata: None, - } - }) - .collect(); + // Step 1: Assign unique IDs to all knowledge entities upfront + for llm_entity in &self.knowledge_entities { + mapper.assign_id(&llm_entity.key); + } - // Second pass: Create relationships using mapped UUIDs + let openai_client = async_openai::Client::new(); + + let mut entities = vec![]; + + // Step 2: Process each knowledge entity sequentially + for llm_entity in &self.knowledge_entities { + // Retrieve the assigned ID for the current entity + let assigned_id = mapper + .get_id(&llm_entity.key) + .ok_or_else(|| { + ProcessingError::LLMError(format!("ID not found for key: {}", llm_entity.key)) + })? + .clone(); + + // Prepare the embedding input + let embedding_input = format!( + "name: {}, description: {}, type: {}", + llm_entity.name, llm_entity.description, llm_entity.entity_type + ); + + // Generate embedding + let embedding = generate_embedding(&openai_client, embedding_input).await?; + + // Construct the KnowledgeEntity with embedding + let knowledge_entity = KnowledgeEntity { + id: assigned_id.to_string(), + name: llm_entity.name.clone(), + description: llm_entity.description.clone(), + entity_type: KnowledgeEntityType::from(llm_entity.entity_type.clone()), + source_id: source_id.to_string(), + metadata: None, + embedding: Some(embedding), + }; + + entities.push(knowledge_entity); + } + + // Step 3: Process relationships using the pre-assigned IDs let relationships: Vec = self .relationships .iter() .filter_map(|llm_rel| { - let source_id = mapper.get_id(&llm_rel.source)?; - let target_id = mapper.get_id(&llm_rel.target)?; + let source_db_id = mapper.get_id(&llm_rel.source)?; + let target_db_id = mapper.get_id(&llm_rel.target)?; Some(KnowledgeRelationship { id: Uuid::new_v4().to_string(), - out: source_id.to_string(), - in_: target_id.to_string(), + out: source_db_id.to_string(), + in_: target_db_id.to_string(), relationship_type: llm_rel.type_.clone(), metadata: None, }) }) .collect(); - (entities, relationships) + Ok((entities, relationships)) } } @@ -87,16 +157,23 @@ pub async fn create_json_ld( db_client: &Surreal, ) -> Result { // Get the nodes from the database - let entities: Vec = db_client.select("knowledge_entity").await?; - info!("{:?}", entities); + let entities: Vec = db_client + .query("SELECT * FROM knowledge_entity") + .await? + .take(0)?; + for entity in entities { + info!("{:?}", entity.name); + } let deleted: Vec = db_client.delete("knowledge_entity").await?; info! {"{:?} KnowledgeEntities deleted", deleted.len()}; - - let relationships: Vec = db_client.select("knowledge_relationship").await?; + + let relationships: Vec = + db_client.select("knowledge_relationship").await?; info!("{:?}", relationships); - - let relationships_deleted: Vec = db_client.delete("knowledge_relationship").await?; + + let relationships_deleted: Vec = + db_client.delete("knowledge_relationship").await?; info!("{:?} Relationships deleted", relationships_deleted.len()); let client = async_openai::Client::new();