evals: eval crate overhaul, simplification and performance improvements

This commit is contained in:
Per Stark
2026-06-17 19:23:11 +02:00
parent 3aca33569d
commit dbf8c91b1f
53 changed files with 2852 additions and 1831 deletions
+106 -26
View File
@@ -1,5 +1,5 @@
use std::{
collections::{BTreeMap, HashMap},
collections::{BTreeMap, HashMap, HashSet},
fs::File,
io::{BufRead, BufReader},
path::{Path, PathBuf},
@@ -47,20 +47,71 @@ struct QrelEntry {
score: i32,
}
/// Convert only documents that appear in qrels (the BEIR evaluation closed world).
#[allow(clippy::arithmetic_side_effects, clippy::indexing_slicing)]
pub fn convert_beir(raw_dir: &Path, dataset: DatasetKind) -> Result<Vec<ConvertedParagraph>> {
convert_beir_documents(raw_dir, dataset, None)
}
/// Convert a subset of qrels-world documents. `doc_ids` use corpus ids (unprefixed).
#[allow(
clippy::too_many_lines,
clippy::arithmetic_side_effects,
clippy::indexing_slicing
)]
pub fn convert_beir_documents(
raw_dir: &Path,
dataset: DatasetKind,
doc_ids: Option<&HashSet<String>>,
) -> Result<Vec<ConvertedParagraph>> {
let corpus_path = raw_dir.join("corpus.jsonl");
let queries_path = raw_dir.join("queries.jsonl");
let qrels_path = resolve_qrels_path(raw_dir)?;
let corpus = load_corpus(&corpus_path)?;
let queries = load_queries(&queries_path)?;
let qrels = load_qrels(&qrels_path)?;
let mut paragraphs = Vec::with_capacity(corpus.len());
let mut qrels_doc_ids = HashSet::new();
for entries in qrels.values() {
for entry in entries {
qrels_doc_ids.insert(entry.doc_id.clone());
}
}
let target_doc_ids: HashSet<String> = match doc_ids {
Some(ids) => ids
.iter()
.filter(|id| qrels_doc_ids.contains(*id))
.cloned()
.collect(),
None => qrels_doc_ids.clone(),
};
if target_doc_ids.is_empty() {
return Err(anyhow!(
"no qrels documents to convert for {} at {}",
dataset.id(),
raw_dir.display()
));
}
let corpus = load_corpus_filtered(&corpus_path, &target_doc_ids)?;
let mut doc_ids_sorted: Vec<String> = target_doc_ids.into_iter().collect();
doc_ids_sorted.sort();
let mut paragraphs = Vec::with_capacity(doc_ids_sorted.len());
let mut paragraph_index = HashMap::new();
for (doc_id, entry) in &corpus {
for doc_id in &doc_ids_sorted {
let Some(entry) = corpus.get(doc_id) else {
warn!(
doc_id = %doc_id,
dataset = %dataset.id(),
"Skipping qrels document missing from corpus"
);
continue;
};
let paragraph_id = format!("{}-{doc_id}", dataset.source_prefix());
let paragraph = ConvertedParagraph {
id: paragraph_id.clone(),
@@ -87,6 +138,12 @@ pub fn convert_beir(raw_dir: &Path, dataset: DatasetKind) -> Result<Vec<Converte
continue;
};
if let Some(filter) = doc_ids {
if !filter.contains(&best.doc_id) {
continue;
}
}
let Some(&paragraph_slot) = paragraph_index.get(&best.doc_id) else {
missing_docs += 1;
warn!(
@@ -106,7 +163,6 @@ pub fn convert_beir(raw_dir: &Path, dataset: DatasetKind) -> Result<Vec<Converte
);
continue;
};
let answers = vec![snippet];
let question_id = format!("{}-{query_id}", dataset.source_prefix());
paragraphs[paragraph_slot]
@@ -114,7 +170,7 @@ pub fn convert_beir(raw_dir: &Path, dataset: DatasetKind) -> Result<Vec<Converte
.push(ConvertedQuestion {
id: question_id,
question: query.text.clone(),
answers,
answers: vec![snippet],
is_impossible: false,
});
}
@@ -122,13 +178,23 @@ pub fn convert_beir(raw_dir: &Path, dataset: DatasetKind) -> Result<Vec<Converte
if missing_queries + missing_docs + skipped_answers > 0 {
warn!(
missing_queries,
missing_docs, skipped_answers, "Skipped some BEIR qrels entries during conversion"
missing_docs,
skipped_answers,
dataset = %dataset.id(),
"Skipped some BEIR qrels entries during conversion"
);
}
Ok(paragraphs)
}
pub fn corpus_doc_id(paragraph_id: &str, dataset: DatasetKind) -> Option<String> {
let prefix = format!("{}-", dataset.source_prefix());
paragraph_id
.strip_prefix(&prefix)
.map(str::to_string)
}
fn resolve_qrels_path(raw_dir: &Path) -> Result<PathBuf> {
let qrels_dir = raw_dir.join("qrels");
let candidates = ["test.tsv", "dev.tsv", "train.tsv"];
@@ -148,7 +214,10 @@ fn resolve_qrels_path(raw_dir: &Path) -> Result<PathBuf> {
}
#[allow(clippy::arithmetic_side_effects)]
fn load_corpus(path: &Path) -> Result<BTreeMap<String, BeirParagraph>> {
fn load_corpus_filtered(
path: &Path,
doc_ids: &HashSet<String>,
) -> Result<BTreeMap<String, BeirParagraph>> {
let file =
File::open(path).with_context(|| format!("opening BEIR corpus at {}", path.display()))?;
let reader = BufReader::new(file);
@@ -167,6 +236,9 @@ fn load_corpus(path: &Path) -> Result<BTreeMap<String, BeirParagraph>> {
path.display()
)
})?;
if !doc_ids.contains(&corpus_row.id) {
continue;
}
let title = corpus_row.title.unwrap_or_else(|| corpus_row.id.clone());
let text = corpus_row.text.unwrap_or_default();
let context = build_context(&title, &text);
@@ -296,10 +368,8 @@ mod tests {
use std::fs;
use tempfile::tempdir;
#[test]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::indexing_slicing)]
fn converts_basic_beir_layout() {
let dir = tempdir().unwrap();
#[allow(clippy::unwrap_used)]
fn write_fixture(dir: &tempfile::TempDir) {
let corpus = r#"
{"_id":"d1","title":"Doc 1","text":"Doc one has some text for testing."}
{"_id":"d2","title":"Doc 2","text":"Second document content."}
@@ -313,24 +383,34 @@ mod tests {
fs::write(dir.path().join("queries.jsonl"), queries.trim()).unwrap();
fs::create_dir_all(dir.path().join("qrels")).unwrap();
fs::write(dir.path().join("qrels/test.tsv"), qrels).unwrap();
}
#[test]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::indexing_slicing)]
fn converts_qrels_world_only() {
let dir = tempdir().unwrap();
write_fixture(&dir);
let paragraphs = convert_beir(dir.path(), DatasetKind::Fever).unwrap();
assert_eq!(paragraphs.len(), 2);
let doc_one = paragraphs
.iter()
.find(|p| p.id == "fever-d1")
.expect("missing paragraph for d1");
assert_eq!(paragraphs.len(), 1);
let doc_one = &paragraphs[0];
assert_eq!(doc_one.id, "fever-d1");
assert_eq!(doc_one.questions.len(), 1);
let question = &doc_one.questions[0];
assert_eq!(question.id, "fever-q1");
assert!(!question.answers.is_empty());
assert!(doc_one.context.contains(&question.answers[0]));
assert_eq!(doc_one.questions[0].id, "fever-q1");
}
let doc_two = paragraphs
.iter()
.find(|p| p.id == "fever-d2")
.expect("missing paragraph for d2");
assert!(doc_two.questions.is_empty());
#[test]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::indexing_slicing)]
fn converts_filtered_doc_ids() {
let dir = tempdir().unwrap();
write_fixture(&dir);
let mut ids = HashSet::new();
ids.insert("d1".to_string());
let paragraphs =
convert_beir_documents(dir.path(), DatasetKind::Fever, Some(&ids)).unwrap();
assert_eq!(paragraphs.len(), 1);
assert_eq!(paragraphs[0].id, "fever-d1");
}
}
+262
View File
@@ -0,0 +1,262 @@
use std::collections::{HashMap, HashSet};
use anyhow::{anyhow, Context, Result};
use sha2::{Digest, Sha256};
use tracing::info;
use super::{
beir,
checksum::hash_file,
store::{
self, build_dataset_from_catalog, paragraph_path, read_meta, store_dir_for,
upsert_sharded_paragraphs, write_sharded,
},
BEIR_DATASETS, ConvertedDataset, DatasetKind, DatasetMetadata,
};
use crate::{
args::Config,
slice,
};
pub fn subset_for_paragraph_id(paragraph_id: &str) -> Option<DatasetKind> {
let mut kinds: Vec<DatasetKind> = BEIR_DATASETS.to_vec();
kinds.sort_by_key(|kind| std::cmp::Reverse(kind.source_prefix().len()));
for kind in kinds {
let prefix = format!("{}-", kind.source_prefix());
if paragraph_id.starts_with(&prefix) {
return Some(kind);
}
}
None
}
pub fn build_beir_mix_qrels_dataset(include_unanswerable: bool) -> Result<ConvertedDataset> {
if include_unanswerable {
tracing::warn!("BEIR mix ignores include_unanswerable flag; all questions are answerable");
}
let mut paragraphs = Vec::new();
for subset in BEIR_DATASETS {
let entry = super::dataset_entry_for_kind(subset)?;
let subset_paragraphs = beir::convert_beir(&entry.raw_path, subset)?;
paragraphs.extend(subset_paragraphs);
}
Ok(ConvertedDataset {
generated_at: super::base_timestamp(),
metadata: DatasetMetadata::for_kind(DatasetKind::Beir, include_unanswerable),
source: "beir-mix".to_string(),
paragraphs,
})
}
pub fn prepare_beir_mix(config: &Config) -> Result<super::loader::LoadedDataset> {
let virtual_ds = build_beir_mix_qrels_dataset(config.llm_mode)?;
let slice_config = slice::slice_config_with_limit(config, slice::ledger_target(config));
let resolved = slice::resolve_slice(&virtual_ds, &slice_config).context(
"resolving BEIR mix slice ledger (check --slice and --limit match your intent)",
)?;
let unique: HashSet<String> = resolved
.manifest
.paragraphs
.iter()
.map(|entry| entry.id.clone())
.collect();
materialize_subset_stores(&unique, config.force_convert)?;
let dataset = load_beir_mix_from_subsets(&unique)?;
let checksum = mix_content_checksum(&unique)?;
info!(
slice = resolved.manifest.slice_id.as_str(),
paragraphs = unique.len(),
checksum = %checksum,
"Prepared BEIR mix from per-subset converted stores"
);
Ok(super::loader::LoadedDataset {
dataset,
content_checksum: checksum,
partial: true,
})
}
pub fn materialize_subset_stores(
paragraph_ids: &HashSet<String>,
force: bool,
) -> Result<()> {
let mut by_subset: HashMap<DatasetKind, Vec<String>> = HashMap::new();
for paragraph_id in paragraph_ids {
let kind = subset_for_paragraph_id(paragraph_id).with_context(|| {
format!("routing BEIR mix paragraph id '{paragraph_id}' to subset store")
})?;
by_subset.entry(kind).or_default().push(paragraph_id.clone());
}
for (kind, ids) in by_subset {
let entry = super::dataset_entry_for_kind(kind)?;
let store_dir = store_dir_for(&entry.converted_path);
let existing = if store_dir.join("meta.json").is_file() {
store::load_paragraph_ids_set(&store_dir)?
} else {
HashSet::new()
};
let missing: Vec<String> = if force {
ids
} else {
ids.into_iter()
.filter(|paragraph_id| !existing.contains(paragraph_id))
.collect()
};
if missing.is_empty() {
continue;
}
let corpus_ids: HashSet<String> = missing
.iter()
.filter_map(|paragraph_id| beir::corpus_doc_id(paragraph_id, kind))
.collect();
let paragraphs = beir::convert_beir_documents(
&entry.raw_path,
kind,
Some(&corpus_ids),
)?;
if store_dir.join("meta.json").is_file() {
upsert_sharded_paragraphs(&store_dir, &paragraphs)?;
} else {
let question_count = paragraphs
.iter()
.map(|paragraph| paragraph.questions.len())
.sum::<usize>();
let dataset = ConvertedDataset {
generated_at: super::base_timestamp(),
metadata: DatasetMetadata::for_kind(kind, false),
source: entry.raw_path.display().to_string(),
paragraphs,
};
write_sharded(&dataset, &store_dir)?;
info!(
subset = kind.id(),
store = %store_dir.display(),
paragraphs = dataset.paragraphs.len(),
questions = question_count,
"Created subset converted store for BEIR mix"
);
}
}
Ok(())
}
pub fn load_beir_mix_from_subsets(paragraph_ids: &HashSet<String>) -> Result<ConvertedDataset> {
let mut by_subset: HashMap<DatasetKind, HashSet<String>> = HashMap::new();
for paragraph_id in paragraph_ids {
let kind = subset_for_paragraph_id(paragraph_id).with_context(|| {
format!("routing BEIR mix paragraph id '{paragraph_id}' to subset store")
})?;
by_subset
.entry(kind)
.or_default()
.insert(paragraph_id.clone());
}
let mut paragraphs = Vec::with_capacity(paragraph_ids.len());
for (kind, subset_ids) in by_subset {
let entry = super::dataset_entry_for_kind(kind)?;
let store_dir = store_dir_for(&entry.converted_path);
let partial = build_dataset_from_catalog(&store_dir, &subset_ids)?;
paragraphs.extend(partial.paragraphs);
}
paragraphs.sort_by(|left, right| left.id.cmp(&right.id));
Ok(ConvertedDataset {
generated_at: super::base_timestamp(),
metadata: DatasetMetadata::for_kind(DatasetKind::Beir, false),
source: "beir-mix".to_string(),
paragraphs,
})
}
pub fn mix_content_checksum(paragraph_ids: &HashSet<String>) -> Result<String> {
let mut ids: Vec<String> = paragraph_ids.iter().cloned().collect();
ids.sort();
let mut hasher = Sha256::new();
for paragraph_id in ids {
let kind = subset_for_paragraph_id(&paragraph_id)
.ok_or_else(|| anyhow!("unknown BEIR subset for paragraph '{paragraph_id}'"))?;
let entry = super::dataset_entry_for_kind(kind)?;
let store_dir = store_dir_for(&entry.converted_path);
let path = paragraph_path(&store_dir, &paragraph_id);
if !path.is_file() {
return Err(anyhow!(
"missing converted paragraph {} at {}",
paragraph_id,
path.display()
));
}
hasher.update(paragraph_id.as_bytes());
hasher.update([0]);
hasher.update(hash_file(&path)?.as_bytes());
}
Ok(format!("{:x}", hasher.finalize()))
}
pub fn beir_subset_stores_ready(paragraph_ids: &HashSet<String>) -> Result<bool> {
for paragraph_id in paragraph_ids {
let kind = subset_for_paragraph_id(paragraph_id).with_context(|| {
format!("routing BEIR mix paragraph id '{paragraph_id}' to subset store")
})?;
let entry = super::dataset_entry_for_kind(kind)?;
let store_dir = store_dir_for(&entry.converted_path);
if !store_dir.join("meta.json").is_file() {
return Ok(false);
}
if !paragraph_path(&store_dir, paragraph_id).is_file() {
return Ok(false);
}
}
Ok(true)
}
pub fn beir_subset_store_summary() -> Result<Vec<(String, usize, usize)>> {
let mut summary = Vec::new();
for kind in BEIR_DATASETS {
let entry = super::dataset_entry_for_kind(kind)?;
let store_dir = store_dir_for(&entry.converted_path);
if store_dir.join("meta.json").is_file() {
let meta = read_meta(&store_dir)?;
summary.push((kind.id().to_string(), meta.paragraph_count, meta.question_count));
}
}
Ok(summary)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn routes_prefixed_paragraph_ids() {
assert_eq!(
subset_for_paragraph_id("fever-doc-1"),
Some(DatasetKind::Fever)
);
assert_eq!(
subset_for_paragraph_id("nq-beir-doc-1"),
Some(DatasetKind::NqBeir)
);
assert_eq!(
subset_for_paragraph_id("trec-covid-doc-1"),
Some(DatasetKind::TrecCovid)
);
assert!(subset_for_paragraph_id("unknown-doc").is_none());
}
}
+216
View File
@@ -0,0 +1,216 @@
use std::{
fs::{self, File},
io::Read,
path::Path,
};
#[cfg(test)]
use std::path::PathBuf;
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
const SIDECAR_VERSION: u32 = 1;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChecksumSidecar {
pub version: u32,
pub sha256: String,
pub size_bytes: u64,
#[serde(default)]
pub modified_unix_secs: u64,
}
impl ChecksumSidecar {
#[cfg(test)]
pub fn sidecar_path(content_path: &Path) -> PathBuf {
content_path.with_extension("sha256")
}
#[cfg(test)]
pub fn is_valid_for(&self, content_path: &Path) -> bool {
if self.version != SIDECAR_VERSION {
return false;
}
let Ok(metadata) = fs::metadata(content_path) else {
return false;
};
if metadata.len() != self.size_bytes {
return false;
}
if self.modified_unix_secs != 0 {
let Ok(modified) = metadata.modified() else {
return true;
};
let Ok(secs) = modified.duration_since(std::time::UNIX_EPOCH) else {
return true;
};
if secs.as_secs() != self.modified_unix_secs {
return false;
}
}
true
}
}
#[allow(clippy::indexing_slicing)]
pub fn hash_file(path: &Path) -> Result<String> {
let mut file =
File::open(path).with_context(|| format!("opening file {} for checksum", path.display()))?;
let mut hasher = Sha256::new();
let mut buffer = vec![0u8; 65_536];
loop {
let read = file
.read(&mut buffer)
.with_context(|| format!("reading {} for checksum", path.display()))?;
if read == 0 {
break;
}
hasher.update(&buffer[..read]);
}
Ok(format!("{:x}", hasher.finalize()))
}
pub fn read_sidecar(path: &Path) -> Result<Option<ChecksumSidecar>> {
if !path.exists() {
return Ok(None);
}
let raw = fs::read_to_string(path)
.with_context(|| format!("reading checksum sidecar {}", path.display()))?;
let sidecar: ChecksumSidecar = serde_json::from_str(&raw)
.with_context(|| format!("parsing checksum sidecar {}", path.display()))?;
Ok(Some(sidecar))
}
#[cfg(test)]
pub fn write_sidecar(content_path: &Path, sha256: &str) -> Result<()> {
let metadata = fs::metadata(content_path)
.with_context(|| format!("reading metadata for {}", content_path.display()))?;
let modified_unix_secs = metadata
.modified()
.ok()
.and_then(|time| time.duration_since(std::time::UNIX_EPOCH).ok())
.map_or(0, |duration| duration.as_secs());
let sidecar = ChecksumSidecar {
version: SIDECAR_VERSION,
sha256: sha256.to_string(),
size_bytes: metadata.len(),
modified_unix_secs,
};
let path = ChecksumSidecar::sidecar_path(content_path);
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("creating checksum sidecar directory {}", parent.display()))?;
}
let blob = serde_json::to_vec_pretty(&sidecar).context("serialising checksum sidecar")?;
fs::write(&path, blob)
.with_context(|| format!("writing checksum sidecar {}", path.display()))?;
Ok(())
}
#[cfg(test)]
pub fn content_checksum(content_path: &Path) -> Result<String> {
let sidecar_path = ChecksumSidecar::sidecar_path(content_path);
if let Some(sidecar) = read_sidecar(&sidecar_path)? {
if sidecar.is_valid_for(content_path) {
return Ok(sidecar.sha256);
}
}
let sha256 = hash_file(content_path)?;
write_sidecar(content_path, &sha256)?;
Ok(sha256)
}
pub fn store_aggregate_checksum(store_dir: &Path) -> Result<String> {
let marker = store_dir.join("checksum.sha256");
let meta = store_dir.join("meta.json");
if marker.is_file() && meta.is_file() {
if let (Ok(marker_meta), Ok(meta_meta)) = (marker.metadata(), meta.metadata()) {
if marker_meta
.modified()
.ok()
.zip(meta_meta.modified().ok())
.is_some_and(|(marker_modified, meta_modified)| marker_modified >= meta_modified)
{
if let Some(sidecar) = read_sidecar(&marker)? {
return Ok(sidecar.sha256);
}
}
}
}
let mut entries = Vec::new();
collect_store_files(store_dir, store_dir, &mut entries)?;
entries.sort();
let mut hasher = Sha256::new();
for relative in &entries {
let path = store_dir.join(relative);
if path == marker {
continue;
}
hasher.update(relative.as_bytes());
hasher.update([0]);
let file_hash = hash_file(&path)?;
hasher.update(file_hash.as_bytes());
}
let digest = format!("{:x}", hasher.finalize());
let sidecar = ChecksumSidecar {
version: SIDECAR_VERSION,
sha256: digest.clone(),
size_bytes: entries.len() as u64,
modified_unix_secs: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_or(0, |duration| duration.as_secs()),
};
if let Some(parent) = marker.parent() {
fs::create_dir_all(parent)?;
}
fs::write(&marker, serde_json::to_vec_pretty(&sidecar)?)?;
Ok(digest)
}
fn collect_store_files(base: &Path, current: &Path, entries: &mut Vec<String>) -> Result<()> {
for entry in fs::read_dir(current)? {
let entry = entry?;
let path = entry.path();
if path.file_name().is_some_and(|name| name == "checksum.sha256") {
continue;
}
if path.is_dir() {
collect_store_files(base, &path, entries)?;
} else if path.is_file() {
let relative = path
.strip_prefix(base)
.unwrap_or(&path)
.to_string_lossy()
.replace('\\', "/");
entries.push(relative);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn sidecar_round_trip() -> Result<()> {
let dir = tempdir()?;
let file = dir.path().join("sample.json");
fs::write(&file, br#"{"hello":"world"}"#)?;
let first = content_checksum(&file)?;
let second = content_checksum(&file)?;
assert_eq!(first, second);
fs::write(&file, br#"{"hello":"world!"}"#)?;
let third = content_checksum(&file)?;
assert_ne!(first, third);
Ok(())
}
}
+197
View File
@@ -0,0 +1,197 @@
use std::collections::HashSet;
use anyhow::{Context, Result};
use tracing::info;
use super::{
catalog,
store::{
self, build_dataset_from_catalog, detect_layout, read_meta, store_dir_for, write_sharded,
ConvertedLayout,
},
ConvertedDataset, DatasetKind,
};
use crate::{
args::Config,
slice::{self, SliceConfig},
};
#[derive(Debug, Clone)]
pub struct LoadedDataset {
pub dataset: ConvertedDataset,
pub content_checksum: String,
pub partial: bool,
}
pub fn prepare_dataset(dataset_kind: DatasetKind, config: &Config) -> Result<LoadedDataset> {
if dataset_kind == DatasetKind::Beir {
return super::beir_mix::prepare_beir_mix(config);
}
let converted_path = &config.converted_dataset_path;
let layout = detect_layout(converted_path);
let store_dir = store_dir_for(converted_path);
if layout == ConvertedLayout::Missing || config.force_convert {
return convert_and_load(dataset_kind, config);
}
load_from_store(dataset_kind, config, &store_dir, true)
}
fn convert_and_load(dataset_kind: DatasetKind, config: &Config) -> Result<LoadedDataset> {
let dataset = super::convert(
config.raw_dataset_path.as_path(),
dataset_kind,
config.llm_mode,
)
.with_context(|| format!("converting {} dataset", dataset_kind.label()))?;
let store_dir = store_dir_for(&config.converted_dataset_path);
write_sharded(&dataset, &store_dir)?;
prebuild_catalog_slices(&dataset, config)?;
let checksum = crate::datasets::store_aggregate_checksum(&store_dir)?;
Ok(LoadedDataset {
dataset,
content_checksum: checksum,
partial: false,
})
}
fn load_from_store(
dataset_kind: DatasetKind,
config: &Config,
store_dir: &std::path::Path,
allow_partial: bool,
) -> Result<LoadedDataset> {
let checksum = crate::datasets::store_aggregate_checksum(store_dir)?;
let meta = read_meta(store_dir)?;
validate_metadata_fields(&meta.metadata, dataset_kind, config)?;
if allow_partial {
if let Some(paragraph_ids) = slice_paragraph_ids_for_fast_path(config)? {
let unique: HashSet<String> = paragraph_ids.into_iter().collect();
info!(
paragraphs = unique.len(),
store = %store_dir.display(),
"Loading slice-addressed paragraphs from sharded converted store"
);
let dataset = build_dataset_from_catalog(store_dir, &unique)?;
return Ok(LoadedDataset {
dataset,
content_checksum: checksum,
partial: true,
});
}
}
info!(
store = %store_dir.display(),
paragraphs = meta.paragraph_count,
"Loading full sharded converted store"
);
let dataset = store::load_sharded_full(store_dir)?;
Ok(LoadedDataset {
dataset,
content_checksum: checksum,
partial: false,
})
}
fn slice_paragraph_ids_for_fast_path(config: &Config) -> Result<Option<Vec<String>>> {
let Some(manifest_path) = slice::cached_manifest_path(config) else {
return Ok(None);
};
let Some(manifest) = slice::read_manifest_if_exists(&manifest_path)? else {
return Ok(None);
};
let slice_config = slice::slice_config_with_limit(config, slice::ledger_target(config));
if !slice::manifest_is_complete(&manifest, &slice_config) {
return Ok(None);
}
Ok(Some(
manifest
.paragraphs
.iter()
.map(|entry| entry.id.clone())
.collect(),
))
}
fn validate_metadata_fields(
metadata: &super::DatasetMetadata,
dataset_kind: DatasetKind,
config: &Config,
) -> Result<()> {
if metadata.id != dataset_kind.id() {
anyhow::bail!(
"converted dataset targets '{}', expected '{}'",
metadata.id,
dataset_kind.id()
);
}
if metadata.include_unanswerable != config.llm_mode {
anyhow::bail!(
"converted dataset include_unanswerable mismatch (expected {}, found {})",
config.llm_mode,
metadata.include_unanswerable
);
}
Ok(())
}
pub fn prebuild_catalog_slices(dataset: &ConvertedDataset, config: &Config) -> Result<()> {
let catalog = catalog()?;
let entry = catalog.dataset(dataset.metadata.id.as_str())?;
if entry.slices.is_empty() {
return Ok(());
}
info!(
dataset = dataset.metadata.id.as_str(),
slices = entry.slices.len(),
"Prebuilding catalog slice ledgers"
);
for slice_entry in &entry.slices {
let slice_config = slice_config_for_catalog_entry(config, slice_entry);
match slice::resolve_slice(dataset, &slice_config) {
Ok(resolved) => info!(
slice = resolved.manifest.slice_id.as_str(),
cases = resolved.manifest.case_count,
positives = resolved.manifest.positive_paragraphs,
negatives = resolved.manifest.negative_paragraphs,
"Prebuilt catalog slice ledger"
),
Err(err) => tracing::warn!(
slice = slice_entry.id.as_str(),
error = %err,
"Failed to prebuild catalog slice ledger"
),
}
}
Ok(())
}
fn slice_config_for_catalog_entry<'a>(
config: &'a Config,
slice_entry: &'a super::SliceEntry,
) -> SliceConfig<'a> {
SliceConfig {
cache_dir: config.cache_dir.as_path(),
force_convert: config.force_convert,
explicit_slice: Some(slice_entry.id.as_str()),
limit: slice_entry.limit,
corpus_limit: slice_entry.corpus_limit,
slice_seed: slice_entry.seed.unwrap_or(config.slice_seed),
llm_mode: slice_entry
.include_unanswerable
.unwrap_or(config.llm_mode),
negative_multiplier: slice_entry
.negative_multiplier
.unwrap_or(config.negative_multiplier),
require_verified_chunks: config.retrieval.require_verified_chunks,
}
}
+38 -143
View File
@@ -1,6 +1,10 @@
mod beir;
mod beir_mix;
mod checksum;
mod loader;
mod nq;
mod squad;
mod store;
use std::{
collections::{BTreeMap, HashMap},
@@ -20,38 +24,31 @@ const MANIFEST_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/manifest.yaml"
static DATASET_CATALOG: OnceCell<DatasetCatalog> = OnceCell::new();
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct DatasetCatalog {
datasets: BTreeMap<String, DatasetEntry>,
slices: HashMap<String, SliceLocation>,
default_dataset: String,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct DatasetEntry {
pub metadata: DatasetMetadata,
pub raw_path: PathBuf,
pub converted_path: PathBuf,
pub include_unanswerable: bool,
pub slices: Vec<SliceEntry>,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct SliceEntry {
pub id: String,
pub dataset_id: String,
pub label: String,
pub description: Option<String>,
pub limit: Option<usize>,
pub corpus_limit: Option<usize>,
pub include_unanswerable: Option<bool>,
pub seed: Option<u64>,
pub negative_multiplier: Option<f32>,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct SliceLocation {
dataset_id: String,
slice_index: usize,
@@ -59,7 +56,6 @@ struct SliceLocation {
#[derive(Debug, Deserialize)]
struct ManifestFile {
default_dataset: Option<String>,
datasets: Vec<ManifestDataset>,
}
@@ -81,6 +77,7 @@ struct ManifestDataset {
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct ManifestSlice {
id: String,
label: String,
@@ -94,6 +91,8 @@ struct ManifestSlice {
include_unanswerable: Option<bool>,
#[serde(default)]
seed: Option<u64>,
#[serde(default)]
negative_multiplier: Option<f32>,
}
impl DatasetCatalog {
@@ -111,18 +110,19 @@ impl DatasetCatalog {
let raw_path = resolve_path(root, &dataset.raw);
let converted_path = resolve_path(root, &dataset.converted);
if !raw_path.exists() {
if !raw_path.exists() && dataset.id != "beir" {
bail!(
"dataset '{}' raw file missing at {}",
dataset.id,
raw_path.display()
);
}
if !converted_path.exists() {
let store_dir = store::store_dir_for(&converted_path);
if !converted_path.exists() && !store_dir.join("meta.json").is_file() {
warn!(
"dataset '{}' converted file missing at {}; the next conversion run will regenerate it",
"dataset '{}' converted store missing at {}; the next conversion run will regenerate it",
dataset.id,
converted_path.display()
store_dir.display()
);
}
@@ -139,7 +139,6 @@ impl DatasetCatalog {
.clone()
.unwrap_or_else(|| dataset.id.clone()),
include_unanswerable: dataset.include_unanswerable,
context_token_limit: None,
};
let mut entry_slices = Vec::with_capacity(dataset.slices.len());
@@ -154,12 +153,11 @@ impl DatasetCatalog {
entry_slices.push(SliceEntry {
id: manifest_slice.id.clone(),
dataset_id: dataset.id.clone(),
label: manifest_slice.label,
description: manifest_slice.description,
limit: manifest_slice.limit,
corpus_limit: manifest_slice.corpus_limit,
include_unanswerable: manifest_slice.include_unanswerable,
seed: manifest_slice.seed,
negative_multiplier: manifest_slice.negative_multiplier,
});
slices.insert(
manifest_slice.id,
@@ -176,22 +174,16 @@ impl DatasetCatalog {
metadata,
raw_path,
converted_path,
include_unanswerable: dataset.include_unanswerable,
slices: entry_slices,
},
);
}
let default_dataset = manifest
.default_dataset
.or_else(|| datasets.keys().next().cloned())
.ok_or_else(|| anyhow!("dataset manifest does not include any datasets"))?;
if datasets.is_empty() {
bail!("dataset manifest does not include any datasets");
}
Ok(Self {
datasets,
slices,
default_dataset,
})
Ok(Self { datasets, slices })
}
pub fn global() -> Result<&'static Self> {
@@ -204,12 +196,6 @@ impl DatasetCatalog {
.ok_or_else(|| anyhow!("unknown dataset '{id}' in manifest"))
}
#[allow(dead_code)]
pub fn default_dataset(&self) -> Result<&DatasetEntry> {
self.dataset(&self.default_dataset)
}
#[allow(dead_code)]
pub fn slice(&self, slice_id: &str) -> Result<(&DatasetEntry, &SliceEntry)> {
let location = self
.slices
@@ -236,20 +222,29 @@ fn resolve_path(root: &Path, value: &str) -> PathBuf {
}
}
pub use checksum::store_aggregate_checksum;
pub use beir_mix::{
beir_subset_store_summary, beir_subset_stores_ready, mix_content_checksum,
};
pub use loader::{prebuild_catalog_slices, prepare_dataset};
pub use store::{
content_checksum_for_layout, detect_layout, store_dir_for, write_sharded, ConvertedLayout,
};
pub fn catalog() -> Result<&'static DatasetCatalog> {
DatasetCatalog::global()
}
fn dataset_entry_for_kind(kind: DatasetKind) -> Result<&'static DatasetEntry> {
pub(crate) fn dataset_entry_for_kind(kind: DatasetKind) -> Result<&'static DatasetEntry> {
let catalog = catalog()?;
catalog.dataset(kind.id())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum, Default)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, ValueEnum, Default)]
pub enum DatasetKind {
#[default]
SquadV2,
NaturalQuestions,
#[default]
Beir,
#[value(name = "fever")]
Fever,
@@ -416,16 +411,10 @@ pub struct DatasetMetadata {
pub source_prefix: String,
#[serde(default)]
pub include_unanswerable: bool,
#[serde(default)]
pub context_token_limit: Option<usize>,
}
impl DatasetMetadata {
pub fn for_kind(
kind: DatasetKind,
include_unanswerable: bool,
context_token_limit: Option<usize>,
) -> Self {
pub fn for_kind(kind: DatasetKind, include_unanswerable: bool) -> Self {
if let Ok(entry) = dataset_entry_for_kind(kind) {
return Self {
id: entry.metadata.id.clone(),
@@ -434,7 +423,6 @@ impl DatasetMetadata {
entity_suffix: entry.metadata.entity_suffix.clone(),
source_prefix: entry.metadata.source_prefix.clone(),
include_unanswerable,
context_token_limit,
};
}
@@ -445,13 +433,12 @@ impl DatasetMetadata {
entity_suffix: kind.entity_suffix().to_string(),
source_prefix: kind.source_prefix().to_string(),
include_unanswerable,
context_token_limit,
}
}
}
fn default_metadata() -> DatasetMetadata {
DatasetMetadata::for_kind(DatasetKind::default(), false, None)
DatasetMetadata::for_kind(DatasetKind::default(), false)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -483,14 +470,15 @@ pub fn convert(
raw_path: &Path,
dataset: DatasetKind,
include_unanswerable: bool,
context_token_limit: Option<usize>,
) -> Result<ConvertedDataset> {
let paragraphs = match dataset {
DatasetKind::SquadV2 => squad::convert_squad(raw_path)?,
DatasetKind::NaturalQuestions => {
nq::convert_nq(raw_path, include_unanswerable, context_token_limit)?
DatasetKind::NaturalQuestions => nq::convert_nq(raw_path, include_unanswerable)?,
DatasetKind::Beir => {
bail!(
"BEIR mix is prepared via slice-first subset stores; use prepare_beir_mix instead of convert"
);
}
DatasetKind::Beir => convert_beir_mix(include_unanswerable, context_token_limit)?,
DatasetKind::Fever
| DatasetKind::Fiqa
| DatasetKind::HotpotQa
@@ -501,11 +489,6 @@ pub fn convert(
| DatasetKind::NqBeir => beir::convert_beir(raw_path, dataset)?,
};
let metadata_limit = match dataset {
DatasetKind::NaturalQuestions => None,
_ => context_token_limit,
};
let generated_at = match dataset {
DatasetKind::Beir
| DatasetKind::Fever
@@ -526,100 +509,12 @@ pub fn convert(
Ok(ConvertedDataset {
generated_at,
metadata: DatasetMetadata::for_kind(dataset, include_unanswerable, metadata_limit),
metadata: DatasetMetadata::for_kind(dataset, include_unanswerable),
source: source_label,
paragraphs,
})
}
fn convert_beir_mix(
include_unanswerable: bool,
_context_token_limit: Option<usize>,
) -> Result<Vec<ConvertedParagraph>> {
if include_unanswerable {
warn!("BEIR mix ignores include_unanswerable flag; all questions are answerable");
}
let mut paragraphs = Vec::new();
for subset in BEIR_DATASETS {
let entry = dataset_entry_for_kind(subset)?;
let subset_paragraphs = beir::convert_beir(&entry.raw_path, subset)?;
paragraphs.extend(subset_paragraphs);
}
Ok(paragraphs)
}
fn ensure_parent(path: &Path) -> Result<()> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("creating parent directory for {}", path.display()))?;
}
Ok(())
}
pub fn write_converted(dataset: &ConvertedDataset, converted_path: &Path) -> Result<()> {
ensure_parent(converted_path)?;
let json =
serde_json::to_string_pretty(dataset).context("serialising converted dataset to JSON")?;
fs::write(converted_path, json)
.with_context(|| format!("writing converted dataset to {}", converted_path.display()))
}
pub fn read_converted(converted_path: &Path) -> Result<ConvertedDataset> {
let raw = fs::read_to_string(converted_path)
.with_context(|| format!("reading converted dataset at {}", converted_path.display()))?;
let mut dataset: ConvertedDataset = serde_json::from_str(&raw)
.with_context(|| format!("parsing converted dataset at {}", converted_path.display()))?;
if dataset.metadata.id.trim().is_empty() {
dataset.metadata = default_metadata();
}
if dataset.source.is_empty() {
dataset.source = converted_path.display().to_string();
}
Ok(dataset)
}
pub fn ensure_converted(
dataset_kind: DatasetKind,
raw_path: &Path,
converted_path: &Path,
force: bool,
include_unanswerable: bool,
context_token_limit: Option<usize>,
) -> Result<ConvertedDataset> {
if force || !converted_path.exists() {
let dataset = convert(
raw_path,
dataset_kind,
include_unanswerable,
context_token_limit,
)?;
write_converted(&dataset, converted_path)?;
return Ok(dataset);
}
match read_converted(converted_path) {
Ok(dataset)
if dataset.metadata.id == dataset_kind.id()
&& dataset.metadata.include_unanswerable == include_unanswerable
&& dataset.metadata.context_token_limit == context_token_limit =>
{
Ok(dataset)
}
_ => {
let dataset = convert(
raw_path,
dataset_kind,
include_unanswerable,
context_token_limit,
)?;
write_converted(&dataset, converted_path)?;
Ok(dataset)
}
}
}
pub fn base_timestamp() -> DateTime<Utc> {
Utc.with_ymd_and_hms(2023, 1, 1, 0, 0, 0).unwrap()
}
+1 -5
View File
@@ -16,11 +16,7 @@ use super::{ConvertedParagraph, ConvertedQuestion};
clippy::arithmetic_side_effects,
clippy::cast_sign_loss
)]
pub fn convert_nq(
raw_path: &Path,
include_unanswerable: bool,
_context_token_limit: Option<usize>,
) -> Result<Vec<ConvertedParagraph>> {
pub fn convert_nq(raw_path: &Path, include_unanswerable: bool) -> Result<Vec<ConvertedParagraph>> {
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct NqExample {
+410
View File
@@ -0,0 +1,410 @@
use std::{
collections::{HashMap, HashSet},
fs::{self, File, OpenOptions},
io::{BufRead, BufReader, Write},
path::{Path, PathBuf},
};
use anyhow::{anyhow, Context, Result};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use tracing::info;
use super::{
checksum::store_aggregate_checksum,
ConvertedDataset, ConvertedParagraph, ConvertedQuestion, DatasetMetadata,
};
use crate::slice;
pub const SHARDED_STORE_VERSION: u32 = 1;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShardedMeta {
pub version: u32,
pub generated_at: DateTime<Utc>,
pub metadata: DatasetMetadata,
pub source: String,
pub paragraph_count: usize,
pub question_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct QuestionRecord {
paragraph_id: String,
#[serde(flatten)]
question: ConvertedQuestion,
}
#[derive(Debug, Clone)]
pub struct QuestionCatalog {
pub entries: Vec<QuestionRecord>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConvertedLayout {
ShardedStore,
Missing,
}
pub fn store_dir_for(converted_path: &Path) -> PathBuf {
converted_path
.parent()
.unwrap_or_else(|| Path::new("."))
.join(
converted_path
.file_stem()
.map_or_else(|| "dataset".to_string(), |stem| stem.to_string_lossy().into()),
)
}
pub fn detect_layout(converted_path: &Path) -> ConvertedLayout {
let store_dir = store_dir_for(converted_path);
if store_dir.join("meta.json").is_file() {
ConvertedLayout::ShardedStore
} else {
ConvertedLayout::Missing
}
}
fn paragraph_file_name(paragraph_id: &str) -> String {
format!("{}.json", slice::paragraph_storage_key(paragraph_id))
}
pub fn paragraph_path(store_dir: &Path, paragraph_id: &str) -> PathBuf {
store_dir
.join("paragraphs")
.join(paragraph_file_name(paragraph_id))
}
pub fn write_sharded(dataset: &ConvertedDataset, store_dir: &Path) -> Result<String> {
if store_dir.exists() {
fs::remove_dir_all(store_dir)
.with_context(|| format!("clearing sharded store {}", store_dir.display()))?;
}
fs::create_dir_all(store_dir.join("paragraphs"))
.with_context(|| format!("creating sharded store {}", store_dir.display()))?;
let question_count = dataset
.paragraphs
.iter()
.map(|paragraph| paragraph.questions.len())
.sum::<usize>();
let meta = ShardedMeta {
version: SHARDED_STORE_VERSION,
generated_at: dataset.generated_at,
metadata: dataset.metadata.clone(),
source: dataset.source.clone(),
paragraph_count: dataset.paragraphs.len(),
question_count,
};
let meta_path = store_dir.join("meta.json");
fs::write(
&meta_path,
serde_json::to_vec_pretty(&meta).context("serialising sharded store metadata")?,
)
.with_context(|| format!("writing sharded metadata {}", meta_path.display()))?;
let mut questions_file = File::create(store_dir.join("questions.jsonl"))
.context("creating questions.jsonl for sharded store")?;
let mut paragraph_ids_file = File::create(store_dir.join("paragraph_ids.jsonl"))
.context("creating paragraph_ids.jsonl for sharded store")?;
for paragraph in &dataset.paragraphs {
writeln!(paragraph_ids_file, "{}", paragraph.id)
.context("writing paragraph id to paragraph_ids.jsonl")?;
for question in &paragraph.questions {
let record = QuestionRecord {
paragraph_id: paragraph.id.clone(),
question: question.clone(),
};
serde_json::to_writer(&mut questions_file, &record)
.context("writing question record to questions.jsonl")?;
questions_file.write_all(b"\n")?;
}
let path = paragraph_path(store_dir, &paragraph.id);
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
fs::write(
&path,
serde_json::to_vec(paragraph).context("serialising sharded paragraph")?,
)
.with_context(|| format!("writing sharded paragraph {}", path.display()))?;
}
let digest = store_aggregate_checksum(store_dir)?;
info!(
store = %store_dir.display(),
paragraphs = dataset.paragraphs.len(),
questions = question_count,
checksum = %digest,
"Wrote sharded converted dataset"
);
Ok(digest)
}
pub fn read_meta(store_dir: &Path) -> Result<ShardedMeta> {
let path = store_dir.join("meta.json");
let raw = fs::read_to_string(&path)
.with_context(|| format!("reading sharded metadata {}", path.display()))?;
serde_json::from_str(&raw)
.with_context(|| format!("parsing sharded metadata {}", path.display()))
}
pub fn content_checksum_for_layout(converted_path: &Path) -> Result<String> {
match detect_layout(converted_path) {
ConvertedLayout::ShardedStore => {
crate::datasets::store_aggregate_checksum(&store_dir_for(converted_path))
}
ConvertedLayout::Missing => Err(anyhow!(
"converted dataset missing at {}",
converted_path.display()
)),
}
}
fn load_paragraph(store_dir: &Path, paragraph_id: &str) -> Result<ConvertedParagraph> {
let path = paragraph_path(store_dir, paragraph_id);
let raw = fs::read(&path)
.with_context(|| format!("reading sharded paragraph {}", path.display()))?;
serde_json::from_slice(&raw)
.with_context(|| format!("parsing sharded paragraph {}", path.display()))
}
fn load_paragraphs(store_dir: &Path, paragraph_ids: &[String]) -> Result<Vec<ConvertedParagraph>> {
paragraph_ids
.iter()
.map(|paragraph_id| load_paragraph(store_dir, paragraph_id))
.collect()
}
pub fn load_sharded_partial(store_dir: &Path, paragraph_ids: &[String]) -> Result<ConvertedDataset> {
let meta = read_meta(store_dir)?;
let mut paragraphs = load_paragraphs(store_dir, paragraph_ids)?;
paragraphs.sort_by(|left, right| left.id.cmp(&right.id));
Ok(ConvertedDataset {
generated_at: meta.generated_at,
metadata: meta.metadata,
source: meta.source,
paragraphs,
})
}
pub fn load_sharded_full(store_dir: &Path) -> Result<ConvertedDataset> {
let meta = read_meta(store_dir)?;
let ids = load_paragraph_ids(store_dir)?;
let paragraphs = load_paragraphs(store_dir, &ids)?;
Ok(ConvertedDataset {
generated_at: meta.generated_at,
metadata: meta.metadata,
source: meta.source,
paragraphs,
})
}
pub fn load_paragraph_ids_set(store_dir: &Path) -> Result<HashSet<String>> {
Ok(load_paragraph_ids(store_dir)?.into_iter().collect())
}
#[allow(clippy::arithmetic_side_effects)]
pub fn upsert_sharded_paragraphs(
store_dir: &Path,
paragraphs: &[ConvertedParagraph],
) -> Result<()> {
if paragraphs.is_empty() {
return Ok(());
}
if !store_dir.join("meta.json").is_file() {
return Err(anyhow!(
"cannot upsert into missing sharded store at {}",
store_dir.display()
));
}
fs::create_dir_all(store_dir.join("paragraphs"))
.with_context(|| format!("creating paragraphs directory in {}", store_dir.display()))?;
let existing = load_paragraph_ids_set(store_dir)?;
let questions_path = store_dir.join("questions.jsonl");
let mut questions_file = OpenOptions::new()
.create(true)
.append(true)
.open(&questions_path)
.with_context(|| format!("opening question catalog {}", questions_path.display()))?;
let mut ids_file = None;
let mut new_paragraphs = 0usize;
let mut new_questions = 0usize;
for paragraph in paragraphs {
let is_new = !existing.contains(&paragraph.id);
let path = paragraph_path(store_dir, &paragraph.id);
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
fs::write(
&path,
serde_json::to_vec(paragraph).context("serialising sharded paragraph")?,
)
.with_context(|| format!("writing sharded paragraph {}", path.display()))?;
if is_new {
if ids_file.is_none() {
ids_file = Some(
OpenOptions::new()
.create(true)
.append(true)
.open(store_dir.join("paragraph_ids.jsonl"))
.context("opening paragraph_ids.jsonl for append")?,
);
}
if let Some(file) = ids_file.as_mut() {
writeln!(file, "{}", paragraph.id).context("appending paragraph id")?;
}
new_paragraphs += 1;
for question in &paragraph.questions {
let record = QuestionRecord {
paragraph_id: paragraph.id.clone(),
question: question.clone(),
};
serde_json::to_writer(&mut questions_file, &record)
.context("writing question record to questions.jsonl")?;
questions_file.write_all(b"\n")?;
new_questions += 1;
}
}
}
if new_paragraphs > 0 || new_questions > 0 {
let meta = read_meta(store_dir)?;
let updated = ShardedMeta {
paragraph_count: meta.paragraph_count + new_paragraphs,
question_count: meta.question_count + new_questions,
..meta
};
fs::write(
store_dir.join("meta.json"),
serde_json::to_vec_pretty(&updated).context("serialising updated sharded metadata")?,
)?;
store_aggregate_checksum(store_dir)?;
info!(
store = %store_dir.display(),
new_paragraphs,
new_questions,
"Upserted paragraphs into sharded converted store"
);
}
Ok(())
}
pub fn load_paragraph_ids(store_dir: &Path) -> Result<Vec<String>> {
let path = store_dir.join("paragraph_ids.jsonl");
let file = File::open(&path)
.with_context(|| format!("opening paragraph id index {}", path.display()))?;
let reader = BufReader::new(file);
reader
.lines()
.map(|line| {
line.context("reading paragraph id index line")
.and_then(|value| {
let trimmed = value.trim();
if trimmed.is_empty() {
Err(anyhow!("empty paragraph id in index"))
} else {
Ok(trimmed.to_string())
}
})
})
.collect()
}
pub fn load_question_catalog(store_dir: &Path) -> Result<QuestionCatalog> {
let path = store_dir.join("questions.jsonl");
let file = File::open(&path)
.with_context(|| format!("opening question catalog {}", path.display()))?;
let reader = BufReader::new(file);
let mut entries = Vec::new();
for line in reader.lines() {
let line = line.context("reading question catalog line")?;
if line.trim().is_empty() {
continue;
}
let record: QuestionRecord = serde_json::from_str(&line)
.context("parsing question catalog record")?;
entries.push(record);
}
Ok(QuestionCatalog { entries })
}
pub fn build_dataset_from_catalog(
store_dir: &Path,
paragraph_ids: &HashSet<String>,
) -> Result<ConvertedDataset> {
let catalog = load_question_catalog(store_dir)?;
let mut questions_by_paragraph: HashMap<String, Vec<ConvertedQuestion>> = HashMap::new();
for entry in catalog.entries {
if paragraph_ids.contains(&entry.paragraph_id) {
questions_by_paragraph
.entry(entry.paragraph_id.clone())
.or_default()
.push(entry.question);
}
}
let mut dataset = load_sharded_partial(
store_dir,
&paragraph_ids.iter().cloned().collect::<Vec<_>>(),
)?;
for paragraph in &mut dataset.paragraphs {
if let Some(questions) = questions_by_paragraph.remove(&paragraph.id) {
paragraph.questions = questions;
} else {
paragraph.questions.clear();
}
}
Ok(dataset)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::datasets::{DatasetKind, DatasetMetadata};
fn sample_dataset() -> ConvertedDataset {
ConvertedDataset {
generated_at: Utc::now(),
metadata: DatasetMetadata::for_kind(DatasetKind::SquadV2, false),
source: "test".to_string(),
paragraphs: vec![ConvertedParagraph {
id: "p1".to_string(),
title: "Title".to_string(),
context: "Body".to_string(),
questions: vec![ConvertedQuestion {
id: "q1".to_string(),
question: "Question?".to_string(),
answers: vec!["Answer".to_string()],
is_impossible: false,
}],
}],
}
}
#[test]
#[allow(clippy::indexing_slicing)]
fn sharded_round_trip() -> Result<()> {
let dir = tempfile::tempdir()?;
let store_dir = dir.path().join("sample");
let dataset = sample_dataset();
write_sharded(&dataset, &store_dir)?;
let loaded = load_sharded_full(&store_dir)?;
assert_eq!(loaded.paragraphs.len(), 1);
assert_eq!(loaded.paragraphs[0].questions[0].id, "q1");
Ok(())
}
}