use std::{ collections::{HashMap, HashSet}, fmt::Write, fs, path::{Path, PathBuf}, }; use anyhow::{Context, Result, anyhow}; use chrono::{DateTime, Utc}; use rand::{SeedableRng, rngs::StdRng, seq::SliceRandom}; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use tracing::{info, warn}; use crate::{ args::Config, datasets::{ConvertedDataset, ConvertedParagraph, ConvertedQuestion, DatasetKind}, }; mod beir; mod build; use build::{BuildParams, mix_seed}; const SLICE_VERSION: u32 = 2; pub const DEFAULT_NEGATIVE_MULTIPLIER: f32 = 4.0; #[derive(Debug, Clone)] pub struct SliceConfig<'a> { pub cache_dir: &'a Path, pub force_convert: bool, pub explicit_slice: Option<&'a str>, pub limit: Option, pub corpus_limit: Option, pub slice_seed: u64, pub llm_mode: bool, pub negative_multiplier: f32, pub require_verified_chunks: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SliceManifest { pub version: u32, pub slice_id: String, pub dataset_id: String, pub dataset_label: String, pub dataset_source: String, pub includes_unanswerable: bool, #[serde(default = "default_require_verified_chunks")] pub require_verified_chunks: bool, pub seed: u64, pub requested_limit: Option, pub requested_corpus: usize, pub generated_at: DateTime, pub case_count: usize, pub positive_paragraphs: usize, pub negative_paragraphs: usize, pub total_paragraphs: usize, pub negative_multiplier: f32, pub cases: Vec, pub paragraphs: Vec, } fn default_require_verified_chunks() -> bool { false } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SliceCaseEntry { pub question_id: String, pub paragraph_id: String, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SliceParagraphEntry { pub id: String, pub kind: SliceParagraphKind, #[serde(default, skip_serializing_if = "Option::is_none")] pub shard_path: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "kind", rename_all = "snake_case")] pub enum SliceParagraphKind { Positive { question_ids: Vec }, Negative, } pub fn paragraph_storage_key(paragraph_id: &str) -> String { sanitize_identifier(paragraph_id) } pub(crate) fn default_shard_path(paragraph_id: &str) -> String { let sanitized = paragraph_storage_key(paragraph_id); format!("paragraphs/{sanitized}.json") } fn sanitize_identifier(input: &str) -> String { let mut sanitized = String::with_capacity(input.len()); for ch in input.chars() { if ch.is_ascii_alphanumeric() { sanitized.push(ch); } else { sanitized.push('-'); } } let trimmed = sanitized.trim_matches('-').to_string(); if trimmed.is_empty() { let mut hasher = Sha256::new(); hasher.update(input.as_bytes()); let digest = hasher.finalize(); digest .iter() .take(6) .fold(String::with_capacity(12), |mut s, b| { let _ = write!(s, "{b:02x}"); s }) } else { trimmed } } #[derive(Debug, Clone)] pub struct ResolvedSlice<'a> { pub manifest: SliceManifest, pub path: PathBuf, pub paragraphs: Vec<&'a ConvertedParagraph>, pub cases: Vec>, } #[derive(Debug, Clone)] pub struct SliceWindow<'a> { pub offset: usize, pub length: usize, pub total_cases: usize, pub cases: Vec>, positive_paragraph_ids: Vec, } impl SliceWindow<'_> { pub fn positive_ids(&self) -> impl Iterator { self.positive_paragraph_ids .iter() .map(std::string::String::as_str) } } #[derive(Debug, Clone)] pub struct CaseRef<'a> { pub paragraph: &'a ConvertedParagraph, pub question: &'a ConvertedQuestion, } struct DatasetIndex { paragraph_by_id: HashMap, question_by_id: HashMap, } impl DatasetIndex { fn build(dataset: &ConvertedDataset) -> Self { let mut paragraph_by_id = HashMap::new(); let mut question_by_id = HashMap::new(); for (p_idx, paragraph) in dataset.paragraphs.iter().enumerate() { paragraph_by_id.insert(paragraph.id.clone(), p_idx); for (q_idx, question) in paragraph.questions.iter().enumerate() { question_by_id.insert(question.id.clone(), (p_idx, q_idx)); } } Self { paragraph_by_id, question_by_id, } } fn paragraph<'a>( &self, dataset: &'a ConvertedDataset, id: &str, ) -> Result<&'a ConvertedParagraph> { let idx = self .paragraph_by_id .get(id) .ok_or_else(|| anyhow!("slice references unknown paragraph '{id}'"))?; dataset .paragraphs .get(*idx) .ok_or_else(|| anyhow!("paragraph index out of bounds")) } fn question<'a>( &self, dataset: &'a ConvertedDataset, question_id: &str, ) -> Result<(&'a ConvertedParagraph, &'a ConvertedQuestion)> { let (p_idx, q_idx) = self .question_by_id .get(question_id) .ok_or_else(|| anyhow!("slice references unknown question '{question_id}'"))?; let paragraph = dataset .paragraphs .get(*p_idx) .ok_or_else(|| anyhow!("paragraph index out of bounds for question '{question_id}'"))?; let question = paragraph .questions .get(*q_idx) .ok_or_else(|| anyhow!("slice maps question '{question_id}' to missing index"))?; Ok((paragraph, question)) } } #[derive(Debug, Serialize)] struct SliceKey<'a> { dataset_id: &'a str, includes_unanswerable: bool, require_verified_chunks: bool, requested_corpus: usize, seed: u64, } #[allow(clippy::too_many_lines)] pub fn resolve_slice<'a>( dataset: &'a ConvertedDataset, config: &SliceConfig<'_>, ) -> Result> { let index = DatasetIndex::build(dataset); if let Some(slice_arg) = config.explicit_slice { let path = explicit_slice_path(dataset, config, slice_arg); if path.exists() { let (path, manifest) = load_explicit_slice(dataset, &index, config, slice_arg)?; let resolved = manifest_to_resolved(dataset, &index, manifest, path)?; info!( slice = %resolved.manifest.slice_id, path = %resolved.path.display(), cases = resolved.manifest.case_count, positives = resolved.manifest.positive_paragraphs, negatives = resolved.manifest.negative_paragraphs, "Using explicitly selected slice" ); return Ok(resolved); } let resolved = materialize_slice_ledger(dataset, config, &index, slice_arg, path)?; info!( slice = %resolved.manifest.slice_id, path = %resolved.path.display(), cases = resolved.manifest.case_count, positives = resolved.manifest.positive_paragraphs, negatives = resolved.manifest.negative_paragraphs, "Built catalog slice ledger" ); return Ok(resolved); } let requested_corpus = config .corpus_limit .unwrap_or(dataset.paragraphs.len()) .min(dataset.paragraphs.len()) .max(1); let key = SliceKey { dataset_id: dataset.metadata.id.as_str(), includes_unanswerable: config.llm_mode, require_verified_chunks: config.require_verified_chunks, requested_corpus, seed: config.slice_seed, }; let slice_id = compute_slice_id(&key)?; let base = config .cache_dir .join("slices") .join(dataset.metadata.id.as_str()); let path = base.join(format!("{slice_id}.json")); materialize_slice_ledger(dataset, config, &index, &slice_id, path) } #[allow(clippy::indexing_slicing, clippy::arithmetic_side_effects)] pub fn select_window<'a>( resolved: &'a ResolvedSlice<'a>, offset: usize, limit: Option, ) -> Result> { let total = resolved.manifest.case_count; if total == 0 { return Err(anyhow!( "slice '{}' contains no cases", resolved.manifest.slice_id )); } if offset >= total { return Err(anyhow!( "slice offset {offset} exceeds available cases ({total})", )); } let available = total - offset; let requested = limit.unwrap_or(available).max(1); let length = requested.min(available); let cases = resolved.cases[offset..offset + length].to_vec(); let mut seen = HashSet::new(); let mut positive_ids = Vec::new(); for case in &cases { if seen.insert(case.paragraph.id.as_str()) { positive_ids.push(case.paragraph.id.clone()); } } Ok(SliceWindow { offset, length, total_cases: total, cases, positive_paragraph_ids: positive_ids, }) } #[allow(dead_code)] pub fn full_window<'a>(resolved: &'a ResolvedSlice<'a>) -> Result> { select_window(resolved, 0, None) } fn explicit_slice_path( dataset: &ConvertedDataset, config: &SliceConfig<'_>, slice_arg: &str, ) -> PathBuf { let explicit_path = Path::new(slice_arg); if explicit_path.exists() { explicit_path.to_path_buf() } else { config .cache_dir .join("slices") .join(dataset.metadata.id.as_str()) .join(format!("{slice_arg}.json")) } } #[allow(clippy::too_many_lines)] fn materialize_slice_ledger<'a>( dataset: &'a ConvertedDataset, config: &SliceConfig<'_>, index: &DatasetIndex, slice_id: &str, path: PathBuf, ) -> Result> { let requested_corpus = config .corpus_limit .unwrap_or(dataset.paragraphs.len()) .min(dataset.paragraphs.len()) .max(1); let total_questions = dataset .paragraphs .iter() .map(|p| p.questions.len()) .sum::() .max(1); let requested_limit = config .limit .unwrap_or(total_questions) .min(total_questions) .max(1); let mut manifest = if !config.force_convert && path.exists() { match read_manifest(&path) { Ok(manifest) if manifest.dataset_id == dataset.metadata.id => { if manifest.includes_unanswerable != config.llm_mode { warn!( slice = manifest.slice_id, path = %path.display(), expected = config.llm_mode, found = manifest.includes_unanswerable, "Slice manifest includes_unanswerable mismatch; regenerating" ); None } else if manifest.require_verified_chunks != config.require_verified_chunks { warn!( slice = manifest.slice_id, path = %path.display(), expected = config.require_verified_chunks, found = manifest.require_verified_chunks, "Slice manifest verified-chunk requirement mismatch; regenerating" ); None } else { Some(manifest) } } Ok(manifest) => { warn!( slice = manifest.slice_id, path = %path.display(), loaded_dataset = %manifest.dataset_id, expected = %dataset.metadata.id, "Slice manifest targets different dataset; regenerating" ); None } Err(err) => { warn!( path = %path.display(), error = %err, "Failed to read cached slice; regenerating" ); None } } } else { None }; let params = BuildParams { include_impossible: config.llm_mode, base_seed: config.slice_seed, rng_seed: mix_seed(dataset.metadata.id.as_str(), config.slice_seed), }; if manifest .as_ref() .is_some_and(|manifest| manifest.version != SLICE_VERSION) { warn!( slice = manifest.as_ref().map_or("unknown", |m| m.slice_id.as_str()), found = manifest.as_ref().map_or(0, |m| m.version), expected = SLICE_VERSION, "Slice manifest version mismatch; regenerating" ); manifest = None; } let mut manifest = manifest.unwrap_or_else(|| { empty_manifest( dataset, slice_id.to_string(), ¶ms, requested_corpus, config.negative_multiplier, config.require_verified_chunks, config.limit, ) }); manifest.requested_limit = config.limit; manifest.requested_corpus = requested_corpus; manifest.negative_multiplier = config.negative_multiplier; manifest.includes_unanswerable = config.llm_mode; manifest.require_verified_chunks = config.require_verified_chunks; let mut changed = ensure_shard_paths(&mut manifest); changed |= ensure_case_capacity(dataset, &mut manifest, ¶ms, requested_limit)?; refresh_manifest_stats(&mut manifest); let desired_negatives = desired_negative_target( manifest.positive_paragraphs, requested_corpus, dataset.paragraphs.len(), config.negative_multiplier, ); changed |= ensure_negative_pool( dataset, &mut manifest, ¶ms, desired_negatives, requested_corpus, ); refresh_manifest_stats(&mut manifest); if changed { manifest.generated_at = Utc::now(); write_manifest(&path, &manifest)?; info!( slice = %manifest.slice_id, path = %path.display(), cases = manifest.case_count, positives = manifest.positive_paragraphs, negatives = manifest.negative_paragraphs, "Updated dataset slice ledger" ); } else { info!( slice = %manifest.slice_id, path = %path.display(), cases = manifest.case_count, positives = manifest.positive_paragraphs, negatives = manifest.negative_paragraphs, "Reusing cached slice ledger" ); } manifest_to_resolved(dataset, index, manifest, path) } fn load_explicit_slice( dataset: &ConvertedDataset, index: &DatasetIndex, config: &SliceConfig<'_>, slice_arg: &str, ) -> Result<(PathBuf, SliceManifest)> { let candidate_path = explicit_slice_path(dataset, config, slice_arg); let manifest = read_manifest(&candidate_path) .with_context(|| format!("reading slice manifest at {}", candidate_path.display()))?; if manifest.dataset_id != dataset.metadata.id { return Err(anyhow!( "slice '{}' targets dataset '{}', but '{}' is loaded", manifest.slice_id, manifest.dataset_id, dataset.metadata.id )); } if manifest.includes_unanswerable != config.llm_mode { return Err(anyhow!( "slice '{}' includes_unanswerable mismatch (expected {}, found {})", manifest.slice_id, config.llm_mode, manifest.includes_unanswerable )); } if manifest.require_verified_chunks != config.require_verified_chunks { return Err(anyhow!( "slice '{}' verified-chunk requirement mismatch (expected {}, found {})", manifest.slice_id, config.require_verified_chunks, manifest.require_verified_chunks )); } // Validate the manifest before returning. manifest_to_resolved(dataset, index, manifest.clone(), candidate_path.clone())?; Ok((candidate_path, manifest)) } fn empty_manifest( dataset: &ConvertedDataset, slice_id: String, params: &BuildParams, requested_corpus: usize, negative_multiplier: f32, require_verified_chunks: bool, requested_limit: Option, ) -> SliceManifest { SliceManifest { version: SLICE_VERSION, slice_id, dataset_id: dataset.metadata.id.clone(), dataset_label: dataset.metadata.label.clone(), dataset_source: dataset.source.clone(), includes_unanswerable: params.include_impossible, require_verified_chunks, seed: params.base_seed, requested_limit, requested_corpus, negative_multiplier, generated_at: Utc::now(), case_count: 0, positive_paragraphs: 0, negative_paragraphs: 0, total_paragraphs: 0, cases: Vec::new(), paragraphs: Vec::new(), } } #[allow(clippy::indexing_slicing, clippy::arithmetic_side_effects)] fn ensure_case_capacity( dataset: &ConvertedDataset, manifest: &mut SliceManifest, params: &BuildParams, target_cases: usize, ) -> Result { if manifest.case_count >= target_cases { return Ok(false); } let question_refs = ordered_question_refs(dataset, params, target_cases)?; let mut existing_questions: HashSet = manifest .cases .iter() .map(|case| case.question_id.clone()) .collect(); let mut paragraph_positions: HashMap = manifest .paragraphs .iter() .enumerate() .map(|(idx, entry)| (entry.id.clone(), idx)) .collect(); let mut changed = false; for (p_idx, q_idx) in question_refs { if manifest.case_count >= target_cases { break; } let paragraph = &dataset.paragraphs[p_idx]; let question = ¶graph.questions[q_idx]; if !existing_questions.insert(question.id.clone()) { continue; } if let Some(idx) = paragraph_positions.get(paragraph.id.as_str()).copied() { match &mut manifest.paragraphs[idx].kind { SliceParagraphKind::Positive { question_ids } => { if !question_ids.contains(&question.id) { question_ids.push(question.id.clone()); } } SliceParagraphKind::Negative => { manifest.paragraphs[idx].kind = SliceParagraphKind::Positive { question_ids: vec![question.id.clone()], }; } } } else { manifest.paragraphs.push(SliceParagraphEntry { id: paragraph.id.clone(), kind: SliceParagraphKind::Positive { question_ids: vec![question.id.clone()], }, shard_path: Some(default_shard_path(¶graph.id)), }); let idx = manifest.paragraphs.len() - 1; paragraph_positions.insert(paragraph.id.clone(), idx); } manifest.cases.push(SliceCaseEntry { question_id: question.id.clone(), paragraph_id: paragraph.id.clone(), }); manifest.case_count += 1; changed = true; } if manifest.case_count < target_cases { return Err(anyhow!( "only {}/{} eligible questions available for dataset {}", manifest.case_count, target_cases, dataset.metadata.id )); } Ok(changed) } fn ordered_question_refs( dataset: &ConvertedDataset, params: &BuildParams, target_cases: usize, ) -> Result> { if dataset.metadata.id == DatasetKind::Beir.id() { return beir::ordered_question_refs_beir(dataset, params, target_cases); } let mut question_refs = Vec::new(); for (p_idx, paragraph) in dataset.paragraphs.iter().enumerate() { for (q_idx, question) in paragraph.questions.iter().enumerate() { let include = if params.include_impossible { true } else { !question.is_impossible && !question.answers.is_empty() }; if include { question_refs.push((p_idx, q_idx)); } } } if question_refs.is_empty() { return Err(anyhow!( "no eligible questions found for dataset {}; cannot build slice", dataset.metadata.id )); } let mut rng = StdRng::seed_from_u64(params.rng_seed); question_refs.shuffle(&mut rng); Ok(question_refs) } #[allow(clippy::indexing_slicing)] fn ensure_negative_pool( dataset: &ConvertedDataset, manifest: &mut SliceManifest, params: &BuildParams, target_negatives: usize, requested_corpus: usize, ) -> bool { let current_negatives = manifest .paragraphs .iter() .filter(|entry| matches!(entry.kind, SliceParagraphKind::Negative)) .count(); if current_negatives >= target_negatives { return false; } let positive_ids: HashSet = manifest .paragraphs .iter() .filter_map(|entry| match entry.kind { SliceParagraphKind::Positive { .. } => Some(entry.id.clone()), SliceParagraphKind::Negative => None, }) .collect(); let mut negative_ids: HashSet = manifest .paragraphs .iter() .filter_map(|entry| match entry.kind { SliceParagraphKind::Negative => Some(entry.id.clone()), SliceParagraphKind::Positive { .. } => None, }) .collect(); let negative_seed = mix_seed( &format!("{}::negatives", dataset.metadata.id), params.base_seed, ); let candidates = ordered_negative_indices(dataset, &positive_ids, negative_seed); let mut added = false; for idx in candidates { if negative_ids.len() >= target_negatives { break; } let paragraph = &dataset.paragraphs[idx]; if negative_ids.contains(paragraph.id.as_str()) || positive_ids.contains(paragraph.id.as_str()) { continue; } manifest.paragraphs.push(SliceParagraphEntry { id: paragraph.id.clone(), kind: SliceParagraphKind::Negative, shard_path: Some(default_shard_path(¶graph.id)), }); negative_ids.insert(paragraph.id.clone()); added = true; } if negative_ids.len() < target_negatives { warn!( dataset = %dataset.metadata.id, desired = target_negatives, available = negative_ids.len(), requested_corpus, "Insufficient negative paragraphs to satisfy multiplier" ); } added } fn ordered_negative_indices( dataset: &ConvertedDataset, positive_ids: &HashSet, rng_seed: u64, ) -> Vec { let mut candidates: Vec = dataset .paragraphs .iter() .enumerate() .filter_map(|(idx, paragraph)| { if positive_ids.contains(paragraph.id.as_str()) { None } else { Some(idx) } }) .collect(); let mut rng = StdRng::seed_from_u64(rng_seed); candidates.shuffle(&mut rng); candidates } fn refresh_manifest_stats(manifest: &mut SliceManifest) { manifest.case_count = manifest.cases.len(); manifest.positive_paragraphs = manifest .paragraphs .iter() .filter(|entry| matches!(entry.kind, SliceParagraphKind::Positive { .. })) .count(); manifest.negative_paragraphs = manifest .paragraphs .iter() .filter(|entry| matches!(entry.kind, SliceParagraphKind::Negative)) .count(); manifest.total_paragraphs = manifest.paragraphs.len(); } fn ensure_shard_paths(manifest: &mut SliceManifest) -> bool { let mut changed = false; for entry in &mut manifest.paragraphs { if entry.shard_path.is_none() { entry.shard_path = Some(default_shard_path(&entry.id)); changed = true; } } changed } #[allow( clippy::cast_precision_loss, clippy::cast_possible_truncation, clippy::cast_sign_loss )] fn desired_negative_target( positive_count: usize, requested_corpus: usize, dataset_paragraphs: usize, multiplier: f32, ) -> usize { if positive_count == 0 { return 0; } let ratio = multiplier.max(0.0); let mut desired = ((positive_count as f32) * ratio).ceil() as usize; let max_total = requested_corpus.min(dataset_paragraphs).max(positive_count); let max_negatives = max_total.saturating_sub(positive_count); desired = desired.min(max_negatives); desired } fn manifest_to_resolved<'a>( dataset: &'a ConvertedDataset, index: &DatasetIndex, manifest: SliceManifest, path: PathBuf, ) -> Result> { if manifest.version != SLICE_VERSION { return Err(anyhow!( "slice version {} does not match expected {}", manifest.version, SLICE_VERSION )); } let mut paragraphs = Vec::with_capacity(manifest.paragraphs.len()); for entry in &manifest.paragraphs { let paragraph = index.paragraph(dataset, &entry.id)?; if let SliceParagraphKind::Positive { question_ids } = &entry.kind { for question_id in question_ids { let (linked_paragraph, _) = index.question(dataset, question_id)?; if linked_paragraph.id != entry.id { return Err(anyhow!( "slice question '{}' expected paragraph '{}', found '{}'", question_id, entry.id, linked_paragraph.id )); } } } paragraphs.push(paragraph); } let mut cases = Vec::with_capacity(manifest.cases.len()); for entry in &manifest.cases { let (paragraph, question) = index.question(dataset, &entry.question_id)?; if paragraph.id != entry.paragraph_id { return Err(anyhow!( "slice case '{}' expected paragraph '{}', found '{}'", entry.question_id, entry.paragraph_id, paragraph.id )); } cases.push(CaseRef { paragraph, question, }); } if cases.is_empty() { return Err(anyhow!( "slice '{}' contains no cases after validation", manifest.slice_id )); } Ok(ResolvedSlice { manifest, path, paragraphs, cases, }) } fn compute_slice_id(key: &SliceKey<'_>) -> Result { let payload = serde_json::to_vec(key).context("SliceKey serialisation failed")?; let mut hasher = Sha256::new(); hasher.update(payload); let digest = hasher.finalize(); Ok(digest .iter() .take(16) .fold(String::with_capacity(32), |mut s, b| { let _ = write!(s, "{b:02x}"); s })) } pub fn read_manifest_if_exists(path: &Path) -> Result> { if !path.exists() { return Ok(None); } read_manifest(path).map(Some) } pub fn cached_manifest_path(config: &crate::args::Config) -> Option { let slice_arg = config.slice.as_deref()?; let explicit_path = Path::new(slice_arg); if explicit_path.exists() { return Some(explicit_path.to_path_buf()); } Some( config .cache_dir .join("slices") .join(config.dataset.id()) .join(format!("{slice_arg}.json")), ) } pub fn manifest_is_complete(manifest: &SliceManifest, config: &SliceConfig<'_>) -> bool { let requested_limit = config.limit.unwrap_or(manifest.case_count.max(1)).max(1); if manifest.case_count < requested_limit { return false; } let requested_corpus = config .corpus_limit .unwrap_or(manifest.total_paragraphs.max(1)) .max(1); let desired_negatives = desired_negative_target( manifest.positive_paragraphs, requested_corpus, manifest .total_paragraphs .max(manifest.positive_paragraphs.max(1)), config.negative_multiplier, ); manifest.negative_paragraphs >= desired_negatives } fn read_manifest(path: &Path) -> Result { let raw = fs::read_to_string(path) .with_context(|| format!("reading slice manifest {}", path.display()))?; let manifest: SliceManifest = serde_json::from_str(&raw) .with_context(|| format!("parsing slice manifest {}", path.display()))?; Ok(manifest) } fn write_manifest(path: &Path, manifest: &SliceManifest) -> Result<()> { if let Some(parent) = path.parent() { fs::create_dir_all(parent) .with_context(|| format!("creating slice directory {}", parent.display()))?; } let json = serde_json::to_vec_pretty(manifest).context("serialising slice manifest to JSON")?; fs::write(path, json).with_context(|| format!("writing slice manifest {}", path.display()))?; Ok(()) } pub fn ledger_target(config: &Config) -> Option { match (config.slice_grow, config.limit) { (Some(grow), Some(limit)) => Some(limit.max(grow)), (Some(grow), None) => Some(grow), (None, limit) => limit, } } /// Grow the slice ledger to contain the target number of cases. pub fn grow_slice(dataset: &ConvertedDataset, config: &Config) -> Result<()> { let ledger_limit = ledger_target(config); let slice_settings = slice_config_with_limit(config, ledger_limit); let slice = resolve_slice(dataset, &slice_settings).context("resolving dataset slice")?; info!( slice = slice.manifest.slice_id.as_str(), cases = slice.manifest.case_count, positives = slice.manifest.positive_paragraphs, negatives = slice.manifest.negative_paragraphs, total_paragraphs = slice.manifest.total_paragraphs, "Slice ledger ready" ); println!( "Slice `{}` now contains {} questions ({} positives, {} negatives)", slice.manifest.slice_id, slice.manifest.case_count, slice.manifest.positive_paragraphs, slice.manifest.negative_paragraphs ); Ok(()) } pub fn slice_config_with_limit(config: &Config, limit_override: Option) -> SliceConfig<'_> { SliceConfig { cache_dir: config.cache_dir.as_path(), force_convert: config.force_convert, explicit_slice: config.slice.as_deref(), limit: limit_override.or(config.limit), corpus_limit: config.corpus_limit, slice_seed: config.slice_seed, llm_mode: config.llm_mode, negative_multiplier: config.negative_multiplier, require_verified_chunks: config.retrieval.require_verified_chunks, } } #[cfg(test)] mod tests { use super::*; use crate::datasets::{ ConvertedDataset, ConvertedParagraph, ConvertedQuestion, DatasetKind, DatasetMetadata, }; use tempfile::tempdir; fn sample_dataset() -> ConvertedDataset { let metadata = DatasetMetadata::for_kind(DatasetKind::SquadV2, false); ConvertedDataset { generated_at: Utc::now(), metadata, source: "test-source".to_string(), paragraphs: vec![ ConvertedParagraph { id: "p1".to_string(), title: "Alpha".to_string(), context: "Alpha context".to_string(), questions: vec![ConvertedQuestion { id: "q1".to_string(), question: "What is alpha?".to_string(), answers: vec!["Alpha".to_string()], is_impossible: false, }], }, ConvertedParagraph { id: "p2".to_string(), title: "Beta".to_string(), context: "Beta context".to_string(), questions: vec![ConvertedQuestion { id: "q2".to_string(), question: "What is beta?".to_string(), answers: vec!["Beta".to_string()], is_impossible: false, }], }, ConvertedParagraph { id: "p3".to_string(), title: "Gamma".to_string(), context: "Gamma context".to_string(), questions: vec![ConvertedQuestion { id: "q3".to_string(), question: "What is gamma?".to_string(), answers: vec!["Gamma".to_string()], is_impossible: false, }], }, ], } } #[test] fn resolve_slice_reuses_cached_manifest() -> Result<()> { let dataset = sample_dataset(); let temp = tempdir().context("creating temp directory")?; let mut config = SliceConfig { cache_dir: temp.path(), force_convert: false, explicit_slice: None, limit: Some(2), corpus_limit: Some(3), slice_seed: 0x5eed_2025, llm_mode: false, negative_multiplier: DEFAULT_NEGATIVE_MULTIPLIER, require_verified_chunks: true, }; let first = resolve_slice(&dataset, &config)?; assert!(first.path.exists()); let initial_generated = first.manifest.generated_at; let second = resolve_slice(&dataset, &config)?; assert_eq!(first.manifest.slice_id, second.manifest.slice_id); assert_eq!(initial_generated, second.manifest.generated_at); config.force_convert = true; let third = resolve_slice(&dataset, &config)?; assert_eq!(first.manifest.slice_id, third.manifest.slice_id); assert_ne!(third.manifest.generated_at, initial_generated); Ok(()) } #[test] #[allow(clippy::indexing_slicing)] fn select_window_yields_expected_cases() -> Result<()> { let dataset = sample_dataset(); let temp = tempdir().context("creating temp directory")?; let config = SliceConfig { cache_dir: temp.path(), force_convert: false, explicit_slice: None, limit: Some(3), corpus_limit: Some(3), slice_seed: 0x5eed_2025, llm_mode: false, negative_multiplier: DEFAULT_NEGATIVE_MULTIPLIER, require_verified_chunks: true, }; let resolved = resolve_slice(&dataset, &config)?; let window = select_window(&resolved, 1, Some(1))?; assert_eq!(window.offset, 1); assert_eq!(window.length, 1); assert_eq!(window.total_cases, resolved.manifest.case_count); assert_eq!(window.cases.len(), 1); let positive_ids: Vec<&str> = window.positive_ids().collect(); assert_eq!(positive_ids.len(), 1); assert!( resolved .manifest .paragraphs .iter() .any(|entry| entry.id == positive_ids[0]) ); Ok(()) } #[test] #[allow(clippy::indexing_slicing)] fn beir_mix_balances_and_rebalances() -> Result<()> { let mut paragraphs = Vec::new(); let counts = [ ("fever", 1usize), ("fiqa", 2usize), ("hotpotqa", 1usize), ("nfcorpus", 0usize), ("quora", 3usize), ("trec-covid", 2usize), ]; for (prefix, count) in counts { for idx in 0..count { let q_id = format!("{prefix}-q{idx}"); paragraphs.push(ConvertedParagraph { id: format!("{prefix}-p{idx}"), title: format!("{prefix} title"), context: format!("{prefix} context {idx}"), questions: vec![ConvertedQuestion { id: q_id, question: format!("{prefix} question {idx}"), answers: vec!["answer".to_string()], is_impossible: false, }], }); } } let metadata = DatasetMetadata::for_kind(DatasetKind::Beir, false); let dataset = ConvertedDataset { generated_at: Utc::now(), metadata, source: "beir-mix".to_string(), paragraphs, }; let params = BuildParams { include_impossible: false, base_seed: 0xAA, rng_seed: 0xBB, }; let refs = beir::ordered_question_refs_beir(&dataset, ¶ms, 8)?; let mut per_prefix: HashMap = HashMap::new(); for (p_idx, q_idx) in refs { let question = &dataset.paragraphs[p_idx].questions[q_idx]; let prefix = beir::question_prefix(&question.id).unwrap_or("unknown"); *per_prefix.entry(prefix.to_string()).or_default() += 1; } assert_eq!(per_prefix.get("fever").copied().unwrap_or(0), 1); assert_eq!(per_prefix.get("fiqa").copied().unwrap_or(0), 2); assert_eq!(per_prefix.get("hotpotqa").copied().unwrap_or(0), 1); assert_eq!(per_prefix.get("nfcorpus").copied().unwrap_or(0), 0); assert_eq!(per_prefix.get("quora").copied().unwrap_or(0), 2); assert_eq!(per_prefix.get("trec-covid").copied().unwrap_or(0), 2); Ok(()) } }