mirror of
https://github.com/perstarkse/minne.git
synced 2026-04-22 08:48:30 +02:00
retrieval-pipeline: v0
This commit is contained in:
117
retrieval-pipeline/src/answer_retrieval.rs
Normal file
117
retrieval-pipeline/src/answer_retrieval.rs
Normal file
@@ -0,0 +1,117 @@
|
||||
use async_openai::{
|
||||
error::OpenAIError,
|
||||
types::{
|
||||
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
|
||||
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
|
||||
ResponseFormat, ResponseFormatJsonSchema,
|
||||
},
|
||||
};
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::types::{
|
||||
message::{format_history, Message},
|
||||
system_settings::SystemSettings,
|
||||
},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
|
||||
use super::answer_retrieval_helper::get_query_response_schema;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Reference {
|
||||
#[allow(dead_code)]
|
||||
pub reference: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct LLMResponseFormat {
|
||||
pub answer: String,
|
||||
#[allow(dead_code)]
|
||||
pub references: Vec<Reference>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Answer {
|
||||
pub content: String,
|
||||
pub references: Vec<String>,
|
||||
}
|
||||
|
||||
pub fn create_user_message(entities_json: &Value, query: &str) -> String {
|
||||
format!(
|
||||
r"
|
||||
Context Information:
|
||||
==================
|
||||
{entities_json}
|
||||
|
||||
User Question:
|
||||
==================
|
||||
{query}
|
||||
"
|
||||
)
|
||||
}
|
||||
|
||||
pub fn create_user_message_with_history(
|
||||
entities_json: &Value,
|
||||
history: &[Message],
|
||||
query: &str,
|
||||
) -> String {
|
||||
format!(
|
||||
r"
|
||||
Chat history:
|
||||
==================
|
||||
{}
|
||||
|
||||
Context Information:
|
||||
==================
|
||||
{}
|
||||
|
||||
User Question:
|
||||
==================
|
||||
{}
|
||||
",
|
||||
format_history(history),
|
||||
entities_json,
|
||||
query
|
||||
)
|
||||
}
|
||||
|
||||
pub fn create_chat_request(
|
||||
user_message: String,
|
||||
settings: &SystemSettings,
|
||||
) -> Result<CreateChatCompletionRequest, OpenAIError> {
|
||||
let response_format = ResponseFormat::JsonSchema {
|
||||
json_schema: ResponseFormatJsonSchema {
|
||||
description: Some("Query answering AI".into()),
|
||||
name: "query_answering_with_uuids".into(),
|
||||
schema: Some(get_query_response_schema()),
|
||||
strict: Some(true),
|
||||
},
|
||||
};
|
||||
|
||||
CreateChatCompletionRequestArgs::default()
|
||||
.model(&settings.query_model)
|
||||
.messages([
|
||||
ChatCompletionRequestSystemMessage::from(settings.query_system_prompt.clone()).into(),
|
||||
ChatCompletionRequestUserMessage::from(user_message).into(),
|
||||
])
|
||||
.response_format(response_format)
|
||||
.build()
|
||||
}
|
||||
|
||||
pub async fn process_llm_response(
|
||||
response: CreateChatCompletionResponse,
|
||||
) -> Result<LLMResponseFormat, AppError> {
|
||||
response
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.message.content.as_ref())
|
||||
.ok_or(AppError::LLMParsing(
|
||||
"No content found in LLM response".into(),
|
||||
))
|
||||
.and_then(|content| {
|
||||
serde_json::from_str::<LLMResponseFormat>(content).map_err(|e| {
|
||||
AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {e}"))
|
||||
})
|
||||
})
|
||||
}
|
||||
26
retrieval-pipeline/src/answer_retrieval_helper.rs
Normal file
26
retrieval-pipeline/src/answer_retrieval_helper.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
use common::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
pub static QUERY_SYSTEM_PROMPT: &str = DEFAULT_QUERY_SYSTEM_PROMPT;
|
||||
|
||||
pub fn get_query_response_schema() -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"answer": { "type": "string" },
|
||||
"references": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reference": { "type": "string" },
|
||||
},
|
||||
"required": ["reference"],
|
||||
"additionalProperties": false,
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["answer", "references"],
|
||||
"additionalProperties": false
|
||||
})
|
||||
}
|
||||
265
retrieval-pipeline/src/fts.rs
Normal file
265
retrieval-pipeline/src/fts.rs
Normal file
@@ -0,0 +1,265 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::Deserialize;
|
||||
use tracing::debug;
|
||||
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{db::SurrealDbClient, types::StoredObject},
|
||||
};
|
||||
|
||||
use crate::scoring::Scored;
|
||||
use common::storage::types::file_info::deserialize_flexible_id;
|
||||
use surrealdb::sql::Thing;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct FtsScoreRow {
|
||||
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||
id: String,
|
||||
fts_score: Option<f32>,
|
||||
}
|
||||
|
||||
/// Executes a full-text search query against SurrealDB and returns scored results.
|
||||
///
|
||||
/// The function expects FTS indexes to exist for the provided table. Currently supports
|
||||
/// `knowledge_entity` (name + description) and `text_chunk` (chunk).
|
||||
pub async fn find_items_by_fts<T>(
|
||||
take: usize,
|
||||
query: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
table: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<Scored<T>>, AppError>
|
||||
where
|
||||
T: for<'de> serde::Deserialize<'de> + StoredObject,
|
||||
{
|
||||
let (filter_clause, score_clause) = match table {
|
||||
"knowledge_entity" => (
|
||||
"(name @0@ $terms OR description @1@ $terms)",
|
||||
"(IF search::score(0) != NONE THEN search::score(0) ELSE 0 END) + \
|
||||
(IF search::score(1) != NONE THEN search::score(1) ELSE 0 END)",
|
||||
),
|
||||
"text_chunk" => (
|
||||
"(chunk @0@ $terms)",
|
||||
"IF search::score(0) != NONE THEN search::score(0) ELSE 0 END",
|
||||
),
|
||||
_ => {
|
||||
return Err(AppError::Validation(format!(
|
||||
"FTS not configured for table '{table}'"
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
let sql = format!(
|
||||
"SELECT id, {score_clause} AS fts_score \
|
||||
FROM {table} \
|
||||
WHERE {filter_clause} \
|
||||
AND user_id = $user_id \
|
||||
ORDER BY fts_score DESC \
|
||||
LIMIT $limit",
|
||||
table = table,
|
||||
filter_clause = filter_clause,
|
||||
score_clause = score_clause
|
||||
);
|
||||
|
||||
debug!(
|
||||
table = table,
|
||||
limit = take,
|
||||
"Executing FTS query with filter clause: {}",
|
||||
filter_clause
|
||||
);
|
||||
|
||||
let mut response = db_client
|
||||
.query(sql)
|
||||
.bind(("terms", query.to_owned()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.bind(("limit", take as i64))
|
||||
.await?;
|
||||
|
||||
let score_rows: Vec<FtsScoreRow> = response.take(0)?;
|
||||
|
||||
if score_rows.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let ids: Vec<String> = score_rows.iter().map(|row| row.id.clone()).collect();
|
||||
let thing_ids: Vec<Thing> = ids
|
||||
.iter()
|
||||
.map(|id| Thing::from((table, id.as_str())))
|
||||
.collect();
|
||||
|
||||
let mut items_response = db_client
|
||||
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
|
||||
.bind(("table", table.to_owned()))
|
||||
.bind(("things", thing_ids.clone()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?;
|
||||
|
||||
let items: Vec<T> = items_response.take(0)?;
|
||||
|
||||
let mut item_map: HashMap<String, T> = items
|
||||
.into_iter()
|
||||
.map(|item| (item.get_id().to_owned(), item))
|
||||
.collect();
|
||||
|
||||
let mut results = Vec::with_capacity(score_rows.len());
|
||||
for row in score_rows {
|
||||
if let Some(item) = item_map.remove(&row.id) {
|
||||
let score = row.fts_score.unwrap_or_default();
|
||||
results.push(Scored::new(item).with_fts_score(score));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use common::storage::types::{
|
||||
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
||||
text_chunk::TextChunk,
|
||||
StoredObject,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
fn dummy_embedding() -> Vec<f32> {
|
||||
vec![0.0; 1536]
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fts_preserves_single_field_score_for_name() {
|
||||
let namespace = "fts_test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("failed to create in-memory surreal");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("failed to apply migrations");
|
||||
|
||||
let user_id = "user_fts";
|
||||
let entity = KnowledgeEntity::new(
|
||||
"source_a".into(),
|
||||
"Rustacean handbook".into(),
|
||||
"completely unrelated description".into(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
dummy_embedding(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(entity.clone())
|
||||
.await
|
||||
.expect("failed to insert entity");
|
||||
|
||||
db.rebuild_indexes()
|
||||
.await
|
||||
.expect("failed to rebuild indexes");
|
||||
|
||||
let results = find_items_by_fts::<KnowledgeEntity>(
|
||||
5,
|
||||
"rustacean",
|
||||
&db,
|
||||
KnowledgeEntity::table_name(),
|
||||
user_id,
|
||||
)
|
||||
.await
|
||||
.expect("fts query failed");
|
||||
|
||||
assert!(!results.is_empty(), "expected at least one FTS result");
|
||||
assert!(
|
||||
results[0].scores.fts.is_some(),
|
||||
"expected an FTS score when only the name matched"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fts_preserves_single_field_score_for_description() {
|
||||
let namespace = "fts_test_ns_desc";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("failed to create in-memory surreal");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("failed to apply migrations");
|
||||
|
||||
let user_id = "user_fts_desc";
|
||||
let entity = KnowledgeEntity::new(
|
||||
"source_b".into(),
|
||||
"neutral name".into(),
|
||||
"Detailed notes about async runtimes".into(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
dummy_embedding(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(entity.clone())
|
||||
.await
|
||||
.expect("failed to insert entity");
|
||||
|
||||
db.rebuild_indexes()
|
||||
.await
|
||||
.expect("failed to rebuild indexes");
|
||||
|
||||
let results = find_items_by_fts::<KnowledgeEntity>(
|
||||
5,
|
||||
"async",
|
||||
&db,
|
||||
KnowledgeEntity::table_name(),
|
||||
user_id,
|
||||
)
|
||||
.await
|
||||
.expect("fts query failed");
|
||||
|
||||
assert!(!results.is_empty(), "expected at least one FTS result");
|
||||
assert!(
|
||||
results[0].scores.fts.is_some(),
|
||||
"expected an FTS score when only the description matched"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fts_preserves_scores_for_text_chunks() {
|
||||
let namespace = "fts_test_ns_chunks";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("failed to create in-memory surreal");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("failed to apply migrations");
|
||||
|
||||
let user_id = "user_fts_chunk";
|
||||
let chunk = TextChunk::new(
|
||||
"source_chunk".into(),
|
||||
"GraphQL documentation reference".into(),
|
||||
dummy_embedding(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(chunk.clone())
|
||||
.await
|
||||
.expect("failed to insert chunk");
|
||||
|
||||
db.rebuild_indexes()
|
||||
.await
|
||||
.expect("failed to rebuild indexes");
|
||||
|
||||
let results =
|
||||
find_items_by_fts::<TextChunk>(5, "graphql", &db, TextChunk::table_name(), user_id)
|
||||
.await
|
||||
.expect("fts query failed");
|
||||
|
||||
assert!(!results.is_empty(), "expected at least one FTS result");
|
||||
assert!(
|
||||
results[0].scores.fts.is_some(),
|
||||
"expected an FTS score when chunk field matched"
|
||||
);
|
||||
}
|
||||
}
|
||||
432
retrieval-pipeline/src/graph.rs
Normal file
432
retrieval-pipeline/src/graph.rs
Normal file
@@ -0,0 +1,432 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use surrealdb::{sql::Thing, Error};
|
||||
|
||||
use common::storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{
|
||||
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
|
||||
StoredObject,
|
||||
},
|
||||
};
|
||||
|
||||
/// Retrieves database entries that match a specific source identifier.
|
||||
///
|
||||
/// This function queries the database for all records in a specified table that have
|
||||
/// a matching `source_id` field. It's commonly used to find related entities or
|
||||
/// track the origin of database entries.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `source_id` - The identifier to search for in the database
|
||||
/// * `table_name` - The name of the table to search in
|
||||
/// * `db_client` - The `SurrealDB` client instance for database operations
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns a `Result` containing either:
|
||||
/// * `Ok(Vec<T>)` - A vector of matching records deserialized into type `T`
|
||||
/// * `Err(Error)` - An error if the database query fails
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// This function will return a `Error` if:
|
||||
/// * The database query fails to execute
|
||||
/// * The results cannot be deserialized into type `T`
|
||||
pub async fn find_entities_by_source_ids<T>(
|
||||
source_ids: Vec<String>,
|
||||
table_name: &str,
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Vec<T>, Error>
|
||||
where
|
||||
T: for<'de> serde::Deserialize<'de>,
|
||||
{
|
||||
let query =
|
||||
"SELECT * FROM type::table($table) WHERE source_id IN $source_ids AND user_id = $user_id";
|
||||
|
||||
db.query(query)
|
||||
.bind(("table", table_name.to_owned()))
|
||||
.bind(("source_ids", source_ids))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?
|
||||
.take(0)
|
||||
}
|
||||
|
||||
/// Find entities by their relationship to the id
|
||||
pub async fn find_entities_by_relationship_by_id(
|
||||
db: &SurrealDbClient,
|
||||
entity_id: &str,
|
||||
user_id: &str,
|
||||
limit: usize,
|
||||
) -> Result<Vec<KnowledgeEntity>, Error> {
|
||||
let mut relationships_response = db
|
||||
.query(
|
||||
"
|
||||
SELECT * FROM relates_to
|
||||
WHERE metadata.user_id = $user_id
|
||||
AND (in = type::thing('knowledge_entity', $entity_id)
|
||||
OR out = type::thing('knowledge_entity', $entity_id))
|
||||
",
|
||||
)
|
||||
.bind(("entity_id", entity_id.to_owned()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?;
|
||||
|
||||
let relationships: Vec<KnowledgeRelationship> = relationships_response.take(0)?;
|
||||
if relationships.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut neighbor_ids: Vec<String> = Vec::new();
|
||||
let mut seen: HashSet<String> = HashSet::new();
|
||||
for rel in relationships {
|
||||
if rel.in_ == entity_id {
|
||||
if seen.insert(rel.out.clone()) {
|
||||
neighbor_ids.push(rel.out);
|
||||
}
|
||||
} else if rel.out == entity_id {
|
||||
if seen.insert(rel.in_.clone()) {
|
||||
neighbor_ids.push(rel.in_);
|
||||
}
|
||||
} else {
|
||||
if seen.insert(rel.in_.clone()) {
|
||||
neighbor_ids.push(rel.in_.clone());
|
||||
}
|
||||
if seen.insert(rel.out.clone()) {
|
||||
neighbor_ids.push(rel.out);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
neighbor_ids.retain(|id| id != entity_id);
|
||||
|
||||
if neighbor_ids.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
if limit > 0 && neighbor_ids.len() > limit {
|
||||
neighbor_ids.truncate(limit);
|
||||
}
|
||||
|
||||
let thing_ids: Vec<Thing> = neighbor_ids
|
||||
.iter()
|
||||
.map(|id| Thing::from((KnowledgeEntity::table_name(), id.as_str())))
|
||||
.collect();
|
||||
|
||||
let mut neighbors_response = db
|
||||
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
|
||||
.bind(("table", KnowledgeEntity::table_name().to_owned()))
|
||||
.bind(("things", thing_ids))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?;
|
||||
|
||||
let neighbors: Vec<KnowledgeEntity> = neighbors_response.take(0)?;
|
||||
if neighbors.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut neighbor_map: HashMap<String, KnowledgeEntity> = neighbors
|
||||
.into_iter()
|
||||
.map(|entity| (entity.id.clone(), entity))
|
||||
.collect();
|
||||
|
||||
let mut ordered = Vec::new();
|
||||
for id in neighbor_ids {
|
||||
if let Some(entity) = neighbor_map.remove(&id) {
|
||||
ordered.push(entity);
|
||||
}
|
||||
if limit > 0 && ordered.len() >= limit {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ordered)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||
use common::storage::types::knowledge_relationship::KnowledgeRelationship;
|
||||
use common::storage::types::StoredObject;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_find_entities_by_source_ids() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create some test entities with different source_ids
|
||||
let source_id1 = "source123".to_string();
|
||||
let source_id2 = "source456".to_string();
|
||||
let source_id3 = "source789".to_string();
|
||||
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let embedding = vec![0.1, 0.2, 0.3];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
// Entity with source_id1
|
||||
let entity1 = KnowledgeEntity::new(
|
||||
source_id1.clone(),
|
||||
"Entity 1".to_string(),
|
||||
"Description 1".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Entity with source_id2
|
||||
let entity2 = KnowledgeEntity::new(
|
||||
source_id2.clone(),
|
||||
"Entity 2".to_string(),
|
||||
"Description 2".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Another entity with source_id1
|
||||
let entity3 = KnowledgeEntity::new(
|
||||
source_id1.clone(),
|
||||
"Entity 3".to_string(),
|
||||
"Description 3".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Entity with source_id3
|
||||
let entity4 = KnowledgeEntity::new(
|
||||
source_id3.clone(),
|
||||
"Entity 4".to_string(),
|
||||
"Description 4".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Store all entities
|
||||
db.store_item(entity1.clone())
|
||||
.await
|
||||
.expect("Failed to store entity 1");
|
||||
db.store_item(entity2.clone())
|
||||
.await
|
||||
.expect("Failed to store entity 2");
|
||||
db.store_item(entity3.clone())
|
||||
.await
|
||||
.expect("Failed to store entity 3");
|
||||
db.store_item(entity4.clone())
|
||||
.await
|
||||
.expect("Failed to store entity 4");
|
||||
|
||||
// Test finding entities by multiple source_ids
|
||||
let source_ids = vec![source_id1.clone(), source_id2.clone()];
|
||||
let found_entities: Vec<KnowledgeEntity> =
|
||||
find_entities_by_source_ids(source_ids, KnowledgeEntity::table_name(), &user_id, &db)
|
||||
.await
|
||||
.expect("Failed to find entities by source_ids");
|
||||
|
||||
// Should find 3 entities (2 with source_id1, 1 with source_id2)
|
||||
assert_eq!(
|
||||
found_entities.len(),
|
||||
3,
|
||||
"Should find 3 entities with the specified source_ids"
|
||||
);
|
||||
|
||||
// Check that entities with source_id1 and source_id2 are found
|
||||
let found_source_ids: Vec<String> =
|
||||
found_entities.iter().map(|e| e.source_id.clone()).collect();
|
||||
assert!(
|
||||
found_source_ids.contains(&source_id1),
|
||||
"Should find entities with source_id1"
|
||||
);
|
||||
assert!(
|
||||
found_source_ids.contains(&source_id2),
|
||||
"Should find entities with source_id2"
|
||||
);
|
||||
assert!(
|
||||
!found_source_ids.contains(&source_id3),
|
||||
"Should not find entities with source_id3"
|
||||
);
|
||||
|
||||
// Test finding entities by a single source_id
|
||||
let single_source_id = vec![source_id1.clone()];
|
||||
let found_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
|
||||
single_source_id,
|
||||
KnowledgeEntity::table_name(),
|
||||
&user_id,
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to find entities by single source_id");
|
||||
|
||||
// Should find 2 entities with source_id1
|
||||
assert_eq!(
|
||||
found_entities.len(),
|
||||
2,
|
||||
"Should find 2 entities with source_id1"
|
||||
);
|
||||
|
||||
// Check that all found entities have source_id1
|
||||
for entity in found_entities {
|
||||
assert_eq!(
|
||||
entity.source_id, source_id1,
|
||||
"All found entities should have source_id1"
|
||||
);
|
||||
}
|
||||
|
||||
// Test finding entities with non-existent source_id
|
||||
let non_existent_source_id = vec!["non_existent_source".to_string()];
|
||||
let found_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
|
||||
non_existent_source_id,
|
||||
KnowledgeEntity::table_name(),
|
||||
&user_id,
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to find entities by non-existent source_id");
|
||||
|
||||
// Should find 0 entities
|
||||
assert_eq!(
|
||||
found_entities.len(),
|
||||
0,
|
||||
"Should find 0 entities with non-existent source_id"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_find_entities_by_relationship_by_id() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create some test entities
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let embedding = vec![0.1, 0.2, 0.3];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
// Create the central entity we'll query relationships for
|
||||
let central_entity = KnowledgeEntity::new(
|
||||
"central_source".to_string(),
|
||||
"Central Entity".to_string(),
|
||||
"Central Description".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Create related entities
|
||||
let related_entity1 = KnowledgeEntity::new(
|
||||
"related_source1".to_string(),
|
||||
"Related Entity 1".to_string(),
|
||||
"Related Description 1".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
let related_entity2 = KnowledgeEntity::new(
|
||||
"related_source2".to_string(),
|
||||
"Related Entity 2".to_string(),
|
||||
"Related Description 2".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Create an unrelated entity
|
||||
let unrelated_entity = KnowledgeEntity::new(
|
||||
"unrelated_source".to_string(),
|
||||
"Unrelated Entity".to_string(),
|
||||
"Unrelated Description".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Store all entities
|
||||
let central_entity = db
|
||||
.store_item(central_entity.clone())
|
||||
.await
|
||||
.expect("Failed to store central entity")
|
||||
.unwrap();
|
||||
let related_entity1 = db
|
||||
.store_item(related_entity1.clone())
|
||||
.await
|
||||
.expect("Failed to store related entity 1")
|
||||
.unwrap();
|
||||
let related_entity2 = db
|
||||
.store_item(related_entity2.clone())
|
||||
.await
|
||||
.expect("Failed to store related entity 2")
|
||||
.unwrap();
|
||||
let _unrelated_entity = db
|
||||
.store_item(unrelated_entity.clone())
|
||||
.await
|
||||
.expect("Failed to store unrelated entity")
|
||||
.unwrap();
|
||||
|
||||
// Create relationships
|
||||
let source_id = "relationship_source".to_string();
|
||||
|
||||
// Create relationship 1: central -> related1
|
||||
let relationship1 = KnowledgeRelationship::new(
|
||||
central_entity.id.clone(),
|
||||
related_entity1.id.clone(),
|
||||
user_id.clone(),
|
||||
source_id.clone(),
|
||||
"references".to_string(),
|
||||
);
|
||||
|
||||
// Create relationship 2: central -> related2
|
||||
let relationship2 = KnowledgeRelationship::new(
|
||||
central_entity.id.clone(),
|
||||
related_entity2.id.clone(),
|
||||
user_id.clone(),
|
||||
source_id.clone(),
|
||||
"contains".to_string(),
|
||||
);
|
||||
|
||||
// Store relationships
|
||||
relationship1
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship 1");
|
||||
relationship2
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship 2");
|
||||
|
||||
// Test finding entities related to the central entity
|
||||
let related_entities =
|
||||
find_entities_by_relationship_by_id(&db, ¢ral_entity.id, &user_id, usize::MAX)
|
||||
.await
|
||||
.expect("Failed to find entities by relationship");
|
||||
|
||||
// Check that we found relationships
|
||||
assert!(
|
||||
related_entities.len() >= 2,
|
||||
"Should find related entities in both directions"
|
||||
);
|
||||
}
|
||||
}
|
||||
336
retrieval-pipeline/src/lib.rs
Normal file
336
retrieval-pipeline/src/lib.rs
Normal file
@@ -0,0 +1,336 @@
|
||||
pub mod answer_retrieval;
|
||||
pub mod answer_retrieval_helper;
|
||||
pub mod fts;
|
||||
pub mod graph;
|
||||
pub mod pipeline;
|
||||
pub mod reranking;
|
||||
pub mod scoring;
|
||||
pub mod vector;
|
||||
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
|
||||
},
|
||||
};
|
||||
use reranking::RerankerLease;
|
||||
use tracing::instrument;
|
||||
|
||||
pub use pipeline::{
|
||||
retrieved_entities_to_json, PipelineDiagnostics, PipelineStageTimings, RetrievalConfig,
|
||||
RetrievalStrategy, RetrievalTuning, StrategyOutput,
|
||||
};
|
||||
|
||||
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetrievedChunk {
|
||||
pub chunk: TextChunk,
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
// Final entity representation returned to callers, enriched with ranked chunks.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetrievedEntity {
|
||||
pub entity: KnowledgeEntity,
|
||||
pub score: f32,
|
||||
pub chunks: Vec<RetrievedChunk>,
|
||||
}
|
||||
|
||||
// Primary orchestrator for the process of retrieving KnowledgeEntitities related to a input_text
|
||||
#[instrument(skip_all, fields(user_id))]
|
||||
pub async fn retrieve_entities(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<StrategyOutput, AppError> {
|
||||
pipeline::run_pipeline(
|
||||
db_client,
|
||||
openai_client,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use async_openai::Client;
|
||||
use common::storage::types::{
|
||||
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
||||
knowledge_relationship::KnowledgeRelationship,
|
||||
text_chunk::TextChunk,
|
||||
};
|
||||
use pipeline::{RetrievalConfig, RetrievalStrategy};
|
||||
use uuid::Uuid;
|
||||
|
||||
fn test_embedding() -> Vec<f32> {
|
||||
vec![0.9, 0.1, 0.0]
|
||||
}
|
||||
|
||||
fn entity_embedding_high() -> Vec<f32> {
|
||||
vec![0.8, 0.2, 0.0]
|
||||
}
|
||||
|
||||
fn entity_embedding_low() -> Vec<f32> {
|
||||
vec![0.1, 0.9, 0.0]
|
||||
}
|
||||
|
||||
fn chunk_embedding_primary() -> Vec<f32> {
|
||||
vec![0.85, 0.15, 0.0]
|
||||
}
|
||||
|
||||
fn chunk_embedding_secondary() -> Vec<f32> {
|
||||
vec![0.2, 0.8, 0.0]
|
||||
}
|
||||
|
||||
async fn setup_test_db() -> SurrealDbClient {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
db.query(
|
||||
"BEGIN TRANSACTION;
|
||||
REMOVE INDEX IF EXISTS idx_embedding_chunks ON TABLE text_chunk;
|
||||
DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION 3;
|
||||
REMOVE INDEX IF EXISTS idx_embedding_entities ON TABLE knowledge_entity;
|
||||
DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION 3;
|
||||
COMMIT TRANSACTION;",
|
||||
)
|
||||
.await
|
||||
.expect("Failed to configure indices");
|
||||
|
||||
db
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_retrieve_entities_with_embedding_basic_flow() {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "test_user";
|
||||
let entity = KnowledgeEntity::new(
|
||||
"source_1".into(),
|
||||
"Rust async guide".into(),
|
||||
"Detailed notes about async runtimes".into(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
entity_embedding_high(),
|
||||
user_id.into(),
|
||||
);
|
||||
let chunk = TextChunk::new(
|
||||
entity.source_id.clone(),
|
||||
"Tokio uses cooperative scheduling for fairness.".into(),
|
||||
chunk_embedding_primary(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(entity.clone())
|
||||
.await
|
||||
.expect("Failed to store entity");
|
||||
db.store_item(chunk.clone())
|
||||
.await
|
||||
.expect("Failed to store chunk");
|
||||
|
||||
let openai_client = Client::new();
|
||||
let results = pipeline::run_pipeline_with_embedding(
|
||||
&db,
|
||||
&openai_client,
|
||||
test_embedding(),
|
||||
"Rust concurrency async tasks",
|
||||
user_id,
|
||||
RetrievalConfig::default(),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Hybrid retrieval failed");
|
||||
|
||||
let entities = match results {
|
||||
StrategyOutput::Entities(items) => items,
|
||||
other => panic!("expected entity results, got {:?}", other),
|
||||
};
|
||||
|
||||
assert!(
|
||||
!entities.is_empty(),
|
||||
"Expected at least one retrieval result"
|
||||
);
|
||||
let top = &entities[0];
|
||||
assert!(
|
||||
top.entity.name.contains("Rust"),
|
||||
"Expected Rust entity to be ranked first"
|
||||
);
|
||||
assert!(
|
||||
!top.chunks.is_empty(),
|
||||
"Expected Rust entity to include supporting chunks"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_graph_relationship_enriches_results() {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "graph_user";
|
||||
|
||||
let primary = KnowledgeEntity::new(
|
||||
"primary_source".into(),
|
||||
"Async Rust patterns".into(),
|
||||
"Explores async runtimes and scheduling strategies.".into(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
entity_embedding_high(),
|
||||
user_id.into(),
|
||||
);
|
||||
let neighbor = KnowledgeEntity::new(
|
||||
"neighbor_source".into(),
|
||||
"Tokio Scheduler Deep Dive".into(),
|
||||
"Details on Tokio's cooperative scheduler.".into(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
entity_embedding_low(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(primary.clone())
|
||||
.await
|
||||
.expect("Failed to store primary entity");
|
||||
db.store_item(neighbor.clone())
|
||||
.await
|
||||
.expect("Failed to store neighbor entity");
|
||||
|
||||
let primary_chunk = TextChunk::new(
|
||||
primary.source_id.clone(),
|
||||
"Rust async tasks use Tokio's cooperative scheduler.".into(),
|
||||
chunk_embedding_primary(),
|
||||
user_id.into(),
|
||||
);
|
||||
let neighbor_chunk = TextChunk::new(
|
||||
neighbor.source_id.clone(),
|
||||
"Tokio's scheduler manages task fairness across executors.".into(),
|
||||
chunk_embedding_secondary(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(primary_chunk)
|
||||
.await
|
||||
.expect("Failed to store primary chunk");
|
||||
db.store_item(neighbor_chunk)
|
||||
.await
|
||||
.expect("Failed to store neighbor chunk");
|
||||
|
||||
let openai_client = Client::new();
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
primary.id.clone(),
|
||||
neighbor.id.clone(),
|
||||
user_id.into(),
|
||||
"relationship_source".into(),
|
||||
"references".into(),
|
||||
);
|
||||
relationship
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship");
|
||||
|
||||
let results = pipeline::run_pipeline_with_embedding(
|
||||
&db,
|
||||
&openai_client,
|
||||
test_embedding(),
|
||||
"Rust concurrency async tasks",
|
||||
user_id,
|
||||
RetrievalConfig::default(),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Hybrid retrieval failed");
|
||||
|
||||
let entities = match results {
|
||||
StrategyOutput::Entities(items) => items,
|
||||
other => panic!("expected entity results, got {:?}", other),
|
||||
};
|
||||
|
||||
let mut neighbor_entry = None;
|
||||
for entity in &entities {
|
||||
if entity.entity.id == neighbor.id {
|
||||
neighbor_entry = Some(entity.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let neighbor_entry =
|
||||
neighbor_entry.expect("Graph-enriched neighbor should appear in results");
|
||||
|
||||
assert!(
|
||||
neighbor_entry.score > 0.2,
|
||||
"Graph-enriched entity should have a meaningful fused score"
|
||||
);
|
||||
assert!(
|
||||
neighbor_entry
|
||||
.chunks
|
||||
.iter()
|
||||
.all(|chunk| chunk.chunk.source_id == neighbor.source_id),
|
||||
"Neighbor entity should surface its own supporting chunks"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_revised_strategy_returns_chunks() {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "chunk_user";
|
||||
let chunk_one = TextChunk::new(
|
||||
"src_alpha".into(),
|
||||
"Tokio tasks execute on worker threads managed by the runtime.".into(),
|
||||
chunk_embedding_primary(),
|
||||
user_id.into(),
|
||||
);
|
||||
let chunk_two = TextChunk::new(
|
||||
"src_beta".into(),
|
||||
"Hyper utilizes Tokio to drive HTTP state machines efficiently.".into(),
|
||||
chunk_embedding_secondary(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(chunk_one.clone())
|
||||
.await
|
||||
.expect("Failed to store chunk one");
|
||||
db.store_item(chunk_two.clone())
|
||||
.await
|
||||
.expect("Failed to store chunk two");
|
||||
|
||||
let config = RetrievalConfig::with_strategy(RetrievalStrategy::Revised);
|
||||
let openai_client = Client::new();
|
||||
let results = pipeline::run_pipeline_with_embedding(
|
||||
&db,
|
||||
&openai_client,
|
||||
test_embedding(),
|
||||
"tokio runtime worker behavior",
|
||||
user_id,
|
||||
config,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Revised retrieval failed");
|
||||
|
||||
let chunks = match results {
|
||||
StrategyOutput::Chunks(items) => items,
|
||||
other => panic!("expected chunk output, got {:?}", other),
|
||||
};
|
||||
|
||||
assert!(
|
||||
!chunks.is_empty(),
|
||||
"Revised strategy should return chunk-only responses"
|
||||
);
|
||||
assert!(
|
||||
chunks
|
||||
.iter()
|
||||
.any(|entry| entry.chunk.chunk.contains("Tokio")),
|
||||
"Chunk results should contain relevant snippets"
|
||||
);
|
||||
}
|
||||
}
|
||||
121
retrieval-pipeline/src/pipeline/config.rs
Normal file
121
retrieval-pipeline/src/pipeline/config.rs
Normal file
@@ -0,0 +1,121 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RetrievalStrategy {
|
||||
Initial,
|
||||
Revised,
|
||||
}
|
||||
|
||||
impl Default for RetrievalStrategy {
|
||||
fn default() -> Self {
|
||||
Self::Initial
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for RetrievalStrategy {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(value: &str) -> Result<Self, Self::Err> {
|
||||
match value.to_ascii_lowercase().as_str() {
|
||||
"initial" => Ok(Self::Initial),
|
||||
"revised" => Ok(Self::Revised),
|
||||
other => Err(format!("unknown retrieval strategy '{other}'")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for RetrievalStrategy {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let label = match self {
|
||||
RetrievalStrategy::Initial => "initial",
|
||||
RetrievalStrategy::Revised => "revised",
|
||||
};
|
||||
f.write_str(label)
|
||||
}
|
||||
}
|
||||
|
||||
/// Tunable parameters that govern each retrieval stage.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RetrievalTuning {
|
||||
pub entity_vector_take: usize,
|
||||
pub chunk_vector_take: usize,
|
||||
pub entity_fts_take: usize,
|
||||
pub chunk_fts_take: usize,
|
||||
pub score_threshold: f32,
|
||||
pub fallback_min_results: usize,
|
||||
pub token_budget_estimate: usize,
|
||||
pub avg_chars_per_token: usize,
|
||||
pub max_chunks_per_entity: usize,
|
||||
pub lexical_match_weight: f32,
|
||||
pub graph_traversal_seed_limit: usize,
|
||||
pub graph_neighbor_limit: usize,
|
||||
pub graph_score_decay: f32,
|
||||
pub graph_seed_min_score: f32,
|
||||
pub graph_vector_inheritance: f32,
|
||||
pub rerank_blend_weight: f32,
|
||||
pub rerank_scores_only: bool,
|
||||
pub rerank_keep_top: usize,
|
||||
}
|
||||
|
||||
impl Default for RetrievalTuning {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
entity_vector_take: 15,
|
||||
chunk_vector_take: 20,
|
||||
entity_fts_take: 10,
|
||||
chunk_fts_take: 20,
|
||||
score_threshold: 0.35,
|
||||
fallback_min_results: 10,
|
||||
token_budget_estimate: 10000,
|
||||
avg_chars_per_token: 4,
|
||||
max_chunks_per_entity: 4,
|
||||
lexical_match_weight: 0.15,
|
||||
graph_traversal_seed_limit: 5,
|
||||
graph_neighbor_limit: 6,
|
||||
graph_score_decay: 0.75,
|
||||
graph_seed_min_score: 0.4,
|
||||
graph_vector_inheritance: 0.6,
|
||||
rerank_blend_weight: 0.65,
|
||||
rerank_scores_only: false,
|
||||
rerank_keep_top: 8,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper containing tuning plus future flags for per-request overrides.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetrievalConfig {
|
||||
pub strategy: RetrievalStrategy,
|
||||
pub tuning: RetrievalTuning,
|
||||
}
|
||||
|
||||
impl RetrievalConfig {
|
||||
pub fn new(tuning: RetrievalTuning) -> Self {
|
||||
Self {
|
||||
strategy: RetrievalStrategy::Initial,
|
||||
tuning,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_strategy(strategy: RetrievalStrategy) -> Self {
|
||||
Self {
|
||||
strategy,
|
||||
tuning: RetrievalTuning::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_tuning(strategy: RetrievalStrategy, tuning: RetrievalTuning) -> Self {
|
||||
Self { strategy, tuning }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RetrievalConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
strategy: RetrievalStrategy::default(),
|
||||
tuning: RetrievalTuning::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
51
retrieval-pipeline/src/pipeline/diagnostics.rs
Normal file
51
retrieval-pipeline/src/pipeline/diagnostics.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
use serde::Serialize;
|
||||
|
||||
/// Captures instrumentation for each hybrid retrieval stage when diagnostics are enabled.
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct PipelineDiagnostics {
|
||||
pub collect_candidates: Option<CollectCandidatesStats>,
|
||||
pub enrich_chunks_from_entities: Option<ChunkEnrichmentStats>,
|
||||
pub assemble: Option<AssembleStats>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct CollectCandidatesStats {
|
||||
pub vector_entity_candidates: usize,
|
||||
pub vector_chunk_candidates: usize,
|
||||
pub fts_entity_candidates: usize,
|
||||
pub fts_chunk_candidates: usize,
|
||||
pub vector_chunk_scores: Vec<f32>,
|
||||
pub fts_chunk_scores: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct ChunkEnrichmentStats {
|
||||
pub filtered_entity_count: usize,
|
||||
pub fallback_min_results: usize,
|
||||
pub chunk_sources_considered: usize,
|
||||
pub chunk_candidates_before_enrichment: usize,
|
||||
pub chunk_candidates_after_enrichment: usize,
|
||||
pub top_chunk_scores: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct AssembleStats {
|
||||
pub token_budget_start: usize,
|
||||
pub token_budget_spent: usize,
|
||||
pub token_budget_remaining: usize,
|
||||
pub budget_exhausted: bool,
|
||||
pub chunks_selected: usize,
|
||||
pub chunks_skipped_due_budget: usize,
|
||||
pub entity_count: usize,
|
||||
pub entity_traces: Vec<EntityAssemblyTrace>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct EntityAssemblyTrace {
|
||||
pub entity_id: String,
|
||||
pub source_id: String,
|
||||
pub inspected_candidates: usize,
|
||||
pub selected_chunk_ids: Vec<String>,
|
||||
pub selected_chunk_scores: Vec<f32>,
|
||||
pub skipped_due_budget: usize,
|
||||
}
|
||||
397
retrieval-pipeline/src/pipeline/mod.rs
Normal file
397
retrieval-pipeline/src/pipeline/mod.rs
Normal file
@@ -0,0 +1,397 @@
|
||||
mod config;
|
||||
mod diagnostics;
|
||||
mod stages;
|
||||
mod strategies;
|
||||
|
||||
pub use config::{RetrievalConfig, RetrievalStrategy, RetrievalTuning};
|
||||
pub use diagnostics::{
|
||||
AssembleStats, ChunkEnrichmentStats, CollectCandidatesStats, EntityAssemblyTrace,
|
||||
PipelineDiagnostics,
|
||||
};
|
||||
|
||||
use crate::{reranking::RerankerLease, RetrievedChunk, RetrievedEntity};
|
||||
use async_openai::Client;
|
||||
use async_trait::async_trait;
|
||||
use common::{error::AppError, storage::db::SurrealDbClient};
|
||||
use std::time::{Duration, Instant};
|
||||
use tracing::info;
|
||||
|
||||
use stages::PipelineContext;
|
||||
use strategies::{InitialStrategyDriver, RevisedStrategyDriver};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum StrategyOutput {
|
||||
Entities(Vec<RetrievedEntity>),
|
||||
Chunks(Vec<RetrievedChunk>),
|
||||
}
|
||||
|
||||
impl StrategyOutput {
|
||||
pub fn as_entities(&self) -> Option<&[RetrievedEntity]> {
|
||||
match self {
|
||||
StrategyOutput::Entities(items) => Some(items),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_entities(self) -> Option<Vec<RetrievedEntity>> {
|
||||
match self {
|
||||
StrategyOutput::Entities(items) => Some(items),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_chunks(&self) -> Option<&[RetrievedChunk]> {
|
||||
match self {
|
||||
StrategyOutput::Chunks(items) => Some(items),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_chunks(self) -> Option<Vec<RetrievedChunk>> {
|
||||
match self {
|
||||
StrategyOutput::Chunks(items) => Some(items),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PipelineRunOutput<T> {
|
||||
pub results: T,
|
||||
pub diagnostics: Option<PipelineDiagnostics>,
|
||||
pub stage_timings: PipelineStageTimings,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum StageKind {
|
||||
Embed,
|
||||
CollectCandidates,
|
||||
GraphExpansion,
|
||||
ChunkAttach,
|
||||
Rerank,
|
||||
Assemble,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, serde::Serialize)]
|
||||
pub struct PipelineStageTimings {
|
||||
pub embed_ms: u128,
|
||||
pub collect_candidates_ms: u128,
|
||||
pub graph_expansion_ms: u128,
|
||||
pub chunk_attach_ms: u128,
|
||||
pub rerank_ms: u128,
|
||||
pub assemble_ms: u128,
|
||||
}
|
||||
|
||||
impl PipelineStageTimings {
|
||||
pub fn record(&mut self, kind: StageKind, duration: Duration) {
|
||||
let elapsed = duration.as_millis() as u128;
|
||||
match kind {
|
||||
StageKind::Embed => self.embed_ms += elapsed,
|
||||
StageKind::CollectCandidates => self.collect_candidates_ms += elapsed,
|
||||
StageKind::GraphExpansion => self.graph_expansion_ms += elapsed,
|
||||
StageKind::ChunkAttach => self.chunk_attach_ms += elapsed,
|
||||
StageKind::Rerank => self.rerank_ms += elapsed,
|
||||
StageKind::Assemble => self.assemble_ms += elapsed,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait PipelineStage: Send + Sync {
|
||||
fn kind(&self) -> StageKind;
|
||||
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError>;
|
||||
}
|
||||
|
||||
pub type BoxedStage = Box<dyn PipelineStage + Send + Sync>;
|
||||
|
||||
pub trait StrategyDriver {
|
||||
type Output;
|
||||
|
||||
fn strategy(&self) -> RetrievalStrategy;
|
||||
fn stages(&self) -> Vec<BoxedStage>;
|
||||
fn override_tuning(&self, _config: &mut RetrievalConfig) {}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError>;
|
||||
}
|
||||
|
||||
pub async fn run_pipeline(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<StrategyOutput, AppError> {
|
||||
let input_chars = input_text.chars().count();
|
||||
let input_preview: String = input_text.chars().take(120).collect();
|
||||
let input_preview_clean = input_preview.replace('\n', " ");
|
||||
let preview_len = input_preview_clean.chars().count();
|
||||
info!(
|
||||
%user_id,
|
||||
input_chars,
|
||||
preview_truncated = input_chars > preview_len,
|
||||
preview = %input_preview_clean,
|
||||
"Starting ingestion retrieval pipeline"
|
||||
);
|
||||
|
||||
if config.strategy == RetrievalStrategy::Initial {
|
||||
let driver = InitialStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
return Ok(StrategyOutput::Entities(run.results));
|
||||
}
|
||||
|
||||
let driver = RevisedStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(StrategyOutput::Chunks(run.results))
|
||||
}
|
||||
|
||||
pub async fn run_pipeline_with_embedding(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<StrategyOutput, AppError> {
|
||||
if config.strategy == RetrievalStrategy::Initial {
|
||||
let driver = InitialStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
return Ok(StrategyOutput::Entities(run.results));
|
||||
}
|
||||
|
||||
let driver = RevisedStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(StrategyOutput::Chunks(run.results))
|
||||
}
|
||||
|
||||
pub async fn run_pipeline_with_embedding_with_metrics(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
|
||||
if config.strategy == RetrievalStrategy::Initial {
|
||||
let driver = InitialStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
return Ok(PipelineRunOutput {
|
||||
results: StrategyOutput::Entities(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
});
|
||||
}
|
||||
|
||||
let driver = RevisedStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(PipelineRunOutput {
|
||||
results: StrategyOutput::Chunks(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn run_pipeline_with_embedding_with_diagnostics(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
|
||||
if config.strategy == RetrievalStrategy::Initial {
|
||||
let driver = InitialStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
true,
|
||||
)
|
||||
.await?;
|
||||
return Ok(PipelineRunOutput {
|
||||
results: StrategyOutput::Entities(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
});
|
||||
}
|
||||
|
||||
let driver = RevisedStrategyDriver::new();
|
||||
let run = execute_strategy(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
true,
|
||||
)
|
||||
.await?;
|
||||
Ok(PipelineRunOutput {
|
||||
results: StrategyOutput::Chunks(run.results),
|
||||
diagnostics: run.diagnostics,
|
||||
stage_timings: run.stage_timings,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value {
|
||||
serde_json::json!(entities
|
||||
.iter()
|
||||
.map(|entry| {
|
||||
serde_json::json!({
|
||||
"KnowledgeEntity": {
|
||||
"id": entry.entity.id,
|
||||
"name": entry.entity.name,
|
||||
"description": entry.entity.description,
|
||||
"score": round_score(entry.score),
|
||||
"chunks": entry.chunks.iter().map(|chunk| {
|
||||
serde_json::json!({
|
||||
"score": round_score(chunk.score),
|
||||
"content": chunk.chunk.chunk
|
||||
})
|
||||
}).collect::<Vec<_>>()
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
}
|
||||
|
||||
async fn execute_strategy<D: StrategyDriver>(
|
||||
driver: D,
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
query_embedding: Option<Vec<f32>>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
mut config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
capture_diagnostics: bool,
|
||||
) -> Result<PipelineRunOutput<D::Output>, AppError> {
|
||||
driver.override_tuning(&mut config);
|
||||
let ctx = match query_embedding {
|
||||
Some(embedding) => PipelineContext::with_embedding(
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding,
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
config,
|
||||
reranker,
|
||||
),
|
||||
None => PipelineContext::new(
|
||||
db_client,
|
||||
openai_client,
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
config,
|
||||
reranker,
|
||||
),
|
||||
};
|
||||
|
||||
run_with_driver(driver, ctx, capture_diagnostics).await
|
||||
}
|
||||
|
||||
async fn run_with_driver<D: StrategyDriver>(
|
||||
driver: D,
|
||||
mut ctx: PipelineContext<'_>,
|
||||
capture_diagnostics: bool,
|
||||
) -> Result<PipelineRunOutput<D::Output>, AppError> {
|
||||
if capture_diagnostics {
|
||||
ctx.enable_diagnostics();
|
||||
}
|
||||
|
||||
for stage in driver.stages() {
|
||||
let start = Instant::now();
|
||||
stage.execute(&mut ctx).await?;
|
||||
ctx.record_stage_duration(stage.kind(), start.elapsed());
|
||||
}
|
||||
|
||||
let diagnostics = ctx.take_diagnostics();
|
||||
let stage_timings = ctx.take_stage_timings();
|
||||
let results = driver.finalize(&mut ctx)?;
|
||||
|
||||
Ok(PipelineRunOutput {
|
||||
results,
|
||||
diagnostics,
|
||||
stage_timings,
|
||||
})
|
||||
}
|
||||
|
||||
fn round_score(value: f32) -> f64 {
|
||||
(f64::from(value) * 1000.0).round() / 1000.0
|
||||
}
|
||||
1290
retrieval-pipeline/src/pipeline/stages/mod.rs
Normal file
1290
retrieval-pipeline/src/pipeline/stages/mod.rs
Normal file
File diff suppressed because it is too large
Load Diff
75
retrieval-pipeline/src/pipeline/strategies.rs
Normal file
75
retrieval-pipeline/src/pipeline/strategies.rs
Normal file
@@ -0,0 +1,75 @@
|
||||
use super::{
|
||||
stages::{
|
||||
AssembleEntitiesStage, ChunkAssembleStage, ChunkAttachStage, ChunkRerankStage,
|
||||
ChunkVectorStage, CollectCandidatesStage, EmbedStage, GraphExpansionStage, PipelineContext,
|
||||
RerankStage,
|
||||
},
|
||||
BoxedStage, RetrievalConfig, RetrievalStrategy, StrategyDriver,
|
||||
};
|
||||
use crate::{RetrievedChunk, RetrievedEntity};
|
||||
use common::error::AppError;
|
||||
|
||||
pub struct InitialStrategyDriver;
|
||||
|
||||
impl InitialStrategyDriver {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl StrategyDriver for InitialStrategyDriver {
|
||||
type Output = Vec<RetrievedEntity>;
|
||||
|
||||
fn strategy(&self) -> RetrievalStrategy {
|
||||
RetrievalStrategy::Initial
|
||||
}
|
||||
|
||||
fn stages(&self) -> Vec<BoxedStage> {
|
||||
vec![
|
||||
Box::new(EmbedStage),
|
||||
Box::new(CollectCandidatesStage),
|
||||
Box::new(GraphExpansionStage),
|
||||
Box::new(ChunkAttachStage),
|
||||
Box::new(RerankStage),
|
||||
Box::new(AssembleEntitiesStage),
|
||||
]
|
||||
}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
|
||||
Ok(ctx.take_entity_results())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RevisedStrategyDriver;
|
||||
|
||||
impl RevisedStrategyDriver {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl StrategyDriver for RevisedStrategyDriver {
|
||||
type Output = Vec<RetrievedChunk>;
|
||||
|
||||
fn strategy(&self) -> RetrievalStrategy {
|
||||
RetrievalStrategy::Revised
|
||||
}
|
||||
|
||||
fn stages(&self) -> Vec<BoxedStage> {
|
||||
vec![
|
||||
Box::new(EmbedStage),
|
||||
Box::new(ChunkVectorStage),
|
||||
Box::new(ChunkRerankStage),
|
||||
Box::new(ChunkAssembleStage),
|
||||
]
|
||||
}
|
||||
|
||||
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
|
||||
Ok(ctx.take_chunk_results())
|
||||
}
|
||||
|
||||
fn override_tuning(&self, config: &mut RetrievalConfig) {
|
||||
config.tuning.entity_vector_take = 0;
|
||||
config.tuning.entity_fts_take = 0;
|
||||
}
|
||||
}
|
||||
170
retrieval-pipeline/src/reranking/mod.rs
Normal file
170
retrieval-pipeline/src/reranking/mod.rs
Normal file
@@ -0,0 +1,170 @@
|
||||
use std::{
|
||||
env, fs,
|
||||
path::{Path, PathBuf},
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
},
|
||||
thread::available_parallelism,
|
||||
};
|
||||
|
||||
use common::{error::AppError, utils::config::AppConfig};
|
||||
use fastembed::{RerankInitOptions, RerankResult, TextRerank};
|
||||
use tokio::sync::{Mutex, OwnedSemaphorePermit, Semaphore};
|
||||
use tracing::debug;
|
||||
|
||||
static NEXT_ENGINE: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
fn pick_engine_index(pool_len: usize) -> usize {
|
||||
let n = NEXT_ENGINE.fetch_add(1, Ordering::Relaxed);
|
||||
n % pool_len
|
||||
}
|
||||
|
||||
pub struct RerankerPool {
|
||||
engines: Vec<Arc<Mutex<TextRerank>>>,
|
||||
semaphore: Arc<Semaphore>,
|
||||
}
|
||||
|
||||
impl RerankerPool {
|
||||
/// Build the pool at startup.
|
||||
/// `pool_size` controls max parallel reranks.
|
||||
pub fn new(pool_size: usize) -> Result<Arc<Self>, AppError> {
|
||||
Self::new_with_options(pool_size, RerankInitOptions::default())
|
||||
}
|
||||
|
||||
fn new_with_options(
|
||||
pool_size: usize,
|
||||
init_options: RerankInitOptions,
|
||||
) -> Result<Arc<Self>, AppError> {
|
||||
if pool_size == 0 {
|
||||
return Err(AppError::Validation(
|
||||
"RERANKING_POOL_SIZE must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
fs::create_dir_all(&init_options.cache_dir)?;
|
||||
|
||||
let mut engines = Vec::with_capacity(pool_size);
|
||||
for x in 0..pool_size {
|
||||
debug!("Creating reranking engine: {x}");
|
||||
let model = TextRerank::try_new(init_options.clone())
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
engines.push(Arc::new(Mutex::new(model)));
|
||||
}
|
||||
|
||||
Ok(Arc::new(Self {
|
||||
engines,
|
||||
semaphore: Arc::new(Semaphore::new(pool_size)),
|
||||
}))
|
||||
}
|
||||
|
||||
/// Initialize a pool using application configuration.
|
||||
pub fn maybe_from_config(config: &AppConfig) -> Result<Option<Arc<Self>>, AppError> {
|
||||
if !config.reranking_enabled {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let pool_size = config.reranking_pool_size.unwrap_or_else(default_pool_size);
|
||||
|
||||
let init_options = build_rerank_init_options(config)?;
|
||||
Self::new_with_options(pool_size, init_options).map(Some)
|
||||
}
|
||||
|
||||
/// Check out capacity + pick an engine.
|
||||
/// This returns a lease that can perform rerank().
|
||||
pub async fn checkout(self: &Arc<Self>) -> RerankerLease {
|
||||
// Acquire a permit. This enforces backpressure.
|
||||
let permit = self
|
||||
.semaphore
|
||||
.clone()
|
||||
.acquire_owned()
|
||||
.await
|
||||
.expect("semaphore closed");
|
||||
|
||||
// Pick an engine.
|
||||
// This is naive: just pick based on a simple modulo counter.
|
||||
// We use an atomic counter to avoid always choosing index 0.
|
||||
let idx = pick_engine_index(self.engines.len());
|
||||
let engine = self.engines[idx].clone();
|
||||
|
||||
RerankerLease {
|
||||
_permit: permit,
|
||||
engine,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_pool_size() -> usize {
|
||||
available_parallelism()
|
||||
.map(|value| value.get().min(2))
|
||||
.unwrap_or(2)
|
||||
.max(1)
|
||||
}
|
||||
|
||||
fn is_truthy(value: &str) -> bool {
|
||||
matches!(
|
||||
value.trim().to_ascii_lowercase().as_str(),
|
||||
"1" | "true" | "yes" | "on"
|
||||
)
|
||||
}
|
||||
|
||||
fn build_rerank_init_options(config: &AppConfig) -> Result<RerankInitOptions, AppError> {
|
||||
let mut options = RerankInitOptions::default();
|
||||
|
||||
let cache_dir = config
|
||||
.fastembed_cache_dir
|
||||
.as_ref()
|
||||
.map(PathBuf::from)
|
||||
.or_else(|| env::var("RERANKING_CACHE_DIR").ok().map(PathBuf::from))
|
||||
.or_else(|| env::var("FASTEMBED_CACHE_DIR").ok().map(PathBuf::from))
|
||||
.unwrap_or_else(|| {
|
||||
Path::new(&config.data_dir)
|
||||
.join("fastembed")
|
||||
.join("reranker")
|
||||
});
|
||||
fs::create_dir_all(&cache_dir)?;
|
||||
options.cache_dir = cache_dir;
|
||||
|
||||
let show_progress = config
|
||||
.fastembed_show_download_progress
|
||||
.or_else(|| env_bool("RERANKING_SHOW_DOWNLOAD_PROGRESS"))
|
||||
.or_else(|| env_bool("FASTEMBED_SHOW_DOWNLOAD_PROGRESS"))
|
||||
.unwrap_or(true);
|
||||
options.show_download_progress = show_progress;
|
||||
|
||||
if let Some(max_length) = config.fastembed_max_length.or_else(|| {
|
||||
env::var("RERANKING_MAX_LENGTH")
|
||||
.ok()
|
||||
.and_then(|value| value.parse().ok())
|
||||
}) {
|
||||
options.max_length = max_length;
|
||||
}
|
||||
|
||||
Ok(options)
|
||||
}
|
||||
|
||||
fn env_bool(key: &str) -> Option<bool> {
|
||||
env::var(key).ok().map(|value| is_truthy(&value))
|
||||
}
|
||||
|
||||
/// Active lease on a single TextRerank instance.
|
||||
pub struct RerankerLease {
|
||||
// When this drops the semaphore permit is released.
|
||||
_permit: OwnedSemaphorePermit,
|
||||
engine: Arc<Mutex<TextRerank>>,
|
||||
}
|
||||
|
||||
impl RerankerLease {
|
||||
pub async fn rerank(
|
||||
&self,
|
||||
query: &str,
|
||||
documents: Vec<String>,
|
||||
) -> Result<Vec<RerankResult>, AppError> {
|
||||
// Lock this specific engine so we get &mut TextRerank
|
||||
let mut guard = self.engine.lock().await;
|
||||
|
||||
guard
|
||||
.rerank(query.to_owned(), documents, false, None)
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))
|
||||
}
|
||||
}
|
||||
183
retrieval-pipeline/src/scoring.rs
Normal file
183
retrieval-pipeline/src/scoring.rs
Normal file
@@ -0,0 +1,183 @@
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use common::storage::types::StoredObject;
|
||||
|
||||
/// Holds optional subscores gathered from different retrieval signals.
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct Scores {
|
||||
pub fts: Option<f32>,
|
||||
pub vector: Option<f32>,
|
||||
pub graph: Option<f32>,
|
||||
}
|
||||
|
||||
/// Generic wrapper combining an item with its accumulated retrieval scores.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Scored<T> {
|
||||
pub item: T,
|
||||
pub scores: Scores,
|
||||
pub fused: f32,
|
||||
}
|
||||
|
||||
impl<T> Scored<T> {
|
||||
pub fn new(item: T) -> Self {
|
||||
Self {
|
||||
item,
|
||||
scores: Scores::default(),
|
||||
fused: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn with_vector_score(mut self, score: f32) -> Self {
|
||||
self.scores.vector = Some(score);
|
||||
self
|
||||
}
|
||||
|
||||
pub const fn with_fts_score(mut self, score: f32) -> Self {
|
||||
self.scores.fts = Some(score);
|
||||
self
|
||||
}
|
||||
|
||||
pub const fn with_graph_score(mut self, score: f32) -> Self {
|
||||
self.scores.graph = Some(score);
|
||||
self
|
||||
}
|
||||
|
||||
pub const fn update_fused(&mut self, fused: f32) {
|
||||
self.fused = fused;
|
||||
}
|
||||
}
|
||||
|
||||
/// Weights used for linear score fusion.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct FusionWeights {
|
||||
pub vector: f32,
|
||||
pub fts: f32,
|
||||
pub graph: f32,
|
||||
pub multi_bonus: f32,
|
||||
}
|
||||
|
||||
impl Default for FusionWeights {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
vector: 0.5,
|
||||
fts: 0.3,
|
||||
graph: 0.2,
|
||||
multi_bonus: 0.02,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn clamp_unit(value: f32) -> f32 {
|
||||
value.clamp(0.0, 1.0)
|
||||
}
|
||||
|
||||
pub fn distance_to_similarity(distance: f32) -> f32 {
|
||||
if !distance.is_finite() {
|
||||
return 0.0;
|
||||
}
|
||||
clamp_unit(1.0 / (1.0 + distance.max(0.0)))
|
||||
}
|
||||
|
||||
pub fn min_max_normalize(scores: &[f32]) -> Vec<f32> {
|
||||
if scores.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut min = f32::MAX;
|
||||
let mut max = f32::MIN;
|
||||
|
||||
for s in scores {
|
||||
if !s.is_finite() {
|
||||
continue;
|
||||
}
|
||||
if *s < min {
|
||||
min = *s;
|
||||
}
|
||||
if *s > max {
|
||||
max = *s;
|
||||
}
|
||||
}
|
||||
|
||||
if !min.is_finite() || !max.is_finite() {
|
||||
return scores.iter().map(|_| 0.0).collect();
|
||||
}
|
||||
|
||||
if (max - min).abs() < f32::EPSILON {
|
||||
return vec![1.0; scores.len()];
|
||||
}
|
||||
|
||||
scores
|
||||
.iter()
|
||||
.map(|score| {
|
||||
if score.is_finite() {
|
||||
clamp_unit((score - min) / (max - min))
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 {
|
||||
let vector = scores.vector.unwrap_or(0.0);
|
||||
let fts = scores.fts.unwrap_or(0.0);
|
||||
let graph = scores.graph.unwrap_or(0.0);
|
||||
|
||||
let mut fused = graph.mul_add(
|
||||
weights.graph,
|
||||
vector.mul_add(weights.vector, fts * weights.fts),
|
||||
);
|
||||
|
||||
let signals_present = scores
|
||||
.vector
|
||||
.iter()
|
||||
.chain(scores.fts.iter())
|
||||
.chain(scores.graph.iter())
|
||||
.count();
|
||||
if signals_present >= 2 {
|
||||
fused += weights.multi_bonus;
|
||||
}
|
||||
|
||||
clamp_unit(fused)
|
||||
}
|
||||
|
||||
pub fn merge_scored_by_id<T>(
|
||||
target: &mut std::collections::HashMap<String, Scored<T>>,
|
||||
incoming: Vec<Scored<T>>,
|
||||
) where
|
||||
T: StoredObject + Clone,
|
||||
{
|
||||
for scored in incoming {
|
||||
let id = scored.item.get_id().to_owned();
|
||||
target
|
||||
.entry(id)
|
||||
.and_modify(|existing| {
|
||||
if let Some(score) = scored.scores.vector {
|
||||
existing.scores.vector = Some(score);
|
||||
}
|
||||
if let Some(score) = scored.scores.fts {
|
||||
existing.scores.fts = Some(score);
|
||||
}
|
||||
if let Some(score) = scored.scores.graph {
|
||||
existing.scores.graph = Some(score);
|
||||
}
|
||||
})
|
||||
.or_insert_with(|| Scored {
|
||||
item: scored.item.clone(),
|
||||
scores: scored.scores,
|
||||
fused: scored.fused,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sort_by_fused_desc<T>(items: &mut [Scored<T>])
|
||||
where
|
||||
T: StoredObject,
|
||||
{
|
||||
items.sort_by(|a, b| {
|
||||
b.fused
|
||||
.partial_cmp(&a.fused)
|
||||
.unwrap_or(Ordering::Equal)
|
||||
.then_with(|| a.item.get_id().cmp(b.item.get_id()))
|
||||
});
|
||||
}
|
||||
218
retrieval-pipeline/src/vector.rs
Normal file
218
retrieval-pipeline/src/vector.rs
Normal file
@@ -0,0 +1,218 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{file_info::deserialize_flexible_id, StoredObject},
|
||||
},
|
||||
utils::embedding::generate_embedding,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use surrealdb::sql::Thing;
|
||||
|
||||
use crate::scoring::{clamp_unit, distance_to_similarity, Scored};
|
||||
|
||||
/// Compares vectors and retrieves a number of items from the specified table.
|
||||
///
|
||||
/// This function generates embeddings for the input text, constructs a query to find the closest matches in the database,
|
||||
/// and then deserializes the results into the specified type `T`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `take` - The number of items to retrieve from the database.
|
||||
/// * `input_text` - The text to generate embeddings for.
|
||||
/// * `db_client` - The SurrealDB client to use for querying the database.
|
||||
/// * `table` - The table to query in the database.
|
||||
/// * `openai_client` - The OpenAI client to use for generating embeddings.
|
||||
/// * 'user_id`- The user id of the current user.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vector of type `T` containing the closest matches to the input text. Returns a `ProcessingError` if an error occurs.
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`.
|
||||
pub async fn find_items_by_vector_similarity<T>(
|
||||
take: usize,
|
||||
input_text: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
table: &str,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<Scored<T>>, AppError>
|
||||
where
|
||||
T: for<'de> serde::Deserialize<'de> + StoredObject,
|
||||
{
|
||||
// Generate embeddings
|
||||
let input_embedding = generate_embedding(openai_client, input_text, db_client).await?;
|
||||
find_items_by_vector_similarity_with_embedding(take, input_embedding, db_client, table, user_id)
|
||||
.await
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DistanceRow {
|
||||
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||
id: String,
|
||||
distance: Option<f32>,
|
||||
}
|
||||
|
||||
pub async fn find_items_by_vector_similarity_with_embedding<T>(
|
||||
take: usize,
|
||||
query_embedding: Vec<f32>,
|
||||
db_client: &SurrealDbClient,
|
||||
table: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<Scored<T>>, AppError>
|
||||
where
|
||||
T: for<'de> serde::Deserialize<'de> + StoredObject,
|
||||
{
|
||||
let embedding_literal = serde_json::to_string(&query_embedding)
|
||||
.map_err(|err| AppError::InternalError(format!("Failed to serialize embedding: {err}")))?;
|
||||
|
||||
let closest_query = format!(
|
||||
"SELECT id, vector::distance::knn() AS distance \
|
||||
FROM {table} \
|
||||
WHERE user_id = $user_id AND embedding <|{take},40|> {embedding} \
|
||||
LIMIT $limit",
|
||||
table = table,
|
||||
take = take,
|
||||
embedding = embedding_literal
|
||||
);
|
||||
|
||||
let mut response = db_client
|
||||
.query(closest_query)
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.bind(("limit", take as i64))
|
||||
.await?;
|
||||
|
||||
let distance_rows: Vec<DistanceRow> = response.take(0)?;
|
||||
|
||||
if distance_rows.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let ids: Vec<String> = distance_rows.iter().map(|row| row.id.clone()).collect();
|
||||
let thing_ids: Vec<Thing> = ids
|
||||
.iter()
|
||||
.map(|id| Thing::from((table, id.as_str())))
|
||||
.collect();
|
||||
|
||||
let mut items_response = db_client
|
||||
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
|
||||
.bind(("table", table.to_owned()))
|
||||
.bind(("things", thing_ids.clone()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?;
|
||||
|
||||
let items: Vec<T> = items_response.take(0)?;
|
||||
|
||||
let mut item_map: HashMap<String, T> = items
|
||||
.into_iter()
|
||||
.map(|item| (item.get_id().to_owned(), item))
|
||||
.collect();
|
||||
|
||||
let mut min_distance = f32::MAX;
|
||||
let mut max_distance = f32::MIN;
|
||||
|
||||
for row in &distance_rows {
|
||||
if let Some(distance) = row.distance {
|
||||
if distance.is_finite() {
|
||||
if distance < min_distance {
|
||||
min_distance = distance;
|
||||
}
|
||||
if distance > max_distance {
|
||||
max_distance = distance;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let normalize = min_distance.is_finite()
|
||||
&& max_distance.is_finite()
|
||||
&& (max_distance - min_distance).abs() > f32::EPSILON;
|
||||
|
||||
let mut scored = Vec::with_capacity(distance_rows.len());
|
||||
for row in distance_rows {
|
||||
if let Some(item) = item_map.remove(&row.id) {
|
||||
let similarity = row
|
||||
.distance
|
||||
.map(|distance| {
|
||||
if normalize {
|
||||
let span = max_distance - min_distance;
|
||||
if span.abs() < f32::EPSILON {
|
||||
1.0
|
||||
} else {
|
||||
let normalized = 1.0 - ((distance - min_distance) / span);
|
||||
clamp_unit(normalized)
|
||||
}
|
||||
} else {
|
||||
distance_to_similarity(distance)
|
||||
}
|
||||
})
|
||||
.unwrap_or_default();
|
||||
scored.push(Scored::new(item).with_vector_score(similarity));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(scored)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ChunkSnippet {
|
||||
pub id: String,
|
||||
pub source_id: String,
|
||||
pub chunk: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChunkDistanceRow {
|
||||
distance: Option<f32>,
|
||||
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||
pub id: String,
|
||||
pub source_id: String,
|
||||
pub chunk: String,
|
||||
}
|
||||
|
||||
pub async fn find_chunk_snippets_by_vector_similarity_with_embedding(
|
||||
take: usize,
|
||||
query_embedding: Vec<f32>,
|
||||
db_client: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<Scored<ChunkSnippet>>, AppError> {
|
||||
let embedding_literal = serde_json::to_string(&query_embedding)
|
||||
.map_err(|err| AppError::InternalError(format!("Failed to serialize embedding: {err}")))?;
|
||||
|
||||
let closest_query = format!(
|
||||
"SELECT id, source_id, chunk, vector::distance::knn() AS distance \
|
||||
FROM text_chunk \
|
||||
WHERE user_id = $user_id AND embedding <|{take},40|> {embedding} \
|
||||
LIMIT $limit",
|
||||
take = take,
|
||||
embedding = embedding_literal
|
||||
);
|
||||
|
||||
let mut response = db_client
|
||||
.query(closest_query)
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.bind(("limit", take as i64))
|
||||
.await?;
|
||||
|
||||
let rows: Vec<ChunkDistanceRow> = response.take(0)?;
|
||||
|
||||
let mut scored = Vec::with_capacity(rows.len());
|
||||
for row in rows {
|
||||
let similarity = row.distance.map(distance_to_similarity).unwrap_or_default();
|
||||
scored.push(
|
||||
Scored::new(ChunkSnippet {
|
||||
id: row.id,
|
||||
source_id: row.source_id,
|
||||
chunk: row.chunk,
|
||||
})
|
||||
.with_vector_score(similarity),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(scored)
|
||||
}
|
||||
Reference in New Issue
Block a user