fix: knowledge entity suggestions simplification

This commit is contained in:
Per Stark
2026-05-31 20:23:40 +02:00
parent 3897345ab3
commit b22c351785
9 changed files with 394 additions and 122 deletions
+128 -39
View File
@@ -16,14 +16,19 @@ use serde::{
use common::{
error::AppError,
storage::types::{
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
knowledge_relationship::KnowledgeRelationship,
user::User,
storage::{
db::SurrealDbClient,
types::{
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
knowledge_relationship::KnowledgeRelationship,
user::User,
},
},
utils::embedding::generate_embedding_with_provider,
utils::embedding::{generate_embedding_with_provider, EmbeddingProvider},
};
use retrieval_pipeline::{
normalize_fts_terms, reciprocal_rank_fusion, RetrievalTuning, RrfConfig, Scored,
};
use retrieval_pipeline;
use tracing::debug;
use uuid::Uuid;
@@ -43,7 +48,6 @@ const KNOWLEDGE_ENTITIES_PER_PAGE: usize = 12;
const RELATIONSHIP_TYPE_OPTIONS: &[&str] = &["RelatedTo", "RelevantTo", "SimilarTo", "References"];
const DEFAULT_RELATIONSHIP_TYPE: &str = "RelatedTo";
const MAX_RELATIONSHIP_SUGGESTIONS: usize = 10;
const SUGGESTION_MIN_SCORE: f32 = 0.5;
const GRAPH_REFRESH_TRIGGER: &str = r#"{"knowledge-graph-refresh":true}"#;
const RELATIONSHIP_TYPE_ALIASES: &[(&str, &str)] = &[("relatesto", "RelatedTo")];
@@ -279,38 +283,30 @@ pub async fn suggest_knowledge_relationships(
}
if !query_parts.is_empty() {
let query = query_parts.join(" ");
let rerank_lease = match state.reranker_pool.as_ref() {
Some(pool) => pool.checkout().await,
None => None,
};
let name = form.name.as_deref().unwrap_or("").trim();
let description = form.description.as_deref().unwrap_or("").trim();
let entity_type = form.entity_type.as_deref().map_or(
KnowledgeEntityType::Document,
|value| KnowledgeEntityType::from(value.to_string()),
);
let config = retrieval_pipeline::RetrievalConfig::with_entities();
if let Ok(retrieval_pipeline::RetrievalOutput::WithEntities { entities, .. }) =
retrieval_pipeline::retrieve(
&state.db,
&state.openai_client,
Some(&*state.embedding_provider),
&query,
&user.id,
config,
rerank_lease,
)
.await
{
for retrieval_pipeline::RetrievedEntity { entity, score, .. } in entities {
if suggestion_scores.len() >= MAX_RELATIONSHIP_SUGGESTIONS {
break;
}
if score.is_nan() || score < SUGGESTION_MIN_SCORE {
continue;
}
if !entity_lookup.contains_key(&entity.id) {
continue;
}
suggestion_scores.insert(entity.id.clone(), score);
selected_ids.insert(entity.id.clone());
}
let suggested = suggest_related_entities(
&state.db,
&state.embedding_provider,
&user.id,
DraftEntityQuery {
name,
description,
entity_type,
search_terms: &query_parts.join(" "),
},
&entity_lookup,
)
.await?;
for (id, score) in suggested {
selected_ids.insert(id.clone());
suggestion_scores.insert(id, score);
}
}
@@ -359,6 +355,90 @@ pub struct RelationshipTableRow {
relationship_type_label: String,
}
struct DraftEntityQuery<'a> {
name: &'a str,
description: &'a str,
entity_type: KnowledgeEntityType,
search_terms: &'a str,
}
async fn suggest_related_entities(
db: &SurrealDbClient,
embedding_provider: &EmbeddingProvider,
user_id: &str,
draft: DraftEntityQuery<'_>,
entity_lookup: &HashMap<String, KnowledgeEntity>,
) -> Result<HashMap<String, f32>, AppError> {
let embedding_input = format!(
"name: {}, description: {}, type: {:?}",
draft.name, draft.description, draft.entity_type
);
let embedding =
generate_embedding_with_provider(embedding_provider, &embedding_input).await?;
let take = MAX_RELATIONSHIP_SUGGESTIONS * 2;
let tuning = RetrievalTuning::default();
let (fts_query, fts_token_count) = normalize_fts_terms(draft.search_terms);
let fts_enabled = tuning.flags.chunk_rrf_use_fts() && !fts_query.is_empty();
let suggestion_min_rrf_score = 1.0 / (tuning.chunk_rrf_k + 1.0);
let (vector_rows, fts_rows) = tokio::try_join!(
KnowledgeEntity::vector_search(take, embedding, db, user_id),
async {
if fts_enabled {
KnowledgeEntity::fts_search(take, &fts_query, db, user_id).await
} else {
Ok(Vec::new())
}
}
)?;
let fts_candidates = fts_rows.len();
let vector_scored: Vec<Scored<KnowledgeEntity>> = vector_rows
.into_iter()
.map(|row| Scored::new(row.entity).with_vector_score(row.score))
.collect();
let fts_scored: Vec<Scored<KnowledgeEntity>> = fts_rows
.into_iter()
.map(|row| Scored::new(row.entity).with_fts_score(row.score))
.collect();
let mut fts_weight = tuning.chunk_rrf_fts_weight;
if fts_enabled && fts_token_count > 0 && fts_token_count <= 3 {
fts_weight *= 1.5;
}
let fused = reciprocal_rank_fusion(
vector_scored,
fts_scored,
RrfConfig {
k: tuning.chunk_rrf_k,
vector_weight: tuning.chunk_rrf_vector_weight,
fts_weight,
use_vector: tuning.flags.chunk_rrf_use_vector(),
use_fts: tuning.flags.chunk_rrf_use_fts() && fts_candidates > 0,
},
);
let mut suggestions = HashMap::new();
for scored in fused {
if suggestions.len() >= MAX_RELATIONSHIP_SUGGESTIONS {
break;
}
if scored.fused.is_nan() || scored.fused < suggestion_min_rrf_score {
continue;
}
if !entity_lookup.contains_key(&scored.item.id) {
continue;
}
suggestions.insert(scored.item.id, scored.fused);
}
Ok(suggestions)
}
fn build_relationship_options(
entities: Vec<KnowledgeEntity>,
selected_ids: &HashSet<String>,
@@ -618,6 +698,7 @@ impl<'de> Deserialize<'de> for CreateKnowledgeEntityParams {
pub struct SuggestRelationshipsParams {
pub name: Option<String>,
pub description: Option<String>,
pub entity_type: Option<String>,
pub relationship_type: Option<String>,
pub relationship_ids: Vec<String>,
}
@@ -653,6 +734,7 @@ impl<'de> Deserialize<'de> for SuggestRelationshipsParams {
{
let mut name: Option<String> = None;
let mut description: Option<String> = None;
let mut entity_type: Option<String> = None;
let mut relationship_type: Option<String> = None;
let mut relationship_ids: Vec<String> = Vec::new();
@@ -687,7 +769,13 @@ impl<'de> Deserialize<'de> for SuggestRelationshipsParams {
}
}
Field::EntityType => {
map.next_value::<de::IgnoredAny>()?;
let value: String = map.next_value()?;
let trimmed = value.trim();
if trimmed.is_empty() {
entity_type = None;
} else {
entity_type = Some(trimmed.to_owned());
}
}
Field::RelationshipIds => {
let value: String = map.next_value()?;
@@ -702,6 +790,7 @@ impl<'de> Deserialize<'de> for SuggestRelationshipsParams {
Ok(SuggestRelationshipsParams {
name,
description,
entity_type,
relationship_type,
relationship_ids,
})