From d731e69bf9bbc960dc94cfb00be63d3ac163e445 Mon Sep 17 00:00:00 2001 From: Per Stark Date: Sat, 23 Nov 2024 15:35:44 +0100 Subject: [PATCH] error handling, and setting result as return --- src/error.rs | 43 +++++++++++++++++++++++++++++++++++- src/server/routes/ingress.rs | 34 ++++++++-------------------- src/server/routes/query.rs | 24 ++++++++++++++++---- 3 files changed, 71 insertions(+), 30 deletions(-) diff --git a/src/error.rs b/src/error.rs index eca465b..77c4bae 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,10 +1,11 @@ use async_openai::error::OpenAIError; +use axum::{http::StatusCode, response::IntoResponse, Json}; +use serde_json::json; use thiserror::Error; use tokio::task::JoinError; use crate::{ingress::types::ingress_input::IngressContentError, rabbitmq::RabbitMQError}; -/// Error types for processing `TextContent`. #[derive(Error, Debug)] pub enum ProcessingError { #[error("SurrealDb error: {0}")] @@ -37,3 +38,43 @@ pub enum IngressConsumerError { #[error("Ingress content error: {0}")] IngressContent(#[from] IngressContentError), } + +#[derive(Error, Debug)] +pub enum ApiError { + #[error("Processing error: {0}")] + ProcessingError(#[from] ProcessingError), + #[error("Ingress content error: {0}")] + IngressContentError(#[from] IngressContentError), + #[error("Publishing error: {0}")] + PublishingError(String), + #[error("Database error: {0}")] + DatabaseError(String), + #[error("Query error: {0}")] + QueryError(String), + #[error("RabbitMQ error: {0}")] + RabbitMQError(#[from] RabbitMQError), +} + +impl IntoResponse for ApiError { + fn into_response(self) -> axum::response::Response { + let (status, error_message) = match &self { + ApiError::ProcessingError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()), + ApiError::PublishingError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()), + ApiError::DatabaseError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()), + ApiError::QueryError(_) => (StatusCode::BAD_REQUEST, self.to_string()), + ApiError::IngressContentError(_) => { + (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()) + } + ApiError::RabbitMQError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()), + }; + + ( + status, + Json(json!({ + "error": error_message, + "status": "error" + })), + ) + .into_response() + } +} diff --git a/src/server/routes/ingress.rs b/src/server/routes/ingress.rs index de187f2..05bd0f2 100644 --- a/src/server/routes/ingress.rs +++ b/src/server/routes/ingress.rs @@ -1,41 +1,25 @@ use crate::{ + error::ApiError, ingress::types::ingress_input::{create_ingress_objects, IngressInput}, rabbitmq::publisher::RabbitMQProducer, storage::db::SurrealDbClient, }; use axum::{http::StatusCode, response::IntoResponse, Extension, Json}; use std::sync::Arc; -use tracing::{error, info}; +use tracing::info; pub async fn ingress_handler( Extension(producer): Extension>, Extension(db_client): Extension>, Json(input): Json, -) -> impl IntoResponse { +) -> Result { info!("Received input: {:?}", input); - match create_ingress_objects(input, &db_client).await { - Ok(objects) => { - for object in objects { - match producer.publish(&object).await { - Ok(_) => { - info!("Message published successfully"); - } - Err(e) => { - error!("Failed to publish message: {:?}", e); - return ( - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to publish message", - ) - .into_response(); - } - } - } - StatusCode::OK.into_response() - } - Err(e) => { - error!("Failed to process input: {:?}", e); - (StatusCode::INTERNAL_SERVER_ERROR, "Failed to process input").into_response() - } + let ingress_objects = create_ingress_objects(input, &db_client).await?; + + for object in ingress_objects { + producer.publish(&object).await?; } + + Ok(StatusCode::OK) } diff --git a/src/server/routes/query.rs b/src/server/routes/query.rs index daaa64d..89afed1 100644 --- a/src/server/routes/query.rs +++ b/src/server/routes/query.rs @@ -1,8 +1,12 @@ -use crate::storage::db::SurrealDbClient; -use axum::{http::StatusCode, response::IntoResponse, Extension, Json}; +use crate::{ + error::ApiError, + retrieval::vector::find_items_by_vector_similarity, + storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity}, +}; +use axum::{response::IntoResponse, Extension, Json}; use serde::Deserialize; use std::sync::Arc; -use tracing::{error, info}; +use tracing::info; #[derive(Debug, Deserialize)] pub struct QueryInput { @@ -12,6 +16,18 @@ pub struct QueryInput { pub async fn query_handler( Extension(db_client): Extension>, Json(query): Json, -) -> impl IntoResponse { +) -> Result { info!("Received input: {:?}", query); + let openai_client = async_openai::Client::new(); + + let closest_items: Vec = find_items_by_vector_similarity( + 10, + query.query, + &db_client, + "knowledge_entity".to_string(), + &openai_client, + ) + .await?; + + Ok(format!("{:?}", closest_items)) }