mirror of
https://github.com/perstarkse/minne.git
synced 2026-02-22 08:07:40 +01:00
fix: references bug
fix
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -3820,7 +3820,7 @@ checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30"
|
||||
|
||||
[[package]]
|
||||
name = "main"
|
||||
version = "1.0.0"
|
||||
version = "1.0.1"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"api-router",
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#![allow(clippy::missing_docs_in_private_items)]
|
||||
|
||||
use std::{pin::Pin, sync::Arc, time::Duration};
|
||||
|
||||
use async_stream::stream;
|
||||
@@ -24,7 +26,7 @@ use retrieval_pipeline::{
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::from_str;
|
||||
use tokio::sync::{mpsc::channel, Mutex};
|
||||
use tracing::{debug, error};
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
use common::storage::{
|
||||
db::SurrealDbClient,
|
||||
@@ -38,6 +40,8 @@ use common::storage::{
|
||||
|
||||
use crate::{html_state::HtmlState, AuthSessionType};
|
||||
|
||||
use super::reference_validation::{collect_reference_ids_from_retrieval, validate_references};
|
||||
|
||||
// Error handling function
|
||||
fn create_error_stream(
|
||||
message: impl Into<String>,
|
||||
@@ -56,13 +60,10 @@ async fn get_message_and_user(
|
||||
Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>>,
|
||||
> {
|
||||
// Check authentication
|
||||
let user = match current_user {
|
||||
Some(user) => user,
|
||||
None => {
|
||||
return Err(Sse::new(create_error_stream(
|
||||
"You must be signed in to use this feature",
|
||||
)))
|
||||
}
|
||||
let Some(user) = current_user else {
|
||||
return Err(Sse::new(create_error_stream(
|
||||
"You must be signed in to use this feature",
|
||||
)));
|
||||
};
|
||||
|
||||
// Retrieve message
|
||||
@@ -105,6 +106,20 @@ pub struct QueryParams {
|
||||
message_id: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ReferenceData {
|
||||
message: Message,
|
||||
}
|
||||
|
||||
fn extract_reference_strings(response: &LLMResponseFormat) -> Vec<String> {
|
||||
response
|
||||
.references
|
||||
.iter()
|
||||
.map(|reference| reference.reference.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn get_response_stream(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSessionType,
|
||||
@@ -146,11 +161,13 @@ pub async fn get_response_stream(
|
||||
}
|
||||
};
|
||||
|
||||
let allowed_reference_ids = collect_reference_ids_from_retrieval(&retrieval_result);
|
||||
|
||||
// 3. Create the OpenAI request with appropriate context format
|
||||
let context_json = match retrieval_result {
|
||||
retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(&chunks),
|
||||
let context_json = match &retrieval_result {
|
||||
retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(chunks),
|
||||
retrieval_pipeline::StrategyOutput::Entities(entities) => {
|
||||
retrieved_entities_to_json(&entities)
|
||||
retrieved_entities_to_json(entities)
|
||||
}
|
||||
retrieval_pipeline::StrategyOutput::Search(search_result) => {
|
||||
// For chat, use chunks from the search result
|
||||
@@ -159,17 +176,11 @@ pub async fn get_response_stream(
|
||||
};
|
||||
let formatted_user_message =
|
||||
create_user_message_with_history(&context_json, &history, &user_message.content);
|
||||
let settings = match SystemSettings::get_current(&state.db).await {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
return Sse::new(create_error_stream("Failed to retrieve system settings"));
|
||||
}
|
||||
let Ok(settings) = SystemSettings::get_current(&state.db).await else {
|
||||
return Sse::new(create_error_stream("Failed to retrieve system settings"));
|
||||
};
|
||||
let request = match create_chat_request(formatted_user_message, &settings) {
|
||||
Ok(req) => req,
|
||||
Err(..) => {
|
||||
return Sse::new(create_error_stream("Failed to create chat request"));
|
||||
}
|
||||
let Ok(request) = create_chat_request(formatted_user_message, &settings) else {
|
||||
return Sse::new(create_error_stream("Failed to create chat request"));
|
||||
};
|
||||
|
||||
// 4. Set up the OpenAI stream
|
||||
@@ -186,7 +197,9 @@ pub async fn get_response_stream(
|
||||
let (tx_final, mut rx_final) = channel::<Message>(1);
|
||||
|
||||
// 6. Set up the collection task for DB storage
|
||||
let db_client = state.db.clone();
|
||||
let db_client = Arc::clone(&state.db);
|
||||
let user_id = user.id.clone();
|
||||
let allowed_reference_ids = allowed_reference_ids.clone();
|
||||
tokio::spawn(async move {
|
||||
drop(tx); // Close sender when no longer needed
|
||||
|
||||
@@ -198,17 +211,55 @@ pub async fn get_response_stream(
|
||||
|
||||
// Try to extract structured data
|
||||
if let Ok(response) = from_str::<LLMResponseFormat>(&full_json) {
|
||||
let references: Vec<String> = response
|
||||
.references
|
||||
.into_iter()
|
||||
.map(|r| r.reference)
|
||||
.collect();
|
||||
let raw_references = extract_reference_strings(&response);
|
||||
let answer = response.answer;
|
||||
|
||||
let initial_validation = match validate_references(
|
||||
&user_id,
|
||||
raw_references,
|
||||
&allowed_reference_ids,
|
||||
&db_client,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
error!(error = %err, "Reference validation failed, storing answer without references");
|
||||
let ai_message = Message::new(
|
||||
user_message.conversation_id,
|
||||
MessageRole::AI,
|
||||
answer,
|
||||
Some(Vec::new()),
|
||||
);
|
||||
|
||||
let _ = tx_final.send(ai_message.clone()).await;
|
||||
if let Err(store_err) = db_client.store_item(ai_message).await {
|
||||
error!(error = ?store_err, "Failed to store AI message after validation failure");
|
||||
}
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
info!(
|
||||
total_refs = initial_validation.reason_stats.total,
|
||||
valid_refs = initial_validation.valid_refs.len(),
|
||||
invalid_refs = initial_validation.invalid_refs.len(),
|
||||
invalid_empty = initial_validation.reason_stats.empty,
|
||||
invalid_unsupported_prefix = initial_validation.reason_stats.unsupported_prefix,
|
||||
invalid_malformed_uuid = initial_validation.reason_stats.malformed_uuid,
|
||||
invalid_duplicate = initial_validation.reason_stats.duplicate,
|
||||
invalid_not_in_context = initial_validation.reason_stats.not_in_context,
|
||||
invalid_not_found = initial_validation.reason_stats.not_found,
|
||||
invalid_wrong_user = initial_validation.reason_stats.wrong_user,
|
||||
invalid_over_limit = initial_validation.reason_stats.over_limit,
|
||||
"Post-LLM reference validation complete"
|
||||
);
|
||||
|
||||
let ai_message = Message::new(
|
||||
user_message.conversation_id,
|
||||
MessageRole::AI,
|
||||
response.answer,
|
||||
Some(references),
|
||||
answer,
|
||||
Some(initial_validation.valid_refs),
|
||||
);
|
||||
|
||||
let _ = tx_final.send(ai_message.clone()).await;
|
||||
@@ -240,7 +291,7 @@ pub async fn get_response_stream(
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
|
||||
.map(move |result| {
|
||||
let tx_storage = tx_clone.clone();
|
||||
let json_state = json_state.clone();
|
||||
let json_state = Arc::clone(&json_state);
|
||||
|
||||
stream! {
|
||||
match result {
|
||||
@@ -288,12 +339,6 @@ pub async fn get_response_stream(
|
||||
return Ok(Event::default().event("empty")); // This event won't be sent
|
||||
}
|
||||
|
||||
// Prepare data for template
|
||||
#[derive(Serialize)]
|
||||
struct ReferenceData {
|
||||
message: Message,
|
||||
}
|
||||
|
||||
// Render template with references
|
||||
match state.templates.render(
|
||||
"chat/reference_list.html",
|
||||
@@ -375,3 +420,27 @@ impl StreamParserState {
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use retrieval_pipeline::answer_retrieval::Reference;
|
||||
|
||||
#[test]
|
||||
fn extracts_reference_strings_in_order() {
|
||||
let response = LLMResponseFormat {
|
||||
answer: "answer".to_string(),
|
||||
references: vec![
|
||||
Reference {
|
||||
reference: "a".to_string(),
|
||||
},
|
||||
Reference {
|
||||
reference: "b".to_string(),
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let extracted = extract_reference_strings(&response);
|
||||
assert_eq!(extracted, vec!["a".to_string(), "b".to_string()]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
mod chat_handlers;
|
||||
mod message_response_stream;
|
||||
mod reference_validation;
|
||||
mod references;
|
||||
|
||||
use axum::{
|
||||
|
||||
477
html-router/src/routes/chat/reference_validation.rs
Normal file
477
html-router/src/routes/chat/reference_validation.rs
Normal file
@@ -0,0 +1,477 @@
|
||||
#![allow(clippy::arithmetic_side_effects, clippy::missing_docs_in_private_items)]
|
||||
|
||||
use std::collections::HashSet;
|
||||
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject},
|
||||
},
|
||||
};
|
||||
use retrieval_pipeline::StrategyOutput;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub(crate) const MAX_REFERENCE_COUNT: usize = 10;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) enum InvalidReferenceReason {
|
||||
Empty,
|
||||
UnsupportedPrefix,
|
||||
MalformedUuid,
|
||||
Duplicate,
|
||||
NotInContext,
|
||||
NotFound,
|
||||
WrongUser,
|
||||
OverLimit,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) struct InvalidReference {
|
||||
pub raw: String,
|
||||
pub normalized: Option<String>,
|
||||
pub reason: InvalidReferenceReason,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
pub(crate) struct ReferenceReasonStats {
|
||||
pub total: usize,
|
||||
pub empty: usize,
|
||||
pub unsupported_prefix: usize,
|
||||
pub malformed_uuid: usize,
|
||||
pub duplicate: usize,
|
||||
pub not_in_context: usize,
|
||||
pub not_found: usize,
|
||||
pub wrong_user: usize,
|
||||
pub over_limit: usize,
|
||||
}
|
||||
|
||||
impl ReferenceReasonStats {
|
||||
fn record(&mut self, reason: &InvalidReferenceReason) {
|
||||
match reason {
|
||||
InvalidReferenceReason::Empty => self.empty += 1,
|
||||
InvalidReferenceReason::UnsupportedPrefix => self.unsupported_prefix += 1,
|
||||
InvalidReferenceReason::MalformedUuid => self.malformed_uuid += 1,
|
||||
InvalidReferenceReason::Duplicate => self.duplicate += 1,
|
||||
InvalidReferenceReason::NotInContext => self.not_in_context += 1,
|
||||
InvalidReferenceReason::NotFound => self.not_found += 1,
|
||||
InvalidReferenceReason::WrongUser => self.wrong_user += 1,
|
||||
InvalidReferenceReason::OverLimit => self.over_limit += 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub(crate) struct ReferenceValidationResult {
|
||||
pub valid_refs: Vec<String>,
|
||||
pub invalid_refs: Vec<InvalidReference>,
|
||||
pub reason_stats: ReferenceReasonStats,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) enum ReferenceLookupTarget {
|
||||
TextChunk,
|
||||
KnowledgeEntity,
|
||||
Any,
|
||||
}
|
||||
|
||||
pub(crate) fn collect_reference_ids_from_retrieval(
|
||||
retrieval_result: &StrategyOutput,
|
||||
) -> Vec<String> {
|
||||
let mut ids = Vec::new();
|
||||
let mut seen = HashSet::new();
|
||||
|
||||
match retrieval_result {
|
||||
StrategyOutput::Chunks(chunks) => {
|
||||
for chunk in chunks {
|
||||
let id = chunk.chunk.id.clone();
|
||||
if seen.insert(id.clone()) {
|
||||
ids.push(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
StrategyOutput::Entities(entities) => {
|
||||
for entity in entities {
|
||||
let id = entity.entity.id.clone();
|
||||
if seen.insert(id.clone()) {
|
||||
ids.push(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
StrategyOutput::Search(search) => {
|
||||
for chunk in &search.chunks {
|
||||
let id = chunk.chunk.id.clone();
|
||||
if seen.insert(id.clone()) {
|
||||
ids.push(id);
|
||||
}
|
||||
}
|
||||
for entity in &search.entities {
|
||||
let id = entity.entity.id.clone();
|
||||
if seen.insert(id.clone()) {
|
||||
ids.push(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ids
|
||||
}
|
||||
|
||||
pub(crate) async fn validate_references(
|
||||
user_id: &str,
|
||||
refs: Vec<String>,
|
||||
allowed_ids: &[String],
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<ReferenceValidationResult, AppError> {
|
||||
let mut result = ReferenceValidationResult::default();
|
||||
result.reason_stats.total = refs.len();
|
||||
|
||||
let mut seen = HashSet::new();
|
||||
let allowed_set: HashSet<&str> = allowed_ids.iter().map(String::as_str).collect();
|
||||
let enforce_context = !allowed_set.is_empty();
|
||||
|
||||
for raw in refs {
|
||||
let (normalized, target) = match normalize_reference(&raw) {
|
||||
Ok(parsed) => parsed,
|
||||
Err(reason) => {
|
||||
result.reason_stats.record(&reason);
|
||||
result.invalid_refs.push(InvalidReference {
|
||||
raw,
|
||||
normalized: None,
|
||||
reason,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if !seen.insert(normalized.clone()) {
|
||||
let reason = InvalidReferenceReason::Duplicate;
|
||||
result.reason_stats.record(&reason);
|
||||
result.invalid_refs.push(InvalidReference {
|
||||
raw,
|
||||
normalized: Some(normalized),
|
||||
reason,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if result.valid_refs.len() >= MAX_REFERENCE_COUNT {
|
||||
let reason = InvalidReferenceReason::OverLimit;
|
||||
result.reason_stats.record(&reason);
|
||||
result.invalid_refs.push(InvalidReference {
|
||||
raw,
|
||||
normalized: Some(normalized),
|
||||
reason,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if enforce_context && !allowed_set.contains(normalized.as_str()) {
|
||||
let reason = InvalidReferenceReason::NotInContext;
|
||||
result.reason_stats.record(&reason);
|
||||
result.invalid_refs.push(InvalidReference {
|
||||
raw,
|
||||
normalized: Some(normalized),
|
||||
reason,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
match lookup_reference_for_user(&normalized, &target, user_id, db).await? {
|
||||
LookupResult::Found => result.valid_refs.push(normalized),
|
||||
LookupResult::WrongUser => {
|
||||
let reason = InvalidReferenceReason::WrongUser;
|
||||
result.reason_stats.record(&reason);
|
||||
result.invalid_refs.push(InvalidReference {
|
||||
raw,
|
||||
normalized: Some(normalized),
|
||||
reason,
|
||||
});
|
||||
}
|
||||
LookupResult::NotFound => {
|
||||
let reason = InvalidReferenceReason::NotFound;
|
||||
result.reason_stats.record(&reason);
|
||||
result.invalid_refs.push(InvalidReference {
|
||||
raw,
|
||||
normalized: Some(normalized),
|
||||
reason,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub(crate) fn normalize_reference(
|
||||
raw: &str,
|
||||
) -> Result<(String, ReferenceLookupTarget), InvalidReferenceReason> {
|
||||
let trimmed = raw.trim();
|
||||
if trimmed.is_empty() {
|
||||
return Err(InvalidReferenceReason::Empty);
|
||||
}
|
||||
|
||||
let (candidate, target) = if let Some((prefix, rest)) = trimmed.split_once(':') {
|
||||
let lookup_target = if prefix.eq_ignore_ascii_case("knowledge_entity") {
|
||||
ReferenceLookupTarget::KnowledgeEntity
|
||||
} else if prefix.eq_ignore_ascii_case("text_chunk") {
|
||||
ReferenceLookupTarget::TextChunk
|
||||
} else {
|
||||
return Err(InvalidReferenceReason::UnsupportedPrefix);
|
||||
};
|
||||
|
||||
(rest.trim(), lookup_target)
|
||||
} else {
|
||||
(trimmed, ReferenceLookupTarget::Any)
|
||||
};
|
||||
|
||||
if candidate.is_empty() {
|
||||
return Err(InvalidReferenceReason::MalformedUuid);
|
||||
}
|
||||
|
||||
Uuid::parse_str(candidate)
|
||||
.map(|uuid| (uuid.to_string(), target))
|
||||
.map_err(|_| InvalidReferenceReason::MalformedUuid)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum LookupResult {
|
||||
Found,
|
||||
WrongUser,
|
||||
NotFound,
|
||||
}
|
||||
|
||||
async fn lookup_reference_for_user(
|
||||
id: &str,
|
||||
target: &ReferenceLookupTarget,
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<LookupResult, AppError> {
|
||||
match target {
|
||||
ReferenceLookupTarget::TextChunk => lookup_single_type::<TextChunk>(id, user_id, db).await,
|
||||
ReferenceLookupTarget::KnowledgeEntity => {
|
||||
lookup_single_type::<KnowledgeEntity>(id, user_id, db).await
|
||||
}
|
||||
ReferenceLookupTarget::Any => {
|
||||
let chunk_result = lookup_single_type::<TextChunk>(id, user_id, db).await?;
|
||||
if chunk_result == LookupResult::Found {
|
||||
return Ok(LookupResult::Found);
|
||||
}
|
||||
|
||||
let entity_result = lookup_single_type::<KnowledgeEntity>(id, user_id, db).await?;
|
||||
if entity_result == LookupResult::Found {
|
||||
return Ok(LookupResult::Found);
|
||||
}
|
||||
|
||||
if chunk_result == LookupResult::WrongUser || entity_result == LookupResult::WrongUser {
|
||||
return Ok(LookupResult::WrongUser);
|
||||
}
|
||||
|
||||
Ok(LookupResult::NotFound)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn lookup_single_type<T>(
|
||||
id: &str,
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<LookupResult, AppError>
|
||||
where
|
||||
T: StoredObject + for<'de> serde::Deserialize<'de> + HasUserId,
|
||||
{
|
||||
let item = db.get_item::<T>(id).await?;
|
||||
Ok(match item {
|
||||
Some(item) if item.user_id() == user_id => LookupResult::Found,
|
||||
Some(_) => LookupResult::WrongUser,
|
||||
None => LookupResult::NotFound,
|
||||
})
|
||||
}
|
||||
|
||||
trait HasUserId {
|
||||
fn user_id(&self) -> &str;
|
||||
}
|
||||
|
||||
impl HasUserId for TextChunk {
|
||||
fn user_id(&self) -> &str {
|
||||
&self.user_id
|
||||
}
|
||||
}
|
||||
|
||||
impl HasUserId for KnowledgeEntity {
|
||||
fn user_id(&self) -> &str {
|
||||
&self.user_id
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(
|
||||
clippy::cloned_ref_to_slice_refs,
|
||||
clippy::expect_used,
|
||||
clippy::indexing_slicing
|
||||
)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use common::storage::types::knowledge_entity::KnowledgeEntityType;
|
||||
use surrealdb::engine::any::connect;
|
||||
|
||||
async fn setup_test_db() -> SurrealDbClient {
|
||||
let client = connect("mem://")
|
||||
.await
|
||||
.expect("failed to create in-memory surrealdb client");
|
||||
let namespace = format!("test_ns_{}", Uuid::new_v4());
|
||||
let database = format!("test_db_{}", Uuid::new_v4());
|
||||
client
|
||||
.use_ns(namespace)
|
||||
.use_db(database)
|
||||
.await
|
||||
.expect("failed to select namespace/db");
|
||||
|
||||
let db = SurrealDbClient { client };
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("failed to apply migrations");
|
||||
db
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn valid_uuid_exists_and_belongs_to_user() {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "user-a";
|
||||
let entity = KnowledgeEntity::new(
|
||||
"source-1".to_string(),
|
||||
"Entity A".to_string(),
|
||||
"Entity description".to_string(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
user_id.to_string(),
|
||||
);
|
||||
db.store_item(entity.clone())
|
||||
.await
|
||||
.expect("failed to store entity");
|
||||
|
||||
let result =
|
||||
validate_references(user_id, vec![entity.id.clone()], &[entity.id.clone()], &db)
|
||||
.await
|
||||
.expect("validation should not fail");
|
||||
|
||||
assert_eq!(result.valid_refs, vec![entity.id]);
|
||||
assert!(result.invalid_refs.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn valid_uuid_exists_but_wrong_user_is_rejected() {
|
||||
let db = setup_test_db().await;
|
||||
let entity = KnowledgeEntity::new(
|
||||
"source-1".to_string(),
|
||||
"Entity B".to_string(),
|
||||
"Entity description".to_string(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
"other-user".to_string(),
|
||||
);
|
||||
db.store_item(entity.clone())
|
||||
.await
|
||||
.expect("failed to store entity");
|
||||
|
||||
let result =
|
||||
validate_references("user-a", vec![entity.id.clone()], &[entity.id.clone()], &db)
|
||||
.await
|
||||
.expect("validation should not fail");
|
||||
|
||||
assert!(result.valid_refs.is_empty());
|
||||
assert_eq!(result.invalid_refs.len(), 1);
|
||||
assert_eq!(
|
||||
result.invalid_refs[0].reason,
|
||||
InvalidReferenceReason::WrongUser
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn malformed_uuid_is_rejected() {
|
||||
let db = setup_test_db().await;
|
||||
let result = validate_references(
|
||||
"user-a",
|
||||
vec!["not-a-uuid".to_string()],
|
||||
&["not-a-uuid".to_string()],
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
.expect("validation should not fail");
|
||||
|
||||
assert!(result.valid_refs.is_empty());
|
||||
assert_eq!(result.invalid_refs.len(), 1);
|
||||
assert_eq!(
|
||||
result.invalid_refs[0].reason,
|
||||
InvalidReferenceReason::MalformedUuid
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mixed_duplicates_are_deduped() {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "user-a";
|
||||
|
||||
let first = KnowledgeEntity::new(
|
||||
"source-1".to_string(),
|
||||
"Entity 1".to_string(),
|
||||
"Entity description".to_string(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
user_id.to_string(),
|
||||
);
|
||||
let second = KnowledgeEntity::new(
|
||||
"source-2".to_string(),
|
||||
"Entity 2".to_string(),
|
||||
"Entity description".to_string(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
user_id.to_string(),
|
||||
);
|
||||
db.store_item(first.clone())
|
||||
.await
|
||||
.expect("failed to store first entity");
|
||||
db.store_item(second.clone())
|
||||
.await
|
||||
.expect("failed to store second entity");
|
||||
|
||||
let refs = vec![
|
||||
first.id.clone(),
|
||||
format!("knowledge_entity:{}", first.id),
|
||||
second.id.clone(),
|
||||
second.id.clone(),
|
||||
];
|
||||
|
||||
let allowed = vec![first.id.clone(), second.id.clone()];
|
||||
let result = validate_references(user_id, refs, &allowed, &db)
|
||||
.await
|
||||
.expect("validation should not fail");
|
||||
|
||||
assert_eq!(result.valid_refs, vec![first.id, second.id]);
|
||||
assert_eq!(result.invalid_refs.len(), 2);
|
||||
assert!(result
|
||||
.invalid_refs
|
||||
.iter()
|
||||
.all(|entry| entry.reason == InvalidReferenceReason::Duplicate));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn bare_uuid_prefers_chunk_lookup_before_entity() {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "user-a";
|
||||
let chunk = TextChunk::new(
|
||||
"source-1".to_string(),
|
||||
"Chunk body".to_string(),
|
||||
user_id.to_string(),
|
||||
);
|
||||
db.store_item(chunk.clone())
|
||||
.await
|
||||
.expect("failed to store chunk");
|
||||
|
||||
let result = validate_references(user_id, vec![chunk.id.clone()], &[chunk.id.clone()], &db)
|
||||
.await
|
||||
.expect("validation should not fail");
|
||||
|
||||
assert_eq!(result.valid_refs, vec![chunk.id]);
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,15 @@
|
||||
#![allow(clippy::missing_docs_in_private_items)]
|
||||
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use chrono::{DateTime, Utc};
|
||||
use chrono_tz::Tz;
|
||||
use serde::Serialize;
|
||||
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::types::{knowledge_entity::KnowledgeEntity, user::User},
|
||||
use common::storage::types::{
|
||||
knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, user::User,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@@ -17,29 +20,101 @@ use crate::{
|
||||
},
|
||||
};
|
||||
|
||||
use super::reference_validation::{normalize_reference, ReferenceLookupTarget};
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ReferenceTooltipData {
|
||||
text_chunk: Option<TextChunk>,
|
||||
text_chunk_updated_at: Option<String>,
|
||||
entity: Option<KnowledgeEntity>,
|
||||
entity_updated_at: Option<String>,
|
||||
user: User,
|
||||
}
|
||||
|
||||
fn format_datetime_for_user(datetime: DateTime<Utc>, timezone: &str) -> String {
|
||||
match timezone.parse::<Tz>() {
|
||||
Ok(tz) => datetime
|
||||
.with_timezone(&tz)
|
||||
.format("%Y-%m-%d %H:%M:%S")
|
||||
.to_string(),
|
||||
Err(_) => datetime.format("%Y-%m-%d %H:%M:%S").to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn show_reference_tooltip(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Path(reference_id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let entity: KnowledgeEntity = state
|
||||
.db
|
||||
.get_item(&reference_id)
|
||||
.await?
|
||||
.ok_or_else(|| AppError::NotFound("Item was not found".to_string()))?;
|
||||
let Ok((normalized_reference_id, target)) = normalize_reference(&reference_id) else {
|
||||
return Ok(TemplateResponse::not_found());
|
||||
};
|
||||
|
||||
if entity.user_id != user.id {
|
||||
return Ok(TemplateResponse::unauthorized());
|
||||
let lookup_order = match target {
|
||||
ReferenceLookupTarget::TextChunk | ReferenceLookupTarget::Any => [
|
||||
ReferenceLookupTarget::TextChunk,
|
||||
ReferenceLookupTarget::KnowledgeEntity,
|
||||
],
|
||||
ReferenceLookupTarget::KnowledgeEntity => [
|
||||
ReferenceLookupTarget::KnowledgeEntity,
|
||||
ReferenceLookupTarget::TextChunk,
|
||||
],
|
||||
};
|
||||
|
||||
let mut text_chunk: Option<TextChunk> = None;
|
||||
let mut knowledge_entity: Option<KnowledgeEntity> = None;
|
||||
|
||||
for lookup_target in lookup_order {
|
||||
match lookup_target {
|
||||
ReferenceLookupTarget::TextChunk => {
|
||||
if let Some(chunk) = state
|
||||
.db
|
||||
.get_item::<TextChunk>(&normalized_reference_id)
|
||||
.await?
|
||||
{
|
||||
if chunk.user_id != user.id {
|
||||
return Ok(TemplateResponse::unauthorized());
|
||||
}
|
||||
text_chunk = Some(chunk);
|
||||
break;
|
||||
}
|
||||
}
|
||||
ReferenceLookupTarget::KnowledgeEntity => {
|
||||
if let Some(entity) = state
|
||||
.db
|
||||
.get_item::<KnowledgeEntity>(&normalized_reference_id)
|
||||
.await?
|
||||
{
|
||||
if entity.user_id != user.id {
|
||||
return Ok(TemplateResponse::unauthorized());
|
||||
}
|
||||
knowledge_entity = Some(entity);
|
||||
break;
|
||||
}
|
||||
}
|
||||
ReferenceLookupTarget::Any => {}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ReferenceTooltipData {
|
||||
entity: KnowledgeEntity,
|
||||
user: User,
|
||||
if text_chunk.is_none() && knowledge_entity.is_none() {
|
||||
return Ok(TemplateResponse::not_found());
|
||||
}
|
||||
|
||||
let text_chunk_updated_at = text_chunk
|
||||
.as_ref()
|
||||
.map(|chunk| format_datetime_for_user(chunk.updated_at, &user.timezone));
|
||||
let entity_updated_at = knowledge_entity
|
||||
.as_ref()
|
||||
.map(|entity| format_datetime_for_user(entity.updated_at, &user.timezone));
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
"chat/reference_tooltip.html",
|
||||
ReferenceTooltipData { entity, user },
|
||||
ReferenceTooltipData {
|
||||
text_chunk,
|
||||
text_chunk_updated_at,
|
||||
entity: knowledge_entity,
|
||||
entity_updated_at,
|
||||
user,
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
@@ -111,12 +111,23 @@
|
||||
// Load content if needed
|
||||
if (!tooltipContent) {
|
||||
fetch(`/chat/reference/${encodeURIComponent(reference)}`)
|
||||
.then(response => response.text())
|
||||
.then(response => {
|
||||
if (!response.ok) {
|
||||
throw new Error(`reference lookup failed with status ${response.status}`);
|
||||
}
|
||||
return response.text();
|
||||
})
|
||||
.then(html => {
|
||||
tooltipContent = html;
|
||||
if (document.getElementById(tooltipId)) {
|
||||
document.getElementById(tooltipId).innerHTML = html;
|
||||
}
|
||||
})
|
||||
.catch(() => {
|
||||
tooltipContent = '<div class="text-xs opacity-70">Reference unavailable.</div>';
|
||||
if (document.getElementById(tooltipId)) {
|
||||
document.getElementById(tooltipId).innerHTML = tooltipContent;
|
||||
}
|
||||
});
|
||||
} else if (tooltip) {
|
||||
// Set content if already loaded
|
||||
|
||||
@@ -1,3 +1,11 @@
|
||||
<div>{{entity.name}}</div>
|
||||
<div>{{entity.description}}</div>
|
||||
<div>{{entity.updated_at|datetimeformat(format="short", tz=user.timezone)}} </div>
|
||||
{% if text_chunk %}
|
||||
<div class="font-semibold">Chunk Reference</div>
|
||||
<div class="text-sm whitespace-pre-wrap">{{text_chunk.chunk}}</div>
|
||||
<div class="text-xs opacity-70">{{text_chunk_updated_at}}</div>
|
||||
{% elif entity %}
|
||||
<div class="font-semibold">{{entity.name}}</div>
|
||||
<div class="text-sm">{{entity.description}}</div>
|
||||
<div class="text-xs opacity-70">{{entity_updated_at}}</div>
|
||||
{% else %}
|
||||
<div class="text-xs opacity-70">Reference unavailable.</div>
|
||||
{% endif %}
|
||||
|
||||
@@ -61,8 +61,8 @@ pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value {
|
||||
.iter()
|
||||
.map(|chunk| {
|
||||
serde_json::json!({
|
||||
"id": chunk.chunk.id,
|
||||
"content": chunk.chunk.chunk,
|
||||
"source_id": chunk.chunk.source_id,
|
||||
"score": round_score(chunk.score),
|
||||
})
|
||||
})
|
||||
@@ -117,7 +117,7 @@ pub fn create_chat_request(
|
||||
.build()
|
||||
}
|
||||
|
||||
pub async fn process_llm_response(
|
||||
pub fn process_llm_response(
|
||||
response: CreateChatCompletionResponse,
|
||||
) -> Result<LLMResponseFormat, AppError> {
|
||||
response
|
||||
|
||||
Reference in New Issue
Block a user