refactor: add openai client and improve reference handling

This commit is contained in:
Per Stark
2024-12-10 17:38:06 +01:00
parent 15edc8aa6b
commit 766653030d
11 changed files with 121 additions and 88 deletions

View File

@@ -12,4 +12,5 @@ pub struct AppState {
pub rabbitmq_consumer: Arc<RabbitMQConsumer>,
pub surreal_db_client: Arc<SurrealDbClient>,
pub tera: Arc<Tera>,
pub openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
}

View File

@@ -1,13 +1,11 @@
pub mod helper;
pub mod prompt;
use crate::{error::ApiError, retrieval::combined_knowledge_entity_retrieval, server::AppState};
use crate::{error::ApiError, server::AppState};
use axum::{extract::State, response::IntoResponse, Json};
use helper::{
create_chat_request, create_user_message, format_entities_json, process_llm_response,
};
use helper::get_answer_with_references;
use serde::Deserialize;
use tracing::{debug, info};
use tracing::info;
#[derive(Debug, Deserialize)]
pub struct QueryInput {
@@ -32,42 +30,13 @@ pub async fn query_handler(
Json(query): Json<QueryInput>,
) -> Result<impl IntoResponse, ApiError> {
info!("Received input: {:?}", query);
let openai_client = async_openai::Client::new();
// Retrieve entities
let entities = combined_knowledge_entity_retrieval(
&state.surreal_db_client,
&openai_client,
query.query.clone(),
)
.await?;
// Format entities and create message
let entities_json = format_entities_json(&entities);
let user_message = create_user_message(&entities_json, &query.query);
debug!("{:?}", user_message);
// Create and send request
let request = create_chat_request(user_message)?;
let response = openai_client
.chat()
.create(request)
.await
.map_err(|e| ApiError::QueryError(e.to_string()))?;
// Process response
let answer = process_llm_response(response).await?;
debug!("{:?}", answer);
let references: Vec<String> = answer
.references
.into_iter()
.map(|reference| reference.reference)
.collect();
info!("{:?}", references);
let answer =
get_answer_with_references(&state.surreal_db_client, &state.openai_client, &query.query)
.await?;
Ok(
Json(serde_json::json!({"answer": answer.answer, "references": references}))
Json(serde_json::json!({"answer": answer.content, "references": answer.references}))
.into_response(),
)
}

View File

@@ -5,13 +5,107 @@ use async_openai::types::{
};
use serde_json::{json, Value};
use crate::{error::ApiError, storage::types::knowledge_entity::KnowledgeEntity};
use crate::{
error::ApiError,
retrieval::combined_knowledge_entity_retrieval,
storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity},
};
use super::{
prompt::{get_query_response_schema, QUERY_SYSTEM_PROMPT},
LLMResponseFormat,
};
// /// Orchestrator function that takes a query and clients and returns a answer with references
// ///
// /// # Arguments
// /// * `surreal_db_client` - Client for interacting with SurrealDn
// /// * `openai_client` - Client for interacting with openai
// /// * `query` - The query
// ///
// /// # Returns
// /// * `Result<(String, Vec<String>, ApiError)` - Will return the answer, and the list of references or Error
// pub async fn get_answer_with_references(
// surreal_db_client: &SurrealDbClient,
// openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
// query: &str,
// ) -> Result<(String, Vec<String>), ApiError> {
// let entities =
// combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query.into()).await?;
// // Format entities and create message
// let entities_json = format_entities_json(&entities);
// let user_message = create_user_message(&entities_json, query);
// // Create and send request
// let request = create_chat_request(user_message)?;
// let response = openai_client
// .chat()
// .create(request)
// .await
// .map_err(|e| ApiError::QueryError(e.to_string()))?;
// // Process response
// let answer = process_llm_response(response).await?;
// let references: Vec<String> = answer
// .references
// .into_iter()
// .map(|reference| reference.reference)
// .collect();
// Ok((answer.answer, references))
// }
/// Orchestrates query processing and returns an answer with references
///
/// Takes a query and uses the provided clients to generate an answer with supporting references.
///
/// # Arguments
///
/// * `surreal_db_client` - Client for SurrealDB interactions
/// * `openai_client` - Client for OpenAI API calls
/// * `query` - The user's query string
///
/// # Returns
///
/// Returns a tuple of the answer and its references, or an API error
#[derive(Debug)]
pub struct Answer {
pub content: String,
pub references: Vec<String>,
}
pub async fn get_answer_with_references(
surreal_db_client: &SurrealDbClient,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
query: &str,
) -> Result<Answer, ApiError> {
let entities =
combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query).await?;
let entities_json = format_entities_json(&entities);
let user_message = create_user_message(&entities_json, query);
let request = create_chat_request(user_message)?;
let response = openai_client
.chat()
.create(request)
.await
.map_err(|e| ApiError::QueryError(e.to_string()))?;
let llm_response = process_llm_response(response).await?;
Ok(Answer {
content: llm_response.answer,
references: llm_response
.references
.into_iter()
.map(|r| r.reference)
.collect(),
})
}
pub fn format_entities_json(entities: &[KnowledgeEntity]) -> Value {
json!(entities
.iter()

View File

@@ -9,13 +9,7 @@ use tracing::info;
use crate::{
error::ApiError,
retrieval::combined_knowledge_entity_retrieval,
server::{
routes::query::helper::{
create_chat_request, create_user_message, format_entities_json, process_llm_response,
},
AppState,
},
server::{routes::query::helper::get_answer_with_references, AppState},
};
#[derive(Deserialize)]
pub struct SearchParams {
@@ -28,43 +22,18 @@ pub async fn search_result_handler(
) -> Result<Html<String>, ApiError> {
info!("Displaying search results");
let openai_client = async_openai::Client::new();
// Retrieve entities
let entities = combined_knowledge_entity_retrieval(
&state.surreal_db_client,
&openai_client,
query.query.clone(),
)
.await?;
// Format entities and create message
let entities_json = format_entities_json(&entities);
let user_message = create_user_message(&entities_json, &query.query);
// Create and send request
let request = create_chat_request(user_message)?;
let response = openai_client
.chat()
.create(request)
.await
.map_err(|e| ApiError::QueryError(e.to_string()))?;
// Process response
let answer = process_llm_response(response).await?;
let references: Vec<String> = answer
.references
.into_iter()
.map(|reference| reference.reference)
.collect();
let answer =
get_answer_with_references(&state.surreal_db_client, &state.openai_client, &query.query)
.await?;
let output = state
.tera
.render(
"search_result.html",
&Context::from_value(json!({"result": answer.answer, "references": references}))
.unwrap(),
&Context::from_value(
json!({"result": answer.content, "references": answer.references}),
)
.unwrap(),
)
.unwrap();