feat: streaming chat response and persistant chats

This commit is contained in:
Per Stark
2025-02-26 16:15:59 +01:00
parent c41e370b81
commit 4ce272d5be
12 changed files with 522 additions and 87 deletions

10
Cargo.lock generated
View File

@@ -2427,6 +2427,15 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "json-stream-parser"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a70ab2b05e827e0604229fcf11b24560b036a21286a41517a6cac271f12a6a9"
dependencies = [
"serde_json",
]
[[package]]
name = "json5"
version = "0.4.1"
@@ -6042,6 +6051,7 @@ dependencies = [
"chrono-tz",
"config",
"futures",
"json-stream-parser",
"lettre",
"mime",
"mime_guess",

View File

@@ -16,6 +16,7 @@ chrono = { version = "0.4.39", features = ["serde"] }
chrono-tz = "0.10.1"
config = "0.15.4"
futures = "0.3.31"
json-stream-parser = "0.1.4"
lettre = { version = "0.11.11", features = ["rustls-tls"] }
mime = "0.3.17"
mime_guess = "2.0.5"

View File

@@ -2541,6 +2541,21 @@
}
}
}
.z-0 {
z-index: 0;
}
.z-1 {
z-index: 1;
}
.z-2 {
z-index: 2;
}
.z-3 {
z-index: 3;
}
.z-5 {
z-index: 5;
}
.z-20 {
z-index: 20;
}
@@ -4074,6 +4089,9 @@
width: calc(var(--spacing) * 10);
height: calc(var(--spacing) * 10);
}
.h-4 {
height: calc(var(--spacing) * 4);
}
.h-5 {
height: calc(var(--spacing) * 5);
}
@@ -4433,6 +4451,9 @@
.p-2 {
padding: calc(var(--spacing) * 2);
}
.p-3 {
padding: calc(var(--spacing) * 3);
}
.p-4 {
padding: calc(var(--spacing) * 4);
}

View File

