diff --git a/common/src/storage/db.rs b/common/src/storage/db.rs index 44b67e2..9fa05f4 100644 --- a/common/src/storage/db.rs +++ b/common/src/storage/db.rs @@ -1,7 +1,5 @@ use super::types::StoredObject; -use crate::{ - error::AppError, -}; +use crate::error::AppError; use axum_session::{SessionConfig, SessionError, SessionStore}; use axum_session_surreal::SessionSurrealPool; use futures::Stream; diff --git a/common/src/storage/indexes.rs b/common/src/storage/indexes.rs index 0cd8ace..0be82ce 100644 --- a/common/src/storage/indexes.rs +++ b/common/src/storage/indexes.rs @@ -84,11 +84,11 @@ async fn ensure_runtime_indexes_inner( ) -> Result<()> { create_fts_analyzer(db).await?; - let fts_tasks = fts_index_specs().into_iter().map(|spec| async move { + for spec in fts_index_specs() { if index_exists(db, spec.table, spec.index_name).await? { - return Ok(()); + continue; } - + // We need to create these sequentially otherwise SurrealDB errors with read/write clash create_index_with_polling( db, spec.definition(), @@ -96,48 +96,43 @@ async fn ensure_runtime_indexes_inner( spec.table, Some(spec.table), ) - .await + .await?; + } + + let hnsw_tasks = hnsw_index_specs().into_iter().map(|spec| async move { + match hnsw_index_state(db, &spec, embedding_dimension).await? { + HnswIndexState::Missing => { + create_index_with_polling( + db, + spec.definition_if_not_exists(embedding_dimension), + spec.index_name, + spec.table, + Some(spec.table), + ) + .await + } + HnswIndexState::Matches => Ok(()), + HnswIndexState::Different(existing) => { + info!( + index = spec.index_name, + table = spec.table, + existing_dimension = existing, + target_dimension = embedding_dimension, + "Overwriting HNSW index to match new embedding dimension" + ); + create_index_with_polling( + db, + spec.definition_overwrite(embedding_dimension), + spec.index_name, + spec.table, + Some(spec.table), + ) + .await + } + } }); - let hnsw_tasks = hnsw_index_specs() - .into_iter() - .map(|spec| async move { - match hnsw_index_state(db, &spec, embedding_dimension).await? { - HnswIndexState::Missing => { - create_index_with_polling( - db, - spec.definition_if_not_exists(embedding_dimension), - spec.index_name, - spec.table, - Some(spec.table), - ) - .await - } - HnswIndexState::Matches => Ok(()), - HnswIndexState::Different(existing) => { - info!( - index = spec.index_name, - table = spec.table, - existing_dimension = existing, - target_dimension = embedding_dimension, - "Overwriting HNSW index to match new embedding dimension" - ); - create_index_with_polling( - db, - spec.definition_overwrite(embedding_dimension), - spec.index_name, - spec.table, - Some(spec.table), - ) - .await - } - } - }); - - futures::try_join!( - async { try_join_all(fts_tasks).await.map(|_| ()) }, - async { try_join_all(hnsw_tasks).await.map(|_| ()) }, - )?; + try_join_all(hnsw_tasks).await.map(|_| ())?; Ok(()) } @@ -204,20 +199,48 @@ fn extract_dimension(definition: &str) -> Option { } async fn create_fts_analyzer(db: &SurrealDbClient) -> Result<()> { - let analyzer_query = format!( + // Prefer snowball stemming when supported; fall back to ascii-only when the filter + // is unavailable in the running Surreal build. Use IF NOT EXISTS to avoid clobbering + // an existing analyzer definition. + let snowball_query = format!( "DEFINE ANALYZER IF NOT EXISTS {analyzer} TOKENIZERS class FILTERS lowercase, ascii, snowball(english);", analyzer = FTS_ANALYZER_NAME ); - let res = db - .client - .query(analyzer_query) - .await - .context("creating FTS analyzer")?; + match db.client.query(snowball_query).await { + Ok(res) => { + if res.check().is_ok() { + return Ok(()); + } + warn!( + "Snowball analyzer check failed; attempting ascii fallback definition (analyzer: {})", + FTS_ANALYZER_NAME + ); + } + Err(err) => { + warn!( + error = %err, + "Snowball analyzer creation errored; attempting ascii fallback definition" + ); + } + } + + let fallback_query = format!( + "DEFINE ANALYZER IF NOT EXISTS {analyzer} + TOKENIZERS class + FILTERS lowercase, ascii;", + analyzer = FTS_ANALYZER_NAME + ); + + db.client + .query(fallback_query) + .await + .context("creating fallback FTS analyzer")? + .check() + .context("failed to create fallback FTS analyzer")?; - res.check().context("failed to create FTS analyzer")?; Ok(()) } @@ -235,13 +258,38 @@ async fn create_index_with_polling( None => None, }; - let res = db - .client - .query(definition) - .await - .with_context(|| format!("creating index {index_name} on table {table}"))?; - res.check() - .with_context(|| format!("index definition failed for {index_name} on {table}"))?; + let mut attempts = 0; + const MAX_ATTEMPTS: usize = 3; + loop { + attempts += 1; + let res = db + .client + .query(definition.clone()) + .await + .with_context(|| format!("creating index {index_name} on table {table}"))?; + match res.check() { + Ok(_) => break, + Err(err) => { + let msg = err.to_string(); + let conflict = msg.contains("read or write conflict"); + warn!( + index = %index_name, + table = %table, + error = ?err, + attempt = attempts, + definition = %definition, + "Index definition failed" + ); + if conflict && attempts < MAX_ATTEMPTS { + tokio::time::sleep(Duration::from_millis(100)).await; + continue; + } + return Err(err).with_context(|| { + format!("index definition failed for {index_name} on {table}") + }); + } + } + } info!( index = %index_name, diff --git a/common/src/storage/types/system_settings.rs b/common/src/storage/types/system_settings.rs index b48a588..9df9ccc 100644 --- a/common/src/storage/types/system_settings.rs +++ b/common/src/storage/types/system_settings.rs @@ -53,9 +53,9 @@ impl SystemSettings { #[cfg(test)] mod tests { + use crate::storage::indexes::ensure_runtime_indexes; use crate::storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk}; use async_openai::Client; - use crate::storage::indexes::ensure_runtime_indexes; use super::*; use uuid::Uuid; diff --git a/common/src/storage/types/text_chunk.rs b/common/src/storage/types/text_chunk.rs index 6ab7df1..219220f 100644 --- a/common/src/storage/types/text_chunk.rs +++ b/common/src/storage/types/text_chunk.rs @@ -17,9 +17,9 @@ stored_object!(TextChunk, "text_chunk", { user_id: String }); -/// Vector search result including hydrated chunk. +/// Search result including hydrated chunk. #[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq)] -pub struct TextChunkVectorResult { +pub struct TextChunkSearchResult { pub chunk: TextChunk, pub score: f32, } @@ -97,7 +97,7 @@ impl TextChunk { query_embedding: Vec, db: &SurrealDbClient, user_id: &str, - ) -> Result, AppError> { + ) -> Result, AppError> { #[derive(Deserialize)] struct Row { chunk_id: TextChunk, @@ -132,13 +132,85 @@ impl TextChunk { Ok(rows .into_iter() - .map(|r| TextChunkVectorResult { + .map(|r| TextChunkSearchResult { chunk: r.chunk_id, score: r.score, }) .collect()) } + /// Full-text search over text chunks using the BM25 FTS index. + pub async fn fts_search( + take: usize, + terms: &str, + db: &SurrealDbClient, + user_id: &str, + ) -> Result, AppError> { + #[derive(Deserialize)] + struct Row { + #[serde(deserialize_with = "deserialize_flexible_id")] + id: String, + #[serde(deserialize_with = "deserialize_datetime")] + created_at: DateTime, + #[serde(deserialize_with = "deserialize_datetime")] + updated_at: DateTime, + source_id: String, + chunk: String, + user_id: String, + score: f32, + } + + let sql = format!( + r#" + SELECT + id, + created_at, + updated_at, + source_id, + chunk, + user_id, + IF search::score(0) != NONE THEN search::score(0) ELSE 0 END AS score + FROM {chunk_table} + WHERE chunk @0@ $terms + AND user_id = $user_id + ORDER BY score DESC + LIMIT $limit; + "#, + chunk_table = Self::table_name(), + ); + + let mut response = db + .query(&sql) + .bind(("terms", terms.to_owned())) + .bind(("user_id", user_id.to_owned())) + .bind(("limit", take as i64)) + .await + .map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?; + + response = response.check().map_err(AppError::Database)?; + + let rows: Vec = response.take::>(0).map_err(AppError::Database)?; + + Ok(rows + .into_iter() + .map(|r| { + let chunk = TextChunk { + id: r.id, + created_at: r.created_at, + updated_at: r.updated_at, + source_id: r.source_id, + chunk: r.chunk, + user_id: r.user_id, + }; + + TextChunkSearchResult { + chunk, + score: r.score, + } + }) + .collect()) + } + /// Re-creates embeddings for all text chunks using a safe, atomic transaction. /// /// This is a costly operation that should be run in the background. It performs these steps: @@ -252,6 +324,26 @@ mod tests { use surrealdb::RecordId; use uuid::Uuid; + async fn ensure_chunk_fts_index(db: &SurrealDbClient) { + let snowball_sql = r#" + DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english); + DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25; + "#; + + if let Err(err) = db.client.query(snowball_sql).await { + // Fall back to ascii-only analyzer when snowball is unavailable in the build. + let fallback_sql = r#" + DEFINE ANALYZER OVERWRITE app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii; + DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25; + "#; + + db.client + .query(fallback_sql) + .await + .unwrap_or_else(|_| panic!("define chunk fts index fallback: {err}")); + } + } + #[tokio::test] async fn test_text_chunk_creation() { let source_id = "source123".to_string(); @@ -435,7 +527,7 @@ mod tests { .await .expect("redefine index"); - let results: Vec = + let results: Vec = TextChunk::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user") .await .unwrap(); @@ -467,7 +559,7 @@ mod tests { .await .expect("store"); - let results: Vec = + let results: Vec = TextChunk::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id) .await .unwrap(); @@ -503,7 +595,7 @@ mod tests { .await .expect("store chunk2"); - let results: Vec = + let results: Vec = TextChunk::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id) .await .unwrap(); @@ -513,4 +605,105 @@ mod tests { assert_eq!(results[1].chunk.id, chunk1.id); assert!(results[0].score >= results[1].score); } + + #[tokio::test] + async fn test_fts_search_returns_empty_when_no_chunks() { + let namespace = "fts_chunk_ns_empty"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + db.apply_migrations().await.expect("migrations"); + ensure_chunk_fts_index(&db).await; + db.rebuild_indexes().await.expect("rebuild indexes"); + + let results = TextChunk::fts_search(5, "hello", &db, "user") + .await + .expect("fts search"); + + assert!(results.is_empty()); + } + + #[tokio::test] + async fn test_fts_search_single_result() { + let namespace = "fts_chunk_ns_single"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + db.apply_migrations().await.expect("migrations"); + ensure_chunk_fts_index(&db).await; + + let user_id = "fts_user"; + let chunk = TextChunk::new( + "fts_src".to_string(), + "rustaceans love rust".to_string(), + user_id.to_string(), + ); + db.store_item(chunk.clone()).await.expect("store chunk"); + db.rebuild_indexes().await.expect("rebuild indexes"); + + let results = TextChunk::fts_search(3, "rust", &db, user_id) + .await + .expect("fts search"); + + assert_eq!(results.len(), 1); + assert_eq!(results[0].chunk.id, chunk.id); + assert!(results[0].score.is_finite(), "expected a finite FTS score"); + } + + #[tokio::test] + async fn test_fts_search_orders_by_score_and_filters_user() { + let namespace = "fts_chunk_ns_order"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + db.apply_migrations().await.expect("migrations"); + ensure_chunk_fts_index(&db).await; + + let user_id = "fts_user_order"; + let high_score_chunk = TextChunk::new( + "src1".to_string(), + "apple apple apple pie recipe".to_string(), + user_id.to_string(), + ); + let low_score_chunk = TextChunk::new( + "src2".to_string(), + "apple tart".to_string(), + user_id.to_string(), + ); + let other_user_chunk = TextChunk::new( + "src3".to_string(), + "apple orchard guide".to_string(), + "other_user".to_string(), + ); + + db.store_item(high_score_chunk.clone()) + .await + .expect("store high score chunk"); + db.store_item(low_score_chunk.clone()) + .await + .expect("store low score chunk"); + db.store_item(other_user_chunk) + .await + .expect("store other user chunk"); + db.rebuild_indexes().await.expect("rebuild indexes"); + + let results = TextChunk::fts_search(3, "apple", &db, user_id) + .await + .expect("fts search"); + + assert_eq!(results.len(), 2); + let ids: Vec<_> = results.iter().map(|r| r.chunk.id.as_str()).collect(); + assert!( + ids.contains(&high_score_chunk.id.as_str()) + && ids.contains(&low_score_chunk.id.as_str()), + "expected only the two chunks for the same user" + ); + assert!( + results[0].score >= results[1].score, + "expected results ordered by descending score" + ); + } } diff --git a/ingestion-pipeline/src/pipeline/tests.rs b/ingestion-pipeline/src/pipeline/tests.rs index 5f4bb56..4c9d7f3 100644 --- a/ingestion-pipeline/src/pipeline/tests.rs +++ b/ingestion-pipeline/src/pipeline/tests.rs @@ -376,9 +376,8 @@ async fn ingestion_pipeline_chunk_only_skips_analysis() { let services = Arc::new(MockServices::new(user_id)); let mut config = pipeline_config(); config.chunk_only = true; - let pipeline = - IngestionPipeline::with_services(Arc::new(db.clone()), config, services.clone()) - .expect("pipeline"); + let pipeline = IngestionPipeline::with_services(Arc::new(db.clone()), config, services.clone()) + .expect("pipeline"); let task = reserve_task( &db, diff --git a/retrieval-pipeline/src/pipeline/config.rs b/retrieval-pipeline/src/pipeline/config.rs index 3f5e6d4..c12ff02 100644 --- a/retrieval-pipeline/src/pipeline/config.rs +++ b/retrieval-pipeline/src/pipeline/config.rs @@ -1,6 +1,8 @@ use serde::{Deserialize, Serialize}; use std::fmt; +use crate::scoring::FusionWeights; + #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, clap::ValueEnum)] #[serde(rename_all = "snake_case")] pub enum RetrievalStrategy { @@ -64,6 +66,12 @@ pub struct RetrievalTuning { pub rerank_scores_only: bool, pub rerank_keep_top: usize, pub chunk_result_cap: usize, + /// Optional fusion weights for hybrid search. If None, uses default weights. + pub fusion_weights: Option, + /// Normalize vector similarity scores before fusion (default: true) + pub normalize_vector_scores: bool, + /// Normalize FTS (BM25) scores before fusion (default: true) + pub normalize_fts_scores: bool, } impl Default for RetrievalTuning { @@ -88,6 +96,12 @@ impl Default for RetrievalTuning { rerank_scores_only: false, rerank_keep_top: 8, chunk_result_cap: 5, + fusion_weights: None, + // Vector scores (cosine similarity) are already in [0,1] range + // Normalization only helps when there's significant variation + normalize_vector_scores: false, + // FTS scores (BM25) are unbounded, normalization helps more + normalize_fts_scores: true, } } } diff --git a/retrieval-pipeline/src/pipeline/stages/mod.rs b/retrieval-pipeline/src/pipeline/stages/mod.rs index 6f19c04..af082a9 100644 --- a/retrieval-pipeline/src/pipeline/stages/mod.rs +++ b/retrieval-pipeline/src/pipeline/stages/mod.rs @@ -593,38 +593,158 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), debug!("Collecting vector chunk candidates for revised strategy"); let embedding = ctx.ensure_embedding()?.clone(); let tuning = &ctx.config.tuning; - let weights = FusionWeights::default(); + let weights = tuning.fusion_weights.unwrap_or_else(FusionWeights::default); + let fts_take = tuning.chunk_fts_take; - let mut vector_chunks: Vec> = TextChunk::vector_search( - tuning.chunk_vector_take, - embedding, - ctx.db_client, - &ctx.user_id, - ) - .await? - .into_iter() - .map(|row| { - let mut scored = Scored::new(row.chunk).with_vector_score(row.score); + let (vector_rows, fts_rows) = tokio::try_join!( + TextChunk::vector_search( + tuning.chunk_vector_take, + embedding, + ctx.db_client, + &ctx.user_id, + ), + async { + if fts_take == 0 { + Ok(Vec::new()) + } else { + TextChunk::fts_search(fts_take, &ctx.input_text, ctx.db_client, &ctx.user_id).await + } + } + )?; + + let mut merged: HashMap> = HashMap::new(); + let vector_candidates = vector_rows.len(); + let fts_candidates = fts_rows.len(); + + // Collect vector results + let vector_scored: Vec> = vector_rows + .into_iter() + .map(|row| Scored::new(row.chunk).with_vector_score(row.score)) + .collect(); + + // Collect FTS results + let fts_scored: Vec> = fts_rows + .into_iter() + .map(|row| Scored::new(row.chunk).with_fts_score(row.score)) + .collect(); + + // Merge by ID first (before normalization) + merge_scored_by_id(&mut merged, vector_scored); + merge_scored_by_id(&mut merged, fts_scored); + + let mut vector_chunks: Vec> = merged.into_values().collect(); + + debug!( + total_merged = vector_chunks.len(), + vector_only = vector_chunks + .iter() + .filter(|c| c.scores.fts.is_none()) + .count(), + fts_only = vector_chunks + .iter() + .filter(|c| c.scores.vector.is_none()) + .count(), + both_signals = vector_chunks + .iter() + .filter(|c| c.scores.vector.is_some() && c.scores.fts.is_some()) + .count(), + "Merged chunk candidates before normalization" + ); + + // Normalize scores AFTER merging, on the final merged set + // This ensures we normalize all vector scores together and all FTS scores together + // for the actual candidates that will be fused + if tuning.normalize_vector_scores && !vector_chunks.is_empty() { + let before_sample: Vec = vector_chunks + .iter() + .filter_map(|c| c.scores.vector) + .take(5) + .collect(); + normalize_vector_scores(&mut vector_chunks); + let after_sample: Vec = vector_chunks + .iter() + .filter_map(|c| c.scores.vector) + .take(5) + .collect(); + debug!( + vector_before = ?before_sample, + vector_after = ?after_sample, + "Vector score normalization applied" + ); + } + if tuning.normalize_fts_scores && !vector_chunks.is_empty() { + let before_sample: Vec = vector_chunks + .iter() + .filter_map(|c| c.scores.fts) + .take(5) + .collect(); + normalize_fts_scores_in_merged(&mut vector_chunks); + let after_sample: Vec = vector_chunks + .iter() + .filter_map(|c| c.scores.fts) + .take(5) + .collect(); + debug!( + fts_before = ?before_sample, + fts_after = ?after_sample, + "FTS score normalization applied" + ); + } + + // Fuse scores after normalization + for scored in &mut vector_chunks { let fused = fuse_scores(&scored.scores, weights); scored.update_fused(fused); - scored - }) - .collect(); + } + + // Filter out FTS-only chunks if they're likely to be low quality + // (when overlap is low, FTS-only chunks are usually noise) + // Always keep chunks with vector scores (vector-only or both signals) + let fts_only_count = vector_chunks + .iter() + .filter(|c| c.scores.vector.is_none()) + .count(); + let both_count = vector_chunks + .iter() + .filter(|c| c.scores.vector.is_some() && c.scores.fts.is_some()) + .count(); + + // If we have very low overlap (few chunks with both signals), filter out FTS-only chunks + // They're likely diluting the good vector results + // This preserves vector-only chunks and golden chunks (both signals) + if fts_only_count > 0 && both_count < 3 { + let before_filter = vector_chunks.len(); + vector_chunks.retain(|c| c.scores.vector.is_some()); + let after_filter = vector_chunks.len(); + debug!( + fts_only_filtered = before_filter - after_filter, + both_signals_preserved = both_count, + "Filtered out FTS-only chunks due to low overlap, preserved golden chunks" + ); + } + + debug!( + fusion_weights = ?weights, + top_fused_scores = ?vector_chunks.iter().take(5).map(|c| c.fused).collect::>(), + "Fused scores after normalization" + ); if ctx.diagnostics_enabled() { ctx.record_collect_candidates(CollectCandidatesStats { vector_entity_candidates: 0, - vector_chunk_candidates: vector_chunks.len(), + vector_chunk_candidates: vector_candidates, fts_entity_candidates: 0, - fts_chunk_candidates: 0, + fts_chunk_candidates: fts_candidates, vector_chunk_scores: sample_scores(&vector_chunks, |chunk| { chunk.scores.vector.unwrap_or(0.0) }), - fts_chunk_scores: Vec::new(), + fts_chunk_scores: sample_scores(&vector_chunks, |chunk| { + chunk.scores.fts.unwrap_or(0.0) + }), }); } - vector_chunks.sort_by(|a, b| b.fused.partial_cmp(&a.fused).unwrap_or(Ordering::Equal)); + sort_by_fused_desc(&mut vector_chunks); ctx.revised_chunk_values = vector_chunks; Ok(()) @@ -668,13 +788,6 @@ pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { debug!("Assembling chunk-only retrieval results"); let mut chunk_values = std::mem::take(&mut ctx.revised_chunk_values); - let question_terms = extract_keywords(&ctx.input_text); - rank_chunks_by_combined_score( - &mut chunk_values, - &question_terms, - ctx.config.tuning.lexical_match_weight, - ); - // Limit how many chunks we return to keep context size reasonable. let limit = ctx .config @@ -682,7 +795,13 @@ pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { .chunk_result_cap .max(1) .min(ctx.config.tuning.chunk_vector_take.max(1)); + if chunk_values.len() > limit { + println!( + "We removed chunks! we had {:?}, now going for {:?}", + chunk_values.len(), + limit + ); chunk_values.truncate(limit); } @@ -847,6 +966,89 @@ fn normalize_fts_scores(results: &mut [Scored]) { } } +fn normalize_vector_scores(results: &mut [Scored]) { + // Only normalize scores for items that actually have vector scores + let items_with_scores: Vec<(usize, f32)> = results + .iter() + .enumerate() + .filter_map(|(idx, candidate)| candidate.scores.vector.map(|score| (idx, score))) + .collect(); + + if items_with_scores.len() < 2 { + // Don't normalize if we have 0 or 1 scores - nothing to normalize against + return; + } + + let raw_scores: Vec = items_with_scores.iter().map(|(_, score)| *score).collect(); + + // For cosine similarity scores (already in [0,1]), use a gentler normalization + // that preserves more of the original distribution + // Only normalize if the range is significant (more than 0.1 difference) + let min = raw_scores.iter().fold(f32::MAX, |a, &b| a.min(b)); + let max = raw_scores.iter().fold(f32::MIN, |a, &b| a.max(b)); + let range = max - min; + + if range < 0.1 { + // Scores are too similar, don't normalize (would compress too much) + debug!( + vector_score_range = range, + min = min, + max = max, + "Skipping vector normalization - scores too similar" + ); + return; + } + + let normalized = min_max_normalize(&raw_scores); + + for ((idx, _), normalized_score) in items_with_scores.iter().zip(normalized.into_iter()) { + results[*idx].scores.vector = Some(normalized_score); + results[*idx].update_fused(0.0); + } +} + +fn normalize_fts_scores_in_merged(results: &mut [Scored]) { + // Only normalize scores for items that actually have FTS scores + let items_with_scores: Vec<(usize, f32)> = results + .iter() + .enumerate() + .filter_map(|(idx, candidate)| candidate.scores.fts.map(|score| (idx, score))) + .collect(); + + if items_with_scores.len() < 2 { + // Don't normalize if we have 0 or 1 scores - nothing to normalize against + // Single FTS score would become 1.0, which doesn't help + return; + } + + let raw_scores: Vec = items_with_scores.iter().map(|(_, score)| *score).collect(); + + // BM25 scores can be negative or very high, so normalization is more important + // But check if we have enough variation to normalize + let min = raw_scores.iter().fold(f32::MAX, |a, &b| a.min(b)); + let max = raw_scores.iter().fold(f32::MIN, |a, &b| a.max(b)); + let range = max - min; + + // For BM25, even small differences can be meaningful, but if all scores are + // very similar, normalization won't help + if range < 0.01 { + debug!( + fts_score_range = range, + min = min, + max = max, + "Skipping FTS normalization - scores too similar" + ); + return; + } + + let normalized = min_max_normalize(&raw_scores); + + for ((idx, _), normalized_score) in items_with_scores.iter().zip(normalized.into_iter()) { + results[*idx].scores.fts = Some(normalized_score); + results[*idx].update_fused(0.0); + } +} + fn apply_fusion(candidates: &mut HashMap>, weights: FusionWeights) where T: StoredObject, diff --git a/retrieval-pipeline/src/scoring.rs b/retrieval-pipeline/src/scoring.rs index 560c086..458709d 100644 --- a/retrieval-pipeline/src/scoring.rs +++ b/retrieval-pipeline/src/scoring.rs @@ -1,6 +1,7 @@ use std::cmp::Ordering; use common::storage::types::StoredObject; +use serde::{Deserialize, Serialize}; /// Holds optional subscores gathered from different retrieval signals. #[derive(Debug, Clone, Copy, Default)] @@ -48,7 +49,7 @@ impl Scored { } /// Weights used for linear score fusion. -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] pub struct FusionWeights { pub vector: f32, pub fts: f32, @@ -58,11 +59,14 @@ pub struct FusionWeights { impl Default for FusionWeights { fn default() -> Self { + // Default weights favor vector search, which typically performs better + // FTS is used as a complement when there's good overlap + // Higher multi_bonus to heavily favor chunks with both signals (the "golden chunk") Self { - vector: 0.5, - fts: 0.3, + vector: 0.8, + fts: 0.2, graph: 0.2, - multi_bonus: 0.02, + multi_bonus: 0.3, // Increased to boost chunks with both signals } } } @@ -134,8 +138,19 @@ pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 { .chain(scores.fts.iter()) .chain(scores.graph.iter()) .count(); + + // Boost chunks with multiple signals (especially vector + FTS, the "golden chunk") if signals_present >= 2 { - fused += weights.multi_bonus; + // For chunks with both vector and FTS, give a significant boost + // This helps identify the "golden chunk" that appears in both searches + if scores.vector.is_some() && scores.fts.is_some() { + // Multiplicative boost: multiply by (1 + bonus) to scale with the base score + // This ensures high-scoring golden chunks get boosted more than low-scoring ones + fused = fused * (1.0 + weights.multi_bonus); + } else { + // For other multi-signal combinations (e.g., vector + graph), use additive bonus + fused += weights.multi_bonus; + } } clamp_unit(fused)