state instead of extensions

This commit is contained in:
Per Stark
2024-12-09 11:30:39 +01:00
parent 50994c7df9
commit 620530bce1
8 changed files with 195 additions and 62 deletions

View File

@@ -1,11 +1,14 @@
use crate::storage::{
db::SurrealDbClient,
types::file_info::{FileError, FileInfo},
use crate::{
server::AppState,
storage::types::file_info::{FileError, FileInfo},
};
use axum::{
extract::{Path, State},
response::IntoResponse,
Json,
};
use axum::{extract::Path, response::IntoResponse, Extension, Json};
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
use serde_json::json;
use std::sync::Arc;
use tempfile::NamedTempFile;
use tracing::info;
use uuid::Uuid;
@@ -20,13 +23,13 @@ pub struct FileUploadRequest {
///
/// Route: POST /file
pub async fn upload_handler(
Extension(db_client): Extension<Arc<SurrealDbClient>>,
State(state): State<AppState>,
TypedMultipart(input): TypedMultipart<FileUploadRequest>,
) -> Result<impl IntoResponse, FileError> {
info!("Received an upload request");
// Process the file upload
let file_info = FileInfo::new(input.file, &db_client).await?;
let file_info = FileInfo::new(input.file, &state.surreal_db_client).await?;
// Prepare the response JSON
let response = json!({
@@ -46,14 +49,14 @@ pub async fn upload_handler(
///
/// Route: GET /file/:uuid
pub async fn get_file_handler(
Extension(db_client): Extension<Arc<SurrealDbClient>>,
State(state): State<AppState>,
Path(uuid_str): Path<String>,
) -> Result<impl IntoResponse, FileError> {
// Parse UUID
let uuid = Uuid::parse_str(&uuid_str).map_err(|_| FileError::InvalidUuid(uuid_str.clone()))?;
// Retrieve FileInfo
let file_info = FileInfo::get_by_uuid(uuid, &db_client).await?;
let file_info = FileInfo::get_by_uuid(uuid, &state.surreal_db_client).await?;
// Prepare the response JSON
let response = json!({
@@ -73,7 +76,7 @@ pub async fn get_file_handler(
///
/// Route: PUT /file/:uuid
pub async fn update_file_handler(
Extension(db_client): Extension<Arc<SurrealDbClient>>,
State(state): State<AppState>,
Path(uuid_str): Path<String>,
TypedMultipart(input): TypedMultipart<FileUploadRequest>,
) -> Result<impl IntoResponse, FileError> {
@@ -81,7 +84,7 @@ pub async fn update_file_handler(
let uuid = Uuid::parse_str(&uuid_str).map_err(|_| FileError::InvalidUuid(uuid_str.clone()))?;
// Update the file
let updated_file_info = FileInfo::update(uuid, input.file, &db_client).await?;
let updated_file_info = FileInfo::update(uuid, input.file, &state.surreal_db_client).await?;
// Prepare the response JSON
let response = json!({
@@ -101,14 +104,14 @@ pub async fn update_file_handler(
///
/// Route: DELETE /file/:uuid
pub async fn delete_file_handler(
Extension(db_client): Extension<Arc<SurrealDbClient>>,
State(state): State<AppState>,
Path(uuid_str): Path<String>,
) -> Result<impl IntoResponse, FileError> {
// Parse UUID
let uuid = Uuid::parse_str(&uuid_str).map_err(|_| FileError::InvalidUuid(uuid_str.clone()))?;
// Delete the file
FileInfo::delete(uuid, &db_client).await?;
FileInfo::delete(uuid, &state.surreal_db_client).await?;
info!("Deleted file with UUID: {}", uuid);

View File

@@ -1,14 +1,22 @@
use axum::{response::Html, Extension};
use axum::{extract::State, response::Html};
use serde_json::json;
use tera::{Context, Tera};
use tera::Context;
use tracing::info;
use crate::error::ApiError;
use crate::{error::ApiError, server::AppState};
pub async fn index_handler(Extension(tera): Extension<Tera>) -> Result<Html<String>, ApiError> {
let output = tera
pub async fn index_handler(State(state): State<AppState>) -> Result<Html<String>, ApiError> {
info!("Displaying index page");
// Now you can access the consumer directly from the state
let queue_length = state.rabbitmq_consumer.queue.message_count();
let output = state
.tera
.render(
"index.html",
&Context::from_value(json!({"adjective": "CRAYCRAY"})).unwrap(),
&Context::from_value(json!({"adjective": "CRAYCRAY", "queue_length": queue_length}))
.unwrap(),
)
.unwrap();

View File

@@ -1,26 +1,23 @@
use crate::{
error::ApiError,
ingress::types::ingress_input::{create_ingress_objects, IngressInput},
rabbitmq::publisher::RabbitMQProducer,
storage::db::SurrealDbClient,
server::AppState,
};
use axum::{http::StatusCode, response::IntoResponse, Extension, Json};
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
use futures::future::try_join_all;
use std::sync::Arc;
use tracing::info;
pub async fn ingress_handler(
Extension(producer): Extension<Arc<RabbitMQProducer>>,
Extension(db_client): Extension<Arc<SurrealDbClient>>,
State(state): State<AppState>,
Json(input): Json<IngressInput>,
) -> Result<impl IntoResponse, ApiError> {
info!("Received input: {:?}", input);
let ingress_objects = create_ingress_objects(input, &db_client).await?;
let ingress_objects = create_ingress_objects(input, &state.surreal_db_client).await?;
let futures: Vec<_> = ingress_objects
.into_iter()
.map(|object| producer.publish(object))
.map(|object| state.rabbitmq_producer.publish(object))
.collect();
try_join_all(futures).await?;

View File

@@ -1,12 +1,8 @@
pub mod helper;
pub mod prompt;
use std::sync::Arc;
use crate::{
error::ApiError, retrieval::combined_knowledge_entity_retrieval, storage::db::SurrealDbClient,
};
use axum::{response::IntoResponse, Extension, Json};
use crate::{error::ApiError, retrieval::combined_knowledge_entity_retrieval, server::AppState};
use axum::{extract::State, response::IntoResponse, Json};
use helper::{
create_chat_request, create_user_message, format_entities_json, process_llm_response,
};
@@ -32,16 +28,19 @@ pub struct LLMResponseFormat {
}
pub async fn query_handler(
Extension(db_client): Extension<Arc<SurrealDbClient>>,
State(state): State<AppState>,
Json(query): Json<QueryInput>,
) -> Result<impl IntoResponse, ApiError> {
info!("Received input: {:?}", query);
let openai_client = async_openai::Client::new();
// Retrieve entities
let entities =
combined_knowledge_entity_retrieval(&db_client, &openai_client, query.query.clone())
.await?;
let entities = combined_knowledge_entity_retrieval(
&state.surreal_db_client,
&openai_client,
query.query.clone(),
)
.await?;
// Format entities and create message
let entities_json = format_entities_json(&entities);