chore: harden api-router errors and add router integration tests while slimming html handlers.

This commit is contained in:
Per Stark
2026-05-30 11:39:47 +02:00
parent 2aa92b6ad7
commit c70141de35
26 changed files with 814 additions and 260 deletions
Generated
+3
View File
@@ -247,7 +247,9 @@ dependencies = [
"tempfile", "tempfile",
"thiserror 1.0.69", "thiserror 1.0.69",
"tokio", "tokio",
"tower",
"tracing", "tracing",
"uuid",
] ]
[[package]] [[package]]
@@ -2978,6 +2980,7 @@ dependencies = [
"thiserror 1.0.69", "thiserror 1.0.69",
"tokio", "tokio",
"tokio-util", "tokio-util",
"tower",
"tower-http", "tower-http",
"tower-serve-static", "tower-serve-static",
"tracing", "tracing",
+5
View File
@@ -20,3 +20,8 @@ futures = { workspace = true }
axum_typed_multipart = { workspace = true} axum_typed_multipart = { workspace = true}
common = { path = "../common" } common = { path = "../common" }
[dev-dependencies]
common = { path = "../common", features = ["test-utils"] }
tower = "0.5"
uuid = { workspace = true }
-28
View File
@@ -11,31 +11,3 @@ pub struct ApiState {
pub config: AppConfig, pub config: AppConfig,
pub storage: StorageManager, pub storage: StorageManager,
} }
impl ApiState {
pub async fn new(
config: &AppConfig,
storage: StorageManager,
) -> anyhow::Result<Self> {
let surreal_db_client = Arc::new(
SurrealDbClient::new(
&config.surrealdb_address,
&config.surrealdb_username,
&config.surrealdb_password,
&config.surrealdb_namespace,
&config.surrealdb_database,
)
.await?,
);
surreal_db_client.apply_migrations().await?;
let app_state = Self {
db: Arc::clone(&surreal_db_client),
config: config.clone(),
storage,
};
Ok(app_state)
}
}
+19 -9
View File
@@ -7,7 +7,7 @@ use common::error::AppError;
use serde::Serialize; use serde::Serialize;
use thiserror::Error; use thiserror::Error;
#[derive(Error, Debug, Serialize, Clone)] #[derive(Error, Debug)]
pub enum ApiErr { pub enum ApiErr {
#[error("internal server error")] #[error("internal server error")]
InternalError(String), InternalError(String),
@@ -28,14 +28,13 @@ pub enum ApiErr {
impl From<AppError> for ApiErr { impl From<AppError> for ApiErr {
fn from(err: AppError) -> Self { fn from(err: AppError) -> Self {
match err { match err {
AppError::Database(_) | AppError::OpenAI(_) => {
tracing::error!("Internal error: {:?}", err);
Self::InternalError("Internal server error".to_string())
}
AppError::NotFound(msg) => Self::NotFound(msg), AppError::NotFound(msg) => Self::NotFound(msg),
AppError::Validation(msg) => Self::ValidationError(msg), AppError::Validation(msg) => Self::ValidationError(msg),
AppError::Auth(msg) => Self::Unauthorized(msg), AppError::Auth(msg) => Self::Unauthorized(msg),
_ => Self::InternalError("Internal server error".to_string()), other => {
tracing::error!("internal API error: {other:?}");
Self::InternalError("Internal server error".to_string())
}
} }
} }
} }
@@ -120,10 +119,21 @@ mod tests {
assert!(matches!(api_error, ApiErr::Unauthorized(msg) if msg == "unauthorized")); assert!(matches!(api_error, ApiErr::Unauthorized(msg) if msg == "unauthorized"));
// Test for internal errors - create a mock error that doesn't require surrealdb // Test for internal errors - create a mock error that doesn't require surrealdb
let internal_error = let internal_error = AppError::Io(io::Error::other("io error"));
AppError::Io(io::Error::other("io error"));
let api_error = ApiErr::from(internal_error); let api_error = ApiErr::from(internal_error);
assert!(matches!(api_error, ApiErr::InternalError(_))); assert!(matches!(
api_error,
ApiErr::InternalError(msg) if msg == "Internal server error"
));
}
#[test]
fn test_app_error_internal_error_is_sanitized() {
let api_error = ApiErr::from(AppError::internal("db password incorrect"));
assert!(matches!(
api_error,
ApiErr::InternalError(msg) if msg == "Internal server error"
));
} }
#[test] #[test]
+2 -2
View File
@@ -6,7 +6,7 @@ use axum::{
Router, Router,
}; };
use middleware_api_auth::api_auth; use middleware_api_auth::api_auth;
use routes::{categories::list, ingest::ingest_data, liveness::live, readiness::ready}; use routes::{categories::list, ingest::handle, liveness::live, readiness::ready};
pub mod api_state; pub mod api_state;
pub mod error; pub mod error;
@@ -28,7 +28,7 @@ where
let protected = Router::new() let protected = Router::new()
.route( .route(
"/ingest", "/ingest",
post(ingest_data).layer(DefaultBodyLimit::max( post(handle).layer(DefaultBodyLimit::max(
app_state.config.ingest_max_body_bytes, app_state.config.ingest_max_body_bytes,
)), )),
) )
+63 -4
View File
@@ -16,7 +16,7 @@ pub async fn api_auth(
let api_key = extract_api_key(&request) let api_key = extract_api_key(&request)
.ok_or_else(|| ApiErr::Unauthorized("You have to be authenticated".to_string()))?; .ok_or_else(|| ApiErr::Unauthorized("You have to be authenticated".to_string()))?;
let user = User::find_by_api_key(&api_key, &state.db).await?; let user = User::find_by_api_key(api_key, &state.db).await?;
let user = let user =
user.ok_or_else(|| ApiErr::Unauthorized("You have to be authenticated".to_string()))?; user.ok_or_else(|| ApiErr::Unauthorized("You have to be authenticated".to_string()))?;
@@ -25,7 +25,7 @@ pub async fn api_auth(
Ok(next.run(request).await) Ok(next.run(request).await)
} }
fn extract_api_key(request: &Request) -> Option<String> { fn extract_api_key(request: &Request) -> Option<&str> {
request request
.headers() .headers()
.get("X-API-Key") .get("X-API-Key")
@@ -35,7 +35,66 @@ fn extract_api_key(request: &Request) -> Option<String> {
.headers() .headers()
.get("Authorization") .get("Authorization")
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())
.and_then(|auth| auth.strip_prefix("Bearer ").map(str::trim)) .and_then(|auth| auth.strip_prefix("Bearer "))
.map(str::trim)
}) })
.map(String::from) }
#[cfg(test)]
#[allow(clippy::expect_used)]
mod tests {
use axum::body::Body;
use axum::http::{HeaderValue, Request};
use super::extract_api_key;
fn request_with_headers(headers: &[(&str, &str)]) -> Request<Body> {
let mut builder = Request::builder().method("GET").uri("/");
for (name, value) in headers {
builder = builder.header(*name, *value);
}
builder.body(Body::empty()).expect("test request")
}
#[test]
fn extract_api_key_from_x_api_key_header() {
let request = request_with_headers(&[("X-API-Key", "sk_test_key")]);
assert_eq!(extract_api_key(&request), Some("sk_test_key"));
}
#[test]
fn extract_api_key_from_bearer_authorization() {
let request = request_with_headers(&[("Authorization", "Bearer sk_bearer_key")]);
assert_eq!(extract_api_key(&request), Some("sk_bearer_key"));
}
#[test]
fn extract_api_key_prefers_x_api_key_over_authorization() {
let request = request_with_headers(&[
("X-API-Key", "sk_header"),
("Authorization", "Bearer sk_bearer"),
]);
assert_eq!(extract_api_key(&request), Some("sk_header"));
}
#[test]
fn extract_api_key_returns_none_when_missing() {
let request = request_with_headers(&[]);
assert_eq!(extract_api_key(&request), None);
}
#[test]
fn extract_api_key_rejects_non_bearer_authorization() {
let request = request_with_headers(&[("Authorization", "Basic abc")]);
assert_eq!(extract_api_key(&request), None);
}
#[test]
fn extract_api_key_rejects_invalid_header_values() {
let mut request = request_with_headers(&[]);
request
.headers_mut()
.insert("X-API-Key", HeaderValue::from_bytes(&[0xFF]).expect("invalid header"));
assert_eq!(extract_api_key(&request), None);
}
} }
+8 -12
View File
@@ -16,7 +16,7 @@ use tracing::info;
use crate::{api_state::ApiState, error::ApiErr}; use crate::{api_state::ApiState, error::ApiErr};
#[derive(Debug, TryFromMultipart)] #[derive(Debug, TryFromMultipart)]
pub struct IngestParams { pub struct Params {
pub content: Option<String>, pub content: Option<String>,
pub context: String, pub context: String,
pub category: String, pub category: String,
@@ -25,24 +25,20 @@ pub struct IngestParams {
pub files: Vec<FieldData<NamedTempFile>>, pub files: Vec<FieldData<NamedTempFile>>,
} }
pub async fn ingest_data( pub async fn handle(
State(state): State<ApiState>, State(state): State<ApiState>,
Extension(user): Extension<User>, Extension(user): Extension<User>,
TypedMultipart(input): TypedMultipart<IngestParams>, TypedMultipart(input): TypedMultipart<Params>,
) -> Result<impl IntoResponse, ApiErr> { ) -> Result<impl IntoResponse, ApiErr> {
let user_id = user.id; let user_id = user.id;
let content_bytes = input.content.as_ref().map_or(0, |c| c.len());
let has_content = input.content.as_ref().is_some_and(|c| !c.trim().is_empty()); let has_content = input.content.as_ref().is_some_and(|c| !c.trim().is_empty());
let context_bytes = input.context.len();
let category_bytes = input.category.len();
let file_count = input.files.len();
match validate_ingest_input( match validate_ingest_input(
&state.config, &state.config,
input.content.as_deref(), input.content.as_deref(),
&input.context, &input.context,
&input.category, &input.category,
file_count, input.files.len(),
) { ) {
Ok(()) => {} Ok(()) => {}
Err(IngestValidationError::PayloadTooLarge(message)) => { Err(IngestValidationError::PayloadTooLarge(message)) => {
@@ -56,10 +52,10 @@ pub async fn ingest_data(
info!( info!(
user_id = %user_id, user_id = %user_id,
has_content, has_content,
content_bytes, content_len = input.content.as_ref().map_or(0, String::len),
context_bytes, context_len = input.context.len(),
category_bytes, category_len = input.category.len(),
file_count, file_count = input.files.len(),
"Received ingest request" "Received ingest request"
); );
+11 -8
View File
@@ -1,5 +1,6 @@
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json}; use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
use serde_json::json; use serde_json::json;
use tracing::error;
use crate::api_state::ApiState; use crate::api_state::ApiState;
@@ -13,13 +14,15 @@ pub async fn ready(State(state): State<ApiState>) -> impl IntoResponse {
"checks": { "db": "ok" } "checks": { "db": "ok" }
})), })),
), ),
Err(e) => ( Err(e) => {
StatusCode::SERVICE_UNAVAILABLE, error!("readiness check failed: {e:?}");
Json(json!({ (
"status": "error", StatusCode::SERVICE_UNAVAILABLE,
"checks": { "db": "fail" }, Json(json!({
"reason": e.to_string() "status": "error",
})), "checks": { "db": "fail" }
), })),
)
}
} }
} }
+167
View File
@@ -0,0 +1,167 @@
#![allow(clippy::expect_used)]
use std::sync::Arc;
use api_router::{api_routes_v1, api_state::ApiState};
use axum::{
body::{to_bytes, Body},
http::{Request, StatusCode},
Router,
};
use common::{
storage::{
db::SurrealDbClient,
store::StorageManager,
types::user::User,
},
utils::config::{AppConfig, StorageKind},
};
use tower::ServiceExt;
async fn build_test_app() -> (Router, Arc<SurrealDbClient>) {
let namespace = "api_router_test";
let database = uuid::Uuid::new_v4().to_string();
let db = Arc::new(
SurrealDbClient::memory(namespace, &database)
.await
.expect("in-memory db"),
);
db.apply_migrations()
.await
.expect("migrations should apply");
let config = AppConfig {
storage: StorageKind::Memory,
..Default::default()
};
let storage = StorageManager::new(&config)
.await
.expect("storage manager");
let state = ApiState {
db: Arc::clone(&db),
config,
storage,
};
let router = api_routes_v1(&state).with_state(state);
(router, db)
}
async fn response_body(response: axum::response::Response) -> String {
let body = to_bytes(response.into_body(), usize::MAX)
.await
.expect("response body");
String::from_utf8(body.to_vec()).expect("utf-8 body")
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn live_probe_is_public() {
let (app, _db) = build_test_app().await;
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/live")
.body(Body::empty())
.expect("live request"),
)
.await
.expect("live response");
assert_eq!(response.status(), StatusCode::OK);
assert!(response_body(response).await.contains("\"status\":\"ok\""));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn ready_probe_is_public_and_reports_db_ok() {
let (app, _db) = build_test_app().await;
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/ready")
.body(Body::empty())
.expect("ready request"),
)
.await
.expect("ready response");
assert_eq!(response.status(), StatusCode::OK);
let body = response_body(response).await;
assert!(body.contains("\"checks\":{\"db\":\"ok\"}") || body.contains("\"db\":\"ok\""));
assert!(!body.contains("reason"));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn protected_route_requires_api_key() {
let (app, _db) = build_test_app().await;
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/categories")
.body(Body::empty())
.expect("categories request"),
)
.await
.expect("categories response");
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn protected_route_rejects_invalid_api_key() {
let (app, _db) = build_test_app().await;
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/categories")
.header("X-API-Key", "sk_invalid")
.body(Body::empty())
.expect("categories request"),
)
.await
.expect("categories response");
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn authenticated_user_can_list_categories() {
let (app, db) = build_test_app().await;
let user = User::create_new(
"api_router_test@example.com".to_string(),
"test_password".to_string(),
&db,
"UTC".to_string(),
"system".to_string(),
)
.await
.expect("test user");
let api_key = User::set_api_key(&user.id, &db)
.await
.expect("api key");
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/categories")
.header("X-API-Key", api_key)
.body(Body::empty())
.expect("categories request"),
)
.await
.expect("categories response");
assert_eq!(response.status(), StatusCode::OK);
}
+1
View File
@@ -43,6 +43,7 @@ json-stream-parser = { path = "../json-stream-parser" }
[dev-dependencies] [dev-dependencies]
common = { path = "../common", features = ["test-utils"] } common = { path = "../common", features = ["test-utils"] }
tower = "0.5"
[build-dependencies] [build-dependencies]
minijinja-embed = { version = "2.8.0" } minijinja-embed = { version = "2.8.0" }
@@ -116,6 +116,29 @@ impl IntoResponse for TemplateResponse {
} }
} }
/// Typical handler return type when no extra response headers or body are needed.
pub type TemplateResult = Result<TemplateResponse, HtmlError>;
/// Handler return type when the response needs custom headers or a non-template body.
pub type ResponseResult = Result<Response, HtmlError>;
/// Converts a [`TemplateResponse`] for [`ResponseResult`] handlers that do not set extra headers.
pub fn template_as_response(template: TemplateResponse) -> Response {
template.into_response()
}
/// Wraps a [`TemplateResponse`] for the template middleware and applies outbound headers.
///
/// Headers listed in [`HTMX_HEADERS_TO_FORWARD`] are copied onto the rendered HTML response.
pub fn template_with_headers(
template: TemplateResponse,
apply: impl FnOnce(&mut axum::http::HeaderMap),
) -> Response {
let mut response = template.into_response();
apply(response.headers_mut());
response
}
#[derive(Serialize)] #[derive(Serialize)]
struct TemplateUser { struct TemplateUser {
id: String, id: String,
@@ -143,7 +166,7 @@ struct ContextWrapper<'a> {
initial_theme: &'a str, initial_theme: &'a str,
is_authenticated: bool, is_authenticated: bool,
user: Option<&'a TemplateUser>, user: Option<&'a TemplateUser>,
conversation_archive: Vec<SidebarConversation>, conversation_archive: &'a [SidebarConversation],
#[serde(flatten)] #[serde(flatten)]
context: HashMap<String, Value>, context: HashMap<String, Value>,
} }
@@ -213,18 +236,14 @@ where
if let Some(template_response) = response.extensions().get::<TemplateResponse>().cloned() { if let Some(template_response) = response.extensions().get::<TemplateResponse>().cloned() {
let template_engine = state.template_engine(); let template_engine = state.template_engine();
let mut conversation_archive = Vec::new();
let should_load_conversation_archive = let should_load_conversation_archive =
matches!(&template_response.template_kind, TemplateKind::Full(_)); matches!(&template_response.template_kind, TemplateKind::Full(_));
if should_load_conversation_archive { let cached_archive = if should_load_conversation_archive {
if let Some(user_id) = current_user.as_ref().map(|u| &u.id) { if let Some(user_id) = current_user.as_ref().map(|u| &u.id) {
let html_state = state.html_state(); let html_state = state.html_state();
if let Some(cached_archive) = if let Some(cached) = html_state.get_cached_conversation_archive(user_id).await {
html_state.get_cached_conversation_archive(user_id).await Some(cached)
{
conversation_archive = cached_archive.to_vec();
} else if let Ok(archive) = } else if let Ok(archive) =
Conversation::get_user_sidebar_conversations(user_id, &html_state.db).await Conversation::get_user_sidebar_conversations(user_id, &html_state.db).await
{ {
@@ -232,10 +251,19 @@ where
html_state html_state
.set_cached_conversation_archive(user_id, Arc::clone(&cached)) .set_cached_conversation_archive(user_id, Arc::clone(&cached))
.await; .await;
conversation_archive = cached.to_vec(); Some(cached)
} else {
None
} }
} else {
None
} }
} } else {
None
};
let conversation_archive = cached_archive
.as_ref()
.map_or(&[][..], |archive| archive.as_ref());
let context_map = match context_to_map(&template_response.context) { let context_map = match context_to_map(&template_response.context) {
Ok(map) => map, Ok(map) => map,
+2 -2
View File
@@ -34,7 +34,7 @@ macro_rules! create_asset_service {
}}; }};
} }
pub type MiddleWareVecType<S> = Vec<Box<dyn FnOnce(Router<S>) -> Router<S> + Send>>; pub type MiddlewareVec<S> = Vec<Box<dyn FnOnce(Router<S>) -> Router<S> + Send>>;
/// Builder for composing public/protected HTML routes and middleware layers. /// Builder for composing public/protected HTML routes and middleware layers.
pub struct RouterFactory<S> { pub struct RouterFactory<S> {
@@ -43,7 +43,7 @@ pub struct RouterFactory<S> {
protected_routers: Vec<Router<S>>, protected_routers: Vec<Router<S>>,
nested_routes: Vec<(String, Router<S>)>, nested_routes: Vec<(String, Router<S>)>,
nested_protected_routes: Vec<(String, Router<S>)>, nested_protected_routes: Vec<(String, Router<S>)>,
custom_middleware: MiddleWareVecType<S>, custom_middleware: MiddlewareVec<S>,
public_assets_config: Option<AssetsConfig>, public_assets_config: Option<AssetsConfig>,
compression_enabled: bool, compression_enabled: bool,
} }
+9 -9
View File
@@ -1,11 +1,11 @@
use axum::{extract::State, response::IntoResponse, Form}; use axum::{extract::State, Form};
use chrono_tz::TZ_VARIANTS; use chrono_tz::TZ_VARIANTS;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{ use crate::{
middlewares::{ middlewares::{
auth_middleware::RequireUser, auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse}, response_middleware::{TemplateResponse, TemplateResult},
}, },
AuthSessionType, AuthSessionType,
}; };
@@ -28,7 +28,7 @@ pub struct AccountPageData {
pub async fn show_account_page( pub async fn show_account_page(
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
State(_state): State<HtmlState>, State(_state): State<HtmlState>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
let timezones = TZ_VARIANTS let timezones = TZ_VARIANTS
.iter() .iter()
.map(std::string::ToString::to_string) .map(std::string::ToString::to_string)
@@ -57,7 +57,7 @@ pub async fn set_api_key(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
auth: AuthSessionType, auth: AuthSessionType,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
// Generate and set the API key // Generate and set the API key
let api_key = User::set_api_key(&user.id, &state.db).await?; let api_key = User::set_api_key(&user.id, &state.db).await?;
@@ -82,7 +82,7 @@ pub async fn delete_account(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
auth: AuthSessionType, auth: AuthSessionType,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
state.db.delete_item::<User>(&user.id).await?; state.db.delete_item::<User>(&user.id).await?;
auth.logout_user(); auth.logout_user();
@@ -102,7 +102,7 @@ pub async fn update_timezone(
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
auth: AuthSessionType, auth: AuthSessionType,
Form(form): Form<UpdateTimezoneForm>, Form(form): Form<UpdateTimezoneForm>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
User::update_timezone(&user.id, &form.timezone, &state.db).await?; User::update_timezone(&user.id, &form.timezone, &state.db).await?;
// Clear the cache // Clear the cache
@@ -137,7 +137,7 @@ pub async fn update_theme(
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
auth: AuthSessionType, auth: AuthSessionType,
Form(form): Form<UpdateThemeForm>, Form(form): Form<UpdateThemeForm>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
User::update_theme(&user.id, &form.theme, &state.db).await?; User::update_theme(&user.id, &form.theme, &state.db).await?;
// Clear the cache // Clear the cache
@@ -166,7 +166,7 @@ pub async fn update_theme(
pub async fn show_change_password( pub async fn show_change_password(
RequireUser(_user): RequireUser, RequireUser(_user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"auth/change_password_form.html", "auth/change_password_form.html",
(), (),
@@ -184,7 +184,7 @@ pub async fn change_password(
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
auth: AuthSessionType, auth: AuthSessionType,
Form(form): Form<NewPasswordForm>, Form(form): Form<NewPasswordForm>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
// Authenticate to make sure the password matches // Authenticate to make sure the password matches
let authenticated_user = User::authenticate(&user.email, &form.old_password, &state.db).await?; let authenticated_user = User::authenticate(&user.email, &form.old_password, &state.db).await?;
+10 -11
View File
@@ -3,7 +3,6 @@ use std::sync::Arc;
use async_openai::types::ListModelResponse; use async_openai::types::ListModelResponse;
use axum::{ use axum::{
extract::{Query, State}, extract::{Query, State},
response::IntoResponse,
Form, Form,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -26,7 +25,7 @@ use tracing::{error, info};
use crate::{ use crate::{
html_state::HtmlState, html_state::HtmlState,
middlewares::response_middleware::{HtmlError, TemplateResponse}, middlewares::response_middleware::{TemplateResponse, TemplateResult},
}; };
#[derive(Serialize)] #[derive(Serialize)]
@@ -57,7 +56,7 @@ pub struct AdminPanelQuery {
pub async fn show_admin_panel( pub async fn show_admin_panel(
State(state): State<HtmlState>, State(state): State<HtmlState>,
Query(query): Query<AdminPanelQuery>, Query(query): Query<AdminPanelQuery>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
let section = match query.section.as_deref() { let section = match query.section.as_deref() {
Some("models") => AdminSection::Models, Some("models") => AdminSection::Models,
_ => AdminSection::Overview, _ => AdminSection::Overview,
@@ -124,7 +123,7 @@ pub struct RegistrationToggleData {
pub async fn toggle_registration_status( pub async fn toggle_registration_status(
State(state): State<HtmlState>, State(state): State<HtmlState>,
Form(input): Form<RegistrationToggleInput>, Form(input): Form<RegistrationToggleInput>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
let new_settings = SystemSettingsPatch { let new_settings = SystemSettingsPatch {
registrations_enabled: Some(input.registration_open), registrations_enabled: Some(input.registration_open),
..Default::default() ..Default::default()
@@ -160,7 +159,7 @@ pub struct ModelSettingsData {
pub async fn update_model_settings( pub async fn update_model_settings(
State(state): State<HtmlState>, State(state): State<HtmlState>,
Form(input): Form<ModelSettingsInput>, Form(input): Form<ModelSettingsInput>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
let current_settings = SystemSettings::get_current(&state.db).await?; let current_settings = SystemSettings::get_current(&state.db).await?;
// Check if using FastEmbed - if so, embedding model/dimensions cannot be changed via UI // Check if using FastEmbed - if so, embedding model/dimensions cannot be changed via UI
@@ -272,7 +271,7 @@ pub struct SystemPromptEditData {
pub async fn show_edit_system_prompt( pub async fn show_edit_system_prompt(
State(state): State<HtmlState>, State(state): State<HtmlState>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
let settings = SystemSettings::get_current(&state.db).await?; let settings = SystemSettings::get_current(&state.db).await?;
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
@@ -297,7 +296,7 @@ pub struct SystemPromptSectionData {
pub async fn patch_query_prompt( pub async fn patch_query_prompt(
State(state): State<HtmlState>, State(state): State<HtmlState>,
Form(input): Form<SystemPromptUpdateInput>, Form(input): Form<SystemPromptUpdateInput>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
let new_settings = SystemSettingsPatch { let new_settings = SystemSettingsPatch {
query_system_prompt: Some(input.query_system_prompt), query_system_prompt: Some(input.query_system_prompt),
..Default::default() ..Default::default()
@@ -322,7 +321,7 @@ pub struct IngestionPromptEditData {
pub async fn show_edit_ingestion_prompt( pub async fn show_edit_ingestion_prompt(
State(state): State<HtmlState>, State(state): State<HtmlState>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
let settings = SystemSettings::get_current(&state.db).await?; let settings = SystemSettings::get_current(&state.db).await?;
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
@@ -342,7 +341,7 @@ pub struct IngestionPromptUpdateInput {
pub async fn patch_ingestion_prompt( pub async fn patch_ingestion_prompt(
State(state): State<HtmlState>, State(state): State<HtmlState>,
Form(input): Form<IngestionPromptUpdateInput>, Form(input): Form<IngestionPromptUpdateInput>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
let new_settings = SystemSettingsPatch { let new_settings = SystemSettingsPatch {
ingestion_system_prompt: Some(input.ingestion_system_prompt), ingestion_system_prompt: Some(input.ingestion_system_prompt),
..Default::default() ..Default::default()
@@ -367,7 +366,7 @@ pub struct ImagePromptEditData {
pub async fn show_edit_image_prompt( pub async fn show_edit_image_prompt(
State(state): State<HtmlState>, State(state): State<HtmlState>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
let settings = SystemSettings::get_current(&state.db).await?; let settings = SystemSettings::get_current(&state.db).await?;
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
@@ -387,7 +386,7 @@ pub struct ImagePromptUpdateInput {
pub async fn patch_image_prompt( pub async fn patch_image_prompt(
State(state): State<HtmlState>, State(state): State<HtmlState>,
Form(input): Form<ImagePromptUpdateInput>, Form(input): Form<ImagePromptUpdateInput>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
let new_settings = SystemSettingsPatch { let new_settings = SystemSettingsPatch {
image_processing_prompt: Some(input.image_processing_prompt), image_processing_prompt: Some(input.image_processing_prompt),
..Default::default() ..Default::default()
+6 -6
View File
@@ -1,10 +1,10 @@
use axum::{extract::State, response::IntoResponse, Form}; use axum::{extract::State, Form};
use axum_htmx::HxBoosted; use axum_htmx::HxBoosted;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{ use crate::{
html_state::HtmlState, html_state::HtmlState,
middlewares::response_middleware::{HtmlError, TemplateResponse}, middlewares::response_middleware::{TemplateResponse, TemplateResult},
AuthSessionType, AuthSessionType,
}; };
use common::storage::types::user::User; use common::storage::types::user::User;
@@ -19,7 +19,7 @@ pub struct SignInParams {
pub async fn show_signin_form( pub async fn show_signin_form(
auth: AuthSessionType, auth: AuthSessionType,
HxBoosted(boosted): HxBoosted, HxBoosted(boosted): HxBoosted,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
if auth.current_user.is_some() { if auth.current_user.is_some() {
return Ok(TemplateResponse::redirect("/")); return Ok(TemplateResponse::redirect("/"));
} }
@@ -38,9 +38,9 @@ pub async fn authenticate_user(
State(state): State<HtmlState>, State(state): State<HtmlState>,
auth: AuthSessionType, auth: AuthSessionType,
Form(form): Form<SignInParams>, Form(form): Form<SignInParams>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
let Ok(user) = User::authenticate(&form.email, &form.password, &state.db).await else { let Ok(user) = User::authenticate(&form.email, &form.password, &state.db).await else {
return Ok(TemplateResponse::bad_request("Incorrect email or password").into_response()); return Ok(TemplateResponse::bad_request("Incorrect email or password"));
}; };
auth.login_user(user.id); auth.login_user(user.id);
@@ -49,5 +49,5 @@ pub async fn authenticate_user(
auth.remember_user(true); auth.remember_user(true);
} }
Ok(TemplateResponse::redirect("/").into_response()) Ok(TemplateResponse::redirect("/"))
} }
+2 -4
View File
@@ -1,11 +1,9 @@
use axum::response::IntoResponse;
use crate::{ use crate::{
middlewares::response_middleware::{HtmlError, TemplateResponse}, middlewares::response_middleware::{TemplateResponse, TemplateResult},
AuthSessionType, AuthSessionType,
}; };
pub async fn sign_out_user(auth: AuthSessionType) -> Result<impl IntoResponse, HtmlError> { pub async fn sign_out_user(auth: AuthSessionType) -> TemplateResult {
if !auth.is_authenticated() { if !auth.is_authenticated() {
return Ok(TemplateResponse::redirect("/")); return Ok(TemplateResponse::redirect("/"));
} }
+6 -6
View File
@@ -1,4 +1,4 @@
use axum::{extract::State, response::IntoResponse, Form}; use axum::{extract::State, Form};
use axum_htmx::HxBoosted; use axum_htmx::HxBoosted;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -6,7 +6,7 @@ use common::{error::AppError, storage::types::user::{Theme, User}};
use crate::{ use crate::{
html_state::HtmlState, html_state::HtmlState,
middlewares::response_middleware::{HtmlError, TemplateResponse}, middlewares::response_middleware::{TemplateResponse, TemplateResult},
AuthSessionType, AuthSessionType,
}; };
@@ -27,7 +27,7 @@ fn signup_error_message(err: &AppError) -> &str {
pub async fn show_signup_form( pub async fn show_signup_form(
auth: AuthSessionType, auth: AuthSessionType,
HxBoosted(boosted): HxBoosted, HxBoosted(boosted): HxBoosted,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
if auth.current_user.is_some() { if auth.current_user.is_some() {
return Ok(TemplateResponse::redirect("/")); return Ok(TemplateResponse::redirect("/"));
} }
@@ -47,7 +47,7 @@ pub async fn process_signup_and_show_verification(
State(state): State<HtmlState>, State(state): State<HtmlState>,
auth: AuthSessionType, auth: AuthSessionType,
Form(form): Form<Params>, Form(form): Form<Params>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
let user = match User::create_new( let user = match User::create_new(
form.email, form.email,
form.password, form.password,
@@ -60,11 +60,11 @@ pub async fn process_signup_and_show_verification(
Ok(user) => user, Ok(user) => user,
Err(err) => { Err(err) => {
tracing::error!(?err, "signup failed"); tracing::error!(?err, "signup failed");
return Ok(TemplateResponse::bad_request(signup_error_message(&err)).into_response()); return Ok(TemplateResponse::bad_request(signup_error_message(&err)));
} }
}; };
auth.login_user(user.id); auth.login_user(user.id);
Ok(TemplateResponse::redirect("/").into_response()) Ok(TemplateResponse::redirect("/"))
} }
+45 -45
View File
@@ -1,7 +1,6 @@
use axum::{ use axum::{
extract::{Path, State}, extract::{Path, State},
http::HeaderValue, http::HeaderValue,
response::IntoResponse,
Form, Form,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -18,7 +17,10 @@ use crate::{
html_state::HtmlState, html_state::HtmlState,
middlewares::{ middlewares::{
auth_middleware::RequireUser, auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse}, response_middleware::{
template_as_response, template_with_headers, TemplateResponse, TemplateResult,
ResponseResult,
},
}, },
}; };
@@ -31,7 +33,7 @@ pub struct ChatPageData {
pub async fn show_chat_base( pub async fn show_chat_base(
State(_state): State<HtmlState>, State(_state): State<HtmlState>,
RequireUser(_user): RequireUser, RequireUser(_user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"chat/base.html", "chat/base.html",
ChatPageData { ChatPageData {
@@ -50,7 +52,7 @@ pub async fn show_existing_chat(
Path(conversation_id): Path<String>, Path(conversation_id): Path<String>,
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
let (conversation, messages) = let (conversation, messages) =
Conversation::get_complete_conversation(conversation_id.as_str(), &user.id, &state.db) Conversation::get_complete_conversation(conversation_id.as_str(), &user.id, &state.db)
.await?; .await?;
@@ -69,7 +71,7 @@ pub async fn new_user_message(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
Form(form): Form<NewMessageForm>, Form(form): Form<NewMessageForm>,
) -> Result<impl IntoResponse, HtmlError> { ) -> ResponseResult {
#[derive(Serialize)] #[derive(Serialize)]
struct SSEResponseInitData { struct SSEResponseInitData {
user_message: Message, user_message: Message,
@@ -82,31 +84,32 @@ pub async fn new_user_message(
.ok_or_else(|| AppError::NotFound("conversation was not found".into()))?; .ok_or_else(|| AppError::NotFound("conversation was not found".into()))?;
if conversation.user_id != user.id { if conversation.user_id != user.id {
return Ok(TemplateResponse::unauthorized().into_response()); return Ok(template_as_response(TemplateResponse::unauthorized()));
} }
let user_message = Message::new(conversation_id, MessageRole::User, form.content, None); let user_message = Message::new(conversation_id, MessageRole::User, form.content, None);
state.db.store_item(user_message.clone()).await?; state.db.store_item(user_message.clone()).await?;
let mut response = TemplateResponse::new_template( let push_path = format!("/chat/{}", conversation.id);
"chat/streaming_response.html", Ok(template_with_headers(
SSEResponseInitData { user_message }, TemplateResponse::new_template(
) "chat/streaming_response.html",
.into_response(); SSEResponseInitData { user_message },
),
if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) { |headers| {
response.headers_mut().insert("HX-Push", header_value); if let Ok(header_value) = HeaderValue::from_str(&push_path) {
} headers.insert("HX-Push", header_value);
}
Ok(response) },
))
} }
pub async fn new_chat_user_message( pub async fn new_chat_user_message(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
Form(form): Form<NewMessageForm>, Form(form): Form<NewMessageForm>,
) -> Result<impl IntoResponse, HtmlError> { ) -> ResponseResult {
#[derive(Serialize)] #[derive(Serialize)]
struct SSEResponseInitData { struct SSEResponseInitData {
user_message: Message, user_message: Message,
@@ -125,20 +128,21 @@ pub async fn new_chat_user_message(
state.db.store_item(user_message.clone()).await?; state.db.store_item(user_message.clone()).await?;
state.invalidate_conversation_archive_cache(&user.id).await; state.invalidate_conversation_archive_cache(&user.id).await;
let mut response = TemplateResponse::new_template( let push_path = format!("/chat/{}", conversation.id);
"chat/new_chat_first_response.html", Ok(template_with_headers(
SSEResponseInitData { TemplateResponse::new_template(
user_message, "chat/new_chat_first_response.html",
conversation: conversation.clone(), SSEResponseInitData {
user_message,
conversation: conversation.clone(),
},
),
|headers| {
if let Ok(header_value) = HeaderValue::from_str(&push_path) {
headers.insert("HX-Push", header_value);
}
}, },
) ))
.into_response();
if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) {
response.headers_mut().insert("HX-Push", header_value);
}
Ok(response)
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@@ -155,7 +159,7 @@ pub async fn show_conversation_editing_title(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
Path(conversation_id): Path<String>, Path(conversation_id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
let conversation: Conversation = state let conversation: Conversation = state
.db .db
.get_item(&conversation_id) .get_item(&conversation_id)
@@ -163,7 +167,7 @@ pub async fn show_conversation_editing_title(
.ok_or_else(|| AppError::NotFound("conversation not found".to_string()))?; .ok_or_else(|| AppError::NotFound("conversation not found".to_string()))?;
if conversation.user_id != user.id { if conversation.user_id != user.id {
return Ok(TemplateResponse::unauthorized().into_response()); return Ok(TemplateResponse::unauthorized());
} }
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
@@ -171,8 +175,7 @@ pub async fn show_conversation_editing_title(
DrawerContext { DrawerContext {
edit_conversation_id: Some(conversation_id), edit_conversation_id: Some(conversation_id),
}, },
) ))
.into_response())
} }
pub async fn patch_conversation_title( pub async fn patch_conversation_title(
@@ -180,7 +183,7 @@ pub async fn patch_conversation_title(
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
Path(conversation_id): Path<String>, Path(conversation_id): Path<String>,
Form(form): Form<PatchConversationTitle>, Form(form): Form<PatchConversationTitle>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
Conversation::patch_title(&conversation_id, &user.id, &form.title, &state.db).await?; Conversation::patch_title(&conversation_id, &user.id, &form.title, &state.db).await?;
state.invalidate_conversation_archive_cache(&user.id).await; state.invalidate_conversation_archive_cache(&user.id).await;
@@ -189,15 +192,14 @@ pub async fn patch_conversation_title(
DrawerContext { DrawerContext {
edit_conversation_id: None, edit_conversation_id: None,
}, },
) ))
.into_response())
} }
pub async fn delete_conversation( pub async fn delete_conversation(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
Path(conversation_id): Path<String>, Path(conversation_id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
let conversation: Conversation = state let conversation: Conversation = state
.db .db
.get_item(&conversation_id) .get_item(&conversation_id)
@@ -205,7 +207,7 @@ pub async fn delete_conversation(
.ok_or_else(|| AppError::NotFound("conversation not found".to_string()))?; .ok_or_else(|| AppError::NotFound("conversation not found".to_string()))?;
if conversation.user_id != user.id { if conversation.user_id != user.id {
return Ok(TemplateResponse::unauthorized().into_response()); return Ok(TemplateResponse::unauthorized());
} }
state state
@@ -219,18 +221,16 @@ pub async fn delete_conversation(
DrawerContext { DrawerContext {
edit_conversation_id: None, edit_conversation_id: None,
}, },
) ))
.into_response())
} }
pub async fn reload_sidebar( pub async fn reload_sidebar(
State(_state): State<HtmlState>, State(_state): State<HtmlState>,
RequireUser(_user): RequireUser, RequireUser(_user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"sidebar.html", "sidebar.html",
DrawerContext { DrawerContext {
edit_conversation_id: None, edit_conversation_id: None,
}, },
) ))
.into_response())
} }
+2 -3
View File
@@ -2,7 +2,6 @@
use axum::{ use axum::{
extract::{Path, State}, extract::{Path, State},
response::IntoResponse,
}; };
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use chrono_tz::Tz; use chrono_tz::Tz;
@@ -16,7 +15,7 @@ use crate::{
html_state::HtmlState, html_state::HtmlState,
middlewares::{ middlewares::{
auth_middleware::RequireUser, auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse}, response_middleware::{TemplateResponse, TemplateResult},
}, },
}; };
@@ -45,7 +44,7 @@ pub async fn show_reference_tooltip(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
Path(reference_id): Path<String>, Path(reference_id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
let Ok((normalized_reference_id, target)) = normalize_reference(&reference_id) else { let Ok((normalized_reference_id, target)) = normalize_reference(&reference_id) else {
return Ok(TemplateResponse::not_found()); return Ok(TemplateResponse::not_found());
}; };
+7 -8
View File
@@ -1,6 +1,5 @@
use axum::{ use axum::{
extract::{Path, Query, State}, extract::{Path, Query, State},
response::IntoResponse,
Form, Form,
}; };
use axum_htmx::{HxBoosted, HxRequest, HxTarget}; use axum_htmx::{HxBoosted, HxRequest, HxTarget};
@@ -15,7 +14,7 @@ use crate::{
html_state::HtmlState, html_state::HtmlState,
middlewares::{ middlewares::{
auth_middleware::RequireUser, auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse}, response_middleware::{TemplateResponse, TemplateResult},
}, },
utils::pagination::{paginate_items, Pagination}, utils::pagination::{paginate_items, Pagination},
utils::text_content_preview::truncate_text_contents, utils::text_content_preview::truncate_text_contents,
@@ -50,7 +49,7 @@ pub async fn show_content_page(
Query(params): Query<FilterParams>, Query(params): Query<FilterParams>,
HxRequest(is_htmx): HxRequest, HxRequest(is_htmx): HxRequest,
HxBoosted(is_boosted): HxBoosted, HxBoosted(is_boosted): HxBoosted,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
// Normalize empty strings to None // Normalize empty strings to None
let category_filter = params let category_filter = params
.category .category
@@ -101,7 +100,7 @@ pub async fn show_text_content_edit_form(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
Path(id): Path<String>, Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
#[derive(Serialize)] #[derive(Serialize)]
pub struct TextContentEditModal { pub struct TextContentEditModal {
pub text_content: TextContent, pub text_content: TextContent,
@@ -127,7 +126,7 @@ pub async fn patch_text_content(
Path(id): Path<String>, Path(id): Path<String>,
HxTarget(target): HxTarget, HxTarget(target): HxTarget,
Form(form): Form<PatchTextContentParams>, Form(form): Form<PatchTextContentParams>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
User::get_and_validate_text_content(&id, &user.id, &state.db).await?; User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
TextContent::patch(&id, &form.context, &form.category, &form.text, &state.db).await?; TextContent::patch(&id, &form.context, &form.category, &form.text, &state.db).await?;
@@ -167,7 +166,7 @@ pub async fn delete_text_content(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
Path(id): Path<String>, Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
// Get and validate the text content // Get and validate the text content
let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?; let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
@@ -213,7 +212,7 @@ pub async fn show_content_read_modal(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
Path(id): Path<String>, Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
#[derive(Serialize)] #[derive(Serialize)]
pub struct TextContentReadModalData { pub struct TextContentReadModalData {
pub text_content: TextContent, pub text_content: TextContent,
@@ -231,7 +230,7 @@ pub async fn show_content_read_modal(
pub async fn show_recent_content( pub async fn show_recent_content(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
let text_contents = let text_contents =
truncate_text_contents(User::get_latest_text_contents(&user.id, &state.db).await?); truncate_text_contents(User::get_latest_text_contents(&user.id, &state.db).await?);
+12 -10
View File
@@ -12,7 +12,9 @@ use crate::{
html_state::HtmlState, html_state::HtmlState,
middlewares::{ middlewares::{
auth_middleware::RequireUser, auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse}, response_middleware::{
template_as_response, HtmlError, TemplateResponse, TemplateResult, ResponseResult,
},
}, },
utils::text_content_preview::truncate_text_contents, utils::text_content_preview::truncate_text_contents,
utils::truncate::with_ellipsis, utils::truncate::with_ellipsis,
@@ -37,7 +39,7 @@ pub struct IndexPageData {
pub async fn index_handler( pub async fn index_handler(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
let (text_contents, dashboard_stats, active_jobs) = try_join!( let (text_contents, dashboard_stats, active_jobs) = try_join!(
User::get_latest_text_contents(&user.id, &state.db), User::get_latest_text_contents(&user.id, &state.db),
User::get_dashboard_stats(&user.id, &state.db), User::get_dashboard_stats(&user.id, &state.db),
@@ -65,7 +67,7 @@ pub async fn delete_text_content(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
Path(id): Path<String>, Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
// Get and validate TextContent // Get and validate TextContent
let text_content = get_and_validate_text_content(&state, &id, &user).await?; let text_content = get_and_validate_text_content(&state, &id, &user).await?;
@@ -154,7 +156,7 @@ pub async fn delete_job(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
Path(id): Path<String>, Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
User::validate_and_delete_job(&id, &user.id, &state.db).await?; User::validate_and_delete_job(&id, &user.id, &state.db).await?;
let active_jobs = User::get_unfinished_ingestion_tasks(&user.id, &state.db).await?; let active_jobs = User::get_unfinished_ingestion_tasks(&user.id, &state.db).await?;
@@ -169,7 +171,7 @@ pub async fn delete_job(
pub async fn show_active_jobs( pub async fn show_active_jobs(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
let active_jobs = User::get_unfinished_ingestion_tasks(&user.id, &state.db).await?; let active_jobs = User::get_unfinished_ingestion_tasks(&user.id, &state.db).await?;
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
@@ -181,7 +183,7 @@ pub async fn show_active_jobs(
pub async fn show_task_archive( pub async fn show_task_archive(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
let tasks = User::get_all_ingestion_tasks(&user.id, &state.db).await?; let tasks = User::get_all_ingestion_tasks(&user.id, &state.db).await?;
let entries: Vec<TaskArchiveEntry> = tasks let entries: Vec<TaskArchiveEntry> = tasks
@@ -234,17 +236,17 @@ pub async fn serve_file(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
Path(file_id): Path<String>, Path(file_id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> { ) -> ResponseResult {
let Ok(file_info) = FileInfo::get_by_id(&file_id, &state.db).await else { let Ok(file_info) = FileInfo::get_by_id(&file_id, &state.db).await else {
return Ok(TemplateResponse::not_found().into_response()); return Ok(template_as_response(TemplateResponse::not_found()));
}; };
if file_info.user_id != user.id { if file_info.user_id != user.id {
return Ok(TemplateResponse::unauthorized().into_response()); return Ok(template_as_response(TemplateResponse::unauthorized()));
} }
let Ok(stream) = state.storage.get_stream(&file_info.path).await else { let Ok(stream) = state.storage.get_stream(&file_info.path).await else {
return Ok(TemplateResponse::server_error().into_response()); return Ok(template_as_response(TemplateResponse::server_error()));
}; };
let body = Body::from_stream(stream); let body = Body::from_stream(stream);
+14 -16
View File
@@ -5,7 +5,7 @@ use axum::{
http::StatusCode, http::StatusCode,
response::{ response::{
sse::{Event, KeepAlive, KeepAliveStream}, sse::{Event, KeepAlive, KeepAliveStream},
IntoResponse, Response, Sse, Sse,
}, },
}; };
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart}; use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
@@ -31,7 +31,7 @@ use crate::{
html_state::HtmlState, html_state::HtmlState,
middlewares::{ middlewares::{
auth_middleware::RequireUser, auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse}, response_middleware::{TemplateResponse, TemplateResult},
}, },
}; };
@@ -49,7 +49,7 @@ fn sse_with_keep_alive(stream: EventStream) -> TaskSse {
pub async fn show_ingest_form( pub async fn show_ingest_form(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
#[derive(Serialize)] #[derive(Serialize)]
pub struct ShowIngestFormData { pub struct ShowIngestFormData {
user_categories: Vec<String>, user_categories: Vec<String>,
@@ -65,7 +65,7 @@ pub async fn show_ingest_form(
pub async fn hide_ingest_form( pub async fn hide_ingest_form(
RequireUser(_user): RequireUser, RequireUser(_user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"ingestion/add_content_button.html", "ingestion/add_content_button.html",
(), (),
@@ -91,12 +91,11 @@ pub async fn process_ingest_form(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
TypedMultipart(input): TypedMultipart<IngestionParams>, TypedMultipart(input): TypedMultipart<IngestionParams>,
) -> Result<Response, HtmlError> { ) -> TemplateResult {
if input.content.as_ref().is_none_or(|c| c.len() < 2) && input.files.is_empty() { if input.content.as_ref().is_none_or(|c| c.len() < 2) && input.files.is_empty() {
return Ok( return Ok(TemplateResponse::bad_request(
TemplateResponse::bad_request("You need to either add files or content") "You need to either add files or content",
.into_response(), ));
);
} }
let content_bytes = input.content.as_ref().map_or(0, String::len); let content_bytes = input.content.as_ref().map_or(0, String::len);
@@ -118,11 +117,10 @@ pub async fn process_ingest_form(
StatusCode::PAYLOAD_TOO_LARGE, StatusCode::PAYLOAD_TOO_LARGE,
"Payload Too Large", "Payload Too Large",
&message, &message,
) ));
.into_response());
} }
Err(IngestValidationError::BadRequest(message)) => { Err(IngestValidationError::BadRequest(message)) => {
return Ok(TemplateResponse::bad_request(&message).into_response()); return Ok(TemplateResponse::bad_request(&message));
} }
} }
@@ -153,10 +151,10 @@ pub async fn process_ingest_form(
let tasks = let tasks =
IngestionTask::create_all_and_add_to_db(payloads, &user.id, &state.db).await?; IngestionTask::create_all_and_add_to_db(payloads, &user.id, &state.db).await?;
Ok( Ok(TemplateResponse::new_template(
TemplateResponse::new_template("dashboard/current_task.html", NewTasksData { tasks }) "dashboard/current_task.html",
.into_response(), NewTasksData { tasks },
) ))
} }
#[derive(Deserialize)] #[derive(Deserialize)]
+25 -23
View File
@@ -31,7 +31,9 @@ use crate::{
html_state::HtmlState, html_state::HtmlState,
middlewares::{ middlewares::{
auth_middleware::RequireUser, auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse}, response_middleware::{
template_with_headers, HtmlError, TemplateResponse, TemplateResult, ResponseResult,
},
}, },
utils::pagination::{paginate_items, Pagination}, utils::pagination::{paginate_items, Pagination},
}; };
@@ -120,12 +122,12 @@ fn collect_relationship_type_options(relationships: &[KnowledgeRelationship]) ->
options options
} }
fn respond_with_graph_refresh(response: TemplateResponse) -> Response { fn graph_refresh_response(template: TemplateResponse) -> Response {
let mut response = response.into_response(); template_with_headers(template, |headers| {
if let Ok(value) = HeaderValue::from_str(GRAPH_REFRESH_TRIGGER) { if let Ok(value) = HeaderValue::from_str(GRAPH_REFRESH_TRIGGER) {
response.headers_mut().insert(HX_TRIGGER, value); headers.insert(HX_TRIGGER, value);
} }
response })
} }
#[derive(Deserialize, Default)] #[derive(Deserialize, Default)]
@@ -138,7 +140,7 @@ pub struct FilterParams {
pub async fn show_new_knowledge_entity_form( pub async fn show_new_knowledge_entity_form(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
let entity_types: Vec<String> = KnowledgeEntityType::variants() let entity_types: Vec<String> = KnowledgeEntityType::variants()
.iter() .iter()
.map(ToString::to_string) .map(ToString::to_string)
@@ -170,7 +172,7 @@ pub async fn create_knowledge_entity(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
Form(form): Form<CreateKnowledgeEntityParams>, Form(form): Form<CreateKnowledgeEntityParams>,
) -> Result<impl IntoResponse, HtmlError> { ) -> ResponseResult {
let name = form.name.trim().to_string(); let name = form.name.trim().to_string();
if name.is_empty() { if name.is_empty() {
return Err(AppError::Validation("name is required".into()).into()); return Err(AppError::Validation("name is required".into()).into());
@@ -230,7 +232,7 @@ pub async fn create_knowledge_entity(
let default_params = FilterParams::default(); let default_params = FilterParams::default();
let kb_data = build_knowledge_base_data(&state, &user, &default_params).await?; let kb_data = build_knowledge_base_data(&state, &user, &default_params).await?;
Ok(respond_with_graph_refresh(TemplateResponse::new_partial( Ok(graph_refresh_response(TemplateResponse::new_partial(
"knowledge/base.html", "knowledge/base.html",
"main", "main",
kb_data, kb_data,
@@ -241,7 +243,7 @@ pub async fn suggest_knowledge_relationships(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
Form(form): Form<SuggestRelationshipsParams>, Form(form): Form<SuggestRelationshipsParams>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
let entity_lookup: HashMap<String, KnowledgeEntity> = let entity_lookup: HashMap<String, KnowledgeEntity> =
User::get_knowledge_entities(&user.id, &state.db) User::get_knowledge_entities(&user.id, &state.db)
.await? .await?
@@ -723,7 +725,7 @@ pub async fn show_knowledge_page(
HxRequest(is_htmx): HxRequest, HxRequest(is_htmx): HxRequest,
HxBoosted(is_boosted): HxBoosted, HxBoosted(is_boosted): HxBoosted,
Query(mut params): Query<FilterParams>, Query(mut params): Query<FilterParams>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
// Normalize filters: treat empty or "none" as no filter // Normalize filters: treat empty or "none" as no filter
params.entity_type = normalize_filter(params.entity_type.take()); params.entity_type = normalize_filter(params.entity_type.take());
params.content_category = normalize_filter(params.content_category.take()); params.content_category = normalize_filter(params.content_category.take());
@@ -772,7 +774,7 @@ pub async fn get_knowledge_graph_json(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
Query(mut params): Query<FilterParams>, Query(mut params): Query<FilterParams>,
) -> Result<impl IntoResponse, HtmlError> { ) -> ResponseResult {
// Normalize filters: treat empty or "none" as no filter // Normalize filters: treat empty or "none" as no filter
params.entity_type = normalize_filter(params.entity_type.take()); params.entity_type = normalize_filter(params.entity_type.take());
params.content_category = normalize_filter(params.content_category.take()); params.content_category = normalize_filter(params.content_category.take());
@@ -821,7 +823,7 @@ pub async fn get_knowledge_graph_json(
}) })
.collect(); .collect();
Ok(Json(GraphData { nodes, links })) Ok(Json(GraphData { nodes, links }).into_response())
} }
// Normalize filter parameters: convert empty strings or "none" (case-insensitive) to None // Normalize filter parameters: convert empty strings or "none" (case-insensitive) to None
fn normalize_filter(input: Option<String>) -> Option<String> { fn normalize_filter(input: Option<String>) -> Option<String> {
@@ -851,7 +853,7 @@ pub async fn show_edit_knowledge_entity_form(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
Path(id): Path<String>, Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
#[derive(Serialize)] #[derive(Serialize)]
pub struct EntityData { pub struct EntityData {
entity: KnowledgeEntity, entity: KnowledgeEntity,
@@ -899,7 +901,7 @@ pub async fn patch_knowledge_entity(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
Form(form): Form<PatchKnowledgeEntityParams>, Form(form): Form<PatchKnowledgeEntityParams>,
) -> Result<impl IntoResponse, HtmlError> { ) -> ResponseResult {
// Get the existing entity and validate that the user is allowed // Get the existing entity and validate that the user is allowed
User::get_and_validate_knowledge_entity(&form.id, &user.id, &state.db).await?; User::get_and_validate_knowledge_entity(&form.id, &user.id, &state.db).await?;
@@ -930,7 +932,7 @@ pub async fn patch_knowledge_entity(
let content_categories = User::get_user_categories(&user.id, &state.db).await?; let content_categories = User::get_user_categories(&user.id, &state.db).await?;
// Render updated list // Render updated list
Ok(respond_with_graph_refresh(TemplateResponse::new_template( Ok(graph_refresh_response(TemplateResponse::new_template(
"knowledge/entity_list.html", "knowledge/entity_list.html",
EntityListData { EntityListData {
visible_entities, visible_entities,
@@ -948,7 +950,7 @@ pub async fn delete_knowledge_entity(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
Path(id): Path<String>, Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> { ) -> ResponseResult {
// Get the existing entity and validate that the user is allowed // Get the existing entity and validate that the user is allowed
User::get_and_validate_knowledge_entity(&id, &user.id, &state.db).await?; User::get_and_validate_knowledge_entity(&id, &user.id, &state.db).await?;
@@ -968,7 +970,7 @@ pub async fn delete_knowledge_entity(
// Get content categories // Get content categories
let content_categories = User::get_user_categories(&user.id, &state.db).await?; let content_categories = User::get_user_categories(&user.id, &state.db).await?;
Ok(respond_with_graph_refresh(TemplateResponse::new_template( Ok(graph_refresh_response(TemplateResponse::new_template(
"knowledge/entity_list.html", "knowledge/entity_list.html",
EntityListData { EntityListData {
visible_entities, visible_entities,
@@ -994,7 +996,7 @@ pub async fn delete_knowledge_relationship(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
Path(id): Path<String>, Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> { ) -> ResponseResult {
KnowledgeRelationship::delete_relationship_by_id(&id, &user.id, &state.db).await?; KnowledgeRelationship::delete_relationship_by_id(&id, &user.id, &state.db).await?;
let entities = User::get_knowledge_entities(&user.id, &state.db).await?; let entities = User::get_knowledge_entities(&user.id, &state.db).await?;
@@ -1003,7 +1005,7 @@ pub async fn delete_knowledge_relationship(
let table_data = build_relationship_table_data(entities, relationships); let table_data = build_relationship_table_data(entities, relationships);
// Render updated list // Render updated list
Ok(respond_with_graph_refresh(TemplateResponse::new_template( Ok(graph_refresh_response(TemplateResponse::new_template(
"knowledge/relationship_table.html", "knowledge/relationship_table.html",
table_data, table_data,
))) )))
@@ -1020,7 +1022,7 @@ pub async fn save_knowledge_relationship(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
Form(form): Form<SaveKnowledgeRelationshipInput>, Form(form): Form<SaveKnowledgeRelationshipInput>,
) -> Result<impl IntoResponse, HtmlError> { ) -> ResponseResult {
// Construct relationship // Construct relationship
let relationship_type = canonicalize_relationship_type(&form.relationship_type); let relationship_type = canonicalize_relationship_type(&form.relationship_type);
let relationship = KnowledgeRelationship::new( let relationship = KnowledgeRelationship::new(
@@ -1039,7 +1041,7 @@ pub async fn save_knowledge_relationship(
let table_data = build_relationship_table_data(entities, relationships); let table_data = build_relationship_table_data(entities, relationships);
// Render updated list // Render updated list
Ok(respond_with_graph_refresh(TemplateResponse::new_template( Ok(graph_refresh_response(TemplateResponse::new_template(
"knowledge/relationship_table.html", "knowledge/relationship_table.html",
table_data, table_data,
))) )))
+44 -41
View File
@@ -11,7 +11,9 @@ use serde::{Deserialize, Serialize};
use crate::html_state::HtmlState; use crate::html_state::HtmlState;
use crate::middlewares::{ use crate::middlewares::{
auth_middleware::RequireUser, auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse}, response_middleware::{
template_with_headers, HtmlError, TemplateResponse, TemplateResult, ResponseResult,
},
}; };
use common::storage::types::{ use common::storage::types::{
ingestion_payload::IngestionPayload, ingestion_task::IngestionTask, scratchpad::Scratchpad, ingestion_payload::IngestionPayload, ingestion_task::IngestionTask, scratchpad::Scratchpad,
@@ -127,7 +129,7 @@ pub async fn show_scratchpad_page(
HxRequest(is_htmx): HxRequest, HxRequest(is_htmx): HxRequest,
HxBoosted(is_boosted): HxBoosted, HxBoosted(is_boosted): HxBoosted,
State(state): State<HtmlState>, State(state): State<HtmlState>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
let scratchpads = Scratchpad::get_by_user(&user.id, &state.db).await?; let scratchpads = Scratchpad::get_by_user(&user.id, &state.db).await?;
let archived_scratchpads = Scratchpad::get_archived_by_user(&user.id, &state.db).await?; let archived_scratchpads = Scratchpad::get_archived_by_user(&user.id, &state.db).await?;
@@ -165,7 +167,7 @@ pub async fn show_scratchpad_modal(
State(state): State<HtmlState>, State(state): State<HtmlState>,
Path(scratchpad_id): Path<String>, Path(scratchpad_id): Path<String>,
Query(query): Query<EditTitleQuery>, Query(query): Query<EditTitleQuery>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
let scratchpad = Scratchpad::get_by_id(&scratchpad_id, &user.id, &state.db).await?; let scratchpad = Scratchpad::get_by_id(&scratchpad_id, &user.id, &state.db).await?;
let scratchpad_detail = ScratchpadDetail::from(&scratchpad); let scratchpad_detail = ScratchpadDetail::from(&scratchpad);
@@ -186,7 +188,7 @@ pub async fn create_scratchpad(
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
State(state): State<HtmlState>, State(state): State<HtmlState>,
Form(form): Form<CreateScratchpadForm>, Form(form): Form<CreateScratchpadForm>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
let user_id = user.id.clone(); let user_id = user.id.clone();
let scratchpad = Scratchpad::new(user_id.clone(), form.title); let scratchpad = Scratchpad::new(user_id.clone(), form.title);
let _stored = state.db.store_item(scratchpad.clone()).await?; let _stored = state.db.store_item(scratchpad.clone()).await?;
@@ -217,7 +219,7 @@ pub async fn auto_save_scratchpad(
State(state): State<HtmlState>, State(state): State<HtmlState>,
Path(scratchpad_id): Path<String>, Path(scratchpad_id): Path<String>,
Form(form): Form<UpdateScratchpadForm>, Form(form): Form<UpdateScratchpadForm>,
) -> Result<impl IntoResponse, HtmlError> { ) -> ResponseResult {
let updated = let updated =
Scratchpad::update_content(&scratchpad_id, &user.id, &form.content, &state.db).await?; Scratchpad::update_content(&scratchpad_id, &user.id, &form.content, &state.db).await?;
@@ -229,7 +231,8 @@ pub async fn auto_save_scratchpad(
.format("%Y-%m-%d %H:%M:%S") .format("%Y-%m-%d %H:%M:%S")
.to_string(), .to_string(),
last_saved_at_iso: updated.last_saved_at.to_rfc3339(), last_saved_at_iso: updated.last_saved_at.to_rfc3339(),
})) })
.into_response())
} }
pub async fn update_scratchpad_title( pub async fn update_scratchpad_title(
@@ -237,7 +240,7 @@ pub async fn update_scratchpad_title(
State(state): State<HtmlState>, State(state): State<HtmlState>,
Path(scratchpad_id): Path<String>, Path(scratchpad_id): Path<String>,
Form(form): Form<UpdateTitleForm>, Form(form): Form<UpdateTitleForm>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
Scratchpad::update_title(&scratchpad_id, &user.id, &form.title, &state.db).await?; Scratchpad::update_title(&scratchpad_id, &user.id, &form.title, &state.db).await?;
let scratchpad = Scratchpad::get_by_id(&scratchpad_id, &user.id, &state.db).await?; let scratchpad = Scratchpad::get_by_id(&scratchpad_id, &user.id, &state.db).await?;
@@ -255,7 +258,7 @@ pub async fn delete_scratchpad(
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
State(state): State<HtmlState>, State(state): State<HtmlState>,
Path(scratchpad_id): Path<String>, Path(scratchpad_id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
Scratchpad::delete(&scratchpad_id, &user.id, &state.db).await?; Scratchpad::delete(&scratchpad_id, &user.id, &state.db).await?;
// Return the updated main section content // Return the updated main section content
@@ -284,7 +287,7 @@ pub async fn ingest_scratchpad(
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
State(state): State<HtmlState>, State(state): State<HtmlState>,
Path(scratchpad_id): Path<String>, Path(scratchpad_id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> { ) -> ResponseResult {
let scratchpad = Scratchpad::get_by_id(&scratchpad_id, &user.id, &state.db).await?; let scratchpad = Scratchpad::get_by_id(&scratchpad_id, &user.id, &state.db).await?;
if scratchpad.content.trim().is_empty() { if scratchpad.content.trim().is_empty() {
@@ -347,29 +350,29 @@ pub async fn ingest_scratchpad(
r#"{"toast":{"title":"Ingestion queued","description":"Scratchpad archived and added to the ingestion queue.","type":"success"}}"#.to_string() r#"{"toast":{"title":"Ingestion queued","description":"Scratchpad archived and added to the ingestion queue.","type":"success"}}"#.to_string()
}); });
let template_response = TemplateResponse::new_partial( Ok(template_with_headers(
"scratchpad/base.html", TemplateResponse::new_partial(
"main", "scratchpad/base.html",
ScratchpadPageData { "main",
scratchpads: scratchpad_list, ScratchpadPageData {
archived_scratchpads: archived_list, scratchpads: scratchpad_list,
new_scratchpad: None, archived_scratchpads: archived_list,
new_scratchpad: None,
},
),
|headers| {
if let Ok(header_value) = HeaderValue::from_str(&trigger_value) {
headers.insert(HX_TRIGGER, header_value);
}
}, },
); ))
let mut response = template_response.into_response();
if let Ok(header_value) = HeaderValue::from_str(&trigger_value) {
response.headers_mut().insert(HX_TRIGGER, header_value);
}
Ok(response)
} }
pub async fn archive_scratchpad( pub async fn archive_scratchpad(
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
State(state): State<HtmlState>, State(state): State<HtmlState>,
Path(scratchpad_id): Path<String>, Path(scratchpad_id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
Scratchpad::archive(&scratchpad_id, &user.id, &state.db, false).await?; Scratchpad::archive(&scratchpad_id, &user.id, &state.db, false).await?;
let scratchpads = Scratchpad::get_by_user(&user.id, &state.db).await?; let scratchpads = Scratchpad::get_by_user(&user.id, &state.db).await?;
@@ -396,7 +399,7 @@ pub async fn restore_scratchpad(
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
State(state): State<HtmlState>, State(state): State<HtmlState>,
Path(scratchpad_id): Path<String>, Path(scratchpad_id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> { ) -> ResponseResult {
Scratchpad::restore(&scratchpad_id, &user.id, &state.db).await?; Scratchpad::restore(&scratchpad_id, &user.id, &state.db).await?;
let scratchpads = Scratchpad::get_by_user(&user.id, &state.db).await?; let scratchpads = Scratchpad::get_by_user(&user.id, &state.db).await?;
@@ -420,22 +423,22 @@ pub async fn restore_scratchpad(
r#"{"toast":{"title":"Scratchpad restored","description":"The scratchpad is back in your active list.","type":"info"}}"#.to_string() r#"{"toast":{"title":"Scratchpad restored","description":"The scratchpad is back in your active list.","type":"info"}}"#.to_string()
}); });
let template_response = TemplateResponse::new_partial( Ok(template_with_headers(
"scratchpad/base.html", TemplateResponse::new_partial(
"main", "scratchpad/base.html",
ScratchpadPageData { "main",
scratchpads: scratchpad_list, ScratchpadPageData {
archived_scratchpads: archived_list, scratchpads: scratchpad_list,
new_scratchpad: None, archived_scratchpads: archived_list,
new_scratchpad: None,
},
),
|headers| {
if let Ok(header_value) = HeaderValue::from_str(&trigger_value) {
headers.insert(HX_TRIGGER, header_value);
}
}, },
); ))
let mut response = template_response.into_response();
if let Ok(header_value) = HeaderValue::from_str(&trigger_value) {
response.headers_mut().insert(HX_TRIGGER, header_value);
}
Ok(response)
} }
#[cfg(test)] #[cfg(test)]
+2 -3
View File
@@ -2,7 +2,6 @@ use std::collections::HashSet;
use axum::{ use axum::{
extract::{Query, State}, extract::{Query, State},
response::IntoResponse,
}; };
use common::storage::types::{text_content::TextContent, user::User}; use common::storage::types::{text_content::TextContent, user::User};
use retrieval_pipeline::{RetrievalConfig, SearchResult, SearchTarget, StrategyOutput}; use retrieval_pipeline::{RetrievalConfig, SearchResult, SearchTarget, StrategyOutput};
@@ -13,7 +12,7 @@ use crate::{
html_state::HtmlState, html_state::HtmlState,
middlewares::{ middlewares::{
auth_middleware::RequireUser, auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse}, response_middleware::{HtmlError, TemplateResponse, TemplateResult},
}, },
}; };
@@ -79,7 +78,7 @@ pub async fn search_result_handler(
State(state): State<HtmlState>, State(state): State<HtmlState>,
Query(params): Query<SearchParams>, Query(params): Query<SearchParams>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> { ) -> TemplateResult {
let (search_results_for_template, final_query_param_for_template) = if let Some(actual_query) = let (search_results_for_template, final_query_param_for_template) = if let Some(actual_query) =
params.query params.query
{ {
+311
View File
@@ -0,0 +1,311 @@
#![allow(clippy::expect_used)]
use std::sync::Arc;
use axum::{
body::{to_bytes, Body},
http::{header, Request, StatusCode},
response::Response,
Router,
};
use common::{
storage::{
db::SurrealDbClient,
store::StorageManager,
types::user::User,
},
utils::{
config::{AppConfig, StorageKind},
embedding::EmbeddingProvider,
},
};
use html_router::{
html_routes,
html_state::{HtmlState, StateResources},
};
use tower::ServiceExt;
async fn build_test_app() -> (Router, Arc<SurrealDbClient>) {
let namespace = "html_router_test";
let database = &uuid::Uuid::new_v4().to_string();
let db = Arc::new(
SurrealDbClient::memory(namespace, database)
.await
.expect("in-memory db"),
);
db.apply_migrations()
.await
.expect("migrations should apply");
let session_store = Arc::new(
db.create_session_store()
.await
.expect("session store"),
);
let config = AppConfig {
storage: StorageKind::Memory,
..Default::default()
};
let storage = StorageManager::new(&config)
.await
.expect("storage manager");
let embedding_provider = Arc::new(
EmbeddingProvider::new_hashed(8).expect("embedding provider"),
);
let state = HtmlState::new_with_resources(StateResources {
db: Arc::clone(&db),
openai_client: Arc::new(async_openai::Client::new()),
session_store,
storage,
config,
reranker_pool: None,
embedding_provider,
template_engine: None,
});
let router = html_routes(&state).with_state(state);
(router, db)
}
fn redirect_location(response: &Response) -> String {
response
.headers()
.get(header::LOCATION)
.or_else(|| response.headers().get("HX-Redirect"))
.expect("redirect response should include Location or HX-Redirect")
.to_str()
.expect("redirect header must be utf-8")
.to_string()
}
fn session_cookie(response: &Response) -> String {
response
.headers()
.get_all(header::SET_COOKIE)
.iter()
.map(|value| {
value
.to_str()
.expect("set-cookie must be utf-8")
.split(';')
.next()
.expect("cookie key=value")
.to_string()
})
.collect::<Vec<_>>()
.join("; ")
}
async fn response_body(response: Response) -> String {
let body = to_bytes(response.into_body(), usize::MAX)
.await
.expect("response body");
String::from_utf8(body.to_vec()).expect("html body")
}
async fn sign_in(app: &Router, email: &str, password: &str) -> String {
let response = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/signin")
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.body(Body::from(format!("email={email}&password={password}")))
.expect("signin request"),
)
.await
.expect("signin response");
assert!(
response.status().is_redirection() || response.status() == StatusCode::OK,
"signin should redirect or return ok"
);
session_cookie(&response)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn protected_route_redirects_unauthenticated_users() {
let (app, _db) = build_test_app().await;
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/")
.body(Body::empty())
.expect("dashboard request"),
)
.await
.expect("dashboard response");
assert!(
response.status().is_redirection() || response.status() == StatusCode::OK,
"unauthenticated access should redirect via template middleware"
);
assert_eq!(redirect_location(&response), "/signin");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn authenticated_user_receives_rendered_dashboard() {
let (app, db) = build_test_app().await;
User::create_new(
"router_test@example.com".to_string(),
"test_password".to_string(),
&db,
"UTC".to_string(),
"system".to_string(),
)
.await
.expect("test user");
let cookie = sign_in(&app, "router_test@example.com", "test_password").await;
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/")
.header(header::COOKIE, cookie)
.body(Body::empty())
.expect("authenticated dashboard request"),
)
.await
.expect("authenticated dashboard response");
assert_eq!(response.status(), StatusCode::OK);
let body = to_bytes(response.into_body(), usize::MAX)
.await
.expect("response body");
let html = String::from_utf8(body.to_vec()).expect("html body");
assert!(
html.contains("dashboard") || html.contains("Dashboard"),
"dashboard template should render html"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn signin_form_is_public() {
let (app, _db) = build_test_app().await;
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/signin")
.body(Body::empty())
.expect("signin form request"),
)
.await
.expect("signin form response");
assert_eq!(response.status(), StatusCode::OK);
let html = response_body(response).await;
assert!(
html.contains("signin") || html.contains("Sign in") || html.contains("email"),
"signin page should render a form"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn signin_rejects_invalid_credentials() {
let (app, db) = build_test_app().await;
User::create_new(
"signin_test@example.com".to_string(),
"correct_password".to_string(),
&db,
"UTC".to_string(),
"system".to_string(),
)
.await
.expect("test user");
let response = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/signin")
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.body(Body::from(
"email=signin_test@example.com&password=wrong_password",
))
.expect("invalid signin request"),
)
.await
.expect("invalid signin response");
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
let html = response_body(response).await;
assert!(
html.contains("Incorrect email or password"),
"signin failure should render a safe error message"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn admin_route_redirects_non_admin_user() {
let (app, db) = build_test_app().await;
User::create_new(
"admin_user@example.com".to_string(),
"admin_password".to_string(),
&db,
"UTC".to_string(),
"system".to_string(),
)
.await
.expect("admin user");
User::create_new(
"member_user@example.com".to_string(),
"member_password".to_string(),
&db,
"UTC".to_string(),
"system".to_string(),
)
.await
.expect("member user");
let member_cookie = sign_in(&app, "member_user@example.com", "member_password").await;
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/admin")
.header(header::COOKIE, member_cookie)
.body(Body::empty())
.expect("non-admin admin request"),
)
.await
.expect("non-admin admin response");
assert!(
response.status().is_redirection() || response.status() == StatusCode::OK,
"non-admin should be redirected away from admin"
);
assert_eq!(redirect_location(&response), "/");
let admin_cookie = sign_in(&app, "admin_user@example.com", "admin_password").await;
let admin_response = app
.clone()
.oneshot(
Request::builder()
.uri("/admin")
.header(header::COOKIE, admin_cookie)
.body(Body::empty())
.expect("admin request"),
)
.await
.expect("admin response");
assert_eq!(admin_response.status(), StatusCode::OK);
}