diff --git a/Cargo.lock b/Cargo.lock index 7cb7212..8e42aa1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2427,6 +2427,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "json-stream-parser" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a70ab2b05e827e0604229fcf11b24560b036a21286a41517a6cac271f12a6a9" +dependencies = [ + "serde_json", +] + [[package]] name = "json5" version = "0.4.1" @@ -6042,6 +6051,7 @@ dependencies = [ "chrono-tz", "config", "futures", + "json-stream-parser", "lettre", "mime", "mime_guess", diff --git a/Cargo.toml b/Cargo.toml index 06426c4..b55c00e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ chrono = { version = "0.4.39", features = ["serde"] } chrono-tz = "0.10.1" config = "0.15.4" futures = "0.3.31" +json-stream-parser = "0.1.4" lettre = { version = "0.11.11", features = ["rustls-tls"] } mime = "0.3.17" mime_guess = "2.0.5" diff --git a/assets/style.css b/assets/style.css index 6bfc311..754c8c3 100644 --- a/assets/style.css +++ b/assets/style.css @@ -2541,6 +2541,21 @@ } } } + .z-0 { + z-index: 0; + } + .z-1 { + z-index: 1; + } + .z-2 { + z-index: 2; + } + .z-3 { + z-index: 3; + } + .z-5 { + z-index: 5; + } .z-20 { z-index: 20; } @@ -4074,6 +4089,9 @@ width: calc(var(--spacing) * 10); height: calc(var(--spacing) * 10); } + .h-4 { + height: calc(var(--spacing) * 4); + } .h-5 { height: calc(var(--spacing) * 5); } @@ -4433,6 +4451,9 @@ .p-2 { padding: calc(var(--spacing) * 2); } + .p-3 { + padding: calc(var(--spacing) * 3); + } .p-4 { padding: calc(var(--spacing) * 4); } diff --git a/src/server/routes/html/chat/message_response_stream.rs b/src/server/routes/html/chat/message_response_stream.rs new file mode 100644 index 0000000..659cf94 --- /dev/null +++ b/src/server/routes/html/chat/message_response_stream.rs @@ -0,0 +1,328 @@ +use std::{pin::Pin, time::Duration}; + +use axum::{ + extract::{Query, State}, + response::{sse::Event, Sse}, +}; +use axum_session_auth::AuthSession; +use axum_session_surreal::SessionSurrealPool; +use futures::{stream, Stream, StreamExt, TryStreamExt}; +use json_stream_parser::JsonStreamParser; +use serde::Deserialize; +use surrealdb::{engine::any::Any, Surreal}; +use tracing::{error, info}; + +use crate::{ + retrieval::{ + combined_knowledge_entity_retrieval, + query_helper::{ + create_chat_request, create_user_message, format_entities_json, LLMResponseFormat, + }, + }, + server::AppState, + storage::{ + db::{get_item, store_item, SurrealDbClient}, + types::{ + message::{Message, MessageRole}, + user::User, + }, + }, +}; + +fn create_error_stream( + message: impl Into, +) -> Pin> + Send>> { + let message = message.into(); + stream::once(async move { Ok(Event::default().event("error").data(message)) }).boxed() +} + +// Helper function to get message and user +async fn get_message_and_user( + db: &SurrealDbClient, + current_user: Option, + message_id: &str, +) -> Result<(Message, User), Sse> + Send>>>> { + // Check authentication + let user = match current_user { + Some(user) => user, + None => { + return Err(Sse::new(create_error_stream( + "You must be signed in to use this feature", + ))) + } + }; + + // Retrieve message + let message = match get_item::(db, message_id).await { + Ok(Some(message)) => message, + Ok(None) => { + return Err(Sse::new(create_error_stream( + "Message not found: the specified message does not exist", + ))) + } + Err(e) => { + error!("Database error retrieving message {}: {:?}", message_id, e); + return Err(Sse::new(create_error_stream( + "Failed to retrieve message: database error", + ))); + } + }; + + Ok((message, user)) +} + +#[derive(Deserialize)] +pub struct QueryParams { + message_id: String, +} + +pub async fn get_response_stream( + State(state): State, + auth: AuthSession, Surreal>, + Query(params): Query, +) -> Sse> + Send>>> { + // 1. Authentication and initial data validation + let (user_message, user) = match get_message_and_user( + &state.surreal_db_client, + auth.current_user, + ¶ms.message_id, + ) + .await + { + Ok((user_message, user)) => (user_message, user), + Err(error_stream) => return error_stream, + }; + + // 2. Retrieve knowledge entities + let entities = match combined_knowledge_entity_retrieval( + &state.surreal_db_client, + &state.openai_client, + &user_message.content, + &user.id, + ) + .await + { + Ok(entities) => entities, + Err(_e) => { + return Sse::new(create_error_stream("Failed to retrieve knowledge entities")); + } + }; + + // 3. Create the OpenAI request + let entities_json = format_entities_json(&entities); + let formatted_user_message = create_user_message(&entities_json, &user_message.content); + let request = match create_chat_request(formatted_user_message) { + Ok(req) => req, + Err(..) => { + return Sse::new(create_error_stream("Failed to create chat request")); + } + }; + + // 4. Set up the OpenAI stream + let openai_stream = match state.openai_client.chat().create_stream(request).await { + Ok(stream) => stream, + Err(_e) => { + return Sse::new(create_error_stream("Failed to create OpenAI stream")); + } + }; + + // 5. Create channel for collecting complete response + let (tx, mut rx) = tokio::sync::mpsc::channel::(1000); + let tx_clone = tx.clone(); + + // 6. Set up the collection task for DB storage + let db_client = state.surreal_db_client.clone(); + tokio::spawn(async move { + drop(tx); // Close sender when no longer needed + + // Collect full response + let mut full_json = String::new(); + while let Some(chunk) = rx.recv().await { + full_json.push_str(&chunk); + } + + // Try to extract structured data + if let Ok(response) = serde_json::from_str::(&full_json) { + let references: Vec = response + .references + .into_iter() + .map(|r| r.reference) + .collect(); + + let ai_message = Message::new( + user_message.conversation_id, + MessageRole::AI, + response.answer, + Some(references), + ); + + match store_item(&db_client, ai_message).await { + Ok(_) => info!("Successfully stored AI message with references"), + Err(e) => error!("Failed to store AI message: {:?}", e), + } + } else { + error!("Failed to parse LLM response as structured format"); + + // Fallback - store raw response + let ai_message = Message::new( + user_message.conversation_id, + MessageRole::AI, + full_json, + Some(vec![]), + ); + + let _ = store_item(&db_client, ai_message).await; + } + }); + + // Create a shared state for tracking the JSON parsing + let json_state = std::sync::Arc::new(tokio::sync::Mutex::new(StreamParserState::new())); + + // 7. Create the response event stream + let event_stream = openai_stream + .map_err(|e| Box::new(e) as Box) + .map(move |result| { + let tx_storage = tx_clone.clone(); + let json_state = json_state.clone(); + + async move { + match result { + Ok(response) => { + let content = response + .choices + .first() + .and_then(|choice| choice.delta.content.clone()) + .unwrap_or_default(); + + if !content.is_empty() { + // Always send raw content to storage + let _ = tx_storage.send(content.clone()).await; + + // Process through JSON parser + let mut state = json_state.lock().await; + + let display_content = state.process_chunk(&content); + drop(state); + if !display_content.is_empty() { + return Ok(Event::default() + .event("chat_message") + .data(display_content)); + } + + // Empty or filtered content + Ok(Event::default().event("chat_message").data("")) + } else { + Ok(Event::default().event("chat_message").data("")) + } + } + Err(e) => Ok(Event::default() + .event("error") + .data(format!("Stream error: {}", e))), + } + } + }) + .buffered(10) + .chain(stream::once(async { + Ok(Event::default() + .event("close_stream") + .data("Stream complete")) + })); + + info!("OpenAI streaming started"); + Sse::new(event_stream.boxed()).keep_alive( + axum::response::sse::KeepAlive::new() + .interval(Duration::from_secs(15)) + .text("keep-alive"), + ) +} + +// Replace JsonParseState with StreamParserState +struct StreamParserState { + parser: JsonStreamParser, + last_answer_content: String, + in_answer_field: bool, +} + +impl StreamParserState { + fn new() -> Self { + Self { + parser: JsonStreamParser::new(), + last_answer_content: String::new(), + in_answer_field: false, + } + } + + fn process_chunk(&mut self, chunk: &str) -> String { + // Feed all characters into the parser + for c in chunk.chars() { + let _ = self.parser.add_char(c); + } + + // Get the current state of the JSON + // The get_result() method returns a &Value, not a Result + let json = self.parser.get_result(); + + // Check if we're in the answer field + if let Some(obj) = json.as_object() { + if let Some(answer) = obj.get("answer") { + self.in_answer_field = true; + + // Get current answer content + let current_content = answer.as_str().unwrap_or_default().to_string(); + + // Calculate difference to send only new content + if current_content.len() > self.last_answer_content.len() { + let new_content = current_content[self.last_answer_content.len()..].to_string(); + self.last_answer_content = current_content; + return new_content; + } + } + } + + // No new content to return + String::new() + } +} + +// 7. Create the response event stream +// let event_stream = openai_stream +// .map_err(|e| Box::new(e) as Box) +// .map(move |result| { +// let tx = tx_clone.clone(); +// async move { +// match result { +// Ok(response) => { +// let content = response +// .choices +// .first() +// .and_then(|choice| choice.delta.content.clone()) +// .unwrap_or_default(); + +// if !content.is_empty() { +// let _ = tx.send(content.clone()).await; +// Ok(Event::default().event("chat_message").data(content)) +// } else { +// Ok(Event::default().event("chat_message").data("")) +// } +// } +// Err(e) => Ok(Event::default() +// .event("error") +// .data(format!("Stream error: {}", e))), +// } +// } +// }) +// .buffered(10) +// .chain(stream::once(async { +// Ok(Event::default() +// .event("close_stream") +// .data("Stream complete")) +// })); + +// info!("OpenAI streaming started"); + +// Sse::new(event_stream.boxed()).keep_alive( +// axum::response::sse::KeepAlive::new() +// .interval(Duration::from_secs(15)) +// .text("keep-alive"), +// ) +// } diff --git a/src/server/routes/html/chat/mod.rs b/src/server/routes/html/chat/mod.rs index 5deef5f..1e8207d 100644 --- a/src/server/routes/html/chat/mod.rs +++ b/src/server/routes/html/chat/mod.rs @@ -1,28 +1,27 @@ -use std::time::Duration; +pub mod message_response_stream; use axum::{ - extract::{Path, Query, State}, - response::{ - sse::{Event, KeepAlive}, - Html, IntoResponse, Redirect, Sse, - }, + extract::{Path, State}, + http::HeaderValue, + response::{IntoResponse, Redirect}, Form, }; use axum_session_auth::AuthSession; use axum_session_surreal::SessionSurrealPool; -use futures::{stream, Stream, StreamExt}; use surrealdb::{engine::any::Any, Surreal}; -use tokio::time::sleep; use tracing::info; -use uuid::Uuid; use crate::{ - error::HtmlError, + error::{AppError, HtmlError}, page_data, server::{routes::html::render_template, AppState}, - storage::types::{ - message::{Message, MessageRole}, - user::User, + storage::{ + db::store_item, + types::{ + conversation::Conversation, + message::{Message, MessageRole}, + user::User, + }, }, }; @@ -47,7 +46,8 @@ where page_data!(ChatData, "chat/base.html", { user: User, history: Vec, - conversation_id: String, + conversation: Conversation, + conversation_archive: Vec }); pub async fn show_initialized_chat( @@ -62,19 +62,37 @@ pub async fn show_initialized_chat( None => return Ok(Redirect::to("/").into_response()), }; - info!("{:?}", form); + let conversation = Conversation::new(user.id.clone(), "Test".to_owned()); - let conversation_id = Uuid::new_v4().to_string(); - - let user_message = Message::new("test".to_string(), MessageRole::User, form.user_query, None); + let user_message = Message::new( + conversation.id.to_string(), + MessageRole::User, + form.user_query, + None, + ); let ai_message = Message::new( - "test".to_string(), + conversation.id.to_string(), MessageRole::AI, form.llm_response, Some(form.references), ); + let (conversation_result, ai_message_result, user_message_result) = futures::join!( + store_item(&state.surreal_db_client, conversation.clone()), + store_item(&state.surreal_db_client, ai_message.clone()), + store_item(&state.surreal_db_client, user_message.clone()) + ); + + // Check each result individually + conversation_result.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?; + user_message_result.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?; + ai_message_result.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?; + + let conversation_archive = User::get_user_conversations(&user.id, &state.surreal_db_client) + .await + .map_err(|e| HtmlError::new(e, state.templates.clone()))?; + let messages = vec![user_message, ai_message]; let output = render_template( @@ -82,12 +100,18 @@ pub async fn show_initialized_chat( ChatData { history: messages, user, - conversation_id, + conversation_archive, + conversation: conversation.clone(), }, state.templates.clone(), )?; - Ok(output.into_response()) + let mut response = output.into_response(); + response.headers_mut().insert( + "HX-Push", + HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(), + ); + Ok(response) } pub async fn show_chat_base( @@ -101,14 +125,19 @@ pub async fn show_chat_base( None => return Ok(Redirect::to("/").into_response()), }; - let conversation_id = Uuid::new_v4().to_string(); + let conversation_archive = User::get_user_conversations(&user.id, &state.surreal_db_client) + .await + .map_err(|e| HtmlError::new(e, state.templates.clone()))?; + + let conversation = Conversation::new(user.id.clone(), "New Chat".to_string()); let output = render_template( ChatData::template_name(), ChatData { history: vec![], user, - conversation_id, + conversation_archive, + conversation, }, state.templates.clone(), )?; @@ -121,39 +150,37 @@ pub struct NewMessageForm { content: String, } -pub async fn new_user_message( +pub async fn show_existing_chat( Path(conversation_id): Path, State(state): State, auth: AuthSession, Surreal>, - Form(form): Form, ) -> Result { - info!("Displaying empty chat start"); + info!("Displaying initialized chat with id: {}", conversation_id); let user = match auth.current_user { Some(user) => user, None => return Ok(Redirect::to("/").into_response()), }; - let query_id = Uuid::new_v4().to_string(); - let user_message = form.content.clone(); + let conversation_archive = User::get_user_conversations(&user.id, &state.surreal_db_client) + .await + .map_err(|e| HtmlError::new(e, state.templates.clone()))?; - // Save to database - // state - // .db - // .save(conversation_id, query_id.clone(), user_message) - // .await; - - #[derive(Serialize)] - struct SSEResponseInitData { - user_message: String, - query_id: String, - } + let (conversation, messages) = Conversation::get_complete_conversation( + conversation_id.as_str(), + &user.id, + &state.surreal_db_client, + ) + .await + .map_err(|e| HtmlError::new(e, state.templates.clone()))?; let output = render_template( - "chat/streaming_response.html", - SSEResponseInitData { - user_message, - query_id, + ChatData::template_name(), + ChatData { + history: messages, + user, + conversation: conversation.clone(), + conversation_archive, }, state.templates.clone(), )?; @@ -161,38 +188,33 @@ pub async fn new_user_message( Ok(output.into_response()) } -#[derive(Deserialize)] -pub struct QueryParams { - query_id: String, -} - -pub async fn get_response_stream( - State(_state): State, +pub async fn new_user_message( + Path(conversation_id): Path, + State(state): State, auth: AuthSession, Surreal>, - Query(params): Query, -) -> Sse>> { - let stream = stream::iter(vec![ - Event::default() - .event("chat_message") - .data("Hello, starting stream!"), - Event::default() - .event("chat_message") - .data("This is message 2"), - Event::default().event("chat_message").data("Final message"), - Event::default() - .event("close_stream") - .data("Stream complete"), // Signal to close - ]) - .then(|event| async move { - sleep(Duration::from_millis(500)).await; // Delay between messages - Ok(event) - }); + Form(form): Form, +) -> Result { + let _user = match auth.current_user { + Some(user) => user, + None => return Ok(Redirect::to("/").into_response()), + }; - info!("Streaming started"); + let user_message = Message::new(conversation_id, MessageRole::User, form.content, None); - Sse::new(stream).keep_alive( - axum::response::sse::KeepAlive::new() - .interval(Duration::from_secs(15)) - .text("keep-alive"), - ) + store_item(&state.surreal_db_client, user_message.clone()) + .await + .map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?; + + #[derive(Serialize)] + struct SSEResponseInitData { + user_message: Message, + } + + let output = render_template( + "chat/streaming_response.html", + SSEResponseInitData { user_message }, + state.templates.clone(), + )?; + + Ok(output.into_response()) } diff --git a/src/server/routes/mod.rs b/src/server/routes/mod.rs index 06b7ef7..5dd5bd4 100644 --- a/src/server/routes/mod.rs +++ b/src/server/routes/mod.rs @@ -16,7 +16,10 @@ use axum_session_surreal::SessionSurrealPool; use html::{ account::{delete_account, set_api_key, show_account_page, update_timezone}, admin_panel::{show_admin_panel, toggle_registration_status}, - chat::{get_response_stream, new_user_message, show_chat_base, show_initialized_chat}, + chat::{ + message_response_stream::get_response_stream, new_user_message, show_chat_base, + show_existing_chat, show_initialized_chat, + }, content::{patch_text_content, show_content_page, show_text_content_edit_form}, documentation::{ show_documentation_index, show_get_started, show_mobile_friendly, show_privacy_policy, @@ -65,7 +68,7 @@ pub fn html_routes(app_state: &AppState) -> Router { .route("/gdpr/deny", post(deny_gdpr)) .route("/search", get(search_result_handler)) .route("/chat", get(show_chat_base).post(show_initialized_chat)) - .route("/chat/:id", post(new_user_message)) + .route("/chat/:id", get(show_existing_chat).post(new_user_message)) .route("/chat/response-stream", get(get_response_stream)) .route("/signout", get(sign_out_user)) .route("/signin", get(show_signin_form).post(authenticate_user)) diff --git a/src/storage/types/conversation.rs b/src/storage/types/conversation.rs index 01ffe18..7a85494 100644 --- a/src/storage/types/conversation.rs +++ b/src/storage/types/conversation.rs @@ -1,6 +1,12 @@ use uuid::Uuid; -use crate::stored_object; +use crate::{ + error::AppError, + storage::db::{get_item, SurrealDbClient}, + stored_object, +}; + +use super::message::Message; stored_object!(Conversation, "conversation", { user_id: String, @@ -18,4 +24,29 @@ impl Conversation { title, } } + + pub async fn get_complete_conversation( + conversation_id: &str, + user_id: &str, + db: &SurrealDbClient, + ) -> Result<(Self, Vec), AppError> { + let conversation: Conversation = get_item(&db, conversation_id) + .await? + .ok_or_else(|| AppError::NotFound("Conversation not found".to_string()))?; + + if conversation.user_id != user_id { + return Err(AppError::Auth( + "You don't have access to this conversation".to_string(), + )); + } + + let messages:Vec = db.client. + query("SELECT * FROM type::table($table_name) WHERE conversation_id = $conversation_id ORDER BY updated_at"). + bind(("table_name", Message::table_name())). + bind(("conversation_id", conversation_id.to_string())) + .await? + .take(0)?; + + Ok((conversation, messages)) + } } diff --git a/src/storage/types/message.rs b/src/storage/types/message.rs index 3cd929f..ca8afbf 100644 --- a/src/storage/types/message.rs +++ b/src/storage/types/message.rs @@ -2,7 +2,7 @@ use uuid::Uuid; use crate::stored_object; -#[derive(Deserialize, Debug, Serialize)] +#[derive(Deserialize, Debug, Clone, Serialize)] pub enum MessageRole { User, AI, diff --git a/src/storage/types/user.rs b/src/storage/types/user.rs index 9a9a8d6..f928452 100644 --- a/src/storage/types/user.rs +++ b/src/storage/types/user.rs @@ -8,8 +8,9 @@ use surrealdb::{engine::any::Any, Surreal}; use uuid::Uuid; use super::{ - knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship, - system_settings::SystemSettings, text_content::TextContent, + conversation::Conversation, knowledge_entity::KnowledgeEntity, + knowledge_relationship::KnowledgeRelationship, system_settings::SystemSettings, + text_content::TextContent, }; #[derive(Deserialize)] @@ -333,4 +334,19 @@ impl User { Ok(text_content) } + + pub async fn get_user_conversations( + user_id: &str, + db: &SurrealDbClient, + ) -> Result, AppError> { + let conversations: Vec = db + .client + .query("SELECT * FROM type::table($table_name) WHERE user_id = $user_id") + .bind(("table_name", Conversation::table_name())) + .bind(("user_id", user_id.to_string())) + .await? + .take(0)?; + + Ok(conversations) + } } diff --git a/templates/chat/drawer.html b/templates/chat/drawer.html index 49bf54f..893674e 100644 --- a/templates/chat/drawer.html +++ b/templates/chat/drawer.html @@ -7,7 +7,10 @@ "icons/edit_icon.html" %}
-
  • Sidebar Item 1
  • -
  • Sidebar Item 2
  • + {% for conversation in conversation_archive %} +
  • {{conversation.title}} - {{conversation.created_at}} +
  • + {% endfor %} + {{conversation_archive}} \ No newline at end of file diff --git a/templates/chat/new_message_form.html b/templates/chat/new_message_form.html index 1250dd7..b8ed3c6 100644 --- a/templates/chat/new_message_form.html +++ b/templates/chat/new_message_form.html @@ -1,5 +1,5 @@
    -