mirror of
https://github.com/perstarkse/minne.git
synced 2026-04-23 09:18:36 +02:00
refactor: file_info, rabbitmq, queue
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
use axum::{
|
use axum::{
|
||||||
extract::DefaultBodyLimit,
|
extract::DefaultBodyLimit,
|
||||||
routing::{delete, get, post, put},
|
routing::{get, post},
|
||||||
Router,
|
Router,
|
||||||
};
|
};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@@ -11,11 +11,8 @@ use zettle_db::{
|
|||||||
rabbitmq::{consumer::RabbitMQConsumer, publisher::RabbitMQProducer, RabbitMQConfig},
|
rabbitmq::{consumer::RabbitMQConsumer, publisher::RabbitMQProducer, RabbitMQConfig},
|
||||||
server::{
|
server::{
|
||||||
routes::{
|
routes::{
|
||||||
file::{delete_file_handler, get_file_handler, update_file_handler, upload_handler},
|
file::upload_handler, index::index_handler, ingress::ingress_handler,
|
||||||
index::index_handler,
|
query::query_handler, queue_length::queue_length_handler,
|
||||||
ingress::ingress_handler,
|
|
||||||
query::query_handler,
|
|
||||||
queue_length::queue_length_handler,
|
|
||||||
},
|
},
|
||||||
AppState,
|
AppState,
|
||||||
},
|
},
|
||||||
@@ -41,7 +38,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
|
|
||||||
let app_state = AppState {
|
let app_state = AppState {
|
||||||
rabbitmq_producer: Arc::new(RabbitMQProducer::new(&config).await?),
|
rabbitmq_producer: Arc::new(RabbitMQProducer::new(&config).await?),
|
||||||
rabbitmq_consumer: Arc::new(RabbitMQConsumer::new(&config).await?),
|
rabbitmq_consumer: Arc::new(RabbitMQConsumer::new(&config, false).await?),
|
||||||
surreal_db_client: Arc::new(SurrealDbClient::new().await?),
|
surreal_db_client: Arc::new(SurrealDbClient::new().await?),
|
||||||
tera: Arc::new(Tera::new("src/server/templates/**/*.html").unwrap()),
|
tera: Arc::new(Tera::new("src/server/templates/**/*.html").unwrap()),
|
||||||
};
|
};
|
||||||
@@ -68,9 +65,6 @@ fn api_routes_v1() -> Router<AppState> {
|
|||||||
// File routes
|
// File routes
|
||||||
.route("/file", post(upload_handler))
|
.route("/file", post(upload_handler))
|
||||||
.layer(DefaultBodyLimit::max(1024 * 1024 * 1024))
|
.layer(DefaultBodyLimit::max(1024 * 1024 * 1024))
|
||||||
.route("/file/:uuid", get(get_file_handler))
|
|
||||||
.route("/file/:uuid", put(update_file_handler))
|
|
||||||
.route("/file/:uuid", delete(delete_file_handler))
|
|
||||||
// Query routes
|
// Query routes
|
||||||
.route("/query", post(query_handler))
|
.route("/query", post(query_handler))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
use tracing::info;
|
use tracing::info;
|
||||||
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
|
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
|
||||||
use zettle_db::rabbitmq::{consumer::RabbitMQConsumer, RabbitMQConfig };
|
use zettle_db::rabbitmq::{consumer::RabbitMQConsumer, RabbitMQConfig};
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
@@ -12,7 +12,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
.ok();
|
.ok();
|
||||||
|
|
||||||
info!("Starting RabbitMQ consumer");
|
info!("Starting RabbitMQ consumer");
|
||||||
|
|
||||||
// Set up RabbitMQ config
|
// Set up RabbitMQ config
|
||||||
let config = RabbitMQConfig {
|
let config = RabbitMQConfig {
|
||||||
amqp_addr: "amqp://localhost".to_string(),
|
amqp_addr: "amqp://localhost".to_string(),
|
||||||
@@ -21,11 +21,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
routing_key: "my_key".to_string(),
|
routing_key: "my_key".to_string(),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
// Create a RabbitMQ consumer
|
// Create a RabbitMQ consumer
|
||||||
let consumer = RabbitMQConsumer::new(&config).await?;
|
let consumer = RabbitMQConsumer::new(&config, true).await?;
|
||||||
|
|
||||||
info!("Consumer connected to RabbitMQ");
|
|
||||||
|
|
||||||
// Start consuming messages
|
// Start consuming messages
|
||||||
consumer.process_messages().await?;
|
consumer.process_messages().await?;
|
||||||
|
|||||||
@@ -4,7 +4,10 @@ use serde_json::json;
|
|||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::task::JoinError;
|
use tokio::task::JoinError;
|
||||||
|
|
||||||
use crate::{ingress::types::ingress_input::IngressContentError, rabbitmq::RabbitMQError};
|
use crate::{
|
||||||
|
ingress::types::ingress_input::IngressContentError, rabbitmq::RabbitMQError,
|
||||||
|
storage::types::file_info::FileError,
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum ProcessingError {
|
pub enum ProcessingError {
|
||||||
@@ -55,6 +58,8 @@ pub enum ApiError {
|
|||||||
RabbitMQError(#[from] RabbitMQError),
|
RabbitMQError(#[from] RabbitMQError),
|
||||||
#[error("LLM processing error: {0}")]
|
#[error("LLM processing error: {0}")]
|
||||||
OpenAIerror(#[from] OpenAIError),
|
OpenAIerror(#[from] OpenAIError),
|
||||||
|
#[error("File error: {0}")]
|
||||||
|
FileError(#[from] FileError),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IntoResponse for ApiError {
|
impl IntoResponse for ApiError {
|
||||||
@@ -69,6 +74,7 @@ impl IntoResponse for ApiError {
|
|||||||
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string())
|
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string())
|
||||||
}
|
}
|
||||||
ApiError::RabbitMQError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
|
ApiError::RabbitMQError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
|
||||||
|
ApiError::FileError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
|
||||||
};
|
};
|
||||||
|
|
||||||
(
|
(
|
||||||
|
|||||||
@@ -83,16 +83,16 @@ pub async fn create_ingress_objects(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Look up FileInfo objects using the db and the submitted uuids in input.files
|
// Look up FileInfo objects using the db and the submitted uuids in input.files
|
||||||
if let Some(file_uuids) = input.files {
|
if let Some(file_ids) = input.files {
|
||||||
for uuid in file_uuids {
|
for id in file_ids {
|
||||||
if let Some(file_info) = get_item::<FileInfo>(&db_client, &uuid).await? {
|
if let Some(file_info) = get_item::<FileInfo>(db_client, &id).await? {
|
||||||
object_list.push(IngressObject::File {
|
object_list.push(IngressObject::File {
|
||||||
file_info,
|
file_info,
|
||||||
instructions: input.instructions.clone(),
|
instructions: input.instructions.clone(),
|
||||||
category: input.category.clone(),
|
category: input.category.clone(),
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
info!("No file with UUID: {}", uuid);
|
info!("No file with id: {}", id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ use tracing::{error, info};
|
|||||||
pub struct RabbitMQConsumer {
|
pub struct RabbitMQConsumer {
|
||||||
common: RabbitMQCommon,
|
common: RabbitMQCommon,
|
||||||
pub queue: Queue,
|
pub queue: Queue,
|
||||||
consumer: Consumer,
|
consumer: Option<Consumer>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RabbitMQConsumer {
|
impl RabbitMQConsumer {
|
||||||
@@ -22,10 +22,14 @@ impl RabbitMQConsumer {
|
|||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * `config` - A initialized RabbitMQConfig containing required configurations
|
/// * `config` - A initialized RabbitMQConfig containing required configurations
|
||||||
|
/// * `start_consuming` - Set to true to start consuming messages
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// * `Result<Self, RabbitMQError>` - The created client or an error.
|
/// * `Result<Self, RabbitMQError>` - The created client or an error.
|
||||||
pub async fn new(config: &RabbitMQConfig) -> Result<Self, RabbitMQError> {
|
pub async fn new(
|
||||||
|
config: &RabbitMQConfig,
|
||||||
|
start_consuming: bool,
|
||||||
|
) -> Result<Self, RabbitMQError> {
|
||||||
let common = RabbitMQCommon::new(config).await?;
|
let common = RabbitMQCommon::new(config).await?;
|
||||||
|
|
||||||
// Passively declare the exchange (it should already exist)
|
// Passively declare the exchange (it should already exist)
|
||||||
@@ -36,7 +40,11 @@ impl RabbitMQConsumer {
|
|||||||
Self::bind_queue(&common.channel, &config.exchange, &queue, config).await?;
|
Self::bind_queue(&common.channel, &config.exchange, &queue, config).await?;
|
||||||
|
|
||||||
// Initialize the consumer
|
// Initialize the consumer
|
||||||
let consumer = Self::initialize_consumer(&common.channel, config).await?;
|
let consumer = if start_consuming {
|
||||||
|
Some(Self::initialize_consumer(&common.channel, config).await?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
common,
|
common,
|
||||||
@@ -67,6 +75,27 @@ impl RabbitMQConsumer {
|
|||||||
.await
|
.await
|
||||||
.map_err(|e| RabbitMQError::InitializeConsumerError(e.to_string()))
|
.map_err(|e| RabbitMQError::InitializeConsumerError(e.to_string()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Operation to get the current queue length
|
||||||
|
/// Will redeclare queue to get a updated number
|
||||||
|
pub async fn get_queue_length(&self) -> Result<u32, RabbitMQError> {
|
||||||
|
let queue_info = self
|
||||||
|
.common
|
||||||
|
.channel
|
||||||
|
.queue_declare(
|
||||||
|
&self.queue.name().to_string(),
|
||||||
|
QueueDeclareOptions {
|
||||||
|
durable: true,
|
||||||
|
..QueueDeclareOptions::default()
|
||||||
|
},
|
||||||
|
FieldTable::default(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(|e| RabbitMQError::QueueError(e.to_string()))?;
|
||||||
|
|
||||||
|
Ok(queue_info.message_count())
|
||||||
|
}
|
||||||
|
|
||||||
/// Declares the queue based on the channel and `RabbitMQConfig`.
|
/// Declares the queue based on the channel and `RabbitMQConfig`.
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * `channel` - Lapin Channel.
|
/// * `channel` - Lapin Channel.
|
||||||
@@ -127,9 +156,15 @@ impl RabbitMQConsumer {
|
|||||||
/// `IngressObject` - The object containing content and metadata.
|
/// `IngressObject` - The object containing content and metadata.
|
||||||
/// `Delivery` - A delivery reciept, required to ack or nack the delivery.
|
/// `Delivery` - A delivery reciept, required to ack or nack the delivery.
|
||||||
pub async fn consume(&self) -> Result<(IngressObject, Delivery), RabbitMQError> {
|
pub async fn consume(&self) -> Result<(IngressObject, Delivery), RabbitMQError> {
|
||||||
|
// Get consumer or return error if not initialized
|
||||||
|
let consumer: &lapin::Consumer = self.consumer.as_ref().ok_or_else(|| {
|
||||||
|
RabbitMQError::ConsumeError(
|
||||||
|
"Consumer not initialized. Call new() with start_consuming=true".to_string(),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
// Receive the next message
|
// Receive the next message
|
||||||
let delivery = self
|
let delivery = consumer
|
||||||
.consumer
|
|
||||||
.clone()
|
.clone()
|
||||||
.next()
|
.next()
|
||||||
.await
|
.await
|
||||||
|
|||||||
@@ -1,21 +1,13 @@
|
|||||||
use crate::{
|
use crate::{error::ApiError, server::AppState, storage::types::file_info::FileInfo};
|
||||||
server::AppState,
|
use axum::{extract::State, response::IntoResponse, Json};
|
||||||
storage::types::file_info::{FileError, FileInfo},
|
|
||||||
};
|
|
||||||
use axum::{
|
|
||||||
extract::{Path, State},
|
|
||||||
response::IntoResponse,
|
|
||||||
Json,
|
|
||||||
};
|
|
||||||
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
|
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use tempfile::NamedTempFile;
|
use tempfile::NamedTempFile;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
#[derive(Debug, TryFromMultipart)]
|
#[derive(Debug, TryFromMultipart)]
|
||||||
pub struct FileUploadRequest {
|
pub struct FileUploadRequest {
|
||||||
#[form_data(limit = "100000")] // Example limit: ~100 KB
|
#[form_data(limit = "1000000")] // Example limit: ~1000 KB
|
||||||
pub file: FieldData<NamedTempFile>,
|
pub file: FieldData<NamedTempFile>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -25,7 +17,7 @@ pub struct FileUploadRequest {
|
|||||||
pub async fn upload_handler(
|
pub async fn upload_handler(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
TypedMultipart(input): TypedMultipart<FileUploadRequest>,
|
TypedMultipart(input): TypedMultipart<FileUploadRequest>,
|
||||||
) -> Result<impl IntoResponse, FileError> {
|
) -> Result<impl IntoResponse, ApiError> {
|
||||||
info!("Received an upload request");
|
info!("Received an upload request");
|
||||||
|
|
||||||
// Process the file upload
|
// Process the file upload
|
||||||
@@ -33,7 +25,7 @@ pub async fn upload_handler(
|
|||||||
|
|
||||||
// Prepare the response JSON
|
// Prepare the response JSON
|
||||||
let response = json!({
|
let response = json!({
|
||||||
"uuid": file_info.uuid,
|
"id": file_info.id,
|
||||||
"sha256": file_info.sha256,
|
"sha256": file_info.sha256,
|
||||||
"path": file_info.path,
|
"path": file_info.path,
|
||||||
"mime_type": file_info.mime_type,
|
"mime_type": file_info.mime_type,
|
||||||
@@ -44,82 +36,3 @@ pub async fn upload_handler(
|
|||||||
// Return the response with HTTP 200
|
// Return the response with HTTP 200
|
||||||
Ok((axum::http::StatusCode::OK, Json(response)))
|
Ok((axum::http::StatusCode::OK, Json(response)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Handler to retrieve file information by UUID.
|
|
||||||
///
|
|
||||||
/// Route: GET /file/:uuid
|
|
||||||
pub async fn get_file_handler(
|
|
||||||
State(state): State<AppState>,
|
|
||||||
Path(uuid_str): Path<String>,
|
|
||||||
) -> Result<impl IntoResponse, FileError> {
|
|
||||||
// Parse UUID
|
|
||||||
let uuid = Uuid::parse_str(&uuid_str).map_err(|_| FileError::InvalidUuid(uuid_str.clone()))?;
|
|
||||||
|
|
||||||
// Retrieve FileInfo
|
|
||||||
let file_info = FileInfo::get_by_uuid(uuid, &state.surreal_db_client).await?;
|
|
||||||
|
|
||||||
// Prepare the response JSON
|
|
||||||
let response = json!({
|
|
||||||
"uuid": file_info.uuid,
|
|
||||||
"sha256": file_info.sha256,
|
|
||||||
"path": file_info.path,
|
|
||||||
"mime_type": file_info.mime_type,
|
|
||||||
});
|
|
||||||
|
|
||||||
info!("Retrieved FileInfo: {:?}", file_info);
|
|
||||||
|
|
||||||
// Return the response with HTTP 200
|
|
||||||
Ok((axum::http::StatusCode::OK, Json(response)))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Handler to update an existing file by UUID.
|
|
||||||
///
|
|
||||||
/// Route: PUT /file/:uuid
|
|
||||||
pub async fn update_file_handler(
|
|
||||||
State(state): State<AppState>,
|
|
||||||
Path(uuid_str): Path<String>,
|
|
||||||
TypedMultipart(input): TypedMultipart<FileUploadRequest>,
|
|
||||||
) -> Result<impl IntoResponse, FileError> {
|
|
||||||
// Parse UUID
|
|
||||||
let uuid = Uuid::parse_str(&uuid_str).map_err(|_| FileError::InvalidUuid(uuid_str.clone()))?;
|
|
||||||
|
|
||||||
// Update the file
|
|
||||||
let updated_file_info = FileInfo::update(uuid, input.file, &state.surreal_db_client).await?;
|
|
||||||
|
|
||||||
// Prepare the response JSON
|
|
||||||
let response = json!({
|
|
||||||
"uuid": updated_file_info.uuid,
|
|
||||||
"sha256": updated_file_info.sha256,
|
|
||||||
"path": updated_file_info.path,
|
|
||||||
"mime_type": updated_file_info.mime_type,
|
|
||||||
});
|
|
||||||
|
|
||||||
info!("File updated successfully: {:?}", updated_file_info);
|
|
||||||
|
|
||||||
// Return the response with HTTP 200
|
|
||||||
Ok((axum::http::StatusCode::OK, Json(response)))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Handler to delete a file by UUID.
|
|
||||||
///
|
|
||||||
/// Route: DELETE /file/:uuid
|
|
||||||
pub async fn delete_file_handler(
|
|
||||||
State(state): State<AppState>,
|
|
||||||
Path(uuid_str): Path<String>,
|
|
||||||
) -> Result<impl IntoResponse, FileError> {
|
|
||||||
// Parse UUID
|
|
||||||
let uuid = Uuid::parse_str(&uuid_str).map_err(|_| FileError::InvalidUuid(uuid_str.clone()))?;
|
|
||||||
|
|
||||||
// Delete the file
|
|
||||||
FileInfo::delete(uuid, &state.surreal_db_client).await?;
|
|
||||||
|
|
||||||
info!("Deleted file with UUID: {}", uuid);
|
|
||||||
|
|
||||||
// Prepare the response JSON
|
|
||||||
let response = json!({
|
|
||||||
"message": "File deleted successfully",
|
|
||||||
});
|
|
||||||
|
|
||||||
// Return the response with HTTP 204 No Content
|
|
||||||
Ok((axum::http::StatusCode::NO_CONTENT, Json(response)))
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ use crate::{error::ApiError, server::AppState};
|
|||||||
pub async fn index_handler(State(state): State<AppState>) -> Result<Html<String>, ApiError> {
|
pub async fn index_handler(State(state): State<AppState>) -> Result<Html<String>, ApiError> {
|
||||||
info!("Displaying index page");
|
info!("Displaying index page");
|
||||||
|
|
||||||
// Now you can access the consumer directly from the state
|
|
||||||
let queue_length = state.rabbitmq_consumer.queue.message_count();
|
let queue_length = state.rabbitmq_consumer.queue.message_count();
|
||||||
|
|
||||||
let output = state
|
let output = state
|
||||||
|
|||||||
@@ -1,42 +1,17 @@
|
|||||||
use axum::{
|
use axum::{extract::State, http::StatusCode, response::IntoResponse};
|
||||||
http::StatusCode,
|
use tracing::info;
|
||||||
response::{IntoResponse, Response},
|
|
||||||
};
|
|
||||||
use tracing::{error, info};
|
|
||||||
|
|
||||||
use crate::rabbitmq::{consumer::RabbitMQConsumer, RabbitMQConfig};
|
use crate::{error::ApiError, server::AppState};
|
||||||
|
|
||||||
pub async fn queue_length_handler() -> Response {
|
pub async fn queue_length_handler(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
) -> Result<impl IntoResponse, ApiError> {
|
||||||
info!("Getting queue length");
|
info!("Getting queue length");
|
||||||
|
|
||||||
// Set up RabbitMQ config
|
let queue_length = state.rabbitmq_consumer.get_queue_length().await?;
|
||||||
let config = RabbitMQConfig {
|
|
||||||
amqp_addr: "amqp://localhost".to_string(),
|
|
||||||
exchange: "my_exchange".to_string(),
|
|
||||||
queue: "my_queue".to_string(),
|
|
||||||
routing_key: "my_key".to_string(),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Create a new consumer
|
info!("Queue length: {}", queue_length);
|
||||||
match RabbitMQConsumer::new(&config).await {
|
|
||||||
Ok(consumer) => {
|
|
||||||
info!("Consumer connected to RabbitMQ");
|
|
||||||
|
|
||||||
// Get the queue length
|
// Return the queue length with a 200 OK status
|
||||||
let queue_length = consumer.queue.message_count();
|
Ok((StatusCode::OK, queue_length.to_string()))
|
||||||
|
|
||||||
info!("Queue length: {}", queue_length);
|
|
||||||
|
|
||||||
// Return the queue length with a 200 OK status
|
|
||||||
(StatusCode::OK, queue_length.to_string()).into_response()
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
error!("Failed to create consumer: {:?}", e);
|
|
||||||
(
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
"Failed to connect to RabbitMQ".to_string(),
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -111,5 +111,5 @@ pub async fn get_item<T>(db_client: &Surreal<Client>, id: &str) -> Result<Option
|
|||||||
where
|
where
|
||||||
T: for<'de> StoredObject,
|
T: for<'de> StoredObject,
|
||||||
{
|
{
|
||||||
Ok(db_client.select((T::table_name(), id)).await?)
|
db_client.select((T::table_name(), id)).await
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,24 +14,57 @@ pub trait StoredObject: Serialize + for<'de> Deserialize<'de> {
|
|||||||
|
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! stored_object {
|
macro_rules! stored_object {
|
||||||
($name:ident, $table:expr, {$($(#[$attr:meta])* $field:ident: $ty:ty),*}) => {
|
($name:ident, $table:expr, {$($(#[$attr:meta])* $field:ident: $ty:ty),*}) => {
|
||||||
use axum::async_trait;
|
use axum::async_trait;
|
||||||
use serde::{Deserialize, Deserializer, Serialize};
|
use serde::{Deserialize, Deserializer, Serialize};
|
||||||
use surrealdb::sql::Thing;
|
use surrealdb::sql::Thing;
|
||||||
use $crate::storage::types::StoredObject;
|
use $crate::storage::types::StoredObject;
|
||||||
|
use serde::de::{self, Visitor};
|
||||||
|
use std::fmt;
|
||||||
|
|
||||||
fn thing_to_string<'de, D>(deserializer: D) -> Result<String, D::Error>
|
struct FlexibleIdVisitor;
|
||||||
|
|
||||||
|
impl<'de> Visitor<'de> for FlexibleIdVisitor {
|
||||||
|
type Value = String;
|
||||||
|
|
||||||
|
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
formatter.write_str("a string or a Thing")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
||||||
|
where
|
||||||
|
E: de::Error,
|
||||||
|
{
|
||||||
|
Ok(value.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
|
||||||
|
where
|
||||||
|
E: de::Error,
|
||||||
|
{
|
||||||
|
Ok(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
|
||||||
|
where
|
||||||
|
A: de::MapAccess<'de>,
|
||||||
|
{
|
||||||
|
// Try to deserialize as Thing
|
||||||
|
let thing = Thing::deserialize(de::value::MapAccessDeserializer::new(map))?;
|
||||||
|
Ok(thing.id.to_raw())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn deserialize_flexible_id<'de, D>(deserializer: D) -> Result<String, D::Error>
|
||||||
where
|
where
|
||||||
D: Deserializer<'de>,
|
D: Deserializer<'de>,
|
||||||
{
|
{
|
||||||
let thing = Thing::deserialize(deserializer)?;
|
deserializer.deserialize_any(FlexibleIdVisitor)
|
||||||
Ok(thing.id.to_raw())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct $name {
|
pub struct $name {
|
||||||
#[serde(deserialize_with = "thing_to_string")]
|
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||||
pub id: String,
|
pub id: String,
|
||||||
$(pub $field: $ty),*
|
$(pub $field: $ty),*
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user