breaking up query fn

This commit is contained in:
Per Stark
2024-11-27 12:38:39 +01:00
parent 8c6eae1f13
commit 2dd78189b3
2 changed files with 27 additions and 15 deletions

View File

@@ -58,7 +58,17 @@ pub async fn query_handler(
// Process response // Process response
let answer = process_llm_response(response).await?; let answer = process_llm_response(response).await?;
info!("{:?}", answer); debug!("{:?}", answer);
Ok(answer.answer.into_response()) let references: Vec<String> = answer
.references
.into_iter()
.map(|reference| reference.reference)
.collect();
info!("{:?}", references);
Ok(
Json(serde_json::json!({"answer": answer.answer, "references": references}))
.into_response(),
)
} }

View File

@@ -1,4 +1,9 @@
use serde_json::json; use async_openai::types::{
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
ResponseFormat, ResponseFormatJsonSchema,
};
use serde_json::{json, Value};
use crate::{error::ApiError, storage::types::knowledge_entity::KnowledgeEntity}; use crate::{error::ApiError, storage::types::knowledge_entity::KnowledgeEntity};
@@ -7,7 +12,7 @@ use super::{
LLMResponseFormat, LLMResponseFormat,
}; };
pub fn format_entities_json(entities: &[KnowledgeEntity]) -> serde_json::Value { pub fn format_entities_json(entities: &[KnowledgeEntity]) -> Value {
json!(entities json!(entities
.iter() .iter()
.map(|entity| { .map(|entity| {
@@ -22,7 +27,7 @@ pub fn format_entities_json(entities: &[KnowledgeEntity]) -> serde_json::Value {
.collect::<Vec<_>>()) .collect::<Vec<_>>())
} }
pub fn create_user_message(entities_json: &serde_json::Value, query: &str) -> String { pub fn create_user_message(entities_json: &Value, query: &str) -> String {
format!( format!(
r#" r#"
Context Information: Context Information:
@@ -37,11 +42,9 @@ pub fn create_user_message(entities_json: &serde_json::Value, query: &str) -> St
) )
} }
pub fn create_chat_request( pub fn create_chat_request(user_message: String) -> Result<CreateChatCompletionRequest, ApiError> {
user_message: String, let response_format = ResponseFormat::JsonSchema {
) -> Result<async_openai::types::CreateChatCompletionRequest, ApiError> { json_schema: ResponseFormatJsonSchema {
let response_format = async_openai::types::ResponseFormat::JsonSchema {
json_schema: async_openai::types::ResponseFormatJsonSchema {
description: Some("Query answering AI".into()), description: Some("Query answering AI".into()),
name: "query_answering_with_uuids".into(), name: "query_answering_with_uuids".into(),
schema: Some(get_query_response_schema()), schema: Some(get_query_response_schema()),
@@ -49,14 +52,13 @@ pub fn create_chat_request(
}, },
}; };
async_openai::types::CreateChatCompletionRequestArgs::default() CreateChatCompletionRequestArgs::default()
.model("gpt-4o-mini") .model("gpt-4o-mini")
.temperature(0.2) .temperature(0.2)
.max_tokens(3048u32) .max_tokens(3048u32)
.messages([ .messages([
async_openai::types::ChatCompletionRequestSystemMessage::from(QUERY_SYSTEM_PROMPT) ChatCompletionRequestSystemMessage::from(QUERY_SYSTEM_PROMPT).into(),
.into(), ChatCompletionRequestUserMessage::from(user_message).into(),
async_openai::types::ChatCompletionRequestUserMessage::from(user_message).into(),
]) ])
.response_format(response_format) .response_format(response_format)
.build() .build()
@@ -64,7 +66,7 @@ pub fn create_chat_request(
} }
pub async fn process_llm_response( pub async fn process_llm_response(
response: async_openai::types::CreateChatCompletionResponse, response: CreateChatCompletionResponse,
) -> Result<LLMResponseFormat, ApiError> { ) -> Result<LLMResponseFormat, ApiError> {
response response
.choices .choices