refactor: file_info, rabbitmq, queue

This commit is contained in:
Per Stark
2024-12-10 16:02:40 +01:00
parent f788b3065a
commit 42e2aad2b0
10 changed files with 113 additions and 161 deletions

View File

@@ -1,6 +1,6 @@
use axum::{ use axum::{
extract::DefaultBodyLimit, extract::DefaultBodyLimit,
routing::{delete, get, post, put}, routing::{get, post},
Router, Router,
}; };
use std::sync::Arc; use std::sync::Arc;
@@ -11,11 +11,8 @@ use zettle_db::{
rabbitmq::{consumer::RabbitMQConsumer, publisher::RabbitMQProducer, RabbitMQConfig}, rabbitmq::{consumer::RabbitMQConsumer, publisher::RabbitMQProducer, RabbitMQConfig},
server::{ server::{
routes::{ routes::{
file::{delete_file_handler, get_file_handler, update_file_handler, upload_handler}, file::upload_handler, index::index_handler, ingress::ingress_handler,
index::index_handler, query::query_handler, queue_length::queue_length_handler,
ingress::ingress_handler,
query::query_handler,
queue_length::queue_length_handler,
}, },
AppState, AppState,
}, },
@@ -41,7 +38,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let app_state = AppState { let app_state = AppState {
rabbitmq_producer: Arc::new(RabbitMQProducer::new(&config).await?), rabbitmq_producer: Arc::new(RabbitMQProducer::new(&config).await?),
rabbitmq_consumer: Arc::new(RabbitMQConsumer::new(&config).await?), rabbitmq_consumer: Arc::new(RabbitMQConsumer::new(&config, false).await?),
surreal_db_client: Arc::new(SurrealDbClient::new().await?), surreal_db_client: Arc::new(SurrealDbClient::new().await?),
tera: Arc::new(Tera::new("src/server/templates/**/*.html").unwrap()), tera: Arc::new(Tera::new("src/server/templates/**/*.html").unwrap()),
}; };
@@ -68,9 +65,6 @@ fn api_routes_v1() -> Router<AppState> {
// File routes // File routes
.route("/file", post(upload_handler)) .route("/file", post(upload_handler))
.layer(DefaultBodyLimit::max(1024 * 1024 * 1024)) .layer(DefaultBodyLimit::max(1024 * 1024 * 1024))
.route("/file/:uuid", get(get_file_handler))
.route("/file/:uuid", put(update_file_handler))
.route("/file/:uuid", delete(delete_file_handler))
// Query routes // Query routes
.route("/query", post(query_handler)) .route("/query", post(query_handler))
} }

View File

