fix: browser back navigation from chat windows

addenum
This commit is contained in:
Per Stark
2026-02-12 20:31:11 +01:00
parent bbad91d55b
commit e5d2b6605f
3 changed files with 285 additions and 12 deletions

View File

@@ -56,7 +56,7 @@ async fn get_message_and_user(
current_user: Option<User>,
message_id: &str,
) -> Result<
(Message, User, Conversation, Vec<Message>),
(Message, User, Conversation, Vec<Message>, Option<Message>),
Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>>,
> {
// Check authentication
@@ -83,7 +83,7 @@ async fn get_message_and_user(
};
// Get conversation history
let (conversation, mut history) =
let (conversation, history) =
match Conversation::get_complete_conversation(&message.conversation_id, &user.id, db).await
{
Err(e) => {
@@ -95,10 +95,95 @@ async fn get_message_and_user(
Ok((conversation, history)) => (conversation, history),
};
// Remove the last message, its the same as the message
history.pop();
let Some(message_index) = find_message_index(&history, message_id) else {
return Err(Sse::new(create_error_stream(
"Message not found in conversation history",
)));
};
Ok((message, user, conversation, history))
let Some(message_from_history) = history.get(message_index) else {
return Err(Sse::new(create_error_stream(
"Message not found in conversation history",
)));
};
if message_from_history.role != MessageRole::User {
return Err(Sse::new(create_error_stream(
"Only user messages can be used to generate a response",
)));
}
let message = message_from_history.clone();
let history_before_message = history_before_message(&history, message_index);
let existing_ai_response = find_existing_ai_response(&history, message_index);
Ok((
message,
user,
conversation,
history_before_message,
existing_ai_response,
))
}
fn find_message_index(messages: &[Message], message_id: &str) -> Option<usize> {
messages.iter().position(|message| message.id == message_id)
}
fn find_existing_ai_response(messages: &[Message], user_message_index: usize) -> Option<Message> {
messages
.iter()
.skip(user_message_index + 1)
.take_while(|message| message.role != MessageRole::User)
.find(|message| message.role == MessageRole::AI)
.cloned()
}
fn history_before_message(messages: &[Message], message_index: usize) -> Vec<Message> {
messages.iter().take(message_index).cloned().collect()
}
fn create_replayed_response_stream(
state: &HtmlState,
existing_ai_message: Message,
) -> Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>> {
let references_event = if existing_ai_message
.references
.as_ref()
.is_some_and(|references| !references.is_empty())
{
state
.templates
.render(
"chat/reference_list.html",
&Value::from_serialize(ReferenceData {
message: existing_ai_message.clone(),
}),
)
.ok()
.map(|html| Event::default().event("references").data(html))
} else {
None
};
let answer = existing_ai_message.content;
let event_stream = stream! {
yield Ok(Event::default().event("chat_message").data(answer));
if let Some(event) = references_event {
yield Ok(event);
}
yield Ok(Event::default().event("close_stream").data("Stream complete"));
};
Sse::new(event_stream.boxed()).keep_alive(
KeepAlive::new()
.interval(Duration::from_secs(15))
.text("keep-alive"),
)
}
#[derive(Deserialize)]
@@ -123,18 +208,25 @@ fn extract_reference_strings(response: &LLMResponseFormat) -> Vec<String> {
pub async fn get_response_stream(
State(state): State<HtmlState>,
auth: AuthSessionType,
// auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Query(params): Query<QueryParams>,
) -> Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>> {
// 1. Authentication and initial data validation
let (user_message, user, _conversation, history) =
let (user_message, user, _conversation, history, existing_ai_response) =
match get_message_and_user(&state.db, auth.current_user, &params.message_id).await {
Ok((user_message, user, conversation, history)) => {
(user_message, user, conversation, history)
}
Ok((user_message, user, conversation, history, existing_ai_response)) => (
user_message,
user,
conversation,
history,
existing_ai_response,
),
Err(error_stream) => return error_stream,
};
if let Some(existing_ai_message) = existing_ai_response {
return create_replayed_response_stream(&state, existing_ai_message);
}
// 2. Retrieve knowledge entities
let rerank_lease = match state.reranker_pool.as_ref() {
Some(pool) => Some(pool.checkout().await),
@@ -424,7 +516,39 @@ impl StreamParserState {
#[cfg(test)]
mod tests {
use super::*;
use chrono::{Duration as ChronoDuration, Utc};
use common::storage::{
db::SurrealDbClient,
types::{conversation::Conversation, user::Theme},
};
use retrieval_pipeline::answer_retrieval::Reference;
use uuid::Uuid;
fn make_test_message(id: &str, role: MessageRole) -> Message {
let mut message = Message::new(
"conversation-1".to_string(),
role,
format!("content-{id}"),
None,
);
message.id = id.to_string();
message
}
fn make_test_user(id: &str) -> User {
User {
id: id.to_string(),
created_at: Utc::now(),
updated_at: Utc::now(),
email: "test@example.com".to_string(),
password: "password".to_string(),
anonymous: false,
api_key: None,
admin: false,
timezone: "UTC".to_string(),
theme: Theme::System,
}
}
#[test]
fn extracts_reference_strings_in_order() {
@@ -443,4 +567,140 @@ mod tests {
let extracted = extract_reference_strings(&response);
assert_eq!(extracted, vec!["a".to_string(), "b".to_string()]);
}
#[test]
fn finds_message_index_for_existing_message() {
let messages = vec![
make_test_message("m1", MessageRole::User),
make_test_message("m2", MessageRole::AI),
make_test_message("m3", MessageRole::User),
];
assert_eq!(find_message_index(&messages, "m2"), Some(1));
assert_eq!(find_message_index(&messages, "missing"), None);
}
#[test]
fn finds_existing_ai_response_for_same_turn() {
let messages = vec![
make_test_message("u1", MessageRole::User),
make_test_message("system", MessageRole::System),
make_test_message("a1", MessageRole::AI),
make_test_message("u2", MessageRole::User),
make_test_message("a2", MessageRole::AI),
];
let ai_reply = find_existing_ai_response(&messages, 0).expect("expected AI response");
assert_eq!(ai_reply.id, "a1");
let ai_reply_second_turn =
find_existing_ai_response(&messages, 3).expect("expected AI response");
assert_eq!(ai_reply_second_turn.id, "a2");
}
#[test]
fn does_not_replay_ai_response_from_later_turn() {
let messages = vec![
make_test_message("u1", MessageRole::User),
make_test_message("u2", MessageRole::User),
make_test_message("a2", MessageRole::AI),
];
assert!(find_existing_ai_response(&messages, 0).is_none());
let ai_reply = find_existing_ai_response(&messages, 1).expect("expected AI response");
assert_eq!(ai_reply.id, "a2");
}
#[test]
fn history_before_message_excludes_target_and_future_messages() {
let messages = vec![
make_test_message("u1", MessageRole::User),
make_test_message("a1", MessageRole::AI),
make_test_message("u2", MessageRole::User),
make_test_message("a2", MessageRole::AI),
];
let history_for_u2 = history_before_message(&messages, 2);
let history_ids: Vec<String> = history_for_u2
.into_iter()
.map(|message| message.id)
.collect();
assert_eq!(history_ids, vec!["u1".to_string(), "a1".to_string()]);
}
#[tokio::test]
async fn get_message_and_user_reuses_existing_ai_response_for_same_turn() {
let namespace = "chat_stream_replay";
let database = Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, &database)
.await
.expect("failed to create in-memory db");
let user = make_test_user("user-1");
let conversation = Conversation::new(user.id.clone(), "Conversation".to_string());
let mut user_message = Message::new(
conversation.id.clone(),
MessageRole::User,
"Question one".to_string(),
None,
);
user_message.id = "u1".to_string();
let mut ai_message = Message::new(
conversation.id.clone(),
MessageRole::AI,
"Answer one".to_string(),
Some(vec!["ref-1".to_string()]),
);
ai_message.id = "a1".to_string();
ai_message.created_at = user_message.created_at + ChronoDuration::seconds(1);
ai_message.updated_at = ai_message.created_at;
let mut second_user_message = Message::new(
conversation.id.clone(),
MessageRole::User,
"Question two".to_string(),
None,
);
second_user_message.id = "u2".to_string();
second_user_message.created_at = ai_message.created_at + ChronoDuration::seconds(1);
second_user_message.updated_at = second_user_message.created_at;
db.store_item(conversation.clone())
.await
.expect("failed to store conversation");
db.store_item(user_message.clone())
.await
.expect("failed to store user message");
db.store_item(ai_message.clone())
.await
.expect("failed to store ai message");
db.store_item(second_user_message.clone())
.await
.expect("failed to store second user message");
let (_, _, _, history_for_first_turn, existing_ai_for_first_turn) =
get_message_and_user(&db, Some(user.clone()), &user_message.id)
.await
.expect("expected first turn to load");
assert!(history_for_first_turn.is_empty());
let existing_ai_for_first_turn =
existing_ai_for_first_turn.expect("expected first-turn AI response");
assert_eq!(existing_ai_for_first_turn.id, ai_message.id);
let (_, _, _, history_for_second_turn, existing_ai_for_second_turn) =
get_message_and_user(&db, Some(user), &second_user_message.id)
.await
.expect("expected second turn to load");
let history_ids: Vec<String> = history_for_second_turn
.into_iter()
.map(|message| message.id)
.collect();
assert_eq!(history_ids, vec!["u1".to_string(), "a1".to_string()]);
assert!(existing_ai_for_second_turn.is_none());
}
}

View File

@@ -4,7 +4,8 @@
</div>
</div>
<div class="chat chat-start">
<div hx-ext="sse" sse-connect="/chat/response-stream?message_id={{user_message.id}}" sse-close="close_stream"
<div id="ai-stream-{{user_message.id}}" hx-ext="sse"
sse-connect="/chat/response-stream?message_id={{user_message.id}}" sse-close="close_stream"
hx-swap="beforeend">
<div class="chat-bubble">
<span class="loading loading-dots loading-sm loading-id-{{user_message.id}}"></span>
@@ -27,13 +28,22 @@
el.innerHTML = marked.parse(window.markdownBuffer[msgId].replace(/\\n/g, '\n'));
if (typeof window.scrollChatToBottom === "function") window.scrollChatToBottom();
});
document.body.addEventListener('htmx:sseClose', function () {
document.body.addEventListener('htmx:sseClose', function (e) {
const msgId = '{{ user_message.id }}';
const streamEl = document.getElementById('ai-stream-' + msgId);
if (streamEl && e.target !== streamEl) return;
const el = document.getElementById('ai-message-content-' + msgId);
if (el && window.markdownBuffer[msgId]) {
el.innerHTML = marked.parse(window.markdownBuffer[msgId].replace(/\\n/g, '\n'));
delete window.markdownBuffer[msgId];
if (typeof window.scrollChatToBottom === "function") window.scrollChatToBottom();
}
if (streamEl) {
streamEl.removeAttribute('sse-connect');
streamEl.removeAttribute('sse-close');
streamEl.removeAttribute('hx-ext');
}
});
</script>