retrieval simplfied

This commit is contained in:
Per Stark
2025-12-09 20:35:42 +01:00
parent a8d10f265c
commit a090a8c76e
55 changed files with 469 additions and 1208 deletions

View File

@@ -51,6 +51,24 @@ pub fn create_user_message(entities_json: &Value, query: &str) -> String {
)
}
/// Convert chunk-based retrieval results to JSON format for LLM context
pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value {
fn round_score(value: f32) -> f64 {
(f64::from(value) * 1000.0).round() / 1000.0
}
serde_json::json!(chunks
.iter()
.map(|chunk| {
serde_json::json!({
"content": chunk.chunk.chunk,
"source_id": chunk.chunk.source_id,
"score": round_score(chunk.score),
})
})
.collect::<Vec<_>>())
}
pub fn create_user_message_with_history(
entities_json: &Value,
history: &[Message],

View File

@@ -1,268 +0,0 @@
use std::collections::HashMap;
use serde::Deserialize;
use tracing::debug;
use common::{
error::AppError,
storage::{db::SurrealDbClient, types::StoredObject},
};
use crate::scoring::Scored;
use common::storage::types::file_info::deserialize_flexible_id;
use surrealdb::sql::Thing;
#[derive(Debug, Deserialize)]
struct FtsScoreRow {
#[serde(deserialize_with = "deserialize_flexible_id")]
id: String,
fts_score: Option<f32>,
}
/// Executes a full-text search query against SurrealDB and returns scored results.
///
/// The function expects FTS indexes to exist for the provided table. Currently supports
/// `knowledge_entity` (name + description) and `text_chunk` (chunk).
pub async fn find_items_by_fts<T>(
take: usize,
query: &str,
db_client: &SurrealDbClient,
table: &str,
user_id: &str,
) -> Result<Vec<Scored<T>>, AppError>
where
T: for<'de> serde::Deserialize<'de> + StoredObject,
{
let (filter_clause, score_clause) = match table {
"knowledge_entity" => (
"(name @0@ $terms OR description @1@ $terms)",
"(IF search::score(0) != NONE THEN search::score(0) ELSE 0 END) + \
(IF search::score(1) != NONE THEN search::score(1) ELSE 0 END)",
),
"text_chunk" => (
"(chunk @0@ $terms)",
"IF search::score(0) != NONE THEN search::score(0) ELSE 0 END",
),
_ => {
return Err(AppError::Validation(format!(
"FTS not configured for table '{table}'"
)))
}
};
let sql = format!(
"SELECT id, {score_clause} AS fts_score \
FROM {table} \
WHERE {filter_clause} \
AND user_id = $user_id \
ORDER BY fts_score DESC \
LIMIT $limit",
table = table,
filter_clause = filter_clause,
score_clause = score_clause
);
debug!(
table = table,
limit = take,
"Executing FTS query with filter clause: {}",
filter_clause
);
let mut response = db_client
.query(sql)
.bind(("terms", query.to_owned()))
.bind(("user_id", user_id.to_owned()))
.bind(("limit", take as i64))
.await?;
let score_rows: Vec<FtsScoreRow> = response.take(0)?;
if score_rows.is_empty() {
return Ok(Vec::new());
}
let ids: Vec<String> = score_rows.iter().map(|row| row.id.clone()).collect();
let thing_ids: Vec<Thing> = ids
.iter()
.map(|id| Thing::from((table, id.as_str())))
.collect();
let mut items_response = db_client
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
.bind(("table", table.to_owned()))
.bind(("things", thing_ids.clone()))
.bind(("user_id", user_id.to_owned()))
.await?;
let items: Vec<T> = items_response.take(0)?;
let mut item_map: HashMap<String, T> = items
.into_iter()
.map(|item| (item.get_id().to_owned(), item))
.collect();
let mut results = Vec::with_capacity(score_rows.len());
for row in score_rows {
if let Some(item) = item_map.remove(&row.id) {
let score = row.fts_score.unwrap_or_default();
results.push(Scored::new(item).with_fts_score(score));
}
}
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
use common::storage::indexes::ensure_runtime_indexes;
use common::storage::types::{
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
text_chunk::TextChunk,
StoredObject,
};
use uuid::Uuid;
#[tokio::test]
async fn fts_preserves_single_field_score_for_name() {
let namespace = "fts_test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("failed to create in-memory surreal");
db.apply_migrations()
.await
.expect("failed to apply migrations");
ensure_runtime_indexes(&db, 1536)
.await
.expect("failed to build runtime indexes");
let user_id = "user_fts";
let entity = KnowledgeEntity::new(
"source_a".into(),
"Rustacean handbook".into(),
"completely unrelated description".into(),
KnowledgeEntityType::Document,
None,
user_id.into(),
);
db.store_item(entity.clone())
.await
.expect("failed to insert entity");
db.rebuild_indexes()
.await
.expect("failed to rebuild indexes");
let results = find_items_by_fts::<KnowledgeEntity>(
5,
"rustacean",
&db,
KnowledgeEntity::table_name(),
user_id,
)
.await
.expect("fts query failed");
assert!(!results.is_empty(), "expected at least one FTS result");
assert!(
results[0].scores.fts.is_some(),
"expected an FTS score when only the name matched"
);
}
#[tokio::test]
async fn fts_preserves_single_field_score_for_description() {
let namespace = "fts_test_ns_desc";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("failed to create in-memory surreal");
db.apply_migrations()
.await
.expect("failed to apply migrations");
ensure_runtime_indexes(&db, 1536)
.await
.expect("failed to build runtime indexes");
let user_id = "user_fts_desc";
let entity = KnowledgeEntity::new(
"source_b".into(),
"neutral name".into(),
"Detailed notes about async runtimes".into(),
KnowledgeEntityType::Document,
None,
user_id.into(),
);
db.store_item(entity.clone())
.await
.expect("failed to insert entity");
db.rebuild_indexes()
.await
.expect("failed to rebuild indexes");
let results = find_items_by_fts::<KnowledgeEntity>(
5,
"async",
&db,
KnowledgeEntity::table_name(),
user_id,
)
.await
.expect("fts query failed");
assert!(!results.is_empty(), "expected at least one FTS result");
assert!(
results[0].scores.fts.is_some(),
"expected an FTS score when only the description matched"
);
}
#[tokio::test]
async fn fts_preserves_scores_for_text_chunks() {
let namespace = "fts_test_ns_chunks";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("failed to create in-memory surreal");
db.apply_migrations()
.await
.expect("failed to apply migrations");
ensure_runtime_indexes(&db, 1536)
.await
.expect("failed to build runtime indexes");
let user_id = "user_fts_chunk";
let chunk = TextChunk::new(
"source_chunk".into(),
"GraphQL documentation reference".into(),
user_id.into(),
);
TextChunk::store_with_embedding(chunk.clone(), vec![0.0; 1536], &db)
.await
.expect("failed to insert chunk");
db.rebuild_indexes()
.await
.expect("failed to rebuild indexes");
let results =
find_items_by_fts::<TextChunk>(5, "graphql", &db, TextChunk::table_name(), user_id)
.await
.expect("fts query failed");
assert!(!results.is_empty(), "expected at least one FTS result");
assert!(
results[0].scores.fts.is_some(),
"expected an FTS score when chunk field matched"
);
}
}

View File

@@ -10,54 +10,17 @@ use common::storage::{
},
};
/// Retrieves database entries that match a specific source identifier.
/// Find entities related to the given entity via graph relationships.
///
/// This function queries the database for all records in a specified table that have
/// a matching `source_id` field. It's commonly used to find related entities or
/// track the origin of database entries.
/// Queries the `relates_to` edge table for all relationships involving the entity,
/// then fetches and returns the neighboring entities.
///
/// # Arguments
///
/// * `source_id` - The identifier to search for in the database
/// * `table_name` - The name of the table to search in
/// * `db_client` - The `SurrealDB` client instance for database operations
///
/// # Type Parameters
///
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`
///
/// # Returns
///
/// Returns a `Result` containing either:
/// * `Ok(Vec<T>)` - A vector of matching records deserialized into type `T`
/// * `Err(Error)` - An error if the database query fails
///
/// # Errors
///
/// This function will return a `Error` if:
/// * The database query fails to execute
/// * The results cannot be deserialized into type `T`
pub async fn find_entities_by_source_ids<T>(
source_ids: Vec<String>,
table_name: &str,
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<T>, Error>
where
T: for<'de> serde::Deserialize<'de>,
{
let query =
"SELECT * FROM type::table($table) WHERE source_id IN $source_ids AND user_id = $user_id";
/// * `db` - Database client
/// * `entity_id` - ID of the entity to find neighbors for
/// * `user_id` - User ID for access control
/// * `limit` - Maximum number of neighbors to return
db.query(query)
.bind(("table", table_name.to_owned()))
.bind(("source_ids", source_ids))
.bind(("user_id", user_id.to_owned()))
.await?
.take(0)
}
/// Find entities by their relationship to the id
pub async fn find_entities_by_relationship_by_id(
db: &SurrealDbClient,
entity_id: &str,
@@ -153,154 +116,8 @@ mod tests {
use super::*;
use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
use common::storage::types::knowledge_relationship::KnowledgeRelationship;
use common::storage::types::StoredObject;
use uuid::Uuid;
#[tokio::test]
async fn test_find_entities_by_source_ids() {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
// Create some test entities with different source_ids
let source_id1 = "source123".to_string();
let source_id2 = "source456".to_string();
let source_id3 = "source789".to_string();
let entity_type = KnowledgeEntityType::Document;
let user_id = "user123".to_string();
// Entity with source_id1
let entity1 = KnowledgeEntity::new(
source_id1.clone(),
"Entity 1".to_string(),
"Description 1".to_string(),
entity_type.clone(),
None,
user_id.clone(),
);
// Entity with source_id2
let entity2 = KnowledgeEntity::new(
source_id2.clone(),
"Entity 2".to_string(),
"Description 2".to_string(),
entity_type.clone(),
None,
user_id.clone(),
);
// Another entity with source_id1
let entity3 = KnowledgeEntity::new(
source_id1.clone(),
"Entity 3".to_string(),
"Description 3".to_string(),
entity_type.clone(),
None,
user_id.clone(),
);
// Entity with source_id3
let entity4 = KnowledgeEntity::new(
source_id3.clone(),
"Entity 4".to_string(),
"Description 4".to_string(),
entity_type.clone(),
None,
user_id.clone(),
);
// Store all entities
db.store_item(entity1.clone())
.await
.expect("Failed to store entity 1");
db.store_item(entity2.clone())
.await
.expect("Failed to store entity 2");
db.store_item(entity3.clone())
.await
.expect("Failed to store entity 3");
db.store_item(entity4.clone())
.await
.expect("Failed to store entity 4");
// Test finding entities by multiple source_ids
let source_ids = vec![source_id1.clone(), source_id2.clone()];
let found_entities: Vec<KnowledgeEntity> =
find_entities_by_source_ids(source_ids, KnowledgeEntity::table_name(), &user_id, &db)
.await
.expect("Failed to find entities by source_ids");
// Should find 3 entities (2 with source_id1, 1 with source_id2)
assert_eq!(
found_entities.len(),
3,
"Should find 3 entities with the specified source_ids"
);
// Check that entities with source_id1 and source_id2 are found
let found_source_ids: Vec<String> =
found_entities.iter().map(|e| e.source_id.clone()).collect();
assert!(
found_source_ids.contains(&source_id1),
"Should find entities with source_id1"
);
assert!(
found_source_ids.contains(&source_id2),
"Should find entities with source_id2"
);
assert!(
!found_source_ids.contains(&source_id3),
"Should not find entities with source_id3"
);
// Test finding entities by a single source_id
let single_source_id = vec![source_id1.clone()];
let found_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
single_source_id,
KnowledgeEntity::table_name(),
&user_id,
&db,
)
.await
.expect("Failed to find entities by single source_id");
// Should find 2 entities with source_id1
assert_eq!(
found_entities.len(),
2,
"Should find 2 entities with source_id1"
);
// Check that all found entities have source_id1
for entity in found_entities {
assert_eq!(
entity.source_id, source_id1,
"All found entities should have source_id1"
);
}
// Test finding entities with non-existent source_id
let non_existent_source_id = vec!["non_existent_source".to_string()];
let found_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
non_existent_source_id,
KnowledgeEntity::table_name(),
&user_id,
&db,
)
.await
.expect("Failed to find entities by non-existent source_id");
// Should find 0 entities
assert_eq!(
found_entities.len(),
0,
"Should find 0 entities with non-existent source_id"
);
}
#[tokio::test]
async fn test_find_entities_by_relationship_by_id() {

View File

@@ -1,6 +1,6 @@
pub mod answer_retrieval;
pub mod answer_retrieval_helper;
pub mod fts;
pub mod graph;
pub mod pipeline;
pub mod reranking;
@@ -70,11 +70,7 @@ mod tests {
use super::*;
use async_openai::Client;
use common::storage::indexes::ensure_runtime_indexes;
use common::storage::types::{
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
knowledge_relationship::KnowledgeRelationship,
text_chunk::TextChunk,
};
use common::storage::types::text_chunk::TextChunk;
use pipeline::{RetrievalConfig, RetrievalStrategy};
use uuid::Uuid;
@@ -82,14 +78,6 @@ mod tests {
vec![0.9, 0.1, 0.0]
}
fn entity_embedding_high() -> Vec<f32> {
vec![0.8, 0.2, 0.0]
}
fn entity_embedding_low() -> Vec<f32> {
vec![0.1, 0.9, 0.0]
}
fn chunk_embedding_primary() -> Vec<f32> {
vec![0.85, 0.15, 0.0]
}
@@ -113,41 +101,19 @@ mod tests {
.await
.expect("failed to build runtime indexes");
db.query(
"BEGIN TRANSACTION;
REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding;
DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION 3;
REMOVE INDEX IF EXISTS idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding;
DEFINE INDEX idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 3;
COMMIT TRANSACTION;",
)
.await
.expect("Failed to configure indices");
db
}
#[tokio::test]
async fn test_retrieve_entities_with_embedding_basic_flow() {
async fn test_default_strategy_retrieves_chunks() {
let db = setup_test_db().await;
let user_id = "test_user";
let entity = KnowledgeEntity::new(
"source_1".into(),
"Rust async guide".into(),
"Detailed notes about async runtimes".into(),
KnowledgeEntityType::Document,
None,
user_id.into(),
);
let chunk = TextChunk::new(
entity.source_id.clone(),
"source_1".into(),
"Tokio uses cooperative scheduling for fairness.".into(),
user_id.into(),
);
KnowledgeEntity::store_with_embedding(entity.clone(), entity_embedding_high(), &db)
.await
.expect("Failed to store entity");
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db)
.await
.expect("Failed to store chunk");
@@ -164,64 +130,32 @@ mod tests {
None,
)
.await
.expect("Hybrid retrieval failed");
.expect("Default strategy retrieval failed");
let entities = match results {
StrategyOutput::Entities(items) => items,
other => panic!("expected entity results, got {:?}", other),
let chunks = match results {
StrategyOutput::Chunks(items) => items,
other => panic!("expected chunk results, got {:?}", other),
};
assert!(!chunks.is_empty(), "Expected at least one retrieval result");
assert!(
!entities.is_empty(),
"Expected at least one retrieval result"
);
let top = &entities[0];
assert!(
top.entity.name.contains("Rust"),
"Expected Rust entity to be ranked first"
);
assert!(
!top.chunks.is_empty(),
"Expected Rust entity to include supporting chunks"
chunks[0].chunk.chunk.contains("Tokio"),
"Expected chunk about Tokio"
);
}
#[tokio::test]
async fn test_graph_relationship_enriches_results() {
async fn test_default_strategy_returns_chunks_from_multiple_sources() {
let db = setup_test_db().await;
let user_id = "graph_user";
let primary = KnowledgeEntity::new(
"primary_source".into(),
"Async Rust patterns".into(),
"Explores async runtimes and scheduling strategies.".into(),
KnowledgeEntityType::Document,
None,
user_id.into(),
);
let neighbor = KnowledgeEntity::new(
"neighbor_source".into(),
"Tokio Scheduler Deep Dive".into(),
"Details on Tokio's cooperative scheduler.".into(),
KnowledgeEntityType::Document,
None,
user_id.into(),
);
KnowledgeEntity::store_with_embedding(primary.clone(), entity_embedding_high(), &db)
.await
.expect("Failed to store primary entity");
KnowledgeEntity::store_with_embedding(neighbor.clone(), entity_embedding_low(), &db)
.await
.expect("Failed to store neighbor entity");
let user_id = "multi_source_user";
let primary_chunk = TextChunk::new(
primary.source_id.clone(),
"primary_source".into(),
"Rust async tasks use Tokio's cooperative scheduler.".into(),
user_id.into(),
);
let neighbor_chunk = TextChunk::new(
neighbor.source_id.clone(),
let secondary_chunk = TextChunk::new(
"secondary_source".into(),
"Tokio's scheduler manages task fairness across executors.".into(),
user_id.into(),
);
@@ -229,23 +163,11 @@ mod tests {
TextChunk::store_with_embedding(primary_chunk, chunk_embedding_primary(), &db)
.await
.expect("Failed to store primary chunk");
TextChunk::store_with_embedding(neighbor_chunk, chunk_embedding_secondary(), &db)
TextChunk::store_with_embedding(secondary_chunk, chunk_embedding_secondary(), &db)
.await
.expect("Failed to store neighbor chunk");
.expect("Failed to store secondary chunk");
let openai_client = Client::new();
let relationship = KnowledgeRelationship::new(
primary.id.clone(),
neighbor.id.clone(),
user_id.into(),
"relationship_source".into(),
"references".into(),
);
relationship
.store_relationship(&db)
.await
.expect("Failed to store relationship");
let results = pipeline::run_pipeline_with_embedding(
&db,
&openai_client,
@@ -257,35 +179,23 @@ mod tests {
None,
)
.await
.expect("Hybrid retrieval failed");
.expect("Default strategy retrieval failed");
let entities = match results {
StrategyOutput::Entities(items) => items,
other => panic!("expected entity results, got {:?}", other),
let chunks = match results {
StrategyOutput::Chunks(items) => items,
other => panic!("expected chunk results, got {:?}", other),
};
let mut neighbor_entry = None;
for entity in &entities {
if entity.entity.id == neighbor.id {
neighbor_entry = Some(entity.clone());
}
}
println!("{:?}", entities);
let neighbor_entry =
neighbor_entry.expect("Graph-enriched neighbor should appear in results");
assert!(chunks.len() >= 2, "Expected chunks from multiple sources");
assert!(
neighbor_entry.score > 0.2,
"Graph-enriched entity should have a meaningful fused score"
chunks.iter().any(|c| c.chunk.source_id == "primary_source"),
"Should include primary source chunk"
);
assert!(
neighbor_entry
.chunks
chunks
.iter()
.all(|chunk| chunk.chunk.source_id == neighbor.source_id),
"Neighbor entity should surface its own supporting chunks"
.any(|c| c.chunk.source_id == "secondary_source"),
"Should include secondary source chunk"
);
}
@@ -311,7 +221,7 @@ mod tests {
.await
.expect("Failed to store chunk two");
let config = RetrievalConfig::with_strategy(RetrievalStrategy::Revised);
let config = RetrievalConfig::with_strategy(RetrievalStrategy::Default);
let openai_client = Client::new();
let results = pipeline::run_pipeline_with_embedding(
&db,

View File

@@ -6,15 +6,17 @@ use crate::scoring::FusionWeights;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, clap::ValueEnum)]
#[serde(rename_all = "snake_case")]
pub enum RetrievalStrategy {
Initial,
Revised,
/// Primary hybrid chunk retrieval for search/chat (formerly Revised)
Default,
/// Entity retrieval for suggesting relationships when creating manual entities
RelationshipSuggestion,
/// Entity retrieval for context during content ingestion
Ingestion,
}
impl Default for RetrievalStrategy {
fn default() -> Self {
Self::Initial
Self::Default
}
}
@@ -23,8 +25,16 @@ impl std::str::FromStr for RetrievalStrategy {
fn from_str(value: &str) -> Result<Self, Self::Err> {
match value.to_ascii_lowercase().as_str() {
"initial" => Ok(Self::Initial),
"revised" => Ok(Self::Revised),
"default" => Ok(Self::Default),
// Backward compatibility: treat "initial" and "revised" as "default"
"initial" | "revised" => {
tracing::warn!(
"Retrieval strategy '{}' is deprecated. Use 'default' instead. \
The 'initial' strategy has been removed in favor of the simpler hybrid chunk retrieval.",
value
);
Ok(Self::Default)
}
"relationship_suggestion" => Ok(Self::RelationshipSuggestion),
"ingestion" => Ok(Self::Ingestion),
other => Err(format!("unknown retrieval strategy '{other}'")),
@@ -35,8 +45,7 @@ impl std::str::FromStr for RetrievalStrategy {
impl fmt::Display for RetrievalStrategy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let label = match self {
RetrievalStrategy::Initial => "initial",
RetrievalStrategy::Revised => "revised",
RetrievalStrategy::Default => "default",
RetrievalStrategy::RelationshipSuggestion => "relationship_suggestion",
RetrievalStrategy::Ingestion => "ingestion",
};
@@ -136,7 +145,7 @@ pub struct RetrievalConfig {
impl RetrievalConfig {
pub fn new(tuning: RetrievalTuning) -> Self {
Self {
strategy: RetrievalStrategy::Initial,
strategy: RetrievalStrategy::Default,
tuning,
}
}

View File

@@ -17,9 +17,7 @@ use std::time::{Duration, Instant};
use tracing::info;
use stages::PipelineContext;
use strategies::{
IngestionDriver, InitialStrategyDriver, RelationshipSuggestionDriver, RevisedStrategyDriver,
};
use strategies::{DefaultStrategyDriver, IngestionDriver, RelationshipSuggestionDriver};
// Export StrategyOutput publicly from this module
// (it's defined in lib.rs but we re-export it here)
@@ -132,25 +130,8 @@ pub async fn run_pipeline(
);
match config.strategy {
RetrievalStrategy::Initial => {
let driver = InitialStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
embedding_provider,
None,
input_text,
user_id,
config,
reranker,
false,
)
.await?;
Ok(StrategyOutput::Entities(run.results))
}
RetrievalStrategy::Revised => {
let driver = RevisedStrategyDriver::new();
RetrievalStrategy::Default => {
let driver = DefaultStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
@@ -214,25 +195,8 @@ pub async fn run_pipeline_with_embedding(
reranker: Option<RerankerLease>,
) -> Result<StrategyOutput, AppError> {
match config.strategy {
RetrievalStrategy::Initial => {
let driver = InitialStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
false,
)
.await?;
Ok(StrategyOutput::Entities(run.results))
}
RetrievalStrategy::Revised => {
let driver = RevisedStrategyDriver::new();
RetrievalStrategy::Default => {
let driver = DefaultStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
@@ -301,29 +265,8 @@ pub async fn run_pipeline_with_embedding_with_metrics(
reranker: Option<RerankerLease>,
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
match config.strategy {
RetrievalStrategy::Initial => {
let driver = InitialStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
false,
)
.await?;
Ok(PipelineRunOutput {
results: StrategyOutput::Entities(run.results),
diagnostics: run.diagnostics,
stage_timings: run.stage_timings,
})
}
RetrievalStrategy::Revised => {
let driver = RevisedStrategyDriver::new();
RetrievalStrategy::Default => {
let driver = DefaultStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
@@ -361,29 +304,8 @@ pub async fn run_pipeline_with_embedding_with_diagnostics(
reranker: Option<RerankerLease>,
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
match config.strategy {
RetrievalStrategy::Initial => {
let driver = InitialStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
true,
)
.await?;
Ok(PipelineRunOutput {
results: StrategyOutput::Entities(run.results),
diagnostics: run.diagnostics,
stage_timings: run.stage_timings,
})
}
RetrievalStrategy::Revised => {
let driver = RevisedStrategyDriver::new();
RetrievalStrategy::Default => {
let driver = DefaultStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,

View File

@@ -12,13 +12,13 @@ use fastembed::RerankResult;
use futures::{stream::FuturesUnordered, StreamExt};
use std::{
cmp::Ordering,
collections::{HashMap, HashSet},
collections::HashMap,
};
use tracing::{debug, instrument, warn};
use crate::{
fts::find_items_by_fts,
graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids},
graph::find_entities_by_relationship_by_id,
reranking::RerankerLease,
scoring::{
clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, reciprocal_rank_fusion,
@@ -45,7 +45,6 @@ pub struct PipelineContext<'a> {
pub config: RetrievalConfig,
pub query_embedding: Option<Vec<f32>>,
pub entity_candidates: HashMap<String, Scored<KnowledgeEntity>>,
pub chunk_candidates: HashMap<String, Scored<TextChunk>>,
pub filtered_entities: Vec<Scored<KnowledgeEntity>>,
pub chunk_values: Vec<Scored<TextChunk>>,
pub revised_chunk_values: Vec<Scored<TextChunk>>,
@@ -75,7 +74,6 @@ impl<'a> PipelineContext<'a> {
config,
query_embedding: None,
entity_candidates: HashMap::new(),
chunk_candidates: HashMap::new(),
filtered_entities: Vec::new(),
chunk_values: Vec::new(),
revised_chunk_values: Vec::new(),
@@ -209,20 +207,6 @@ impl PipelineStage for GraphExpansionStage {
}
}
#[derive(Debug, Clone, Copy)]
pub struct ChunkAttachStage;
#[async_trait]
impl PipelineStage for ChunkAttachStage {
fn kind(&self) -> StageKind {
StageKind::ChunkAttach
}
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
attach_chunks(ctx).await
}
}
#[derive(Debug, Clone, Copy)]
pub struct RerankStage;
@@ -324,75 +308,68 @@ pub async fn collect_candidates(ctx: &mut PipelineContext<'_>) -> Result<(), App
let weights = FusionWeights::default();
let (vector_entity_results, vector_chunk_results, mut fts_entities, mut fts_chunks) = tokio::try_join!(
let (vector_entity_results, fts_entity_results) = tokio::try_join!(
KnowledgeEntity::vector_search(
tuning.entity_vector_take,
embedding.clone(),
ctx.db_client,
&ctx.user_id,
),
TextChunk::vector_search(
tuning.chunk_vector_take,
embedding,
ctx.db_client,
&ctx.user_id,
),
find_items_by_fts(
tuning.entity_fts_take,
&ctx.input_text,
KnowledgeEntity::search(
ctx.db_client,
"knowledge_entity",
&ctx.input_text,
&ctx.user_id,
),
find_items_by_fts(
tuning.chunk_fts_take,
&ctx.input_text,
ctx.db_client,
"text_chunk",
&ctx.user_id
),
tuning.entity_fts_take,
)
)?;
#[allow(clippy::useless_conversion)]
let vector_entities: Vec<Scored<KnowledgeEntity>> = vector_entity_results
.into_iter()
.map(|row| Scored::new(row.entity).with_vector_score(row.score))
.collect();
let vector_chunks: Vec<Scored<TextChunk>> = vector_chunk_results
let mut fts_entities: Vec<Scored<KnowledgeEntity>> = fts_entity_results
.into_iter()
.map(|row| Scored::new(row.chunk).with_vector_score(row.score))
.map(|res| {
let entity = KnowledgeEntity {
id: res.id,
created_at: res.created_at,
updated_at: res.updated_at,
source_id: res.source_id,
name: res.name,
description: res.description,
entity_type: res.entity_type,
metadata: res.metadata,
user_id: res.user_id,
};
Scored::new(entity).with_fts_score(res.score)
})
.collect();
debug!(
vector_entities = vector_entities.len(),
vector_chunks = vector_chunks.len(),
fts_entities = fts_entities.len(),
fts_chunks = fts_chunks.len(),
"Hybrid retrieval initial candidate counts"
);
if ctx.diagnostics_enabled() {
ctx.record_collect_candidates(CollectCandidatesStats {
vector_entity_candidates: vector_entities.len(),
vector_chunk_candidates: vector_chunks.len(),
vector_chunk_candidates: 0,
fts_entity_candidates: fts_entities.len(),
fts_chunk_candidates: fts_chunks.len(),
vector_chunk_scores: sample_scores(&vector_chunks, |chunk| {
chunk.scores.vector.unwrap_or(0.0)
}),
fts_chunk_scores: sample_scores(&fts_chunks, |chunk| chunk.scores.fts.unwrap_or(0.0)),
fts_chunk_candidates: 0,
vector_chunk_scores: Vec::new(),
fts_chunk_scores: Vec::new(),
});
}
normalize_fts_scores(&mut fts_entities);
normalize_fts_scores(&mut fts_chunks);
merge_scored_by_id(&mut ctx.entity_candidates, vector_entities);
merge_scored_by_id(&mut ctx.entity_candidates, fts_entities);
merge_scored_by_id(&mut ctx.chunk_candidates, vector_chunks);
merge_scored_by_id(&mut ctx.chunk_candidates, fts_chunks);
apply_fusion(&mut ctx.entity_candidates, weights);
apply_fusion(&mut ctx.chunk_candidates, weights);
Ok(())
}
@@ -467,82 +444,6 @@ pub async fn expand_graph(ctx: &mut PipelineContext<'_>) -> Result<(), AppError>
Ok(())
}
#[instrument(level = "trace", skip_all)]
pub async fn attach_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
debug!("Attaching chunks to surviving entities");
let tuning = &ctx.config.tuning;
let weights = FusionWeights::default();
let chunk_by_source = group_chunks_by_source(&ctx.chunk_candidates);
let chunk_candidates_before = ctx.chunk_candidates.len();
let chunk_sources_considered = chunk_by_source.len();
backfill_entities_from_chunks(
&mut ctx.entity_candidates,
&chunk_by_source,
ctx.db_client,
&ctx.user_id,
weights,
)
.await?;
boost_entities_with_chunks(&mut ctx.entity_candidates, &chunk_by_source, weights);
let mut entity_results: Vec<Scored<KnowledgeEntity>> =
ctx.entity_candidates.values().cloned().collect();
sort_by_fused_desc(&mut entity_results);
let mut filtered_entities: Vec<Scored<KnowledgeEntity>> = entity_results
.iter()
.filter(|candidate| candidate.fused >= tuning.score_threshold)
.cloned()
.collect();
if filtered_entities.len() < tuning.fallback_min_results {
filtered_entities = entity_results
.into_iter()
.take(tuning.fallback_min_results)
.collect();
}
ctx.filtered_entities = filtered_entities;
let mut chunk_results: Vec<Scored<TextChunk>> =
ctx.chunk_candidates.values().cloned().collect();
sort_by_fused_desc(&mut chunk_results);
let mut chunk_by_id: HashMap<String, Scored<TextChunk>> = HashMap::new();
for chunk in chunk_results {
chunk_by_id.insert(chunk.item.id.clone(), chunk);
}
enrich_chunks_from_entities(
&mut chunk_by_id,
&ctx.filtered_entities,
ctx.db_client,
&ctx.user_id,
weights,
)
.await?;
let mut chunk_values: Vec<Scored<TextChunk>> = chunk_by_id.into_values().collect();
sort_by_fused_desc(&mut chunk_values);
if ctx.diagnostics_enabled() {
ctx.record_chunk_enrichment(ChunkEnrichmentStats {
filtered_entity_count: ctx.filtered_entities.len(),
fallback_min_results: tuning.fallback_min_results,
chunk_sources_considered,
chunk_candidates_before_enrichment: chunk_candidates_before,
chunk_candidates_after_enrichment: chunk_values.len(),
top_chunk_scores: sample_scores(&chunk_values, |chunk| chunk.fused),
});
}
ctx.chunk_values = chunk_values;
Ok(())
}
#[instrument(level = "trace", skip_all)]
pub async fn rerank(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
@@ -960,142 +861,6 @@ where
}
}
fn group_chunks_by_source(
chunks: &HashMap<String, Scored<TextChunk>>,
) -> HashMap<String, Vec<Scored<TextChunk>>> {
let mut by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
for chunk in chunks.values() {
by_source
.entry(chunk.item.source_id.clone())
.or_default()
.push(chunk.clone());
}
by_source
}
async fn backfill_entities_from_chunks(
entity_candidates: &mut HashMap<String, Scored<KnowledgeEntity>>,
chunk_by_source: &HashMap<String, Vec<Scored<TextChunk>>>,
db_client: &SurrealDbClient,
user_id: &str,
weights: FusionWeights,
) -> Result<(), AppError> {
let mut missing_sources = Vec::new();
for source_id in chunk_by_source.keys() {
if !entity_candidates
.values()
.any(|entity| entity.item.source_id == *source_id)
{
missing_sources.push(source_id.clone());
}
}
if missing_sources.is_empty() {
return Ok(());
}
let related_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
missing_sources.clone(),
"knowledge_entity",
user_id,
db_client,
)
.await
.unwrap_or_default();
if related_entities.is_empty() {
warn!("expected related entities for missing chunk sources, but none were found");
}
for entity in related_entities {
if let Some(chunks) = chunk_by_source.get(&entity.source_id) {
let best_chunk_score = chunks
.iter()
.map(|chunk| chunk.fused)
.fold(0.0f32, f32::max);
let mut scored = Scored::new(entity.clone()).with_vector_score(best_chunk_score);
let fused = fuse_scores(&scored.scores, weights);
scored.update_fused(fused);
entity_candidates.insert(entity.id.clone(), scored);
}
}
Ok(())
}
fn boost_entities_with_chunks(
entity_candidates: &mut HashMap<String, Scored<KnowledgeEntity>>,
chunk_by_source: &HashMap<String, Vec<Scored<TextChunk>>>,
weights: FusionWeights,
) {
for entity in entity_candidates.values_mut() {
if let Some(chunks) = chunk_by_source.get(&entity.item.source_id) {
let best_chunk_score = chunks
.iter()
.map(|chunk| chunk.fused)
.fold(0.0f32, f32::max);
if best_chunk_score > 0.0 {
let boosted = entity.scores.vector.unwrap_or(0.0).max(best_chunk_score);
entity.scores.vector = Some(boosted);
let fused = fuse_scores(&entity.scores, weights);
entity.update_fused(fused);
}
}
}
}
async fn enrich_chunks_from_entities(
chunk_candidates: &mut HashMap<String, Scored<TextChunk>>,
entities: &[Scored<KnowledgeEntity>],
db_client: &SurrealDbClient,
user_id: &str,
weights: FusionWeights,
) -> Result<(), AppError> {
let mut source_ids: HashSet<String> = HashSet::new();
for entity in entities {
source_ids.insert(entity.item.source_id.clone());
}
if source_ids.is_empty() {
return Ok(());
}
let chunks = find_entities_by_source_ids::<TextChunk>(
source_ids.into_iter().collect(),
"text_chunk",
user_id,
db_client,
)
.await?;
let mut entity_score_lookup: HashMap<String, f32> = HashMap::new();
for entity in entities {
entity_score_lookup.insert(entity.item.source_id.clone(), entity.fused);
}
for chunk in chunks {
let entry = chunk_candidates
.entry(chunk.id.clone())
.or_insert_with(|| Scored::new(chunk.clone()).with_vector_score(0.0));
let entity_score = entity_score_lookup
.get(&chunk.source_id)
.copied()
.unwrap_or(0.0);
entry.scores.vector = Some(entry.scores.vector.unwrap_or(0.0).max(entity_score * 0.8));
let fused = fuse_scores(&entry.scores, weights);
entry.update_fused(fused);
entry.item = chunk;
}
Ok(())
}
fn build_rerank_documents(ctx: &PipelineContext<'_>, max_chunks_per_entity: usize) -> Vec<String> {
if ctx.filtered_entities.is_empty() {
return Vec::new();

View File

@@ -1,50 +1,24 @@
use super::{
stages::{
AssembleEntitiesStage, ChunkAssembleStage, ChunkAttachStage, ChunkRerankStage,
ChunkVectorStage, CollectCandidatesStage, EmbedStage, GraphExpansionStage, PipelineContext,
RerankStage,
AssembleEntitiesStage, ChunkAssembleStage, ChunkRerankStage, ChunkVectorStage,
CollectCandidatesStage, EmbedStage, GraphExpansionStage, PipelineContext, RerankStage,
},
BoxedStage, StrategyDriver,
};
use crate::{RetrievedChunk, RetrievedEntity};
use common::error::AppError;
pub struct InitialStrategyDriver;
impl InitialStrategyDriver {
pub struct DefaultStrategyDriver;
impl DefaultStrategyDriver {
pub fn new() -> Self {
Self
}
}
impl StrategyDriver for InitialStrategyDriver {
type Output = Vec<RetrievedEntity>;
fn stages(&self) -> Vec<BoxedStage> {
vec![
Box::new(EmbedStage),
Box::new(CollectCandidatesStage),
Box::new(GraphExpansionStage),
Box::new(ChunkAttachStage),
Box::new(RerankStage),
Box::new(AssembleEntitiesStage),
]
}
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
Ok(ctx.take_entity_results())
}
}
pub struct RevisedStrategyDriver;
impl RevisedStrategyDriver {
pub fn new() -> Self {
Self
}
}
impl StrategyDriver for RevisedStrategyDriver {
impl StrategyDriver for DefaultStrategyDriver {
type Output = Vec<RetrievedChunk>;
fn stages(&self) -> Vec<BoxedStage> {