mirror of
https://github.com/perstarkse/minne.git
synced 2026-04-22 08:48:30 +02:00
retrieval: hybrid search, linear fusion
This commit is contained in:
@@ -1,7 +1,5 @@
|
|||||||
use super::types::StoredObject;
|
use super::types::StoredObject;
|
||||||
use crate::{
|
use crate::error::AppError;
|
||||||
error::AppError,
|
|
||||||
};
|
|
||||||
use axum_session::{SessionConfig, SessionError, SessionStore};
|
use axum_session::{SessionConfig, SessionError, SessionStore};
|
||||||
use axum_session_surreal::SessionSurrealPool;
|
use axum_session_surreal::SessionSurrealPool;
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
|
|||||||
@@ -84,11 +84,11 @@ async fn ensure_runtime_indexes_inner(
|
|||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
create_fts_analyzer(db).await?;
|
create_fts_analyzer(db).await?;
|
||||||
|
|
||||||
let fts_tasks = fts_index_specs().into_iter().map(|spec| async move {
|
for spec in fts_index_specs() {
|
||||||
if index_exists(db, spec.table, spec.index_name).await? {
|
if index_exists(db, spec.table, spec.index_name).await? {
|
||||||
return Ok(());
|
continue;
|
||||||
}
|
}
|
||||||
|
// We need to create these sequentially otherwise SurrealDB errors with read/write clash
|
||||||
create_index_with_polling(
|
create_index_with_polling(
|
||||||
db,
|
db,
|
||||||
spec.definition(),
|
spec.definition(),
|
||||||
@@ -96,48 +96,43 @@ async fn ensure_runtime_indexes_inner(
|
|||||||
spec.table,
|
spec.table,
|
||||||
Some(spec.table),
|
Some(spec.table),
|
||||||
)
|
)
|
||||||
.await
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let hnsw_tasks = hnsw_index_specs().into_iter().map(|spec| async move {
|
||||||
|
match hnsw_index_state(db, &spec, embedding_dimension).await? {
|
||||||
|
HnswIndexState::Missing => {
|
||||||
|
create_index_with_polling(
|
||||||
|
db,
|
||||||
|
spec.definition_if_not_exists(embedding_dimension),
|
||||||
|
spec.index_name,
|
||||||
|
spec.table,
|
||||||
|
Some(spec.table),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
HnswIndexState::Matches => Ok(()),
|
||||||
|
HnswIndexState::Different(existing) => {
|
||||||
|
info!(
|
||||||
|
index = spec.index_name,
|
||||||
|
table = spec.table,
|
||||||
|
existing_dimension = existing,
|
||||||
|
target_dimension = embedding_dimension,
|
||||||
|
"Overwriting HNSW index to match new embedding dimension"
|
||||||
|
);
|
||||||
|
create_index_with_polling(
|
||||||
|
db,
|
||||||
|
spec.definition_overwrite(embedding_dimension),
|
||||||
|
spec.index_name,
|
||||||
|
spec.table,
|
||||||
|
Some(spec.table),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
let hnsw_tasks = hnsw_index_specs()
|
try_join_all(hnsw_tasks).await.map(|_| ())?;
|
||||||
.into_iter()
|
|
||||||
.map(|spec| async move {
|
|
||||||
match hnsw_index_state(db, &spec, embedding_dimension).await? {
|
|
||||||
HnswIndexState::Missing => {
|
|
||||||
create_index_with_polling(
|
|
||||||
db,
|
|
||||||
spec.definition_if_not_exists(embedding_dimension),
|
|
||||||
spec.index_name,
|
|
||||||
spec.table,
|
|
||||||
Some(spec.table),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
HnswIndexState::Matches => Ok(()),
|
|
||||||
HnswIndexState::Different(existing) => {
|
|
||||||
info!(
|
|
||||||
index = spec.index_name,
|
|
||||||
table = spec.table,
|
|
||||||
existing_dimension = existing,
|
|
||||||
target_dimension = embedding_dimension,
|
|
||||||
"Overwriting HNSW index to match new embedding dimension"
|
|
||||||
);
|
|
||||||
create_index_with_polling(
|
|
||||||
db,
|
|
||||||
spec.definition_overwrite(embedding_dimension),
|
|
||||||
spec.index_name,
|
|
||||||
spec.table,
|
|
||||||
Some(spec.table),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
futures::try_join!(
|
|
||||||
async { try_join_all(fts_tasks).await.map(|_| ()) },
|
|
||||||
async { try_join_all(hnsw_tasks).await.map(|_| ()) },
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -204,20 +199,48 @@ fn extract_dimension(definition: &str) -> Option<u64> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn create_fts_analyzer(db: &SurrealDbClient) -> Result<()> {
|
async fn create_fts_analyzer(db: &SurrealDbClient) -> Result<()> {
|
||||||
let analyzer_query = format!(
|
// Prefer snowball stemming when supported; fall back to ascii-only when the filter
|
||||||
|
// is unavailable in the running Surreal build. Use IF NOT EXISTS to avoid clobbering
|
||||||
|
// an existing analyzer definition.
|
||||||
|
let snowball_query = format!(
|
||||||
"DEFINE ANALYZER IF NOT EXISTS {analyzer}
|
"DEFINE ANALYZER IF NOT EXISTS {analyzer}
|
||||||
TOKENIZERS class
|
TOKENIZERS class
|
||||||
FILTERS lowercase, ascii, snowball(english);",
|
FILTERS lowercase, ascii, snowball(english);",
|
||||||
analyzer = FTS_ANALYZER_NAME
|
analyzer = FTS_ANALYZER_NAME
|
||||||
);
|
);
|
||||||
|
|
||||||
let res = db
|
match db.client.query(snowball_query).await {
|
||||||
.client
|
Ok(res) => {
|
||||||
.query(analyzer_query)
|
if res.check().is_ok() {
|
||||||
.await
|
return Ok(());
|
||||||
.context("creating FTS analyzer")?;
|
}
|
||||||
|
warn!(
|
||||||
|
"Snowball analyzer check failed; attempting ascii fallback definition (analyzer: {})",
|
||||||
|
FTS_ANALYZER_NAME
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
warn!(
|
||||||
|
error = %err,
|
||||||
|
"Snowball analyzer creation errored; attempting ascii fallback definition"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let fallback_query = format!(
|
||||||
|
"DEFINE ANALYZER IF NOT EXISTS {analyzer}
|
||||||
|
TOKENIZERS class
|
||||||
|
FILTERS lowercase, ascii;",
|
||||||
|
analyzer = FTS_ANALYZER_NAME
|
||||||
|
);
|
||||||
|
|
||||||
|
db.client
|
||||||
|
.query(fallback_query)
|
||||||
|
.await
|
||||||
|
.context("creating fallback FTS analyzer")?
|
||||||
|
.check()
|
||||||
|
.context("failed to create fallback FTS analyzer")?;
|
||||||
|
|
||||||
res.check().context("failed to create FTS analyzer")?;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -235,13 +258,38 @@ async fn create_index_with_polling(
|
|||||||
None => None,
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let res = db
|
let mut attempts = 0;
|
||||||
.client
|
const MAX_ATTEMPTS: usize = 3;
|
||||||
.query(definition)
|
loop {
|
||||||
.await
|
attempts += 1;
|
||||||
.with_context(|| format!("creating index {index_name} on table {table}"))?;
|
let res = db
|
||||||
res.check()
|
.client
|
||||||
.with_context(|| format!("index definition failed for {index_name} on {table}"))?;
|
.query(definition.clone())
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("creating index {index_name} on table {table}"))?;
|
||||||
|
match res.check() {
|
||||||
|
Ok(_) => break,
|
||||||
|
Err(err) => {
|
||||||
|
let msg = err.to_string();
|
||||||
|
let conflict = msg.contains("read or write conflict");
|
||||||
|
warn!(
|
||||||
|
index = %index_name,
|
||||||
|
table = %table,
|
||||||
|
error = ?err,
|
||||||
|
attempt = attempts,
|
||||||
|
definition = %definition,
|
||||||
|
"Index definition failed"
|
||||||
|
);
|
||||||
|
if conflict && attempts < MAX_ATTEMPTS {
|
||||||
|
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return Err(err).with_context(|| {
|
||||||
|
format!("index definition failed for {index_name} on {table}")
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
index = %index_name,
|
index = %index_name,
|
||||||
|
|||||||
@@ -53,9 +53,9 @@ impl SystemSettings {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use crate::storage::indexes::ensure_runtime_indexes;
|
||||||
use crate::storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk};
|
use crate::storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk};
|
||||||
use async_openai::Client;
|
use async_openai::Client;
|
||||||
use crate::storage::indexes::ensure_runtime_indexes;
|
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|||||||
@@ -17,9 +17,9 @@ stored_object!(TextChunk, "text_chunk", {
|
|||||||
user_id: String
|
user_id: String
|
||||||
});
|
});
|
||||||
|
|
||||||
/// Vector search result including hydrated chunk.
|
/// Search result including hydrated chunk.
|
||||||
#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq)]
|
#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq)]
|
||||||
pub struct TextChunkVectorResult {
|
pub struct TextChunkSearchResult {
|
||||||
pub chunk: TextChunk,
|
pub chunk: TextChunk,
|
||||||
pub score: f32,
|
pub score: f32,
|
||||||
}
|
}
|
||||||
@@ -97,7 +97,7 @@ impl TextChunk {
|
|||||||
query_embedding: Vec<f32>,
|
query_embedding: Vec<f32>,
|
||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
) -> Result<Vec<TextChunkVectorResult>, AppError> {
|
) -> Result<Vec<TextChunkSearchResult>, AppError> {
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct Row {
|
struct Row {
|
||||||
chunk_id: TextChunk,
|
chunk_id: TextChunk,
|
||||||
@@ -132,13 +132,85 @@ impl TextChunk {
|
|||||||
|
|
||||||
Ok(rows
|
Ok(rows
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|r| TextChunkVectorResult {
|
.map(|r| TextChunkSearchResult {
|
||||||
chunk: r.chunk_id,
|
chunk: r.chunk_id,
|
||||||
score: r.score,
|
score: r.score,
|
||||||
})
|
})
|
||||||
.collect())
|
.collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Full-text search over text chunks using the BM25 FTS index.
|
||||||
|
pub async fn fts_search(
|
||||||
|
take: usize,
|
||||||
|
terms: &str,
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
user_id: &str,
|
||||||
|
) -> Result<Vec<TextChunkSearchResult>, AppError> {
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct Row {
|
||||||
|
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||||
|
id: String,
|
||||||
|
#[serde(deserialize_with = "deserialize_datetime")]
|
||||||
|
created_at: DateTime<Utc>,
|
||||||
|
#[serde(deserialize_with = "deserialize_datetime")]
|
||||||
|
updated_at: DateTime<Utc>,
|
||||||
|
source_id: String,
|
||||||
|
chunk: String,
|
||||||
|
user_id: String,
|
||||||
|
score: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
let sql = format!(
|
||||||
|
r#"
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
created_at,
|
||||||
|
updated_at,
|
||||||
|
source_id,
|
||||||
|
chunk,
|
||||||
|
user_id,
|
||||||
|
IF search::score(0) != NONE THEN search::score(0) ELSE 0 END AS score
|
||||||
|
FROM {chunk_table}
|
||||||
|
WHERE chunk @0@ $terms
|
||||||
|
AND user_id = $user_id
|
||||||
|
ORDER BY score DESC
|
||||||
|
LIMIT $limit;
|
||||||
|
"#,
|
||||||
|
chunk_table = Self::table_name(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut response = db
|
||||||
|
.query(&sql)
|
||||||
|
.bind(("terms", terms.to_owned()))
|
||||||
|
.bind(("user_id", user_id.to_owned()))
|
||||||
|
.bind(("limit", take as i64))
|
||||||
|
.await
|
||||||
|
.map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?;
|
||||||
|
|
||||||
|
response = response.check().map_err(AppError::Database)?;
|
||||||
|
|
||||||
|
let rows: Vec<Row> = response.take::<Vec<Row>>(0).map_err(AppError::Database)?;
|
||||||
|
|
||||||
|
Ok(rows
|
||||||
|
.into_iter()
|
||||||
|
.map(|r| {
|
||||||
|
let chunk = TextChunk {
|
||||||
|
id: r.id,
|
||||||
|
created_at: r.created_at,
|
||||||
|
updated_at: r.updated_at,
|
||||||
|
source_id: r.source_id,
|
||||||
|
chunk: r.chunk,
|
||||||
|
user_id: r.user_id,
|
||||||
|
};
|
||||||
|
|
||||||
|
TextChunkSearchResult {
|
||||||
|
chunk,
|
||||||
|
score: r.score,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect())
|
||||||
|
}
|
||||||
|
|
||||||
/// Re-creates embeddings for all text chunks using a safe, atomic transaction.
|
/// Re-creates embeddings for all text chunks using a safe, atomic transaction.
|
||||||
///
|
///
|
||||||
/// This is a costly operation that should be run in the background. It performs these steps:
|
/// This is a costly operation that should be run in the background. It performs these steps:
|
||||||
@@ -252,6 +324,26 @@ mod tests {
|
|||||||
use surrealdb::RecordId;
|
use surrealdb::RecordId;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
async fn ensure_chunk_fts_index(db: &SurrealDbClient) {
|
||||||
|
let snowball_sql = r#"
|
||||||
|
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english);
|
||||||
|
DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||||
|
"#;
|
||||||
|
|
||||||
|
if let Err(err) = db.client.query(snowball_sql).await {
|
||||||
|
// Fall back to ascii-only analyzer when snowball is unavailable in the build.
|
||||||
|
let fallback_sql = r#"
|
||||||
|
DEFINE ANALYZER OVERWRITE app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii;
|
||||||
|
DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||||
|
"#;
|
||||||
|
|
||||||
|
db.client
|
||||||
|
.query(fallback_sql)
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|_| panic!("define chunk fts index fallback: {err}"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_text_chunk_creation() {
|
async fn test_text_chunk_creation() {
|
||||||
let source_id = "source123".to_string();
|
let source_id = "source123".to_string();
|
||||||
@@ -435,7 +527,7 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.expect("redefine index");
|
.expect("redefine index");
|
||||||
|
|
||||||
let results: Vec<TextChunkVectorResult> =
|
let results: Vec<TextChunkSearchResult> =
|
||||||
TextChunk::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user")
|
TextChunk::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user")
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -467,7 +559,7 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.expect("store");
|
.expect("store");
|
||||||
|
|
||||||
let results: Vec<TextChunkVectorResult> =
|
let results: Vec<TextChunkSearchResult> =
|
||||||
TextChunk::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
|
TextChunk::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -503,7 +595,7 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.expect("store chunk2");
|
.expect("store chunk2");
|
||||||
|
|
||||||
let results: Vec<TextChunkVectorResult> =
|
let results: Vec<TextChunkSearchResult> =
|
||||||
TextChunk::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id)
|
TextChunk::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -513,4 +605,105 @@ mod tests {
|
|||||||
assert_eq!(results[1].chunk.id, chunk1.id);
|
assert_eq!(results[1].chunk.id, chunk1.id);
|
||||||
assert!(results[0].score >= results[1].score);
|
assert!(results[0].score >= results[1].score);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_fts_search_returns_empty_when_no_chunks() {
|
||||||
|
let namespace = "fts_chunk_ns_empty";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
|
.await
|
||||||
|
.expect("Failed to start in-memory surrealdb");
|
||||||
|
db.apply_migrations().await.expect("migrations");
|
||||||
|
ensure_chunk_fts_index(&db).await;
|
||||||
|
db.rebuild_indexes().await.expect("rebuild indexes");
|
||||||
|
|
||||||
|
let results = TextChunk::fts_search(5, "hello", &db, "user")
|
||||||
|
.await
|
||||||
|
.expect("fts search");
|
||||||
|
|
||||||
|
assert!(results.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_fts_search_single_result() {
|
||||||
|
let namespace = "fts_chunk_ns_single";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
|
.await
|
||||||
|
.expect("Failed to start in-memory surrealdb");
|
||||||
|
db.apply_migrations().await.expect("migrations");
|
||||||
|
ensure_chunk_fts_index(&db).await;
|
||||||
|
|
||||||
|
let user_id = "fts_user";
|
||||||
|
let chunk = TextChunk::new(
|
||||||
|
"fts_src".to_string(),
|
||||||
|
"rustaceans love rust".to_string(),
|
||||||
|
user_id.to_string(),
|
||||||
|
);
|
||||||
|
db.store_item(chunk.clone()).await.expect("store chunk");
|
||||||
|
db.rebuild_indexes().await.expect("rebuild indexes");
|
||||||
|
|
||||||
|
let results = TextChunk::fts_search(3, "rust", &db, user_id)
|
||||||
|
.await
|
||||||
|
.expect("fts search");
|
||||||
|
|
||||||
|
assert_eq!(results.len(), 1);
|
||||||
|
assert_eq!(results[0].chunk.id, chunk.id);
|
||||||
|
assert!(results[0].score.is_finite(), "expected a finite FTS score");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_fts_search_orders_by_score_and_filters_user() {
|
||||||
|
let namespace = "fts_chunk_ns_order";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
|
.await
|
||||||
|
.expect("Failed to start in-memory surrealdb");
|
||||||
|
db.apply_migrations().await.expect("migrations");
|
||||||
|
ensure_chunk_fts_index(&db).await;
|
||||||
|
|
||||||
|
let user_id = "fts_user_order";
|
||||||
|
let high_score_chunk = TextChunk::new(
|
||||||
|
"src1".to_string(),
|
||||||
|
"apple apple apple pie recipe".to_string(),
|
||||||
|
user_id.to_string(),
|
||||||
|
);
|
||||||
|
let low_score_chunk = TextChunk::new(
|
||||||
|
"src2".to_string(),
|
||||||
|
"apple tart".to_string(),
|
||||||
|
user_id.to_string(),
|
||||||
|
);
|
||||||
|
let other_user_chunk = TextChunk::new(
|
||||||
|
"src3".to_string(),
|
||||||
|
"apple orchard guide".to_string(),
|
||||||
|
"other_user".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
db.store_item(high_score_chunk.clone())
|
||||||
|
.await
|
||||||
|
.expect("store high score chunk");
|
||||||
|
db.store_item(low_score_chunk.clone())
|
||||||
|
.await
|
||||||
|
.expect("store low score chunk");
|
||||||
|
db.store_item(other_user_chunk)
|
||||||
|
.await
|
||||||
|
.expect("store other user chunk");
|
||||||
|
db.rebuild_indexes().await.expect("rebuild indexes");
|
||||||
|
|
||||||
|
let results = TextChunk::fts_search(3, "apple", &db, user_id)
|
||||||
|
.await
|
||||||
|
.expect("fts search");
|
||||||
|
|
||||||
|
assert_eq!(results.len(), 2);
|
||||||
|
let ids: Vec<_> = results.iter().map(|r| r.chunk.id.as_str()).collect();
|
||||||
|
assert!(
|
||||||
|
ids.contains(&high_score_chunk.id.as_str())
|
||||||
|
&& ids.contains(&low_score_chunk.id.as_str()),
|
||||||
|
"expected only the two chunks for the same user"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
results[0].score >= results[1].score,
|
||||||
|
"expected results ordered by descending score"
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -376,9 +376,8 @@ async fn ingestion_pipeline_chunk_only_skips_analysis() {
|
|||||||
let services = Arc::new(MockServices::new(user_id));
|
let services = Arc::new(MockServices::new(user_id));
|
||||||
let mut config = pipeline_config();
|
let mut config = pipeline_config();
|
||||||
config.chunk_only = true;
|
config.chunk_only = true;
|
||||||
let pipeline =
|
let pipeline = IngestionPipeline::with_services(Arc::new(db.clone()), config, services.clone())
|
||||||
IngestionPipeline::with_services(Arc::new(db.clone()), config, services.clone())
|
.expect("pipeline");
|
||||||
.expect("pipeline");
|
|
||||||
|
|
||||||
let task = reserve_task(
|
let task = reserve_task(
|
||||||
&db,
|
&db,
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
|
use crate::scoring::FusionWeights;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, clap::ValueEnum)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, clap::ValueEnum)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub enum RetrievalStrategy {
|
pub enum RetrievalStrategy {
|
||||||
@@ -64,6 +66,12 @@ pub struct RetrievalTuning {
|
|||||||
pub rerank_scores_only: bool,
|
pub rerank_scores_only: bool,
|
||||||
pub rerank_keep_top: usize,
|
pub rerank_keep_top: usize,
|
||||||
pub chunk_result_cap: usize,
|
pub chunk_result_cap: usize,
|
||||||
|
/// Optional fusion weights for hybrid search. If None, uses default weights.
|
||||||
|
pub fusion_weights: Option<FusionWeights>,
|
||||||
|
/// Normalize vector similarity scores before fusion (default: true)
|
||||||
|
pub normalize_vector_scores: bool,
|
||||||
|
/// Normalize FTS (BM25) scores before fusion (default: true)
|
||||||
|
pub normalize_fts_scores: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for RetrievalTuning {
|
impl Default for RetrievalTuning {
|
||||||
@@ -88,6 +96,12 @@ impl Default for RetrievalTuning {
|
|||||||
rerank_scores_only: false,
|
rerank_scores_only: false,
|
||||||
rerank_keep_top: 8,
|
rerank_keep_top: 8,
|
||||||
chunk_result_cap: 5,
|
chunk_result_cap: 5,
|
||||||
|
fusion_weights: None,
|
||||||
|
// Vector scores (cosine similarity) are already in [0,1] range
|
||||||
|
// Normalization only helps when there's significant variation
|
||||||
|
normalize_vector_scores: false,
|
||||||
|
// FTS scores (BM25) are unbounded, normalization helps more
|
||||||
|
normalize_fts_scores: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -593,38 +593,158 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(),
|
|||||||
debug!("Collecting vector chunk candidates for revised strategy");
|
debug!("Collecting vector chunk candidates for revised strategy");
|
||||||
let embedding = ctx.ensure_embedding()?.clone();
|
let embedding = ctx.ensure_embedding()?.clone();
|
||||||
let tuning = &ctx.config.tuning;
|
let tuning = &ctx.config.tuning;
|
||||||
let weights = FusionWeights::default();
|
let weights = tuning.fusion_weights.unwrap_or_else(FusionWeights::default);
|
||||||
|
let fts_take = tuning.chunk_fts_take;
|
||||||
|
|
||||||
let mut vector_chunks: Vec<Scored<TextChunk>> = TextChunk::vector_search(
|
let (vector_rows, fts_rows) = tokio::try_join!(
|
||||||
tuning.chunk_vector_take,
|
TextChunk::vector_search(
|
||||||
embedding,
|
tuning.chunk_vector_take,
|
||||||
ctx.db_client,
|
embedding,
|
||||||
&ctx.user_id,
|
ctx.db_client,
|
||||||
)
|
&ctx.user_id,
|
||||||
.await?
|
),
|
||||||
.into_iter()
|
async {
|
||||||
.map(|row| {
|
if fts_take == 0 {
|
||||||
let mut scored = Scored::new(row.chunk).with_vector_score(row.score);
|
Ok(Vec::new())
|
||||||
|
} else {
|
||||||
|
TextChunk::fts_search(fts_take, &ctx.input_text, ctx.db_client, &ctx.user_id).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let mut merged: HashMap<String, Scored<TextChunk>> = HashMap::new();
|
||||||
|
let vector_candidates = vector_rows.len();
|
||||||
|
let fts_candidates = fts_rows.len();
|
||||||
|
|
||||||
|
// Collect vector results
|
||||||
|
let vector_scored: Vec<Scored<TextChunk>> = vector_rows
|
||||||
|
.into_iter()
|
||||||
|
.map(|row| Scored::new(row.chunk).with_vector_score(row.score))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Collect FTS results
|
||||||
|
let fts_scored: Vec<Scored<TextChunk>> = fts_rows
|
||||||
|
.into_iter()
|
||||||
|
.map(|row| Scored::new(row.chunk).with_fts_score(row.score))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Merge by ID first (before normalization)
|
||||||
|
merge_scored_by_id(&mut merged, vector_scored);
|
||||||
|
merge_scored_by_id(&mut merged, fts_scored);
|
||||||
|
|
||||||
|
let mut vector_chunks: Vec<Scored<TextChunk>> = merged.into_values().collect();
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
total_merged = vector_chunks.len(),
|
||||||
|
vector_only = vector_chunks
|
||||||
|
.iter()
|
||||||
|
.filter(|c| c.scores.fts.is_none())
|
||||||
|
.count(),
|
||||||
|
fts_only = vector_chunks
|
||||||
|
.iter()
|
||||||
|
.filter(|c| c.scores.vector.is_none())
|
||||||
|
.count(),
|
||||||
|
both_signals = vector_chunks
|
||||||
|
.iter()
|
||||||
|
.filter(|c| c.scores.vector.is_some() && c.scores.fts.is_some())
|
||||||
|
.count(),
|
||||||
|
"Merged chunk candidates before normalization"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Normalize scores AFTER merging, on the final merged set
|
||||||
|
// This ensures we normalize all vector scores together and all FTS scores together
|
||||||
|
// for the actual candidates that will be fused
|
||||||
|
if tuning.normalize_vector_scores && !vector_chunks.is_empty() {
|
||||||
|
let before_sample: Vec<f32> = vector_chunks
|
||||||
|
.iter()
|
||||||
|
.filter_map(|c| c.scores.vector)
|
||||||
|
.take(5)
|
||||||
|
.collect();
|
||||||
|
normalize_vector_scores(&mut vector_chunks);
|
||||||
|
let after_sample: Vec<f32> = vector_chunks
|
||||||
|
.iter()
|
||||||
|
.filter_map(|c| c.scores.vector)
|
||||||
|
.take(5)
|
||||||
|
.collect();
|
||||||
|
debug!(
|
||||||
|
vector_before = ?before_sample,
|
||||||
|
vector_after = ?after_sample,
|
||||||
|
"Vector score normalization applied"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if tuning.normalize_fts_scores && !vector_chunks.is_empty() {
|
||||||
|
let before_sample: Vec<f32> = vector_chunks
|
||||||
|
.iter()
|
||||||
|
.filter_map(|c| c.scores.fts)
|
||||||
|
.take(5)
|
||||||
|
.collect();
|
||||||
|
normalize_fts_scores_in_merged(&mut vector_chunks);
|
||||||
|
let after_sample: Vec<f32> = vector_chunks
|
||||||
|
.iter()
|
||||||
|
.filter_map(|c| c.scores.fts)
|
||||||
|
.take(5)
|
||||||
|
.collect();
|
||||||
|
debug!(
|
||||||
|
fts_before = ?before_sample,
|
||||||
|
fts_after = ?after_sample,
|
||||||
|
"FTS score normalization applied"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fuse scores after normalization
|
||||||
|
for scored in &mut vector_chunks {
|
||||||
let fused = fuse_scores(&scored.scores, weights);
|
let fused = fuse_scores(&scored.scores, weights);
|
||||||
scored.update_fused(fused);
|
scored.update_fused(fused);
|
||||||
scored
|
}
|
||||||
})
|
|
||||||
.collect();
|
// Filter out FTS-only chunks if they're likely to be low quality
|
||||||
|
// (when overlap is low, FTS-only chunks are usually noise)
|
||||||
|
// Always keep chunks with vector scores (vector-only or both signals)
|
||||||
|
let fts_only_count = vector_chunks
|
||||||
|
.iter()
|
||||||
|
.filter(|c| c.scores.vector.is_none())
|
||||||
|
.count();
|
||||||
|
let both_count = vector_chunks
|
||||||
|
.iter()
|
||||||
|
.filter(|c| c.scores.vector.is_some() && c.scores.fts.is_some())
|
||||||
|
.count();
|
||||||
|
|
||||||
|
// If we have very low overlap (few chunks with both signals), filter out FTS-only chunks
|
||||||
|
// They're likely diluting the good vector results
|
||||||
|
// This preserves vector-only chunks and golden chunks (both signals)
|
||||||
|
if fts_only_count > 0 && both_count < 3 {
|
||||||
|
let before_filter = vector_chunks.len();
|
||||||
|
vector_chunks.retain(|c| c.scores.vector.is_some());
|
||||||
|
let after_filter = vector_chunks.len();
|
||||||
|
debug!(
|
||||||
|
fts_only_filtered = before_filter - after_filter,
|
||||||
|
both_signals_preserved = both_count,
|
||||||
|
"Filtered out FTS-only chunks due to low overlap, preserved golden chunks"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
fusion_weights = ?weights,
|
||||||
|
top_fused_scores = ?vector_chunks.iter().take(5).map(|c| c.fused).collect::<Vec<_>>(),
|
||||||
|
"Fused scores after normalization"
|
||||||
|
);
|
||||||
|
|
||||||
if ctx.diagnostics_enabled() {
|
if ctx.diagnostics_enabled() {
|
||||||
ctx.record_collect_candidates(CollectCandidatesStats {
|
ctx.record_collect_candidates(CollectCandidatesStats {
|
||||||
vector_entity_candidates: 0,
|
vector_entity_candidates: 0,
|
||||||
vector_chunk_candidates: vector_chunks.len(),
|
vector_chunk_candidates: vector_candidates,
|
||||||
fts_entity_candidates: 0,
|
fts_entity_candidates: 0,
|
||||||
fts_chunk_candidates: 0,
|
fts_chunk_candidates: fts_candidates,
|
||||||
vector_chunk_scores: sample_scores(&vector_chunks, |chunk| {
|
vector_chunk_scores: sample_scores(&vector_chunks, |chunk| {
|
||||||
chunk.scores.vector.unwrap_or(0.0)
|
chunk.scores.vector.unwrap_or(0.0)
|
||||||
}),
|
}),
|
||||||
fts_chunk_scores: Vec::new(),
|
fts_chunk_scores: sample_scores(&vector_chunks, |chunk| {
|
||||||
|
chunk.scores.fts.unwrap_or(0.0)
|
||||||
|
}),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
vector_chunks.sort_by(|a, b| b.fused.partial_cmp(&a.fused).unwrap_or(Ordering::Equal));
|
sort_by_fused_desc(&mut vector_chunks);
|
||||||
ctx.revised_chunk_values = vector_chunks;
|
ctx.revised_chunk_values = vector_chunks;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -668,13 +788,6 @@ pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError
|
|||||||
pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
debug!("Assembling chunk-only retrieval results");
|
debug!("Assembling chunk-only retrieval results");
|
||||||
let mut chunk_values = std::mem::take(&mut ctx.revised_chunk_values);
|
let mut chunk_values = std::mem::take(&mut ctx.revised_chunk_values);
|
||||||
let question_terms = extract_keywords(&ctx.input_text);
|
|
||||||
rank_chunks_by_combined_score(
|
|
||||||
&mut chunk_values,
|
|
||||||
&question_terms,
|
|
||||||
ctx.config.tuning.lexical_match_weight,
|
|
||||||
);
|
|
||||||
|
|
||||||
// Limit how many chunks we return to keep context size reasonable.
|
// Limit how many chunks we return to keep context size reasonable.
|
||||||
let limit = ctx
|
let limit = ctx
|
||||||
.config
|
.config
|
||||||
@@ -682,7 +795,13 @@ pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
|||||||
.chunk_result_cap
|
.chunk_result_cap
|
||||||
.max(1)
|
.max(1)
|
||||||
.min(ctx.config.tuning.chunk_vector_take.max(1));
|
.min(ctx.config.tuning.chunk_vector_take.max(1));
|
||||||
|
|
||||||
if chunk_values.len() > limit {
|
if chunk_values.len() > limit {
|
||||||
|
println!(
|
||||||
|
"We removed chunks! we had {:?}, now going for {:?}",
|
||||||
|
chunk_values.len(),
|
||||||
|
limit
|
||||||
|
);
|
||||||
chunk_values.truncate(limit);
|
chunk_values.truncate(limit);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -847,6 +966,89 @@ fn normalize_fts_scores<T>(results: &mut [Scored<T>]) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn normalize_vector_scores<T>(results: &mut [Scored<T>]) {
|
||||||
|
// Only normalize scores for items that actually have vector scores
|
||||||
|
let items_with_scores: Vec<(usize, f32)> = results
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.filter_map(|(idx, candidate)| candidate.scores.vector.map(|score| (idx, score)))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if items_with_scores.len() < 2 {
|
||||||
|
// Don't normalize if we have 0 or 1 scores - nothing to normalize against
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let raw_scores: Vec<f32> = items_with_scores.iter().map(|(_, score)| *score).collect();
|
||||||
|
|
||||||
|
// For cosine similarity scores (already in [0,1]), use a gentler normalization
|
||||||
|
// that preserves more of the original distribution
|
||||||
|
// Only normalize if the range is significant (more than 0.1 difference)
|
||||||
|
let min = raw_scores.iter().fold(f32::MAX, |a, &b| a.min(b));
|
||||||
|
let max = raw_scores.iter().fold(f32::MIN, |a, &b| a.max(b));
|
||||||
|
let range = max - min;
|
||||||
|
|
||||||
|
if range < 0.1 {
|
||||||
|
// Scores are too similar, don't normalize (would compress too much)
|
||||||
|
debug!(
|
||||||
|
vector_score_range = range,
|
||||||
|
min = min,
|
||||||
|
max = max,
|
||||||
|
"Skipping vector normalization - scores too similar"
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let normalized = min_max_normalize(&raw_scores);
|
||||||
|
|
||||||
|
for ((idx, _), normalized_score) in items_with_scores.iter().zip(normalized.into_iter()) {
|
||||||
|
results[*idx].scores.vector = Some(normalized_score);
|
||||||
|
results[*idx].update_fused(0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize_fts_scores_in_merged<T>(results: &mut [Scored<T>]) {
|
||||||
|
// Only normalize scores for items that actually have FTS scores
|
||||||
|
let items_with_scores: Vec<(usize, f32)> = results
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.filter_map(|(idx, candidate)| candidate.scores.fts.map(|score| (idx, score)))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if items_with_scores.len() < 2 {
|
||||||
|
// Don't normalize if we have 0 or 1 scores - nothing to normalize against
|
||||||
|
// Single FTS score would become 1.0, which doesn't help
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let raw_scores: Vec<f32> = items_with_scores.iter().map(|(_, score)| *score).collect();
|
||||||
|
|
||||||
|
// BM25 scores can be negative or very high, so normalization is more important
|
||||||
|
// But check if we have enough variation to normalize
|
||||||
|
let min = raw_scores.iter().fold(f32::MAX, |a, &b| a.min(b));
|
||||||
|
let max = raw_scores.iter().fold(f32::MIN, |a, &b| a.max(b));
|
||||||
|
let range = max - min;
|
||||||
|
|
||||||
|
// For BM25, even small differences can be meaningful, but if all scores are
|
||||||
|
// very similar, normalization won't help
|
||||||
|
if range < 0.01 {
|
||||||
|
debug!(
|
||||||
|
fts_score_range = range,
|
||||||
|
min = min,
|
||||||
|
max = max,
|
||||||
|
"Skipping FTS normalization - scores too similar"
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let normalized = min_max_normalize(&raw_scores);
|
||||||
|
|
||||||
|
for ((idx, _), normalized_score) in items_with_scores.iter().zip(normalized.into_iter()) {
|
||||||
|
results[*idx].scores.fts = Some(normalized_score);
|
||||||
|
results[*idx].update_fused(0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn apply_fusion<T>(candidates: &mut HashMap<String, Scored<T>>, weights: FusionWeights)
|
fn apply_fusion<T>(candidates: &mut HashMap<String, Scored<T>>, weights: FusionWeights)
|
||||||
where
|
where
|
||||||
T: StoredObject,
|
T: StoredObject,
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
use std::cmp::Ordering;
|
use std::cmp::Ordering;
|
||||||
|
|
||||||
use common::storage::types::StoredObject;
|
use common::storage::types::StoredObject;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// Holds optional subscores gathered from different retrieval signals.
|
/// Holds optional subscores gathered from different retrieval signals.
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
@@ -48,7 +49,7 @@ impl<T> Scored<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Weights used for linear score fusion.
|
/// Weights used for linear score fusion.
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||||
pub struct FusionWeights {
|
pub struct FusionWeights {
|
||||||
pub vector: f32,
|
pub vector: f32,
|
||||||
pub fts: f32,
|
pub fts: f32,
|
||||||
@@ -58,11 +59,14 @@ pub struct FusionWeights {
|
|||||||
|
|
||||||
impl Default for FusionWeights {
|
impl Default for FusionWeights {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
|
// Default weights favor vector search, which typically performs better
|
||||||
|
// FTS is used as a complement when there's good overlap
|
||||||
|
// Higher multi_bonus to heavily favor chunks with both signals (the "golden chunk")
|
||||||
Self {
|
Self {
|
||||||
vector: 0.5,
|
vector: 0.8,
|
||||||
fts: 0.3,
|
fts: 0.2,
|
||||||
graph: 0.2,
|
graph: 0.2,
|
||||||
multi_bonus: 0.02,
|
multi_bonus: 0.3, // Increased to boost chunks with both signals
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -134,8 +138,19 @@ pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 {
|
|||||||
.chain(scores.fts.iter())
|
.chain(scores.fts.iter())
|
||||||
.chain(scores.graph.iter())
|
.chain(scores.graph.iter())
|
||||||
.count();
|
.count();
|
||||||
|
|
||||||
|
// Boost chunks with multiple signals (especially vector + FTS, the "golden chunk")
|
||||||
if signals_present >= 2 {
|
if signals_present >= 2 {
|
||||||
fused += weights.multi_bonus;
|
// For chunks with both vector and FTS, give a significant boost
|
||||||
|
// This helps identify the "golden chunk" that appears in both searches
|
||||||
|
if scores.vector.is_some() && scores.fts.is_some() {
|
||||||
|
// Multiplicative boost: multiply by (1 + bonus) to scale with the base score
|
||||||
|
// This ensures high-scoring golden chunks get boosted more than low-scoring ones
|
||||||
|
fused = fused * (1.0 + weights.multi_bonus);
|
||||||
|
} else {
|
||||||
|
// For other multi-signal combinations (e.g., vector + graph), use additive bonus
|
||||||
|
fused += weights.multi_bonus;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
clamp_unit(fused)
|
clamp_unit(fused)
|
||||||
|
|||||||
Reference in New Issue
Block a user