mirror of
https://github.com/perstarkse/minne.git
synced 2026-05-31 03:40:38 +02:00
chore: harden api-router errors and add router integration tests while slimming html handlers.
This commit is contained in:
@@ -43,6 +43,7 @@ json-stream-parser = { path = "../json-stream-parser" }
|
||||
|
||||
[dev-dependencies]
|
||||
common = { path = "../common", features = ["test-utils"] }
|
||||
tower = "0.5"
|
||||
|
||||
[build-dependencies]
|
||||
minijinja-embed = { version = "2.8.0" }
|
||||
|
||||
@@ -116,6 +116,29 @@ impl IntoResponse for TemplateResponse {
|
||||
}
|
||||
}
|
||||
|
||||
/// Typical handler return type when no extra response headers or body are needed.
|
||||
pub type TemplateResult = Result<TemplateResponse, HtmlError>;
|
||||
|
||||
/// Handler return type when the response needs custom headers or a non-template body.
|
||||
pub type ResponseResult = Result<Response, HtmlError>;
|
||||
|
||||
/// Converts a [`TemplateResponse`] for [`ResponseResult`] handlers that do not set extra headers.
|
||||
pub fn template_as_response(template: TemplateResponse) -> Response {
|
||||
template.into_response()
|
||||
}
|
||||
|
||||
/// Wraps a [`TemplateResponse`] for the template middleware and applies outbound headers.
|
||||
///
|
||||
/// Headers listed in [`HTMX_HEADERS_TO_FORWARD`] are copied onto the rendered HTML response.
|
||||
pub fn template_with_headers(
|
||||
template: TemplateResponse,
|
||||
apply: impl FnOnce(&mut axum::http::HeaderMap),
|
||||
) -> Response {
|
||||
let mut response = template.into_response();
|
||||
apply(response.headers_mut());
|
||||
response
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct TemplateUser {
|
||||
id: String,
|
||||
@@ -143,7 +166,7 @@ struct ContextWrapper<'a> {
|
||||
initial_theme: &'a str,
|
||||
is_authenticated: bool,
|
||||
user: Option<&'a TemplateUser>,
|
||||
conversation_archive: Vec<SidebarConversation>,
|
||||
conversation_archive: &'a [SidebarConversation],
|
||||
#[serde(flatten)]
|
||||
context: HashMap<String, Value>,
|
||||
}
|
||||
@@ -213,18 +236,14 @@ where
|
||||
if let Some(template_response) = response.extensions().get::<TemplateResponse>().cloned() {
|
||||
let template_engine = state.template_engine();
|
||||
|
||||
let mut conversation_archive = Vec::new();
|
||||
|
||||
let should_load_conversation_archive =
|
||||
matches!(&template_response.template_kind, TemplateKind::Full(_));
|
||||
|
||||
if should_load_conversation_archive {
|
||||
let cached_archive = if should_load_conversation_archive {
|
||||
if let Some(user_id) = current_user.as_ref().map(|u| &u.id) {
|
||||
let html_state = state.html_state();
|
||||
if let Some(cached_archive) =
|
||||
html_state.get_cached_conversation_archive(user_id).await
|
||||
{
|
||||
conversation_archive = cached_archive.to_vec();
|
||||
if let Some(cached) = html_state.get_cached_conversation_archive(user_id).await {
|
||||
Some(cached)
|
||||
} else if let Ok(archive) =
|
||||
Conversation::get_user_sidebar_conversations(user_id, &html_state.db).await
|
||||
{
|
||||
@@ -232,10 +251,19 @@ where
|
||||
html_state
|
||||
.set_cached_conversation_archive(user_id, Arc::clone(&cached))
|
||||
.await;
|
||||
conversation_archive = cached.to_vec();
|
||||
Some(cached)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let conversation_archive = cached_archive
|
||||
.as_ref()
|
||||
.map_or(&[][..], |archive| archive.as_ref());
|
||||
|
||||
let context_map = match context_to_map(&template_response.context) {
|
||||
Ok(map) => map,
|
||||
|
||||
@@ -34,7 +34,7 @@ macro_rules! create_asset_service {
|
||||
}};
|
||||
}
|
||||
|
||||
pub type MiddleWareVecType<S> = Vec<Box<dyn FnOnce(Router<S>) -> Router<S> + Send>>;
|
||||
pub type MiddlewareVec<S> = Vec<Box<dyn FnOnce(Router<S>) -> Router<S> + Send>>;
|
||||
|
||||
/// Builder for composing public/protected HTML routes and middleware layers.
|
||||
pub struct RouterFactory<S> {
|
||||
@@ -43,7 +43,7 @@ pub struct RouterFactory<S> {
|
||||
protected_routers: Vec<Router<S>>,
|
||||
nested_routes: Vec<(String, Router<S>)>,
|
||||
nested_protected_routes: Vec<(String, Router<S>)>,
|
||||
custom_middleware: MiddleWareVecType<S>,
|
||||
custom_middleware: MiddlewareVec<S>,
|
||||
public_assets_config: Option<AssetsConfig>,
|
||||
compression_enabled: bool,
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
use axum::{extract::State, response::IntoResponse, Form};
|
||||
use axum::{extract::State, Form};
|
||||
use chrono_tz::TZ_VARIANTS;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
middlewares::{
|
||||
auth_middleware::RequireUser,
|
||||
response_middleware::{HtmlError, TemplateResponse},
|
||||
response_middleware::{TemplateResponse, TemplateResult},
|
||||
},
|
||||
AuthSessionType,
|
||||
};
|
||||
@@ -28,7 +28,7 @@ pub struct AccountPageData {
|
||||
pub async fn show_account_page(
|
||||
RequireUser(user): RequireUser,
|
||||
State(_state): State<HtmlState>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
let timezones = TZ_VARIANTS
|
||||
.iter()
|
||||
.map(std::string::ToString::to_string)
|
||||
@@ -57,7 +57,7 @@ pub async fn set_api_key(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
auth: AuthSessionType,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
// Generate and set the API key
|
||||
let api_key = User::set_api_key(&user.id, &state.db).await?;
|
||||
|
||||
@@ -82,7 +82,7 @@ pub async fn delete_account(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
auth: AuthSessionType,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
state.db.delete_item::<User>(&user.id).await?;
|
||||
|
||||
auth.logout_user();
|
||||
@@ -102,7 +102,7 @@ pub async fn update_timezone(
|
||||
RequireUser(user): RequireUser,
|
||||
auth: AuthSessionType,
|
||||
Form(form): Form<UpdateTimezoneForm>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
User::update_timezone(&user.id, &form.timezone, &state.db).await?;
|
||||
|
||||
// Clear the cache
|
||||
@@ -137,7 +137,7 @@ pub async fn update_theme(
|
||||
RequireUser(user): RequireUser,
|
||||
auth: AuthSessionType,
|
||||
Form(form): Form<UpdateThemeForm>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
User::update_theme(&user.id, &form.theme, &state.db).await?;
|
||||
|
||||
// Clear the cache
|
||||
@@ -166,7 +166,7 @@ pub async fn update_theme(
|
||||
|
||||
pub async fn show_change_password(
|
||||
RequireUser(_user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
Ok(TemplateResponse::new_template(
|
||||
"auth/change_password_form.html",
|
||||
(),
|
||||
@@ -184,7 +184,7 @@ pub async fn change_password(
|
||||
RequireUser(user): RequireUser,
|
||||
auth: AuthSessionType,
|
||||
Form(form): Form<NewPasswordForm>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
// Authenticate to make sure the password matches
|
||||
let authenticated_user = User::authenticate(&user.email, &form.old_password, &state.db).await?;
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ use std::sync::Arc;
|
||||
use async_openai::types::ListModelResponse;
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
response::IntoResponse,
|
||||
Form,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -26,7 +25,7 @@ use tracing::{error, info};
|
||||
|
||||
use crate::{
|
||||
html_state::HtmlState,
|
||||
middlewares::response_middleware::{HtmlError, TemplateResponse},
|
||||
middlewares::response_middleware::{TemplateResponse, TemplateResult},
|
||||
};
|
||||
|
||||
#[derive(Serialize)]
|
||||
@@ -57,7 +56,7 @@ pub struct AdminPanelQuery {
|
||||
pub async fn show_admin_panel(
|
||||
State(state): State<HtmlState>,
|
||||
Query(query): Query<AdminPanelQuery>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
let section = match query.section.as_deref() {
|
||||
Some("models") => AdminSection::Models,
|
||||
_ => AdminSection::Overview,
|
||||
@@ -124,7 +123,7 @@ pub struct RegistrationToggleData {
|
||||
pub async fn toggle_registration_status(
|
||||
State(state): State<HtmlState>,
|
||||
Form(input): Form<RegistrationToggleInput>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
let new_settings = SystemSettingsPatch {
|
||||
registrations_enabled: Some(input.registration_open),
|
||||
..Default::default()
|
||||
@@ -160,7 +159,7 @@ pub struct ModelSettingsData {
|
||||
pub async fn update_model_settings(
|
||||
State(state): State<HtmlState>,
|
||||
Form(input): Form<ModelSettingsInput>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
let current_settings = SystemSettings::get_current(&state.db).await?;
|
||||
|
||||
// Check if using FastEmbed - if so, embedding model/dimensions cannot be changed via UI
|
||||
@@ -272,7 +271,7 @@ pub struct SystemPromptEditData {
|
||||
|
||||
pub async fn show_edit_system_prompt(
|
||||
State(state): State<HtmlState>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
let settings = SystemSettings::get_current(&state.db).await?;
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
@@ -297,7 +296,7 @@ pub struct SystemPromptSectionData {
|
||||
pub async fn patch_query_prompt(
|
||||
State(state): State<HtmlState>,
|
||||
Form(input): Form<SystemPromptUpdateInput>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
let new_settings = SystemSettingsPatch {
|
||||
query_system_prompt: Some(input.query_system_prompt),
|
||||
..Default::default()
|
||||
@@ -322,7 +321,7 @@ pub struct IngestionPromptEditData {
|
||||
|
||||
pub async fn show_edit_ingestion_prompt(
|
||||
State(state): State<HtmlState>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
let settings = SystemSettings::get_current(&state.db).await?;
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
@@ -342,7 +341,7 @@ pub struct IngestionPromptUpdateInput {
|
||||
pub async fn patch_ingestion_prompt(
|
||||
State(state): State<HtmlState>,
|
||||
Form(input): Form<IngestionPromptUpdateInput>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
let new_settings = SystemSettingsPatch {
|
||||
ingestion_system_prompt: Some(input.ingestion_system_prompt),
|
||||
..Default::default()
|
||||
@@ -367,7 +366,7 @@ pub struct ImagePromptEditData {
|
||||
|
||||
pub async fn show_edit_image_prompt(
|
||||
State(state): State<HtmlState>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
let settings = SystemSettings::get_current(&state.db).await?;
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
@@ -387,7 +386,7 @@ pub struct ImagePromptUpdateInput {
|
||||
pub async fn patch_image_prompt(
|
||||
State(state): State<HtmlState>,
|
||||
Form(input): Form<ImagePromptUpdateInput>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
let new_settings = SystemSettingsPatch {
|
||||
image_processing_prompt: Some(input.image_processing_prompt),
|
||||
..Default::default()
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
use axum::{extract::State, response::IntoResponse, Form};
|
||||
use axum::{extract::State, Form};
|
||||
use axum_htmx::HxBoosted;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
html_state::HtmlState,
|
||||
middlewares::response_middleware::{HtmlError, TemplateResponse},
|
||||
middlewares::response_middleware::{TemplateResponse, TemplateResult},
|
||||
AuthSessionType,
|
||||
};
|
||||
use common::storage::types::user::User;
|
||||
@@ -19,7 +19,7 @@ pub struct SignInParams {
|
||||
pub async fn show_signin_form(
|
||||
auth: AuthSessionType,
|
||||
HxBoosted(boosted): HxBoosted,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
if auth.current_user.is_some() {
|
||||
return Ok(TemplateResponse::redirect("/"));
|
||||
}
|
||||
@@ -38,9 +38,9 @@ pub async fn authenticate_user(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSessionType,
|
||||
Form(form): Form<SignInParams>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
let Ok(user) = User::authenticate(&form.email, &form.password, &state.db).await else {
|
||||
return Ok(TemplateResponse::bad_request("Incorrect email or password").into_response());
|
||||
return Ok(TemplateResponse::bad_request("Incorrect email or password"));
|
||||
};
|
||||
|
||||
auth.login_user(user.id);
|
||||
@@ -49,5 +49,5 @@ pub async fn authenticate_user(
|
||||
auth.remember_user(true);
|
||||
}
|
||||
|
||||
Ok(TemplateResponse::redirect("/").into_response())
|
||||
Ok(TemplateResponse::redirect("/"))
|
||||
}
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
use axum::response::IntoResponse;
|
||||
|
||||
use crate::{
|
||||
middlewares::response_middleware::{HtmlError, TemplateResponse},
|
||||
middlewares::response_middleware::{TemplateResponse, TemplateResult},
|
||||
AuthSessionType,
|
||||
};
|
||||
|
||||
pub async fn sign_out_user(auth: AuthSessionType) -> Result<impl IntoResponse, HtmlError> {
|
||||
pub async fn sign_out_user(auth: AuthSessionType) -> TemplateResult {
|
||||
if !auth.is_authenticated() {
|
||||
return Ok(TemplateResponse::redirect("/"));
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use axum::{extract::State, response::IntoResponse, Form};
|
||||
use axum::{extract::State, Form};
|
||||
use axum_htmx::HxBoosted;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -6,7 +6,7 @@ use common::{error::AppError, storage::types::user::{Theme, User}};
|
||||
|
||||
use crate::{
|
||||
html_state::HtmlState,
|
||||
middlewares::response_middleware::{HtmlError, TemplateResponse},
|
||||
middlewares::response_middleware::{TemplateResponse, TemplateResult},
|
||||
AuthSessionType,
|
||||
};
|
||||
|
||||
@@ -27,7 +27,7 @@ fn signup_error_message(err: &AppError) -> &str {
|
||||
pub async fn show_signup_form(
|
||||
auth: AuthSessionType,
|
||||
HxBoosted(boosted): HxBoosted,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
if auth.current_user.is_some() {
|
||||
return Ok(TemplateResponse::redirect("/"));
|
||||
}
|
||||
@@ -47,7 +47,7 @@ pub async fn process_signup_and_show_verification(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSessionType,
|
||||
Form(form): Form<Params>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
let user = match User::create_new(
|
||||
form.email,
|
||||
form.password,
|
||||
@@ -60,11 +60,11 @@ pub async fn process_signup_and_show_verification(
|
||||
Ok(user) => user,
|
||||
Err(err) => {
|
||||
tracing::error!(?err, "signup failed");
|
||||
return Ok(TemplateResponse::bad_request(signup_error_message(&err)).into_response());
|
||||
return Ok(TemplateResponse::bad_request(signup_error_message(&err)));
|
||||
}
|
||||
};
|
||||
|
||||
auth.login_user(user.id);
|
||||
|
||||
Ok(TemplateResponse::redirect("/").into_response())
|
||||
Ok(TemplateResponse::redirect("/"))
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::HeaderValue,
|
||||
response::IntoResponse,
|
||||
Form,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -18,7 +17,10 @@ use crate::{
|
||||
html_state::HtmlState,
|
||||
middlewares::{
|
||||
auth_middleware::RequireUser,
|
||||
response_middleware::{HtmlError, TemplateResponse},
|
||||
response_middleware::{
|
||||
template_as_response, template_with_headers, TemplateResponse, TemplateResult,
|
||||
ResponseResult,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
@@ -31,7 +33,7 @@ pub struct ChatPageData {
|
||||
pub async fn show_chat_base(
|
||||
State(_state): State<HtmlState>,
|
||||
RequireUser(_user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
Ok(TemplateResponse::new_template(
|
||||
"chat/base.html",
|
||||
ChatPageData {
|
||||
@@ -50,7 +52,7 @@ pub async fn show_existing_chat(
|
||||
Path(conversation_id): Path<String>,
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
let (conversation, messages) =
|
||||
Conversation::get_complete_conversation(conversation_id.as_str(), &user.id, &state.db)
|
||||
.await?;
|
||||
@@ -69,7 +71,7 @@ pub async fn new_user_message(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Form(form): Form<NewMessageForm>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> ResponseResult {
|
||||
#[derive(Serialize)]
|
||||
struct SSEResponseInitData {
|
||||
user_message: Message,
|
||||
@@ -82,31 +84,32 @@ pub async fn new_user_message(
|
||||
.ok_or_else(|| AppError::NotFound("conversation was not found".into()))?;
|
||||
|
||||
if conversation.user_id != user.id {
|
||||
return Ok(TemplateResponse::unauthorized().into_response());
|
||||
return Ok(template_as_response(TemplateResponse::unauthorized()));
|
||||
}
|
||||
|
||||
let user_message = Message::new(conversation_id, MessageRole::User, form.content, None);
|
||||
|
||||
state.db.store_item(user_message.clone()).await?;
|
||||
|
||||
let mut response = TemplateResponse::new_template(
|
||||
"chat/streaming_response.html",
|
||||
SSEResponseInitData { user_message },
|
||||
)
|
||||
.into_response();
|
||||
|
||||
if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) {
|
||||
response.headers_mut().insert("HX-Push", header_value);
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
let push_path = format!("/chat/{}", conversation.id);
|
||||
Ok(template_with_headers(
|
||||
TemplateResponse::new_template(
|
||||
"chat/streaming_response.html",
|
||||
SSEResponseInitData { user_message },
|
||||
),
|
||||
|headers| {
|
||||
if let Ok(header_value) = HeaderValue::from_str(&push_path) {
|
||||
headers.insert("HX-Push", header_value);
|
||||
}
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
pub async fn new_chat_user_message(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Form(form): Form<NewMessageForm>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> ResponseResult {
|
||||
#[derive(Serialize)]
|
||||
struct SSEResponseInitData {
|
||||
user_message: Message,
|
||||
@@ -125,20 +128,21 @@ pub async fn new_chat_user_message(
|
||||
state.db.store_item(user_message.clone()).await?;
|
||||
state.invalidate_conversation_archive_cache(&user.id).await;
|
||||
|
||||
let mut response = TemplateResponse::new_template(
|
||||
"chat/new_chat_first_response.html",
|
||||
SSEResponseInitData {
|
||||
user_message,
|
||||
conversation: conversation.clone(),
|
||||
let push_path = format!("/chat/{}", conversation.id);
|
||||
Ok(template_with_headers(
|
||||
TemplateResponse::new_template(
|
||||
"chat/new_chat_first_response.html",
|
||||
SSEResponseInitData {
|
||||
user_message,
|
||||
conversation: conversation.clone(),
|
||||
},
|
||||
),
|
||||
|headers| {
|
||||
if let Ok(header_value) = HeaderValue::from_str(&push_path) {
|
||||
headers.insert("HX-Push", header_value);
|
||||
}
|
||||
},
|
||||
)
|
||||
.into_response();
|
||||
|
||||
if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) {
|
||||
response.headers_mut().insert("HX-Push", header_value);
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -155,7 +159,7 @@ pub async fn show_conversation_editing_title(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Path(conversation_id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
let conversation: Conversation = state
|
||||
.db
|
||||
.get_item(&conversation_id)
|
||||
@@ -163,7 +167,7 @@ pub async fn show_conversation_editing_title(
|
||||
.ok_or_else(|| AppError::NotFound("conversation not found".to_string()))?;
|
||||
|
||||
if conversation.user_id != user.id {
|
||||
return Ok(TemplateResponse::unauthorized().into_response());
|
||||
return Ok(TemplateResponse::unauthorized());
|
||||
}
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
@@ -171,8 +175,7 @@ pub async fn show_conversation_editing_title(
|
||||
DrawerContext {
|
||||
edit_conversation_id: Some(conversation_id),
|
||||
},
|
||||
)
|
||||
.into_response())
|
||||
))
|
||||
}
|
||||
|
||||
pub async fn patch_conversation_title(
|
||||
@@ -180,7 +183,7 @@ pub async fn patch_conversation_title(
|
||||
RequireUser(user): RequireUser,
|
||||
Path(conversation_id): Path<String>,
|
||||
Form(form): Form<PatchConversationTitle>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
Conversation::patch_title(&conversation_id, &user.id, &form.title, &state.db).await?;
|
||||
state.invalidate_conversation_archive_cache(&user.id).await;
|
||||
|
||||
@@ -189,15 +192,14 @@ pub async fn patch_conversation_title(
|
||||
DrawerContext {
|
||||
edit_conversation_id: None,
|
||||
},
|
||||
)
|
||||
.into_response())
|
||||
))
|
||||
}
|
||||
|
||||
pub async fn delete_conversation(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Path(conversation_id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
let conversation: Conversation = state
|
||||
.db
|
||||
.get_item(&conversation_id)
|
||||
@@ -205,7 +207,7 @@ pub async fn delete_conversation(
|
||||
.ok_or_else(|| AppError::NotFound("conversation not found".to_string()))?;
|
||||
|
||||
if conversation.user_id != user.id {
|
||||
return Ok(TemplateResponse::unauthorized().into_response());
|
||||
return Ok(TemplateResponse::unauthorized());
|
||||
}
|
||||
|
||||
state
|
||||
@@ -219,18 +221,16 @@ pub async fn delete_conversation(
|
||||
DrawerContext {
|
||||
edit_conversation_id: None,
|
||||
},
|
||||
)
|
||||
.into_response())
|
||||
))
|
||||
}
|
||||
pub async fn reload_sidebar(
|
||||
State(_state): State<HtmlState>,
|
||||
RequireUser(_user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
Ok(TemplateResponse::new_template(
|
||||
"sidebar.html",
|
||||
DrawerContext {
|
||||
edit_conversation_id: None,
|
||||
},
|
||||
)
|
||||
.into_response())
|
||||
))
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use chrono::{DateTime, Utc};
|
||||
use chrono_tz::Tz;
|
||||
@@ -16,7 +15,7 @@ use crate::{
|
||||
html_state::HtmlState,
|
||||
middlewares::{
|
||||
auth_middleware::RequireUser,
|
||||
response_middleware::{HtmlError, TemplateResponse},
|
||||
response_middleware::{TemplateResponse, TemplateResult},
|
||||
},
|
||||
};
|
||||
|
||||
@@ -45,7 +44,7 @@ pub async fn show_reference_tooltip(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Path(reference_id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
let Ok((normalized_reference_id, target)) = normalize_reference(&reference_id) else {
|
||||
return Ok(TemplateResponse::not_found());
|
||||
};
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
response::IntoResponse,
|
||||
Form,
|
||||
};
|
||||
use axum_htmx::{HxBoosted, HxRequest, HxTarget};
|
||||
@@ -15,7 +14,7 @@ use crate::{
|
||||
html_state::HtmlState,
|
||||
middlewares::{
|
||||
auth_middleware::RequireUser,
|
||||
response_middleware::{HtmlError, TemplateResponse},
|
||||
response_middleware::{TemplateResponse, TemplateResult},
|
||||
},
|
||||
utils::pagination::{paginate_items, Pagination},
|
||||
utils::text_content_preview::truncate_text_contents,
|
||||
@@ -50,7 +49,7 @@ pub async fn show_content_page(
|
||||
Query(params): Query<FilterParams>,
|
||||
HxRequest(is_htmx): HxRequest,
|
||||
HxBoosted(is_boosted): HxBoosted,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
// Normalize empty strings to None
|
||||
let category_filter = params
|
||||
.category
|
||||
@@ -101,7 +100,7 @@ pub async fn show_text_content_edit_form(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
#[derive(Serialize)]
|
||||
pub struct TextContentEditModal {
|
||||
pub text_content: TextContent,
|
||||
@@ -127,7 +126,7 @@ pub async fn patch_text_content(
|
||||
Path(id): Path<String>,
|
||||
HxTarget(target): HxTarget,
|
||||
Form(form): Form<PatchTextContentParams>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
|
||||
|
||||
TextContent::patch(&id, &form.context, &form.category, &form.text, &state.db).await?;
|
||||
@@ -167,7 +166,7 @@ pub async fn delete_text_content(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
// Get and validate the text content
|
||||
let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
|
||||
|
||||
@@ -213,7 +212,7 @@ pub async fn show_content_read_modal(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
#[derive(Serialize)]
|
||||
pub struct TextContentReadModalData {
|
||||
pub text_content: TextContent,
|
||||
@@ -231,7 +230,7 @@ pub async fn show_content_read_modal(
|
||||
pub async fn show_recent_content(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
let text_contents =
|
||||
truncate_text_contents(User::get_latest_text_contents(&user.id, &state.db).await?);
|
||||
|
||||
|
||||
@@ -12,7 +12,9 @@ use crate::{
|
||||
html_state::HtmlState,
|
||||
middlewares::{
|
||||
auth_middleware::RequireUser,
|
||||
response_middleware::{HtmlError, TemplateResponse},
|
||||
response_middleware::{
|
||||
template_as_response, HtmlError, TemplateResponse, TemplateResult, ResponseResult,
|
||||
},
|
||||
},
|
||||
utils::text_content_preview::truncate_text_contents,
|
||||
utils::truncate::with_ellipsis,
|
||||
@@ -37,7 +39,7 @@ pub struct IndexPageData {
|
||||
pub async fn index_handler(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
let (text_contents, dashboard_stats, active_jobs) = try_join!(
|
||||
User::get_latest_text_contents(&user.id, &state.db),
|
||||
User::get_dashboard_stats(&user.id, &state.db),
|
||||
@@ -65,7 +67,7 @@ pub async fn delete_text_content(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
// Get and validate TextContent
|
||||
let text_content = get_and_validate_text_content(&state, &id, &user).await?;
|
||||
|
||||
@@ -154,7 +156,7 @@ pub async fn delete_job(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
User::validate_and_delete_job(&id, &user.id, &state.db).await?;
|
||||
|
||||
let active_jobs = User::get_unfinished_ingestion_tasks(&user.id, &state.db).await?;
|
||||
@@ -169,7 +171,7 @@ pub async fn delete_job(
|
||||
pub async fn show_active_jobs(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
let active_jobs = User::get_unfinished_ingestion_tasks(&user.id, &state.db).await?;
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
@@ -181,7 +183,7 @@ pub async fn show_active_jobs(
|
||||
pub async fn show_task_archive(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
let tasks = User::get_all_ingestion_tasks(&user.id, &state.db).await?;
|
||||
|
||||
let entries: Vec<TaskArchiveEntry> = tasks
|
||||
@@ -234,17 +236,17 @@ pub async fn serve_file(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Path(file_id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> ResponseResult {
|
||||
let Ok(file_info) = FileInfo::get_by_id(&file_id, &state.db).await else {
|
||||
return Ok(TemplateResponse::not_found().into_response());
|
||||
return Ok(template_as_response(TemplateResponse::not_found()));
|
||||
};
|
||||
|
||||
if file_info.user_id != user.id {
|
||||
return Ok(TemplateResponse::unauthorized().into_response());
|
||||
return Ok(template_as_response(TemplateResponse::unauthorized()));
|
||||
}
|
||||
|
||||
let Ok(stream) = state.storage.get_stream(&file_info.path).await else {
|
||||
return Ok(TemplateResponse::server_error().into_response());
|
||||
return Ok(template_as_response(TemplateResponse::server_error()));
|
||||
};
|
||||
let body = Body::from_stream(stream);
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ use axum::{
|
||||
http::StatusCode,
|
||||
response::{
|
||||
sse::{Event, KeepAlive, KeepAliveStream},
|
||||
IntoResponse, Response, Sse,
|
||||
Sse,
|
||||
},
|
||||
};
|
||||
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
|
||||
@@ -31,7 +31,7 @@ use crate::{
|
||||
html_state::HtmlState,
|
||||
middlewares::{
|
||||
auth_middleware::RequireUser,
|
||||
response_middleware::{HtmlError, TemplateResponse},
|
||||
response_middleware::{TemplateResponse, TemplateResult},
|
||||
},
|
||||
};
|
||||
|
||||
@@ -49,7 +49,7 @@ fn sse_with_keep_alive(stream: EventStream) -> TaskSse {
|
||||
pub async fn show_ingest_form(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
#[derive(Serialize)]
|
||||
pub struct ShowIngestFormData {
|
||||
user_categories: Vec<String>,
|
||||
@@ -65,7 +65,7 @@ pub async fn show_ingest_form(
|
||||
|
||||
pub async fn hide_ingest_form(
|
||||
RequireUser(_user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
Ok(TemplateResponse::new_template(
|
||||
"ingestion/add_content_button.html",
|
||||
(),
|
||||
@@ -91,12 +91,11 @@ pub async fn process_ingest_form(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
TypedMultipart(input): TypedMultipart<IngestionParams>,
|
||||
) -> Result<Response, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
if input.content.as_ref().is_none_or(|c| c.len() < 2) && input.files.is_empty() {
|
||||
return Ok(
|
||||
TemplateResponse::bad_request("You need to either add files or content")
|
||||
.into_response(),
|
||||
);
|
||||
return Ok(TemplateResponse::bad_request(
|
||||
"You need to either add files or content",
|
||||
));
|
||||
}
|
||||
|
||||
let content_bytes = input.content.as_ref().map_or(0, String::len);
|
||||
@@ -118,11 +117,10 @@ pub async fn process_ingest_form(
|
||||
StatusCode::PAYLOAD_TOO_LARGE,
|
||||
"Payload Too Large",
|
||||
&message,
|
||||
)
|
||||
.into_response());
|
||||
));
|
||||
}
|
||||
Err(IngestValidationError::BadRequest(message)) => {
|
||||
return Ok(TemplateResponse::bad_request(&message).into_response());
|
||||
return Ok(TemplateResponse::bad_request(&message));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -153,10 +151,10 @@ pub async fn process_ingest_form(
|
||||
let tasks =
|
||||
IngestionTask::create_all_and_add_to_db(payloads, &user.id, &state.db).await?;
|
||||
|
||||
Ok(
|
||||
TemplateResponse::new_template("dashboard/current_task.html", NewTasksData { tasks })
|
||||
.into_response(),
|
||||
)
|
||||
Ok(TemplateResponse::new_template(
|
||||
"dashboard/current_task.html",
|
||||
NewTasksData { tasks },
|
||||
))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
|
||||
@@ -31,7 +31,9 @@ use crate::{
|
||||
html_state::HtmlState,
|
||||
middlewares::{
|
||||
auth_middleware::RequireUser,
|
||||
response_middleware::{HtmlError, TemplateResponse},
|
||||
response_middleware::{
|
||||
template_with_headers, HtmlError, TemplateResponse, TemplateResult, ResponseResult,
|
||||
},
|
||||
},
|
||||
utils::pagination::{paginate_items, Pagination},
|
||||
};
|
||||
@@ -120,12 +122,12 @@ fn collect_relationship_type_options(relationships: &[KnowledgeRelationship]) ->
|
||||
options
|
||||
}
|
||||
|
||||
fn respond_with_graph_refresh(response: TemplateResponse) -> Response {
|
||||
let mut response = response.into_response();
|
||||
if let Ok(value) = HeaderValue::from_str(GRAPH_REFRESH_TRIGGER) {
|
||||
response.headers_mut().insert(HX_TRIGGER, value);
|
||||
}
|
||||
response
|
||||
fn graph_refresh_response(template: TemplateResponse) -> Response {
|
||||
template_with_headers(template, |headers| {
|
||||
if let Ok(value) = HeaderValue::from_str(GRAPH_REFRESH_TRIGGER) {
|
||||
headers.insert(HX_TRIGGER, value);
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Default)]
|
||||
@@ -138,7 +140,7 @@ pub struct FilterParams {
|
||||
pub async fn show_new_knowledge_entity_form(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
let entity_types: Vec<String> = KnowledgeEntityType::variants()
|
||||
.iter()
|
||||
.map(ToString::to_string)
|
||||
@@ -170,7 +172,7 @@ pub async fn create_knowledge_entity(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Form(form): Form<CreateKnowledgeEntityParams>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> ResponseResult {
|
||||
let name = form.name.trim().to_string();
|
||||
if name.is_empty() {
|
||||
return Err(AppError::Validation("name is required".into()).into());
|
||||
@@ -230,7 +232,7 @@ pub async fn create_knowledge_entity(
|
||||
|
||||
let default_params = FilterParams::default();
|
||||
let kb_data = build_knowledge_base_data(&state, &user, &default_params).await?;
|
||||
Ok(respond_with_graph_refresh(TemplateResponse::new_partial(
|
||||
Ok(graph_refresh_response(TemplateResponse::new_partial(
|
||||
"knowledge/base.html",
|
||||
"main",
|
||||
kb_data,
|
||||
@@ -241,7 +243,7 @@ pub async fn suggest_knowledge_relationships(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Form(form): Form<SuggestRelationshipsParams>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
let entity_lookup: HashMap<String, KnowledgeEntity> =
|
||||
User::get_knowledge_entities(&user.id, &state.db)
|
||||
.await?
|
||||
@@ -723,7 +725,7 @@ pub async fn show_knowledge_page(
|
||||
HxRequest(is_htmx): HxRequest,
|
||||
HxBoosted(is_boosted): HxBoosted,
|
||||
Query(mut params): Query<FilterParams>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
// Normalize filters: treat empty or "none" as no filter
|
||||
params.entity_type = normalize_filter(params.entity_type.take());
|
||||
params.content_category = normalize_filter(params.content_category.take());
|
||||
@@ -772,7 +774,7 @@ pub async fn get_knowledge_graph_json(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Query(mut params): Query<FilterParams>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> ResponseResult {
|
||||
// Normalize filters: treat empty or "none" as no filter
|
||||
params.entity_type = normalize_filter(params.entity_type.take());
|
||||
params.content_category = normalize_filter(params.content_category.take());
|
||||
@@ -821,7 +823,7 @@ pub async fn get_knowledge_graph_json(
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(Json(GraphData { nodes, links }))
|
||||
Ok(Json(GraphData { nodes, links }).into_response())
|
||||
}
|
||||
// Normalize filter parameters: convert empty strings or "none" (case-insensitive) to None
|
||||
fn normalize_filter(input: Option<String>) -> Option<String> {
|
||||
@@ -851,7 +853,7 @@ pub async fn show_edit_knowledge_entity_form(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
#[derive(Serialize)]
|
||||
pub struct EntityData {
|
||||
entity: KnowledgeEntity,
|
||||
@@ -899,7 +901,7 @@ pub async fn patch_knowledge_entity(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Form(form): Form<PatchKnowledgeEntityParams>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> ResponseResult {
|
||||
// Get the existing entity and validate that the user is allowed
|
||||
User::get_and_validate_knowledge_entity(&form.id, &user.id, &state.db).await?;
|
||||
|
||||
@@ -930,7 +932,7 @@ pub async fn patch_knowledge_entity(
|
||||
let content_categories = User::get_user_categories(&user.id, &state.db).await?;
|
||||
|
||||
// Render updated list
|
||||
Ok(respond_with_graph_refresh(TemplateResponse::new_template(
|
||||
Ok(graph_refresh_response(TemplateResponse::new_template(
|
||||
"knowledge/entity_list.html",
|
||||
EntityListData {
|
||||
visible_entities,
|
||||
@@ -948,7 +950,7 @@ pub async fn delete_knowledge_entity(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> ResponseResult {
|
||||
// Get the existing entity and validate that the user is allowed
|
||||
User::get_and_validate_knowledge_entity(&id, &user.id, &state.db).await?;
|
||||
|
||||
@@ -968,7 +970,7 @@ pub async fn delete_knowledge_entity(
|
||||
// Get content categories
|
||||
let content_categories = User::get_user_categories(&user.id, &state.db).await?;
|
||||
|
||||
Ok(respond_with_graph_refresh(TemplateResponse::new_template(
|
||||
Ok(graph_refresh_response(TemplateResponse::new_template(
|
||||
"knowledge/entity_list.html",
|
||||
EntityListData {
|
||||
visible_entities,
|
||||
@@ -994,7 +996,7 @@ pub async fn delete_knowledge_relationship(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> ResponseResult {
|
||||
KnowledgeRelationship::delete_relationship_by_id(&id, &user.id, &state.db).await?;
|
||||
|
||||
let entities = User::get_knowledge_entities(&user.id, &state.db).await?;
|
||||
@@ -1003,7 +1005,7 @@ pub async fn delete_knowledge_relationship(
|
||||
let table_data = build_relationship_table_data(entities, relationships);
|
||||
|
||||
// Render updated list
|
||||
Ok(respond_with_graph_refresh(TemplateResponse::new_template(
|
||||
Ok(graph_refresh_response(TemplateResponse::new_template(
|
||||
"knowledge/relationship_table.html",
|
||||
table_data,
|
||||
)))
|
||||
@@ -1020,7 +1022,7 @@ pub async fn save_knowledge_relationship(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Form(form): Form<SaveKnowledgeRelationshipInput>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> ResponseResult {
|
||||
// Construct relationship
|
||||
let relationship_type = canonicalize_relationship_type(&form.relationship_type);
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
@@ -1039,7 +1041,7 @@ pub async fn save_knowledge_relationship(
|
||||
let table_data = build_relationship_table_data(entities, relationships);
|
||||
|
||||
// Render updated list
|
||||
Ok(respond_with_graph_refresh(TemplateResponse::new_template(
|
||||
Ok(graph_refresh_response(TemplateResponse::new_template(
|
||||
"knowledge/relationship_table.html",
|
||||
table_data,
|
||||
)))
|
||||
|
||||
@@ -11,7 +11,9 @@ use serde::{Deserialize, Serialize};
|
||||
use crate::html_state::HtmlState;
|
||||
use crate::middlewares::{
|
||||
auth_middleware::RequireUser,
|
||||
response_middleware::{HtmlError, TemplateResponse},
|
||||
response_middleware::{
|
||||
template_with_headers, HtmlError, TemplateResponse, TemplateResult, ResponseResult,
|
||||
},
|
||||
};
|
||||
use common::storage::types::{
|
||||
ingestion_payload::IngestionPayload, ingestion_task::IngestionTask, scratchpad::Scratchpad,
|
||||
@@ -127,7 +129,7 @@ pub async fn show_scratchpad_page(
|
||||
HxRequest(is_htmx): HxRequest,
|
||||
HxBoosted(is_boosted): HxBoosted,
|
||||
State(state): State<HtmlState>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
let scratchpads = Scratchpad::get_by_user(&user.id, &state.db).await?;
|
||||
let archived_scratchpads = Scratchpad::get_archived_by_user(&user.id, &state.db).await?;
|
||||
|
||||
@@ -165,7 +167,7 @@ pub async fn show_scratchpad_modal(
|
||||
State(state): State<HtmlState>,
|
||||
Path(scratchpad_id): Path<String>,
|
||||
Query(query): Query<EditTitleQuery>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
let scratchpad = Scratchpad::get_by_id(&scratchpad_id, &user.id, &state.db).await?;
|
||||
|
||||
let scratchpad_detail = ScratchpadDetail::from(&scratchpad);
|
||||
@@ -186,7 +188,7 @@ pub async fn create_scratchpad(
|
||||
RequireUser(user): RequireUser,
|
||||
State(state): State<HtmlState>,
|
||||
Form(form): Form<CreateScratchpadForm>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
let user_id = user.id.clone();
|
||||
let scratchpad = Scratchpad::new(user_id.clone(), form.title);
|
||||
let _stored = state.db.store_item(scratchpad.clone()).await?;
|
||||
@@ -217,7 +219,7 @@ pub async fn auto_save_scratchpad(
|
||||
State(state): State<HtmlState>,
|
||||
Path(scratchpad_id): Path<String>,
|
||||
Form(form): Form<UpdateScratchpadForm>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> ResponseResult {
|
||||
let updated =
|
||||
Scratchpad::update_content(&scratchpad_id, &user.id, &form.content, &state.db).await?;
|
||||
|
||||
@@ -229,7 +231,8 @@ pub async fn auto_save_scratchpad(
|
||||
.format("%Y-%m-%d %H:%M:%S")
|
||||
.to_string(),
|
||||
last_saved_at_iso: updated.last_saved_at.to_rfc3339(),
|
||||
}))
|
||||
})
|
||||
.into_response())
|
||||
}
|
||||
|
||||
pub async fn update_scratchpad_title(
|
||||
@@ -237,7 +240,7 @@ pub async fn update_scratchpad_title(
|
||||
State(state): State<HtmlState>,
|
||||
Path(scratchpad_id): Path<String>,
|
||||
Form(form): Form<UpdateTitleForm>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
Scratchpad::update_title(&scratchpad_id, &user.id, &form.title, &state.db).await?;
|
||||
|
||||
let scratchpad = Scratchpad::get_by_id(&scratchpad_id, &user.id, &state.db).await?;
|
||||
@@ -255,7 +258,7 @@ pub async fn delete_scratchpad(
|
||||
RequireUser(user): RequireUser,
|
||||
State(state): State<HtmlState>,
|
||||
Path(scratchpad_id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
Scratchpad::delete(&scratchpad_id, &user.id, &state.db).await?;
|
||||
|
||||
// Return the updated main section content
|
||||
@@ -284,7 +287,7 @@ pub async fn ingest_scratchpad(
|
||||
RequireUser(user): RequireUser,
|
||||
State(state): State<HtmlState>,
|
||||
Path(scratchpad_id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> ResponseResult {
|
||||
let scratchpad = Scratchpad::get_by_id(&scratchpad_id, &user.id, &state.db).await?;
|
||||
|
||||
if scratchpad.content.trim().is_empty() {
|
||||
@@ -347,29 +350,29 @@ pub async fn ingest_scratchpad(
|
||||
r#"{"toast":{"title":"Ingestion queued","description":"Scratchpad archived and added to the ingestion queue.","type":"success"}}"#.to_string()
|
||||
});
|
||||
|
||||
let template_response = TemplateResponse::new_partial(
|
||||
"scratchpad/base.html",
|
||||
"main",
|
||||
ScratchpadPageData {
|
||||
scratchpads: scratchpad_list,
|
||||
archived_scratchpads: archived_list,
|
||||
new_scratchpad: None,
|
||||
Ok(template_with_headers(
|
||||
TemplateResponse::new_partial(
|
||||
"scratchpad/base.html",
|
||||
"main",
|
||||
ScratchpadPageData {
|
||||
scratchpads: scratchpad_list,
|
||||
archived_scratchpads: archived_list,
|
||||
new_scratchpad: None,
|
||||
},
|
||||
),
|
||||
|headers| {
|
||||
if let Ok(header_value) = HeaderValue::from_str(&trigger_value) {
|
||||
headers.insert(HX_TRIGGER, header_value);
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
let mut response = template_response.into_response();
|
||||
if let Ok(header_value) = HeaderValue::from_str(&trigger_value) {
|
||||
response.headers_mut().insert(HX_TRIGGER, header_value);
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
))
|
||||
}
|
||||
|
||||
pub async fn archive_scratchpad(
|
||||
RequireUser(user): RequireUser,
|
||||
State(state): State<HtmlState>,
|
||||
Path(scratchpad_id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
Scratchpad::archive(&scratchpad_id, &user.id, &state.db, false).await?;
|
||||
|
||||
let scratchpads = Scratchpad::get_by_user(&user.id, &state.db).await?;
|
||||
@@ -396,7 +399,7 @@ pub async fn restore_scratchpad(
|
||||
RequireUser(user): RequireUser,
|
||||
State(state): State<HtmlState>,
|
||||
Path(scratchpad_id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> ResponseResult {
|
||||
Scratchpad::restore(&scratchpad_id, &user.id, &state.db).await?;
|
||||
|
||||
let scratchpads = Scratchpad::get_by_user(&user.id, &state.db).await?;
|
||||
@@ -420,22 +423,22 @@ pub async fn restore_scratchpad(
|
||||
r#"{"toast":{"title":"Scratchpad restored","description":"The scratchpad is back in your active list.","type":"info"}}"#.to_string()
|
||||
});
|
||||
|
||||
let template_response = TemplateResponse::new_partial(
|
||||
"scratchpad/base.html",
|
||||
"main",
|
||||
ScratchpadPageData {
|
||||
scratchpads: scratchpad_list,
|
||||
archived_scratchpads: archived_list,
|
||||
new_scratchpad: None,
|
||||
Ok(template_with_headers(
|
||||
TemplateResponse::new_partial(
|
||||
"scratchpad/base.html",
|
||||
"main",
|
||||
ScratchpadPageData {
|
||||
scratchpads: scratchpad_list,
|
||||
archived_scratchpads: archived_list,
|
||||
new_scratchpad: None,
|
||||
},
|
||||
),
|
||||
|headers| {
|
||||
if let Ok(header_value) = HeaderValue::from_str(&trigger_value) {
|
||||
headers.insert(HX_TRIGGER, header_value);
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
let mut response = template_response.into_response();
|
||||
if let Ok(header_value) = HeaderValue::from_str(&trigger_value) {
|
||||
response.headers_mut().insert(HX_TRIGGER, header_value);
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -2,7 +2,6 @@ use std::collections::HashSet;
|
||||
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use common::storage::types::{text_content::TextContent, user::User};
|
||||
use retrieval_pipeline::{RetrievalConfig, SearchResult, SearchTarget, StrategyOutput};
|
||||
@@ -13,7 +12,7 @@ use crate::{
|
||||
html_state::HtmlState,
|
||||
middlewares::{
|
||||
auth_middleware::RequireUser,
|
||||
response_middleware::{HtmlError, TemplateResponse},
|
||||
response_middleware::{HtmlError, TemplateResponse, TemplateResult},
|
||||
},
|
||||
};
|
||||
|
||||
@@ -79,7 +78,7 @@ pub async fn search_result_handler(
|
||||
State(state): State<HtmlState>,
|
||||
Query(params): Query<SearchParams>,
|
||||
RequireUser(user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
) -> TemplateResult {
|
||||
let (search_results_for_template, final_query_param_for_template) = if let Some(actual_query) =
|
||||
params.query
|
||||
{
|
||||
|
||||
@@ -0,0 +1,311 @@
|
||||
#![allow(clippy::expect_used)]
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
body::{to_bytes, Body},
|
||||
http::{header, Request, StatusCode},
|
||||
response::Response,
|
||||
Router,
|
||||
};
|
||||
use common::{
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
store::StorageManager,
|
||||
types::user::User,
|
||||
},
|
||||
utils::{
|
||||
config::{AppConfig, StorageKind},
|
||||
embedding::EmbeddingProvider,
|
||||
},
|
||||
};
|
||||
use html_router::{
|
||||
html_routes,
|
||||
html_state::{HtmlState, StateResources},
|
||||
};
|
||||
use tower::ServiceExt;
|
||||
|
||||
async fn build_test_app() -> (Router, Arc<SurrealDbClient>) {
|
||||
let namespace = "html_router_test";
|
||||
let database = &uuid::Uuid::new_v4().to_string();
|
||||
let db = Arc::new(
|
||||
SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("in-memory db"),
|
||||
);
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("migrations should apply");
|
||||
|
||||
let session_store = Arc::new(
|
||||
db.create_session_store()
|
||||
.await
|
||||
.expect("session store"),
|
||||
);
|
||||
|
||||
let config = AppConfig {
|
||||
storage: StorageKind::Memory,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let storage = StorageManager::new(&config)
|
||||
.await
|
||||
.expect("storage manager");
|
||||
|
||||
let embedding_provider = Arc::new(
|
||||
EmbeddingProvider::new_hashed(8).expect("embedding provider"),
|
||||
);
|
||||
|
||||
let state = HtmlState::new_with_resources(StateResources {
|
||||
db: Arc::clone(&db),
|
||||
openai_client: Arc::new(async_openai::Client::new()),
|
||||
session_store,
|
||||
storage,
|
||||
config,
|
||||
reranker_pool: None,
|
||||
embedding_provider,
|
||||
template_engine: None,
|
||||
});
|
||||
|
||||
let router = html_routes(&state).with_state(state);
|
||||
(router, db)
|
||||
}
|
||||
|
||||
fn redirect_location(response: &Response) -> String {
|
||||
response
|
||||
.headers()
|
||||
.get(header::LOCATION)
|
||||
.or_else(|| response.headers().get("HX-Redirect"))
|
||||
.expect("redirect response should include Location or HX-Redirect")
|
||||
.to_str()
|
||||
.expect("redirect header must be utf-8")
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn session_cookie(response: &Response) -> String {
|
||||
response
|
||||
.headers()
|
||||
.get_all(header::SET_COOKIE)
|
||||
.iter()
|
||||
.map(|value| {
|
||||
value
|
||||
.to_str()
|
||||
.expect("set-cookie must be utf-8")
|
||||
.split(';')
|
||||
.next()
|
||||
.expect("cookie key=value")
|
||||
.to_string()
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("; ")
|
||||
}
|
||||
|
||||
async fn response_body(response: Response) -> String {
|
||||
let body = to_bytes(response.into_body(), usize::MAX)
|
||||
.await
|
||||
.expect("response body");
|
||||
String::from_utf8(body.to_vec()).expect("html body")
|
||||
}
|
||||
|
||||
async fn sign_in(app: &Router, email: &str, password: &str) -> String {
|
||||
let response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri("/signin")
|
||||
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
|
||||
.body(Body::from(format!("email={email}&password={password}")))
|
||||
.expect("signin request"),
|
||||
)
|
||||
.await
|
||||
.expect("signin response");
|
||||
|
||||
assert!(
|
||||
response.status().is_redirection() || response.status() == StatusCode::OK,
|
||||
"signin should redirect or return ok"
|
||||
);
|
||||
session_cookie(&response)
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn protected_route_redirects_unauthenticated_users() {
|
||||
let (app, _db) = build_test_app().await;
|
||||
|
||||
let response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri("/")
|
||||
.body(Body::empty())
|
||||
.expect("dashboard request"),
|
||||
)
|
||||
.await
|
||||
.expect("dashboard response");
|
||||
|
||||
assert!(
|
||||
response.status().is_redirection() || response.status() == StatusCode::OK,
|
||||
"unauthenticated access should redirect via template middleware"
|
||||
);
|
||||
assert_eq!(redirect_location(&response), "/signin");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn authenticated_user_receives_rendered_dashboard() {
|
||||
let (app, db) = build_test_app().await;
|
||||
|
||||
User::create_new(
|
||||
"router_test@example.com".to_string(),
|
||||
"test_password".to_string(),
|
||||
&db,
|
||||
"UTC".to_string(),
|
||||
"system".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("test user");
|
||||
|
||||
let cookie = sign_in(&app, "router_test@example.com", "test_password").await;
|
||||
|
||||
let response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri("/")
|
||||
.header(header::COOKIE, cookie)
|
||||
.body(Body::empty())
|
||||
.expect("authenticated dashboard request"),
|
||||
)
|
||||
.await
|
||||
.expect("authenticated dashboard response");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body = to_bytes(response.into_body(), usize::MAX)
|
||||
.await
|
||||
.expect("response body");
|
||||
let html = String::from_utf8(body.to_vec()).expect("html body");
|
||||
assert!(
|
||||
html.contains("dashboard") || html.contains("Dashboard"),
|
||||
"dashboard template should render html"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn signin_form_is_public() {
|
||||
let (app, _db) = build_test_app().await;
|
||||
|
||||
let response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri("/signin")
|
||||
.body(Body::empty())
|
||||
.expect("signin form request"),
|
||||
)
|
||||
.await
|
||||
.expect("signin form response");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
let html = response_body(response).await;
|
||||
assert!(
|
||||
html.contains("signin") || html.contains("Sign in") || html.contains("email"),
|
||||
"signin page should render a form"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn signin_rejects_invalid_credentials() {
|
||||
let (app, db) = build_test_app().await;
|
||||
|
||||
User::create_new(
|
||||
"signin_test@example.com".to_string(),
|
||||
"correct_password".to_string(),
|
||||
&db,
|
||||
"UTC".to_string(),
|
||||
"system".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("test user");
|
||||
|
||||
let response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri("/signin")
|
||||
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
|
||||
.body(Body::from(
|
||||
"email=signin_test@example.com&password=wrong_password",
|
||||
))
|
||||
.expect("invalid signin request"),
|
||||
)
|
||||
.await
|
||||
.expect("invalid signin response");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
let html = response_body(response).await;
|
||||
assert!(
|
||||
html.contains("Incorrect email or password"),
|
||||
"signin failure should render a safe error message"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn admin_route_redirects_non_admin_user() {
|
||||
let (app, db) = build_test_app().await;
|
||||
|
||||
User::create_new(
|
||||
"admin_user@example.com".to_string(),
|
||||
"admin_password".to_string(),
|
||||
&db,
|
||||
"UTC".to_string(),
|
||||
"system".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("admin user");
|
||||
|
||||
User::create_new(
|
||||
"member_user@example.com".to_string(),
|
||||
"member_password".to_string(),
|
||||
&db,
|
||||
"UTC".to_string(),
|
||||
"system".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("member user");
|
||||
|
||||
let member_cookie = sign_in(&app, "member_user@example.com", "member_password").await;
|
||||
|
||||
let response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri("/admin")
|
||||
.header(header::COOKIE, member_cookie)
|
||||
.body(Body::empty())
|
||||
.expect("non-admin admin request"),
|
||||
)
|
||||
.await
|
||||
.expect("non-admin admin response");
|
||||
|
||||
assert!(
|
||||
response.status().is_redirection() || response.status() == StatusCode::OK,
|
||||
"non-admin should be redirected away from admin"
|
||||
);
|
||||
assert_eq!(redirect_location(&response), "/");
|
||||
|
||||
let admin_cookie = sign_in(&app, "admin_user@example.com", "admin_password").await;
|
||||
let admin_response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri("/admin")
|
||||
.header(header::COOKIE, admin_cookie)
|
||||
.body(Body::empty())
|
||||
.expect("admin request"),
|
||||
)
|
||||
.await
|
||||
.expect("admin response");
|
||||
|
||||
assert_eq!(admin_response.status(), StatusCode::OK);
|
||||
}
|
||||
Reference in New Issue
Block a user