@@ -0,0 +1,328 @@
use std::{pin::Pin, time::Duration};
use axum::{
extract::{Query, State},
response::{sse::Event, Sse},
};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use futures::{stream, Stream, StreamExt, TryStreamExt};
use json_stream_parser::JsonStreamParser;
use serde::Deserialize;
use surrealdb::{engine::any::Any, Surreal};
use tracing::{error, info};
use crate::{
retrieval::{
combined_knowledge_entity_retrieval,
query_helper::{
create_chat_request, create_user_message, format_entities_json, LLMResponseFormat,
},
},
server::AppState,
storage::{
db::{get_item, store_item, SurrealDbClient},
types::{
message::{Message, MessageRole},
user::User,
},
},
};
fn create_error_stream(
message: impl Into<String>,
) -> Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>> {
let message = message.into();
stream::once(async move { Ok(Event::default().event("error").data(message)) }).boxed()
}
// Helper function to get message and user
async fn get_message_and_user(
db: &SurrealDbClient,
current_user: Option<User>,
message_id: &str,
) -> Result<(Message, User), Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>>> {
// Check authentication
let user = match current_user {
Some(user) => user,
None => {
return Err(Sse::new(create_error_stream(
"You must be signed in to use this feature",
)))
}
};
// Retrieve message
let message = match get_item::<Message>(db, message_id).await {
Ok(Some(message)) => message,
Ok(None) => {
return Err(Sse::new(create_error_stream(
"Message not found: the specified message does not exist",
)))
}
Err(e) => {
error!("Database error retrieving message {}: {:?}", message_id, e);
return Err(Sse::new(create_error_stream(
"Failed to retrieve message: database error",
)));
}
};
Ok((message, user))
}
#[derive(Deserialize)]
pub struct QueryParams {
message_id: String,
}
pub async fn get_response_stream(
State(state): State<AppState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Query(params): Query<QueryParams>,
) -> Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>> {
// 1. Authentication and initial data validation
let (user_message, user) = match get_message_and_user(
&state.surreal_db_client,
auth.current_user,
&params.message_id,
)
.await
{
Ok((user_message, user)) => (user_message, user),
Err(error_stream) => return error_stream,
};
// 2. Retrieve knowledge entities
let entities = match combined_knowledge_entity_retrieval(
&state.surreal_db_client,
&state.openai_client,
&user_message.content,
&user.id,
)
.await
{
Ok(entities) => entities,
Err(_e) => {
return Sse::new(create_error_stream("Failed to retrieve knowledge entities"));
}
};
// 3. Create the OpenAI request
let entities_json = format_entities_json(&entities);
let formatted_user_message = create_user_message(&entities_json, &user_message.content);
let request = match create_chat_request(formatted_user_message) {
Ok(req) => req,
Err(..) => {
return Sse::new(create_error_stream("Failed to create chat request"));
}
};
// 4. Set up the OpenAI stream
let openai_stream = match state.openai_client.chat().create_stream(request).await {
Ok(stream) => stream,
Err(_e) => {
return Sse::new(create_error_stream("Failed to create OpenAI stream"));
}
};
// 5. Create channel for collecting complete response
let (tx, mut rx) = tokio::sync::mpsc::channel::<String>(1000);
let tx_clone = tx.clone();
// 6. Set up the collection task for DB storage
let db_client = state.surreal_db_client.clone();
tokio::spawn(async move {
drop(tx); // Close sender when no longer needed
// Collect full response
let mut full_json = String::new();
while let Some(chunk) = rx.recv().await {
full_json.push_str(&chunk);
}
// Try to extract structured data
if let Ok(response) = serde_json::from_str::<LLMResponseFormat>(&full_json) {
let references: Vec<String> = response
.references
.into_iter()
.map(|r| r.reference)
.collect();
let ai_message = Message::new(
user_message.conversation_id,
MessageRole::AI,
response.answer,
Some(references),
);
match store_item(&db_client, ai_message).await {
Ok(_) => info!("Successfully stored AI message with references"),
Err(e) => error!("Failed to store AI message: {:?}", e),
}
} else {
error!("Failed to parse LLM response as structured format");
// Fallback - store raw response
let ai_message = Message::new(
user_message.conversation_id,
MessageRole::AI,
full_json,
Some(vec![]),
);
let _ = store_item(&db_client, ai_message).await;
}
});
// Create a shared state for tracking the JSON parsing
let json_state = std::sync::Arc::new(tokio::sync::Mutex::new(StreamParserState::new()));
// 7. Create the response event stream
let event_stream = openai_stream
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
.map(move |result| {
let tx_storage = tx_clone.clone();
let json_state = json_state.clone();
async move {
match result {
Ok(response) => {
let content = response
.choices
.first()
.and_then(|choice| choice.delta.content.clone())
.unwrap_or_default();
if !content.is_empty() {
// Always send raw content to storage
let _ = tx_storage.send(content.clone()).await;
// Process through JSON parser
let mut state = json_state.lock().await;
let display_content = state.process_chunk(&content);
drop(state);
if !display_content.is_empty() {
return Ok(Event::default()
.event("chat_message")
.data(display_content));
}
// Empty or filtered content
Ok(Event::default().event("chat_message").data(""))
} else {
Ok(Event::default().event("chat_message").data(""))
}
}
Err(e) => Ok(Event::default()
.event("error")
.data(format!("Stream error: {}", e))),
}
}
})
.buffered(10)
.chain(stream::once(async {
Ok(Event::default()
.event("close_stream")
.data("Stream complete"))
}));
info!("OpenAI streaming started");
Sse::new(event_stream.boxed()).keep_alive(
axum::response::sse::KeepAlive::new()
.interval(Duration::from_secs(15))
.text("keep-alive"),
)
}
// Replace JsonParseState with StreamParserState
struct StreamParserState {
parser: JsonStreamParser,
last_answer_content: String,
in_answer_field: bool,
}
impl StreamParserState {
fn new() -> Self {
Self {
parser: JsonStreamParser::new(),
last_answer_content: String::new(),
in_answer_field: false,
}
}
fn process_chunk(&mut self, chunk: &str) -> String {
// Feed all characters into the parser
for c in chunk.chars() {
let _ = self.parser.add_char(c);
}
// Get the current state of the JSON
// The get_result() method returns a &Value, not a Result
let json = self.parser.get_result();
// Check if we're in the answer field
if let Some(obj) = json.as_object() {
if let Some(answer) = obj.get("answer") {
self.in_answer_field = true;
// Get current answer content
let current_content = answer.as_str().unwrap_or_default().to_string();
// Calculate difference to send only new content
if current_content.len() > self.last_answer_content.len() {
let new_content = current_content[self.last_answer_content.len()..].to_string();
self.last_answer_content = current_content;
return new_content;
}
}
}
// No new content to return
String::new()
}
}
// 7. Create the response event stream
// let event_stream = openai_stream
// .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
// .map(move |result| {
// let tx = tx_clone.clone();
// async move {
// match result {
// Ok(response) => {
// let content = response
// .choices
// .first()
// .and_then(|choice| choice.delta.content.clone())
// .unwrap_or_default();
// if !content.is_empty() {
// let _ = tx.send(content.clone()).await;
// Ok(Event::default().event("chat_message").data(content))
// } else {
// Ok(Event::default().event("chat_message").data(""))
// }
// }
// Err(e) => Ok(Event::default()
// .event("error")
// .data(format!("Stream error: {}", e))),
// }
// }
// })
// .buffered(10)
// .chain(stream::once(async {
// Ok(Event::default()
// .event("close_stream")
// .data("Stream complete"))
// }));
// info!("OpenAI streaming started");
// Sse::new(event_stream.boxed()).keep_alive(
// axum::response::sse::KeepAlive::new()
// .interval(Duration::from_secs(15))
// .text("keep-alive"),
// )
// }

