mirror of
https://github.com/perstarkse/minne.git
synced 2026-06-30 10:01:40 +02:00
488 lines
15 KiB
Rust
488 lines
15 KiB
Rust
#![allow(clippy::missing_docs_in_private_items)]
|
|
|
|
use std::collections::HashSet;
|
|
|
|
use common::{
|
|
error::AppError,
|
|
storage::{
|
|
db::SurrealDbClient,
|
|
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject},
|
|
},
|
|
};
|
|
use retrieval_pipeline::StrategyOutput;
|
|
use uuid::Uuid;
|
|
|
|
pub(crate) const MAX_REFERENCE_COUNT: usize = 10;
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub(crate) enum InvalidReferenceReason {
|
|
Empty,
|
|
UnsupportedPrefix,
|
|
MalformedUuid,
|
|
Duplicate,
|
|
NotInContext,
|
|
NotFound,
|
|
WrongUser,
|
|
OverLimit,
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub(crate) struct InvalidReference {
|
|
pub raw: String,
|
|
pub normalized: Option<String>,
|
|
pub reason: InvalidReferenceReason,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
|
pub(crate) struct ReferenceReasonStats {
|
|
pub total: usize,
|
|
pub empty: usize,
|
|
pub unsupported_prefix: usize,
|
|
pub malformed_uuid: usize,
|
|
pub duplicate: usize,
|
|
pub not_in_context: usize,
|
|
pub not_found: usize,
|
|
pub wrong_user: usize,
|
|
pub over_limit: usize,
|
|
}
|
|
|
|
impl ReferenceReasonStats {
|
|
fn record(&mut self, reason: &InvalidReferenceReason) {
|
|
match reason {
|
|
InvalidReferenceReason::Empty => self.empty = self.empty.saturating_add(1),
|
|
InvalidReferenceReason::UnsupportedPrefix => {
|
|
self.unsupported_prefix = self.unsupported_prefix.saturating_add(1)
|
|
}
|
|
InvalidReferenceReason::MalformedUuid => {
|
|
self.malformed_uuid = self.malformed_uuid.saturating_add(1)
|
|
}
|
|
InvalidReferenceReason::Duplicate => self.duplicate = self.duplicate.saturating_add(1),
|
|
InvalidReferenceReason::NotInContext => {
|
|
self.not_in_context = self.not_in_context.saturating_add(1)
|
|
}
|
|
InvalidReferenceReason::NotFound => self.not_found = self.not_found.saturating_add(1),
|
|
InvalidReferenceReason::WrongUser => {
|
|
self.wrong_user = self.wrong_user.saturating_add(1)
|
|
}
|
|
InvalidReferenceReason::OverLimit => {
|
|
self.over_limit = self.over_limit.saturating_add(1)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Default)]
|
|
pub(crate) struct ReferenceValidationResult {
|
|
pub valid_refs: Vec<String>,
|
|
pub invalid_refs: Vec<InvalidReference>,
|
|
pub reason_stats: ReferenceReasonStats,
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub(crate) enum ReferenceLookupTarget {
|
|
TextChunk,
|
|
KnowledgeEntity,
|
|
Any,
|
|
}
|
|
|
|
pub(crate) fn collect_reference_ids_from_retrieval(
|
|
retrieval_result: &StrategyOutput,
|
|
) -> Vec<String> {
|
|
let mut ids = Vec::new();
|
|
let mut seen = HashSet::new();
|
|
|
|
match retrieval_result {
|
|
StrategyOutput::Chunks(chunks) => {
|
|
for chunk in chunks {
|
|
let id = chunk.chunk.id.clone();
|
|
if seen.insert(id.clone()) {
|
|
ids.push(id);
|
|
}
|
|
}
|
|
}
|
|
StrategyOutput::Entities(entities) => {
|
|
for entity in entities {
|
|
let id = entity.entity.id.clone();
|
|
if seen.insert(id.clone()) {
|
|
ids.push(id);
|
|
}
|
|
}
|
|
}
|
|
StrategyOutput::Search(search) => {
|
|
for chunk in &search.chunks {
|
|
let id = chunk.chunk.id.clone();
|
|
if seen.insert(id.clone()) {
|
|
ids.push(id);
|
|
}
|
|
}
|
|
for entity in &search.entities {
|
|
let id = entity.entity.id.clone();
|
|
if seen.insert(id.clone()) {
|
|
ids.push(id);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
ids
|
|
}
|
|
|
|
pub(crate) async fn validate_references(
|
|
user_id: &str,
|
|
refs: Vec<String>,
|
|
allowed_ids: &[String],
|
|
db: &SurrealDbClient,
|
|
) -> Result<ReferenceValidationResult, AppError> {
|
|
let mut result = ReferenceValidationResult::default();
|
|
result.reason_stats.total = refs.len();
|
|
|
|
let mut seen = HashSet::new();
|
|
let allowed_set: HashSet<&str> = allowed_ids.iter().map(String::as_str).collect();
|
|
let enforce_context = !allowed_set.is_empty();
|
|
|
|
for raw in refs {
|
|
let (normalized, target) = match normalize_reference(&raw) {
|
|
Ok(parsed) => parsed,
|
|
Err(reason) => {
|
|
result.reason_stats.record(&reason);
|
|
result.invalid_refs.push(InvalidReference {
|
|
raw,
|
|
normalized: None,
|
|
reason,
|
|
});
|
|
continue;
|
|
}
|
|
};
|
|
|
|
if !seen.insert(normalized.clone()) {
|
|
let reason = InvalidReferenceReason::Duplicate;
|
|
result.reason_stats.record(&reason);
|
|
result.invalid_refs.push(InvalidReference {
|
|
raw,
|
|
normalized: Some(normalized),
|
|
reason,
|
|
});
|
|
continue;
|
|
}
|
|
|
|
if result.valid_refs.len() >= MAX_REFERENCE_COUNT {
|
|
let reason = InvalidReferenceReason::OverLimit;
|
|
result.reason_stats.record(&reason);
|
|
result.invalid_refs.push(InvalidReference {
|
|
raw,
|
|
normalized: Some(normalized),
|
|
reason,
|
|
});
|
|
continue;
|
|
}
|
|
|
|
if enforce_context && !allowed_set.contains(normalized.as_str()) {
|
|
let reason = InvalidReferenceReason::NotInContext;
|
|
result.reason_stats.record(&reason);
|
|
result.invalid_refs.push(InvalidReference {
|
|
raw,
|
|
normalized: Some(normalized),
|
|
reason,
|
|
});
|
|
continue;
|
|
}
|
|
|
|
match lookup_reference_for_user(&normalized, &target, user_id, db).await? {
|
|
LookupResult::Found => result.valid_refs.push(normalized),
|
|
LookupResult::WrongUser => {
|
|
let reason = InvalidReferenceReason::WrongUser;
|
|
result.reason_stats.record(&reason);
|
|
result.invalid_refs.push(InvalidReference {
|
|
raw,
|
|
normalized: Some(normalized),
|
|
reason,
|
|
});
|
|
}
|
|
LookupResult::NotFound => {
|
|
let reason = InvalidReferenceReason::NotFound;
|
|
result.reason_stats.record(&reason);
|
|
result.invalid_refs.push(InvalidReference {
|
|
raw,
|
|
normalized: Some(normalized),
|
|
reason,
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(result)
|
|
}
|
|
|
|
pub(crate) fn normalize_reference(
|
|
raw: &str,
|
|
) -> Result<(String, ReferenceLookupTarget), InvalidReferenceReason> {
|
|
let trimmed = raw.trim();
|
|
if trimmed.is_empty() {
|
|
return Err(InvalidReferenceReason::Empty);
|
|
}
|
|
|
|
let (candidate, target) = if let Some((prefix, rest)) = trimmed.split_once(':') {
|
|
let lookup_target = if prefix.eq_ignore_ascii_case("knowledge_entity") {
|
|
ReferenceLookupTarget::KnowledgeEntity
|
|
} else if prefix.eq_ignore_ascii_case("text_chunk") {
|
|
ReferenceLookupTarget::TextChunk
|
|
} else {
|
|
return Err(InvalidReferenceReason::UnsupportedPrefix);
|
|
};
|
|
|
|
(rest.trim(), lookup_target)
|
|
} else {
|
|
(trimmed, ReferenceLookupTarget::Any)
|
|
};
|
|
|
|
if candidate.is_empty() {
|
|
return Err(InvalidReferenceReason::MalformedUuid);
|
|
}
|
|
|
|
Uuid::parse_str(candidate)
|
|
.map(|uuid| (uuid.to_string(), target))
|
|
.map_err(|_| InvalidReferenceReason::MalformedUuid)
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
enum LookupResult {
|
|
Found,
|
|
WrongUser,
|
|
NotFound,
|
|
}
|
|
|
|
async fn lookup_reference_for_user(
|
|
id: &str,
|
|
target: &ReferenceLookupTarget,
|
|
user_id: &str,
|
|
db: &SurrealDbClient,
|
|
) -> Result<LookupResult, AppError> {
|
|
match target {
|
|
ReferenceLookupTarget::TextChunk => lookup_single_type::<TextChunk>(id, user_id, db).await,
|
|
ReferenceLookupTarget::KnowledgeEntity => {
|
|
lookup_single_type::<KnowledgeEntity>(id, user_id, db).await
|
|
}
|
|
ReferenceLookupTarget::Any => {
|
|
let chunk_result = lookup_single_type::<TextChunk>(id, user_id, db).await?;
|
|
if chunk_result == LookupResult::Found {
|
|
return Ok(LookupResult::Found);
|
|
}
|
|
|
|
let entity_result = lookup_single_type::<KnowledgeEntity>(id, user_id, db).await?;
|
|
if entity_result == LookupResult::Found {
|
|
return Ok(LookupResult::Found);
|
|
}
|
|
|
|
if chunk_result == LookupResult::WrongUser || entity_result == LookupResult::WrongUser {
|
|
return Ok(LookupResult::WrongUser);
|
|
}
|
|
|
|
Ok(LookupResult::NotFound)
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn lookup_single_type<T>(
|
|
id: &str,
|
|
user_id: &str,
|
|
db: &SurrealDbClient,
|
|
) -> Result<LookupResult, AppError>
|
|
where
|
|
T: StoredObject + for<'de> serde::Deserialize<'de> + HasUserId,
|
|
{
|
|
let item = db.get_item::<T>(id).await?;
|
|
Ok(match item {
|
|
Some(item) if item.user_id() == user_id => LookupResult::Found,
|
|
Some(_) => LookupResult::WrongUser,
|
|
None => LookupResult::NotFound,
|
|
})
|
|
}
|
|
|
|
trait HasUserId {
|
|
fn user_id(&self) -> &str;
|
|
}
|
|
|
|
impl HasUserId for TextChunk {
|
|
fn user_id(&self) -> &str {
|
|
&self.user_id
|
|
}
|
|
}
|
|
|
|
impl HasUserId for KnowledgeEntity {
|
|
fn user_id(&self) -> &str {
|
|
&self.user_id
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
#[allow(
|
|
clippy::cloned_ref_to_slice_refs,
|
|
clippy::expect_used,
|
|
clippy::indexing_slicing
|
|
)]
|
|
mod tests {
|
|
use super::*;
|
|
use common::storage::types::knowledge_entity::KnowledgeEntityType;
|
|
use surrealdb::engine::any::connect;
|
|
|
|
async fn setup_test_db() -> SurrealDbClient {
|
|
let client = connect("mem://")
|
|
.await
|
|
.expect("failed to create in-memory surrealdb client");
|
|
let namespace = format!("test_ns_{}", Uuid::new_v4());
|
|
let database = format!("test_db_{}", Uuid::new_v4());
|
|
client
|
|
.use_ns(namespace)
|
|
.use_db(database)
|
|
.await
|
|
.expect("failed to select namespace/db");
|
|
|
|
let db = SurrealDbClient { client };
|
|
db.apply_migrations()
|
|
.await
|
|
.expect("failed to apply migrations");
|
|
db
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn valid_uuid_exists_and_belongs_to_user() {
|
|
let db = setup_test_db().await;
|
|
let user_id = "user-a";
|
|
let entity = KnowledgeEntity::new(
|
|
"source-1".to_string(),
|
|
"Entity A".to_string(),
|
|
"Entity description".to_string(),
|
|
KnowledgeEntityType::Document,
|
|
None,
|
|
user_id.to_string(),
|
|
);
|
|
db.store_item(entity.clone())
|
|
.await
|
|
.expect("failed to store entity");
|
|
|
|
let result =
|
|
validate_references(user_id, vec![entity.id.clone()], &[entity.id.clone()], &db)
|
|
.await
|
|
.expect("validation should not fail");
|
|
|
|
assert_eq!(result.valid_refs, vec![entity.id]);
|
|
assert!(result.invalid_refs.is_empty());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn valid_uuid_exists_but_wrong_user_is_rejected() {
|
|
let db = setup_test_db().await;
|
|
let entity = KnowledgeEntity::new(
|
|
"source-1".to_string(),
|
|
"Entity B".to_string(),
|
|
"Entity description".to_string(),
|
|
KnowledgeEntityType::Document,
|
|
None,
|
|
"other-user".to_string(),
|
|
);
|
|
db.store_item(entity.clone())
|
|
.await
|
|
.expect("failed to store entity");
|
|
|
|
let result =
|
|
validate_references("user-a", vec![entity.id.clone()], &[entity.id.clone()], &db)
|
|
.await
|
|
.expect("validation should not fail");
|
|
|
|
assert!(result.valid_refs.is_empty());
|
|
assert_eq!(result.invalid_refs.len(), 1);
|
|
assert_eq!(
|
|
result.invalid_refs[0].reason,
|
|
InvalidReferenceReason::WrongUser
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn malformed_uuid_is_rejected() {
|
|
let db = setup_test_db().await;
|
|
let result = validate_references(
|
|
"user-a",
|
|
vec!["not-a-uuid".to_string()],
|
|
&["not-a-uuid".to_string()],
|
|
&db,
|
|
)
|
|
.await
|
|
.expect("validation should not fail");
|
|
|
|
assert!(result.valid_refs.is_empty());
|
|
assert_eq!(result.invalid_refs.len(), 1);
|
|
assert_eq!(
|
|
result.invalid_refs[0].reason,
|
|
InvalidReferenceReason::MalformedUuid
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn mixed_duplicates_are_deduped() {
|
|
let db = setup_test_db().await;
|
|
let user_id = "user-a";
|
|
|
|
let first = KnowledgeEntity::new(
|
|
"source-1".to_string(),
|
|
"Entity 1".to_string(),
|
|
"Entity description".to_string(),
|
|
KnowledgeEntityType::Document,
|
|
None,
|
|
user_id.to_string(),
|
|
);
|
|
let second = KnowledgeEntity::new(
|
|
"source-2".to_string(),
|
|
"Entity 2".to_string(),
|
|
"Entity description".to_string(),
|
|
KnowledgeEntityType::Document,
|
|
None,
|
|
user_id.to_string(),
|
|
);
|
|
db.store_item(first.clone())
|
|
.await
|
|
.expect("failed to store first entity");
|
|
db.store_item(second.clone())
|
|
.await
|
|
.expect("failed to store second entity");
|
|
|
|
let refs = vec![
|
|
first.id.clone(),
|
|
format!("knowledge_entity:{}", first.id),
|
|
second.id.clone(),
|
|
second.id.clone(),
|
|
];
|
|
|
|
let allowed = vec![first.id.clone(), second.id.clone()];
|
|
let result = validate_references(user_id, refs, &allowed, &db)
|
|
.await
|
|
.expect("validation should not fail");
|
|
|
|
assert_eq!(result.valid_refs, vec![first.id, second.id]);
|
|
assert_eq!(result.invalid_refs.len(), 2);
|
|
assert!(result
|
|
.invalid_refs
|
|
.iter()
|
|
.all(|entry| entry.reason == InvalidReferenceReason::Duplicate));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn bare_uuid_prefers_chunk_lookup_before_entity() {
|
|
let db = setup_test_db().await;
|
|
let user_id = "user-a";
|
|
let chunk = TextChunk::new(
|
|
"source-1".to_string(),
|
|
"Chunk body".to_string(),
|
|
user_id.to_string(),
|
|
);
|
|
db.store_item(chunk.clone())
|
|
.await
|
|
.expect("failed to store chunk");
|
|
|
|
let result = validate_references(user_id, vec![chunk.id.clone()], &[chunk.id.clone()], &db)
|
|
.await
|
|
.expect("validation should not fail");
|
|
|
|
assert_eq!(result.valid_refs, vec![chunk.id]);
|
|
}
|
|
}
|