Files
minne/evaluations/src/slice/mod.rs
T
2026-06-24 22:02:31 +02:00

1191 lines
37 KiB
Rust

use std::{
collections::{HashMap, HashSet},
fmt::Write,
fs,
path::{Path, PathBuf},
};
use anyhow::{Context, Result, anyhow};
use chrono::{DateTime, Utc};
use rand::{SeedableRng, rngs::StdRng, seq::SliceRandom};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use tracing::{info, warn};
use crate::{
args::Config,
datasets::{ConvertedDataset, ConvertedParagraph, ConvertedQuestion, DatasetKind},
};
mod beir;
mod build;
use build::{BuildParams, mix_seed};
const SLICE_VERSION: u32 = 2;
pub const DEFAULT_NEGATIVE_MULTIPLIER: f32 = 4.0;
#[derive(Debug, Clone)]
pub struct SliceConfig<'a> {
pub cache_dir: &'a Path,
pub force_convert: bool,
pub explicit_slice: Option<&'a str>,
pub limit: Option<usize>,
pub corpus_limit: Option<usize>,
pub slice_seed: u64,
pub llm_mode: bool,
pub negative_multiplier: f32,
pub require_verified_chunks: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SliceManifest {
pub version: u32,
pub slice_id: String,
pub dataset_id: String,
pub dataset_label: String,
pub dataset_source: String,
pub includes_unanswerable: bool,
#[serde(default = "default_require_verified_chunks")]
pub require_verified_chunks: bool,
pub seed: u64,
pub requested_limit: Option<usize>,
pub requested_corpus: usize,
pub generated_at: DateTime<Utc>,
pub case_count: usize,
pub positive_paragraphs: usize,
pub negative_paragraphs: usize,
pub total_paragraphs: usize,
pub negative_multiplier: f32,
pub cases: Vec<SliceCaseEntry>,
pub paragraphs: Vec<SliceParagraphEntry>,
}
fn default_require_verified_chunks() -> bool {
false
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SliceCaseEntry {
pub question_id: String,
pub paragraph_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SliceParagraphEntry {
pub id: String,
pub kind: SliceParagraphKind,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub shard_path: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum SliceParagraphKind {
Positive { question_ids: Vec<String> },
Negative,
}
pub fn paragraph_storage_key(paragraph_id: &str) -> String {
sanitize_identifier(paragraph_id)
}
pub(crate) fn default_shard_path(paragraph_id: &str) -> String {
let sanitized = paragraph_storage_key(paragraph_id);
format!("paragraphs/{sanitized}.json")
}
fn sanitize_identifier(input: &str) -> String {
let mut sanitized = String::with_capacity(input.len());
for ch in input.chars() {
if ch.is_ascii_alphanumeric() {
sanitized.push(ch);
} else {
sanitized.push('-');
}
}
let trimmed = sanitized.trim_matches('-').to_string();
if trimmed.is_empty() {
let mut hasher = Sha256::new();
hasher.update(input.as_bytes());
let digest = hasher.finalize();
digest
.iter()
.take(6)
.fold(String::with_capacity(12), |mut s, b| {
let _ = write!(s, "{b:02x}");
s
})
} else {
trimmed
}
}
#[derive(Debug, Clone)]
pub struct ResolvedSlice<'a> {
pub manifest: SliceManifest,
pub path: PathBuf,
pub paragraphs: Vec<&'a ConvertedParagraph>,
pub cases: Vec<CaseRef<'a>>,
}
#[derive(Debug, Clone)]
pub struct SliceWindow<'a> {
pub offset: usize,
pub length: usize,
pub total_cases: usize,
pub cases: Vec<CaseRef<'a>>,
positive_paragraph_ids: Vec<String>,
}
impl SliceWindow<'_> {
pub fn positive_ids(&self) -> impl Iterator<Item = &str> {
self.positive_paragraph_ids
.iter()
.map(std::string::String::as_str)
}
}
#[derive(Debug, Clone)]
pub struct CaseRef<'a> {
pub paragraph: &'a ConvertedParagraph,
pub question: &'a ConvertedQuestion,
}
struct DatasetIndex {
paragraph_by_id: HashMap<String, usize>,
question_by_id: HashMap<String, (usize, usize)>,
}
impl DatasetIndex {
fn build(dataset: &ConvertedDataset) -> Self {
let mut paragraph_by_id = HashMap::new();
let mut question_by_id = HashMap::new();
for (p_idx, paragraph) in dataset.paragraphs.iter().enumerate() {
paragraph_by_id.insert(paragraph.id.clone(), p_idx);
for (q_idx, question) in paragraph.questions.iter().enumerate() {
question_by_id.insert(question.id.clone(), (p_idx, q_idx));
}
}
Self {
paragraph_by_id,
question_by_id,
}
}
fn paragraph<'a>(
&self,
dataset: &'a ConvertedDataset,
id: &str,
) -> Result<&'a ConvertedParagraph> {
let idx = self
.paragraph_by_id
.get(id)
.ok_or_else(|| anyhow!("slice references unknown paragraph '{id}'"))?;
dataset
.paragraphs
.get(*idx)
.ok_or_else(|| anyhow!("paragraph index out of bounds"))
}
fn question<'a>(
&self,
dataset: &'a ConvertedDataset,
question_id: &str,
) -> Result<(&'a ConvertedParagraph, &'a ConvertedQuestion)> {
let (p_idx, q_idx) = self
.question_by_id
.get(question_id)
.ok_or_else(|| anyhow!("slice references unknown question '{question_id}'"))?;
let paragraph = dataset
.paragraphs
.get(*p_idx)
.ok_or_else(|| anyhow!("paragraph index out of bounds for question '{question_id}'"))?;
let question = paragraph
.questions
.get(*q_idx)
.ok_or_else(|| anyhow!("slice maps question '{question_id}' to missing index"))?;
Ok((paragraph, question))
}
}
#[derive(Debug, Serialize)]
struct SliceKey<'a> {
dataset_id: &'a str,
includes_unanswerable: bool,
require_verified_chunks: bool,
requested_corpus: usize,
seed: u64,
}
#[allow(clippy::too_many_lines)]
pub fn resolve_slice<'a>(
dataset: &'a ConvertedDataset,
config: &SliceConfig<'_>,
) -> Result<ResolvedSlice<'a>> {
let index = DatasetIndex::build(dataset);
if let Some(slice_arg) = config.explicit_slice {
let path = explicit_slice_path(dataset, config, slice_arg);
if path.exists() {
let (path, manifest) = load_explicit_slice(dataset, &index, config, slice_arg)?;
let resolved = manifest_to_resolved(dataset, &index, manifest, path)?;
info!(
slice = %resolved.manifest.slice_id,
path = %resolved.path.display(),
cases = resolved.manifest.case_count,
positives = resolved.manifest.positive_paragraphs,
negatives = resolved.manifest.negative_paragraphs,
"Using explicitly selected slice"
);
return Ok(resolved);
}
let resolved = materialize_slice_ledger(dataset, config, &index, slice_arg, path)?;
info!(
slice = %resolved.manifest.slice_id,
path = %resolved.path.display(),
cases = resolved.manifest.case_count,
positives = resolved.manifest.positive_paragraphs,
negatives = resolved.manifest.negative_paragraphs,
"Built catalog slice ledger"
);
return Ok(resolved);
}
let requested_corpus = config
.corpus_limit
.unwrap_or(dataset.paragraphs.len())
.min(dataset.paragraphs.len())
.max(1);
let key = SliceKey {
dataset_id: dataset.metadata.id.as_str(),
includes_unanswerable: config.llm_mode,
require_verified_chunks: config.require_verified_chunks,
requested_corpus,
seed: config.slice_seed,
};
let slice_id = compute_slice_id(&key)?;
let base = config
.cache_dir
.join("slices")
.join(dataset.metadata.id.as_str());
let path = base.join(format!("{slice_id}.json"));
materialize_slice_ledger(dataset, config, &index, &slice_id, path)
}
#[allow(clippy::indexing_slicing, clippy::arithmetic_side_effects)]
pub fn select_window<'a>(
resolved: &'a ResolvedSlice<'a>,
offset: usize,
limit: Option<usize>,
) -> Result<SliceWindow<'a>> {
let total = resolved.manifest.case_count;
if total == 0 {
return Err(anyhow!(
"slice '{}' contains no cases",
resolved.manifest.slice_id
));
}
if offset >= total {
return Err(anyhow!(
"slice offset {offset} exceeds available cases ({total})",
));
}
let available = total - offset;
let requested = limit.unwrap_or(available).max(1);
let length = requested.min(available);
let cases = resolved.cases[offset..offset + length].to_vec();
let mut seen = HashSet::new();
let mut positive_ids = Vec::new();
for case in &cases {
if seen.insert(case.paragraph.id.as_str()) {
positive_ids.push(case.paragraph.id.clone());
}
}
Ok(SliceWindow {
offset,
length,
total_cases: total,
cases,
positive_paragraph_ids: positive_ids,
})
}
#[allow(dead_code)]
pub fn full_window<'a>(resolved: &'a ResolvedSlice<'a>) -> Result<SliceWindow<'a>> {
select_window(resolved, 0, None)
}
fn explicit_slice_path(
dataset: &ConvertedDataset,
config: &SliceConfig<'_>,
slice_arg: &str,
) -> PathBuf {
let explicit_path = Path::new(slice_arg);
if explicit_path.exists() {
explicit_path.to_path_buf()
} else {
config
.cache_dir
.join("slices")
.join(dataset.metadata.id.as_str())
.join(format!("{slice_arg}.json"))
}
}
#[allow(clippy::too_many_lines)]
fn materialize_slice_ledger<'a>(
dataset: &'a ConvertedDataset,
config: &SliceConfig<'_>,
index: &DatasetIndex,
slice_id: &str,
path: PathBuf,
) -> Result<ResolvedSlice<'a>> {
let requested_corpus = config
.corpus_limit
.unwrap_or(dataset.paragraphs.len())
.min(dataset.paragraphs.len())
.max(1);
let total_questions = dataset
.paragraphs
.iter()
.map(|p| p.questions.len())
.sum::<usize>()
.max(1);
let requested_limit = config
.limit
.unwrap_or(total_questions)
.min(total_questions)
.max(1);
let mut manifest = if !config.force_convert && path.exists() {
match read_manifest(&path) {
Ok(manifest) if manifest.dataset_id == dataset.metadata.id => {
if manifest.includes_unanswerable != config.llm_mode {
warn!(
slice = manifest.slice_id,
path = %path.display(),
expected = config.llm_mode,
found = manifest.includes_unanswerable,
"Slice manifest includes_unanswerable mismatch; regenerating"
);
None
} else if manifest.require_verified_chunks != config.require_verified_chunks {
warn!(
slice = manifest.slice_id,
path = %path.display(),
expected = config.require_verified_chunks,
found = manifest.require_verified_chunks,
"Slice manifest verified-chunk requirement mismatch; regenerating"
);
None
} else {
Some(manifest)
}
}
Ok(manifest) => {
warn!(
slice = manifest.slice_id,
path = %path.display(),
loaded_dataset = %manifest.dataset_id,
expected = %dataset.metadata.id,
"Slice manifest targets different dataset; regenerating"
);
None
}
Err(err) => {
warn!(
path = %path.display(),
error = %err,
"Failed to read cached slice; regenerating"
);
None
}
}
} else {
None
};
let params = BuildParams {
include_impossible: config.llm_mode,
base_seed: config.slice_seed,
rng_seed: mix_seed(dataset.metadata.id.as_str(), config.slice_seed),
};
if manifest
.as_ref()
.is_some_and(|manifest| manifest.version != SLICE_VERSION)
{
warn!(
slice = manifest.as_ref().map_or("unknown", |m| m.slice_id.as_str()),
found = manifest.as_ref().map_or(0, |m| m.version),
expected = SLICE_VERSION,
"Slice manifest version mismatch; regenerating"
);
manifest = None;
}
let mut manifest = manifest.unwrap_or_else(|| {
empty_manifest(
dataset,
slice_id.to_string(),
&params,
requested_corpus,
config.negative_multiplier,
config.require_verified_chunks,
config.limit,
)
});
manifest.requested_limit = config.limit;
manifest.requested_corpus = requested_corpus;
manifest.negative_multiplier = config.negative_multiplier;
manifest.includes_unanswerable = config.llm_mode;
manifest.require_verified_chunks = config.require_verified_chunks;
let mut changed = ensure_shard_paths(&mut manifest);
changed |= ensure_case_capacity(dataset, &mut manifest, &params, requested_limit)?;
refresh_manifest_stats(&mut manifest);
let desired_negatives = desired_negative_target(
manifest.positive_paragraphs,
requested_corpus,
dataset.paragraphs.len(),
config.negative_multiplier,
);
changed |= ensure_negative_pool(
dataset,
&mut manifest,
&params,
desired_negatives,
requested_corpus,
);
refresh_manifest_stats(&mut manifest);
if changed {
manifest.generated_at = Utc::now();
write_manifest(&path, &manifest)?;
info!(
slice = %manifest.slice_id,
path = %path.display(),
cases = manifest.case_count,
positives = manifest.positive_paragraphs,
negatives = manifest.negative_paragraphs,
"Updated dataset slice ledger"
);
} else {
info!(
slice = %manifest.slice_id,
path = %path.display(),
cases = manifest.case_count,
positives = manifest.positive_paragraphs,
negatives = manifest.negative_paragraphs,
"Reusing cached slice ledger"
);
}
manifest_to_resolved(dataset, index, manifest, path)
}
fn load_explicit_slice(
dataset: &ConvertedDataset,
index: &DatasetIndex,
config: &SliceConfig<'_>,
slice_arg: &str,
) -> Result<(PathBuf, SliceManifest)> {
let candidate_path = explicit_slice_path(dataset, config, slice_arg);
let manifest = read_manifest(&candidate_path)
.with_context(|| format!("reading slice manifest at {}", candidate_path.display()))?;
if manifest.dataset_id != dataset.metadata.id {
return Err(anyhow!(
"slice '{}' targets dataset '{}', but '{}' is loaded",
manifest.slice_id,
manifest.dataset_id,
dataset.metadata.id
));
}
if manifest.includes_unanswerable != config.llm_mode {
return Err(anyhow!(
"slice '{}' includes_unanswerable mismatch (expected {}, found {})",
manifest.slice_id,
config.llm_mode,
manifest.includes_unanswerable
));
}
if manifest.require_verified_chunks != config.require_verified_chunks {
return Err(anyhow!(
"slice '{}' verified-chunk requirement mismatch (expected {}, found {})",
manifest.slice_id,
config.require_verified_chunks,
manifest.require_verified_chunks
));
}
// Validate the manifest before returning.
manifest_to_resolved(dataset, index, manifest.clone(), candidate_path.clone())?;
Ok((candidate_path, manifest))
}
fn empty_manifest(
dataset: &ConvertedDataset,
slice_id: String,
params: &BuildParams,
requested_corpus: usize,
negative_multiplier: f32,
require_verified_chunks: bool,
requested_limit: Option<usize>,
) -> SliceManifest {
SliceManifest {
version: SLICE_VERSION,
slice_id,
dataset_id: dataset.metadata.id.clone(),
dataset_label: dataset.metadata.label.clone(),
dataset_source: dataset.source.clone(),
includes_unanswerable: params.include_impossible,
require_verified_chunks,
seed: params.base_seed,
requested_limit,
requested_corpus,
negative_multiplier,
generated_at: Utc::now(),
case_count: 0,
positive_paragraphs: 0,
negative_paragraphs: 0,
total_paragraphs: 0,
cases: Vec::new(),
paragraphs: Vec::new(),
}
}
#[allow(clippy::indexing_slicing, clippy::arithmetic_side_effects)]
fn ensure_case_capacity(
dataset: &ConvertedDataset,
manifest: &mut SliceManifest,
params: &BuildParams,
target_cases: usize,
) -> Result<bool> {
if manifest.case_count >= target_cases {
return Ok(false);
}
let question_refs = ordered_question_refs(dataset, params, target_cases)?;
let mut existing_questions: HashSet<String> = manifest
.cases
.iter()
.map(|case| case.question_id.clone())
.collect();
let mut paragraph_positions: HashMap<String, usize> = manifest
.paragraphs
.iter()
.enumerate()
.map(|(idx, entry)| (entry.id.clone(), idx))
.collect();
let mut changed = false;
for (p_idx, q_idx) in question_refs {
if manifest.case_count >= target_cases {
break;
}
let paragraph = &dataset.paragraphs[p_idx];
let question = &paragraph.questions[q_idx];
if !existing_questions.insert(question.id.clone()) {
continue;
}
if let Some(idx) = paragraph_positions.get(paragraph.id.as_str()).copied() {
match &mut manifest.paragraphs[idx].kind {
SliceParagraphKind::Positive { question_ids } => {
if !question_ids.contains(&question.id) {
question_ids.push(question.id.clone());
}
}
SliceParagraphKind::Negative => {
manifest.paragraphs[idx].kind = SliceParagraphKind::Positive {
question_ids: vec![question.id.clone()],
};
}
}
} else {
manifest.paragraphs.push(SliceParagraphEntry {
id: paragraph.id.clone(),
kind: SliceParagraphKind::Positive {
question_ids: vec![question.id.clone()],
},
shard_path: Some(default_shard_path(&paragraph.id)),
});
let idx = manifest.paragraphs.len() - 1;
paragraph_positions.insert(paragraph.id.clone(), idx);
}
manifest.cases.push(SliceCaseEntry {
question_id: question.id.clone(),
paragraph_id: paragraph.id.clone(),
});
manifest.case_count += 1;
changed = true;
}
if manifest.case_count < target_cases {
return Err(anyhow!(
"only {}/{} eligible questions available for dataset {}",
manifest.case_count,
target_cases,
dataset.metadata.id
));
}
Ok(changed)
}
fn ordered_question_refs(
dataset: &ConvertedDataset,
params: &BuildParams,
target_cases: usize,
) -> Result<Vec<(usize, usize)>> {
if dataset.metadata.id == DatasetKind::Beir.id() {
return beir::ordered_question_refs_beir(dataset, params, target_cases);
}
let mut question_refs = Vec::new();
for (p_idx, paragraph) in dataset.paragraphs.iter().enumerate() {
for (q_idx, question) in paragraph.questions.iter().enumerate() {
let include = if params.include_impossible {
true
} else {
!question.is_impossible && !question.answers.is_empty()
};
if include {
question_refs.push((p_idx, q_idx));
}
}
}
if question_refs.is_empty() {
return Err(anyhow!(
"no eligible questions found for dataset {}; cannot build slice",
dataset.metadata.id
));
}
let mut rng = StdRng::seed_from_u64(params.rng_seed);
question_refs.shuffle(&mut rng);
Ok(question_refs)
}
#[allow(clippy::indexing_slicing)]
fn ensure_negative_pool(
dataset: &ConvertedDataset,
manifest: &mut SliceManifest,
params: &BuildParams,
target_negatives: usize,
requested_corpus: usize,
) -> bool {
let current_negatives = manifest
.paragraphs
.iter()
.filter(|entry| matches!(entry.kind, SliceParagraphKind::Negative))
.count();
if current_negatives >= target_negatives {
return false;
}
let positive_ids: HashSet<String> = manifest
.paragraphs
.iter()
.filter_map(|entry| match entry.kind {
SliceParagraphKind::Positive { .. } => Some(entry.id.clone()),
SliceParagraphKind::Negative => None,
})
.collect();
let mut negative_ids: HashSet<String> = manifest
.paragraphs
.iter()
.filter_map(|entry| match entry.kind {
SliceParagraphKind::Negative => Some(entry.id.clone()),
SliceParagraphKind::Positive { .. } => None,
})
.collect();
let negative_seed = mix_seed(
&format!("{}::negatives", dataset.metadata.id),
params.base_seed,
);
let candidates = ordered_negative_indices(dataset, &positive_ids, negative_seed);
let mut added = false;
for idx in candidates {
if negative_ids.len() >= target_negatives {
break;
}
let paragraph = &dataset.paragraphs[idx];
if negative_ids.contains(paragraph.id.as_str())
|| positive_ids.contains(paragraph.id.as_str())
{
continue;
}
manifest.paragraphs.push(SliceParagraphEntry {
id: paragraph.id.clone(),
kind: SliceParagraphKind::Negative,
shard_path: Some(default_shard_path(&paragraph.id)),
});
negative_ids.insert(paragraph.id.clone());
added = true;
}
if negative_ids.len() < target_negatives {
warn!(
dataset = %dataset.metadata.id,
desired = target_negatives,
available = negative_ids.len(),
requested_corpus,
"Insufficient negative paragraphs to satisfy multiplier"
);
}
added
}
fn ordered_negative_indices(
dataset: &ConvertedDataset,
positive_ids: &HashSet<String>,
rng_seed: u64,
) -> Vec<usize> {
let mut candidates: Vec<usize> = dataset
.paragraphs
.iter()
.enumerate()
.filter_map(|(idx, paragraph)| {
if positive_ids.contains(paragraph.id.as_str()) {
None
} else {
Some(idx)
}
})
.collect();
let mut rng = StdRng::seed_from_u64(rng_seed);
candidates.shuffle(&mut rng);
candidates
}
fn refresh_manifest_stats(manifest: &mut SliceManifest) {
manifest.case_count = manifest.cases.len();
manifest.positive_paragraphs = manifest
.paragraphs
.iter()
.filter(|entry| matches!(entry.kind, SliceParagraphKind::Positive { .. }))
.count();
manifest.negative_paragraphs = manifest
.paragraphs
.iter()
.filter(|entry| matches!(entry.kind, SliceParagraphKind::Negative))
.count();
manifest.total_paragraphs = manifest.paragraphs.len();
}
fn ensure_shard_paths(manifest: &mut SliceManifest) -> bool {
let mut changed = false;
for entry in &mut manifest.paragraphs {
if entry.shard_path.is_none() {
entry.shard_path = Some(default_shard_path(&entry.id));
changed = true;
}
}
changed
}
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
fn desired_negative_target(
positive_count: usize,
requested_corpus: usize,
dataset_paragraphs: usize,
multiplier: f32,
) -> usize {
if positive_count == 0 {
return 0;
}
let ratio = multiplier.max(0.0);
let mut desired = ((positive_count as f32) * ratio).ceil() as usize;
let max_total = requested_corpus.min(dataset_paragraphs).max(positive_count);
let max_negatives = max_total.saturating_sub(positive_count);
desired = desired.min(max_negatives);
desired
}
fn manifest_to_resolved<'a>(
dataset: &'a ConvertedDataset,
index: &DatasetIndex,
manifest: SliceManifest,
path: PathBuf,
) -> Result<ResolvedSlice<'a>> {
if manifest.version != SLICE_VERSION {
return Err(anyhow!(
"slice version {} does not match expected {}",
manifest.version,
SLICE_VERSION
));
}
let mut paragraphs = Vec::with_capacity(manifest.paragraphs.len());
for entry in &manifest.paragraphs {
let paragraph = index.paragraph(dataset, &entry.id)?;
if let SliceParagraphKind::Positive { question_ids } = &entry.kind {
for question_id in question_ids {
let (linked_paragraph, _) = index.question(dataset, question_id)?;
if linked_paragraph.id != entry.id {
return Err(anyhow!(
"slice question '{}' expected paragraph '{}', found '{}'",
question_id,
entry.id,
linked_paragraph.id
));
}
}
}
paragraphs.push(paragraph);
}
let mut cases = Vec::with_capacity(manifest.cases.len());
for entry in &manifest.cases {
let (paragraph, question) = index.question(dataset, &entry.question_id)?;
if paragraph.id != entry.paragraph_id {
return Err(anyhow!(
"slice case '{}' expected paragraph '{}', found '{}'",
entry.question_id,
entry.paragraph_id,
paragraph.id
));
}
cases.push(CaseRef {
paragraph,
question,
});
}
if cases.is_empty() {
return Err(anyhow!(
"slice '{}' contains no cases after validation",
manifest.slice_id
));
}
Ok(ResolvedSlice {
manifest,
path,
paragraphs,
cases,
})
}
fn compute_slice_id(key: &SliceKey<'_>) -> Result<String> {
let payload = serde_json::to_vec(key).context("SliceKey serialisation failed")?;
let mut hasher = Sha256::new();
hasher.update(payload);
let digest = hasher.finalize();
Ok(digest
.iter()
.take(16)
.fold(String::with_capacity(32), |mut s, b| {
let _ = write!(s, "{b:02x}");
s
}))
}
pub fn read_manifest_if_exists(path: &Path) -> Result<Option<SliceManifest>> {
if !path.exists() {
return Ok(None);
}
read_manifest(path).map(Some)
}
pub fn cached_manifest_path(config: &crate::args::Config) -> Option<PathBuf> {
let slice_arg = config.slice.as_deref()?;
let explicit_path = Path::new(slice_arg);
if explicit_path.exists() {
return Some(explicit_path.to_path_buf());
}
Some(
config
.cache_dir
.join("slices")
.join(config.dataset.id())
.join(format!("{slice_arg}.json")),
)
}
pub fn manifest_is_complete(manifest: &SliceManifest, config: &SliceConfig<'_>) -> bool {
let requested_limit = config.limit.unwrap_or(manifest.case_count.max(1)).max(1);
if manifest.case_count < requested_limit {
return false;
}
let requested_corpus = config
.corpus_limit
.unwrap_or(manifest.total_paragraphs.max(1))
.max(1);
let desired_negatives = desired_negative_target(
manifest.positive_paragraphs,
requested_corpus,
manifest
.total_paragraphs
.max(manifest.positive_paragraphs.max(1)),
config.negative_multiplier,
);
manifest.negative_paragraphs >= desired_negatives
}
fn read_manifest(path: &Path) -> Result<SliceManifest> {
let raw = fs::read_to_string(path)
.with_context(|| format!("reading slice manifest {}", path.display()))?;
let manifest: SliceManifest = serde_json::from_str(&raw)
.with_context(|| format!("parsing slice manifest {}", path.display()))?;
Ok(manifest)
}
fn write_manifest(path: &Path, manifest: &SliceManifest) -> Result<()> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("creating slice directory {}", parent.display()))?;
}
let json = serde_json::to_vec_pretty(manifest).context("serialising slice manifest to JSON")?;
fs::write(path, json).with_context(|| format!("writing slice manifest {}", path.display()))?;
Ok(())
}
pub fn ledger_target(config: &Config) -> Option<usize> {
match (config.slice_grow, config.limit) {
(Some(grow), Some(limit)) => Some(limit.max(grow)),
(Some(grow), None) => Some(grow),
(None, limit) => limit,
}
}
/// Grow the slice ledger to contain the target number of cases.
pub fn grow_slice(dataset: &ConvertedDataset, config: &Config) -> Result<()> {
let ledger_limit = ledger_target(config);
let slice_settings = slice_config_with_limit(config, ledger_limit);
let slice = resolve_slice(dataset, &slice_settings).context("resolving dataset slice")?;
info!(
slice = slice.manifest.slice_id.as_str(),
cases = slice.manifest.case_count,
positives = slice.manifest.positive_paragraphs,
negatives = slice.manifest.negative_paragraphs,
total_paragraphs = slice.manifest.total_paragraphs,
"Slice ledger ready"
);
println!(
"Slice `{}` now contains {} questions ({} positives, {} negatives)",
slice.manifest.slice_id,
slice.manifest.case_count,
slice.manifest.positive_paragraphs,
slice.manifest.negative_paragraphs
);
Ok(())
}
pub fn slice_config_with_limit(config: &Config, limit_override: Option<usize>) -> SliceConfig<'_> {
SliceConfig {
cache_dir: config.cache_dir.as_path(),
force_convert: config.force_convert,
explicit_slice: config.slice.as_deref(),
limit: limit_override.or(config.limit),
corpus_limit: config.corpus_limit,
slice_seed: config.slice_seed,
llm_mode: config.llm_mode,
negative_multiplier: config.negative_multiplier,
require_verified_chunks: config.retrieval.require_verified_chunks,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::datasets::{
ConvertedDataset, ConvertedParagraph, ConvertedQuestion, DatasetKind, DatasetMetadata,
};
use tempfile::tempdir;
fn sample_dataset() -> ConvertedDataset {
let metadata = DatasetMetadata::for_kind(DatasetKind::SquadV2, false);
ConvertedDataset {
generated_at: Utc::now(),
metadata,
source: "test-source".to_string(),
paragraphs: vec![
ConvertedParagraph {
id: "p1".to_string(),
title: "Alpha".to_string(),
context: "Alpha context".to_string(),
questions: vec![ConvertedQuestion {
id: "q1".to_string(),
question: "What is alpha?".to_string(),
answers: vec!["Alpha".to_string()],
is_impossible: false,
}],
},
ConvertedParagraph {
id: "p2".to_string(),
title: "Beta".to_string(),
context: "Beta context".to_string(),
questions: vec![ConvertedQuestion {
id: "q2".to_string(),
question: "What is beta?".to_string(),
answers: vec!["Beta".to_string()],
is_impossible: false,
}],
},
ConvertedParagraph {
id: "p3".to_string(),
title: "Gamma".to_string(),
context: "Gamma context".to_string(),
questions: vec![ConvertedQuestion {
id: "q3".to_string(),
question: "What is gamma?".to_string(),
answers: vec!["Gamma".to_string()],
is_impossible: false,
}],
},
],
}
}
#[test]
fn resolve_slice_reuses_cached_manifest() -> Result<()> {
let dataset = sample_dataset();
let temp = tempdir().context("creating temp directory")?;
let mut config = SliceConfig {
cache_dir: temp.path(),
force_convert: false,
explicit_slice: None,
limit: Some(2),
corpus_limit: Some(3),
slice_seed: 0x5eed_2025,
llm_mode: false,
negative_multiplier: DEFAULT_NEGATIVE_MULTIPLIER,
require_verified_chunks: true,
};
let first = resolve_slice(&dataset, &config)?;
assert!(first.path.exists());
let initial_generated = first.manifest.generated_at;
let second = resolve_slice(&dataset, &config)?;
assert_eq!(first.manifest.slice_id, second.manifest.slice_id);
assert_eq!(initial_generated, second.manifest.generated_at);
config.force_convert = true;
let third = resolve_slice(&dataset, &config)?;
assert_eq!(first.manifest.slice_id, third.manifest.slice_id);
assert_ne!(third.manifest.generated_at, initial_generated);
Ok(())
}
#[test]
#[allow(clippy::indexing_slicing)]
fn select_window_yields_expected_cases() -> Result<()> {
let dataset = sample_dataset();
let temp = tempdir().context("creating temp directory")?;
let config = SliceConfig {
cache_dir: temp.path(),
force_convert: false,
explicit_slice: None,
limit: Some(3),
corpus_limit: Some(3),
slice_seed: 0x5eed_2025,
llm_mode: false,
negative_multiplier: DEFAULT_NEGATIVE_MULTIPLIER,
require_verified_chunks: true,
};
let resolved = resolve_slice(&dataset, &config)?;
let window = select_window(&resolved, 1, Some(1))?;
assert_eq!(window.offset, 1);
assert_eq!(window.length, 1);
assert_eq!(window.total_cases, resolved.manifest.case_count);
assert_eq!(window.cases.len(), 1);
let positive_ids: Vec<&str> = window.positive_ids().collect();
assert_eq!(positive_ids.len(), 1);
assert!(
resolved
.manifest
.paragraphs
.iter()
.any(|entry| entry.id == positive_ids[0])
);
Ok(())
}
#[test]
#[allow(clippy::indexing_slicing)]
fn beir_mix_balances_and_rebalances() -> Result<()> {
let mut paragraphs = Vec::new();
let counts = [
("fever", 1usize),
("fiqa", 2usize),
("hotpotqa", 1usize),
("nfcorpus", 0usize),
("quora", 3usize),
("trec-covid", 2usize),
];
for (prefix, count) in counts {
for idx in 0..count {
let q_id = format!("{prefix}-q{idx}");
paragraphs.push(ConvertedParagraph {
id: format!("{prefix}-p{idx}"),
title: format!("{prefix} title"),
context: format!("{prefix} context {idx}"),
questions: vec![ConvertedQuestion {
id: q_id,
question: format!("{prefix} question {idx}"),
answers: vec!["answer".to_string()],
is_impossible: false,
}],
});
}
}
let metadata = DatasetMetadata::for_kind(DatasetKind::Beir, false);
let dataset = ConvertedDataset {
generated_at: Utc::now(),
metadata,
source: "beir-mix".to_string(),
paragraphs,
};
let params = BuildParams {
include_impossible: false,
base_seed: 0xAA,
rng_seed: 0xBB,
};
let refs = beir::ordered_question_refs_beir(&dataset, &params, 8)?;
let mut per_prefix: HashMap<String, usize> = HashMap::new();
for (p_idx, q_idx) in refs {
let question = &dataset.paragraphs[p_idx].questions[q_idx];
let prefix = beir::question_prefix(&question.id).unwrap_or("unknown");
*per_prefix.entry(prefix.to_string()).or_default() += 1;
}
assert_eq!(per_prefix.get("fever").copied().unwrap_or(0), 1);
assert_eq!(per_prefix.get("fiqa").copied().unwrap_or(0), 2);
assert_eq!(per_prefix.get("hotpotqa").copied().unwrap_or(0), 1);
assert_eq!(per_prefix.get("nfcorpus").copied().unwrap_or(0), 0);
assert_eq!(per_prefix.get("quora").copied().unwrap_or(0), 2);
assert_eq!(per_prefix.get("trec-covid").copied().unwrap_or(0), 2);
Ok(())
}
}