refactoring: new structure and mailer

This commit is contained in:
Per Stark
2024-12-19 23:15:12 +01:00
parent 0a5e37130d
commit ff1c8836ee
25 changed files with 659 additions and 379 deletions

View File

@@ -0,0 +1,4 @@
pub mod file;
pub mod ingress;
pub mod query;
pub mod queue_length;

View File

@@ -1,9 +1,8 @@
pub mod helper;
pub mod prompt;
use crate::{error::ApiError, server::AppState, storage::types::user::User};
use crate::{
error::ApiError, retrieval::query_helper::get_answer_with_references, server::AppState,
storage::types::user::User,
};
use axum::{extract::State, response::IntoResponse, Extension, Json};
use helper::get_answer_with_references;
use serde::Deserialize;
use tracing::info;
@@ -12,19 +11,6 @@ pub struct QueryInput {
query: String,
}
#[derive(Debug, Deserialize)]
pub struct Reference {
#[allow(dead_code)]
pub reference: String,
}
#[derive(Debug, Deserialize)]
pub struct LLMResponseFormat {
pub answer: String,
#[allow(dead_code)]
pub references: Vec<Reference>,
}
pub async fn query_handler(
State(state): State<AppState>,
Extension(user): Extension<User>,

View File

@@ -1,7 +1,6 @@
use axum::{
extract::State,
http::Response,
response::{Html, IntoResponse},
response::{IntoResponse, Redirect},
Form,
};
use axum_htmx::HxBoosted;
@@ -27,19 +26,23 @@ struct PageData {
pub async fn show_signup_form(
State(state): State<AppState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
HxBoosted(boosted): HxBoosted,
) -> Result<Html<String>, ApiError> {
) -> Result<impl IntoResponse, ApiError> {
if auth.is_authenticated() {
return Ok(Redirect::to("/").into_response());
}
let output = match boosted {
true => render_block(
"auth/signup_form.html",
"content",
"body",
PageData {},
state.templates,
)?,
false => render_template("auth/signup_form.html", PageData {}, state.templates)?,
};
Ok(output)
Ok(output.into_response())
}
pub async fn signup_handler(

View File

@@ -1,24 +1,19 @@
use axum::{extract::State, response::Html};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use minijinja::context;
use serde::Serialize;
use serde_json::json;
use surrealdb::{engine::any::Any, sql::Relation, Surreal};
use tera::Context;
// use tera::Context;
use surrealdb::{engine::any::Any, Surreal};
use tracing::info;
use crate::{
error::ApiError,
server::{routes::render_template, AppState},
page_data,
server::{routes::html::render_template, AppState},
storage::types::user::User,
};
#[derive(Serialize)]
struct PageData<'a> {
queue_length: &'a str,
}
page_data!(IndexData, {
queue_length: u32,
});
pub async fn index_handler(
State(state): State<AppState>,
@@ -30,13 +25,7 @@ pub async fn index_handler(
let queue_length = state.rabbitmq_consumer.get_queue_length().await?;
let output = render_template(
"index.html",
PageData {
queue_length: "1000",
},
state.templates,
)?;
let output = render_template("index.html", IndexData { queue_length }, state.templates)?;
Ok(output)
}

View File

@@ -0,0 +1,55 @@
use std::sync::Arc;
use axum::response::Html;
use minijinja_autoreload::AutoReloader;
pub mod auth;
pub mod index;
pub mod search_result;
pub fn render_template<T>(
template_name: &str,
context: T,
templates: Arc<AutoReloader>,
) -> Result<Html<String>, minijinja::Error>
where
T: serde::Serialize,
{
let env = templates.acquire_env()?;
let tmpl = env.get_template(template_name)?;
let context = minijinja::Value::from_serialize(&context);
let output = tmpl.render(context)?;
Ok(output.into())
}
pub fn render_block<T>(
template_name: &str,
block: &str,
context: T,
templates: Arc<AutoReloader>,
) -> Result<Html<String>, minijinja::Error>
where
T: serde::Serialize,
{
let env = templates.acquire_env()?;
let tmpl = env.get_template(template_name)?;
let context = minijinja::Value::from_serialize(&context);
let output = tmpl.eval_to_state(context)?.render_block(block)?;
Ok(output.into())
}
#[macro_export]
macro_rules! page_data {
($name:ident, {$($(#[$attr:meta])* $field:ident: $ty:ty),*$(,)?}) => {
use serde::{Serialize, Deserialize};
#[derive(Debug, Deserialize, Serialize)]
pub struct $name {
$($(#[$attr])* pub $field: $ty),*
}
};
}

View File

@@ -5,14 +5,11 @@ use axum::{
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use serde::Deserialize;
use serde_json::json;
use surrealdb::{engine::any::Any, Surreal};
use tera::Context;
use tracing::info;
use crate::{
error::ApiError,
server::{routes::query::helper::get_answer_with_references, AppState},
error::ApiError, retrieval::query_helper::get_answer_with_references, server::AppState,
storage::types::user::User,
};
#[derive(Deserialize)]

View File

@@ -1,47 +1,2 @@
use std::sync::Arc;
use axum::response::Html;
use minijinja_autoreload::AutoReloader;
pub mod auth;
pub mod file;
pub mod index;
pub mod ingress;
pub mod query;
pub mod queue_length;
pub mod search_result;
pub fn render_template<T>(
template_name: &str,
context: T,
templates: Arc<AutoReloader>,
) -> Result<Html<String>, minijinja::Error>
where
T: serde::Serialize,
{
let env = templates.acquire_env()?;
let tmpl = env.get_template(template_name)?;
let context = minijinja::Value::from_serialize(&context);
let output = tmpl.render(context)?;
Ok(output.into())
}
pub fn render_block<T>(
template_name: &str,
block: &str,
context: T,
templates: Arc<AutoReloader>,
) -> Result<Html<String>, minijinja::Error>
where
T: serde::Serialize,
{
let env = templates.acquire_env()?;
let tmpl = env.get_template(template_name)?;
let context = minijinja::Value::from_serialize(&context);
let output = tmpl.eval_to_state(context)?.render_block(block)?;
Ok(output.into())
}
pub mod api;
pub mod html;

View File

@@ -1,182 +0,0 @@
use async_openai::types::{
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
ResponseFormat, ResponseFormatJsonSchema,
};
use serde_json::{json, Value};
use tracing::debug;
use crate::{
error::ApiError,
retrieval::combined_knowledge_entity_retrieval,
storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity},
};
use super::{
prompt::{get_query_response_schema, QUERY_SYSTEM_PROMPT},
LLMResponseFormat,
};
// /// Orchestrator function that takes a query and clients and returns a answer with references
// ///
// /// # Arguments
// /// * `surreal_db_client` - Client for interacting with SurrealDn
// /// * `openai_client` - Client for interacting with openai
// /// * `query` - The query
// ///
// /// # Returns
// /// * `Result<(String, Vec<String>, ApiError)` - Will return the answer, and the list of references or Error
// pub async fn get_answer_with_references(
// surreal_db_client: &SurrealDbClient,
// openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
// query: &str,
// ) -> Result<(String, Vec<String>), ApiError> {
// let entities =
// combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query.into()).await?;
// // Format entities and create message
// let entities_json = format_entities_json(&entities);
// let user_message = create_user_message(&entities_json, query);
// // Create and send request
// let request = create_chat_request(user_message)?;
// let response = openai_client
// .chat()
// .create(request)
// .await
// .map_err(|e| ApiError::QueryError(e.to_string()))?;
// // Process response
// let answer = process_llm_response(response).await?;
// let references: Vec<String> = answer
// .references
// .into_iter()
// .map(|reference| reference.reference)
// .collect();
// Ok((answer.answer, references))
// }
/// Orchestrates query processing and returns an answer with references
///
/// Takes a query and uses the provided clients to generate an answer with supporting references.
///
/// # Arguments
///
/// * `surreal_db_client` - Client for SurrealDB interactions
/// * `openai_client` - Client for OpenAI API calls
/// * `query` - The user's query string
/// * `user_id` - The user's id
///
/// # Returns
///
/// Returns a tuple of the answer and its references, or an API error
#[derive(Debug)]
pub struct Answer {
pub content: String,
pub references: Vec<String>,
}
pub async fn get_answer_with_references(
surreal_db_client: &SurrealDbClient,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
query: &str,
user_id: &str,
) -> Result<Answer, ApiError> {
let entities =
combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query, user_id)
.await?;
let entities_json = format_entities_json(&entities);
debug!("{:?}", entities_json);
let user_message = create_user_message(&entities_json, query);
let request = create_chat_request(user_message)?;
let response = openai_client
.chat()
.create(request)
.await
.map_err(|e| ApiError::QueryError(e.to_string()))?;
let llm_response = process_llm_response(response).await?;
Ok(Answer {
content: llm_response.answer,
references: llm_response
.references
.into_iter()
.map(|r| r.reference)
.collect(),
})
}
pub fn format_entities_json(entities: &[KnowledgeEntity]) -> Value {
json!(entities
.iter()
.map(|entity| {
json!({
"KnowledgeEntity": {
"id": entity.id,
"name": entity.name,
"description": entity.description
}
})
})
.collect::<Vec<_>>())
}
pub fn create_user_message(entities_json: &Value, query: &str) -> String {
format!(
r#"
Context Information:
==================
{}
User Question:
==================
{}
"#,
entities_json, query
)
}
pub fn create_chat_request(user_message: String) -> Result<CreateChatCompletionRequest, ApiError> {
let response_format = ResponseFormat::JsonSchema {
json_schema: ResponseFormatJsonSchema {
description: Some("Query answering AI".into()),
name: "query_answering_with_uuids".into(),
schema: Some(get_query_response_schema()),
strict: Some(true),
},
};
CreateChatCompletionRequestArgs::default()
.model("gpt-4o-mini")
.temperature(0.2)
.max_tokens(3048u32)
.messages([
ChatCompletionRequestSystemMessage::from(QUERY_SYSTEM_PROMPT).into(),
ChatCompletionRequestUserMessage::from(user_message).into(),
])
.response_format(response_format)
.build()
.map_err(|e| ApiError::QueryError(e.to_string()))
}
pub async fn process_llm_response(
response: CreateChatCompletionResponse,
) -> Result<LLMResponseFormat, ApiError> {
response
.choices
.first()
.and_then(|choice| choice.message.content.as_ref())
.ok_or(ApiError::QueryError(
"No content found in LLM response".into(),
))
.and_then(|content| {
serde_json::from_str::<LLMResponseFormat>(content).map_err(|e| {
ApiError::QueryError(format!("Failed to parse LLM response into analysis: {}", e))
})
})
}

View File

@@ -1,48 +0,0 @@
use serde_json::{json, Value};
pub static QUERY_SYSTEM_PROMPT: &str = r#"
You are a knowledgeable assistant with access to a specialized knowledge base. You will be provided with relevant knowledge entities from the database as context. Each knowledge entity contains a name, description, and type, representing different concepts, ideas, and information.
Your task is to:
1. Carefully analyze the provided knowledge entities in the context
2. Answer user questions based on this information
3. Provide clear, concise, and accurate responses
4. When referencing information, briefly mention which knowledge entity it came from
5. If the provided context doesn't contain enough information to answer the question confidently, clearly state this
6. If only partial information is available, explain what you can answer and what information is missing
7. Avoid making assumptions or providing information not supported by the context
8. Output the references to the documents. Use the UUIDs and make sure they are correct!
Remember:
- Be direct and honest about the limitations of your knowledge
- Cite the relevant knowledge entities when providing information, but only provide the UUIDs in the reference array
- If you need to combine information from multiple entities, explain how they connect
- Don't speculate beyond what's provided in the context
Example response formats:
"Based on [Entity Name], [answer...]"
"I found relevant information in multiple entries: [explanation...]"
"I apologize, but the provided context doesn't contain information about [topic]"
"#;
pub fn get_query_response_schema() -> Value {
json!({
"type": "object",
"properties": {
"answer": { "type": "string" },
"references": {
"type": "array",
"items": {
"type": "object",
"properties": {
"reference": { "type": "string" },
},
"required": ["reference"],
"additionalProperties": false,
}
}
},
"required": ["answer", "references"],
"additionalProperties": false
})
}