From 2e064c7148576b3cc293e6795cd6b1ce97fae829 Mon Sep 17 00:00:00 2001 From: Per Stark Date: Sat, 23 Nov 2024 22:27:51 +0100 Subject: [PATCH] wip: query --- src/rabbitmq/publisher.rs | 4 +-- src/retrieval/graph.rs | 16 +++++----- src/server/routes/ingress.rs | 10 ++++-- src/server/routes/query.rs | 60 ++++++++++++++++++++++++++++++++---- 4 files changed, 72 insertions(+), 18 deletions(-) diff --git a/src/rabbitmq/publisher.rs b/src/rabbitmq/publisher.rs index c86ad5b..2cc5ebe 100644 --- a/src/rabbitmq/publisher.rs +++ b/src/rabbitmq/publisher.rs @@ -42,10 +42,10 @@ impl RabbitMQProducer { /// * `Result` - Confirmation of sent message or error pub async fn publish( &self, - ingress_object: &IngressObject, + ingress_object: IngressObject, ) -> Result { // Serialize IngressObject to JSON - let payload = serde_json::to_vec(ingress_object).map_err(|e| { + let payload = serde_json::to_vec(&ingress_object).map_err(|e| { error!("Serialization Error: {}", e); RabbitMQError::PublishError(format!("Serialization Error: {}", e)) })?; diff --git a/src/retrieval/graph.rs b/src/retrieval/graph.rs index e2b3338..87fa782 100644 --- a/src/retrieval/graph.rs +++ b/src/retrieval/graph.rs @@ -46,20 +46,22 @@ use crate::error::ProcessingError; /// &db_client /// ).await?; /// ``` -pub async fn find_entities_by_source_id( - source_id: String, +pub async fn find_entities_by_source_ids( + source_id: Vec, table_name: String, db_client: &Surreal, ) -> Result, ProcessingError> where T: for<'de> serde::Deserialize<'de>, { - let query = format!( - "SELECT * FROM {} WHERE source_id = '{}'", - table_name, source_id - ); + let query = "SELECT * FROM type::table($table) WHERE source_id IN $source_ids"; - let matching_entities: Vec = db_client.query(query).await?.take(0)?; + let matching_entities: Vec = db_client + .query(query) + .bind(("table", table_name)) + .bind(("source_ids", source_id)) + .await? + .take(0)?; Ok(matching_entities) } diff --git a/src/server/routes/ingress.rs b/src/server/routes/ingress.rs index 05bd0f2..393995d 100644 --- a/src/server/routes/ingress.rs +++ b/src/server/routes/ingress.rs @@ -5,6 +5,7 @@ use crate::{ storage::db::SurrealDbClient, }; use axum::{http::StatusCode, response::IntoResponse, Extension, Json}; +use futures::future::try_join_all; use std::sync::Arc; use tracing::info; @@ -17,9 +18,12 @@ pub async fn ingress_handler( let ingress_objects = create_ingress_objects(input, &db_client).await?; - for object in ingress_objects { - producer.publish(&object).await?; - } + let futures: Vec<_> = ingress_objects + .into_iter() + .map(|object| producer.publish(object)) + .collect(); + + try_join_all(futures).await?; Ok(StatusCode::OK) } diff --git a/src/server/routes/query.rs b/src/server/routes/query.rs index 89afed1..b6c383e 100644 --- a/src/server/routes/query.rs +++ b/src/server/routes/query.rs @@ -1,10 +1,15 @@ use crate::{ error::ApiError, - retrieval::vector::find_items_by_vector_similarity, - storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity}, + retrieval::{graph::find_entities_by_source_ids, vector::find_items_by_vector_similarity}, + storage::{ + db::SurrealDbClient, + types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk}, + }, }; use axum::{response::IntoResponse, Extension, Json}; use serde::Deserialize; +use serde_json::json; +use std::collections::HashMap; use std::sync::Arc; use tracing::info; @@ -20,14 +25,57 @@ pub async fn query_handler( info!("Received input: {:?}", query); let openai_client = async_openai::Client::new(); - let closest_items: Vec = find_items_by_vector_similarity( - 10, + let items_from_knowledge_entity_similarity: Vec = + find_items_by_vector_similarity( + 10, + query.query.to_string(), + &db_client, + "knowledge_entity".to_string(), + &openai_client, + ) + .await?; + + let closest_chunks: Vec = find_items_by_vector_similarity( + 5, query.query, &db_client, - "knowledge_entity".to_string(), + "text_chunk".to_string(), &openai_client, ) .await?; - Ok(format!("{:?}", closest_items)) + let source_ids = closest_chunks + .iter() + .map(|chunk| chunk.source_id.clone()) + .collect::>(); + + let items_from_text_chunk_similarity: Vec = + find_entities_by_source_ids(source_ids, "knowledge_entity".to_string(), &db_client).await?; + + let entities: Vec = items_from_knowledge_entity_similarity + .into_iter() + .chain(items_from_text_chunk_similarity.into_iter()) + .fold(HashMap::new(), |mut map, entity| { + map.insert(entity.id.clone(), entity); + map + }) + .into_values() + .collect(); + + let entities_json = json!(entities + .iter() + .map(|entity| { + json!({ + "KnowledgeEntity": { + "id": entity.id, + "name": entity.name, + "description": entity.description + } + }) + }) + .collect::>()); + + info!("{} Entities\n{:#?}", entities.len(), entities_json); + + Ok("we got some stuff".to_string()) }