clippy: adhere to pedantic clippy, uniform test error handling

This commit is contained in:
Per Stark
2026-05-26 11:43:45 +02:00
parent e0068ebe26
commit 5ce7a76c75
68 changed files with 2468 additions and 2547 deletions
+30 -27
View File
@@ -49,6 +49,8 @@ pub struct ChatPageData {
conversation: Option<Conversation>,
}
/// # Panics
/// Panics if the HX-Push header value cannot be parsed.
pub async fn show_initialized_chat(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
@@ -57,14 +59,14 @@ pub async fn show_initialized_chat(
let conversation = Conversation::new(user.id.clone(), "Test".to_owned());
let user_message = Message::new(
conversation.id.to_string(),
conversation.id.clone(),
MessageRole::User,
form.user_query,
None,
);
let ai_message = Message::new(
conversation.id.to_string(),
conversation.id.clone(),
MessageRole::AI,
form.llm_response,
Some(form.references),
@@ -86,10 +88,9 @@ pub async fn show_initialized_chat(
)
.into_response();
response.headers_mut().insert(
"HX-Push",
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(),
);
if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) {
response.headers_mut().insert("HX-Push", header_value);
}
Ok(response)
}
@@ -130,12 +131,19 @@ pub async fn show_existing_chat(
))
}
/// # Panics
/// Panics if the HX-Push header value cannot be parsed.
pub async fn new_user_message(
Path(conversation_id): Path<String>,
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(form): Form<NewMessageForm>,
) -> Result<impl IntoResponse, HtmlError> {
#[derive(Serialize)]
struct SSEResponseInitData {
user_message: Message,
}
let conversation: Conversation = state
.db
.get_item(&conversation_id)
@@ -150,33 +158,34 @@ pub async fn new_user_message(
state.db.store_item(user_message.clone()).await?;
#[derive(Serialize)]
struct SSEResponseInitData {
user_message: Message,
}
let mut response = TemplateResponse::new_template(
"chat/streaming_response.html",
SSEResponseInitData { user_message },
)
.into_response();
response.headers_mut().insert(
"HX-Push",
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(),
);
if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) {
response.headers_mut().insert("HX-Push", header_value);
}
Ok(response)
}
/// # Panics
/// Panics if the HX-Push header value cannot be parsed.
pub async fn new_chat_user_message(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Form(form): Form<NewMessageForm>,
) -> Result<impl IntoResponse, HtmlError> {
let user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/").into_response()),
#[derive(Serialize)]
struct SSEResponseInitData {
user_message: Message,
conversation: Conversation,
}
let Some(user) = auth.current_user else {
return Ok(Redirect::to("/").into_response());
};
let conversation = Conversation::new(user.id.clone(), "New chat".to_string());
@@ -191,11 +200,6 @@ pub async fn new_chat_user_message(
state.db.store_item(user_message.clone()).await?;
state.invalidate_conversation_archive_cache(&user.id).await;
#[derive(Serialize)]
struct SSEResponseInitData {
user_message: Message,
conversation: Conversation,
}
let mut response = TemplateResponse::new_template(
"chat/new_chat_first_response.html",
SSEResponseInitData {
@@ -205,10 +209,9 @@ pub async fn new_chat_user_message(
)
.into_response();
response.headers_mut().insert(
"HX-Push",
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(),
);
if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) {
response.headers_mut().insert("HX-Push", header_value);
}
Ok(response.into_response())
}
@@ -53,26 +53,22 @@ fn sse_with_keep_alive(stream: EventStream) -> SseResponse {
)
}
// Error handling function
fn create_error_stream(message: impl Into<String>) -> EventStream {
let message = message.into();
stream::once(async move { Ok(Event::default().event("error").data(message)) }).boxed()
}
// Helper function to get message and user
async fn get_message_and_user(
db: &SurrealDbClient,
current_user: Option<User>,
message_id: &str,
) -> Result<(Message, User, Conversation, Vec<Message>, Option<Message>), SseResponse> {
// Check authentication
let Some(user) = current_user else {
return Err(sse_with_keep_alive(create_error_stream(
"You must be signed in to use this feature",
)));
};
// Retrieve message
let message = match db.get_item::<Message>(message_id).await {
Ok(Some(message)) => message,
Ok(None) => {
@@ -88,7 +84,6 @@ async fn get_message_and_user(
}
};
// Get conversation history
let (conversation, history) =
match Conversation::get_complete_conversation(&message.conversation_id, &user.id, db).await
{
@@ -209,7 +204,6 @@ pub async fn get_response_stream(
auth: AuthSessionType,
Query(params): Query<QueryParams>,
) -> SseResponse {
// 1. Authentication and initial data validation
let (user_message, user, _conversation, history, existing_ai_response) =
match get_message_and_user(&state.db, auth.current_user, &params.message_id).await {
Ok((user_message, user, conversation, history, existing_ai_response)) => (
@@ -226,9 +220,123 @@ pub async fn get_response_stream(
return create_replayed_response_stream(&state, existing_ai_message);
}
// 2. Retrieve knowledge entities
let (request, allowed_reference_ids) = match prepare_chat_request(&state, &user_message, &user, &history).await {
Ok(result) => result,
Err(sse) => return sse,
};
let openai_stream = match state.openai_client.chat().create_stream(request).await {
Ok(stream) => stream,
Err(_e) => {
return sse_with_keep_alive(create_error_stream("Failed to create OpenAI stream"));
}
};
build_chat_event_stream(state, openai_stream, &user_message, user.id.clone(), allowed_reference_ids)
}
fn build_chat_event_stream(
state: HtmlState,
openai_stream: impl Stream<Item = Result<async_openai::types::CreateChatCompletionStreamResponse, async_openai::error::OpenAIError>> + Send + 'static,
user_message: &Message,
user_id: String,
allowed_reference_ids: Vec<String>,
) -> SseResponse {
let (tx, rx) = channel::<String>(1000);
let (tx_final, mut rx_final) = channel::<Message>(1);
spawn_storage_task(Arc::clone(&state.db), rx, tx_final, user_message, user_id, allowed_reference_ids);
let json_state = Arc::new(Mutex::new(StreamParserState::new()));
let event_stream = openai_stream
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
.map(move |result| {
let tx_storage = tx.clone();
let json_state = Arc::clone(&json_state);
stream! {
match result {
Ok(response) => {
let content = response
.choices
.first()
.and_then(|choice| choice.delta.content.clone())
.unwrap_or_default();
if !content.is_empty() {
let _ = tx_storage.send(content.clone()).await;
let mut state = json_state.lock().await;
let display_content = state.process_chunk(&content);
drop(state);
if !display_content.is_empty() {
yield Ok(Event::default()
.event("chat_message")
.data(display_content));
}
}
}
Err(e) => {
yield Ok(Event::default()
.event("error")
.data(format!("Stream error: {e}")));
}
}
}
})
.flatten()
.chain(stream::once(async move {
#[derive(Serialize)]
struct LocalReferenceData {
message: Message,
}
if let Some(message) = rx_final.recv().await {
if message
.references
.as_ref()
.is_some_and(std::vec::Vec::is_empty)
{
return Ok(Event::default().event("empty"));
}
match state.templates.render(
"chat/reference_list.html",
&Value::from_serialize(LocalReferenceData { message }),
) {
Ok(html) => Ok(Event::default().event("references").data(html)),
Err(_) => Ok(Event::default()
.event("error")
.data("Failed to render references")),
}
} else {
Ok(Event::default()
.event("error")
.data("Failed to retrieve references"))
}
}))
.chain(once(async {
Ok(Event::default()
.event("close_stream")
.data("Stream complete"))
}))
.boxed();
sse_with_keep_alive(event_stream)
}
async fn prepare_chat_request(
state: &HtmlState,
user_message: &Message,
user: &User,
history: &[Message],
) -> Result<
(async_openai::types::CreateChatCompletionRequest, Vec<String>),
Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>>,
> {
let rerank_lease = match state.reranker_pool.as_ref() {
Some(pool) => Some(pool.checkout().await),
Some(pool) => pool.checkout().await,
None => None,
};
@@ -248,59 +356,49 @@ pub async fn get_response_stream(
{
Ok(result) => result,
Err(_e) => {
return sse_with_keep_alive(create_error_stream("Failed to retrieve knowledge"));
return Err(Sse::new(create_error_stream("Failed to retrieve knowledge")));
}
};
let allowed_reference_ids = collect_reference_ids_from_retrieval(&retrieval_result);
// 3. Create the OpenAI request with appropriate context format
let context_json = match &retrieval_result {
retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(chunks),
let context_json = match retrieval_result {
retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(&chunks),
retrieval_pipeline::StrategyOutput::Entities(entities) => {
retrieved_entities_to_json(entities)
}
retrieval_pipeline::StrategyOutput::Search(search_result) => {
// For chat, use chunks from the search result
chunks_to_chat_context(&search_result.chunks)
}
};
let formatted_user_message =
create_user_message_with_history(&context_json, &history, &user_message.content);
create_user_message_with_history(&context_json, history, &user_message.content);
let Ok(settings) = SystemSettings::get_current(&state.db).await else {
return sse_with_keep_alive(create_error_stream("Failed to retrieve system settings"));
return Err(Sse::new(create_error_stream("Failed to retrieve system settings")));
};
let Ok(request) = create_chat_request(formatted_user_message, &settings) else {
return sse_with_keep_alive(create_error_stream("Failed to create chat request"));
return Err(Sse::new(create_error_stream("Failed to create chat request")));
};
// 4. Set up the OpenAI stream
let openai_stream = match state.openai_client.chat().create_stream(request).await {
Ok(stream) => stream,
Err(_e) => {
return sse_with_keep_alive(create_error_stream("Failed to create OpenAI stream"));
}
};
Ok((request, allowed_reference_ids))
}
// 5. Create channel for collecting complete response
let (tx, mut rx) = channel::<String>(1000);
let tx_clone = tx.clone();
let (tx_final, mut rx_final) = channel::<Message>(1);
fn spawn_storage_task(
db_client: Arc<SurrealDbClient>,
mut rx: tokio::sync::mpsc::Receiver<String>,
tx_final: tokio::sync::mpsc::Sender<Message>,
user_message: &Message,
user_id: String,
allowed_reference_ids: Vec<String>,
) {
let conversation_id = user_message.conversation_id.clone();
// 6. Set up the collection task for DB storage
let db_client = Arc::clone(&state.db);
let user_id = user.id.clone();
let allowed_reference_ids = allowed_reference_ids.clone();
tokio::spawn(async move {
drop(tx); // Close sender when no longer needed
// Collect full response
let mut full_json = String::new();
while let Some(chunk) = rx.recv().await {
full_json.push_str(&chunk);
}
// Try to extract structured data
if let Ok(response) = from_str::<LLMResponseFormat>(&full_json) {
let raw_references = extract_reference_strings(&response);
let answer = response.answer;
@@ -347,7 +445,7 @@ pub async fn get_response_stream(
);
let ai_message = Message::new(
user_message.conversation_id,
conversation_id,
MessageRole::AI,
answer,
Some(initial_validation.valid_refs),
@@ -362,104 +460,11 @@ pub async fn get_response_stream(
} else {
error!("Failed to parse LLM response as structured format");
// Fallback - store raw response
let ai_message = Message::new(
user_message.conversation_id,
MessageRole::AI,
full_json,
None,
);
let ai_message = Message::new(conversation_id, MessageRole::AI, full_json, None);
let _ = db_client.store_item(ai_message).await;
}
});
// Create a shared state for tracking the JSON parsing
let json_state = Arc::new(Mutex::new(StreamParserState::new()));
// 7. Create the response event stream
let event_stream = openai_stream
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
.map(move |result| {
let tx_storage = tx_clone.clone();
let json_state = Arc::clone(&json_state);
stream! {
match result {
Ok(response) => {
let content = response
.choices
.first()
.and_then(|choice| choice.delta.content.clone())
.unwrap_or_default();
if !content.is_empty() {
// Always send raw content to storage
let _ = tx_storage.send(content.clone()).await;
// Process through JSON parser
let mut state = json_state.lock().await;
let display_content = state.process_chunk(&content);
drop(state);
if !display_content.is_empty() {
yield Ok(Event::default()
.event("chat_message")
.data(display_content));
}
// If display_content is empty, don't yield anything
}
// If content is empty, don't yield anything
}
Err(e) => {
yield Ok(Event::default()
.event("error")
.data(format!("Stream error: {e}")));
}
}
}
})
.flatten()
.chain(stream::once(async move {
if let Some(message) = rx_final.recv().await {
// Don't send any event if references is empty
if message
.references
.as_ref()
.is_some_and(std::vec::Vec::is_empty)
{
return Ok(Event::default().event("empty")); // This event won't be sent
}
// Render template with references
match state.templates.render(
"chat/reference_list.html",
&Value::from_serialize(ReferenceData { message }),
) {
Ok(html) => {
// Return the rendered HTML
Ok(Event::default().event("references").data(html))
}
Err(_) => {
// Handle template rendering error
Ok(Event::default()
.event("error")
.data("Failed to render references"))
}
}
} else {
// Handle case where no references were received
Ok(Event::default()
.event("error")
.data("Failed to retrieve references"))
}
}))
.chain(once(async {
Ok(Event::default()
.event("close_stream")
.data("Stream complete"))
}));
sse_with_keep_alive(event_stream.boxed())
}
struct StreamParserState {
@@ -478,23 +483,18 @@ impl StreamParserState {
}
fn process_chunk(&mut self, chunk: &str) -> String {
// Feed all characters into the parser
for c in chunk.chars() {
let _ = self.parser.add_char(c);
}
// Get the current state of the JSON
let json = self.parser.get_result();
// Check if we're in the answer field
if let Some(obj) = json.as_object() {
if let Some(answer) = obj.get("answer") {
self.in_answer_field = true;
// Get current answer content
let current_content = answer.as_str().unwrap_or_default().to_string();
// Calculate difference to send only new content
if current_content.len() > self.last_answer_content.len() {
let new_content = current_content[self.last_answer_content.len()..].to_string();
self.last_answer_content = current_content;
@@ -503,7 +503,6 @@ impl StreamParserState {
}
}
// No new content to return
String::new()
}
}
+6 -5
View File
@@ -10,8 +10,9 @@ use axum::{
};
pub use chat_handlers::{
delete_conversation, new_chat_user_message, new_user_message, patch_conversation_title,
reload_sidebar, show_chat_base, show_conversation_editing_title, show_existing_chat,
show_initialized_chat,
reload_sidebar, show_conversation_editing_title,
show_chat_base as show_base, show_existing_chat as show_existing,
show_initialized_chat as show_initialized,
};
use message_response_stream::get_response_stream;
use references::show_reference_tooltip;
@@ -24,10 +25,10 @@ where
HtmlState: FromRef<S>,
{
Router::new()
.route("/chat", get(show_chat_base).post(new_chat_user_message))
.route("/chat", get(show_base).post(new_chat_user_message))
.route(
"/chat/{id}",
get(show_existing_chat)
get(show_existing)
.post(new_user_message)
.delete(delete_conversation),
)
@@ -36,7 +37,7 @@ where
get(show_conversation_editing_title).patch(patch_conversation_title),
)
.route("/chat/sidebar", get(reload_sidebar))
.route("/initialized-chat", post(show_initialized_chat))
.route("/initialized-chat", post(show_initialized))
.route("/chat/response-stream", get(get_response_stream))
.route("/chat/reference/{id}", get(show_reference_tooltip))
}