state instead of extensions

This commit is contained in:
Per Stark
2024-12-09 11:30:39 +01:00
parent 3bde1d4b06
commit f067f2aaa8
8 changed files with 195 additions and 62 deletions

View File

@@ -1,20 +1,23 @@
use axum::{
extract::DefaultBodyLimit,
routing::{delete, get, post, put},
Extension, Router,
Router,
};
use std::sync::Arc;
use tera::Tera;
use tower_http::services::ServeDir;
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
use zettle_db::{
rabbitmq::{publisher::RabbitMQProducer, RabbitMQConfig},
server::routes::{
file::{delete_file_handler, get_file_handler, update_file_handler, upload_handler},
index::index_handler,
ingress::ingress_handler,
query::query_handler,
queue_length::queue_length_handler,
rabbitmq::{consumer::RabbitMQConsumer, publisher::RabbitMQProducer, RabbitMQConfig},
server::{
routes::{
file::{delete_file_handler, get_file_handler, update_file_handler, upload_handler},
index::index_handler,
ingress::ingress_handler,
query::query_handler,
queue_length::queue_length_handler,
},
AppState,
},
storage::db::SurrealDbClient,
};
@@ -28,8 +31,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.try_init()
.ok();
let tera = Tera::new("src/server/templates/**/*.html").unwrap();
// Set up RabbitMQ
let config = RabbitMQConfig {
amqp_addr: "amqp://localhost".to_string(),
@@ -38,18 +39,18 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
routing_key: "my_key".to_string(),
};
// Set up producer
let producer = Arc::new(RabbitMQProducer::new(&config).await?);
// Set up database client
let db_client = Arc::new(SurrealDbClient::new().await?);
let app_state = AppState {
rabbitmq_producer: Arc::new(RabbitMQProducer::new(&config).await?),
rabbitmq_consumer: Arc::new(RabbitMQConsumer::new(&config).await?),
surreal_db_client: Arc::new(SurrealDbClient::new().await?),
tera: Arc::new(Tera::new("src/server/templates/**/*.html").unwrap()),
};
// Create Axum router
let app = Router::new()
// Ingress routes
.route("/ingress", post(ingress_handler))
.route("/message_count", get(queue_length_handler))
.layer(Extension(producer))
// File routes
.route("/file", post(upload_handler))
.layer(DefaultBodyLimit::max(1024 * 1024 * 1024))
@@ -58,10 +59,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.route("/file/:uuid", delete(delete_file_handler))
// Query routes
.route("/query", post(query_handler))
.layer(Extension(db_client))
// Html routes
.route("/", get(index_handler))
.layer(Extension(tera))
.with_state(app_state)
.nest_service("/assets", ServeDir::new("src/server/assets"));
tracing::info!("Listening on 0.0.0.0:3000");

View File

@@ -554,17 +554,120 @@ video {
display: none;
}
.container {
width: 100%;
}
@media (min-width: 640px) {
.container {
max-width: 640px;
}
}
@media (min-width: 768px) {
.container {
max-width: 768px;
}
}
@media (min-width: 1024px) {
.container {
max-width: 1024px;
}
}
@media (min-width: 1280px) {
.container {
max-width: 1280px;
}
}
@media (min-width: 1536px) {
.container {
max-width: 1536px;
}
}
.h-full {
height: 100%;
}
.rounded {
border-radius: 0.25rem;
}
.rounded-s {
border-start-start-radius: 0.25rem;
border-end-start-radius: 0.25rem;
}
.rounded-b-md {
border-bottom-right-radius: 0.375rem;
border-bottom-left-radius: 0.375rem;
}
.bg-blue-300 {
--tw-bg-opacity: 1;
background-color: rgb(147 197 253 / var(--tw-bg-opacity, 1));
}
.text-3xl {
font-size: 1.875rem;
line-height: 2.25rem;
.bg-red-500 {
--tw-bg-opacity: 1;
background-color: rgb(239 68 68 / var(--tw-bg-opacity, 1));
}
.bg-slate-50 {
--tw-bg-opacity: 1;
background-color: rgb(248 250 252 / var(--tw-bg-opacity, 1));
}
.bg-slate-100 {
--tw-bg-opacity: 1;
background-color: rgb(241 245 249 / var(--tw-bg-opacity, 1));
}
.bg-blue-200 {
--tw-bg-opacity: 1;
background-color: rgb(191 219 254 / var(--tw-bg-opacity, 1));
}
.bg-blue-400 {
--tw-bg-opacity: 1;
background-color: rgb(96 165 250 / var(--tw-bg-opacity, 1));
}
.bg-blue-600 {
--tw-bg-opacity: 1;
background-color: rgb(37 99 235 / var(--tw-bg-opacity, 1));
}
.p-10 {
padding: 2.5rem;
}
.p-2 {
padding: 0.5rem;
}
.py-4 {
padding-top: 1rem;
padding-bottom: 1rem;
}
.font-mono {
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;
}
.text-4xl {
font-size: 2.25rem;
line-height: 2.5rem;
}
.text-lg {
font-size: 1.125rem;
line-height: 1.75rem;
}
.font-bold {
font-weight: 700;
}

View File

@@ -1 +1,15 @@
use crate::rabbitmq::consumer::RabbitMQConsumer;
use crate::rabbitmq::publisher::RabbitMQProducer;
use crate::storage::db::SurrealDbClient;
use std::sync::Arc;
use tera::Tera;
pub mod routes;
#[derive(Clone)]
pub struct AppState {
pub rabbitmq_producer: Arc<RabbitMQProducer>,
pub rabbitmq_consumer: Arc<RabbitMQConsumer>,
pub surreal_db_client: Arc<SurrealDbClient>,
pub tera: Arc<Tera>,
}

View File

@@ -1,11 +1,14 @@
use crate::storage::{
db::SurrealDbClient,
types::file_info::{FileError, FileInfo},
use crate::{
server::AppState,
storage::types::file_info::{FileError, FileInfo},
};
use axum::{
extract::{Path, State},
response::IntoResponse,
Json,
};
use axum::{extract::Path, response::IntoResponse, Extension, Json};
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
use serde_json::json;
use std::sync::Arc;
use tempfile::NamedTempFile;
use tracing::info;
use uuid::Uuid;
@@ -20,13 +23,13 @@ pub struct FileUploadRequest {
///
/// Route: POST /file
pub async fn upload_handler(
Extension(db_client): Extension<Arc<SurrealDbClient>>,
State(state): State<AppState>,
TypedMultipart(input): TypedMultipart<FileUploadRequest>,
) -> Result<impl IntoResponse, FileError> {
info!("Received an upload request");
// Process the file upload
let file_info = FileInfo::new(input.file, &db_client).await?;
let file_info = FileInfo::new(input.file, &state.surreal_db_client).await?;
// Prepare the response JSON
let response = json!({
@@ -46,14 +49,14 @@ pub async fn upload_handler(
///
/// Route: GET /file/:uuid
pub async fn get_file_handler(
Extension(db_client): Extension<Arc<SurrealDbClient>>,
State(state): State<AppState>,
Path(uuid_str): Path<String>,
) -> Result<impl IntoResponse, FileError> {
// Parse UUID
let uuid = Uuid::parse_str(&uuid_str).map_err(|_| FileError::InvalidUuid(uuid_str.clone()))?;
// Retrieve FileInfo
let file_info = FileInfo::get_by_uuid(uuid, &db_client).await?;
let file_info = FileInfo::get_by_uuid(uuid, &state.surreal_db_client).await?;
// Prepare the response JSON
let response = json!({
@@ -73,7 +76,7 @@ pub async fn get_file_handler(
///
/// Route: PUT /file/:uuid
pub async fn update_file_handler(
Extension(db_client): Extension<Arc<SurrealDbClient>>,
State(state): State<AppState>,
Path(uuid_str): Path<String>,
TypedMultipart(input): TypedMultipart<FileUploadRequest>,
) -> Result<impl IntoResponse, FileError> {
@@ -81,7 +84,7 @@ pub async fn update_file_handler(
let uuid = Uuid::parse_str(&uuid_str).map_err(|_| FileError::InvalidUuid(uuid_str.clone()))?;
// Update the file
let updated_file_info = FileInfo::update(uuid, input.file, &db_client).await?;
let updated_file_info = FileInfo::update(uuid, input.file, &state.surreal_db_client).await?;
// Prepare the response JSON
let response = json!({
@@ -101,14 +104,14 @@ pub async fn update_file_handler(
///
/// Route: DELETE /file/:uuid
pub async fn delete_file_handler(
Extension(db_client): Extension<Arc<SurrealDbClient>>,
State(state): State<AppState>,
Path(uuid_str): Path<String>,
) -> Result<impl IntoResponse, FileError> {
// Parse UUID
let uuid = Uuid::parse_str(&uuid_str).map_err(|_| FileError::InvalidUuid(uuid_str.clone()))?;
// Delete the file
FileInfo::delete(uuid, &db_client).await?;
FileInfo::delete(uuid, &state.surreal_db_client).await?;
info!("Deleted file with UUID: {}", uuid);

View File

@@ -1,14 +1,22 @@
use axum::{response::Html, Extension};
use axum::{extract::State, response::Html};
use serde_json::json;
use tera::{Context, Tera};
use tera::Context;
use tracing::info;
use crate::error::ApiError;
use crate::{error::ApiError, server::AppState};
pub async fn index_handler(Extension(tera): Extension<Tera>) -> Result<Html<String>, ApiError> {
let output = tera
pub async fn index_handler(State(state): State<AppState>) -> Result<Html<String>, ApiError> {
info!("Displaying index page");
// Now you can access the consumer directly from the state
let queue_length = state.rabbitmq_consumer.queue.message_count();
let output = state
.tera
.render(
"index.html",
&Context::from_value(json!({"adjective": "CRAYCRAY"})).unwrap(),
&Context::from_value(json!({"adjective": "CRAYCRAY", "queue_length": queue_length}))
.unwrap(),
)
.unwrap();

View File

@@ -1,26 +1,23 @@
use crate::{
error::ApiError,
ingress::types::ingress_input::{create_ingress_objects, IngressInput},
rabbitmq::publisher::RabbitMQProducer,
storage::db::SurrealDbClient,
server::AppState,
};
use axum::{http::StatusCode, response::IntoResponse, Extension, Json};
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
use futures::future::try_join_all;
use std::sync::Arc;
use tracing::info;
pub async fn ingress_handler(
Extension(producer): Extension<Arc<RabbitMQProducer>>,
Extension(db_client): Extension<Arc<SurrealDbClient>>,
State(state): State<AppState>,
Json(input): Json<IngressInput>,
) -> Result<impl IntoResponse, ApiError> {
info!("Received input: {:?}", input);
let ingress_objects = create_ingress_objects(input, &db_client).await?;
let ingress_objects = create_ingress_objects(input, &state.surreal_db_client).await?;
let futures: Vec<_> = ingress_objects
.into_iter()
.map(|object| producer.publish(object))
.map(|object| state.rabbitmq_producer.publish(object))
.collect();
try_join_all(futures).await?;

View File

@@ -1,12 +1,8 @@
pub mod helper;
pub mod prompt;
use std::sync::Arc;
use crate::{
error::ApiError, retrieval::combined_knowledge_entity_retrieval, storage::db::SurrealDbClient,
};
use axum::{response::IntoResponse, Extension, Json};
use crate::{error::ApiError, retrieval::combined_knowledge_entity_retrieval, server::AppState};
use axum::{extract::State, response::IntoResponse, Json};
use helper::{
create_chat_request, create_user_message, format_entities_json, process_llm_response,
};
@@ -32,16 +28,19 @@ pub struct LLMResponseFormat {
}
pub async fn query_handler(
Extension(db_client): Extension<Arc<SurrealDbClient>>,
State(state): State<AppState>,
Json(query): Json<QueryInput>,
) -> Result<impl IntoResponse, ApiError> {
info!("Received input: {:?}", query);
let openai_client = async_openai::Client::new();
// Retrieve entities
let entities =
combined_knowledge_entity_retrieval(&db_client, &openai_client, query.query.clone())
.await?;
let entities = combined_knowledge_entity_retrieval(
&state.surreal_db_client,
&openai_client,
query.query.clone(),
)
.await?;
// Format entities and create message
let entities_json = format_entities_json(&entities);

View File

@@ -3,13 +3,22 @@
<head>
<link rel="stylesheet" type="text/css" href="assets/style.css">
<script src="https://unpkg.com/htmx.org@2.0.3"></script>
<title>
radien
</title>
</head>
<body class="h-full bg-slate-100">
<h1 class="text-4xl bg-blue-300 p-2 font-mono font-bold">radien</h1>
<div class="p-10">
I am {{adjective}}
<div class="py-4">
<h2 class="text-lg font-bold">Queue:</h2>
<p class="">There are {{queue_length}} items queued </p>
</div>
<body>
<h1 class="text-4xl bg-blue-300">Hello world!</h1>
I am {{adjective}}
<h1>WOW!</h1>
</div>
</body>
</html>