storing & retrieving embeddings

This commit is contained in:
Per Stark
2024-11-12 21:14:29 +01:00
parent 0cc4046cf5
commit 97a105082f
4 changed files with 129 additions and 45 deletions

View File

@@ -41,20 +41,15 @@
''; '';
packages = [ packages = [
pkgs.neo4j
]; ];
languages.rust.enable = true; languages.rust.enable = true;
processes = { processes = {
# start-neo4j.exec = "NEO4J_HOME=$(mktemp -d) neo4j console"; surreal_db.exec = "docker run --rm --pull always -p 8000:8000 --user $(id -u) -v $(pwd)/database:/database surrealdb/surrealdb:latest-dev start rocksdb:/database/database.db --user root_user --pass root_password";
surreal_db.exec = "docker run --rm --pull always -p 8000:8000 --user $(id -u) -v $(pwd)/database:/database surrealdb/surrealdb:latest start rocksdb:/database/database.db --user root_user --pass root_password";
}; };
services = { services = {
redis = {
enable = true;
};
rabbitmq = { rabbitmq = {
enable = true; enable = true;
# plugins = ["tracing"]; # plugins = ["tracing"];

View File

@@ -15,6 +15,7 @@ pub struct KnowledgeEntity {
pub entity_type: KnowledgeEntityType, pub entity_type: KnowledgeEntityType,
pub source_id: String, pub source_id: String,
pub metadata: Option<serde_json::Value>, pub metadata: Option<serde_json::Value>,
pub embedding: Option<Vec<f32>>,
} }
fn thing_to_string<'de, D>(deserializer: D) -> Result<String, D::Error> fn thing_to_string<'de, D>(deserializer: D) -> Result<String, D::Error>
@@ -61,6 +62,7 @@ pub struct KnowledgeRelationship {
} }
/// Intermediate struct to hold mapping between LLM keys and generated IDs. /// Intermediate struct to hold mapping between LLM keys and generated IDs.
#[derive(Clone)]
pub struct GraphMapper { pub struct GraphMapper {
pub key_to_id: HashMap<String, Uuid>, pub key_to_id: HashMap<String, Uuid>,
} }

View File

@@ -1,3 +1,4 @@
use async_openai::{error::OpenAIError, types::{CreateEmbeddingRequest, CreateEmbeddingRequestArgs}};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use surrealdb::{engine::remote::ws::Client, Surreal}; use surrealdb::{engine::remote::ws::Client, Surreal};
use tracing::{debug, info}; use tracing::{debug, info};
@@ -37,6 +38,9 @@ pub enum ProcessingError {
#[error("Unknown processing error")] #[error("Unknown processing error")]
Unknown, Unknown,
#[error("LLM processing error: {0}")]
OpenAIerror(#[from] OpenAIError),
} }
@@ -45,14 +49,17 @@ impl TextContent {
pub async fn process(&self) -> Result<(), ProcessingError> { pub async fn process(&self) -> Result<(), ProcessingError> {
// Store TextContent // Store TextContent
let db_client = SurrealDbClient::new().await?; let db_client = SurrealDbClient::new().await?;
db_client.query("REMOVE INDEX embeddings ON knowledge_entity").await?;
// db_client.query("DEFINE INDEX embeddings ON knowledge_entity FIELDS embedding UNIQUE").await?;
// db_client.query("REBUILD INDEX IF EXISTS embeddings ON knowledge_entity").await?;
// Step 1: Send to LLM for analysis // Step 1: Send to LLM for analysis
let analysis = create_json_ld(&self.category, &self.instructions, &self.text, &db_client).await?; let analysis = create_json_ld(&self.category, &self.instructions, &self.text, &db_client).await?;
// info!("{:#?}", &analysis); // info!("{:#?}", &analysis);
// Step 2: Convert LLM analysis to database entities // Step 2: Convert LLM analysis to database entities
let (entities, relationships) = analysis.to_database_entities(&self.id); let (entities, relationships) = analysis.to_database_entities(&self.id).await?;
// Step 3: Store in database // Step 3: Store in database
self.store_in_graph_db(entities, relationships, &db_client).await?; self.store_in_graph_db(entities, relationships, &db_client).await?;
@@ -70,19 +77,22 @@ impl TextContent {
relationships: Vec<KnowledgeRelationship>, relationships: Vec<KnowledgeRelationship>,
db_client: &Surreal<Client>, db_client: &Surreal<Client>,
) -> Result<(), ProcessingError> { ) -> Result<(), ProcessingError> {
for entity in entities { for entity in entities {
info!("{:?}", entity); // info!("{:?}", &entity);
let _created: Option<KnowledgeEntity> = db_client let _created: Option<KnowledgeEntity> = db_client
.create(("knowledge_entity", &entity.id.to_string())) .create(("knowledge_entity", &entity.id.to_string()))
.content(entity) .content(entity.clone())
.await?; .await?;
debug!("{:?}",_created); debug!("{:?}",_created);
} }
for relationship in relationships { for relationship in relationships {
info!("{:?}", relationship); // info!("{:?}", relationship);
let _created: Option<KnowledgeRelationship> = db_client let _created: Option<KnowledgeRelationship> = db_client
.insert(("knowledge_relationship", &relationship.id.to_string())) .insert(("knowledge_relationship", &relationship.id.to_string()))

View File

@@ -1,11 +1,18 @@
use crate::models::graph_entities::{GraphMapper, KnowledgeEntity, KnowledgeEntityType, KnowledgeRelationship}; use crate::models::graph_entities::{
GraphMapper, KnowledgeEntity, KnowledgeEntityType, KnowledgeRelationship,
};
use crate::models::text_content::ProcessingError; use crate::models::text_content::ProcessingError;
use async_openai::types::{CreateChatCompletionRequestArgs, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage }; use async_openai::types::{
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequestArgs, CreateEmbeddingRequestArgs, Embedding,
};
use futures::future::try_join_all;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json; use serde_json::json;
use surrealdb::engine::remote::ws::Client; use surrealdb::engine::remote::ws::Client;
use surrealdb::Surreal; use surrealdb::Surreal;
use tracing::{info,debug}; use tokio::try_join;
use tracing::{debug, info};
use uuid::Uuid; use uuid::Uuid;
/// Represents a single knowledge entity from the LLM. /// Represents a single knowledge entity from the LLM.
@@ -33,49 +40,112 @@ pub struct LLMGraphAnalysisResult {
pub relationships: Vec<LLMRelationship>, pub relationships: Vec<LLMRelationship>,
} }
async fn generate_embedding(
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
input: String,
) -> Result<Vec<f32>, ProcessingError> {
let request = CreateEmbeddingRequestArgs::default()
.model("text-embedding-3-small")
.input(&[input])
.build()
.map_err(|e| ProcessingError::LLMError(e.to_string()))?;
// Send the request to OpenAI
let response = client
.embeddings()
.create(request)
.await
.map_err(|e| ProcessingError::LLMError(e.to_string()))?;
// Extract the embedding vector
let embedding: Vec<f32> = response
.data
.first()
.ok_or_else(|| ProcessingError::LLMError("No embedding data received".into()))?
.embedding
.clone();
Ok(embedding)
}
impl LLMGraphAnalysisResult { impl LLMGraphAnalysisResult {
pub fn to_database_entities( /// Converts the LLM graph analysis result into database entities and relationships.
/// Processes embeddings sequentially for simplicity.
///
/// # Arguments
///
/// * `source_id` - A UUID representing the source identifier.
///
/// # Returns
///
/// * `Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), ProcessingError>` - A tuple containing vectors of `KnowledgeEntity` and `KnowledgeRelationship`.
pub async fn to_database_entities(
&self, &self,
source_id: &Uuid, source_id: &Uuid,
) -> (Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>) { ) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), ProcessingError> {
let mut mapper = GraphMapper::new(); let mut mapper = GraphMapper::new();
// First pass: Create all entities and map their keys to UUIDs // Step 1: Assign unique IDs to all knowledge entities upfront
let entities: Vec<KnowledgeEntity> = self for llm_entity in &self.knowledge_entities {
.knowledge_entities mapper.assign_id(&llm_entity.key);
.iter() }
.map(|llm_entity| {
let id = mapper.assign_id(&llm_entity.key);
KnowledgeEntity {
id: id.to_string(),
name: llm_entity.name.clone(),
description: llm_entity.description.clone(),
entity_type: KnowledgeEntityType::from(llm_entity.entity_type.clone()),
source_id: source_id.to_string(),
metadata: None,
}
})
.collect();
// Second pass: Create relationships using mapped UUIDs let openai_client = async_openai::Client::new();
let mut entities = vec![];
// Step 2: Process each knowledge entity sequentially
for llm_entity in &self.knowledge_entities {
// Retrieve the assigned ID for the current entity
let assigned_id = mapper
.get_id(&llm_entity.key)
.ok_or_else(|| {
ProcessingError::LLMError(format!("ID not found for key: {}", llm_entity.key))
})?
.clone();
// Prepare the embedding input
let embedding_input = format!(
"name: {}, description: {}, type: {}",
llm_entity.name, llm_entity.description, llm_entity.entity_type
);
// Generate embedding
let embedding = generate_embedding(&openai_client, embedding_input).await?;
// Construct the KnowledgeEntity with embedding
let knowledge_entity = KnowledgeEntity {
id: assigned_id.to_string(),
name: llm_entity.name.clone(),
description: llm_entity.description.clone(),
entity_type: KnowledgeEntityType::from(llm_entity.entity_type.clone()),
source_id: source_id.to_string(),
metadata: None,
embedding: Some(embedding),
};
entities.push(knowledge_entity);
}
// Step 3: Process relationships using the pre-assigned IDs
let relationships: Vec<KnowledgeRelationship> = self let relationships: Vec<KnowledgeRelationship> = self
.relationships .relationships
.iter() .iter()
.filter_map(|llm_rel| { .filter_map(|llm_rel| {
let source_id = mapper.get_id(&llm_rel.source)?; let source_db_id = mapper.get_id(&llm_rel.source)?;
let target_id = mapper.get_id(&llm_rel.target)?; let target_db_id = mapper.get_id(&llm_rel.target)?;
Some(KnowledgeRelationship { Some(KnowledgeRelationship {
id: Uuid::new_v4().to_string(), id: Uuid::new_v4().to_string(),
out: source_id.to_string(), out: source_db_id.to_string(),
in_: target_id.to_string(), in_: target_db_id.to_string(),
relationship_type: llm_rel.type_.clone(), relationship_type: llm_rel.type_.clone(),
metadata: None, metadata: None,
}) })
}) })
.collect(); .collect();
(entities, relationships) Ok((entities, relationships))
} }
} }
@@ -87,16 +157,23 @@ pub async fn create_json_ld(
db_client: &Surreal<Client>, db_client: &Surreal<Client>,
) -> Result<LLMGraphAnalysisResult, ProcessingError> { ) -> Result<LLMGraphAnalysisResult, ProcessingError> {
// Get the nodes from the database // Get the nodes from the database
let entities: Vec<KnowledgeEntity> = db_client.select("knowledge_entity").await?; let entities: Vec<KnowledgeEntity> = db_client
info!("{:?}", entities); .query("SELECT * FROM knowledge_entity")
.await?
.take(0)?;
for entity in entities {
info!("{:?}", entity.name);
}
let deleted: Vec<KnowledgeEntity> = db_client.delete("knowledge_entity").await?; let deleted: Vec<KnowledgeEntity> = db_client.delete("knowledge_entity").await?;
info! {"{:?} KnowledgeEntities deleted", deleted.len()}; info! {"{:?} KnowledgeEntities deleted", deleted.len()};
let relationships: Vec<KnowledgeRelationship> = db_client.select("knowledge_relationship").await?; let relationships: Vec<KnowledgeRelationship> =
db_client.select("knowledge_relationship").await?;
info!("{:?}", relationships); info!("{:?}", relationships);
let relationships_deleted: Vec<KnowledgeRelationship> = db_client.delete("knowledge_relationship").await?; let relationships_deleted: Vec<KnowledgeRelationship> =
db_client.delete("knowledge_relationship").await?;
info!("{:?} Relationships deleted", relationships_deleted.len()); info!("{:?} Relationships deleted", relationships_deleted.len());
let client = async_openai::Client::new(); let client = async_openai::Client::new();