mirror of
https://github.com/perstarkse/minne.git
synced 2026-04-23 09:18:36 +02:00
fix: references bug
fix
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -3820,7 +3820,7 @@ checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "main"
|
name = "main"
|
||||||
version = "1.0.0"
|
version = "1.0.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"api-router",
|
"api-router",
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
#![allow(clippy::missing_docs_in_private_items)]
|
||||||
|
|
||||||
use std::{pin::Pin, sync::Arc, time::Duration};
|
use std::{pin::Pin, sync::Arc, time::Duration};
|
||||||
|
|
||||||
use async_stream::stream;
|
use async_stream::stream;
|
||||||
@@ -24,7 +26,7 @@ use retrieval_pipeline::{
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::from_str;
|
use serde_json::from_str;
|
||||||
use tokio::sync::{mpsc::channel, Mutex};
|
use tokio::sync::{mpsc::channel, Mutex};
|
||||||
use tracing::{debug, error};
|
use tracing::{debug, error, info};
|
||||||
|
|
||||||
use common::storage::{
|
use common::storage::{
|
||||||
db::SurrealDbClient,
|
db::SurrealDbClient,
|
||||||
@@ -38,6 +40,8 @@ use common::storage::{
|
|||||||
|
|
||||||
use crate::{html_state::HtmlState, AuthSessionType};
|
use crate::{html_state::HtmlState, AuthSessionType};
|
||||||
|
|
||||||
|
use super::reference_validation::{collect_reference_ids_from_retrieval, validate_references};
|
||||||
|
|
||||||
// Error handling function
|
// Error handling function
|
||||||
fn create_error_stream(
|
fn create_error_stream(
|
||||||
message: impl Into<String>,
|
message: impl Into<String>,
|
||||||
@@ -56,13 +60,10 @@ async fn get_message_and_user(
|
|||||||
Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>>,
|
Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>>,
|
||||||
> {
|
> {
|
||||||
// Check authentication
|
// Check authentication
|
||||||
let user = match current_user {
|
let Some(user) = current_user else {
|
||||||
Some(user) => user,
|
return Err(Sse::new(create_error_stream(
|
||||||
None => {
|
"You must be signed in to use this feature",
|
||||||
return Err(Sse::new(create_error_stream(
|
)));
|
||||||
"You must be signed in to use this feature",
|
|
||||||
)))
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Retrieve message
|
// Retrieve message
|
||||||
@@ -105,6 +106,20 @@ pub struct QueryParams {
|
|||||||
message_id: String,
|
message_id: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct ReferenceData {
|
||||||
|
message: Message,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_reference_strings(response: &LLMResponseFormat) -> Vec<String> {
|
||||||
|
response
|
||||||
|
.references
|
||||||
|
.iter()
|
||||||
|
.map(|reference| reference.reference.clone())
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_lines)]
|
||||||
pub async fn get_response_stream(
|
pub async fn get_response_stream(
|
||||||
State(state): State<HtmlState>,
|
State(state): State<HtmlState>,
|
||||||
auth: AuthSessionType,
|
auth: AuthSessionType,
|
||||||
@@ -146,11 +161,13 @@ pub async fn get_response_stream(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let allowed_reference_ids = collect_reference_ids_from_retrieval(&retrieval_result);
|
||||||
|
|
||||||
// 3. Create the OpenAI request with appropriate context format
|
// 3. Create the OpenAI request with appropriate context format
|
||||||
let context_json = match retrieval_result {
|
let context_json = match &retrieval_result {
|
||||||
retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(&chunks),
|
retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(chunks),
|
||||||
retrieval_pipeline::StrategyOutput::Entities(entities) => {
|
retrieval_pipeline::StrategyOutput::Entities(entities) => {
|
||||||
retrieved_entities_to_json(&entities)
|
retrieved_entities_to_json(entities)
|
||||||
}
|
}
|
||||||
retrieval_pipeline::StrategyOutput::Search(search_result) => {
|
retrieval_pipeline::StrategyOutput::Search(search_result) => {
|
||||||
// For chat, use chunks from the search result
|
// For chat, use chunks from the search result
|
||||||
@@ -159,17 +176,11 @@ pub async fn get_response_stream(
|
|||||||
};
|
};
|
||||||
let formatted_user_message =
|
let formatted_user_message =
|
||||||
create_user_message_with_history(&context_json, &history, &user_message.content);
|
create_user_message_with_history(&context_json, &history, &user_message.content);
|
||||||
let settings = match SystemSettings::get_current(&state.db).await {
|
let Ok(settings) = SystemSettings::get_current(&state.db).await else {
|
||||||
Ok(s) => s,
|
return Sse::new(create_error_stream("Failed to retrieve system settings"));
|
||||||
Err(_) => {
|
|
||||||
return Sse::new(create_error_stream("Failed to retrieve system settings"));
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
let request = match create_chat_request(formatted_user_message, &settings) {
|
let Ok(request) = create_chat_request(formatted_user_message, &settings) else {
|
||||||
Ok(req) => req,
|
return Sse::new(create_error_stream("Failed to create chat request"));
|
||||||
Err(..) => {
|
|
||||||
return Sse::new(create_error_stream("Failed to create chat request"));
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// 4. Set up the OpenAI stream
|
// 4. Set up the OpenAI stream
|
||||||
@@ -186,7 +197,9 @@ pub async fn get_response_stream(
|
|||||||
let (tx_final, mut rx_final) = channel::<Message>(1);
|
let (tx_final, mut rx_final) = channel::<Message>(1);
|
||||||
|
|
||||||
// 6. Set up the collection task for DB storage
|
// 6. Set up the collection task for DB storage
|
||||||
let db_client = state.db.clone();
|
let db_client = Arc::clone(&state.db);
|
||||||
|
let user_id = user.id.clone();
|
||||||
|
let allowed_reference_ids = allowed_reference_ids.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
drop(tx); // Close sender when no longer needed
|
drop(tx); // Close sender when no longer needed
|
||||||
|
|
||||||
@@ -198,17 +211,55 @@ pub async fn get_response_stream(
|
|||||||
|
|
||||||
// Try to extract structured data
|
// Try to extract structured data
|
||||||
if let Ok(response) = from_str::<LLMResponseFormat>(&full_json) {
|
if let Ok(response) = from_str::<LLMResponseFormat>(&full_json) {
|
||||||
let references: Vec<String> = response
|
let raw_references = extract_reference_strings(&response);
|
||||||
.references
|
let answer = response.answer;
|
||||||
.into_iter()
|
|
||||||
.map(|r| r.reference)
|
let initial_validation = match validate_references(
|
||||||
.collect();
|
&user_id,
|
||||||
|
raw_references,
|
||||||
|
&allowed_reference_ids,
|
||||||
|
&db_client,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(result) => result,
|
||||||
|
Err(err) => {
|
||||||
|
error!(error = %err, "Reference validation failed, storing answer without references");
|
||||||
|
let ai_message = Message::new(
|
||||||
|
user_message.conversation_id,
|
||||||
|
MessageRole::AI,
|
||||||
|
answer,
|
||||||
|
Some(Vec::new()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let _ = tx_final.send(ai_message.clone()).await;
|
||||||
|
if let Err(store_err) = db_client.store_item(ai_message).await {
|
||||||
|
error!(error = ?store_err, "Failed to store AI message after validation failure");
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
info!(
|
||||||
|
total_refs = initial_validation.reason_stats.total,
|
||||||
|
valid_refs = initial_validation.valid_refs.len(),
|
||||||
|
invalid_refs = initial_validation.invalid_refs.len(),
|
||||||
|
invalid_empty = initial_validation.reason_stats.empty,
|
||||||
|
invalid_unsupported_prefix = initial_validation.reason_stats.unsupported_prefix,
|
||||||
|
invalid_malformed_uuid = initial_validation.reason_stats.malformed_uuid,
|
||||||
|
invalid_duplicate = initial_validation.reason_stats.duplicate,
|
||||||
|
invalid_not_in_context = initial_validation.reason_stats.not_in_context,
|
||||||
|
invalid_not_found = initial_validation.reason_stats.not_found,
|
||||||
|
invalid_wrong_user = initial_validation.reason_stats.wrong_user,
|
||||||
|
invalid_over_limit = initial_validation.reason_stats.over_limit,
|
||||||
|
"Post-LLM reference validation complete"
|
||||||
|
);
|
||||||
|
|
||||||
let ai_message = Message::new(
|
let ai_message = Message::new(
|
||||||
user_message.conversation_id,
|
user_message.conversation_id,
|
||||||
MessageRole::AI,
|
MessageRole::AI,
|
||||||
response.answer,
|
answer,
|
||||||
Some(references),
|
Some(initial_validation.valid_refs),
|
||||||
);
|
);
|
||||||
|
|
||||||
let _ = tx_final.send(ai_message.clone()).await;
|
let _ = tx_final.send(ai_message.clone()).await;
|
||||||
@@ -240,7 +291,7 @@ pub async fn get_response_stream(
|
|||||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
|
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
|
||||||
.map(move |result| {
|
.map(move |result| {
|
||||||
let tx_storage = tx_clone.clone();
|
let tx_storage = tx_clone.clone();
|
||||||
let json_state = json_state.clone();
|
let json_state = Arc::clone(&json_state);
|
||||||
|
|
||||||
stream! {
|
stream! {
|
||||||
match result {
|
match result {
|
||||||
@@ -288,12 +339,6 @@ pub async fn get_response_stream(
|
|||||||
return Ok(Event::default().event("empty")); // This event won't be sent
|
return Ok(Event::default().event("empty")); // This event won't be sent
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare data for template
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct ReferenceData {
|
|
||||||
message: Message,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Render template with references
|
// Render template with references
|
||||||
match state.templates.render(
|
match state.templates.render(
|
||||||
"chat/reference_list.html",
|
"chat/reference_list.html",
|
||||||
@@ -375,3 +420,27 @@ impl StreamParserState {
|
|||||||
String::new()
|
String::new()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use retrieval_pipeline::answer_retrieval::Reference;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extracts_reference_strings_in_order() {
|
||||||
|
let response = LLMResponseFormat {
|
||||||
|
answer: "answer".to_string(),
|
||||||
|
references: vec![
|
||||||
|
Reference {
|
||||||
|
reference: "a".to_string(),
|
||||||
|
},
|
||||||
|
Reference {
|
||||||
|
reference: "b".to_string(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
let extracted = extract_reference_strings(&response);
|
||||||
|
assert_eq!(extracted, vec!["a".to_string(), "b".to_string()]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
mod chat_handlers;
|
mod chat_handlers;
|
||||||
mod message_response_stream;
|
mod message_response_stream;
|
||||||
|
mod reference_validation;
|
||||||
mod references;
|
mod references;
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
|
|||||||
477
html-router/src/routes/chat/reference_validation.rs
Normal file
477
html-router/src/routes/chat/reference_validation.rs
Normal file
@@ -0,0 +1,477 @@
|
|||||||
|
#![allow(clippy::arithmetic_side_effects, clippy::missing_docs_in_private_items)]
|
||||||
|
|
||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
|
use common::{
|
||||||
|
error::AppError,
|
||||||
|
storage::{
|
||||||
|
db::SurrealDbClient,
|
||||||
|
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
use retrieval_pipeline::StrategyOutput;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
pub(crate) const MAX_REFERENCE_COUNT: usize = 10;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub(crate) enum InvalidReferenceReason {
|
||||||
|
Empty,
|
||||||
|
UnsupportedPrefix,
|
||||||
|
MalformedUuid,
|
||||||
|
Duplicate,
|
||||||
|
NotInContext,
|
||||||
|
NotFound,
|
||||||
|
WrongUser,
|
||||||
|
OverLimit,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub(crate) struct InvalidReference {
|
||||||
|
pub raw: String,
|
||||||
|
pub normalized: Option<String>,
|
||||||
|
pub reason: InvalidReferenceReason,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||||
|
pub(crate) struct ReferenceReasonStats {
|
||||||
|
pub total: usize,
|
||||||
|
pub empty: usize,
|
||||||
|
pub unsupported_prefix: usize,
|
||||||
|
pub malformed_uuid: usize,
|
||||||
|
pub duplicate: usize,
|
||||||
|
pub not_in_context: usize,
|
||||||
|
pub not_found: usize,
|
||||||
|
pub wrong_user: usize,
|
||||||
|
pub over_limit: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ReferenceReasonStats {
|
||||||
|
fn record(&mut self, reason: &InvalidReferenceReason) {
|
||||||
|
match reason {
|
||||||
|
InvalidReferenceReason::Empty => self.empty += 1,
|
||||||
|
InvalidReferenceReason::UnsupportedPrefix => self.unsupported_prefix += 1,
|
||||||
|
InvalidReferenceReason::MalformedUuid => self.malformed_uuid += 1,
|
||||||
|
InvalidReferenceReason::Duplicate => self.duplicate += 1,
|
||||||
|
InvalidReferenceReason::NotInContext => self.not_in_context += 1,
|
||||||
|
InvalidReferenceReason::NotFound => self.not_found += 1,
|
||||||
|
InvalidReferenceReason::WrongUser => self.wrong_user += 1,
|
||||||
|
InvalidReferenceReason::OverLimit => self.over_limit += 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default)]
|
||||||
|
pub(crate) struct ReferenceValidationResult {
|
||||||
|
pub valid_refs: Vec<String>,
|
||||||
|
pub invalid_refs: Vec<InvalidReference>,
|
||||||
|
pub reason_stats: ReferenceReasonStats,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub(crate) enum ReferenceLookupTarget {
|
||||||
|
TextChunk,
|
||||||
|
KnowledgeEntity,
|
||||||
|
Any,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn collect_reference_ids_from_retrieval(
|
||||||
|
retrieval_result: &StrategyOutput,
|
||||||
|
) -> Vec<String> {
|
||||||
|
let mut ids = Vec::new();
|
||||||
|
let mut seen = HashSet::new();
|
||||||
|
|
||||||
|
match retrieval_result {
|
||||||
|
StrategyOutput::Chunks(chunks) => {
|
||||||
|
for chunk in chunks {
|
||||||
|
let id = chunk.chunk.id.clone();
|
||||||
|
if seen.insert(id.clone()) {
|
||||||
|
ids.push(id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
StrategyOutput::Entities(entities) => {
|
||||||
|
for entity in entities {
|
||||||
|
let id = entity.entity.id.clone();
|
||||||
|
if seen.insert(id.clone()) {
|
||||||
|
ids.push(id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
StrategyOutput::Search(search) => {
|
||||||
|
for chunk in &search.chunks {
|
||||||
|
let id = chunk.chunk.id.clone();
|
||||||
|
if seen.insert(id.clone()) {
|
||||||
|
ids.push(id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for entity in &search.entities {
|
||||||
|
let id = entity.entity.id.clone();
|
||||||
|
if seen.insert(id.clone()) {
|
||||||
|
ids.push(id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ids
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn validate_references(
|
||||||
|
user_id: &str,
|
||||||
|
refs: Vec<String>,
|
||||||
|
allowed_ids: &[String],
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
) -> Result<ReferenceValidationResult, AppError> {
|
||||||
|
let mut result = ReferenceValidationResult::default();
|
||||||
|
result.reason_stats.total = refs.len();
|
||||||
|
|
||||||
|
let mut seen = HashSet::new();
|
||||||
|
let allowed_set: HashSet<&str> = allowed_ids.iter().map(String::as_str).collect();
|
||||||
|
let enforce_context = !allowed_set.is_empty();
|
||||||
|
|
||||||
|
for raw in refs {
|
||||||
|
let (normalized, target) = match normalize_reference(&raw) {
|
||||||
|
Ok(parsed) => parsed,
|
||||||
|
Err(reason) => {
|
||||||
|
result.reason_stats.record(&reason);
|
||||||
|
result.invalid_refs.push(InvalidReference {
|
||||||
|
raw,
|
||||||
|
normalized: None,
|
||||||
|
reason,
|
||||||
|
});
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if !seen.insert(normalized.clone()) {
|
||||||
|
let reason = InvalidReferenceReason::Duplicate;
|
||||||
|
result.reason_stats.record(&reason);
|
||||||
|
result.invalid_refs.push(InvalidReference {
|
||||||
|
raw,
|
||||||
|
normalized: Some(normalized),
|
||||||
|
reason,
|
||||||
|
});
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.valid_refs.len() >= MAX_REFERENCE_COUNT {
|
||||||
|
let reason = InvalidReferenceReason::OverLimit;
|
||||||
|
result.reason_stats.record(&reason);
|
||||||
|
result.invalid_refs.push(InvalidReference {
|
||||||
|
raw,
|
||||||
|
normalized: Some(normalized),
|
||||||
|
reason,
|
||||||
|
});
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if enforce_context && !allowed_set.contains(normalized.as_str()) {
|
||||||
|
let reason = InvalidReferenceReason::NotInContext;
|
||||||
|
result.reason_stats.record(&reason);
|
||||||
|
result.invalid_refs.push(InvalidReference {
|
||||||
|
raw,
|
||||||
|
normalized: Some(normalized),
|
||||||
|
reason,
|
||||||
|
});
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
match lookup_reference_for_user(&normalized, &target, user_id, db).await? {
|
||||||
|
LookupResult::Found => result.valid_refs.push(normalized),
|
||||||
|
LookupResult::WrongUser => {
|
||||||
|
let reason = InvalidReferenceReason::WrongUser;
|
||||||
|
result.reason_stats.record(&reason);
|
||||||
|
result.invalid_refs.push(InvalidReference {
|
||||||
|
raw,
|
||||||
|
normalized: Some(normalized),
|
||||||
|
reason,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
LookupResult::NotFound => {
|
||||||
|
let reason = InvalidReferenceReason::NotFound;
|
||||||
|
result.reason_stats.record(&reason);
|
||||||
|
result.invalid_refs.push(InvalidReference {
|
||||||
|
raw,
|
||||||
|
normalized: Some(normalized),
|
||||||
|
reason,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn normalize_reference(
|
||||||
|
raw: &str,
|
||||||
|
) -> Result<(String, ReferenceLookupTarget), InvalidReferenceReason> {
|
||||||
|
let trimmed = raw.trim();
|
||||||
|
if trimmed.is_empty() {
|
||||||
|
return Err(InvalidReferenceReason::Empty);
|
||||||
|
}
|
||||||
|
|
||||||
|
let (candidate, target) = if let Some((prefix, rest)) = trimmed.split_once(':') {
|
||||||
|
let lookup_target = if prefix.eq_ignore_ascii_case("knowledge_entity") {
|
||||||
|
ReferenceLookupTarget::KnowledgeEntity
|
||||||
|
} else if prefix.eq_ignore_ascii_case("text_chunk") {
|
||||||
|
ReferenceLookupTarget::TextChunk
|
||||||
|
} else {
|
||||||
|
return Err(InvalidReferenceReason::UnsupportedPrefix);
|
||||||
|
};
|
||||||
|
|
||||||
|
(rest.trim(), lookup_target)
|
||||||
|
} else {
|
||||||
|
(trimmed, ReferenceLookupTarget::Any)
|
||||||
|
};
|
||||||
|
|
||||||
|
if candidate.is_empty() {
|
||||||
|
return Err(InvalidReferenceReason::MalformedUuid);
|
||||||
|
}
|
||||||
|
|
||||||
|
Uuid::parse_str(candidate)
|
||||||
|
.map(|uuid| (uuid.to_string(), target))
|
||||||
|
.map_err(|_| InvalidReferenceReason::MalformedUuid)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
enum LookupResult {
|
||||||
|
Found,
|
||||||
|
WrongUser,
|
||||||
|
NotFound,
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn lookup_reference_for_user(
|
||||||
|
id: &str,
|
||||||
|
target: &ReferenceLookupTarget,
|
||||||
|
user_id: &str,
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
) -> Result<LookupResult, AppError> {
|
||||||
|
match target {
|
||||||
|
ReferenceLookupTarget::TextChunk => lookup_single_type::<TextChunk>(id, user_id, db).await,
|
||||||
|
ReferenceLookupTarget::KnowledgeEntity => {
|
||||||
|
lookup_single_type::<KnowledgeEntity>(id, user_id, db).await
|
||||||
|
}
|
||||||
|
ReferenceLookupTarget::Any => {
|
||||||
|
let chunk_result = lookup_single_type::<TextChunk>(id, user_id, db).await?;
|
||||||
|
if chunk_result == LookupResult::Found {
|
||||||
|
return Ok(LookupResult::Found);
|
||||||
|
}
|
||||||
|
|
||||||
|
let entity_result = lookup_single_type::<KnowledgeEntity>(id, user_id, db).await?;
|
||||||
|
if entity_result == LookupResult::Found {
|
||||||
|
return Ok(LookupResult::Found);
|
||||||
|
}
|
||||||
|
|
||||||
|
if chunk_result == LookupResult::WrongUser || entity_result == LookupResult::WrongUser {
|
||||||
|
return Ok(LookupResult::WrongUser);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(LookupResult::NotFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn lookup_single_type<T>(
|
||||||
|
id: &str,
|
||||||
|
user_id: &str,
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
) -> Result<LookupResult, AppError>
|
||||||
|
where
|
||||||
|
T: StoredObject + for<'de> serde::Deserialize<'de> + HasUserId,
|
||||||
|
{
|
||||||
|
let item = db.get_item::<T>(id).await?;
|
||||||
|
Ok(match item {
|
||||||
|
Some(item) if item.user_id() == user_id => LookupResult::Found,
|
||||||
|
Some(_) => LookupResult::WrongUser,
|
||||||
|
None => LookupResult::NotFound,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
trait HasUserId {
|
||||||
|
fn user_id(&self) -> &str;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HasUserId for TextChunk {
|
||||||
|
fn user_id(&self) -> &str {
|
||||||
|
&self.user_id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HasUserId for KnowledgeEntity {
|
||||||
|
fn user_id(&self) -> &str {
|
||||||
|
&self.user_id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[allow(
|
||||||
|
clippy::cloned_ref_to_slice_refs,
|
||||||
|
clippy::expect_used,
|
||||||
|
clippy::indexing_slicing
|
||||||
|
)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use common::storage::types::knowledge_entity::KnowledgeEntityType;
|
||||||
|
use surrealdb::engine::any::connect;
|
||||||
|
|
||||||
|
async fn setup_test_db() -> SurrealDbClient {
|
||||||
|
let client = connect("mem://")
|
||||||
|
.await
|
||||||
|
.expect("failed to create in-memory surrealdb client");
|
||||||
|
let namespace = format!("test_ns_{}", Uuid::new_v4());
|
||||||
|
let database = format!("test_db_{}", Uuid::new_v4());
|
||||||
|
client
|
||||||
|
.use_ns(namespace)
|
||||||
|
.use_db(database)
|
||||||
|
.await
|
||||||
|
.expect("failed to select namespace/db");
|
||||||
|
|
||||||
|
let db = SurrealDbClient { client };
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.expect("failed to apply migrations");
|
||||||
|
db
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn valid_uuid_exists_and_belongs_to_user() {
|
||||||
|
let db = setup_test_db().await;
|
||||||
|
let user_id = "user-a";
|
||||||
|
let entity = KnowledgeEntity::new(
|
||||||
|
"source-1".to_string(),
|
||||||
|
"Entity A".to_string(),
|
||||||
|
"Entity description".to_string(),
|
||||||
|
KnowledgeEntityType::Document,
|
||||||
|
None,
|
||||||
|
user_id.to_string(),
|
||||||
|
);
|
||||||
|
db.store_item(entity.clone())
|
||||||
|
.await
|
||||||
|
.expect("failed to store entity");
|
||||||
|
|
||||||
|
let result =
|
||||||
|
validate_references(user_id, vec![entity.id.clone()], &[entity.id.clone()], &db)
|
||||||
|
.await
|
||||||
|
.expect("validation should not fail");
|
||||||
|
|
||||||
|
assert_eq!(result.valid_refs, vec![entity.id]);
|
||||||
|
assert!(result.invalid_refs.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn valid_uuid_exists_but_wrong_user_is_rejected() {
|
||||||
|
let db = setup_test_db().await;
|
||||||
|
let entity = KnowledgeEntity::new(
|
||||||
|
"source-1".to_string(),
|
||||||
|
"Entity B".to_string(),
|
||||||
|
"Entity description".to_string(),
|
||||||
|
KnowledgeEntityType::Document,
|
||||||
|
None,
|
||||||
|
"other-user".to_string(),
|
||||||
|
);
|
||||||
|
db.store_item(entity.clone())
|
||||||
|
.await
|
||||||
|
.expect("failed to store entity");
|
||||||
|
|
||||||
|
let result =
|
||||||
|
validate_references("user-a", vec![entity.id.clone()], &[entity.id.clone()], &db)
|
||||||
|
.await
|
||||||
|
.expect("validation should not fail");
|
||||||
|
|
||||||
|
assert!(result.valid_refs.is_empty());
|
||||||
|
assert_eq!(result.invalid_refs.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
result.invalid_refs[0].reason,
|
||||||
|
InvalidReferenceReason::WrongUser
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn malformed_uuid_is_rejected() {
|
||||||
|
let db = setup_test_db().await;
|
||||||
|
let result = validate_references(
|
||||||
|
"user-a",
|
||||||
|
vec!["not-a-uuid".to_string()],
|
||||||
|
&["not-a-uuid".to_string()],
|
||||||
|
&db,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("validation should not fail");
|
||||||
|
|
||||||
|
assert!(result.valid_refs.is_empty());
|
||||||
|
assert_eq!(result.invalid_refs.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
result.invalid_refs[0].reason,
|
||||||
|
InvalidReferenceReason::MalformedUuid
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn mixed_duplicates_are_deduped() {
|
||||||
|
let db = setup_test_db().await;
|
||||||
|
let user_id = "user-a";
|
||||||
|
|
||||||
|
let first = KnowledgeEntity::new(
|
||||||
|
"source-1".to_string(),
|
||||||
|
"Entity 1".to_string(),
|
||||||
|
"Entity description".to_string(),
|
||||||
|
KnowledgeEntityType::Document,
|
||||||
|
None,
|
||||||
|
user_id.to_string(),
|
||||||
|
);
|
||||||
|
let second = KnowledgeEntity::new(
|
||||||
|
"source-2".to_string(),
|
||||||
|
"Entity 2".to_string(),
|
||||||
|
"Entity description".to_string(),
|
||||||
|
KnowledgeEntityType::Document,
|
||||||
|
None,
|
||||||
|
user_id.to_string(),
|
||||||
|
);
|
||||||
|
db.store_item(first.clone())
|
||||||
|
.await
|
||||||
|
.expect("failed to store first entity");
|
||||||
|
db.store_item(second.clone())
|
||||||
|
.await
|
||||||
|
.expect("failed to store second entity");
|
||||||
|
|
||||||
|
let refs = vec![
|
||||||
|
first.id.clone(),
|
||||||
|
format!("knowledge_entity:{}", first.id),
|
||||||
|
second.id.clone(),
|
||||||
|
second.id.clone(),
|
||||||
|
];
|
||||||
|
|
||||||
|
let allowed = vec![first.id.clone(), second.id.clone()];
|
||||||
|
let result = validate_references(user_id, refs, &allowed, &db)
|
||||||
|
.await
|
||||||
|
.expect("validation should not fail");
|
||||||
|
|
||||||
|
assert_eq!(result.valid_refs, vec![first.id, second.id]);
|
||||||
|
assert_eq!(result.invalid_refs.len(), 2);
|
||||||
|
assert!(result
|
||||||
|
.invalid_refs
|
||||||
|
.iter()
|
||||||
|
.all(|entry| entry.reason == InvalidReferenceReason::Duplicate));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn bare_uuid_prefers_chunk_lookup_before_entity() {
|
||||||
|
let db = setup_test_db().await;
|
||||||
|
let user_id = "user-a";
|
||||||
|
let chunk = TextChunk::new(
|
||||||
|
"source-1".to_string(),
|
||||||
|
"Chunk body".to_string(),
|
||||||
|
user_id.to_string(),
|
||||||
|
);
|
||||||
|
db.store_item(chunk.clone())
|
||||||
|
.await
|
||||||
|
.expect("failed to store chunk");
|
||||||
|
|
||||||
|
let result = validate_references(user_id, vec![chunk.id.clone()], &[chunk.id.clone()], &db)
|
||||||
|
.await
|
||||||
|
.expect("validation should not fail");
|
||||||
|
|
||||||
|
assert_eq!(result.valid_refs, vec![chunk.id]);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,12 +1,15 @@
|
|||||||
|
#![allow(clippy::missing_docs_in_private_items)]
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Path, State},
|
extract::{Path, State},
|
||||||
response::IntoResponse,
|
response::IntoResponse,
|
||||||
};
|
};
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use chrono_tz::Tz;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
|
||||||
use common::{
|
use common::storage::types::{
|
||||||
error::AppError,
|
knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, user::User,
|
||||||
storage::types::{knowledge_entity::KnowledgeEntity, user::User},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
@@ -17,29 +20,101 @@ use crate::{
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use super::reference_validation::{normalize_reference, ReferenceLookupTarget};
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct ReferenceTooltipData {
|
||||||
|
text_chunk: Option<TextChunk>,
|
||||||
|
text_chunk_updated_at: Option<String>,
|
||||||
|
entity: Option<KnowledgeEntity>,
|
||||||
|
entity_updated_at: Option<String>,
|
||||||
|
user: User,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn format_datetime_for_user(datetime: DateTime<Utc>, timezone: &str) -> String {
|
||||||
|
match timezone.parse::<Tz>() {
|
||||||
|
Ok(tz) => datetime
|
||||||
|
.with_timezone(&tz)
|
||||||
|
.format("%Y-%m-%d %H:%M:%S")
|
||||||
|
.to_string(),
|
||||||
|
Err(_) => datetime.format("%Y-%m-%d %H:%M:%S").to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn show_reference_tooltip(
|
pub async fn show_reference_tooltip(
|
||||||
State(state): State<HtmlState>,
|
State(state): State<HtmlState>,
|
||||||
RequireUser(user): RequireUser,
|
RequireUser(user): RequireUser,
|
||||||
Path(reference_id): Path<String>,
|
Path(reference_id): Path<String>,
|
||||||
) -> Result<impl IntoResponse, HtmlError> {
|
) -> Result<impl IntoResponse, HtmlError> {
|
||||||
let entity: KnowledgeEntity = state
|
let Ok((normalized_reference_id, target)) = normalize_reference(&reference_id) else {
|
||||||
.db
|
return Ok(TemplateResponse::not_found());
|
||||||
.get_item(&reference_id)
|
};
|
||||||
.await?
|
|
||||||
.ok_or_else(|| AppError::NotFound("Item was not found".to_string()))?;
|
|
||||||
|
|
||||||
if entity.user_id != user.id {
|
let lookup_order = match target {
|
||||||
return Ok(TemplateResponse::unauthorized());
|
ReferenceLookupTarget::TextChunk | ReferenceLookupTarget::Any => [
|
||||||
|
ReferenceLookupTarget::TextChunk,
|
||||||
|
ReferenceLookupTarget::KnowledgeEntity,
|
||||||
|
],
|
||||||
|
ReferenceLookupTarget::KnowledgeEntity => [
|
||||||
|
ReferenceLookupTarget::KnowledgeEntity,
|
||||||
|
ReferenceLookupTarget::TextChunk,
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut text_chunk: Option<TextChunk> = None;
|
||||||
|
let mut knowledge_entity: Option<KnowledgeEntity> = None;
|
||||||
|
|
||||||
|
for lookup_target in lookup_order {
|
||||||
|
match lookup_target {
|
||||||
|
ReferenceLookupTarget::TextChunk => {
|
||||||
|
if let Some(chunk) = state
|
||||||
|
.db
|
||||||
|
.get_item::<TextChunk>(&normalized_reference_id)
|
||||||
|
.await?
|
||||||
|
{
|
||||||
|
if chunk.user_id != user.id {
|
||||||
|
return Ok(TemplateResponse::unauthorized());
|
||||||
|
}
|
||||||
|
text_chunk = Some(chunk);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ReferenceLookupTarget::KnowledgeEntity => {
|
||||||
|
if let Some(entity) = state
|
||||||
|
.db
|
||||||
|
.get_item::<KnowledgeEntity>(&normalized_reference_id)
|
||||||
|
.await?
|
||||||
|
{
|
||||||
|
if entity.user_id != user.id {
|
||||||
|
return Ok(TemplateResponse::unauthorized());
|
||||||
|
}
|
||||||
|
knowledge_entity = Some(entity);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ReferenceLookupTarget::Any => {}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
if text_chunk.is_none() && knowledge_entity.is_none() {
|
||||||
struct ReferenceTooltipData {
|
return Ok(TemplateResponse::not_found());
|
||||||
entity: KnowledgeEntity,
|
|
||||||
user: User,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let text_chunk_updated_at = text_chunk
|
||||||
|
.as_ref()
|
||||||
|
.map(|chunk| format_datetime_for_user(chunk.updated_at, &user.timezone));
|
||||||
|
let entity_updated_at = knowledge_entity
|
||||||
|
.as_ref()
|
||||||
|
.map(|entity| format_datetime_for_user(entity.updated_at, &user.timezone));
|
||||||
|
|
||||||
Ok(TemplateResponse::new_template(
|
Ok(TemplateResponse::new_template(
|
||||||
"chat/reference_tooltip.html",
|
"chat/reference_tooltip.html",
|
||||||
ReferenceTooltipData { entity, user },
|
ReferenceTooltipData {
|
||||||
|
text_chunk,
|
||||||
|
text_chunk_updated_at,
|
||||||
|
entity: knowledge_entity,
|
||||||
|
entity_updated_at,
|
||||||
|
user,
|
||||||
|
},
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -111,12 +111,23 @@
|
|||||||
// Load content if needed
|
// Load content if needed
|
||||||
if (!tooltipContent) {
|
if (!tooltipContent) {
|
||||||
fetch(`/chat/reference/${encodeURIComponent(reference)}`)
|
fetch(`/chat/reference/${encodeURIComponent(reference)}`)
|
||||||
.then(response => response.text())
|
.then(response => {
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`reference lookup failed with status ${response.status}`);
|
||||||
|
}
|
||||||
|
return response.text();
|
||||||
|
})
|
||||||
.then(html => {
|
.then(html => {
|
||||||
tooltipContent = html;
|
tooltipContent = html;
|
||||||
if (document.getElementById(tooltipId)) {
|
if (document.getElementById(tooltipId)) {
|
||||||
document.getElementById(tooltipId).innerHTML = html;
|
document.getElementById(tooltipId).innerHTML = html;
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
.catch(() => {
|
||||||
|
tooltipContent = '<div class="text-xs opacity-70">Reference unavailable.</div>';
|
||||||
|
if (document.getElementById(tooltipId)) {
|
||||||
|
document.getElementById(tooltipId).innerHTML = tooltipContent;
|
||||||
|
}
|
||||||
});
|
});
|
||||||
} else if (tooltip) {
|
} else if (tooltip) {
|
||||||
// Set content if already loaded
|
// Set content if already loaded
|
||||||
|
|||||||
@@ -1,3 +1,11 @@
|
|||||||
<div>{{entity.name}}</div>
|
{% if text_chunk %}
|
||||||
<div>{{entity.description}}</div>
|
<div class="font-semibold">Chunk Reference</div>
|
||||||
<div>{{entity.updated_at|datetimeformat(format="short", tz=user.timezone)}} </div>
|
<div class="text-sm whitespace-pre-wrap">{{text_chunk.chunk}}</div>
|
||||||
|
<div class="text-xs opacity-70">{{text_chunk_updated_at}}</div>
|
||||||
|
{% elif entity %}
|
||||||
|
<div class="font-semibold">{{entity.name}}</div>
|
||||||
|
<div class="text-sm">{{entity.description}}</div>
|
||||||
|
<div class="text-xs opacity-70">{{entity_updated_at}}</div>
|
||||||
|
{% else %}
|
||||||
|
<div class="text-xs opacity-70">Reference unavailable.</div>
|
||||||
|
{% endif %}
|
||||||
|
|||||||
@@ -61,8 +61,8 @@ pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value {
|
|||||||
.iter()
|
.iter()
|
||||||
.map(|chunk| {
|
.map(|chunk| {
|
||||||
serde_json::json!({
|
serde_json::json!({
|
||||||
|
"id": chunk.chunk.id,
|
||||||
"content": chunk.chunk.chunk,
|
"content": chunk.chunk.chunk,
|
||||||
"source_id": chunk.chunk.source_id,
|
|
||||||
"score": round_score(chunk.score),
|
"score": round_score(chunk.score),
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@@ -117,7 +117,7 @@ pub fn create_chat_request(
|
|||||||
.build()
|
.build()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn process_llm_response(
|
pub fn process_llm_response(
|
||||||
response: CreateChatCompletionResponse,
|
response: CreateChatCompletionResponse,
|
||||||
) -> Result<LLMResponseFormat, AppError> {
|
) -> Result<LLMResponseFormat, AppError> {
|
||||||
response
|
response
|
||||||
|
|||||||
Reference in New Issue
Block a user