mirror of
https://github.com/perstarkse/minne.git
synced 2026-04-14 13:09:47 +02:00
feat: reranking with fastembed added
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# Changelog
|
||||
## Unreleased
|
||||
- Added an opt-in FastEmbed-based reranking stage behind `reranking_enabled`. It improves retrieval accuracy by re-scoring hybrid results.
|
||||
|
||||
## Version 0.2.5 (2025-10-24)
|
||||
- Added manual knowledge entity creation flows using a modal, with the option for suggested relationships
|
||||
|
||||
1024
Cargo.lock
generated
1024
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
50
Cargo.toml
50
Cargo.toml
@@ -56,18 +56,56 @@ base64 = "0.22.1"
|
||||
object_store = { version = "0.11.2" }
|
||||
bytes = "1.7.1"
|
||||
state-machines = "0.2.0"
|
||||
fastembed = "5.2.0"
|
||||
|
||||
[profile.dist]
|
||||
inherits = "release"
|
||||
lto = "thin"
|
||||
|
||||
[workspace.lints.clippy]
|
||||
perf = { level = "warn", priority = -1 }
|
||||
pedantic = { level = "warn", priority = -1 }
|
||||
nursery = { level = "warn", priority = -1 }
|
||||
cargo = { level = "warn", priority = -1 }
|
||||
# Performance-focused lints
|
||||
perf = { level = "warn", priority = -1 }
|
||||
vec_init_then_push = "warn"
|
||||
large_stack_frames = "warn"
|
||||
redundant_allocation = "warn"
|
||||
single_char_pattern = "warn"
|
||||
string_extend_chars = "warn"
|
||||
format_in_format_args = "warn"
|
||||
slow_vector_initialization = "warn"
|
||||
inefficient_to_string = "warn"
|
||||
implicit_clone = "warn"
|
||||
redundant_clone = "warn"
|
||||
|
||||
needless_question_mark = "allow"
|
||||
single_call_fn = "allow"
|
||||
# Security-focused lints
|
||||
integer_arithmetic = "warn"
|
||||
indexing_slicing = "warn"
|
||||
unwrap_used = "warn"
|
||||
expect_used = "warn"
|
||||
panic = "warn"
|
||||
unimplemented = "warn"
|
||||
todo = "warn"
|
||||
|
||||
# Async/Network lints
|
||||
async_yields_async = "warn"
|
||||
await_holding_invalid_state = "warn"
|
||||
rc_buffer = "warn"
|
||||
|
||||
# Maintainability-focused lints
|
||||
cargo = { level = "warn", priority = -1 }
|
||||
pedantic = { level = "warn", priority = -1 }
|
||||
clone_on_ref_ptr = "warn"
|
||||
float_cmp = "warn"
|
||||
manual_string_new = "warn"
|
||||
uninlined_format_args = "warn"
|
||||
unused_self = "warn"
|
||||
must_use_candidate = "allow"
|
||||
missing_errors_doc = "allow"
|
||||
missing_panics_doc = "warn"
|
||||
module_name_repetitions = "warn"
|
||||
wildcard_dependencies = "warn"
|
||||
missing_docs_in_private_items = "warn"
|
||||
|
||||
# Allow noisy lints that don't add value for this project
|
||||
manual_must_use = "allow"
|
||||
needless_raw_string_hashes = "allow"
|
||||
multiple_bound_locations = "allow"
|
||||
|
||||
21
README.md
21
README.md
@@ -98,6 +98,23 @@ The graph visualization shows:
|
||||
- Relationships as connections (manually defined, AI-discovered, or suggested)
|
||||
- Interactive navigation for discovery and editing
|
||||
|
||||
### Optional FastEmbed Reranking
|
||||
|
||||
Minne ships with an opt-in reranking stage powered by [fastembed-rs](https://github.com/Anush008/fastembed-rs). When enabled, the hybrid retrieval results are rescored with a lightweight cross-encoder before being returned to chat or ingestion flows. In practice this often means more relevant results, boosting answer quality and downstream enrichment.
|
||||
|
||||
⚠️ **Resource notes**
|
||||
- Enabling reranking downloads and caches ~1.1 GB of model data on first startup (cached under `<data_dir>/fastembed/reranker` by default).
|
||||
- Initialization takes longer while warming the cache, and each query consumes extra CPU. The default pool size (2) is tuned for a singe user setup, but could work with a pool size on 1 as well.
|
||||
- The feature is disabled by default. Set `reranking_enabled: true` (or `RERANKING_ENABLED=true`) if you’re comfortable with the additional footprint.
|
||||
|
||||
Example configuration:
|
||||
|
||||
```yaml
|
||||
reranking_enabled: true
|
||||
reranking_pool_size: 2
|
||||
fastembed_cache_dir: "/var/lib/minne/fastembed" # optional override, defaults to .fastembed_cache
|
||||
```
|
||||
|
||||
## Tech Stack
|
||||
|
||||
- **Backend:** Rust with Axum framework and Server-Side Rendering (SSR)
|
||||
@@ -125,6 +142,10 @@ Minne can be configured using environment variables or a `config.yaml` file. Env
|
||||
- `RUST_LOG`: Controls logging level (e.g., `minne=info,tower_http=debug`)
|
||||
- `DATA_DIR`: Directory to store local data (e.g., `./data`)
|
||||
- `OPENAI_BASE_URL`: Base URL for custom AI providers (like Ollama)
|
||||
- `RERANKING_ENABLED` / `reranking_enabled`: Set to `true` to enable the FastEmbed reranking stage (default `false`)
|
||||
- `RERANKING_POOL_SIZE` / `reranking_pool_size`: Maximum concurrent reranker workers (defaults to `2`)
|
||||
- `FASTEMBED_CACHE_DIR` / `fastembed_cache_dir`: Directory for cached FastEmbed models (defaults to `<data_dir>/fastembed/reranker`)
|
||||
- `FASTEMBED_SHOW_DOWNLOAD_PROGRESS` / `fastembed_show_download_progress`: Show model download progress when warming the cache (default `true`)
|
||||
|
||||
### Example config.yaml
|
||||
|
||||
|
||||
@@ -214,6 +214,7 @@ mod tests {
|
||||
openai_base_url: "..".into(),
|
||||
storage: StorageKind::Local,
|
||||
pdf_ingest_mode: LlmFirst,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -270,12 +270,29 @@ impl FileInfo {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::utils::config::{PdfIngestMode::LlmFirst, StorageKind};
|
||||
use crate::utils::config::{AppConfig, PdfIngestMode::LlmFirst, StorageKind};
|
||||
use axum::http::HeaderMap;
|
||||
use axum_typed_multipart::FieldMetadata;
|
||||
use std::io::Write;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
fn test_config(data_dir: &str) -> AppConfig {
|
||||
AppConfig {
|
||||
data_dir: data_dir.to_string(),
|
||||
openai_api_key: "test_key".to_string(),
|
||||
surrealdb_address: "test_address".to_string(),
|
||||
surrealdb_username: "test_user".to_string(),
|
||||
surrealdb_password: "test_pass".to_string(),
|
||||
surrealdb_namespace: "test_ns".to_string(),
|
||||
surrealdb_database: "test_db".to_string(),
|
||||
http_port: 3000,
|
||||
openai_base_url: "..".to_string(),
|
||||
storage: StorageKind::Local,
|
||||
pdf_ingest_mode: LlmFirst,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a test temporary file with the given content
|
||||
fn create_test_file(content: &[u8], file_name: &str) -> FieldData<NamedTempFile> {
|
||||
let mut temp_file = NamedTempFile::new().expect("Failed to create temp file");
|
||||
@@ -314,19 +331,7 @@ mod tests {
|
||||
|
||||
// Create a FileInfo instance with data_dir in /tmp
|
||||
let user_id = "test_user";
|
||||
let config = AppConfig {
|
||||
data_dir: "/tmp/minne_test_data".to_string(), // Using /tmp which is typically on a different filesystem
|
||||
openai_api_key: "test_key".to_string(),
|
||||
surrealdb_address: "test_address".to_string(),
|
||||
surrealdb_username: "test_user".to_string(),
|
||||
surrealdb_password: "test_pass".to_string(),
|
||||
surrealdb_namespace: "test_ns".to_string(),
|
||||
surrealdb_database: "test_db".to_string(),
|
||||
http_port: 3000,
|
||||
openai_base_url: "..".to_string(),
|
||||
storage: StorageKind::Local,
|
||||
pdf_ingest_mode: LlmFirst,
|
||||
};
|
||||
let config = test_config("/tmp/minne_test_data");
|
||||
|
||||
// Test file creation
|
||||
let file_info = FileInfo::new(field_data, &db, user_id, &config)
|
||||
@@ -375,19 +380,7 @@ mod tests {
|
||||
|
||||
// Create a FileInfo instance with data_dir in /tmp
|
||||
let user_id = "test_user";
|
||||
let config = AppConfig {
|
||||
data_dir: "/tmp/minne_test_data".to_string(),
|
||||
openai_api_key: "test_key".to_string(),
|
||||
surrealdb_address: "test_address".to_string(),
|
||||
surrealdb_username: "test_user".to_string(),
|
||||
surrealdb_password: "test_pass".to_string(),
|
||||
surrealdb_namespace: "test_ns".to_string(),
|
||||
surrealdb_database: "test_db".to_string(),
|
||||
http_port: 3000,
|
||||
openai_base_url: "..".to_string(),
|
||||
storage: StorageKind::Local,
|
||||
pdf_ingest_mode: LlmFirst,
|
||||
};
|
||||
let config = test_config("/tmp/minne_test_data");
|
||||
|
||||
// Store the original file
|
||||
let original_file_info = FileInfo::new(field_data, &db, user_id, &config)
|
||||
@@ -432,19 +425,7 @@ mod tests {
|
||||
|
||||
// Create a FileInfo instance
|
||||
let user_id = "test_user";
|
||||
let config = AppConfig {
|
||||
data_dir: "./data".to_string(),
|
||||
openai_api_key: "test_key".to_string(),
|
||||
surrealdb_address: "test_address".to_string(),
|
||||
surrealdb_username: "test_user".to_string(),
|
||||
surrealdb_password: "test_pass".to_string(),
|
||||
surrealdb_namespace: "test_ns".to_string(),
|
||||
surrealdb_database: "test_db".to_string(),
|
||||
http_port: 3000,
|
||||
openai_base_url: "..".to_string(),
|
||||
storage: StorageKind::Local,
|
||||
pdf_ingest_mode: LlmFirst,
|
||||
};
|
||||
let config = test_config("./data");
|
||||
let file_info = FileInfo::new(field_data, &db, user_id, &config).await;
|
||||
|
||||
// We can't fully test persistence to disk in unit tests,
|
||||
@@ -490,19 +471,7 @@ mod tests {
|
||||
let file_name = "original.txt";
|
||||
let user_id = "test_user";
|
||||
|
||||
let config = AppConfig {
|
||||
data_dir: "./data".to_string(),
|
||||
openai_api_key: "test_key".to_string(),
|
||||
surrealdb_address: "test_address".to_string(),
|
||||
surrealdb_username: "test_user".to_string(),
|
||||
surrealdb_password: "test_pass".to_string(),
|
||||
surrealdb_namespace: "test_ns".to_string(),
|
||||
surrealdb_database: "test_db".to_string(),
|
||||
http_port: 3000,
|
||||
openai_base_url: "..".to_string(),
|
||||
storage: StorageKind::Local,
|
||||
pdf_ingest_mode: LlmFirst,
|
||||
};
|
||||
let config = test_config("./data");
|
||||
|
||||
let field_data1 = create_test_file(content, file_name);
|
||||
let original_file_info = FileInfo::new(field_data1, &db, user_id, &config)
|
||||
@@ -655,19 +624,7 @@ mod tests {
|
||||
|
||||
// Create and persist a test file via FileInfo::new
|
||||
let user_id = "user123";
|
||||
let cfg = AppConfig {
|
||||
data_dir: "./data".to_string(),
|
||||
openai_api_key: "".to_string(),
|
||||
surrealdb_address: "".to_string(),
|
||||
surrealdb_username: "".to_string(),
|
||||
surrealdb_password: "".to_string(),
|
||||
surrealdb_namespace: "".to_string(),
|
||||
surrealdb_database: "".to_string(),
|
||||
http_port: 0,
|
||||
openai_base_url: "".to_string(),
|
||||
storage: crate::utils::config::StorageKind::Local,
|
||||
pdf_ingest_mode: LlmFirst,
|
||||
};
|
||||
let cfg = test_config("./data");
|
||||
let temp = create_test_file(b"test content", "test_file.txt");
|
||||
let file_info = FileInfo::new(temp, &db, user_id, &cfg)
|
||||
.await
|
||||
@@ -710,19 +667,7 @@ mod tests {
|
||||
let result = FileInfo::delete_by_id(
|
||||
"nonexistent_id",
|
||||
&db,
|
||||
&AppConfig {
|
||||
data_dir: "./data".to_string(),
|
||||
openai_api_key: "".to_string(),
|
||||
surrealdb_address: "".to_string(),
|
||||
surrealdb_username: "".to_string(),
|
||||
surrealdb_password: "".to_string(),
|
||||
surrealdb_namespace: "".to_string(),
|
||||
surrealdb_database: "".to_string(),
|
||||
http_port: 0,
|
||||
openai_base_url: "".to_string(),
|
||||
storage: crate::utils::config::StorageKind::Local,
|
||||
pdf_ingest_mode: LlmFirst,
|
||||
},
|
||||
&test_config("./data"),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -813,19 +758,7 @@ mod tests {
|
||||
// Create a FileInfo instance with a custom data directory
|
||||
let user_id = "test_user";
|
||||
let custom_data_dir = "/tmp/minne_custom_data_dir";
|
||||
let config = AppConfig {
|
||||
data_dir: custom_data_dir.to_string(),
|
||||
openai_api_key: "test_key".to_string(),
|
||||
surrealdb_address: "test_address".to_string(),
|
||||
surrealdb_username: "test_user".to_string(),
|
||||
surrealdb_password: "test_pass".to_string(),
|
||||
surrealdb_namespace: "test_ns".to_string(),
|
||||
surrealdb_database: "test_db".to_string(),
|
||||
http_port: 3000,
|
||||
openai_base_url: "..".to_string(),
|
||||
storage: StorageKind::Local,
|
||||
pdf_ingest_mode: LlmFirst,
|
||||
};
|
||||
let config = test_config(custom_data_dir);
|
||||
|
||||
// Test file creation
|
||||
let file_info = FileInfo::new(field_data, &db, user_id, &config)
|
||||
|
||||
@@ -42,6 +42,16 @@ pub struct AppConfig {
|
||||
pub storage: StorageKind,
|
||||
#[serde(default = "default_pdf_ingest_mode")]
|
||||
pub pdf_ingest_mode: PdfIngestMode,
|
||||
#[serde(default = "default_reranking_enabled")]
|
||||
pub reranking_enabled: bool,
|
||||
#[serde(default)]
|
||||
pub reranking_pool_size: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub fastembed_cache_dir: Option<String>,
|
||||
#[serde(default)]
|
||||
pub fastembed_show_download_progress: Option<bool>,
|
||||
#[serde(default)]
|
||||
pub fastembed_max_length: Option<usize>,
|
||||
}
|
||||
|
||||
fn default_data_dir() -> String {
|
||||
@@ -52,6 +62,33 @@ fn default_base_url() -> String {
|
||||
"https://api.openai.com/v1".to_string()
|
||||
}
|
||||
|
||||
fn default_reranking_enabled() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
impl Default for AppConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
openai_api_key: String::new(),
|
||||
surrealdb_address: String::new(),
|
||||
surrealdb_username: String::new(),
|
||||
surrealdb_password: String::new(),
|
||||
surrealdb_namespace: String::new(),
|
||||
surrealdb_database: String::new(),
|
||||
data_dir: default_data_dir(),
|
||||
http_port: 0,
|
||||
openai_base_url: default_base_url(),
|
||||
storage: default_storage_kind(),
|
||||
pdf_ingest_mode: default_pdf_ingest_mode(),
|
||||
reranking_enabled: default_reranking_enabled(),
|
||||
reranking_pool_size: None,
|
||||
fastembed_cache_dir: None,
|
||||
fastembed_show_download_progress: None,
|
||||
fastembed_max_length: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_config() -> Result<AppConfig, ConfigError> {
|
||||
let config = Config::builder()
|
||||
.add_source(File::with_name("config").required(false))
|
||||
|
||||
@@ -19,6 +19,7 @@ surrealdb = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
async-openai = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
fastembed = { workspace = true }
|
||||
|
||||
common = { path = "../common", features = ["test-utils"] }
|
||||
state-machines = { workspace = true }
|
||||
|
||||
@@ -8,19 +8,14 @@ use async_openai::{
|
||||
};
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{
|
||||
message::{format_history, Message},
|
||||
system_settings::SystemSettings,
|
||||
},
|
||||
storage::types::{
|
||||
message::{format_history, Message},
|
||||
system_settings::SystemSettings,
|
||||
},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::{retrieve_entities, retrieved_entities_to_json};
|
||||
|
||||
use super::answer_retrieval_helper::get_query_response_schema;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -36,53 +31,12 @@ pub struct LLMResponseFormat {
|
||||
pub references: Vec<Reference>,
|
||||
}
|
||||
|
||||
/// Orchestrates query processing and returns an answer with references
|
||||
///
|
||||
/// Takes a query and uses the provided clients to generate an answer with supporting references.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `surreal_db_client` - Client for `SurrealDB` interactions
|
||||
/// * `openai_client` - Client for `OpenAI` API calls
|
||||
/// * `query` - The user's query string
|
||||
/// * `user_id` - The user's id
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns a tuple of the answer and its references, or an API error
|
||||
#[derive(Debug)]
|
||||
pub struct Answer {
|
||||
pub content: String,
|
||||
pub references: Vec<String>,
|
||||
}
|
||||
|
||||
pub async fn get_answer_with_references(
|
||||
surreal_db_client: &SurrealDbClient,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
query: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Answer, AppError> {
|
||||
let entities = retrieve_entities(surreal_db_client, openai_client, query, user_id).await?;
|
||||
let settings = SystemSettings::get_current(surreal_db_client).await?;
|
||||
|
||||
let entities_json = retrieved_entities_to_json(&entities);
|
||||
let user_message = create_user_message(&entities_json, query);
|
||||
|
||||
let request = create_chat_request(user_message, &settings)?;
|
||||
let response = openai_client.chat().create(request).await?;
|
||||
|
||||
let llm_response = process_llm_response(response).await?;
|
||||
|
||||
Ok(Answer {
|
||||
content: llm_response.answer,
|
||||
references: llm_response
|
||||
.references
|
||||
.into_iter()
|
||||
.map(|r| r.reference)
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn create_user_message(entities_json: &Value, query: &str) -> String {
|
||||
format!(
|
||||
r"
|
||||
|
||||
@@ -3,6 +3,7 @@ pub mod answer_retrieval_helper;
|
||||
pub mod fts;
|
||||
pub mod graph;
|
||||
pub mod pipeline;
|
||||
pub mod reranking;
|
||||
pub mod scoring;
|
||||
pub mod vector;
|
||||
|
||||
@@ -13,6 +14,7 @@ use common::{
|
||||
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
|
||||
},
|
||||
};
|
||||
use reranking::RerankerLease;
|
||||
use tracing::instrument;
|
||||
|
||||
pub use pipeline::{retrieved_entities_to_json, RetrievalConfig, RetrievalTuning};
|
||||
@@ -39,6 +41,7 @@ pub async fn retrieve_entities(
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
pipeline::run_pipeline(
|
||||
db_client,
|
||||
@@ -46,6 +49,7 @@ pub async fn retrieve_entities(
|
||||
input_text,
|
||||
user_id,
|
||||
RetrievalConfig::default(),
|
||||
reranker,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -142,6 +146,7 @@ mod tests {
|
||||
"Rust concurrency async tasks",
|
||||
user_id,
|
||||
RetrievalConfig::default(),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Hybrid retrieval failed");
|
||||
@@ -232,6 +237,7 @@ mod tests {
|
||||
"Rust concurrency async tasks",
|
||||
user_id,
|
||||
RetrievalConfig::default(),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Hybrid retrieval failed");
|
||||
|
||||
@@ -17,6 +17,9 @@ pub struct RetrievalTuning {
|
||||
pub graph_score_decay: f32,
|
||||
pub graph_seed_min_score: f32,
|
||||
pub graph_vector_inheritance: f32,
|
||||
pub rerank_blend_weight: f32,
|
||||
pub rerank_scores_only: bool,
|
||||
pub rerank_keep_top: usize,
|
||||
}
|
||||
|
||||
impl Default for RetrievalTuning {
|
||||
@@ -36,6 +39,9 @@ impl Default for RetrievalTuning {
|
||||
graph_score_decay: 0.75,
|
||||
graph_seed_min_score: 0.4,
|
||||
graph_vector_inheritance: 0.6,
|
||||
rerank_blend_weight: 0.65,
|
||||
rerank_scores_only: false,
|
||||
rerank_keep_top: 8,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ mod state;
|
||||
|
||||
pub use config::{RetrievalConfig, RetrievalTuning};
|
||||
|
||||
use crate::RetrievedEntity;
|
||||
use crate::{reranking::RerankerLease, RetrievedEntity};
|
||||
use async_openai::Client;
|
||||
use common::{error::AppError, storage::db::SurrealDbClient};
|
||||
use tracing::info;
|
||||
@@ -16,6 +16,7 @@ pub async fn run_pipeline(
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
let machine = state::ready();
|
||||
let input_chars = input_text.chars().count();
|
||||
@@ -35,11 +36,13 @@ pub async fn run_pipeline(
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
config,
|
||||
reranker,
|
||||
);
|
||||
let machine = stages::embed(machine, &mut ctx).await?;
|
||||
let machine = stages::collect_candidates(machine, &mut ctx).await?;
|
||||
let machine = stages::expand_graph(machine, &mut ctx).await?;
|
||||
let machine = stages::attach_chunks(machine, &mut ctx).await?;
|
||||
let machine = stages::rerank(machine, &mut ctx).await?;
|
||||
let results = stages::assemble(machine, &mut ctx)?;
|
||||
|
||||
Ok(results)
|
||||
@@ -53,6 +56,7 @@ pub async fn run_pipeline_with_embedding(
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
let machine = state::ready();
|
||||
let mut ctx = stages::PipelineContext::with_embedding(
|
||||
@@ -62,11 +66,13 @@ pub async fn run_pipeline_with_embedding(
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
config,
|
||||
reranker,
|
||||
);
|
||||
let machine = stages::embed(machine, &mut ctx).await?;
|
||||
let machine = stages::collect_candidates(machine, &mut ctx).await?;
|
||||
let machine = stages::expand_graph(machine, &mut ctx).await?;
|
||||
let machine = stages::attach_chunks(machine, &mut ctx).await?;
|
||||
let machine = stages::rerank(machine, &mut ctx).await?;
|
||||
let results = stages::assemble(machine, &mut ctx)?;
|
||||
|
||||
Ok(results)
|
||||
|
||||
@@ -7,6 +7,7 @@ use common::{
|
||||
},
|
||||
utils::embedding::generate_embedding,
|
||||
};
|
||||
use fastembed::RerankResult;
|
||||
use futures::{stream::FuturesUnordered, StreamExt};
|
||||
use state_machines::core::GuardError;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
@@ -15,6 +16,7 @@ use tracing::{debug, instrument, warn};
|
||||
use crate::{
|
||||
fts::find_items_by_fts,
|
||||
graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids},
|
||||
reranking::RerankerLease,
|
||||
scoring::{
|
||||
clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc,
|
||||
FusionWeights, Scored,
|
||||
@@ -27,6 +29,7 @@ use super::{
|
||||
config::RetrievalConfig,
|
||||
state::{
|
||||
CandidatesLoaded, ChunksAttached, Embedded, GraphExpanded, HybridRetrievalMachine, Ready,
|
||||
Reranked,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -41,6 +44,7 @@ pub struct PipelineContext<'a> {
|
||||
pub chunk_candidates: HashMap<String, Scored<TextChunk>>,
|
||||
pub filtered_entities: Vec<Scored<KnowledgeEntity>>,
|
||||
pub chunk_values: Vec<Scored<TextChunk>>,
|
||||
pub reranker: Option<RerankerLease>,
|
||||
}
|
||||
|
||||
impl<'a> PipelineContext<'a> {
|
||||
@@ -50,6 +54,7 @@ impl<'a> PipelineContext<'a> {
|
||||
input_text: String,
|
||||
user_id: String,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Self {
|
||||
Self {
|
||||
db_client,
|
||||
@@ -62,6 +67,7 @@ impl<'a> PipelineContext<'a> {
|
||||
chunk_candidates: HashMap::new(),
|
||||
filtered_entities: Vec::new(),
|
||||
chunk_values: Vec::new(),
|
||||
reranker,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,8 +79,16 @@ impl<'a> PipelineContext<'a> {
|
||||
input_text: String,
|
||||
user_id: String,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Self {
|
||||
let mut ctx = Self::new(db_client, openai_client, input_text, user_id, config);
|
||||
let mut ctx = Self::new(
|
||||
db_client,
|
||||
openai_client,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
);
|
||||
ctx.query_embedding = Some(query_embedding);
|
||||
ctx
|
||||
}
|
||||
@@ -327,9 +341,58 @@ pub async fn attach_chunks(
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub fn assemble(
|
||||
pub async fn rerank(
|
||||
machine: HybridRetrievalMachine<(), ChunksAttached>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<HybridRetrievalMachine<(), Reranked>, AppError> {
|
||||
let mut applied = false;
|
||||
|
||||
if let Some(reranker) = ctx.reranker.as_ref() {
|
||||
if ctx.filtered_entities.len() > 1 {
|
||||
let documents = build_rerank_documents(ctx, ctx.config.tuning.max_chunks_per_entity);
|
||||
|
||||
if documents.len() > 1 {
|
||||
match reranker.rerank(&ctx.input_text, documents).await {
|
||||
Ok(results) if !results.is_empty() => {
|
||||
apply_rerank_results(ctx, results);
|
||||
applied = true;
|
||||
}
|
||||
Ok(_) => {
|
||||
debug!("Reranker returned no results; retaining original ordering");
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
error = %err,
|
||||
"Reranking failed; continuing with original ordering"
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
debug!(
|
||||
document_count = documents.len(),
|
||||
"Skipping reranking stage; insufficient document context"
|
||||
);
|
||||
}
|
||||
} else {
|
||||
debug!("Skipping reranking stage; less than two entities available");
|
||||
}
|
||||
} else {
|
||||
debug!("No reranker lease provided; skipping reranking stage");
|
||||
}
|
||||
|
||||
if applied {
|
||||
debug!("Applied reranking adjustments to candidate ordering");
|
||||
}
|
||||
|
||||
machine
|
||||
.rerank()
|
||||
.map_err(|(_, guard)| map_guard_error("rerank", guard))
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub fn assemble(
|
||||
machine: HybridRetrievalMachine<(), Reranked>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
debug!("Assembling final retrieved entities");
|
||||
let tuning = &ctx.config.tuning;
|
||||
@@ -561,6 +624,113 @@ async fn enrich_chunks_from_entities(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_rerank_documents(ctx: &PipelineContext<'_>, max_chunks_per_entity: usize) -> Vec<String> {
|
||||
if ctx.filtered_entities.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut chunk_by_source: HashMap<&str, Vec<&Scored<TextChunk>>> = HashMap::new();
|
||||
for chunk in &ctx.chunk_values {
|
||||
chunk_by_source
|
||||
.entry(chunk.item.source_id.as_str())
|
||||
.or_default()
|
||||
.push(chunk);
|
||||
}
|
||||
|
||||
ctx.filtered_entities
|
||||
.iter()
|
||||
.map(|entity| {
|
||||
let mut doc = format!(
|
||||
"Name: {}\nType: {:?}\nDescription: {}\n",
|
||||
entity.item.name, entity.item.entity_type, entity.item.description
|
||||
);
|
||||
|
||||
if let Some(chunks) = chunk_by_source.get(entity.item.source_id.as_str()) {
|
||||
let mut chunk_refs = chunks.clone();
|
||||
chunk_refs.sort_by(|a, b| {
|
||||
b.fused
|
||||
.partial_cmp(&a.fused)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
let mut header_added = false;
|
||||
for chunk in chunk_refs.into_iter().take(max_chunks_per_entity.max(1)) {
|
||||
let snippet = chunk.item.chunk.trim();
|
||||
if snippet.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if !header_added {
|
||||
doc.push_str("Chunks:\n");
|
||||
header_added = true;
|
||||
}
|
||||
doc.push_str("- ");
|
||||
doc.push_str(snippet);
|
||||
doc.push('\n');
|
||||
}
|
||||
}
|
||||
|
||||
doc
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec<RerankResult>) {
|
||||
if results.is_empty() || ctx.filtered_entities.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut remaining: Vec<Option<Scored<KnowledgeEntity>>> =
|
||||
std::mem::take(&mut ctx.filtered_entities)
|
||||
.into_iter()
|
||||
.map(Some)
|
||||
.collect();
|
||||
|
||||
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
|
||||
let normalized_scores = min_max_normalize(&raw_scores);
|
||||
|
||||
let use_only = ctx.config.tuning.rerank_scores_only;
|
||||
let blend = if use_only {
|
||||
1.0
|
||||
} else {
|
||||
clamp_unit(ctx.config.tuning.rerank_blend_weight)
|
||||
};
|
||||
let mut reranked: Vec<Scored<KnowledgeEntity>> = Vec::with_capacity(remaining.len());
|
||||
for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) {
|
||||
if let Some(slot) = remaining.get_mut(result.index) {
|
||||
if let Some(mut candidate) = slot.take() {
|
||||
let original = candidate.fused;
|
||||
let blended = if use_only {
|
||||
clamp_unit(normalized)
|
||||
} else {
|
||||
clamp_unit(original * (1.0 - blend) + normalized * blend)
|
||||
};
|
||||
candidate.update_fused(blended);
|
||||
reranked.push(candidate);
|
||||
}
|
||||
} else {
|
||||
warn!(
|
||||
result_index = result.index,
|
||||
"Reranker returned out-of-range index; skipping"
|
||||
);
|
||||
}
|
||||
if reranked.len() == remaining.len() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for slot in remaining.into_iter() {
|
||||
if let Some(candidate) = slot {
|
||||
reranked.push(candidate);
|
||||
}
|
||||
}
|
||||
|
||||
ctx.filtered_entities = reranked;
|
||||
let keep_top = ctx.config.tuning.rerank_keep_top;
|
||||
if keep_top > 0 && ctx.filtered_entities.len() > keep_top {
|
||||
ctx.filtered_entities.truncate(keep_top);
|
||||
}
|
||||
}
|
||||
|
||||
fn estimate_tokens(text: &str, avg_chars_per_token: usize) -> usize {
|
||||
let chars = text.chars().count().max(1);
|
||||
(chars / avg_chars_per_token).max(1)
|
||||
|
||||
@@ -4,18 +4,20 @@ state_machine! {
|
||||
name: HybridRetrievalMachine,
|
||||
state: HybridRetrievalState,
|
||||
initial: Ready,
|
||||
states: [Ready, Embedded, CandidatesLoaded, GraphExpanded, ChunksAttached, Completed, Failed],
|
||||
states: [Ready, Embedded, CandidatesLoaded, GraphExpanded, ChunksAttached, Reranked, Completed, Failed],
|
||||
events {
|
||||
embed { transition: { from: Ready, to: Embedded } }
|
||||
collect_candidates { transition: { from: Embedded, to: CandidatesLoaded } }
|
||||
expand_graph { transition: { from: CandidatesLoaded, to: GraphExpanded } }
|
||||
attach_chunks { transition: { from: GraphExpanded, to: ChunksAttached } }
|
||||
assemble { transition: { from: ChunksAttached, to: Completed } }
|
||||
rerank { transition: { from: ChunksAttached, to: Reranked } }
|
||||
assemble { transition: { from: Reranked, to: Completed } }
|
||||
abort {
|
||||
transition: { from: Ready, to: Failed }
|
||||
transition: { from: CandidatesLoaded, to: Failed }
|
||||
transition: { from: GraphExpanded, to: Failed }
|
||||
transition: { from: ChunksAttached, to: Failed }
|
||||
transition: { from: Reranked, to: Failed }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
170
composite-retrieval/src/reranking/mod.rs
Normal file
170
composite-retrieval/src/reranking/mod.rs
Normal file
@@ -0,0 +1,170 @@
|
||||
use std::{
|
||||
env, fs,
|
||||
path::{Path, PathBuf},
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
},
|
||||
thread::available_parallelism,
|
||||
};
|
||||
|
||||
use common::{error::AppError, utils::config::AppConfig};
|
||||
use fastembed::{RerankInitOptions, RerankResult, TextRerank};
|
||||
use tokio::sync::{Mutex, OwnedSemaphorePermit, Semaphore};
|
||||
use tracing::debug;
|
||||
|
||||
static NEXT_ENGINE: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
fn pick_engine_index(pool_len: usize) -> usize {
|
||||
let n = NEXT_ENGINE.fetch_add(1, Ordering::Relaxed);
|
||||
n % pool_len
|
||||
}
|
||||
|
||||
pub struct RerankerPool {
|
||||
engines: Vec<Arc<Mutex<TextRerank>>>,
|
||||
semaphore: Arc<Semaphore>,
|
||||
}
|
||||
|
||||
impl RerankerPool {
|
||||
/// Build the pool at startup.
|
||||
/// `pool_size` controls max parallel reranks.
|
||||
pub fn new(pool_size: usize) -> Result<Arc<Self>, AppError> {
|
||||
Self::new_with_options(pool_size, RerankInitOptions::default())
|
||||
}
|
||||
|
||||
fn new_with_options(
|
||||
pool_size: usize,
|
||||
init_options: RerankInitOptions,
|
||||
) -> Result<Arc<Self>, AppError> {
|
||||
if pool_size == 0 {
|
||||
return Err(AppError::Validation(
|
||||
"RERANKING_POOL_SIZE must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
fs::create_dir_all(&init_options.cache_dir)?;
|
||||
|
||||
let mut engines = Vec::with_capacity(pool_size);
|
||||
for x in 0..pool_size {
|
||||
debug!("Creating reranking engine: {x}");
|
||||
let model = TextRerank::try_new(init_options.clone())
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
engines.push(Arc::new(Mutex::new(model)));
|
||||
}
|
||||
|
||||
Ok(Arc::new(Self {
|
||||
engines,
|
||||
semaphore: Arc::new(Semaphore::new(pool_size)),
|
||||
}))
|
||||
}
|
||||
|
||||
/// Initialize a pool using application configuration.
|
||||
pub fn maybe_from_config(config: &AppConfig) -> Result<Option<Arc<Self>>, AppError> {
|
||||
if !config.reranking_enabled {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let pool_size = config.reranking_pool_size.unwrap_or_else(default_pool_size);
|
||||
|
||||
let init_options = build_rerank_init_options(config)?;
|
||||
Self::new_with_options(pool_size, init_options).map(Some)
|
||||
}
|
||||
|
||||
/// Check out capacity + pick an engine.
|
||||
/// This returns a lease that can perform rerank().
|
||||
pub async fn checkout(self: &Arc<Self>) -> RerankerLease {
|
||||
// Acquire a permit. This enforces backpressure.
|
||||
let permit = self
|
||||
.semaphore
|
||||
.clone()
|
||||
.acquire_owned()
|
||||
.await
|
||||
.expect("semaphore closed");
|
||||
|
||||
// Pick an engine.
|
||||
// This is naive: just pick based on a simple modulo counter.
|
||||
// We use an atomic counter to avoid always choosing index 0.
|
||||
let idx = pick_engine_index(self.engines.len());
|
||||
let engine = self.engines[idx].clone();
|
||||
|
||||
RerankerLease {
|
||||
_permit: permit,
|
||||
engine,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_pool_size() -> usize {
|
||||
available_parallelism()
|
||||
.map(|value| value.get().min(2))
|
||||
.unwrap_or(2)
|
||||
.max(1)
|
||||
}
|
||||
|
||||
fn is_truthy(value: &str) -> bool {
|
||||
matches!(
|
||||
value.trim().to_ascii_lowercase().as_str(),
|
||||
"1" | "true" | "yes" | "on"
|
||||
)
|
||||
}
|
||||
|
||||
fn build_rerank_init_options(config: &AppConfig) -> Result<RerankInitOptions, AppError> {
|
||||
let mut options = RerankInitOptions::default();
|
||||
|
||||
let cache_dir = config
|
||||
.fastembed_cache_dir
|
||||
.as_ref()
|
||||
.map(PathBuf::from)
|
||||
.or_else(|| env::var("RERANKING_CACHE_DIR").ok().map(PathBuf::from))
|
||||
.or_else(|| env::var("FASTEMBED_CACHE_DIR").ok().map(PathBuf::from))
|
||||
.unwrap_or_else(|| {
|
||||
Path::new(&config.data_dir)
|
||||
.join("fastembed")
|
||||
.join("reranker")
|
||||
});
|
||||
fs::create_dir_all(&cache_dir)?;
|
||||
options.cache_dir = cache_dir;
|
||||
|
||||
let show_progress = config
|
||||
.fastembed_show_download_progress
|
||||
.or_else(|| env_bool("RERANKING_SHOW_DOWNLOAD_PROGRESS"))
|
||||
.or_else(|| env_bool("FASTEMBED_SHOW_DOWNLOAD_PROGRESS"))
|
||||
.unwrap_or(true);
|
||||
options.show_download_progress = show_progress;
|
||||
|
||||
if let Some(max_length) = config.fastembed_max_length.or_else(|| {
|
||||
env::var("RERANKING_MAX_LENGTH")
|
||||
.ok()
|
||||
.and_then(|value| value.parse().ok())
|
||||
}) {
|
||||
options.max_length = max_length;
|
||||
}
|
||||
|
||||
Ok(options)
|
||||
}
|
||||
|
||||
fn env_bool(key: &str) -> Option<bool> {
|
||||
env::var(key).ok().map(|value| is_truthy(&value))
|
||||
}
|
||||
|
||||
/// Active lease on a single TextRerank instance.
|
||||
pub struct RerankerLease {
|
||||
// When this drops the semaphore permit is released.
|
||||
_permit: OwnedSemaphorePermit,
|
||||
engine: Arc<Mutex<TextRerank>>,
|
||||
}
|
||||
|
||||
impl RerankerLease {
|
||||
pub async fn rerank(
|
||||
&self,
|
||||
query: &str,
|
||||
documents: Vec<String>,
|
||||
) -> Result<Vec<RerankResult>, AppError> {
|
||||
// Lock this specific engine so we get &mut TextRerank
|
||||
let mut guard = self.engine.lock().await;
|
||||
|
||||
guard
|
||||
.rerank(query.to_owned(), documents, false, None)
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))
|
||||
}
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
@@ -1,6 +1,7 @@
|
||||
use common::storage::db::SurrealDbClient;
|
||||
use common::utils::template_engine::{ProvidesTemplateEngine, TemplateEngine};
|
||||
use common::{create_template_engine, storage::db::ProvidesDb, utils::config::AppConfig};
|
||||
use composite_retrieval::reranking::RerankerPool;
|
||||
use std::sync::Arc;
|
||||
use tracing::debug;
|
||||
|
||||
@@ -13,6 +14,7 @@ pub struct HtmlState {
|
||||
pub templates: Arc<TemplateEngine>,
|
||||
pub session_store: Arc<SessionStoreType>,
|
||||
pub config: AppConfig,
|
||||
pub reranker_pool: Option<Arc<RerankerPool>>,
|
||||
}
|
||||
|
||||
impl HtmlState {
|
||||
@@ -21,6 +23,7 @@ impl HtmlState {
|
||||
openai_client: Arc<OpenAIClientType>,
|
||||
session_store: Arc<SessionStoreType>,
|
||||
config: AppConfig,
|
||||
reranker_pool: Option<Arc<RerankerPool>>,
|
||||
) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let template_engine = create_template_engine!("templates");
|
||||
debug!("Template engine created for html_router.");
|
||||
@@ -31,6 +34,7 @@ impl HtmlState {
|
||||
session_store,
|
||||
templates: Arc::new(template_engine),
|
||||
config,
|
||||
reranker_pool,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,11 +118,17 @@ pub async fn get_response_stream(
|
||||
};
|
||||
|
||||
// 2. Retrieve knowledge entities
|
||||
let rerank_lease = match state.reranker_pool.as_ref() {
|
||||
Some(pool) => Some(pool.checkout().await),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let entities = match retrieve_entities(
|
||||
&state.db,
|
||||
&state.openai_client,
|
||||
&user_message.content,
|
||||
&user.id,
|
||||
rerank_lease,
|
||||
)
|
||||
.await
|
||||
{
|
||||
|
||||
@@ -195,8 +195,19 @@ pub async fn suggest_knowledge_relationships(
|
||||
|
||||
if !query_parts.is_empty() {
|
||||
let query = query_parts.join(" ");
|
||||
if let Ok(results) =
|
||||
retrieve_entities(&state.db, &state.openai_client, &query, &user.id).await
|
||||
let rerank_lease = match state.reranker_pool.as_ref() {
|
||||
Some(pool) => Some(pool.checkout().await),
|
||||
None => None,
|
||||
};
|
||||
|
||||
if let Ok(results) = retrieve_entities(
|
||||
&state.db,
|
||||
&state.openai_client,
|
||||
&query,
|
||||
&user.id,
|
||||
rerank_lease,
|
||||
)
|
||||
.await
|
||||
{
|
||||
for RetrievedEntity { entity, score, .. } in results {
|
||||
if suggestion_scores.len() >= MAX_RELATIONSHIP_SUGGESTIONS {
|
||||
|
||||
@@ -26,6 +26,7 @@ use common::{
|
||||
},
|
||||
utils::config::AppConfig,
|
||||
};
|
||||
use composite_retrieval::reranking::RerankerPool;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use self::{
|
||||
@@ -45,9 +46,14 @@ impl IngestionPipeline {
|
||||
db: Arc<SurrealDbClient>,
|
||||
openai_client: Arc<Client<async_openai::config::OpenAIConfig>>,
|
||||
config: AppConfig,
|
||||
reranker_pool: Option<Arc<RerankerPool>>,
|
||||
) -> Result<Self, AppError> {
|
||||
let services =
|
||||
DefaultPipelineServices::new(db.clone(), openai_client.clone(), config.clone());
|
||||
let services = DefaultPipelineServices::new(
|
||||
db.clone(),
|
||||
openai_client.clone(),
|
||||
config.clone(),
|
||||
reranker_pool,
|
||||
);
|
||||
|
||||
Self::with_services(db, IngestionConfig::default(), Arc::new(services))
|
||||
}
|
||||
|
||||
@@ -18,7 +18,9 @@ use common::{
|
||||
},
|
||||
utils::{config::AppConfig, embedding::generate_embedding},
|
||||
};
|
||||
use composite_retrieval::{retrieve_entities, retrieved_entities_to_json, RetrievedEntity};
|
||||
use composite_retrieval::{
|
||||
reranking::RerankerPool, retrieve_entities, retrieved_entities_to_json, RetrievedEntity,
|
||||
};
|
||||
use text_splitter::TextSplitter;
|
||||
|
||||
use super::{enrichment_result::LLMEnrichmentResult, preparation::to_text_content};
|
||||
@@ -62,6 +64,7 @@ pub struct DefaultPipelineServices {
|
||||
db: Arc<SurrealDbClient>,
|
||||
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
|
||||
config: AppConfig,
|
||||
reranker_pool: Option<Arc<RerankerPool>>,
|
||||
}
|
||||
|
||||
impl DefaultPipelineServices {
|
||||
@@ -69,11 +72,13 @@ impl DefaultPipelineServices {
|
||||
db: Arc<SurrealDbClient>,
|
||||
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
|
||||
config: AppConfig,
|
||||
reranker_pool: Option<Arc<RerankerPool>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
db,
|
||||
openai_client,
|
||||
config,
|
||||
reranker_pool,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,7 +156,19 @@ impl PipelineServices for DefaultPipelineServices {
|
||||
content.text, content.category, content.context
|
||||
);
|
||||
|
||||
retrieve_entities(&self.db, &self.openai_client, &input_text, &content.user_id).await
|
||||
let rerank_lease = match &self.reranker_pool {
|
||||
Some(pool) => Some(pool.checkout().await),
|
||||
None => None,
|
||||
};
|
||||
|
||||
retrieve_entities(
|
||||
&self.db,
|
||||
&self.openai_client,
|
||||
&input_text,
|
||||
&content.user_id,
|
||||
rerank_lease,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn run_enrichment(
|
||||
|
||||
@@ -25,6 +25,7 @@ ingestion-pipeline = { path = "../ingestion-pipeline" }
|
||||
api-router = { path = "../api-router" }
|
||||
html-router = { path = "../html-router" }
|
||||
common = { path = "../common" }
|
||||
composite-retrieval = { path = "../composite-retrieval" }
|
||||
|
||||
[dev-dependencies]
|
||||
tower = "0.5"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use api_router::{api_routes_v1, api_state::ApiState};
|
||||
use axum::{extract::FromRef, Router};
|
||||
use common::{storage::db::SurrealDbClient, utils::config::get_config};
|
||||
use composite_retrieval::reranking::RerankerPool;
|
||||
use html_router::{html_routes, html_state::HtmlState};
|
||||
use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop};
|
||||
use std::sync::Arc;
|
||||
@@ -43,8 +44,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.with_api_base(&config.openai_base_url),
|
||||
));
|
||||
|
||||
let html_state =
|
||||
HtmlState::new_with_resources(db, openai_client, session_store, config.clone())?;
|
||||
let reranker_pool = RerankerPool::maybe_from_config(&config)?;
|
||||
|
||||
let html_state = HtmlState::new_with_resources(
|
||||
db,
|
||||
openai_client,
|
||||
session_store,
|
||||
config.clone(),
|
||||
reranker_pool.clone(),
|
||||
)?;
|
||||
|
||||
let api_state = ApiState {
|
||||
db: html_state.db.clone(),
|
||||
@@ -102,9 +110,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.with_api_base(&config.openai_base_url),
|
||||
));
|
||||
let ingestion_pipeline = Arc::new(
|
||||
IngestionPipeline::new(worker_db.clone(), openai_client.clone(), config.clone())
|
||||
.await
|
||||
.unwrap(),
|
||||
IngestionPipeline::new(
|
||||
worker_db.clone(),
|
||||
openai_client.clone(),
|
||||
config.clone(),
|
||||
reranker_pool.clone(),
|
||||
)
|
||||
.await
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
info!("Starting worker process");
|
||||
@@ -152,6 +165,7 @@ mod tests {
|
||||
openai_base_url: "https://example.com".into(),
|
||||
storage: StorageKind::Local,
|
||||
pdf_ingest_mode: PdfIngestMode::LlmFirst,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -181,9 +195,14 @@ mod tests {
|
||||
.with_api_base(&config.openai_base_url),
|
||||
));
|
||||
|
||||
let html_state =
|
||||
HtmlState::new_with_resources(db.clone(), openai_client, session_store, config.clone())
|
||||
.expect("failed to build html state");
|
||||
let html_state = HtmlState::new_with_resources(
|
||||
db.clone(),
|
||||
openai_client,
|
||||
session_store,
|
||||
config.clone(),
|
||||
None,
|
||||
)
|
||||
.expect("failed to build html state");
|
||||
|
||||
let api_state = ApiState {
|
||||
db: html_state.db.clone(),
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::sync::Arc;
|
||||
use api_router::{api_routes_v1, api_state::ApiState};
|
||||
use axum::{extract::FromRef, Router};
|
||||
use common::{storage::db::SurrealDbClient, utils::config::get_config};
|
||||
use composite_retrieval::reranking::RerankerPool;
|
||||
use html_router::{html_routes, html_state::HtmlState};
|
||||
use tracing::info;
|
||||
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
|
||||
@@ -41,8 +42,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.with_api_base(&config.openai_base_url),
|
||||
));
|
||||
|
||||
let html_state =
|
||||
HtmlState::new_with_resources(db, openai_client, session_store, config.clone())?;
|
||||
let reranker_pool = RerankerPool::maybe_from_config(&config)?;
|
||||
|
||||
let html_state = HtmlState::new_with_resources(
|
||||
db,
|
||||
openai_client,
|
||||
session_store,
|
||||
config.clone(),
|
||||
reranker_pool,
|
||||
)?;
|
||||
|
||||
let api_state = ApiState {
|
||||
db: html_state.db.clone(),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use common::{storage::db::SurrealDbClient, utils::config::get_config};
|
||||
use composite_retrieval::reranking::RerankerPool;
|
||||
use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop};
|
||||
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
|
||||
|
||||
@@ -32,8 +33,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.with_api_base(&config.openai_base_url),
|
||||
));
|
||||
|
||||
let ingestion_pipeline =
|
||||
Arc::new(IngestionPipeline::new(db.clone(), openai_client.clone(), config).await?);
|
||||
let reranker_pool = RerankerPool::maybe_from_config(&config)?;
|
||||
|
||||
let ingestion_pipeline = Arc::new(
|
||||
IngestionPipeline::new(db.clone(), openai_client.clone(), config, reranker_pool).await?,
|
||||
);
|
||||
|
||||
run_worker_loop(db, ingestion_pipeline).await
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user