mirror of
https://github.com/perstarkse/minne.git
synced 2026-07-04 03:51:43 +02:00
fix: references bug
fix
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
#![allow(clippy::missing_docs_in_private_items)]
|
||||
|
||||
use std::{pin::Pin, sync::Arc, time::Duration};
|
||||
|
||||
use async_stream::stream;
|
||||
@@ -24,7 +26,7 @@ use retrieval_pipeline::{
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::from_str;
|
||||
use tokio::sync::{mpsc::channel, Mutex};
|
||||
use tracing::{debug, error};
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
use common::storage::{
|
||||
db::SurrealDbClient,
|
||||
@@ -38,6 +40,8 @@ use common::storage::{
|
||||
|
||||
use crate::{html_state::HtmlState, AuthSessionType};
|
||||
|
||||
use super::reference_validation::{collect_reference_ids_from_retrieval, validate_references};
|
||||
|
||||
// Error handling function
|
||||
fn create_error_stream(
|
||||
message: impl Into<String>,
|
||||
@@ -56,13 +60,10 @@ async fn get_message_and_user(
|
||||
Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>>,
|
||||
> {
|
||||
// Check authentication
|
||||
let user = match current_user {
|
||||
Some(user) => user,
|
||||
None => {
|
||||
return Err(Sse::new(create_error_stream(
|
||||
"You must be signed in to use this feature",
|
||||
)))
|
||||
}
|
||||
let Some(user) = current_user else {
|
||||
return Err(Sse::new(create_error_stream(
|
||||
"You must be signed in to use this feature",
|
||||
)));
|
||||
};
|
||||
|
||||
// Retrieve message
|
||||
@@ -105,6 +106,20 @@ pub struct QueryParams {
|
||||
message_id: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ReferenceData {
|
||||
message: Message,
|
||||
}
|
||||
|
||||
fn extract_reference_strings(response: &LLMResponseFormat) -> Vec<String> {
|
||||
response
|
||||
.references
|
||||
.iter()
|
||||
.map(|reference| reference.reference.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn get_response_stream(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSessionType,
|
||||
@@ -146,11 +161,13 @@ pub async fn get_response_stream(
|
||||
}
|
||||
};
|
||||
|
||||
let allowed_reference_ids = collect_reference_ids_from_retrieval(&retrieval_result);
|
||||
|
||||
// 3. Create the OpenAI request with appropriate context format
|
||||
let context_json = match retrieval_result {
|
||||
retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(&chunks),
|
||||
let context_json = match &retrieval_result {
|
||||
retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(chunks),
|
||||
retrieval_pipeline::StrategyOutput::Entities(entities) => {
|
||||
retrieved_entities_to_json(&entities)
|
||||
retrieved_entities_to_json(entities)
|
||||
}
|
||||
retrieval_pipeline::StrategyOutput::Search(search_result) => {
|
||||
// For chat, use chunks from the search result
|
||||
@@ -159,17 +176,11 @@ pub async fn get_response_stream(
|
||||
};
|
||||
let formatted_user_message =
|
||||
create_user_message_with_history(&context_json, &history, &user_message.content);
|
||||
let settings = match SystemSettings::get_current(&state.db).await {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
return Sse::new(create_error_stream("Failed to retrieve system settings"));
|
||||
}
|
||||
let Ok(settings) = SystemSettings::get_current(&state.db).await else {
|
||||
return Sse::new(create_error_stream("Failed to retrieve system settings"));
|
||||
};
|
||||
let request = match create_chat_request(formatted_user_message, &settings) {
|
||||
Ok(req) => req,
|
||||
Err(..) => {
|
||||
return Sse::new(create_error_stream("Failed to create chat request"));
|
||||
}
|
||||
let Ok(request) = create_chat_request(formatted_user_message, &settings) else {
|
||||
return Sse::new(create_error_stream("Failed to create chat request"));
|
||||
};
|
||||
|
||||
// 4. Set up the OpenAI stream
|
||||
@@ -186,7 +197,9 @@ pub async fn get_response_stream(
|
||||
let (tx_final, mut rx_final) = channel::<Message>(1);
|
||||
|
||||
// 6. Set up the collection task for DB storage
|
||||
let db_client = state.db.clone();
|
||||
let db_client = Arc::clone(&state.db);
|
||||
let user_id = user.id.clone();
|
||||
let allowed_reference_ids = allowed_reference_ids.clone();
|
||||
tokio::spawn(async move {
|
||||
drop(tx); // Close sender when no longer needed
|
||||
|
||||
@@ -198,17 +211,55 @@ pub async fn get_response_stream(
|
||||
|
||||
// Try to extract structured data
|
||||
if let Ok(response) = from_str::<LLMResponseFormat>(&full_json) {
|
||||
let references: Vec<String> = response
|
||||
.references
|
||||
.into_iter()
|
||||
.map(|r| r.reference)
|
||||
.collect();
|
||||
let raw_references = extract_reference_strings(&response);
|
||||
let answer = response.answer;
|
||||
|
||||
let initial_validation = match validate_references(
|
||||
&user_id,
|
||||
raw_references,
|
||||
&allowed_reference_ids,
|
||||
&db_client,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
error!(error = %err, "Reference validation failed, storing answer without references");
|
||||
let ai_message = Message::new(
|
||||
user_message.conversation_id,
|
||||
MessageRole::AI,
|
||||
answer,
|
||||
Some(Vec::new()),
|
||||
);
|
||||
|
||||
let _ = tx_final.send(ai_message.clone()).await;
|
||||
if let Err(store_err) = db_client.store_item(ai_message).await {
|
||||
error!(error = ?store_err, "Failed to store AI message after validation failure");
|
||||
}
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
info!(
|
||||
total_refs = initial_validation.reason_stats.total,
|
||||
valid_refs = initial_validation.valid_refs.len(),
|
||||
invalid_refs = initial_validation.invalid_refs.len(),
|
||||
invalid_empty = initial_validation.reason_stats.empty,
|
||||
invalid_unsupported_prefix = initial_validation.reason_stats.unsupported_prefix,
|
||||
invalid_malformed_uuid = initial_validation.reason_stats.malformed_uuid,
|
||||
invalid_duplicate = initial_validation.reason_stats.duplicate,
|
||||
invalid_not_in_context = initial_validation.reason_stats.not_in_context,
|
||||
invalid_not_found = initial_validation.reason_stats.not_found,
|
||||
invalid_wrong_user = initial_validation.reason_stats.wrong_user,
|
||||
invalid_over_limit = initial_validation.reason_stats.over_limit,
|
||||
"Post-LLM reference validation complete"
|
||||
);
|
||||
|
||||
let ai_message = Message::new(
|
||||
user_message.conversation_id,
|
||||
MessageRole::AI,
|
||||
response.answer,
|
||||
Some(references),
|
||||
answer,
|
||||
Some(initial_validation.valid_refs),
|
||||
);
|
||||
|
||||
let _ = tx_final.send(ai_message.clone()).await;
|
||||
@@ -240,7 +291,7 @@ pub async fn get_response_stream(
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
|
||||
.map(move |result| {
|
||||
let tx_storage = tx_clone.clone();
|
||||
let json_state = json_state.clone();
|
||||
let json_state = Arc::clone(&json_state);
|
||||
|
||||
stream! {
|
||||
match result {
|
||||
@@ -288,12 +339,6 @@ pub async fn get_response_stream(
|
||||
return Ok(Event::default().event("empty")); // This event won't be sent
|
||||
}
|
||||
|
||||
// Prepare data for template
|
||||
#[derive(Serialize)]
|
||||
struct ReferenceData {
|
||||
message: Message,
|
||||
}
|
||||
|
||||
// Render template with references
|
||||
match state.templates.render(
|
||||
"chat/reference_list.html",
|
||||
@@ -375,3 +420,27 @@ impl StreamParserState {
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use retrieval_pipeline::answer_retrieval::Reference;
|
||||
|
||||
#[test]
|
||||
fn extracts_reference_strings_in_order() {
|
||||
let response = LLMResponseFormat {
|
||||
answer: "answer".to_string(),
|
||||
references: vec![
|
||||
Reference {
|
||||
reference: "a".to_string(),
|
||||
},
|
||||
Reference {
|
||||
reference: "b".to_string(),
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let extracted = extract_reference_strings(&response);
|
||||
assert_eq!(extracted, vec!["a".to_string(), "b".to_string()]);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user