diff --git a/CHANGELOG.md b/CHANGELOG.md index 0335d70..0de8473 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,7 @@ # Changelog +## Unreleased +- Fix: edge case where navigation back to a chat page could trigger a new response generation + ## 1.0.1 (2026-02-11) - Shipped an S3 storage backend so content can be stored in object storage instead of local disk, with configuration support for S3 deployments. - Introduced user theme preferences with the new Obsidian Prism look and improved dark mode styling. diff --git a/html-router/src/routes/chat/message_response_stream.rs b/html-router/src/routes/chat/message_response_stream.rs index b1f42d6..84a9e55 100644 --- a/html-router/src/routes/chat/message_response_stream.rs +++ b/html-router/src/routes/chat/message_response_stream.rs @@ -56,7 +56,7 @@ async fn get_message_and_user( current_user: Option, message_id: &str, ) -> Result< - (Message, User, Conversation, Vec), + (Message, User, Conversation, Vec, Option), Sse> + Send>>>, > { // Check authentication @@ -83,7 +83,7 @@ async fn get_message_and_user( }; // Get conversation history - let (conversation, mut history) = + let (conversation, history) = match Conversation::get_complete_conversation(&message.conversation_id, &user.id, db).await { Err(e) => { @@ -95,10 +95,95 @@ async fn get_message_and_user( Ok((conversation, history)) => (conversation, history), }; - // Remove the last message, its the same as the message - history.pop(); + let Some(message_index) = find_message_index(&history, message_id) else { + return Err(Sse::new(create_error_stream( + "Message not found in conversation history", + ))); + }; - Ok((message, user, conversation, history)) + let Some(message_from_history) = history.get(message_index) else { + return Err(Sse::new(create_error_stream( + "Message not found in conversation history", + ))); + }; + + if message_from_history.role != MessageRole::User { + return Err(Sse::new(create_error_stream( + "Only user messages can be used to generate a response", + ))); + } + + let message = message_from_history.clone(); + + let history_before_message = history_before_message(&history, message_index); + let existing_ai_response = find_existing_ai_response(&history, message_index); + + Ok(( + message, + user, + conversation, + history_before_message, + existing_ai_response, + )) +} + +fn find_message_index(messages: &[Message], message_id: &str) -> Option { + messages.iter().position(|message| message.id == message_id) +} + +fn find_existing_ai_response(messages: &[Message], user_message_index: usize) -> Option { + messages + .iter() + .skip(user_message_index + 1) + .take_while(|message| message.role != MessageRole::User) + .find(|message| message.role == MessageRole::AI) + .cloned() +} + +fn history_before_message(messages: &[Message], message_index: usize) -> Vec { + messages.iter().take(message_index).cloned().collect() +} + +fn create_replayed_response_stream( + state: &HtmlState, + existing_ai_message: Message, +) -> Sse> + Send>>> { + let references_event = if existing_ai_message + .references + .as_ref() + .is_some_and(|references| !references.is_empty()) + { + state + .templates + .render( + "chat/reference_list.html", + &Value::from_serialize(ReferenceData { + message: existing_ai_message.clone(), + }), + ) + .ok() + .map(|html| Event::default().event("references").data(html)) + } else { + None + }; + + let answer = existing_ai_message.content; + + let event_stream = stream! { + yield Ok(Event::default().event("chat_message").data(answer)); + + if let Some(event) = references_event { + yield Ok(event); + } + + yield Ok(Event::default().event("close_stream").data("Stream complete")); + }; + + Sse::new(event_stream.boxed()).keep_alive( + KeepAlive::new() + .interval(Duration::from_secs(15)) + .text("keep-alive"), + ) } #[derive(Deserialize)] @@ -123,18 +208,25 @@ fn extract_reference_strings(response: &LLMResponseFormat) -> Vec { pub async fn get_response_stream( State(state): State, auth: AuthSessionType, - // auth: AuthSession, Surreal>, Query(params): Query, ) -> Sse> + Send>>> { // 1. Authentication and initial data validation - let (user_message, user, _conversation, history) = + let (user_message, user, _conversation, history, existing_ai_response) = match get_message_and_user(&state.db, auth.current_user, ¶ms.message_id).await { - Ok((user_message, user, conversation, history)) => { - (user_message, user, conversation, history) - } + Ok((user_message, user, conversation, history, existing_ai_response)) => ( + user_message, + user, + conversation, + history, + existing_ai_response, + ), Err(error_stream) => return error_stream, }; + if let Some(existing_ai_message) = existing_ai_response { + return create_replayed_response_stream(&state, existing_ai_message); + } + // 2. Retrieve knowledge entities let rerank_lease = match state.reranker_pool.as_ref() { Some(pool) => Some(pool.checkout().await), @@ -424,7 +516,39 @@ impl StreamParserState { #[cfg(test)] mod tests { use super::*; + use chrono::{Duration as ChronoDuration, Utc}; + use common::storage::{ + db::SurrealDbClient, + types::{conversation::Conversation, user::Theme}, + }; use retrieval_pipeline::answer_retrieval::Reference; + use uuid::Uuid; + + fn make_test_message(id: &str, role: MessageRole) -> Message { + let mut message = Message::new( + "conversation-1".to_string(), + role, + format!("content-{id}"), + None, + ); + message.id = id.to_string(); + message + } + + fn make_test_user(id: &str) -> User { + User { + id: id.to_string(), + created_at: Utc::now(), + updated_at: Utc::now(), + email: "test@example.com".to_string(), + password: "password".to_string(), + anonymous: false, + api_key: None, + admin: false, + timezone: "UTC".to_string(), + theme: Theme::System, + } + } #[test] fn extracts_reference_strings_in_order() { @@ -443,4 +567,140 @@ mod tests { let extracted = extract_reference_strings(&response); assert_eq!(extracted, vec!["a".to_string(), "b".to_string()]); } + + #[test] + fn finds_message_index_for_existing_message() { + let messages = vec![ + make_test_message("m1", MessageRole::User), + make_test_message("m2", MessageRole::AI), + make_test_message("m3", MessageRole::User), + ]; + + assert_eq!(find_message_index(&messages, "m2"), Some(1)); + assert_eq!(find_message_index(&messages, "missing"), None); + } + + #[test] + fn finds_existing_ai_response_for_same_turn() { + let messages = vec![ + make_test_message("u1", MessageRole::User), + make_test_message("system", MessageRole::System), + make_test_message("a1", MessageRole::AI), + make_test_message("u2", MessageRole::User), + make_test_message("a2", MessageRole::AI), + ]; + + let ai_reply = find_existing_ai_response(&messages, 0).expect("expected AI response"); + assert_eq!(ai_reply.id, "a1"); + + let ai_reply_second_turn = + find_existing_ai_response(&messages, 3).expect("expected AI response"); + assert_eq!(ai_reply_second_turn.id, "a2"); + } + + #[test] + fn does_not_replay_ai_response_from_later_turn() { + let messages = vec![ + make_test_message("u1", MessageRole::User), + make_test_message("u2", MessageRole::User), + make_test_message("a2", MessageRole::AI), + ]; + + assert!(find_existing_ai_response(&messages, 0).is_none()); + + let ai_reply = find_existing_ai_response(&messages, 1).expect("expected AI response"); + assert_eq!(ai_reply.id, "a2"); + } + + #[test] + fn history_before_message_excludes_target_and_future_messages() { + let messages = vec![ + make_test_message("u1", MessageRole::User), + make_test_message("a1", MessageRole::AI), + make_test_message("u2", MessageRole::User), + make_test_message("a2", MessageRole::AI), + ]; + + let history_for_u2 = history_before_message(&messages, 2); + let history_ids: Vec = history_for_u2 + .into_iter() + .map(|message| message.id) + .collect(); + assert_eq!(history_ids, vec!["u1".to_string(), "a1".to_string()]); + } + + #[tokio::test] + async fn get_message_and_user_reuses_existing_ai_response_for_same_turn() { + let namespace = "chat_stream_replay"; + let database = Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, &database) + .await + .expect("failed to create in-memory db"); + + let user = make_test_user("user-1"); + let conversation = Conversation::new(user.id.clone(), "Conversation".to_string()); + + let mut user_message = Message::new( + conversation.id.clone(), + MessageRole::User, + "Question one".to_string(), + None, + ); + user_message.id = "u1".to_string(); + + let mut ai_message = Message::new( + conversation.id.clone(), + MessageRole::AI, + "Answer one".to_string(), + Some(vec!["ref-1".to_string()]), + ); + ai_message.id = "a1".to_string(); + ai_message.created_at = user_message.created_at + ChronoDuration::seconds(1); + ai_message.updated_at = ai_message.created_at; + + let mut second_user_message = Message::new( + conversation.id.clone(), + MessageRole::User, + "Question two".to_string(), + None, + ); + second_user_message.id = "u2".to_string(); + second_user_message.created_at = ai_message.created_at + ChronoDuration::seconds(1); + second_user_message.updated_at = second_user_message.created_at; + + db.store_item(conversation.clone()) + .await + .expect("failed to store conversation"); + db.store_item(user_message.clone()) + .await + .expect("failed to store user message"); + db.store_item(ai_message.clone()) + .await + .expect("failed to store ai message"); + db.store_item(second_user_message.clone()) + .await + .expect("failed to store second user message"); + + let (_, _, _, history_for_first_turn, existing_ai_for_first_turn) = + get_message_and_user(&db, Some(user.clone()), &user_message.id) + .await + .expect("expected first turn to load"); + + assert!(history_for_first_turn.is_empty()); + let existing_ai_for_first_turn = + existing_ai_for_first_turn.expect("expected first-turn AI response"); + assert_eq!(existing_ai_for_first_turn.id, ai_message.id); + + let (_, _, _, history_for_second_turn, existing_ai_for_second_turn) = + get_message_and_user(&db, Some(user), &second_user_message.id) + .await + .expect("expected second turn to load"); + + let history_ids: Vec = history_for_second_turn + .into_iter() + .map(|message| message.id) + .collect(); + assert_eq!(history_ids, vec!["u1".to_string(), "a1".to_string()]); + assert!(existing_ai_for_second_turn.is_none()); + } } diff --git a/html-router/templates/chat/streaming_response.html b/html-router/templates/chat/streaming_response.html index b60d60f..c42bc5c 100644 --- a/html-router/templates/chat/streaming_response.html +++ b/html-router/templates/chat/streaming_response.html @@ -4,7 +4,8 @@
-
@@ -27,13 +28,22 @@ el.innerHTML = marked.parse(window.markdownBuffer[msgId].replace(/\\n/g, '\n')); if (typeof window.scrollChatToBottom === "function") window.scrollChatToBottom(); }); - document.body.addEventListener('htmx:sseClose', function () { + document.body.addEventListener('htmx:sseClose', function (e) { const msgId = '{{ user_message.id }}'; + const streamEl = document.getElementById('ai-stream-' + msgId); + if (streamEl && e.target !== streamEl) return; + const el = document.getElementById('ai-message-content-' + msgId); if (el && window.markdownBuffer[msgId]) { el.innerHTML = marked.parse(window.markdownBuffer[msgId].replace(/\\n/g, '\n')); delete window.markdownBuffer[msgId]; if (typeof window.scrollChatToBottom === "function") window.scrollChatToBottom(); } + + if (streamEl) { + streamEl.removeAttribute('sse-connect'); + streamEl.removeAttribute('sse-close'); + streamEl.removeAttribute('hx-ext'); + } });