feat: refactored error handling

This commit is contained in:
Per Stark
2025-01-01 23:26:41 +01:00
parent a3bb73646c
commit 9976fef5a3
25 changed files with 439 additions and 293 deletions

File diff suppressed because one or more lines are too long

View File

@@ -1,107 +1,233 @@
use std::sync::Arc;
use async_openai::error::OpenAIError; use async_openai::error::OpenAIError;
use axum::{http::StatusCode, response::IntoResponse, Json}; use axum::{
http::StatusCode,
response::{Html, IntoResponse, Response},
Json,
};
use minijinja::context;
use minijinja_autoreload::AutoReloader;
use serde::Serialize;
use serde_json::json; use serde_json::json;
use thiserror::Error; use thiserror::Error;
use tokio::task::JoinError; use tokio::task::JoinError;
use crate::{ use crate::{
ingress::types::ingress_input::IngressContentError, rabbitmq::RabbitMQError, rabbitmq::RabbitMQError, storage::types::file_info::FileError, utils::mailer::EmailError,
storage::types::file_info::FileError, utils::mailer::EmailError,
}; };
// Core internal errors
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum ProcessingError { pub enum AppError {
#[error("SurrealDb error: {0}")] #[error("Database error: {0}")]
SurrealDbError(#[from] surrealdb::Error), Database(#[from] surrealdb::Error),
#[error("OpenAI error: {0}")]
#[error("LLM processing error: {0}")] OpenAI(#[from] OpenAIError),
OpenAIerror(#[from] OpenAIError),
#[error("Embedding processing error: {0}")]
EmbeddingError(String),
#[error("Graph processing error: {0}")]
GraphProcessingError(String),
#[error("LLM parsing error: {0}")]
LLMParsingError(String),
#[error("Task join error: {0}")]
JoinError(#[from] JoinError),
}
#[derive(Error, Debug)]
pub enum IngressConsumerError {
#[error("RabbitMQ error: {0}")] #[error("RabbitMQ error: {0}")]
RabbitMQ(#[from] RabbitMQError), RabbitMQ(#[from] RabbitMQError),
#[error("File error: {0}")]
#[error("Processing error: {0}")] File(#[from] FileError),
Processing(#[from] ProcessingError), #[error("Email error: {0}")]
Email(#[from] EmailError),
#[error("Ingress content error: {0}")] #[error("Not found: {0}")]
IngressContent(#[from] IngressContentError), NotFound(String),
#[error("Validation error: {0}")]
Validation(String),
#[error("Authorization error: {0}")]
Auth(String),
#[error("LLM parsing error: {0}")]
LLMParsing(String),
#[error("Task join error: {0}")]
Join(#[from] JoinError),
#[error("Graph mapper error: {0}")]
GraphMapper(String),
#[error("IoError: {0}")]
Io(#[from] std::io::Error),
#[error("Minijina error: {0}")]
MiniJinja(#[from] minijinja::Error),
} }
#[derive(Error, Debug)] // API-specific errors
#[derive(Debug, Serialize)]
pub enum ApiError { pub enum ApiError {
#[error("Processing error: {0}")] InternalError(String),
ProcessingError(#[from] ProcessingError), ValidationError(String),
#[error("Ingress content error: {0}")] NotFound(String),
IngressContentError(#[from] IngressContentError), Unauthorized(String),
#[error("Publishing error: {0}")] }
PublishingError(String),
#[error("Database error: {0}")] impl From<AppError> for ApiError {
DatabaseError(String), fn from(err: AppError) -> Self {
#[error("Query error: {0}")] match err {
QueryError(String), AppError::Database(_) | AppError::OpenAI(_) | AppError::Email(_) => {
#[error("RabbitMQ error: {0}")] tracing::error!("Internal error: {:?}", err);
RabbitMQError(#[from] RabbitMQError), ApiError::InternalError("Internal server error".to_string())
#[error("LLM processing error: {0}")] }
OpenAIerror(#[from] OpenAIError), AppError::NotFound(msg) => ApiError::NotFound(msg),
#[error("File error: {0}")] AppError::Validation(msg) => ApiError::ValidationError(msg),
FileError(#[from] FileError), AppError::Auth(msg) => ApiError::Unauthorized(msg),
#[error("SurrealDb error: {0}")] _ => ApiError::InternalError("Internal server error".to_string()),
SurrealDbError(#[from] surrealdb::Error), }
#[error("User already exists")] }
UserAlreadyExists,
#[error("User was not found")]
UserNotFound,
#[error("You must provide valid credentials")]
AuthRequired,
#[error("Templating error: {0}")]
TemplatingError(#[from] minijinja::Error),
#[error("Mail error: {0}")]
EmailError(#[from] EmailError),
} }
impl IntoResponse for ApiError { impl IntoResponse for ApiError {
fn into_response(self) -> axum::response::Response { fn into_response(self) -> Response {
let (status, error_message) = match &self { let (status, body) = match self {
ApiError::ProcessingError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()), ApiError::InternalError(message) => (
ApiError::SurrealDbError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()), StatusCode::INTERNAL_SERVER_ERROR,
ApiError::PublishingError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()), json!({
ApiError::DatabaseError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()), "error": message,
ApiError::OpenAIerror(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
ApiError::QueryError(_) => (StatusCode::BAD_REQUEST, self.to_string()),
ApiError::UserAlreadyExists => (StatusCode::BAD_REQUEST, self.to_string()),
ApiError::AuthRequired => (StatusCode::BAD_REQUEST, self.to_string()),
ApiError::UserNotFound => (StatusCode::BAD_REQUEST, self.to_string()),
ApiError::IngressContentError(_) => {
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string())
}
ApiError::RabbitMQError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
ApiError::FileError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
ApiError::TemplatingError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
ApiError::EmailError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
};
(
status,
Json(json!({
"error": error_message,
"status": "error" "status": "error"
})), }),
) ),
.into_response() ApiError::ValidationError(message) => (
StatusCode::BAD_REQUEST,
json!({
"error": message,
"status": "error"
}),
),
ApiError::NotFound(message) => (
StatusCode::NOT_FOUND,
json!({
"error": message,
"status": "error"
}),
),
ApiError::Unauthorized(message) => (
StatusCode::UNAUTHORIZED,
json!({
"error": message,
"status": "error"
}),
), // ... other matches
};
(status, Json(body)).into_response()
}
}
#[derive(Clone)]
pub struct ErrorContext {
#[allow(dead_code)]
templates: Arc<AutoReloader>,
}
impl ErrorContext {
pub fn new(templates: Arc<AutoReloader>) -> Self {
Self { templates }
}
}
pub enum HtmlError {
ServerError(Arc<AutoReloader>),
NotFound(Arc<AutoReloader>),
Unauthorized(Arc<AutoReloader>),
BadRequest(String, Arc<AutoReloader>),
Template(String, Arc<AutoReloader>),
}
// Implement From<ApiError> for HtmlError
impl HtmlError {
pub fn new(error: AppError, templates: Arc<AutoReloader>) -> Self {
match error {
AppError::NotFound(_msg) => HtmlError::NotFound(templates),
AppError::Auth(_msg) => HtmlError::Unauthorized(templates),
AppError::Validation(msg) => HtmlError::BadRequest(msg, templates),
_ => {
tracing::error!("Internal error: {:?}", error);
HtmlError::ServerError(templates)
}
}
}
pub fn from_template_error(error: minijinja::Error, templates: Arc<AutoReloader>) -> Self {
tracing::error!("Template error: {:?}", error);
HtmlError::Template(error.to_string(), templates)
}
}
impl IntoResponse for HtmlError {
fn into_response(self) -> Response {
let (status, context, templates) = match self {
HtmlError::ServerError(templates) | HtmlError::Template(_, templates) => (
StatusCode::INTERNAL_SERVER_ERROR,
context! {
status_code => 500,
title => "Internal Server Error",
error => "Internal Server Error",
description => "Something went wrong on our end."
},
templates,
),
HtmlError::NotFound(templates) => (
StatusCode::NOT_FOUND,
context! {
status_code => 404,
title => "Page Not Found",
error => "Not Found",
description => "The page you're looking for doesn't exist or was removed."
},
templates,
),
HtmlError::Unauthorized(templates) => (
StatusCode::UNAUTHORIZED,
context! {
status_code => 401,
title => "Unauthorized",
error => "Access Denied",
description => "You need to be logged in to access this page."
},
templates,
),
HtmlError::BadRequest(msg, templates) => (
StatusCode::BAD_REQUEST,
context! {
status_code => 400,
title => "Bad Request",
error => "Bad Request",
description => msg
},
templates,
),
};
let html = match templates.acquire_env() {
Ok(env) => match env.get_template("errors/error.html") {
Ok(tmpl) => match tmpl.render(context) {
Ok(output) => output,
Err(e) => {
tracing::error!("Template render error: {:?}", e);
Self::fallback_html()
}
},
Err(e) => {
tracing::error!("Template get error: {:?}", e);
Self::fallback_html()
}
},
Err(e) => {
tracing::error!("Environment acquire error: {:?}", e);
Self::fallback_html()
}
};
(status, Html(html)).into_response()
}
}
impl HtmlError {
fn fallback_html() -> String {
r#"
<html>
<body>
<div class="container mx-auto p-4">
<h1 class="text-4xl text-error">Error</h1>
<p class="mt-4">Sorry, something went wrong displaying this page.</p>
</div>
</body>
</html>
"#
.to_string()
} }
} }

View File

@@ -1,13 +1,16 @@
use crate::{ use crate::{
error::ProcessingError, error::AppError,
ingress::analysis::prompt::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE}, ingress::analysis::prompt::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE},
retrieval::combined_knowledge_entity_retrieval, retrieval::combined_knowledge_entity_retrieval,
storage::types::knowledge_entity::KnowledgeEntity, storage::types::knowledge_entity::KnowledgeEntity,
}; };
use async_openai::types::{ use async_openai::{
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, error::OpenAIError,
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, ResponseFormat, types::{
ResponseFormatJsonSchema, ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, ResponseFormat,
ResponseFormatJsonSchema,
},
}; };
use serde_json::json; use serde_json::json;
use surrealdb::engine::any::Any; use surrealdb::engine::any::Any;
@@ -38,7 +41,7 @@ impl<'a> IngressAnalyzer<'a> {
instructions: &str, instructions: &str,
text: &str, text: &str,
user_id: &str, user_id: &str,
) -> Result<LLMGraphAnalysisResult, ProcessingError> { ) -> Result<LLMGraphAnalysisResult, AppError> {
let similar_entities = self let similar_entities = self
.find_similar_entities(category, instructions, text, user_id) .find_similar_entities(category, instructions, text, user_id)
.await?; .await?;
@@ -53,7 +56,7 @@ impl<'a> IngressAnalyzer<'a> {
instructions: &str, instructions: &str,
text: &str, text: &str,
user_id: &str, user_id: &str,
) -> Result<Vec<KnowledgeEntity>, ProcessingError> { ) -> Result<Vec<KnowledgeEntity>, AppError> {
let input_text = format!( let input_text = format!(
"content: {}, category: {}, user_instructions: {}", "content: {}, category: {}, user_instructions: {}",
text, category, instructions text, category, instructions
@@ -74,7 +77,7 @@ impl<'a> IngressAnalyzer<'a> {
instructions: &str, instructions: &str,
text: &str, text: &str,
similar_entities: &[KnowledgeEntity], similar_entities: &[KnowledgeEntity],
) -> Result<CreateChatCompletionRequest, ProcessingError> { ) -> Result<CreateChatCompletionRequest, OpenAIError> {
let entities_json = json!(similar_entities let entities_json = json!(similar_entities
.iter() .iter()
.map(|entity| { .map(|entity| {
@@ -114,13 +117,12 @@ impl<'a> IngressAnalyzer<'a> {
]) ])
.response_format(response_format) .response_format(response_format)
.build() .build()
.map_err(|e| ProcessingError::LLMParsingError(e.to_string()))
} }
async fn perform_analysis( async fn perform_analysis(
&self, &self,
request: CreateChatCompletionRequest, request: CreateChatCompletionRequest,
) -> Result<LLMGraphAnalysisResult, ProcessingError> { ) -> Result<LLMGraphAnalysisResult, AppError> {
let response = self.openai_client.chat().create(request).await?; let response = self.openai_client.chat().create(request).await?;
debug!("Received LLM response: {:?}", response); debug!("Received LLM response: {:?}", response);
@@ -128,12 +130,12 @@ impl<'a> IngressAnalyzer<'a> {
.choices .choices
.first() .first()
.and_then(|choice| choice.message.content.as_ref()) .and_then(|choice| choice.message.content.as_ref())
.ok_or(ProcessingError::LLMParsingError( .ok_or(AppError::LLMParsing(
"No content found in LLM response".into(), "No content found in LLM response".to_string(),
)) ))
.and_then(|content| { .and_then(|content| {
serde_json::from_str(content).map_err(|e| { serde_json::from_str(content).map_err(|e| {
ProcessingError::LLMParsingError(format!( AppError::LLMParsing(format!(
"Failed to parse LLM response into analysis: {}", "Failed to parse LLM response into analysis: {}",
e e
)) ))

View File

@@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
use tokio::task; use tokio::task;
use crate::{ use crate::{
error::ProcessingError, error::AppError,
storage::types::{ storage::types::{
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}, knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
knowledge_relationship::KnowledgeRelationship, knowledge_relationship::KnowledgeRelationship,
@@ -49,13 +49,13 @@ impl LLMGraphAnalysisResult {
/// ///
/// # Returns /// # Returns
/// ///
/// * `Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), ProcessingError>` - A tuple containing vectors of `KnowledgeEntity` and `KnowledgeRelationship`. /// * `Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), AppError>` - A tuple containing vectors of `KnowledgeEntity` and `KnowledgeRelationship`.
pub async fn to_database_entities( pub async fn to_database_entities(
&self, &self,
source_id: &str, source_id: &str,
user_id: &str, user_id: &str,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>, openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), ProcessingError> { ) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), AppError> {
// Create mapper and pre-assign IDs // Create mapper and pre-assign IDs
let mapper = Arc::new(Mutex::new(self.create_mapper()?)); let mapper = Arc::new(Mutex::new(self.create_mapper()?));
@@ -70,7 +70,7 @@ impl LLMGraphAnalysisResult {
Ok((entities, relationships)) Ok((entities, relationships))
} }
fn create_mapper(&self) -> Result<GraphMapper, ProcessingError> { fn create_mapper(&self) -> Result<GraphMapper, AppError> {
let mut mapper = GraphMapper::new(); let mut mapper = GraphMapper::new();
// Pre-assign all IDs // Pre-assign all IDs
@@ -87,7 +87,7 @@ impl LLMGraphAnalysisResult {
user_id: &str, user_id: &str,
mapper: Arc<Mutex<GraphMapper>>, mapper: Arc<Mutex<GraphMapper>>,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>, openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<Vec<KnowledgeEntity>, ProcessingError> { ) -> Result<Vec<KnowledgeEntity>, AppError> {
let futures: Vec<_> = self let futures: Vec<_> = self
.knowledge_entities .knowledge_entities
.iter() .iter()
@@ -116,10 +116,10 @@ impl LLMGraphAnalysisResult {
fn process_relationships( fn process_relationships(
&self, &self,
mapper: Arc<Mutex<GraphMapper>>, mapper: Arc<Mutex<GraphMapper>>,
) -> Result<Vec<KnowledgeRelationship>, ProcessingError> { ) -> Result<Vec<KnowledgeRelationship>, AppError> {
let mut mapper_guard = mapper let mut mapper_guard = mapper
.lock() .lock()
.map_err(|_| ProcessingError::GraphProcessingError("Failed to lock mapper".into()))?; .map_err(|_| AppError::GraphMapper("Failed to lock mapper".into()))?;
self.relationships self.relationships
.iter() .iter()
.map(|rel| { .map(|rel| {
@@ -142,18 +142,15 @@ async fn create_single_entity(
user_id: &str, user_id: &str,
mapper: Arc<Mutex<GraphMapper>>, mapper: Arc<Mutex<GraphMapper>>,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>, openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<KnowledgeEntity, ProcessingError> { ) -> Result<KnowledgeEntity, AppError> {
let assigned_id = { let assigned_id = {
let mapper = mapper let mapper = mapper
.lock() .lock()
.map_err(|_| ProcessingError::GraphProcessingError("Failed to lock mapper".into()))?; .map_err(|_| AppError::GraphMapper("Failed to lock mapper".into()))?;
mapper mapper
.get_id(&llm_entity.key) .get_id(&llm_entity.key)
.ok_or_else(|| { .ok_or_else(|| {
ProcessingError::GraphProcessingError(format!( AppError::GraphMapper(format!("ID not found for key: {}", llm_entity.key))
"ID not found for key: {}",
llm_entity.key
))
})? })?
.to_string() .to_string()
}; };

View File

@@ -4,7 +4,7 @@ use text_splitter::TextSplitter;
use tracing::{debug, info}; use tracing::{debug, info};
use crate::{ use crate::{
error::ProcessingError, error::AppError,
storage::{ storage::{
db::{store_item, SurrealDbClient}, db::{store_item, SurrealDbClient},
types::{ types::{
@@ -25,7 +25,7 @@ pub struct ContentProcessor {
} }
impl ContentProcessor { impl ContentProcessor {
pub async fn new(app_config: &AppConfig) -> Result<Self, ProcessingError> { pub async fn new(app_config: &AppConfig) -> Result<Self, AppError> {
Ok(Self { Ok(Self {
db_client: SurrealDbClient::new( db_client: SurrealDbClient::new(
&app_config.surrealdb_address, &app_config.surrealdb_address,
@@ -39,7 +39,7 @@ impl ContentProcessor {
}) })
} }
pub async fn process(&self, content: &TextContent) -> Result<(), ProcessingError> { pub async fn process(&self, content: &TextContent) -> Result<(), AppError> {
// Store original content // Store original content
store_item(&self.db_client, content.clone()).await?; store_item(&self.db_client, content.clone()).await?;
@@ -72,7 +72,7 @@ impl ContentProcessor {
async fn perform_semantic_analysis( async fn perform_semantic_analysis(
&self, &self,
content: &TextContent, content: &TextContent,
) -> Result<LLMGraphAnalysisResult, ProcessingError> { ) -> Result<LLMGraphAnalysisResult, AppError> {
let analyser = IngressAnalyzer::new(&self.db_client, &self.openai_client); let analyser = IngressAnalyzer::new(&self.db_client, &self.openai_client);
analyser analyser
.analyze_content( .analyze_content(
@@ -88,7 +88,7 @@ impl ContentProcessor {
&self, &self,
entities: Vec<KnowledgeEntity>, entities: Vec<KnowledgeEntity>,
relationships: Vec<KnowledgeRelationship>, relationships: Vec<KnowledgeRelationship>,
) -> Result<(), ProcessingError> { ) -> Result<(), AppError> {
for entity in &entities { for entity in &entities {
debug!("Storing entity: {:?}", entity); debug!("Storing entity: {:?}", entity);
store_item(&self.db_client, entity.clone()).await?; store_item(&self.db_client, entity.clone()).await?;
@@ -107,7 +107,7 @@ impl ContentProcessor {
Ok(()) Ok(())
} }
async fn store_vector_chunks(&self, content: &TextContent) -> Result<(), ProcessingError> { async fn store_vector_chunks(&self, content: &TextContent) -> Result<(), AppError> {
let splitter = TextSplitter::new(500..2000); let splitter = TextSplitter::new(500..2000);
let chunks = splitter.chunks(&content.text); let chunks = splitter.chunks(&content.text);

View File

@@ -1,10 +1,12 @@
use super::ingress_object::IngressObject; use super::ingress_object::IngressObject;
use crate::storage::{ use crate::{
db::{get_item, SurrealDbClient}, error::AppError,
types::file_info::FileInfo, storage::{
db::{get_item, SurrealDbClient},
types::file_info::FileInfo,
},
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use thiserror::Error;
use tracing::info; use tracing::info;
use url::Url; use url::Url;
@@ -17,34 +19,6 @@ pub struct IngressInput {
pub files: Option<Vec<String>>, pub files: Option<Vec<String>>,
} }
/// Error types for processing ingress content.
#[derive(Error, Debug)]
pub enum IngressContentError {
#[error("IO error occurred: {0}")]
Io(#[from] std::io::Error),
#[error("UTF-8 conversion error: {0}")]
Utf8(#[from] std::string::FromUtf8Error),
#[error("SurrealDb error: {0}")]
SurrealDbError(#[from] surrealdb::Error),
#[error("MIME type detection failed for input: {0}")]
MimeDetection(String),
#[error("Unsupported MIME type: {0}")]
UnsupportedMime(String),
#[error("URL parse error: {0}")]
UrlParse(#[from] url::ParseError),
#[error("UUID parse error: {0}")]
UuidParse(#[from] uuid::Error),
#[error("Redis error: {0}")]
RedisError(String),
}
/// Function to create ingress objects from input. /// Function to create ingress objects from input.
/// ///
/// # Arguments /// # Arguments
@@ -57,7 +31,7 @@ pub async fn create_ingress_objects(
input: IngressInput, input: IngressInput,
db_client: &SurrealDbClient, db_client: &SurrealDbClient,
user_id: &str, user_id: &str,
) -> Result<Vec<IngressObject>, IngressContentError> { ) -> Result<Vec<IngressObject>, AppError> {
// Initialize list // Initialize list
let mut object_list = Vec::new(); let mut object_list = Vec::new();
@@ -103,7 +77,7 @@ pub async fn create_ingress_objects(
// If no objects are constructed, we return Err // If no objects are constructed, we return Err
if object_list.is_empty() { if object_list.is_empty() {
return Err(IngressContentError::MimeDetection( return Err(AppError::NotFound(
"No valid content or files provided".into(), "No valid content or files provided".into(),
)); ));
} }

View File

@@ -1,8 +1,9 @@
use crate::storage::types::{file_info::FileInfo, text_content::TextContent}; use crate::{
error::AppError,
storage::types::{file_info::FileInfo, text_content::TextContent},
};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::ingress_input::IngressContentError;
/// Knowledge object type, containing the content or reference to it, as well as metadata /// Knowledge object type, containing the content or reference to it, as well as metadata
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
pub enum IngressObject { pub enum IngressObject {
@@ -34,7 +35,7 @@ impl IngressObject {
/// ///
/// # Returns /// # Returns
/// `TextContent` - An object containing a text representation of the object, could be a scraped URL, parsed PDF, etc. /// `TextContent` - An object containing a text representation of the object, could be a scraped URL, parsed PDF, etc.
pub async fn to_text_content(&self) -> Result<TextContent, IngressContentError> { pub async fn to_text_content(&self) -> Result<TextContent, AppError> {
match self { match self {
IngressObject::Url { IngressObject::Url {
url, url,
@@ -82,12 +83,12 @@ impl IngressObject {
} }
/// Fetches and extracts text from a URL. /// Fetches and extracts text from a URL.
async fn fetch_text_from_url(_url: &str) -> Result<String, IngressContentError> { async fn fetch_text_from_url(_url: &str) -> Result<String, AppError> {
unimplemented!() unimplemented!()
} }
/// Extracts text from a file based on its MIME type. /// Extracts text from a file based on its MIME type.
async fn extract_text_from_file(file_info: &FileInfo) -> Result<String, IngressContentError> { async fn extract_text_from_file(file_info: &FileInfo) -> Result<String, AppError> {
match file_info.mime_type.as_str() { match file_info.mime_type.as_str() {
"text/plain" => { "text/plain" => {
// Read the file and return its content // Read the file and return its content
@@ -101,15 +102,11 @@ impl IngressObject {
} }
"application/pdf" => { "application/pdf" => {
// TODO: Implement PDF text extraction using a crate like `pdf-extract` or `lopdf` // TODO: Implement PDF text extraction using a crate like `pdf-extract` or `lopdf`
Err(IngressContentError::UnsupportedMime( Err(AppError::NotFound(file_info.mime_type.clone()))
file_info.mime_type.clone(),
))
} }
"image/png" | "image/jpeg" => { "image/png" | "image/jpeg" => {
// TODO: Implement OCR on image using a crate like `tesseract` // TODO: Implement OCR on image using a crate like `tesseract`
Err(IngressContentError::UnsupportedMime( Err(AppError::NotFound(file_info.mime_type.clone()))
file_info.mime_type.clone(),
))
} }
"application/octet-stream" => { "application/octet-stream" => {
let content = tokio::fs::read_to_string(&file_info.path).await?; let content = tokio::fs::read_to_string(&file_info.path).await?;
@@ -120,9 +117,7 @@ impl IngressObject {
Ok(content) Ok(content)
} }
// Handle other MIME types as needed // Handle other MIME types as needed
_ => Err(IngressContentError::UnsupportedMime( _ => Err(AppError::NotFound(file_info.mime_type.clone())),
file_info.mime_type.clone(),
)),
} }
} }
} }

View File

@@ -9,8 +9,6 @@ use lapin::{
use thiserror::Error; use thiserror::Error;
use tracing::debug; use tracing::debug;
use crate::error::ProcessingError;
/// Possible errors related to RabbitMQ operations. /// Possible errors related to RabbitMQ operations.
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum RabbitMQError { pub enum RabbitMQError {
@@ -28,8 +26,6 @@ pub enum RabbitMQError {
InitializeConsumerError(String), InitializeConsumerError(String),
#[error("Queue error: {0}")] #[error("Queue error: {0}")]
QueueError(String), QueueError(String),
#[error("Processing error: {0}")]
ProcessingError(#[from] ProcessingError),
} }
/// Struct containing the information required to set up a client and connection. /// Struct containing the information required to set up a client and connection.

View File

@@ -4,7 +4,7 @@ pub mod query_helper_prompt;
pub mod vector; pub mod vector;
use crate::{ use crate::{
error::ProcessingError, error::AppError,
retrieval::{ retrieval::{
graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids}, graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids},
vector::find_items_by_vector_similarity, vector::find_items_by_vector_similarity,
@@ -34,14 +34,14 @@ use surrealdb::{engine::any::Any, Surreal};
/// * 'user_id' - The user id of the current user /// * 'user_id' - The user id of the current user
/// ///
/// # Returns /// # Returns
/// * `Result<Vec<KnowledgeEntity>, ProcessingError>` - A deduplicated vector of relevant /// * `Result<Vec<KnowledgeEntity>, AppError>` - A deduplicated vector of relevant
/// knowledge entities, or an error if the retrieval process fails /// knowledge entities, or an error if the retrieval process fails
pub async fn combined_knowledge_entity_retrieval( pub async fn combined_knowledge_entity_retrieval(
db_client: &Surreal<Any>, db_client: &Surreal<Any>,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>, openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
query: &str, query: &str,
user_id: &str, user_id: &str,
) -> Result<Vec<KnowledgeEntity>, ProcessingError> { ) -> Result<Vec<KnowledgeEntity>, AppError> {
// info!("Received input: {:?}", query); // info!("Received input: {:?}", query);
let (items_from_knowledge_entity_similarity, closest_chunks) = try_join( let (items_from_knowledge_entity_similarity, closest_chunks) = try_join(

View File

@@ -1,14 +1,17 @@
use async_openai::types::{ use async_openai::{
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, error::OpenAIError,
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse, types::{
ResponseFormat, ResponseFormatJsonSchema, ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
ResponseFormat, ResponseFormatJsonSchema,
},
}; };
use serde::Deserialize; use serde::Deserialize;
use serde_json::{json, Value}; use serde_json::{json, Value};
use tracing::debug; use tracing::debug;
use crate::{ use crate::{
error::ApiError, error::AppError,
retrieval::combined_knowledge_entity_retrieval, retrieval::combined_knowledge_entity_retrieval,
storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity}, storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity},
}; };
@@ -94,7 +97,7 @@ pub async fn get_answer_with_references(
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>, openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
query: &str, query: &str,
user_id: &str, user_id: &str,
) -> Result<Answer, ApiError> { ) -> Result<Answer, AppError> {
let entities = let entities =
combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query, user_id) combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query, user_id)
.await?; .await?;
@@ -104,11 +107,7 @@ pub async fn get_answer_with_references(
let user_message = create_user_message(&entities_json, query); let user_message = create_user_message(&entities_json, query);
let request = create_chat_request(user_message)?; let request = create_chat_request(user_message)?;
let response = openai_client let response = openai_client.chat().create(request).await?;
.chat()
.create(request)
.await
.map_err(|e| ApiError::QueryError(e.to_string()))?;
let llm_response = process_llm_response(response).await?; let llm_response = process_llm_response(response).await?;
@@ -152,7 +151,9 @@ pub fn create_user_message(entities_json: &Value, query: &str) -> String {
) )
} }
pub fn create_chat_request(user_message: String) -> Result<CreateChatCompletionRequest, ApiError> { pub fn create_chat_request(
user_message: String,
) -> Result<CreateChatCompletionRequest, OpenAIError> {
let response_format = ResponseFormat::JsonSchema { let response_format = ResponseFormat::JsonSchema {
json_schema: ResponseFormatJsonSchema { json_schema: ResponseFormatJsonSchema {
description: Some("Query answering AI".into()), description: Some("Query answering AI".into()),
@@ -172,22 +173,21 @@ pub fn create_chat_request(user_message: String) -> Result<CreateChatCompletionR
]) ])
.response_format(response_format) .response_format(response_format)
.build() .build()
.map_err(|e| ApiError::QueryError(e.to_string()))
} }
pub async fn process_llm_response( pub async fn process_llm_response(
response: CreateChatCompletionResponse, response: CreateChatCompletionResponse,
) -> Result<LLMResponseFormat, ApiError> { ) -> Result<LLMResponseFormat, AppError> {
response response
.choices .choices
.first() .first()
.and_then(|choice| choice.message.content.as_ref()) .and_then(|choice| choice.message.content.as_ref())
.ok_or(ApiError::QueryError( .ok_or(AppError::LLMParsing(
"No content found in LLM response".into(), "No content found in LLM response".into(),
)) ))
.and_then(|content| { .and_then(|content| {
serde_json::from_str::<LLMResponseFormat>(content).map_err(|e| { serde_json::from_str::<LLMResponseFormat>(content).map_err(|e| {
ApiError::QueryError(format!("Failed to parse LLM response into analysis: {}", e)) AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {}", e))
}) })
}) })
} }

View File

@@ -1,6 +1,6 @@
use surrealdb::{engine::any::Any, Surreal}; use surrealdb::{engine::any::Any, Surreal};
use crate::{error::ProcessingError, utils::embedding::generate_embedding}; use crate::{error::AppError, utils::embedding::generate_embedding};
/// Compares vectors and retrieves a number of items from the specified table. /// Compares vectors and retrieves a number of items from the specified table.
/// ///
@@ -30,7 +30,7 @@ pub async fn find_items_by_vector_similarity<T>(
table: String, table: String,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>, openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
user_id: &str, user_id: &str,
) -> Result<Vec<T>, ProcessingError> ) -> Result<Vec<T>, AppError>
where where
T: for<'de> serde::Deserialize<'de>, T: for<'de> serde::Deserialize<'de>,
{ {

View File

@@ -13,10 +13,14 @@ pub async fn api_auth(
mut request: Request, mut request: Request,
next: Next, next: Next,
) -> Result<Response, ApiError> { ) -> Result<Response, ApiError> {
let api_key = extract_api_key(&request).ok_or(ApiError::AuthRequired)?; let api_key = extract_api_key(&request).ok_or(ApiError::Unauthorized(
"You have to be authenticated".to_string(),
))?;
let user = User::find_by_api_key(&api_key, &state.surreal_db_client).await?; let user = User::find_by_api_key(&api_key, &state.surreal_db_client).await?;
let user = user.ok_or(ApiError::UserNotFound)?; let user = user.ok_or(ApiError::Unauthorized(
"You have to be authenticated".to_string(),
))?;
request.extensions_mut().insert(user); request.extensions_mut().insert(user);

View File

@@ -1,4 +1,8 @@
use crate::{error::ApiError, server::AppState, storage::types::file_info::FileInfo}; use crate::{
error::{ApiError, AppError},
server::AppState,
storage::types::file_info::FileInfo,
};
use axum::{extract::State, response::IntoResponse, Json}; use axum::{extract::State, response::IntoResponse, Json};
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart}; use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
use serde_json::json; use serde_json::json;
@@ -21,7 +25,9 @@ pub async fn upload_handler(
info!("Received an upload request"); info!("Received an upload request");
// Process the file upload // Process the file upload
let file_info = FileInfo::new(input.file, &state.surreal_db_client).await?; let file_info = FileInfo::new(input.file, &state.surreal_db_client)
.await
.map_err(AppError::from)?;
// Prepare the response JSON // Prepare the response JSON
let response = json!({ let response = json!({

View File

@@ -1,5 +1,5 @@
use crate::{ use crate::{
error::ApiError, error::{ApiError, AppError},
ingress::types::ingress_input::{create_ingress_objects, IngressInput}, ingress::types::ingress_input::{create_ingress_objects, IngressInput},
server::AppState, server::AppState,
storage::types::user::User, storage::types::user::User,
@@ -22,7 +22,7 @@ pub async fn ingress_handler(
.map(|object| state.rabbitmq_producer.publish(object)) .map(|object| state.rabbitmq_producer.publish(object))
.collect(); .collect();
try_join_all(futures).await?; try_join_all(futures).await.map_err(AppError::from)?;
Ok(StatusCode::OK) Ok(StatusCode::OK)
} }

View File

@@ -1,21 +1,28 @@
use axum::{extract::State, http::StatusCode, response::IntoResponse}; use axum::{extract::State, http::StatusCode, response::IntoResponse};
use minijinja::context; use tracing::info;
use tracing::{info, Instrument};
use crate::{error::ApiError, server::AppState}; use crate::{
error::{ApiError, AppError},
server::AppState,
};
pub async fn queue_length_handler( pub async fn queue_length_handler(
State(state): State<AppState>, State(state): State<AppState>,
) -> Result<impl IntoResponse, ApiError> { ) -> Result<impl IntoResponse, ApiError> {
info!("Getting queue length"); info!("Getting queue length");
let queue_length = state.rabbitmq_consumer.get_queue_length().await?; let queue_length = state
.rabbitmq_consumer
.get_queue_length()
.await
.map_err(AppError::from)?;
info!("Queue length: {}", queue_length); info!("Queue length: {}", queue_length);
state state
.mailer .mailer
.send_email_verification("per@starks.cloud", "1001010", &state.templates)?; .send_email_verification("per@starks.cloud", "1001010", &state.templates)
.map_err(AppError::from)?;
// Return the queue length with a 200 OK status // Return the queue length with a 200 OK status
Ok((StatusCode::OK, queue_length.to_string())) Ok((StatusCode::OK, queue_length.to_string()))

View File

@@ -1,7 +1,7 @@
use axum::{ use axum::{
extract::State, extract::State,
http::{Response, StatusCode, Uri}, http::{StatusCode, Uri},
response::{Html, IntoResponse, Redirect}, response::{IntoResponse, Redirect},
}; };
use axum_htmx::HxRedirect; use axum_htmx::HxRedirect;
use axum_session_auth::AuthSession; use axum_session_auth::AuthSession;
@@ -9,7 +9,7 @@ use axum_session_surreal::SessionSurrealPool;
use surrealdb::{engine::any::Any, Surreal}; use surrealdb::{engine::any::Any, Surreal};
use crate::{ use crate::{
error::ApiError, error::{AppError, HtmlError},
page_data, page_data,
server::{routes::html::render_template, AppState}, server::{routes::html::render_template, AppState},
storage::{db::delete_item, types::user::User}, storage::{db::delete_item, types::user::User},
@@ -24,7 +24,7 @@ page_data!(AccountData, "auth/account.html", {
pub async fn show_account_page( pub async fn show_account_page(
State(state): State<AppState>, State(state): State<AppState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>, auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<impl IntoResponse, ApiError> { ) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not authenticated // Early return if the user is not authenticated
let user = match auth.current_user { let user = match auth.current_user {
Some(user) => user, Some(user) => user,
@@ -34,8 +34,9 @@ pub async fn show_account_page(
let output = render_template( let output = render_template(
AccountData::template_name(), AccountData::template_name(),
AccountData { user }, AccountData { user },
state.templates, state.templates.clone(),
)?; )
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
Ok(output.into_response()) Ok(output.into_response())
} }
@@ -43,12 +44,17 @@ pub async fn show_account_page(
pub async fn set_api_key( pub async fn set_api_key(
State(state): State<AppState>, State(state): State<AppState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>, auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<impl IntoResponse, ApiError> { ) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not authenticated // Early return if the user is not authenticated
let user = auth.current_user.as_ref().ok_or(ApiError::AuthRequired)?; let user = match &auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/").into_response()),
};
// Generate and set the API key // Generate and set the API key
let api_key = User::set_api_key(&user.id, &state.surreal_db_client).await?; let api_key = User::set_api_key(&user.id, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
auth.cache_clear_user(user.id.to_string()); auth.cache_clear_user(user.id.to_string());
@@ -63,8 +69,9 @@ pub async fn set_api_key(
AccountData::template_name(), AccountData::template_name(),
"api_key_section", "api_key_section",
AccountData { user: updated_user }, AccountData { user: updated_user },
state.templates, state.templates.clone(),
)?; )
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
Ok(output.into_response()) Ok(output.into_response())
} }
@@ -72,11 +79,16 @@ pub async fn set_api_key(
pub async fn delete_account( pub async fn delete_account(
State(state): State<AppState>, State(state): State<AppState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>, auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<impl IntoResponse, ApiError> { ) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not authenticated // Early return if the user is not authenticated
let user = auth.current_user.as_ref().ok_or(ApiError::AuthRequired)?; let user = match &auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/").into_response()),
};
delete_item::<User>(&state.surreal_db_client, &user.id).await?; delete_item::<User>(&state.surreal_db_client, &user.id)
.await
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
auth.logout_user(); auth.logout_user();

View File

@@ -1,11 +1,11 @@
use axum::{extract::State, response::Html}; use axum::{extract::State, response::IntoResponse};
use axum_session_auth::AuthSession; use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool; use axum_session_surreal::SessionSurrealPool;
use surrealdb::{engine::any::Any, Surreal}; use surrealdb::{engine::any::Any, Surreal};
use tracing::info; use tracing::info;
use crate::{ use crate::{
error::ApiError, error::{AppError, HtmlError},
page_data, page_data,
server::{routes::html::render_template, AppState}, server::{routes::html::render_template, AppState},
storage::types::user::User, storage::types::user::User,
@@ -19,10 +19,14 @@ page_data!(IndexData, "index/index.html", {
pub async fn index_handler( pub async fn index_handler(
State(state): State<AppState>, State(state): State<AppState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>, auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<Html<String>, ApiError> { ) -> Result<impl IntoResponse, HtmlError> {
info!("Displaying index page"); info!("Displaying index page");
let queue_length = state.rabbitmq_consumer.get_queue_length().await?; let queue_length = state
.rabbitmq_consumer
.get_queue_length()
.await
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
// let knowledge_entities = User::get_knowledge_entities( // let knowledge_entities = User::get_knowledge_entities(
// &auth.current_user.clone().unwrap().id, // &auth.current_user.clone().unwrap().id,
@@ -38,8 +42,9 @@ pub async fn index_handler(
queue_length, queue_length,
user: auth.current_user, user: auth.current_user,
}, },
state.templates, state.templates.clone(),
)?; )
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
Ok(output) Ok(output.into_response())
} }

View File

@@ -1,25 +1,22 @@
use axum::{ use axum::{
extract::State, extract::State,
http::{StatusCode, Uri},
response::{Html, IntoResponse, Redirect}, response::{Html, IntoResponse, Redirect},
Form,
}; };
use axum_htmx::{HxBoosted, HxRedirect};
use axum_session_auth::AuthSession; use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool; use axum_session_surreal::SessionSurrealPool;
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart}; use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
use serde::{Deserialize, Serialize}; use serde::Serialize;
use surrealdb::{engine::any::Any, Surreal}; use surrealdb::{engine::any::Any, Surreal};
use tempfile::NamedTempFile; use tempfile::NamedTempFile;
use tracing::info; use tracing::info;
use crate::{ use crate::{
error::ApiError, error::{AppError, HtmlError},
server::AppState, server::AppState,
storage::types::{file_info::FileInfo, user::User}, storage::types::{file_info::FileInfo, user::User},
}; };
use super::{render_block, render_template}; use super::render_template;
#[derive(Serialize)] #[derive(Serialize)]
struct PageData { struct PageData {
@@ -29,13 +26,17 @@ struct PageData {
pub async fn show_ingress_form( pub async fn show_ingress_form(
State(state): State<AppState>, State(state): State<AppState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>, auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<impl IntoResponse, ApiError> { ) -> Result<impl IntoResponse, HtmlError> {
if !auth.is_authenticated() { if !auth.is_authenticated() {
return Ok(Redirect::to("/").into_response()); return Ok(Redirect::to("/").into_response());
} }
Ok(render_template("ingress_form.html", PageData {}, state.templates)?.into_response()) let output = render_template("ingress_form.html", PageData {}, state.templates.clone())
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
Ok(output.into_response())
} }
#[derive(Debug, TryFromMultipart)] #[derive(Debug, TryFromMultipart)]
pub struct IngressParams { pub struct IngressParams {
pub content: Option<String>, pub content: Option<String>,
@@ -49,8 +50,8 @@ pub async fn process_ingress_form(
State(state): State<AppState>, State(state): State<AppState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>, auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
TypedMultipart(input): TypedMultipart<IngressParams>, TypedMultipart(input): TypedMultipart<IngressParams>,
) -> Result<impl IntoResponse, ApiError> { ) -> Result<impl IntoResponse, HtmlError> {
let user = match auth.current_user { let _user = match auth.current_user {
Some(user) => user, Some(user) => user,
None => return Ok(Redirect::to("/").into_response()), None => return Ok(Redirect::to("/").into_response()),
}; };
@@ -60,7 +61,9 @@ pub async fn process_ingress_form(
// Process files and create FileInfo objects // Process files and create FileInfo objects
let mut file_infos = Vec::new(); let mut file_infos = Vec::new();
for file in input.files { for file in input.files {
let file_info = FileInfo::new(file, &state.surreal_db_client).await?; let file_info = FileInfo::new(file, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
file_infos.push(file_info); file_infos.push(file_info);
} }

View File

@@ -1,6 +1,6 @@
use axum::{ use axum::{
extract::{Query, State}, extract::{Query, State},
response::Html, response::{Html, IntoResponse, Redirect},
}; };
use axum_session_auth::AuthSession; use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool; use axum_session_surreal::SessionSurrealPool;
@@ -9,7 +9,7 @@ use surrealdb::{engine::any::Any, Surreal};
use tracing::info; use tracing::info;
use crate::{ use crate::{
error::ApiError, retrieval::query_helper::get_answer_with_references, server::AppState, error::HtmlError, retrieval::query_helper::get_answer_with_references, server::AppState,
storage::types::user::User, storage::types::user::User,
}; };
#[derive(Deserialize)] #[derive(Deserialize)]
@@ -21,30 +21,22 @@ pub async fn search_result_handler(
State(state): State<AppState>, State(state): State<AppState>,
Query(query): Query<SearchParams>, Query(query): Query<SearchParams>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>, auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<Html<String>, ApiError> { ) -> Result<impl IntoResponse, HtmlError> {
info!("Displaying search results"); info!("Displaying search results");
let user_id = auth.current_user.ok_or_else(|| ApiError::AuthRequired)?.id; let user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/").into_response()),
};
let answer = get_answer_with_references( let answer = get_answer_with_references(
&state.surreal_db_client, &state.surreal_db_client,
&state.openai_client, &state.openai_client,
&query.query, &query.query,
&user_id, &user.id,
) )
.await?; .await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
Ok(Html(answer.content)) Ok(Html(answer.content).into_response())
// let output = state
// .tera
// .render(
// "search_result.html",
// &Context::from_value(
// json!({"result": answer.content, "references": answer.references}),
// )
// .unwrap(),
// )
// .unwrap();
// Ok(output.into())
} }

View File

@@ -9,7 +9,12 @@ use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool; use axum_session_surreal::SessionSurrealPool;
use surrealdb::{engine::any::Any, Surreal}; use surrealdb::{engine::any::Any, Surreal};
use crate::{error::ApiError, page_data, server::AppState, storage::types::user::User}; use crate::{
error::{AppError, HtmlError},
page_data,
server::AppState,
storage::types::user::User,
};
use super::{render_block, render_template}; use super::{render_block, render_template};
@@ -26,7 +31,7 @@ pub async fn show_signin_form(
State(state): State<AppState>, State(state): State<AppState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>, auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
HxBoosted(boosted): HxBoosted, HxBoosted(boosted): HxBoosted,
) -> Result<impl IntoResponse, ApiError> { ) -> Result<impl IntoResponse, HtmlError> {
if auth.is_authenticated() { if auth.is_authenticated() {
return Ok(Redirect::to("/").into_response()); return Ok(Redirect::to("/").into_response());
} }
@@ -35,13 +40,15 @@ pub async fn show_signin_form(
ShowSignInForm::template_name(), ShowSignInForm::template_name(),
"body", "body",
ShowSignInForm {}, ShowSignInForm {},
state.templates, state.templates.clone(),
)?, )
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?,
false => render_template( false => render_template(
ShowSignInForm::template_name(), ShowSignInForm::template_name(),
ShowSignInForm {}, ShowSignInForm {},
state.templates, state.templates.clone(),
)?, )
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?,
}; };
Ok(output.into_response()) Ok(output.into_response())
@@ -51,11 +58,11 @@ pub async fn authenticate_user(
State(state): State<AppState>, State(state): State<AppState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>, auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Form(form): Form<SignupParams>, Form(form): Form<SignupParams>,
) -> Result<impl IntoResponse, ApiError> { ) -> Result<impl IntoResponse, HtmlError> {
let user = match User::authenticate(form.email, form.password, &state.surreal_db_client).await { let user = match User::authenticate(form.email, form.password, &state.surreal_db_client).await {
Ok(user) => user, Ok(user) => user,
Err(_) => { Err(_) => {
return Ok(Html("<p>Invalid email or password.</p>").into_response()); return Ok(Html("<p>Incorrect email or password </p>").into_response());
} }
}; };

View File

@@ -10,7 +10,11 @@ use axum_session_surreal::SessionSurrealPool;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use surrealdb::{engine::any::Any, Surreal}; use surrealdb::{engine::any::Any, Surreal};
use crate::{error::ApiError, server::AppState, storage::types::user::User}; use crate::{
error::{AppError, HtmlError},
server::AppState,
storage::types::user::User,
};
use super::{render_block, render_template}; use super::{render_block, render_template};
@@ -29,7 +33,7 @@ pub async fn show_signup_form(
State(state): State<AppState>, State(state): State<AppState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>, auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
HxBoosted(boosted): HxBoosted, HxBoosted(boosted): HxBoosted,
) -> Result<impl IntoResponse, ApiError> { ) -> Result<impl IntoResponse, HtmlError> {
if auth.is_authenticated() { if auth.is_authenticated() {
return Ok(Redirect::to("/").into_response()); return Ok(Redirect::to("/").into_response());
} }
@@ -38,9 +42,15 @@ pub async fn show_signup_form(
"auth/signup_form.html", "auth/signup_form.html",
"body", "body",
PageData {}, PageData {},
state.templates, state.templates.clone(),
)?, )
false => render_template("auth/signup_form.html", PageData {}, state.templates)?, .map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?,
false => render_template(
"auth/signup_form.html",
PageData {},
state.templates.clone(),
)
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?,
}; };
Ok(output.into_response()) Ok(output.into_response())
@@ -50,7 +60,7 @@ pub async fn process_signup_and_show_verification(
State(state): State<AppState>, State(state): State<AppState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>, auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Form(form): Form<SignupParams>, Form(form): Form<SignupParams>,
) -> Result<impl IntoResponse, ApiError> { ) -> Result<impl IntoResponse, HtmlError> {
let user = match User::create_new(form.email, form.password, &state.surreal_db_client).await { let user = match User::create_new(form.email, form.password, &state.surreal_db_client).await {
Ok(user) => user, Ok(user) => user,
Err(_) => { Err(_) => {

View File

@@ -1,4 +1,4 @@
use crate::{error::ProcessingError, stored_object}; use crate::{error::AppError, stored_object};
use surrealdb::{engine::any::Any, Surreal}; use surrealdb::{engine::any::Any, Surreal};
use tracing::debug; use tracing::debug;
use uuid::Uuid; use uuid::Uuid;
@@ -26,10 +26,7 @@ impl KnowledgeRelationship {
metadata, metadata,
} }
} }
pub async fn store_relationship( pub async fn store_relationship(&self, db_client: &Surreal<Any>) -> Result<(), AppError> {
&self,
db_client: &Surreal<Any>,
) -> Result<(), ProcessingError> {
let query = format!( let query = format!(
"RELATE knowledge_entity:`{}` -> relates_to -> knowledge_entity:`{}`", "RELATE knowledge_entity:`{}` -> relates_to -> knowledge_entity:`{}`",
self.in_, self.out self.in_, self.out

View File

@@ -1,5 +1,5 @@
use crate::{ use crate::{
error::ApiError, error::AppError,
storage::db::{get_item, SurrealDbClient}, storage::db::{get_item, SurrealDbClient},
stored_object, stored_object,
}; };
@@ -41,10 +41,10 @@ impl User {
email: String, email: String,
password: String, password: String,
db: &SurrealDbClient, db: &SurrealDbClient,
) -> Result<Self, ApiError> { ) -> Result<Self, AppError> {
// Check if user exists // Check if user exists
if (Self::find_by_email(&email, db).await?).is_some() { if (Self::find_by_email(&email, db).await?).is_some() {
return Err(ApiError::UserAlreadyExists); return Err(AppError::Auth("User already exists".into()));
} }
let id = Uuid::new_v4().to_string(); let id = Uuid::new_v4().to_string();
@@ -62,14 +62,14 @@ impl User {
.await? .await?
.take(0)?; .take(0)?;
user.ok_or(ApiError::UserAlreadyExists) user.ok_or(AppError::Auth("User failed to create".into()))
} }
pub async fn authenticate( pub async fn authenticate(
email: String, email: String,
password: String, password: String,
db: &SurrealDbClient, db: &SurrealDbClient,
) -> Result<Self, ApiError> { ) -> Result<Self, AppError> {
let user: Option<User> = db let user: Option<User> = db
.client .client
.query( .query(
@@ -81,13 +81,13 @@ impl User {
.bind(("password", password)) .bind(("password", password))
.await? .await?
.take(0)?; .take(0)?;
user.ok_or(ApiError::UserAlreadyExists) user.ok_or(AppError::Auth("User failed to authenticate".into()))
} }
pub async fn find_by_email( pub async fn find_by_email(
email: &str, email: &str,
db: &SurrealDbClient, db: &SurrealDbClient,
) -> Result<Option<Self>, ApiError> { ) -> Result<Option<Self>, AppError> {
let user: Option<User> = db let user: Option<User> = db
.client .client
.query("SELECT * FROM user WHERE email = $email LIMIT 1") .query("SELECT * FROM user WHERE email = $email LIMIT 1")
@@ -101,7 +101,7 @@ impl User {
pub async fn find_by_api_key( pub async fn find_by_api_key(
api_key: &str, api_key: &str,
db: &SurrealDbClient, db: &SurrealDbClient,
) -> Result<Option<Self>, ApiError> { ) -> Result<Option<Self>, AppError> {
let user: Option<User> = db let user: Option<User> = db
.client .client
.query("SELECT * FROM user WHERE api_key = $api_key LIMIT 1") .query("SELECT * FROM user WHERE api_key = $api_key LIMIT 1")
@@ -112,7 +112,7 @@ impl User {
Ok(user) Ok(user)
} }
pub async fn set_api_key(id: &str, db: &SurrealDbClient) -> Result<String, ApiError> { pub async fn set_api_key(id: &str, db: &SurrealDbClient) -> Result<String, AppError> {
// Generate a secure random API key // Generate a secure random API key
let api_key = format!("sk_{}", Uuid::new_v4().to_string().replace("-", "")); let api_key = format!("sk_{}", Uuid::new_v4().to_string().replace("-", ""));
@@ -133,11 +133,11 @@ impl User {
if user.is_some() { if user.is_some() {
Ok(api_key) Ok(api_key)
} else { } else {
Err(ApiError::UserNotFound) Err(AppError::Auth("User not found".into()))
} }
} }
pub async fn revoke_api_key(id: &str, db: &SurrealDbClient) -> Result<(), ApiError> { pub async fn revoke_api_key(id: &str, db: &SurrealDbClient) -> Result<(), AppError> {
let user: Option<User> = db let user: Option<User> = db
.client .client
.query( .query(
@@ -152,14 +152,14 @@ impl User {
if user.is_some() { if user.is_some() {
Ok(()) Ok(())
} else { } else {
Err(ApiError::UserNotFound) Err(AppError::Auth("User was not found".into()))
} }
} }
pub async fn get_knowledge_entities( pub async fn get_knowledge_entities(
id: &str, id: &str,
db: &SurrealDbClient, db: &SurrealDbClient,
) -> Result<Vec<KnowledgeEntity>, ApiError> { ) -> Result<Vec<KnowledgeEntity>, AppError> {
let entities: Vec<KnowledgeEntity> = db let entities: Vec<KnowledgeEntity> = db
.client .client
.query("SELECT * FROM knowledge_entity WHERE user_id = $user_id") .query("SELECT * FROM knowledge_entity WHERE user_id = $user_id")

View File

@@ -1,7 +1,6 @@
use async_openai::types::CreateEmbeddingRequestArgs; use async_openai::types::CreateEmbeddingRequestArgs;
use crate::error::ProcessingError; use crate::error::AppError;
/// Generates an embedding vector for the given input text using OpenAI's embedding model. /// Generates an embedding vector for the given input text using OpenAI's embedding model.
/// ///
/// This function takes a text input and converts it into a numerical vector representation (embedding) /// This function takes a text input and converts it into a numerical vector representation (embedding)
@@ -21,14 +20,14 @@ use crate::error::ProcessingError;
/// ///
/// # Errors /// # Errors
/// ///
/// This function can return a `ProcessingError` in the following cases: /// This function can return a `AppError` in the following cases:
/// * If the OpenAI API request fails /// * If the OpenAI API request fails
/// * If the request building fails /// * If the request building fails
/// * If no embedding data is received in the response /// * If no embedding data is received in the response
pub async fn generate_embedding( pub async fn generate_embedding(
client: &async_openai::Client<async_openai::config::OpenAIConfig>, client: &async_openai::Client<async_openai::config::OpenAIConfig>,
input: &str, input: &str,
) -> Result<Vec<f32>, ProcessingError> { ) -> Result<Vec<f32>, AppError> {
let request = CreateEmbeddingRequestArgs::default() let request = CreateEmbeddingRequestArgs::default()
.model("text-embedding-3-small") .model("text-embedding-3-small")
.input([input]) .input([input])
@@ -41,7 +40,7 @@ pub async fn generate_embedding(
let embedding: Vec<f32> = response let embedding: Vec<f32> = response
.data .data
.first() .first()
.ok_or_else(|| ProcessingError::EmbeddingError("No embedding data received".into()))? .ok_or_else(|| AppError::LLMParsing("No embedding data received".into()))?
.embedding .embedding
.clone(); .clone();

View File

@@ -0,0 +1,14 @@
{% extends "body_base.html" %}
{% block main %}
<div class="container mx-auto px-4 flex items-center justify-center">
<div class="flex flex-col space-y-4 text-center">
<h1 class="text-2xl font-bold text-error">
{{ status_code }}
</h1>
<p class="text-2xl my-4">{{ error }}</p>
<p class="text-base-content/60">{{ description }}</p>
<a href="/" class="btn btn-primary mt-8">Go Home</a>
</div>
</div>
{% endblock %}