mirror of
https://github.com/perstarkse/minne.git
synced 2026-06-25 03:16:26 +02:00
1191 lines
37 KiB
Rust
1191 lines
37 KiB
Rust
use std::{
|
|
collections::{HashMap, HashSet},
|
|
fmt::Write,
|
|
fs,
|
|
path::{Path, PathBuf},
|
|
};
|
|
|
|
use anyhow::{Context, Result, anyhow};
|
|
use chrono::{DateTime, Utc};
|
|
use rand::{SeedableRng, rngs::StdRng, seq::SliceRandom};
|
|
use serde::{Deserialize, Serialize};
|
|
use sha2::{Digest, Sha256};
|
|
use tracing::{info, warn};
|
|
|
|
use crate::{
|
|
args::Config,
|
|
datasets::{ConvertedDataset, ConvertedParagraph, ConvertedQuestion, DatasetKind},
|
|
};
|
|
|
|
mod beir;
|
|
mod build;
|
|
|
|
use build::{BuildParams, mix_seed};
|
|
|
|
const SLICE_VERSION: u32 = 2;
|
|
pub const DEFAULT_NEGATIVE_MULTIPLIER: f32 = 4.0;
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct SliceConfig<'a> {
|
|
pub cache_dir: &'a Path,
|
|
pub force_convert: bool,
|
|
pub explicit_slice: Option<&'a str>,
|
|
pub limit: Option<usize>,
|
|
pub corpus_limit: Option<usize>,
|
|
pub slice_seed: u64,
|
|
pub llm_mode: bool,
|
|
pub negative_multiplier: f32,
|
|
pub require_verified_chunks: bool,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct SliceManifest {
|
|
pub version: u32,
|
|
pub slice_id: String,
|
|
pub dataset_id: String,
|
|
pub dataset_label: String,
|
|
pub dataset_source: String,
|
|
pub includes_unanswerable: bool,
|
|
#[serde(default = "default_require_verified_chunks")]
|
|
pub require_verified_chunks: bool,
|
|
pub seed: u64,
|
|
pub requested_limit: Option<usize>,
|
|
pub requested_corpus: usize,
|
|
pub generated_at: DateTime<Utc>,
|
|
pub case_count: usize,
|
|
pub positive_paragraphs: usize,
|
|
pub negative_paragraphs: usize,
|
|
pub total_paragraphs: usize,
|
|
pub negative_multiplier: f32,
|
|
pub cases: Vec<SliceCaseEntry>,
|
|
pub paragraphs: Vec<SliceParagraphEntry>,
|
|
}
|
|
|
|
fn default_require_verified_chunks() -> bool {
|
|
false
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct SliceCaseEntry {
|
|
pub question_id: String,
|
|
pub paragraph_id: String,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct SliceParagraphEntry {
|
|
pub id: String,
|
|
pub kind: SliceParagraphKind,
|
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
pub shard_path: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[serde(tag = "kind", rename_all = "snake_case")]
|
|
pub enum SliceParagraphKind {
|
|
Positive { question_ids: Vec<String> },
|
|
Negative,
|
|
}
|
|
|
|
pub fn paragraph_storage_key(paragraph_id: &str) -> String {
|
|
sanitize_identifier(paragraph_id)
|
|
}
|
|
|
|
pub(crate) fn default_shard_path(paragraph_id: &str) -> String {
|
|
let sanitized = paragraph_storage_key(paragraph_id);
|
|
format!("paragraphs/{sanitized}.json")
|
|
}
|
|
|
|
fn sanitize_identifier(input: &str) -> String {
|
|
let mut sanitized = String::with_capacity(input.len());
|
|
for ch in input.chars() {
|
|
if ch.is_ascii_alphanumeric() {
|
|
sanitized.push(ch);
|
|
} else {
|
|
sanitized.push('-');
|
|
}
|
|
}
|
|
let trimmed = sanitized.trim_matches('-').to_string();
|
|
if trimmed.is_empty() {
|
|
let mut hasher = Sha256::new();
|
|
hasher.update(input.as_bytes());
|
|
let digest = hasher.finalize();
|
|
digest
|
|
.iter()
|
|
.take(6)
|
|
.fold(String::with_capacity(12), |mut s, b| {
|
|
let _ = write!(s, "{b:02x}");
|
|
s
|
|
})
|
|
} else {
|
|
trimmed
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct ResolvedSlice<'a> {
|
|
pub manifest: SliceManifest,
|
|
pub path: PathBuf,
|
|
pub paragraphs: Vec<&'a ConvertedParagraph>,
|
|
pub cases: Vec<CaseRef<'a>>,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct SliceWindow<'a> {
|
|
pub offset: usize,
|
|
pub length: usize,
|
|
pub total_cases: usize,
|
|
pub cases: Vec<CaseRef<'a>>,
|
|
positive_paragraph_ids: Vec<String>,
|
|
}
|
|
|
|
impl SliceWindow<'_> {
|
|
pub fn positive_ids(&self) -> impl Iterator<Item = &str> {
|
|
self.positive_paragraph_ids
|
|
.iter()
|
|
.map(std::string::String::as_str)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct CaseRef<'a> {
|
|
pub paragraph: &'a ConvertedParagraph,
|
|
pub question: &'a ConvertedQuestion,
|
|
}
|
|
|
|
struct DatasetIndex {
|
|
paragraph_by_id: HashMap<String, usize>,
|
|
question_by_id: HashMap<String, (usize, usize)>,
|
|
}
|
|
|
|
impl DatasetIndex {
|
|
fn build(dataset: &ConvertedDataset) -> Self {
|
|
let mut paragraph_by_id = HashMap::new();
|
|
let mut question_by_id = HashMap::new();
|
|
|
|
for (p_idx, paragraph) in dataset.paragraphs.iter().enumerate() {
|
|
paragraph_by_id.insert(paragraph.id.clone(), p_idx);
|
|
for (q_idx, question) in paragraph.questions.iter().enumerate() {
|
|
question_by_id.insert(question.id.clone(), (p_idx, q_idx));
|
|
}
|
|
}
|
|
|
|
Self {
|
|
paragraph_by_id,
|
|
question_by_id,
|
|
}
|
|
}
|
|
|
|
fn paragraph<'a>(
|
|
&self,
|
|
dataset: &'a ConvertedDataset,
|
|
id: &str,
|
|
) -> Result<&'a ConvertedParagraph> {
|
|
let idx = self
|
|
.paragraph_by_id
|
|
.get(id)
|
|
.ok_or_else(|| anyhow!("slice references unknown paragraph '{id}'"))?;
|
|
dataset
|
|
.paragraphs
|
|
.get(*idx)
|
|
.ok_or_else(|| anyhow!("paragraph index out of bounds"))
|
|
}
|
|
|
|
fn question<'a>(
|
|
&self,
|
|
dataset: &'a ConvertedDataset,
|
|
question_id: &str,
|
|
) -> Result<(&'a ConvertedParagraph, &'a ConvertedQuestion)> {
|
|
let (p_idx, q_idx) = self
|
|
.question_by_id
|
|
.get(question_id)
|
|
.ok_or_else(|| anyhow!("slice references unknown question '{question_id}'"))?;
|
|
let paragraph = dataset
|
|
.paragraphs
|
|
.get(*p_idx)
|
|
.ok_or_else(|| anyhow!("paragraph index out of bounds for question '{question_id}'"))?;
|
|
let question = paragraph
|
|
.questions
|
|
.get(*q_idx)
|
|
.ok_or_else(|| anyhow!("slice maps question '{question_id}' to missing index"))?;
|
|
Ok((paragraph, question))
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
struct SliceKey<'a> {
|
|
dataset_id: &'a str,
|
|
includes_unanswerable: bool,
|
|
require_verified_chunks: bool,
|
|
requested_corpus: usize,
|
|
seed: u64,
|
|
}
|
|
|
|
#[allow(clippy::too_many_lines)]
|
|
pub fn resolve_slice<'a>(
|
|
dataset: &'a ConvertedDataset,
|
|
config: &SliceConfig<'_>,
|
|
) -> Result<ResolvedSlice<'a>> {
|
|
let index = DatasetIndex::build(dataset);
|
|
|
|
if let Some(slice_arg) = config.explicit_slice {
|
|
let path = explicit_slice_path(dataset, config, slice_arg);
|
|
if path.exists() {
|
|
let (path, manifest) = load_explicit_slice(dataset, &index, config, slice_arg)?;
|
|
let resolved = manifest_to_resolved(dataset, &index, manifest, path)?;
|
|
info!(
|
|
slice = %resolved.manifest.slice_id,
|
|
path = %resolved.path.display(),
|
|
cases = resolved.manifest.case_count,
|
|
positives = resolved.manifest.positive_paragraphs,
|
|
negatives = resolved.manifest.negative_paragraphs,
|
|
"Using explicitly selected slice"
|
|
);
|
|
return Ok(resolved);
|
|
}
|
|
let resolved = materialize_slice_ledger(dataset, config, &index, slice_arg, path)?;
|
|
info!(
|
|
slice = %resolved.manifest.slice_id,
|
|
path = %resolved.path.display(),
|
|
cases = resolved.manifest.case_count,
|
|
positives = resolved.manifest.positive_paragraphs,
|
|
negatives = resolved.manifest.negative_paragraphs,
|
|
"Built catalog slice ledger"
|
|
);
|
|
return Ok(resolved);
|
|
}
|
|
|
|
let requested_corpus = config
|
|
.corpus_limit
|
|
.unwrap_or(dataset.paragraphs.len())
|
|
.min(dataset.paragraphs.len())
|
|
.max(1);
|
|
let key = SliceKey {
|
|
dataset_id: dataset.metadata.id.as_str(),
|
|
includes_unanswerable: config.llm_mode,
|
|
require_verified_chunks: config.require_verified_chunks,
|
|
requested_corpus,
|
|
seed: config.slice_seed,
|
|
};
|
|
let slice_id = compute_slice_id(&key)?;
|
|
let base = config
|
|
.cache_dir
|
|
.join("slices")
|
|
.join(dataset.metadata.id.as_str());
|
|
let path = base.join(format!("{slice_id}.json"));
|
|
materialize_slice_ledger(dataset, config, &index, &slice_id, path)
|
|
}
|
|
|
|
#[allow(clippy::indexing_slicing, clippy::arithmetic_side_effects)]
|
|
pub fn select_window<'a>(
|
|
resolved: &'a ResolvedSlice<'a>,
|
|
offset: usize,
|
|
limit: Option<usize>,
|
|
) -> Result<SliceWindow<'a>> {
|
|
let total = resolved.manifest.case_count;
|
|
if total == 0 {
|
|
return Err(anyhow!(
|
|
"slice '{}' contains no cases",
|
|
resolved.manifest.slice_id
|
|
));
|
|
}
|
|
if offset >= total {
|
|
return Err(anyhow!(
|
|
"slice offset {offset} exceeds available cases ({total})",
|
|
));
|
|
}
|
|
let available = total - offset;
|
|
let requested = limit.unwrap_or(available).max(1);
|
|
let length = requested.min(available);
|
|
let cases = resolved.cases[offset..offset + length].to_vec();
|
|
let mut seen = HashSet::new();
|
|
let mut positive_ids = Vec::new();
|
|
for case in &cases {
|
|
if seen.insert(case.paragraph.id.as_str()) {
|
|
positive_ids.push(case.paragraph.id.clone());
|
|
}
|
|
}
|
|
Ok(SliceWindow {
|
|
offset,
|
|
length,
|
|
total_cases: total,
|
|
cases,
|
|
positive_paragraph_ids: positive_ids,
|
|
})
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
pub fn full_window<'a>(resolved: &'a ResolvedSlice<'a>) -> Result<SliceWindow<'a>> {
|
|
select_window(resolved, 0, None)
|
|
}
|
|
|
|
fn explicit_slice_path(
|
|
dataset: &ConvertedDataset,
|
|
config: &SliceConfig<'_>,
|
|
slice_arg: &str,
|
|
) -> PathBuf {
|
|
let explicit_path = Path::new(slice_arg);
|
|
if explicit_path.exists() {
|
|
explicit_path.to_path_buf()
|
|
} else {
|
|
config
|
|
.cache_dir
|
|
.join("slices")
|
|
.join(dataset.metadata.id.as_str())
|
|
.join(format!("{slice_arg}.json"))
|
|
}
|
|
}
|
|
|
|
#[allow(clippy::too_many_lines)]
|
|
fn materialize_slice_ledger<'a>(
|
|
dataset: &'a ConvertedDataset,
|
|
config: &SliceConfig<'_>,
|
|
index: &DatasetIndex,
|
|
slice_id: &str,
|
|
path: PathBuf,
|
|
) -> Result<ResolvedSlice<'a>> {
|
|
let requested_corpus = config
|
|
.corpus_limit
|
|
.unwrap_or(dataset.paragraphs.len())
|
|
.min(dataset.paragraphs.len())
|
|
.max(1);
|
|
|
|
let total_questions = dataset
|
|
.paragraphs
|
|
.iter()
|
|
.map(|p| p.questions.len())
|
|
.sum::<usize>()
|
|
.max(1);
|
|
let requested_limit = config
|
|
.limit
|
|
.unwrap_or(total_questions)
|
|
.min(total_questions)
|
|
.max(1);
|
|
|
|
let mut manifest = if !config.force_convert && path.exists() {
|
|
match read_manifest(&path) {
|
|
Ok(manifest) if manifest.dataset_id == dataset.metadata.id => {
|
|
if manifest.includes_unanswerable != config.llm_mode {
|
|
warn!(
|
|
slice = manifest.slice_id,
|
|
path = %path.display(),
|
|
expected = config.llm_mode,
|
|
found = manifest.includes_unanswerable,
|
|
"Slice manifest includes_unanswerable mismatch; regenerating"
|
|
);
|
|
None
|
|
} else if manifest.require_verified_chunks != config.require_verified_chunks {
|
|
warn!(
|
|
slice = manifest.slice_id,
|
|
path = %path.display(),
|
|
expected = config.require_verified_chunks,
|
|
found = manifest.require_verified_chunks,
|
|
"Slice manifest verified-chunk requirement mismatch; regenerating"
|
|
);
|
|
None
|
|
} else {
|
|
Some(manifest)
|
|
}
|
|
}
|
|
Ok(manifest) => {
|
|
warn!(
|
|
slice = manifest.slice_id,
|
|
path = %path.display(),
|
|
loaded_dataset = %manifest.dataset_id,
|
|
expected = %dataset.metadata.id,
|
|
"Slice manifest targets different dataset; regenerating"
|
|
);
|
|
None
|
|
}
|
|
Err(err) => {
|
|
warn!(
|
|
path = %path.display(),
|
|
error = %err,
|
|
"Failed to read cached slice; regenerating"
|
|
);
|
|
None
|
|
}
|
|
}
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let params = BuildParams {
|
|
include_impossible: config.llm_mode,
|
|
base_seed: config.slice_seed,
|
|
rng_seed: mix_seed(dataset.metadata.id.as_str(), config.slice_seed),
|
|
};
|
|
|
|
if manifest
|
|
.as_ref()
|
|
.is_some_and(|manifest| manifest.version != SLICE_VERSION)
|
|
{
|
|
warn!(
|
|
slice = manifest.as_ref().map_or("unknown", |m| m.slice_id.as_str()),
|
|
found = manifest.as_ref().map_or(0, |m| m.version),
|
|
expected = SLICE_VERSION,
|
|
"Slice manifest version mismatch; regenerating"
|
|
);
|
|
manifest = None;
|
|
}
|
|
|
|
let mut manifest = manifest.unwrap_or_else(|| {
|
|
empty_manifest(
|
|
dataset,
|
|
slice_id.to_string(),
|
|
¶ms,
|
|
requested_corpus,
|
|
config.negative_multiplier,
|
|
config.require_verified_chunks,
|
|
config.limit,
|
|
)
|
|
});
|
|
|
|
manifest.requested_limit = config.limit;
|
|
manifest.requested_corpus = requested_corpus;
|
|
manifest.negative_multiplier = config.negative_multiplier;
|
|
manifest.includes_unanswerable = config.llm_mode;
|
|
manifest.require_verified_chunks = config.require_verified_chunks;
|
|
|
|
let mut changed = ensure_shard_paths(&mut manifest);
|
|
|
|
changed |= ensure_case_capacity(dataset, &mut manifest, ¶ms, requested_limit)?;
|
|
refresh_manifest_stats(&mut manifest);
|
|
|
|
let desired_negatives = desired_negative_target(
|
|
manifest.positive_paragraphs,
|
|
requested_corpus,
|
|
dataset.paragraphs.len(),
|
|
config.negative_multiplier,
|
|
);
|
|
changed |= ensure_negative_pool(
|
|
dataset,
|
|
&mut manifest,
|
|
¶ms,
|
|
desired_negatives,
|
|
requested_corpus,
|
|
);
|
|
refresh_manifest_stats(&mut manifest);
|
|
|
|
if changed {
|
|
manifest.generated_at = Utc::now();
|
|
write_manifest(&path, &manifest)?;
|
|
info!(
|
|
slice = %manifest.slice_id,
|
|
path = %path.display(),
|
|
cases = manifest.case_count,
|
|
positives = manifest.positive_paragraphs,
|
|
negatives = manifest.negative_paragraphs,
|
|
"Updated dataset slice ledger"
|
|
);
|
|
} else {
|
|
info!(
|
|
slice = %manifest.slice_id,
|
|
path = %path.display(),
|
|
cases = manifest.case_count,
|
|
positives = manifest.positive_paragraphs,
|
|
negatives = manifest.negative_paragraphs,
|
|
"Reusing cached slice ledger"
|
|
);
|
|
}
|
|
|
|
manifest_to_resolved(dataset, index, manifest, path)
|
|
}
|
|
|
|
fn load_explicit_slice(
|
|
dataset: &ConvertedDataset,
|
|
index: &DatasetIndex,
|
|
config: &SliceConfig<'_>,
|
|
slice_arg: &str,
|
|
) -> Result<(PathBuf, SliceManifest)> {
|
|
let candidate_path = explicit_slice_path(dataset, config, slice_arg);
|
|
|
|
let manifest = read_manifest(&candidate_path)
|
|
.with_context(|| format!("reading slice manifest at {}", candidate_path.display()))?;
|
|
|
|
if manifest.dataset_id != dataset.metadata.id {
|
|
return Err(anyhow!(
|
|
"slice '{}' targets dataset '{}', but '{}' is loaded",
|
|
manifest.slice_id,
|
|
manifest.dataset_id,
|
|
dataset.metadata.id
|
|
));
|
|
}
|
|
if manifest.includes_unanswerable != config.llm_mode {
|
|
return Err(anyhow!(
|
|
"slice '{}' includes_unanswerable mismatch (expected {}, found {})",
|
|
manifest.slice_id,
|
|
config.llm_mode,
|
|
manifest.includes_unanswerable
|
|
));
|
|
}
|
|
if manifest.require_verified_chunks != config.require_verified_chunks {
|
|
return Err(anyhow!(
|
|
"slice '{}' verified-chunk requirement mismatch (expected {}, found {})",
|
|
manifest.slice_id,
|
|
config.require_verified_chunks,
|
|
manifest.require_verified_chunks
|
|
));
|
|
}
|
|
|
|
// Validate the manifest before returning.
|
|
manifest_to_resolved(dataset, index, manifest.clone(), candidate_path.clone())?;
|
|
|
|
Ok((candidate_path, manifest))
|
|
}
|
|
|
|
fn empty_manifest(
|
|
dataset: &ConvertedDataset,
|
|
slice_id: String,
|
|
params: &BuildParams,
|
|
requested_corpus: usize,
|
|
negative_multiplier: f32,
|
|
require_verified_chunks: bool,
|
|
requested_limit: Option<usize>,
|
|
) -> SliceManifest {
|
|
SliceManifest {
|
|
version: SLICE_VERSION,
|
|
slice_id,
|
|
dataset_id: dataset.metadata.id.clone(),
|
|
dataset_label: dataset.metadata.label.clone(),
|
|
dataset_source: dataset.source.clone(),
|
|
includes_unanswerable: params.include_impossible,
|
|
require_verified_chunks,
|
|
seed: params.base_seed,
|
|
requested_limit,
|
|
requested_corpus,
|
|
negative_multiplier,
|
|
generated_at: Utc::now(),
|
|
case_count: 0,
|
|
positive_paragraphs: 0,
|
|
negative_paragraphs: 0,
|
|
total_paragraphs: 0,
|
|
cases: Vec::new(),
|
|
paragraphs: Vec::new(),
|
|
}
|
|
}
|
|
|
|
#[allow(clippy::indexing_slicing, clippy::arithmetic_side_effects)]
|
|
fn ensure_case_capacity(
|
|
dataset: &ConvertedDataset,
|
|
manifest: &mut SliceManifest,
|
|
params: &BuildParams,
|
|
target_cases: usize,
|
|
) -> Result<bool> {
|
|
if manifest.case_count >= target_cases {
|
|
return Ok(false);
|
|
}
|
|
|
|
let question_refs = ordered_question_refs(dataset, params, target_cases)?;
|
|
let mut existing_questions: HashSet<String> = manifest
|
|
.cases
|
|
.iter()
|
|
.map(|case| case.question_id.clone())
|
|
.collect();
|
|
let mut paragraph_positions: HashMap<String, usize> = manifest
|
|
.paragraphs
|
|
.iter()
|
|
.enumerate()
|
|
.map(|(idx, entry)| (entry.id.clone(), idx))
|
|
.collect();
|
|
|
|
let mut changed = false;
|
|
|
|
for (p_idx, q_idx) in question_refs {
|
|
if manifest.case_count >= target_cases {
|
|
break;
|
|
}
|
|
let paragraph = &dataset.paragraphs[p_idx];
|
|
let question = ¶graph.questions[q_idx];
|
|
if !existing_questions.insert(question.id.clone()) {
|
|
continue;
|
|
}
|
|
|
|
if let Some(idx) = paragraph_positions.get(paragraph.id.as_str()).copied() {
|
|
match &mut manifest.paragraphs[idx].kind {
|
|
SliceParagraphKind::Positive { question_ids } => {
|
|
if !question_ids.contains(&question.id) {
|
|
question_ids.push(question.id.clone());
|
|
}
|
|
}
|
|
SliceParagraphKind::Negative => {
|
|
manifest.paragraphs[idx].kind = SliceParagraphKind::Positive {
|
|
question_ids: vec![question.id.clone()],
|
|
};
|
|
}
|
|
}
|
|
} else {
|
|
manifest.paragraphs.push(SliceParagraphEntry {
|
|
id: paragraph.id.clone(),
|
|
kind: SliceParagraphKind::Positive {
|
|
question_ids: vec![question.id.clone()],
|
|
},
|
|
shard_path: Some(default_shard_path(¶graph.id)),
|
|
});
|
|
let idx = manifest.paragraphs.len() - 1;
|
|
paragraph_positions.insert(paragraph.id.clone(), idx);
|
|
}
|
|
|
|
manifest.cases.push(SliceCaseEntry {
|
|
question_id: question.id.clone(),
|
|
paragraph_id: paragraph.id.clone(),
|
|
});
|
|
manifest.case_count += 1;
|
|
changed = true;
|
|
}
|
|
|
|
if manifest.case_count < target_cases {
|
|
return Err(anyhow!(
|
|
"only {}/{} eligible questions available for dataset {}",
|
|
manifest.case_count,
|
|
target_cases,
|
|
dataset.metadata.id
|
|
));
|
|
}
|
|
|
|
Ok(changed)
|
|
}
|
|
|
|
fn ordered_question_refs(
|
|
dataset: &ConvertedDataset,
|
|
params: &BuildParams,
|
|
target_cases: usize,
|
|
) -> Result<Vec<(usize, usize)>> {
|
|
if dataset.metadata.id == DatasetKind::Beir.id() {
|
|
return beir::ordered_question_refs_beir(dataset, params, target_cases);
|
|
}
|
|
|
|
let mut question_refs = Vec::new();
|
|
for (p_idx, paragraph) in dataset.paragraphs.iter().enumerate() {
|
|
for (q_idx, question) in paragraph.questions.iter().enumerate() {
|
|
let include = if params.include_impossible {
|
|
true
|
|
} else {
|
|
!question.is_impossible && !question.answers.is_empty()
|
|
};
|
|
if include {
|
|
question_refs.push((p_idx, q_idx));
|
|
}
|
|
}
|
|
}
|
|
|
|
if question_refs.is_empty() {
|
|
return Err(anyhow!(
|
|
"no eligible questions found for dataset {}; cannot build slice",
|
|
dataset.metadata.id
|
|
));
|
|
}
|
|
|
|
let mut rng = StdRng::seed_from_u64(params.rng_seed);
|
|
question_refs.shuffle(&mut rng);
|
|
Ok(question_refs)
|
|
}
|
|
|
|
#[allow(clippy::indexing_slicing)]
|
|
fn ensure_negative_pool(
|
|
dataset: &ConvertedDataset,
|
|
manifest: &mut SliceManifest,
|
|
params: &BuildParams,
|
|
target_negatives: usize,
|
|
requested_corpus: usize,
|
|
) -> bool {
|
|
let current_negatives = manifest
|
|
.paragraphs
|
|
.iter()
|
|
.filter(|entry| matches!(entry.kind, SliceParagraphKind::Negative))
|
|
.count();
|
|
if current_negatives >= target_negatives {
|
|
return false;
|
|
}
|
|
|
|
let positive_ids: HashSet<String> = manifest
|
|
.paragraphs
|
|
.iter()
|
|
.filter_map(|entry| match entry.kind {
|
|
SliceParagraphKind::Positive { .. } => Some(entry.id.clone()),
|
|
SliceParagraphKind::Negative => None,
|
|
})
|
|
.collect();
|
|
let mut negative_ids: HashSet<String> = manifest
|
|
.paragraphs
|
|
.iter()
|
|
.filter_map(|entry| match entry.kind {
|
|
SliceParagraphKind::Negative => Some(entry.id.clone()),
|
|
SliceParagraphKind::Positive { .. } => None,
|
|
})
|
|
.collect();
|
|
|
|
let negative_seed = mix_seed(
|
|
&format!("{}::negatives", dataset.metadata.id),
|
|
params.base_seed,
|
|
);
|
|
let candidates = ordered_negative_indices(dataset, &positive_ids, negative_seed);
|
|
let mut added = false;
|
|
for idx in candidates {
|
|
if negative_ids.len() >= target_negatives {
|
|
break;
|
|
}
|
|
let paragraph = &dataset.paragraphs[idx];
|
|
if negative_ids.contains(paragraph.id.as_str())
|
|
|| positive_ids.contains(paragraph.id.as_str())
|
|
{
|
|
continue;
|
|
}
|
|
manifest.paragraphs.push(SliceParagraphEntry {
|
|
id: paragraph.id.clone(),
|
|
kind: SliceParagraphKind::Negative,
|
|
shard_path: Some(default_shard_path(¶graph.id)),
|
|
});
|
|
negative_ids.insert(paragraph.id.clone());
|
|
added = true;
|
|
}
|
|
|
|
if negative_ids.len() < target_negatives {
|
|
warn!(
|
|
dataset = %dataset.metadata.id,
|
|
desired = target_negatives,
|
|
available = negative_ids.len(),
|
|
requested_corpus,
|
|
"Insufficient negative paragraphs to satisfy multiplier"
|
|
);
|
|
}
|
|
|
|
added
|
|
}
|
|
|
|
fn ordered_negative_indices(
|
|
dataset: &ConvertedDataset,
|
|
positive_ids: &HashSet<String>,
|
|
rng_seed: u64,
|
|
) -> Vec<usize> {
|
|
let mut candidates: Vec<usize> = dataset
|
|
.paragraphs
|
|
.iter()
|
|
.enumerate()
|
|
.filter_map(|(idx, paragraph)| {
|
|
if positive_ids.contains(paragraph.id.as_str()) {
|
|
None
|
|
} else {
|
|
Some(idx)
|
|
}
|
|
})
|
|
.collect();
|
|
let mut rng = StdRng::seed_from_u64(rng_seed);
|
|
candidates.shuffle(&mut rng);
|
|
candidates
|
|
}
|
|
|
|
fn refresh_manifest_stats(manifest: &mut SliceManifest) {
|
|
manifest.case_count = manifest.cases.len();
|
|
manifest.positive_paragraphs = manifest
|
|
.paragraphs
|
|
.iter()
|
|
.filter(|entry| matches!(entry.kind, SliceParagraphKind::Positive { .. }))
|
|
.count();
|
|
manifest.negative_paragraphs = manifest
|
|
.paragraphs
|
|
.iter()
|
|
.filter(|entry| matches!(entry.kind, SliceParagraphKind::Negative))
|
|
.count();
|
|
manifest.total_paragraphs = manifest.paragraphs.len();
|
|
}
|
|
|
|
fn ensure_shard_paths(manifest: &mut SliceManifest) -> bool {
|
|
let mut changed = false;
|
|
for entry in &mut manifest.paragraphs {
|
|
if entry.shard_path.is_none() {
|
|
entry.shard_path = Some(default_shard_path(&entry.id));
|
|
changed = true;
|
|
}
|
|
}
|
|
changed
|
|
}
|
|
|
|
#[allow(
|
|
clippy::cast_precision_loss,
|
|
clippy::cast_possible_truncation,
|
|
clippy::cast_sign_loss
|
|
)]
|
|
fn desired_negative_target(
|
|
positive_count: usize,
|
|
requested_corpus: usize,
|
|
dataset_paragraphs: usize,
|
|
multiplier: f32,
|
|
) -> usize {
|
|
if positive_count == 0 {
|
|
return 0;
|
|
}
|
|
let ratio = multiplier.max(0.0);
|
|
let mut desired = ((positive_count as f32) * ratio).ceil() as usize;
|
|
let max_total = requested_corpus.min(dataset_paragraphs).max(positive_count);
|
|
let max_negatives = max_total.saturating_sub(positive_count);
|
|
desired = desired.min(max_negatives);
|
|
desired
|
|
}
|
|
|
|
fn manifest_to_resolved<'a>(
|
|
dataset: &'a ConvertedDataset,
|
|
index: &DatasetIndex,
|
|
manifest: SliceManifest,
|
|
path: PathBuf,
|
|
) -> Result<ResolvedSlice<'a>> {
|
|
if manifest.version != SLICE_VERSION {
|
|
return Err(anyhow!(
|
|
"slice version {} does not match expected {}",
|
|
manifest.version,
|
|
SLICE_VERSION
|
|
));
|
|
}
|
|
|
|
let mut paragraphs = Vec::with_capacity(manifest.paragraphs.len());
|
|
for entry in &manifest.paragraphs {
|
|
let paragraph = index.paragraph(dataset, &entry.id)?;
|
|
if let SliceParagraphKind::Positive { question_ids } = &entry.kind {
|
|
for question_id in question_ids {
|
|
let (linked_paragraph, _) = index.question(dataset, question_id)?;
|
|
if linked_paragraph.id != entry.id {
|
|
return Err(anyhow!(
|
|
"slice question '{}' expected paragraph '{}', found '{}'",
|
|
question_id,
|
|
entry.id,
|
|
linked_paragraph.id
|
|
));
|
|
}
|
|
}
|
|
}
|
|
paragraphs.push(paragraph);
|
|
}
|
|
|
|
let mut cases = Vec::with_capacity(manifest.cases.len());
|
|
for entry in &manifest.cases {
|
|
let (paragraph, question) = index.question(dataset, &entry.question_id)?;
|
|
if paragraph.id != entry.paragraph_id {
|
|
return Err(anyhow!(
|
|
"slice case '{}' expected paragraph '{}', found '{}'",
|
|
entry.question_id,
|
|
entry.paragraph_id,
|
|
paragraph.id
|
|
));
|
|
}
|
|
cases.push(CaseRef {
|
|
paragraph,
|
|
question,
|
|
});
|
|
}
|
|
|
|
if cases.is_empty() {
|
|
return Err(anyhow!(
|
|
"slice '{}' contains no cases after validation",
|
|
manifest.slice_id
|
|
));
|
|
}
|
|
|
|
Ok(ResolvedSlice {
|
|
manifest,
|
|
path,
|
|
paragraphs,
|
|
cases,
|
|
})
|
|
}
|
|
|
|
fn compute_slice_id(key: &SliceKey<'_>) -> Result<String> {
|
|
let payload = serde_json::to_vec(key).context("SliceKey serialisation failed")?;
|
|
let mut hasher = Sha256::new();
|
|
hasher.update(payload);
|
|
let digest = hasher.finalize();
|
|
Ok(digest
|
|
.iter()
|
|
.take(16)
|
|
.fold(String::with_capacity(32), |mut s, b| {
|
|
let _ = write!(s, "{b:02x}");
|
|
s
|
|
}))
|
|
}
|
|
|
|
pub fn read_manifest_if_exists(path: &Path) -> Result<Option<SliceManifest>> {
|
|
if !path.exists() {
|
|
return Ok(None);
|
|
}
|
|
read_manifest(path).map(Some)
|
|
}
|
|
|
|
pub fn cached_manifest_path(config: &crate::args::Config) -> Option<PathBuf> {
|
|
let slice_arg = config.slice.as_deref()?;
|
|
let explicit_path = Path::new(slice_arg);
|
|
if explicit_path.exists() {
|
|
return Some(explicit_path.to_path_buf());
|
|
}
|
|
Some(
|
|
config
|
|
.cache_dir
|
|
.join("slices")
|
|
.join(config.dataset.id())
|
|
.join(format!("{slice_arg}.json")),
|
|
)
|
|
}
|
|
|
|
pub fn manifest_is_complete(manifest: &SliceManifest, config: &SliceConfig<'_>) -> bool {
|
|
let requested_limit = config.limit.unwrap_or(manifest.case_count.max(1)).max(1);
|
|
if manifest.case_count < requested_limit {
|
|
return false;
|
|
}
|
|
|
|
let requested_corpus = config
|
|
.corpus_limit
|
|
.unwrap_or(manifest.total_paragraphs.max(1))
|
|
.max(1);
|
|
let desired_negatives = desired_negative_target(
|
|
manifest.positive_paragraphs,
|
|
requested_corpus,
|
|
manifest
|
|
.total_paragraphs
|
|
.max(manifest.positive_paragraphs.max(1)),
|
|
config.negative_multiplier,
|
|
);
|
|
manifest.negative_paragraphs >= desired_negatives
|
|
}
|
|
|
|
fn read_manifest(path: &Path) -> Result<SliceManifest> {
|
|
let raw = fs::read_to_string(path)
|
|
.with_context(|| format!("reading slice manifest {}", path.display()))?;
|
|
let manifest: SliceManifest = serde_json::from_str(&raw)
|
|
.with_context(|| format!("parsing slice manifest {}", path.display()))?;
|
|
Ok(manifest)
|
|
}
|
|
|
|
fn write_manifest(path: &Path, manifest: &SliceManifest) -> Result<()> {
|
|
if let Some(parent) = path.parent() {
|
|
fs::create_dir_all(parent)
|
|
.with_context(|| format!("creating slice directory {}", parent.display()))?;
|
|
}
|
|
let json = serde_json::to_vec_pretty(manifest).context("serialising slice manifest to JSON")?;
|
|
fs::write(path, json).with_context(|| format!("writing slice manifest {}", path.display()))?;
|
|
Ok(())
|
|
}
|
|
|
|
pub fn ledger_target(config: &Config) -> Option<usize> {
|
|
match (config.slice_grow, config.limit) {
|
|
(Some(grow), Some(limit)) => Some(limit.max(grow)),
|
|
(Some(grow), None) => Some(grow),
|
|
(None, limit) => limit,
|
|
}
|
|
}
|
|
|
|
/// Grow the slice ledger to contain the target number of cases.
|
|
pub fn grow_slice(dataset: &ConvertedDataset, config: &Config) -> Result<()> {
|
|
let ledger_limit = ledger_target(config);
|
|
let slice_settings = slice_config_with_limit(config, ledger_limit);
|
|
let slice = resolve_slice(dataset, &slice_settings).context("resolving dataset slice")?;
|
|
info!(
|
|
slice = slice.manifest.slice_id.as_str(),
|
|
cases = slice.manifest.case_count,
|
|
positives = slice.manifest.positive_paragraphs,
|
|
negatives = slice.manifest.negative_paragraphs,
|
|
total_paragraphs = slice.manifest.total_paragraphs,
|
|
"Slice ledger ready"
|
|
);
|
|
println!(
|
|
"Slice `{}` now contains {} questions ({} positives, {} negatives)",
|
|
slice.manifest.slice_id,
|
|
slice.manifest.case_count,
|
|
slice.manifest.positive_paragraphs,
|
|
slice.manifest.negative_paragraphs
|
|
);
|
|
Ok(())
|
|
}
|
|
|
|
pub fn slice_config_with_limit(config: &Config, limit_override: Option<usize>) -> SliceConfig<'_> {
|
|
SliceConfig {
|
|
cache_dir: config.cache_dir.as_path(),
|
|
force_convert: config.force_convert,
|
|
explicit_slice: config.slice.as_deref(),
|
|
limit: limit_override.or(config.limit),
|
|
corpus_limit: config.corpus_limit,
|
|
slice_seed: config.slice_seed,
|
|
llm_mode: config.llm_mode,
|
|
negative_multiplier: config.negative_multiplier,
|
|
require_verified_chunks: config.retrieval.require_verified_chunks,
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::datasets::{
|
|
ConvertedDataset, ConvertedParagraph, ConvertedQuestion, DatasetKind, DatasetMetadata,
|
|
};
|
|
use tempfile::tempdir;
|
|
|
|
fn sample_dataset() -> ConvertedDataset {
|
|
let metadata = DatasetMetadata::for_kind(DatasetKind::SquadV2, false);
|
|
ConvertedDataset {
|
|
generated_at: Utc::now(),
|
|
metadata,
|
|
source: "test-source".to_string(),
|
|
paragraphs: vec![
|
|
ConvertedParagraph {
|
|
id: "p1".to_string(),
|
|
title: "Alpha".to_string(),
|
|
context: "Alpha context".to_string(),
|
|
questions: vec![ConvertedQuestion {
|
|
id: "q1".to_string(),
|
|
question: "What is alpha?".to_string(),
|
|
answers: vec!["Alpha".to_string()],
|
|
is_impossible: false,
|
|
}],
|
|
},
|
|
ConvertedParagraph {
|
|
id: "p2".to_string(),
|
|
title: "Beta".to_string(),
|
|
context: "Beta context".to_string(),
|
|
questions: vec![ConvertedQuestion {
|
|
id: "q2".to_string(),
|
|
question: "What is beta?".to_string(),
|
|
answers: vec!["Beta".to_string()],
|
|
is_impossible: false,
|
|
}],
|
|
},
|
|
ConvertedParagraph {
|
|
id: "p3".to_string(),
|
|
title: "Gamma".to_string(),
|
|
context: "Gamma context".to_string(),
|
|
questions: vec![ConvertedQuestion {
|
|
id: "q3".to_string(),
|
|
question: "What is gamma?".to_string(),
|
|
answers: vec!["Gamma".to_string()],
|
|
is_impossible: false,
|
|
}],
|
|
},
|
|
],
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn resolve_slice_reuses_cached_manifest() -> Result<()> {
|
|
let dataset = sample_dataset();
|
|
let temp = tempdir().context("creating temp directory")?;
|
|
|
|
let mut config = SliceConfig {
|
|
cache_dir: temp.path(),
|
|
force_convert: false,
|
|
explicit_slice: None,
|
|
limit: Some(2),
|
|
corpus_limit: Some(3),
|
|
slice_seed: 0x5eed_2025,
|
|
llm_mode: false,
|
|
negative_multiplier: DEFAULT_NEGATIVE_MULTIPLIER,
|
|
require_verified_chunks: true,
|
|
};
|
|
|
|
let first = resolve_slice(&dataset, &config)?;
|
|
assert!(first.path.exists());
|
|
let initial_generated = first.manifest.generated_at;
|
|
|
|
let second = resolve_slice(&dataset, &config)?;
|
|
assert_eq!(first.manifest.slice_id, second.manifest.slice_id);
|
|
assert_eq!(initial_generated, second.manifest.generated_at);
|
|
|
|
config.force_convert = true;
|
|
let third = resolve_slice(&dataset, &config)?;
|
|
assert_eq!(first.manifest.slice_id, third.manifest.slice_id);
|
|
assert_ne!(third.manifest.generated_at, initial_generated);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
#[allow(clippy::indexing_slicing)]
|
|
fn select_window_yields_expected_cases() -> Result<()> {
|
|
let dataset = sample_dataset();
|
|
let temp = tempdir().context("creating temp directory")?;
|
|
let config = SliceConfig {
|
|
cache_dir: temp.path(),
|
|
force_convert: false,
|
|
explicit_slice: None,
|
|
limit: Some(3),
|
|
corpus_limit: Some(3),
|
|
slice_seed: 0x5eed_2025,
|
|
llm_mode: false,
|
|
negative_multiplier: DEFAULT_NEGATIVE_MULTIPLIER,
|
|
require_verified_chunks: true,
|
|
};
|
|
let resolved = resolve_slice(&dataset, &config)?;
|
|
let window = select_window(&resolved, 1, Some(1))?;
|
|
assert_eq!(window.offset, 1);
|
|
assert_eq!(window.length, 1);
|
|
assert_eq!(window.total_cases, resolved.manifest.case_count);
|
|
assert_eq!(window.cases.len(), 1);
|
|
let positive_ids: Vec<&str> = window.positive_ids().collect();
|
|
assert_eq!(positive_ids.len(), 1);
|
|
assert!(
|
|
resolved
|
|
.manifest
|
|
.paragraphs
|
|
.iter()
|
|
.any(|entry| entry.id == positive_ids[0])
|
|
);
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
#[allow(clippy::indexing_slicing)]
|
|
fn beir_mix_balances_and_rebalances() -> Result<()> {
|
|
let mut paragraphs = Vec::new();
|
|
let counts = [
|
|
("fever", 1usize),
|
|
("fiqa", 2usize),
|
|
("hotpotqa", 1usize),
|
|
("nfcorpus", 0usize),
|
|
("quora", 3usize),
|
|
("trec-covid", 2usize),
|
|
];
|
|
|
|
for (prefix, count) in counts {
|
|
for idx in 0..count {
|
|
let q_id = format!("{prefix}-q{idx}");
|
|
paragraphs.push(ConvertedParagraph {
|
|
id: format!("{prefix}-p{idx}"),
|
|
title: format!("{prefix} title"),
|
|
context: format!("{prefix} context {idx}"),
|
|
questions: vec![ConvertedQuestion {
|
|
id: q_id,
|
|
question: format!("{prefix} question {idx}"),
|
|
answers: vec!["answer".to_string()],
|
|
is_impossible: false,
|
|
}],
|
|
});
|
|
}
|
|
}
|
|
|
|
let metadata = DatasetMetadata::for_kind(DatasetKind::Beir, false);
|
|
let dataset = ConvertedDataset {
|
|
generated_at: Utc::now(),
|
|
metadata,
|
|
source: "beir-mix".to_string(),
|
|
paragraphs,
|
|
};
|
|
|
|
let params = BuildParams {
|
|
include_impossible: false,
|
|
base_seed: 0xAA,
|
|
rng_seed: 0xBB,
|
|
};
|
|
|
|
let refs = beir::ordered_question_refs_beir(&dataset, ¶ms, 8)?;
|
|
let mut per_prefix: HashMap<String, usize> = HashMap::new();
|
|
for (p_idx, q_idx) in refs {
|
|
let question = &dataset.paragraphs[p_idx].questions[q_idx];
|
|
let prefix = beir::question_prefix(&question.id).unwrap_or("unknown");
|
|
*per_prefix.entry(prefix.to_string()).or_default() += 1;
|
|
}
|
|
|
|
assert_eq!(per_prefix.get("fever").copied().unwrap_or(0), 1);
|
|
assert_eq!(per_prefix.get("fiqa").copied().unwrap_or(0), 2);
|
|
assert_eq!(per_prefix.get("hotpotqa").copied().unwrap_or(0), 1);
|
|
assert_eq!(per_prefix.get("nfcorpus").copied().unwrap_or(0), 0);
|
|
assert_eq!(per_prefix.get("quora").copied().unwrap_or(0), 2);
|
|
assert_eq!(per_prefix.get("trec-covid").copied().unwrap_or(0), 2);
|
|
|
|
Ok(())
|
|
}
|
|
}
|