mirror of
https://github.com/perstarkse/minne.git
synced 2026-07-01 02:21:34 +02:00
refactor: better separation of dependencies to crates
node stuff to html crate only
This commit is contained in:
@@ -0,0 +1,226 @@
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::HeaderValue,
|
||||
response::{IntoResponse, Redirect},
|
||||
Form,
|
||||
};
|
||||
use axum_session_auth::AuthSession;
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::types::{
|
||||
conversation::Conversation,
|
||||
message::{Message, MessageRole},
|
||||
user::User,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
html_state::HtmlState,
|
||||
middlewares::{
|
||||
auth_middleware::RequireUser,
|
||||
response_middleware::{HtmlError, TemplateResponse},
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ChatStartParams {
|
||||
user_query: String,
|
||||
llm_response: String,
|
||||
#[serde(deserialize_with = "deserialize_references")]
|
||||
references: Vec<String>,
|
||||
}
|
||||
|
||||
// Custom deserializer function
|
||||
fn deserialize_references<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
serde_json::from_str(&s).map_err(serde::de::Error::custom)
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ChatPageData {
|
||||
user: User,
|
||||
history: Vec<Message>,
|
||||
conversation: Option<Conversation>,
|
||||
conversation_archive: Vec<Conversation>,
|
||||
}
|
||||
|
||||
pub async fn show_initialized_chat(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Form(form): Form<ChatStartParams>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let conversation = Conversation::new(user.id.clone(), "Test".to_owned());
|
||||
|
||||
let user_message = Message::new(
|
||||
conversation.id.to_string(),
|
||||
MessageRole::User,
|
||||
form.user_query,
|
||||
None,
|
||||
);
|
||||
|
||||
let ai_message = Message::new(
|
||||
conversation.id.to_string(),
|
||||
MessageRole::AI,
|
||||
form.llm_response,
|
||||
Some(form.references),
|
||||
);
|
||||
|
||||
state.db.store_item(conversation.clone()).await?;
|
||||
state.db.store_item(ai_message.clone()).await?;
|
||||
state.db.store_item(user_message.clone()).await?;
|
||||
|
||||
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
|
||||
|
||||
let messages = vec![user_message, ai_message];
|
||||
|
||||
let mut response = TemplateResponse::new_template(
|
||||
"chat/base.html",
|
||||
ChatPageData {
|
||||
history: messages,
|
||||
user,
|
||||
conversation_archive,
|
||||
conversation: Some(conversation.clone()),
|
||||
},
|
||||
)
|
||||
.into_response();
|
||||
|
||||
response.headers_mut().insert(
|
||||
"HX-Push",
|
||||
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(),
|
||||
);
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub async fn show_chat_base(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
"chat/base.html",
|
||||
ChatPageData {
|
||||
history: vec![],
|
||||
user,
|
||||
conversation_archive,
|
||||
conversation: None,
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct NewMessageForm {
|
||||
content: String,
|
||||
}
|
||||
|
||||
pub async fn show_existing_chat(
|
||||
Path(conversation_id): Path<String>,
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
|
||||
|
||||
let (conversation, messages) =
|
||||
Conversation::get_complete_conversation(conversation_id.as_str(), &user.id, &state.db)
|
||||
.await?;
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
"chat/base.html",
|
||||
ChatPageData {
|
||||
history: messages,
|
||||
user,
|
||||
conversation: Some(conversation.clone()),
|
||||
conversation_archive,
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
pub async fn new_user_message(
|
||||
Path(conversation_id): Path<String>,
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Form(form): Form<NewMessageForm>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let conversation: Conversation = state
|
||||
.db
|
||||
.get_item(&conversation_id)
|
||||
.await?
|
||||
.ok_or_else(|| AppError::NotFound("Conversation was not found".into()))?;
|
||||
|
||||
if conversation.user_id != user.id {
|
||||
return Ok(TemplateResponse::unauthorized().into_response());
|
||||
};
|
||||
|
||||
let user_message = Message::new(conversation_id, MessageRole::User, form.content, None);
|
||||
|
||||
state.db.store_item(user_message.clone()).await?;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct SSEResponseInitData {
|
||||
user_message: Message,
|
||||
}
|
||||
|
||||
let mut response = TemplateResponse::new_template(
|
||||
"chat/streaming_response.html",
|
||||
SSEResponseInitData { user_message },
|
||||
)
|
||||
.into_response();
|
||||
|
||||
response.headers_mut().insert(
|
||||
"HX-Push",
|
||||
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(),
|
||||
);
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub async fn new_chat_user_message(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
Form(form): Form<NewMessageForm>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let user = match auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/").into_response()),
|
||||
};
|
||||
|
||||
let conversation = Conversation::new(user.id, "New chat".to_string());
|
||||
let user_message = Message::new(
|
||||
conversation.id.clone(),
|
||||
MessageRole::User,
|
||||
form.content,
|
||||
None,
|
||||
);
|
||||
|
||||
state.db.store_item(conversation.clone()).await?;
|
||||
state.db.store_item(user_message.clone()).await?;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct SSEResponseInitData {
|
||||
user_message: Message,
|
||||
conversation: Conversation,
|
||||
}
|
||||
let mut response = TemplateResponse::new_template(
|
||||
"chat/new_chat_first_response.html",
|
||||
SSEResponseInitData {
|
||||
user_message,
|
||||
conversation: conversation.clone(),
|
||||
},
|
||||
)
|
||||
.into_response();
|
||||
|
||||
response.headers_mut().insert(
|
||||
"HX-Push",
|
||||
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(),
|
||||
);
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
@@ -0,0 +1,355 @@
|
||||
use std::{pin::Pin, sync::Arc, time::Duration};
|
||||
|
||||
use async_stream::stream;
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
response::{
|
||||
sse::{Event, KeepAlive},
|
||||
Sse,
|
||||
},
|
||||
};
|
||||
use axum_session_auth::AuthSession;
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use composite_retrieval::{
|
||||
answer_retrieval::{
|
||||
create_chat_request, create_user_message_with_history, format_entities_json,
|
||||
LLMResponseFormat,
|
||||
},
|
||||
retrieve_entities,
|
||||
};
|
||||
use futures::{
|
||||
stream::{self, once},
|
||||
Stream, StreamExt, TryStreamExt,
|
||||
};
|
||||
use json_stream_parser::JsonStreamParser;
|
||||
use minijinja::Value;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::from_str;
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
use tokio::sync::{mpsc::channel, Mutex};
|
||||
use tracing::{debug, error};
|
||||
|
||||
use common::storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{
|
||||
conversation::Conversation,
|
||||
message::{Message, MessageRole},
|
||||
system_settings::SystemSettings,
|
||||
user::User,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::html_state::HtmlState;
|
||||
|
||||
// Error handling function
|
||||
fn create_error_stream(
|
||||
message: impl Into<String>,
|
||||
) -> Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>> {
|
||||
let message = message.into();
|
||||
stream::once(async move { Ok(Event::default().event("error").data(message)) }).boxed()
|
||||
}
|
||||
|
||||
// Helper function to get message and user
|
||||
async fn get_message_and_user(
|
||||
db: &SurrealDbClient,
|
||||
current_user: Option<User>,
|
||||
message_id: &str,
|
||||
) -> Result<
|
||||
(Message, User, Conversation, Vec<Message>),
|
||||
Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>>,
|
||||
> {
|
||||
// Check authentication
|
||||
let user = match current_user {
|
||||
Some(user) => user,
|
||||
None => {
|
||||
return Err(Sse::new(create_error_stream(
|
||||
"You must be signed in to use this feature",
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
// Retrieve message
|
||||
let message = match db.get_item::<Message>(message_id).await {
|
||||
Ok(Some(message)) => message,
|
||||
Ok(None) => {
|
||||
return Err(Sse::new(create_error_stream(
|
||||
"Message not found: the specified message does not exist",
|
||||
)))
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Database error retrieving message {}: {:?}", message_id, e);
|
||||
return Err(Sse::new(create_error_stream(
|
||||
"Failed to retrieve message: database error",
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
// Get conversation history
|
||||
let (conversation, mut history) =
|
||||
match Conversation::get_complete_conversation(&message.conversation_id, &user.id, db).await
|
||||
{
|
||||
Err(e) => {
|
||||
error!("Database error retrieving message {}: {:?}", message_id, e);
|
||||
return Err(Sse::new(create_error_stream(
|
||||
"Failed to retrieve message: database error",
|
||||
)));
|
||||
}
|
||||
Ok((conversation, history)) => (conversation, history),
|
||||
};
|
||||
|
||||
// Remove the last message, its the same as the message
|
||||
history.pop();
|
||||
|
||||
Ok((message, user, conversation, history))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct QueryParams {
|
||||
message_id: String,
|
||||
}
|
||||
|
||||
pub async fn get_response_stream(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
Query(params): Query<QueryParams>,
|
||||
) -> Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>> {
|
||||
// 1. Authentication and initial data validation
|
||||
let (user_message, user, _conversation, history) =
|
||||
match get_message_and_user(&state.db, auth.current_user, ¶ms.message_id).await {
|
||||
Ok((user_message, user, conversation, history)) => {
|
||||
(user_message, user, conversation, history)
|
||||
}
|
||||
Err(error_stream) => return error_stream,
|
||||
};
|
||||
|
||||
// 2. Retrieve knowledge entities
|
||||
let entities = match retrieve_entities(
|
||||
&state.db,
|
||||
&state.openai_client,
|
||||
&user_message.content,
|
||||
&user.id,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(entities) => entities,
|
||||
Err(_e) => {
|
||||
return Sse::new(create_error_stream("Failed to retrieve knowledge entities"));
|
||||
}
|
||||
};
|
||||
|
||||
// 3. Create the OpenAI request
|
||||
let entities_json = format_entities_json(&entities);
|
||||
let formatted_user_message =
|
||||
create_user_message_with_history(&entities_json, &history, &user_message.content);
|
||||
let settings = match SystemSettings::get_current(&state.db).await {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
return Sse::new(create_error_stream("Failed to retrieve system settings"));
|
||||
}
|
||||
};
|
||||
let request = match create_chat_request(formatted_user_message, &settings) {
|
||||
Ok(req) => req,
|
||||
Err(..) => {
|
||||
return Sse::new(create_error_stream("Failed to create chat request"));
|
||||
}
|
||||
};
|
||||
|
||||
// 4. Set up the OpenAI stream
|
||||
let openai_stream = match state.openai_client.chat().create_stream(request).await {
|
||||
Ok(stream) => stream,
|
||||
Err(_e) => {
|
||||
return Sse::new(create_error_stream("Failed to create OpenAI stream"));
|
||||
}
|
||||
};
|
||||
|
||||
// 5. Create channel for collecting complete response
|
||||
let (tx, mut rx) = channel::<String>(1000);
|
||||
let tx_clone = tx.clone();
|
||||
let (tx_final, mut rx_final) = channel::<Message>(1);
|
||||
|
||||
// 6. Set up the collection task for DB storage
|
||||
let db_client = state.db.clone();
|
||||
tokio::spawn(async move {
|
||||
drop(tx); // Close sender when no longer needed
|
||||
|
||||
// Collect full response
|
||||
let mut full_json = String::new();
|
||||
while let Some(chunk) = rx.recv().await {
|
||||
full_json.push_str(&chunk);
|
||||
}
|
||||
|
||||
// Try to extract structured data
|
||||
if let Ok(response) = from_str::<LLMResponseFormat>(&full_json) {
|
||||
let references: Vec<String> = response
|
||||
.references
|
||||
.into_iter()
|
||||
.map(|r| r.reference)
|
||||
.collect();
|
||||
|
||||
let ai_message = Message::new(
|
||||
user_message.conversation_id,
|
||||
MessageRole::AI,
|
||||
response.answer,
|
||||
Some(references),
|
||||
);
|
||||
|
||||
let _ = tx_final.send(ai_message.clone()).await;
|
||||
|
||||
match db_client.store_item(ai_message).await {
|
||||
Ok(_) => debug!("Successfully stored AI message with references"),
|
||||
Err(e) => error!("Failed to store AI message: {:?}", e),
|
||||
}
|
||||
} else {
|
||||
error!("Failed to parse LLM response as structured format");
|
||||
|
||||
// Fallback - store raw response
|
||||
let ai_message = Message::new(
|
||||
user_message.conversation_id,
|
||||
MessageRole::AI,
|
||||
full_json,
|
||||
None,
|
||||
);
|
||||
|
||||
let _ = db_client.store_item(ai_message).await;
|
||||
}
|
||||
});
|
||||
|
||||
// Create a shared state for tracking the JSON parsing
|
||||
let json_state = Arc::new(Mutex::new(StreamParserState::new()));
|
||||
|
||||
// 7. Create the response event stream
|
||||
let event_stream = openai_stream
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
|
||||
.map(move |result| {
|
||||
let tx_storage = tx_clone.clone();
|
||||
let json_state = json_state.clone();
|
||||
|
||||
stream! {
|
||||
match result {
|
||||
Ok(response) => {
|
||||
let content = response
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.delta.content.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
if !content.is_empty() {
|
||||
// Always send raw content to storage
|
||||
let _ = tx_storage.send(content.clone()).await;
|
||||
|
||||
// Process through JSON parser
|
||||
let mut state = json_state.lock().await;
|
||||
let display_content = state.process_chunk(&content);
|
||||
drop(state);
|
||||
if !display_content.is_empty() {
|
||||
yield Ok(Event::default()
|
||||
.event("chat_message")
|
||||
.data(display_content));
|
||||
}
|
||||
// If display_content is empty, don't yield anything
|
||||
}
|
||||
// If content is empty, don't yield anything
|
||||
}
|
||||
Err(e) => {
|
||||
yield Ok(Event::default()
|
||||
.event("error")
|
||||
.data(format!("Stream error: {}", e)));
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.flatten()
|
||||
.chain(stream::once(async move {
|
||||
if let Some(message) = rx_final.recv().await {
|
||||
// Don't send any event if references is empty
|
||||
if message.references.as_ref().is_some_and(|x| x.is_empty()) {
|
||||
return Ok(Event::default().event("empty")); // This event won't be sent
|
||||
}
|
||||
|
||||
// Prepare data for template
|
||||
#[derive(Serialize)]
|
||||
struct ReferenceData {
|
||||
message: Message,
|
||||
}
|
||||
|
||||
// Render template with references
|
||||
match state.templates.render(
|
||||
"chat/reference_list.html",
|
||||
&Value::from_serialize(ReferenceData { message }),
|
||||
) {
|
||||
Ok(html) => {
|
||||
// Return the rendered HTML
|
||||
Ok(Event::default().event("references").data(html))
|
||||
}
|
||||
Err(_) => {
|
||||
// Handle template rendering error
|
||||
Ok(Event::default()
|
||||
.event("error")
|
||||
.data("Failed to render references"))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Handle case where no references were received
|
||||
Ok(Event::default()
|
||||
.event("error")
|
||||
.data("Failed to retrieve references"))
|
||||
}
|
||||
}))
|
||||
.chain(once(async {
|
||||
Ok(Event::default()
|
||||
.event("close_stream")
|
||||
.data("Stream complete"))
|
||||
}));
|
||||
|
||||
Sse::new(event_stream.boxed()).keep_alive(
|
||||
KeepAlive::new()
|
||||
.interval(Duration::from_secs(15))
|
||||
.text("keep-alive"),
|
||||
)
|
||||
}
|
||||
|
||||
struct StreamParserState {
|
||||
parser: JsonStreamParser,
|
||||
last_answer_content: String,
|
||||
in_answer_field: bool,
|
||||
}
|
||||
|
||||
impl StreamParserState {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
parser: JsonStreamParser::new(),
|
||||
last_answer_content: String::new(),
|
||||
in_answer_field: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn process_chunk(&mut self, chunk: &str) -> String {
|
||||
// Feed all characters into the parser
|
||||
for c in chunk.chars() {
|
||||
let _ = self.parser.add_char(c);
|
||||
}
|
||||
|
||||
// Get the current state of the JSON
|
||||
let json = self.parser.get_result();
|
||||
|
||||
// Check if we're in the answer field
|
||||
if let Some(obj) = json.as_object() {
|
||||
if let Some(answer) = obj.get("answer") {
|
||||
self.in_answer_field = true;
|
||||
|
||||
// Get current answer content
|
||||
let current_content = answer.as_str().unwrap_or_default().to_string();
|
||||
|
||||
// Calculate difference to send only new content
|
||||
if current_content.len() > self.last_answer_content.len() {
|
||||
let new_content = current_content[self.last_answer_content.len()..].to_string();
|
||||
self.last_answer_content = current_content;
|
||||
return new_content;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No new content to return
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
mod chat_handlers;
|
||||
mod message_response_stream;
|
||||
mod references;
|
||||
|
||||
use axum::{
|
||||
extract::FromRef,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use chat_handlers::{
|
||||
new_chat_user_message, new_user_message, show_chat_base, show_existing_chat,
|
||||
show_initialized_chat,
|
||||
};
|
||||
use message_response_stream::get_response_stream;
|
||||
use references::show_reference_tooltip;
|
||||
|
||||
use crate::html_state::HtmlState;
|
||||
|
||||
pub fn router<S>() -> Router<S>
|
||||
where
|
||||
S: Clone + Send + Sync + 'static,
|
||||
HtmlState: FromRef<S>,
|
||||
{
|
||||
Router::new()
|
||||
.route("/chat", get(show_chat_base).post(new_chat_user_message))
|
||||
.route("/chat/:id", get(show_existing_chat).post(new_user_message))
|
||||
.route("/initialized-chat", post(show_initialized_chat))
|
||||
.route("/chat/response-stream", get(get_response_stream))
|
||||
.route("/chat/reference/:id", get(show_reference_tooltip))
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use serde::Serialize;
|
||||
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::types::{knowledge_entity::KnowledgeEntity, user::User},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
html_state::HtmlState,
|
||||
middlewares::{
|
||||
auth_middleware::RequireUser,
|
||||
response_middleware::{HtmlError, TemplateResponse},
|
||||
},
|
||||
};
|
||||
|
||||
pub async fn show_reference_tooltip(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Path(reference_id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let entity: KnowledgeEntity = state
|
||||
.db
|
||||
.get_item(&reference_id)
|
||||
.await?
|
||||
.ok_or_else(|| AppError::NotFound("Item was not found".to_string()))?;
|
||||
|
||||
if entity.user_id != user.id {
|
||||
return Ok(TemplateResponse::unauthorized());
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ReferenceTooltipData {
|
||||
entity: KnowledgeEntity,
|
||||
user: User,
|
||||
}
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
"chat/reference_tooltip.html",
|
||||
ReferenceTooltipData { entity, user },
|
||||
))
|
||||
}
|
||||
Reference in New Issue
Block a user