diff --git a/html-router/src/html_state.rs b/html-router/src/html_state.rs index 426fe8f..e6e284e 100644 --- a/html-router/src/html_state.rs +++ b/html-router/src/html_state.rs @@ -1,9 +1,13 @@ +use common::storage::types::conversation::Conversation; use common::storage::{db::SurrealDbClient, store::StorageManager}; use common::utils::embedding::EmbeddingProvider; use common::utils::template_engine::{ProvidesTemplateEngine, TemplateEngine}; use common::{create_template_engine, storage::db::ProvidesDb, utils::config::AppConfig}; use retrieval_pipeline::{reranking::RerankerPool, RetrievalStrategy}; +use std::collections::HashMap; use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::RwLock; use tracing::debug; use crate::{OpenAIClientType, SessionStoreType}; @@ -18,8 +22,17 @@ pub struct HtmlState { pub storage: StorageManager, pub reranker_pool: Option>, pub embedding_provider: Arc, + conversation_archive_cache: Arc>>, } +#[derive(Clone)] +struct ConversationArchiveCacheEntry { + conversations: Vec, + expires_at: Instant, +} + +const CONVERSATION_ARCHIVE_CACHE_TTL: Duration = Duration::from_secs(30); + impl HtmlState { pub async fn new_with_resources( db: Arc, @@ -44,6 +57,7 @@ impl HtmlState { storage, reranker_pool, embedding_provider, + conversation_archive_cache: Arc::new(RwLock::new(HashMap::new())), }) } @@ -54,6 +68,38 @@ impl HtmlState { .and_then(|value| value.parse().ok()) .unwrap_or(RetrievalStrategy::Default) } + + pub async fn get_cached_conversation_archive( + &self, + user_id: &str, + ) -> Option> { + let cache = self.conversation_archive_cache.read().await; + let entry = cache.get(user_id)?; + if entry.expires_at <= Instant::now() { + return None; + } + Some(entry.conversations.clone()) + } + + pub async fn set_cached_conversation_archive( + &self, + user_id: &str, + conversations: Vec, + ) { + let mut cache = self.conversation_archive_cache.write().await; + cache.insert( + user_id.to_string(), + ConversationArchiveCacheEntry { + conversations, + expires_at: Instant::now() + CONVERSATION_ARCHIVE_CACHE_TTL, + }, + ); + } + + pub async fn invalidate_conversation_archive_cache(&self, user_id: &str) { + let mut cache = self.conversation_archive_cache.write().await; + cache.remove(user_id); + } } impl ProvidesDb for HtmlState { fn db(&self) -> &Arc { diff --git a/html-router/src/middlewares/response_middleware.rs b/html-router/src/middlewares/response_middleware.rs index b7db176..7269a9b 100644 --- a/html-router/src/middlewares/response_middleware.rs +++ b/html-router/src/middlewares/response_middleware.rs @@ -27,7 +27,7 @@ pub trait ProvidesHtmlState { fn html_state(&self) -> &HtmlState; } -#[derive(Clone)] +#[derive(Clone, Debug)] pub enum TemplateKind { Full(String), Partial(String, String), @@ -114,12 +114,33 @@ impl IntoResponse for TemplateResponse { } } +#[derive(Serialize)] +struct TemplateUser { + id: String, + email: String, + admin: bool, + timezone: String, + theme: String, +} + +impl From<&User> for TemplateUser { + fn from(user: &User) -> Self { + Self { + id: user.id.clone(), + email: user.email.clone(), + admin: user.admin, + timezone: user.timezone.clone(), + theme: user.theme.as_str().to_string(), + } + } +} + #[derive(Serialize)] struct ContextWrapper<'a> { user_theme: &'a str, initial_theme: &'a str, is_authenticated: bool, - user: Option<&'a User>, + user: Option<&'a TemplateUser>, conversation_archive: Vec, #[serde(flatten)] context: HashMap, @@ -138,6 +159,7 @@ where let mut initial_theme = Theme::System.initial_theme(); let mut is_authenticated = false; let mut current_user_id = None; + let mut current_user = None; { if let Some(auth) = req.extensions().get::() { @@ -146,6 +168,7 @@ where current_user_id = Some(user.id.clone()); user_theme = user.theme.as_str(); initial_theme = user.theme.initial_theme(); + current_user = Some(TemplateUser::from(user)); } } } @@ -158,17 +181,48 @@ where if let Some(template_response) = response.extensions().get::().cloned() { let template_engine = state.template_engine(); - let mut current_user = None; let mut conversation_archive = Vec::new(); - if let Some(user_id) = current_user_id { - let html_state = state.html_state(); - if let Ok(Some(user)) = html_state.db.get_item::(&user_id).await { - // Fetch conversation archive globally for authenticated users - if let Ok(archive) = User::get_user_conversations(&user.id, &html_state.db).await { + let should_load_conversation_archive = + matches!(&template_response.template_kind, TemplateKind::Full(_)); + + if should_load_conversation_archive { + if let Some(user_id) = current_user_id { + let html_state = state.html_state(); + if let Some(cached_archive) = + html_state.get_cached_conversation_archive(&user_id).await + { + conversation_archive = cached_archive; + } else if let Ok(archive) = + User::get_user_conversations(&user_id, &html_state.db).await + { + html_state + .set_cached_conversation_archive(&user_id, archive.clone()) + .await; conversation_archive = archive; } - current_user = Some(user); + } + } + + fn context_to_map( + value: &Value, + ) -> Result, minijinja::value::ValueKind> { + match value.kind() { + minijinja::value::ValueKind::Map => { + let mut map = HashMap::new(); + if let Ok(keys) = value.try_iter() { + for key in keys { + if let Ok(val) = value.get_item(&key) { + map.insert(key.to_string(), val); + } + } + } + Ok(map) + } + minijinja::value::ValueKind::None | minijinja::value::ValueKind::Undefined => { + Ok(HashMap::new()) + } + other => Err(other), } } @@ -183,19 +237,15 @@ where } } - // Convert minijinja::Value to HashMap if it's a map, otherwise use empty HashMap - let context_map = if template_response.context.kind() == minijinja::value::ValueKind::Map { - let mut map = HashMap::new(); - if let Ok(keys) = template_response.context.try_iter() { - for key in keys { - if let Ok(val) = template_response.context.get_item(&key) { - map.insert(key.to_string(), val); - } - } + let context_map = match context_to_map(&template_response.context) { + Ok(map) => map, + Err(kind) => { + error!( + "Template context must be a map or unit, got kind={:?} for template_kind={:?}", + kind, template_response.template_kind + ); + return (StatusCode::INTERNAL_SERVER_ERROR, Html(fallback_error())).into_response(); } - map - } else { - HashMap::new() }; let context = ContextWrapper { diff --git a/html-router/src/routes/account/handlers.rs b/html-router/src/routes/account/handlers.rs index c0ccae8..99100fa 100644 --- a/html-router/src/routes/account/handlers.rs +++ b/html-router/src/routes/account/handlers.rs @@ -17,10 +17,16 @@ use crate::html_state::HtmlState; pub struct AccountPageData { timezones: Vec, theme_options: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + api_key: Option, + #[serde(skip_serializing_if = "Option::is_none")] + selected_timezone: Option, + #[serde(skip_serializing_if = "Option::is_none")] + selected_theme: Option, } pub async fn show_account_page( - RequireUser(_user): RequireUser, + RequireUser(user): RequireUser, State(_state): State, ) -> Result { let timezones = TZ_VARIANTS @@ -40,6 +46,9 @@ pub async fn show_account_page( AccountPageData { timezones, theme_options, + api_key: user.api_key, + selected_timezone: None, + selected_theme: None, }, )) } @@ -50,7 +59,7 @@ pub async fn set_api_key( auth: AuthSessionType, ) -> Result { // Generate and set the API key - User::set_api_key(&user.id, &state.db).await?; + let api_key = User::set_api_key(&user.id, &state.db).await?; // Clear the cache so new requests have access to the user with api key auth.cache_clear_user(user.id.to_string()); @@ -62,6 +71,9 @@ pub async fn set_api_key( AccountPageData { timezones: vec![], theme_options: vec![], + api_key: Some(api_key), + selected_timezone: None, + selected_theme: None, }, )) } @@ -108,6 +120,9 @@ pub async fn update_timezone( AccountPageData { timezones, theme_options: vec![], + api_key: None, + selected_timezone: Some(form.timezone), + selected_theme: None, }, )) } @@ -142,6 +157,9 @@ pub async fn update_theme( AccountPageData { timezones: vec![], theme_options, + api_key: None, + selected_timezone: None, + selected_theme: Some(form.theme), }, )) } diff --git a/html-router/src/routes/auth/signin.rs b/html-router/src/routes/auth/signin.rs index edb7237..6952308 100644 --- a/html-router/src/routes/auth/signin.rs +++ b/html-router/src/routes/auth/signin.rs @@ -1,8 +1,4 @@ -use axum::{ - extract::State, - response::{Html, IntoResponse}, - Form, -}; +use axum::{extract::State, response::IntoResponse, Form}; use axum_htmx::HxBoosted; use serde::{Deserialize, Serialize}; @@ -46,7 +42,7 @@ pub async fn authenticate_user( let user = match User::authenticate(&form.email, &form.password, &state.db).await { Ok(user) => user, Err(_) => { - return Ok(Html("

