fix: references bug

fix
This commit is contained in:
Per Stark
2026-02-11 21:45:20 +01:00
parent 96846ad664
commit bbad91d55b
8 changed files with 699 additions and 58 deletions
@@ -1,3 +1,5 @@
#![allow(clippy::missing_docs_in_private_items)]
use std::{pin::Pin, sync::Arc, time::Duration};
use async_stream::stream;
@@ -24,7 +26,7 @@ use retrieval_pipeline::{
use serde::{Deserialize, Serialize};
use serde_json::from_str;
use tokio::sync::{mpsc::channel, Mutex};
use tracing::{debug, error};
use tracing::{debug, error, info};
use common::storage::{
db::SurrealDbClient,
@@ -38,6 +40,8 @@ use common::storage::{
use crate::{html_state::HtmlState, AuthSessionType};
use super::reference_validation::{collect_reference_ids_from_retrieval, validate_references};
// Error handling function
fn create_error_stream(
message: impl Into<String>,
@@ -56,13 +60,10 @@ async fn get_message_and_user(
Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>>,
> {
// Check authentication
let user = match current_user {
Some(user) => user,
None => {
return Err(Sse::new(create_error_stream(
"You must be signed in to use this feature",
)))
}
let Some(user) = current_user else {
return Err(Sse::new(create_error_stream(
"You must be signed in to use this feature",
)));
};
// Retrieve message
@@ -105,6 +106,20 @@ pub struct QueryParams {
message_id: String,
}
#[derive(Serialize)]
struct ReferenceData {
message: Message,
}
fn extract_reference_strings(response: &LLMResponseFormat) -> Vec<String> {
response
.references
.iter()
.map(|reference| reference.reference.clone())
.collect()
}
#[allow(clippy::too_many_lines)]
pub async fn get_response_stream(
State(state): State<HtmlState>,
auth: AuthSessionType,
@@ -146,11 +161,13 @@ pub async fn get_response_stream(
}
};
let allowed_reference_ids = collect_reference_ids_from_retrieval(&retrieval_result);
// 3. Create the OpenAI request with appropriate context format
let context_json = match retrieval_result {
retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(&chunks),
let context_json = match &retrieval_result {
retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(chunks),
retrieval_pipeline::StrategyOutput::Entities(entities) => {
retrieved_entities_to_json(&entities)
retrieved_entities_to_json(entities)
}
retrieval_pipeline::StrategyOutput::Search(search_result) => {
// For chat, use chunks from the search result
@@ -159,17 +176,11 @@ pub async fn get_response_stream(
};
let formatted_user_message =
create_user_message_with_history(&context_json, &history, &user_message.content);
let settings = match SystemSettings::get_current(&state.db).await {
Ok(s) => s,
Err(_) => {
return Sse::new(create_error_stream("Failed to retrieve system settings"));
}
let Ok(settings) = SystemSettings::get_current(&state.db).await else {
return Sse::new(create_error_stream("Failed to retrieve system settings"));
};
let request = match create_chat_request(formatted_user_message, &settings) {
Ok(req) => req,
Err(..) => {
return Sse::new(create_error_stream("Failed to create chat request"));
}
let Ok(request) = create_chat_request(formatted_user_message, &settings) else {
return Sse::new(create_error_stream("Failed to create chat request"));
};
// 4. Set up the OpenAI stream
@@ -186,7 +197,9 @@ pub async fn get_response_stream(
let (tx_final, mut rx_final) = channel::<Message>(1);
// 6. Set up the collection task for DB storage
let db_client = state.db.clone();
let db_client = Arc::clone(&state.db);
let user_id = user.id.clone();
let allowed_reference_ids = allowed_reference_ids.clone();
tokio::spawn(async move {
drop(tx); // Close sender when no longer needed
@@ -198,17 +211,55 @@ pub async fn get_response_stream(
// Try to extract structured data
if let Ok(response) = from_str::<LLMResponseFormat>(&full_json) {
let references: Vec<String> = response
.references
.into_iter()
.map(|r| r.reference)
.collect();
let raw_references = extract_reference_strings(&response);
let answer = response.answer;
let initial_validation = match validate_references(
&user_id,
raw_references,
&allowed_reference_ids,
&db_client,
)
.await
{
Ok(result) => result,
Err(err) => {
error!(error = %err, "Reference validation failed, storing answer without references");
let ai_message = Message::new(
user_message.conversation_id,
MessageRole::AI,
answer,
Some(Vec::new()),
);
let _ = tx_final.send(ai_message.clone()).await;
if let Err(store_err) = db_client.store_item(ai_message).await {
error!(error = ?store_err, "Failed to store AI message after validation failure");
}
return;
}
};
info!(
total_refs = initial_validation.reason_stats.total,
valid_refs = initial_validation.valid_refs.len(),
invalid_refs = initial_validation.invalid_refs.len(),
invalid_empty = initial_validation.reason_stats.empty,
invalid_unsupported_prefix = initial_validation.reason_stats.unsupported_prefix,
invalid_malformed_uuid = initial_validation.reason_stats.malformed_uuid,
invalid_duplicate = initial_validation.reason_stats.duplicate,
invalid_not_in_context = initial_validation.reason_stats.not_in_context,
invalid_not_found = initial_validation.reason_stats.not_found,
invalid_wrong_user = initial_validation.reason_stats.wrong_user,
invalid_over_limit = initial_validation.reason_stats.over_limit,
"Post-LLM reference validation complete"
);
let ai_message = Message::new(
user_message.conversation_id,
MessageRole::AI,
response.answer,
Some(references),
answer,
Some(initial_validation.valid_refs),
);
let _ = tx_final.send(ai_message.clone()).await;
@@ -240,7 +291,7 @@ pub async fn get_response_stream(
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
.map(move |result| {
let tx_storage = tx_clone.clone();
let json_state = json_state.clone();
let json_state = Arc::clone(&json_state);
stream! {
match result {
@@ -288,12 +339,6 @@ pub async fn get_response_stream(
return Ok(Event::default().event("empty")); // This event won't be sent
}
// Prepare data for template
#[derive(Serialize)]
struct ReferenceData {
message: Message,
}
// Render template with references
match state.templates.render(
"chat/reference_list.html",
@@ -375,3 +420,27 @@ impl StreamParserState {
String::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use retrieval_pipeline::answer_retrieval::Reference;
#[test]
fn extracts_reference_strings_in_order() {
let response = LLMResponseFormat {
answer: "answer".to_string(),
references: vec![
Reference {
reference: "a".to_string(),
},
Reference {
reference: "b".to_string(),
},
],
};
let extracted = extract_reference_strings(&response);
assert_eq!(extracted, vec!["a".to_string(), "b".to_string()]);
}
}