storing & retrieving embeddings

This commit is contained in:
Per Stark
2024-11-12 21:14:29 +01:00
parent 0cc4046cf5
commit 97a105082f
4 changed files with 129 additions and 45 deletions

View File

@@ -1,11 +1,18 @@
use crate::models::graph_entities::{GraphMapper, KnowledgeEntity, KnowledgeEntityType, KnowledgeRelationship};
use crate::models::graph_entities::{
GraphMapper, KnowledgeEntity, KnowledgeEntityType, KnowledgeRelationship,
};
use crate::models::text_content::ProcessingError;
use async_openai::types::{CreateChatCompletionRequestArgs, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage };
use async_openai::types::{
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequestArgs, CreateEmbeddingRequestArgs, Embedding,
};
use futures::future::try_join_all;
use serde::{Deserialize, Serialize};
use serde_json::json;
use surrealdb::engine::remote::ws::Client;
use surrealdb::Surreal;
use tracing::{info,debug};
use tokio::try_join;
use tracing::{debug, info};
use uuid::Uuid;
/// Represents a single knowledge entity from the LLM.
@@ -33,49 +40,112 @@ pub struct LLMGraphAnalysisResult {
pub relationships: Vec<LLMRelationship>,
}
async fn generate_embedding(
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
input: String,
) -> Result<Vec<f32>, ProcessingError> {
let request = CreateEmbeddingRequestArgs::default()
.model("text-embedding-3-small")
.input(&[input])
.build()
.map_err(|e| ProcessingError::LLMError(e.to_string()))?;
// Send the request to OpenAI
let response = client
.embeddings()
.create(request)
.await
.map_err(|e| ProcessingError::LLMError(e.to_string()))?;
// Extract the embedding vector
let embedding: Vec<f32> = response
.data
.first()
.ok_or_else(|| ProcessingError::LLMError("No embedding data received".into()))?
.embedding
.clone();
Ok(embedding)
}
impl LLMGraphAnalysisResult {
pub fn to_database_entities(
/// Converts the LLM graph analysis result into database entities and relationships.
/// Processes embeddings sequentially for simplicity.
///
/// # Arguments
///
/// * `source_id` - A UUID representing the source identifier.
///
/// # Returns
///
/// * `Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), ProcessingError>` - A tuple containing vectors of `KnowledgeEntity` and `KnowledgeRelationship`.
pub async fn to_database_entities(
&self,
source_id: &Uuid,
) -> (Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>) {
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), ProcessingError> {
let mut mapper = GraphMapper::new();
// First pass: Create all entities and map their keys to UUIDs
let entities: Vec<KnowledgeEntity> = self
.knowledge_entities
.iter()
.map(|llm_entity| {
let id = mapper.assign_id(&llm_entity.key);
KnowledgeEntity {
id: id.to_string(),
name: llm_entity.name.clone(),
description: llm_entity.description.clone(),
entity_type: KnowledgeEntityType::from(llm_entity.entity_type.clone()),
source_id: source_id.to_string(),
metadata: None,
}
})
.collect();
// Step 1: Assign unique IDs to all knowledge entities upfront
for llm_entity in &self.knowledge_entities {
mapper.assign_id(&llm_entity.key);
}
// Second pass: Create relationships using mapped UUIDs
let openai_client = async_openai::Client::new();
let mut entities = vec![];
// Step 2: Process each knowledge entity sequentially
for llm_entity in &self.knowledge_entities {
// Retrieve the assigned ID for the current entity
let assigned_id = mapper
.get_id(&llm_entity.key)
.ok_or_else(|| {
ProcessingError::LLMError(format!("ID not found for key: {}", llm_entity.key))
})?
.clone();
// Prepare the embedding input
let embedding_input = format!(
"name: {}, description: {}, type: {}",
llm_entity.name, llm_entity.description, llm_entity.entity_type
);
// Generate embedding
let embedding = generate_embedding(&openai_client, embedding_input).await?;
// Construct the KnowledgeEntity with embedding
let knowledge_entity = KnowledgeEntity {
id: assigned_id.to_string(),
name: llm_entity.name.clone(),
description: llm_entity.description.clone(),
entity_type: KnowledgeEntityType::from(llm_entity.entity_type.clone()),
source_id: source_id.to_string(),
metadata: None,
embedding: Some(embedding),
};
entities.push(knowledge_entity);
}
// Step 3: Process relationships using the pre-assigned IDs
let relationships: Vec<KnowledgeRelationship> = self
.relationships
.iter()
.filter_map(|llm_rel| {
let source_id = mapper.get_id(&llm_rel.source)?;
let target_id = mapper.get_id(&llm_rel.target)?;
let source_db_id = mapper.get_id(&llm_rel.source)?;
let target_db_id = mapper.get_id(&llm_rel.target)?;
Some(KnowledgeRelationship {
id: Uuid::new_v4().to_string(),
out: source_id.to_string(),
in_: target_id.to_string(),
out: source_db_id.to_string(),
in_: target_db_id.to_string(),
relationship_type: llm_rel.type_.clone(),
metadata: None,
})
})
.collect();
(entities, relationships)
Ok((entities, relationships))
}
}
@@ -87,16 +157,23 @@ pub async fn create_json_ld(
db_client: &Surreal<Client>,
) -> Result<LLMGraphAnalysisResult, ProcessingError> {
// Get the nodes from the database
let entities: Vec<KnowledgeEntity> = db_client.select("knowledge_entity").await?;
info!("{:?}", entities);
let entities: Vec<KnowledgeEntity> = db_client
.query("SELECT * FROM knowledge_entity")
.await?
.take(0)?;
for entity in entities {
info!("{:?}", entity.name);
}
let deleted: Vec<KnowledgeEntity> = db_client.delete("knowledge_entity").await?;
info! {"{:?} KnowledgeEntities deleted", deleted.len()};
let relationships: Vec<KnowledgeRelationship> = db_client.select("knowledge_relationship").await?;
let relationships: Vec<KnowledgeRelationship> =
db_client.select("knowledge_relationship").await?;
info!("{:?}", relationships);
let relationships_deleted: Vec<KnowledgeRelationship> = db_client.delete("knowledge_relationship").await?;
let relationships_deleted: Vec<KnowledgeRelationship> =
db_client.delete("knowledge_relationship").await?;
info!("{:?} Relationships deleted", relationships_deleted.len());
let client = async_openai::Client::new();