chore: improve html-router auth, caching, and analytics while centralizing search labels in common.

small fix
This commit is contained in:
Per Stark
2026-05-29 14:42:20 +02:00
parent d3443d4153
commit 2aa92b6ad7
27 changed files with 510 additions and 344 deletions
+3 -78
View File
@@ -1,20 +1,16 @@
use axum::{
extract::{Path, State},
http::HeaderValue,
response::{IntoResponse, Redirect},
response::IntoResponse,
Form,
};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use serde::{Deserialize, Serialize};
use surrealdb::{engine::any::Any, Surreal};
use common::{
error::AppError,
storage::types::{
conversation::Conversation,
message::{Message, MessageRole},
user::User,
},
};
@@ -26,75 +22,12 @@ use crate::{
},
};
#[derive(Debug, Deserialize)]
pub struct ChatStartParams {
user_query: String,
llm_response: String,
#[serde(deserialize_with = "deserialize_references")]
references: Vec<String>,
}
// Custom deserializer function
fn deserialize_references<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
serde_json::from_str(&s).map_err(serde::de::Error::custom)
}
#[derive(Serialize)]
pub struct ChatPageData {
history: Vec<Message>,
conversation: Option<Conversation>,
}
/// # Panics
/// Panics if the HX-Push header value cannot be parsed.
pub async fn show_initialized_chat(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(form): Form<ChatStartParams>,
) -> Result<impl IntoResponse, HtmlError> {
let conversation = Conversation::new(user.id.clone(), "Test".to_owned());
let user_message = Message::new(
conversation.id.clone(),
MessageRole::User,
form.user_query,
None,
);
let ai_message = Message::new(
conversation.id.clone(),
MessageRole::AI,
form.llm_response,
Some(form.references),
);
state.db.store_item(conversation.clone()).await?;
state.db.store_item(ai_message.clone()).await?;
state.db.store_item(user_message.clone()).await?;
state.invalidate_conversation_archive_cache(&user.id).await;
let messages = vec![user_message, ai_message];
let mut response = TemplateResponse::new_template(
"chat/base.html",
ChatPageData {
history: messages,
conversation: Some(conversation.clone()),
},
)
.into_response();
if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) {
response.headers_mut().insert("HX-Push", header_value);
}
Ok(response)
}
pub async fn show_chat_base(
State(_state): State<HtmlState>,
RequireUser(_user): RequireUser,
@@ -131,8 +64,6 @@ pub async fn show_existing_chat(
))
}
/// # Panics
/// Panics if the HX-Push header value cannot be parsed.
pub async fn new_user_message(
Path(conversation_id): Path<String>,
State(state): State<HtmlState>,
@@ -171,11 +102,9 @@ pub async fn new_user_message(
Ok(response)
}
/// # Panics
/// Panics if the HX-Push header value cannot be parsed.
pub async fn new_chat_user_message(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
RequireUser(user): RequireUser,
Form(form): Form<NewMessageForm>,
) -> Result<impl IntoResponse, HtmlError> {
#[derive(Serialize)]
@@ -184,10 +113,6 @@ pub async fn new_chat_user_message(
conversation: Conversation,
}
let Some(user) = auth.current_user else {
return Ok(Redirect::to("/").into_response());
};
let conversation = Conversation::new(user.id.clone(), "New chat".to_string());
let user_message = Message::new(
conversation.id.clone(),
@@ -213,7 +138,7 @@ pub async fn new_chat_user_message(
response.headers_mut().insert("HX-Push", header_value);
}
Ok(response.into_response())
Ok(response)
}
#[derive(Deserialize)]
@@ -25,7 +25,7 @@ use retrieval_pipeline::{
};
use serde::{Deserialize, Serialize};
use serde_json::from_str;
use std::sync::Mutex;
use tokio::sync::Mutex;
use tokio::sync::mpsc::channel;
use tracing::{debug, error, info};
@@ -39,7 +39,10 @@ use common::storage::{
},
};
use crate::{html_state::HtmlState, AuthSessionType};
use crate::{
html_state::HtmlState,
middlewares::auth_middleware::RequireUser,
};
use super::reference_validation::{collect_reference_ids_from_retrieval, validate_references};
@@ -61,15 +64,9 @@ fn create_error_stream(message: impl Into<String>) -> EventStream {
async fn get_message_and_user(
db: &SurrealDbClient,
current_user: Option<User>,
user: User,
message_id: &str,
) -> Result<(Message, User, Conversation, Vec<Message>, Option<Message>), SseResponse> {
let Some(user) = current_user else {
return Err(sse_with_keep_alive(create_error_stream(
"You must be signed in to use this feature",
)));
};
let message = match db.get_item::<Message>(message_id).await {
Ok(Some(message)) => message,
Ok(None) => {
@@ -136,7 +133,7 @@ fn find_message_index(messages: &[Message], message_id: &str) -> Option<usize> {
fn find_existing_ai_response(messages: &[Message], user_message_index: usize) -> Option<Message> {
messages
.iter()
.skip(user_message_index + 1)
.skip(user_message_index.saturating_add(1))
.take_while(|message| message.role != MessageRole::User)
.find(|message| message.role == MessageRole::AI)
.cloned()
@@ -202,11 +199,11 @@ fn extract_reference_strings(response: &LLMResponseFormat) -> Vec<String> {
#[allow(clippy::too_many_lines)]
pub async fn get_response_stream(
State(state): State<HtmlState>,
auth: AuthSessionType,
RequireUser(user): RequireUser,
Query(params): Query<QueryParams>,
) -> SseResponse {
let (user_message, user, _conversation, history, existing_ai_response) =
match get_message_and_user(&state.db, auth.current_user, &params.message_id).await {
match get_message_and_user(&state.db, user, &params.message_id).await {
Ok((user_message, user, conversation, history, existing_ai_response)) => (
user_message,
user,
@@ -289,7 +286,7 @@ fn build_chat_event_stream(
let _ = tx_storage.send(content.clone()).await;
let display_content = {
let mut state = json_state.lock().expect("json parser mutex poisoned");
let mut state = json_state.lock().await;
state.process_chunk(&content)
};
if !display_content.is_empty() {
@@ -540,6 +537,8 @@ impl StreamParserState {
#[cfg(test)]
mod tests {
#![allow(clippy::expect_used, clippy::missing_docs_in_private_items)]
use super::*;
use chrono::{Duration as ChronoDuration, Utc};
use common::storage::{
@@ -707,7 +706,7 @@ mod tests {
.expect("failed to store second user message");
let (_, _, _, history_for_first_turn, existing_ai_for_first_turn) =
get_message_and_user(&db, Some(user.clone()), &user_message.id)
get_message_and_user(&db, user.clone(), &user_message.id)
.await
.expect("expected first turn to load");
@@ -717,7 +716,7 @@ mod tests {
assert_eq!(existing_ai_for_first_turn.id, ai_message.id);
let (_, _, _, history_for_second_turn, existing_ai_for_second_turn) =
get_message_and_user(&db, Some(user), &second_user_message.id)
get_message_and_user(&db, user, &second_user_message.id)
.await
.expect("expected second turn to load");
+1 -3
View File
@@ -5,14 +5,13 @@ mod references;
use axum::{
extract::FromRef,
routing::{get, post},
routing::get,
Router,
};
pub use chat_handlers::{
delete_conversation, new_chat_user_message, new_user_message, patch_conversation_title,
reload_sidebar, show_conversation_editing_title,
show_chat_base as show_base, show_existing_chat as show_existing,
show_initialized_chat as show_initialized,
};
use message_response_stream::get_response_stream;
use references::show_reference_tooltip;
@@ -37,7 +36,6 @@ where
get(show_conversation_editing_title).patch(patch_conversation_title),
)
.route("/chat/sidebar", get(reload_sidebar))
.route("/initialized-chat", post(show_initialized))
.route("/chat/response-stream", get(get_response_stream))
.route("/chat/reference/{id}", get(show_reference_tooltip))
}
@@ -51,21 +51,21 @@ impl ReferenceReasonStats {
match reason {
InvalidReferenceReason::Empty => self.empty = self.empty.saturating_add(1),
InvalidReferenceReason::UnsupportedPrefix => {
self.unsupported_prefix = self.unsupported_prefix.saturating_add(1)
self.unsupported_prefix = self.unsupported_prefix.saturating_add(1);
}
InvalidReferenceReason::MalformedUuid => {
self.malformed_uuid = self.malformed_uuid.saturating_add(1)
self.malformed_uuid = self.malformed_uuid.saturating_add(1);
}
InvalidReferenceReason::Duplicate => self.duplicate = self.duplicate.saturating_add(1),
InvalidReferenceReason::NotInContext => {
self.not_in_context = self.not_in_context.saturating_add(1)
self.not_in_context = self.not_in_context.saturating_add(1);
}
InvalidReferenceReason::NotFound => self.not_found = self.not_found.saturating_add(1),
InvalidReferenceReason::WrongUser => {
self.wrong_user = self.wrong_user.saturating_add(1)
self.wrong_user = self.wrong_user.saturating_add(1);
}
InvalidReferenceReason::OverLimit => {
self.over_limit = self.over_limit.saturating_add(1)
self.over_limit = self.over_limit.saturating_add(1);
}
}
}