From c70141de35322e1d2c8ab185437c893457a50511 Mon Sep 17 00:00:00 2001 From: Per Stark Date: Sat, 30 May 2026 11:39:47 +0200 Subject: [PATCH] chore: harden api-router errors and add router integration tests while slimming html handlers. --- Cargo.lock | 3 + api-router/Cargo.toml | 5 + api-router/src/api_state.rs | 28 -- api-router/src/error.rs | 28 +- api-router/src/lib.rs | 4 +- api-router/src/middleware_api_auth.rs | 67 +++- api-router/src/routes/ingest.rs | 20 +- api-router/src/routes/readiness.rs | 19 +- api-router/tests/api_router_integration.rs | 167 ++++++++++ html-router/Cargo.toml | 1 + .../src/middlewares/response_middleware.rs | 48 ++- html-router/src/router_factory.rs | 4 +- html-router/src/routes/account/handlers.rs | 18 +- html-router/src/routes/admin/handlers.rs | 21 +- html-router/src/routes/auth/signin.rs | 12 +- html-router/src/routes/auth/signout.rs | 6 +- html-router/src/routes/auth/signup.rs | 12 +- html-router/src/routes/chat/chat_handlers.rs | 90 ++--- html-router/src/routes/chat/references.rs | 5 +- html-router/src/routes/content/handlers.rs | 15 +- html-router/src/routes/index/handlers.rs | 22 +- html-router/src/routes/ingestion/handlers.rs | 30 +- html-router/src/routes/knowledge/handlers.rs | 48 +-- html-router/src/routes/scratchpad/handlers.rs | 85 ++--- html-router/src/routes/search/handlers.rs | 5 +- html-router/tests/router_integration.rs | 311 ++++++++++++++++++ 26 files changed, 814 insertions(+), 260 deletions(-) create mode 100644 api-router/tests/api_router_integration.rs create mode 100644 html-router/tests/router_integration.rs diff --git a/Cargo.lock b/Cargo.lock index cc9a986..2fdb359 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -247,7 +247,9 @@ dependencies = [ "tempfile", "thiserror 1.0.69", "tokio", + "tower", "tracing", + "uuid", ] [[package]] @@ -2978,6 +2980,7 @@ dependencies = [ "thiserror 1.0.69", "tokio", "tokio-util", + "tower", "tower-http", "tower-serve-static", "tracing", diff --git a/api-router/Cargo.toml b/api-router/Cargo.toml index 42665a4..5d9125d 100644 --- a/api-router/Cargo.toml +++ b/api-router/Cargo.toml @@ -20,3 +20,8 @@ futures = { workspace = true } axum_typed_multipart = { workspace = true} common = { path = "../common" } + +[dev-dependencies] +common = { path = "../common", features = ["test-utils"] } +tower = "0.5" +uuid = { workspace = true } diff --git a/api-router/src/api_state.rs b/api-router/src/api_state.rs index 9cdcd32..f4a0728 100644 --- a/api-router/src/api_state.rs +++ b/api-router/src/api_state.rs @@ -11,31 +11,3 @@ pub struct ApiState { pub config: AppConfig, pub storage: StorageManager, } - -impl ApiState { - pub async fn new( - config: &AppConfig, - storage: StorageManager, - ) -> anyhow::Result { - let surreal_db_client = Arc::new( - SurrealDbClient::new( - &config.surrealdb_address, - &config.surrealdb_username, - &config.surrealdb_password, - &config.surrealdb_namespace, - &config.surrealdb_database, - ) - .await?, - ); - - surreal_db_client.apply_migrations().await?; - - let app_state = Self { - db: Arc::clone(&surreal_db_client), - config: config.clone(), - storage, - }; - - Ok(app_state) - } -} diff --git a/api-router/src/error.rs b/api-router/src/error.rs index 9aa16ad..e82e49b 100644 --- a/api-router/src/error.rs +++ b/api-router/src/error.rs @@ -7,7 +7,7 @@ use common::error::AppError; use serde::Serialize; use thiserror::Error; -#[derive(Error, Debug, Serialize, Clone)] +#[derive(Error, Debug)] pub enum ApiErr { #[error("internal server error")] InternalError(String), @@ -28,14 +28,13 @@ pub enum ApiErr { impl From for ApiErr { fn from(err: AppError) -> Self { match err { - AppError::Database(_) | AppError::OpenAI(_) => { - tracing::error!("Internal error: {:?}", err); - Self::InternalError("Internal server error".to_string()) - } AppError::NotFound(msg) => Self::NotFound(msg), AppError::Validation(msg) => Self::ValidationError(msg), AppError::Auth(msg) => Self::Unauthorized(msg), - _ => Self::InternalError("Internal server error".to_string()), + other => { + tracing::error!("internal API error: {other:?}"); + Self::InternalError("Internal server error".to_string()) + } } } } @@ -120,10 +119,21 @@ mod tests { assert!(matches!(api_error, ApiErr::Unauthorized(msg) if msg == "unauthorized")); // Test for internal errors - create a mock error that doesn't require surrealdb - let internal_error = - AppError::Io(io::Error::other("io error")); + let internal_error = AppError::Io(io::Error::other("io error")); let api_error = ApiErr::from(internal_error); - assert!(matches!(api_error, ApiErr::InternalError(_))); + assert!(matches!( + api_error, + ApiErr::InternalError(msg) if msg == "Internal server error" + )); + } + + #[test] + fn test_app_error_internal_error_is_sanitized() { + let api_error = ApiErr::from(AppError::internal("db password incorrect")); + assert!(matches!( + api_error, + ApiErr::InternalError(msg) if msg == "Internal server error" + )); } #[test] diff --git a/api-router/src/lib.rs b/api-router/src/lib.rs index df24676..1bb9d66 100644 --- a/api-router/src/lib.rs +++ b/api-router/src/lib.rs @@ -6,7 +6,7 @@ use axum::{ Router, }; use middleware_api_auth::api_auth; -use routes::{categories::list, ingest::ingest_data, liveness::live, readiness::ready}; +use routes::{categories::list, ingest::handle, liveness::live, readiness::ready}; pub mod api_state; pub mod error; @@ -28,7 +28,7 @@ where let protected = Router::new() .route( "/ingest", - post(ingest_data).layer(DefaultBodyLimit::max( + post(handle).layer(DefaultBodyLimit::max( app_state.config.ingest_max_body_bytes, )), ) diff --git a/api-router/src/middleware_api_auth.rs b/api-router/src/middleware_api_auth.rs index 111d42d..69369d3 100644 --- a/api-router/src/middleware_api_auth.rs +++ b/api-router/src/middleware_api_auth.rs @@ -16,7 +16,7 @@ pub async fn api_auth( let api_key = extract_api_key(&request) .ok_or_else(|| ApiErr::Unauthorized("You have to be authenticated".to_string()))?; - let user = User::find_by_api_key(&api_key, &state.db).await?; + let user = User::find_by_api_key(api_key, &state.db).await?; let user = user.ok_or_else(|| ApiErr::Unauthorized("You have to be authenticated".to_string()))?; @@ -25,7 +25,7 @@ pub async fn api_auth( Ok(next.run(request).await) } -fn extract_api_key(request: &Request) -> Option { +fn extract_api_key(request: &Request) -> Option<&str> { request .headers() .get("X-API-Key") @@ -35,7 +35,66 @@ fn extract_api_key(request: &Request) -> Option { .headers() .get("Authorization") .and_then(|v| v.to_str().ok()) - .and_then(|auth| auth.strip_prefix("Bearer ").map(str::trim)) + .and_then(|auth| auth.strip_prefix("Bearer ")) + .map(str::trim) }) - .map(String::from) +} + +#[cfg(test)] +#[allow(clippy::expect_used)] +mod tests { + use axum::body::Body; + use axum::http::{HeaderValue, Request}; + + use super::extract_api_key; + + fn request_with_headers(headers: &[(&str, &str)]) -> Request { + let mut builder = Request::builder().method("GET").uri("/"); + for (name, value) in headers { + builder = builder.header(*name, *value); + } + builder.body(Body::empty()).expect("test request") + } + + #[test] + fn extract_api_key_from_x_api_key_header() { + let request = request_with_headers(&[("X-API-Key", "sk_test_key")]); + assert_eq!(extract_api_key(&request), Some("sk_test_key")); + } + + #[test] + fn extract_api_key_from_bearer_authorization() { + let request = request_with_headers(&[("Authorization", "Bearer sk_bearer_key")]); + assert_eq!(extract_api_key(&request), Some("sk_bearer_key")); + } + + #[test] + fn extract_api_key_prefers_x_api_key_over_authorization() { + let request = request_with_headers(&[ + ("X-API-Key", "sk_header"), + ("Authorization", "Bearer sk_bearer"), + ]); + assert_eq!(extract_api_key(&request), Some("sk_header")); + } + + #[test] + fn extract_api_key_returns_none_when_missing() { + let request = request_with_headers(&[]); + assert_eq!(extract_api_key(&request), None); + } + + #[test] + fn extract_api_key_rejects_non_bearer_authorization() { + let request = request_with_headers(&[("Authorization", "Basic abc")]); + assert_eq!(extract_api_key(&request), None); + } + + #[test] + fn extract_api_key_rejects_invalid_header_values() { + let mut request = request_with_headers(&[]); + request + .headers_mut() + .insert("X-API-Key", HeaderValue::from_bytes(&[0xFF]).expect("invalid header")); + assert_eq!(extract_api_key(&request), None); + } } diff --git a/api-router/src/routes/ingest.rs b/api-router/src/routes/ingest.rs index d269cf6..ea6ab86 100644 --- a/api-router/src/routes/ingest.rs +++ b/api-router/src/routes/ingest.rs @@ -16,7 +16,7 @@ use tracing::info; use crate::{api_state::ApiState, error::ApiErr}; #[derive(Debug, TryFromMultipart)] -pub struct IngestParams { +pub struct Params { pub content: Option, pub context: String, pub category: String, @@ -25,24 +25,20 @@ pub struct IngestParams { pub files: Vec>, } -pub async fn ingest_data( +pub async fn handle( State(state): State, Extension(user): Extension, - TypedMultipart(input): TypedMultipart, + TypedMultipart(input): TypedMultipart, ) -> Result { let user_id = user.id; - let content_bytes = input.content.as_ref().map_or(0, |c| c.len()); let has_content = input.content.as_ref().is_some_and(|c| !c.trim().is_empty()); - let context_bytes = input.context.len(); - let category_bytes = input.category.len(); - let file_count = input.files.len(); match validate_ingest_input( &state.config, input.content.as_deref(), &input.context, &input.category, - file_count, + input.files.len(), ) { Ok(()) => {} Err(IngestValidationError::PayloadTooLarge(message)) => { @@ -56,10 +52,10 @@ pub async fn ingest_data( info!( user_id = %user_id, has_content, - content_bytes, - context_bytes, - category_bytes, - file_count, + content_len = input.content.as_ref().map_or(0, String::len), + context_len = input.context.len(), + category_len = input.category.len(), + file_count = input.files.len(), "Received ingest request" ); diff --git a/api-router/src/routes/readiness.rs b/api-router/src/routes/readiness.rs index 628c208..7827d94 100644 --- a/api-router/src/routes/readiness.rs +++ b/api-router/src/routes/readiness.rs @@ -1,5 +1,6 @@ use axum::{extract::State, http::StatusCode, response::IntoResponse, Json}; use serde_json::json; +use tracing::error; use crate::api_state::ApiState; @@ -13,13 +14,15 @@ pub async fn ready(State(state): State) -> impl IntoResponse { "checks": { "db": "ok" } })), ), - Err(e) => ( - StatusCode::SERVICE_UNAVAILABLE, - Json(json!({ - "status": "error", - "checks": { "db": "fail" }, - "reason": e.to_string() - })), - ), + Err(e) => { + error!("readiness check failed: {e:?}"); + ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ + "status": "error", + "checks": { "db": "fail" } + })), + ) + } } } diff --git a/api-router/tests/api_router_integration.rs b/api-router/tests/api_router_integration.rs new file mode 100644 index 0000000..11c7022 --- /dev/null +++ b/api-router/tests/api_router_integration.rs @@ -0,0 +1,167 @@ +#![allow(clippy::expect_used)] + +use std::sync::Arc; + +use api_router::{api_routes_v1, api_state::ApiState}; +use axum::{ + body::{to_bytes, Body}, + http::{Request, StatusCode}, + Router, +}; +use common::{ + storage::{ + db::SurrealDbClient, + store::StorageManager, + types::user::User, + }, + utils::config::{AppConfig, StorageKind}, +}; +use tower::ServiceExt; + +async fn build_test_app() -> (Router, Arc) { + let namespace = "api_router_test"; + let database = uuid::Uuid::new_v4().to_string(); + let db = Arc::new( + SurrealDbClient::memory(namespace, &database) + .await + .expect("in-memory db"), + ); + db.apply_migrations() + .await + .expect("migrations should apply"); + + let config = AppConfig { + storage: StorageKind::Memory, + ..Default::default() + }; + let storage = StorageManager::new(&config) + .await + .expect("storage manager"); + + let state = ApiState { + db: Arc::clone(&db), + config, + storage, + }; + + let router = api_routes_v1(&state).with_state(state); + + (router, db) +} + +async fn response_body(response: axum::response::Response) -> String { + let body = to_bytes(response.into_body(), usize::MAX) + .await + .expect("response body"); + String::from_utf8(body.to_vec()).expect("utf-8 body") +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn live_probe_is_public() { + let (app, _db) = build_test_app().await; + + let response = app + .clone() + .oneshot( + Request::builder() + .uri("/live") + .body(Body::empty()) + .expect("live request"), + ) + .await + .expect("live response"); + + assert_eq!(response.status(), StatusCode::OK); + assert!(response_body(response).await.contains("\"status\":\"ok\"")); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn ready_probe_is_public_and_reports_db_ok() { + let (app, _db) = build_test_app().await; + + let response = app + .clone() + .oneshot( + Request::builder() + .uri("/ready") + .body(Body::empty()) + .expect("ready request"), + ) + .await + .expect("ready response"); + + assert_eq!(response.status(), StatusCode::OK); + let body = response_body(response).await; + assert!(body.contains("\"checks\":{\"db\":\"ok\"}") || body.contains("\"db\":\"ok\"")); + assert!(!body.contains("reason")); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn protected_route_requires_api_key() { + let (app, _db) = build_test_app().await; + + let response = app + .clone() + .oneshot( + Request::builder() + .uri("/categories") + .body(Body::empty()) + .expect("categories request"), + ) + .await + .expect("categories response"); + + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn protected_route_rejects_invalid_api_key() { + let (app, _db) = build_test_app().await; + + let response = app + .clone() + .oneshot( + Request::builder() + .uri("/categories") + .header("X-API-Key", "sk_invalid") + .body(Body::empty()) + .expect("categories request"), + ) + .await + .expect("categories response"); + + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn authenticated_user_can_list_categories() { + let (app, db) = build_test_app().await; + + let user = User::create_new( + "api_router_test@example.com".to_string(), + "test_password".to_string(), + &db, + "UTC".to_string(), + "system".to_string(), + ) + .await + .expect("test user"); + + let api_key = User::set_api_key(&user.id, &db) + .await + .expect("api key"); + + let response = app + .clone() + .oneshot( + Request::builder() + .uri("/categories") + .header("X-API-Key", api_key) + .body(Body::empty()) + .expect("categories request"), + ) + .await + .expect("categories response"); + + assert_eq!(response.status(), StatusCode::OK); +} diff --git a/html-router/Cargo.toml b/html-router/Cargo.toml index dcfcbc4..6eddc96 100644 --- a/html-router/Cargo.toml +++ b/html-router/Cargo.toml @@ -43,6 +43,7 @@ json-stream-parser = { path = "../json-stream-parser" } [dev-dependencies] common = { path = "../common", features = ["test-utils"] } +tower = "0.5" [build-dependencies] minijinja-embed = { version = "2.8.0" } diff --git a/html-router/src/middlewares/response_middleware.rs b/html-router/src/middlewares/response_middleware.rs index 616db31..5cf3d0a 100644 --- a/html-router/src/middlewares/response_middleware.rs +++ b/html-router/src/middlewares/response_middleware.rs @@ -116,6 +116,29 @@ impl IntoResponse for TemplateResponse { } } +/// Typical handler return type when no extra response headers or body are needed. +pub type TemplateResult = Result; + +/// Handler return type when the response needs custom headers or a non-template body. +pub type ResponseResult = Result; + +/// Converts a [`TemplateResponse`] for [`ResponseResult`] handlers that do not set extra headers. +pub fn template_as_response(template: TemplateResponse) -> Response { + template.into_response() +} + +/// Wraps a [`TemplateResponse`] for the template middleware and applies outbound headers. +/// +/// Headers listed in [`HTMX_HEADERS_TO_FORWARD`] are copied onto the rendered HTML response. +pub fn template_with_headers( + template: TemplateResponse, + apply: impl FnOnce(&mut axum::http::HeaderMap), +) -> Response { + let mut response = template.into_response(); + apply(response.headers_mut()); + response +} + #[derive(Serialize)] struct TemplateUser { id: String, @@ -143,7 +166,7 @@ struct ContextWrapper<'a> { initial_theme: &'a str, is_authenticated: bool, user: Option<&'a TemplateUser>, - conversation_archive: Vec, + conversation_archive: &'a [SidebarConversation], #[serde(flatten)] context: HashMap, } @@ -213,18 +236,14 @@ where if let Some(template_response) = response.extensions().get::().cloned() { let template_engine = state.template_engine(); - let mut conversation_archive = Vec::new(); - let should_load_conversation_archive = matches!(&template_response.template_kind, TemplateKind::Full(_)); - if should_load_conversation_archive { + let cached_archive = if should_load_conversation_archive { if let Some(user_id) = current_user.as_ref().map(|u| &u.id) { let html_state = state.html_state(); - if let Some(cached_archive) = - html_state.get_cached_conversation_archive(user_id).await - { - conversation_archive = cached_archive.to_vec(); + if let Some(cached) = html_state.get_cached_conversation_archive(user_id).await { + Some(cached) } else if let Ok(archive) = Conversation::get_user_sidebar_conversations(user_id, &html_state.db).await { @@ -232,10 +251,19 @@ where html_state .set_cached_conversation_archive(user_id, Arc::clone(&cached)) .await; - conversation_archive = cached.to_vec(); + Some(cached) + } else { + None } + } else { + None } - } + } else { + None + }; + let conversation_archive = cached_archive + .as_ref() + .map_or(&[][..], |archive| archive.as_ref()); let context_map = match context_to_map(&template_response.context) { Ok(map) => map, diff --git a/html-router/src/router_factory.rs b/html-router/src/router_factory.rs index 9ca9c7a..2c85166 100644 --- a/html-router/src/router_factory.rs +++ b/html-router/src/router_factory.rs @@ -34,7 +34,7 @@ macro_rules! create_asset_service { }}; } -pub type MiddleWareVecType = Vec) -> Router + Send>>; +pub type MiddlewareVec = Vec) -> Router + Send>>; /// Builder for composing public/protected HTML routes and middleware layers. pub struct RouterFactory { @@ -43,7 +43,7 @@ pub struct RouterFactory { protected_routers: Vec>, nested_routes: Vec<(String, Router)>, nested_protected_routes: Vec<(String, Router)>, - custom_middleware: MiddleWareVecType, + custom_middleware: MiddlewareVec, public_assets_config: Option, compression_enabled: bool, } diff --git a/html-router/src/routes/account/handlers.rs b/html-router/src/routes/account/handlers.rs index f189b3c..902be8a 100644 --- a/html-router/src/routes/account/handlers.rs +++ b/html-router/src/routes/account/handlers.rs @@ -1,11 +1,11 @@ -use axum::{extract::State, response::IntoResponse, Form}; +use axum::{extract::State, Form}; use chrono_tz::TZ_VARIANTS; use serde::{Deserialize, Serialize}; use crate::{ middlewares::{ auth_middleware::RequireUser, - response_middleware::{HtmlError, TemplateResponse}, + response_middleware::{TemplateResponse, TemplateResult}, }, AuthSessionType, }; @@ -28,7 +28,7 @@ pub struct AccountPageData { pub async fn show_account_page( RequireUser(user): RequireUser, State(_state): State, -) -> Result { +) -> TemplateResult { let timezones = TZ_VARIANTS .iter() .map(std::string::ToString::to_string) @@ -57,7 +57,7 @@ pub async fn set_api_key( State(state): State, RequireUser(user): RequireUser, auth: AuthSessionType, -) -> Result { +) -> TemplateResult { // Generate and set the API key let api_key = User::set_api_key(&user.id, &state.db).await?; @@ -82,7 +82,7 @@ pub async fn delete_account( State(state): State, RequireUser(user): RequireUser, auth: AuthSessionType, -) -> Result { +) -> TemplateResult { state.db.delete_item::(&user.id).await?; auth.logout_user(); @@ -102,7 +102,7 @@ pub async fn update_timezone( RequireUser(user): RequireUser, auth: AuthSessionType, Form(form): Form, -) -> Result { +) -> TemplateResult { User::update_timezone(&user.id, &form.timezone, &state.db).await?; // Clear the cache @@ -137,7 +137,7 @@ pub async fn update_theme( RequireUser(user): RequireUser, auth: AuthSessionType, Form(form): Form, -) -> Result { +) -> TemplateResult { User::update_theme(&user.id, &form.theme, &state.db).await?; // Clear the cache @@ -166,7 +166,7 @@ pub async fn update_theme( pub async fn show_change_password( RequireUser(_user): RequireUser, -) -> Result { +) -> TemplateResult { Ok(TemplateResponse::new_template( "auth/change_password_form.html", (), @@ -184,7 +184,7 @@ pub async fn change_password( RequireUser(user): RequireUser, auth: AuthSessionType, Form(form): Form, -) -> Result { +) -> TemplateResult { // Authenticate to make sure the password matches let authenticated_user = User::authenticate(&user.email, &form.old_password, &state.db).await?; diff --git a/html-router/src/routes/admin/handlers.rs b/html-router/src/routes/admin/handlers.rs index 52a9072..b0657ad 100644 --- a/html-router/src/routes/admin/handlers.rs +++ b/html-router/src/routes/admin/handlers.rs @@ -3,7 +3,6 @@ use std::sync::Arc; use async_openai::types::ListModelResponse; use axum::{ extract::{Query, State}, - response::IntoResponse, Form, }; use serde::{Deserialize, Serialize}; @@ -26,7 +25,7 @@ use tracing::{error, info}; use crate::{ html_state::HtmlState, - middlewares::response_middleware::{HtmlError, TemplateResponse}, + middlewares::response_middleware::{TemplateResponse, TemplateResult}, }; #[derive(Serialize)] @@ -57,7 +56,7 @@ pub struct AdminPanelQuery { pub async fn show_admin_panel( State(state): State, Query(query): Query, -) -> Result { +) -> TemplateResult { let section = match query.section.as_deref() { Some("models") => AdminSection::Models, _ => AdminSection::Overview, @@ -124,7 +123,7 @@ pub struct RegistrationToggleData { pub async fn toggle_registration_status( State(state): State, Form(input): Form, -) -> Result { +) -> TemplateResult { let new_settings = SystemSettingsPatch { registrations_enabled: Some(input.registration_open), ..Default::default() @@ -160,7 +159,7 @@ pub struct ModelSettingsData { pub async fn update_model_settings( State(state): State, Form(input): Form, -) -> Result { +) -> TemplateResult { let current_settings = SystemSettings::get_current(&state.db).await?; // Check if using FastEmbed - if so, embedding model/dimensions cannot be changed via UI @@ -272,7 +271,7 @@ pub struct SystemPromptEditData { pub async fn show_edit_system_prompt( State(state): State, -) -> Result { +) -> TemplateResult { let settings = SystemSettings::get_current(&state.db).await?; Ok(TemplateResponse::new_template( @@ -297,7 +296,7 @@ pub struct SystemPromptSectionData { pub async fn patch_query_prompt( State(state): State, Form(input): Form, -) -> Result { +) -> TemplateResult { let new_settings = SystemSettingsPatch { query_system_prompt: Some(input.query_system_prompt), ..Default::default() @@ -322,7 +321,7 @@ pub struct IngestionPromptEditData { pub async fn show_edit_ingestion_prompt( State(state): State, -) -> Result { +) -> TemplateResult { let settings = SystemSettings::get_current(&state.db).await?; Ok(TemplateResponse::new_template( @@ -342,7 +341,7 @@ pub struct IngestionPromptUpdateInput { pub async fn patch_ingestion_prompt( State(state): State, Form(input): Form, -) -> Result { +) -> TemplateResult { let new_settings = SystemSettingsPatch { ingestion_system_prompt: Some(input.ingestion_system_prompt), ..Default::default() @@ -367,7 +366,7 @@ pub struct ImagePromptEditData { pub async fn show_edit_image_prompt( State(state): State, -) -> Result { +) -> TemplateResult { let settings = SystemSettings::get_current(&state.db).await?; Ok(TemplateResponse::new_template( @@ -387,7 +386,7 @@ pub struct ImagePromptUpdateInput { pub async fn patch_image_prompt( State(state): State, Form(input): Form, -) -> Result { +) -> TemplateResult { let new_settings = SystemSettingsPatch { image_processing_prompt: Some(input.image_processing_prompt), ..Default::default() diff --git a/html-router/src/routes/auth/signin.rs b/html-router/src/routes/auth/signin.rs index f2a412c..14736a3 100644 --- a/html-router/src/routes/auth/signin.rs +++ b/html-router/src/routes/auth/signin.rs @@ -1,10 +1,10 @@ -use axum::{extract::State, response::IntoResponse, Form}; +use axum::{extract::State, Form}; use axum_htmx::HxBoosted; use serde::{Deserialize, Serialize}; use crate::{ html_state::HtmlState, - middlewares::response_middleware::{HtmlError, TemplateResponse}, + middlewares::response_middleware::{TemplateResponse, TemplateResult}, AuthSessionType, }; use common::storage::types::user::User; @@ -19,7 +19,7 @@ pub struct SignInParams { pub async fn show_signin_form( auth: AuthSessionType, HxBoosted(boosted): HxBoosted, -) -> Result { +) -> TemplateResult { if auth.current_user.is_some() { return Ok(TemplateResponse::redirect("/")); } @@ -38,9 +38,9 @@ pub async fn authenticate_user( State(state): State, auth: AuthSessionType, Form(form): Form, -) -> Result { +) -> TemplateResult { let Ok(user) = User::authenticate(&form.email, &form.password, &state.db).await else { - return Ok(TemplateResponse::bad_request("Incorrect email or password").into_response()); + return Ok(TemplateResponse::bad_request("Incorrect email or password")); }; auth.login_user(user.id); @@ -49,5 +49,5 @@ pub async fn authenticate_user( auth.remember_user(true); } - Ok(TemplateResponse::redirect("/").into_response()) + Ok(TemplateResponse::redirect("/")) } diff --git a/html-router/src/routes/auth/signout.rs b/html-router/src/routes/auth/signout.rs index f1dd4fa..1795d8f 100644 --- a/html-router/src/routes/auth/signout.rs +++ b/html-router/src/routes/auth/signout.rs @@ -1,11 +1,9 @@ -use axum::response::IntoResponse; - use crate::{ - middlewares::response_middleware::{HtmlError, TemplateResponse}, + middlewares::response_middleware::{TemplateResponse, TemplateResult}, AuthSessionType, }; -pub async fn sign_out_user(auth: AuthSessionType) -> Result { +pub async fn sign_out_user(auth: AuthSessionType) -> TemplateResult { if !auth.is_authenticated() { return Ok(TemplateResponse::redirect("/")); } diff --git a/html-router/src/routes/auth/signup.rs b/html-router/src/routes/auth/signup.rs index 0dff980..d195803 100644 --- a/html-router/src/routes/auth/signup.rs +++ b/html-router/src/routes/auth/signup.rs @@ -1,4 +1,4 @@ -use axum::{extract::State, response::IntoResponse, Form}; +use axum::{extract::State, Form}; use axum_htmx::HxBoosted; use serde::{Deserialize, Serialize}; @@ -6,7 +6,7 @@ use common::{error::AppError, storage::types::user::{Theme, User}}; use crate::{ html_state::HtmlState, - middlewares::response_middleware::{HtmlError, TemplateResponse}, + middlewares::response_middleware::{TemplateResponse, TemplateResult}, AuthSessionType, }; @@ -27,7 +27,7 @@ fn signup_error_message(err: &AppError) -> &str { pub async fn show_signup_form( auth: AuthSessionType, HxBoosted(boosted): HxBoosted, -) -> Result { +) -> TemplateResult { if auth.current_user.is_some() { return Ok(TemplateResponse::redirect("/")); } @@ -47,7 +47,7 @@ pub async fn process_signup_and_show_verification( State(state): State, auth: AuthSessionType, Form(form): Form, -) -> Result { +) -> TemplateResult { let user = match User::create_new( form.email, form.password, @@ -60,11 +60,11 @@ pub async fn process_signup_and_show_verification( Ok(user) => user, Err(err) => { tracing::error!(?err, "signup failed"); - return Ok(TemplateResponse::bad_request(signup_error_message(&err)).into_response()); + return Ok(TemplateResponse::bad_request(signup_error_message(&err))); } }; auth.login_user(user.id); - Ok(TemplateResponse::redirect("/").into_response()) + Ok(TemplateResponse::redirect("/")) } diff --git a/html-router/src/routes/chat/chat_handlers.rs b/html-router/src/routes/chat/chat_handlers.rs index 5591953..7e7e0ef 100644 --- a/html-router/src/routes/chat/chat_handlers.rs +++ b/html-router/src/routes/chat/chat_handlers.rs @@ -1,7 +1,6 @@ use axum::{ extract::{Path, State}, http::HeaderValue, - response::IntoResponse, Form, }; use serde::{Deserialize, Serialize}; @@ -18,7 +17,10 @@ use crate::{ html_state::HtmlState, middlewares::{ auth_middleware::RequireUser, - response_middleware::{HtmlError, TemplateResponse}, + response_middleware::{ + template_as_response, template_with_headers, TemplateResponse, TemplateResult, + ResponseResult, + }, }, }; @@ -31,7 +33,7 @@ pub struct ChatPageData { pub async fn show_chat_base( State(_state): State, RequireUser(_user): RequireUser, -) -> Result { +) -> TemplateResult { Ok(TemplateResponse::new_template( "chat/base.html", ChatPageData { @@ -50,7 +52,7 @@ pub async fn show_existing_chat( Path(conversation_id): Path, State(state): State, RequireUser(user): RequireUser, -) -> Result { +) -> TemplateResult { let (conversation, messages) = Conversation::get_complete_conversation(conversation_id.as_str(), &user.id, &state.db) .await?; @@ -69,7 +71,7 @@ pub async fn new_user_message( State(state): State, RequireUser(user): RequireUser, Form(form): Form, -) -> Result { +) -> ResponseResult { #[derive(Serialize)] struct SSEResponseInitData { user_message: Message, @@ -82,31 +84,32 @@ pub async fn new_user_message( .ok_or_else(|| AppError::NotFound("conversation was not found".into()))?; if conversation.user_id != user.id { - return Ok(TemplateResponse::unauthorized().into_response()); + return Ok(template_as_response(TemplateResponse::unauthorized())); } let user_message = Message::new(conversation_id, MessageRole::User, form.content, None); state.db.store_item(user_message.clone()).await?; - let mut response = TemplateResponse::new_template( - "chat/streaming_response.html", - SSEResponseInitData { user_message }, - ) - .into_response(); - - if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) { - response.headers_mut().insert("HX-Push", header_value); - } - - Ok(response) + let push_path = format!("/chat/{}", conversation.id); + Ok(template_with_headers( + TemplateResponse::new_template( + "chat/streaming_response.html", + SSEResponseInitData { user_message }, + ), + |headers| { + if let Ok(header_value) = HeaderValue::from_str(&push_path) { + headers.insert("HX-Push", header_value); + } + }, + )) } pub async fn new_chat_user_message( State(state): State, RequireUser(user): RequireUser, Form(form): Form, -) -> Result { +) -> ResponseResult { #[derive(Serialize)] struct SSEResponseInitData { user_message: Message, @@ -125,20 +128,21 @@ pub async fn new_chat_user_message( state.db.store_item(user_message.clone()).await?; state.invalidate_conversation_archive_cache(&user.id).await; - let mut response = TemplateResponse::new_template( - "chat/new_chat_first_response.html", - SSEResponseInitData { - user_message, - conversation: conversation.clone(), + let push_path = format!("/chat/{}", conversation.id); + Ok(template_with_headers( + TemplateResponse::new_template( + "chat/new_chat_first_response.html", + SSEResponseInitData { + user_message, + conversation: conversation.clone(), + }, + ), + |headers| { + if let Ok(header_value) = HeaderValue::from_str(&push_path) { + headers.insert("HX-Push", header_value); + } }, - ) - .into_response(); - - if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) { - response.headers_mut().insert("HX-Push", header_value); - } - - Ok(response) + )) } #[derive(Deserialize)] @@ -155,7 +159,7 @@ pub async fn show_conversation_editing_title( State(state): State, RequireUser(user): RequireUser, Path(conversation_id): Path, -) -> Result { +) -> TemplateResult { let conversation: Conversation = state .db .get_item(&conversation_id) @@ -163,7 +167,7 @@ pub async fn show_conversation_editing_title( .ok_or_else(|| AppError::NotFound("conversation not found".to_string()))?; if conversation.user_id != user.id { - return Ok(TemplateResponse::unauthorized().into_response()); + return Ok(TemplateResponse::unauthorized()); } Ok(TemplateResponse::new_template( @@ -171,8 +175,7 @@ pub async fn show_conversation_editing_title( DrawerContext { edit_conversation_id: Some(conversation_id), }, - ) - .into_response()) + )) } pub async fn patch_conversation_title( @@ -180,7 +183,7 @@ pub async fn patch_conversation_title( RequireUser(user): RequireUser, Path(conversation_id): Path, Form(form): Form, -) -> Result { +) -> TemplateResult { Conversation::patch_title(&conversation_id, &user.id, &form.title, &state.db).await?; state.invalidate_conversation_archive_cache(&user.id).await; @@ -189,15 +192,14 @@ pub async fn patch_conversation_title( DrawerContext { edit_conversation_id: None, }, - ) - .into_response()) + )) } pub async fn delete_conversation( State(state): State, RequireUser(user): RequireUser, Path(conversation_id): Path, -) -> Result { +) -> TemplateResult { let conversation: Conversation = state .db .get_item(&conversation_id) @@ -205,7 +207,7 @@ pub async fn delete_conversation( .ok_or_else(|| AppError::NotFound("conversation not found".to_string()))?; if conversation.user_id != user.id { - return Ok(TemplateResponse::unauthorized().into_response()); + return Ok(TemplateResponse::unauthorized()); } state @@ -219,18 +221,16 @@ pub async fn delete_conversation( DrawerContext { edit_conversation_id: None, }, - ) - .into_response()) + )) } pub async fn reload_sidebar( State(_state): State, RequireUser(_user): RequireUser, -) -> Result { +) -> TemplateResult { Ok(TemplateResponse::new_template( "sidebar.html", DrawerContext { edit_conversation_id: None, }, - ) - .into_response()) + )) } diff --git a/html-router/src/routes/chat/references.rs b/html-router/src/routes/chat/references.rs index fa206bc..f09c294 100644 --- a/html-router/src/routes/chat/references.rs +++ b/html-router/src/routes/chat/references.rs @@ -2,7 +2,6 @@ use axum::{ extract::{Path, State}, - response::IntoResponse, }; use chrono::{DateTime, Utc}; use chrono_tz::Tz; @@ -16,7 +15,7 @@ use crate::{ html_state::HtmlState, middlewares::{ auth_middleware::RequireUser, - response_middleware::{HtmlError, TemplateResponse}, + response_middleware::{TemplateResponse, TemplateResult}, }, }; @@ -45,7 +44,7 @@ pub async fn show_reference_tooltip( State(state): State, RequireUser(user): RequireUser, Path(reference_id): Path, -) -> Result { +) -> TemplateResult { let Ok((normalized_reference_id, target)) = normalize_reference(&reference_id) else { return Ok(TemplateResponse::not_found()); }; diff --git a/html-router/src/routes/content/handlers.rs b/html-router/src/routes/content/handlers.rs index 57f258a..4a1d78c 100644 --- a/html-router/src/routes/content/handlers.rs +++ b/html-router/src/routes/content/handlers.rs @@ -1,6 +1,5 @@ use axum::{ extract::{Path, Query, State}, - response::IntoResponse, Form, }; use axum_htmx::{HxBoosted, HxRequest, HxTarget}; @@ -15,7 +14,7 @@ use crate::{ html_state::HtmlState, middlewares::{ auth_middleware::RequireUser, - response_middleware::{HtmlError, TemplateResponse}, + response_middleware::{TemplateResponse, TemplateResult}, }, utils::pagination::{paginate_items, Pagination}, utils::text_content_preview::truncate_text_contents, @@ -50,7 +49,7 @@ pub async fn show_content_page( Query(params): Query, HxRequest(is_htmx): HxRequest, HxBoosted(is_boosted): HxBoosted, -) -> Result { +) -> TemplateResult { // Normalize empty strings to None let category_filter = params .category @@ -101,7 +100,7 @@ pub async fn show_text_content_edit_form( State(state): State, RequireUser(user): RequireUser, Path(id): Path, -) -> Result { +) -> TemplateResult { #[derive(Serialize)] pub struct TextContentEditModal { pub text_content: TextContent, @@ -127,7 +126,7 @@ pub async fn patch_text_content( Path(id): Path, HxTarget(target): HxTarget, Form(form): Form, -) -> Result { +) -> TemplateResult { User::get_and_validate_text_content(&id, &user.id, &state.db).await?; TextContent::patch(&id, &form.context, &form.category, &form.text, &state.db).await?; @@ -167,7 +166,7 @@ pub async fn delete_text_content( State(state): State, RequireUser(user): RequireUser, Path(id): Path, -) -> Result { +) -> TemplateResult { // Get and validate the text content let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?; @@ -213,7 +212,7 @@ pub async fn show_content_read_modal( State(state): State, RequireUser(user): RequireUser, Path(id): Path, -) -> Result { +) -> TemplateResult { #[derive(Serialize)] pub struct TextContentReadModalData { pub text_content: TextContent, @@ -231,7 +230,7 @@ pub async fn show_content_read_modal( pub async fn show_recent_content( State(state): State, RequireUser(user): RequireUser, -) -> Result { +) -> TemplateResult { let text_contents = truncate_text_contents(User::get_latest_text_contents(&user.id, &state.db).await?); diff --git a/html-router/src/routes/index/handlers.rs b/html-router/src/routes/index/handlers.rs index f04f6ac..2ec9b2b 100644 --- a/html-router/src/routes/index/handlers.rs +++ b/html-router/src/routes/index/handlers.rs @@ -12,7 +12,9 @@ use crate::{ html_state::HtmlState, middlewares::{ auth_middleware::RequireUser, - response_middleware::{HtmlError, TemplateResponse}, + response_middleware::{ + template_as_response, HtmlError, TemplateResponse, TemplateResult, ResponseResult, + }, }, utils::text_content_preview::truncate_text_contents, utils::truncate::with_ellipsis, @@ -37,7 +39,7 @@ pub struct IndexPageData { pub async fn index_handler( State(state): State, RequireUser(user): RequireUser, -) -> Result { +) -> TemplateResult { let (text_contents, dashboard_stats, active_jobs) = try_join!( User::get_latest_text_contents(&user.id, &state.db), User::get_dashboard_stats(&user.id, &state.db), @@ -65,7 +67,7 @@ pub async fn delete_text_content( State(state): State, RequireUser(user): RequireUser, Path(id): Path, -) -> Result { +) -> TemplateResult { // Get and validate TextContent let text_content = get_and_validate_text_content(&state, &id, &user).await?; @@ -154,7 +156,7 @@ pub async fn delete_job( State(state): State, RequireUser(user): RequireUser, Path(id): Path, -) -> Result { +) -> TemplateResult { User::validate_and_delete_job(&id, &user.id, &state.db).await?; let active_jobs = User::get_unfinished_ingestion_tasks(&user.id, &state.db).await?; @@ -169,7 +171,7 @@ pub async fn delete_job( pub async fn show_active_jobs( State(state): State, RequireUser(user): RequireUser, -) -> Result { +) -> TemplateResult { let active_jobs = User::get_unfinished_ingestion_tasks(&user.id, &state.db).await?; Ok(TemplateResponse::new_template( @@ -181,7 +183,7 @@ pub async fn show_active_jobs( pub async fn show_task_archive( State(state): State, RequireUser(user): RequireUser, -) -> Result { +) -> TemplateResult { let tasks = User::get_all_ingestion_tasks(&user.id, &state.db).await?; let entries: Vec = tasks @@ -234,17 +236,17 @@ pub async fn serve_file( State(state): State, RequireUser(user): RequireUser, Path(file_id): Path, -) -> Result { +) -> ResponseResult { let Ok(file_info) = FileInfo::get_by_id(&file_id, &state.db).await else { - return Ok(TemplateResponse::not_found().into_response()); + return Ok(template_as_response(TemplateResponse::not_found())); }; if file_info.user_id != user.id { - return Ok(TemplateResponse::unauthorized().into_response()); + return Ok(template_as_response(TemplateResponse::unauthorized())); } let Ok(stream) = state.storage.get_stream(&file_info.path).await else { - return Ok(TemplateResponse::server_error().into_response()); + return Ok(template_as_response(TemplateResponse::server_error())); }; let body = Body::from_stream(stream); diff --git a/html-router/src/routes/ingestion/handlers.rs b/html-router/src/routes/ingestion/handlers.rs index 3c1511c..49dd303 100644 --- a/html-router/src/routes/ingestion/handlers.rs +++ b/html-router/src/routes/ingestion/handlers.rs @@ -5,7 +5,7 @@ use axum::{ http::StatusCode, response::{ sse::{Event, KeepAlive, KeepAliveStream}, - IntoResponse, Response, Sse, + Sse, }, }; use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart}; @@ -31,7 +31,7 @@ use crate::{ html_state::HtmlState, middlewares::{ auth_middleware::RequireUser, - response_middleware::{HtmlError, TemplateResponse}, + response_middleware::{TemplateResponse, TemplateResult}, }, }; @@ -49,7 +49,7 @@ fn sse_with_keep_alive(stream: EventStream) -> TaskSse { pub async fn show_ingest_form( State(state): State, RequireUser(user): RequireUser, -) -> Result { +) -> TemplateResult { #[derive(Serialize)] pub struct ShowIngestFormData { user_categories: Vec, @@ -65,7 +65,7 @@ pub async fn show_ingest_form( pub async fn hide_ingest_form( RequireUser(_user): RequireUser, -) -> Result { +) -> TemplateResult { Ok(TemplateResponse::new_template( "ingestion/add_content_button.html", (), @@ -91,12 +91,11 @@ pub async fn process_ingest_form( State(state): State, RequireUser(user): RequireUser, TypedMultipart(input): TypedMultipart, -) -> Result { +) -> TemplateResult { if input.content.as_ref().is_none_or(|c| c.len() < 2) && input.files.is_empty() { - return Ok( - TemplateResponse::bad_request("You need to either add files or content") - .into_response(), - ); + return Ok(TemplateResponse::bad_request( + "You need to either add files or content", + )); } let content_bytes = input.content.as_ref().map_or(0, String::len); @@ -118,11 +117,10 @@ pub async fn process_ingest_form( StatusCode::PAYLOAD_TOO_LARGE, "Payload Too Large", &message, - ) - .into_response()); + )); } Err(IngestValidationError::BadRequest(message)) => { - return Ok(TemplateResponse::bad_request(&message).into_response()); + return Ok(TemplateResponse::bad_request(&message)); } } @@ -153,10 +151,10 @@ pub async fn process_ingest_form( let tasks = IngestionTask::create_all_and_add_to_db(payloads, &user.id, &state.db).await?; - Ok( - TemplateResponse::new_template("dashboard/current_task.html", NewTasksData { tasks }) - .into_response(), - ) + Ok(TemplateResponse::new_template( + "dashboard/current_task.html", + NewTasksData { tasks }, + )) } #[derive(Deserialize)] diff --git a/html-router/src/routes/knowledge/handlers.rs b/html-router/src/routes/knowledge/handlers.rs index 970ae45..d288249 100644 --- a/html-router/src/routes/knowledge/handlers.rs +++ b/html-router/src/routes/knowledge/handlers.rs @@ -31,7 +31,9 @@ use crate::{ html_state::HtmlState, middlewares::{ auth_middleware::RequireUser, - response_middleware::{HtmlError, TemplateResponse}, + response_middleware::{ + template_with_headers, HtmlError, TemplateResponse, TemplateResult, ResponseResult, + }, }, utils::pagination::{paginate_items, Pagination}, }; @@ -120,12 +122,12 @@ fn collect_relationship_type_options(relationships: &[KnowledgeRelationship]) -> options } -fn respond_with_graph_refresh(response: TemplateResponse) -> Response { - let mut response = response.into_response(); - if let Ok(value) = HeaderValue::from_str(GRAPH_REFRESH_TRIGGER) { - response.headers_mut().insert(HX_TRIGGER, value); - } - response +fn graph_refresh_response(template: TemplateResponse) -> Response { + template_with_headers(template, |headers| { + if let Ok(value) = HeaderValue::from_str(GRAPH_REFRESH_TRIGGER) { + headers.insert(HX_TRIGGER, value); + } + }) } #[derive(Deserialize, Default)] @@ -138,7 +140,7 @@ pub struct FilterParams { pub async fn show_new_knowledge_entity_form( State(state): State, RequireUser(user): RequireUser, -) -> Result { +) -> TemplateResult { let entity_types: Vec = KnowledgeEntityType::variants() .iter() .map(ToString::to_string) @@ -170,7 +172,7 @@ pub async fn create_knowledge_entity( State(state): State, RequireUser(user): RequireUser, Form(form): Form, -) -> Result { +) -> ResponseResult { let name = form.name.trim().to_string(); if name.is_empty() { return Err(AppError::Validation("name is required".into()).into()); @@ -230,7 +232,7 @@ pub async fn create_knowledge_entity( let default_params = FilterParams::default(); let kb_data = build_knowledge_base_data(&state, &user, &default_params).await?; - Ok(respond_with_graph_refresh(TemplateResponse::new_partial( + Ok(graph_refresh_response(TemplateResponse::new_partial( "knowledge/base.html", "main", kb_data, @@ -241,7 +243,7 @@ pub async fn suggest_knowledge_relationships( State(state): State, RequireUser(user): RequireUser, Form(form): Form, -) -> Result { +) -> TemplateResult { let entity_lookup: HashMap = User::get_knowledge_entities(&user.id, &state.db) .await? @@ -723,7 +725,7 @@ pub async fn show_knowledge_page( HxRequest(is_htmx): HxRequest, HxBoosted(is_boosted): HxBoosted, Query(mut params): Query, -) -> Result { +) -> TemplateResult { // Normalize filters: treat empty or "none" as no filter params.entity_type = normalize_filter(params.entity_type.take()); params.content_category = normalize_filter(params.content_category.take()); @@ -772,7 +774,7 @@ pub async fn get_knowledge_graph_json( State(state): State, RequireUser(user): RequireUser, Query(mut params): Query, -) -> Result { +) -> ResponseResult { // Normalize filters: treat empty or "none" as no filter params.entity_type = normalize_filter(params.entity_type.take()); params.content_category = normalize_filter(params.content_category.take()); @@ -821,7 +823,7 @@ pub async fn get_knowledge_graph_json( }) .collect(); - Ok(Json(GraphData { nodes, links })) + Ok(Json(GraphData { nodes, links }).into_response()) } // Normalize filter parameters: convert empty strings or "none" (case-insensitive) to None fn normalize_filter(input: Option) -> Option { @@ -851,7 +853,7 @@ pub async fn show_edit_knowledge_entity_form( State(state): State, RequireUser(user): RequireUser, Path(id): Path, -) -> Result { +) -> TemplateResult { #[derive(Serialize)] pub struct EntityData { entity: KnowledgeEntity, @@ -899,7 +901,7 @@ pub async fn patch_knowledge_entity( State(state): State, RequireUser(user): RequireUser, Form(form): Form, -) -> Result { +) -> ResponseResult { // Get the existing entity and validate that the user is allowed User::get_and_validate_knowledge_entity(&form.id, &user.id, &state.db).await?; @@ -930,7 +932,7 @@ pub async fn patch_knowledge_entity( let content_categories = User::get_user_categories(&user.id, &state.db).await?; // Render updated list - Ok(respond_with_graph_refresh(TemplateResponse::new_template( + Ok(graph_refresh_response(TemplateResponse::new_template( "knowledge/entity_list.html", EntityListData { visible_entities, @@ -948,7 +950,7 @@ pub async fn delete_knowledge_entity( State(state): State, RequireUser(user): RequireUser, Path(id): Path, -) -> Result { +) -> ResponseResult { // Get the existing entity and validate that the user is allowed User::get_and_validate_knowledge_entity(&id, &user.id, &state.db).await?; @@ -968,7 +970,7 @@ pub async fn delete_knowledge_entity( // Get content categories let content_categories = User::get_user_categories(&user.id, &state.db).await?; - Ok(respond_with_graph_refresh(TemplateResponse::new_template( + Ok(graph_refresh_response(TemplateResponse::new_template( "knowledge/entity_list.html", EntityListData { visible_entities, @@ -994,7 +996,7 @@ pub async fn delete_knowledge_relationship( State(state): State, RequireUser(user): RequireUser, Path(id): Path, -) -> Result { +) -> ResponseResult { KnowledgeRelationship::delete_relationship_by_id(&id, &user.id, &state.db).await?; let entities = User::get_knowledge_entities(&user.id, &state.db).await?; @@ -1003,7 +1005,7 @@ pub async fn delete_knowledge_relationship( let table_data = build_relationship_table_data(entities, relationships); // Render updated list - Ok(respond_with_graph_refresh(TemplateResponse::new_template( + Ok(graph_refresh_response(TemplateResponse::new_template( "knowledge/relationship_table.html", table_data, ))) @@ -1020,7 +1022,7 @@ pub async fn save_knowledge_relationship( State(state): State, RequireUser(user): RequireUser, Form(form): Form, -) -> Result { +) -> ResponseResult { // Construct relationship let relationship_type = canonicalize_relationship_type(&form.relationship_type); let relationship = KnowledgeRelationship::new( @@ -1039,7 +1041,7 @@ pub async fn save_knowledge_relationship( let table_data = build_relationship_table_data(entities, relationships); // Render updated list - Ok(respond_with_graph_refresh(TemplateResponse::new_template( + Ok(graph_refresh_response(TemplateResponse::new_template( "knowledge/relationship_table.html", table_data, ))) diff --git a/html-router/src/routes/scratchpad/handlers.rs b/html-router/src/routes/scratchpad/handlers.rs index c429780..991f248 100644 --- a/html-router/src/routes/scratchpad/handlers.rs +++ b/html-router/src/routes/scratchpad/handlers.rs @@ -11,7 +11,9 @@ use serde::{Deserialize, Serialize}; use crate::html_state::HtmlState; use crate::middlewares::{ auth_middleware::RequireUser, - response_middleware::{HtmlError, TemplateResponse}, + response_middleware::{ + template_with_headers, HtmlError, TemplateResponse, TemplateResult, ResponseResult, + }, }; use common::storage::types::{ ingestion_payload::IngestionPayload, ingestion_task::IngestionTask, scratchpad::Scratchpad, @@ -127,7 +129,7 @@ pub async fn show_scratchpad_page( HxRequest(is_htmx): HxRequest, HxBoosted(is_boosted): HxBoosted, State(state): State, -) -> Result { +) -> TemplateResult { let scratchpads = Scratchpad::get_by_user(&user.id, &state.db).await?; let archived_scratchpads = Scratchpad::get_archived_by_user(&user.id, &state.db).await?; @@ -165,7 +167,7 @@ pub async fn show_scratchpad_modal( State(state): State, Path(scratchpad_id): Path, Query(query): Query, -) -> Result { +) -> TemplateResult { let scratchpad = Scratchpad::get_by_id(&scratchpad_id, &user.id, &state.db).await?; let scratchpad_detail = ScratchpadDetail::from(&scratchpad); @@ -186,7 +188,7 @@ pub async fn create_scratchpad( RequireUser(user): RequireUser, State(state): State, Form(form): Form, -) -> Result { +) -> TemplateResult { let user_id = user.id.clone(); let scratchpad = Scratchpad::new(user_id.clone(), form.title); let _stored = state.db.store_item(scratchpad.clone()).await?; @@ -217,7 +219,7 @@ pub async fn auto_save_scratchpad( State(state): State, Path(scratchpad_id): Path, Form(form): Form, -) -> Result { +) -> ResponseResult { let updated = Scratchpad::update_content(&scratchpad_id, &user.id, &form.content, &state.db).await?; @@ -229,7 +231,8 @@ pub async fn auto_save_scratchpad( .format("%Y-%m-%d %H:%M:%S") .to_string(), last_saved_at_iso: updated.last_saved_at.to_rfc3339(), - })) + }) + .into_response()) } pub async fn update_scratchpad_title( @@ -237,7 +240,7 @@ pub async fn update_scratchpad_title( State(state): State, Path(scratchpad_id): Path, Form(form): Form, -) -> Result { +) -> TemplateResult { Scratchpad::update_title(&scratchpad_id, &user.id, &form.title, &state.db).await?; let scratchpad = Scratchpad::get_by_id(&scratchpad_id, &user.id, &state.db).await?; @@ -255,7 +258,7 @@ pub async fn delete_scratchpad( RequireUser(user): RequireUser, State(state): State, Path(scratchpad_id): Path, -) -> Result { +) -> TemplateResult { Scratchpad::delete(&scratchpad_id, &user.id, &state.db).await?; // Return the updated main section content @@ -284,7 +287,7 @@ pub async fn ingest_scratchpad( RequireUser(user): RequireUser, State(state): State, Path(scratchpad_id): Path, -) -> Result { +) -> ResponseResult { let scratchpad = Scratchpad::get_by_id(&scratchpad_id, &user.id, &state.db).await?; if scratchpad.content.trim().is_empty() { @@ -347,29 +350,29 @@ pub async fn ingest_scratchpad( r#"{"toast":{"title":"Ingestion queued","description":"Scratchpad archived and added to the ingestion queue.","type":"success"}}"#.to_string() }); - let template_response = TemplateResponse::new_partial( - "scratchpad/base.html", - "main", - ScratchpadPageData { - scratchpads: scratchpad_list, - archived_scratchpads: archived_list, - new_scratchpad: None, + Ok(template_with_headers( + TemplateResponse::new_partial( + "scratchpad/base.html", + "main", + ScratchpadPageData { + scratchpads: scratchpad_list, + archived_scratchpads: archived_list, + new_scratchpad: None, + }, + ), + |headers| { + if let Ok(header_value) = HeaderValue::from_str(&trigger_value) { + headers.insert(HX_TRIGGER, header_value); + } }, - ); - - let mut response = template_response.into_response(); - if let Ok(header_value) = HeaderValue::from_str(&trigger_value) { - response.headers_mut().insert(HX_TRIGGER, header_value); - } - - Ok(response) + )) } pub async fn archive_scratchpad( RequireUser(user): RequireUser, State(state): State, Path(scratchpad_id): Path, -) -> Result { +) -> TemplateResult { Scratchpad::archive(&scratchpad_id, &user.id, &state.db, false).await?; let scratchpads = Scratchpad::get_by_user(&user.id, &state.db).await?; @@ -396,7 +399,7 @@ pub async fn restore_scratchpad( RequireUser(user): RequireUser, State(state): State, Path(scratchpad_id): Path, -) -> Result { +) -> ResponseResult { Scratchpad::restore(&scratchpad_id, &user.id, &state.db).await?; let scratchpads = Scratchpad::get_by_user(&user.id, &state.db).await?; @@ -420,22 +423,22 @@ pub async fn restore_scratchpad( r#"{"toast":{"title":"Scratchpad restored","description":"The scratchpad is back in your active list.","type":"info"}}"#.to_string() }); - let template_response = TemplateResponse::new_partial( - "scratchpad/base.html", - "main", - ScratchpadPageData { - scratchpads: scratchpad_list, - archived_scratchpads: archived_list, - new_scratchpad: None, + Ok(template_with_headers( + TemplateResponse::new_partial( + "scratchpad/base.html", + "main", + ScratchpadPageData { + scratchpads: scratchpad_list, + archived_scratchpads: archived_list, + new_scratchpad: None, + }, + ), + |headers| { + if let Ok(header_value) = HeaderValue::from_str(&trigger_value) { + headers.insert(HX_TRIGGER, header_value); + } }, - ); - - let mut response = template_response.into_response(); - if let Ok(header_value) = HeaderValue::from_str(&trigger_value) { - response.headers_mut().insert(HX_TRIGGER, header_value); - } - - Ok(response) + )) } #[cfg(test)] diff --git a/html-router/src/routes/search/handlers.rs b/html-router/src/routes/search/handlers.rs index 4021e2f..8716936 100644 --- a/html-router/src/routes/search/handlers.rs +++ b/html-router/src/routes/search/handlers.rs @@ -2,7 +2,6 @@ use std::collections::HashSet; use axum::{ extract::{Query, State}, - response::IntoResponse, }; use common::storage::types::{text_content::TextContent, user::User}; use retrieval_pipeline::{RetrievalConfig, SearchResult, SearchTarget, StrategyOutput}; @@ -13,7 +12,7 @@ use crate::{ html_state::HtmlState, middlewares::{ auth_middleware::RequireUser, - response_middleware::{HtmlError, TemplateResponse}, + response_middleware::{HtmlError, TemplateResponse, TemplateResult}, }, }; @@ -79,7 +78,7 @@ pub async fn search_result_handler( State(state): State, Query(params): Query, RequireUser(user): RequireUser, -) -> Result { +) -> TemplateResult { let (search_results_for_template, final_query_param_for_template) = if let Some(actual_query) = params.query { diff --git a/html-router/tests/router_integration.rs b/html-router/tests/router_integration.rs new file mode 100644 index 0000000..6317f9b --- /dev/null +++ b/html-router/tests/router_integration.rs @@ -0,0 +1,311 @@ +#![allow(clippy::expect_used)] + +use std::sync::Arc; + +use axum::{ + body::{to_bytes, Body}, + http::{header, Request, StatusCode}, + response::Response, + Router, +}; +use common::{ + storage::{ + db::SurrealDbClient, + store::StorageManager, + types::user::User, + }, + utils::{ + config::{AppConfig, StorageKind}, + embedding::EmbeddingProvider, + }, +}; +use html_router::{ + html_routes, + html_state::{HtmlState, StateResources}, +}; +use tower::ServiceExt; + +async fn build_test_app() -> (Router, Arc) { + let namespace = "html_router_test"; + let database = &uuid::Uuid::new_v4().to_string(); + let db = Arc::new( + SurrealDbClient::memory(namespace, database) + .await + .expect("in-memory db"), + ); + db.apply_migrations() + .await + .expect("migrations should apply"); + + let session_store = Arc::new( + db.create_session_store() + .await + .expect("session store"), + ); + + let config = AppConfig { + storage: StorageKind::Memory, + ..Default::default() + }; + + let storage = StorageManager::new(&config) + .await + .expect("storage manager"); + + let embedding_provider = Arc::new( + EmbeddingProvider::new_hashed(8).expect("embedding provider"), + ); + + let state = HtmlState::new_with_resources(StateResources { + db: Arc::clone(&db), + openai_client: Arc::new(async_openai::Client::new()), + session_store, + storage, + config, + reranker_pool: None, + embedding_provider, + template_engine: None, + }); + + let router = html_routes(&state).with_state(state); + (router, db) +} + +fn redirect_location(response: &Response) -> String { + response + .headers() + .get(header::LOCATION) + .or_else(|| response.headers().get("HX-Redirect")) + .expect("redirect response should include Location or HX-Redirect") + .to_str() + .expect("redirect header must be utf-8") + .to_string() +} + +fn session_cookie(response: &Response) -> String { + response + .headers() + .get_all(header::SET_COOKIE) + .iter() + .map(|value| { + value + .to_str() + .expect("set-cookie must be utf-8") + .split(';') + .next() + .expect("cookie key=value") + .to_string() + }) + .collect::>() + .join("; ") +} + +async fn response_body(response: Response) -> String { + let body = to_bytes(response.into_body(), usize::MAX) + .await + .expect("response body"); + String::from_utf8(body.to_vec()).expect("html body") +} + +async fn sign_in(app: &Router, email: &str, password: &str) -> String { + let response = app + .clone() + .oneshot( + Request::builder() + .method("POST") + .uri("/signin") + .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") + .body(Body::from(format!("email={email}&password={password}"))) + .expect("signin request"), + ) + .await + .expect("signin response"); + + assert!( + response.status().is_redirection() || response.status() == StatusCode::OK, + "signin should redirect or return ok" + ); + session_cookie(&response) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn protected_route_redirects_unauthenticated_users() { + let (app, _db) = build_test_app().await; + + let response = app + .clone() + .oneshot( + Request::builder() + .uri("/") + .body(Body::empty()) + .expect("dashboard request"), + ) + .await + .expect("dashboard response"); + + assert!( + response.status().is_redirection() || response.status() == StatusCode::OK, + "unauthenticated access should redirect via template middleware" + ); + assert_eq!(redirect_location(&response), "/signin"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn authenticated_user_receives_rendered_dashboard() { + let (app, db) = build_test_app().await; + + User::create_new( + "router_test@example.com".to_string(), + "test_password".to_string(), + &db, + "UTC".to_string(), + "system".to_string(), + ) + .await + .expect("test user"); + + let cookie = sign_in(&app, "router_test@example.com", "test_password").await; + + let response = app + .clone() + .oneshot( + Request::builder() + .uri("/") + .header(header::COOKIE, cookie) + .body(Body::empty()) + .expect("authenticated dashboard request"), + ) + .await + .expect("authenticated dashboard response"); + + assert_eq!(response.status(), StatusCode::OK); + + let body = to_bytes(response.into_body(), usize::MAX) + .await + .expect("response body"); + let html = String::from_utf8(body.to_vec()).expect("html body"); + assert!( + html.contains("dashboard") || html.contains("Dashboard"), + "dashboard template should render html" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn signin_form_is_public() { + let (app, _db) = build_test_app().await; + + let response = app + .clone() + .oneshot( + Request::builder() + .uri("/signin") + .body(Body::empty()) + .expect("signin form request"), + ) + .await + .expect("signin form response"); + + assert_eq!(response.status(), StatusCode::OK); + let html = response_body(response).await; + assert!( + html.contains("signin") || html.contains("Sign in") || html.contains("email"), + "signin page should render a form" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn signin_rejects_invalid_credentials() { + let (app, db) = build_test_app().await; + + User::create_new( + "signin_test@example.com".to_string(), + "correct_password".to_string(), + &db, + "UTC".to_string(), + "system".to_string(), + ) + .await + .expect("test user"); + + let response = app + .clone() + .oneshot( + Request::builder() + .method("POST") + .uri("/signin") + .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") + .body(Body::from( + "email=signin_test@example.com&password=wrong_password", + )) + .expect("invalid signin request"), + ) + .await + .expect("invalid signin response"); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + let html = response_body(response).await; + assert!( + html.contains("Incorrect email or password"), + "signin failure should render a safe error message" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn admin_route_redirects_non_admin_user() { + let (app, db) = build_test_app().await; + + User::create_new( + "admin_user@example.com".to_string(), + "admin_password".to_string(), + &db, + "UTC".to_string(), + "system".to_string(), + ) + .await + .expect("admin user"); + + User::create_new( + "member_user@example.com".to_string(), + "member_password".to_string(), + &db, + "UTC".to_string(), + "system".to_string(), + ) + .await + .expect("member user"); + + let member_cookie = sign_in(&app, "member_user@example.com", "member_password").await; + + let response = app + .clone() + .oneshot( + Request::builder() + .uri("/admin") + .header(header::COOKIE, member_cookie) + .body(Body::empty()) + .expect("non-admin admin request"), + ) + .await + .expect("non-admin admin response"); + + assert!( + response.status().is_redirection() || response.status() == StatusCode::OK, + "non-admin should be redirected away from admin" + ); + assert_eq!(redirect_location(&response), "/"); + + let admin_cookie = sign_in(&app, "admin_user@example.com", "admin_password").await; + let admin_response = app + .clone() + .oneshot( + Request::builder() + .uri("/admin") + .header(header::COOKIE, admin_cookie) + .body(Body::empty()) + .expect("admin request"), + ) + .await + .expect("admin response"); + + assert_eq!(admin_response.status(), StatusCode::OK); +}