chore: harden api-router errors and add router integration tests while slimming html handlers.

This commit is contained in:
Per Stark
2026-05-30 11:39:47 +02:00
parent 2aa92b6ad7
commit c70141de35
26 changed files with 814 additions and 260 deletions
Generated
+3
View File
@@ -247,7 +247,9 @@ dependencies = [
"tempfile",
"thiserror 1.0.69",
"tokio",
"tower",
"tracing",
"uuid",
]
[[package]]
@@ -2978,6 +2980,7 @@ dependencies = [
"thiserror 1.0.69",
"tokio",
"tokio-util",
"tower",
"tower-http",
"tower-serve-static",
"tracing",
+5
View File
@@ -20,3 +20,8 @@ futures = { workspace = true }
axum_typed_multipart = { workspace = true}
common = { path = "../common" }
[dev-dependencies]
common = { path = "../common", features = ["test-utils"] }
tower = "0.5"
uuid = { workspace = true }
-28
View File
@@ -11,31 +11,3 @@ pub struct ApiState {
pub config: AppConfig,
pub storage: StorageManager,
}
impl ApiState {
pub async fn new(
config: &AppConfig,
storage: StorageManager,
) -> anyhow::Result<Self> {
let surreal_db_client = Arc::new(
SurrealDbClient::new(
&config.surrealdb_address,
&config.surrealdb_username,
&config.surrealdb_password,
&config.surrealdb_namespace,
&config.surrealdb_database,
)
.await?,
);
surreal_db_client.apply_migrations().await?;
let app_state = Self {
db: Arc::clone(&surreal_db_client),
config: config.clone(),
storage,
};
Ok(app_state)
}
}
+19 -9
View File
@@ -7,7 +7,7 @@ use common::error::AppError;
use serde::Serialize;
use thiserror::Error;
#[derive(Error, Debug, Serialize, Clone)]
#[derive(Error, Debug)]
pub enum ApiErr {
#[error("internal server error")]
InternalError(String),
@@ -28,14 +28,13 @@ pub enum ApiErr {
impl From<AppError> for ApiErr {
fn from(err: AppError) -> Self {
match err {
AppError::Database(_) | AppError::OpenAI(_) => {
tracing::error!("Internal error: {:?}", err);
Self::InternalError("Internal server error".to_string())
}
AppError::NotFound(msg) => Self::NotFound(msg),
AppError::Validation(msg) => Self::ValidationError(msg),
AppError::Auth(msg) => Self::Unauthorized(msg),
_ => Self::InternalError("Internal server error".to_string()),
other => {
tracing::error!("internal API error: {other:?}");
Self::InternalError("Internal server error".to_string())
}
}
}
}
@@ -120,10 +119,21 @@ mod tests {
assert!(matches!(api_error, ApiErr::Unauthorized(msg) if msg == "unauthorized"));
// Test for internal errors - create a mock error that doesn't require surrealdb
let internal_error =
AppError::Io(io::Error::other("io error"));
let internal_error = AppError::Io(io::Error::other("io error"));
let api_error = ApiErr::from(internal_error);
assert!(matches!(api_error, ApiErr::InternalError(_)));
assert!(matches!(
api_error,
ApiErr::InternalError(msg) if msg == "Internal server error"
));
}
#[test]
fn test_app_error_internal_error_is_sanitized() {
let api_error = ApiErr::from(AppError::internal("db password incorrect"));
assert!(matches!(
api_error,
ApiErr::InternalError(msg) if msg == "Internal server error"
));
}
#[test]
+2 -2
View File
@@ -6,7 +6,7 @@ use axum::{
Router,
};
use middleware_api_auth::api_auth;
use routes::{categories::list, ingest::ingest_data, liveness::live, readiness::ready};
use routes::{categories::list, ingest::handle, liveness::live, readiness::ready};
pub mod api_state;
pub mod error;
@@ -28,7 +28,7 @@ where
let protected = Router::new()
.route(
"/ingest",
post(ingest_data).layer(DefaultBodyLimit::max(
post(handle).layer(DefaultBodyLimit::max(
app_state.config.ingest_max_body_bytes,
)),
)
+63 -4
View File
@@ -16,7 +16,7 @@ pub async fn api_auth(
let api_key = extract_api_key(&request)
.ok_or_else(|| ApiErr::Unauthorized("You have to be authenticated".to_string()))?;
let user = User::find_by_api_key(&api_key, &state.db).await?;
let user = User::find_by_api_key(api_key, &state.db).await?;
let user =
user.ok_or_else(|| ApiErr::Unauthorized("You have to be authenticated".to_string()))?;
@@ -25,7 +25,7 @@ pub async fn api_auth(
Ok(next.run(request).await)
}
fn extract_api_key(request: &Request) -> Option<String> {
fn extract_api_key(request: &Request) -> Option<&str> {
request
.headers()
.get("X-API-Key")
@@ -35,7 +35,66 @@ fn extract_api_key(request: &Request) -> Option<String> {
.headers()
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|auth| auth.strip_prefix("Bearer ").map(str::trim))
.and_then(|auth| auth.strip_prefix("Bearer "))
.map(str::trim)
})
.map(String::from)
}
#[cfg(test)]
#[allow(clippy::expect_used)]
mod tests {
use axum::body::Body;
use axum::http::{HeaderValue, Request};
use super::extract_api_key;
fn request_with_headers(headers: &[(&str, &str)]) -> Request<Body> {
let mut builder = Request::builder().method("GET").uri("/");
for (name, value) in headers {
builder = builder.header(*name, *value);
}
builder.body(Body::empty()).expect("test request")
}
#[test]
fn extract_api_key_from_x_api_key_header() {
let request = request_with_headers(&[("X-API-Key", "sk_test_key")]);
assert_eq!(extract_api_key(&request), Some("sk_test_key"));
}
#[test]
fn extract_api_key_from_bearer_authorization() {
let request = request_with_headers(&[("Authorization", "Bearer sk_bearer_key")]);
assert_eq!(extract_api_key(&request), Some("sk_bearer_key"));
}
#[test]
fn extract_api_key_prefers_x_api_key_over_authorization() {
let request = request_with_headers(&[
("X-API-Key", "sk_header"),
("Authorization", "Bearer sk_bearer"),
]);
assert_eq!(extract_api_key(&request), Some("sk_header"));
}
#[test]
fn extract_api_key_returns_none_when_missing() {
let request = request_with_headers(&[]);
assert_eq!(extract_api_key(&request), None);
}
#[test]
fn extract_api_key_rejects_non_bearer_authorization() {
let request = request_with_headers(&[("Authorization", "Basic abc")]);
assert_eq!(extract_api_key(&request), None);
}
#[test]
fn extract_api_key_rejects_invalid_header_values() {
let mut request = request_with_headers(&[]);
request
.headers_mut()
.insert("X-API-Key", HeaderValue::from_bytes(&[0xFF]).expect("invalid header"));
assert_eq!(extract_api_key(&request), None);
}
}
+8 -12
View File
@@ -16,7 +16,7 @@ use tracing::info;
use crate::{api_state::ApiState, error::ApiErr};
#[derive(Debug, TryFromMultipart)]
pub struct IngestParams {
pub struct Params {
pub content: Option<String>,
pub context: String,
pub category: String,
@@ -25,24 +25,20 @@ pub struct IngestParams {
pub files: Vec<FieldData<NamedTempFile>>,
}
pub async fn ingest_data(
pub async fn handle(
State(state): State<ApiState>,
Extension(user): Extension<User>,
TypedMultipart(input): TypedMultipart<IngestParams>,
TypedMultipart(input): TypedMultipart<Params>,
) -> Result<impl IntoResponse, ApiErr> {
let user_id = user.id;
let content_bytes = input.content.as_ref().map_or(0, |c| c.len());
let has_content = input.content.as_ref().is_some_and(|c| !c.trim().is_empty());
let context_bytes = input.context.len();
let category_bytes = input.category.len();
let file_count = input.files.len();
match validate_ingest_input(
&state.config,
input.content.as_deref(),
&input.context,
&input.category,
file_count,
input.files.len(),
) {
Ok(()) => {}
Err(IngestValidationError::PayloadTooLarge(message)) => {
@@ -56,10 +52,10 @@ pub async fn ingest_data(
info!(
user_id = %user_id,
has_content,
content_bytes,
context_bytes,
category_bytes,
file_count,
content_len = input.content.as_ref().map_or(0, String::len),
context_len = input.context.len(),
category_len = input.category.len(),
file_count = input.files.len(),
"Received ingest request"
);
+7 -4
View File
@@ -1,5 +1,6 @@
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
use serde_json::json;
use tracing::error;
use crate::api_state::ApiState;
@@ -13,13 +14,15 @@ pub async fn ready(State(state): State<ApiState>) -> impl IntoResponse {
"checks": { "db": "ok" }
})),
),
Err(e) => (
Err(e) => {
error!("readiness check failed: {e:?}");
(
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({
"status": "error",
"checks": { "db": "fail" },
"reason": e.to_string()
"checks": { "db": "fail" }
})),
),
)
}
}
}
+167
View File
@@ -0,0 +1,167 @@
#![allow(clippy::expect_used)]
use std::sync::Arc;
use api_router::{api_routes_v1, api_state::ApiState};
use axum::{
body::{to_bytes, Body},
http::{Request, StatusCode},
Router,
};
use common::{
storage::{
db::SurrealDbClient,
store::StorageManager,
types::user::User,
},
utils::config::{AppConfig, StorageKind},
};
use tower::ServiceExt;
async fn build_test_app() -> (Router, Arc<SurrealDbClient>) {
let namespace = "api_router_test";
let database = uuid::Uuid::new_v4().to_string();
let db = Arc::new(
SurrealDbClient::memory(namespace, &database)
.await
.expect("in-memory db"),
);
db.apply_migrations()
.await
.expect("migrations should apply");
let config = AppConfig {
storage: StorageKind::Memory,
..Default::default()
};
let storage = StorageManager::new(&config)
.await
.expect("storage manager");
let state = ApiState {
db: Arc::clone(&db),
config,
storage,
};
let router = api_routes_v1(&state).with_state(state);
(router, db)
}
async fn response_body(response: axum::response::Response) -> String {
let body = to_bytes(response.into_body(), usize::MAX)
.await
.expect("response body");
String::from_utf8(body.to_vec()).expect("utf-8 body")
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn live_probe_is_public() {
let (app, _db) = build_test_app().await;
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/live")
.body(Body::empty())
.expect("live request"),
)
.await
.expect("live response");
assert_eq!(response.status(), StatusCode::OK);
assert!(response_body(response).await.contains("\"status\":\"ok\""));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn ready_probe_is_public_and_reports_db_ok() {
let (app, _db) = build_test_app().await;
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/ready")
.body(Body::empty())
.expect("ready request"),
)
.await
.expect("ready response");
assert_eq!(response.status(), StatusCode::OK);
let body = response_body(response).await;
assert!(body.contains("\"checks\":{\"db\":\"ok\"}") || body.contains("\"db\":\"ok\""));
assert!(!body.contains("reason"));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn protected_route_requires_api_key() {
let (app, _db) = build_test_app().await;
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/categories")
.body(Body::empty())
.expect("categories request"),
)
.await
.expect("categories response");
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn protected_route_rejects_invalid_api_key() {
let (app, _db) = build_test_app().await;
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/categories")
.header("X-API-Key", "sk_invalid")
.body(Body::empty())
.expect("categories request"),
)
.await
.expect("categories response");
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn authenticated_user_can_list_categories() {
let (app, db) = build_test_app().await;
let user = User::create_new(
"api_router_test@example.com".to_string(),
"test_password".to_string(),
&db,
"UTC".to_string(),
"system".to_string(),
)
.await
.expect("test user");
let api_key = User::set_api_key(&user.id, &db)
.await
.expect("api key");
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/categories")
.header("X-API-Key", api_key)
.body(Body::empty())
.expect("categories request"),
)
.await
.expect("categories response");
assert_eq!(response.status(), StatusCode::OK);
}
+1
View File
@@ -43,6 +43,7 @@ json-stream-parser = { path = "../json-stream-parser" }
[dev-dependencies]
common = { path = "../common", features = ["test-utils"] }
tower = "0.5"
[build-dependencies]
minijinja-embed = { version = "2.8.0" }
@@ -116,6 +116,29 @@ impl IntoResponse for TemplateResponse {
}
}
/// Typical handler return type when no extra response headers or body are needed.
pub type TemplateResult = Result<TemplateResponse, HtmlError>;
/// Handler return type when the response needs custom headers or a non-template body.
pub type ResponseResult = Result<Response, HtmlError>;
/// Converts a [`TemplateResponse`] for [`ResponseResult`] handlers that do not set extra headers.
pub fn template_as_response(template: TemplateResponse) -> Response {
template.into_response()
}
/// Wraps a [`TemplateResponse`] for the template middleware and applies outbound headers.
///
/// Headers listed in [`HTMX_HEADERS_TO_FORWARD`] are copied onto the rendered HTML response.
pub fn template_with_headers(
template: TemplateResponse,
apply: impl FnOnce(&mut axum::http::HeaderMap),
) -> Response {
let mut response = template.into_response();
apply(response.headers_mut());
response
}
#[derive(Serialize)]
struct TemplateUser {
id: String,
@@ -143,7 +166,7 @@ struct ContextWrapper<'a> {
initial_theme: &'a str,
is_authenticated: bool,
user: Option<&'a TemplateUser>,
conversation_archive: Vec<SidebarConversation>,
conversation_archive: &'a [SidebarConversation],
#[serde(flatten)]
context: HashMap<String, Value>,
}
@@ -213,18 +236,14 @@ where
if let Some(template_response) = response.extensions().get::<TemplateResponse>().cloned() {
let template_engine = state.template_engine();
let mut conversation_archive = Vec::new();
let should_load_conversation_archive =
matches!(&template_response.template_kind, TemplateKind::Full(_));
if should_load_conversation_archive {
let cached_archive = if should_load_conversation_archive {
if let Some(user_id) = current_user.as_ref().map(|u| &u.id) {
let html_state = state.html_state();
if let Some(cached_archive) =
html_state.get_cached_conversation_archive(user_id).await
{
conversation_archive = cached_archive.to_vec();
if let Some(cached) = html_state.get_cached_conversation_archive(user_id).await {
Some(cached)
} else if let Ok(archive) =
Conversation::get_user_sidebar_conversations(user_id, &html_state.db).await
{
@@ -232,10 +251,19 @@ where
html_state
.set_cached_conversation_archive(user_id, Arc::clone(&cached))
.await;
conversation_archive = cached.to_vec();
}
Some(cached)
} else {
None
}
} else {
None
}
} else {
None
};
let conversation_archive = cached_archive
.as_ref()
.map_or(&[][..], |archive| archive.as_ref());
let context_map = match context_to_map(&template_response.context) {
Ok(map) => map,
+2 -2
View File
@@ -34,7 +34,7 @@ macro_rules! create_asset_service {
}};
}
pub type MiddleWareVecType<S> = Vec<Box<dyn FnOnce(Router<S>) -> Router<S> + Send>>;
pub type MiddlewareVec<S> = Vec<Box<dyn FnOnce(Router<S>) -> Router<S> + Send>>;
/// Builder for composing public/protected HTML routes and middleware layers.
pub struct RouterFactory<S> {
@@ -43,7 +43,7 @@ pub struct RouterFactory<S> {
protected_routers: Vec<Router<S>>,
nested_routes: Vec<(String, Router<S>)>,
nested_protected_routes: Vec<(String, Router<S>)>,
custom_middleware: MiddleWareVecType<S>,
custom_middleware: MiddlewareVec<S>,
public_assets_config: Option<AssetsConfig>,
compression_enabled: bool,
}
+9 -9
View File
@@ -1,11 +1,11 @@
use axum::{extract::State, response::IntoResponse, Form};
use axum::{extract::State, Form};
use chrono_tz::TZ_VARIANTS;
use serde::{Deserialize, Serialize};
use crate::{
middlewares::{
auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse},
response_middleware::{TemplateResponse, TemplateResult},
},
AuthSessionType,
};
@@ -28,7 +28,7 @@ pub struct AccountPageData {
pub async fn show_account_page(
RequireUser(user): RequireUser,
State(_state): State<HtmlState>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
let timezones = TZ_VARIANTS
.iter()
.map(std::string::ToString::to_string)
@@ -57,7 +57,7 @@ pub async fn set_api_key(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
auth: AuthSessionType,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
// Generate and set the API key
let api_key = User::set_api_key(&user.id, &state.db).await?;
@@ -82,7 +82,7 @@ pub async fn delete_account(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
auth: AuthSessionType,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
state.db.delete_item::<User>(&user.id).await?;
auth.logout_user();
@@ -102,7 +102,7 @@ pub async fn update_timezone(
RequireUser(user): RequireUser,
auth: AuthSessionType,
Form(form): Form<UpdateTimezoneForm>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
User::update_timezone(&user.id, &form.timezone, &state.db).await?;
// Clear the cache
@@ -137,7 +137,7 @@ pub async fn update_theme(
RequireUser(user): RequireUser,
auth: AuthSessionType,
Form(form): Form<UpdateThemeForm>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
User::update_theme(&user.id, &form.theme, &state.db).await?;
// Clear the cache
@@ -166,7 +166,7 @@ pub async fn update_theme(
pub async fn show_change_password(
RequireUser(_user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
Ok(TemplateResponse::new_template(
"auth/change_password_form.html",
(),
@@ -184,7 +184,7 @@ pub async fn change_password(
RequireUser(user): RequireUser,
auth: AuthSessionType,
Form(form): Form<NewPasswordForm>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
// Authenticate to make sure the password matches
let authenticated_user = User::authenticate(&user.email, &form.old_password, &state.db).await?;
+10 -11
View File
@@ -3,7 +3,6 @@ use std::sync::Arc;
use async_openai::types::ListModelResponse;
use axum::{
extract::{Query, State},
response::IntoResponse,
Form,
};
use serde::{Deserialize, Serialize};
@@ -26,7 +25,7 @@ use tracing::{error, info};
use crate::{
html_state::HtmlState,
middlewares::response_middleware::{HtmlError, TemplateResponse},
middlewares::response_middleware::{TemplateResponse, TemplateResult},
};
#[derive(Serialize)]
@@ -57,7 +56,7 @@ pub struct AdminPanelQuery {
pub async fn show_admin_panel(
State(state): State<HtmlState>,
Query(query): Query<AdminPanelQuery>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
let section = match query.section.as_deref() {
Some("models") => AdminSection::Models,
_ => AdminSection::Overview,
@@ -124,7 +123,7 @@ pub struct RegistrationToggleData {
pub async fn toggle_registration_status(
State(state): State<HtmlState>,
Form(input): Form<RegistrationToggleInput>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
let new_settings = SystemSettingsPatch {
registrations_enabled: Some(input.registration_open),
..Default::default()
@@ -160,7 +159,7 @@ pub struct ModelSettingsData {
pub async fn update_model_settings(
State(state): State<HtmlState>,
Form(input): Form<ModelSettingsInput>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
let current_settings = SystemSettings::get_current(&state.db).await?;
// Check if using FastEmbed - if so, embedding model/dimensions cannot be changed via UI
@@ -272,7 +271,7 @@ pub struct SystemPromptEditData {
pub async fn show_edit_system_prompt(
State(state): State<HtmlState>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
let settings = SystemSettings::get_current(&state.db).await?;
Ok(TemplateResponse::new_template(
@@ -297,7 +296,7 @@ pub struct SystemPromptSectionData {
pub async fn patch_query_prompt(
State(state): State<HtmlState>,
Form(input): Form<SystemPromptUpdateInput>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
let new_settings = SystemSettingsPatch {
query_system_prompt: Some(input.query_system_prompt),
..Default::default()
@@ -322,7 +321,7 @@ pub struct IngestionPromptEditData {
pub async fn show_edit_ingestion_prompt(
State(state): State<HtmlState>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
let settings = SystemSettings::get_current(&state.db).await?;
Ok(TemplateResponse::new_template(
@@ -342,7 +341,7 @@ pub struct IngestionPromptUpdateInput {
pub async fn patch_ingestion_prompt(
State(state): State<HtmlState>,
Form(input): Form<IngestionPromptUpdateInput>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
let new_settings = SystemSettingsPatch {
ingestion_system_prompt: Some(input.ingestion_system_prompt),
..Default::default()
@@ -367,7 +366,7 @@ pub struct ImagePromptEditData {
pub async fn show_edit_image_prompt(
State(state): State<HtmlState>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
let settings = SystemSettings::get_current(&state.db).await?;
Ok(TemplateResponse::new_template(
@@ -387,7 +386,7 @@ pub struct ImagePromptUpdateInput {
pub async fn patch_image_prompt(
State(state): State<HtmlState>,
Form(input): Form<ImagePromptUpdateInput>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
let new_settings = SystemSettingsPatch {
image_processing_prompt: Some(input.image_processing_prompt),
..Default::default()
+6 -6
View File
@@ -1,10 +1,10 @@
use axum::{extract::State, response::IntoResponse, Form};
use axum::{extract::State, Form};
use axum_htmx::HxBoosted;
use serde::{Deserialize, Serialize};
use crate::{
html_state::HtmlState,
middlewares::response_middleware::{HtmlError, TemplateResponse},
middlewares::response_middleware::{TemplateResponse, TemplateResult},
AuthSessionType,
};
use common::storage::types::user::User;
@@ -19,7 +19,7 @@ pub struct SignInParams {
pub async fn show_signin_form(
auth: AuthSessionType,
HxBoosted(boosted): HxBoosted,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
if auth.current_user.is_some() {
return Ok(TemplateResponse::redirect("/"));
}
@@ -38,9 +38,9 @@ pub async fn authenticate_user(
State(state): State<HtmlState>,
auth: AuthSessionType,
Form(form): Form<SignInParams>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
let Ok(user) = User::authenticate(&form.email, &form.password, &state.db).await else {
return Ok(TemplateResponse::bad_request("Incorrect email or password").into_response());
return Ok(TemplateResponse::bad_request("Incorrect email or password"));
};
auth.login_user(user.id);
@@ -49,5 +49,5 @@ pub async fn authenticate_user(
auth.remember_user(true);
}
Ok(TemplateResponse::redirect("/").into_response())
Ok(TemplateResponse::redirect("/"))
}
+2 -4
View File
@@ -1,11 +1,9 @@
use axum::response::IntoResponse;
use crate::{
middlewares::response_middleware::{HtmlError, TemplateResponse},
middlewares::response_middleware::{TemplateResponse, TemplateResult},
AuthSessionType,
};
pub async fn sign_out_user(auth: AuthSessionType) -> Result<impl IntoResponse, HtmlError> {
pub async fn sign_out_user(auth: AuthSessionType) -> TemplateResult {
if !auth.is_authenticated() {
return Ok(TemplateResponse::redirect("/"));
}
+6 -6
View File
@@ -1,4 +1,4 @@
use axum::{extract::State, response::IntoResponse, Form};
use axum::{extract::State, Form};
use axum_htmx::HxBoosted;
use serde::{Deserialize, Serialize};
@@ -6,7 +6,7 @@ use common::{error::AppError, storage::types::user::{Theme, User}};
use crate::{
html_state::HtmlState,
middlewares::response_middleware::{HtmlError, TemplateResponse},
middlewares::response_middleware::{TemplateResponse, TemplateResult},
AuthSessionType,
};
@@ -27,7 +27,7 @@ fn signup_error_message(err: &AppError) -> &str {
pub async fn show_signup_form(
auth: AuthSessionType,
HxBoosted(boosted): HxBoosted,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
if auth.current_user.is_some() {
return Ok(TemplateResponse::redirect("/"));
}
@@ -47,7 +47,7 @@ pub async fn process_signup_and_show_verification(
State(state): State<HtmlState>,
auth: AuthSessionType,
Form(form): Form<Params>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
let user = match User::create_new(
form.email,
form.password,
@@ -60,11 +60,11 @@ pub async fn process_signup_and_show_verification(
Ok(user) => user,
Err(err) => {
tracing::error!(?err, "signup failed");
return Ok(TemplateResponse::bad_request(signup_error_message(&err)).into_response());
return Ok(TemplateResponse::bad_request(signup_error_message(&err)));
}
};
auth.login_user(user.id);
Ok(TemplateResponse::redirect("/").into_response())
Ok(TemplateResponse::redirect("/"))
}
+37 -37
View File
@@ -1,7 +1,6 @@
use axum::{
extract::{Path, State},
http::HeaderValue,
response::IntoResponse,
Form,
};
use serde::{Deserialize, Serialize};
@@ -18,7 +17,10 @@ use crate::{
html_state::HtmlState,
middlewares::{
auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse},
response_middleware::{
template_as_response, template_with_headers, TemplateResponse, TemplateResult,
ResponseResult,
},
},
};
@@ -31,7 +33,7 @@ pub struct ChatPageData {
pub async fn show_chat_base(
State(_state): State<HtmlState>,
RequireUser(_user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
Ok(TemplateResponse::new_template(
"chat/base.html",
ChatPageData {
@@ -50,7 +52,7 @@ pub async fn show_existing_chat(
Path(conversation_id): Path<String>,
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
let (conversation, messages) =
Conversation::get_complete_conversation(conversation_id.as_str(), &user.id, &state.db)
.await?;
@@ -69,7 +71,7 @@ pub async fn new_user_message(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(form): Form<NewMessageForm>,
) -> Result<impl IntoResponse, HtmlError> {
) -> ResponseResult {
#[derive(Serialize)]
struct SSEResponseInitData {
user_message: Message,
@@ -82,31 +84,32 @@ pub async fn new_user_message(
.ok_or_else(|| AppError::NotFound("conversation was not found".into()))?;
if conversation.user_id != user.id {
return Ok(TemplateResponse::unauthorized().into_response());
return Ok(template_as_response(TemplateResponse::unauthorized()));
}
let user_message = Message::new(conversation_id, MessageRole::User, form.content, None);
state.db.store_item(user_message.clone()).await?;
let mut response = TemplateResponse::new_template(
let push_path = format!("/chat/{}", conversation.id);
Ok(template_with_headers(
TemplateResponse::new_template(
"chat/streaming_response.html",
SSEResponseInitData { user_message },
)
.into_response();
if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) {
response.headers_mut().insert("HX-Push", header_value);
),
|headers| {
if let Ok(header_value) = HeaderValue::from_str(&push_path) {
headers.insert("HX-Push", header_value);
}
Ok(response)
},
))
}
pub async fn new_chat_user_message(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(form): Form<NewMessageForm>,
) -> Result<impl IntoResponse, HtmlError> {
) -> ResponseResult {
#[derive(Serialize)]
struct SSEResponseInitData {
user_message: Message,
@@ -125,20 +128,21 @@ pub async fn new_chat_user_message(
state.db.store_item(user_message.clone()).await?;
state.invalidate_conversation_archive_cache(&user.id).await;
let mut response = TemplateResponse::new_template(
let push_path = format!("/chat/{}", conversation.id);
Ok(template_with_headers(
TemplateResponse::new_template(
"chat/new_chat_first_response.html",
SSEResponseInitData {
user_message,
conversation: conversation.clone(),
},
)
.into_response();
if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) {
response.headers_mut().insert("HX-Push", header_value);
),
|headers| {
if let Ok(header_value) = HeaderValue::from_str(&push_path) {
headers.insert("HX-Push", header_value);
}
Ok(response)
},
))
}
#[derive(Deserialize)]
@@ -155,7 +159,7 @@ pub async fn show_conversation_editing_title(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Path(conversation_id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
let conversation: Conversation = state
.db
.get_item(&conversation_id)
@@ -163,7 +167,7 @@ pub async fn show_conversation_editing_title(
.ok_or_else(|| AppError::NotFound("conversation not found".to_string()))?;
if conversation.user_id != user.id {
return Ok(TemplateResponse::unauthorized().into_response());
return Ok(TemplateResponse::unauthorized());
}
Ok(TemplateResponse::new_template(
@@ -171,8 +175,7 @@ pub async fn show_conversation_editing_title(
DrawerContext {
edit_conversation_id: Some(conversation_id),
},
)
.into_response())
))
}
pub async fn patch_conversation_title(
@@ -180,7 +183,7 @@ pub async fn patch_conversation_title(
RequireUser(user): RequireUser,
Path(conversation_id): Path<String>,
Form(form): Form<PatchConversationTitle>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
Conversation::patch_title(&conversation_id, &user.id, &form.title, &state.db).await?;
state.invalidate_conversation_archive_cache(&user.id).await;
@@ -189,15 +192,14 @@ pub async fn patch_conversation_title(
DrawerContext {
edit_conversation_id: None,
},
)
.into_response())
))
}
pub async fn delete_conversation(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Path(conversation_id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
let conversation: Conversation = state
.db
.get_item(&conversation_id)
@@ -205,7 +207,7 @@ pub async fn delete_conversation(
.ok_or_else(|| AppError::NotFound("conversation not found".to_string()))?;
if conversation.user_id != user.id {
return Ok(TemplateResponse::unauthorized().into_response());
return Ok(TemplateResponse::unauthorized());
}
state
@@ -219,18 +221,16 @@ pub async fn delete_conversation(
DrawerContext {
edit_conversation_id: None,
},
)
.into_response())
))
}
pub async fn reload_sidebar(
State(_state): State<HtmlState>,
RequireUser(_user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
Ok(TemplateResponse::new_template(
"sidebar.html",
DrawerContext {
edit_conversation_id: None,
},
)
.into_response())
))
}
+2 -3
View File
@@ -2,7 +2,6 @@
use axum::{
extract::{Path, State},
response::IntoResponse,
};
use chrono::{DateTime, Utc};
use chrono_tz::Tz;
@@ -16,7 +15,7 @@ use crate::{
html_state::HtmlState,
middlewares::{
auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse},
response_middleware::{TemplateResponse, TemplateResult},
},
};
@@ -45,7 +44,7 @@ pub async fn show_reference_tooltip(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Path(reference_id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
let Ok((normalized_reference_id, target)) = normalize_reference(&reference_id) else {
return Ok(TemplateResponse::not_found());
};
+7 -8
View File
@@ -1,6 +1,5 @@
use axum::{
extract::{Path, Query, State},
response::IntoResponse,
Form,
};
use axum_htmx::{HxBoosted, HxRequest, HxTarget};
@@ -15,7 +14,7 @@ use crate::{
html_state::HtmlState,
middlewares::{
auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse},
response_middleware::{TemplateResponse, TemplateResult},
},
utils::pagination::{paginate_items, Pagination},
utils::text_content_preview::truncate_text_contents,
@@ -50,7 +49,7 @@ pub async fn show_content_page(
Query(params): Query<FilterParams>,
HxRequest(is_htmx): HxRequest,
HxBoosted(is_boosted): HxBoosted,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
// Normalize empty strings to None
let category_filter = params
.category
@@ -101,7 +100,7 @@ pub async fn show_text_content_edit_form(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
#[derive(Serialize)]
pub struct TextContentEditModal {
pub text_content: TextContent,
@@ -127,7 +126,7 @@ pub async fn patch_text_content(
Path(id): Path<String>,
HxTarget(target): HxTarget,
Form(form): Form<PatchTextContentParams>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
TextContent::patch(&id, &form.context, &form.category, &form.text, &state.db).await?;
@@ -167,7 +166,7 @@ pub async fn delete_text_content(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
// Get and validate the text content
let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
@@ -213,7 +212,7 @@ pub async fn show_content_read_modal(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
#[derive(Serialize)]
pub struct TextContentReadModalData {
pub text_content: TextContent,
@@ -231,7 +230,7 @@ pub async fn show_content_read_modal(
pub async fn show_recent_content(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
let text_contents =
truncate_text_contents(User::get_latest_text_contents(&user.id, &state.db).await?);
+12 -10
View File
@@ -12,7 +12,9 @@ use crate::{
html_state::HtmlState,
middlewares::{
auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse},
response_middleware::{
template_as_response, HtmlError, TemplateResponse, TemplateResult, ResponseResult,
},
},
utils::text_content_preview::truncate_text_contents,
utils::truncate::with_ellipsis,
@@ -37,7 +39,7 @@ pub struct IndexPageData {
pub async fn index_handler(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
let (text_contents, dashboard_stats, active_jobs) = try_join!(
User::get_latest_text_contents(&user.id, &state.db),
User::get_dashboard_stats(&user.id, &state.db),
@@ -65,7 +67,7 @@ pub async fn delete_text_content(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
// Get and validate TextContent
let text_content = get_and_validate_text_content(&state, &id, &user).await?;
@@ -154,7 +156,7 @@ pub async fn delete_job(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
User::validate_and_delete_job(&id, &user.id, &state.db).await?;
let active_jobs = User::get_unfinished_ingestion_tasks(&user.id, &state.db).await?;
@@ -169,7 +171,7 @@ pub async fn delete_job(
pub async fn show_active_jobs(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
let active_jobs = User::get_unfinished_ingestion_tasks(&user.id, &state.db).await?;
Ok(TemplateResponse::new_template(
@@ -181,7 +183,7 @@ pub async fn show_active_jobs(
pub async fn show_task_archive(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
let tasks = User::get_all_ingestion_tasks(&user.id, &state.db).await?;
let entries: Vec<TaskArchiveEntry> = tasks
@@ -234,17 +236,17 @@ pub async fn serve_file(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Path(file_id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
) -> ResponseResult {
let Ok(file_info) = FileInfo::get_by_id(&file_id, &state.db).await else {
return Ok(TemplateResponse::not_found().into_response());
return Ok(template_as_response(TemplateResponse::not_found()));
};
if file_info.user_id != user.id {
return Ok(TemplateResponse::unauthorized().into_response());
return Ok(template_as_response(TemplateResponse::unauthorized()));
}
let Ok(stream) = state.storage.get_stream(&file_info.path).await else {
return Ok(TemplateResponse::server_error().into_response());
return Ok(template_as_response(TemplateResponse::server_error()));
};
let body = Body::from_stream(stream);
+14 -16
View File
@@ -5,7 +5,7 @@ use axum::{
http::StatusCode,
response::{
sse::{Event, KeepAlive, KeepAliveStream},
IntoResponse, Response, Sse,
Sse,
},
};
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
@@ -31,7 +31,7 @@ use crate::{
html_state::HtmlState,
middlewares::{
auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse},
response_middleware::{TemplateResponse, TemplateResult},
},
};
@@ -49,7 +49,7 @@ fn sse_with_keep_alive(stream: EventStream) -> TaskSse {
pub async fn show_ingest_form(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
#[derive(Serialize)]
pub struct ShowIngestFormData {
user_categories: Vec<String>,
@@ -65,7 +65,7 @@ pub async fn show_ingest_form(
pub async fn hide_ingest_form(
RequireUser(_user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
Ok(TemplateResponse::new_template(
"ingestion/add_content_button.html",
(),
@@ -91,12 +91,11 @@ pub async fn process_ingest_form(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
TypedMultipart(input): TypedMultipart<IngestionParams>,
) -> Result<Response, HtmlError> {
) -> TemplateResult {
if input.content.as_ref().is_none_or(|c| c.len() < 2) && input.files.is_empty() {
return Ok(
TemplateResponse::bad_request("You need to either add files or content")
.into_response(),
);
return Ok(TemplateResponse::bad_request(
"You need to either add files or content",
));
}
let content_bytes = input.content.as_ref().map_or(0, String::len);
@@ -118,11 +117,10 @@ pub async fn process_ingest_form(
StatusCode::PAYLOAD_TOO_LARGE,
"Payload Too Large",
&message,
)
.into_response());
));
}
Err(IngestValidationError::BadRequest(message)) => {
return Ok(TemplateResponse::bad_request(&message).into_response());
return Ok(TemplateResponse::bad_request(&message));
}
}
@@ -153,10 +151,10 @@ pub async fn process_ingest_form(
let tasks =
IngestionTask::create_all_and_add_to_db(payloads, &user.id, &state.db).await?;
Ok(
TemplateResponse::new_template("dashboard/current_task.html", NewTasksData { tasks })
.into_response(),
)
Ok(TemplateResponse::new_template(
"dashboard/current_task.html",
NewTasksData { tasks },
))
}
#[derive(Deserialize)]
+23 -21
View File
@@ -31,7 +31,9 @@ use crate::{
html_state::HtmlState,
middlewares::{
auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse},
response_middleware::{
template_with_headers, HtmlError, TemplateResponse, TemplateResult, ResponseResult,
},
},
utils::pagination::{paginate_items, Pagination},
};
@@ -120,12 +122,12 @@ fn collect_relationship_type_options(relationships: &[KnowledgeRelationship]) ->
options
}
fn respond_with_graph_refresh(response: TemplateResponse) -> Response {
let mut response = response.into_response();
fn graph_refresh_response(template: TemplateResponse) -> Response {
template_with_headers(template, |headers| {
if let Ok(value) = HeaderValue::from_str(GRAPH_REFRESH_TRIGGER) {
response.headers_mut().insert(HX_TRIGGER, value);
headers.insert(HX_TRIGGER, value);
}
response
})
}
#[derive(Deserialize, Default)]
@@ -138,7 +140,7 @@ pub struct FilterParams {
pub async fn show_new_knowledge_entity_form(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
let entity_types: Vec<String> = KnowledgeEntityType::variants()
.iter()
.map(ToString::to_string)
@@ -170,7 +172,7 @@ pub async fn create_knowledge_entity(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(form): Form<CreateKnowledgeEntityParams>,
) -> Result<impl IntoResponse, HtmlError> {
) -> ResponseResult {
let name = form.name.trim().to_string();
if name.is_empty() {
return Err(AppError::Validation("name is required".into()).into());
@@ -230,7 +232,7 @@ pub async fn create_knowledge_entity(
let default_params = FilterParams::default();
let kb_data = build_knowledge_base_data(&state, &user, &default_params).await?;
Ok(respond_with_graph_refresh(TemplateResponse::new_partial(
Ok(graph_refresh_response(TemplateResponse::new_partial(
"knowledge/base.html",
"main",
kb_data,
@@ -241,7 +243,7 @@ pub async fn suggest_knowledge_relationships(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(form): Form<SuggestRelationshipsParams>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
let entity_lookup: HashMap<String, KnowledgeEntity> =
User::get_knowledge_entities(&user.id, &state.db)
.await?
@@ -723,7 +725,7 @@ pub async fn show_knowledge_page(
HxRequest(is_htmx): HxRequest,
HxBoosted(is_boosted): HxBoosted,
Query(mut params): Query<FilterParams>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
// Normalize filters: treat empty or "none" as no filter
params.entity_type = normalize_filter(params.entity_type.take());
params.content_category = normalize_filter(params.content_category.take());
@@ -772,7 +774,7 @@ pub async fn get_knowledge_graph_json(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Query(mut params): Query<FilterParams>,
) -> Result<impl IntoResponse, HtmlError> {
) -> ResponseResult {
// Normalize filters: treat empty or "none" as no filter
params.entity_type = normalize_filter(params.entity_type.take());
params.content_category = normalize_filter(params.content_category.take());
@@ -821,7 +823,7 @@ pub async fn get_knowledge_graph_json(
})
.collect();
Ok(Json(GraphData { nodes, links }))
Ok(Json(GraphData { nodes, links }).into_response())
}
// Normalize filter parameters: convert empty strings or "none" (case-insensitive) to None
fn normalize_filter(input: Option<String>) -> Option<String> {
@@ -851,7 +853,7 @@ pub async fn show_edit_knowledge_entity_form(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
#[derive(Serialize)]
pub struct EntityData {
entity: KnowledgeEntity,
@@ -899,7 +901,7 @@ pub async fn patch_knowledge_entity(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(form): Form<PatchKnowledgeEntityParams>,
) -> Result<impl IntoResponse, HtmlError> {
) -> ResponseResult {
// Get the existing entity and validate that the user is allowed
User::get_and_validate_knowledge_entity(&form.id, &user.id, &state.db).await?;
@@ -930,7 +932,7 @@ pub async fn patch_knowledge_entity(
let content_categories = User::get_user_categories(&user.id, &state.db).await?;
// Render updated list
Ok(respond_with_graph_refresh(TemplateResponse::new_template(
Ok(graph_refresh_response(TemplateResponse::new_template(
"knowledge/entity_list.html",
EntityListData {
visible_entities,
@@ -948,7 +950,7 @@ pub async fn delete_knowledge_entity(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
) -> ResponseResult {
// Get the existing entity and validate that the user is allowed
User::get_and_validate_knowledge_entity(&id, &user.id, &state.db).await?;
@@ -968,7 +970,7 @@ pub async fn delete_knowledge_entity(
// Get content categories
let content_categories = User::get_user_categories(&user.id, &state.db).await?;
Ok(respond_with_graph_refresh(TemplateResponse::new_template(
Ok(graph_refresh_response(TemplateResponse::new_template(
"knowledge/entity_list.html",
EntityListData {
visible_entities,
@@ -994,7 +996,7 @@ pub async fn delete_knowledge_relationship(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
) -> ResponseResult {
KnowledgeRelationship::delete_relationship_by_id(&id, &user.id, &state.db).await?;
let entities = User::get_knowledge_entities(&user.id, &state.db).await?;
@@ -1003,7 +1005,7 @@ pub async fn delete_knowledge_relationship(
let table_data = build_relationship_table_data(entities, relationships);
// Render updated list
Ok(respond_with_graph_refresh(TemplateResponse::new_template(
Ok(graph_refresh_response(TemplateResponse::new_template(
"knowledge/relationship_table.html",
table_data,
)))
@@ -1020,7 +1022,7 @@ pub async fn save_knowledge_relationship(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(form): Form<SaveKnowledgeRelationshipInput>,
) -> Result<impl IntoResponse, HtmlError> {
) -> ResponseResult {
// Construct relationship
let relationship_type = canonicalize_relationship_type(&form.relationship_type);
let relationship = KnowledgeRelationship::new(
@@ -1039,7 +1041,7 @@ pub async fn save_knowledge_relationship(
let table_data = build_relationship_table_data(entities, relationships);
// Render updated list
Ok(respond_with_graph_refresh(TemplateResponse::new_template(
Ok(graph_refresh_response(TemplateResponse::new_template(
"knowledge/relationship_table.html",
table_data,
)))
+28 -25
View File
@@ -11,7 +11,9 @@ use serde::{Deserialize, Serialize};
use crate::html_state::HtmlState;
use crate::middlewares::{
auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse},
response_middleware::{
template_with_headers, HtmlError, TemplateResponse, TemplateResult, ResponseResult,
},
};
use common::storage::types::{
ingestion_payload::IngestionPayload, ingestion_task::IngestionTask, scratchpad::Scratchpad,
@@ -127,7 +129,7 @@ pub async fn show_scratchpad_page(
HxRequest(is_htmx): HxRequest,
HxBoosted(is_boosted): HxBoosted,
State(state): State<HtmlState>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
let scratchpads = Scratchpad::get_by_user(&user.id, &state.db).await?;
let archived_scratchpads = Scratchpad::get_archived_by_user(&user.id, &state.db).await?;
@@ -165,7 +167,7 @@ pub async fn show_scratchpad_modal(
State(state): State<HtmlState>,
Path(scratchpad_id): Path<String>,
Query(query): Query<EditTitleQuery>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
let scratchpad = Scratchpad::get_by_id(&scratchpad_id, &user.id, &state.db).await?;
let scratchpad_detail = ScratchpadDetail::from(&scratchpad);
@@ -186,7 +188,7 @@ pub async fn create_scratchpad(
RequireUser(user): RequireUser,
State(state): State<HtmlState>,
Form(form): Form<CreateScratchpadForm>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
let user_id = user.id.clone();
let scratchpad = Scratchpad::new(user_id.clone(), form.title);
let _stored = state.db.store_item(scratchpad.clone()).await?;
@@ -217,7 +219,7 @@ pub async fn auto_save_scratchpad(
State(state): State<HtmlState>,
Path(scratchpad_id): Path<String>,
Form(form): Form<UpdateScratchpadForm>,
) -> Result<impl IntoResponse, HtmlError> {
) -> ResponseResult {
let updated =
Scratchpad::update_content(&scratchpad_id, &user.id, &form.content, &state.db).await?;
@@ -229,7 +231,8 @@ pub async fn auto_save_scratchpad(
.format("%Y-%m-%d %H:%M:%S")
.to_string(),
last_saved_at_iso: updated.last_saved_at.to_rfc3339(),
}))
})
.into_response())
}
pub async fn update_scratchpad_title(
@@ -237,7 +240,7 @@ pub async fn update_scratchpad_title(
State(state): State<HtmlState>,
Path(scratchpad_id): Path<String>,
Form(form): Form<UpdateTitleForm>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
Scratchpad::update_title(&scratchpad_id, &user.id, &form.title, &state.db).await?;
let scratchpad = Scratchpad::get_by_id(&scratchpad_id, &user.id, &state.db).await?;
@@ -255,7 +258,7 @@ pub async fn delete_scratchpad(
RequireUser(user): RequireUser,
State(state): State<HtmlState>,
Path(scratchpad_id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
Scratchpad::delete(&scratchpad_id, &user.id, &state.db).await?;
// Return the updated main section content
@@ -284,7 +287,7 @@ pub async fn ingest_scratchpad(
RequireUser(user): RequireUser,
State(state): State<HtmlState>,
Path(scratchpad_id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
) -> ResponseResult {
let scratchpad = Scratchpad::get_by_id(&scratchpad_id, &user.id, &state.db).await?;
if scratchpad.content.trim().is_empty() {
@@ -347,7 +350,8 @@ pub async fn ingest_scratchpad(
r#"{"toast":{"title":"Ingestion queued","description":"Scratchpad archived and added to the ingestion queue.","type":"success"}}"#.to_string()
});
let template_response = TemplateResponse::new_partial(
Ok(template_with_headers(
TemplateResponse::new_partial(
"scratchpad/base.html",
"main",
ScratchpadPageData {
@@ -355,21 +359,20 @@ pub async fn ingest_scratchpad(
archived_scratchpads: archived_list,
new_scratchpad: None,
},
);
let mut response = template_response.into_response();
),
|headers| {
if let Ok(header_value) = HeaderValue::from_str(&trigger_value) {
response.headers_mut().insert(HX_TRIGGER, header_value);
headers.insert(HX_TRIGGER, header_value);
}
Ok(response)
},
))
}
pub async fn archive_scratchpad(
RequireUser(user): RequireUser,
State(state): State<HtmlState>,
Path(scratchpad_id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
Scratchpad::archive(&scratchpad_id, &user.id, &state.db, false).await?;
let scratchpads = Scratchpad::get_by_user(&user.id, &state.db).await?;
@@ -396,7 +399,7 @@ pub async fn restore_scratchpad(
RequireUser(user): RequireUser,
State(state): State<HtmlState>,
Path(scratchpad_id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
) -> ResponseResult {
Scratchpad::restore(&scratchpad_id, &user.id, &state.db).await?;
let scratchpads = Scratchpad::get_by_user(&user.id, &state.db).await?;
@@ -420,7 +423,8 @@ pub async fn restore_scratchpad(
r#"{"toast":{"title":"Scratchpad restored","description":"The scratchpad is back in your active list.","type":"info"}}"#.to_string()
});
let template_response = TemplateResponse::new_partial(
Ok(template_with_headers(
TemplateResponse::new_partial(
"scratchpad/base.html",
"main",
ScratchpadPageData {
@@ -428,14 +432,13 @@ pub async fn restore_scratchpad(
archived_scratchpads: archived_list,
new_scratchpad: None,
},
);
let mut response = template_response.into_response();
),
|headers| {
if let Ok(header_value) = HeaderValue::from_str(&trigger_value) {
response.headers_mut().insert(HX_TRIGGER, header_value);
headers.insert(HX_TRIGGER, header_value);
}
Ok(response)
},
))
}
#[cfg(test)]
+2 -3
View File
@@ -2,7 +2,6 @@ use std::collections::HashSet;
use axum::{
extract::{Query, State},
response::IntoResponse,
};
use common::storage::types::{text_content::TextContent, user::User};
use retrieval_pipeline::{RetrievalConfig, SearchResult, SearchTarget, StrategyOutput};
@@ -13,7 +12,7 @@ use crate::{
html_state::HtmlState,
middlewares::{
auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse},
response_middleware::{HtmlError, TemplateResponse, TemplateResult},
},
};
@@ -79,7 +78,7 @@ pub async fn search_result_handler(
State(state): State<HtmlState>,
Query(params): Query<SearchParams>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
) -> TemplateResult {
let (search_results_for_template, final_query_param_for_template) = if let Some(actual_query) =
params.query
{
+311
View File
@@ -0,0 +1,311 @@
#![allow(clippy::expect_used)]
use std::sync::Arc;
use axum::{
body::{to_bytes, Body},
http::{header, Request, StatusCode},
response::Response,
Router,
};
use common::{
storage::{
db::SurrealDbClient,
store::StorageManager,
types::user::User,
},
utils::{
config::{AppConfig, StorageKind},
embedding::EmbeddingProvider,
},
};
use html_router::{
html_routes,
html_state::{HtmlState, StateResources},
};
use tower::ServiceExt;
async fn build_test_app() -> (Router, Arc<SurrealDbClient>) {
let namespace = "html_router_test";
let database = &uuid::Uuid::new_v4().to_string();
let db = Arc::new(
SurrealDbClient::memory(namespace, database)
.await
.expect("in-memory db"),
);
db.apply_migrations()
.await
.expect("migrations should apply");
let session_store = Arc::new(
db.create_session_store()
.await
.expect("session store"),
);
let config = AppConfig {
storage: StorageKind::Memory,
..Default::default()
};
let storage = StorageManager::new(&config)
.await
.expect("storage manager");
let embedding_provider = Arc::new(
EmbeddingProvider::new_hashed(8).expect("embedding provider"),
);
let state = HtmlState::new_with_resources(StateResources {
db: Arc::clone(&db),
openai_client: Arc::new(async_openai::Client::new()),
session_store,
storage,
config,
reranker_pool: None,
embedding_provider,
template_engine: None,
});
let router = html_routes(&state).with_state(state);
(router, db)
}
fn redirect_location(response: &Response) -> String {
response
.headers()
.get(header::LOCATION)
.or_else(|| response.headers().get("HX-Redirect"))
.expect("redirect response should include Location or HX-Redirect")
.to_str()
.expect("redirect header must be utf-8")
.to_string()
}
fn session_cookie(response: &Response) -> String {
response
.headers()
.get_all(header::SET_COOKIE)
.iter()
.map(|value| {
value
.to_str()
.expect("set-cookie must be utf-8")
.split(';')
.next()
.expect("cookie key=value")
.to_string()
})
.collect::<Vec<_>>()
.join("; ")
}
async fn response_body(response: Response) -> String {
let body = to_bytes(response.into_body(), usize::MAX)
.await
.expect("response body");
String::from_utf8(body.to_vec()).expect("html body")
}
async fn sign_in(app: &Router, email: &str, password: &str) -> String {
let response = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/signin")
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.body(Body::from(format!("email={email}&password={password}")))
.expect("signin request"),
)
.await
.expect("signin response");
assert!(
response.status().is_redirection() || response.status() == StatusCode::OK,
"signin should redirect or return ok"
);
session_cookie(&response)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn protected_route_redirects_unauthenticated_users() {
let (app, _db) = build_test_app().await;
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/")
.body(Body::empty())
.expect("dashboard request"),
)
.await
.expect("dashboard response");
assert!(
response.status().is_redirection() || response.status() == StatusCode::OK,
"unauthenticated access should redirect via template middleware"
);
assert_eq!(redirect_location(&response), "/signin");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn authenticated_user_receives_rendered_dashboard() {
let (app, db) = build_test_app().await;
User::create_new(
"router_test@example.com".to_string(),
"test_password".to_string(),
&db,
"UTC".to_string(),
"system".to_string(),
)
.await
.expect("test user");
let cookie = sign_in(&app, "router_test@example.com", "test_password").await;
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/")
.header(header::COOKIE, cookie)
.body(Body::empty())
.expect("authenticated dashboard request"),
)
.await
.expect("authenticated dashboard response");
assert_eq!(response.status(), StatusCode::OK);
let body = to_bytes(response.into_body(), usize::MAX)
.await
.expect("response body");
let html = String::from_utf8(body.to_vec()).expect("html body");
assert!(
html.contains("dashboard") || html.contains("Dashboard"),
"dashboard template should render html"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn signin_form_is_public() {
let (app, _db) = build_test_app().await;
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/signin")
.body(Body::empty())
.expect("signin form request"),
)
.await
.expect("signin form response");
assert_eq!(response.status(), StatusCode::OK);
let html = response_body(response).await;
assert!(
html.contains("signin") || html.contains("Sign in") || html.contains("email"),
"signin page should render a form"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn signin_rejects_invalid_credentials() {
let (app, db) = build_test_app().await;
User::create_new(
"signin_test@example.com".to_string(),
"correct_password".to_string(),
&db,
"UTC".to_string(),
"system".to_string(),
)
.await
.expect("test user");
let response = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/signin")
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.body(Body::from(
"email=signin_test@example.com&password=wrong_password",
))
.expect("invalid signin request"),
)
.await
.expect("invalid signin response");
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
let html = response_body(response).await;
assert!(
html.contains("Incorrect email or password"),
"signin failure should render a safe error message"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn admin_route_redirects_non_admin_user() {
let (app, db) = build_test_app().await;
User::create_new(
"admin_user@example.com".to_string(),
"admin_password".to_string(),
&db,
"UTC".to_string(),
"system".to_string(),
)
.await
.expect("admin user");
User::create_new(
"member_user@example.com".to_string(),
"member_password".to_string(),
&db,
"UTC".to_string(),
"system".to_string(),
)
.await
.expect("member user");
let member_cookie = sign_in(&app, "member_user@example.com", "member_password").await;
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/admin")
.header(header::COOKIE, member_cookie)
.body(Body::empty())
.expect("non-admin admin request"),
)
.await
.expect("non-admin admin response");
assert!(
response.status().is_redirection() || response.status() == StatusCode::OK,
"non-admin should be redirected away from admin"
);
assert_eq!(redirect_location(&response), "/");
let admin_cookie = sign_in(&app, "admin_user@example.com", "admin_password").await;
let admin_response = app
.clone()
.oneshot(
Request::builder()
.uri("/admin")
.header(header::COOKIE, admin_cookie)
.body(Body::empty())
.expect("admin request"),
)
.await
.expect("admin response");
assert_eq!(admin_response.status(), StatusCode::OK);
}