mirror of
https://github.com/perstarkse/minne.git
synced 2026-03-31 14:43:20 +02:00
refactor: async-stream and improved reference handling
This commit is contained in:
@@ -1,15 +1,24 @@
|
||||
use std::{pin::Pin, time::Duration};
|
||||
use std::{pin::Pin, sync::Arc, time::Duration};
|
||||
|
||||
use async_stream::stream;
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
response::{sse::Event, Sse},
|
||||
response::{
|
||||
sse::{Event, KeepAlive},
|
||||
Sse,
|
||||
},
|
||||
};
|
||||
use axum_session_auth::AuthSession;
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use futures::{stream, Stream, StreamExt, TryStreamExt};
|
||||
use futures::{
|
||||
stream::{self, once},
|
||||
Stream, StreamExt, TryStreamExt,
|
||||
};
|
||||
use json_stream_parser::JsonStreamParser;
|
||||
use serde::Deserialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::from_str;
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
use tokio::sync::{mpsc::channel, Mutex};
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::{
|
||||
@@ -19,7 +28,7 @@ use crate::{
|
||||
create_chat_request, create_user_message, format_entities_json, LLMResponseFormat,
|
||||
},
|
||||
},
|
||||
server::AppState,
|
||||
server::{routes::html::render_template, AppState},
|
||||
storage::{
|
||||
db::{get_item, store_item, SurrealDbClient},
|
||||
types::{
|
||||
@@ -29,6 +38,7 @@ use crate::{
|
||||
},
|
||||
};
|
||||
|
||||
// Error handling function
|
||||
fn create_error_stream(
|
||||
message: impl Into<String>,
|
||||
) -> Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>> {
|
||||
@@ -127,8 +137,9 @@ pub async fn get_response_stream(
|
||||
};
|
||||
|
||||
// 5. Create channel for collecting complete response
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel::<String>(1000);
|
||||
let (tx, mut rx) = channel::<String>(1000);
|
||||
let tx_clone = tx.clone();
|
||||
let (tx_final, mut rx_final) = channel::<Vec<String>>(1);
|
||||
|
||||
// 6. Set up the collection task for DB storage
|
||||
let db_client = state.surreal_db_client.clone();
|
||||
@@ -142,13 +153,15 @@ pub async fn get_response_stream(
|
||||
}
|
||||
|
||||
// Try to extract structured data
|
||||
if let Ok(response) = serde_json::from_str::<LLMResponseFormat>(&full_json) {
|
||||
if let Ok(response) = from_str::<LLMResponseFormat>(&full_json) {
|
||||
let references: Vec<String> = response
|
||||
.references
|
||||
.into_iter()
|
||||
.map(|r| r.reference)
|
||||
.collect();
|
||||
|
||||
let _ = tx_final.send(references.clone()).await;
|
||||
|
||||
let ai_message = Message::new(
|
||||
user_message.conversation_id,
|
||||
MessageRole::AI,
|
||||
@@ -176,7 +189,7 @@ pub async fn get_response_stream(
|
||||
});
|
||||
|
||||
// Create a shared state for tracking the JSON parsing
|
||||
let json_state = std::sync::Arc::new(tokio::sync::Mutex::new(StreamParserState::new()));
|
||||
let json_state = Arc::new(Mutex::new(StreamParserState::new()));
|
||||
|
||||
// 7. Create the response event stream
|
||||
let event_stream = openai_stream
|
||||
@@ -185,7 +198,7 @@ pub async fn get_response_stream(
|
||||
let tx_storage = tx_clone.clone();
|
||||
let json_state = json_state.clone();
|
||||
|
||||
async move {
|
||||
stream! {
|
||||
match result {
|
||||
Ok(response) => {
|
||||
let content = response
|
||||
@@ -200,29 +213,71 @@ pub async fn get_response_stream(
|
||||
|
||||
// Process through JSON parser
|
||||
let mut state = json_state.lock().await;
|
||||
|
||||
let display_content = state.process_chunk(&content);
|
||||
drop(state);
|
||||
if !display_content.is_empty() {
|
||||
return Ok(Event::default()
|
||||
yield Ok(Event::default()
|
||||
.event("chat_message")
|
||||
.data(display_content));
|
||||
}
|
||||
|
||||
// Empty or filtered content
|
||||
Ok(Event::default().event("chat_message").data(""))
|
||||
} else {
|
||||
Ok(Event::default().event("chat_message").data(""))
|
||||
// If display_content is empty, don't yield anything
|
||||
}
|
||||
// If content is empty, don't yield anything
|
||||
}
|
||||
Err(e) => {
|
||||
yield Ok(Event::default()
|
||||
.event("error")
|
||||
.data(format!("Stream error: {}", e)));
|
||||
}
|
||||
Err(e) => Ok(Event::default()
|
||||
.event("error")
|
||||
.data(format!("Stream error: {}", e))),
|
||||
}
|
||||
}
|
||||
})
|
||||
.buffered(10)
|
||||
.chain(stream::once(async {
|
||||
.flatten()
|
||||
.chain(stream::once(async move {
|
||||
if let Some(references) = rx_final.recv().await {
|
||||
// Don't send any event if references is empty
|
||||
if references.is_empty() {
|
||||
return Ok(Event::default().event("empty")); // This event won't be sent
|
||||
}
|
||||
|
||||
// Prepare data for template
|
||||
#[derive(Serialize)]
|
||||
struct ReferenceData {
|
||||
references: Vec<String>,
|
||||
user_message_id: String,
|
||||
}
|
||||
|
||||
// Render template with references
|
||||
match render_template(
|
||||
"chat/reference_list.html",
|
||||
ReferenceData {
|
||||
references,
|
||||
user_message_id: user_message.id,
|
||||
},
|
||||
state.templates.clone(),
|
||||
) {
|
||||
Ok(html) => {
|
||||
// Extract the String from Html<String>
|
||||
let html_string = html.0; // Convert Html<String> to String
|
||||
|
||||
// Return the rendered HTML
|
||||
Ok(Event::default().event("references").data(html_string))
|
||||
}
|
||||
Err(_) => {
|
||||
// Handle template rendering error
|
||||
Ok(Event::default()
|
||||
.event("error")
|
||||
.data("Failed to render references"))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Handle case where no references were received
|
||||
Ok(Event::default()
|
||||
.event("error")
|
||||
.data("Failed to retrieve references"))
|
||||
}
|
||||
}))
|
||||
.chain(once(async {
|
||||
Ok(Event::default()
|
||||
.event("close_stream")
|
||||
.data("Stream complete"))
|
||||
@@ -230,7 +285,7 @@ pub async fn get_response_stream(
|
||||
|
||||
info!("OpenAI streaming started");
|
||||
Sse::new(event_stream.boxed()).keep_alive(
|
||||
axum::response::sse::KeepAlive::new()
|
||||
KeepAlive::new()
|
||||
.interval(Duration::from_secs(15))
|
||||
.text("keep-alive"),
|
||||
)
|
||||
@@ -259,7 +314,6 @@ impl StreamParserState {
|
||||
}
|
||||
|
||||
// Get the current state of the JSON
|
||||
// The get_result() method returns a &Value, not a Result
|
||||
let json = self.parser.get_result();
|
||||
|
||||
// Check if we're in the answer field
|
||||
@@ -283,46 +337,3 @@ impl StreamParserState {
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
|
||||
// 7. Create the response event stream
|
||||
// let event_stream = openai_stream
|
||||
// .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
|
||||
// .map(move |result| {
|
||||
// let tx = tx_clone.clone();
|
||||
// async move {
|
||||
// match result {
|
||||
// Ok(response) => {
|
||||
// let content = response
|
||||
// .choices
|
||||
// .first()
|
||||
// .and_then(|choice| choice.delta.content.clone())
|
||||
// .unwrap_or_default();
|
||||
|
||||
// if !content.is_empty() {
|
||||
// let _ = tx.send(content.clone()).await;
|
||||
// Ok(Event::default().event("chat_message").data(content))
|
||||
// } else {
|
||||
// Ok(Event::default().event("chat_message").data(""))
|
||||
// }
|
||||
// }
|
||||
// Err(e) => Ok(Event::default()
|
||||
// .event("error")
|
||||
// .data(format!("Stream error: {}", e))),
|
||||
// }
|
||||
// }
|
||||
// })
|
||||
// .buffered(10)
|
||||
// .chain(stream::once(async {
|
||||
// Ok(Event::default()
|
||||
// .event("close_stream")
|
||||
// .data("Stream complete"))
|
||||
// }));
|
||||
|
||||
// info!("OpenAI streaming started");
|
||||
|
||||
// Sse::new(event_stream.boxed()).keep_alive(
|
||||
// axum::response::sse::KeepAlive::new()
|
||||
// .interval(Duration::from_secs(15))
|
||||
// .text("keep-alive"),
|
||||
// )
|
||||
// }
|
||||
|
||||
Reference in New Issue
Block a user