#![allow(clippy::missing_docs_in_private_items)] use std::{pin::Pin, sync::Arc, time::Duration}; use async_stream::stream; use axum::{ extract::{Query, State}, response::{ sse::{Event, KeepAlive}, Sse, }, }; use futures::{ stream::{self, once}, Stream, StreamExt, TryStreamExt, }; use json_stream_parser::JsonStreamParser; use minijinja::Value; use retrieval_pipeline::{ answer_retrieval::{ chunks_to_chat_context, create_chat_request, create_user_message_with_history, LLMResponseFormat, }, retrieved_entities_to_json, }; use serde::{Deserialize, Serialize}; use serde_json::from_str; use tokio::sync::{mpsc::channel, Mutex}; use tracing::{debug, error, info}; use common::storage::{ db::SurrealDbClient, types::{ conversation::Conversation, message::{Message, MessageRole}, system_settings::SystemSettings, user::User, }, }; use crate::{html_state::HtmlState, AuthSessionType}; use super::reference_validation::{collect_reference_ids_from_retrieval, validate_references}; // Error handling function fn create_error_stream( message: impl Into, ) -> Pin> + Send>> { let message = message.into(); stream::once(async move { Ok(Event::default().event("error").data(message)) }).boxed() } // Helper function to get message and user async fn get_message_and_user( db: &SurrealDbClient, current_user: Option, message_id: &str, ) -> Result< (Message, User, Conversation, Vec, Option), Sse> + Send>>>, > { // Check authentication let Some(user) = current_user else { return Err(Sse::new(create_error_stream( "You must be signed in to use this feature", ))); }; // Retrieve message let message = match db.get_item::(message_id).await { Ok(Some(message)) => message, Ok(None) => { return Err(Sse::new(create_error_stream( "Message not found: the specified message does not exist", ))) } Err(e) => { error!("Database error retrieving message {}: {:?}", message_id, e); return Err(Sse::new(create_error_stream( "Failed to retrieve message: database error", ))); } }; // Get conversation history let (conversation, history) = match Conversation::get_complete_conversation(&message.conversation_id, &user.id, db).await { Err(e) => { error!("Database error retrieving message {}: {:?}", message_id, e); return Err(Sse::new(create_error_stream( "Failed to retrieve message: database error", ))); } Ok((conversation, history)) => (conversation, history), }; let Some(message_index) = find_message_index(&history, message_id) else { return Err(Sse::new(create_error_stream( "Message not found in conversation history", ))); }; let Some(message_from_history) = history.get(message_index) else { return Err(Sse::new(create_error_stream( "Message not found in conversation history", ))); }; if message_from_history.role != MessageRole::User { return Err(Sse::new(create_error_stream( "Only user messages can be used to generate a response", ))); } let message = message_from_history.clone(); let history_before_message = history_before_message(&history, message_index); let existing_ai_response = find_existing_ai_response(&history, message_index); Ok(( message, user, conversation, history_before_message, existing_ai_response, )) } fn find_message_index(messages: &[Message], message_id: &str) -> Option { messages.iter().position(|message| message.id == message_id) } fn find_existing_ai_response(messages: &[Message], user_message_index: usize) -> Option { messages .iter() .skip(user_message_index + 1) .take_while(|message| message.role != MessageRole::User) .find(|message| message.role == MessageRole::AI) .cloned() } fn history_before_message(messages: &[Message], message_index: usize) -> Vec { messages.iter().take(message_index).cloned().collect() } fn create_replayed_response_stream( state: &HtmlState, existing_ai_message: Message, ) -> Sse> + Send>>> { let references_event = if existing_ai_message .references .as_ref() .is_some_and(|references| !references.is_empty()) { state .templates .render( "chat/reference_list.html", &Value::from_serialize(ReferenceData { message: existing_ai_message.clone(), }), ) .ok() .map(|html| Event::default().event("references").data(html)) } else { None }; let answer = existing_ai_message.content; let event_stream = stream! { yield Ok(Event::default().event("chat_message").data(answer)); if let Some(event) = references_event { yield Ok(event); } yield Ok(Event::default().event("close_stream").data("Stream complete")); }; Sse::new(event_stream.boxed()).keep_alive( KeepAlive::new() .interval(Duration::from_secs(15)) .text("keep-alive"), ) } #[derive(Deserialize)] pub struct QueryParams { message_id: String, } #[derive(Serialize)] struct ReferenceData { message: Message, } fn extract_reference_strings(response: &LLMResponseFormat) -> Vec { response .references .iter() .map(|reference| reference.reference.clone()) .collect() } #[allow(clippy::too_many_lines)] pub async fn get_response_stream( State(state): State, auth: AuthSessionType, Query(params): Query, ) -> Sse> + Send>>> { // 1. Authentication and initial data validation let (user_message, user, _conversation, history, existing_ai_response) = match get_message_and_user(&state.db, auth.current_user, ¶ms.message_id).await { Ok((user_message, user, conversation, history, existing_ai_response)) => ( user_message, user, conversation, history, existing_ai_response, ), Err(error_stream) => return error_stream, }; if let Some(existing_ai_message) = existing_ai_response { return create_replayed_response_stream(&state, existing_ai_message); } // 2. Retrieve knowledge entities let rerank_lease = match state.reranker_pool.as_ref() { Some(pool) => Some(pool.checkout().await), None => None, }; let strategy = state.retrieval_strategy(); let config = retrieval_pipeline::RetrievalConfig::for_chat(strategy); let retrieval_result = match retrieval_pipeline::retrieve_entities( &state.db, &state.openai_client, Some(&*state.embedding_provider), &user_message.content, &user.id, config, rerank_lease, ) .await { Ok(result) => result, Err(_e) => { return Sse::new(create_error_stream("Failed to retrieve knowledge")); } }; let allowed_reference_ids = collect_reference_ids_from_retrieval(&retrieval_result); // 3. Create the OpenAI request with appropriate context format let context_json = match &retrieval_result { retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(chunks), retrieval_pipeline::StrategyOutput::Entities(entities) => { retrieved_entities_to_json(entities) } retrieval_pipeline::StrategyOutput::Search(search_result) => { // For chat, use chunks from the search result chunks_to_chat_context(&search_result.chunks) } }; let formatted_user_message = create_user_message_with_history(&context_json, &history, &user_message.content); let Ok(settings) = SystemSettings::get_current(&state.db).await else { return Sse::new(create_error_stream("Failed to retrieve system settings")); }; let Ok(request) = create_chat_request(formatted_user_message, &settings) else { return Sse::new(create_error_stream("Failed to create chat request")); }; // 4. Set up the OpenAI stream let openai_stream = match state.openai_client.chat().create_stream(request).await { Ok(stream) => stream, Err(_e) => { return Sse::new(create_error_stream("Failed to create OpenAI stream")); } }; // 5. Create channel for collecting complete response let (tx, mut rx) = channel::(1000); let tx_clone = tx.clone(); let (tx_final, mut rx_final) = channel::(1); // 6. Set up the collection task for DB storage let db_client = Arc::clone(&state.db); let user_id = user.id.clone(); let allowed_reference_ids = allowed_reference_ids.clone(); tokio::spawn(async move { drop(tx); // Close sender when no longer needed // Collect full response let mut full_json = String::new(); while let Some(chunk) = rx.recv().await { full_json.push_str(&chunk); } // Try to extract structured data if let Ok(response) = from_str::(&full_json) { let raw_references = extract_reference_strings(&response); let answer = response.answer; let initial_validation = match validate_references( &user_id, raw_references, &allowed_reference_ids, &db_client, ) .await { Ok(result) => result, Err(err) => { error!(error = %err, "Reference validation failed, storing answer without references"); let ai_message = Message::new( user_message.conversation_id, MessageRole::AI, answer, Some(Vec::new()), ); let _ = tx_final.send(ai_message.clone()).await; if let Err(store_err) = db_client.store_item(ai_message).await { error!(error = ?store_err, "Failed to store AI message after validation failure"); } return; } }; info!( total_refs = initial_validation.reason_stats.total, valid_refs = initial_validation.valid_refs.len(), invalid_refs = initial_validation.invalid_refs.len(), invalid_empty = initial_validation.reason_stats.empty, invalid_unsupported_prefix = initial_validation.reason_stats.unsupported_prefix, invalid_malformed_uuid = initial_validation.reason_stats.malformed_uuid, invalid_duplicate = initial_validation.reason_stats.duplicate, invalid_not_in_context = initial_validation.reason_stats.not_in_context, invalid_not_found = initial_validation.reason_stats.not_found, invalid_wrong_user = initial_validation.reason_stats.wrong_user, invalid_over_limit = initial_validation.reason_stats.over_limit, "Post-LLM reference validation complete" ); let ai_message = Message::new( user_message.conversation_id, MessageRole::AI, answer, Some(initial_validation.valid_refs), ); let _ = tx_final.send(ai_message.clone()).await; match db_client.store_item(ai_message).await { Ok(_) => debug!("Successfully stored AI message with references"), Err(e) => error!("Failed to store AI message: {:?}", e), } } else { error!("Failed to parse LLM response as structured format"); // Fallback - store raw response let ai_message = Message::new( user_message.conversation_id, MessageRole::AI, full_json, None, ); let _ = db_client.store_item(ai_message).await; } }); // Create a shared state for tracking the JSON parsing let json_state = Arc::new(Mutex::new(StreamParserState::new())); // 7. Create the response event stream let event_stream = openai_stream .map_err(|e| Box::new(e) as Box) .map(move |result| { let tx_storage = tx_clone.clone(); let json_state = Arc::clone(&json_state); stream! { match result { Ok(response) => { let content = response .choices .first() .and_then(|choice| choice.delta.content.clone()) .unwrap_or_default(); if !content.is_empty() { // Always send raw content to storage let _ = tx_storage.send(content.clone()).await; // Process through JSON parser let mut state = json_state.lock().await; let display_content = state.process_chunk(&content); drop(state); if !display_content.is_empty() { yield Ok(Event::default() .event("chat_message") .data(display_content)); } // If display_content is empty, don't yield anything } // If content is empty, don't yield anything } Err(e) => { yield Ok(Event::default() .event("error") .data(format!("Stream error: {e}"))); } } } }) .flatten() .chain(stream::once(async move { if let Some(message) = rx_final.recv().await { // Don't send any event if references is empty if message .references .as_ref() .is_some_and(std::vec::Vec::is_empty) { return Ok(Event::default().event("empty")); // This event won't be sent } // Render template with references match state.templates.render( "chat/reference_list.html", &Value::from_serialize(ReferenceData { message }), ) { Ok(html) => { // Return the rendered HTML Ok(Event::default().event("references").data(html)) } Err(_) => { // Handle template rendering error Ok(Event::default() .event("error") .data("Failed to render references")) } } } else { // Handle case where no references were received Ok(Event::default() .event("error") .data("Failed to retrieve references")) } })) .chain(once(async { Ok(Event::default() .event("close_stream") .data("Stream complete")) })); Sse::new(event_stream.boxed()).keep_alive( KeepAlive::new() .interval(Duration::from_secs(15)) .text("keep-alive"), ) } struct StreamParserState { parser: JsonStreamParser, last_answer_content: String, in_answer_field: bool, } impl StreamParserState { fn new() -> Self { Self { parser: JsonStreamParser::new(), last_answer_content: String::new(), in_answer_field: false, } } fn process_chunk(&mut self, chunk: &str) -> String { // Feed all characters into the parser for c in chunk.chars() { let _ = self.parser.add_char(c); } // Get the current state of the JSON let json = self.parser.get_result(); // Check if we're in the answer field if let Some(obj) = json.as_object() { if let Some(answer) = obj.get("answer") { self.in_answer_field = true; // Get current answer content let current_content = answer.as_str().unwrap_or_default().to_string(); // Calculate difference to send only new content if current_content.len() > self.last_answer_content.len() { let new_content = current_content[self.last_answer_content.len()..].to_string(); self.last_answer_content = current_content; return new_content; } } } // No new content to return String::new() } } #[cfg(test)] mod tests { use super::*; use chrono::{Duration as ChronoDuration, Utc}; use common::storage::{ db::SurrealDbClient, types::{conversation::Conversation, user::Theme}, }; use retrieval_pipeline::answer_retrieval::Reference; use uuid::Uuid; fn make_test_message(id: &str, role: MessageRole) -> Message { let mut message = Message::new( "conversation-1".to_string(), role, format!("content-{id}"), None, ); message.id = id.to_string(); message } fn make_test_user(id: &str) -> User { User { id: id.to_string(), created_at: Utc::now(), updated_at: Utc::now(), email: "test@example.com".to_string(), password: "password".to_string(), anonymous: false, api_key: None, admin: false, timezone: "UTC".to_string(), theme: Theme::System, } } #[test] fn extracts_reference_strings_in_order() { let response = LLMResponseFormat { answer: "answer".to_string(), references: vec![ Reference { reference: "a".to_string(), }, Reference { reference: "b".to_string(), }, ], }; let extracted = extract_reference_strings(&response); assert_eq!(extracted, vec!["a".to_string(), "b".to_string()]); } #[test] fn finds_message_index_for_existing_message() { let messages = vec![ make_test_message("m1", MessageRole::User), make_test_message("m2", MessageRole::AI), make_test_message("m3", MessageRole::User), ]; assert_eq!(find_message_index(&messages, "m2"), Some(1)); assert_eq!(find_message_index(&messages, "missing"), None); } #[test] fn finds_existing_ai_response_for_same_turn() { let messages = vec![ make_test_message("u1", MessageRole::User), make_test_message("system", MessageRole::System), make_test_message("a1", MessageRole::AI), make_test_message("u2", MessageRole::User), make_test_message("a2", MessageRole::AI), ]; let ai_reply = find_existing_ai_response(&messages, 0).expect("expected AI response"); assert_eq!(ai_reply.id, "a1"); let ai_reply_second_turn = find_existing_ai_response(&messages, 3).expect("expected AI response"); assert_eq!(ai_reply_second_turn.id, "a2"); } #[test] fn does_not_replay_ai_response_from_later_turn() { let messages = vec![ make_test_message("u1", MessageRole::User), make_test_message("u2", MessageRole::User), make_test_message("a2", MessageRole::AI), ]; assert!(find_existing_ai_response(&messages, 0).is_none()); let ai_reply = find_existing_ai_response(&messages, 1).expect("expected AI response"); assert_eq!(ai_reply.id, "a2"); } #[test] fn history_before_message_excludes_target_and_future_messages() { let messages = vec![ make_test_message("u1", MessageRole::User), make_test_message("a1", MessageRole::AI), make_test_message("u2", MessageRole::User), make_test_message("a2", MessageRole::AI), ]; let history_for_u2 = history_before_message(&messages, 2); let history_ids: Vec = history_for_u2 .into_iter() .map(|message| message.id) .collect(); assert_eq!(history_ids, vec!["u1".to_string(), "a1".to_string()]); } #[tokio::test] async fn get_message_and_user_reuses_existing_ai_response_for_same_turn() { let namespace = "chat_stream_replay"; let database = Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, &database) .await .expect("failed to create in-memory db"); let user = make_test_user("user-1"); let conversation = Conversation::new(user.id.clone(), "Conversation".to_string()); let mut user_message = Message::new( conversation.id.clone(), MessageRole::User, "Question one".to_string(), None, ); user_message.id = "u1".to_string(); let mut ai_message = Message::new( conversation.id.clone(), MessageRole::AI, "Answer one".to_string(), Some(vec!["ref-1".to_string()]), ); ai_message.id = "a1".to_string(); ai_message.created_at = user_message.created_at + ChronoDuration::seconds(1); ai_message.updated_at = ai_message.created_at; let mut second_user_message = Message::new( conversation.id.clone(), MessageRole::User, "Question two".to_string(), None, ); second_user_message.id = "u2".to_string(); second_user_message.created_at = ai_message.created_at + ChronoDuration::seconds(1); second_user_message.updated_at = second_user_message.created_at; db.store_item(conversation.clone()) .await .expect("failed to store conversation"); db.store_item(user_message.clone()) .await .expect("failed to store user message"); db.store_item(ai_message.clone()) .await .expect("failed to store ai message"); db.store_item(second_user_message.clone()) .await .expect("failed to store second user message"); let (_, _, _, history_for_first_turn, existing_ai_for_first_turn) = get_message_and_user(&db, Some(user.clone()), &user_message.id) .await .expect("expected first turn to load"); assert!(history_for_first_turn.is_empty()); let existing_ai_for_first_turn = existing_ai_for_first_turn.expect("expected first-turn AI response"); assert_eq!(existing_ai_for_first_turn.id, ai_message.id); let (_, _, _, history_for_second_turn, existing_ai_for_second_turn) = get_message_and_user(&db, Some(user), &second_user_message.id) .await .expect("expected second turn to load"); let history_ids: Vec = history_for_second_turn .into_iter() .map(|message| message.id) .collect(); assert_eq!(history_ids, vec!["u1".to_string(), "a1".to_string()]); assert!(existing_ai_for_second_turn.is_none()); } }