Files
minne/src/retrieval/query_helper.rs
2024-12-19 23:15:12 +01:00

194 lines
5.9 KiB
Rust

use async_openai::types::{
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
ResponseFormat, ResponseFormatJsonSchema,
};
use serde::Deserialize;
use serde_json::{json, Value};
use tracing::debug;
use crate::{
error::ApiError,
retrieval::combined_knowledge_entity_retrieval,
storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity},
};
use super::query_helper_prompt::{get_query_response_schema, QUERY_SYSTEM_PROMPT};
#[derive(Debug, Deserialize)]
pub struct Reference {
#[allow(dead_code)]
pub reference: String,
}
#[derive(Debug, Deserialize)]
pub struct LLMResponseFormat {
pub answer: String,
#[allow(dead_code)]
pub references: Vec<Reference>,
}
// /// Orchestrator function that takes a query and clients and returns a answer with references
// ///
// /// # Arguments
// /// * `surreal_db_client` - Client for interacting with SurrealDn
// /// * `openai_client` - Client for interacting with openai
// /// * `query` - The query
// ///
// /// # Returns
// /// * `Result<(String, Vec<String>, ApiError)` - Will return the answer, and the list of references or Error
// pub async fn get_answer_with_references(
// surreal_db_client: &SurrealDbClient,
// openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
// query: &str,
// ) -> Result<(String, Vec<String>), ApiError> {
// let entities =
// combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query.into()).await?;
// // Format entities and create message
// let entities_json = format_entities_json(&entities);
// let user_message = create_user_message(&entities_json, query);
// // Create and send request
// let request = create_chat_request(user_message)?;
// let response = openai_client
// .chat()
// .create(request)
// .await
// .map_err(|e| ApiError::QueryError(e.to_string()))?;
// // Process response
// let answer = process_llm_response(response).await?;
// let references: Vec<String> = answer
// .references
// .into_iter()
// .map(|reference| reference.reference)
// .collect();
// Ok((answer.answer, references))
// }
/// Orchestrates query processing and returns an answer with references
///
/// Takes a query and uses the provided clients to generate an answer with supporting references.
///
/// # Arguments
///
/// * `surreal_db_client` - Client for SurrealDB interactions
/// * `openai_client` - Client for OpenAI API calls
/// * `query` - The user's query string
/// * `user_id` - The user's id
///
/// # Returns
///
/// Returns a tuple of the answer and its references, or an API error
#[derive(Debug)]
pub struct Answer {
pub content: String,
pub references: Vec<String>,
}
pub async fn get_answer_with_references(
surreal_db_client: &SurrealDbClient,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
query: &str,
user_id: &str,
) -> Result<Answer, ApiError> {
let entities =
combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query, user_id)
.await?;
let entities_json = format_entities_json(&entities);
debug!("{:?}", entities_json);
let user_message = create_user_message(&entities_json, query);
let request = create_chat_request(user_message)?;
let response = openai_client
.chat()
.create(request)
.await
.map_err(|e| ApiError::QueryError(e.to_string()))?;
let llm_response = process_llm_response(response).await?;
Ok(Answer {
content: llm_response.answer,
references: llm_response
.references
.into_iter()
.map(|r| r.reference)
.collect(),
})
}
pub fn format_entities_json(entities: &[KnowledgeEntity]) -> Value {
json!(entities
.iter()
.map(|entity| {
json!({
"KnowledgeEntity": {
"id": entity.id,
"name": entity.name,
"description": entity.description
}
})
})
.collect::<Vec<_>>())
}
pub fn create_user_message(entities_json: &Value, query: &str) -> String {
format!(
r#"
Context Information:
==================
{}
User Question:
==================
{}
"#,
entities_json, query
)
}
pub fn create_chat_request(user_message: String) -> Result<CreateChatCompletionRequest, ApiError> {
let response_format = ResponseFormat::JsonSchema {
json_schema: ResponseFormatJsonSchema {
description: Some("Query answering AI".into()),
name: "query_answering_with_uuids".into(),
schema: Some(get_query_response_schema()),
strict: Some(true),
},
};
CreateChatCompletionRequestArgs::default()
.model("gpt-4o-mini")
.temperature(0.2)
.max_tokens(3048u32)
.messages([
ChatCompletionRequestSystemMessage::from(QUERY_SYSTEM_PROMPT).into(),
ChatCompletionRequestUserMessage::from(user_message).into(),
])
.response_format(response_format)
.build()
.map_err(|e| ApiError::QueryError(e.to_string()))
}
pub async fn process_llm_response(
response: CreateChatCompletionResponse,
) -> Result<LLMResponseFormat, ApiError> {
response
.choices
.first()
.and_then(|choice| choice.message.content.as_ref())
.ok_or(ApiError::QueryError(
"No content found in LLM response".into(),
))
.and_then(|content| {
serde_json::from_str::<LLMResponseFormat>(content).map_err(|e| {
ApiError::QueryError(format!("Failed to parse LLM response into analysis: {}", e))
})
})
}