chore: harden api-router errors and add router integration tests while slimming html handlers.

This commit is contained in:
Per Stark
2026-05-30 11:39:47 +02:00
parent 2aa92b6ad7
commit c70141de35
26 changed files with 814 additions and 260 deletions
+5
View File
@@ -20,3 +20,8 @@ futures = { workspace = true }
axum_typed_multipart = { workspace = true}
common = { path = "../common" }
[dev-dependencies]
common = { path = "../common", features = ["test-utils"] }
tower = "0.5"
uuid = { workspace = true }
-28
View File
@@ -11,31 +11,3 @@ pub struct ApiState {
pub config: AppConfig,
pub storage: StorageManager,
}
impl ApiState {
pub async fn new(
config: &AppConfig,
storage: StorageManager,
) -> anyhow::Result<Self> {
let surreal_db_client = Arc::new(
SurrealDbClient::new(
&config.surrealdb_address,
&config.surrealdb_username,
&config.surrealdb_password,
&config.surrealdb_namespace,
&config.surrealdb_database,
)
.await?,
);
surreal_db_client.apply_migrations().await?;
let app_state = Self {
db: Arc::clone(&surreal_db_client),
config: config.clone(),
storage,
};
Ok(app_state)
}
}
+19 -9
View File
@@ -7,7 +7,7 @@ use common::error::AppError;
use serde::Serialize;
use thiserror::Error;
#[derive(Error, Debug, Serialize, Clone)]
#[derive(Error, Debug)]
pub enum ApiErr {
#[error("internal server error")]
InternalError(String),
@@ -28,14 +28,13 @@ pub enum ApiErr {
impl From<AppError> for ApiErr {
fn from(err: AppError) -> Self {
match err {
AppError::Database(_) | AppError::OpenAI(_) => {
tracing::error!("Internal error: {:?}", err);
Self::InternalError("Internal server error".to_string())
}
AppError::NotFound(msg) => Self::NotFound(msg),
AppError::Validation(msg) => Self::ValidationError(msg),
AppError::Auth(msg) => Self::Unauthorized(msg),
_ => Self::InternalError("Internal server error".to_string()),
other => {
tracing::error!("internal API error: {other:?}");
Self::InternalError("Internal server error".to_string())
}
}
}
}
@@ -120,10 +119,21 @@ mod tests {
assert!(matches!(api_error, ApiErr::Unauthorized(msg) if msg == "unauthorized"));
// Test for internal errors - create a mock error that doesn't require surrealdb
let internal_error =
AppError::Io(io::Error::other("io error"));
let internal_error = AppError::Io(io::Error::other("io error"));
let api_error = ApiErr::from(internal_error);
assert!(matches!(api_error, ApiErr::InternalError(_)));
assert!(matches!(
api_error,
ApiErr::InternalError(msg) if msg == "Internal server error"
));
}
#[test]
fn test_app_error_internal_error_is_sanitized() {
let api_error = ApiErr::from(AppError::internal("db password incorrect"));
assert!(matches!(
api_error,
ApiErr::InternalError(msg) if msg == "Internal server error"
));
}
#[test]
+2 -2
View File
@@ -6,7 +6,7 @@ use axum::{
Router,
};
use middleware_api_auth::api_auth;
use routes::{categories::list, ingest::ingest_data, liveness::live, readiness::ready};
use routes::{categories::list, ingest::handle, liveness::live, readiness::ready};
pub mod api_state;
pub mod error;
@@ -28,7 +28,7 @@ where
let protected = Router::new()
.route(
"/ingest",
post(ingest_data).layer(DefaultBodyLimit::max(
post(handle).layer(DefaultBodyLimit::max(
app_state.config.ingest_max_body_bytes,
)),
)
+63 -4
View File
@@ -16,7 +16,7 @@ pub async fn api_auth(
let api_key = extract_api_key(&request)
.ok_or_else(|| ApiErr::Unauthorized("You have to be authenticated".to_string()))?;
let user = User::find_by_api_key(&api_key, &state.db).await?;
let user = User::find_by_api_key(api_key, &state.db).await?;
let user =
user.ok_or_else(|| ApiErr::Unauthorized("You have to be authenticated".to_string()))?;
@@ -25,7 +25,7 @@ pub async fn api_auth(
Ok(next.run(request).await)
}
fn extract_api_key(request: &Request) -> Option<String> {
fn extract_api_key(request: &Request) -> Option<&str> {
request
.headers()
.get("X-API-Key")
@@ -35,7 +35,66 @@ fn extract_api_key(request: &Request) -> Option<String> {
.headers()
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|auth| auth.strip_prefix("Bearer ").map(str::trim))
.and_then(|auth| auth.strip_prefix("Bearer "))
.map(str::trim)
})
.map(String::from)
}
#[cfg(test)]
#[allow(clippy::expect_used)]
mod tests {
use axum::body::Body;
use axum::http::{HeaderValue, Request};
use super::extract_api_key;
fn request_with_headers(headers: &[(&str, &str)]) -> Request<Body> {
let mut builder = Request::builder().method("GET").uri("/");
for (name, value) in headers {
builder = builder.header(*name, *value);
}
builder.body(Body::empty()).expect("test request")
}
#[test]
fn extract_api_key_from_x_api_key_header() {
let request = request_with_headers(&[("X-API-Key", "sk_test_key")]);
assert_eq!(extract_api_key(&request), Some("sk_test_key"));
}
#[test]
fn extract_api_key_from_bearer_authorization() {
let request = request_with_headers(&[("Authorization", "Bearer sk_bearer_key")]);
assert_eq!(extract_api_key(&request), Some("sk_bearer_key"));
}
#[test]
fn extract_api_key_prefers_x_api_key_over_authorization() {
let request = request_with_headers(&[
("X-API-Key", "sk_header"),
("Authorization", "Bearer sk_bearer"),
]);
assert_eq!(extract_api_key(&request), Some("sk_header"));
}
#[test]
fn extract_api_key_returns_none_when_missing() {
let request = request_with_headers(&[]);
assert_eq!(extract_api_key(&request), None);
}
#[test]
fn extract_api_key_rejects_non_bearer_authorization() {
let request = request_with_headers(&[("Authorization", "Basic abc")]);
assert_eq!(extract_api_key(&request), None);
}
#[test]
fn extract_api_key_rejects_invalid_header_values() {
let mut request = request_with_headers(&[]);
request
.headers_mut()
.insert("X-API-Key", HeaderValue::from_bytes(&[0xFF]).expect("invalid header"));
assert_eq!(extract_api_key(&request), None);
}
}
+8 -12
View File
@@ -16,7 +16,7 @@ use tracing::info;
use crate::{api_state::ApiState, error::ApiErr};
#[derive(Debug, TryFromMultipart)]
pub struct IngestParams {
pub struct Params {
pub content: Option<String>,
pub context: String,
pub category: String,
@@ -25,24 +25,20 @@ pub struct IngestParams {
pub files: Vec<FieldData<NamedTempFile>>,
}
pub async fn ingest_data(
pub async fn handle(
State(state): State<ApiState>,
Extension(user): Extension<User>,
TypedMultipart(input): TypedMultipart<IngestParams>,
TypedMultipart(input): TypedMultipart<Params>,
) -> Result<impl IntoResponse, ApiErr> {
let user_id = user.id;
let content_bytes = input.content.as_ref().map_or(0, |c| c.len());
let has_content = input.content.as_ref().is_some_and(|c| !c.trim().is_empty());
let context_bytes = input.context.len();
let category_bytes = input.category.len();
let file_count = input.files.len();
match validate_ingest_input(
&state.config,
input.content.as_deref(),
&input.context,
&input.category,
file_count,
input.files.len(),
) {
Ok(()) => {}
Err(IngestValidationError::PayloadTooLarge(message)) => {
@@ -56,10 +52,10 @@ pub async fn ingest_data(
info!(
user_id = %user_id,
has_content,
content_bytes,
context_bytes,
category_bytes,
file_count,
content_len = input.content.as_ref().map_or(0, String::len),
context_len = input.context.len(),
category_len = input.category.len(),
file_count = input.files.len(),
"Received ingest request"
);
+11 -8
View File
@@ -1,5 +1,6 @@
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
use serde_json::json;
use tracing::error;
use crate::api_state::ApiState;
@@ -13,13 +14,15 @@ pub async fn ready(State(state): State<ApiState>) -> impl IntoResponse {
"checks": { "db": "ok" }
})),
),
Err(e) => (
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({
"status": "error",
"checks": { "db": "fail" },
"reason": e.to_string()
})),
),
Err(e) => {
error!("readiness check failed: {e:?}");
(
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({
"status": "error",
"checks": { "db": "fail" }
})),
)
}
}
}
+167
View File
@@ -0,0 +1,167 @@
#![allow(clippy::expect_used)]
use std::sync::Arc;
use api_router::{api_routes_v1, api_state::ApiState};
use axum::{
body::{to_bytes, Body},
http::{Request, StatusCode},
Router,
};
use common::{
storage::{
db::SurrealDbClient,
store::StorageManager,
types::user::User,
},
utils::config::{AppConfig, StorageKind},
};
use tower::ServiceExt;
async fn build_test_app() -> (Router, Arc<SurrealDbClient>) {
let namespace = "api_router_test";
let database = uuid::Uuid::new_v4().to_string();
let db = Arc::new(
SurrealDbClient::memory(namespace, &database)
.await
.expect("in-memory db"),
);
db.apply_migrations()
.await
.expect("migrations should apply");
let config = AppConfig {
storage: StorageKind::Memory,
..Default::default()
};
let storage = StorageManager::new(&config)
.await
.expect("storage manager");
let state = ApiState {
db: Arc::clone(&db),
config,
storage,
};
let router = api_routes_v1(&state).with_state(state);
(router, db)
}
async fn response_body(response: axum::response::Response) -> String {
let body = to_bytes(response.into_body(), usize::MAX)
.await
.expect("response body");
String::from_utf8(body.to_vec()).expect("utf-8 body")
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn live_probe_is_public() {
let (app, _db) = build_test_app().await;
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/live")
.body(Body::empty())
.expect("live request"),
)
.await
.expect("live response");
assert_eq!(response.status(), StatusCode::OK);
assert!(response_body(response).await.contains("\"status\":\"ok\""));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn ready_probe_is_public_and_reports_db_ok() {
let (app, _db) = build_test_app().await;
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/ready")
.body(Body::empty())
.expect("ready request"),
)
.await
.expect("ready response");
assert_eq!(response.status(), StatusCode::OK);
let body = response_body(response).await;
assert!(body.contains("\"checks\":{\"db\":\"ok\"}") || body.contains("\"db\":\"ok\""));
assert!(!body.contains("reason"));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn protected_route_requires_api_key() {
let (app, _db) = build_test_app().await;
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/categories")
.body(Body::empty())
.expect("categories request"),
)
.await
.expect("categories response");
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn protected_route_rejects_invalid_api_key() {
let (app, _db) = build_test_app().await;
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/categories")
.header("X-API-Key", "sk_invalid")
.body(Body::empty())
.expect("categories request"),
)
.await
.expect("categories response");
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn authenticated_user_can_list_categories() {
let (app, db) = build_test_app().await;
let user = User::create_new(
"api_router_test@example.com".to_string(),
"test_password".to_string(),
&db,
"UTC".to_string(),
"system".to_string(),
)
.await
.expect("test user");
let api_key = User::set_api_key(&user.id, &db)
.await
.expect("api key");
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/categories")
.header("X-API-Key", api_key)
.body(Body::empty())
.expect("categories request"),
)
.await
.expect("categories response");
assert_eq!(response.status(), StatusCode::OK);
}