retrieval: hybrid search, linear fusion

This commit is contained in:
Per Stark
2025-12-04 12:48:59 +01:00
parent dd881efbf9
commit d3fa3be3e5
8 changed files with 570 additions and 101 deletions

View File

@@ -1,7 +1,5 @@
use super::types::StoredObject;
use crate::{
error::AppError,
};
use crate::error::AppError;
use axum_session::{SessionConfig, SessionError, SessionStore};
use axum_session_surreal::SessionSurrealPool;
use futures::Stream;

View File

@@ -84,11 +84,11 @@ async fn ensure_runtime_indexes_inner(
) -> Result<()> {
create_fts_analyzer(db).await?;
let fts_tasks = fts_index_specs().into_iter().map(|spec| async move {
for spec in fts_index_specs() {
if index_exists(db, spec.table, spec.index_name).await? {
return Ok(());
continue;
}
// We need to create these sequentially otherwise SurrealDB errors with read/write clash
create_index_with_polling(
db,
spec.definition(),
@@ -96,48 +96,43 @@ async fn ensure_runtime_indexes_inner(
spec.table,
Some(spec.table),
)
.await
.await?;
}
let hnsw_tasks = hnsw_index_specs().into_iter().map(|spec| async move {
match hnsw_index_state(db, &spec, embedding_dimension).await? {
HnswIndexState::Missing => {
create_index_with_polling(
db,
spec.definition_if_not_exists(embedding_dimension),
spec.index_name,
spec.table,
Some(spec.table),
)
.await
}
HnswIndexState::Matches => Ok(()),
HnswIndexState::Different(existing) => {
info!(
index = spec.index_name,
table = spec.table,
existing_dimension = existing,
target_dimension = embedding_dimension,
"Overwriting HNSW index to match new embedding dimension"
);
create_index_with_polling(
db,
spec.definition_overwrite(embedding_dimension),
spec.index_name,
spec.table,
Some(spec.table),
)
.await
}
}
});
let hnsw_tasks = hnsw_index_specs()
.into_iter()
.map(|spec| async move {
match hnsw_index_state(db, &spec, embedding_dimension).await? {
HnswIndexState::Missing => {
create_index_with_polling(
db,
spec.definition_if_not_exists(embedding_dimension),
spec.index_name,
spec.table,
Some(spec.table),
)
.await
}
HnswIndexState::Matches => Ok(()),
HnswIndexState::Different(existing) => {
info!(
index = spec.index_name,
table = spec.table,
existing_dimension = existing,
target_dimension = embedding_dimension,
"Overwriting HNSW index to match new embedding dimension"
);
create_index_with_polling(
db,
spec.definition_overwrite(embedding_dimension),
spec.index_name,
spec.table,
Some(spec.table),
)
.await
}
}
});
futures::try_join!(
async { try_join_all(fts_tasks).await.map(|_| ()) },
async { try_join_all(hnsw_tasks).await.map(|_| ()) },
)?;
try_join_all(hnsw_tasks).await.map(|_| ())?;
Ok(())
}
@@ -204,20 +199,48 @@ fn extract_dimension(definition: &str) -> Option<u64> {
}
async fn create_fts_analyzer(db: &SurrealDbClient) -> Result<()> {
let analyzer_query = format!(
// Prefer snowball stemming when supported; fall back to ascii-only when the filter
// is unavailable in the running Surreal build. Use IF NOT EXISTS to avoid clobbering
// an existing analyzer definition.
let snowball_query = format!(
"DEFINE ANALYZER IF NOT EXISTS {analyzer}
TOKENIZERS class
FILTERS lowercase, ascii, snowball(english);",
analyzer = FTS_ANALYZER_NAME
);
let res = db
.client
.query(analyzer_query)
.await
.context("creating FTS analyzer")?;
match db.client.query(snowball_query).await {
Ok(res) => {
if res.check().is_ok() {
return Ok(());
}
warn!(
"Snowball analyzer check failed; attempting ascii fallback definition (analyzer: {})",
FTS_ANALYZER_NAME
);
}
Err(err) => {
warn!(
error = %err,
"Snowball analyzer creation errored; attempting ascii fallback definition"
);
}
}
let fallback_query = format!(
"DEFINE ANALYZER IF NOT EXISTS {analyzer}
TOKENIZERS class
FILTERS lowercase, ascii;",
analyzer = FTS_ANALYZER_NAME
);
db.client
.query(fallback_query)
.await
.context("creating fallback FTS analyzer")?
.check()
.context("failed to create fallback FTS analyzer")?;
res.check().context("failed to create FTS analyzer")?;
Ok(())
}
@@ -235,13 +258,38 @@ async fn create_index_with_polling(
None => None,
};
let res = db
.client
.query(definition)
.await
.with_context(|| format!("creating index {index_name} on table {table}"))?;
res.check()
.with_context(|| format!("index definition failed for {index_name} on {table}"))?;
let mut attempts = 0;
const MAX_ATTEMPTS: usize = 3;
loop {
attempts += 1;
let res = db
.client
.query(definition.clone())
.await
.with_context(|| format!("creating index {index_name} on table {table}"))?;
match res.check() {
Ok(_) => break,
Err(err) => {
let msg = err.to_string();
let conflict = msg.contains("read or write conflict");
warn!(
index = %index_name,
table = %table,
error = ?err,
attempt = attempts,
definition = %definition,
"Index definition failed"
);
if conflict && attempts < MAX_ATTEMPTS {
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
}
return Err(err).with_context(|| {
format!("index definition failed for {index_name} on {table}")
});
}
}
}
info!(
index = %index_name,

View File

@@ -53,9 +53,9 @@ impl SystemSettings {
#[cfg(test)]
mod tests {
use crate::storage::indexes::ensure_runtime_indexes;
use crate::storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk};
use async_openai::Client;
use crate::storage::indexes::ensure_runtime_indexes;
use super::*;
use uuid::Uuid;

View File

@@ -17,9 +17,9 @@ stored_object!(TextChunk, "text_chunk", {
user_id: String
});
/// Vector search result including hydrated chunk.
/// Search result including hydrated chunk.
#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq)]
pub struct TextChunkVectorResult {
pub struct TextChunkSearchResult {
pub chunk: TextChunk,
pub score: f32,
}
@@ -97,7 +97,7 @@ impl TextChunk {
query_embedding: Vec<f32>,
db: &SurrealDbClient,
user_id: &str,
) -> Result<Vec<TextChunkVectorResult>, AppError> {
) -> Result<Vec<TextChunkSearchResult>, AppError> {
#[derive(Deserialize)]
struct Row {
chunk_id: TextChunk,
@@ -132,13 +132,85 @@ impl TextChunk {
Ok(rows
.into_iter()
.map(|r| TextChunkVectorResult {
.map(|r| TextChunkSearchResult {
chunk: r.chunk_id,
score: r.score,
})
.collect())
}
/// Full-text search over text chunks using the BM25 FTS index.
pub async fn fts_search(
take: usize,
terms: &str,
db: &SurrealDbClient,
user_id: &str,
) -> Result<Vec<TextChunkSearchResult>, AppError> {
#[derive(Deserialize)]
struct Row {
#[serde(deserialize_with = "deserialize_flexible_id")]
id: String,
#[serde(deserialize_with = "deserialize_datetime")]
created_at: DateTime<Utc>,
#[serde(deserialize_with = "deserialize_datetime")]
updated_at: DateTime<Utc>,
source_id: String,
chunk: String,
user_id: String,
score: f32,
}
let sql = format!(
r#"
SELECT
id,
created_at,
updated_at,
source_id,
chunk,
user_id,
IF search::score(0) != NONE THEN search::score(0) ELSE 0 END AS score
FROM {chunk_table}
WHERE chunk @0@ $terms
AND user_id = $user_id
ORDER BY score DESC
LIMIT $limit;
"#,
chunk_table = Self::table_name(),
);
let mut response = db
.query(&sql)
.bind(("terms", terms.to_owned()))
.bind(("user_id", user_id.to_owned()))
.bind(("limit", take as i64))
.await
.map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?;
response = response.check().map_err(AppError::Database)?;
let rows: Vec<Row> = response.take::<Vec<Row>>(0).map_err(AppError::Database)?;
Ok(rows
.into_iter()
.map(|r| {
let chunk = TextChunk {
id: r.id,
created_at: r.created_at,
updated_at: r.updated_at,
source_id: r.source_id,
chunk: r.chunk,
user_id: r.user_id,
};
TextChunkSearchResult {
chunk,
score: r.score,
}
})
.collect())
}
/// Re-creates embeddings for all text chunks using a safe, atomic transaction.
///
/// This is a costly operation that should be run in the background. It performs these steps:
@@ -252,6 +324,26 @@ mod tests {
use surrealdb::RecordId;
use uuid::Uuid;
async fn ensure_chunk_fts_index(db: &SurrealDbClient) {
let snowball_sql = r#"
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english);
DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25;
"#;
if let Err(err) = db.client.query(snowball_sql).await {
// Fall back to ascii-only analyzer when snowball is unavailable in the build.
let fallback_sql = r#"
DEFINE ANALYZER OVERWRITE app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii;
DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25;
"#;
db.client
.query(fallback_sql)
.await
.unwrap_or_else(|_| panic!("define chunk fts index fallback: {err}"));
}
}
#[tokio::test]
async fn test_text_chunk_creation() {
let source_id = "source123".to_string();
@@ -435,7 +527,7 @@ mod tests {
.await
.expect("redefine index");
let results: Vec<TextChunkVectorResult> =
let results: Vec<TextChunkSearchResult> =
TextChunk::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user")
.await
.unwrap();
@@ -467,7 +559,7 @@ mod tests {
.await
.expect("store");
let results: Vec<TextChunkVectorResult> =
let results: Vec<TextChunkSearchResult> =
TextChunk::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
.await
.unwrap();
@@ -503,7 +595,7 @@ mod tests {
.await
.expect("store chunk2");
let results: Vec<TextChunkVectorResult> =
let results: Vec<TextChunkSearchResult> =
TextChunk::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id)
.await
.unwrap();
@@ -513,4 +605,105 @@ mod tests {
assert_eq!(results[1].chunk.id, chunk1.id);
assert!(results[0].score >= results[1].score);
}
#[tokio::test]
async fn test_fts_search_returns_empty_when_no_chunks() {
let namespace = "fts_chunk_ns_empty";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
ensure_chunk_fts_index(&db).await;
db.rebuild_indexes().await.expect("rebuild indexes");
let results = TextChunk::fts_search(5, "hello", &db, "user")
.await
.expect("fts search");
assert!(results.is_empty());
}
#[tokio::test]
async fn test_fts_search_single_result() {
let namespace = "fts_chunk_ns_single";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
ensure_chunk_fts_index(&db).await;
let user_id = "fts_user";
let chunk = TextChunk::new(
"fts_src".to_string(),
"rustaceans love rust".to_string(),
user_id.to_string(),
);
db.store_item(chunk.clone()).await.expect("store chunk");
db.rebuild_indexes().await.expect("rebuild indexes");
let results = TextChunk::fts_search(3, "rust", &db, user_id)
.await
.expect("fts search");
assert_eq!(results.len(), 1);
assert_eq!(results[0].chunk.id, chunk.id);
assert!(results[0].score.is_finite(), "expected a finite FTS score");
}
#[tokio::test]
async fn test_fts_search_orders_by_score_and_filters_user() {
let namespace = "fts_chunk_ns_order";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
ensure_chunk_fts_index(&db).await;
let user_id = "fts_user_order";
let high_score_chunk = TextChunk::new(
"src1".to_string(),
"apple apple apple pie recipe".to_string(),
user_id.to_string(),
);
let low_score_chunk = TextChunk::new(
"src2".to_string(),
"apple tart".to_string(),
user_id.to_string(),
);
let other_user_chunk = TextChunk::new(
"src3".to_string(),
"apple orchard guide".to_string(),
"other_user".to_string(),
);
db.store_item(high_score_chunk.clone())
.await
.expect("store high score chunk");
db.store_item(low_score_chunk.clone())
.await
.expect("store low score chunk");
db.store_item(other_user_chunk)
.await
.expect("store other user chunk");
db.rebuild_indexes().await.expect("rebuild indexes");
let results = TextChunk::fts_search(3, "apple", &db, user_id)
.await
.expect("fts search");
assert_eq!(results.len(), 2);
let ids: Vec<_> = results.iter().map(|r| r.chunk.id.as_str()).collect();
assert!(
ids.contains(&high_score_chunk.id.as_str())
&& ids.contains(&low_score_chunk.id.as_str()),
"expected only the two chunks for the same user"
);
assert!(
results[0].score >= results[1].score,
"expected results ordered by descending score"
);
}
}