mirror of
https://github.com/perstarkse/minne.git
synced 2026-05-31 03:40:38 +02:00
chore: harden api-router errors and add router integration tests while slimming html handlers.
This commit is contained in:
@@ -11,31 +11,3 @@ pub struct ApiState {
|
||||
pub config: AppConfig,
|
||||
pub storage: StorageManager,
|
||||
}
|
||||
|
||||
impl ApiState {
|
||||
pub async fn new(
|
||||
config: &AppConfig,
|
||||
storage: StorageManager,
|
||||
) -> anyhow::Result<Self> {
|
||||
let surreal_db_client = Arc::new(
|
||||
SurrealDbClient::new(
|
||||
&config.surrealdb_address,
|
||||
&config.surrealdb_username,
|
||||
&config.surrealdb_password,
|
||||
&config.surrealdb_namespace,
|
||||
&config.surrealdb_database,
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
|
||||
surreal_db_client.apply_migrations().await?;
|
||||
|
||||
let app_state = Self {
|
||||
db: Arc::clone(&surreal_db_client),
|
||||
config: config.clone(),
|
||||
storage,
|
||||
};
|
||||
|
||||
Ok(app_state)
|
||||
}
|
||||
}
|
||||
|
||||
+19
-9
@@ -7,7 +7,7 @@ use common::error::AppError;
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug, Serialize, Clone)]
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ApiErr {
|
||||
#[error("internal server error")]
|
||||
InternalError(String),
|
||||
@@ -28,14 +28,13 @@ pub enum ApiErr {
|
||||
impl From<AppError> for ApiErr {
|
||||
fn from(err: AppError) -> Self {
|
||||
match err {
|
||||
AppError::Database(_) | AppError::OpenAI(_) => {
|
||||
tracing::error!("Internal error: {:?}", err);
|
||||
Self::InternalError("Internal server error".to_string())
|
||||
}
|
||||
AppError::NotFound(msg) => Self::NotFound(msg),
|
||||
AppError::Validation(msg) => Self::ValidationError(msg),
|
||||
AppError::Auth(msg) => Self::Unauthorized(msg),
|
||||
_ => Self::InternalError("Internal server error".to_string()),
|
||||
other => {
|
||||
tracing::error!("internal API error: {other:?}");
|
||||
Self::InternalError("Internal server error".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -120,10 +119,21 @@ mod tests {
|
||||
assert!(matches!(api_error, ApiErr::Unauthorized(msg) if msg == "unauthorized"));
|
||||
|
||||
// Test for internal errors - create a mock error that doesn't require surrealdb
|
||||
let internal_error =
|
||||
AppError::Io(io::Error::other("io error"));
|
||||
let internal_error = AppError::Io(io::Error::other("io error"));
|
||||
let api_error = ApiErr::from(internal_error);
|
||||
assert!(matches!(api_error, ApiErr::InternalError(_)));
|
||||
assert!(matches!(
|
||||
api_error,
|
||||
ApiErr::InternalError(msg) if msg == "Internal server error"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_app_error_internal_error_is_sanitized() {
|
||||
let api_error = ApiErr::from(AppError::internal("db password incorrect"));
|
||||
assert!(matches!(
|
||||
api_error,
|
||||
ApiErr::InternalError(msg) if msg == "Internal server error"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -6,7 +6,7 @@ use axum::{
|
||||
Router,
|
||||
};
|
||||
use middleware_api_auth::api_auth;
|
||||
use routes::{categories::list, ingest::ingest_data, liveness::live, readiness::ready};
|
||||
use routes::{categories::list, ingest::handle, liveness::live, readiness::ready};
|
||||
|
||||
pub mod api_state;
|
||||
pub mod error;
|
||||
@@ -28,7 +28,7 @@ where
|
||||
let protected = Router::new()
|
||||
.route(
|
||||
"/ingest",
|
||||
post(ingest_data).layer(DefaultBodyLimit::max(
|
||||
post(handle).layer(DefaultBodyLimit::max(
|
||||
app_state.config.ingest_max_body_bytes,
|
||||
)),
|
||||
)
|
||||
|
||||
@@ -16,7 +16,7 @@ pub async fn api_auth(
|
||||
let api_key = extract_api_key(&request)
|
||||
.ok_or_else(|| ApiErr::Unauthorized("You have to be authenticated".to_string()))?;
|
||||
|
||||
let user = User::find_by_api_key(&api_key, &state.db).await?;
|
||||
let user = User::find_by_api_key(api_key, &state.db).await?;
|
||||
let user =
|
||||
user.ok_or_else(|| ApiErr::Unauthorized("You have to be authenticated".to_string()))?;
|
||||
|
||||
@@ -25,7 +25,7 @@ pub async fn api_auth(
|
||||
Ok(next.run(request).await)
|
||||
}
|
||||
|
||||
fn extract_api_key(request: &Request) -> Option<String> {
|
||||
fn extract_api_key(request: &Request) -> Option<&str> {
|
||||
request
|
||||
.headers()
|
||||
.get("X-API-Key")
|
||||
@@ -35,7 +35,66 @@ fn extract_api_key(request: &Request) -> Option<String> {
|
||||
.headers()
|
||||
.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|auth| auth.strip_prefix("Bearer ").map(str::trim))
|
||||
.and_then(|auth| auth.strip_prefix("Bearer "))
|
||||
.map(str::trim)
|
||||
})
|
||||
.map(String::from)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::expect_used)]
|
||||
mod tests {
|
||||
use axum::body::Body;
|
||||
use axum::http::{HeaderValue, Request};
|
||||
|
||||
use super::extract_api_key;
|
||||
|
||||
fn request_with_headers(headers: &[(&str, &str)]) -> Request<Body> {
|
||||
let mut builder = Request::builder().method("GET").uri("/");
|
||||
for (name, value) in headers {
|
||||
builder = builder.header(*name, *value);
|
||||
}
|
||||
builder.body(Body::empty()).expect("test request")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_api_key_from_x_api_key_header() {
|
||||
let request = request_with_headers(&[("X-API-Key", "sk_test_key")]);
|
||||
assert_eq!(extract_api_key(&request), Some("sk_test_key"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_api_key_from_bearer_authorization() {
|
||||
let request = request_with_headers(&[("Authorization", "Bearer sk_bearer_key")]);
|
||||
assert_eq!(extract_api_key(&request), Some("sk_bearer_key"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_api_key_prefers_x_api_key_over_authorization() {
|
||||
let request = request_with_headers(&[
|
||||
("X-API-Key", "sk_header"),
|
||||
("Authorization", "Bearer sk_bearer"),
|
||||
]);
|
||||
assert_eq!(extract_api_key(&request), Some("sk_header"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_api_key_returns_none_when_missing() {
|
||||
let request = request_with_headers(&[]);
|
||||
assert_eq!(extract_api_key(&request), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_api_key_rejects_non_bearer_authorization() {
|
||||
let request = request_with_headers(&[("Authorization", "Basic abc")]);
|
||||
assert_eq!(extract_api_key(&request), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_api_key_rejects_invalid_header_values() {
|
||||
let mut request = request_with_headers(&[]);
|
||||
request
|
||||
.headers_mut()
|
||||
.insert("X-API-Key", HeaderValue::from_bytes(&[0xFF]).expect("invalid header"));
|
||||
assert_eq!(extract_api_key(&request), None);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ use tracing::info;
|
||||
use crate::{api_state::ApiState, error::ApiErr};
|
||||
|
||||
#[derive(Debug, TryFromMultipart)]
|
||||
pub struct IngestParams {
|
||||
pub struct Params {
|
||||
pub content: Option<String>,
|
||||
pub context: String,
|
||||
pub category: String,
|
||||
@@ -25,24 +25,20 @@ pub struct IngestParams {
|
||||
pub files: Vec<FieldData<NamedTempFile>>,
|
||||
}
|
||||
|
||||
pub async fn ingest_data(
|
||||
pub async fn handle(
|
||||
State(state): State<ApiState>,
|
||||
Extension(user): Extension<User>,
|
||||
TypedMultipart(input): TypedMultipart<IngestParams>,
|
||||
TypedMultipart(input): TypedMultipart<Params>,
|
||||
) -> Result<impl IntoResponse, ApiErr> {
|
||||
let user_id = user.id;
|
||||
let content_bytes = input.content.as_ref().map_or(0, |c| c.len());
|
||||
let has_content = input.content.as_ref().is_some_and(|c| !c.trim().is_empty());
|
||||
let context_bytes = input.context.len();
|
||||
let category_bytes = input.category.len();
|
||||
let file_count = input.files.len();
|
||||
|
||||
match validate_ingest_input(
|
||||
&state.config,
|
||||
input.content.as_deref(),
|
||||
&input.context,
|
||||
&input.category,
|
||||
file_count,
|
||||
input.files.len(),
|
||||
) {
|
||||
Ok(()) => {}
|
||||
Err(IngestValidationError::PayloadTooLarge(message)) => {
|
||||
@@ -56,10 +52,10 @@ pub async fn ingest_data(
|
||||
info!(
|
||||
user_id = %user_id,
|
||||
has_content,
|
||||
content_bytes,
|
||||
context_bytes,
|
||||
category_bytes,
|
||||
file_count,
|
||||
content_len = input.content.as_ref().map_or(0, String::len),
|
||||
context_len = input.context.len(),
|
||||
category_len = input.category.len(),
|
||||
file_count = input.files.len(),
|
||||
"Received ingest request"
|
||||
);
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
|
||||
use serde_json::json;
|
||||
use tracing::error;
|
||||
|
||||
use crate::api_state::ApiState;
|
||||
|
||||
@@ -13,13 +14,15 @@ pub async fn ready(State(state): State<ApiState>) -> impl IntoResponse {
|
||||
"checks": { "db": "ok" }
|
||||
})),
|
||||
),
|
||||
Err(e) => (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(json!({
|
||||
"status": "error",
|
||||
"checks": { "db": "fail" },
|
||||
"reason": e.to_string()
|
||||
})),
|
||||
),
|
||||
Err(e) => {
|
||||
error!("readiness check failed: {e:?}");
|
||||
(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(json!({
|
||||
"status": "error",
|
||||
"checks": { "db": "fail" }
|
||||
})),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user