From 6efa8cd4ee5ddde08c7172e932325a4021158c18 Mon Sep 17 00:00:00 2001 From: Per Stark Date: Mon, 25 Nov 2024 20:39:53 +0100 Subject: [PATCH] wip query --- src/error.rs | 3 + src/ingress/analysis/ingress_analyser.rs | 15 +-- src/ingress/content_processor.rs | 21 +--- src/retrieval/graph.rs | 16 ++- src/retrieval/mod.rs | 92 ++++++++++++++ src/server/routes/query.rs | 130 +++++++++++--------- src/storage/types/knowledge_relationship.rs | 4 +- 7 files changed, 179 insertions(+), 102 deletions(-) diff --git a/src/error.rs b/src/error.rs index 77c4bae..ae06928 100644 --- a/src/error.rs +++ b/src/error.rs @@ -53,6 +53,8 @@ pub enum ApiError { QueryError(String), #[error("RabbitMQ error: {0}")] RabbitMQError(#[from] RabbitMQError), + #[error("LLM processing error: {0}")] + OpenAIerror(#[from] OpenAIError), } impl IntoResponse for ApiError { @@ -61,6 +63,7 @@ impl IntoResponse for ApiError { ApiError::ProcessingError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()), ApiError::PublishingError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()), ApiError::DatabaseError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()), + ApiError::OpenAIerror(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()), ApiError::QueryError(_) => (StatusCode::BAD_REQUEST, self.to_string()), ApiError::IngressContentError(_) => { (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()) diff --git a/src/ingress/analysis/ingress_analyser.rs b/src/ingress/analysis/ingress_analyser.rs index aae0bde..16f9810 100644 --- a/src/ingress/analysis/ingress_analyser.rs +++ b/src/ingress/analysis/ingress_analyser.rs @@ -1,8 +1,8 @@ use crate::{ error::ProcessingError, ingress::analysis::prompt::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE}, - retrieval::vector::find_items_by_vector_similarity, - storage::types::{knowledge_entity::KnowledgeEntity, StoredObject}, + retrieval::combined_knowledge_entity_retrieval, + storage::types::knowledge_entity::KnowledgeEntity, }; use async_openai::types::{ ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, @@ -57,14 +57,7 @@ impl<'a> IngressAnalyzer<'a> { text, category, instructions ); - find_items_by_vector_similarity( - 10, - input_text, - self.db_client, - KnowledgeEntity::table_name().to_string(), - self.openai_client, - ) - .await + combined_knowledge_entity_retrieval(self.db_client, self.openai_client, input_text).await } fn prepare_llm_request( @@ -106,7 +99,7 @@ impl<'a> IngressAnalyzer<'a> { CreateChatCompletionRequestArgs::default() .model("gpt-4o-mini") .temperature(0.2) - .max_tokens(2048u32) + .max_tokens(3048u32) .messages([ ChatCompletionRequestSystemMessage::from(INGRESS_ANALYSIS_SYSTEM_MESSAGE).into(), ChatCompletionRequestUserMessage::from(user_message).into(), diff --git a/src/ingress/content_processor.rs b/src/ingress/content_processor.rs index 0b03300..4a236ec 100644 --- a/src/ingress/content_processor.rs +++ b/src/ingress/content_processor.rs @@ -5,7 +5,6 @@ use tracing::{debug, info}; use crate::{ error::ProcessingError, - retrieval::vector::find_items_by_vector_similarity, storage::{ db::{store_item, SurrealDbClient}, types::{ @@ -39,10 +38,8 @@ impl ContentProcessor { let now = Instant::now(); // Process in parallel where possible - let (analysis, _similar_chunks) = tokio::try_join!( - self.perform_semantic_analysis(content), - self.find_similar_content(content), - )?; + let analysis = self.perform_semantic_analysis(content).await?; + let end = now.elapsed(); info!( "{:?} time elapsed during creation of entities and relationships", @@ -74,20 +71,6 @@ impl ContentProcessor { .await } - async fn find_similar_content( - &self, - content: &TextContent, - ) -> Result, ProcessingError> { - find_items_by_vector_similarity( - 3, - content.text.clone(), - &self.db_client, - "text_chunk".to_string(), - &self.openai_client, - ) - .await - } - async fn store_graph_entities( &self, entities: Vec, diff --git a/src/retrieval/graph.rs b/src/retrieval/graph.rs index d99587d..27aa961 100644 --- a/src/retrieval/graph.rs +++ b/src/retrieval/graph.rs @@ -1,5 +1,5 @@ use surrealdb::{engine::remote::ws::Client, Surreal}; -use tracing::info; +use tracing::debug; use crate::{error::ProcessingError, storage::types::knowledge_entity::KnowledgeEntity}; @@ -71,23 +71,21 @@ pub async fn find_entities_by_relationship_by_source_ids( db_client: &Surreal, source_ids: &[String], ) -> Result, ProcessingError> { - // Create a comma-separated list of IDs wrapped in backticks let ids = source_ids .iter() + // .map(|id| format!("`{}`", id)) .map(|id| format!("knowledge_entity:`{}`", id)) .collect::>() .join(", "); - info!("{:?}", ids); - - // let first = format!("knowledge_entity:`{}`", source_ids.first().unwrap()); + debug!("{:?}", ids); let query = format!( - "SELECT *, array::complement(<->relates_to<->knowledge_entity, [id]) AS related FROM [{}] FETCH related", + "SELECT *, <-> relates_to <-> knowledge_entity AS related FROM [{}]", ids ); - info!("{}", query); + debug!("{}", query); let result: Vec = db_client.query(query).await?.take(0)?; @@ -95,14 +93,14 @@ pub async fn find_entities_by_relationship_by_source_ids( } pub async fn find_entities_by_relationship_by_id( db_client: &Surreal, - source_id: &str, + source_id: String, ) -> Result, ProcessingError> { let query = format!( "SELECT *, <-> relates_to <-> knowledge_entity AS related FROM knowledge_entity:`{}`", source_id ); - info!("{}", query); + debug!("{}", query); let result: Vec = db_client.query(query).await?.take(0)?; diff --git a/src/retrieval/mod.rs b/src/retrieval/mod.rs index a9e0919..4286aff 100644 --- a/src/retrieval/mod.rs +++ b/src/retrieval/mod.rs @@ -1,2 +1,94 @@ pub mod graph; pub mod vector; + +use crate::{ + error::ProcessingError, + retrieval::{ + graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids}, + vector::find_items_by_vector_similarity, + }, + storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk}, +}; +use futures::future::{try_join, try_join_all}; +use std::collections::HashMap; +use surrealdb::{engine::remote::ws::Client, Surreal}; +use tracing::info; + +/// Performs a comprehensive knowledge entity retrieval using multiple search strategies +/// to find the most relevant entities for a given query. +/// +/// # Strategy +/// The function employs a three-pronged approach to knowledge retrieval: +/// 1. Direct vector similarity search on knowledge entities +/// 2. Text chunk similarity search with source entity lookup +/// 3. Graph relationship traversal from related entities +/// +/// This combined approach ensures both semantic similarity matches and structurally +/// related content are included in the results. +/// +/// # Arguments +/// * `db_client` - SurrealDB client for database operations +/// * `openai_client` - OpenAI client for vector embeddings generation +/// * `query` - The search query string to find relevant knowledge entities +/// +/// # Returns +/// * `Result, ProcessingError>` - A deduplicated vector of relevant +/// knowledge entities, or an error if the retrieval process fails +pub async fn combined_knowledge_entity_retrieval( + db_client: &Surreal, + openai_client: &async_openai::Client, + query: String, +) -> Result, ProcessingError> { + info!("Received input: {:?}", query); + + let (items_from_knowledge_entity_similarity, closest_chunks) = try_join( + find_items_by_vector_similarity( + 10, + query.clone(), + db_client, + "knowledge_entity".to_string(), + openai_client, + ), + find_items_by_vector_similarity( + 5, + query, + db_client, + "text_chunk".to_string(), + openai_client, + ), + ) + .await?; + + let source_ids = closest_chunks + .iter() + .map(|chunk: &TextChunk| chunk.source_id.clone()) + .collect::>(); + + let items_from_text_chunk_similarity: Vec = + find_entities_by_source_ids(source_ids, "knowledge_entity".to_string(), db_client).await?; + + let items_from_relationships_futures: Vec<_> = items_from_text_chunk_similarity + .clone() + .into_iter() + .map(|entity| find_entities_by_relationship_by_id(db_client, entity.id.clone())) + .collect(); + + let items_from_relationships = try_join_all(items_from_relationships_futures) + .await? + .into_iter() + .flatten() + .collect::>(); + + let entities: Vec = items_from_knowledge_entity_similarity + .into_iter() + .chain(items_from_text_chunk_similarity.into_iter()) + .chain(items_from_relationships.into_iter()) + .fold(HashMap::new(), |mut map, entity| { + map.insert(entity.id.clone(), entity); + map + }) + .into_values() + .collect(); + + Ok(entities) +} diff --git a/src/server/routes/query.rs b/src/server/routes/query.rs index 67e3fc8..4962a0c 100644 --- a/src/server/routes/query.rs +++ b/src/server/routes/query.rs @@ -1,21 +1,13 @@ use crate::{ - error::ApiError, - retrieval::{ - graph::{ - find_entities_by_relationship_by_id, find_entities_by_relationship_by_source_ids, - find_entities_by_source_ids, - }, - vector::find_items_by_vector_similarity, - }, - storage::{ - db::SurrealDbClient, - types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk}, - }, + error::ApiError, retrieval::combined_knowledge_entity_retrieval, storage::db::SurrealDbClient, +}; +use async_openai::types::{ + ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, + CreateChatCompletionRequestArgs, }; use axum::{response::IntoResponse, Extension, Json}; use serde::Deserialize; use serde_json::json; -use std::collections::HashMap; use std::sync::Arc; use tracing::info; @@ -31,49 +23,9 @@ pub async fn query_handler( info!("Received input: {:?}", query); let openai_client = async_openai::Client::new(); - let test = find_entities_by_relationship_by_id(&db_client, &query.query).await?; - info!("{:?}", test); - - let items_from_knowledge_entity_similarity: Vec = - find_items_by_vector_similarity( - 10, - query.query.to_string(), - &db_client, - "knowledge_entity".to_string(), - &openai_client, - ) - .await?; - - let closest_chunks: Vec = find_items_by_vector_similarity( - 5, - query.query, - &db_client, - "text_chunk".to_string(), - &openai_client, - ) - .await?; - - let source_ids = closest_chunks - .iter() - .map(|chunk| chunk.source_id.clone()) - .collect::>(); - - let items_from_text_chunk_similarity: Vec = find_entities_by_source_ids( - source_ids.clone(), - "knowledge_entity".to_string(), - &db_client, - ) - .await?; - - let entities: Vec = items_from_knowledge_entity_similarity - .into_iter() - .chain(items_from_text_chunk_similarity.into_iter()) - .fold(HashMap::new(), |mut map, entity| { - map.insert(entity.id.clone(), entity); - map - }) - .into_values() - .collect(); + let entities = + combined_knowledge_entity_retrieval(&db_client, &openai_client, query.query.clone()) + .await?; let entities_json = json!(entities .iter() @@ -88,12 +40,68 @@ pub async fn query_handler( }) .collect::>()); - let graph_retrieval = - find_entities_by_relationship_by_source_ids(&db_client, &source_ids).await?; + let system_message = r#" + You are a knowledgeable assistant with access to a specialized knowledge base. You will be provided with relevant knowledge entities from the database as context. Each knowledge entity contains a name, description, and type, representing different concepts, ideas, and information. - info!("{:?}", graph_retrieval); + Your task is to: + 1. Carefully analyze the provided knowledge entities in the context + 2. Answer user questions based on this information + 3. Provide clear, concise, and accurate responses + 4. When referencing information, briefly mention which knowledge entity it came from + 5. If the provided context doesn't contain enough information to answer the question confidently, clearly state this + 6. If only partial information is available, explain what you can answer and what information is missing + 7. Avoid making assumptions or providing information not supported by the context - // info!("{} Entities\n{:#?}", entities.len(), entities_json); + Remember: + - Be direct and honest about the limitations of your knowledge + - Cite the relevant knowledge entities when providing information + - If you need to combine information from multiple entities, explain how they connect + - Don't speculate beyond what's provided in the context - Ok("we got some stuff".to_string()) + Example response formats: + "Based on [Entity Name], [answer...]" + "I found relevant information in multiple entries: [explanation...]" + "I apologize, but the provided context doesn't contain information about [topic]" + "#; + + let user_message = format!( + r#" + Context Information: + ================== + {} + + User Question: + ================== + {} + "#, + entities_json, query.query + ); + + info!("{:?}", user_message); + + let request = CreateChatCompletionRequestArgs::default() + .model("gpt-4o-mini") + .temperature(0.2) + .max_tokens(3048u32) + .messages([ + ChatCompletionRequestSystemMessage::from(system_message).into(), + ChatCompletionRequestUserMessage::from(user_message).into(), + ]) + .build()?; + + let response = openai_client.chat().create(request).await?; + + let answer = response + .choices + .first() + .and_then(|choice| choice.message.content.as_ref()) + .ok_or(ApiError::QueryError( + "No content found in LLM response".to_string(), + ))?; + + info!("{:?}", answer); + + // info!("{:#?}", entities_json); + + Ok(answer.clone().into_response()) } diff --git a/src/storage/types/knowledge_relationship.rs b/src/storage/types/knowledge_relationship.rs index 8e8c985..aa7b6b1 100644 --- a/src/storage/types/knowledge_relationship.rs +++ b/src/storage/types/knowledge_relationship.rs @@ -1,6 +1,6 @@ use crate::{error::ProcessingError, stored_object}; use surrealdb::{engine::remote::ws::Client, Surreal}; -use tracing::info; +use tracing::debug; use uuid::Uuid; stored_object!(KnowledgeRelationship, "knowledge_relationship", { @@ -37,7 +37,7 @@ impl KnowledgeRelationship { let result = db_client.query(query).await?; - info!("{:?}", result); + debug!("{:?}", result); Ok(()) }