Incorrect email or password

").into_response()); + return Ok(TemplateResponse::bad_request("Incorrect email or password").into_response()); } }; diff --git a/html-router/src/routes/auth/signup.rs b/html-router/src/routes/auth/signup.rs index 4b4d0e1..ba0a538 100644 --- a/html-router/src/routes/auth/signup.rs +++ b/html-router/src/routes/auth/signup.rs @@ -1,8 +1,4 @@ -use axum::{ - extract::State, - response::{Html, IntoResponse}, - Form, -}; +use axum::{extract::State, response::IntoResponse, Form}; use axum_htmx::HxBoosted; use serde::{Deserialize, Serialize}; @@ -57,7 +53,7 @@ pub async fn process_signup_and_show_verification( Ok(user) => user, Err(e) => { tracing::error!("{:?}", e); - return Ok(Html(format!("

{e}

")).into_response()); + return Ok(TemplateResponse::bad_request(&e.to_string()).into_response()); } }; diff --git a/html-router/src/routes/chat/chat_handlers.rs b/html-router/src/routes/chat/chat_handlers.rs index 07342d1..c45ec38 100644 --- a/html-router/src/routes/chat/chat_handlers.rs +++ b/html-router/src/routes/chat/chat_handlers.rs @@ -73,6 +73,7 @@ pub async fn show_initialized_chat( state.db.store_item(conversation.clone()).await?; state.db.store_item(ai_message.clone()).await?; state.db.store_item(user_message.clone()).await?; + state.invalidate_conversation_archive_cache(&user.id).await; let messages = vec![user_message, ai_message]; @@ -178,7 +179,7 @@ pub async fn new_chat_user_message( None => return Ok(Redirect::to("/").into_response()), }; - let conversation = Conversation::new(user.id, "New chat".to_string()); + let conversation = Conversation::new(user.id.clone(), "New chat".to_string()); let user_message = Message::new( conversation.id.clone(), MessageRole::User, @@ -188,6 +189,7 @@ pub async fn new_chat_user_message( state.db.store_item(conversation.clone()).await?; state.db.store_item(user_message.clone()).await?; + state.invalidate_conversation_archive_cache(&user.id).await; #[derive(Serialize)] struct SSEResponseInitData { @@ -252,6 +254,7 @@ pub async fn patch_conversation_title( Form(form): Form, ) -> Result { Conversation::patch_title(&conversation_id, &user.id, &form.title, &state.db).await?; + state.invalidate_conversation_archive_cache(&user.id).await; Ok(TemplateResponse::new_template( "sidebar.html", @@ -281,6 +284,7 @@ pub async fn delete_conversation( .db .delete_item::(&conversation_id) .await?; + state.invalidate_conversation_archive_cache(&user.id).await; Ok(TemplateResponse::new_template( "sidebar.html", diff --git a/html-router/src/routes/ingestion/handlers.rs b/html-router/src/routes/ingestion/handlers.rs index a037745..9a8944d 100644 --- a/html-router/src/routes/ingestion/handlers.rs +++ b/html-router/src/routes/ingestion/handlers.rs @@ -5,7 +5,7 @@ use axum::{ http::StatusCode, response::{ sse::{Event, KeepAlive}, - Html, IntoResponse, Response, Sse, + IntoResponse, Response, Sse, }, }; use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart}; @@ -56,12 +56,10 @@ pub async fn show_ingest_form( pub async fn hide_ingest_form( RequireUser(_user): RequireUser, ) -> Result { - Ok( - Html( - "Add Content", - ) - .into_response(), - ) + Ok(TemplateResponse::new_template( + "ingestion/add_content_button.html", + (), + )) } #[derive(Debug, TryFromMultipart)] @@ -80,11 +78,10 @@ pub async fn process_ingest_form( TypedMultipart(input): TypedMultipart, ) -> Result { if input.content.as_ref().is_none_or(|c| c.len() < 2) && input.files.is_empty() { - return Ok(( - StatusCode::BAD_REQUEST, - "You need to either add files or content", - ) - .into_response()); + return Ok( + TemplateResponse::bad_request("You need to either add files or content") + .into_response(), + ); } let content_bytes = input.content.as_ref().map_or(0, |c| c.len()); @@ -102,10 +99,15 @@ pub async fn process_ingest_form( ) { Ok(()) => {} Err(IngestValidationError::PayloadTooLarge(message)) => { - return Ok((StatusCode::PAYLOAD_TOO_LARGE, message).into_response()); + return Ok(TemplateResponse::error( + StatusCode::PAYLOAD_TOO_LARGE, + "Payload Too Large", + &message, + ) + .into_response()); } Err(IngestValidationError::BadRequest(message)) => { - return Ok((StatusCode::BAD_REQUEST, message).into_response()); + return Ok(TemplateResponse::bad_request(&message).into_response()); } } diff --git a/html-router/templates/auth/_account_settings_core.html b/html-router/templates/auth/_account_settings_core.html index dc159d0..c40de49 100644 --- a/html-router/templates/auth/_account_settings_core.html +++ b/html-router/templates/auth/_account_settings_core.html @@ -13,9 +13,9 @@