From 764cd52c1295f92a97616ea90f5be4050b96c819 Mon Sep 17 00:00:00 2001 From: Per Stark Date: Sun, 24 Nov 2024 09:38:14 +0100 Subject: [PATCH] improved edge creation, wip graph retrieval --- src/ingress/content_processor.rs | 2 +- src/retrieval/graph.rs | 45 ++++++++++++++++++++- src/server/routes/query.rs | 26 ++++++++++-- src/storage/types/knowledge_relationship.rs | 19 ++++++++- 4 files changed, 85 insertions(+), 7 deletions(-) diff --git a/src/ingress/content_processor.rs b/src/ingress/content_processor.rs index 4e74aea..0b03300 100644 --- a/src/ingress/content_processor.rs +++ b/src/ingress/content_processor.rs @@ -100,7 +100,7 @@ impl ContentProcessor { for relationship in &relationships { debug!("Storing relationship: {:?}", relationship); - store_item(&self.db_client, relationship.clone()).await?; + relationship.store_relationship(&self.db_client).await?; } info!( diff --git a/src/retrieval/graph.rs b/src/retrieval/graph.rs index 87fa782..d99587d 100644 --- a/src/retrieval/graph.rs +++ b/src/retrieval/graph.rs @@ -1,6 +1,7 @@ use surrealdb::{engine::remote::ws::Client, Surreal}; +use tracing::info; -use crate::error::ProcessingError; +use crate::{error::ProcessingError, storage::types::knowledge_entity::KnowledgeEntity}; /// Retrieves database entries that match a specific source identifier. /// @@ -65,3 +66,45 @@ where Ok(matching_entities) } + +pub async fn find_entities_by_relationship_by_source_ids( + db_client: &Surreal, + source_ids: &[String], +) -> Result, ProcessingError> { + // Create a comma-separated list of IDs wrapped in backticks + let ids = source_ids + .iter() + .map(|id| format!("knowledge_entity:`{}`", id)) + .collect::>() + .join(", "); + + info!("{:?}", ids); + + // let first = format!("knowledge_entity:`{}`", source_ids.first().unwrap()); + + let query = format!( + "SELECT *, array::complement(<->relates_to<->knowledge_entity, [id]) AS related FROM [{}] FETCH related", + ids + ); + + info!("{}", query); + + let result: Vec = db_client.query(query).await?.take(0)?; + + Ok(result) +} +pub async fn find_entities_by_relationship_by_id( + db_client: &Surreal, + source_id: &str, +) -> Result, ProcessingError> { + let query = format!( + "SELECT *, <-> relates_to <-> knowledge_entity AS related FROM knowledge_entity:`{}`", + source_id + ); + + info!("{}", query); + + let result: Vec = db_client.query(query).await?.take(0)?; + + Ok(result) +} diff --git a/src/server/routes/query.rs b/src/server/routes/query.rs index b6c383e..67e3fc8 100644 --- a/src/server/routes/query.rs +++ b/src/server/routes/query.rs @@ -1,6 +1,12 @@ use crate::{ error::ApiError, - retrieval::{graph::find_entities_by_source_ids, vector::find_items_by_vector_similarity}, + retrieval::{ + graph::{ + find_entities_by_relationship_by_id, find_entities_by_relationship_by_source_ids, + find_entities_by_source_ids, + }, + vector::find_items_by_vector_similarity, + }, storage::{ db::SurrealDbClient, types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk}, @@ -25,6 +31,9 @@ pub async fn query_handler( info!("Received input: {:?}", query); let openai_client = async_openai::Client::new(); + let test = find_entities_by_relationship_by_id(&db_client, &query.query).await?; + info!("{:?}", test); + let items_from_knowledge_entity_similarity: Vec = find_items_by_vector_similarity( 10, @@ -49,8 +58,12 @@ pub async fn query_handler( .map(|chunk| chunk.source_id.clone()) .collect::>(); - let items_from_text_chunk_similarity: Vec = - find_entities_by_source_ids(source_ids, "knowledge_entity".to_string(), &db_client).await?; + let items_from_text_chunk_similarity: Vec = find_entities_by_source_ids( + source_ids.clone(), + "knowledge_entity".to_string(), + &db_client, + ) + .await?; let entities: Vec = items_from_knowledge_entity_similarity .into_iter() @@ -75,7 +88,12 @@ pub async fn query_handler( }) .collect::>()); - info!("{} Entities\n{:#?}", entities.len(), entities_json); + let graph_retrieval = + find_entities_by_relationship_by_source_ids(&db_client, &source_ids).await?; + + info!("{:?}", graph_retrieval); + + // info!("{} Entities\n{:#?}", entities.len(), entities_json); Ok("we got some stuff".to_string()) } diff --git a/src/storage/types/knowledge_relationship.rs b/src/storage/types/knowledge_relationship.rs index 1311a88..8e8c985 100644 --- a/src/storage/types/knowledge_relationship.rs +++ b/src/storage/types/knowledge_relationship.rs @@ -1,4 +1,6 @@ -use crate::stored_object; +use crate::{error::ProcessingError, stored_object}; +use surrealdb::{engine::remote::ws::Client, Surreal}; +use tracing::info; use uuid::Uuid; stored_object!(KnowledgeRelationship, "knowledge_relationship", { @@ -24,4 +26,19 @@ impl KnowledgeRelationship { metadata, } } + pub async fn store_relationship( + &self, + db_client: &Surreal, + ) -> Result<(), ProcessingError> { + let query = format!( + "RELATE knowledge_entity:`{}` -> relates_to -> knowledge_entity:`{}`", + self.in_, self.out + ); + + let result = db_client.query(query).await?; + + info!("{:?}", result); + + Ok(()) + } }