storing & retrieving embeddings

This commit is contained in:
Per Stark
2024-11-12 21:14:29 +01:00
parent 0f491f88d3
commit dfa597758f
4 changed files with 129 additions and 45 deletions

View File

@@ -15,6 +15,7 @@ pub struct KnowledgeEntity {
pub entity_type: KnowledgeEntityType,
pub source_id: String,
pub metadata: Option<serde_json::Value>,
pub embedding: Option<Vec<f32>>,
}
fn thing_to_string<'de, D>(deserializer: D) -> Result<String, D::Error>
@@ -61,6 +62,7 @@ pub struct KnowledgeRelationship {
}
/// Intermediate struct to hold mapping between LLM keys and generated IDs.
#[derive(Clone)]
pub struct GraphMapper {
pub key_to_id: HashMap<String, Uuid>,
}

View File

@@ -1,3 +1,4 @@
use async_openai::{error::OpenAIError, types::{CreateEmbeddingRequest, CreateEmbeddingRequestArgs}};
use serde::{Deserialize, Serialize};
use surrealdb::{engine::remote::ws::Client, Surreal};
use tracing::{debug, info};
@@ -37,6 +38,9 @@ pub enum ProcessingError {
#[error("Unknown processing error")]
Unknown,
#[error("LLM processing error: {0}")]
OpenAIerror(#[from] OpenAIError),
}
@@ -45,14 +49,17 @@ impl TextContent {
pub async fn process(&self) -> Result<(), ProcessingError> {
// Store TextContent
let db_client = SurrealDbClient::new().await?;
db_client.query("REMOVE INDEX embeddings ON knowledge_entity").await?;
// db_client.query("DEFINE INDEX embeddings ON knowledge_entity FIELDS embedding UNIQUE").await?;
// db_client.query("REBUILD INDEX IF EXISTS embeddings ON knowledge_entity").await?;
// Step 1: Send to LLM for analysis
let analysis = create_json_ld(&self.category, &self.instructions, &self.text, &db_client).await?;
// info!("{:#?}", &analysis);
// Step 2: Convert LLM analysis to database entities
let (entities, relationships) = analysis.to_database_entities(&self.id);
let (entities, relationships) = analysis.to_database_entities(&self.id).await?;
// Step 3: Store in database
self.store_in_graph_db(entities, relationships, &db_client).await?;
@@ -70,19 +77,22 @@ impl TextContent {
relationships: Vec<KnowledgeRelationship>,
db_client: &Surreal<Client>,
) -> Result<(), ProcessingError> {
for entity in entities {
info!("{:?}", entity);
for entity in entities {
// info!("{:?}", &entity);
let _created: Option<KnowledgeEntity> = db_client
.create(("knowledge_entity", &entity.id.to_string()))
.content(entity)
.content(entity.clone())
.await?;
debug!("{:?}",_created);
}
for relationship in relationships {
info!("{:?}", relationship);
// info!("{:?}", relationship);
let _created: Option<KnowledgeRelationship> = db_client
.insert(("knowledge_relationship", &relationship.id.to_string()))

View File

@@ -1,11 +1,18 @@
use crate::models::graph_entities::{GraphMapper, KnowledgeEntity, KnowledgeEntityType, KnowledgeRelationship};
use crate::models::graph_entities::{
GraphMapper, KnowledgeEntity, KnowledgeEntityType, KnowledgeRelationship,
};
use crate::models::text_content::ProcessingError;
use async_openai::types::{CreateChatCompletionRequestArgs, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage };
use async_openai::types::{
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequestArgs, CreateEmbeddingRequestArgs, Embedding,
};
use futures::future::try_join_all;
use serde::{Deserialize, Serialize};
use serde_json::json;
use surrealdb::engine::remote::ws::Client;
use surrealdb::Surreal;
use tracing::{info,debug};
use tokio::try_join;
use tracing::{debug, info};
use uuid::Uuid;
/// Represents a single knowledge entity from the LLM.
@@ -33,49 +40,112 @@ pub struct LLMGraphAnalysisResult {
pub relationships: Vec<LLMRelationship>,
}
async fn generate_embedding(
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
input: String,
) -> Result<Vec<f32>, ProcessingError> {
let request = CreateEmbeddingRequestArgs::default()
.model("text-embedding-3-small")
.input(&[input])
.build()
.map_err(|e| ProcessingError::LLMError(e.to_string()))?;
// Send the request to OpenAI
let response = client
.embeddings()
.create(request)
.await
.map_err(|e| ProcessingError::LLMError(e.to_string()))?;
// Extract the embedding vector
let embedding: Vec<f32> = response
.data
.first()
.ok_or_else(|| ProcessingError::LLMError("No embedding data received".into()))?
.embedding
.clone();
Ok(embedding)
}
impl LLMGraphAnalysisResult {
pub fn to_database_entities(
/// Converts the LLM graph analysis result into database entities and relationships.
/// Processes embeddings sequentially for simplicity.
///
/// # Arguments
///
/// * `source_id` - A UUID representing the source identifier.
///
/// # Returns
///
/// * `Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), ProcessingError>` - A tuple containing vectors of `KnowledgeEntity` and `KnowledgeRelationship`.
pub async fn to_database_entities(
&self,
source_id: &Uuid,
) -> (Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>) {
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), ProcessingError> {
let mut mapper = GraphMapper::new();
// First pass: Create all entities and map their keys to UUIDs
let entities: Vec<KnowledgeEntity> = self
.knowledge_entities
.iter()
.map(|llm_entity| {
let id = mapper.assign_id(&llm_entity.key);
KnowledgeEntity {
id: id.to_string(),
name: llm_entity.name.clone(),
description: llm_entity.description.clone(),
entity_type: KnowledgeEntityType::from(llm_entity.entity_type.clone()),
source_id: source_id.to_string(),
metadata: None,
}
})
.collect();
// Step 1: Assign unique IDs to all knowledge entities upfront
for llm_entity in &self.knowledge_entities {
mapper.assign_id(&llm_entity.key);
}
// Second pass: Create relationships using mapped UUIDs
let openai_client = async_openai::Client::new();
let mut entities = vec![];
// Step 2: Process each knowledge entity sequentially
for llm_entity in &self.knowledge_entities {
// Retrieve the assigned ID for the current entity
let assigned_id = mapper
.get_id(&llm_entity.key)
.ok_or_else(|| {
ProcessingError::LLMError(format!("ID not found for key: {}", llm_entity.key))
})?
.clone();
// Prepare the embedding input
let embedding_input = format!(
"name: {}, description: {}, type: {}",
llm_entity.name, llm_entity.description, llm_entity.entity_type
);
// Generate embedding
let embedding = generate_embedding(&openai_client, embedding_input).await?;
// Construct the KnowledgeEntity with embedding
let knowledge_entity = KnowledgeEntity {
id: assigned_id.to_string(),
name: llm_entity.name.clone(),
description: llm_entity.description.clone(),
entity_type: KnowledgeEntityType::from(llm_entity.entity_type.clone()),
source_id: source_id.to_string(),
metadata: None,
embedding: Some(embedding),
};
entities.push(knowledge_entity);
}
// Step 3: Process relationships using the pre-assigned IDs
let relationships: Vec<KnowledgeRelationship> = self
.relationships
.iter()
.filter_map(|llm_rel| {
let source_id = mapper.get_id(&llm_rel.source)?;
let target_id = mapper.get_id(&llm_rel.target)?;
let source_db_id = mapper.get_id(&llm_rel.source)?;
let target_db_id = mapper.get_id(&llm_rel.target)?;
Some(KnowledgeRelationship {
id: Uuid::new_v4().to_string(),
out: source_id.to_string(),
in_: target_id.to_string(),
out: source_db_id.to_string(),
in_: target_db_id.to_string(),
relationship_type: llm_rel.type_.clone(),
metadata: None,
})
})
.collect();
(entities, relationships)
Ok((entities, relationships))
}
}
@@ -87,16 +157,23 @@ pub async fn create_json_ld(
db_client: &Surreal<Client>,
) -> Result<LLMGraphAnalysisResult, ProcessingError> {
// Get the nodes from the database
let entities: Vec<KnowledgeEntity> = db_client.select("knowledge_entity").await?;
info!("{:?}", entities);
let entities: Vec<KnowledgeEntity> = db_client
.query("SELECT * FROM knowledge_entity")
.await?
.take(0)?;
for entity in entities {
info!("{:?}", entity.name);
}
let deleted: Vec<KnowledgeEntity> = db_client.delete("knowledge_entity").await?;
info! {"{:?} KnowledgeEntities deleted", deleted.len()};
let relationships: Vec<KnowledgeRelationship> = db_client.select("knowledge_relationship").await?;
let relationships: Vec<KnowledgeRelationship> =
db_client.select("knowledge_relationship").await?;
info!("{:?}", relationships);
let relationships_deleted: Vec<KnowledgeRelationship> = db_client.delete("knowledge_relationship").await?;
let relationships_deleted: Vec<KnowledgeRelationship> =
db_client.delete("knowledge_relationship").await?;
info!("{:?} Relationships deleted", relationships_deleted.len());
let client = async_openai::Client::new();