View File

@@ -1,28 +1,27 @@
use std::time::Duration;
pub mod message_response_stream;
use axum::{
extract::{Path, Query, State},
response::{
sse::{Event, KeepAlive},
Html, IntoResponse, Redirect, Sse,
},
extract::{Path, State},
http::HeaderValue,
response::{IntoResponse, Redirect},
Form,
};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use futures::{stream, Stream, StreamExt};
use surrealdb::{engine::any::Any, Surreal};
use tokio::time::sleep;
use tracing::info;
use uuid::Uuid;
use crate::{
error::HtmlError,
error::{AppError, HtmlError},
page_data,
server::{routes::html::render_template, AppState},
storage::types::{
message::{Message, MessageRole},
user::User,
storage::{
db::store_item,
types::{
conversation::Conversation,
message::{Message, MessageRole},
user::User,
},
},
};
@@ -47,7 +46,8 @@ where
page_data!(ChatData, "chat/base.html", {
user: User,
history: Vec<Message>,
conversation_id: String,
conversation: Conversation,
conversation_archive: Vec<Conversation>
});
pub async fn show_initialized_chat(
@@ -62,19 +62,37 @@ pub async fn show_initialized_chat(
None => return Ok(Redirect::to("/").into_response()),
};
info!("{:?}", form);
let conversation = Conversation::new(user.id.clone(), "Test".to_owned());
let conversation_id = Uuid::new_v4().to_string();
let user_message = Message::new("test".to_string(), MessageRole::User, form.user_query, None);
let user_message = Message::new(
conversation.id.to_string(),
MessageRole::User,
form.user_query,
None,
);
let ai_message = Message::new(
"test".to_string(),
conversation.id.to_string(),
MessageRole::AI,
form.llm_response,
Some(form.references),
);
let (conversation_result, ai_message_result, user_message_result) = futures::join!(
store_item(&state.surreal_db_client, conversation.clone()),
store_item(&state.surreal_db_client, ai_message.clone()),
store_item(&state.surreal_db_client, user_message.clone())
);
// Check each result individually
conversation_result.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
user_message_result.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
ai_message_result.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
let conversation_archive = User::get_user_conversations(&user.id, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let messages = vec![user_message, ai_message];
let output = render_template(
@@ -82,12 +100,18 @@ pub async fn show_initialized_chat(
ChatData {
history: messages,
user,
conversation_id,
conversation_archive,
conversation: conversation.clone(),
},
state.templates.clone(),
)?;
Ok(output.into_response())
let mut response = output.into_response();
response.headers_mut().insert(
"HX-Push",
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(),
);
Ok(response)
}
pub async fn show_chat_base(
@@ -101,14 +125,19 @@ pub async fn show_chat_base(
None => return Ok(Redirect::to("/").into_response()),
};
let conversation_id = Uuid::new_v4().to_string();
let conversation_archive = User::get_user_conversations(&user.id, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let conversation = Conversation::new(user.id.clone(), "New Chat".to_string());
let output = render_template(
ChatData::template_name(),
ChatData {
history: vec![],
user,
conversation_id,
conversation_archive,
conversation,
},
state.templates.clone(),
)?;
@@ -121,39 +150,37 @@ pub struct NewMessageForm {
content: String,
}
pub async fn new_user_message(
pub async fn show_existing_chat(
Path(conversation_id): Path<String>,
State(state): State<AppState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Form(form): Form<NewMessageForm>,
) -> Result<impl IntoResponse, HtmlError> {
info!("Displaying empty chat start");
info!("Displaying initialized chat with id: {}", conversation_id);
let user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/").into_response()),
};
let query_id = Uuid::new_v4().to_string();
let user_message = form.content.clone();
let conversation_archive = User::get_user_conversations(&user.id, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
// Save to database
// state
// .db
// .save(conversation_id, query_id.clone(), user_message)
// .await;
#[derive(Serialize)]
struct SSEResponseInitData {
user_message: String,
query_id: String,
}
let (conversation, messages) = Conversation::get_complete_conversation(
conversation_id.as_str(),
&user.id,
&state.surreal_db_client,
)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let output = render_template(
"chat/streaming_response.html",
SSEResponseInitData {
user_message,
query_id,
ChatData::template_name(),
ChatData {
history: messages,
user,
conversation: conversation.clone(),
conversation_archive,
},
state.templates.clone(),
)?;
@@ -161,38 +188,33 @@ pub async fn new_user_message(
Ok(output.into_response())
}
#[derive(Deserialize)]
pub struct QueryParams {
query_id: String,
}
pub async fn get_response_stream(
State(_state): State<AppState>,
pub async fn new_user_message(
Path(conversation_id): Path<String>,
State(state): State<AppState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Query(params): Query<QueryParams>,
) -> Sse<impl Stream<Item = Result<Event, axum::Error>>> {
let stream = stream::iter(vec![
Event::default()
.event("chat_message")
.data("Hello, starting stream!"),
Event::default()
.event("chat_message")
.data("This is message 2"),
Event::default().event("chat_message").data("Final message"),
Event::default()
.event("close_stream")
.data("Stream complete"), // Signal to close
])
.then(|event| async move {
sleep(Duration::from_millis(500)).await; // Delay between messages
Ok(event)
});
Form(form): Form<NewMessageForm>,
) -> Result<impl IntoResponse, HtmlError> {
let _user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/").into_response()),
};
info!("Streaming started");
let user_message = Message::new(conversation_id, MessageRole::User, form.content, None);
Sse::new(stream).keep_alive(
axum::response::sse::KeepAlive::new()
.interval(Duration::from_secs(15))
.text("keep-alive"),
)
store_item(&state.surreal_db_client, user_message.clone())
.await
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
#[derive(Serialize)]
struct SSEResponseInitData {
user_message: Message,
}
let output = render_template(
"chat/streaming_response.html",
SSEResponseInitData { user_message },
state.templates.clone(),
)?;
Ok(output.into_response())
}

View File

@@ -16,7 +16,10 @@ use axum_session_surreal::SessionSurrealPool;
use html::{
account::{delete_account, set_api_key, show_account_page, update_timezone},
admin_panel::{show_admin_panel, toggle_registration_status},
chat::{get_response_stream, new_user_message, show_chat_base, show_initialized_chat},
chat::{
message_response_stream::get_response_stream, new_user_message, show_chat_base,
show_existing_chat, show_initialized_chat,
},
content::{patch_text_content, show_content_page, show_text_content_edit_form},
documentation::{
show_documentation_index, show_get_started, show_mobile_friendly, show_privacy_policy,
@@ -65,7 +68,7 @@ pub fn html_routes(app_state: &AppState) -> Router<AppState> {
.route("/gdpr/deny", post(deny_gdpr))
.route("/search", get(search_result_handler))
.route("/chat", get(show_chat_base).post(show_initialized_chat))
.route("/chat/:id", post(new_user_message))
.route("/chat/:id", get(show_existing_chat).post(new_user_message))
.route("/chat/response-stream", get(get_response_stream))
.route("/signout", get(sign_out_user))
.route("/signin", get(show_signin_form).post(authenticate_user))

View File

@@ -1,6 +1,12 @@
use uuid::Uuid;
use crate::stored_object;
use crate::{
error::AppError,
storage::db::{get_item, SurrealDbClient},
stored_object,
};
use super::message::Message;
stored_object!(Conversation, "conversation", {
user_id: String,
@@ -18,4 +24,29 @@ impl Conversation {
title,
}
}
pub async fn get_complete_conversation(
conversation_id: &str,
user_id: &str,
db: &SurrealDbClient,
) -> Result<(Self, Vec<Message>), AppError> {
let conversation: Conversation = get_item(&db, conversation_id)
.await?
.ok_or_else(|| AppError::NotFound("Conversation not found".to_string()))?;
if conversation.user_id != user_id {
return Err(AppError::Auth(
"You don't have access to this conversation".to_string(),
));
}
let messages:Vec<Message> = db.client.
query("SELECT * FROM type::table($table_name) WHERE conversation_id = $conversation_id ORDER BY updated_at").
bind(("table_name", Message::table_name())).
bind(("conversation_id", conversation_id.to_string()))
.await?
.take(0)?;
Ok((conversation, messages))
}
}

View File

@@ -2,7 +2,7 @@ use uuid::Uuid;
use crate::stored_object;
#[derive(Deserialize, Debug, Serialize)]
#[derive(Deserialize, Debug, Clone, Serialize)]
pub enum MessageRole {
User,
AI,

View File

@@ -8,8 +8,9 @@ use surrealdb::{engine::any::Any, Surreal};
use uuid::Uuid;
use super::{
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
system_settings::SystemSettings, text_content::TextContent,
conversation::Conversation, knowledge_entity::KnowledgeEntity,
knowledge_relationship::KnowledgeRelationship, system_settings::SystemSettings,
text_content::TextContent,
};
#[derive(Deserialize)]
@@ -333,4 +334,19 @@ impl User {
Ok(text_content)
}
pub async fn get_user_conversations(
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<Conversation>, AppError> {
let conversations: Vec<Conversation> = db
.client
.query("SELECT * FROM type::table($table_name) WHERE user_id = $user_id")
.bind(("table_name", Conversation::table_name()))
.bind(("user_id", user_id.to_string()))
.await?
.take(0)?;
Ok(conversations)
}
}

View File

@@ -7,7 +7,10 @@
"icons/edit_icon.html" %}
</span></a></li>
<div class="divider"></div>
<li><a>Sidebar Item 1</a></li>
<li><a>Sidebar Item 2</a></li>
{% for conversation in conversation_archive %}
<li><a href="/chat/{{conversation.id}}" hx-boost="true">{{conversation.title}} - {{conversation.created_at}}</a>
</li>
{% endfor %}
{{conversation_archive}}
</ul>
</div>

View File

@@ -1,5 +1,5 @@
<div class="fixed w-full mx-auto max-w-3xl p-4 pb-0 sm:pb-4 left-0 right-0 bottom-0">
<form hx-post="/chat/{{conversation_id}}" hx-target="#chat_container" hx-swap="beforeend" class="relative flex gap-2"
<form hx-post="/chat/{{conversation.id}}" hx-target="#chat_container" hx-swap="beforeend" class="relative flex gap-2"
id="chat-form">
<textarea autofocus required name="content" placeholder="Type your message..." rows="2"
class="textarea textarea-ghost rounded-2xl rounded-b-none h-24 sm:rounded-b-2xl pr-8 bg-base-200 flex-grow resize-none"

View File

@@ -1,20 +1,20 @@
<div class="chat chat-end">
<div class="chat-header">User</div>
<div class="chat-bubble">
{{user_message}}
{{user_message.content}}
</div>
</div>
<div class="chat chat-start">
<div class="chat-header">AI</div>
<div class="chat-bubble" hx-ext="sse" sse-connect="/chat/response-stream?query_id={{query_id}}"
<div class="chat-bubble" hx-ext="sse" sse-connect="/chat/response-stream?message_id={{user_message.id}}"
sse-swap="chat_message" sse-close="close_stream" hx-swap="beforeend">
<span class="loading loading-dots loading-sm loading-id-{{query_id}}"></span>
<span class="loading loading-dots loading-sm loading-id-{{user_message.id}}"></span>
</div>
</div>
<script>
document.body.addEventListener('htmx:sseBeforeMessage', (e) => {
const targetElement = e.detail.elt;
const loadingSpinner = targetElement.querySelector('.loading-id-{{query_id}}');
const loadingSpinner = targetElement.querySelector('.loading-id-{{user_message.id}}');
// Hiding the loading spinner before data is swapped in
if (loadingSpinner) {