mirror of
https://github.com/perstarkse/minne.git
synced 2026-03-20 16:44:12 +01:00
feat: streaming chat response and persistant chats
This commit is contained in:
10
Cargo.lock
generated
10
Cargo.lock
generated
@@ -2427,6 +2427,15 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "json-stream-parser"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8a70ab2b05e827e0604229fcf11b24560b036a21286a41517a6cac271f12a6a9"
|
||||
dependencies = [
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "json5"
|
||||
version = "0.4.1"
|
||||
@@ -6042,6 +6051,7 @@ dependencies = [
|
||||
"chrono-tz",
|
||||
"config",
|
||||
"futures",
|
||||
"json-stream-parser",
|
||||
"lettre",
|
||||
"mime",
|
||||
"mime_guess",
|
||||
|
||||
@@ -16,6 +16,7 @@ chrono = { version = "0.4.39", features = ["serde"] }
|
||||
chrono-tz = "0.10.1"
|
||||
config = "0.15.4"
|
||||
futures = "0.3.31"
|
||||
json-stream-parser = "0.1.4"
|
||||
lettre = { version = "0.11.11", features = ["rustls-tls"] }
|
||||
mime = "0.3.17"
|
||||
mime_guess = "2.0.5"
|
||||
|
||||
@@ -2541,6 +2541,21 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
.z-0 {
|
||||
z-index: 0;
|
||||
}
|
||||
.z-1 {
|
||||
z-index: 1;
|
||||
}
|
||||
.z-2 {
|
||||
z-index: 2;
|
||||
}
|
||||
.z-3 {
|
||||
z-index: 3;
|
||||
}
|
||||
.z-5 {
|
||||
z-index: 5;
|
||||
}
|
||||
.z-20 {
|
||||
z-index: 20;
|
||||
}
|
||||
@@ -4074,6 +4089,9 @@
|
||||
width: calc(var(--spacing) * 10);
|
||||
height: calc(var(--spacing) * 10);
|
||||
}
|
||||
.h-4 {
|
||||
height: calc(var(--spacing) * 4);
|
||||
}
|
||||
.h-5 {
|
||||
height: calc(var(--spacing) * 5);
|
||||
}
|
||||
@@ -4433,6 +4451,9 @@
|
||||
.p-2 {
|
||||
padding: calc(var(--spacing) * 2);
|
||||
}
|
||||
.p-3 {
|
||||
padding: calc(var(--spacing) * 3);
|
||||
}
|
||||
.p-4 {
|
||||
padding: calc(var(--spacing) * 4);
|
||||
}
|
||||
|
||||
328
src/server/routes/html/chat/message_response_stream.rs
Normal file
328
src/server/routes/html/chat/message_response_stream.rs
Normal file
@@ -0,0 +1,328 @@
|
||||
use std::{pin::Pin, time::Duration};
|
||||
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
response::{sse::Event, Sse},
|
||||
};
|
||||
use axum_session_auth::AuthSession;
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use futures::{stream, Stream, StreamExt, TryStreamExt};
|
||||
use json_stream_parser::JsonStreamParser;
|
||||
use serde::Deserialize;
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::{
|
||||
retrieval::{
|
||||
combined_knowledge_entity_retrieval,
|
||||
query_helper::{
|
||||
create_chat_request, create_user_message, format_entities_json, LLMResponseFormat,
|
||||
},
|
||||
},
|
||||
server::AppState,
|
||||
storage::{
|
||||
db::{get_item, store_item, SurrealDbClient},
|
||||
types::{
|
||||
message::{Message, MessageRole},
|
||||
user::User,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
fn create_error_stream(
|
||||
message: impl Into<String>,
|
||||
) -> Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>> {
|
||||
let message = message.into();
|
||||
stream::once(async move { Ok(Event::default().event("error").data(message)) }).boxed()
|
||||
}
|
||||
|
||||
// Helper function to get message and user
|
||||
async fn get_message_and_user(
|
||||
db: &SurrealDbClient,
|
||||
current_user: Option<User>,
|
||||
message_id: &str,
|
||||
) -> Result<(Message, User), Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>>> {
|
||||
// Check authentication
|
||||
let user = match current_user {
|
||||
Some(user) => user,
|
||||
None => {
|
||||
return Err(Sse::new(create_error_stream(
|
||||
"You must be signed in to use this feature",
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
// Retrieve message
|
||||
let message = match get_item::<Message>(db, message_id).await {
|
||||
Ok(Some(message)) => message,
|
||||
Ok(None) => {
|
||||
return Err(Sse::new(create_error_stream(
|
||||
"Message not found: the specified message does not exist",
|
||||
)))
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Database error retrieving message {}: {:?}", message_id, e);
|
||||
return Err(Sse::new(create_error_stream(
|
||||
"Failed to retrieve message: database error",
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
Ok((message, user))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct QueryParams {
|
||||
message_id: String,
|
||||
}
|
||||
|
||||
pub async fn get_response_stream(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
Query(params): Query<QueryParams>,
|
||||
) -> Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>> {
|
||||
// 1. Authentication and initial data validation
|
||||
let (user_message, user) = match get_message_and_user(
|
||||
&state.surreal_db_client,
|
||||
auth.current_user,
|
||||
¶ms.message_id,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok((user_message, user)) => (user_message, user),
|
||||
Err(error_stream) => return error_stream,
|
||||
};
|
||||
|
||||
// 2. Retrieve knowledge entities
|
||||
let entities = match combined_knowledge_entity_retrieval(
|
||||
&state.surreal_db_client,
|
||||
&state.openai_client,
|
||||
&user_message.content,
|
||||
&user.id,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(entities) => entities,
|
||||
Err(_e) => {
|
||||
return Sse::new(create_error_stream("Failed to retrieve knowledge entities"));
|
||||
}
|
||||
};
|
||||
|
||||
// 3. Create the OpenAI request
|
||||
let entities_json = format_entities_json(&entities);
|
||||
let formatted_user_message = create_user_message(&entities_json, &user_message.content);
|
||||
let request = match create_chat_request(formatted_user_message) {
|
||||
Ok(req) => req,
|
||||
Err(..) => {
|
||||
return Sse::new(create_error_stream("Failed to create chat request"));
|
||||
}
|
||||
};
|
||||
|
||||
// 4. Set up the OpenAI stream
|
||||
let openai_stream = match state.openai_client.chat().create_stream(request).await {
|
||||
Ok(stream) => stream,
|
||||
Err(_e) => {
|
||||
return Sse::new(create_error_stream("Failed to create OpenAI stream"));
|
||||
}
|
||||
};
|
||||
|
||||
// 5. Create channel for collecting complete response
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel::<String>(1000);
|
||||
let tx_clone = tx.clone();
|
||||
|
||||
// 6. Set up the collection task for DB storage
|
||||
let db_client = state.surreal_db_client.clone();
|
||||
tokio::spawn(async move {
|
||||
drop(tx); // Close sender when no longer needed
|
||||
|
||||
// Collect full response
|
||||
let mut full_json = String::new();
|
||||
while let Some(chunk) = rx.recv().await {
|
||||
full_json.push_str(&chunk);
|
||||
}
|
||||
|
||||
// Try to extract structured data
|
||||
if let Ok(response) = serde_json::from_str::<LLMResponseFormat>(&full_json) {
|
||||
let references: Vec<String> = response
|
||||
.references
|
||||
.into_iter()
|
||||
.map(|r| r.reference)
|
||||
.collect();
|
||||
|
||||
let ai_message = Message::new(
|
||||
user_message.conversation_id,
|
||||
MessageRole::AI,
|
||||
response.answer,
|
||||
Some(references),
|
||||
);
|
||||
|
||||
match store_item(&db_client, ai_message).await {
|
||||
Ok(_) => info!("Successfully stored AI message with references"),
|
||||
Err(e) => error!("Failed to store AI message: {:?}", e),
|
||||
}
|
||||
} else {
|
||||
error!("Failed to parse LLM response as structured format");
|
||||
|
||||
// Fallback - store raw response
|
||||
let ai_message = Message::new(
|
||||
user_message.conversation_id,
|
||||
MessageRole::AI,
|
||||
full_json,
|
||||
Some(vec![]),
|
||||
);
|
||||
|
||||
let _ = store_item(&db_client, ai_message).await;
|
||||
}
|
||||
});
|
||||
|
||||
// Create a shared state for tracking the JSON parsing
|
||||
let json_state = std::sync::Arc::new(tokio::sync::Mutex::new(StreamParserState::new()));
|
||||
|
||||
// 7. Create the response event stream
|
||||
let event_stream = openai_stream
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
|
||||
.map(move |result| {
|
||||
let tx_storage = tx_clone.clone();
|
||||
let json_state = json_state.clone();
|
||||
|
||||
async move {
|
||||
match result {
|
||||
Ok(response) => {
|
||||
let content = response
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.delta.content.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
if !content.is_empty() {
|
||||
// Always send raw content to storage
|
||||
let _ = tx_storage.send(content.clone()).await;
|
||||
|
||||
// Process through JSON parser
|
||||
let mut state = json_state.lock().await;
|
||||
|
||||
let display_content = state.process_chunk(&content);
|
||||
drop(state);
|
||||
if !display_content.is_empty() {
|
||||
return Ok(Event::default()
|
||||
.event("chat_message")
|
||||
.data(display_content));
|
||||
}
|
||||
|
||||
// Empty or filtered content
|
||||
Ok(Event::default().event("chat_message").data(""))
|
||||
} else {
|
||||
Ok(Event::default().event("chat_message").data(""))
|
||||
}
|
||||
}
|
||||
Err(e) => Ok(Event::default()
|
||||
.event("error")
|
||||
.data(format!("Stream error: {}", e))),
|
||||
}
|
||||
}
|
||||
})
|
||||
.buffered(10)
|
||||
.chain(stream::once(async {
|
||||
Ok(Event::default()
|
||||
.event("close_stream")
|
||||
.data("Stream complete"))
|
||||
}));
|
||||
|
||||
info!("OpenAI streaming started");
|
||||
Sse::new(event_stream.boxed()).keep_alive(
|
||||
axum::response::sse::KeepAlive::new()
|
||||
.interval(Duration::from_secs(15))
|
||||
.text("keep-alive"),
|
||||
)
|
||||
}
|
||||
|
||||
// Replace JsonParseState with StreamParserState
|
||||
struct StreamParserState {
|
||||
parser: JsonStreamParser,
|
||||
last_answer_content: String,
|
||||
in_answer_field: bool,
|
||||
}
|
||||
|
||||
impl StreamParserState {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
parser: JsonStreamParser::new(),
|
||||
last_answer_content: String::new(),
|
||||
in_answer_field: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn process_chunk(&mut self, chunk: &str) -> String {
|
||||
// Feed all characters into the parser
|
||||
for c in chunk.chars() {
|
||||
let _ = self.parser.add_char(c);
|
||||
}
|
||||
|
||||
// Get the current state of the JSON
|
||||
// The get_result() method returns a &Value, not a Result
|
||||
let json = self.parser.get_result();
|
||||
|
||||
// Check if we're in the answer field
|
||||
if let Some(obj) = json.as_object() {
|
||||
if let Some(answer) = obj.get("answer") {
|
||||
self.in_answer_field = true;
|
||||
|
||||
// Get current answer content
|
||||
let current_content = answer.as_str().unwrap_or_default().to_string();
|
||||
|
||||
// Calculate difference to send only new content
|
||||
if current_content.len() > self.last_answer_content.len() {
|
||||
let new_content = current_content[self.last_answer_content.len()..].to_string();
|
||||
self.last_answer_content = current_content;
|
||||
return new_content;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No new content to return
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
|
||||
// 7. Create the response event stream
|
||||
// let event_stream = openai_stream
|
||||
// .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
|
||||
// .map(move |result| {
|
||||
// let tx = tx_clone.clone();
|
||||
// async move {
|
||||
// match result {
|
||||
// Ok(response) => {
|
||||
// let content = response
|
||||
// .choices
|
||||
// .first()
|
||||
// .and_then(|choice| choice.delta.content.clone())
|
||||
// .unwrap_or_default();
|
||||
|
||||
// if !content.is_empty() {
|
||||
// let _ = tx.send(content.clone()).await;
|
||||
// Ok(Event::default().event("chat_message").data(content))
|
||||
// } else {
|
||||
// Ok(Event::default().event("chat_message").data(""))
|
||||
// }
|
||||
// }
|
||||
// Err(e) => Ok(Event::default()
|
||||
// .event("error")
|
||||
// .data(format!("Stream error: {}", e))),
|
||||
// }
|
||||
// }
|
||||
// })
|
||||
// .buffered(10)
|
||||
// .chain(stream::once(async {
|
||||
// Ok(Event::default()
|
||||
// .event("close_stream")
|
||||
// .data("Stream complete"))
|
||||
// }));
|
||||
|
||||
// info!("OpenAI streaming started");
|
||||
|
||||
// Sse::new(event_stream.boxed()).keep_alive(
|
||||
// axum::response::sse::KeepAlive::new()
|
||||
// .interval(Duration::from_secs(15))
|
||||
// .text("keep-alive"),
|
||||
// )
|
||||
// }
|
||||
@@ -1,28 +1,27 @@
|
||||
use std::time::Duration;
|
||||
pub mod message_response_stream;
|
||||
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
response::{
|
||||
sse::{Event, KeepAlive},
|
||||
Html, IntoResponse, Redirect, Sse,
|
||||
},
|
||||
extract::{Path, State},
|
||||
http::HeaderValue,
|
||||
response::{IntoResponse, Redirect},
|
||||
Form,
|
||||
};
|
||||
use axum_session_auth::AuthSession;
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use futures::{stream, Stream, StreamExt};
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
use tokio::time::sleep;
|
||||
use tracing::info;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
error::HtmlError,
|
||||
error::{AppError, HtmlError},
|
||||
page_data,
|
||||
server::{routes::html::render_template, AppState},
|
||||
storage::types::{
|
||||
message::{Message, MessageRole},
|
||||
user::User,
|
||||
storage::{
|
||||
db::store_item,
|
||||
types::{
|
||||
conversation::Conversation,
|
||||
message::{Message, MessageRole},
|
||||
user::User,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
@@ -47,7 +46,8 @@ where
|
||||
page_data!(ChatData, "chat/base.html", {
|
||||
user: User,
|
||||
history: Vec<Message>,
|
||||
conversation_id: String,
|
||||
conversation: Conversation,
|
||||
conversation_archive: Vec<Conversation>
|
||||
});
|
||||
|
||||
pub async fn show_initialized_chat(
|
||||
@@ -62,19 +62,37 @@ pub async fn show_initialized_chat(
|
||||
None => return Ok(Redirect::to("/").into_response()),
|
||||
};
|
||||
|
||||
info!("{:?}", form);
|
||||
let conversation = Conversation::new(user.id.clone(), "Test".to_owned());
|
||||
|
||||
let conversation_id = Uuid::new_v4().to_string();
|
||||
|
||||
let user_message = Message::new("test".to_string(), MessageRole::User, form.user_query, None);
|
||||
let user_message = Message::new(
|
||||
conversation.id.to_string(),
|
||||
MessageRole::User,
|
||||
form.user_query,
|
||||
None,
|
||||
);
|
||||
|
||||
let ai_message = Message::new(
|
||||
"test".to_string(),
|
||||
conversation.id.to_string(),
|
||||
MessageRole::AI,
|
||||
form.llm_response,
|
||||
Some(form.references),
|
||||
);
|
||||
|
||||
let (conversation_result, ai_message_result, user_message_result) = futures::join!(
|
||||
store_item(&state.surreal_db_client, conversation.clone()),
|
||||
store_item(&state.surreal_db_client, ai_message.clone()),
|
||||
store_item(&state.surreal_db_client, user_message.clone())
|
||||
);
|
||||
|
||||
// Check each result individually
|
||||
conversation_result.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
|
||||
user_message_result.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
|
||||
ai_message_result.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
|
||||
|
||||
let conversation_archive = User::get_user_conversations(&user.id, &state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let messages = vec![user_message, ai_message];
|
||||
|
||||
let output = render_template(
|
||||
@@ -82,12 +100,18 @@ pub async fn show_initialized_chat(
|
||||
ChatData {
|
||||
history: messages,
|
||||
user,
|
||||
conversation_id,
|
||||
conversation_archive,
|
||||
conversation: conversation.clone(),
|
||||
},
|
||||
state.templates.clone(),
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
let mut response = output.into_response();
|
||||
response.headers_mut().insert(
|
||||
"HX-Push",
|
||||
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(),
|
||||
);
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub async fn show_chat_base(
|
||||
@@ -101,14 +125,19 @@ pub async fn show_chat_base(
|
||||
None => return Ok(Redirect::to("/").into_response()),
|
||||
};
|
||||
|
||||
let conversation_id = Uuid::new_v4().to_string();
|
||||
let conversation_archive = User::get_user_conversations(&user.id, &state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let conversation = Conversation::new(user.id.clone(), "New Chat".to_string());
|
||||
|
||||
let output = render_template(
|
||||
ChatData::template_name(),
|
||||
ChatData {
|
||||
history: vec![],
|
||||
user,
|
||||
conversation_id,
|
||||
conversation_archive,
|
||||
conversation,
|
||||
},
|
||||
state.templates.clone(),
|
||||
)?;
|
||||
@@ -121,39 +150,37 @@ pub struct NewMessageForm {
|
||||
content: String,
|
||||
}
|
||||
|
||||
pub async fn new_user_message(
|
||||
pub async fn show_existing_chat(
|
||||
Path(conversation_id): Path<String>,
|
||||
State(state): State<AppState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
Form(form): Form<NewMessageForm>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
info!("Displaying empty chat start");
|
||||
info!("Displaying initialized chat with id: {}", conversation_id);
|
||||
|
||||
let user = match auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/").into_response()),
|
||||
};
|
||||
|
||||
let query_id = Uuid::new_v4().to_string();
|
||||
let user_message = form.content.clone();
|
||||
let conversation_archive = User::get_user_conversations(&user.id, &state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
// Save to database
|
||||
// state
|
||||
// .db
|
||||
// .save(conversation_id, query_id.clone(), user_message)
|
||||
// .await;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct SSEResponseInitData {
|
||||
user_message: String,
|
||||
query_id: String,
|
||||
}
|
||||
let (conversation, messages) = Conversation::get_complete_conversation(
|
||||
conversation_id.as_str(),
|
||||
&user.id,
|
||||
&state.surreal_db_client,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let output = render_template(
|
||||
"chat/streaming_response.html",
|
||||
SSEResponseInitData {
|
||||
user_message,
|
||||
query_id,
|
||||
ChatData::template_name(),
|
||||
ChatData {
|
||||
history: messages,
|
||||
user,
|
||||
conversation: conversation.clone(),
|
||||
conversation_archive,
|
||||
},
|
||||
state.templates.clone(),
|
||||
)?;
|
||||
@@ -161,38 +188,33 @@ pub async fn new_user_message(
|
||||
Ok(output.into_response())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct QueryParams {
|
||||
query_id: String,
|
||||
}
|
||||
|
||||
pub async fn get_response_stream(
|
||||
State(_state): State<AppState>,
|
||||
pub async fn new_user_message(
|
||||
Path(conversation_id): Path<String>,
|
||||
State(state): State<AppState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
Query(params): Query<QueryParams>,
|
||||
) -> Sse<impl Stream<Item = Result<Event, axum::Error>>> {
|
||||
let stream = stream::iter(vec![
|
||||
Event::default()
|
||||
.event("chat_message")
|
||||
.data("Hello, starting stream!"),
|
||||
Event::default()
|
||||
.event("chat_message")
|
||||
.data("This is message 2"),
|
||||
Event::default().event("chat_message").data("Final message"),
|
||||
Event::default()
|
||||
.event("close_stream")
|
||||
.data("Stream complete"), // Signal to close
|
||||
])
|
||||
.then(|event| async move {
|
||||
sleep(Duration::from_millis(500)).await; // Delay between messages
|
||||
Ok(event)
|
||||
});
|
||||
Form(form): Form<NewMessageForm>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let _user = match auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/").into_response()),
|
||||
};
|
||||
|
||||
info!("Streaming started");
|
||||
let user_message = Message::new(conversation_id, MessageRole::User, form.content, None);
|
||||
|
||||
Sse::new(stream).keep_alive(
|
||||
axum::response::sse::KeepAlive::new()
|
||||
.interval(Duration::from_secs(15))
|
||||
.text("keep-alive"),
|
||||
)
|
||||
store_item(&state.surreal_db_client, user_message.clone())
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct SSEResponseInitData {
|
||||
user_message: Message,
|
||||
}
|
||||
|
||||
let output = render_template(
|
||||
"chat/streaming_response.html",
|
||||
SSEResponseInitData { user_message },
|
||||
state.templates.clone(),
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
|
||||
@@ -16,7 +16,10 @@ use axum_session_surreal::SessionSurrealPool;
|
||||
use html::{
|
||||
account::{delete_account, set_api_key, show_account_page, update_timezone},
|
||||
admin_panel::{show_admin_panel, toggle_registration_status},
|
||||
chat::{get_response_stream, new_user_message, show_chat_base, show_initialized_chat},
|
||||
chat::{
|
||||
message_response_stream::get_response_stream, new_user_message, show_chat_base,
|
||||
show_existing_chat, show_initialized_chat,
|
||||
},
|
||||
content::{patch_text_content, show_content_page, show_text_content_edit_form},
|
||||
documentation::{
|
||||
show_documentation_index, show_get_started, show_mobile_friendly, show_privacy_policy,
|
||||
@@ -65,7 +68,7 @@ pub fn html_routes(app_state: &AppState) -> Router<AppState> {
|
||||
.route("/gdpr/deny", post(deny_gdpr))
|
||||
.route("/search", get(search_result_handler))
|
||||
.route("/chat", get(show_chat_base).post(show_initialized_chat))
|
||||
.route("/chat/:id", post(new_user_message))
|
||||
.route("/chat/:id", get(show_existing_chat).post(new_user_message))
|
||||
.route("/chat/response-stream", get(get_response_stream))
|
||||
.route("/signout", get(sign_out_user))
|
||||
.route("/signin", get(show_signin_form).post(authenticate_user))
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::stored_object;
|
||||
use crate::{
|
||||
error::AppError,
|
||||
storage::db::{get_item, SurrealDbClient},
|
||||
stored_object,
|
||||
};
|
||||
|
||||
use super::message::Message;
|
||||
|
||||
stored_object!(Conversation, "conversation", {
|
||||
user_id: String,
|
||||
@@ -18,4 +24,29 @@ impl Conversation {
|
||||
title,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_complete_conversation(
|
||||
conversation_id: &str,
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(Self, Vec<Message>), AppError> {
|
||||
let conversation: Conversation = get_item(&db, conversation_id)
|
||||
.await?
|
||||
.ok_or_else(|| AppError::NotFound("Conversation not found".to_string()))?;
|
||||
|
||||
if conversation.user_id != user_id {
|
||||
return Err(AppError::Auth(
|
||||
"You don't have access to this conversation".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let messages:Vec<Message> = db.client.
|
||||
query("SELECT * FROM type::table($table_name) WHERE conversation_id = $conversation_id ORDER BY updated_at").
|
||||
bind(("table_name", Message::table_name())).
|
||||
bind(("conversation_id", conversation_id.to_string()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
Ok((conversation, messages))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ use uuid::Uuid;
|
||||
|
||||
use crate::stored_object;
|
||||
|
||||
#[derive(Deserialize, Debug, Serialize)]
|
||||
#[derive(Deserialize, Debug, Clone, Serialize)]
|
||||
pub enum MessageRole {
|
||||
User,
|
||||
AI,
|
||||
|
||||
@@ -8,8 +8,9 @@ use surrealdb::{engine::any::Any, Surreal};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{
|
||||
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
|
||||
system_settings::SystemSettings, text_content::TextContent,
|
||||
conversation::Conversation, knowledge_entity::KnowledgeEntity,
|
||||
knowledge_relationship::KnowledgeRelationship, system_settings::SystemSettings,
|
||||
text_content::TextContent,
|
||||
};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -333,4 +334,19 @@ impl User {
|
||||
|
||||
Ok(text_content)
|
||||
}
|
||||
|
||||
pub async fn get_user_conversations(
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Vec<Conversation>, AppError> {
|
||||
let conversations: Vec<Conversation> = db
|
||||
.client
|
||||
.query("SELECT * FROM type::table($table_name) WHERE user_id = $user_id")
|
||||
.bind(("table_name", Conversation::table_name()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
Ok(conversations)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,10 @@
|
||||
"icons/edit_icon.html" %}
|
||||
</span></a></li>
|
||||
<div class="divider"></div>
|
||||
<li><a>Sidebar Item 1</a></li>
|
||||
<li><a>Sidebar Item 2</a></li>
|
||||
{% for conversation in conversation_archive %}
|
||||
<li><a href="/chat/{{conversation.id}}" hx-boost="true">{{conversation.title}} - {{conversation.created_at}}</a>
|
||||
</li>
|
||||
{% endfor %}
|
||||
{{conversation_archive}}
|
||||
</ul>
|
||||
</div>
|
||||
@@ -1,5 +1,5 @@
|
||||
<div class="fixed w-full mx-auto max-w-3xl p-4 pb-0 sm:pb-4 left-0 right-0 bottom-0">
|
||||
<form hx-post="/chat/{{conversation_id}}" hx-target="#chat_container" hx-swap="beforeend" class="relative flex gap-2"
|
||||
<form hx-post="/chat/{{conversation.id}}" hx-target="#chat_container" hx-swap="beforeend" class="relative flex gap-2"
|
||||
id="chat-form">
|
||||
<textarea autofocus required name="content" placeholder="Type your message..." rows="2"
|
||||
class="textarea textarea-ghost rounded-2xl rounded-b-none h-24 sm:rounded-b-2xl pr-8 bg-base-200 flex-grow resize-none"
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
<div class="chat chat-end">
|
||||
<div class="chat-header">User</div>
|
||||
<div class="chat-bubble">
|
||||
{{user_message}}
|
||||
{{user_message.content}}
|
||||
</div>
|
||||
</div>
|
||||
<div class="chat chat-start">
|
||||
<div class="chat-header">AI</div>
|
||||
<div class="chat-bubble" hx-ext="sse" sse-connect="/chat/response-stream?query_id={{query_id}}"
|
||||
<div class="chat-bubble" hx-ext="sse" sse-connect="/chat/response-stream?message_id={{user_message.id}}"
|
||||
sse-swap="chat_message" sse-close="close_stream" hx-swap="beforeend">
|
||||
<span class="loading loading-dots loading-sm loading-id-{{query_id}}"></span>
|
||||
<span class="loading loading-dots loading-sm loading-id-{{user_message.id}}"></span>
|
||||
</div>
|
||||
</div>
|
||||
<script>
|
||||
document.body.addEventListener('htmx:sseBeforeMessage', (e) => {
|
||||
const targetElement = e.detail.elt;
|
||||
const loadingSpinner = targetElement.querySelector('.loading-id-{{query_id}}');
|
||||
const loadingSpinner = targetElement.querySelector('.loading-id-{{user_message.id}}');
|
||||
|
||||
// Hiding the loading spinner before data is swapped in
|
||||
if (loadingSpinner) {
|
||||
|
||||
Reference in New Issue
Block a user