diff --git a/composite-retrieval/src/lib.rs b/composite-retrieval/src/lib.rs index e3e2ece..00a451d 100644 --- a/composite-retrieval/src/lib.rs +++ b/composite-retrieval/src/lib.rs @@ -41,12 +41,14 @@ const GRAPH_SCORE_DECAY: f32 = 0.75; const GRAPH_SEED_MIN_SCORE: f32 = 0.4; const GRAPH_VECTOR_INHERITANCE: f32 = 0.6; +// Captures a supporting chunk plus its fused retrieval score for downstream prompts. #[derive(Debug, Clone)] pub struct RetrievedChunk { pub chunk: TextChunk, pub score: f32, } +// Final entity representation returned to callers, enriched with ranked chunks. #[derive(Debug, Clone)] pub struct RetrievedEntity { pub entity: KnowledgeEntity, @@ -114,6 +116,7 @@ pub(crate) async fn retrieve_entities_with_embedding( let mut entity_candidates: HashMap> = HashMap::new(); let mut chunk_candidates: HashMap> = HashMap::new(); + // Collate raw retrieval results so each ID accumulates all available signals. merge_scored_by_id(&mut entity_candidates, vector_entities); merge_scored_by_id(&mut entity_candidates, fts_entities); merge_scored_by_id(&mut chunk_candidates, vector_chunks); @@ -190,6 +193,7 @@ pub(crate) async fn retrieve_entities_with_embedding( .collect(); if filtered_entities.len() < FALLBACK_MIN_RESULTS { + // Low recall scenarios still benefit from some context; take the top N regardless of score. filtered_entities = entity_results .into_iter() .take(FALLBACK_MIN_RESULTS) @@ -220,6 +224,7 @@ pub(crate) async fn retrieve_entities_with_embedding( Ok(assemble_results(filtered_entities, chunk_values)) } +// Minimal record used while seeding graph expansion so we can retain the original fused score. #[derive(Clone)] struct GraphSeed { id: String, @@ -250,6 +255,7 @@ async fn enrich_entities_from_graph( return Ok(()); } + // Prioritise the strongest seeds so we explore the most grounded context first. seeds.sort_by(|a, b| { b.fused .partial_cmp(&a.fused) @@ -279,6 +285,7 @@ async fn enrich_entities_from_graph( continue; } + // Fold neighbors back into the candidate map and let them inherit attenuated signal. for neighbor in neighbors { if neighbor.id == seed.id { continue; diff --git a/composite-retrieval/src/vector.rs b/composite-retrieval/src/vector.rs index 56bf2ef..d2aaae7 100644 --- a/composite-retrieval/src/vector.rs +++ b/composite-retrieval/src/vector.rs @@ -9,7 +9,7 @@ use common::{ use serde::Deserialize; use surrealdb::sql::Thing; -use crate::scoring::{distance_to_similarity, Scored}; +use crate::scoring::{clamp_unit, distance_to_similarity, Scored}; /// Compares vectors and retrieves a number of items from the specified table. /// @@ -110,10 +110,45 @@ where .map(|item| (item.get_id().to_owned(), item)) .collect(); + let mut min_distance = f32::MAX; + let mut max_distance = f32::MIN; + + for row in &distance_rows { + if let Some(distance) = row.distance { + if distance.is_finite() { + if distance < min_distance { + min_distance = distance; + } + if distance > max_distance { + max_distance = distance; + } + } + } + } + + let normalize = min_distance.is_finite() + && max_distance.is_finite() + && (max_distance - min_distance).abs() > f32::EPSILON; + let mut scored = Vec::with_capacity(distance_rows.len()); for row in distance_rows { if let Some(item) = item_map.remove(&row.id) { - let similarity = row.distance.map(distance_to_similarity).unwrap_or_default(); + let similarity = row + .distance + .map(|distance| { + if normalize { + let span = max_distance - min_distance; + if span.abs() < f32::EPSILON { + 1.0 + } else { + let normalized = 1.0 - ((distance - min_distance) / span); + clamp_unit(normalized) + } + } else { + distance_to_similarity(distance) + } + }) + .unwrap_or_default(); scored.push(Scored::new(item).with_vector_score(similarity)); } }