feat: reranking with fastembed added

This commit is contained in:
Per Stark
2025-10-27 13:05:10 +01:00
parent a0e9387c76
commit 72578296db
25 changed files with 1586 additions and 202 deletions

View File

@@ -1,5 +1,6 @@
# Changelog
## Unreleased
- Added an opt-in FastEmbed-based reranking stage behind `reranking_enabled`. It improves retrieval accuracy by re-scoring hybrid results.
## Version 0.2.5 (2025-10-24)
- Added manual knowledge entity creation flows using a modal, with the option for suggested relationships

1024
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -56,18 +56,56 @@ base64 = "0.22.1"
object_store = { version = "0.11.2" }
bytes = "1.7.1"
state-machines = "0.2.0"
fastembed = "5.2.0"
[profile.dist]
inherits = "release"
lto = "thin"
[workspace.lints.clippy]
perf = { level = "warn", priority = -1 }
pedantic = { level = "warn", priority = -1 }
nursery = { level = "warn", priority = -1 }
cargo = { level = "warn", priority = -1 }
# Performance-focused lints
perf = { level = "warn", priority = -1 }
vec_init_then_push = "warn"
large_stack_frames = "warn"
redundant_allocation = "warn"
single_char_pattern = "warn"
string_extend_chars = "warn"
format_in_format_args = "warn"
slow_vector_initialization = "warn"
inefficient_to_string = "warn"
implicit_clone = "warn"
redundant_clone = "warn"
needless_question_mark = "allow"
single_call_fn = "allow"
# Security-focused lints
integer_arithmetic = "warn"
indexing_slicing = "warn"
unwrap_used = "warn"
expect_used = "warn"
panic = "warn"
unimplemented = "warn"
todo = "warn"
# Async/Network lints
async_yields_async = "warn"
await_holding_invalid_state = "warn"
rc_buffer = "warn"
# Maintainability-focused lints
cargo = { level = "warn", priority = -1 }
pedantic = { level = "warn", priority = -1 }
clone_on_ref_ptr = "warn"
float_cmp = "warn"
manual_string_new = "warn"
uninlined_format_args = "warn"
unused_self = "warn"
must_use_candidate = "allow"
missing_errors_doc = "allow"
missing_panics_doc = "warn"
module_name_repetitions = "warn"
wildcard_dependencies = "warn"
missing_docs_in_private_items = "warn"
# Allow noisy lints that don't add value for this project
manual_must_use = "allow"
needless_raw_string_hashes = "allow"
multiple_bound_locations = "allow"

View File

@@ -98,6 +98,23 @@ The graph visualization shows:
- Relationships as connections (manually defined, AI-discovered, or suggested)
- Interactive navigation for discovery and editing
### Optional FastEmbed Reranking
Minne ships with an opt-in reranking stage powered by [fastembed-rs](https://github.com/Anush008/fastembed-rs). When enabled, the hybrid retrieval results are rescored with a lightweight cross-encoder before being returned to chat or ingestion flows. In practice this often means more relevant results, boosting answer quality and downstream enrichment.
⚠️ **Resource notes**
- Enabling reranking downloads and caches ~1.1GB of model data on first startup (cached under `<data_dir>/fastembed/reranker` by default).
- Initialization takes longer while warming the cache, and each query consumes extra CPU. The default pool size (2) is tuned for a singe user setup, but could work with a pool size on 1 as well.
- The feature is disabled by default. Set `reranking_enabled: true` (or `RERANKING_ENABLED=true`) if youre comfortable with the additional footprint.
Example configuration:
```yaml
reranking_enabled: true
reranking_pool_size: 2
fastembed_cache_dir: "/var/lib/minne/fastembed" # optional override, defaults to .fastembed_cache
```
## Tech Stack
- **Backend:** Rust with Axum framework and Server-Side Rendering (SSR)
@@ -125,6 +142,10 @@ Minne can be configured using environment variables or a `config.yaml` file. Env
- `RUST_LOG`: Controls logging level (e.g., `minne=info,tower_http=debug`)
- `DATA_DIR`: Directory to store local data (e.g., `./data`)
- `OPENAI_BASE_URL`: Base URL for custom AI providers (like Ollama)
- `RERANKING_ENABLED` / `reranking_enabled`: Set to `true` to enable the FastEmbed reranking stage (default `false`)
- `RERANKING_POOL_SIZE` / `reranking_pool_size`: Maximum concurrent reranker workers (defaults to `2`)
- `FASTEMBED_CACHE_DIR` / `fastembed_cache_dir`: Directory for cached FastEmbed models (defaults to `<data_dir>/fastembed/reranker`)
- `FASTEMBED_SHOW_DOWNLOAD_PROGRESS` / `fastembed_show_download_progress`: Show model download progress when warming the cache (default `true`)
### Example config.yaml

View File

@@ -214,6 +214,7 @@ mod tests {
openai_base_url: "..".into(),
storage: StorageKind::Local,
pdf_ingest_mode: LlmFirst,
..Default::default()
}
}

View File

