mirror of
https://github.com/perstarkse/minne.git
synced 2026-03-18 15:34:16 +01:00
refactor: better separation of dependencies to crates
node stuff to html crate only
This commit is contained in:
31
api-router/src/api_state.rs
Normal file
31
api-router/src/api_state.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use common::{storage::db::SurrealDbClient, utils::config::AppConfig};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ApiState {
|
||||
pub db: Arc<SurrealDbClient>,
|
||||
}
|
||||
|
||||
impl ApiState {
|
||||
pub async fn new(config: &AppConfig) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let surreal_db_client = Arc::new(
|
||||
SurrealDbClient::new(
|
||||
&config.surrealdb_address,
|
||||
&config.surrealdb_username,
|
||||
&config.surrealdb_password,
|
||||
&config.surrealdb_namespace,
|
||||
&config.surrealdb_database,
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
|
||||
surreal_db_client.ensure_initialized().await?;
|
||||
|
||||
let app_state = ApiState {
|
||||
db: surreal_db_client.clone(),
|
||||
};
|
||||
|
||||
Ok(app_state)
|
||||
}
|
||||
}
|
||||
168
api-router/src/error.rs
Normal file
168
api-router/src/error.rs
Normal file
@@ -0,0 +1,168 @@
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use common::error::AppError;
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug, Serialize, Clone)]
|
||||
pub enum ApiError {
|
||||
#[error("Internal server error")]
|
||||
InternalError(String),
|
||||
|
||||
#[error("Validation error: {0}")]
|
||||
ValidationError(String),
|
||||
|
||||
#[error("Not found: {0}")]
|
||||
NotFound(String),
|
||||
|
||||
#[error("Unauthorized: {0}")]
|
||||
Unauthorized(String),
|
||||
}
|
||||
|
||||
impl From<AppError> for ApiError {
|
||||
fn from(err: AppError) -> Self {
|
||||
match err {
|
||||
AppError::Database(_) | AppError::OpenAI(_) => {
|
||||
tracing::error!("Internal error: {:?}", err);
|
||||
ApiError::InternalError("Internal server error".to_string())
|
||||
}
|
||||
AppError::NotFound(msg) => ApiError::NotFound(msg),
|
||||
AppError::Validation(msg) => ApiError::ValidationError(msg),
|
||||
AppError::Auth(msg) => ApiError::Unauthorized(msg),
|
||||
_ => ApiError::InternalError("Internal server error".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl IntoResponse for ApiError {
|
||||
fn into_response(self) -> Response {
|
||||
let (status, error_response) = match self {
|
||||
ApiError::InternalError(message) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
ErrorResponse {
|
||||
error: message,
|
||||
status: "error".to_string(),
|
||||
},
|
||||
),
|
||||
ApiError::ValidationError(message) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
ErrorResponse {
|
||||
error: message,
|
||||
status: "error".to_string(),
|
||||
},
|
||||
),
|
||||
ApiError::NotFound(message) => (
|
||||
StatusCode::NOT_FOUND,
|
||||
ErrorResponse {
|
||||
error: message,
|
||||
status: "error".to_string(),
|
||||
},
|
||||
),
|
||||
ApiError::Unauthorized(message) => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
ErrorResponse {
|
||||
error: message,
|
||||
status: "error".to_string(),
|
||||
},
|
||||
),
|
||||
};
|
||||
|
||||
(status, Json(error_response)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
struct ErrorResponse {
|
||||
error: String,
|
||||
status: String,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use common::error::AppError;
|
||||
use std::fmt::Debug;
|
||||
|
||||
// Helper to check status code
|
||||
fn assert_status_code<T: IntoResponse + Debug>(response: T, expected_status: StatusCode) {
|
||||
let response = response.into_response();
|
||||
assert_eq!(response.status(), expected_status);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_app_error_to_api_error_conversion() {
|
||||
// Test NotFound error conversion
|
||||
let not_found = AppError::NotFound("resource not found".to_string());
|
||||
let api_error = ApiError::from(not_found);
|
||||
assert!(matches!(api_error, ApiError::NotFound(msg) if msg == "resource not found"));
|
||||
|
||||
// Test Validation error conversion
|
||||
let validation = AppError::Validation("invalid input".to_string());
|
||||
let api_error = ApiError::from(validation);
|
||||
assert!(matches!(api_error, ApiError::ValidationError(msg) if msg == "invalid input"));
|
||||
|
||||
// Test Auth error conversion
|
||||
let auth = AppError::Auth("unauthorized".to_string());
|
||||
let api_error = ApiError::from(auth);
|
||||
assert!(matches!(api_error, ApiError::Unauthorized(msg) if msg == "unauthorized"));
|
||||
|
||||
// Test for internal errors - create a mock error that doesn't require surrealdb
|
||||
let internal_error =
|
||||
AppError::Io(std::io::Error::new(std::io::ErrorKind::Other, "io error"));
|
||||
let api_error = ApiError::from(internal_error);
|
||||
assert!(matches!(api_error, ApiError::InternalError(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_api_error_response_status_codes() {
|
||||
// Test internal error status
|
||||
let error = ApiError::InternalError("server error".to_string());
|
||||
assert_status_code(error, StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
// Test not found status
|
||||
let error = ApiError::NotFound("not found".to_string());
|
||||
assert_status_code(error, StatusCode::NOT_FOUND);
|
||||
|
||||
// Test validation error status
|
||||
let error = ApiError::ValidationError("invalid input".to_string());
|
||||
assert_status_code(error, StatusCode::BAD_REQUEST);
|
||||
|
||||
// Test unauthorized status
|
||||
let error = ApiError::Unauthorized("not allowed".to_string());
|
||||
assert_status_code(error, StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
// Alternative approach that doesn't try to parse the response body
|
||||
#[test]
|
||||
fn test_error_messages() {
|
||||
// For validation errors
|
||||
let message = "invalid data format";
|
||||
let error = ApiError::ValidationError(message.to_string());
|
||||
|
||||
// Check that the error itself contains the message
|
||||
assert_eq!(error.to_string(), format!("Validation error: {}", message));
|
||||
|
||||
// For not found errors
|
||||
let message = "user not found";
|
||||
let error = ApiError::NotFound(message.to_string());
|
||||
assert_eq!(error.to_string(), format!("Not found: {}", message));
|
||||
}
|
||||
|
||||
// Alternative approach for internal error test
|
||||
#[test]
|
||||
fn test_internal_error_sanitization() {
|
||||
// Create a sensitive error message
|
||||
let sensitive_info = "db password incorrect";
|
||||
|
||||
// Create ApiError with sensitive info
|
||||
let api_error = ApiError::InternalError(sensitive_info.to_string());
|
||||
|
||||
// Check the error message is correctly set
|
||||
assert_eq!(api_error.to_string(), "Internal server error");
|
||||
|
||||
// Also verify correct status code
|
||||
assert_status_code(api_error, StatusCode::INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
}
|
||||
26
api-router/src/lib.rs
Normal file
26
api-router/src/lib.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
use api_state::ApiState;
|
||||
use axum::{
|
||||
extract::{DefaultBodyLimit, FromRef},
|
||||
middleware::from_fn_with_state,
|
||||
routing::post,
|
||||
Router,
|
||||
};
|
||||
use middleware_api_auth::api_auth;
|
||||
use routes::ingress::ingest_data;
|
||||
|
||||
pub mod api_state;
|
||||
pub mod error;
|
||||
mod middleware_api_auth;
|
||||
mod routes;
|
||||
|
||||
/// Router for API functionality, version 1
|
||||
pub fn api_routes_v1<S>(app_state: &ApiState) -> Router<S>
|
||||
where
|
||||
S: Clone + Send + Sync + 'static,
|
||||
ApiState: FromRef<S>,
|
||||
{
|
||||
Router::new()
|
||||
.route("/ingress", post(ingest_data))
|
||||
.layer(DefaultBodyLimit::max(1024 * 1024 * 1024))
|
||||
.route_layer(from_fn_with_state(app_state.clone(), api_auth))
|
||||
}
|
||||
43
api-router/src/middleware_api_auth.rs
Normal file
43
api-router/src/middleware_api_auth.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
|
||||
use common::storage::types::user::User;
|
||||
|
||||
use crate::{api_state::ApiState, error::ApiError};
|
||||
|
||||
pub async fn api_auth(
|
||||
State(state): State<ApiState>,
|
||||
mut request: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, ApiError> {
|
||||
let api_key = extract_api_key(&request).ok_or(ApiError::Unauthorized(
|
||||
"You have to be authenticated".to_string(),
|
||||
))?;
|
||||
|
||||
let user = User::find_by_api_key(&api_key, &state.db).await?;
|
||||
let user = user.ok_or(ApiError::Unauthorized(
|
||||
"You have to be authenticated".to_string(),
|
||||
))?;
|
||||
|
||||
request.extensions_mut().insert(user);
|
||||
|
||||
Ok(next.run(request).await)
|
||||
}
|
||||
|
||||
fn extract_api_key(request: &Request) -> Option<String> {
|
||||
request
|
||||
.headers()
|
||||
.get("X-API-Key")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.or_else(|| {
|
||||
request
|
||||
.headers()
|
||||
.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|auth| auth.strip_prefix("Bearer ").map(|s| s.trim()))
|
||||
})
|
||||
.map(String::from)
|
||||
}
|
||||
59
api-router/src/routes/ingress.rs
Normal file
59
api-router/src/routes/ingress.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension};
|
||||
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::types::{
|
||||
file_info::FileInfo, ingestion_payload::IngestionPayload, ingestion_task::IngestionTask,
|
||||
user::User,
|
||||
},
|
||||
};
|
||||
use futures::{future::try_join_all, TryFutureExt};
|
||||
use tempfile::NamedTempFile;
|
||||
use tracing::info;
|
||||
|
||||
use crate::{api_state::ApiState, error::ApiError};
|
||||
|
||||
#[derive(Debug, TryFromMultipart)]
|
||||
pub struct IngestParams {
|
||||
pub content: Option<String>,
|
||||
pub instructions: String,
|
||||
pub category: String,
|
||||
#[form_data(limit = "10000000")] // Adjust limit as needed
|
||||
#[form_data(default)]
|
||||
pub files: Vec<FieldData<NamedTempFile>>,
|
||||
}
|
||||
|
||||
pub async fn ingest_data(
|
||||
State(state): State<ApiState>,
|
||||
Extension(user): Extension<User>,
|
||||
TypedMultipart(input): TypedMultipart<IngestParams>,
|
||||
) -> Result<impl IntoResponse, ApiError> {
|
||||
info!("Received input: {:?}", input);
|
||||
|
||||
let file_infos = try_join_all(
|
||||
input
|
||||
.files
|
||||
.into_iter()
|
||||
.map(|file| FileInfo::new(file, &state.db, &user.id).map_err(AppError::from)),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let payloads = IngestionPayload::create_ingestion_payload(
|
||||
input.content,
|
||||
input.instructions,
|
||||
input.category,
|
||||
file_infos,
|
||||
user.id.as_str(),
|
||||
)?;
|
||||
|
||||
let futures: Vec<_> = payloads
|
||||
.into_iter()
|
||||
.map(|object| {
|
||||
IngestionTask::create_and_add_to_db(object.clone(), user.id.clone(), &state.db)
|
||||
})
|
||||
.collect();
|
||||
|
||||
try_join_all(futures).await.map_err(AppError::from)?;
|
||||
|
||||
Ok(StatusCode::OK)
|
||||
}
|
||||
1
api-router/src/routes/mod.rs
Normal file
1
api-router/src/routes/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod ingress;
|
||||
Reference in New Issue
Block a user