From 1c4b3284bfb8a5124385ef8fe857b74461490588 Mon Sep 17 00:00:00 2001 From: Per Stark Date: Tue, 10 Dec 2024 17:38:06 +0100 Subject: [PATCH] refactor: add openai client and improve reference handling --- src/bin/server.rs | 1 + src/ingress/analysis/ingress_analyser.rs | 2 +- .../analysis/types/llm_analysis_result.rs | 2 +- src/ingress/content_processor.rs | 2 +- src/retrieval/mod.rs | 7 +- src/retrieval/vector.rs | 2 +- src/server/mod.rs | 1 + src/server/routes/query.rs | 45 ++------- src/server/routes/query/helper.rs | 96 ++++++++++++++++++- src/server/routes/search_result.rs | 47 ++------- src/utils/embedding.rs | 4 +- 11 files changed, 121 insertions(+), 88 deletions(-) diff --git a/src/bin/server.rs b/src/bin/server.rs index 57cd448..85cdc0f 100644 --- a/src/bin/server.rs +++ b/src/bin/server.rs @@ -42,6 +42,7 @@ async fn main() -> Result<(), Box> { rabbitmq_consumer: Arc::new(RabbitMQConsumer::new(&config, false).await?), surreal_db_client: Arc::new(SurrealDbClient::new().await?), tera: Arc::new(Tera::new("src/server/templates/**/*.html").unwrap()), + openai_client: Arc::new(async_openai::Client::new()), }; // Create Axum router diff --git a/src/ingress/analysis/ingress_analyser.rs b/src/ingress/analysis/ingress_analyser.rs index 16f9810..26bcc2f 100644 --- a/src/ingress/analysis/ingress_analyser.rs +++ b/src/ingress/analysis/ingress_analyser.rs @@ -57,7 +57,7 @@ impl<'a> IngressAnalyzer<'a> { text, category, instructions ); - combined_knowledge_entity_retrieval(self.db_client, self.openai_client, input_text).await + combined_knowledge_entity_retrieval(self.db_client, self.openai_client, &input_text).await } fn prepare_llm_request( diff --git a/src/ingress/analysis/types/llm_analysis_result.rs b/src/ingress/analysis/types/llm_analysis_result.rs index a54f5da..69c7923 100644 --- a/src/ingress/analysis/types/llm_analysis_result.rs +++ b/src/ingress/analysis/types/llm_analysis_result.rs @@ -158,7 +158,7 @@ async fn create_single_entity( llm_entity.name, llm_entity.description, llm_entity.entity_type ); - let embedding = generate_embedding(openai_client, embedding_input).await?; + let embedding = generate_embedding(openai_client, &embedding_input).await?; Ok(KnowledgeEntity { id: assigned_id, diff --git a/src/ingress/content_processor.rs b/src/ingress/content_processor.rs index 0c3e308..03dec71 100644 --- a/src/ingress/content_processor.rs +++ b/src/ingress/content_processor.rs @@ -101,7 +101,7 @@ impl ContentProcessor { // Could potentially process chunks in parallel with a bounded concurrent limit for chunk in chunks { - let embedding = generate_embedding(&self.openai_client, chunk.to_string()).await?; + let embedding = generate_embedding(&self.openai_client, chunk).await?; let text_chunk = TextChunk::new(content.id.to_string(), chunk.to_string(), embedding); store_item(&self.db_client, text_chunk).await?; } diff --git a/src/retrieval/mod.rs b/src/retrieval/mod.rs index 4286aff..b4a9788 100644 --- a/src/retrieval/mod.rs +++ b/src/retrieval/mod.rs @@ -12,7 +12,6 @@ use crate::{ use futures::future::{try_join, try_join_all}; use std::collections::HashMap; use surrealdb::{engine::remote::ws::Client, Surreal}; -use tracing::info; /// Performs a comprehensive knowledge entity retrieval using multiple search strategies /// to find the most relevant entities for a given query. @@ -37,14 +36,14 @@ use tracing::info; pub async fn combined_knowledge_entity_retrieval( db_client: &Surreal, openai_client: &async_openai::Client, - query: String, + query: &str, ) -> Result, ProcessingError> { - info!("Received input: {:?}", query); + // info!("Received input: {:?}", query); let (items_from_knowledge_entity_similarity, closest_chunks) = try_join( find_items_by_vector_similarity( 10, - query.clone(), + query, db_client, "knowledge_entity".to_string(), openai_client, diff --git a/src/retrieval/vector.rs b/src/retrieval/vector.rs index 0adb63a..53b7dbf 100644 --- a/src/retrieval/vector.rs +++ b/src/retrieval/vector.rs @@ -24,7 +24,7 @@ use crate::{error::ProcessingError, utils::embedding::generate_embedding}; /// * `T`: The type to deserialize the query results into. Must implement `serde::Deserialize`. pub async fn find_items_by_vector_similarity( take: u8, - input_text: String, + input_text: &str, db_client: &Surreal, table: String, openai_client: &async_openai::Client, diff --git a/src/server/mod.rs b/src/server/mod.rs index 9986842..a6f8748 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -12,4 +12,5 @@ pub struct AppState { pub rabbitmq_consumer: Arc, pub surreal_db_client: Arc, pub tera: Arc, + pub openai_client: Arc>, } diff --git a/src/server/routes/query.rs b/src/server/routes/query.rs index 97b2356..282cfc0 100644 --- a/src/server/routes/query.rs +++ b/src/server/routes/query.rs @@ -1,13 +1,11 @@ pub mod helper; pub mod prompt; -use crate::{error::ApiError, retrieval::combined_knowledge_entity_retrieval, server::AppState}; +use crate::{error::ApiError, server::AppState}; use axum::{extract::State, response::IntoResponse, Json}; -use helper::{ - create_chat_request, create_user_message, format_entities_json, process_llm_response, -}; +use helper::get_answer_with_references; use serde::Deserialize; -use tracing::{debug, info}; +use tracing::info; #[derive(Debug, Deserialize)] pub struct QueryInput { @@ -32,42 +30,13 @@ pub async fn query_handler( Json(query): Json, ) -> Result { info!("Received input: {:?}", query); - let openai_client = async_openai::Client::new(); - // Retrieve entities - let entities = combined_knowledge_entity_retrieval( - &state.surreal_db_client, - &openai_client, - query.query.clone(), - ) - .await?; - - // Format entities and create message - let entities_json = format_entities_json(&entities); - let user_message = create_user_message(&entities_json, &query.query); - debug!("{:?}", user_message); - - // Create and send request - let request = create_chat_request(user_message)?; - let response = openai_client - .chat() - .create(request) - .await - .map_err(|e| ApiError::QueryError(e.to_string()))?; - - // Process response - let answer = process_llm_response(response).await?; - debug!("{:?}", answer); - - let references: Vec = answer - .references - .into_iter() - .map(|reference| reference.reference) - .collect(); - info!("{:?}", references); + let answer = + get_answer_with_references(&state.surreal_db_client, &state.openai_client, &query.query) + .await?; Ok( - Json(serde_json::json!({"answer": answer.answer, "references": references})) + Json(serde_json::json!({"answer": answer.content, "references": answer.references})) .into_response(), ) } diff --git a/src/server/routes/query/helper.rs b/src/server/routes/query/helper.rs index 4fd9e8f..f281b5a 100644 --- a/src/server/routes/query/helper.rs +++ b/src/server/routes/query/helper.rs @@ -5,13 +5,107 @@ use async_openai::types::{ }; use serde_json::{json, Value}; -use crate::{error::ApiError, storage::types::knowledge_entity::KnowledgeEntity}; +use crate::{ + error::ApiError, + retrieval::combined_knowledge_entity_retrieval, + storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity}, +}; use super::{ prompt::{get_query_response_schema, QUERY_SYSTEM_PROMPT}, LLMResponseFormat, }; +// /// Orchestrator function that takes a query and clients and returns a answer with references +// /// +// /// # Arguments +// /// * `surreal_db_client` - Client for interacting with SurrealDn +// /// * `openai_client` - Client for interacting with openai +// /// * `query` - The query +// /// +// /// # Returns +// /// * `Result<(String, Vec, ApiError)` - Will return the answer, and the list of references or Error +// pub async fn get_answer_with_references( +// surreal_db_client: &SurrealDbClient, +// openai_client: &async_openai::Client, +// query: &str, +// ) -> Result<(String, Vec), ApiError> { +// let entities = +// combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query.into()).await?; + +// // Format entities and create message +// let entities_json = format_entities_json(&entities); +// let user_message = create_user_message(&entities_json, query); + +// // Create and send request +// let request = create_chat_request(user_message)?; +// let response = openai_client +// .chat() +// .create(request) +// .await +// .map_err(|e| ApiError::QueryError(e.to_string()))?; + +// // Process response +// let answer = process_llm_response(response).await?; + +// let references: Vec = answer +// .references +// .into_iter() +// .map(|reference| reference.reference) +// .collect(); + +// Ok((answer.answer, references)) +// } + +/// Orchestrates query processing and returns an answer with references +/// +/// Takes a query and uses the provided clients to generate an answer with supporting references. +/// +/// # Arguments +/// +/// * `surreal_db_client` - Client for SurrealDB interactions +/// * `openai_client` - Client for OpenAI API calls +/// * `query` - The user's query string +/// +/// # Returns +/// +/// Returns a tuple of the answer and its references, or an API error +#[derive(Debug)] +pub struct Answer { + pub content: String, + pub references: Vec, +} + +pub async fn get_answer_with_references( + surreal_db_client: &SurrealDbClient, + openai_client: &async_openai::Client, + query: &str, +) -> Result { + let entities = + combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query).await?; + + let entities_json = format_entities_json(&entities); + let user_message = create_user_message(&entities_json, query); + + let request = create_chat_request(user_message)?; + let response = openai_client + .chat() + .create(request) + .await + .map_err(|e| ApiError::QueryError(e.to_string()))?; + + let llm_response = process_llm_response(response).await?; + + Ok(Answer { + content: llm_response.answer, + references: llm_response + .references + .into_iter() + .map(|r| r.reference) + .collect(), + }) +} + pub fn format_entities_json(entities: &[KnowledgeEntity]) -> Value { json!(entities .iter() diff --git a/src/server/routes/search_result.rs b/src/server/routes/search_result.rs index c563207..7809076 100644 --- a/src/server/routes/search_result.rs +++ b/src/server/routes/search_result.rs @@ -9,13 +9,7 @@ use tracing::info; use crate::{ error::ApiError, - retrieval::combined_knowledge_entity_retrieval, - server::{ - routes::query::helper::{ - create_chat_request, create_user_message, format_entities_json, process_llm_response, - }, - AppState, - }, + server::{routes::query::helper::get_answer_with_references, AppState}, }; #[derive(Deserialize)] pub struct SearchParams { @@ -28,43 +22,18 @@ pub async fn search_result_handler( ) -> Result, ApiError> { info!("Displaying search results"); - let openai_client = async_openai::Client::new(); - - // Retrieve entities - let entities = combined_knowledge_entity_retrieval( - &state.surreal_db_client, - &openai_client, - query.query.clone(), - ) - .await?; - - // Format entities and create message - let entities_json = format_entities_json(&entities); - let user_message = create_user_message(&entities_json, &query.query); - - // Create and send request - let request = create_chat_request(user_message)?; - let response = openai_client - .chat() - .create(request) - .await - .map_err(|e| ApiError::QueryError(e.to_string()))?; - - // Process response - let answer = process_llm_response(response).await?; - - let references: Vec = answer - .references - .into_iter() - .map(|reference| reference.reference) - .collect(); + let answer = + get_answer_with_references(&state.surreal_db_client, &state.openai_client, &query.query) + .await?; let output = state .tera .render( "search_result.html", - &Context::from_value(json!({"result": answer.answer, "references": references})) - .unwrap(), + &Context::from_value( + json!({"result": answer.content, "references": answer.references}), + ) + .unwrap(), ) .unwrap(); diff --git a/src/utils/embedding.rs b/src/utils/embedding.rs index 9ef662a..ca7e486 100644 --- a/src/utils/embedding.rs +++ b/src/utils/embedding.rs @@ -27,11 +27,11 @@ use crate::error::ProcessingError; /// * If no embedding data is received in the response pub async fn generate_embedding( client: &async_openai::Client, - input: String, + input: &str, ) -> Result, ProcessingError> { let request = CreateEmbeddingRequestArgs::default() .model("text-embedding-3-small") - .input(&[input]) + .input([input]) .build()?; // Send the request to OpenAI