chore: harden api-router errors and add router integration tests while slimming html handlers.

This commit is contained in:
Per Stark
2026-05-30 11:39:47 +02:00
parent 2aa92b6ad7
commit c70141de35
26 changed files with 814 additions and 260 deletions
-28
View File
@@ -11,31 +11,3 @@ pub struct ApiState {
pub config: AppConfig,
pub storage: StorageManager,
}
impl ApiState {
pub async fn new(
config: &AppConfig,
storage: StorageManager,
) -> anyhow::Result<Self> {
let surreal_db_client = Arc::new(
SurrealDbClient::new(
&config.surrealdb_address,
&config.surrealdb_username,
&config.surrealdb_password,
&config.surrealdb_namespace,
&config.surrealdb_database,
)
.await?,
);
surreal_db_client.apply_migrations().await?;
let app_state = Self {
db: Arc::clone(&surreal_db_client),
config: config.clone(),
storage,
};
Ok(app_state)
}
}
+19 -9
View File
@@ -7,7 +7,7 @@ use common::error::AppError;
use serde::Serialize;
use thiserror::Error;
#[derive(Error, Debug, Serialize, Clone)]
#[derive(Error, Debug)]
pub enum ApiErr {
#[error("internal server error")]
InternalError(String),
@@ -28,14 +28,13 @@ pub enum ApiErr {
impl From<AppError> for ApiErr {
fn from(err: AppError) -> Self {
match err {
AppError::Database(_) | AppError::OpenAI(_) => {
tracing::error!("Internal error: {:?}", err);
Self::InternalError("Internal server error".to_string())
}
AppError::NotFound(msg) => Self::NotFound(msg),
AppError::Validation(msg) => Self::ValidationError(msg),
AppError::Auth(msg) => Self::Unauthorized(msg),
_ => Self::InternalError("Internal server error".to_string()),
other => {
tracing::error!("internal API error: {other:?}");
Self::InternalError("Internal server error".to_string())
}
}
}
}
@@ -120,10 +119,21 @@ mod tests {
assert!(matches!(api_error, ApiErr::Unauthorized(msg) if msg == "unauthorized"));
// Test for internal errors - create a mock error that doesn't require surrealdb
let internal_error =
AppError::Io(io::Error::other("io error"));
let internal_error = AppError::Io(io::Error::other("io error"));
let api_error = ApiErr::from(internal_error);
assert!(matches!(api_error, ApiErr::InternalError(_)));
assert!(matches!(
api_error,
ApiErr::InternalError(msg) if msg == "Internal server error"
));
}
#[test]
fn test_app_error_internal_error_is_sanitized() {
let api_error = ApiErr::from(AppError::internal("db password incorrect"));
assert!(matches!(
api_error,
ApiErr::InternalError(msg) if msg == "Internal server error"
));
}
#[test]
+2 -2
View File
@@ -6,7 +6,7 @@ use axum::{
Router,
};
use middleware_api_auth::api_auth;
use routes::{categories::list, ingest::ingest_data, liveness::live, readiness::ready};
use routes::{categories::list, ingest::handle, liveness::live, readiness::ready};
pub mod api_state;
pub mod error;
@@ -28,7 +28,7 @@ where
let protected = Router::new()
.route(
"/ingest",
post(ingest_data).layer(DefaultBodyLimit::max(
post(handle).layer(DefaultBodyLimit::max(
app_state.config.ingest_max_body_bytes,
)),
)
+63 -4
View File
@@ -16,7 +16,7 @@ pub async fn api_auth(
let api_key = extract_api_key(&request)
.ok_or_else(|| ApiErr::Unauthorized("You have to be authenticated".to_string()))?;
let user = User::find_by_api_key(&api_key, &state.db).await?;
let user = User::find_by_api_key(api_key, &state.db).await?;
let user =
user.ok_or_else(|| ApiErr::Unauthorized("You have to be authenticated".to_string()))?;
@@ -25,7 +25,7 @@ pub async fn api_auth(
Ok(next.run(request).await)
}
fn extract_api_key(request: &Request) -> Option<String> {
fn extract_api_key(request: &Request) -> Option<&str> {
request
.headers()
.get("X-API-Key")
@@ -35,7 +35,66 @@ fn extract_api_key(request: &Request) -> Option<String> {
.headers()
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|auth| auth.strip_prefix("Bearer ").map(str::trim))
.and_then(|auth| auth.strip_prefix("Bearer "))
.map(str::trim)
})
.map(String::from)
}
#[cfg(test)]
#[allow(clippy::expect_used)]
mod tests {
use axum::body::Body;
use axum::http::{HeaderValue, Request};
use super::extract_api_key;
fn request_with_headers(headers: &[(&str, &str)]) -> Request<Body> {
let mut builder = Request::builder().method("GET").uri("/");
for (name, value) in headers {
builder = builder.header(*name, *value);
}
builder.body(Body::empty()).expect("test request")
}
#[test]
fn extract_api_key_from_x_api_key_header() {
let request = request_with_headers(&[("X-API-Key", "sk_test_key")]);
assert_eq!(extract_api_key(&request), Some("sk_test_key"));
}
#[test]
fn extract_api_key_from_bearer_authorization() {
let request = request_with_headers(&[("Authorization", "Bearer sk_bearer_key")]);
assert_eq!(extract_api_key(&request), Some("sk_bearer_key"));
}
#[test]
fn extract_api_key_prefers_x_api_key_over_authorization() {
let request = request_with_headers(&[
("X-API-Key", "sk_header"),
("Authorization", "Bearer sk_bearer"),
]);
assert_eq!(extract_api_key(&request), Some("sk_header"));
}
#[test]
fn extract_api_key_returns_none_when_missing() {
let request = request_with_headers(&[]);
assert_eq!(extract_api_key(&request), None);
}
#[test]
fn extract_api_key_rejects_non_bearer_authorization() {
let request = request_with_headers(&[("Authorization", "Basic abc")]);
assert_eq!(extract_api_key(&request), None);
}
#[test]
fn extract_api_key_rejects_invalid_header_values() {
let mut request = request_with_headers(&[]);
request
.headers_mut()
.insert("X-API-Key", HeaderValue::from_bytes(&[0xFF]).expect("invalid header"));
assert_eq!(extract_api_key(&request), None);
}
}
+8 -12
View File
@@ -16,7 +16,7 @@ use tracing::info;
use crate::{api_state::ApiState, error::ApiErr};
#[derive(Debug, TryFromMultipart)]
pub struct IngestParams {
pub struct Params {
pub content: Option<String>,
pub context: String,
pub category: String,
@@ -25,24 +25,20 @@ pub struct IngestParams {
pub files: Vec<FieldData<NamedTempFile>>,
}
pub async fn ingest_data(
pub async fn handle(
State(state): State<ApiState>,
Extension(user): Extension<User>,
TypedMultipart(input): TypedMultipart<IngestParams>,
TypedMultipart(input): TypedMultipart<Params>,
) -> Result<impl IntoResponse, ApiErr> {
let user_id = user.id;
let content_bytes = input.content.as_ref().map_or(0, |c| c.len());
let has_content = input.content.as_ref().is_some_and(|c| !c.trim().is_empty());
let context_bytes = input.context.len();
let category_bytes = input.category.len();
let file_count = input.files.len();
match validate_ingest_input(
&state.config,
input.content.as_deref(),
&input.context,
&input.category,
file_count,
input.files.len(),
) {
Ok(()) => {}
Err(IngestValidationError::PayloadTooLarge(message)) => {
@@ -56,10 +52,10 @@ pub async fn ingest_data(
info!(
user_id = %user_id,
has_content,
content_bytes,
context_bytes,
category_bytes,
file_count,
content_len = input.content.as_ref().map_or(0, String::len),
context_len = input.context.len(),
category_len = input.category.len(),
file_count = input.files.len(),
"Received ingest request"
);
+11 -8
View File
@@ -1,5 +1,6 @@
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
use serde_json::json;
use tracing::error;
use crate::api_state::ApiState;
@@ -13,13 +14,15 @@ pub async fn ready(State(state): State<ApiState>) -> impl IntoResponse {
"checks": { "db": "ok" }
})),
),
Err(e) => (
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({
"status": "error",
"checks": { "db": "fail" },
"reason": e.to_string()
})),
),
Err(e) => {
error!("readiness check failed: {e:?}");
(
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({
"status": "error",
"checks": { "db": "fail" }
})),
)
}
}
}