mirror of
https://github.com/perstarkse/minne.git
synced 2026-07-01 18:41:37 +02:00
evals: eval crate overhaul, simplification and performance improvements
This commit is contained in:
+38
-143
@@ -1,6 +1,10 @@
|
||||
mod beir;
|
||||
mod beir_mix;
|
||||
mod checksum;
|
||||
mod loader;
|
||||
mod nq;
|
||||
mod squad;
|
||||
mod store;
|
||||
|
||||
use std::{
|
||||
collections::{BTreeMap, HashMap},
|
||||
@@ -20,38 +24,31 @@ const MANIFEST_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/manifest.yaml"
|
||||
static DATASET_CATALOG: OnceCell<DatasetCatalog> = OnceCell::new();
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)]
|
||||
pub struct DatasetCatalog {
|
||||
datasets: BTreeMap<String, DatasetEntry>,
|
||||
slices: HashMap<String, SliceLocation>,
|
||||
default_dataset: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)]
|
||||
pub struct DatasetEntry {
|
||||
pub metadata: DatasetMetadata,
|
||||
pub raw_path: PathBuf,
|
||||
pub converted_path: PathBuf,
|
||||
pub include_unanswerable: bool,
|
||||
pub slices: Vec<SliceEntry>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)]
|
||||
pub struct SliceEntry {
|
||||
pub id: String,
|
||||
pub dataset_id: String,
|
||||
pub label: String,
|
||||
pub description: Option<String>,
|
||||
pub limit: Option<usize>,
|
||||
pub corpus_limit: Option<usize>,
|
||||
pub include_unanswerable: Option<bool>,
|
||||
pub seed: Option<u64>,
|
||||
pub negative_multiplier: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)]
|
||||
struct SliceLocation {
|
||||
dataset_id: String,
|
||||
slice_index: usize,
|
||||
@@ -59,7 +56,6 @@ struct SliceLocation {
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ManifestFile {
|
||||
default_dataset: Option<String>,
|
||||
datasets: Vec<ManifestDataset>,
|
||||
}
|
||||
|
||||
@@ -81,6 +77,7 @@ struct ManifestDataset {
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
struct ManifestSlice {
|
||||
id: String,
|
||||
label: String,
|
||||
@@ -94,6 +91,8 @@ struct ManifestSlice {
|
||||
include_unanswerable: Option<bool>,
|
||||
#[serde(default)]
|
||||
seed: Option<u64>,
|
||||
#[serde(default)]
|
||||
negative_multiplier: Option<f32>,
|
||||
}
|
||||
|
||||
impl DatasetCatalog {
|
||||
@@ -111,18 +110,19 @@ impl DatasetCatalog {
|
||||
let raw_path = resolve_path(root, &dataset.raw);
|
||||
let converted_path = resolve_path(root, &dataset.converted);
|
||||
|
||||
if !raw_path.exists() {
|
||||
if !raw_path.exists() && dataset.id != "beir" {
|
||||
bail!(
|
||||
"dataset '{}' raw file missing at {}",
|
||||
dataset.id,
|
||||
raw_path.display()
|
||||
);
|
||||
}
|
||||
if !converted_path.exists() {
|
||||
let store_dir = store::store_dir_for(&converted_path);
|
||||
if !converted_path.exists() && !store_dir.join("meta.json").is_file() {
|
||||
warn!(
|
||||
"dataset '{}' converted file missing at {}; the next conversion run will regenerate it",
|
||||
"dataset '{}' converted store missing at {}; the next conversion run will regenerate it",
|
||||
dataset.id,
|
||||
converted_path.display()
|
||||
store_dir.display()
|
||||
);
|
||||
}
|
||||
|
||||
@@ -139,7 +139,6 @@ impl DatasetCatalog {
|
||||
.clone()
|
||||
.unwrap_or_else(|| dataset.id.clone()),
|
||||
include_unanswerable: dataset.include_unanswerable,
|
||||
context_token_limit: None,
|
||||
};
|
||||
|
||||
let mut entry_slices = Vec::with_capacity(dataset.slices.len());
|
||||
@@ -154,12 +153,11 @@ impl DatasetCatalog {
|
||||
entry_slices.push(SliceEntry {
|
||||
id: manifest_slice.id.clone(),
|
||||
dataset_id: dataset.id.clone(),
|
||||
label: manifest_slice.label,
|
||||
description: manifest_slice.description,
|
||||
limit: manifest_slice.limit,
|
||||
corpus_limit: manifest_slice.corpus_limit,
|
||||
include_unanswerable: manifest_slice.include_unanswerable,
|
||||
seed: manifest_slice.seed,
|
||||
negative_multiplier: manifest_slice.negative_multiplier,
|
||||
});
|
||||
slices.insert(
|
||||
manifest_slice.id,
|
||||
@@ -176,22 +174,16 @@ impl DatasetCatalog {
|
||||
metadata,
|
||||
raw_path,
|
||||
converted_path,
|
||||
include_unanswerable: dataset.include_unanswerable,
|
||||
slices: entry_slices,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let default_dataset = manifest
|
||||
.default_dataset
|
||||
.or_else(|| datasets.keys().next().cloned())
|
||||
.ok_or_else(|| anyhow!("dataset manifest does not include any datasets"))?;
|
||||
if datasets.is_empty() {
|
||||
bail!("dataset manifest does not include any datasets");
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
datasets,
|
||||
slices,
|
||||
default_dataset,
|
||||
})
|
||||
Ok(Self { datasets, slices })
|
||||
}
|
||||
|
||||
pub fn global() -> Result<&'static Self> {
|
||||
@@ -204,12 +196,6 @@ impl DatasetCatalog {
|
||||
.ok_or_else(|| anyhow!("unknown dataset '{id}' in manifest"))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn default_dataset(&self) -> Result<&DatasetEntry> {
|
||||
self.dataset(&self.default_dataset)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn slice(&self, slice_id: &str) -> Result<(&DatasetEntry, &SliceEntry)> {
|
||||
let location = self
|
||||
.slices
|
||||
@@ -236,20 +222,29 @@ fn resolve_path(root: &Path, value: &str) -> PathBuf {
|
||||
}
|
||||
}
|
||||
|
||||
pub use checksum::store_aggregate_checksum;
|
||||
pub use beir_mix::{
|
||||
beir_subset_store_summary, beir_subset_stores_ready, mix_content_checksum,
|
||||
};
|
||||
pub use loader::{prebuild_catalog_slices, prepare_dataset};
|
||||
pub use store::{
|
||||
content_checksum_for_layout, detect_layout, store_dir_for, write_sharded, ConvertedLayout,
|
||||
};
|
||||
|
||||
pub fn catalog() -> Result<&'static DatasetCatalog> {
|
||||
DatasetCatalog::global()
|
||||
}
|
||||
|
||||
fn dataset_entry_for_kind(kind: DatasetKind) -> Result<&'static DatasetEntry> {
|
||||
pub(crate) fn dataset_entry_for_kind(kind: DatasetKind) -> Result<&'static DatasetEntry> {
|
||||
let catalog = catalog()?;
|
||||
catalog.dataset(kind.id())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum, Default)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, ValueEnum, Default)]
|
||||
pub enum DatasetKind {
|
||||
#[default]
|
||||
SquadV2,
|
||||
NaturalQuestions,
|
||||
#[default]
|
||||
Beir,
|
||||
#[value(name = "fever")]
|
||||
Fever,
|
||||
@@ -416,16 +411,10 @@ pub struct DatasetMetadata {
|
||||
pub source_prefix: String,
|
||||
#[serde(default)]
|
||||
pub include_unanswerable: bool,
|
||||
#[serde(default)]
|
||||
pub context_token_limit: Option<usize>,
|
||||
}
|
||||
|
||||
impl DatasetMetadata {
|
||||
pub fn for_kind(
|
||||
kind: DatasetKind,
|
||||
include_unanswerable: bool,
|
||||
context_token_limit: Option<usize>,
|
||||
) -> Self {
|
||||
pub fn for_kind(kind: DatasetKind, include_unanswerable: bool) -> Self {
|
||||
if let Ok(entry) = dataset_entry_for_kind(kind) {
|
||||
return Self {
|
||||
id: entry.metadata.id.clone(),
|
||||
@@ -434,7 +423,6 @@ impl DatasetMetadata {
|
||||
entity_suffix: entry.metadata.entity_suffix.clone(),
|
||||
source_prefix: entry.metadata.source_prefix.clone(),
|
||||
include_unanswerable,
|
||||
context_token_limit,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -445,13 +433,12 @@ impl DatasetMetadata {
|
||||
entity_suffix: kind.entity_suffix().to_string(),
|
||||
source_prefix: kind.source_prefix().to_string(),
|
||||
include_unanswerable,
|
||||
context_token_limit,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_metadata() -> DatasetMetadata {
|
||||
DatasetMetadata::for_kind(DatasetKind::default(), false, None)
|
||||
DatasetMetadata::for_kind(DatasetKind::default(), false)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -483,14 +470,15 @@ pub fn convert(
|
||||
raw_path: &Path,
|
||||
dataset: DatasetKind,
|
||||
include_unanswerable: bool,
|
||||
context_token_limit: Option<usize>,
|
||||
) -> Result<ConvertedDataset> {
|
||||
let paragraphs = match dataset {
|
||||
DatasetKind::SquadV2 => squad::convert_squad(raw_path)?,
|
||||
DatasetKind::NaturalQuestions => {
|
||||
nq::convert_nq(raw_path, include_unanswerable, context_token_limit)?
|
||||
DatasetKind::NaturalQuestions => nq::convert_nq(raw_path, include_unanswerable)?,
|
||||
DatasetKind::Beir => {
|
||||
bail!(
|
||||
"BEIR mix is prepared via slice-first subset stores; use prepare_beir_mix instead of convert"
|
||||
);
|
||||
}
|
||||
DatasetKind::Beir => convert_beir_mix(include_unanswerable, context_token_limit)?,
|
||||
DatasetKind::Fever
|
||||
| DatasetKind::Fiqa
|
||||
| DatasetKind::HotpotQa
|
||||
@@ -501,11 +489,6 @@ pub fn convert(
|
||||
| DatasetKind::NqBeir => beir::convert_beir(raw_path, dataset)?,
|
||||
};
|
||||
|
||||
let metadata_limit = match dataset {
|
||||
DatasetKind::NaturalQuestions => None,
|
||||
_ => context_token_limit,
|
||||
};
|
||||
|
||||
let generated_at = match dataset {
|
||||
DatasetKind::Beir
|
||||
| DatasetKind::Fever
|
||||
@@ -526,100 +509,12 @@ pub fn convert(
|
||||
|
||||
Ok(ConvertedDataset {
|
||||
generated_at,
|
||||
metadata: DatasetMetadata::for_kind(dataset, include_unanswerable, metadata_limit),
|
||||
metadata: DatasetMetadata::for_kind(dataset, include_unanswerable),
|
||||
source: source_label,
|
||||
paragraphs,
|
||||
})
|
||||
}
|
||||
|
||||
fn convert_beir_mix(
|
||||
include_unanswerable: bool,
|
||||
_context_token_limit: Option<usize>,
|
||||
) -> Result<Vec<ConvertedParagraph>> {
|
||||
if include_unanswerable {
|
||||
warn!("BEIR mix ignores include_unanswerable flag; all questions are answerable");
|
||||
}
|
||||
|
||||
let mut paragraphs = Vec::new();
|
||||
for subset in BEIR_DATASETS {
|
||||
let entry = dataset_entry_for_kind(subset)?;
|
||||
let subset_paragraphs = beir::convert_beir(&entry.raw_path, subset)?;
|
||||
paragraphs.extend(subset_paragraphs);
|
||||
}
|
||||
|
||||
Ok(paragraphs)
|
||||
}
|
||||
|
||||
fn ensure_parent(path: &Path) -> Result<()> {
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.with_context(|| format!("creating parent directory for {}", path.display()))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write_converted(dataset: &ConvertedDataset, converted_path: &Path) -> Result<()> {
|
||||
ensure_parent(converted_path)?;
|
||||
let json =
|
||||
serde_json::to_string_pretty(dataset).context("serialising converted dataset to JSON")?;
|
||||
fs::write(converted_path, json)
|
||||
.with_context(|| format!("writing converted dataset to {}", converted_path.display()))
|
||||
}
|
||||
|
||||
pub fn read_converted(converted_path: &Path) -> Result<ConvertedDataset> {
|
||||
let raw = fs::read_to_string(converted_path)
|
||||
.with_context(|| format!("reading converted dataset at {}", converted_path.display()))?;
|
||||
let mut dataset: ConvertedDataset = serde_json::from_str(&raw)
|
||||
.with_context(|| format!("parsing converted dataset at {}", converted_path.display()))?;
|
||||
if dataset.metadata.id.trim().is_empty() {
|
||||
dataset.metadata = default_metadata();
|
||||
}
|
||||
if dataset.source.is_empty() {
|
||||
dataset.source = converted_path.display().to_string();
|
||||
}
|
||||
Ok(dataset)
|
||||
}
|
||||
|
||||
pub fn ensure_converted(
|
||||
dataset_kind: DatasetKind,
|
||||
raw_path: &Path,
|
||||
converted_path: &Path,
|
||||
force: bool,
|
||||
include_unanswerable: bool,
|
||||
context_token_limit: Option<usize>,
|
||||
) -> Result<ConvertedDataset> {
|
||||
if force || !converted_path.exists() {
|
||||
let dataset = convert(
|
||||
raw_path,
|
||||
dataset_kind,
|
||||
include_unanswerable,
|
||||
context_token_limit,
|
||||
)?;
|
||||
write_converted(&dataset, converted_path)?;
|
||||
return Ok(dataset);
|
||||
}
|
||||
|
||||
match read_converted(converted_path) {
|
||||
Ok(dataset)
|
||||
if dataset.metadata.id == dataset_kind.id()
|
||||
&& dataset.metadata.include_unanswerable == include_unanswerable
|
||||
&& dataset.metadata.context_token_limit == context_token_limit =>
|
||||
{
|
||||
Ok(dataset)
|
||||
}
|
||||
_ => {
|
||||
let dataset = convert(
|
||||
raw_path,
|
||||
dataset_kind,
|
||||
include_unanswerable,
|
||||
context_token_limit,
|
||||
)?;
|
||||
write_converted(&dataset, converted_path)?;
|
||||
Ok(dataset)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn base_timestamp() -> DateTime<Utc> {
|
||||
Utc.with_ymd_and_hms(2023, 1, 1, 0, 0, 0).unwrap()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user