mirror of
https://github.com/perstarkse/minne.git
synced 2026-04-24 17:58:31 +02:00
feat: refactored error handling
This commit is contained in:
File diff suppressed because one or more lines are too long
300
src/error.rs
300
src/error.rs
@@ -1,107 +1,233 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use async_openai::error::OpenAIError;
|
use async_openai::error::OpenAIError;
|
||||||
use axum::{http::StatusCode, response::IntoResponse, Json};
|
use axum::{
|
||||||
|
http::StatusCode,
|
||||||
|
response::{Html, IntoResponse, Response},
|
||||||
|
Json,
|
||||||
|
};
|
||||||
|
use minijinja::context;
|
||||||
|
use minijinja_autoreload::AutoReloader;
|
||||||
|
use serde::Serialize;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::task::JoinError;
|
use tokio::task::JoinError;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
ingress::types::ingress_input::IngressContentError, rabbitmq::RabbitMQError,
|
rabbitmq::RabbitMQError, storage::types::file_info::FileError, utils::mailer::EmailError,
|
||||||
storage::types::file_info::FileError, utils::mailer::EmailError,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Core internal errors
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum ProcessingError {
|
pub enum AppError {
|
||||||
#[error("SurrealDb error: {0}")]
|
#[error("Database error: {0}")]
|
||||||
SurrealDbError(#[from] surrealdb::Error),
|
Database(#[from] surrealdb::Error),
|
||||||
|
#[error("OpenAI error: {0}")]
|
||||||
#[error("LLM processing error: {0}")]
|
OpenAI(#[from] OpenAIError),
|
||||||
OpenAIerror(#[from] OpenAIError),
|
|
||||||
|
|
||||||
#[error("Embedding processing error: {0}")]
|
|
||||||
EmbeddingError(String),
|
|
||||||
|
|
||||||
#[error("Graph processing error: {0}")]
|
|
||||||
GraphProcessingError(String),
|
|
||||||
|
|
||||||
#[error("LLM parsing error: {0}")]
|
|
||||||
LLMParsingError(String),
|
|
||||||
|
|
||||||
#[error("Task join error: {0}")]
|
|
||||||
JoinError(#[from] JoinError),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
|
||||||
pub enum IngressConsumerError {
|
|
||||||
#[error("RabbitMQ error: {0}")]
|
#[error("RabbitMQ error: {0}")]
|
||||||
RabbitMQ(#[from] RabbitMQError),
|
RabbitMQ(#[from] RabbitMQError),
|
||||||
|
#[error("File error: {0}")]
|
||||||
#[error("Processing error: {0}")]
|
File(#[from] FileError),
|
||||||
Processing(#[from] ProcessingError),
|
#[error("Email error: {0}")]
|
||||||
|
Email(#[from] EmailError),
|
||||||
#[error("Ingress content error: {0}")]
|
#[error("Not found: {0}")]
|
||||||
IngressContent(#[from] IngressContentError),
|
NotFound(String),
|
||||||
|
#[error("Validation error: {0}")]
|
||||||
|
Validation(String),
|
||||||
|
#[error("Authorization error: {0}")]
|
||||||
|
Auth(String),
|
||||||
|
#[error("LLM parsing error: {0}")]
|
||||||
|
LLMParsing(String),
|
||||||
|
#[error("Task join error: {0}")]
|
||||||
|
Join(#[from] JoinError),
|
||||||
|
#[error("Graph mapper error: {0}")]
|
||||||
|
GraphMapper(String),
|
||||||
|
#[error("IoError: {0}")]
|
||||||
|
Io(#[from] std::io::Error),
|
||||||
|
#[error("Minijina error: {0}")]
|
||||||
|
MiniJinja(#[from] minijinja::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
// API-specific errors
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
pub enum ApiError {
|
pub enum ApiError {
|
||||||
#[error("Processing error: {0}")]
|
InternalError(String),
|
||||||
ProcessingError(#[from] ProcessingError),
|
ValidationError(String),
|
||||||
#[error("Ingress content error: {0}")]
|
NotFound(String),
|
||||||
IngressContentError(#[from] IngressContentError),
|
Unauthorized(String),
|
||||||
#[error("Publishing error: {0}")]
|
}
|
||||||
PublishingError(String),
|
|
||||||
#[error("Database error: {0}")]
|
impl From<AppError> for ApiError {
|
||||||
DatabaseError(String),
|
fn from(err: AppError) -> Self {
|
||||||
#[error("Query error: {0}")]
|
match err {
|
||||||
QueryError(String),
|
AppError::Database(_) | AppError::OpenAI(_) | AppError::Email(_) => {
|
||||||
#[error("RabbitMQ error: {0}")]
|
tracing::error!("Internal error: {:?}", err);
|
||||||
RabbitMQError(#[from] RabbitMQError),
|
ApiError::InternalError("Internal server error".to_string())
|
||||||
#[error("LLM processing error: {0}")]
|
}
|
||||||
OpenAIerror(#[from] OpenAIError),
|
AppError::NotFound(msg) => ApiError::NotFound(msg),
|
||||||
#[error("File error: {0}")]
|
AppError::Validation(msg) => ApiError::ValidationError(msg),
|
||||||
FileError(#[from] FileError),
|
AppError::Auth(msg) => ApiError::Unauthorized(msg),
|
||||||
#[error("SurrealDb error: {0}")]
|
_ => ApiError::InternalError("Internal server error".to_string()),
|
||||||
SurrealDbError(#[from] surrealdb::Error),
|
}
|
||||||
#[error("User already exists")]
|
}
|
||||||
UserAlreadyExists,
|
|
||||||
#[error("User was not found")]
|
|
||||||
UserNotFound,
|
|
||||||
#[error("You must provide valid credentials")]
|
|
||||||
AuthRequired,
|
|
||||||
#[error("Templating error: {0}")]
|
|
||||||
TemplatingError(#[from] minijinja::Error),
|
|
||||||
#[error("Mail error: {0}")]
|
|
||||||
EmailError(#[from] EmailError),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IntoResponse for ApiError {
|
impl IntoResponse for ApiError {
|
||||||
fn into_response(self) -> axum::response::Response {
|
fn into_response(self) -> Response {
|
||||||
let (status, error_message) = match &self {
|
let (status, body) = match self {
|
||||||
ApiError::ProcessingError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
|
ApiError::InternalError(message) => (
|
||||||
ApiError::SurrealDbError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
ApiError::PublishingError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
|
json!({
|
||||||
ApiError::DatabaseError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
|
"error": message,
|
||||||
ApiError::OpenAIerror(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
|
|
||||||
ApiError::QueryError(_) => (StatusCode::BAD_REQUEST, self.to_string()),
|
|
||||||
ApiError::UserAlreadyExists => (StatusCode::BAD_REQUEST, self.to_string()),
|
|
||||||
ApiError::AuthRequired => (StatusCode::BAD_REQUEST, self.to_string()),
|
|
||||||
ApiError::UserNotFound => (StatusCode::BAD_REQUEST, self.to_string()),
|
|
||||||
ApiError::IngressContentError(_) => {
|
|
||||||
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string())
|
|
||||||
}
|
|
||||||
ApiError::RabbitMQError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
|
|
||||||
ApiError::FileError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
|
|
||||||
ApiError::TemplatingError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
|
|
||||||
ApiError::EmailError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
|
|
||||||
};
|
|
||||||
|
|
||||||
(
|
|
||||||
status,
|
|
||||||
Json(json!({
|
|
||||||
"error": error_message,
|
|
||||||
"status": "error"
|
"status": "error"
|
||||||
})),
|
}),
|
||||||
)
|
),
|
||||||
.into_response()
|
ApiError::ValidationError(message) => (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
json!({
|
||||||
|
"error": message,
|
||||||
|
"status": "error"
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
ApiError::NotFound(message) => (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
json!({
|
||||||
|
"error": message,
|
||||||
|
"status": "error"
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
ApiError::Unauthorized(message) => (
|
||||||
|
StatusCode::UNAUTHORIZED,
|
||||||
|
json!({
|
||||||
|
"error": message,
|
||||||
|
"status": "error"
|
||||||
|
}),
|
||||||
|
), // ... other matches
|
||||||
|
};
|
||||||
|
(status, Json(body)).into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct ErrorContext {
|
||||||
|
#[allow(dead_code)]
|
||||||
|
templates: Arc<AutoReloader>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ErrorContext {
|
||||||
|
pub fn new(templates: Arc<AutoReloader>) -> Self {
|
||||||
|
Self { templates }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum HtmlError {
|
||||||
|
ServerError(Arc<AutoReloader>),
|
||||||
|
NotFound(Arc<AutoReloader>),
|
||||||
|
Unauthorized(Arc<AutoReloader>),
|
||||||
|
BadRequest(String, Arc<AutoReloader>),
|
||||||
|
Template(String, Arc<AutoReloader>),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement From<ApiError> for HtmlError
|
||||||
|
impl HtmlError {
|
||||||
|
pub fn new(error: AppError, templates: Arc<AutoReloader>) -> Self {
|
||||||
|
match error {
|
||||||
|
AppError::NotFound(_msg) => HtmlError::NotFound(templates),
|
||||||
|
AppError::Auth(_msg) => HtmlError::Unauthorized(templates),
|
||||||
|
AppError::Validation(msg) => HtmlError::BadRequest(msg, templates),
|
||||||
|
_ => {
|
||||||
|
tracing::error!("Internal error: {:?}", error);
|
||||||
|
HtmlError::ServerError(templates)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_template_error(error: minijinja::Error, templates: Arc<AutoReloader>) -> Self {
|
||||||
|
tracing::error!("Template error: {:?}", error);
|
||||||
|
HtmlError::Template(error.to_string(), templates)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse for HtmlError {
|
||||||
|
fn into_response(self) -> Response {
|
||||||
|
let (status, context, templates) = match self {
|
||||||
|
HtmlError::ServerError(templates) | HtmlError::Template(_, templates) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
context! {
|
||||||
|
status_code => 500,
|
||||||
|
title => "Internal Server Error",
|
||||||
|
error => "Internal Server Error",
|
||||||
|
description => "Something went wrong on our end."
|
||||||
|
},
|
||||||
|
templates,
|
||||||
|
),
|
||||||
|
HtmlError::NotFound(templates) => (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
context! {
|
||||||
|
status_code => 404,
|
||||||
|
title => "Page Not Found",
|
||||||
|
error => "Not Found",
|
||||||
|
description => "The page you're looking for doesn't exist or was removed."
|
||||||
|
},
|
||||||
|
templates,
|
||||||
|
),
|
||||||
|
HtmlError::Unauthorized(templates) => (
|
||||||
|
StatusCode::UNAUTHORIZED,
|
||||||
|
context! {
|
||||||
|
status_code => 401,
|
||||||
|
title => "Unauthorized",
|
||||||
|
error => "Access Denied",
|
||||||
|
description => "You need to be logged in to access this page."
|
||||||
|
},
|
||||||
|
templates,
|
||||||
|
),
|
||||||
|
HtmlError::BadRequest(msg, templates) => (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
context! {
|
||||||
|
status_code => 400,
|
||||||
|
title => "Bad Request",
|
||||||
|
error => "Bad Request",
|
||||||
|
description => msg
|
||||||
|
},
|
||||||
|
templates,
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
let html = match templates.acquire_env() {
|
||||||
|
Ok(env) => match env.get_template("errors/error.html") {
|
||||||
|
Ok(tmpl) => match tmpl.render(context) {
|
||||||
|
Ok(output) => output,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Template render error: {:?}", e);
|
||||||
|
Self::fallback_html()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Template get error: {:?}", e);
|
||||||
|
Self::fallback_html()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Environment acquire error: {:?}", e);
|
||||||
|
Self::fallback_html()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
(status, Html(html)).into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HtmlError {
|
||||||
|
fn fallback_html() -> String {
|
||||||
|
r#"
|
||||||
|
<html>
|
||||||
|
<body>
|
||||||
|
<div class="container mx-auto p-4">
|
||||||
|
<h1 class="text-4xl text-error">Error</h1>
|
||||||
|
<p class="mt-4">Sorry, something went wrong displaying this page.</p>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"#
|
||||||
|
.to_string()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
error::ProcessingError,
|
error::AppError,
|
||||||
ingress::analysis::prompt::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE},
|
ingress::analysis::prompt::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE},
|
||||||
retrieval::combined_knowledge_entity_retrieval,
|
retrieval::combined_knowledge_entity_retrieval,
|
||||||
storage::types::knowledge_entity::KnowledgeEntity,
|
storage::types::knowledge_entity::KnowledgeEntity,
|
||||||
};
|
};
|
||||||
use async_openai::types::{
|
use async_openai::{
|
||||||
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
|
error::OpenAIError,
|
||||||
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, ResponseFormat,
|
types::{
|
||||||
ResponseFormatJsonSchema,
|
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
|
||||||
|
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, ResponseFormat,
|
||||||
|
ResponseFormatJsonSchema,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use surrealdb::engine::any::Any;
|
use surrealdb::engine::any::Any;
|
||||||
@@ -38,7 +41,7 @@ impl<'a> IngressAnalyzer<'a> {
|
|||||||
instructions: &str,
|
instructions: &str,
|
||||||
text: &str,
|
text: &str,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
) -> Result<LLMGraphAnalysisResult, ProcessingError> {
|
) -> Result<LLMGraphAnalysisResult, AppError> {
|
||||||
let similar_entities = self
|
let similar_entities = self
|
||||||
.find_similar_entities(category, instructions, text, user_id)
|
.find_similar_entities(category, instructions, text, user_id)
|
||||||
.await?;
|
.await?;
|
||||||
@@ -53,7 +56,7 @@ impl<'a> IngressAnalyzer<'a> {
|
|||||||
instructions: &str,
|
instructions: &str,
|
||||||
text: &str,
|
text: &str,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
) -> Result<Vec<KnowledgeEntity>, ProcessingError> {
|
) -> Result<Vec<KnowledgeEntity>, AppError> {
|
||||||
let input_text = format!(
|
let input_text = format!(
|
||||||
"content: {}, category: {}, user_instructions: {}",
|
"content: {}, category: {}, user_instructions: {}",
|
||||||
text, category, instructions
|
text, category, instructions
|
||||||
@@ -74,7 +77,7 @@ impl<'a> IngressAnalyzer<'a> {
|
|||||||
instructions: &str,
|
instructions: &str,
|
||||||
text: &str,
|
text: &str,
|
||||||
similar_entities: &[KnowledgeEntity],
|
similar_entities: &[KnowledgeEntity],
|
||||||
) -> Result<CreateChatCompletionRequest, ProcessingError> {
|
) -> Result<CreateChatCompletionRequest, OpenAIError> {
|
||||||
let entities_json = json!(similar_entities
|
let entities_json = json!(similar_entities
|
||||||
.iter()
|
.iter()
|
||||||
.map(|entity| {
|
.map(|entity| {
|
||||||
@@ -114,13 +117,12 @@ impl<'a> IngressAnalyzer<'a> {
|
|||||||
])
|
])
|
||||||
.response_format(response_format)
|
.response_format(response_format)
|
||||||
.build()
|
.build()
|
||||||
.map_err(|e| ProcessingError::LLMParsingError(e.to_string()))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn perform_analysis(
|
async fn perform_analysis(
|
||||||
&self,
|
&self,
|
||||||
request: CreateChatCompletionRequest,
|
request: CreateChatCompletionRequest,
|
||||||
) -> Result<LLMGraphAnalysisResult, ProcessingError> {
|
) -> Result<LLMGraphAnalysisResult, AppError> {
|
||||||
let response = self.openai_client.chat().create(request).await?;
|
let response = self.openai_client.chat().create(request).await?;
|
||||||
debug!("Received LLM response: {:?}", response);
|
debug!("Received LLM response: {:?}", response);
|
||||||
|
|
||||||
@@ -128,12 +130,12 @@ impl<'a> IngressAnalyzer<'a> {
|
|||||||
.choices
|
.choices
|
||||||
.first()
|
.first()
|
||||||
.and_then(|choice| choice.message.content.as_ref())
|
.and_then(|choice| choice.message.content.as_ref())
|
||||||
.ok_or(ProcessingError::LLMParsingError(
|
.ok_or(AppError::LLMParsing(
|
||||||
"No content found in LLM response".into(),
|
"No content found in LLM response".to_string(),
|
||||||
))
|
))
|
||||||
.and_then(|content| {
|
.and_then(|content| {
|
||||||
serde_json::from_str(content).map_err(|e| {
|
serde_json::from_str(content).map_err(|e| {
|
||||||
ProcessingError::LLMParsingError(format!(
|
AppError::LLMParsing(format!(
|
||||||
"Failed to parse LLM response into analysis: {}",
|
"Failed to parse LLM response into analysis: {}",
|
||||||
e
|
e
|
||||||
))
|
))
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
|
|||||||
use tokio::task;
|
use tokio::task;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
error::ProcessingError,
|
error::AppError,
|
||||||
storage::types::{
|
storage::types::{
|
||||||
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
||||||
knowledge_relationship::KnowledgeRelationship,
|
knowledge_relationship::KnowledgeRelationship,
|
||||||
@@ -49,13 +49,13 @@ impl LLMGraphAnalysisResult {
|
|||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
///
|
///
|
||||||
/// * `Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), ProcessingError>` - A tuple containing vectors of `KnowledgeEntity` and `KnowledgeRelationship`.
|
/// * `Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), AppError>` - A tuple containing vectors of `KnowledgeEntity` and `KnowledgeRelationship`.
|
||||||
pub async fn to_database_entities(
|
pub async fn to_database_entities(
|
||||||
&self,
|
&self,
|
||||||
source_id: &str,
|
source_id: &str,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||||
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), ProcessingError> {
|
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), AppError> {
|
||||||
// Create mapper and pre-assign IDs
|
// Create mapper and pre-assign IDs
|
||||||
let mapper = Arc::new(Mutex::new(self.create_mapper()?));
|
let mapper = Arc::new(Mutex::new(self.create_mapper()?));
|
||||||
|
|
||||||
@@ -70,7 +70,7 @@ impl LLMGraphAnalysisResult {
|
|||||||
Ok((entities, relationships))
|
Ok((entities, relationships))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_mapper(&self) -> Result<GraphMapper, ProcessingError> {
|
fn create_mapper(&self) -> Result<GraphMapper, AppError> {
|
||||||
let mut mapper = GraphMapper::new();
|
let mut mapper = GraphMapper::new();
|
||||||
|
|
||||||
// Pre-assign all IDs
|
// Pre-assign all IDs
|
||||||
@@ -87,7 +87,7 @@ impl LLMGraphAnalysisResult {
|
|||||||
user_id: &str,
|
user_id: &str,
|
||||||
mapper: Arc<Mutex<GraphMapper>>,
|
mapper: Arc<Mutex<GraphMapper>>,
|
||||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||||
) -> Result<Vec<KnowledgeEntity>, ProcessingError> {
|
) -> Result<Vec<KnowledgeEntity>, AppError> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.knowledge_entities
|
.knowledge_entities
|
||||||
.iter()
|
.iter()
|
||||||
@@ -116,10 +116,10 @@ impl LLMGraphAnalysisResult {
|
|||||||
fn process_relationships(
|
fn process_relationships(
|
||||||
&self,
|
&self,
|
||||||
mapper: Arc<Mutex<GraphMapper>>,
|
mapper: Arc<Mutex<GraphMapper>>,
|
||||||
) -> Result<Vec<KnowledgeRelationship>, ProcessingError> {
|
) -> Result<Vec<KnowledgeRelationship>, AppError> {
|
||||||
let mut mapper_guard = mapper
|
let mut mapper_guard = mapper
|
||||||
.lock()
|
.lock()
|
||||||
.map_err(|_| ProcessingError::GraphProcessingError("Failed to lock mapper".into()))?;
|
.map_err(|_| AppError::GraphMapper("Failed to lock mapper".into()))?;
|
||||||
self.relationships
|
self.relationships
|
||||||
.iter()
|
.iter()
|
||||||
.map(|rel| {
|
.map(|rel| {
|
||||||
@@ -142,18 +142,15 @@ async fn create_single_entity(
|
|||||||
user_id: &str,
|
user_id: &str,
|
||||||
mapper: Arc<Mutex<GraphMapper>>,
|
mapper: Arc<Mutex<GraphMapper>>,
|
||||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||||
) -> Result<KnowledgeEntity, ProcessingError> {
|
) -> Result<KnowledgeEntity, AppError> {
|
||||||
let assigned_id = {
|
let assigned_id = {
|
||||||
let mapper = mapper
|
let mapper = mapper
|
||||||
.lock()
|
.lock()
|
||||||
.map_err(|_| ProcessingError::GraphProcessingError("Failed to lock mapper".into()))?;
|
.map_err(|_| AppError::GraphMapper("Failed to lock mapper".into()))?;
|
||||||
mapper
|
mapper
|
||||||
.get_id(&llm_entity.key)
|
.get_id(&llm_entity.key)
|
||||||
.ok_or_else(|| {
|
.ok_or_else(|| {
|
||||||
ProcessingError::GraphProcessingError(format!(
|
AppError::GraphMapper(format!("ID not found for key: {}", llm_entity.key))
|
||||||
"ID not found for key: {}",
|
|
||||||
llm_entity.key
|
|
||||||
))
|
|
||||||
})?
|
})?
|
||||||
.to_string()
|
.to_string()
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use text_splitter::TextSplitter;
|
|||||||
use tracing::{debug, info};
|
use tracing::{debug, info};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
error::ProcessingError,
|
error::AppError,
|
||||||
storage::{
|
storage::{
|
||||||
db::{store_item, SurrealDbClient},
|
db::{store_item, SurrealDbClient},
|
||||||
types::{
|
types::{
|
||||||
@@ -25,7 +25,7 @@ pub struct ContentProcessor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ContentProcessor {
|
impl ContentProcessor {
|
||||||
pub async fn new(app_config: &AppConfig) -> Result<Self, ProcessingError> {
|
pub async fn new(app_config: &AppConfig) -> Result<Self, AppError> {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
db_client: SurrealDbClient::new(
|
db_client: SurrealDbClient::new(
|
||||||
&app_config.surrealdb_address,
|
&app_config.surrealdb_address,
|
||||||
@@ -39,7 +39,7 @@ impl ContentProcessor {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn process(&self, content: &TextContent) -> Result<(), ProcessingError> {
|
pub async fn process(&self, content: &TextContent) -> Result<(), AppError> {
|
||||||
// Store original content
|
// Store original content
|
||||||
store_item(&self.db_client, content.clone()).await?;
|
store_item(&self.db_client, content.clone()).await?;
|
||||||
|
|
||||||
@@ -72,7 +72,7 @@ impl ContentProcessor {
|
|||||||
async fn perform_semantic_analysis(
|
async fn perform_semantic_analysis(
|
||||||
&self,
|
&self,
|
||||||
content: &TextContent,
|
content: &TextContent,
|
||||||
) -> Result<LLMGraphAnalysisResult, ProcessingError> {
|
) -> Result<LLMGraphAnalysisResult, AppError> {
|
||||||
let analyser = IngressAnalyzer::new(&self.db_client, &self.openai_client);
|
let analyser = IngressAnalyzer::new(&self.db_client, &self.openai_client);
|
||||||
analyser
|
analyser
|
||||||
.analyze_content(
|
.analyze_content(
|
||||||
@@ -88,7 +88,7 @@ impl ContentProcessor {
|
|||||||
&self,
|
&self,
|
||||||
entities: Vec<KnowledgeEntity>,
|
entities: Vec<KnowledgeEntity>,
|
||||||
relationships: Vec<KnowledgeRelationship>,
|
relationships: Vec<KnowledgeRelationship>,
|
||||||
) -> Result<(), ProcessingError> {
|
) -> Result<(), AppError> {
|
||||||
for entity in &entities {
|
for entity in &entities {
|
||||||
debug!("Storing entity: {:?}", entity);
|
debug!("Storing entity: {:?}", entity);
|
||||||
store_item(&self.db_client, entity.clone()).await?;
|
store_item(&self.db_client, entity.clone()).await?;
|
||||||
@@ -107,7 +107,7 @@ impl ContentProcessor {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn store_vector_chunks(&self, content: &TextContent) -> Result<(), ProcessingError> {
|
async fn store_vector_chunks(&self, content: &TextContent) -> Result<(), AppError> {
|
||||||
let splitter = TextSplitter::new(500..2000);
|
let splitter = TextSplitter::new(500..2000);
|
||||||
let chunks = splitter.chunks(&content.text);
|
let chunks = splitter.chunks(&content.text);
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
use super::ingress_object::IngressObject;
|
use super::ingress_object::IngressObject;
|
||||||
use crate::storage::{
|
use crate::{
|
||||||
db::{get_item, SurrealDbClient},
|
error::AppError,
|
||||||
types::file_info::FileInfo,
|
storage::{
|
||||||
|
db::{get_item, SurrealDbClient},
|
||||||
|
types::file_info::FileInfo,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use thiserror::Error;
|
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
@@ -17,34 +19,6 @@ pub struct IngressInput {
|
|||||||
pub files: Option<Vec<String>>,
|
pub files: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Error types for processing ingress content.
|
|
||||||
#[derive(Error, Debug)]
|
|
||||||
pub enum IngressContentError {
|
|
||||||
#[error("IO error occurred: {0}")]
|
|
||||||
Io(#[from] std::io::Error),
|
|
||||||
|
|
||||||
#[error("UTF-8 conversion error: {0}")]
|
|
||||||
Utf8(#[from] std::string::FromUtf8Error),
|
|
||||||
|
|
||||||
#[error("SurrealDb error: {0}")]
|
|
||||||
SurrealDbError(#[from] surrealdb::Error),
|
|
||||||
|
|
||||||
#[error("MIME type detection failed for input: {0}")]
|
|
||||||
MimeDetection(String),
|
|
||||||
|
|
||||||
#[error("Unsupported MIME type: {0}")]
|
|
||||||
UnsupportedMime(String),
|
|
||||||
|
|
||||||
#[error("URL parse error: {0}")]
|
|
||||||
UrlParse(#[from] url::ParseError),
|
|
||||||
|
|
||||||
#[error("UUID parse error: {0}")]
|
|
||||||
UuidParse(#[from] uuid::Error),
|
|
||||||
|
|
||||||
#[error("Redis error: {0}")]
|
|
||||||
RedisError(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Function to create ingress objects from input.
|
/// Function to create ingress objects from input.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
@@ -57,7 +31,7 @@ pub async fn create_ingress_objects(
|
|||||||
input: IngressInput,
|
input: IngressInput,
|
||||||
db_client: &SurrealDbClient,
|
db_client: &SurrealDbClient,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
) -> Result<Vec<IngressObject>, IngressContentError> {
|
) -> Result<Vec<IngressObject>, AppError> {
|
||||||
// Initialize list
|
// Initialize list
|
||||||
let mut object_list = Vec::new();
|
let mut object_list = Vec::new();
|
||||||
|
|
||||||
@@ -103,7 +77,7 @@ pub async fn create_ingress_objects(
|
|||||||
|
|
||||||
// If no objects are constructed, we return Err
|
// If no objects are constructed, we return Err
|
||||||
if object_list.is_empty() {
|
if object_list.is_empty() {
|
||||||
return Err(IngressContentError::MimeDetection(
|
return Err(AppError::NotFound(
|
||||||
"No valid content or files provided".into(),
|
"No valid content or files provided".into(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
use crate::storage::types::{file_info::FileInfo, text_content::TextContent};
|
use crate::{
|
||||||
|
error::AppError,
|
||||||
|
storage::types::{file_info::FileInfo, text_content::TextContent},
|
||||||
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use super::ingress_input::IngressContentError;
|
|
||||||
|
|
||||||
/// Knowledge object type, containing the content or reference to it, as well as metadata
|
/// Knowledge object type, containing the content or reference to it, as well as metadata
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
pub enum IngressObject {
|
pub enum IngressObject {
|
||||||
@@ -34,7 +35,7 @@ impl IngressObject {
|
|||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// `TextContent` - An object containing a text representation of the object, could be a scraped URL, parsed PDF, etc.
|
/// `TextContent` - An object containing a text representation of the object, could be a scraped URL, parsed PDF, etc.
|
||||||
pub async fn to_text_content(&self) -> Result<TextContent, IngressContentError> {
|
pub async fn to_text_content(&self) -> Result<TextContent, AppError> {
|
||||||
match self {
|
match self {
|
||||||
IngressObject::Url {
|
IngressObject::Url {
|
||||||
url,
|
url,
|
||||||
@@ -82,12 +83,12 @@ impl IngressObject {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Fetches and extracts text from a URL.
|
/// Fetches and extracts text from a URL.
|
||||||
async fn fetch_text_from_url(_url: &str) -> Result<String, IngressContentError> {
|
async fn fetch_text_from_url(_url: &str) -> Result<String, AppError> {
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extracts text from a file based on its MIME type.
|
/// Extracts text from a file based on its MIME type.
|
||||||
async fn extract_text_from_file(file_info: &FileInfo) -> Result<String, IngressContentError> {
|
async fn extract_text_from_file(file_info: &FileInfo) -> Result<String, AppError> {
|
||||||
match file_info.mime_type.as_str() {
|
match file_info.mime_type.as_str() {
|
||||||
"text/plain" => {
|
"text/plain" => {
|
||||||
// Read the file and return its content
|
// Read the file and return its content
|
||||||
@@ -101,15 +102,11 @@ impl IngressObject {
|
|||||||
}
|
}
|
||||||
"application/pdf" => {
|
"application/pdf" => {
|
||||||
// TODO: Implement PDF text extraction using a crate like `pdf-extract` or `lopdf`
|
// TODO: Implement PDF text extraction using a crate like `pdf-extract` or `lopdf`
|
||||||
Err(IngressContentError::UnsupportedMime(
|
Err(AppError::NotFound(file_info.mime_type.clone()))
|
||||||
file_info.mime_type.clone(),
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
"image/png" | "image/jpeg" => {
|
"image/png" | "image/jpeg" => {
|
||||||
// TODO: Implement OCR on image using a crate like `tesseract`
|
// TODO: Implement OCR on image using a crate like `tesseract`
|
||||||
Err(IngressContentError::UnsupportedMime(
|
Err(AppError::NotFound(file_info.mime_type.clone()))
|
||||||
file_info.mime_type.clone(),
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
"application/octet-stream" => {
|
"application/octet-stream" => {
|
||||||
let content = tokio::fs::read_to_string(&file_info.path).await?;
|
let content = tokio::fs::read_to_string(&file_info.path).await?;
|
||||||
@@ -120,9 +117,7 @@ impl IngressObject {
|
|||||||
Ok(content)
|
Ok(content)
|
||||||
}
|
}
|
||||||
// Handle other MIME types as needed
|
// Handle other MIME types as needed
|
||||||
_ => Err(IngressContentError::UnsupportedMime(
|
_ => Err(AppError::NotFound(file_info.mime_type.clone())),
|
||||||
file_info.mime_type.clone(),
|
|
||||||
)),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,8 +9,6 @@ use lapin::{
|
|||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
|
|
||||||
use crate::error::ProcessingError;
|
|
||||||
|
|
||||||
/// Possible errors related to RabbitMQ operations.
|
/// Possible errors related to RabbitMQ operations.
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum RabbitMQError {
|
pub enum RabbitMQError {
|
||||||
@@ -28,8 +26,6 @@ pub enum RabbitMQError {
|
|||||||
InitializeConsumerError(String),
|
InitializeConsumerError(String),
|
||||||
#[error("Queue error: {0}")]
|
#[error("Queue error: {0}")]
|
||||||
QueueError(String),
|
QueueError(String),
|
||||||
#[error("Processing error: {0}")]
|
|
||||||
ProcessingError(#[from] ProcessingError),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Struct containing the information required to set up a client and connection.
|
/// Struct containing the information required to set up a client and connection.
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ pub mod query_helper_prompt;
|
|||||||
pub mod vector;
|
pub mod vector;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
error::ProcessingError,
|
error::AppError,
|
||||||
retrieval::{
|
retrieval::{
|
||||||
graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids},
|
graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids},
|
||||||
vector::find_items_by_vector_similarity,
|
vector::find_items_by_vector_similarity,
|
||||||
@@ -34,14 +34,14 @@ use surrealdb::{engine::any::Any, Surreal};
|
|||||||
/// * 'user_id' - The user id of the current user
|
/// * 'user_id' - The user id of the current user
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// * `Result<Vec<KnowledgeEntity>, ProcessingError>` - A deduplicated vector of relevant
|
/// * `Result<Vec<KnowledgeEntity>, AppError>` - A deduplicated vector of relevant
|
||||||
/// knowledge entities, or an error if the retrieval process fails
|
/// knowledge entities, or an error if the retrieval process fails
|
||||||
pub async fn combined_knowledge_entity_retrieval(
|
pub async fn combined_knowledge_entity_retrieval(
|
||||||
db_client: &Surreal<Any>,
|
db_client: &Surreal<Any>,
|
||||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||||
query: &str,
|
query: &str,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
) -> Result<Vec<KnowledgeEntity>, ProcessingError> {
|
) -> Result<Vec<KnowledgeEntity>, AppError> {
|
||||||
// info!("Received input: {:?}", query);
|
// info!("Received input: {:?}", query);
|
||||||
|
|
||||||
let (items_from_knowledge_entity_similarity, closest_chunks) = try_join(
|
let (items_from_knowledge_entity_similarity, closest_chunks) = try_join(
|
||||||
|
|||||||
@@ -1,14 +1,17 @@
|
|||||||
use async_openai::types::{
|
use async_openai::{
|
||||||
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
|
error::OpenAIError,
|
||||||
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
|
types::{
|
||||||
ResponseFormat, ResponseFormatJsonSchema,
|
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
|
||||||
|
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
|
||||||
|
ResponseFormat, ResponseFormatJsonSchema,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
error::ApiError,
|
error::AppError,
|
||||||
retrieval::combined_knowledge_entity_retrieval,
|
retrieval::combined_knowledge_entity_retrieval,
|
||||||
storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity},
|
storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity},
|
||||||
};
|
};
|
||||||
@@ -94,7 +97,7 @@ pub async fn get_answer_with_references(
|
|||||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||||
query: &str,
|
query: &str,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
) -> Result<Answer, ApiError> {
|
) -> Result<Answer, AppError> {
|
||||||
let entities =
|
let entities =
|
||||||
combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query, user_id)
|
combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query, user_id)
|
||||||
.await?;
|
.await?;
|
||||||
@@ -104,11 +107,7 @@ pub async fn get_answer_with_references(
|
|||||||
let user_message = create_user_message(&entities_json, query);
|
let user_message = create_user_message(&entities_json, query);
|
||||||
|
|
||||||
let request = create_chat_request(user_message)?;
|
let request = create_chat_request(user_message)?;
|
||||||
let response = openai_client
|
let response = openai_client.chat().create(request).await?;
|
||||||
.chat()
|
|
||||||
.create(request)
|
|
||||||
.await
|
|
||||||
.map_err(|e| ApiError::QueryError(e.to_string()))?;
|
|
||||||
|
|
||||||
let llm_response = process_llm_response(response).await?;
|
let llm_response = process_llm_response(response).await?;
|
||||||
|
|
||||||
@@ -152,7 +151,9 @@ pub fn create_user_message(entities_json: &Value, query: &str) -> String {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn create_chat_request(user_message: String) -> Result<CreateChatCompletionRequest, ApiError> {
|
pub fn create_chat_request(
|
||||||
|
user_message: String,
|
||||||
|
) -> Result<CreateChatCompletionRequest, OpenAIError> {
|
||||||
let response_format = ResponseFormat::JsonSchema {
|
let response_format = ResponseFormat::JsonSchema {
|
||||||
json_schema: ResponseFormatJsonSchema {
|
json_schema: ResponseFormatJsonSchema {
|
||||||
description: Some("Query answering AI".into()),
|
description: Some("Query answering AI".into()),
|
||||||
@@ -172,22 +173,21 @@ pub fn create_chat_request(user_message: String) -> Result<CreateChatCompletionR
|
|||||||
])
|
])
|
||||||
.response_format(response_format)
|
.response_format(response_format)
|
||||||
.build()
|
.build()
|
||||||
.map_err(|e| ApiError::QueryError(e.to_string()))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn process_llm_response(
|
pub async fn process_llm_response(
|
||||||
response: CreateChatCompletionResponse,
|
response: CreateChatCompletionResponse,
|
||||||
) -> Result<LLMResponseFormat, ApiError> {
|
) -> Result<LLMResponseFormat, AppError> {
|
||||||
response
|
response
|
||||||
.choices
|
.choices
|
||||||
.first()
|
.first()
|
||||||
.and_then(|choice| choice.message.content.as_ref())
|
.and_then(|choice| choice.message.content.as_ref())
|
||||||
.ok_or(ApiError::QueryError(
|
.ok_or(AppError::LLMParsing(
|
||||||
"No content found in LLM response".into(),
|
"No content found in LLM response".into(),
|
||||||
))
|
))
|
||||||
.and_then(|content| {
|
.and_then(|content| {
|
||||||
serde_json::from_str::<LLMResponseFormat>(content).map_err(|e| {
|
serde_json::from_str::<LLMResponseFormat>(content).map_err(|e| {
|
||||||
ApiError::QueryError(format!("Failed to parse LLM response into analysis: {}", e))
|
AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {}", e))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
use surrealdb::{engine::any::Any, Surreal};
|
use surrealdb::{engine::any::Any, Surreal};
|
||||||
|
|
||||||
use crate::{error::ProcessingError, utils::embedding::generate_embedding};
|
use crate::{error::AppError, utils::embedding::generate_embedding};
|
||||||
|
|
||||||
/// Compares vectors and retrieves a number of items from the specified table.
|
/// Compares vectors and retrieves a number of items from the specified table.
|
||||||
///
|
///
|
||||||
@@ -30,7 +30,7 @@ pub async fn find_items_by_vector_similarity<T>(
|
|||||||
table: String,
|
table: String,
|
||||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
) -> Result<Vec<T>, ProcessingError>
|
) -> Result<Vec<T>, AppError>
|
||||||
where
|
where
|
||||||
T: for<'de> serde::Deserialize<'de>,
|
T: for<'de> serde::Deserialize<'de>,
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -13,10 +13,14 @@ pub async fn api_auth(
|
|||||||
mut request: Request,
|
mut request: Request,
|
||||||
next: Next,
|
next: Next,
|
||||||
) -> Result<Response, ApiError> {
|
) -> Result<Response, ApiError> {
|
||||||
let api_key = extract_api_key(&request).ok_or(ApiError::AuthRequired)?;
|
let api_key = extract_api_key(&request).ok_or(ApiError::Unauthorized(
|
||||||
|
"You have to be authenticated".to_string(),
|
||||||
|
))?;
|
||||||
|
|
||||||
let user = User::find_by_api_key(&api_key, &state.surreal_db_client).await?;
|
let user = User::find_by_api_key(&api_key, &state.surreal_db_client).await?;
|
||||||
let user = user.ok_or(ApiError::UserNotFound)?;
|
let user = user.ok_or(ApiError::Unauthorized(
|
||||||
|
"You have to be authenticated".to_string(),
|
||||||
|
))?;
|
||||||
|
|
||||||
request.extensions_mut().insert(user);
|
request.extensions_mut().insert(user);
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,8 @@
|
|||||||
use crate::{error::ApiError, server::AppState, storage::types::file_info::FileInfo};
|
use crate::{
|
||||||
|
error::{ApiError, AppError},
|
||||||
|
server::AppState,
|
||||||
|
storage::types::file_info::FileInfo,
|
||||||
|
};
|
||||||
use axum::{extract::State, response::IntoResponse, Json};
|
use axum::{extract::State, response::IntoResponse, Json};
|
||||||
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
|
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
@@ -21,7 +25,9 @@ pub async fn upload_handler(
|
|||||||
info!("Received an upload request");
|
info!("Received an upload request");
|
||||||
|
|
||||||
// Process the file upload
|
// Process the file upload
|
||||||
let file_info = FileInfo::new(input.file, &state.surreal_db_client).await?;
|
let file_info = FileInfo::new(input.file, &state.surreal_db_client)
|
||||||
|
.await
|
||||||
|
.map_err(AppError::from)?;
|
||||||
|
|
||||||
// Prepare the response JSON
|
// Prepare the response JSON
|
||||||
let response = json!({
|
let response = json!({
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
error::ApiError,
|
error::{ApiError, AppError},
|
||||||
ingress::types::ingress_input::{create_ingress_objects, IngressInput},
|
ingress::types::ingress_input::{create_ingress_objects, IngressInput},
|
||||||
server::AppState,
|
server::AppState,
|
||||||
storage::types::user::User,
|
storage::types::user::User,
|
||||||
@@ -22,7 +22,7 @@ pub async fn ingress_handler(
|
|||||||
.map(|object| state.rabbitmq_producer.publish(object))
|
.map(|object| state.rabbitmq_producer.publish(object))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
try_join_all(futures).await?;
|
try_join_all(futures).await.map_err(AppError::from)?;
|
||||||
|
|
||||||
Ok(StatusCode::OK)
|
Ok(StatusCode::OK)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,21 +1,28 @@
|
|||||||
use axum::{extract::State, http::StatusCode, response::IntoResponse};
|
use axum::{extract::State, http::StatusCode, response::IntoResponse};
|
||||||
use minijinja::context;
|
use tracing::info;
|
||||||
use tracing::{info, Instrument};
|
|
||||||
|
|
||||||
use crate::{error::ApiError, server::AppState};
|
use crate::{
|
||||||
|
error::{ApiError, AppError},
|
||||||
|
server::AppState,
|
||||||
|
};
|
||||||
|
|
||||||
pub async fn queue_length_handler(
|
pub async fn queue_length_handler(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
) -> Result<impl IntoResponse, ApiError> {
|
) -> Result<impl IntoResponse, ApiError> {
|
||||||
info!("Getting queue length");
|
info!("Getting queue length");
|
||||||
|
|
||||||
let queue_length = state.rabbitmq_consumer.get_queue_length().await?;
|
let queue_length = state
|
||||||
|
.rabbitmq_consumer
|
||||||
|
.get_queue_length()
|
||||||
|
.await
|
||||||
|
.map_err(AppError::from)?;
|
||||||
|
|
||||||
info!("Queue length: {}", queue_length);
|
info!("Queue length: {}", queue_length);
|
||||||
|
|
||||||
state
|
state
|
||||||
.mailer
|
.mailer
|
||||||
.send_email_verification("per@starks.cloud", "1001010", &state.templates)?;
|
.send_email_verification("per@starks.cloud", "1001010", &state.templates)
|
||||||
|
.map_err(AppError::from)?;
|
||||||
|
|
||||||
// Return the queue length with a 200 OK status
|
// Return the queue length with a 200 OK status
|
||||||
Ok((StatusCode::OK, queue_length.to_string()))
|
Ok((StatusCode::OK, queue_length.to_string()))
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use axum::{
|
use axum::{
|
||||||
extract::State,
|
extract::State,
|
||||||
http::{Response, StatusCode, Uri},
|
http::{StatusCode, Uri},
|
||||||
response::{Html, IntoResponse, Redirect},
|
response::{IntoResponse, Redirect},
|
||||||
};
|
};
|
||||||
use axum_htmx::HxRedirect;
|
use axum_htmx::HxRedirect;
|
||||||
use axum_session_auth::AuthSession;
|
use axum_session_auth::AuthSession;
|
||||||
@@ -9,7 +9,7 @@ use axum_session_surreal::SessionSurrealPool;
|
|||||||
use surrealdb::{engine::any::Any, Surreal};
|
use surrealdb::{engine::any::Any, Surreal};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
error::ApiError,
|
error::{AppError, HtmlError},
|
||||||
page_data,
|
page_data,
|
||||||
server::{routes::html::render_template, AppState},
|
server::{routes::html::render_template, AppState},
|
||||||
storage::{db::delete_item, types::user::User},
|
storage::{db::delete_item, types::user::User},
|
||||||
@@ -24,7 +24,7 @@ page_data!(AccountData, "auth/account.html", {
|
|||||||
pub async fn show_account_page(
|
pub async fn show_account_page(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||||
) -> Result<impl IntoResponse, ApiError> {
|
) -> Result<impl IntoResponse, HtmlError> {
|
||||||
// Early return if the user is not authenticated
|
// Early return if the user is not authenticated
|
||||||
let user = match auth.current_user {
|
let user = match auth.current_user {
|
||||||
Some(user) => user,
|
Some(user) => user,
|
||||||
@@ -34,8 +34,9 @@ pub async fn show_account_page(
|
|||||||
let output = render_template(
|
let output = render_template(
|
||||||
AccountData::template_name(),
|
AccountData::template_name(),
|
||||||
AccountData { user },
|
AccountData { user },
|
||||||
state.templates,
|
state.templates.clone(),
|
||||||
)?;
|
)
|
||||||
|
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
|
||||||
|
|
||||||
Ok(output.into_response())
|
Ok(output.into_response())
|
||||||
}
|
}
|
||||||
@@ -43,12 +44,17 @@ pub async fn show_account_page(
|
|||||||
pub async fn set_api_key(
|
pub async fn set_api_key(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||||
) -> Result<impl IntoResponse, ApiError> {
|
) -> Result<impl IntoResponse, HtmlError> {
|
||||||
// Early return if the user is not authenticated
|
// Early return if the user is not authenticated
|
||||||
let user = auth.current_user.as_ref().ok_or(ApiError::AuthRequired)?;
|
let user = match &auth.current_user {
|
||||||
|
Some(user) => user,
|
||||||
|
None => return Ok(Redirect::to("/").into_response()),
|
||||||
|
};
|
||||||
|
|
||||||
// Generate and set the API key
|
// Generate and set the API key
|
||||||
let api_key = User::set_api_key(&user.id, &state.surreal_db_client).await?;
|
let api_key = User::set_api_key(&user.id, &state.surreal_db_client)
|
||||||
|
.await
|
||||||
|
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||||
|
|
||||||
auth.cache_clear_user(user.id.to_string());
|
auth.cache_clear_user(user.id.to_string());
|
||||||
|
|
||||||
@@ -63,8 +69,9 @@ pub async fn set_api_key(
|
|||||||
AccountData::template_name(),
|
AccountData::template_name(),
|
||||||
"api_key_section",
|
"api_key_section",
|
||||||
AccountData { user: updated_user },
|
AccountData { user: updated_user },
|
||||||
state.templates,
|
state.templates.clone(),
|
||||||
)?;
|
)
|
||||||
|
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
|
||||||
|
|
||||||
Ok(output.into_response())
|
Ok(output.into_response())
|
||||||
}
|
}
|
||||||
@@ -72,11 +79,16 @@ pub async fn set_api_key(
|
|||||||
pub async fn delete_account(
|
pub async fn delete_account(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||||
) -> Result<impl IntoResponse, ApiError> {
|
) -> Result<impl IntoResponse, HtmlError> {
|
||||||
// Early return if the user is not authenticated
|
// Early return if the user is not authenticated
|
||||||
let user = auth.current_user.as_ref().ok_or(ApiError::AuthRequired)?;
|
let user = match &auth.current_user {
|
||||||
|
Some(user) => user,
|
||||||
|
None => return Ok(Redirect::to("/").into_response()),
|
||||||
|
};
|
||||||
|
|
||||||
delete_item::<User>(&state.surreal_db_client, &user.id).await?;
|
delete_item::<User>(&state.surreal_db_client, &user.id)
|
||||||
|
.await
|
||||||
|
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
|
||||||
|
|
||||||
auth.logout_user();
|
auth.logout_user();
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
use axum::{extract::State, response::Html};
|
use axum::{extract::State, response::IntoResponse};
|
||||||
use axum_session_auth::AuthSession;
|
use axum_session_auth::AuthSession;
|
||||||
use axum_session_surreal::SessionSurrealPool;
|
use axum_session_surreal::SessionSurrealPool;
|
||||||
use surrealdb::{engine::any::Any, Surreal};
|
use surrealdb::{engine::any::Any, Surreal};
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
error::ApiError,
|
error::{AppError, HtmlError},
|
||||||
page_data,
|
page_data,
|
||||||
server::{routes::html::render_template, AppState},
|
server::{routes::html::render_template, AppState},
|
||||||
storage::types::user::User,
|
storage::types::user::User,
|
||||||
@@ -19,10 +19,14 @@ page_data!(IndexData, "index/index.html", {
|
|||||||
pub async fn index_handler(
|
pub async fn index_handler(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||||
) -> Result<Html<String>, ApiError> {
|
) -> Result<impl IntoResponse, HtmlError> {
|
||||||
info!("Displaying index page");
|
info!("Displaying index page");
|
||||||
|
|
||||||
let queue_length = state.rabbitmq_consumer.get_queue_length().await?;
|
let queue_length = state
|
||||||
|
.rabbitmq_consumer
|
||||||
|
.get_queue_length()
|
||||||
|
.await
|
||||||
|
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
|
||||||
|
|
||||||
// let knowledge_entities = User::get_knowledge_entities(
|
// let knowledge_entities = User::get_knowledge_entities(
|
||||||
// &auth.current_user.clone().unwrap().id,
|
// &auth.current_user.clone().unwrap().id,
|
||||||
@@ -38,8 +42,9 @@ pub async fn index_handler(
|
|||||||
queue_length,
|
queue_length,
|
||||||
user: auth.current_user,
|
user: auth.current_user,
|
||||||
},
|
},
|
||||||
state.templates,
|
state.templates.clone(),
|
||||||
)?;
|
)
|
||||||
|
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
|
||||||
|
|
||||||
Ok(output)
|
Ok(output.into_response())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,25 +1,22 @@
|
|||||||
use axum::{
|
use axum::{
|
||||||
extract::State,
|
extract::State,
|
||||||
http::{StatusCode, Uri},
|
|
||||||
response::{Html, IntoResponse, Redirect},
|
response::{Html, IntoResponse, Redirect},
|
||||||
Form,
|
|
||||||
};
|
};
|
||||||
use axum_htmx::{HxBoosted, HxRedirect};
|
|
||||||
use axum_session_auth::AuthSession;
|
use axum_session_auth::AuthSession;
|
||||||
use axum_session_surreal::SessionSurrealPool;
|
use axum_session_surreal::SessionSurrealPool;
|
||||||
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
|
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::Serialize;
|
||||||
use surrealdb::{engine::any::Any, Surreal};
|
use surrealdb::{engine::any::Any, Surreal};
|
||||||
use tempfile::NamedTempFile;
|
use tempfile::NamedTempFile;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
error::ApiError,
|
error::{AppError, HtmlError},
|
||||||
server::AppState,
|
server::AppState,
|
||||||
storage::types::{file_info::FileInfo, user::User},
|
storage::types::{file_info::FileInfo, user::User},
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{render_block, render_template};
|
use super::render_template;
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
struct PageData {
|
struct PageData {
|
||||||
@@ -29,13 +26,17 @@ struct PageData {
|
|||||||
pub async fn show_ingress_form(
|
pub async fn show_ingress_form(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||||
) -> Result<impl IntoResponse, ApiError> {
|
) -> Result<impl IntoResponse, HtmlError> {
|
||||||
if !auth.is_authenticated() {
|
if !auth.is_authenticated() {
|
||||||
return Ok(Redirect::to("/").into_response());
|
return Ok(Redirect::to("/").into_response());
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(render_template("ingress_form.html", PageData {}, state.templates)?.into_response())
|
let output = render_template("ingress_form.html", PageData {}, state.templates.clone())
|
||||||
|
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
|
||||||
|
|
||||||
|
Ok(output.into_response())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, TryFromMultipart)]
|
#[derive(Debug, TryFromMultipart)]
|
||||||
pub struct IngressParams {
|
pub struct IngressParams {
|
||||||
pub content: Option<String>,
|
pub content: Option<String>,
|
||||||
@@ -49,8 +50,8 @@ pub async fn process_ingress_form(
|
|||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||||
TypedMultipart(input): TypedMultipart<IngressParams>,
|
TypedMultipart(input): TypedMultipart<IngressParams>,
|
||||||
) -> Result<impl IntoResponse, ApiError> {
|
) -> Result<impl IntoResponse, HtmlError> {
|
||||||
let user = match auth.current_user {
|
let _user = match auth.current_user {
|
||||||
Some(user) => user,
|
Some(user) => user,
|
||||||
None => return Ok(Redirect::to("/").into_response()),
|
None => return Ok(Redirect::to("/").into_response()),
|
||||||
};
|
};
|
||||||
@@ -60,7 +61,9 @@ pub async fn process_ingress_form(
|
|||||||
// Process files and create FileInfo objects
|
// Process files and create FileInfo objects
|
||||||
let mut file_infos = Vec::new();
|
let mut file_infos = Vec::new();
|
||||||
for file in input.files {
|
for file in input.files {
|
||||||
let file_info = FileInfo::new(file, &state.surreal_db_client).await?;
|
let file_info = FileInfo::new(file, &state.surreal_db_client)
|
||||||
|
.await
|
||||||
|
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
|
||||||
file_infos.push(file_info);
|
file_infos.push(file_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
use axum::{
|
use axum::{
|
||||||
extract::{Query, State},
|
extract::{Query, State},
|
||||||
response::Html,
|
response::{Html, IntoResponse, Redirect},
|
||||||
};
|
};
|
||||||
use axum_session_auth::AuthSession;
|
use axum_session_auth::AuthSession;
|
||||||
use axum_session_surreal::SessionSurrealPool;
|
use axum_session_surreal::SessionSurrealPool;
|
||||||
@@ -9,7 +9,7 @@ use surrealdb::{engine::any::Any, Surreal};
|
|||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
error::ApiError, retrieval::query_helper::get_answer_with_references, server::AppState,
|
error::HtmlError, retrieval::query_helper::get_answer_with_references, server::AppState,
|
||||||
storage::types::user::User,
|
storage::types::user::User,
|
||||||
};
|
};
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
@@ -21,30 +21,22 @@ pub async fn search_result_handler(
|
|||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Query(query): Query<SearchParams>,
|
Query(query): Query<SearchParams>,
|
||||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||||
) -> Result<Html<String>, ApiError> {
|
) -> Result<impl IntoResponse, HtmlError> {
|
||||||
info!("Displaying search results");
|
info!("Displaying search results");
|
||||||
|
|
||||||
let user_id = auth.current_user.ok_or_else(|| ApiError::AuthRequired)?.id;
|
let user = match auth.current_user {
|
||||||
|
Some(user) => user,
|
||||||
|
None => return Ok(Redirect::to("/").into_response()),
|
||||||
|
};
|
||||||
|
|
||||||
let answer = get_answer_with_references(
|
let answer = get_answer_with_references(
|
||||||
&state.surreal_db_client,
|
&state.surreal_db_client,
|
||||||
&state.openai_client,
|
&state.openai_client,
|
||||||
&query.query,
|
&query.query,
|
||||||
&user_id,
|
&user.id,
|
||||||
)
|
)
|
||||||
.await?;
|
.await
|
||||||
|
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||||
|
|
||||||
Ok(Html(answer.content))
|
Ok(Html(answer.content).into_response())
|
||||||
// let output = state
|
|
||||||
// .tera
|
|
||||||
// .render(
|
|
||||||
// "search_result.html",
|
|
||||||
// &Context::from_value(
|
|
||||||
// json!({"result": answer.content, "references": answer.references}),
|
|
||||||
// )
|
|
||||||
// .unwrap(),
|
|
||||||
// )
|
|
||||||
// .unwrap();
|
|
||||||
|
|
||||||
// Ok(output.into())
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,12 @@ use axum_session_auth::AuthSession;
|
|||||||
use axum_session_surreal::SessionSurrealPool;
|
use axum_session_surreal::SessionSurrealPool;
|
||||||
use surrealdb::{engine::any::Any, Surreal};
|
use surrealdb::{engine::any::Any, Surreal};
|
||||||
|
|
||||||
use crate::{error::ApiError, page_data, server::AppState, storage::types::user::User};
|
use crate::{
|
||||||
|
error::{AppError, HtmlError},
|
||||||
|
page_data,
|
||||||
|
server::AppState,
|
||||||
|
storage::types::user::User,
|
||||||
|
};
|
||||||
|
|
||||||
use super::{render_block, render_template};
|
use super::{render_block, render_template};
|
||||||
|
|
||||||
@@ -26,7 +31,7 @@ pub async fn show_signin_form(
|
|||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||||
HxBoosted(boosted): HxBoosted,
|
HxBoosted(boosted): HxBoosted,
|
||||||
) -> Result<impl IntoResponse, ApiError> {
|
) -> Result<impl IntoResponse, HtmlError> {
|
||||||
if auth.is_authenticated() {
|
if auth.is_authenticated() {
|
||||||
return Ok(Redirect::to("/").into_response());
|
return Ok(Redirect::to("/").into_response());
|
||||||
}
|
}
|
||||||
@@ -35,13 +40,15 @@ pub async fn show_signin_form(
|
|||||||
ShowSignInForm::template_name(),
|
ShowSignInForm::template_name(),
|
||||||
"body",
|
"body",
|
||||||
ShowSignInForm {},
|
ShowSignInForm {},
|
||||||
state.templates,
|
state.templates.clone(),
|
||||||
)?,
|
)
|
||||||
|
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?,
|
||||||
false => render_template(
|
false => render_template(
|
||||||
ShowSignInForm::template_name(),
|
ShowSignInForm::template_name(),
|
||||||
ShowSignInForm {},
|
ShowSignInForm {},
|
||||||
state.templates,
|
state.templates.clone(),
|
||||||
)?,
|
)
|
||||||
|
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?,
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(output.into_response())
|
Ok(output.into_response())
|
||||||
@@ -51,11 +58,11 @@ pub async fn authenticate_user(
|
|||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||||
Form(form): Form<SignupParams>,
|
Form(form): Form<SignupParams>,
|
||||||
) -> Result<impl IntoResponse, ApiError> {
|
) -> Result<impl IntoResponse, HtmlError> {
|
||||||
let user = match User::authenticate(form.email, form.password, &state.surreal_db_client).await {
|
let user = match User::authenticate(form.email, form.password, &state.surreal_db_client).await {
|
||||||
Ok(user) => user,
|
Ok(user) => user,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
return Ok(Html("<p>Invalid email or password.</p>").into_response());
|
return Ok(Html("<p>Incorrect email or password </p>").into_response());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,11 @@ use axum_session_surreal::SessionSurrealPool;
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use surrealdb::{engine::any::Any, Surreal};
|
use surrealdb::{engine::any::Any, Surreal};
|
||||||
|
|
||||||
use crate::{error::ApiError, server::AppState, storage::types::user::User};
|
use crate::{
|
||||||
|
error::{AppError, HtmlError},
|
||||||
|
server::AppState,
|
||||||
|
storage::types::user::User,
|
||||||
|
};
|
||||||
|
|
||||||
use super::{render_block, render_template};
|
use super::{render_block, render_template};
|
||||||
|
|
||||||
@@ -29,7 +33,7 @@ pub async fn show_signup_form(
|
|||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||||
HxBoosted(boosted): HxBoosted,
|
HxBoosted(boosted): HxBoosted,
|
||||||
) -> Result<impl IntoResponse, ApiError> {
|
) -> Result<impl IntoResponse, HtmlError> {
|
||||||
if auth.is_authenticated() {
|
if auth.is_authenticated() {
|
||||||
return Ok(Redirect::to("/").into_response());
|
return Ok(Redirect::to("/").into_response());
|
||||||
}
|
}
|
||||||
@@ -38,9 +42,15 @@ pub async fn show_signup_form(
|
|||||||
"auth/signup_form.html",
|
"auth/signup_form.html",
|
||||||
"body",
|
"body",
|
||||||
PageData {},
|
PageData {},
|
||||||
state.templates,
|
state.templates.clone(),
|
||||||
)?,
|
)
|
||||||
false => render_template("auth/signup_form.html", PageData {}, state.templates)?,
|
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?,
|
||||||
|
false => render_template(
|
||||||
|
"auth/signup_form.html",
|
||||||
|
PageData {},
|
||||||
|
state.templates.clone(),
|
||||||
|
)
|
||||||
|
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?,
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(output.into_response())
|
Ok(output.into_response())
|
||||||
@@ -50,7 +60,7 @@ pub async fn process_signup_and_show_verification(
|
|||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||||
Form(form): Form<SignupParams>,
|
Form(form): Form<SignupParams>,
|
||||||
) -> Result<impl IntoResponse, ApiError> {
|
) -> Result<impl IntoResponse, HtmlError> {
|
||||||
let user = match User::create_new(form.email, form.password, &state.surreal_db_client).await {
|
let user = match User::create_new(form.email, form.password, &state.surreal_db_client).await {
|
||||||
Ok(user) => user,
|
Ok(user) => user,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use crate::{error::ProcessingError, stored_object};
|
use crate::{error::AppError, stored_object};
|
||||||
use surrealdb::{engine::any::Any, Surreal};
|
use surrealdb::{engine::any::Any, Surreal};
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
@@ -26,10 +26,7 @@ impl KnowledgeRelationship {
|
|||||||
metadata,
|
metadata,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pub async fn store_relationship(
|
pub async fn store_relationship(&self, db_client: &Surreal<Any>) -> Result<(), AppError> {
|
||||||
&self,
|
|
||||||
db_client: &Surreal<Any>,
|
|
||||||
) -> Result<(), ProcessingError> {
|
|
||||||
let query = format!(
|
let query = format!(
|
||||||
"RELATE knowledge_entity:`{}` -> relates_to -> knowledge_entity:`{}`",
|
"RELATE knowledge_entity:`{}` -> relates_to -> knowledge_entity:`{}`",
|
||||||
self.in_, self.out
|
self.in_, self.out
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
error::ApiError,
|
error::AppError,
|
||||||
storage::db::{get_item, SurrealDbClient},
|
storage::db::{get_item, SurrealDbClient},
|
||||||
stored_object,
|
stored_object,
|
||||||
};
|
};
|
||||||
@@ -41,10 +41,10 @@ impl User {
|
|||||||
email: String,
|
email: String,
|
||||||
password: String,
|
password: String,
|
||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
) -> Result<Self, ApiError> {
|
) -> Result<Self, AppError> {
|
||||||
// Check if user exists
|
// Check if user exists
|
||||||
if (Self::find_by_email(&email, db).await?).is_some() {
|
if (Self::find_by_email(&email, db).await?).is_some() {
|
||||||
return Err(ApiError::UserAlreadyExists);
|
return Err(AppError::Auth("User already exists".into()));
|
||||||
}
|
}
|
||||||
|
|
||||||
let id = Uuid::new_v4().to_string();
|
let id = Uuid::new_v4().to_string();
|
||||||
@@ -62,14 +62,14 @@ impl User {
|
|||||||
.await?
|
.await?
|
||||||
.take(0)?;
|
.take(0)?;
|
||||||
|
|
||||||
user.ok_or(ApiError::UserAlreadyExists)
|
user.ok_or(AppError::Auth("User failed to create".into()))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn authenticate(
|
pub async fn authenticate(
|
||||||
email: String,
|
email: String,
|
||||||
password: String,
|
password: String,
|
||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
) -> Result<Self, ApiError> {
|
) -> Result<Self, AppError> {
|
||||||
let user: Option<User> = db
|
let user: Option<User> = db
|
||||||
.client
|
.client
|
||||||
.query(
|
.query(
|
||||||
@@ -81,13 +81,13 @@ impl User {
|
|||||||
.bind(("password", password))
|
.bind(("password", password))
|
||||||
.await?
|
.await?
|
||||||
.take(0)?;
|
.take(0)?;
|
||||||
user.ok_or(ApiError::UserAlreadyExists)
|
user.ok_or(AppError::Auth("User failed to authenticate".into()))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn find_by_email(
|
pub async fn find_by_email(
|
||||||
email: &str,
|
email: &str,
|
||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
) -> Result<Option<Self>, ApiError> {
|
) -> Result<Option<Self>, AppError> {
|
||||||
let user: Option<User> = db
|
let user: Option<User> = db
|
||||||
.client
|
.client
|
||||||
.query("SELECT * FROM user WHERE email = $email LIMIT 1")
|
.query("SELECT * FROM user WHERE email = $email LIMIT 1")
|
||||||
@@ -101,7 +101,7 @@ impl User {
|
|||||||
pub async fn find_by_api_key(
|
pub async fn find_by_api_key(
|
||||||
api_key: &str,
|
api_key: &str,
|
||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
) -> Result<Option<Self>, ApiError> {
|
) -> Result<Option<Self>, AppError> {
|
||||||
let user: Option<User> = db
|
let user: Option<User> = db
|
||||||
.client
|
.client
|
||||||
.query("SELECT * FROM user WHERE api_key = $api_key LIMIT 1")
|
.query("SELECT * FROM user WHERE api_key = $api_key LIMIT 1")
|
||||||
@@ -112,7 +112,7 @@ impl User {
|
|||||||
Ok(user)
|
Ok(user)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn set_api_key(id: &str, db: &SurrealDbClient) -> Result<String, ApiError> {
|
pub async fn set_api_key(id: &str, db: &SurrealDbClient) -> Result<String, AppError> {
|
||||||
// Generate a secure random API key
|
// Generate a secure random API key
|
||||||
let api_key = format!("sk_{}", Uuid::new_v4().to_string().replace("-", ""));
|
let api_key = format!("sk_{}", Uuid::new_v4().to_string().replace("-", ""));
|
||||||
|
|
||||||
@@ -133,11 +133,11 @@ impl User {
|
|||||||
if user.is_some() {
|
if user.is_some() {
|
||||||
Ok(api_key)
|
Ok(api_key)
|
||||||
} else {
|
} else {
|
||||||
Err(ApiError::UserNotFound)
|
Err(AppError::Auth("User not found".into()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn revoke_api_key(id: &str, db: &SurrealDbClient) -> Result<(), ApiError> {
|
pub async fn revoke_api_key(id: &str, db: &SurrealDbClient) -> Result<(), AppError> {
|
||||||
let user: Option<User> = db
|
let user: Option<User> = db
|
||||||
.client
|
.client
|
||||||
.query(
|
.query(
|
||||||
@@ -152,14 +152,14 @@ impl User {
|
|||||||
if user.is_some() {
|
if user.is_some() {
|
||||||
Ok(())
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
Err(ApiError::UserNotFound)
|
Err(AppError::Auth("User was not found".into()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_knowledge_entities(
|
pub async fn get_knowledge_entities(
|
||||||
id: &str,
|
id: &str,
|
||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
) -> Result<Vec<KnowledgeEntity>, ApiError> {
|
) -> Result<Vec<KnowledgeEntity>, AppError> {
|
||||||
let entities: Vec<KnowledgeEntity> = db
|
let entities: Vec<KnowledgeEntity> = db
|
||||||
.client
|
.client
|
||||||
.query("SELECT * FROM knowledge_entity WHERE user_id = $user_id")
|
.query("SELECT * FROM knowledge_entity WHERE user_id = $user_id")
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
use async_openai::types::CreateEmbeddingRequestArgs;
|
use async_openai::types::CreateEmbeddingRequestArgs;
|
||||||
|
|
||||||
use crate::error::ProcessingError;
|
use crate::error::AppError;
|
||||||
|
|
||||||
/// Generates an embedding vector for the given input text using OpenAI's embedding model.
|
/// Generates an embedding vector for the given input text using OpenAI's embedding model.
|
||||||
///
|
///
|
||||||
/// This function takes a text input and converts it into a numerical vector representation (embedding)
|
/// This function takes a text input and converts it into a numerical vector representation (embedding)
|
||||||
@@ -21,14 +20,14 @@ use crate::error::ProcessingError;
|
|||||||
///
|
///
|
||||||
/// # Errors
|
/// # Errors
|
||||||
///
|
///
|
||||||
/// This function can return a `ProcessingError` in the following cases:
|
/// This function can return a `AppError` in the following cases:
|
||||||
/// * If the OpenAI API request fails
|
/// * If the OpenAI API request fails
|
||||||
/// * If the request building fails
|
/// * If the request building fails
|
||||||
/// * If no embedding data is received in the response
|
/// * If no embedding data is received in the response
|
||||||
pub async fn generate_embedding(
|
pub async fn generate_embedding(
|
||||||
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||||
input: &str,
|
input: &str,
|
||||||
) -> Result<Vec<f32>, ProcessingError> {
|
) -> Result<Vec<f32>, AppError> {
|
||||||
let request = CreateEmbeddingRequestArgs::default()
|
let request = CreateEmbeddingRequestArgs::default()
|
||||||
.model("text-embedding-3-small")
|
.model("text-embedding-3-small")
|
||||||
.input([input])
|
.input([input])
|
||||||
@@ -41,7 +40,7 @@ pub async fn generate_embedding(
|
|||||||
let embedding: Vec<f32> = response
|
let embedding: Vec<f32> = response
|
||||||
.data
|
.data
|
||||||
.first()
|
.first()
|
||||||
.ok_or_else(|| ProcessingError::EmbeddingError("No embedding data received".into()))?
|
.ok_or_else(|| AppError::LLMParsing("No embedding data received".into()))?
|
||||||
.embedding
|
.embedding
|
||||||
.clone();
|
.clone();
|
||||||
|
|
||||||
|
|||||||
14
templates/errors/error.html
Normal file
14
templates/errors/error.html
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
{% extends "body_base.html" %}
|
||||||
|
|
||||||
|
{% block main %}
|
||||||
|
<div class="container mx-auto px-4 flex items-center justify-center">
|
||||||
|
<div class="flex flex-col space-y-4 text-center">
|
||||||
|
<h1 class="text-2xl font-bold text-error">
|
||||||
|
{{ status_code }}
|
||||||
|
</h1>
|
||||||
|
<p class="text-2xl my-4">{{ error }}</p>
|
||||||
|
<p class="text-base-content/60">{{ description }}</p>
|
||||||
|
<a href="/" class="btn btn-primary mt-8">Go Home</a>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{% endblock %}
|
||||||
Reference in New Issue
Block a user