refactor: async-stream and improved reference handling

This commit is contained in:
Per Stark
2025-02-27 13:49:45 +01:00
parent 4ce272d5be
commit 21f0ebef33
7 changed files with 254 additions and 194 deletions

View File

@@ -1,15 +1,24 @@
use std::{pin::Pin, time::Duration};
use std::{pin::Pin, sync::Arc, time::Duration};
use async_stream::stream;
use axum::{
extract::{Query, State},
response::{sse::Event, Sse},
response::{
sse::{Event, KeepAlive},
Sse,
},
};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use futures::{stream, Stream, StreamExt, TryStreamExt};
use futures::{
stream::{self, once},
Stream, StreamExt, TryStreamExt,
};
use json_stream_parser::JsonStreamParser;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use serde_json::from_str;
use surrealdb::{engine::any::Any, Surreal};
use tokio::sync::{mpsc::channel, Mutex};
use tracing::{error, info};
use crate::{
@@ -19,7 +28,7 @@ use crate::{
create_chat_request, create_user_message, format_entities_json, LLMResponseFormat,
},
},
server::AppState,
server::{routes::html::render_template, AppState},
storage::{
db::{get_item, store_item, SurrealDbClient},
types::{
@@ -29,6 +38,7 @@ use crate::{
},
};
// Error handling function
fn create_error_stream(
message: impl Into<String>,
) -> Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>> {
@@ -127,8 +137,9 @@ pub async fn get_response_stream(
};
// 5. Create channel for collecting complete response
let (tx, mut rx) = tokio::sync::mpsc::channel::<String>(1000);
let (tx, mut rx) = channel::<String>(1000);
let tx_clone = tx.clone();
let (tx_final, mut rx_final) = channel::<Vec<String>>(1);
// 6. Set up the collection task for DB storage
let db_client = state.surreal_db_client.clone();
@@ -142,13 +153,15 @@ pub async fn get_response_stream(
}
// Try to extract structured data
if let Ok(response) = serde_json::from_str::<LLMResponseFormat>(&full_json) {
if let Ok(response) = from_str::<LLMResponseFormat>(&full_json) {
let references: Vec<String> = response
.references
.into_iter()
.map(|r| r.reference)
.collect();
let _ = tx_final.send(references.clone()).await;
let ai_message = Message::new(
user_message.conversation_id,
MessageRole::AI,
@@ -176,7 +189,7 @@ pub async fn get_response_stream(
});
// Create a shared state for tracking the JSON parsing
let json_state = std::sync::Arc::new(tokio::sync::Mutex::new(StreamParserState::new()));
let json_state = Arc::new(Mutex::new(StreamParserState::new()));
// 7. Create the response event stream
let event_stream = openai_stream
@@ -185,7 +198,7 @@ pub async fn get_response_stream(
let tx_storage = tx_clone.clone();
let json_state = json_state.clone();
async move {
stream! {
match result {
Ok(response) => {
let content = response
@@ -200,29 +213,71 @@ pub async fn get_response_stream(
// Process through JSON parser
let mut state = json_state.lock().await;
let display_content = state.process_chunk(&content);
drop(state);
if !display_content.is_empty() {
return Ok(Event::default()
yield Ok(Event::default()
.event("chat_message")
.data(display_content));
}
// Empty or filtered content
Ok(Event::default().event("chat_message").data(""))
} else {
Ok(Event::default().event("chat_message").data(""))
// If display_content is empty, don't yield anything
}
// If content is empty, don't yield anything
}
Err(e) => {
yield Ok(Event::default()
.event("error")
.data(format!("Stream error: {}", e)));
}
Err(e) => Ok(Event::default()
.event("error")
.data(format!("Stream error: {}", e))),
}
}
})
.buffered(10)
.chain(stream::once(async {
.flatten()
.chain(stream::once(async move {
if let Some(references) = rx_final.recv().await {
// Don't send any event if references is empty
if references.is_empty() {
return Ok(Event::default().event("empty")); // This event won't be sent
}
// Prepare data for template
#[derive(Serialize)]
struct ReferenceData {
references: Vec<String>,
user_message_id: String,
}
// Render template with references
match render_template(
"chat/reference_list.html",
ReferenceData {
references,
user_message_id: user_message.id,
},
state.templates.clone(),
) {
Ok(html) => {
// Extract the String from Html<String>
let html_string = html.0; // Convert Html<String> to String
// Return the rendered HTML
Ok(Event::default().event("references").data(html_string))
}
Err(_) => {
// Handle template rendering error
Ok(Event::default()
.event("error")
.data("Failed to render references"))
}
}
} else {
// Handle case where no references were received
Ok(Event::default()
.event("error")
.data("Failed to retrieve references"))
}
}))
.chain(once(async {
Ok(Event::default()
.event("close_stream")
.data("Stream complete"))
@@ -230,7 +285,7 @@ pub async fn get_response_stream(
info!("OpenAI streaming started");
Sse::new(event_stream.boxed()).keep_alive(
axum::response::sse::KeepAlive::new()
KeepAlive::new()
.interval(Duration::from_secs(15))
.text("keep-alive"),
)
@@ -259,7 +314,6 @@ impl StreamParserState {
}
// Get the current state of the JSON
// The get_result() method returns a &Value, not a Result
let json = self.parser.get_result();
// Check if we're in the answer field
@@ -283,46 +337,3 @@ impl StreamParserState {
String::new()
}
}
// 7. Create the response event stream
// let event_stream = openai_stream
// .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
// .map(move |result| {
// let tx = tx_clone.clone();
// async move {
// match result {
// Ok(response) => {
// let content = response
// .choices
// .first()
// .and_then(|choice| choice.delta.content.clone())
// .unwrap_or_default();
// if !content.is_empty() {
// let _ = tx.send(content.clone()).await;
// Ok(Event::default().event("chat_message").data(content))
// } else {
// Ok(Event::default().event("chat_message").data(""))
// }
// }
// Err(e) => Ok(Event::default()
// .event("error")
// .data(format!("Stream error: {}", e))),
// }
// }
// })
// .buffered(10)
// .chain(stream::once(async {
// Ok(Event::default()
// .event("close_stream")
// .data("Stream complete"))
// }));
// info!("OpenAI streaming started");
// Sse::new(event_stream.boxed()).keep_alive(
// axum::response::sse::KeepAlive::new()
// .interval(Duration::from_secs(15))
// .text("keep-alive"),
// )
// }