mirror of
https://github.com/perstarkse/minne.git
synced 2026-04-25 02:08:30 +02:00
dataset: beir
This commit is contained in:
@@ -31,3 +31,108 @@ datasets:
|
|||||||
corpus_limit: 2000
|
corpus_limit: 2000
|
||||||
include_unanswerable: false
|
include_unanswerable: false
|
||||||
seed: 0x5eed2025
|
seed: 0x5eed2025
|
||||||
|
- id: beir
|
||||||
|
label: "BEIR mix"
|
||||||
|
category: "BEIR"
|
||||||
|
entity_suffix: "BEIR"
|
||||||
|
source_prefix: "beir"
|
||||||
|
raw: "data/raw/beir"
|
||||||
|
converted: "data/converted/beir-minne.json"
|
||||||
|
include_unanswerable: false
|
||||||
|
slices:
|
||||||
|
- id: beir-mix-600
|
||||||
|
label: "BEIR mix (600)"
|
||||||
|
description: "Balanced slice across FEVER, FiQA, HotpotQA, NFCorpus, Quora, TREC-COVID"
|
||||||
|
limit: 600
|
||||||
|
corpus_limit: 6000
|
||||||
|
seed: 0x5eed2025
|
||||||
|
- id: fever
|
||||||
|
label: "FEVER (BEIR)"
|
||||||
|
category: "FEVER"
|
||||||
|
entity_suffix: "FEVER"
|
||||||
|
source_prefix: "fever"
|
||||||
|
raw: "data/raw/fever"
|
||||||
|
converted: "data/converted/fever-minne.json"
|
||||||
|
include_unanswerable: false
|
||||||
|
slices:
|
||||||
|
- id: fever-test-200
|
||||||
|
label: "FEVER test (200)"
|
||||||
|
description: "200-case slice from BEIR test qrels"
|
||||||
|
limit: 200
|
||||||
|
corpus_limit: 5000
|
||||||
|
seed: 0x5eed2025
|
||||||
|
- id: fiqa
|
||||||
|
label: "FiQA-2018 (BEIR)"
|
||||||
|
category: "FiQA-2018"
|
||||||
|
entity_suffix: "FiQA"
|
||||||
|
source_prefix: "fiqa"
|
||||||
|
raw: "data/raw/fiqa"
|
||||||
|
converted: "data/converted/fiqa-minne.json"
|
||||||
|
include_unanswerable: false
|
||||||
|
slices:
|
||||||
|
- id: fiqa-test-200
|
||||||
|
label: "FiQA test (200)"
|
||||||
|
description: "200-case slice from BEIR test qrels"
|
||||||
|
limit: 200
|
||||||
|
corpus_limit: 5000
|
||||||
|
seed: 0x5eed2025
|
||||||
|
- id: hotpotqa
|
||||||
|
label: "HotpotQA (BEIR)"
|
||||||
|
category: "HotpotQA"
|
||||||
|
entity_suffix: "HotpotQA"
|
||||||
|
source_prefix: "hotpotqa"
|
||||||
|
raw: "data/raw/hotpotqa"
|
||||||
|
converted: "data/converted/hotpotqa-minne.json"
|
||||||
|
include_unanswerable: false
|
||||||
|
slices:
|
||||||
|
- id: hotpotqa-test-200
|
||||||
|
label: "HotpotQA test (200)"
|
||||||
|
description: "200-case slice from BEIR test qrels"
|
||||||
|
limit: 200
|
||||||
|
corpus_limit: 5000
|
||||||
|
seed: 0x5eed2025
|
||||||
|
- id: nfcorpus
|
||||||
|
label: "NFCorpus (BEIR)"
|
||||||
|
category: "NFCorpus"
|
||||||
|
entity_suffix: "NFCorpus"
|
||||||
|
source_prefix: "nfcorpus"
|
||||||
|
raw: "data/raw/nfcorpus"
|
||||||
|
converted: "data/converted/nfcorpus-minne.json"
|
||||||
|
include_unanswerable: false
|
||||||
|
slices:
|
||||||
|
- id: nfcorpus-test-200
|
||||||
|
label: "NFCorpus test (200)"
|
||||||
|
description: "200-case slice from BEIR test qrels"
|
||||||
|
limit: 200
|
||||||
|
corpus_limit: 5000
|
||||||
|
seed: 0x5eed2025
|
||||||
|
- id: quora
|
||||||
|
label: "Quora (IR)"
|
||||||
|
category: "Quora"
|
||||||
|
entity_suffix: "Quora"
|
||||||
|
source_prefix: "quora"
|
||||||
|
raw: "data/raw/quora"
|
||||||
|
converted: "data/converted/quora-minne.json"
|
||||||
|
include_unanswerable: false
|
||||||
|
slices:
|
||||||
|
- id: quora-test-200
|
||||||
|
label: "Quora test (200)"
|
||||||
|
description: "200-case slice from BEIR test qrels"
|
||||||
|
limit: 200
|
||||||
|
corpus_limit: 5000
|
||||||
|
seed: 0x5eed2025
|
||||||
|
- id: trec-covid
|
||||||
|
label: "TREC-COVID (BEIR)"
|
||||||
|
category: "TREC-COVID"
|
||||||
|
entity_suffix: "TREC-COVID"
|
||||||
|
source_prefix: "trec-covid"
|
||||||
|
raw: "data/raw/trec-covid"
|
||||||
|
converted: "data/converted/trec-covid-minne.json"
|
||||||
|
include_unanswerable: false
|
||||||
|
slices:
|
||||||
|
- id: trec-covid-test-200
|
||||||
|
label: "TREC-COVID test (200)"
|
||||||
|
description: "200-case slice from BEIR test qrels"
|
||||||
|
limit: 200
|
||||||
|
corpus_limit: 5000
|
||||||
|
seed: 0x5eed2025
|
||||||
|
|||||||
@@ -347,6 +347,10 @@ impl Config {
|
|||||||
self.retrieval.require_verified_chunks = true;
|
self.retrieval.require_verified_chunks = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.dataset == DatasetKind::Beir {
|
||||||
|
self.negative_multiplier = 9.0;
|
||||||
|
}
|
||||||
|
|
||||||
// Validations
|
// Validations
|
||||||
if self.ingest_chunk_min_tokens == 0
|
if self.ingest_chunk_min_tokens == 0
|
||||||
|| self.ingest_chunk_min_tokens >= self.ingest_chunk_max_tokens
|
|| self.ingest_chunk_min_tokens >= self.ingest_chunk_max_tokens
|
||||||
|
|||||||
341
eval/src/datasets/beir.rs
Normal file
341
eval/src/datasets/beir.rs
Normal file
@@ -0,0 +1,341 @@
|
|||||||
|
use std::{
|
||||||
|
collections::{BTreeMap, HashMap},
|
||||||
|
fs::File,
|
||||||
|
io::{BufRead, BufReader},
|
||||||
|
path::{Path, PathBuf},
|
||||||
|
};
|
||||||
|
|
||||||
|
use anyhow::{anyhow, Context, Result};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use tracing::warn;
|
||||||
|
|
||||||
|
use super::{ConvertedParagraph, ConvertedQuestion, DatasetKind};
|
||||||
|
|
||||||
|
const ANSWER_SNIPPET_CHARS: usize = 240;
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct BeirCorpusRow {
|
||||||
|
#[serde(rename = "_id")]
|
||||||
|
id: String,
|
||||||
|
#[serde(default)]
|
||||||
|
title: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
text: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct BeirQueryRow {
|
||||||
|
#[serde(rename = "_id")]
|
||||||
|
id: String,
|
||||||
|
text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct BeirParagraph {
|
||||||
|
title: String,
|
||||||
|
context: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct BeirQuery {
|
||||||
|
text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct QrelEntry {
|
||||||
|
doc_id: String,
|
||||||
|
score: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn convert_beir(raw_dir: &Path, dataset: DatasetKind) -> Result<Vec<ConvertedParagraph>> {
|
||||||
|
let corpus_path = raw_dir.join("corpus.jsonl");
|
||||||
|
let queries_path = raw_dir.join("queries.jsonl");
|
||||||
|
let qrels_path = resolve_qrels_path(raw_dir)?;
|
||||||
|
|
||||||
|
let corpus = load_corpus(&corpus_path)?;
|
||||||
|
let queries = load_queries(&queries_path)?;
|
||||||
|
let qrels = load_qrels(&qrels_path)?;
|
||||||
|
|
||||||
|
let mut paragraphs = Vec::with_capacity(corpus.len());
|
||||||
|
let mut paragraph_index = HashMap::new();
|
||||||
|
|
||||||
|
for (doc_id, entry) in corpus.iter() {
|
||||||
|
let paragraph_id = format!("{}-{doc_id}", dataset.source_prefix());
|
||||||
|
let paragraph = ConvertedParagraph {
|
||||||
|
id: paragraph_id.clone(),
|
||||||
|
title: entry.title.clone(),
|
||||||
|
context: entry.context.clone(),
|
||||||
|
questions: Vec::new(),
|
||||||
|
};
|
||||||
|
paragraph_index.insert(doc_id.clone(), paragraphs.len());
|
||||||
|
paragraphs.push(paragraph);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut missing_queries = 0usize;
|
||||||
|
let mut missing_docs = 0usize;
|
||||||
|
let mut skipped_answers = 0usize;
|
||||||
|
|
||||||
|
for (query_id, entries) in qrels {
|
||||||
|
let query = match queries.get(&query_id) {
|
||||||
|
Some(query) => query,
|
||||||
|
None => {
|
||||||
|
missing_queries += 1;
|
||||||
|
warn!(query_id = %query_id, "Skipping qrels entry for missing query");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let best = match select_best_doc(&entries) {
|
||||||
|
Some(entry) => entry,
|
||||||
|
None => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
let paragraph_slot = match paragraph_index.get(&best.doc_id) {
|
||||||
|
Some(slot) => *slot,
|
||||||
|
None => {
|
||||||
|
missing_docs += 1;
|
||||||
|
warn!(
|
||||||
|
query_id = %query_id,
|
||||||
|
doc_id = %best.doc_id,
|
||||||
|
"Skipping qrels entry referencing missing corpus document"
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let answer = answer_snippet(¶graphs[paragraph_slot].context);
|
||||||
|
let answers = match answer {
|
||||||
|
Some(snippet) => vec![snippet],
|
||||||
|
None => {
|
||||||
|
skipped_answers += 1;
|
||||||
|
warn!(
|
||||||
|
query_id = %query_id,
|
||||||
|
doc_id = %best.doc_id,
|
||||||
|
"Skipping query because no non-empty answer snippet could be derived"
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let question_id = format!("{}-{query_id}", dataset.source_prefix());
|
||||||
|
paragraphs[paragraph_slot]
|
||||||
|
.questions
|
||||||
|
.push(ConvertedQuestion {
|
||||||
|
id: question_id,
|
||||||
|
question: query.text.clone(),
|
||||||
|
answers,
|
||||||
|
is_impossible: false,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if missing_queries + missing_docs + skipped_answers > 0 {
|
||||||
|
warn!(
|
||||||
|
missing_queries,
|
||||||
|
missing_docs, skipped_answers, "Skipped some BEIR qrels entries during conversion"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(paragraphs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn resolve_qrels_path(raw_dir: &Path) -> Result<PathBuf> {
|
||||||
|
let qrels_dir = raw_dir.join("qrels");
|
||||||
|
let candidates = ["test.tsv", "dev.tsv", "train.tsv"];
|
||||||
|
|
||||||
|
for name in candidates {
|
||||||
|
let candidate = qrels_dir.join(name);
|
||||||
|
if candidate.exists() {
|
||||||
|
return Ok(candidate);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(anyhow!(
|
||||||
|
"No qrels file found under {}; expected one of {:?}",
|
||||||
|
qrels_dir.display(),
|
||||||
|
candidates
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_corpus(path: &Path) -> Result<BTreeMap<String, BeirParagraph>> {
|
||||||
|
let file =
|
||||||
|
File::open(path).with_context(|| format!("opening BEIR corpus at {}", path.display()))?;
|
||||||
|
let reader = BufReader::new(file);
|
||||||
|
let mut corpus = BTreeMap::new();
|
||||||
|
|
||||||
|
for (idx, line) in reader.lines().enumerate() {
|
||||||
|
let raw = line
|
||||||
|
.with_context(|| format!("reading corpus line {} from {}", idx + 1, path.display()))?;
|
||||||
|
if raw.trim().is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let row: BeirCorpusRow = serde_json::from_str(&raw).with_context(|| {
|
||||||
|
format!(
|
||||||
|
"parsing corpus JSON on line {} from {}",
|
||||||
|
idx + 1,
|
||||||
|
path.display()
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let title = row.title.unwrap_or_else(|| row.id.clone());
|
||||||
|
let text = row.text.unwrap_or_default();
|
||||||
|
let context = build_context(&title, &text);
|
||||||
|
|
||||||
|
if context.is_empty() {
|
||||||
|
warn!(doc_id = %row.id, "Skipping empty corpus document");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
corpus.insert(row.id, BeirParagraph { title, context });
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(corpus)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_queries(path: &Path) -> Result<BTreeMap<String, BeirQuery>> {
|
||||||
|
let file = File::open(path)
|
||||||
|
.with_context(|| format!("opening BEIR queries file at {}", path.display()))?;
|
||||||
|
let reader = BufReader::new(file);
|
||||||
|
let mut queries = BTreeMap::new();
|
||||||
|
|
||||||
|
for (idx, line) in reader.lines().enumerate() {
|
||||||
|
let raw = line
|
||||||
|
.with_context(|| format!("reading query line {} from {}", idx + 1, path.display()))?;
|
||||||
|
if raw.trim().is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let row: BeirQueryRow = serde_json::from_str(&raw).with_context(|| {
|
||||||
|
format!(
|
||||||
|
"parsing query JSON on line {} from {}",
|
||||||
|
idx + 1,
|
||||||
|
path.display()
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
queries.insert(
|
||||||
|
row.id,
|
||||||
|
BeirQuery {
|
||||||
|
text: row.text.trim().to_string(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(queries)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_qrels(path: &Path) -> Result<BTreeMap<String, Vec<QrelEntry>>> {
|
||||||
|
let file =
|
||||||
|
File::open(path).with_context(|| format!("opening BEIR qrels at {}", path.display()))?;
|
||||||
|
let reader = BufReader::new(file);
|
||||||
|
let mut qrels: BTreeMap<String, Vec<QrelEntry>> = BTreeMap::new();
|
||||||
|
|
||||||
|
for (idx, line) in reader.lines().enumerate() {
|
||||||
|
let raw = line
|
||||||
|
.with_context(|| format!("reading qrels line {} from {}", idx + 1, path.display()))?;
|
||||||
|
let trimmed = raw.trim();
|
||||||
|
if trimmed.is_empty() || trimmed.starts_with("query-id") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let mut parts = trimmed.split_whitespace();
|
||||||
|
let query_id = parts
|
||||||
|
.next()
|
||||||
|
.ok_or_else(|| anyhow!("missing query id on line {}", idx + 1))?;
|
||||||
|
let doc_id = parts
|
||||||
|
.next()
|
||||||
|
.ok_or_else(|| anyhow!("missing document id on line {}", idx + 1))?;
|
||||||
|
let score_raw = parts
|
||||||
|
.next()
|
||||||
|
.ok_or_else(|| anyhow!("missing score on line {}", idx + 1))?;
|
||||||
|
let score: i32 = score_raw.parse().with_context(|| {
|
||||||
|
format!(
|
||||||
|
"parsing qrels score '{}' on line {} from {}",
|
||||||
|
score_raw,
|
||||||
|
idx + 1,
|
||||||
|
path.display()
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
qrels
|
||||||
|
.entry(query_id.to_string())
|
||||||
|
.or_default()
|
||||||
|
.push(QrelEntry {
|
||||||
|
doc_id: doc_id.to_string(),
|
||||||
|
score,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(qrels)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn select_best_doc(entries: &[QrelEntry]) -> Option<&QrelEntry> {
|
||||||
|
entries
|
||||||
|
.iter()
|
||||||
|
.max_by(|a, b| a.score.cmp(&b.score).then_with(|| b.doc_id.cmp(&a.doc_id)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn answer_snippet(text: &str) -> Option<String> {
|
||||||
|
let trimmed = text.trim();
|
||||||
|
if trimmed.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let snippet: String = trimmed.chars().take(ANSWER_SNIPPET_CHARS).collect();
|
||||||
|
let snippet = snippet.trim();
|
||||||
|
if snippet.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(snippet.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_context(title: &str, text: &str) -> String {
|
||||||
|
let title = title.trim();
|
||||||
|
let text = text.trim();
|
||||||
|
|
||||||
|
match (title.is_empty(), text.is_empty()) {
|
||||||
|
(true, true) => String::new(),
|
||||||
|
(true, false) => text.to_string(),
|
||||||
|
(false, true) => title.to_string(),
|
||||||
|
(false, false) => format!("{title}\n\n{text}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::fs;
|
||||||
|
use tempfile::tempdir;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn converts_basic_beir_layout() {
|
||||||
|
let dir = tempdir().unwrap();
|
||||||
|
let corpus = r#"
|
||||||
|
{"_id":"d1","title":"Doc 1","text":"Doc one has some text for testing."}
|
||||||
|
{"_id":"d2","title":"Doc 2","text":"Second document content."}
|
||||||
|
"#;
|
||||||
|
let queries = r#"
|
||||||
|
{"_id":"q1","text":"What is in doc one?"}
|
||||||
|
"#;
|
||||||
|
let qrels = "query-id\tcorpus-id\tscore\nq1\td1\t2\n";
|
||||||
|
|
||||||
|
fs::write(dir.path().join("corpus.jsonl"), corpus.trim()).unwrap();
|
||||||
|
fs::write(dir.path().join("queries.jsonl"), queries.trim()).unwrap();
|
||||||
|
fs::create_dir_all(dir.path().join("qrels")).unwrap();
|
||||||
|
fs::write(dir.path().join("qrels/test.tsv"), qrels).unwrap();
|
||||||
|
|
||||||
|
let paragraphs = convert_beir(dir.path(), DatasetKind::Fever).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(paragraphs.len(), 2);
|
||||||
|
let doc_one = paragraphs
|
||||||
|
.iter()
|
||||||
|
.find(|p| p.id == "fever-d1")
|
||||||
|
.expect("missing paragraph for d1");
|
||||||
|
assert_eq!(doc_one.questions.len(), 1);
|
||||||
|
let question = &doc_one.questions[0];
|
||||||
|
assert_eq!(question.id, "fever-q1");
|
||||||
|
assert!(!question.answers.is_empty());
|
||||||
|
assert!(doc_one.context.contains(&question.answers[0]));
|
||||||
|
|
||||||
|
let doc_two = paragraphs
|
||||||
|
.iter()
|
||||||
|
.find(|p| p.id == "fever-d2")
|
||||||
|
.expect("missing paragraph for d2");
|
||||||
|
assert!(doc_two.questions.is_empty());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
mod beir;
|
||||||
mod nq;
|
mod nq;
|
||||||
mod squad;
|
mod squad;
|
||||||
|
|
||||||
@@ -10,10 +11,10 @@ use std::{
|
|||||||
|
|
||||||
use anyhow::{anyhow, bail, Context, Result};
|
use anyhow::{anyhow, bail, Context, Result};
|
||||||
use chrono::{DateTime, TimeZone, Utc};
|
use chrono::{DateTime, TimeZone, Utc};
|
||||||
|
use clap::ValueEnum;
|
||||||
use once_cell::sync::OnceCell;
|
use once_cell::sync::OnceCell;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tracing::warn;
|
use tracing::warn;
|
||||||
use clap::ValueEnum;
|
|
||||||
|
|
||||||
const MANIFEST_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/manifest.yaml");
|
const MANIFEST_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/manifest.yaml");
|
||||||
static DATASET_CATALOG: OnceCell<DatasetCatalog> = OnceCell::new();
|
static DATASET_CATALOG: OnceCell<DatasetCatalog> = OnceCell::new();
|
||||||
@@ -248,6 +249,19 @@ fn dataset_entry_for_kind(kind: DatasetKind) -> Result<&'static DatasetEntry> {
|
|||||||
pub enum DatasetKind {
|
pub enum DatasetKind {
|
||||||
SquadV2,
|
SquadV2,
|
||||||
NaturalQuestions,
|
NaturalQuestions,
|
||||||
|
Beir,
|
||||||
|
#[value(name = "fever")]
|
||||||
|
Fever,
|
||||||
|
#[value(name = "fiqa")]
|
||||||
|
Fiqa,
|
||||||
|
#[value(name = "hotpotqa", alias = "hotpot-qa")]
|
||||||
|
HotpotQa,
|
||||||
|
#[value(name = "nfcorpus", alias = "nf-corpus")]
|
||||||
|
Nfcorpus,
|
||||||
|
#[value(name = "quora")]
|
||||||
|
Quora,
|
||||||
|
#[value(name = "trec-covid", alias = "treccovid", alias = "trec_covid")]
|
||||||
|
TrecCovid,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DatasetKind {
|
impl DatasetKind {
|
||||||
@@ -255,6 +269,13 @@ impl DatasetKind {
|
|||||||
match self {
|
match self {
|
||||||
Self::SquadV2 => "squad-v2",
|
Self::SquadV2 => "squad-v2",
|
||||||
Self::NaturalQuestions => "natural-questions-dev",
|
Self::NaturalQuestions => "natural-questions-dev",
|
||||||
|
Self::Beir => "beir",
|
||||||
|
Self::Fever => "fever",
|
||||||
|
Self::Fiqa => "fiqa",
|
||||||
|
Self::HotpotQa => "hotpotqa",
|
||||||
|
Self::Nfcorpus => "nfcorpus",
|
||||||
|
Self::Quora => "quora",
|
||||||
|
Self::TrecCovid => "trec-covid",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -262,6 +283,13 @@ impl DatasetKind {
|
|||||||
match self {
|
match self {
|
||||||
Self::SquadV2 => "SQuAD v2.0",
|
Self::SquadV2 => "SQuAD v2.0",
|
||||||
Self::NaturalQuestions => "Natural Questions (dev)",
|
Self::NaturalQuestions => "Natural Questions (dev)",
|
||||||
|
Self::Beir => "BEIR mix",
|
||||||
|
Self::Fever => "FEVER (BEIR)",
|
||||||
|
Self::Fiqa => "FiQA-2018 (BEIR)",
|
||||||
|
Self::HotpotQa => "HotpotQA (BEIR)",
|
||||||
|
Self::Nfcorpus => "NFCorpus (BEIR)",
|
||||||
|
Self::Quora => "Quora (IR)",
|
||||||
|
Self::TrecCovid => "TREC-COVID (BEIR)",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -269,6 +297,13 @@ impl DatasetKind {
|
|||||||
match self {
|
match self {
|
||||||
Self::SquadV2 => "SQuAD v2.0",
|
Self::SquadV2 => "SQuAD v2.0",
|
||||||
Self::NaturalQuestions => "Natural Questions",
|
Self::NaturalQuestions => "Natural Questions",
|
||||||
|
Self::Beir => "BEIR",
|
||||||
|
Self::Fever => "FEVER",
|
||||||
|
Self::Fiqa => "FiQA-2018",
|
||||||
|
Self::HotpotQa => "HotpotQA",
|
||||||
|
Self::Nfcorpus => "NFCorpus",
|
||||||
|
Self::Quora => "Quora",
|
||||||
|
Self::TrecCovid => "TREC-COVID",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -276,6 +311,13 @@ impl DatasetKind {
|
|||||||
match self {
|
match self {
|
||||||
Self::SquadV2 => "SQuAD",
|
Self::SquadV2 => "SQuAD",
|
||||||
Self::NaturalQuestions => "Natural Questions",
|
Self::NaturalQuestions => "Natural Questions",
|
||||||
|
Self::Beir => "BEIR",
|
||||||
|
Self::Fever => "FEVER",
|
||||||
|
Self::Fiqa => "FiQA",
|
||||||
|
Self::HotpotQa => "HotpotQA",
|
||||||
|
Self::Nfcorpus => "NFCorpus",
|
||||||
|
Self::Quora => "Quora",
|
||||||
|
Self::TrecCovid => "TREC-COVID",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -283,6 +325,13 @@ impl DatasetKind {
|
|||||||
match self {
|
match self {
|
||||||
Self::SquadV2 => "squad",
|
Self::SquadV2 => "squad",
|
||||||
Self::NaturalQuestions => "nq",
|
Self::NaturalQuestions => "nq",
|
||||||
|
Self::Beir => "beir",
|
||||||
|
Self::Fever => "fever",
|
||||||
|
Self::Fiqa => "fiqa",
|
||||||
|
Self::HotpotQa => "hotpotqa",
|
||||||
|
Self::Nfcorpus => "nfcorpus",
|
||||||
|
Self::Quora => "quora",
|
||||||
|
Self::TrecCovid => "trec-covid",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -320,13 +369,29 @@ impl FromStr for DatasetKind {
|
|||||||
"nq" | "natural-questions" | "natural_questions" | "natural-questions-dev" => {
|
"nq" | "natural-questions" | "natural_questions" | "natural-questions-dev" => {
|
||||||
Ok(Self::NaturalQuestions)
|
Ok(Self::NaturalQuestions)
|
||||||
}
|
}
|
||||||
|
"beir" => Ok(Self::Beir),
|
||||||
|
"fever" => Ok(Self::Fever),
|
||||||
|
"fiqa" | "fiqa-2018" => Ok(Self::Fiqa),
|
||||||
|
"hotpotqa" | "hotpot-qa" => Ok(Self::HotpotQa),
|
||||||
|
"nfcorpus" | "nf-corpus" => Ok(Self::Nfcorpus),
|
||||||
|
"quora" => Ok(Self::Quora),
|
||||||
|
"trec-covid" | "treccovid" | "trec_covid" => Ok(Self::TrecCovid),
|
||||||
other => {
|
other => {
|
||||||
anyhow::bail!("unknown dataset '{other}'. Expected 'squad' or 'natural-questions'.")
|
anyhow::bail!("unknown dataset '{other}'. Expected one of: squad, natural-questions, beir, fever, fiqa, hotpotqa, nfcorpus, quora, trec-covid.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub const BEIR_DATASETS: [DatasetKind; 6] = [
|
||||||
|
DatasetKind::Fever,
|
||||||
|
DatasetKind::Fiqa,
|
||||||
|
DatasetKind::HotpotQa,
|
||||||
|
DatasetKind::Nfcorpus,
|
||||||
|
DatasetKind::Quora,
|
||||||
|
DatasetKind::TrecCovid,
|
||||||
|
];
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct DatasetMetadata {
|
pub struct DatasetMetadata {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
@@ -410,6 +475,13 @@ pub fn convert(
|
|||||||
DatasetKind::NaturalQuestions => {
|
DatasetKind::NaturalQuestions => {
|
||||||
nq::convert_nq(raw_path, include_unanswerable, context_token_limit)?
|
nq::convert_nq(raw_path, include_unanswerable, context_token_limit)?
|
||||||
}
|
}
|
||||||
|
DatasetKind::Beir => convert_beir_mix(include_unanswerable, context_token_limit)?,
|
||||||
|
DatasetKind::Fever
|
||||||
|
| DatasetKind::Fiqa
|
||||||
|
| DatasetKind::HotpotQa
|
||||||
|
| DatasetKind::Nfcorpus
|
||||||
|
| DatasetKind::Quora
|
||||||
|
| DatasetKind::TrecCovid => beir::convert_beir(raw_path, dataset)?,
|
||||||
};
|
};
|
||||||
|
|
||||||
let metadata_limit = match dataset {
|
let metadata_limit = match dataset {
|
||||||
@@ -417,14 +489,37 @@ pub fn convert(
|
|||||||
_ => context_token_limit,
|
_ => context_token_limit,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let source_label = match dataset {
|
||||||
|
DatasetKind::Beir => "beir-mix".to_string(),
|
||||||
|
_ => raw_path.display().to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
Ok(ConvertedDataset {
|
Ok(ConvertedDataset {
|
||||||
generated_at: Utc::now(),
|
generated_at: Utc::now(),
|
||||||
metadata: DatasetMetadata::for_kind(dataset, include_unanswerable, metadata_limit),
|
metadata: DatasetMetadata::for_kind(dataset, include_unanswerable, metadata_limit),
|
||||||
source: raw_path.display().to_string(),
|
source: source_label,
|
||||||
paragraphs,
|
paragraphs,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn convert_beir_mix(
|
||||||
|
include_unanswerable: bool,
|
||||||
|
_context_token_limit: Option<usize>,
|
||||||
|
) -> Result<Vec<ConvertedParagraph>> {
|
||||||
|
if include_unanswerable {
|
||||||
|
warn!("BEIR mix ignores include_unanswerable flag; all questions are answerable");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut paragraphs = Vec::new();
|
||||||
|
for subset in BEIR_DATASETS {
|
||||||
|
let entry = dataset_entry_for_kind(subset)?;
|
||||||
|
let subset_paragraphs = beir::convert_beir(&entry.raw_path, subset)?;
|
||||||
|
paragraphs.extend(subset_paragraphs);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(paragraphs)
|
||||||
|
}
|
||||||
|
|
||||||
fn ensure_parent(path: &Path) -> Result<()> {
|
fn ensure_parent(path: &Path) -> Result<()> {
|
||||||
if let Some(parent) = path.parent() {
|
if let Some(parent) = path.parent() {
|
||||||
fs::create_dir_all(parent)
|
fs::create_dir_all(parent)
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ use futures::stream::{self, StreamExt};
|
|||||||
use tracing::{debug, info};
|
use tracing::{debug, info};
|
||||||
|
|
||||||
use crate::eval::{
|
use crate::eval::{
|
||||||
adapt_strategy_output, build_case_diagnostics,
|
adapt_strategy_output, build_case_diagnostics, text_contains_answer, CaseDiagnostics,
|
||||||
text_contains_answer, CaseDiagnostics, CaseSummary, RetrievedSummary,
|
CaseSummary, RetrievedSummary,
|
||||||
};
|
};
|
||||||
use retrieval_pipeline::{
|
use retrieval_pipeline::{
|
||||||
pipeline::{self, PipelineStageTimings, RetrievalConfig},
|
pipeline::{self, PipelineStageTimings, RetrievalConfig},
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ use uuid::Uuid;
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
datasets::{ConvertedDataset, ConvertedParagraph, ConvertedQuestion},
|
datasets::{ConvertedDataset, ConvertedParagraph, ConvertedQuestion},
|
||||||
db_helpers::change_embedding_length_in_hnsw_indexes,
|
|
||||||
slices::{self, ResolvedSlice, SliceParagraphKind},
|
slices::{self, ResolvedSlice, SliceParagraphKind},
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -417,10 +416,6 @@ async fn ingest_paragraph_batch(
|
|||||||
.await
|
.await
|
||||||
.context("applying migrations for ingestion")?;
|
.context("applying migrations for ingestion")?;
|
||||||
|
|
||||||
change_embedding_length_in_hnsw_indexes(&db, embedding_dimension)
|
|
||||||
.await
|
|
||||||
.context("failed setting new hnsw length")?;
|
|
||||||
|
|
||||||
let mut app_config = AppConfig::default();
|
let mut app_config = AppConfig::default();
|
||||||
app_config.storage = StorageKind::Memory;
|
app_config.storage = StorageKind::Memory;
|
||||||
let backend: DynStore = Arc::new(InMemory::new());
|
let backend: DynStore = Arc::new(InMemory::new());
|
||||||
|
|||||||
@@ -93,7 +93,6 @@ async fn async_main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
// Clap handles help automatically, so we don't need to check for it manually
|
// Clap handles help automatically, so we don't need to check for it manually
|
||||||
|
|
||||||
|
|
||||||
if parsed.config.inspect_question.is_some() {
|
if parsed.config.inspect_question.is_some() {
|
||||||
inspection::inspect_question(&parsed.config).await?;
|
inspection::inspect_question(&parsed.config).await?;
|
||||||
return Ok(());
|
return Ok(());
|
||||||
|
|||||||
@@ -145,6 +145,8 @@ mod tests {
|
|||||||
precision_at_1: 0.5,
|
precision_at_1: 0.5,
|
||||||
precision_at_2: 0.5,
|
precision_at_2: 0.5,
|
||||||
precision_at_3: 0.5,
|
precision_at_3: 0.5,
|
||||||
|
mrr: 0.0,
|
||||||
|
average_ndcg: 0.0,
|
||||||
duration_ms: 1234,
|
duration_ms: 1234,
|
||||||
dataset_id: "squad-v2".into(),
|
dataset_id: "squad-v2".into(),
|
||||||
dataset_label: "SQuAD v2".into(),
|
dataset_label: "SQuAD v2".into(),
|
||||||
@@ -192,18 +194,17 @@ mod tests {
|
|||||||
rerank_pool_size: Some(4),
|
rerank_pool_size: Some(4),
|
||||||
rerank_keep_top: 10,
|
rerank_keep_top: 10,
|
||||||
concurrency: 2,
|
concurrency: 2,
|
||||||
retrieval_strategy: "initial".into(),
|
|
||||||
detailed_report: false,
|
detailed_report: false,
|
||||||
|
retrieval_strategy: "initial".into(),
|
||||||
|
chunk_result_cap: 5,
|
||||||
ingest_chunk_min_tokens: 256,
|
ingest_chunk_min_tokens: 256,
|
||||||
ingest_chunk_max_tokens: 512,
|
ingest_chunk_max_tokens: 512,
|
||||||
ingest_chunk_overlap_tokens: 50,
|
|
||||||
ingest_chunks_only: false,
|
ingest_chunks_only: false,
|
||||||
|
ingest_chunk_overlap_tokens: 50,
|
||||||
chunk_vector_take: 20,
|
chunk_vector_take: 20,
|
||||||
chunk_fts_take: 20,
|
chunk_fts_take: 20,
|
||||||
chunk_avg_chars_per_token: 4,
|
chunk_avg_chars_per_token: 4,
|
||||||
max_chunks_per_entity: 4,
|
max_chunks_per_entity: 4,
|
||||||
average_ndcg: 0.0,
|
|
||||||
mrr: 0.0,
|
|
||||||
cases: Vec::new(),
|
cases: Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -88,6 +88,10 @@ pub struct RetrievalSection {
|
|||||||
pub rerank_pool_size: Option<usize>,
|
pub rerank_pool_size: Option<usize>,
|
||||||
pub rerank_keep_top: usize,
|
pub rerank_keep_top: usize,
|
||||||
pub chunk_result_cap: usize,
|
pub chunk_result_cap: usize,
|
||||||
|
#[serde(default)]
|
||||||
|
pub chunk_vector_take: usize,
|
||||||
|
#[serde(default)]
|
||||||
|
pub chunk_fts_take: usize,
|
||||||
pub ingest_chunk_min_tokens: usize,
|
pub ingest_chunk_min_tokens: usize,
|
||||||
pub ingest_chunk_max_tokens: usize,
|
pub ingest_chunk_max_tokens: usize,
|
||||||
pub ingest_chunk_overlap_tokens: usize,
|
pub ingest_chunk_overlap_tokens: usize,
|
||||||
@@ -202,6 +206,8 @@ impl EvaluationReport {
|
|||||||
rerank_pool_size: summary.rerank_pool_size,
|
rerank_pool_size: summary.rerank_pool_size,
|
||||||
rerank_keep_top: summary.rerank_keep_top,
|
rerank_keep_top: summary.rerank_keep_top,
|
||||||
chunk_result_cap: summary.chunk_result_cap,
|
chunk_result_cap: summary.chunk_result_cap,
|
||||||
|
chunk_vector_take: summary.chunk_vector_take,
|
||||||
|
chunk_fts_take: summary.chunk_fts_take,
|
||||||
ingest_chunk_min_tokens: summary.ingest_chunk_min_tokens,
|
ingest_chunk_min_tokens: summary.ingest_chunk_min_tokens,
|
||||||
ingest_chunk_max_tokens: summary.ingest_chunk_max_tokens,
|
ingest_chunk_max_tokens: summary.ingest_chunk_max_tokens,
|
||||||
ingest_chunk_overlap_tokens: summary.ingest_chunk_overlap_tokens,
|
ingest_chunk_overlap_tokens: summary.ingest_chunk_overlap_tokens,
|
||||||
@@ -467,10 +473,7 @@ fn render_markdown(report: &EvaluationReport) -> String {
|
|||||||
report.retrieval.precision_at_2,
|
report.retrieval.precision_at_2,
|
||||||
report.retrieval.precision_at_3
|
report.retrieval.precision_at_3
|
||||||
));
|
));
|
||||||
md.push_str(&format!(
|
md.push_str(&format!("| MRR | {:.3} |\\n", report.retrieval.mrr));
|
||||||
"| MRR | {:.3} |\\n",
|
|
||||||
report.retrieval.mrr
|
|
||||||
));
|
|
||||||
md.push_str(&format!(
|
md.push_str(&format!(
|
||||||
"| NDCG | {:.3} |\\n",
|
"| NDCG | {:.3} |\\n",
|
||||||
report.retrieval.average_ndcg
|
report.retrieval.average_ndcg
|
||||||
@@ -632,7 +635,9 @@ fn render_markdown(report: &EvaluationReport) -> String {
|
|||||||
if report.detailed_report {
|
if report.detailed_report {
|
||||||
md.push_str("All LLM-only cases matched within the evaluation window.\\n");
|
md.push_str("All LLM-only cases matched within the evaluation window.\\n");
|
||||||
} else {
|
} else {
|
||||||
md.push_str("LLM-only cases omitted. Re-run with `--detailed-report` to see samples.\\n");
|
md.push_str(
|
||||||
|
"LLM-only cases omitted. Re-run with `--detailed-report` to see samples.\\n",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
md.push_str("| Question ID | Answered | Match Rank | Top Retrieved |\\n");
|
md.push_str("| Question ID | Answered | Match Rank | Top Retrieved |\\n");
|
||||||
@@ -851,6 +856,8 @@ fn convert_legacy_entry(entry: LegacyHistoryEntry) -> EvaluationReport {
|
|||||||
rerank_pool_size: entry.rerank_pool_size,
|
rerank_pool_size: entry.rerank_pool_size,
|
||||||
rerank_keep_top: entry.rerank_keep_top,
|
rerank_keep_top: entry.rerank_keep_top,
|
||||||
chunk_result_cap: entry.chunk_result_cap.unwrap_or(5),
|
chunk_result_cap: entry.chunk_result_cap.unwrap_or(5),
|
||||||
|
chunk_vector_take: 0,
|
||||||
|
chunk_fts_take: 0,
|
||||||
ingest_chunk_min_tokens: entry.ingest_chunk_min_tokens.unwrap_or(256),
|
ingest_chunk_min_tokens: entry.ingest_chunk_min_tokens.unwrap_or(256),
|
||||||
ingest_chunk_max_tokens: entry.ingest_chunk_max_tokens.unwrap_or(512),
|
ingest_chunk_max_tokens: entry.ingest_chunk_max_tokens.unwrap_or(512),
|
||||||
ingest_chunk_overlap_tokens: entry.ingest_chunk_overlap_tokens.unwrap_or(50),
|
ingest_chunk_overlap_tokens: entry.ingest_chunk_overlap_tokens.unwrap_or(50),
|
||||||
@@ -1126,8 +1133,7 @@ mod tests {
|
|||||||
let tmp = tempdir().unwrap();
|
let tmp = tempdir().unwrap();
|
||||||
let summary = sample_summary(false);
|
let summary = sample_summary(false);
|
||||||
|
|
||||||
let outcome =
|
let outcome = write_reports(&summary, tmp.path(), 5).expect("writing consolidated reports");
|
||||||
write_reports(&summary, tmp.path(), 5).expect("writing consolidated reports");
|
|
||||||
let contents =
|
let contents =
|
||||||
std::fs::read_to_string(&outcome.history_path).expect("reading evaluations history");
|
std::fs::read_to_string(&outcome.history_path).expect("reading evaluations history");
|
||||||
let entries: Vec<EvaluationReport> =
|
let entries: Vec<EvaluationReport> =
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use std::{
|
use std::{
|
||||||
collections::{HashMap, HashSet},
|
collections::{HashMap, HashSet, VecDeque},
|
||||||
fs,
|
fs,
|
||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
};
|
};
|
||||||
@@ -11,7 +11,9 @@ use serde::{Deserialize, Serialize};
|
|||||||
use sha2::{Digest, Sha256};
|
use sha2::{Digest, Sha256};
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
|
|
||||||
use crate::datasets::{ConvertedDataset, ConvertedParagraph, ConvertedQuestion};
|
use crate::datasets::{
|
||||||
|
ConvertedDataset, ConvertedParagraph, ConvertedQuestion, DatasetKind, BEIR_DATASETS,
|
||||||
|
};
|
||||||
|
|
||||||
const SLICE_VERSION: u32 = 2;
|
const SLICE_VERSION: u32 = 2;
|
||||||
pub const DEFAULT_NEGATIVE_MULTIPLIER: f32 = 4.0;
|
pub const DEFAULT_NEGATIVE_MULTIPLIER: f32 = 4.0;
|
||||||
@@ -526,7 +528,7 @@ fn ensure_case_capacity(
|
|||||||
return Ok(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
let question_refs = ordered_question_refs(dataset, params)?;
|
let question_refs = ordered_question_refs(dataset, params, target_cases)?;
|
||||||
let mut existing_questions: HashSet<String> = manifest
|
let mut existing_questions: HashSet<String> = manifest
|
||||||
.cases
|
.cases
|
||||||
.iter()
|
.iter()
|
||||||
@@ -599,7 +601,12 @@ fn ensure_case_capacity(
|
|||||||
fn ordered_question_refs(
|
fn ordered_question_refs(
|
||||||
dataset: &ConvertedDataset,
|
dataset: &ConvertedDataset,
|
||||||
params: &BuildParams,
|
params: &BuildParams,
|
||||||
|
target_cases: usize,
|
||||||
) -> Result<Vec<(usize, usize)>> {
|
) -> Result<Vec<(usize, usize)>> {
|
||||||
|
if dataset.metadata.id == DatasetKind::Beir.id() {
|
||||||
|
return ordered_question_refs_beir(dataset, params, target_cases);
|
||||||
|
}
|
||||||
|
|
||||||
let mut question_refs = Vec::new();
|
let mut question_refs = Vec::new();
|
||||||
for (p_idx, paragraph) in dataset.paragraphs.iter().enumerate() {
|
for (p_idx, paragraph) in dataset.paragraphs.iter().enumerate() {
|
||||||
for (q_idx, question) in paragraph.questions.iter().enumerate() {
|
for (q_idx, question) in paragraph.questions.iter().enumerate() {
|
||||||
@@ -626,6 +633,170 @@ fn ordered_question_refs(
|
|||||||
Ok(question_refs)
|
Ok(question_refs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn ordered_question_refs_beir(
|
||||||
|
dataset: &ConvertedDataset,
|
||||||
|
params: &BuildParams,
|
||||||
|
target_cases: usize,
|
||||||
|
) -> Result<Vec<(usize, usize)>> {
|
||||||
|
let prefixes: Vec<&str> = BEIR_DATASETS
|
||||||
|
.iter()
|
||||||
|
.map(|kind| kind.source_prefix())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut grouped: HashMap<&str, Vec<(usize, usize)>> = HashMap::new();
|
||||||
|
for (p_idx, paragraph) in dataset.paragraphs.iter().enumerate() {
|
||||||
|
for (q_idx, question) in paragraph.questions.iter().enumerate() {
|
||||||
|
let include = if params.include_impossible {
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
!question.is_impossible && !question.answers.is_empty()
|
||||||
|
};
|
||||||
|
if !include {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(prefix) = question_prefix(&question.id) else {
|
||||||
|
warn!(
|
||||||
|
question_id = %question.id,
|
||||||
|
"Skipping BEIR question without expected prefix"
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
if !prefixes.contains(&prefix) {
|
||||||
|
warn!(
|
||||||
|
question_id = %question.id,
|
||||||
|
prefix = %prefix,
|
||||||
|
"Skipping BEIR question with unknown subset prefix"
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
grouped.entry(prefix).or_default().push((p_idx, q_idx));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if grouped.values().all(|entries| entries.is_empty()) {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"no eligible BEIR questions found; cannot build slice"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
for prefix in &prefixes {
|
||||||
|
if let Some(entries) = grouped.get_mut(prefix) {
|
||||||
|
let seed = mix_seed(
|
||||||
|
&format!("{}::{prefix}", dataset.metadata.id),
|
||||||
|
params.base_seed,
|
||||||
|
);
|
||||||
|
let mut rng = StdRng::seed_from_u64(seed);
|
||||||
|
entries.shuffle(&mut rng);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let dataset_count = prefixes.len().max(1);
|
||||||
|
let base_quota = target_cases / dataset_count;
|
||||||
|
let mut remainder = target_cases % dataset_count;
|
||||||
|
|
||||||
|
let mut quotas: HashMap<&str, usize> = HashMap::new();
|
||||||
|
for prefix in &prefixes {
|
||||||
|
let mut quota = base_quota;
|
||||||
|
if remainder > 0 {
|
||||||
|
quota += 1;
|
||||||
|
remainder -= 1;
|
||||||
|
}
|
||||||
|
quotas.insert(*prefix, quota);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut take_counts: HashMap<&str, usize> = HashMap::new();
|
||||||
|
let mut spare_slots: HashMap<&str, usize> = HashMap::new();
|
||||||
|
let mut shortfall = 0usize;
|
||||||
|
|
||||||
|
for prefix in &prefixes {
|
||||||
|
let available = grouped.get(prefix).map(|v| v.len()).unwrap_or(0);
|
||||||
|
let quota = *quotas.get(prefix).unwrap_or(&0);
|
||||||
|
let take = quota.min(available);
|
||||||
|
let missing = quota.saturating_sub(take);
|
||||||
|
shortfall += missing;
|
||||||
|
take_counts.insert(*prefix, take);
|
||||||
|
spare_slots.insert(*prefix, available.saturating_sub(take));
|
||||||
|
}
|
||||||
|
|
||||||
|
while shortfall > 0 {
|
||||||
|
let mut allocated = false;
|
||||||
|
for prefix in &prefixes {
|
||||||
|
if shortfall == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let spare = spare_slots.get(prefix).copied().unwrap_or(0);
|
||||||
|
if spare == 0 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if let Some(count) = take_counts.get_mut(prefix) {
|
||||||
|
*count += 1;
|
||||||
|
}
|
||||||
|
spare_slots.insert(*prefix, spare - 1);
|
||||||
|
shortfall = shortfall.saturating_sub(1);
|
||||||
|
allocated = true;
|
||||||
|
}
|
||||||
|
if !allocated {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut queues: Vec<VecDeque<(usize, usize)>> = Vec::new();
|
||||||
|
let mut total_selected = 0usize;
|
||||||
|
for prefix in &prefixes {
|
||||||
|
let take = *take_counts.get(prefix).unwrap_or(&0);
|
||||||
|
let mut deque = VecDeque::new();
|
||||||
|
if let Some(entries) = grouped.get(prefix) {
|
||||||
|
for item in entries.iter().take(take) {
|
||||||
|
deque.push_back(*item);
|
||||||
|
total_selected += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
queues.push(deque);
|
||||||
|
}
|
||||||
|
|
||||||
|
if total_selected < target_cases {
|
||||||
|
warn!(
|
||||||
|
requested = target_cases,
|
||||||
|
available = total_selected,
|
||||||
|
"BEIR mix requested more questions than available after balancing; continuing with capped set"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut output = Vec::with_capacity(total_selected);
|
||||||
|
loop {
|
||||||
|
let mut progressed = false;
|
||||||
|
for queue in queues.iter_mut() {
|
||||||
|
if let Some(item) = queue.pop_front() {
|
||||||
|
output.push(item);
|
||||||
|
progressed = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !progressed {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if output.is_empty() {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"no eligible BEIR questions found; cannot build slice"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(output)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn question_prefix(question_id: &str) -> Option<&'static str> {
|
||||||
|
for prefix in BEIR_DATASETS.iter().map(|kind| kind.source_prefix()) {
|
||||||
|
if let Some(rest) = question_id.strip_prefix(prefix) {
|
||||||
|
if rest.starts_with('-') {
|
||||||
|
return Some(prefix);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
fn ensure_negative_pool(
|
fn ensure_negative_pool(
|
||||||
dataset: &ConvertedDataset,
|
dataset: &ConvertedDataset,
|
||||||
manifest: &mut SliceManifest,
|
manifest: &mut SliceManifest,
|
||||||
@@ -981,4 +1152,65 @@ mod tests {
|
|||||||
.any(|entry| entry.id == positive_ids[0]));
|
.any(|entry| entry.id == positive_ids[0]));
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn beir_mix_balances_and_rebalances() -> Result<()> {
|
||||||
|
let mut paragraphs = Vec::new();
|
||||||
|
let counts = [
|
||||||
|
("fever", 1usize),
|
||||||
|
("fiqa", 2usize),
|
||||||
|
("hotpotqa", 1usize),
|
||||||
|
("nfcorpus", 0usize),
|
||||||
|
("quora", 3usize),
|
||||||
|
("trec-covid", 2usize),
|
||||||
|
];
|
||||||
|
|
||||||
|
for (prefix, count) in counts {
|
||||||
|
for idx in 0..count {
|
||||||
|
let q_id = format!("{prefix}-q{idx}");
|
||||||
|
paragraphs.push(ConvertedParagraph {
|
||||||
|
id: format!("{prefix}-p{idx}"),
|
||||||
|
title: format!("{prefix} title"),
|
||||||
|
context: format!("{prefix} context {idx}"),
|
||||||
|
questions: vec![ConvertedQuestion {
|
||||||
|
id: q_id,
|
||||||
|
question: format!("{prefix} question {idx}"),
|
||||||
|
answers: vec!["answer".to_string()],
|
||||||
|
is_impossible: false,
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let metadata = DatasetMetadata::for_kind(DatasetKind::Beir, false, None);
|
||||||
|
let dataset = ConvertedDataset {
|
||||||
|
generated_at: Utc::now(),
|
||||||
|
metadata,
|
||||||
|
source: "beir-mix".to_string(),
|
||||||
|
paragraphs,
|
||||||
|
};
|
||||||
|
|
||||||
|
let params = BuildParams {
|
||||||
|
include_impossible: false,
|
||||||
|
base_seed: 0xAA,
|
||||||
|
rng_seed: 0xBB,
|
||||||
|
};
|
||||||
|
|
||||||
|
let refs = ordered_question_refs_beir(&dataset, ¶ms, 8)?;
|
||||||
|
let mut per_prefix: HashMap<String, usize> = HashMap::new();
|
||||||
|
for (p_idx, q_idx) in refs {
|
||||||
|
let question = &dataset.paragraphs[p_idx].questions[q_idx];
|
||||||
|
let prefix = question_prefix(&question.id).unwrap_or("unknown");
|
||||||
|
*per_prefix.entry(prefix.to_string()).or_default() += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(per_prefix.get("fever").copied().unwrap_or(0), 1);
|
||||||
|
assert_eq!(per_prefix.get("fiqa").copied().unwrap_or(0), 2);
|
||||||
|
assert_eq!(per_prefix.get("hotpotqa").copied().unwrap_or(0), 1);
|
||||||
|
assert_eq!(per_prefix.get("nfcorpus").copied().unwrap_or(0), 0);
|
||||||
|
assert_eq!(per_prefix.get("quora").copied().unwrap_or(0), 2);
|
||||||
|
assert_eq!(per_prefix.get("trec-covid").copied().unwrap_or(0), 2);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user