@@ -1,6 +1,6 @@
use tracing::info; use tracing::info;
use tracing_subscriber::{fmt, prelude::*, EnvFilter}; use tracing_subscriber::{fmt, prelude::*, EnvFilter};
use zettle_db::rabbitmq::{consumer::RabbitMQConsumer, RabbitMQConfig }; use zettle_db::rabbitmq::{consumer::RabbitMQConsumer, RabbitMQConfig};
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<(), Box<dyn std::error::Error>> {
@@ -12,7 +12,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.ok(); .ok();
info!("Starting RabbitMQ consumer"); info!("Starting RabbitMQ consumer");
// Set up RabbitMQ config // Set up RabbitMQ config
let config = RabbitMQConfig { let config = RabbitMQConfig {
amqp_addr: "amqp://localhost".to_string(), amqp_addr: "amqp://localhost".to_string(),
@@ -21,11 +21,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
routing_key: "my_key".to_string(), routing_key: "my_key".to_string(),
}; };
// Create a RabbitMQ consumer // Create a RabbitMQ consumer
let consumer = RabbitMQConsumer::new(&config).await?; let consumer = RabbitMQConsumer::new(&config, true).await?;
info!("Consumer connected to RabbitMQ");
// Start consuming messages // Start consuming messages
consumer.process_messages().await?; consumer.process_messages().await?;

View File

@@ -4,7 +4,10 @@ use serde_json::json;
use thiserror::Error; use thiserror::Error;
use tokio::task::JoinError; use tokio::task::JoinError;
use crate::{ingress::types::ingress_input::IngressContentError, rabbitmq::RabbitMQError}; use crate::{
ingress::types::ingress_input::IngressContentError, rabbitmq::RabbitMQError,
storage::types::file_info::FileError,
};
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum ProcessingError { pub enum ProcessingError {
@@ -55,6 +58,8 @@ pub enum ApiError {
RabbitMQError(#[from] RabbitMQError), RabbitMQError(#[from] RabbitMQError),
#[error("LLM processing error: {0}")] #[error("LLM processing error: {0}")]
OpenAIerror(#[from] OpenAIError), OpenAIerror(#[from] OpenAIError),
#[error("File error: {0}")]
FileError(#[from] FileError),
} }
impl IntoResponse for ApiError { impl IntoResponse for ApiError {
@@ -69,6 +74,7 @@ impl IntoResponse for ApiError {
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string()) (StatusCode::INTERNAL_SERVER_ERROR, self.to_string())
} }
ApiError::RabbitMQError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()), ApiError::RabbitMQError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
ApiError::FileError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
}; };
( (

View File

@@ -83,16 +83,16 @@ pub async fn create_ingress_objects(
} }
// Look up FileInfo objects using the db and the submitted uuids in input.files // Look up FileInfo objects using the db and the submitted uuids in input.files
if let Some(file_uuids) = input.files { if let Some(file_ids) = input.files {
for uuid in file_uuids { for id in file_ids {
if let Some(file_info) = get_item::<FileInfo>(&db_client, &uuid).await? { if let Some(file_info) = get_item::<FileInfo>(db_client, &id).await? {
object_list.push(IngressObject::File { object_list.push(IngressObject::File {
file_info, file_info,
instructions: input.instructions.clone(), instructions: input.instructions.clone(),
category: input.category.clone(), category: input.category.clone(),
}); });
} else { } else {
info!("No file with UUID: {}", uuid); info!("No file with id: {}", id);
} }
} }
} }

View File

@@ -13,7 +13,7 @@ use tracing::{error, info};
pub struct RabbitMQConsumer { pub struct RabbitMQConsumer {
common: RabbitMQCommon, common: RabbitMQCommon,
pub queue: Queue, pub queue: Queue,
consumer: Consumer, consumer: Option<Consumer>,
} }
impl RabbitMQConsumer { impl RabbitMQConsumer {
@@ -22,10 +22,14 @@ impl RabbitMQConsumer {
/// ///
/// # Arguments /// # Arguments
/// * `config` - A initialized RabbitMQConfig containing required configurations /// * `config` - A initialized RabbitMQConfig containing required configurations
/// * `start_consuming` - Set to true to start consuming messages
/// ///
/// # Returns /// # Returns
/// * `Result<Self, RabbitMQError>` - The created client or an error. /// * `Result<Self, RabbitMQError>` - The created client or an error.
pub async fn new(config: &RabbitMQConfig) -> Result<Self, RabbitMQError> { pub async fn new(
config: &RabbitMQConfig,
start_consuming: bool,
) -> Result<Self, RabbitMQError> {
let common = RabbitMQCommon::new(config).await?; let common = RabbitMQCommon::new(config).await?;
// Passively declare the exchange (it should already exist) // Passively declare the exchange (it should already exist)
@@ -36,7 +40,11 @@ impl RabbitMQConsumer {
Self::bind_queue(&common.channel, &config.exchange, &queue, config).await?; Self::bind_queue(&common.channel, &config.exchange, &queue, config).await?;
// Initialize the consumer // Initialize the consumer
let consumer = Self::initialize_consumer(&common.channel, config).await?; let consumer = if start_consuming {
Some(Self::initialize_consumer(&common.channel, config).await?)
} else {
None
};
Ok(Self { Ok(Self {
common, common,
@@ -67,6 +75,27 @@ impl RabbitMQConsumer {
.await .await
.map_err(|e| RabbitMQError::InitializeConsumerError(e.to_string())) .map_err(|e| RabbitMQError::InitializeConsumerError(e.to_string()))
} }
/// Operation to get the current queue length
/// Will redeclare queue to get a updated number
pub async fn get_queue_length(&self) -> Result<u32, RabbitMQError> {
let queue_info = self
.common
.channel
.queue_declare(
&self.queue.name().to_string(),
QueueDeclareOptions {
durable: true,
..QueueDeclareOptions::default()
},
FieldTable::default(),
)
.await
.map_err(|e| RabbitMQError::QueueError(e.to_string()))?;
Ok(queue_info.message_count())
}
/// Declares the queue based on the channel and `RabbitMQConfig`. /// Declares the queue based on the channel and `RabbitMQConfig`.
/// # Arguments /// # Arguments
/// * `channel` - Lapin Channel. /// * `channel` - Lapin Channel.
@@ -127,9 +156,15 @@ impl RabbitMQConsumer {
/// `IngressObject` - The object containing content and metadata. /// `IngressObject` - The object containing content and metadata.
/// `Delivery` - A delivery reciept, required to ack or nack the delivery. /// `Delivery` - A delivery reciept, required to ack or nack the delivery.
pub async fn consume(&self) -> Result<(IngressObject, Delivery), RabbitMQError> { pub async fn consume(&self) -> Result<(IngressObject, Delivery), RabbitMQError> {
// Get consumer or return error if not initialized
let consumer: &lapin::Consumer = self.consumer.as_ref().ok_or_else(|| {
RabbitMQError::ConsumeError(
"Consumer not initialized. Call new() with start_consuming=true".to_string(),
)
})?;
// Receive the next message // Receive the next message
let delivery = self let delivery = consumer
.consumer
.clone() .clone()
.next() .next()
.await .await

View File

@@ -1,21 +1,13 @@
use crate::{ use crate::{error::ApiError, server::AppState, storage::types::file_info::FileInfo};
server::AppState, use axum::{extract::State, response::IntoResponse, Json};
storage::types::file_info::{FileError, FileInfo},
};
use axum::{
extract::{Path, State},
response::IntoResponse,
Json,
};
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart}; use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
use serde_json::json; use serde_json::json;
use tempfile::NamedTempFile; use tempfile::NamedTempFile;
use tracing::info; use tracing::info;
use uuid::Uuid;
#[derive(Debug, TryFromMultipart)] #[derive(Debug, TryFromMultipart)]
pub struct FileUploadRequest { pub struct FileUploadRequest {
#[form_data(limit = "100000")] // Example limit: ~100 KB #[form_data(limit = "1000000")] // Example limit: ~1000 KB
pub file: FieldData<NamedTempFile>, pub file: FieldData<NamedTempFile>,
} }
@@ -25,7 +17,7 @@ pub struct FileUploadRequest {
pub async fn upload_handler( pub async fn upload_handler(
State(state): State<AppState>, State(state): State<AppState>,
TypedMultipart(input): TypedMultipart<FileUploadRequest>, TypedMultipart(input): TypedMultipart<FileUploadRequest>,
) -> Result<impl IntoResponse, FileError> { ) -> Result<impl IntoResponse, ApiError> {
info!("Received an upload request"); info!("Received an upload request");
// Process the file upload // Process the file upload
@@ -33,7 +25,7 @@ pub async fn upload_handler(
// Prepare the response JSON // Prepare the response JSON
let response = json!({ let response = json!({
"uuid": file_info.uuid, "id": file_info.id,
"sha256": file_info.sha256, "sha256": file_info.sha256,
"path": file_info.path, "path": file_info.path,
"mime_type": file_info.mime_type, "mime_type": file_info.mime_type,
@@ -44,82 +36,3 @@ pub async fn upload_handler(
// Return the response with HTTP 200 // Return the response with HTTP 200
Ok((axum::http::StatusCode::OK, Json(response))) Ok((axum::http::StatusCode::OK, Json(response)))
} }
/// Handler to retrieve file information by UUID.
///
/// Route: GET /file/:uuid
pub async fn get_file_handler(
State(state): State<AppState>,
Path(uuid_str): Path<String>,
) -> Result<impl IntoResponse, FileError> {
// Parse UUID
let uuid = Uuid::parse_str(&uuid_str).map_err(|_| FileError::InvalidUuid(uuid_str.clone()))?;
// Retrieve FileInfo
let file_info = FileInfo::get_by_uuid(uuid, &state.surreal_db_client).await?;
// Prepare the response JSON
let response = json!({
"uuid": file_info.uuid,
"sha256": file_info.sha256,
"path": file_info.path,
"mime_type": file_info.mime_type,
});
info!("Retrieved FileInfo: {:?}", file_info);
// Return the response with HTTP 200
Ok((axum::http::StatusCode::OK, Json(response)))
}
/// Handler to update an existing file by UUID.
///
/// Route: PUT /file/:uuid
pub async fn update_file_handler(
State(state): State<AppState>,
Path(uuid_str): Path<String>,
TypedMultipart(input): TypedMultipart<FileUploadRequest>,
) -> Result<impl IntoResponse, FileError> {
// Parse UUID
let uuid = Uuid::parse_str(&uuid_str).map_err(|_| FileError::InvalidUuid(uuid_str.clone()))?;
// Update the file
let updated_file_info = FileInfo::update(uuid, input.file, &state.surreal_db_client).await?;
// Prepare the response JSON
let response = json!({
"uuid": updated_file_info.uuid,
"sha256": updated_file_info.sha256,
"path": updated_file_info.path,
"mime_type": updated_file_info.mime_type,
});
info!("File updated successfully: {:?}", updated_file_info);
// Return the response with HTTP 200
Ok((axum::http::StatusCode::OK, Json(response)))
}
/// Handler to delete a file by UUID.
///
/// Route: DELETE /file/:uuid
pub async fn delete_file_handler(
State(state): State<AppState>,
Path(uuid_str): Path<String>,
) -> Result<impl IntoResponse, FileError> {
// Parse UUID
let uuid = Uuid::parse_str(&uuid_str).map_err(|_| FileError::InvalidUuid(uuid_str.clone()))?;
// Delete the file
FileInfo::delete(uuid, &state.surreal_db_client).await?;
info!("Deleted file with UUID: {}", uuid);
// Prepare the response JSON
let response = json!({
"message": "File deleted successfully",
});
// Return the response with HTTP 204 No Content
Ok((axum::http::StatusCode::NO_CONTENT, Json(response)))
}

View File

@@ -8,7 +8,6 @@ use crate::{error::ApiError, server::AppState};
pub async fn index_handler(State(state): State<AppState>) -> Result<Html<String>, ApiError> { pub async fn index_handler(State(state): State<AppState>) -> Result<Html<String>, ApiError> {
info!("Displaying index page"); info!("Displaying index page");
// Now you can access the consumer directly from the state
let queue_length = state.rabbitmq_consumer.queue.message_count(); let queue_length = state.rabbitmq_consumer.queue.message_count();
let output = state let output = state

View File

@@ -1,42 +1,17 @@
use axum::{ use axum::{extract::State, http::StatusCode, response::IntoResponse};
http::StatusCode, use tracing::info;
response::{IntoResponse, Response},
};
use tracing::{error, info};
use crate::rabbitmq::{consumer::RabbitMQConsumer, RabbitMQConfig}; use crate::{error::ApiError, server::AppState};
pub async fn queue_length_handler() -> Response { pub async fn queue_length_handler(
State(state): State<AppState>,
) -> Result<impl IntoResponse, ApiError> {
info!("Getting queue length"); info!("Getting queue length");
// Set up RabbitMQ config let queue_length = state.rabbitmq_consumer.get_queue_length().await?;
let config = RabbitMQConfig {
amqp_addr: "amqp://localhost".to_string(),
exchange: "my_exchange".to_string(),
queue: "my_queue".to_string(),
routing_key: "my_key".to_string(),
};
// Create a new consumer info!("Queue length: {}", queue_length);
match RabbitMQConsumer::new(&config).await {
Ok(consumer) => {
info!("Consumer connected to RabbitMQ");
// Get the queue length // Return the queue length with a 200 OK status
let queue_length = consumer.queue.message_count(); Ok((StatusCode::OK, queue_length.to_string()))
info!("Queue length: {}", queue_length);
// Return the queue length with a 200 OK status
(StatusCode::OK, queue_length.to_string()).into_response()
}
Err(e) => {
error!("Failed to create consumer: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to connect to RabbitMQ".to_string(),
)
.into_response()
}
}
} }

View File

@@ -111,5 +111,5 @@ pub async fn get_item<T>(db_client: &Surreal<Client>, id: &str) -> Result<Option
where where
T: for<'de> StoredObject, T: for<'de> StoredObject,
{ {
Ok(db_client.select((T::table_name(), id)).await?) db_client.select((T::table_name(), id)).await
} }

View File

@@ -14,24 +14,57 @@ pub trait StoredObject: Serialize + for<'de> Deserialize<'de> {
#[macro_export] #[macro_export]
macro_rules! stored_object { macro_rules! stored_object {
($name:ident, $table:expr, {$($(#[$attr:meta])* $field:ident: $ty:ty),*}) => { ($name:ident, $table:expr, {$($(#[$attr:meta])* $field:ident: $ty:ty),*}) => {
use axum::async_trait; use axum::async_trait;
use serde::{Deserialize, Deserializer, Serialize}; use serde::{Deserialize, Deserializer, Serialize};
use surrealdb::sql::Thing; use surrealdb::sql::Thing;
use $crate::storage::types::StoredObject; use $crate::storage::types::StoredObject;
use serde::de::{self, Visitor};
use std::fmt;
fn thing_to_string<'de, D>(deserializer: D) -> Result<String, D::Error> struct FlexibleIdVisitor;
impl<'de> Visitor<'de> for FlexibleIdVisitor {
type Value = String;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string or a Thing")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(value.to_string())
}
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(value)
}
fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
where
A: de::MapAccess<'de>,
{
// Try to deserialize as Thing
let thing = Thing::deserialize(de::value::MapAccessDeserializer::new(map))?;
Ok(thing.id.to_raw())
}
}
fn deserialize_flexible_id<'de, D>(deserializer: D) -> Result<String, D::Error>
where where
D: Deserializer<'de>, D: Deserializer<'de>,
{ {
let thing = Thing::deserialize(deserializer)?; deserializer.deserialize_any(FlexibleIdVisitor)
Ok(thing.id.to_raw())
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct $name { pub struct $name {
#[serde(deserialize_with = "thing_to_string")] #[serde(deserialize_with = "deserialize_flexible_id")]
pub id: String, pub id: String,
$(pub $field: $ty),* $(pub $field: $ty),*
} }