@@ -270,12 +270,29 @@ impl FileInfo {
#[cfg(test)]
mod tests {
use super::*;
use crate::utils::config::{PdfIngestMode::LlmFirst, StorageKind};
use crate::utils::config::{AppConfig, PdfIngestMode::LlmFirst, StorageKind};
use axum::http::HeaderMap;
use axum_typed_multipart::FieldMetadata;
use std::io::Write;
use tempfile::NamedTempFile;
fn test_config(data_dir: &str) -> AppConfig {
AppConfig {
data_dir: data_dir.to_string(),
openai_api_key: "test_key".to_string(),
surrealdb_address: "test_address".to_string(),
surrealdb_username: "test_user".to_string(),
surrealdb_password: "test_pass".to_string(),
surrealdb_namespace: "test_ns".to_string(),
surrealdb_database: "test_db".to_string(),
http_port: 3000,
openai_base_url: "..".to_string(),
storage: StorageKind::Local,
pdf_ingest_mode: LlmFirst,
..Default::default()
}
}
/// Creates a test temporary file with the given content
fn create_test_file(content: &[u8], file_name: &str) -> FieldData<NamedTempFile> {
let mut temp_file = NamedTempFile::new().expect("Failed to create temp file");
@@ -314,19 +331,7 @@ mod tests {
// Create a FileInfo instance with data_dir in /tmp
let user_id = "test_user";
let config = AppConfig {
data_dir: "/tmp/minne_test_data".to_string(), // Using /tmp which is typically on a different filesystem
openai_api_key: "test_key".to_string(),
surrealdb_address: "test_address".to_string(),
surrealdb_username: "test_user".to_string(),
surrealdb_password: "test_pass".to_string(),
surrealdb_namespace: "test_ns".to_string(),
surrealdb_database: "test_db".to_string(),
http_port: 3000,
openai_base_url: "..".to_string(),
storage: StorageKind::Local,
pdf_ingest_mode: LlmFirst,
};
let config = test_config("/tmp/minne_test_data");
// Test file creation
let file_info = FileInfo::new(field_data, &db, user_id, &config)
@@ -375,19 +380,7 @@ mod tests {
// Create a FileInfo instance with data_dir in /tmp
let user_id = "test_user";
let config = AppConfig {
data_dir: "/tmp/minne_test_data".to_string(),
openai_api_key: "test_key".to_string(),
surrealdb_address: "test_address".to_string(),
surrealdb_username: "test_user".to_string(),
surrealdb_password: "test_pass".to_string(),
surrealdb_namespace: "test_ns".to_string(),
surrealdb_database: "test_db".to_string(),
http_port: 3000,
openai_base_url: "..".to_string(),
storage: StorageKind::Local,
pdf_ingest_mode: LlmFirst,
};
let config = test_config("/tmp/minne_test_data");
// Store the original file
let original_file_info = FileInfo::new(field_data, &db, user_id, &config)
@@ -432,19 +425,7 @@ mod tests {
// Create a FileInfo instance
let user_id = "test_user";
let config = AppConfig {
data_dir: "./data".to_string(),
openai_api_key: "test_key".to_string(),
surrealdb_address: "test_address".to_string(),
surrealdb_username: "test_user".to_string(),
surrealdb_password: "test_pass".to_string(),
surrealdb_namespace: "test_ns".to_string(),
surrealdb_database: "test_db".to_string(),
http_port: 3000,
openai_base_url: "..".to_string(),
storage: StorageKind::Local,
pdf_ingest_mode: LlmFirst,
};
let config = test_config("./data");
let file_info = FileInfo::new(field_data, &db, user_id, &config).await;
// We can't fully test persistence to disk in unit tests,
@@ -490,19 +471,7 @@ mod tests {
let file_name = "original.txt";
let user_id = "test_user";
let config = AppConfig {
data_dir: "./data".to_string(),
openai_api_key: "test_key".to_string(),
surrealdb_address: "test_address".to_string(),
surrealdb_username: "test_user".to_string(),
surrealdb_password: "test_pass".to_string(),
surrealdb_namespace: "test_ns".to_string(),
surrealdb_database: "test_db".to_string(),
http_port: 3000,
openai_base_url: "..".to_string(),
storage: StorageKind::Local,
pdf_ingest_mode: LlmFirst,
};
let config = test_config("./data");
let field_data1 = create_test_file(content, file_name);
let original_file_info = FileInfo::new(field_data1, &db, user_id, &config)
@@ -655,19 +624,7 @@ mod tests {
// Create and persist a test file via FileInfo::new
let user_id = "user123";
let cfg = AppConfig {
data_dir: "./data".to_string(),
openai_api_key: "".to_string(),
surrealdb_address: "".to_string(),
surrealdb_username: "".to_string(),
surrealdb_password: "".to_string(),
surrealdb_namespace: "".to_string(),
surrealdb_database: "".to_string(),
http_port: 0,
openai_base_url: "".to_string(),
storage: crate::utils::config::StorageKind::Local,
pdf_ingest_mode: LlmFirst,
};
let cfg = test_config("./data");
let temp = create_test_file(b"test content", "test_file.txt");
let file_info = FileInfo::new(temp, &db, user_id, &cfg)
.await
@@ -710,19 +667,7 @@ mod tests {
let result = FileInfo::delete_by_id(
"nonexistent_id",
&db,
&AppConfig {
data_dir: "./data".to_string(),
openai_api_key: "".to_string(),
surrealdb_address: "".to_string(),
surrealdb_username: "".to_string(),
surrealdb_password: "".to_string(),
surrealdb_namespace: "".to_string(),
surrealdb_database: "".to_string(),
http_port: 0,
openai_base_url: "".to_string(),
storage: crate::utils::config::StorageKind::Local,
pdf_ingest_mode: LlmFirst,
},
&test_config("./data"),
)
.await;
@@ -813,19 +758,7 @@ mod tests {
// Create a FileInfo instance with a custom data directory
let user_id = "test_user";
let custom_data_dir = "/tmp/minne_custom_data_dir";
let config = AppConfig {
data_dir: custom_data_dir.to_string(),
openai_api_key: "test_key".to_string(),
surrealdb_address: "test_address".to_string(),
surrealdb_username: "test_user".to_string(),
surrealdb_password: "test_pass".to_string(),
surrealdb_namespace: "test_ns".to_string(),
surrealdb_database: "test_db".to_string(),
http_port: 3000,
openai_base_url: "..".to_string(),
storage: StorageKind::Local,
pdf_ingest_mode: LlmFirst,
};
let config = test_config(custom_data_dir);
// Test file creation
let file_info = FileInfo::new(field_data, &db, user_id, &config)

View File

@@ -42,6 +42,16 @@ pub struct AppConfig {
pub storage: StorageKind,
#[serde(default = "default_pdf_ingest_mode")]
pub pdf_ingest_mode: PdfIngestMode,
#[serde(default = "default_reranking_enabled")]
pub reranking_enabled: bool,
#[serde(default)]
pub reranking_pool_size: Option<usize>,
#[serde(default)]
pub fastembed_cache_dir: Option<String>,
#[serde(default)]
pub fastembed_show_download_progress: Option<bool>,
#[serde(default)]
pub fastembed_max_length: Option<usize>,
}
fn default_data_dir() -> String {
@@ -52,6 +62,33 @@ fn default_base_url() -> String {
"https://api.openai.com/v1".to_string()
}
fn default_reranking_enabled() -> bool {
false
}
impl Default for AppConfig {
fn default() -> Self {
Self {
openai_api_key: String::new(),
surrealdb_address: String::new(),
surrealdb_username: String::new(),
surrealdb_password: String::new(),
surrealdb_namespace: String::new(),
surrealdb_database: String::new(),
data_dir: default_data_dir(),
http_port: 0,
openai_base_url: default_base_url(),
storage: default_storage_kind(),
pdf_ingest_mode: default_pdf_ingest_mode(),
reranking_enabled: default_reranking_enabled(),
reranking_pool_size: None,
fastembed_cache_dir: None,
fastembed_show_download_progress: None,
fastembed_max_length: None,
}
}
}
pub fn get_config() -> Result<AppConfig, ConfigError> {
let config = Config::builder()
.add_source(File::with_name("config").required(false))

View File

@@ -19,6 +19,7 @@ surrealdb = { workspace = true }
futures = { workspace = true }
async-openai = { workspace = true }
uuid = { workspace = true }
fastembed = { workspace = true }
common = { path = "../common", features = ["test-utils"] }
state-machines = { workspace = true }

View File

@@ -8,19 +8,14 @@ use async_openai::{
};
use common::{
error::AppError,
storage::{
db::SurrealDbClient,
types::{
message::{format_history, Message},
system_settings::SystemSettings,
},
storage::types::{
message::{format_history, Message},
system_settings::SystemSettings,
},
};
use serde::Deserialize;
use serde_json::Value;
use crate::{retrieve_entities, retrieved_entities_to_json};
use super::answer_retrieval_helper::get_query_response_schema;
#[derive(Debug, Deserialize)]
@@ -36,53 +31,12 @@ pub struct LLMResponseFormat {
pub references: Vec<Reference>,
}
/// Orchestrates query processing and returns an answer with references
///
/// Takes a query and uses the provided clients to generate an answer with supporting references.
///
/// # Arguments
///
/// * `surreal_db_client` - Client for `SurrealDB` interactions
/// * `openai_client` - Client for `OpenAI` API calls
/// * `query` - The user's query string
/// * `user_id` - The user's id
///
/// # Returns
///
/// Returns a tuple of the answer and its references, or an API error
#[derive(Debug)]
pub struct Answer {
pub content: String,
pub references: Vec<String>,
}
pub async fn get_answer_with_references(
surreal_db_client: &SurrealDbClient,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
query: &str,
user_id: &str,
) -> Result<Answer, AppError> {
let entities = retrieve_entities(surreal_db_client, openai_client, query, user_id).await?;
let settings = SystemSettings::get_current(surreal_db_client).await?;
let entities_json = retrieved_entities_to_json(&entities);
let user_message = create_user_message(&entities_json, query);
let request = create_chat_request(user_message, &settings)?;
let response = openai_client.chat().create(request).await?;
let llm_response = process_llm_response(response).await?;
Ok(Answer {
content: llm_response.answer,
references: llm_response
.references
.into_iter()
.map(|r| r.reference)
.collect(),
})
}
pub fn create_user_message(entities_json: &Value, query: &str) -> String {
format!(
r"

View File

@@ -3,6 +3,7 @@ pub mod answer_retrieval_helper;
pub mod fts;
pub mod graph;
pub mod pipeline;
pub mod reranking;
pub mod scoring;
pub mod vector;
@@ -13,6 +14,7 @@ use common::{
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
},
};
use reranking::RerankerLease;
use tracing::instrument;
pub use pipeline::{retrieved_entities_to_json, RetrievalConfig, RetrievalTuning};
@@ -39,6 +41,7 @@ pub async fn retrieve_entities(
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
input_text: &str,
user_id: &str,
reranker: Option<RerankerLease>,
) -> Result<Vec<RetrievedEntity>, AppError> {
pipeline::run_pipeline(
db_client,
@@ -46,6 +49,7 @@ pub async fn retrieve_entities(
input_text,
user_id,
RetrievalConfig::default(),
reranker,
)
.await
}
@@ -142,6 +146,7 @@ mod tests {
"Rust concurrency async tasks",
user_id,
RetrievalConfig::default(),
None,
)
.await
.expect("Hybrid retrieval failed");
@@ -232,6 +237,7 @@ mod tests {
"Rust concurrency async tasks",
user_id,
RetrievalConfig::default(),
None,
)
.await
.expect("Hybrid retrieval failed");

View File

@@ -17,6 +17,9 @@ pub struct RetrievalTuning {
pub graph_score_decay: f32,
pub graph_seed_min_score: f32,
pub graph_vector_inheritance: f32,
pub rerank_blend_weight: f32,
pub rerank_scores_only: bool,
pub rerank_keep_top: usize,
}
impl Default for RetrievalTuning {
@@ -36,6 +39,9 @@ impl Default for RetrievalTuning {
graph_score_decay: 0.75,
graph_seed_min_score: 0.4,
graph_vector_inheritance: 0.6,
rerank_blend_weight: 0.65,
rerank_scores_only: false,
rerank_keep_top: 8,
}
}
}

View File

@@ -4,7 +4,7 @@ mod state;
pub use config::{RetrievalConfig, RetrievalTuning};
use crate::RetrievedEntity;
use crate::{reranking::RerankerLease, RetrievedEntity};
use async_openai::Client;
use common::{error::AppError, storage::db::SurrealDbClient};
use tracing::info;
@@ -16,6 +16,7 @@ pub async fn run_pipeline(
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<Vec<RetrievedEntity>, AppError> {
let machine = state::ready();
let input_chars = input_text.chars().count();
@@ -35,11 +36,13 @@ pub async fn run_pipeline(
input_text.to_owned(),
user_id.to_owned(),
config,
reranker,
);
let machine = stages::embed(machine, &mut ctx).await?;
let machine = stages::collect_candidates(machine, &mut ctx).await?;
let machine = stages::expand_graph(machine, &mut ctx).await?;
let machine = stages::attach_chunks(machine, &mut ctx).await?;
let machine = stages::rerank(machine, &mut ctx).await?;
let results = stages::assemble(machine, &mut ctx)?;
Ok(results)
@@ -53,6 +56,7 @@ pub async fn run_pipeline_with_embedding(
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<Vec<RetrievedEntity>, AppError> {
let machine = state::ready();
let mut ctx = stages::PipelineContext::with_embedding(
@@ -62,11 +66,13 @@ pub async fn run_pipeline_with_embedding(
input_text.to_owned(),
user_id.to_owned(),
config,
reranker,
);
let machine = stages::embed(machine, &mut ctx).await?;
let machine = stages::collect_candidates(machine, &mut ctx).await?;
let machine = stages::expand_graph(machine, &mut ctx).await?;
let machine = stages::attach_chunks(machine, &mut ctx).await?;
let machine = stages::rerank(machine, &mut ctx).await?;
let results = stages::assemble(machine, &mut ctx)?;
Ok(results)

View File

@@ -7,6 +7,7 @@ use common::{
},
utils::embedding::generate_embedding,
};
use fastembed::RerankResult;
use futures::{stream::FuturesUnordered, StreamExt};
use state_machines::core::GuardError;
use std::collections::{HashMap, HashSet};
@@ -15,6 +16,7 @@ use tracing::{debug, instrument, warn};
use crate::{
fts::find_items_by_fts,
graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids},
reranking::RerankerLease,
scoring::{
clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc,
FusionWeights, Scored,
@@ -27,6 +29,7 @@ use super::{
config::RetrievalConfig,
state::{
CandidatesLoaded, ChunksAttached, Embedded, GraphExpanded, HybridRetrievalMachine, Ready,
Reranked,
},
};
@@ -41,6 +44,7 @@ pub struct PipelineContext<'a> {
pub chunk_candidates: HashMap<String, Scored<TextChunk>>,
pub filtered_entities: Vec<Scored<KnowledgeEntity>>,
pub chunk_values: Vec<Scored<TextChunk>>,
pub reranker: Option<RerankerLease>,
}
impl<'a> PipelineContext<'a> {
@@ -50,6 +54,7 @@ impl<'a> PipelineContext<'a> {
input_text: String,
user_id: String,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Self {
Self {
db_client,
@@ -62,6 +67,7 @@ impl<'a> PipelineContext<'a> {
chunk_candidates: HashMap::new(),
filtered_entities: Vec::new(),
chunk_values: Vec::new(),
reranker,
}
}
@@ -73,8 +79,16 @@ impl<'a> PipelineContext<'a> {
input_text: String,
user_id: String,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Self {
let mut ctx = Self::new(db_client, openai_client, input_text, user_id, config);
let mut ctx = Self::new(
db_client,
openai_client,
input_text,
user_id,
config,
reranker,
);
ctx.query_embedding = Some(query_embedding);
ctx
}
@@ -327,9 +341,58 @@ pub async fn attach_chunks(
}
#[instrument(level = "trace", skip_all)]
pub fn assemble(
pub async fn rerank(
machine: HybridRetrievalMachine<(), ChunksAttached>,
ctx: &mut PipelineContext<'_>,
) -> Result<HybridRetrievalMachine<(), Reranked>, AppError> {
let mut applied = false;
if let Some(reranker) = ctx.reranker.as_ref() {
if ctx.filtered_entities.len() > 1 {
let documents = build_rerank_documents(ctx, ctx.config.tuning.max_chunks_per_entity);
if documents.len() > 1 {
match reranker.rerank(&ctx.input_text, documents).await {
Ok(results) if !results.is_empty() => {
apply_rerank_results(ctx, results);
applied = true;
}
Ok(_) => {
debug!("Reranker returned no results; retaining original ordering");
}
Err(err) => {
warn!(
error = %err,
"Reranking failed; continuing with original ordering"
);
}
}
} else {
debug!(
document_count = documents.len(),
"Skipping reranking stage; insufficient document context"
);
}
} else {
debug!("Skipping reranking stage; less than two entities available");
}
} else {
debug!("No reranker lease provided; skipping reranking stage");
}
if applied {
debug!("Applied reranking adjustments to candidate ordering");
}
machine
.rerank()
.map_err(|(_, guard)| map_guard_error("rerank", guard))
}
#[instrument(level = "trace", skip_all)]
pub fn assemble(
machine: HybridRetrievalMachine<(), Reranked>,
ctx: &mut PipelineContext<'_>,
) -> Result<Vec<RetrievedEntity>, AppError> {
debug!("Assembling final retrieved entities");
let tuning = &ctx.config.tuning;
@@ -561,6 +624,113 @@ async fn enrich_chunks_from_entities(
Ok(())
}
fn build_rerank_documents(ctx: &PipelineContext<'_>, max_chunks_per_entity: usize) -> Vec<String> {
if ctx.filtered_entities.is_empty() {
return Vec::new();
}
let mut chunk_by_source: HashMap<&str, Vec<&Scored<TextChunk>>> = HashMap::new();
for chunk in &ctx.chunk_values {
chunk_by_source
.entry(chunk.item.source_id.as_str())
.or_default()
.push(chunk);
}
ctx.filtered_entities
.iter()
.map(|entity| {
let mut doc = format!(
"Name: {}\nType: {:?}\nDescription: {}\n",
entity.item.name, entity.item.entity_type, entity.item.description
);
if let Some(chunks) = chunk_by_source.get(entity.item.source_id.as_str()) {
let mut chunk_refs = chunks.clone();
chunk_refs.sort_by(|a, b| {
b.fused
.partial_cmp(&a.fused)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut header_added = false;
for chunk in chunk_refs.into_iter().take(max_chunks_per_entity.max(1)) {
let snippet = chunk.item.chunk.trim();
if snippet.is_empty() {
continue;
}
if !header_added {
doc.push_str("Chunks:\n");
header_added = true;
}
doc.push_str("- ");
doc.push_str(snippet);
doc.push('\n');
}
}
doc
})
.collect()
}
fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec<RerankResult>) {
if results.is_empty() || ctx.filtered_entities.is_empty() {
return;
}
let mut remaining: Vec<Option<Scored<KnowledgeEntity>>> =
std::mem::take(&mut ctx.filtered_entities)
.into_iter()
.map(Some)
.collect();
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
let normalized_scores = min_max_normalize(&raw_scores);
let use_only = ctx.config.tuning.rerank_scores_only;
let blend = if use_only {
1.0
} else {
clamp_unit(ctx.config.tuning.rerank_blend_weight)
};
let mut reranked: Vec<Scored<KnowledgeEntity>> = Vec::with_capacity(remaining.len());
for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) {
if let Some(slot) = remaining.get_mut(result.index) {
if let Some(mut candidate) = slot.take() {
let original = candidate.fused;
let blended = if use_only {
clamp_unit(normalized)
} else {
clamp_unit(original * (1.0 - blend) + normalized * blend)
};
candidate.update_fused(blended);
reranked.push(candidate);
}
} else {
warn!(
result_index = result.index,
"Reranker returned out-of-range index; skipping"
);
}
if reranked.len() == remaining.len() {
break;
}
}
for slot in remaining.into_iter() {
if let Some(candidate) = slot {
reranked.push(candidate);
}
}
ctx.filtered_entities = reranked;
let keep_top = ctx.config.tuning.rerank_keep_top;
if keep_top > 0 && ctx.filtered_entities.len() > keep_top {
ctx.filtered_entities.truncate(keep_top);
}
}
fn estimate_tokens(text: &str, avg_chars_per_token: usize) -> usize {
let chars = text.chars().count().max(1);
(chars / avg_chars_per_token).max(1)

View File

@@ -4,18 +4,20 @@ state_machine! {
name: HybridRetrievalMachine,
state: HybridRetrievalState,
initial: Ready,
states: [Ready, Embedded, CandidatesLoaded, GraphExpanded, ChunksAttached, Completed, Failed],
states: [Ready, Embedded, CandidatesLoaded, GraphExpanded, ChunksAttached, Reranked, Completed, Failed],
events {
embed { transition: { from: Ready, to: Embedded } }
collect_candidates { transition: { from: Embedded, to: CandidatesLoaded } }
expand_graph { transition: { from: CandidatesLoaded, to: GraphExpanded } }
attach_chunks { transition: { from: GraphExpanded, to: ChunksAttached } }
assemble { transition: { from: ChunksAttached, to: Completed } }
rerank { transition: { from: ChunksAttached, to: Reranked } }
assemble { transition: { from: Reranked, to: Completed } }
abort {
transition: { from: Ready, to: Failed }
transition: { from: CandidatesLoaded, to: Failed }
transition: { from: GraphExpanded, to: Failed }
transition: { from: ChunksAttached, to: Failed }
transition: { from: Reranked, to: Failed }
}
}
}

View File

@@ -0,0 +1,170 @@
use std::{
env, fs,
path::{Path, PathBuf},
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
thread::available_parallelism,
};
use common::{error::AppError, utils::config::AppConfig};
use fastembed::{RerankInitOptions, RerankResult, TextRerank};
use tokio::sync::{Mutex, OwnedSemaphorePermit, Semaphore};
use tracing::debug;
static NEXT_ENGINE: AtomicUsize = AtomicUsize::new(0);
fn pick_engine_index(pool_len: usize) -> usize {
let n = NEXT_ENGINE.fetch_add(1, Ordering::Relaxed);
n % pool_len
}
pub struct RerankerPool {
engines: Vec<Arc<Mutex<TextRerank>>>,
semaphore: Arc<Semaphore>,
}
impl RerankerPool {
/// Build the pool at startup.
/// `pool_size` controls max parallel reranks.
pub fn new(pool_size: usize) -> Result<Arc<Self>, AppError> {
Self::new_with_options(pool_size, RerankInitOptions::default())
}
fn new_with_options(
pool_size: usize,
init_options: RerankInitOptions,
) -> Result<Arc<Self>, AppError> {
if pool_size == 0 {
return Err(AppError::Validation(
"RERANKING_POOL_SIZE must be greater than zero".to_string(),
));
}
fs::create_dir_all(&init_options.cache_dir)?;
let mut engines = Vec::with_capacity(pool_size);
for x in 0..pool_size {
debug!("Creating reranking engine: {x}");
let model = TextRerank::try_new(init_options.clone())
.map_err(|e| AppError::InternalError(e.to_string()))?;
engines.push(Arc::new(Mutex::new(model)));
}
Ok(Arc::new(Self {
engines,
semaphore: Arc::new(Semaphore::new(pool_size)),
}))
}
/// Initialize a pool using application configuration.
pub fn maybe_from_config(config: &AppConfig) -> Result<Option<Arc<Self>>, AppError> {
if !config.reranking_enabled {
return Ok(None);
}
let pool_size = config.reranking_pool_size.unwrap_or_else(default_pool_size);
let init_options = build_rerank_init_options(config)?;
Self::new_with_options(pool_size, init_options).map(Some)
}
/// Check out capacity + pick an engine.
/// This returns a lease that can perform rerank().
pub async fn checkout(self: &Arc<Self>) -> RerankerLease {
// Acquire a permit. This enforces backpressure.
let permit = self
.semaphore
.clone()
.acquire_owned()
.await
.expect("semaphore closed");
// Pick an engine.
// This is naive: just pick based on a simple modulo counter.
// We use an atomic counter to avoid always choosing index 0.
let idx = pick_engine_index(self.engines.len());
let engine = self.engines[idx].clone();
RerankerLease {
_permit: permit,
engine,
}
}
}
fn default_pool_size() -> usize {
available_parallelism()
.map(|value| value.get().min(2))
.unwrap_or(2)
.max(1)
}
fn is_truthy(value: &str) -> bool {
matches!(
value.trim().to_ascii_lowercase().as_str(),
"1" | "true" | "yes" | "on"
)
}
fn build_rerank_init_options(config: &AppConfig) -> Result<RerankInitOptions, AppError> {
let mut options = RerankInitOptions::default();
let cache_dir = config
.fastembed_cache_dir
.as_ref()
.map(PathBuf::from)
.or_else(|| env::var("RERANKING_CACHE_DIR").ok().map(PathBuf::from))
.or_else(|| env::var("FASTEMBED_CACHE_DIR").ok().map(PathBuf::from))
.unwrap_or_else(|| {
Path::new(&config.data_dir)
.join("fastembed")
.join("reranker")
});
fs::create_dir_all(&cache_dir)?;
options.cache_dir = cache_dir;
let show_progress = config
.fastembed_show_download_progress
.or_else(|| env_bool("RERANKING_SHOW_DOWNLOAD_PROGRESS"))
.or_else(|| env_bool("FASTEMBED_SHOW_DOWNLOAD_PROGRESS"))
.unwrap_or(true);
options.show_download_progress = show_progress;
if let Some(max_length) = config.fastembed_max_length.or_else(|| {
env::var("RERANKING_MAX_LENGTH")
.ok()
.and_then(|value| value.parse().ok())
}) {
options.max_length = max_length;
}
Ok(options)
}
fn env_bool(key: &str) -> Option<bool> {
env::var(key).ok().map(|value| is_truthy(&value))
}
/// Active lease on a single TextRerank instance.
pub struct RerankerLease {
// When this drops the semaphore permit is released.
_permit: OwnedSemaphorePermit,
engine: Arc<Mutex<TextRerank>>,
}
impl RerankerLease {
pub async fn rerank(
&self,
query: &str,
documents: Vec<String>,
) -> Result<Vec<RerankResult>, AppError> {
// Lock this specific engine so we get &mut TextRerank
let mut guard = self.engine.lock().await;
guard
.rerank(query.to_owned(), documents, false, None)
.map_err(|e| AppError::InternalError(e.to_string()))
}
}

File diff suppressed because one or more lines are too long

View File

@@ -1,6 +1,7 @@
use common::storage::db::SurrealDbClient;
use common::utils::template_engine::{ProvidesTemplateEngine, TemplateEngine};
use common::{create_template_engine, storage::db::ProvidesDb, utils::config::AppConfig};
use composite_retrieval::reranking::RerankerPool;
use std::sync::Arc;
use tracing::debug;
@@ -13,6 +14,7 @@ pub struct HtmlState {
pub templates: Arc<TemplateEngine>,
pub session_store: Arc<SessionStoreType>,
pub config: AppConfig,
pub reranker_pool: Option<Arc<RerankerPool>>,
}
impl HtmlState {
@@ -21,6 +23,7 @@ impl HtmlState {
openai_client: Arc<OpenAIClientType>,
session_store: Arc<SessionStoreType>,
config: AppConfig,
reranker_pool: Option<Arc<RerankerPool>>,
) -> Result<Self, Box<dyn std::error::Error>> {
let template_engine = create_template_engine!("templates");
debug!("Template engine created for html_router.");
@@ -31,6 +34,7 @@ impl HtmlState {
session_store,
templates: Arc::new(template_engine),
config,
reranker_pool,
})
}
}

View File

@@ -118,11 +118,17 @@ pub async fn get_response_stream(
};
// 2. Retrieve knowledge entities
let rerank_lease = match state.reranker_pool.as_ref() {
Some(pool) => Some(pool.checkout().await),
None => None,
};
let entities = match retrieve_entities(
&state.db,
&state.openai_client,
&user_message.content,
&user.id,
rerank_lease,
)
.await
{

View File

@@ -195,8 +195,19 @@ pub async fn suggest_knowledge_relationships(
if !query_parts.is_empty() {
let query = query_parts.join(" ");
if let Ok(results) =
retrieve_entities(&state.db, &state.openai_client, &query, &user.id).await
let rerank_lease = match state.reranker_pool.as_ref() {
Some(pool) => Some(pool.checkout().await),
None => None,
};
if let Ok(results) = retrieve_entities(
&state.db,
&state.openai_client,
&query,
&user.id,
rerank_lease,
)
.await
{
for RetrievedEntity { entity, score, .. } in results {
if suggestion_scores.len() >= MAX_RELATIONSHIP_SUGGESTIONS {

View File

@@ -26,6 +26,7 @@ use common::{
},
utils::config::AppConfig,
};
use composite_retrieval::reranking::RerankerPool;
use tracing::{debug, info, warn};
use self::{
@@ -45,9 +46,14 @@ impl IngestionPipeline {
db: Arc<SurrealDbClient>,
openai_client: Arc<Client<async_openai::config::OpenAIConfig>>,
config: AppConfig,
reranker_pool: Option<Arc<RerankerPool>>,
) -> Result<Self, AppError> {
let services =
DefaultPipelineServices::new(db.clone(), openai_client.clone(), config.clone());
let services = DefaultPipelineServices::new(
db.clone(),
openai_client.clone(),
config.clone(),
reranker_pool,
);
Self::with_services(db, IngestionConfig::default(), Arc::new(services))
}

View File

@@ -18,7 +18,9 @@ use common::{
},
utils::{config::AppConfig, embedding::generate_embedding},
};
use composite_retrieval::{retrieve_entities, retrieved_entities_to_json, RetrievedEntity};
use composite_retrieval::{
reranking::RerankerPool, retrieve_entities, retrieved_entities_to_json, RetrievedEntity,
};
use text_splitter::TextSplitter;
use super::{enrichment_result::LLMEnrichmentResult, preparation::to_text_content};
@@ -62,6 +64,7 @@ pub struct DefaultPipelineServices {
db: Arc<SurrealDbClient>,
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
config: AppConfig,
reranker_pool: Option<Arc<RerankerPool>>,
}
impl DefaultPipelineServices {
@@ -69,11 +72,13 @@ impl DefaultPipelineServices {
db: Arc<SurrealDbClient>,
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
config: AppConfig,
reranker_pool: Option<Arc<RerankerPool>>,
) -> Self {
Self {
db,
openai_client,
config,
reranker_pool,
}
}
@@ -151,7 +156,19 @@ impl PipelineServices for DefaultPipelineServices {
content.text, content.category, content.context
);
retrieve_entities(&self.db, &self.openai_client, &input_text, &content.user_id).await
let rerank_lease = match &self.reranker_pool {
Some(pool) => Some(pool.checkout().await),
None => None,
};
retrieve_entities(
&self.db,
&self.openai_client,
&input_text,
&content.user_id,
rerank_lease,
)
.await
}
async fn run_enrichment(

View File

@@ -25,6 +25,7 @@ ingestion-pipeline = { path = "../ingestion-pipeline" }
api-router = { path = "../api-router" }
html-router = { path = "../html-router" }
common = { path = "../common" }
composite-retrieval = { path = "../composite-retrieval" }
[dev-dependencies]
tower = "0.5"

View File

@@ -1,6 +1,7 @@
use api_router::{api_routes_v1, api_state::ApiState};
use axum::{extract::FromRef, Router};
use common::{storage::db::SurrealDbClient, utils::config::get_config};
use composite_retrieval::reranking::RerankerPool;
use html_router::{html_routes, html_state::HtmlState};
use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop};
use std::sync::Arc;
@@ -43,8 +44,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.with_api_base(&config.openai_base_url),
));
let html_state =
HtmlState::new_with_resources(db, openai_client, session_store, config.clone())?;
let reranker_pool = RerankerPool::maybe_from_config(&config)?;
let html_state = HtmlState::new_with_resources(
db,
openai_client,
session_store,
config.clone(),
reranker_pool.clone(),
)?;
let api_state = ApiState {
db: html_state.db.clone(),
@@ -102,9 +110,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.with_api_base(&config.openai_base_url),
));
let ingestion_pipeline = Arc::new(
IngestionPipeline::new(worker_db.clone(), openai_client.clone(), config.clone())
.await
.unwrap(),
IngestionPipeline::new(
worker_db.clone(),
openai_client.clone(),
config.clone(),
reranker_pool.clone(),
)
.await
.unwrap(),
);
info!("Starting worker process");
@@ -152,6 +165,7 @@ mod tests {
openai_base_url: "https://example.com".into(),
storage: StorageKind::Local,
pdf_ingest_mode: PdfIngestMode::LlmFirst,
..Default::default()
}
}
@@ -181,9 +195,14 @@ mod tests {
.with_api_base(&config.openai_base_url),
));
let html_state =
HtmlState::new_with_resources(db.clone(), openai_client, session_store, config.clone())
.expect("failed to build html state");
let html_state = HtmlState::new_with_resources(
db.clone(),
openai_client,
session_store,
config.clone(),
None,
)
.expect("failed to build html state");
let api_state = ApiState {
db: html_state.db.clone(),

View File

@@ -3,6 +3,7 @@ use std::sync::Arc;
use api_router::{api_routes_v1, api_state::ApiState};
use axum::{extract::FromRef, Router};
use common::{storage::db::SurrealDbClient, utils::config::get_config};
use composite_retrieval::reranking::RerankerPool;
use html_router::{html_routes, html_state::HtmlState};
use tracing::info;
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
@@ -41,8 +42,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.with_api_base(&config.openai_base_url),
));
let html_state =
HtmlState::new_with_resources(db, openai_client, session_store, config.clone())?;
let reranker_pool = RerankerPool::maybe_from_config(&config)?;
let html_state = HtmlState::new_with_resources(
db,
openai_client,
session_store,
config.clone(),
reranker_pool,
)?;
let api_state = ApiState {
db: html_state.db.clone(),

View File

@@ -1,6 +1,7 @@
use std::sync::Arc;
use common::{storage::db::SurrealDbClient, utils::config::get_config};
use composite_retrieval::reranking::RerankerPool;
use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop};
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
@@ -32,8 +33,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.with_api_base(&config.openai_base_url),
));
let ingestion_pipeline =
Arc::new(IngestionPipeline::new(db.clone(), openai_client.clone(), config).await?);
let reranker_pool = RerankerPool::maybe_from_config(&config)?;
let ingestion_pipeline = Arc::new(
IngestionPipeline::new(db.clone(), openai_client.clone(), config, reranker_pool).await?,
);
run_worker_loop(db, ingestion_pipeline).await
}