retrieval simplfied

This commit is contained in:
Per Stark
2025-12-09 20:35:42 +01:00
parent a8d10f265c
commit a090a8c76e
55 changed files with 469 additions and 1208 deletions

View File

@@ -28,19 +28,14 @@ fn default_ingestion_cache_dir() -> PathBuf {
pub const DEFAULT_SLICE_SEED: u64 = 0x5eed_2025;
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum, Default)]
#[value(rename_all = "lowercase")]
pub enum EmbeddingBackend {
Hashed,
#[default]
FastEmbed,
}
impl Default for EmbeddingBackend {
fn default() -> Self {
Self::FastEmbed
}
}
impl std::fmt::Display for EmbeddingBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
@@ -109,7 +104,7 @@ pub struct RetrievalSettings {
pub require_verified_chunks: bool,
/// Select the retrieval pipeline strategy
#[arg(long, default_value_t = RetrievalStrategy::Initial)]
#[arg(long, default_value_t = RetrievalStrategy::Default)]
pub strategy: RetrievalStrategy,
}
@@ -130,7 +125,7 @@ impl Default for RetrievalSettings {
chunk_rrf_use_vector: None,
chunk_rrf_use_fts: None,
require_verified_chunks: true,
strategy: RetrievalStrategy::Initial,
strategy: RetrievalStrategy::Default,
}
}
}
@@ -378,11 +373,7 @@ impl Config {
self.summary_sample = self.sample.max(1);
// Handle retrieval settings
if self.llm_mode {
self.retrieval.require_verified_chunks = false;
} else {
self.retrieval.require_verified_chunks = true;
}
self.retrieval.require_verified_chunks = !self.llm_mode;
if self.dataset == DatasetKind::Beir {
self.negative_multiplier = 9.0;

View File

@@ -14,13 +14,13 @@ pub use store::{
};
pub fn make_ingestion_config(config: &crate::args::Config) -> ingestion_pipeline::IngestionConfig {
let mut tuning = ingestion_pipeline::IngestionTuning::default();
tuning.chunk_min_tokens = config.ingest.ingest_chunk_min_tokens;
tuning.chunk_max_tokens = config.ingest.ingest_chunk_max_tokens;
tuning.chunk_overlap_tokens = config.ingest.ingest_chunk_overlap_tokens;
ingestion_pipeline::IngestionConfig {
tuning,
tuning: ingestion_pipeline::IngestionTuning {
chunk_min_tokens: config.ingest.ingest_chunk_min_tokens,
chunk_max_tokens: config.ingest.ingest_chunk_max_tokens,
chunk_overlap_tokens: config.ingest.ingest_chunk_overlap_tokens,
..Default::default()
},
chunk_only: config.ingest.ingest_chunks_only,
}
}

View File

@@ -106,6 +106,7 @@ struct IngestionStats {
negative_ingested: usize,
}
#[allow(clippy::too_many_arguments)]
pub async fn ensure_corpus(
dataset: &ConvertedDataset,
slice: &ResolvedSlice<'_>,
@@ -337,11 +338,9 @@ pub async fn ensure_corpus(
});
}
for record in &mut records {
if let Some(ref mut entry) = record {
if entry.dirty {
store.persist(&entry.shard)?;
}
for entry in records.iter_mut().flatten() {
if entry.dirty {
store.persist(&entry.shard)?;
}
}
@@ -403,6 +402,7 @@ pub async fn ensure_corpus(
Ok(handle)
}
#[allow(clippy::too_many_arguments)]
async fn ingest_paragraph_batch(
dataset: &ConvertedDataset,
targets: &[IngestRequest<'_>],
@@ -430,8 +430,10 @@ async fn ingest_paragraph_batch(
.await
.context("applying migrations for ingestion")?;
let mut app_config = AppConfig::default();
app_config.storage = StorageKind::Memory;
let app_config = AppConfig {
storage: StorageKind::Memory,
..Default::default()
};
let backend: DynStore = Arc::new(InMemory::new());
let storage = StorageManager::with_backend(backend, StorageKind::Memory);
@@ -444,8 +446,7 @@ async fn ingest_paragraph_batch(
storage,
embedding.clone(),
pipeline_config,
)
.await?;
)?;
let pipeline = Arc::new(pipeline);
let mut shards = Vec::with_capacity(targets.len());
@@ -454,7 +455,7 @@ async fn ingest_paragraph_batch(
info!(
batch = batch_index,
batch_size = batch.len(),
total_batches = (targets.len() + batch_size - 1) / batch_size,
total_batches = targets.len().div_ceil(batch_size),
"Ingesting paragraph batch"
);
let model_clone = embedding_model.clone();
@@ -486,6 +487,7 @@ async fn ingest_paragraph_batch(
Ok(shards)
}
#[allow(clippy::too_many_arguments)]
async fn ingest_single_paragraph(
pipeline: Arc<IngestionPipeline>,
request: IngestRequest<'_>,

View File

@@ -481,6 +481,7 @@ impl ParagraphShardStore {
}
impl ParagraphShard {
#[allow(clippy::too_many_arguments)]
pub fn new(
paragraph: &ConvertedParagraph,
shard_path: String,
@@ -674,10 +675,8 @@ async fn execute_batched_inserts<T: Clone + Serialize + 'static>(
let slice = &batches[start..group_end];
let mut query = db.client.query("BEGIN TRANSACTION;");
let mut bind_index = 0usize;
for batch in slice {
for (bind_index, batch) in slice.iter().enumerate() {
let name = format!("{prefix}{bind_index}");
bind_index += 1;
query = query
.query(format!("{} ${};", statement.as_ref(), name))
.bind((name, batch.items.clone()));
@@ -702,7 +701,7 @@ async fn execute_batched_inserts<T: Clone + Serialize + 'static>(
pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManifest) -> Result<()> {
let batches = build_manifest_batches(manifest).context("preparing manifest batches")?;
let result = (|| async {
let result = async {
execute_batched_inserts(
db,
format!("INSERT INTO {}", TextContent::table_name()),
@@ -752,7 +751,7 @@ pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManife
.await?;
Ok(())
})()
}
.await;
if result.is_err() {
@@ -778,7 +777,6 @@ pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManife
#[cfg(test)]
mod tests {
use super::*;
use crate::db_helpers::change_embedding_length_in_hnsw_indexes;
use chrono::Utc;
use common::storage::types::knowledge_entity::KnowledgeEntityType;
use uuid::Uuid;
@@ -905,9 +903,6 @@ mod tests {
db.apply_migrations()
.await
.expect("apply migrations for memory db");
change_embedding_length_in_hnsw_indexes(&db, 3)
.await
.expect("set embedding index dimension for test");
let manifest = build_manifest();
seed_manifest_into_db(&db, &manifest)

View File

@@ -245,8 +245,9 @@ fn dataset_entry_for_kind(kind: DatasetKind) -> Result<&'static DatasetEntry> {
catalog.dataset(kind.id())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum, Default)]
pub enum DatasetKind {
#[default]
SquadV2,
NaturalQuestions,
Beir,
@@ -368,12 +369,6 @@ impl std::fmt::Display for DatasetKind {
}
}
impl Default for DatasetKind {
fn default() -> Self {
Self::SquadV2
}
}
impl FromStr for DatasetKind {
type Err = anyhow::Error;

View File

@@ -36,13 +36,14 @@ pub async fn reset_namespace(db: &SurrealDbClient, namespace: &str, database: &s
Ok(())
}
// Test helper to force index dimension change
pub async fn change_embedding_length_in_hnsw_indexes(
db: &SurrealDbClient,
dimension: usize,
) -> Result<()> {
recreate_indexes(db, dimension).await
}
// // Test helper to force index dimension change
// #[allow(dead_code)]
// pub async fn change_embedding_length_in_hnsw_indexes(
// db: &SurrealDbClient,
// dimension: usize,
// ) -> Result<()> {
// recreate_indexes(db, dimension).await
// }
#[cfg(test)]
mod tests {

View File

@@ -86,6 +86,7 @@ pub(crate) async fn namespace_has_corpus(db: &SurrealDbClient) -> Result<bool> {
}
/// Determine if we can reuse an existing namespace based on cached state.
#[allow(clippy::too_many_arguments)]
pub(crate) async fn can_reuse_namespace(
db: &SurrealDbClient,
descriptor: &snapshot::Descriptor,
@@ -213,7 +214,7 @@ pub(crate) async fn ensure_eval_user(db: &SurrealDbClient) -> Result<User> {
timezone: "UTC".to_string(),
};
if let Some(existing) = db.get_item::<User>(&user.get_id()).await? {
if let Some(existing) = db.get_item::<User>(user.get_id()).await? {
return Ok(existing);
}

View File

@@ -154,7 +154,7 @@ impl<'a> EvaluationContext<'a> {
}
pub fn record_stage_duration(&mut self, stage: EvalStage, duration: Duration) {
let elapsed = duration.as_millis() as u128;
let elapsed = duration.as_millis();
match stage {
EvalStage::PrepareSlice => self.stage_timings.prepare_slice_ms += elapsed,
EvalStage::PrepareDb => self.stage_timings.prepare_db_ms += elapsed,

View File

@@ -21,9 +21,7 @@ pub async fn run_evaluation(
let machine = stages::prepare_namespace(machine, &mut ctx).await?;
let machine = stages::run_queries(machine, &mut ctx).await?;
let machine = stages::summarize(machine, &mut ctx).await?;
let machine = stages::finalize(machine, &mut ctx).await?;
drop(machine);
let _ = stages::finalize(machine, &mut ctx).await?;
Ok(ctx.into_summary())
}

View File

@@ -113,7 +113,7 @@ pub(crate) async fn prepare_corpus(
.metadata
.ingestion_fingerprint
.clone();
let ingestion_duration_ms = ingestion_timer.elapsed().as_millis() as u128;
let ingestion_duration_ms = ingestion_timer.elapsed().as_millis();
info!(
cache = %corpus_handle.path.display(),
reused_ingestion = corpus_handle.reused_ingestion,

View File

@@ -119,7 +119,7 @@ pub(crate) async fn prepare_namespace(
corpus::seed_manifest_into_db(ctx.db(), &manifest_for_seed)
.await
.context("seeding ingestion corpus from manifest")?;
namespace_seed_ms = Some(seed_start.elapsed().as_millis() as u128);
namespace_seed_ms = Some(seed_start.elapsed().as_millis());
// Recreate indexes AFTER data is loaded (correct bulk loading pattern)
if indexes_disabled {

View File

@@ -50,8 +50,10 @@ pub(crate) async fn run_queries(
None
};
let mut retrieval_config = RetrievalConfig::default();
retrieval_config.strategy = config.retrieval.strategy;
let mut retrieval_config = RetrievalConfig {
strategy: config.retrieval.strategy,
..Default::default()
};
retrieval_config.tuning.rerank_keep_top = config.retrieval.rerank_keep_top;
if retrieval_config.tuning.fallback_min_results < config.retrieval.rerank_keep_top {
retrieval_config.tuning.fallback_min_results = config.retrieval.rerank_keep_top;
@@ -213,7 +215,7 @@ pub(crate) async fn run_queries(
.with_context(|| format!("running pipeline for question {}", question_id))?;
(outcome.results, None, outcome.stage_timings)
};
let query_latency = query_start.elapsed().as_millis() as u128;
let query_latency = query_start.elapsed().as_millis();
let candidates = adapt_strategy_output(result_output);
let mut retrieved = Vec::new();

View File

@@ -436,8 +436,8 @@ pub fn full_window<'a>(resolved: &'a ResolvedSlice<'a>) -> Result<SliceWindow<'a
select_window(resolved, 0, None)
}
fn load_explicit_slice<'a>(
dataset: &'a ConvertedDataset,
fn load_explicit_slice(
dataset: &ConvertedDataset,
index: &DatasetIndex,
config: &SliceConfig<'_>,
slice_arg: &str,