refactoring: new structure and mailer

This commit is contained in:
Per Stark
2024-12-19 23:15:12 +01:00
parent e54533d005
commit 13608bc41e
25 changed files with 659 additions and 379 deletions

View File

@@ -19,17 +19,20 @@ use zettle_db::{
server::{
middleware_api_auth::api_auth,
routes::{
auth::{show_signup_form, signup_handler},
file::upload_handler,
index::index_handler,
ingress::ingress_handler,
query::query_handler,
queue_length::queue_length_handler,
search_result::search_result_handler,
api::{
file::upload_handler, ingress::ingress_handler, query::query_handler,
queue_length::queue_length_handler,
},
html::{
auth::{show_signup_form, signup_handler},
index::index_handler,
search_result::search_result_handler,
},
},
AppState,
},
storage::{db::SurrealDbClient, types::user::User},
utils::mailer::Mailer,
};
#[tokio::main(flavor = "multi_thread", worker_threads = 2)]
@@ -59,6 +62,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
Ok(env)
});
let mailer = Mailer::new();
let app_state = AppState {
rabbitmq_producer: Arc::new(RabbitMQProducer::new(&config).await?),
rabbitmq_consumer: Arc::new(RabbitMQConsumer::new(&config, false).await?),

View File

@@ -1,4 +1,6 @@
pub mod graph;
pub mod query_helper;
pub mod query_helper_prompt;
pub mod vector;
use crate::{

View File

@@ -3,6 +3,7 @@ use async_openai::types::{
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
ResponseFormat, ResponseFormatJsonSchema,
};
use serde::Deserialize;
use serde_json::{json, Value};
use tracing::debug;
@@ -12,10 +13,20 @@ use crate::{
storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity},
};
use super::{
prompt::{get_query_response_schema, QUERY_SYSTEM_PROMPT},
LLMResponseFormat,
};
use super::query_helper_prompt::{get_query_response_schema, QUERY_SYSTEM_PROMPT};
#[derive(Debug, Deserialize)]
pub struct Reference {
#[allow(dead_code)]
pub reference: String,
}
#[derive(Debug, Deserialize)]
pub struct LLMResponseFormat {
pub answer: String,
#[allow(dead_code)]
pub references: Vec<Reference>,
}
// /// Orchestrator function that takes a query and clients and returns a answer with references
// ///

View File

@@ -0,0 +1,4 @@
pub mod file;
pub mod ingress;
pub mod query;
pub mod queue_length;

View File

@@ -1,9 +1,8 @@
pub mod helper;
pub mod prompt;
use crate::{error::ApiError, server::AppState, storage::types::user::User};
use crate::{
error::ApiError, retrieval::query_helper::get_answer_with_references, server::AppState,
storage::types::user::User,
};
use axum::{extract::State, response::IntoResponse, Extension, Json};
use helper::get_answer_with_references;
use serde::Deserialize;
use tracing::info;
@@ -12,19 +11,6 @@ pub struct QueryInput {
query: String,
}
#[derive(Debug, Deserialize)]
pub struct Reference {
#[allow(dead_code)]
pub reference: String,
}
#[derive(Debug, Deserialize)]
pub struct LLMResponseFormat {
pub answer: String,
#[allow(dead_code)]
pub references: Vec<Reference>,
}
pub async fn query_handler(
State(state): State<AppState>,
Extension(user): Extension<User>,

View File

@@ -1,7 +1,6 @@
use axum::{
extract::State,
http::Response,
response::{Html, IntoResponse},
response::{IntoResponse, Redirect},
Form,
};
use axum_htmx::HxBoosted;
@@ -27,19 +26,23 @@ struct PageData {
pub async fn show_signup_form(
State(state): State<AppState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
HxBoosted(boosted): HxBoosted,
) -> Result<Html<String>, ApiError> {
) -> Result<impl IntoResponse, ApiError> {
if auth.is_authenticated() {
return Ok(Redirect::to("/").into_response());
}
let output = match boosted {
true => render_block(
"auth/signup_form.html",
"content",
"body",
PageData {},
state.templates,
)?,
false => render_template("auth/signup_form.html", PageData {}, state.templates)?,
};
Ok(output)
Ok(output.into_response())
}
pub async fn signup_handler(

View File

@@ -1,24 +1,19 @@
use axum::{extract::State, response::Html};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use minijinja::context;
use serde::Serialize;
use serde_json::json;
use surrealdb::{engine::any::Any, sql::Relation, Surreal};
use tera::Context;
// use tera::Context;
use surrealdb::{engine::any::Any, Surreal};
use tracing::info;
use crate::{
error::ApiError,
server::{routes::render_template, AppState},
page_data,
server::{routes::html::render_template, AppState},
storage::types::user::User,
};
#[derive(Serialize)]
struct PageData<'a> {
queue_length: &'a str,
}
page_data!(IndexData, {
queue_length: u32,
});
pub async fn index_handler(
State(state): State<AppState>,
@@ -30,13 +25,7 @@ pub async fn index_handler(
let queue_length = state.rabbitmq_consumer.get_queue_length().await?;
let output = render_template(
"index.html",
PageData {
queue_length: "1000",
},
state.templates,
)?;
let output = render_template("index.html", IndexData { queue_length }, state.templates)?;
Ok(output)
}

View File

@@ -0,0 +1,55 @@
use std::sync::Arc;
use axum::response::Html;
use minijinja_autoreload::AutoReloader;
pub mod auth;
pub mod index;
pub mod search_result;
pub fn render_template<T>(
template_name: &str,
context: T,
templates: Arc<AutoReloader>,
) -> Result<Html<String>, minijinja::Error>
where
T: serde::Serialize,
{
let env = templates.acquire_env()?;
let tmpl = env.get_template(template_name)?;
let context = minijinja::Value::from_serialize(&context);
let output = tmpl.render(context)?;
Ok(output.into())
}
pub fn render_block<T>(
template_name: &str,
block: &str,
context: T,
templates: Arc<AutoReloader>,
) -> Result<Html<String>, minijinja::Error>
where
T: serde::Serialize,
{
let env = templates.acquire_env()?;
let tmpl = env.get_template(template_name)?;
let context = minijinja::Value::from_serialize(&context);
let output = tmpl.eval_to_state(context)?.render_block(block)?;
Ok(output.into())
}
#[macro_export]
macro_rules! page_data {
($name:ident, {$($(#[$attr:meta])* $field:ident: $ty:ty),*$(,)?}) => {
use serde::{Serialize, Deserialize};
#[derive(Debug, Deserialize, Serialize)]
pub struct $name {
$($(#[$attr])* pub $field: $ty),*
}
};
}

View File

@@ -5,14 +5,11 @@ use axum::{
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use serde::Deserialize;
use serde_json::json;
use surrealdb::{engine::any::Any, Surreal};
use tera::Context;
use tracing::info;
use crate::{
error::ApiError,
server::{routes::query::helper::get_answer_with_references, AppState},
error::ApiError, retrieval::query_helper::get_answer_with_references, server::AppState,
storage::types::user::User,
};
#[derive(Deserialize)]

View File

@@ -1,47 +1,2 @@
use std::sync::Arc;
use axum::response::Html;
use minijinja_autoreload::AutoReloader;
pub mod auth;
pub mod file;
pub mod index;
pub mod ingress;
pub mod query;
pub mod queue_length;
pub mod search_result;
pub fn render_template<T>(
template_name: &str,
context: T,
templates: Arc<AutoReloader>,
) -> Result<Html<String>, minijinja::Error>
where
T: serde::Serialize,
{
let env = templates.acquire_env()?;
let tmpl = env.get_template(template_name)?;
let context = minijinja::Value::from_serialize(&context);
let output = tmpl.render(context)?;
Ok(output.into())
}
pub fn render_block<T>(
template_name: &str,
block: &str,
context: T,
templates: Arc<AutoReloader>,
) -> Result<Html<String>, minijinja::Error>
where
T: serde::Serialize,
{
let env = templates.acquire_env()?;
let tmpl = env.get_template(template_name)?;
let context = minijinja::Value::from_serialize(&context);
let output = tmpl.eval_to_state(context)?.render_block(block)?;
Ok(output.into())
}
pub mod api;
pub mod html;

23
src/utils/mailer.rs Normal file
View File

@@ -0,0 +1,23 @@
use std::env;
use lettre::{transport::smtp::authentication::Credentials, SmtpTransport};
pub struct Mailer {
pub mailer: SmtpTransport,
}
impl Mailer {
pub fn new() -> Self {
let creds = Credentials::new(
env::var("SMTP_USERNAME").unwrap().to_owned(),
env::var("SMTP_PASSWORD").unwrap().to_owned(),
);
let mailer = SmtpTransport::relay(env::var("SMTP_RELAYER").unwrap().as_str())
.unwrap()
.credentials(creds)
.build();
Mailer { mailer }
}
}

View File

@@ -1 +1,2 @@
pub mod embedding;
pub mod mailer;