mirror of
https://github.com/perstarkse/minne.git
synced 2026-04-18 15:09:49 +02:00
fix: browser back navigation from chat windows
addenum
This commit is contained in:
@@ -56,7 +56,7 @@ async fn get_message_and_user(
|
||||
current_user: Option<User>,
|
||||
message_id: &str,
|
||||
) -> Result<
|
||||
(Message, User, Conversation, Vec<Message>),
|
||||
(Message, User, Conversation, Vec<Message>, Option<Message>),
|
||||
Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>>,
|
||||
> {
|
||||
// Check authentication
|
||||
@@ -83,7 +83,7 @@ async fn get_message_and_user(
|
||||
};
|
||||
|
||||
// Get conversation history
|
||||
let (conversation, mut history) =
|
||||
let (conversation, history) =
|
||||
match Conversation::get_complete_conversation(&message.conversation_id, &user.id, db).await
|
||||
{
|
||||
Err(e) => {
|
||||
@@ -95,10 +95,95 @@ async fn get_message_and_user(
|
||||
Ok((conversation, history)) => (conversation, history),
|
||||
};
|
||||
|
||||
// Remove the last message, its the same as the message
|
||||
history.pop();
|
||||
let Some(message_index) = find_message_index(&history, message_id) else {
|
||||
return Err(Sse::new(create_error_stream(
|
||||
"Message not found in conversation history",
|
||||
)));
|
||||
};
|
||||
|
||||
Ok((message, user, conversation, history))
|
||||
let Some(message_from_history) = history.get(message_index) else {
|
||||
return Err(Sse::new(create_error_stream(
|
||||
"Message not found in conversation history",
|
||||
)));
|
||||
};
|
||||
|
||||
if message_from_history.role != MessageRole::User {
|
||||
return Err(Sse::new(create_error_stream(
|
||||
"Only user messages can be used to generate a response",
|
||||
)));
|
||||
}
|
||||
|
||||
let message = message_from_history.clone();
|
||||
|
||||
let history_before_message = history_before_message(&history, message_index);
|
||||
let existing_ai_response = find_existing_ai_response(&history, message_index);
|
||||
|
||||
Ok((
|
||||
message,
|
||||
user,
|
||||
conversation,
|
||||
history_before_message,
|
||||
existing_ai_response,
|
||||
))
|
||||
}
|
||||
|
||||
fn find_message_index(messages: &[Message], message_id: &str) -> Option<usize> {
|
||||
messages.iter().position(|message| message.id == message_id)
|
||||
}
|
||||
|
||||
fn find_existing_ai_response(messages: &[Message], user_message_index: usize) -> Option<Message> {
|
||||
messages
|
||||
.iter()
|
||||
.skip(user_message_index + 1)
|
||||
.take_while(|message| message.role != MessageRole::User)
|
||||
.find(|message| message.role == MessageRole::AI)
|
||||
.cloned()
|
||||
}
|
||||
|
||||
fn history_before_message(messages: &[Message], message_index: usize) -> Vec<Message> {
|
||||
messages.iter().take(message_index).cloned().collect()
|
||||
}
|
||||
|
||||
fn create_replayed_response_stream(
|
||||
state: &HtmlState,
|
||||
existing_ai_message: Message,
|
||||
) -> Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>> {
|
||||
let references_event = if existing_ai_message
|
||||
.references
|
||||
.as_ref()
|
||||
.is_some_and(|references| !references.is_empty())
|
||||
{
|
||||
state
|
||||
.templates
|
||||
.render(
|
||||
"chat/reference_list.html",
|
||||
&Value::from_serialize(ReferenceData {
|
||||
message: existing_ai_message.clone(),
|
||||
}),
|
||||
)
|
||||
.ok()
|
||||
.map(|html| Event::default().event("references").data(html))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let answer = existing_ai_message.content;
|
||||
|
||||
let event_stream = stream! {
|
||||
yield Ok(Event::default().event("chat_message").data(answer));
|
||||
|
||||
if let Some(event) = references_event {
|
||||
yield Ok(event);
|
||||
}
|
||||
|
||||
yield Ok(Event::default().event("close_stream").data("Stream complete"));
|
||||
};
|
||||
|
||||
Sse::new(event_stream.boxed()).keep_alive(
|
||||
KeepAlive::new()
|
||||
.interval(Duration::from_secs(15))
|
||||
.text("keep-alive"),
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -123,18 +208,25 @@ fn extract_reference_strings(response: &LLMResponseFormat) -> Vec<String> {
|
||||
pub async fn get_response_stream(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSessionType,
|
||||
// auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
Query(params): Query<QueryParams>,
|
||||
) -> Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>> {
|
||||
// 1. Authentication and initial data validation
|
||||
let (user_message, user, _conversation, history) =
|
||||
let (user_message, user, _conversation, history, existing_ai_response) =
|
||||
match get_message_and_user(&state.db, auth.current_user, ¶ms.message_id).await {
|
||||
Ok((user_message, user, conversation, history)) => {
|
||||
(user_message, user, conversation, history)
|
||||
}
|
||||
Ok((user_message, user, conversation, history, existing_ai_response)) => (
|
||||
user_message,
|
||||
user,
|
||||
conversation,
|
||||
history,
|
||||
existing_ai_response,
|
||||
),
|
||||
Err(error_stream) => return error_stream,
|
||||
};
|
||||
|
||||
if let Some(existing_ai_message) = existing_ai_response {
|
||||
return create_replayed_response_stream(&state, existing_ai_message);
|
||||
}
|
||||
|
||||
// 2. Retrieve knowledge entities
|
||||
let rerank_lease = match state.reranker_pool.as_ref() {
|
||||
Some(pool) => Some(pool.checkout().await),
|
||||
@@ -424,7 +516,39 @@ impl StreamParserState {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::{Duration as ChronoDuration, Utc};
|
||||
use common::storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{conversation::Conversation, user::Theme},
|
||||
};
|
||||
use retrieval_pipeline::answer_retrieval::Reference;
|
||||
use uuid::Uuid;
|
||||
|
||||
fn make_test_message(id: &str, role: MessageRole) -> Message {
|
||||
let mut message = Message::new(
|
||||
"conversation-1".to_string(),
|
||||
role,
|
||||
format!("content-{id}"),
|
||||
None,
|
||||
);
|
||||
message.id = id.to_string();
|
||||
message
|
||||
}
|
||||
|
||||
fn make_test_user(id: &str) -> User {
|
||||
User {
|
||||
id: id.to_string(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
email: "test@example.com".to_string(),
|
||||
password: "password".to_string(),
|
||||
anonymous: false,
|
||||
api_key: None,
|
||||
admin: false,
|
||||
timezone: "UTC".to_string(),
|
||||
theme: Theme::System,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extracts_reference_strings_in_order() {
|
||||
@@ -443,4 +567,140 @@ mod tests {
|
||||
let extracted = extract_reference_strings(&response);
|
||||
assert_eq!(extracted, vec!["a".to_string(), "b".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn finds_message_index_for_existing_message() {
|
||||
let messages = vec![
|
||||
make_test_message("m1", MessageRole::User),
|
||||
make_test_message("m2", MessageRole::AI),
|
||||
make_test_message("m3", MessageRole::User),
|
||||
];
|
||||
|
||||
assert_eq!(find_message_index(&messages, "m2"), Some(1));
|
||||
assert_eq!(find_message_index(&messages, "missing"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn finds_existing_ai_response_for_same_turn() {
|
||||
let messages = vec![
|
||||
make_test_message("u1", MessageRole::User),
|
||||
make_test_message("system", MessageRole::System),
|
||||
make_test_message("a1", MessageRole::AI),
|
||||
make_test_message("u2", MessageRole::User),
|
||||
make_test_message("a2", MessageRole::AI),
|
||||
];
|
||||
|
||||
let ai_reply = find_existing_ai_response(&messages, 0).expect("expected AI response");
|
||||
assert_eq!(ai_reply.id, "a1");
|
||||
|
||||
let ai_reply_second_turn =
|
||||
find_existing_ai_response(&messages, 3).expect("expected AI response");
|
||||
assert_eq!(ai_reply_second_turn.id, "a2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn does_not_replay_ai_response_from_later_turn() {
|
||||
let messages = vec![
|
||||
make_test_message("u1", MessageRole::User),
|
||||
make_test_message("u2", MessageRole::User),
|
||||
make_test_message("a2", MessageRole::AI),
|
||||
];
|
||||
|
||||
assert!(find_existing_ai_response(&messages, 0).is_none());
|
||||
|
||||
let ai_reply = find_existing_ai_response(&messages, 1).expect("expected AI response");
|
||||
assert_eq!(ai_reply.id, "a2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn history_before_message_excludes_target_and_future_messages() {
|
||||
let messages = vec![
|
||||
make_test_message("u1", MessageRole::User),
|
||||
make_test_message("a1", MessageRole::AI),
|
||||
make_test_message("u2", MessageRole::User),
|
||||
make_test_message("a2", MessageRole::AI),
|
||||
];
|
||||
|
||||
let history_for_u2 = history_before_message(&messages, 2);
|
||||
let history_ids: Vec<String> = history_for_u2
|
||||
.into_iter()
|
||||
.map(|message| message.id)
|
||||
.collect();
|
||||
assert_eq!(history_ids, vec!["u1".to_string(), "a1".to_string()]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn get_message_and_user_reuses_existing_ai_response_for_same_turn() {
|
||||
let namespace = "chat_stream_replay";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.expect("failed to create in-memory db");
|
||||
|
||||
let user = make_test_user("user-1");
|
||||
let conversation = Conversation::new(user.id.clone(), "Conversation".to_string());
|
||||
|
||||
let mut user_message = Message::new(
|
||||
conversation.id.clone(),
|
||||
MessageRole::User,
|
||||
"Question one".to_string(),
|
||||
None,
|
||||
);
|
||||
user_message.id = "u1".to_string();
|
||||
|
||||
let mut ai_message = Message::new(
|
||||
conversation.id.clone(),
|
||||
MessageRole::AI,
|
||||
"Answer one".to_string(),
|
||||
Some(vec!["ref-1".to_string()]),
|
||||
);
|
||||
ai_message.id = "a1".to_string();
|
||||
ai_message.created_at = user_message.created_at + ChronoDuration::seconds(1);
|
||||
ai_message.updated_at = ai_message.created_at;
|
||||
|
||||
let mut second_user_message = Message::new(
|
||||
conversation.id.clone(),
|
||||
MessageRole::User,
|
||||
"Question two".to_string(),
|
||||
None,
|
||||
);
|
||||
second_user_message.id = "u2".to_string();
|
||||
second_user_message.created_at = ai_message.created_at + ChronoDuration::seconds(1);
|
||||
second_user_message.updated_at = second_user_message.created_at;
|
||||
|
||||
db.store_item(conversation.clone())
|
||||
.await
|
||||
.expect("failed to store conversation");
|
||||
db.store_item(user_message.clone())
|
||||
.await
|
||||
.expect("failed to store user message");
|
||||
db.store_item(ai_message.clone())
|
||||
.await
|
||||
.expect("failed to store ai message");
|
||||
db.store_item(second_user_message.clone())
|
||||
.await
|
||||
.expect("failed to store second user message");
|
||||
|
||||
let (_, _, _, history_for_first_turn, existing_ai_for_first_turn) =
|
||||
get_message_and_user(&db, Some(user.clone()), &user_message.id)
|
||||
.await
|
||||
.expect("expected first turn to load");
|
||||
|
||||
assert!(history_for_first_turn.is_empty());
|
||||
let existing_ai_for_first_turn =
|
||||
existing_ai_for_first_turn.expect("expected first-turn AI response");
|
||||
assert_eq!(existing_ai_for_first_turn.id, ai_message.id);
|
||||
|
||||
let (_, _, _, history_for_second_turn, existing_ai_for_second_turn) =
|
||||
get_message_and_user(&db, Some(user), &second_user_message.id)
|
||||
.await
|
||||
.expect("expected second turn to load");
|
||||
|
||||
let history_ids: Vec<String> = history_for_second_turn
|
||||
.into_iter()
|
||||
.map(|message| message.id)
|
||||
.collect();
|
||||
assert_eq!(history_ids, vec!["u1".to_string(), "a1".to_string()]);
|
||||
assert!(existing_ai_for_second_turn.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,8 @@
|
||||
</div>
|
||||
</div>
|
||||
<div class="chat chat-start">
|
||||
<div hx-ext="sse" sse-connect="/chat/response-stream?message_id={{user_message.id}}" sse-close="close_stream"
|
||||
<div id="ai-stream-{{user_message.id}}" hx-ext="sse"
|
||||
sse-connect="/chat/response-stream?message_id={{user_message.id}}" sse-close="close_stream"
|
||||
hx-swap="beforeend">
|
||||
<div class="chat-bubble">
|
||||
<span class="loading loading-dots loading-sm loading-id-{{user_message.id}}"></span>
|
||||
@@ -27,13 +28,22 @@
|
||||
el.innerHTML = marked.parse(window.markdownBuffer[msgId].replace(/\\n/g, '\n'));
|
||||
if (typeof window.scrollChatToBottom === "function") window.scrollChatToBottom();
|
||||
});
|
||||
document.body.addEventListener('htmx:sseClose', function () {
|
||||
document.body.addEventListener('htmx:sseClose', function (e) {
|
||||
const msgId = '{{ user_message.id }}';
|
||||
const streamEl = document.getElementById('ai-stream-' + msgId);
|
||||
if (streamEl && e.target !== streamEl) return;
|
||||
|
||||
const el = document.getElementById('ai-message-content-' + msgId);
|
||||
if (el && window.markdownBuffer[msgId]) {
|
||||
el.innerHTML = marked.parse(window.markdownBuffer[msgId].replace(/\\n/g, '\n'));
|
||||
delete window.markdownBuffer[msgId];
|
||||
if (typeof window.scrollChatToBottom === "function") window.scrollChatToBottom();
|
||||
}
|
||||
|
||||
if (streamEl) {
|
||||
streamEl.removeAttribute('sse-connect');
|
||||
streamEl.removeAttribute('sse-close');
|
||||
streamEl.removeAttribute('hx-ext');
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
Reference in New Issue
Block a user