mirror of
https://github.com/perstarkse/minne.git
synced 2026-03-27 11:51:37 +01:00
benchmarks: v2
Minor refactor
This commit is contained in:
@@ -8,7 +8,7 @@ anyhow = { workspace = true }
|
||||
async-openai = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
common = { path = "../common" }
|
||||
composite-retrieval = { path = "../composite-retrieval" }
|
||||
retrieval-pipeline = { path = "../retrieval-pipeline" }
|
||||
ingestion-pipeline = { path = "../ingestion-pipeline" }
|
||||
futures = { workspace = true }
|
||||
fastembed = { workspace = true }
|
||||
|
||||
125
eval/src/args.rs
125
eval/src/args.rs
@@ -4,6 +4,7 @@ use std::{
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use retrieval_pipeline::RetrievalStrategy;
|
||||
|
||||
use crate::datasets::DatasetKind;
|
||||
|
||||
@@ -35,6 +36,41 @@ impl std::str::FromStr for EmbeddingBackend {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetrievalSettings {
|
||||
pub chunk_min_chars: usize,
|
||||
pub chunk_max_chars: usize,
|
||||
pub chunk_vector_take: Option<usize>,
|
||||
pub chunk_fts_take: Option<usize>,
|
||||
pub chunk_token_budget: Option<usize>,
|
||||
pub chunk_avg_chars_per_token: Option<usize>,
|
||||
pub max_chunks_per_entity: Option<usize>,
|
||||
pub rerank: bool,
|
||||
pub rerank_pool_size: usize,
|
||||
pub rerank_keep_top: usize,
|
||||
pub require_verified_chunks: bool,
|
||||
pub strategy: RetrievalStrategy,
|
||||
}
|
||||
|
||||
impl Default for RetrievalSettings {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
chunk_min_chars: 500,
|
||||
chunk_max_chars: 2_000,
|
||||
chunk_vector_take: None,
|
||||
chunk_fts_take: None,
|
||||
chunk_token_budget: None,
|
||||
chunk_avg_chars_per_token: None,
|
||||
max_chunks_per_entity: None,
|
||||
rerank: true,
|
||||
rerank_pool_size: 16,
|
||||
rerank_keep_top: 10,
|
||||
require_verified_chunks: true,
|
||||
strategy: RetrievalStrategy::Initial,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub convert_only: bool,
|
||||
@@ -49,21 +85,14 @@ pub struct Config {
|
||||
pub limit: Option<usize>,
|
||||
pub summary_sample: usize,
|
||||
pub full_context: bool,
|
||||
pub chunk_min_chars: usize,
|
||||
pub chunk_max_chars: usize,
|
||||
pub chunk_vector_take: Option<usize>,
|
||||
pub chunk_fts_take: Option<usize>,
|
||||
pub chunk_token_budget: Option<usize>,
|
||||
pub chunk_avg_chars_per_token: Option<usize>,
|
||||
pub max_chunks_per_entity: Option<usize>,
|
||||
pub rerank: bool,
|
||||
pub rerank_pool_size: usize,
|
||||
pub rerank_keep_top: usize,
|
||||
pub retrieval: RetrievalSettings,
|
||||
pub concurrency: usize,
|
||||
pub embedding_backend: EmbeddingBackend,
|
||||
pub embedding_model: Option<String>,
|
||||
pub cache_dir: PathBuf,
|
||||
pub ingestion_cache_dir: PathBuf,
|
||||
pub ingestion_batch_size: usize,
|
||||
pub ingestion_max_retries: usize,
|
||||
pub refresh_embeddings_only: bool,
|
||||
pub detailed_report: bool,
|
||||
pub slice: Option<String>,
|
||||
@@ -105,21 +134,14 @@ impl Default for Config {
|
||||
limit: Some(200),
|
||||
summary_sample: 5,
|
||||
full_context: false,
|
||||
chunk_min_chars: 500,
|
||||
chunk_max_chars: 2_000,
|
||||
chunk_vector_take: None,
|
||||
chunk_fts_take: None,
|
||||
chunk_token_budget: None,
|
||||
chunk_avg_chars_per_token: None,
|
||||
max_chunks_per_entity: None,
|
||||
rerank: true,
|
||||
rerank_pool_size: 16,
|
||||
rerank_keep_top: 10,
|
||||
retrieval: RetrievalSettings::default(),
|
||||
concurrency: 4,
|
||||
embedding_backend: EmbeddingBackend::FastEmbed,
|
||||
embedding_model: None,
|
||||
cache_dir: PathBuf::from("eval/cache"),
|
||||
ingestion_cache_dir: PathBuf::from("eval/cache/ingested"),
|
||||
ingestion_batch_size: 5,
|
||||
ingestion_max_retries: 3,
|
||||
refresh_embeddings_only: false,
|
||||
detailed_report: false,
|
||||
slice: None,
|
||||
@@ -176,6 +198,7 @@ pub fn parse() -> Result<ParsedArgs> {
|
||||
"--force" | "--refresh" => config.force_convert = true,
|
||||
"--llm-mode" => {
|
||||
config.llm_mode = true;
|
||||
config.retrieval.require_verified_chunks = false;
|
||||
}
|
||||
"--dataset" => {
|
||||
let value = take_value("--dataset", &mut args)?;
|
||||
@@ -279,14 +302,14 @@ pub fn parse() -> Result<ParsedArgs> {
|
||||
let parsed = value.parse::<usize>().with_context(|| {
|
||||
format!("failed to parse --chunk-min value '{value}' as usize")
|
||||
})?;
|
||||
config.chunk_min_chars = parsed.max(1);
|
||||
config.retrieval.chunk_min_chars = parsed.max(1);
|
||||
}
|
||||
"--chunk-max" => {
|
||||
let value = take_value("--chunk-max", &mut args)?;
|
||||
let parsed = value.parse::<usize>().with_context(|| {
|
||||
format!("failed to parse --chunk-max value '{value}' as usize")
|
||||
})?;
|
||||
config.chunk_max_chars = parsed.max(1);
|
||||
config.retrieval.chunk_max_chars = parsed.max(1);
|
||||
}
|
||||
"--chunk-vector-take" => {
|
||||
let value = take_value("--chunk-vector-take", &mut args)?;
|
||||
@@ -296,7 +319,7 @@ pub fn parse() -> Result<ParsedArgs> {
|
||||
if parsed == 0 {
|
||||
return Err(anyhow!("--chunk-vector-take must be greater than zero"));
|
||||
}
|
||||
config.chunk_vector_take = Some(parsed);
|
||||
config.retrieval.chunk_vector_take = Some(parsed);
|
||||
}
|
||||
"--chunk-fts-take" => {
|
||||
let value = take_value("--chunk-fts-take", &mut args)?;
|
||||
@@ -306,7 +329,7 @@ pub fn parse() -> Result<ParsedArgs> {
|
||||
if parsed == 0 {
|
||||
return Err(anyhow!("--chunk-fts-take must be greater than zero"));
|
||||
}
|
||||
config.chunk_fts_take = Some(parsed);
|
||||
config.retrieval.chunk_fts_take = Some(parsed);
|
||||
}
|
||||
"--chunk-token-budget" => {
|
||||
let value = take_value("--chunk-token-budget", &mut args)?;
|
||||
@@ -316,7 +339,7 @@ pub fn parse() -> Result<ParsedArgs> {
|
||||
if parsed == 0 {
|
||||
return Err(anyhow!("--chunk-token-budget must be greater than zero"));
|
||||
}
|
||||
config.chunk_token_budget = Some(parsed);
|
||||
config.retrieval.chunk_token_budget = Some(parsed);
|
||||
}
|
||||
"--chunk-token-chars" => {
|
||||
let value = take_value("--chunk-token-chars", &mut args)?;
|
||||
@@ -326,7 +349,14 @@ pub fn parse() -> Result<ParsedArgs> {
|
||||
if parsed == 0 {
|
||||
return Err(anyhow!("--chunk-token-chars must be greater than zero"));
|
||||
}
|
||||
config.chunk_avg_chars_per_token = Some(parsed);
|
||||
config.retrieval.chunk_avg_chars_per_token = Some(parsed);
|
||||
}
|
||||
"--retrieval-strategy" => {
|
||||
let value = take_value("--retrieval-strategy", &mut args)?;
|
||||
let parsed = value.parse::<RetrievalStrategy>().map_err(|err| {
|
||||
anyhow!("failed to parse --retrieval-strategy value '{value}': {err}")
|
||||
})?;
|
||||
config.retrieval.strategy = parsed;
|
||||
}
|
||||
"--max-chunks-per-entity" => {
|
||||
let value = take_value("--max-chunks-per-entity", &mut args)?;
|
||||
@@ -336,7 +366,7 @@ pub fn parse() -> Result<ParsedArgs> {
|
||||
if parsed == 0 {
|
||||
return Err(anyhow!("--max-chunks-per-entity must be greater than zero"));
|
||||
}
|
||||
config.max_chunks_per_entity = Some(parsed);
|
||||
config.retrieval.max_chunks_per_entity = Some(parsed);
|
||||
}
|
||||
"--embedding" => {
|
||||
let value = take_value("--embedding", &mut args)?;
|
||||
@@ -354,6 +384,23 @@ pub fn parse() -> Result<ParsedArgs> {
|
||||
let value = take_value("--ingestion-cache-dir", &mut args)?;
|
||||
config.ingestion_cache_dir = PathBuf::from(value);
|
||||
}
|
||||
"--ingestion-batch-size" => {
|
||||
let value = take_value("--ingestion-batch-size", &mut args)?;
|
||||
let parsed = value.parse::<usize>().with_context(|| {
|
||||
format!("failed to parse --ingestion-batch-size value '{value}' as usize")
|
||||
})?;
|
||||
if parsed == 0 {
|
||||
return Err(anyhow!("--ingestion-batch-size must be greater than zero"));
|
||||
}
|
||||
config.ingestion_batch_size = parsed;
|
||||
}
|
||||
"--ingestion-max-retries" => {
|
||||
let value = take_value("--ingestion-max-retries", &mut args)?;
|
||||
let parsed = value.parse::<usize>().with_context(|| {
|
||||
format!("failed to parse --ingestion-max-retries value '{value}' as usize")
|
||||
})?;
|
||||
config.ingestion_max_retries = parsed;
|
||||
}
|
||||
"--negative-multiplier" => {
|
||||
let value = take_value("--negative-multiplier", &mut args)?;
|
||||
let parsed = value.parse::<f32>().with_context(|| {
|
||||
@@ -367,21 +414,21 @@ pub fn parse() -> Result<ParsedArgs> {
|
||||
config.negative_multiplier = parsed;
|
||||
}
|
||||
"--no-rerank" => {
|
||||
config.rerank = false;
|
||||
config.retrieval.rerank = false;
|
||||
}
|
||||
"--rerank-pool" => {
|
||||
let value = take_value("--rerank-pool", &mut args)?;
|
||||
let parsed = value.parse::<usize>().with_context(|| {
|
||||
format!("failed to parse --rerank-pool value '{value}' as usize")
|
||||
})?;
|
||||
config.rerank_pool_size = parsed.max(1);
|
||||
config.retrieval.rerank_pool_size = parsed.max(1);
|
||||
}
|
||||
"--rerank-keep" => {
|
||||
let value = take_value("--rerank-keep", &mut args)?;
|
||||
let parsed = value.parse::<usize>().with_context(|| {
|
||||
format!("failed to parse --rerank-keep value '{value}' as usize")
|
||||
})?;
|
||||
config.rerank_keep_top = parsed.max(1);
|
||||
config.retrieval.rerank_keep_top = parsed.max(1);
|
||||
}
|
||||
"--concurrency" => {
|
||||
let value = take_value("--concurrency", &mut args)?;
|
||||
@@ -451,15 +498,15 @@ pub fn parse() -> Result<ParsedArgs> {
|
||||
}
|
||||
}
|
||||
|
||||
if config.chunk_min_chars >= config.chunk_max_chars {
|
||||
if config.retrieval.chunk_min_chars >= config.retrieval.chunk_max_chars {
|
||||
return Err(anyhow!(
|
||||
"--chunk-min must be less than --chunk-max (got {} >= {})",
|
||||
config.chunk_min_chars,
|
||||
config.chunk_max_chars
|
||||
config.retrieval.chunk_min_chars,
|
||||
config.retrieval.chunk_max_chars
|
||||
));
|
||||
}
|
||||
|
||||
if config.rerank && config.rerank_pool_size == 0 {
|
||||
if config.retrieval.rerank && config.retrieval.rerank_pool_size == 0 {
|
||||
return Err(anyhow!(
|
||||
"--rerank-pool must be greater than zero when reranking is enabled"
|
||||
));
|
||||
@@ -578,14 +625,20 @@ OPTIONS:
|
||||
Override chunk token budget estimate for assembly (default: 10000).
|
||||
--chunk-token-chars <int>
|
||||
Override average characters per token used for budgeting (default: 4).
|
||||
--retrieval-strategy <initial|revised>
|
||||
Select the retrieval pipeline strategy (default: initial).
|
||||
--max-chunks-per-entity <int>
|
||||
Override maximum chunks attached per entity (default: 4).
|
||||
--embedding <name> Embedding backend: 'fastembed' (default) or 'hashed'.
|
||||
--embedding-model <code>
|
||||
FastEmbed model code (defaults to crate preset when omitted).
|
||||
--cache-dir <path> Directory for embedding caches (default: eval/cache).
|
||||
--ingestion-cache-dir <path>
|
||||
Directory for ingestion corpora caches (default: eval/cache/ingested).
|
||||
--ingestion-cache-dir <path>
|
||||
Directory for ingestion corpora caches (default: eval/cache/ingested).
|
||||
--ingestion-batch-size <int>
|
||||
Number of paragraphs to ingest concurrently (default: 5).
|
||||
--ingestion-max-retries <int>
|
||||
Maximum retries for ingestion failures per paragraph (default: 3).
|
||||
--negative-multiplier <float>
|
||||
Target negative-to-positive paragraph ratio for slice growth (default: 4.0).
|
||||
--refresh-embeddings Recompute embeddings for cached corpora without re-running ingestion.
|
||||
|
||||
1003
eval/src/datasets.rs
1003
eval/src/datasets.rs
File diff suppressed because it is too large
Load Diff
493
eval/src/datasets/mod.rs
Normal file
493
eval/src/datasets/mod.rs
Normal file
@@ -0,0 +1,493 @@
|
||||
mod nq;
|
||||
mod squad;
|
||||
|
||||
use std::{
|
||||
collections::{BTreeMap, HashMap},
|
||||
fs::{self},
|
||||
path::{Path, PathBuf},
|
||||
str::FromStr,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use chrono::{DateTime, TimeZone, Utc};
|
||||
use once_cell::sync::OnceCell;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::warn;
|
||||
|
||||
const MANIFEST_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/manifest.yaml");
|
||||
static DATASET_CATALOG: OnceCell<DatasetCatalog> = OnceCell::new();
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)]
|
||||
pub struct DatasetCatalog {
|
||||
datasets: BTreeMap<String, DatasetEntry>,
|
||||
slices: HashMap<String, SliceLocation>,
|
||||
default_dataset: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)]
|
||||
pub struct DatasetEntry {
|
||||
pub metadata: DatasetMetadata,
|
||||
pub raw_path: PathBuf,
|
||||
pub converted_path: PathBuf,
|
||||
pub include_unanswerable: bool,
|
||||
pub slices: Vec<SliceEntry>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)]
|
||||
pub struct SliceEntry {
|
||||
pub id: String,
|
||||
pub dataset_id: String,
|
||||
pub label: String,
|
||||
pub description: Option<String>,
|
||||
pub limit: Option<usize>,
|
||||
pub corpus_limit: Option<usize>,
|
||||
pub include_unanswerable: Option<bool>,
|
||||
pub seed: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)]
|
||||
struct SliceLocation {
|
||||
dataset_id: String,
|
||||
slice_index: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ManifestFile {
|
||||
default_dataset: Option<String>,
|
||||
datasets: Vec<ManifestDataset>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ManifestDataset {
|
||||
id: String,
|
||||
label: String,
|
||||
category: String,
|
||||
#[serde(default)]
|
||||
entity_suffix: Option<String>,
|
||||
#[serde(default)]
|
||||
source_prefix: Option<String>,
|
||||
raw: String,
|
||||
converted: String,
|
||||
#[serde(default)]
|
||||
include_unanswerable: bool,
|
||||
#[serde(default)]
|
||||
slices: Vec<ManifestSlice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ManifestSlice {
|
||||
id: String,
|
||||
label: String,
|
||||
#[serde(default)]
|
||||
description: Option<String>,
|
||||
#[serde(default)]
|
||||
limit: Option<usize>,
|
||||
#[serde(default)]
|
||||
corpus_limit: Option<usize>,
|
||||
#[serde(default)]
|
||||
include_unanswerable: Option<bool>,
|
||||
#[serde(default)]
|
||||
seed: Option<u64>,
|
||||
}
|
||||
|
||||
impl DatasetCatalog {
|
||||
pub fn load() -> Result<Self> {
|
||||
let manifest_raw = fs::read_to_string(MANIFEST_PATH)
|
||||
.with_context(|| format!("reading dataset manifest at {}", MANIFEST_PATH))?;
|
||||
let manifest: ManifestFile = serde_yaml::from_str(&manifest_raw)
|
||||
.with_context(|| format!("parsing dataset manifest at {}", MANIFEST_PATH))?;
|
||||
|
||||
let root = Path::new(env!("CARGO_MANIFEST_DIR"));
|
||||
let mut datasets = BTreeMap::new();
|
||||
let mut slices = HashMap::new();
|
||||
|
||||
for dataset in manifest.datasets {
|
||||
let raw_path = resolve_path(root, &dataset.raw);
|
||||
let converted_path = resolve_path(root, &dataset.converted);
|
||||
|
||||
if !raw_path.exists() {
|
||||
bail!(
|
||||
"dataset '{}' raw file missing at {}",
|
||||
dataset.id,
|
||||
raw_path.display()
|
||||
);
|
||||
}
|
||||
if !converted_path.exists() {
|
||||
warn!(
|
||||
"dataset '{}' converted file missing at {}; the next conversion run will regenerate it",
|
||||
dataset.id,
|
||||
converted_path.display()
|
||||
);
|
||||
}
|
||||
|
||||
let metadata = DatasetMetadata {
|
||||
id: dataset.id.clone(),
|
||||
label: dataset.label.clone(),
|
||||
category: dataset.category.clone(),
|
||||
entity_suffix: dataset
|
||||
.entity_suffix
|
||||
.clone()
|
||||
.unwrap_or_else(|| dataset.label.clone()),
|
||||
source_prefix: dataset
|
||||
.source_prefix
|
||||
.clone()
|
||||
.unwrap_or_else(|| dataset.id.clone()),
|
||||
include_unanswerable: dataset.include_unanswerable,
|
||||
context_token_limit: None,
|
||||
};
|
||||
|
||||
let mut entry_slices = Vec::with_capacity(dataset.slices.len());
|
||||
|
||||
for (index, manifest_slice) in dataset.slices.into_iter().enumerate() {
|
||||
if slices.contains_key(&manifest_slice.id) {
|
||||
bail!(
|
||||
"slice '{}' defined multiple times in manifest",
|
||||
manifest_slice.id
|
||||
);
|
||||
}
|
||||
entry_slices.push(SliceEntry {
|
||||
id: manifest_slice.id.clone(),
|
||||
dataset_id: dataset.id.clone(),
|
||||
label: manifest_slice.label,
|
||||
description: manifest_slice.description,
|
||||
limit: manifest_slice.limit,
|
||||
corpus_limit: manifest_slice.corpus_limit,
|
||||
include_unanswerable: manifest_slice.include_unanswerable,
|
||||
seed: manifest_slice.seed,
|
||||
});
|
||||
slices.insert(
|
||||
manifest_slice.id,
|
||||
SliceLocation {
|
||||
dataset_id: dataset.id.clone(),
|
||||
slice_index: index,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
datasets.insert(
|
||||
metadata.id.clone(),
|
||||
DatasetEntry {
|
||||
metadata,
|
||||
raw_path,
|
||||
converted_path,
|
||||
include_unanswerable: dataset.include_unanswerable,
|
||||
slices: entry_slices,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let default_dataset = manifest
|
||||
.default_dataset
|
||||
.or_else(|| datasets.keys().next().cloned())
|
||||
.ok_or_else(|| anyhow!("dataset manifest does not include any datasets"))?;
|
||||
|
||||
Ok(Self {
|
||||
datasets,
|
||||
slices,
|
||||
default_dataset,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn global() -> Result<&'static Self> {
|
||||
DATASET_CATALOG.get_or_try_init(Self::load)
|
||||
}
|
||||
|
||||
pub fn dataset(&self, id: &str) -> Result<&DatasetEntry> {
|
||||
self.datasets
|
||||
.get(id)
|
||||
.ok_or_else(|| anyhow!("unknown dataset '{id}' in manifest"))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn default_dataset(&self) -> Result<&DatasetEntry> {
|
||||
self.dataset(&self.default_dataset)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn slice(&self, slice_id: &str) -> Result<(&DatasetEntry, &SliceEntry)> {
|
||||
let location = self
|
||||
.slices
|
||||
.get(slice_id)
|
||||
.ok_or_else(|| anyhow!("unknown slice '{slice_id}' in manifest"))?;
|
||||
let dataset = self
|
||||
.datasets
|
||||
.get(&location.dataset_id)
|
||||
.ok_or_else(|| anyhow!("slice '{slice_id}' references missing dataset"))?;
|
||||
let slice = dataset
|
||||
.slices
|
||||
.get(location.slice_index)
|
||||
.ok_or_else(|| anyhow!("slice index out of bounds for '{slice_id}'"))?;
|
||||
Ok((dataset, slice))
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_path(root: &Path, value: &str) -> PathBuf {
|
||||
let path = Path::new(value);
|
||||
if path.is_absolute() {
|
||||
path.to_path_buf()
|
||||
} else {
|
||||
root.join(path)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn catalog() -> Result<&'static DatasetCatalog> {
|
||||
DatasetCatalog::global()
|
||||
}
|
||||
|
||||
fn dataset_entry_for_kind(kind: DatasetKind) -> Result<&'static DatasetEntry> {
|
||||
let catalog = catalog()?;
|
||||
catalog.dataset(kind.id())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum DatasetKind {
|
||||
SquadV2,
|
||||
NaturalQuestions,
|
||||
}
|
||||
|
||||
impl DatasetKind {
|
||||
pub fn id(self) -> &'static str {
|
||||
match self {
|
||||
Self::SquadV2 => "squad-v2",
|
||||
Self::NaturalQuestions => "natural-questions-dev",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn label(self) -> &'static str {
|
||||
match self {
|
||||
Self::SquadV2 => "SQuAD v2.0",
|
||||
Self::NaturalQuestions => "Natural Questions (dev)",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn category(self) -> &'static str {
|
||||
match self {
|
||||
Self::SquadV2 => "SQuAD v2.0",
|
||||
Self::NaturalQuestions => "Natural Questions",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn entity_suffix(self) -> &'static str {
|
||||
match self {
|
||||
Self::SquadV2 => "SQuAD",
|
||||
Self::NaturalQuestions => "Natural Questions",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn source_prefix(self) -> &'static str {
|
||||
match self {
|
||||
Self::SquadV2 => "squad",
|
||||
Self::NaturalQuestions => "nq",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn default_raw_path(self) -> PathBuf {
|
||||
dataset_entry_for_kind(self)
|
||||
.map(|entry| entry.raw_path.clone())
|
||||
.unwrap_or_else(|err| panic!("dataset manifest missing entry for {:?}: {err}", self))
|
||||
}
|
||||
|
||||
pub fn default_converted_path(self) -> PathBuf {
|
||||
dataset_entry_for_kind(self)
|
||||
.map(|entry| entry.converted_path.clone())
|
||||
.unwrap_or_else(|err| panic!("dataset manifest missing entry for {:?}: {err}", self))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DatasetKind {
|
||||
fn default() -> Self {
|
||||
Self::SquadV2
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for DatasetKind {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_ascii_lowercase().as_str() {
|
||||
"squad" | "squad-v2" | "squad_v2" => Ok(Self::SquadV2),
|
||||
"nq" | "natural-questions" | "natural_questions" | "natural-questions-dev" => {
|
||||
Ok(Self::NaturalQuestions)
|
||||
}
|
||||
other => {
|
||||
anyhow::bail!("unknown dataset '{other}'. Expected 'squad' or 'natural-questions'.")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DatasetMetadata {
|
||||
pub id: String,
|
||||
pub label: String,
|
||||
pub category: String,
|
||||
pub entity_suffix: String,
|
||||
pub source_prefix: String,
|
||||
#[serde(default)]
|
||||
pub include_unanswerable: bool,
|
||||
#[serde(default)]
|
||||
pub context_token_limit: Option<usize>,
|
||||
}
|
||||
|
||||
impl DatasetMetadata {
|
||||
pub fn for_kind(
|
||||
kind: DatasetKind,
|
||||
include_unanswerable: bool,
|
||||
context_token_limit: Option<usize>,
|
||||
) -> Self {
|
||||
if let Ok(entry) = dataset_entry_for_kind(kind) {
|
||||
return Self {
|
||||
id: entry.metadata.id.clone(),
|
||||
label: entry.metadata.label.clone(),
|
||||
category: entry.metadata.category.clone(),
|
||||
entity_suffix: entry.metadata.entity_suffix.clone(),
|
||||
source_prefix: entry.metadata.source_prefix.clone(),
|
||||
include_unanswerable,
|
||||
context_token_limit,
|
||||
};
|
||||
}
|
||||
|
||||
Self {
|
||||
id: kind.id().to_string(),
|
||||
label: kind.label().to_string(),
|
||||
category: kind.category().to_string(),
|
||||
entity_suffix: kind.entity_suffix().to_string(),
|
||||
source_prefix: kind.source_prefix().to_string(),
|
||||
include_unanswerable,
|
||||
context_token_limit,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_metadata() -> DatasetMetadata {
|
||||
DatasetMetadata::for_kind(DatasetKind::default(), false, None)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ConvertedDataset {
|
||||
pub generated_at: DateTime<Utc>,
|
||||
#[serde(default = "default_metadata")]
|
||||
pub metadata: DatasetMetadata,
|
||||
pub source: String,
|
||||
pub paragraphs: Vec<ConvertedParagraph>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ConvertedParagraph {
|
||||
pub id: String,
|
||||
pub title: String,
|
||||
pub context: String,
|
||||
pub questions: Vec<ConvertedQuestion>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ConvertedQuestion {
|
||||
pub id: String,
|
||||
pub question: String,
|
||||
pub answers: Vec<String>,
|
||||
pub is_impossible: bool,
|
||||
}
|
||||
|
||||
pub fn convert(
|
||||
raw_path: &Path,
|
||||
dataset: DatasetKind,
|
||||
include_unanswerable: bool,
|
||||
context_token_limit: Option<usize>,
|
||||
) -> Result<ConvertedDataset> {
|
||||
let paragraphs = match dataset {
|
||||
DatasetKind::SquadV2 => squad::convert_squad(raw_path)?,
|
||||
DatasetKind::NaturalQuestions => {
|
||||
nq::convert_nq(raw_path, include_unanswerable, context_token_limit)?
|
||||
}
|
||||
};
|
||||
|
||||
let metadata_limit = match dataset {
|
||||
DatasetKind::NaturalQuestions => None,
|
||||
_ => context_token_limit,
|
||||
};
|
||||
|
||||
Ok(ConvertedDataset {
|
||||
generated_at: Utc::now(),
|
||||
metadata: DatasetMetadata::for_kind(dataset, include_unanswerable, metadata_limit),
|
||||
source: raw_path.display().to_string(),
|
||||
paragraphs,
|
||||
})
|
||||
}
|
||||
|
||||
fn ensure_parent(path: &Path) -> Result<()> {
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.with_context(|| format!("creating parent directory for {}", path.display()))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write_converted(dataset: &ConvertedDataset, converted_path: &Path) -> Result<()> {
|
||||
ensure_parent(converted_path)?;
|
||||
let json =
|
||||
serde_json::to_string_pretty(dataset).context("serialising converted dataset to JSON")?;
|
||||
fs::write(converted_path, json)
|
||||
.with_context(|| format!("writing converted dataset to {}", converted_path.display()))
|
||||
}
|
||||
|
||||
pub fn read_converted(converted_path: &Path) -> Result<ConvertedDataset> {
|
||||
let raw = fs::read_to_string(converted_path)
|
||||
.with_context(|| format!("reading converted dataset at {}", converted_path.display()))?;
|
||||
let mut dataset: ConvertedDataset = serde_json::from_str(&raw)
|
||||
.with_context(|| format!("parsing converted dataset at {}", converted_path.display()))?;
|
||||
if dataset.metadata.id.trim().is_empty() {
|
||||
dataset.metadata = default_metadata();
|
||||
}
|
||||
if dataset.source.is_empty() {
|
||||
dataset.source = converted_path.display().to_string();
|
||||
}
|
||||
Ok(dataset)
|
||||
}
|
||||
|
||||
pub fn ensure_converted(
|
||||
dataset_kind: DatasetKind,
|
||||
raw_path: &Path,
|
||||
converted_path: &Path,
|
||||
force: bool,
|
||||
include_unanswerable: bool,
|
||||
context_token_limit: Option<usize>,
|
||||
) -> Result<ConvertedDataset> {
|
||||
if force || !converted_path.exists() {
|
||||
let dataset = convert(
|
||||
raw_path,
|
||||
dataset_kind,
|
||||
include_unanswerable,
|
||||
context_token_limit,
|
||||
)?;
|
||||
write_converted(&dataset, converted_path)?;
|
||||
return Ok(dataset);
|
||||
}
|
||||
|
||||
match read_converted(converted_path) {
|
||||
Ok(dataset)
|
||||
if dataset.metadata.id == dataset_kind.id()
|
||||
&& dataset.metadata.include_unanswerable == include_unanswerable
|
||||
&& dataset.metadata.context_token_limit == context_token_limit =>
|
||||
{
|
||||
Ok(dataset)
|
||||
}
|
||||
_ => {
|
||||
let dataset = convert(
|
||||
raw_path,
|
||||
dataset_kind,
|
||||
include_unanswerable,
|
||||
context_token_limit,
|
||||
)?;
|
||||
write_converted(&dataset, converted_path)?;
|
||||
Ok(dataset)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn base_timestamp() -> DateTime<Utc> {
|
||||
Utc.with_ymd_and_hms(2023, 1, 1, 0, 0, 0).unwrap()
|
||||
}
|
||||
234
eval/src/datasets/nq.rs
Normal file
234
eval/src/datasets/nq.rs
Normal file
@@ -0,0 +1,234 @@
|
||||
use std::{
|
||||
collections::BTreeSet,
|
||||
fs::File,
|
||||
io::{BufRead, BufReader},
|
||||
path::Path,
|
||||
};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::Deserialize;
|
||||
use tracing::warn;
|
||||
|
||||
use super::{ConvertedParagraph, ConvertedQuestion};
|
||||
|
||||
pub fn convert_nq(
|
||||
raw_path: &Path,
|
||||
include_unanswerable: bool,
|
||||
_context_token_limit: Option<usize>,
|
||||
) -> Result<Vec<ConvertedParagraph>> {
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct NqExample {
|
||||
question_text: String,
|
||||
document_title: String,
|
||||
example_id: i64,
|
||||
document_tokens: Vec<NqToken>,
|
||||
long_answer_candidates: Vec<NqLongAnswerCandidate>,
|
||||
annotations: Vec<NqAnnotation>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct NqToken {
|
||||
token: String,
|
||||
#[serde(default)]
|
||||
html_token: bool,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct NqLongAnswerCandidate {
|
||||
start_token: i32,
|
||||
end_token: i32,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct NqAnnotation {
|
||||
short_answers: Vec<NqShortAnswer>,
|
||||
#[serde(default)]
|
||||
yes_no_answer: String,
|
||||
long_answer: NqLongAnswer,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct NqShortAnswer {
|
||||
start_token: i32,
|
||||
end_token: i32,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct NqLongAnswer {
|
||||
candidate_index: i32,
|
||||
}
|
||||
|
||||
fn join_tokens(tokens: &[NqToken], start: usize, end: usize) -> String {
|
||||
let mut buffer = String::new();
|
||||
let end = end.min(tokens.len());
|
||||
for token in tokens.iter().skip(start).take(end.saturating_sub(start)) {
|
||||
if token.html_token {
|
||||
continue;
|
||||
}
|
||||
let text = token.token.trim();
|
||||
if text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let attach = matches!(
|
||||
text,
|
||||
"," | "." | "!" | "?" | ";" | ":" | ")" | "]" | "}" | "%" | "…" | "..."
|
||||
) || text.starts_with('\'')
|
||||
|| text == "n't"
|
||||
|| text == "'s"
|
||||
|| text == "'re"
|
||||
|| text == "'ve"
|
||||
|| text == "'d"
|
||||
|| text == "'ll";
|
||||
|
||||
if buffer.is_empty() || attach {
|
||||
buffer.push_str(text);
|
||||
} else {
|
||||
buffer.push(' ');
|
||||
buffer.push_str(text);
|
||||
}
|
||||
}
|
||||
|
||||
buffer.trim().to_string()
|
||||
}
|
||||
|
||||
let file = File::open(raw_path).with_context(|| {
|
||||
format!(
|
||||
"opening Natural Questions dataset at {}",
|
||||
raw_path.display()
|
||||
)
|
||||
})?;
|
||||
let reader = BufReader::new(file);
|
||||
|
||||
let mut paragraphs = Vec::new();
|
||||
for (line_idx, line) in reader.lines().enumerate() {
|
||||
let line = line.with_context(|| {
|
||||
format!(
|
||||
"reading Natural Questions line {} from {}",
|
||||
line_idx + 1,
|
||||
raw_path.display()
|
||||
)
|
||||
})?;
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let example: NqExample = serde_json::from_str(&line).with_context(|| {
|
||||
format!(
|
||||
"parsing Natural Questions JSON (line {}) at {}",
|
||||
line_idx + 1,
|
||||
raw_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut answer_texts: Vec<String> = Vec::new();
|
||||
let mut short_answer_texts: Vec<String> = Vec::new();
|
||||
let mut has_short_or_yesno = false;
|
||||
let mut has_short_answer = false;
|
||||
for annotation in &example.annotations {
|
||||
for short in &annotation.short_answers {
|
||||
if short.start_token < 0 || short.end_token <= short.start_token {
|
||||
continue;
|
||||
}
|
||||
let start = short.start_token as usize;
|
||||
let end = short.end_token as usize;
|
||||
if start >= example.document_tokens.len() || end > example.document_tokens.len() {
|
||||
continue;
|
||||
}
|
||||
let text = join_tokens(&example.document_tokens, start, end);
|
||||
if !text.is_empty() {
|
||||
answer_texts.push(text.clone());
|
||||
short_answer_texts.push(text);
|
||||
has_short_or_yesno = true;
|
||||
has_short_answer = true;
|
||||
}
|
||||
}
|
||||
|
||||
match annotation
|
||||
.yes_no_answer
|
||||
.trim()
|
||||
.to_ascii_lowercase()
|
||||
.as_str()
|
||||
{
|
||||
"yes" => {
|
||||
answer_texts.push("yes".to_string());
|
||||
has_short_or_yesno = true;
|
||||
}
|
||||
"no" => {
|
||||
answer_texts.push("no".to_string());
|
||||
has_short_or_yesno = true;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let mut answers = dedupe_strings(answer_texts);
|
||||
let is_unanswerable = !has_short_or_yesno || answers.is_empty();
|
||||
if is_unanswerable {
|
||||
if !include_unanswerable {
|
||||
continue;
|
||||
}
|
||||
answers.clear();
|
||||
}
|
||||
|
||||
let paragraph_id = format!("nq-{}", example.example_id);
|
||||
let question_id = format!("nq-{}", example.example_id);
|
||||
|
||||
let context = join_tokens(&example.document_tokens, 0, example.document_tokens.len());
|
||||
if context.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if has_short_answer && !short_answer_texts.is_empty() {
|
||||
let normalized_context = context.to_ascii_lowercase();
|
||||
let missing_answer = short_answer_texts.iter().any(|answer| {
|
||||
let needle = answer.trim().to_ascii_lowercase();
|
||||
!needle.is_empty() && !normalized_context.contains(&needle)
|
||||
});
|
||||
if missing_answer {
|
||||
warn!(
|
||||
question_id = %question_id,
|
||||
"Skipping Natural Questions example because answers were not found in the assembled context"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if !include_unanswerable && (!has_short_answer || short_answer_texts.is_empty()) {
|
||||
// yes/no-only questions are excluded by default unless --llm-mode is used
|
||||
continue;
|
||||
}
|
||||
|
||||
let question = ConvertedQuestion {
|
||||
id: question_id,
|
||||
question: example.question_text.trim().to_string(),
|
||||
answers,
|
||||
is_impossible: is_unanswerable,
|
||||
};
|
||||
|
||||
paragraphs.push(ConvertedParagraph {
|
||||
id: paragraph_id,
|
||||
title: example.document_title.trim().to_string(),
|
||||
context,
|
||||
questions: vec![question],
|
||||
});
|
||||
}
|
||||
|
||||
Ok(paragraphs)
|
||||
}
|
||||
|
||||
fn dedupe_strings<I>(values: I) -> Vec<String>
|
||||
where
|
||||
I: IntoIterator<Item = String>,
|
||||
{
|
||||
let mut set = BTreeSet::new();
|
||||
for value in values {
|
||||
let trimmed = value.trim();
|
||||
if !trimmed.is_empty() {
|
||||
set.insert(trimmed.to_string());
|
||||
}
|
||||
}
|
||||
set.into_iter().collect()
|
||||
}
|
||||
107
eval/src/datasets/squad.rs
Normal file
107
eval/src/datasets/squad.rs
Normal file
@@ -0,0 +1,107 @@
|
||||
use std::{collections::BTreeSet, fs, path::Path};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::{ConvertedParagraph, ConvertedQuestion};
|
||||
|
||||
pub fn convert_squad(raw_path: &Path) -> Result<Vec<ConvertedParagraph>> {
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SquadDataset {
|
||||
data: Vec<SquadArticle>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SquadArticle {
|
||||
title: String,
|
||||
paragraphs: Vec<SquadParagraph>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SquadParagraph {
|
||||
context: String,
|
||||
qas: Vec<SquadQuestion>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SquadQuestion {
|
||||
id: String,
|
||||
question: String,
|
||||
answers: Vec<SquadAnswer>,
|
||||
#[serde(default)]
|
||||
is_impossible: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SquadAnswer {
|
||||
text: String,
|
||||
}
|
||||
|
||||
let raw = fs::read_to_string(raw_path)
|
||||
.with_context(|| format!("reading raw SQuAD dataset at {}", raw_path.display()))?;
|
||||
let parsed: SquadDataset = serde_json::from_str(&raw)
|
||||
.with_context(|| format!("parsing SQuAD dataset at {}", raw_path.display()))?;
|
||||
|
||||
let mut paragraphs = Vec::new();
|
||||
for (article_idx, article) in parsed.data.into_iter().enumerate() {
|
||||
for (paragraph_idx, paragraph) in article.paragraphs.into_iter().enumerate() {
|
||||
let mut questions = Vec::new();
|
||||
for qa in paragraph.qas {
|
||||
let answers = dedupe_strings(qa.answers.into_iter().map(|answer| answer.text));
|
||||
questions.push(ConvertedQuestion {
|
||||
id: qa.id,
|
||||
question: qa.question.trim().to_string(),
|
||||
answers,
|
||||
is_impossible: qa.is_impossible,
|
||||
});
|
||||
}
|
||||
|
||||
let paragraph_id =
|
||||
format!("{}-{}", slugify(&article.title, article_idx), paragraph_idx);
|
||||
|
||||
paragraphs.push(ConvertedParagraph {
|
||||
id: paragraph_id,
|
||||
title: article.title.trim().to_string(),
|
||||
context: paragraph.context.trim().to_string(),
|
||||
questions,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(paragraphs)
|
||||
}
|
||||
|
||||
fn dedupe_strings<I>(values: I) -> Vec<String>
|
||||
where
|
||||
I: IntoIterator<Item = String>,
|
||||
{
|
||||
let mut set = BTreeSet::new();
|
||||
for value in values {
|
||||
let trimmed = value.trim();
|
||||
if !trimmed.is_empty() {
|
||||
set.insert(trimmed.to_string());
|
||||
}
|
||||
}
|
||||
set.into_iter().collect()
|
||||
}
|
||||
|
||||
fn slugify(input: &str, fallback_idx: usize) -> String {
|
||||
let mut slug = String::new();
|
||||
let mut last_dash = false;
|
||||
for ch in input.chars() {
|
||||
let c = ch.to_ascii_lowercase();
|
||||
if c.is_ascii_alphanumeric() {
|
||||
slug.push(c);
|
||||
last_dash = false;
|
||||
} else if !last_dash {
|
||||
slug.push('-');
|
||||
last_dash = true;
|
||||
}
|
||||
}
|
||||
|
||||
slug = slug.trim_matches('-').to_string();
|
||||
if slug.is_empty() {
|
||||
slug = format!("article-{fallback_idx}");
|
||||
}
|
||||
slug
|
||||
}
|
||||
@@ -1,12 +1,10 @@
|
||||
mod pipeline;
|
||||
mod types;
|
||||
|
||||
pub use pipeline::run_evaluation;
|
||||
pub use types::*;
|
||||
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
path::Path,
|
||||
time::Duration,
|
||||
};
|
||||
use std::{collections::HashMap, path::Path, time::Duration};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use chrono::{DateTime, SecondsFormat, Utc};
|
||||
@@ -17,10 +15,8 @@ use common::{
|
||||
types::{system_settings::SystemSettings, user::User},
|
||||
},
|
||||
};
|
||||
use composite_retrieval::pipeline as retrieval_pipeline;
|
||||
use composite_retrieval::pipeline::PipelineStageTimings;
|
||||
use composite_retrieval::pipeline::RetrievalTuning;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use retrieval_pipeline::RetrievalTuning;
|
||||
use serde::Deserialize;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tracing::{info, warn};
|
||||
|
||||
@@ -33,178 +29,6 @@ use crate::{
|
||||
snapshot::{self, DbSnapshotState},
|
||||
};
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct EvaluationSummary {
|
||||
pub generated_at: DateTime<Utc>,
|
||||
pub k: usize,
|
||||
pub limit: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub run_label: Option<String>,
|
||||
pub total_cases: usize,
|
||||
pub correct: usize,
|
||||
pub precision: f64,
|
||||
pub correct_at_1: usize,
|
||||
pub correct_at_2: usize,
|
||||
pub correct_at_3: usize,
|
||||
pub precision_at_1: f64,
|
||||
pub precision_at_2: f64,
|
||||
pub precision_at_3: f64,
|
||||
pub duration_ms: u128,
|
||||
pub dataset_id: String,
|
||||
pub dataset_label: String,
|
||||
pub dataset_includes_unanswerable: bool,
|
||||
pub dataset_source: String,
|
||||
pub slice_id: String,
|
||||
pub slice_seed: u64,
|
||||
pub slice_total_cases: usize,
|
||||
pub slice_window_offset: usize,
|
||||
pub slice_window_length: usize,
|
||||
pub slice_cases: usize,
|
||||
pub slice_positive_paragraphs: usize,
|
||||
pub slice_negative_paragraphs: usize,
|
||||
pub slice_total_paragraphs: usize,
|
||||
pub slice_negative_multiplier: f32,
|
||||
pub namespace_reused: bool,
|
||||
pub corpus_paragraphs: usize,
|
||||
pub ingestion_cache_path: String,
|
||||
pub ingestion_reused: bool,
|
||||
pub ingestion_embeddings_reused: bool,
|
||||
pub ingestion_fingerprint: String,
|
||||
pub positive_paragraphs_reused: usize,
|
||||
pub negative_paragraphs_reused: usize,
|
||||
pub latency_ms: LatencyStats,
|
||||
pub perf: PerformanceTimings,
|
||||
pub embedding_backend: String,
|
||||
pub embedding_model: Option<String>,
|
||||
pub embedding_dimension: usize,
|
||||
pub rerank_enabled: bool,
|
||||
pub rerank_pool_size: Option<usize>,
|
||||
pub rerank_keep_top: usize,
|
||||
pub concurrency: usize,
|
||||
pub detailed_report: bool,
|
||||
pub chunk_vector_take: usize,
|
||||
pub chunk_fts_take: usize,
|
||||
pub chunk_token_budget: usize,
|
||||
pub chunk_avg_chars_per_token: usize,
|
||||
pub max_chunks_per_entity: usize,
|
||||
pub cases: Vec<CaseSummary>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct CaseSummary {
|
||||
pub question_id: String,
|
||||
pub question: String,
|
||||
pub paragraph_id: String,
|
||||
pub paragraph_title: String,
|
||||
pub expected_source: String,
|
||||
pub answers: Vec<String>,
|
||||
pub matched: bool,
|
||||
pub entity_match: bool,
|
||||
pub chunk_text_match: bool,
|
||||
pub chunk_id_match: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub match_rank: Option<usize>,
|
||||
pub latency_ms: u128,
|
||||
pub retrieved: Vec<RetrievedSummary>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LatencyStats {
|
||||
pub avg: f64,
|
||||
pub p50: u128,
|
||||
pub p95: u128,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct StageLatencyBreakdown {
|
||||
pub collect_candidates: LatencyStats,
|
||||
pub graph_expansion: LatencyStats,
|
||||
pub chunk_attach: LatencyStats,
|
||||
pub rerank: LatencyStats,
|
||||
pub assemble: LatencyStats,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, Serialize)]
|
||||
pub struct EvaluationStageTimings {
|
||||
pub prepare_slice_ms: u128,
|
||||
pub prepare_db_ms: u128,
|
||||
pub prepare_corpus_ms: u128,
|
||||
pub prepare_namespace_ms: u128,
|
||||
pub run_queries_ms: u128,
|
||||
pub summarize_ms: u128,
|
||||
pub finalize_ms: u128,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct PerformanceTimings {
|
||||
pub openai_base_url: String,
|
||||
pub ingestion_ms: u128,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub namespace_seed_ms: Option<u128>,
|
||||
pub evaluation_stage_ms: EvaluationStageTimings,
|
||||
pub stage_latency: StageLatencyBreakdown,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct RetrievedSummary {
|
||||
pub rank: usize,
|
||||
pub entity_id: String,
|
||||
pub source_id: String,
|
||||
pub entity_name: String,
|
||||
pub score: f32,
|
||||
pub matched: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub entity_description: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub entity_category: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub chunk_text_match: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub chunk_id_match: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub(crate) struct CaseDiagnostics {
|
||||
question_id: String,
|
||||
question: String,
|
||||
paragraph_id: String,
|
||||
paragraph_title: String,
|
||||
expected_source: String,
|
||||
expected_chunk_ids: Vec<String>,
|
||||
answers: Vec<String>,
|
||||
entity_match: bool,
|
||||
chunk_text_match: bool,
|
||||
chunk_id_match: bool,
|
||||
failure_reasons: Vec<String>,
|
||||
missing_expected_chunk_ids: Vec<String>,
|
||||
attached_chunk_ids: Vec<String>,
|
||||
retrieved: Vec<EntityDiagnostics>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pipeline: Option<retrieval_pipeline::PipelineDiagnostics>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct EntityDiagnostics {
|
||||
rank: usize,
|
||||
entity_id: String,
|
||||
source_id: String,
|
||||
name: String,
|
||||
score: f32,
|
||||
entity_match: bool,
|
||||
chunk_text_match: bool,
|
||||
chunk_id_match: bool,
|
||||
chunks: Vec<ChunkDiagnosticsEntry>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ChunkDiagnosticsEntry {
|
||||
chunk_id: String,
|
||||
score: f32,
|
||||
contains_answer: bool,
|
||||
expected_chunk: bool,
|
||||
snippet: String,
|
||||
}
|
||||
|
||||
pub(crate) struct SeededCase {
|
||||
question_id: String,
|
||||
question: String,
|
||||
@@ -213,6 +37,8 @@ pub(crate) struct SeededCase {
|
||||
paragraph_id: String,
|
||||
paragraph_title: String,
|
||||
expected_chunk_ids: Vec<String>,
|
||||
is_impossible: bool,
|
||||
has_verified_chunks: bool,
|
||||
}
|
||||
|
||||
pub(crate) fn cases_from_manifest(manifest: &ingest::CorpusManifest) -> Vec<SeededCase> {
|
||||
@@ -221,10 +47,15 @@ pub(crate) fn cases_from_manifest(manifest: &ingest::CorpusManifest) -> Vec<Seed
|
||||
title_map.insert(paragraph.paragraph_id.as_str(), paragraph.title.clone());
|
||||
}
|
||||
|
||||
let include_impossible = manifest.metadata.include_unanswerable;
|
||||
let require_verified_chunks = manifest.metadata.require_verified_chunks;
|
||||
|
||||
manifest
|
||||
.questions
|
||||
.iter()
|
||||
.filter(|question| !question.is_impossible)
|
||||
.filter(|question| {
|
||||
should_include_question(question, include_impossible, require_verified_chunks)
|
||||
})
|
||||
.map(|question| {
|
||||
let title = title_map
|
||||
.get(question.paragraph_id.as_str())
|
||||
@@ -238,66 +69,25 @@ pub(crate) fn cases_from_manifest(manifest: &ingest::CorpusManifest) -> Vec<Seed
|
||||
paragraph_id: question.paragraph_id.clone(),
|
||||
paragraph_title: title,
|
||||
expected_chunk_ids: question.matching_chunk_ids.clone(),
|
||||
is_impossible: question.is_impossible,
|
||||
has_verified_chunks: !question.matching_chunk_ids.is_empty(),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub(crate) fn text_contains_answer(text: &str, answers: &[String]) -> bool {
|
||||
if answers.is_empty() {
|
||||
return true;
|
||||
fn should_include_question(
|
||||
question: &ingest::CorpusQuestion,
|
||||
include_impossible: bool,
|
||||
require_verified_chunks: bool,
|
||||
) -> bool {
|
||||
if !include_impossible && question.is_impossible {
|
||||
return false;
|
||||
}
|
||||
let haystack = text.to_ascii_lowercase();
|
||||
answers.iter().any(|needle| haystack.contains(needle))
|
||||
}
|
||||
|
||||
pub(crate) fn compute_latency_stats(latencies: &[u128]) -> LatencyStats {
|
||||
if latencies.is_empty() {
|
||||
return LatencyStats {
|
||||
avg: 0.0,
|
||||
p50: 0,
|
||||
p95: 0,
|
||||
};
|
||||
if require_verified_chunks && question.matching_chunk_ids.is_empty() {
|
||||
return false;
|
||||
}
|
||||
let mut sorted = latencies.to_vec();
|
||||
sorted.sort_unstable();
|
||||
let sum: u128 = sorted.iter().copied().sum();
|
||||
let avg = sum as f64 / (sorted.len() as f64);
|
||||
let p50 = percentile(&sorted, 0.50);
|
||||
let p95 = percentile(&sorted, 0.95);
|
||||
LatencyStats { avg, p50, p95 }
|
||||
}
|
||||
|
||||
pub(crate) fn build_stage_latency_breakdown(
|
||||
samples: &[PipelineStageTimings],
|
||||
) -> StageLatencyBreakdown {
|
||||
fn collect_stage<F>(samples: &[PipelineStageTimings], selector: F) -> Vec<u128>
|
||||
where
|
||||
F: Fn(&PipelineStageTimings) -> u128,
|
||||
{
|
||||
samples.iter().map(selector).collect()
|
||||
}
|
||||
|
||||
StageLatencyBreakdown {
|
||||
collect_candidates: compute_latency_stats(&collect_stage(samples, |entry| {
|
||||
entry.collect_candidates_ms
|
||||
})),
|
||||
graph_expansion: compute_latency_stats(&collect_stage(samples, |entry| {
|
||||
entry.graph_expansion_ms
|
||||
})),
|
||||
chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| entry.chunk_attach_ms)),
|
||||
rerank: compute_latency_stats(&collect_stage(samples, |entry| entry.rerank_ms)),
|
||||
assemble: compute_latency_stats(&collect_stage(samples, |entry| entry.assemble_ms)),
|
||||
}
|
||||
}
|
||||
|
||||
fn percentile(sorted: &[u128], fraction: f64) -> u128 {
|
||||
if sorted.is_empty() {
|
||||
return 0;
|
||||
}
|
||||
let clamped = fraction.clamp(0.0, 1.0);
|
||||
let idx = (clamped * (sorted.len() as f64 - 1.0)).round() as usize;
|
||||
sorted[idx.min(sorted.len() - 1)]
|
||||
true
|
||||
}
|
||||
|
||||
pub async fn grow_slice(dataset: &ConvertedDataset, config: &Config) -> Result<()> {
|
||||
@@ -345,16 +135,16 @@ pub(crate) fn apply_dataset_tuning_overrides(
|
||||
return;
|
||||
}
|
||||
|
||||
if config.chunk_vector_take.is_none() {
|
||||
if config.retrieval.chunk_vector_take.is_none() {
|
||||
tuning.chunk_vector_take = tuning.chunk_vector_take.max(80);
|
||||
}
|
||||
if config.chunk_fts_take.is_none() {
|
||||
if config.retrieval.chunk_fts_take.is_none() {
|
||||
tuning.chunk_fts_take = tuning.chunk_fts_take.max(80);
|
||||
}
|
||||
if config.chunk_token_budget.is_none() {
|
||||
if config.retrieval.chunk_token_budget.is_none() {
|
||||
tuning.token_budget_estimate = tuning.token_budget_estimate.max(20_000);
|
||||
}
|
||||
if config.max_chunks_per_entity.is_none() {
|
||||
if config.retrieval.max_chunks_per_entity.is_none() {
|
||||
tuning.max_chunks_per_entity = tuning.max_chunks_per_entity.max(12);
|
||||
}
|
||||
if tuning.lexical_match_weight < 0.25 {
|
||||
@@ -362,92 +152,6 @@ pub(crate) fn apply_dataset_tuning_overrides(
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn build_case_diagnostics(
|
||||
summary: &CaseSummary,
|
||||
expected_chunk_ids: &[String],
|
||||
answers_lower: &[String],
|
||||
entities: &[composite_retrieval::RetrievedEntity],
|
||||
pipeline_stats: Option<retrieval_pipeline::PipelineDiagnostics>,
|
||||
) -> CaseDiagnostics {
|
||||
let expected_set: HashSet<&str> = expected_chunk_ids.iter().map(|id| id.as_str()).collect();
|
||||
let mut seen_chunks: HashSet<String> = HashSet::new();
|
||||
let mut attached_chunk_ids = Vec::new();
|
||||
let mut entity_diagnostics = Vec::new();
|
||||
|
||||
for (idx, entity) in entities.iter().enumerate() {
|
||||
let mut chunk_entries = Vec::new();
|
||||
for chunk in &entity.chunks {
|
||||
let contains_answer = text_contains_answer(&chunk.chunk.chunk, answers_lower);
|
||||
let expected_chunk = expected_set.contains(chunk.chunk.id.as_str());
|
||||
seen_chunks.insert(chunk.chunk.id.clone());
|
||||
attached_chunk_ids.push(chunk.chunk.id.clone());
|
||||
chunk_entries.push(ChunkDiagnosticsEntry {
|
||||
chunk_id: chunk.chunk.id.clone(),
|
||||
score: chunk.score,
|
||||
contains_answer,
|
||||
expected_chunk,
|
||||
snippet: chunk_preview(&chunk.chunk.chunk),
|
||||
});
|
||||
}
|
||||
entity_diagnostics.push(EntityDiagnostics {
|
||||
rank: idx + 1,
|
||||
entity_id: entity.entity.id.clone(),
|
||||
source_id: entity.entity.source_id.clone(),
|
||||
name: entity.entity.name.clone(),
|
||||
score: entity.score,
|
||||
entity_match: entity.entity.source_id == summary.expected_source,
|
||||
chunk_text_match: chunk_entries.iter().any(|entry| entry.contains_answer),
|
||||
chunk_id_match: chunk_entries.iter().any(|entry| entry.expected_chunk),
|
||||
chunks: chunk_entries,
|
||||
});
|
||||
}
|
||||
|
||||
let missing_expected_chunk_ids = expected_chunk_ids
|
||||
.iter()
|
||||
.filter(|id| !seen_chunks.contains(id.as_str()))
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut failure_reasons = Vec::new();
|
||||
if !summary.entity_match {
|
||||
failure_reasons.push("entity_miss".to_string());
|
||||
}
|
||||
if !summary.chunk_text_match {
|
||||
failure_reasons.push("chunk_text_missing".to_string());
|
||||
}
|
||||
if !summary.chunk_id_match {
|
||||
failure_reasons.push("chunk_id_missing".to_string());
|
||||
}
|
||||
if !missing_expected_chunk_ids.is_empty() {
|
||||
failure_reasons.push("expected_chunk_absent".to_string());
|
||||
}
|
||||
|
||||
CaseDiagnostics {
|
||||
question_id: summary.question_id.clone(),
|
||||
question: summary.question.clone(),
|
||||
paragraph_id: summary.paragraph_id.clone(),
|
||||
paragraph_title: summary.paragraph_title.clone(),
|
||||
expected_source: summary.expected_source.clone(),
|
||||
expected_chunk_ids: expected_chunk_ids.to_vec(),
|
||||
answers: summary.answers.clone(),
|
||||
entity_match: summary.entity_match,
|
||||
chunk_text_match: summary.chunk_text_match,
|
||||
chunk_id_match: summary.chunk_id_match,
|
||||
failure_reasons,
|
||||
missing_expected_chunk_ids,
|
||||
attached_chunk_ids,
|
||||
retrieved: entity_diagnostics,
|
||||
pipeline: pipeline_stats,
|
||||
}
|
||||
}
|
||||
|
||||
fn chunk_preview(text: &str) -> String {
|
||||
text.chars()
|
||||
.take(200)
|
||||
.collect::<String>()
|
||||
.replace('\n', " ")
|
||||
}
|
||||
|
||||
pub(crate) async fn write_chunk_diagnostics(path: &Path, cases: &[CaseDiagnostics]) -> Result<()> {
|
||||
args::ensure_parent(path)?;
|
||||
let mut file = tokio::fs::File::create(path)
|
||||
@@ -765,3 +469,118 @@ pub(crate) async fn load_or_init_system_settings(
|
||||
Err(err) => Err(err).context("loading system settings"),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::ingest::{CorpusManifest, CorpusMetadata, CorpusParagraph, CorpusQuestion};
|
||||
use chrono::Utc;
|
||||
use common::storage::types::text_content::TextContent;
|
||||
|
||||
fn sample_manifest() -> CorpusManifest {
|
||||
let paragraphs = vec![
|
||||
CorpusParagraph {
|
||||
paragraph_id: "p1".to_string(),
|
||||
title: "Alpha".to_string(),
|
||||
text_content: TextContent::new(
|
||||
"alpha context".to_string(),
|
||||
None,
|
||||
"test".to_string(),
|
||||
None,
|
||||
None,
|
||||
"user".to_string(),
|
||||
),
|
||||
entities: Vec::new(),
|
||||
relationships: Vec::new(),
|
||||
chunks: Vec::new(),
|
||||
},
|
||||
CorpusParagraph {
|
||||
paragraph_id: "p2".to_string(),
|
||||
title: "Beta".to_string(),
|
||||
text_content: TextContent::new(
|
||||
"beta context".to_string(),
|
||||
None,
|
||||
"test".to_string(),
|
||||
None,
|
||||
None,
|
||||
"user".to_string(),
|
||||
),
|
||||
entities: Vec::new(),
|
||||
relationships: Vec::new(),
|
||||
chunks: Vec::new(),
|
||||
},
|
||||
];
|
||||
let questions = vec![
|
||||
CorpusQuestion {
|
||||
question_id: "q1".to_string(),
|
||||
paragraph_id: "p1".to_string(),
|
||||
text_content_id: "tc-alpha".to_string(),
|
||||
question_text: "What is Alpha?".to_string(),
|
||||
answers: vec!["Alpha".to_string()],
|
||||
is_impossible: false,
|
||||
matching_chunk_ids: vec!["chunk-alpha".to_string()],
|
||||
},
|
||||
CorpusQuestion {
|
||||
question_id: "q2".to_string(),
|
||||
paragraph_id: "p1".to_string(),
|
||||
text_content_id: "tc-alpha".to_string(),
|
||||
question_text: "Unanswerable?".to_string(),
|
||||
answers: Vec::new(),
|
||||
is_impossible: true,
|
||||
matching_chunk_ids: Vec::new(),
|
||||
},
|
||||
CorpusQuestion {
|
||||
question_id: "q3".to_string(),
|
||||
paragraph_id: "p2".to_string(),
|
||||
text_content_id: "tc-beta".to_string(),
|
||||
question_text: "Where is Beta?".to_string(),
|
||||
answers: vec!["Beta".to_string()],
|
||||
is_impossible: false,
|
||||
matching_chunk_ids: Vec::new(),
|
||||
},
|
||||
];
|
||||
CorpusManifest {
|
||||
version: 1,
|
||||
metadata: CorpusMetadata {
|
||||
dataset_id: "ds".to_string(),
|
||||
dataset_label: "Dataset".to_string(),
|
||||
slice_id: "slice".to_string(),
|
||||
include_unanswerable: true,
|
||||
require_verified_chunks: true,
|
||||
ingestion_fingerprint: "fp".to_string(),
|
||||
embedding_backend: "test".to_string(),
|
||||
embedding_model: None,
|
||||
embedding_dimension: 3,
|
||||
converted_checksum: "chk".to_string(),
|
||||
generated_at: Utc::now(),
|
||||
paragraph_count: paragraphs.len(),
|
||||
question_count: questions.len(),
|
||||
},
|
||||
paragraphs,
|
||||
questions,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cases_respect_mode_filters() {
|
||||
let mut manifest = sample_manifest();
|
||||
manifest.metadata.include_unanswerable = false;
|
||||
manifest.metadata.require_verified_chunks = true;
|
||||
|
||||
let strict_cases = cases_from_manifest(&manifest);
|
||||
assert_eq!(strict_cases.len(), 1);
|
||||
assert_eq!(strict_cases[0].question_id, "q1");
|
||||
assert_eq!(strict_cases[0].paragraph_title, "Alpha");
|
||||
|
||||
let mut llm_manifest = manifest.clone();
|
||||
llm_manifest.metadata.include_unanswerable = true;
|
||||
llm_manifest.metadata.require_verified_chunks = false;
|
||||
|
||||
let llm_cases = cases_from_manifest(&llm_manifest);
|
||||
let ids: Vec<_> = llm_cases
|
||||
.iter()
|
||||
.map(|case| case.question_id.as_str())
|
||||
.collect();
|
||||
assert_eq!(ids, vec!["q1", "q2", "q3"]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ use common::storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{system_settings::SystemSettings, user::User},
|
||||
};
|
||||
use composite_retrieval::{
|
||||
use retrieval_pipeline::{
|
||||
pipeline::{PipelineStageTimings, RetrievalConfig},
|
||||
reranking::RerankerPool,
|
||||
};
|
||||
@@ -52,6 +52,7 @@ pub(super) struct EvaluationContext<'a> {
|
||||
pub eval_user: Option<User>,
|
||||
pub corpus_handle: Option<ingest::CorpusHandle>,
|
||||
pub cases: Vec<SeededCase>,
|
||||
pub filtered_questions: usize,
|
||||
pub stage_latency_samples: Vec<PipelineStageTimings>,
|
||||
pub latencies: Vec<u128>,
|
||||
pub diagnostics_output: Vec<CaseDiagnostics>,
|
||||
@@ -94,6 +95,7 @@ impl<'a> EvaluationContext<'a> {
|
||||
eval_user: None,
|
||||
corpus_handle: None,
|
||||
cases: Vec::new(),
|
||||
filtered_questions: 0,
|
||||
stage_latency_samples: Vec::new(),
|
||||
latencies: Vec::new(),
|
||||
diagnostics_output: Vec::new(),
|
||||
|
||||
@@ -128,13 +128,28 @@ pub(crate) async fn prepare_namespace(
|
||||
let user = ensure_eval_user(ctx.db()).await?;
|
||||
ctx.eval_user = Some(user);
|
||||
|
||||
let cases = cases_from_manifest(&ctx.corpus_handle().manifest);
|
||||
let corpus_handle = ctx.corpus_handle();
|
||||
let total_manifest_questions = corpus_handle.manifest.questions.len();
|
||||
let cases = cases_from_manifest(&corpus_handle.manifest);
|
||||
let include_impossible = corpus_handle.manifest.metadata.include_unanswerable;
|
||||
let require_verified_chunks = corpus_handle.manifest.metadata.require_verified_chunks;
|
||||
let filtered = total_manifest_questions.saturating_sub(cases.len());
|
||||
if filtered > 0 {
|
||||
info!(
|
||||
filtered_questions = filtered,
|
||||
total_questions = total_manifest_questions,
|
||||
includes_impossible = include_impossible,
|
||||
require_verified_chunks = require_verified_chunks,
|
||||
"Filtered questions not eligible for this evaluation mode (impossible or unverifiable)"
|
||||
);
|
||||
}
|
||||
if cases.is_empty() {
|
||||
return Err(anyhow!(
|
||||
"no answerable questions found in converted dataset for evaluation"
|
||||
"no eligible questions found in converted dataset for evaluation (consider --llm-mode or refreshing ingestion data)"
|
||||
));
|
||||
}
|
||||
ctx.cases = cases;
|
||||
ctx.filtered_questions = filtered;
|
||||
ctx.namespace_reused = namespace_reused;
|
||||
ctx.namespace_seed_ms = namespace_seed_ms;
|
||||
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
use std::{collections::HashSet, sync::Arc, time::Instant};
|
||||
|
||||
use anyhow::Context;
|
||||
use futures::stream::{self, StreamExt, TryStreamExt};
|
||||
use futures::stream::{self, StreamExt};
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::eval::{
|
||||
apply_dataset_tuning_overrides, build_case_diagnostics, text_contains_answer, CaseDiagnostics,
|
||||
CaseSummary, RetrievedSummary,
|
||||
adapt_strategy_output, apply_dataset_tuning_overrides, build_case_diagnostics,
|
||||
text_contains_answer, CaseDiagnostics, CaseSummary, RetrievedSummary,
|
||||
};
|
||||
use retrieval_pipeline::{
|
||||
pipeline::{self, PipelineStageTimings, RetrievalConfig},
|
||||
reranking::RerankerPool,
|
||||
};
|
||||
use composite_retrieval::pipeline::{self, PipelineStageTimings, RetrievalConfig};
|
||||
use composite_retrieval::reranking::RerankerPool;
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
use super::super::{
|
||||
@@ -38,30 +40,34 @@ pub(crate) async fn run_queries(
|
||||
let total_cases = ctx.cases.len();
|
||||
let cases_iter = std::mem::take(&mut ctx.cases).into_iter().enumerate();
|
||||
|
||||
let rerank_pool = if config.rerank {
|
||||
Some(RerankerPool::new(config.rerank_pool_size).context("initialising reranker pool")?)
|
||||
let rerank_pool = if config.retrieval.rerank {
|
||||
Some(
|
||||
RerankerPool::new(config.retrieval.rerank_pool_size)
|
||||
.context("initialising reranker pool")?,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut retrieval_config = RetrievalConfig::default();
|
||||
retrieval_config.tuning.rerank_keep_top = config.rerank_keep_top;
|
||||
if retrieval_config.tuning.fallback_min_results < config.rerank_keep_top {
|
||||
retrieval_config.tuning.fallback_min_results = config.rerank_keep_top;
|
||||
retrieval_config.strategy = config.retrieval.strategy;
|
||||
retrieval_config.tuning.rerank_keep_top = config.retrieval.rerank_keep_top;
|
||||
if retrieval_config.tuning.fallback_min_results < config.retrieval.rerank_keep_top {
|
||||
retrieval_config.tuning.fallback_min_results = config.retrieval.rerank_keep_top;
|
||||
}
|
||||
if let Some(value) = config.chunk_vector_take {
|
||||
if let Some(value) = config.retrieval.chunk_vector_take {
|
||||
retrieval_config.tuning.chunk_vector_take = value;
|
||||
}
|
||||
if let Some(value) = config.chunk_fts_take {
|
||||
if let Some(value) = config.retrieval.chunk_fts_take {
|
||||
retrieval_config.tuning.chunk_fts_take = value;
|
||||
}
|
||||
if let Some(value) = config.chunk_token_budget {
|
||||
if let Some(value) = config.retrieval.chunk_token_budget {
|
||||
retrieval_config.tuning.token_budget_estimate = value;
|
||||
}
|
||||
if let Some(value) = config.chunk_avg_chars_per_token {
|
||||
if let Some(value) = config.retrieval.chunk_avg_chars_per_token {
|
||||
retrieval_config.tuning.avg_chars_per_token = value;
|
||||
}
|
||||
if let Some(value) = config.max_chunks_per_entity {
|
||||
if let Some(value) = config.retrieval.max_chunks_per_entity {
|
||||
retrieval_config.tuning.max_chunks_per_entity = value;
|
||||
}
|
||||
|
||||
@@ -69,9 +75,11 @@ pub(crate) async fn run_queries(
|
||||
|
||||
let active_tuning = retrieval_config.tuning.clone();
|
||||
let effective_chunk_vector = config
|
||||
.retrieval
|
||||
.chunk_vector_take
|
||||
.unwrap_or(active_tuning.chunk_vector_take);
|
||||
let effective_chunk_fts = config
|
||||
.retrieval
|
||||
.chunk_fts_take
|
||||
.unwrap_or(active_tuning.chunk_fts_take);
|
||||
|
||||
@@ -83,11 +91,11 @@ pub(crate) async fn run_queries(
|
||||
.limit
|
||||
.unwrap_or(ctx.window_total_cases),
|
||||
negative_multiplier = %slice_settings.negative_multiplier,
|
||||
rerank_enabled = config.rerank,
|
||||
rerank_pool_size = config.rerank_pool_size,
|
||||
rerank_keep_top = config.rerank_keep_top,
|
||||
chunk_min = config.chunk_min_chars,
|
||||
chunk_max = config.chunk_max_chars,
|
||||
rerank_enabled = config.retrieval.rerank,
|
||||
rerank_pool_size = config.retrieval.rerank_pool_size,
|
||||
rerank_keep_top = config.retrieval.rerank_keep_top,
|
||||
chunk_min = config.retrieval.chunk_min_chars,
|
||||
chunk_max = config.retrieval.chunk_max_chars,
|
||||
chunk_vector_take = effective_chunk_vector,
|
||||
chunk_fts_take = effective_chunk_fts,
|
||||
chunk_token_budget = active_tuning.token_budget_estimate,
|
||||
@@ -122,12 +130,7 @@ pub(crate) async fn run_queries(
|
||||
let db = ctx.db().clone();
|
||||
let openai_client = ctx.openai_client();
|
||||
|
||||
let results: Vec<(
|
||||
usize,
|
||||
CaseSummary,
|
||||
Option<CaseDiagnostics>,
|
||||
PipelineStageTimings,
|
||||
)> = stream::iter(cases_iter)
|
||||
let raw_results = stream::iter(cases_iter)
|
||||
.map(move |(idx, case)| {
|
||||
let db = db.clone();
|
||||
let openai_client = openai_client.clone();
|
||||
@@ -152,6 +155,8 @@ pub(crate) async fn run_queries(
|
||||
paragraph_id,
|
||||
paragraph_title,
|
||||
expected_chunk_ids,
|
||||
is_impossible,
|
||||
has_verified_chunks,
|
||||
} = case;
|
||||
let query_start = Instant::now();
|
||||
|
||||
@@ -165,7 +170,7 @@ pub(crate) async fn run_queries(
|
||||
None => None,
|
||||
};
|
||||
|
||||
let (results, pipeline_diagnostics, stage_timings) = if diagnostics_enabled {
|
||||
let (result_output, pipeline_diagnostics, stage_timings) = if diagnostics_enabled {
|
||||
let outcome = pipeline::run_pipeline_with_embedding_with_diagnostics(
|
||||
&db,
|
||||
&openai_client,
|
||||
@@ -194,26 +199,27 @@ pub(crate) async fn run_queries(
|
||||
};
|
||||
let query_latency = query_start.elapsed().as_millis() as u128;
|
||||
|
||||
let candidates = adapt_strategy_output(result_output);
|
||||
let mut retrieved = Vec::new();
|
||||
let mut match_rank = None;
|
||||
let answers_lower: Vec<String> =
|
||||
answers.iter().map(|ans| ans.to_ascii_lowercase()).collect();
|
||||
let expected_chunk_ids_set: HashSet<&str> =
|
||||
expected_chunk_ids.iter().map(|id| id.as_str()).collect();
|
||||
let chunk_id_required = !expected_chunk_ids_set.is_empty();
|
||||
let chunk_id_required = has_verified_chunks;
|
||||
let mut entity_hit = false;
|
||||
let mut chunk_text_hit = false;
|
||||
let mut chunk_id_hit = !chunk_id_required;
|
||||
|
||||
for (idx_entity, entity) in results.iter().enumerate() {
|
||||
for (idx_entity, candidate) in candidates.iter().enumerate() {
|
||||
if idx_entity >= config.k {
|
||||
break;
|
||||
}
|
||||
let entity_match = entity.entity.source_id == expected_source;
|
||||
let entity_match = candidate.source_id == expected_source;
|
||||
if entity_match {
|
||||
entity_hit = true;
|
||||
}
|
||||
let chunk_text_for_entity = entity
|
||||
let chunk_text_for_entity = candidate
|
||||
.chunks
|
||||
.iter()
|
||||
.any(|chunk| text_contains_answer(&chunk.chunk.chunk, &answers_lower));
|
||||
@@ -221,8 +227,8 @@ pub(crate) async fn run_queries(
|
||||
chunk_text_hit = true;
|
||||
}
|
||||
let chunk_id_for_entity = if chunk_id_required {
|
||||
expected_chunk_ids_set.contains(entity.entity.source_id.as_str())
|
||||
|| entity.chunks.iter().any(|chunk| {
|
||||
expected_chunk_ids_set.contains(candidate.source_id.as_str())
|
||||
|| candidate.chunks.iter().any(|chunk| {
|
||||
expected_chunk_ids_set.contains(chunk.chunk.id.as_str())
|
||||
})
|
||||
} else {
|
||||
@@ -236,9 +242,11 @@ pub(crate) async fn run_queries(
|
||||
match_rank = Some(idx_entity + 1);
|
||||
}
|
||||
let detail_fields = if config.detailed_report {
|
||||
let description = candidate.entity_description.clone();
|
||||
let category = candidate.entity_category.clone();
|
||||
(
|
||||
Some(entity.entity.description.clone()),
|
||||
Some(format!("{:?}", entity.entity.entity_type)),
|
||||
description,
|
||||
category,
|
||||
Some(chunk_text_for_entity),
|
||||
Some(chunk_id_for_entity),
|
||||
)
|
||||
@@ -247,10 +255,10 @@ pub(crate) async fn run_queries(
|
||||
};
|
||||
retrieved.push(RetrievedSummary {
|
||||
rank: idx_entity + 1,
|
||||
entity_id: entity.entity.id.clone(),
|
||||
source_id: entity.entity.source_id.clone(),
|
||||
entity_name: entity.entity.name.clone(),
|
||||
score: entity.score,
|
||||
entity_id: candidate.entity_id.clone(),
|
||||
source_id: candidate.source_id.clone(),
|
||||
entity_name: candidate.entity_name.clone(),
|
||||
score: candidate.score,
|
||||
matched: success,
|
||||
entity_description: detail_fields.0,
|
||||
entity_category: detail_fields.1,
|
||||
@@ -271,6 +279,8 @@ pub(crate) async fn run_queries(
|
||||
entity_match: entity_hit,
|
||||
chunk_text_match: chunk_text_hit,
|
||||
chunk_id_match: chunk_id_hit,
|
||||
is_impossible,
|
||||
has_verified_chunks,
|
||||
match_rank,
|
||||
latency_ms: query_latency,
|
||||
retrieved,
|
||||
@@ -281,7 +291,7 @@ pub(crate) async fn run_queries(
|
||||
&summary,
|
||||
&expected_chunk_ids,
|
||||
&answers_lower,
|
||||
&results,
|
||||
&candidates,
|
||||
pipeline_diagnostics,
|
||||
))
|
||||
} else {
|
||||
@@ -300,8 +310,18 @@ pub(crate) async fn run_queries(
|
||||
}
|
||||
})
|
||||
.buffer_unordered(concurrency)
|
||||
.try_collect()
|
||||
.await?;
|
||||
.collect::<Vec<_>>()
|
||||
.await;
|
||||
|
||||
let mut results = Vec::with_capacity(raw_results.len());
|
||||
for result in raw_results {
|
||||
match result {
|
||||
Ok(val) => results.push(val),
|
||||
Err(err) => {
|
||||
tracing::error!(error = ?err, "Query execution failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut ordered = results;
|
||||
ordered.sort_by_key(|(idx, ..)| *idx);
|
||||
|
||||
@@ -42,7 +42,18 @@ pub(crate) async fn summarize(
|
||||
let mut correct_at_1 = 0usize;
|
||||
let mut correct_at_2 = 0usize;
|
||||
let mut correct_at_3 = 0usize;
|
||||
let mut retrieval_cases = 0usize;
|
||||
let mut llm_cases = 0usize;
|
||||
let mut llm_answered = 0usize;
|
||||
for summary in &summaries {
|
||||
if summary.is_impossible {
|
||||
llm_cases += 1;
|
||||
if summary.matched {
|
||||
llm_answered += 1;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
retrieval_cases += 1;
|
||||
if summary.matched {
|
||||
correct += 1;
|
||||
if let Some(rank) = summary.match_rank {
|
||||
@@ -62,25 +73,31 @@ pub(crate) async fn summarize(
|
||||
let latency_stats = compute_latency_stats(&latencies);
|
||||
let stage_latency = build_stage_latency_breakdown(&stage_latency_samples);
|
||||
|
||||
let precision = if total_cases == 0 {
|
||||
let retrieval_precision = if retrieval_cases == 0 {
|
||||
0.0
|
||||
} else {
|
||||
(correct as f64) / (total_cases as f64)
|
||||
(correct as f64) / (retrieval_cases as f64)
|
||||
};
|
||||
let precision_at_1 = if total_cases == 0 {
|
||||
let llm_precision = if llm_cases == 0 {
|
||||
0.0
|
||||
} else {
|
||||
(correct_at_1 as f64) / (total_cases as f64)
|
||||
(llm_answered as f64) / (llm_cases as f64)
|
||||
};
|
||||
let precision_at_2 = if total_cases == 0 {
|
||||
let precision = retrieval_precision;
|
||||
let precision_at_1 = if retrieval_cases == 0 {
|
||||
0.0
|
||||
} else {
|
||||
(correct_at_2 as f64) / (total_cases as f64)
|
||||
(correct_at_1 as f64) / (retrieval_cases as f64)
|
||||
};
|
||||
let precision_at_3 = if total_cases == 0 {
|
||||
let precision_at_2 = if retrieval_cases == 0 {
|
||||
0.0
|
||||
} else {
|
||||
(correct_at_3 as f64) / (total_cases as f64)
|
||||
(correct_at_2 as f64) / (retrieval_cases as f64)
|
||||
};
|
||||
let precision_at_3 = if retrieval_cases == 0 {
|
||||
0.0
|
||||
} else {
|
||||
(correct_at_3 as f64) / (retrieval_cases as f64)
|
||||
};
|
||||
|
||||
let active_tuning = ctx
|
||||
@@ -119,6 +136,15 @@ pub(crate) async fn summarize(
|
||||
dataset_label: dataset.metadata.label.clone(),
|
||||
dataset_includes_unanswerable: dataset.metadata.include_unanswerable,
|
||||
dataset_source: dataset.source.clone(),
|
||||
includes_impossible_cases: slice.manifest.includes_unanswerable,
|
||||
require_verified_chunks: slice.manifest.require_verified_chunks,
|
||||
filtered_questions: ctx.filtered_questions,
|
||||
retrieval_cases,
|
||||
retrieval_correct: correct,
|
||||
retrieval_precision,
|
||||
llm_cases,
|
||||
llm_answered,
|
||||
llm_precision,
|
||||
slice_id: slice.manifest.slice_id.clone(),
|
||||
slice_seed: slice.manifest.seed,
|
||||
slice_total_cases: slice.manifest.case_count,
|
||||
@@ -146,11 +172,15 @@ pub(crate) async fn summarize(
|
||||
embedding_backend: ctx.embedding_provider().backend_label().to_string(),
|
||||
embedding_model: ctx.embedding_provider().model_code(),
|
||||
embedding_dimension: ctx.embedding_provider().dimension(),
|
||||
rerank_enabled: config.rerank,
|
||||
rerank_pool_size: ctx.rerank_pool.as_ref().map(|_| config.rerank_pool_size),
|
||||
rerank_keep_top: config.rerank_keep_top,
|
||||
rerank_enabled: config.retrieval.rerank,
|
||||
rerank_pool_size: ctx
|
||||
.rerank_pool
|
||||
.as_ref()
|
||||
.map(|_| config.retrieval.rerank_pool_size),
|
||||
rerank_keep_top: config.retrieval.rerank_keep_top,
|
||||
concurrency: config.concurrency.max(1),
|
||||
detailed_report: config.detailed_report,
|
||||
retrieval_strategy: config.retrieval.strategy.to_string(),
|
||||
chunk_vector_take: active_tuning.chunk_vector_take,
|
||||
chunk_fts_take: active_tuning.chunk_fts_take,
|
||||
chunk_token_budget: active_tuning.token_budget_estimate,
|
||||
|
||||
396
eval/src/eval/types.rs
Normal file
396
eval/src/eval/types.rs
Normal file
@@ -0,0 +1,396 @@
|
||||
use std::collections::HashSet;
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use retrieval_pipeline::{
|
||||
PipelineDiagnostics, PipelineStageTimings, RetrievedChunk, RetrievedEntity, StrategyOutput,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct EvaluationSummary {
|
||||
pub generated_at: DateTime<Utc>,
|
||||
pub k: usize,
|
||||
pub limit: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub run_label: Option<String>,
|
||||
pub total_cases: usize,
|
||||
pub correct: usize,
|
||||
pub precision: f64,
|
||||
pub correct_at_1: usize,
|
||||
pub correct_at_2: usize,
|
||||
pub correct_at_3: usize,
|
||||
pub precision_at_1: f64,
|
||||
pub precision_at_2: f64,
|
||||
pub precision_at_3: f64,
|
||||
pub duration_ms: u128,
|
||||
pub dataset_id: String,
|
||||
pub dataset_label: String,
|
||||
pub dataset_includes_unanswerable: bool,
|
||||
pub dataset_source: String,
|
||||
pub includes_impossible_cases: bool,
|
||||
pub require_verified_chunks: bool,
|
||||
pub filtered_questions: usize,
|
||||
pub retrieval_cases: usize,
|
||||
pub retrieval_correct: usize,
|
||||
pub retrieval_precision: f64,
|
||||
pub llm_cases: usize,
|
||||
pub llm_answered: usize,
|
||||
pub llm_precision: f64,
|
||||
pub slice_id: String,
|
||||
pub slice_seed: u64,
|
||||
pub slice_total_cases: usize,
|
||||
pub slice_window_offset: usize,
|
||||
pub slice_window_length: usize,
|
||||
pub slice_cases: usize,
|
||||
pub slice_positive_paragraphs: usize,
|
||||
pub slice_negative_paragraphs: usize,
|
||||
pub slice_total_paragraphs: usize,
|
||||
pub slice_negative_multiplier: f32,
|
||||
pub namespace_reused: bool,
|
||||
pub corpus_paragraphs: usize,
|
||||
pub ingestion_cache_path: String,
|
||||
pub ingestion_reused: bool,
|
||||
pub ingestion_embeddings_reused: bool,
|
||||
pub ingestion_fingerprint: String,
|
||||
pub positive_paragraphs_reused: usize,
|
||||
pub negative_paragraphs_reused: usize,
|
||||
pub latency_ms: LatencyStats,
|
||||
pub perf: PerformanceTimings,
|
||||
pub embedding_backend: String,
|
||||
pub embedding_model: Option<String>,
|
||||
pub embedding_dimension: usize,
|
||||
pub rerank_enabled: bool,
|
||||
pub rerank_pool_size: Option<usize>,
|
||||
pub rerank_keep_top: usize,
|
||||
pub concurrency: usize,
|
||||
pub detailed_report: bool,
|
||||
pub retrieval_strategy: String,
|
||||
pub chunk_vector_take: usize,
|
||||
pub chunk_fts_take: usize,
|
||||
pub chunk_token_budget: usize,
|
||||
pub chunk_avg_chars_per_token: usize,
|
||||
pub max_chunks_per_entity: usize,
|
||||
pub cases: Vec<CaseSummary>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct CaseSummary {
|
||||
pub question_id: String,
|
||||
pub question: String,
|
||||
pub paragraph_id: String,
|
||||
pub paragraph_title: String,
|
||||
pub expected_source: String,
|
||||
pub answers: Vec<String>,
|
||||
pub matched: bool,
|
||||
pub entity_match: bool,
|
||||
pub chunk_text_match: bool,
|
||||
pub chunk_id_match: bool,
|
||||
pub is_impossible: bool,
|
||||
pub has_verified_chunks: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub match_rank: Option<usize>,
|
||||
pub latency_ms: u128,
|
||||
pub retrieved: Vec<RetrievedSummary>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LatencyStats {
|
||||
pub avg: f64,
|
||||
pub p50: u128,
|
||||
pub p95: u128,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct StageLatencyBreakdown {
|
||||
pub embed: LatencyStats,
|
||||
pub collect_candidates: LatencyStats,
|
||||
pub graph_expansion: LatencyStats,
|
||||
pub chunk_attach: LatencyStats,
|
||||
pub rerank: LatencyStats,
|
||||
pub assemble: LatencyStats,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, Serialize)]
|
||||
pub struct EvaluationStageTimings {
|
||||
pub prepare_slice_ms: u128,
|
||||
pub prepare_db_ms: u128,
|
||||
pub prepare_corpus_ms: u128,
|
||||
pub prepare_namespace_ms: u128,
|
||||
pub run_queries_ms: u128,
|
||||
pub summarize_ms: u128,
|
||||
pub finalize_ms: u128,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct PerformanceTimings {
|
||||
pub openai_base_url: String,
|
||||
pub ingestion_ms: u128,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub namespace_seed_ms: Option<u128>,
|
||||
pub evaluation_stage_ms: EvaluationStageTimings,
|
||||
pub stage_latency: StageLatencyBreakdown,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct RetrievedSummary {
|
||||
pub rank: usize,
|
||||
pub entity_id: String,
|
||||
pub source_id: String,
|
||||
pub entity_name: String,
|
||||
pub score: f32,
|
||||
pub matched: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub entity_description: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub entity_category: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub chunk_text_match: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub chunk_id_match: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EvaluationCandidate {
|
||||
pub entity_id: String,
|
||||
pub source_id: String,
|
||||
pub entity_name: String,
|
||||
pub entity_description: Option<String>,
|
||||
pub entity_category: Option<String>,
|
||||
pub score: f32,
|
||||
pub chunks: Vec<RetrievedChunk>,
|
||||
}
|
||||
|
||||
impl EvaluationCandidate {
|
||||
fn from_entity(entity: RetrievedEntity) -> Self {
|
||||
let entity_category = Some(format!("{:?}", entity.entity.entity_type));
|
||||
Self {
|
||||
entity_id: entity.entity.id.clone(),
|
||||
source_id: entity.entity.source_id.clone(),
|
||||
entity_name: entity.entity.name.clone(),
|
||||
entity_description: Some(entity.entity.description.clone()),
|
||||
entity_category,
|
||||
score: entity.score,
|
||||
chunks: entity.chunks,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_chunk(chunk: RetrievedChunk) -> Self {
|
||||
let snippet = chunk_snippet(&chunk.chunk.chunk);
|
||||
Self {
|
||||
entity_id: chunk.chunk.id.clone(),
|
||||
source_id: chunk.chunk.source_id.clone(),
|
||||
entity_name: chunk.chunk.source_id.clone(),
|
||||
entity_description: Some(snippet),
|
||||
entity_category: Some("Chunk".to_string()),
|
||||
score: chunk.score,
|
||||
chunks: vec![chunk],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn adapt_strategy_output(output: StrategyOutput) -> Vec<EvaluationCandidate> {
|
||||
match output {
|
||||
StrategyOutput::Entities(entities) => entities
|
||||
.into_iter()
|
||||
.map(EvaluationCandidate::from_entity)
|
||||
.collect(),
|
||||
StrategyOutput::Chunks(chunks) => chunks
|
||||
.into_iter()
|
||||
.map(EvaluationCandidate::from_chunk)
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct CaseDiagnostics {
|
||||
pub question_id: String,
|
||||
pub question: String,
|
||||
pub paragraph_id: String,
|
||||
pub paragraph_title: String,
|
||||
pub expected_source: String,
|
||||
pub expected_chunk_ids: Vec<String>,
|
||||
pub answers: Vec<String>,
|
||||
pub entity_match: bool,
|
||||
pub chunk_text_match: bool,
|
||||
pub chunk_id_match: bool,
|
||||
pub failure_reasons: Vec<String>,
|
||||
pub missing_expected_chunk_ids: Vec<String>,
|
||||
pub attached_chunk_ids: Vec<String>,
|
||||
pub retrieved: Vec<EntityDiagnostics>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub pipeline: Option<PipelineDiagnostics>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct EntityDiagnostics {
|
||||
pub rank: usize,
|
||||
pub entity_id: String,
|
||||
pub source_id: String,
|
||||
pub name: String,
|
||||
pub score: f32,
|
||||
pub entity_match: bool,
|
||||
pub chunk_text_match: bool,
|
||||
pub chunk_id_match: bool,
|
||||
pub chunks: Vec<ChunkDiagnosticsEntry>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ChunkDiagnosticsEntry {
|
||||
pub chunk_id: String,
|
||||
pub score: f32,
|
||||
pub contains_answer: bool,
|
||||
pub expected_chunk: bool,
|
||||
pub snippet: String,
|
||||
}
|
||||
|
||||
pub fn text_contains_answer(text: &str, answers: &[String]) -> bool {
|
||||
if answers.is_empty() {
|
||||
return true;
|
||||
}
|
||||
let haystack = text.to_ascii_lowercase();
|
||||
answers.iter().any(|needle| haystack.contains(needle))
|
||||
}
|
||||
|
||||
fn chunk_snippet(text: &str) -> String {
|
||||
const MAX_CHARS: usize = 160;
|
||||
let trimmed = text.trim();
|
||||
if trimmed.chars().count() <= MAX_CHARS {
|
||||
return trimmed.to_string();
|
||||
}
|
||||
let mut acc = String::with_capacity(MAX_CHARS + 3);
|
||||
for (idx, ch) in trimmed.chars().enumerate() {
|
||||
if idx >= MAX_CHARS {
|
||||
break;
|
||||
}
|
||||
acc.push(ch);
|
||||
}
|
||||
acc.push_str("...");
|
||||
acc
|
||||
}
|
||||
|
||||
pub fn compute_latency_stats(latencies: &[u128]) -> LatencyStats {
|
||||
if latencies.is_empty() {
|
||||
return LatencyStats {
|
||||
avg: 0.0,
|
||||
p50: 0,
|
||||
p95: 0,
|
||||
};
|
||||
}
|
||||
let mut sorted = latencies.to_vec();
|
||||
sorted.sort_unstable();
|
||||
let sum: u128 = sorted.iter().copied().sum();
|
||||
let avg = sum as f64 / (sorted.len() as f64);
|
||||
let p50 = percentile(&sorted, 0.50);
|
||||
let p95 = percentile(&sorted, 0.95);
|
||||
LatencyStats { avg, p50, p95 }
|
||||
}
|
||||
|
||||
pub fn build_stage_latency_breakdown(samples: &[PipelineStageTimings]) -> StageLatencyBreakdown {
|
||||
fn collect_stage<F>(samples: &[PipelineStageTimings], selector: F) -> Vec<u128>
|
||||
where
|
||||
F: Fn(&PipelineStageTimings) -> u128,
|
||||
{
|
||||
samples.iter().map(selector).collect()
|
||||
}
|
||||
|
||||
StageLatencyBreakdown {
|
||||
embed: compute_latency_stats(&collect_stage(samples, |entry| entry.embed_ms)),
|
||||
collect_candidates: compute_latency_stats(&collect_stage(samples, |entry| {
|
||||
entry.collect_candidates_ms
|
||||
})),
|
||||
graph_expansion: compute_latency_stats(&collect_stage(samples, |entry| {
|
||||
entry.graph_expansion_ms
|
||||
})),
|
||||
chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| entry.chunk_attach_ms)),
|
||||
rerank: compute_latency_stats(&collect_stage(samples, |entry| entry.rerank_ms)),
|
||||
assemble: compute_latency_stats(&collect_stage(samples, |entry| entry.assemble_ms)),
|
||||
}
|
||||
}
|
||||
|
||||
fn percentile(sorted: &[u128], fraction: f64) -> u128 {
|
||||
if sorted.is_empty() {
|
||||
return 0;
|
||||
}
|
||||
let clamped = fraction.clamp(0.0, 1.0);
|
||||
let idx = (clamped * (sorted.len() as f64 - 1.0)).round() as usize;
|
||||
sorted[idx.min(sorted.len() - 1)]
|
||||
}
|
||||
|
||||
pub fn build_case_diagnostics(
|
||||
summary: &CaseSummary,
|
||||
expected_chunk_ids: &[String],
|
||||
answers_lower: &[String],
|
||||
candidates: &[EvaluationCandidate],
|
||||
pipeline_stats: Option<PipelineDiagnostics>,
|
||||
) -> CaseDiagnostics {
|
||||
let expected_set: HashSet<&str> = expected_chunk_ids.iter().map(|id| id.as_str()).collect();
|
||||
let mut seen_chunks: HashSet<String> = HashSet::new();
|
||||
let mut attached_chunk_ids = Vec::new();
|
||||
let mut entity_diagnostics = Vec::new();
|
||||
|
||||
for (idx, candidate) in candidates.iter().enumerate() {
|
||||
let mut chunk_entries = Vec::new();
|
||||
for chunk in &candidate.chunks {
|
||||
let contains_answer = text_contains_answer(&chunk.chunk.chunk, answers_lower);
|
||||
let expected_chunk = expected_set.contains(chunk.chunk.id.as_str());
|
||||
seen_chunks.insert(chunk.chunk.id.clone());
|
||||
attached_chunk_ids.push(chunk.chunk.id.clone());
|
||||
chunk_entries.push(ChunkDiagnosticsEntry {
|
||||
chunk_id: chunk.chunk.id.clone(),
|
||||
score: chunk.score,
|
||||
contains_answer,
|
||||
expected_chunk,
|
||||
snippet: chunk_snippet(&chunk.chunk.chunk),
|
||||
});
|
||||
}
|
||||
entity_diagnostics.push(EntityDiagnostics {
|
||||
rank: idx + 1,
|
||||
entity_id: candidate.entity_id.clone(),
|
||||
source_id: candidate.source_id.clone(),
|
||||
name: candidate.entity_name.clone(),
|
||||
score: candidate.score,
|
||||
entity_match: candidate.source_id == summary.expected_source,
|
||||
chunk_text_match: chunk_entries.iter().any(|entry| entry.contains_answer),
|
||||
chunk_id_match: chunk_entries.iter().any(|entry| entry.expected_chunk),
|
||||
chunks: chunk_entries,
|
||||
});
|
||||
}
|
||||
|
||||
let missing_expected_chunk_ids = expected_chunk_ids
|
||||
.iter()
|
||||
.filter(|id| !seen_chunks.contains(id.as_str()))
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut failure_reasons = Vec::new();
|
||||
if !summary.entity_match {
|
||||
failure_reasons.push("entity_miss".to_string());
|
||||
}
|
||||
if !summary.chunk_text_match {
|
||||
failure_reasons.push("chunk_text_missing".to_string());
|
||||
}
|
||||
if !summary.chunk_id_match {
|
||||
failure_reasons.push("chunk_id_missing".to_string());
|
||||
}
|
||||
if !missing_expected_chunk_ids.is_empty() {
|
||||
failure_reasons.push("expected_chunk_absent".to_string());
|
||||
}
|
||||
|
||||
CaseDiagnostics {
|
||||
question_id: summary.question_id.clone(),
|
||||
question: summary.question.clone(),
|
||||
paragraph_id: summary.paragraph_id.clone(),
|
||||
paragraph_title: summary.paragraph_title.clone(),
|
||||
expected_source: summary.expected_source.clone(),
|
||||
expected_chunk_ids: expected_chunk_ids.to_vec(),
|
||||
answers: summary.answers.clone(),
|
||||
entity_match: summary.entity_match,
|
||||
chunk_text_match: summary.chunk_text_match,
|
||||
chunk_id_match: summary.chunk_id_match,
|
||||
failure_reasons,
|
||||
missing_expected_chunk_ids,
|
||||
attached_chunk_ids,
|
||||
retrieved: entity_diagnostics,
|
||||
pipeline: pipeline_stats,
|
||||
}
|
||||
}
|
||||
72
eval/src/ingest/config.rs
Normal file
72
eval/src/ingest/config.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::{args::Config, embedding::EmbeddingProvider};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CorpusCacheConfig {
|
||||
pub ingestion_cache_dir: PathBuf,
|
||||
pub force_refresh: bool,
|
||||
pub refresh_embeddings_only: bool,
|
||||
pub ingestion_batch_size: usize,
|
||||
pub ingestion_max_retries: usize,
|
||||
}
|
||||
|
||||
impl CorpusCacheConfig {
|
||||
pub fn new(
|
||||
ingestion_cache_dir: impl Into<PathBuf>,
|
||||
force_refresh: bool,
|
||||
refresh_embeddings_only: bool,
|
||||
ingestion_batch_size: usize,
|
||||
ingestion_max_retries: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
ingestion_cache_dir: ingestion_cache_dir.into(),
|
||||
force_refresh,
|
||||
refresh_embeddings_only,
|
||||
ingestion_batch_size,
|
||||
ingestion_max_retries,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait CorpusEmbeddingProvider: Send + Sync {
|
||||
fn backend_label(&self) -> &str;
|
||||
fn model_code(&self) -> Option<String>;
|
||||
fn dimension(&self) -> usize;
|
||||
async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl CorpusEmbeddingProvider for EmbeddingProvider {
|
||||
fn backend_label(&self) -> &str {
|
||||
EmbeddingProvider::backend_label(self)
|
||||
}
|
||||
|
||||
fn model_code(&self) -> Option<String> {
|
||||
EmbeddingProvider::model_code(self)
|
||||
}
|
||||
|
||||
fn dimension(&self) -> usize {
|
||||
EmbeddingProvider::dimension(self)
|
||||
}
|
||||
|
||||
async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
|
||||
EmbeddingProvider::embed_batch(self, texts).await
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Config> for CorpusCacheConfig {
|
||||
fn from(config: &Config) -> Self {
|
||||
CorpusCacheConfig::new(
|
||||
config.ingestion_cache_dir.clone(),
|
||||
config.force_convert || config.slice_reset_ingestion,
|
||||
config.refresh_embeddings_only,
|
||||
config.ingestion_batch_size,
|
||||
config.ingestion_max_retries,
|
||||
)
|
||||
}
|
||||
}
|
||||
10
eval/src/ingest/mod.rs
Normal file
10
eval/src/ingest/mod.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
mod config;
|
||||
mod orchestrator;
|
||||
mod store;
|
||||
|
||||
pub use config::{CorpusCacheConfig, CorpusEmbeddingProvider};
|
||||
pub use orchestrator::ensure_corpus;
|
||||
pub use store::{
|
||||
seed_manifest_into_db, CorpusHandle, CorpusManifest, CorpusMetadata, CorpusQuestion,
|
||||
ParagraphShard, ParagraphShardStore, MANIFEST_VERSION,
|
||||
};
|
||||
@@ -1,23 +1,21 @@
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
fs,
|
||||
io::{BufReader, Read},
|
||||
path::{Path, PathBuf},
|
||||
io::Read,
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use async_openai::Client;
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use chrono::Utc;
|
||||
use common::{
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
store::{DynStore, StorageManager},
|
||||
types::{
|
||||
ingestion_payload::IngestionPayload, ingestion_task::IngestionTask,
|
||||
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
|
||||
text_chunk::TextChunk, text_content::TextContent,
|
||||
knowledge_entity::KnowledgeEntity, text_chunk::TextChunk,
|
||||
},
|
||||
},
|
||||
utils::config::{AppConfig, StorageKind},
|
||||
@@ -30,274 +28,19 @@ use tracing::{info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
args::Config,
|
||||
datasets::{ConvertedDataset, ConvertedParagraph, ConvertedQuestion},
|
||||
embedding::EmbeddingProvider,
|
||||
slices::{self, ResolvedSlice, SliceParagraphKind},
|
||||
};
|
||||
|
||||
const MANIFEST_VERSION: u32 = 1;
|
||||
use crate::ingest::{
|
||||
CorpusCacheConfig, CorpusEmbeddingProvider, CorpusHandle, CorpusManifest, CorpusMetadata,
|
||||
CorpusQuestion, ParagraphShard, ParagraphShardStore, MANIFEST_VERSION,
|
||||
};
|
||||
|
||||
const INGESTION_SPEC_VERSION: u32 = 1;
|
||||
const INGESTION_MAX_RETRIES: usize = 3;
|
||||
const INGESTION_BATCH_SIZE: usize = 5;
|
||||
const PARAGRAPH_SHARD_VERSION: u32 = 1;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CorpusCacheConfig {
|
||||
pub ingestion_cache_dir: PathBuf,
|
||||
pub force_refresh: bool,
|
||||
pub refresh_embeddings_only: bool,
|
||||
}
|
||||
|
||||
impl CorpusCacheConfig {
|
||||
pub fn new(
|
||||
ingestion_cache_dir: impl Into<PathBuf>,
|
||||
force_refresh: bool,
|
||||
refresh_embeddings_only: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
ingestion_cache_dir: ingestion_cache_dir.into(),
|
||||
force_refresh,
|
||||
refresh_embeddings_only,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait CorpusEmbeddingProvider: Send + Sync {
|
||||
fn backend_label(&self) -> &str;
|
||||
fn model_code(&self) -> Option<String>;
|
||||
fn dimension(&self) -> usize;
|
||||
async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>>;
|
||||
}
|
||||
|
||||
type OpenAIClient = Client<async_openai::config::OpenAIConfig>;
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct CorpusManifest {
|
||||
pub version: u32,
|
||||
pub metadata: CorpusMetadata,
|
||||
pub paragraphs: Vec<CorpusParagraph>,
|
||||
pub questions: Vec<CorpusQuestion>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct CorpusMetadata {
|
||||
pub dataset_id: String,
|
||||
pub dataset_label: String,
|
||||
pub slice_id: String,
|
||||
pub include_unanswerable: bool,
|
||||
pub ingestion_fingerprint: String,
|
||||
pub embedding_backend: String,
|
||||
pub embedding_model: Option<String>,
|
||||
pub embedding_dimension: usize,
|
||||
pub converted_checksum: String,
|
||||
pub generated_at: DateTime<Utc>,
|
||||
pub paragraph_count: usize,
|
||||
pub question_count: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct CorpusParagraph {
|
||||
pub paragraph_id: String,
|
||||
pub title: String,
|
||||
pub text_content: TextContent,
|
||||
pub entities: Vec<KnowledgeEntity>,
|
||||
pub relationships: Vec<KnowledgeRelationship>,
|
||||
pub chunks: Vec<TextChunk>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct CorpusQuestion {
|
||||
pub question_id: String,
|
||||
pub paragraph_id: String,
|
||||
pub text_content_id: String,
|
||||
pub question_text: String,
|
||||
pub answers: Vec<String>,
|
||||
pub is_impossible: bool,
|
||||
pub matching_chunk_ids: Vec<String>,
|
||||
}
|
||||
|
||||
pub struct CorpusHandle {
|
||||
pub manifest: CorpusManifest,
|
||||
pub path: PathBuf,
|
||||
pub reused_ingestion: bool,
|
||||
pub reused_embeddings: bool,
|
||||
pub positive_reused: usize,
|
||||
pub positive_ingested: usize,
|
||||
pub negative_reused: usize,
|
||||
pub negative_ingested: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
struct ParagraphShard {
|
||||
version: u32,
|
||||
paragraph_id: String,
|
||||
shard_path: String,
|
||||
ingestion_fingerprint: String,
|
||||
ingested_at: DateTime<Utc>,
|
||||
title: String,
|
||||
text_content: TextContent,
|
||||
entities: Vec<KnowledgeEntity>,
|
||||
relationships: Vec<KnowledgeRelationship>,
|
||||
chunks: Vec<TextChunk>,
|
||||
#[serde(default)]
|
||||
question_bindings: HashMap<String, Vec<String>>,
|
||||
#[serde(default)]
|
||||
embedding_backend: String,
|
||||
#[serde(default)]
|
||||
embedding_model: Option<String>,
|
||||
#[serde(default)]
|
||||
embedding_dimension: usize,
|
||||
}
|
||||
|
||||
struct ParagraphShardStore {
|
||||
base_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl ParagraphShardStore {
|
||||
fn new(base_dir: PathBuf) -> Self {
|
||||
Self { base_dir }
|
||||
}
|
||||
|
||||
fn ensure_base_dir(&self) -> Result<()> {
|
||||
fs::create_dir_all(&self.base_dir)
|
||||
.with_context(|| format!("creating shard base dir {}", self.base_dir.display()))
|
||||
}
|
||||
|
||||
fn resolve(&self, relative: &str) -> PathBuf {
|
||||
self.base_dir.join(relative)
|
||||
}
|
||||
|
||||
fn load(&self, relative: &str, fingerprint: &str) -> Result<Option<ParagraphShard>> {
|
||||
let path = self.resolve(relative);
|
||||
let file = match fs::File::open(&path) {
|
||||
Ok(file) => file,
|
||||
Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
|
||||
Err(err) => {
|
||||
return Err(err).with_context(|| format!("opening shard {}", path.display()))
|
||||
}
|
||||
};
|
||||
let reader = BufReader::new(file);
|
||||
let mut shard: ParagraphShard = serde_json::from_reader(reader)
|
||||
.with_context(|| format!("parsing shard {}", path.display()))?;
|
||||
if shard.version != PARAGRAPH_SHARD_VERSION {
|
||||
warn!(
|
||||
path = %path.display(),
|
||||
version = shard.version,
|
||||
expected = PARAGRAPH_SHARD_VERSION,
|
||||
"Skipping shard due to version mismatch"
|
||||
);
|
||||
return Ok(None);
|
||||
}
|
||||
if shard.ingestion_fingerprint != fingerprint {
|
||||
return Ok(None);
|
||||
}
|
||||
shard.shard_path = relative.to_string();
|
||||
Ok(Some(shard))
|
||||
}
|
||||
|
||||
fn persist(&self, shard: &ParagraphShard) -> Result<()> {
|
||||
let path = self.resolve(&shard.shard_path);
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.with_context(|| format!("creating shard dir {}", parent.display()))?;
|
||||
}
|
||||
let tmp_path = path.with_extension("json.tmp");
|
||||
let body = serde_json::to_vec_pretty(shard).context("serialising paragraph shard")?;
|
||||
fs::write(&tmp_path, &body)
|
||||
.with_context(|| format!("writing shard tmp {}", tmp_path.display()))?;
|
||||
fs::rename(&tmp_path, &path)
|
||||
.with_context(|| format!("renaming shard tmp {}", path.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl CorpusEmbeddingProvider for EmbeddingProvider {
|
||||
fn backend_label(&self) -> &str {
|
||||
EmbeddingProvider::backend_label(self)
|
||||
}
|
||||
|
||||
fn model_code(&self) -> Option<String> {
|
||||
EmbeddingProvider::model_code(self)
|
||||
}
|
||||
|
||||
fn dimension(&self) -> usize {
|
||||
EmbeddingProvider::dimension(self)
|
||||
}
|
||||
|
||||
async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
|
||||
EmbeddingProvider::embed_batch(self, texts).await
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Config> for CorpusCacheConfig {
|
||||
fn from(config: &Config) -> Self {
|
||||
CorpusCacheConfig::new(
|
||||
config.ingestion_cache_dir.clone(),
|
||||
config.force_convert || config.slice_reset_ingestion,
|
||||
config.refresh_embeddings_only,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl ParagraphShard {
|
||||
fn new(
|
||||
paragraph: &ConvertedParagraph,
|
||||
shard_path: String,
|
||||
ingestion_fingerprint: &str,
|
||||
text_content: TextContent,
|
||||
entities: Vec<KnowledgeEntity>,
|
||||
relationships: Vec<KnowledgeRelationship>,
|
||||
chunks: Vec<TextChunk>,
|
||||
embedding_backend: &str,
|
||||
embedding_model: Option<String>,
|
||||
embedding_dimension: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
version: PARAGRAPH_SHARD_VERSION,
|
||||
paragraph_id: paragraph.id.clone(),
|
||||
shard_path,
|
||||
ingestion_fingerprint: ingestion_fingerprint.to_string(),
|
||||
ingested_at: Utc::now(),
|
||||
title: paragraph.title.clone(),
|
||||
text_content,
|
||||
entities,
|
||||
relationships,
|
||||
chunks,
|
||||
question_bindings: HashMap::new(),
|
||||
embedding_backend: embedding_backend.to_string(),
|
||||
embedding_model,
|
||||
embedding_dimension,
|
||||
}
|
||||
}
|
||||
|
||||
fn to_corpus_paragraph(&self) -> CorpusParagraph {
|
||||
CorpusParagraph {
|
||||
paragraph_id: self.paragraph_id.clone(),
|
||||
title: self.title.clone(),
|
||||
text_content: self.text_content.clone(),
|
||||
entities: self.entities.clone(),
|
||||
relationships: self.relationships.clone(),
|
||||
chunks: self.chunks.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn ensure_question_binding(
|
||||
&mut self,
|
||||
question: &ConvertedQuestion,
|
||||
) -> Result<(Vec<String>, bool)> {
|
||||
if let Some(existing) = self.question_bindings.get(&question.id) {
|
||||
return Ok((existing.clone(), false));
|
||||
}
|
||||
let chunk_ids = validate_answers(&self.text_content, &self.chunks, question)?;
|
||||
self.question_bindings
|
||||
.insert(question.id.clone(), chunk_ids.clone());
|
||||
Ok((chunk_ids, true))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ParagraphShardRecord {
|
||||
shard: ParagraphShard,
|
||||
@@ -390,6 +133,7 @@ pub async fn ensure_corpus<E: CorpusEmbeddingProvider>(
|
||||
store.ensure_base_dir()?;
|
||||
|
||||
let positive_set: HashSet<&str> = window.positive_ids().collect();
|
||||
let require_verified_chunks = slice.manifest.require_verified_chunks;
|
||||
let embedding_backend_label = embedding.backend_label().to_string();
|
||||
let embedding_model_code = embedding.model_code();
|
||||
let embedding_dimension = embedding.dimension();
|
||||
@@ -487,6 +231,8 @@ pub async fn ensure_corpus<E: CorpusEmbeddingProvider>(
|
||||
&embedding_backend_label,
|
||||
embedding_model_code.clone(),
|
||||
embedding_dimension,
|
||||
cache.ingestion_batch_size,
|
||||
cache.ingestion_max_retries,
|
||||
)
|
||||
.await
|
||||
.context("ingesting missing slice paragraphs")?;
|
||||
@@ -548,6 +294,12 @@ pub async fn ensure_corpus<E: CorpusEmbeddingProvider>(
|
||||
let (chunk_ids, updated) = match record.shard.ensure_question_binding(case.question) {
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
if require_verified_chunks {
|
||||
return Err(err).context(format!(
|
||||
"locating answer text for question '{}' in paragraph '{}'",
|
||||
case.question.id, case.paragraph.id
|
||||
));
|
||||
}
|
||||
warn!(
|
||||
question_id = %case.question.id,
|
||||
paragraph_id = %case.paragraph.id,
|
||||
@@ -591,6 +343,7 @@ pub async fn ensure_corpus<E: CorpusEmbeddingProvider>(
|
||||
dataset_label: dataset.metadata.label.clone(),
|
||||
slice_id: slice.manifest.slice_id.clone(),
|
||||
include_unanswerable: slice.manifest.includes_unanswerable,
|
||||
require_verified_chunks: slice.manifest.require_verified_chunks,
|
||||
ingestion_fingerprint: ingestion_fingerprint.clone(),
|
||||
embedding_backend: embedding.backend_label().to_string(),
|
||||
embedding_model: embedding.model_code(),
|
||||
@@ -681,6 +434,8 @@ async fn ingest_paragraph_batch<E: CorpusEmbeddingProvider>(
|
||||
embedding_backend: &str,
|
||||
embedding_model: Option<String>,
|
||||
embedding_dimension: usize,
|
||||
batch_size: usize,
|
||||
max_retries: usize,
|
||||
) -> Result<Vec<ParagraphShard>> {
|
||||
if targets.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
@@ -704,7 +459,7 @@ async fn ingest_paragraph_batch<E: CorpusEmbeddingProvider>(
|
||||
db,
|
||||
openai.clone(),
|
||||
app_config,
|
||||
None::<Arc<composite_retrieval::reranking::RerankerPool>>,
|
||||
None::<Arc<retrieval_pipeline::reranking::RerankerPool>>,
|
||||
storage,
|
||||
)
|
||||
.await?;
|
||||
@@ -712,11 +467,11 @@ async fn ingest_paragraph_batch<E: CorpusEmbeddingProvider>(
|
||||
|
||||
let mut shards = Vec::with_capacity(targets.len());
|
||||
let category = dataset.metadata.category.clone();
|
||||
for (batch_index, batch) in targets.chunks(INGESTION_BATCH_SIZE).enumerate() {
|
||||
for (batch_index, batch) in targets.chunks(batch_size).enumerate() {
|
||||
info!(
|
||||
batch = batch_index,
|
||||
batch_size = batch.len(),
|
||||
total_batches = (targets.len() + INGESTION_BATCH_SIZE - 1) / INGESTION_BATCH_SIZE,
|
||||
total_batches = (targets.len() + batch_size - 1) / batch_size,
|
||||
"Ingesting paragraph batch"
|
||||
);
|
||||
let model_clone = embedding_model.clone();
|
||||
@@ -734,6 +489,7 @@ async fn ingest_paragraph_batch<E: CorpusEmbeddingProvider>(
|
||||
backend_clone.clone(),
|
||||
model_clone.clone(),
|
||||
embedding_dimension,
|
||||
max_retries,
|
||||
)
|
||||
});
|
||||
let batch_results: Vec<ParagraphShard> = try_join_all(tasks)
|
||||
@@ -755,10 +511,11 @@ async fn ingest_single_paragraph<E: CorpusEmbeddingProvider>(
|
||||
embedding_backend: String,
|
||||
embedding_model: Option<String>,
|
||||
embedding_dimension: usize,
|
||||
max_retries: usize,
|
||||
) -> Result<ParagraphShard> {
|
||||
let paragraph = request.paragraph;
|
||||
let mut last_err: Option<anyhow::Error> = None;
|
||||
for attempt in 1..=INGESTION_MAX_RETRIES {
|
||||
for attempt in 1..=max_retries {
|
||||
let payload = IngestionPayload::Text {
|
||||
text: paragraph.context.clone(),
|
||||
context: paragraph.title.clone(),
|
||||
@@ -801,7 +558,7 @@ async fn ingest_single_paragraph<E: CorpusEmbeddingProvider>(
|
||||
warn!(
|
||||
paragraph_id = %paragraph.id,
|
||||
attempt,
|
||||
max_attempts = INGESTION_MAX_RETRIES,
|
||||
max_attempts = max_retries,
|
||||
error = ?err,
|
||||
"ingestion attempt failed for paragraph; retrying"
|
||||
);
|
||||
@@ -815,49 +572,6 @@ async fn ingest_single_paragraph<E: CorpusEmbeddingProvider>(
|
||||
.context(format!("running ingestion for paragraph {}", paragraph.id)))
|
||||
}
|
||||
|
||||
fn validate_answers(
|
||||
content: &TextContent,
|
||||
chunks: &[TextChunk],
|
||||
question: &ConvertedQuestion,
|
||||
) -> Result<Vec<String>> {
|
||||
if question.is_impossible || question.answers.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut matches = std::collections::BTreeSet::new();
|
||||
let mut found_any = false;
|
||||
let haystack = content.text.to_ascii_lowercase();
|
||||
let haystack_norm = normalize_answer_text(&haystack);
|
||||
for answer in &question.answers {
|
||||
let needle: String = answer.to_ascii_lowercase();
|
||||
let needle_norm = normalize_answer_text(&needle);
|
||||
let text_match = haystack.contains(&needle)
|
||||
|| (!needle_norm.is_empty() && haystack_norm.contains(&needle_norm));
|
||||
if text_match {
|
||||
found_any = true;
|
||||
}
|
||||
for chunk in chunks {
|
||||
let chunk_text = chunk.chunk.to_ascii_lowercase();
|
||||
let chunk_norm = normalize_answer_text(&chunk_text);
|
||||
if chunk_text.contains(&needle)
|
||||
|| (!needle_norm.is_empty() && chunk_norm.contains(&needle_norm))
|
||||
{
|
||||
matches.insert(chunk.id.clone());
|
||||
found_any = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !found_any {
|
||||
Err(anyhow!(
|
||||
"expected answer for question '{}' was not found in ingested content",
|
||||
question.id
|
||||
))
|
||||
} else {
|
||||
Ok(matches.into_iter().collect())
|
||||
}
|
||||
}
|
||||
|
||||
fn build_ingestion_fingerprint(
|
||||
dataset: &ConvertedDataset,
|
||||
slice: &ResolvedSlice<'_>,
|
||||
@@ -894,107 +608,3 @@ fn compute_file_checksum(path: &Path) -> Result<String> {
|
||||
}
|
||||
Ok(format!("{:x}", hasher.finalize()))
|
||||
}
|
||||
|
||||
pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManifest) -> Result<()> {
|
||||
for paragraph in &manifest.paragraphs {
|
||||
db.store_item(paragraph.text_content.clone())
|
||||
.await
|
||||
.context("storing text_content from manifest")?;
|
||||
for entity in ¶graph.entities {
|
||||
db.store_item(entity.clone())
|
||||
.await
|
||||
.context("storing knowledge_entity from manifest")?;
|
||||
}
|
||||
for relationship in ¶graph.relationships {
|
||||
relationship
|
||||
.store_relationship(db)
|
||||
.await
|
||||
.context("storing knowledge_relationship from manifest")?;
|
||||
}
|
||||
for chunk in ¶graph.chunks {
|
||||
db.store_item(chunk.clone())
|
||||
.await
|
||||
.context("storing text_chunk from manifest")?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn normalize_answer_text(text: &str) -> String {
|
||||
text.chars()
|
||||
.map(|ch| {
|
||||
if ch.is_alphanumeric() || ch.is_whitespace() {
|
||||
ch.to_ascii_lowercase()
|
||||
} else {
|
||||
' '
|
||||
}
|
||||
})
|
||||
.collect::<String>()
|
||||
.split_whitespace()
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::datasets::ConvertedQuestion;
|
||||
|
||||
fn mock_text_content() -> TextContent {
|
||||
TextContent {
|
||||
id: "tc1".into(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
text: "alpha beta gamma".into(),
|
||||
file_info: None,
|
||||
url_info: None,
|
||||
context: Some("ctx".into()),
|
||||
category: "cat".into(),
|
||||
user_id: "user".into(),
|
||||
}
|
||||
}
|
||||
|
||||
fn mock_chunk(id: &str, text: &str) -> TextChunk {
|
||||
TextChunk {
|
||||
id: id.into(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
source_id: "src".into(),
|
||||
chunk: text.into(),
|
||||
embedding: vec![],
|
||||
user_id: "user".into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_answers_passes_when_present() {
|
||||
let content = mock_text_content();
|
||||
let chunk = mock_chunk("chunk1", "alpha chunk");
|
||||
let question = ConvertedQuestion {
|
||||
id: "q1".into(),
|
||||
question: "?".into(),
|
||||
answers: vec!["Alpha".into()],
|
||||
is_impossible: false,
|
||||
};
|
||||
let matches = validate_answers(&content, &[chunk], &question).expect("answers match");
|
||||
assert_eq!(matches, vec!["chunk1".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_answers_fails_when_missing() {
|
||||
let question = ConvertedQuestion {
|
||||
id: "q1".into(),
|
||||
question: "?".into(),
|
||||
answers: vec!["delta".into()],
|
||||
is_impossible: false,
|
||||
};
|
||||
let err = validate_answers(
|
||||
&mock_text_content(),
|
||||
&[mock_chunk("chunk", "alpha")],
|
||||
&question,
|
||||
)
|
||||
.expect_err("missing answer should fail");
|
||||
assert!(err.to_string().contains("not found"));
|
||||
}
|
||||
}
|
||||
299
eval/src/ingest/store.rs
Normal file
299
eval/src/ingest/store.rs
Normal file
@@ -0,0 +1,299 @@
|
||||
use std::{collections::HashMap, fs, io::BufReader, path::PathBuf};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use common::storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{
|
||||
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
|
||||
text_chunk::TextChunk, text_content::TextContent,
|
||||
},
|
||||
};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::datasets::{ConvertedParagraph, ConvertedQuestion};
|
||||
|
||||
pub const MANIFEST_VERSION: u32 = 1;
|
||||
pub const PARAGRAPH_SHARD_VERSION: u32 = 1;
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct CorpusManifest {
|
||||
pub version: u32,
|
||||
pub metadata: CorpusMetadata,
|
||||
pub paragraphs: Vec<CorpusParagraph>,
|
||||
pub questions: Vec<CorpusQuestion>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct CorpusMetadata {
|
||||
pub dataset_id: String,
|
||||
pub dataset_label: String,
|
||||
pub slice_id: String,
|
||||
pub include_unanswerable: bool,
|
||||
#[serde(default)]
|
||||
pub require_verified_chunks: bool,
|
||||
pub ingestion_fingerprint: String,
|
||||
pub embedding_backend: String,
|
||||
pub embedding_model: Option<String>,
|
||||
pub embedding_dimension: usize,
|
||||
pub converted_checksum: String,
|
||||
pub generated_at: DateTime<Utc>,
|
||||
pub paragraph_count: usize,
|
||||
pub question_count: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct CorpusParagraph {
|
||||
pub paragraph_id: String,
|
||||
pub title: String,
|
||||
pub text_content: TextContent,
|
||||
pub entities: Vec<KnowledgeEntity>,
|
||||
pub relationships: Vec<KnowledgeRelationship>,
|
||||
pub chunks: Vec<TextChunk>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct CorpusQuestion {
|
||||
pub question_id: String,
|
||||
pub paragraph_id: String,
|
||||
pub text_content_id: String,
|
||||
pub question_text: String,
|
||||
pub answers: Vec<String>,
|
||||
pub is_impossible: bool,
|
||||
pub matching_chunk_ids: Vec<String>,
|
||||
}
|
||||
|
||||
pub struct CorpusHandle {
|
||||
pub manifest: CorpusManifest,
|
||||
pub path: PathBuf,
|
||||
pub reused_ingestion: bool,
|
||||
pub reused_embeddings: bool,
|
||||
pub positive_reused: usize,
|
||||
pub positive_ingested: usize,
|
||||
pub negative_reused: usize,
|
||||
pub negative_ingested: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct ParagraphShard {
|
||||
pub version: u32,
|
||||
pub paragraph_id: String,
|
||||
pub shard_path: String,
|
||||
pub ingestion_fingerprint: String,
|
||||
pub ingested_at: DateTime<Utc>,
|
||||
pub title: String,
|
||||
pub text_content: TextContent,
|
||||
pub entities: Vec<KnowledgeEntity>,
|
||||
pub relationships: Vec<KnowledgeRelationship>,
|
||||
pub chunks: Vec<TextChunk>,
|
||||
#[serde(default)]
|
||||
pub question_bindings: HashMap<String, Vec<String>>,
|
||||
#[serde(default)]
|
||||
pub embedding_backend: String,
|
||||
#[serde(default)]
|
||||
pub embedding_model: Option<String>,
|
||||
#[serde(default)]
|
||||
pub embedding_dimension: usize,
|
||||
}
|
||||
|
||||
pub struct ParagraphShardStore {
|
||||
base_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl ParagraphShardStore {
|
||||
pub fn new(base_dir: PathBuf) -> Self {
|
||||
Self { base_dir }
|
||||
}
|
||||
|
||||
pub fn ensure_base_dir(&self) -> Result<()> {
|
||||
fs::create_dir_all(&self.base_dir)
|
||||
.with_context(|| format!("creating shard base dir {}", self.base_dir.display()))
|
||||
}
|
||||
|
||||
fn resolve(&self, relative: &str) -> PathBuf {
|
||||
self.base_dir.join(relative)
|
||||
}
|
||||
|
||||
pub fn load(&self, relative: &str, fingerprint: &str) -> Result<Option<ParagraphShard>> {
|
||||
let path = self.resolve(relative);
|
||||
let file = match fs::File::open(&path) {
|
||||
Ok(file) => file,
|
||||
Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
|
||||
Err(err) => {
|
||||
return Err(err).with_context(|| format!("opening shard {}", path.display()))
|
||||
}
|
||||
};
|
||||
let reader = BufReader::new(file);
|
||||
let mut shard: ParagraphShard = serde_json::from_reader(reader)
|
||||
.with_context(|| format!("parsing shard {}", path.display()))?;
|
||||
if shard.version != PARAGRAPH_SHARD_VERSION {
|
||||
warn!(
|
||||
path = %path.display(),
|
||||
version = shard.version,
|
||||
expected = PARAGRAPH_SHARD_VERSION,
|
||||
"Skipping shard due to version mismatch"
|
||||
);
|
||||
return Ok(None);
|
||||
}
|
||||
if shard.ingestion_fingerprint != fingerprint {
|
||||
return Ok(None);
|
||||
}
|
||||
shard.shard_path = relative.to_string();
|
||||
Ok(Some(shard))
|
||||
}
|
||||
|
||||
pub fn persist(&self, shard: &ParagraphShard) -> Result<()> {
|
||||
let path = self.resolve(&shard.shard_path);
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.with_context(|| format!("creating shard dir {}", parent.display()))?;
|
||||
}
|
||||
let tmp_path = path.with_extension("json.tmp");
|
||||
let body = serde_json::to_vec_pretty(shard).context("serialising paragraph shard")?;
|
||||
fs::write(&tmp_path, &body)
|
||||
.with_context(|| format!("writing shard tmp {}", tmp_path.display()))?;
|
||||
fs::rename(&tmp_path, &path)
|
||||
.with_context(|| format!("renaming shard tmp {}", path.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl ParagraphShard {
|
||||
pub fn new(
|
||||
paragraph: &ConvertedParagraph,
|
||||
shard_path: String,
|
||||
ingestion_fingerprint: &str,
|
||||
text_content: TextContent,
|
||||
entities: Vec<KnowledgeEntity>,
|
||||
relationships: Vec<KnowledgeRelationship>,
|
||||
chunks: Vec<TextChunk>,
|
||||
embedding_backend: &str,
|
||||
embedding_model: Option<String>,
|
||||
embedding_dimension: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
version: PARAGRAPH_SHARD_VERSION,
|
||||
paragraph_id: paragraph.id.clone(),
|
||||
shard_path,
|
||||
ingestion_fingerprint: ingestion_fingerprint.to_string(),
|
||||
ingested_at: Utc::now(),
|
||||
title: paragraph.title.clone(),
|
||||
text_content,
|
||||
entities,
|
||||
relationships,
|
||||
chunks,
|
||||
question_bindings: HashMap::new(),
|
||||
embedding_backend: embedding_backend.to_string(),
|
||||
embedding_model,
|
||||
embedding_dimension,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_corpus_paragraph(&self) -> CorpusParagraph {
|
||||
CorpusParagraph {
|
||||
paragraph_id: self.paragraph_id.clone(),
|
||||
title: self.title.clone(),
|
||||
text_content: self.text_content.clone(),
|
||||
entities: self.entities.clone(),
|
||||
relationships: self.relationships.clone(),
|
||||
chunks: self.chunks.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ensure_question_binding(
|
||||
&mut self,
|
||||
question: &ConvertedQuestion,
|
||||
) -> Result<(Vec<String>, bool)> {
|
||||
if let Some(existing) = self.question_bindings.get(&question.id) {
|
||||
return Ok((existing.clone(), false));
|
||||
}
|
||||
let chunk_ids = validate_answers(&self.text_content, &self.chunks, question)?;
|
||||
self.question_bindings
|
||||
.insert(question.id.clone(), chunk_ids.clone());
|
||||
Ok((chunk_ids, true))
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_answers(
|
||||
content: &TextContent,
|
||||
chunks: &[TextChunk],
|
||||
question: &ConvertedQuestion,
|
||||
) -> Result<Vec<String>> {
|
||||
if question.is_impossible || question.answers.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut matches = std::collections::BTreeSet::new();
|
||||
let mut found_any = false;
|
||||
let haystack = content.text.to_ascii_lowercase();
|
||||
let haystack_norm = normalize_answer_text(&haystack);
|
||||
for answer in &question.answers {
|
||||
let needle: String = answer.to_ascii_lowercase();
|
||||
let needle_norm = normalize_answer_text(&needle);
|
||||
let text_match = haystack.contains(&needle)
|
||||
|| (!needle_norm.is_empty() && haystack_norm.contains(&needle_norm));
|
||||
if text_match {
|
||||
found_any = true;
|
||||
}
|
||||
for chunk in chunks {
|
||||
let chunk_text = chunk.chunk.to_ascii_lowercase();
|
||||
let chunk_norm = normalize_answer_text(&chunk_text);
|
||||
if chunk_text.contains(&needle)
|
||||
|| (!needle_norm.is_empty() && chunk_norm.contains(&needle_norm))
|
||||
{
|
||||
matches.insert(chunk.id.clone());
|
||||
found_any = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !found_any {
|
||||
Err(anyhow!(
|
||||
"expected answer for question '{}' was not found in ingested content",
|
||||
question.id
|
||||
))
|
||||
} else {
|
||||
Ok(matches.into_iter().collect())
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_answer_text(text: &str) -> String {
|
||||
text.chars()
|
||||
.map(|ch| {
|
||||
if ch.is_alphanumeric() || ch.is_whitespace() {
|
||||
ch.to_ascii_lowercase()
|
||||
} else {
|
||||
' '
|
||||
}
|
||||
})
|
||||
.collect::<String>()
|
||||
.split_whitespace()
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ")
|
||||
}
|
||||
|
||||
pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManifest) -> Result<()> {
|
||||
for paragraph in &manifest.paragraphs {
|
||||
db.upsert_item(paragraph.text_content.clone())
|
||||
.await
|
||||
.context("storing text_content from manifest")?;
|
||||
for entity in ¶graph.entities {
|
||||
db.upsert_item(entity.clone())
|
||||
.await
|
||||
.context("storing knowledge_entity from manifest")?;
|
||||
}
|
||||
for relationship in ¶graph.relationships {
|
||||
relationship
|
||||
.store_relationship(db)
|
||||
.await
|
||||
.context("storing knowledge_relationship from manifest")?;
|
||||
}
|
||||
for chunk in ¶graph.chunks {
|
||||
db.upsert_item(chunk.clone())
|
||||
.await
|
||||
.context("storing text_chunk from manifest")?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -194,17 +194,34 @@ async fn async_main() -> anyhow::Result<()> {
|
||||
)
|
||||
})?;
|
||||
|
||||
println!(
|
||||
"[{}] Precision@{k}: {precision:.3} ({correct}/{total}) → JSON: {json} | Markdown: {md} | Perf: {perf}",
|
||||
summary.dataset_label,
|
||||
k = summary.k,
|
||||
precision = summary.precision,
|
||||
correct = summary.correct,
|
||||
total = summary.total_cases,
|
||||
json = report_paths.json.display(),
|
||||
md = report_paths.markdown.display(),
|
||||
perf = perf_log_path.display()
|
||||
);
|
||||
if summary.llm_cases > 0 {
|
||||
println!(
|
||||
"[{}] Retrieval Precision@{k}: {precision:.3} ({correct}/{retrieval_total}) + LLM: {llm_answered}/{llm_total} ({llm_precision:.3}) → JSON: {json} | Markdown: {md} | Perf: {perf}",
|
||||
summary.dataset_label,
|
||||
k = summary.k,
|
||||
precision = summary.precision,
|
||||
correct = summary.correct,
|
||||
retrieval_total = summary.retrieval_cases,
|
||||
llm_answered = summary.llm_answered,
|
||||
llm_total = summary.llm_cases,
|
||||
llm_precision = summary.llm_precision,
|
||||
json = report_paths.json.display(),
|
||||
md = report_paths.markdown.display(),
|
||||
perf = perf_log_path.display()
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
"[{}] Retrieval Precision@{k}: {precision:.3} ({correct}/{retrieval_total}) → JSON: {json} | Markdown: {md} | Perf: {perf}",
|
||||
summary.dataset_label,
|
||||
k = summary.k,
|
||||
precision = summary.precision,
|
||||
correct = summary.correct,
|
||||
retrieval_total = summary.retrieval_cases,
|
||||
json = report_paths.json.display(),
|
||||
md = report_paths.markdown.display(),
|
||||
perf = perf_log_path.display()
|
||||
);
|
||||
}
|
||||
|
||||
if parsed.config.perf_log_console {
|
||||
perf::print_console_summary(&summary);
|
||||
|
||||
@@ -19,6 +19,7 @@ struct PerformanceLogEntry {
|
||||
dataset_id: String,
|
||||
dataset_label: String,
|
||||
run_label: Option<String>,
|
||||
retrieval_strategy: String,
|
||||
slice_id: String,
|
||||
slice_seed: u64,
|
||||
slice_window_offset: usize,
|
||||
@@ -27,6 +28,10 @@ struct PerformanceLogEntry {
|
||||
total_cases: usize,
|
||||
correct: usize,
|
||||
precision: f64,
|
||||
retrieval_cases: usize,
|
||||
llm_cases: usize,
|
||||
llm_answered: usize,
|
||||
llm_precision: f64,
|
||||
k: usize,
|
||||
openai_base_url: String,
|
||||
ingestion: IngestionPerf,
|
||||
@@ -87,7 +92,7 @@ impl PerformanceLogEntry {
|
||||
rerank_enabled: summary.rerank_enabled,
|
||||
rerank_pool_size: summary.rerank_pool_size,
|
||||
rerank_keep_top: summary.rerank_keep_top,
|
||||
evaluated_cases: summary.total_cases,
|
||||
evaluated_cases: summary.retrieval_cases,
|
||||
};
|
||||
|
||||
Self {
|
||||
@@ -95,6 +100,7 @@ impl PerformanceLogEntry {
|
||||
dataset_id: summary.dataset_id.clone(),
|
||||
dataset_label: summary.dataset_label.clone(),
|
||||
run_label: summary.run_label.clone(),
|
||||
retrieval_strategy: summary.retrieval_strategy.clone(),
|
||||
slice_id: summary.slice_id.clone(),
|
||||
slice_seed: summary.slice_seed,
|
||||
slice_window_offset: summary.slice_window_offset,
|
||||
@@ -103,6 +109,10 @@ impl PerformanceLogEntry {
|
||||
total_cases: summary.total_cases,
|
||||
correct: summary.correct,
|
||||
precision: summary.precision,
|
||||
retrieval_cases: summary.retrieval_cases,
|
||||
llm_cases: summary.llm_cases,
|
||||
llm_answered: summary.llm_answered,
|
||||
llm_precision: summary.llm_precision,
|
||||
k: summary.k,
|
||||
openai_base_url: summary.perf.openai_base_url.clone(),
|
||||
ingestion,
|
||||
@@ -162,6 +172,13 @@ pub fn write_perf_logs(
|
||||
|
||||
pub fn print_console_summary(summary: &EvaluationSummary) {
|
||||
let perf = &summary.perf;
|
||||
println!(
|
||||
"[perf] retrieval strategy={} | rerank={} (pool {:?}, keep {})",
|
||||
summary.retrieval_strategy,
|
||||
summary.rerank_enabled,
|
||||
summary.rerank_pool_size,
|
||||
summary.rerank_keep_top
|
||||
);
|
||||
println!(
|
||||
"[perf] ingestion={}ms | namespace_seed={}",
|
||||
perf.ingestion_ms,
|
||||
@@ -169,7 +186,8 @@ pub fn print_console_summary(summary: &EvaluationSummary) {
|
||||
);
|
||||
let stage = &perf.stage_latency;
|
||||
println!(
|
||||
"[perf] stage avg ms → collect {:.1} | graph {:.1} | chunk {:.1} | rerank {:.1} | assemble {:.1}",
|
||||
"[perf] stage avg ms → embed {:.1} | collect {:.1} | graph {:.1} | chunk {:.1} | rerank {:.1} | assemble {:.1}",
|
||||
stage.embed.avg,
|
||||
stage.collect_candidates.avg,
|
||||
stage.graph_expansion.avg,
|
||||
stage.chunk_attach.avg,
|
||||
@@ -212,6 +230,7 @@ mod tests {
|
||||
|
||||
fn sample_stage_latency() -> crate::eval::StageLatencyBreakdown {
|
||||
crate::eval::StageLatencyBreakdown {
|
||||
embed: sample_latency(),
|
||||
collect_candidates: sample_latency(),
|
||||
graph_expansion: sample_latency(),
|
||||
chunk_attach: sample_latency(),
|
||||
@@ -252,6 +271,15 @@ mod tests {
|
||||
dataset_label: "SQuAD v2".into(),
|
||||
dataset_includes_unanswerable: false,
|
||||
dataset_source: "dev".into(),
|
||||
includes_impossible_cases: false,
|
||||
require_verified_chunks: true,
|
||||
filtered_questions: 0,
|
||||
retrieval_cases: 2,
|
||||
retrieval_correct: 1,
|
||||
retrieval_precision: 0.5,
|
||||
llm_cases: 0,
|
||||
llm_answered: 0,
|
||||
llm_precision: 0.0,
|
||||
slice_id: "slice123".into(),
|
||||
slice_seed: 42,
|
||||
slice_total_cases: 400,
|
||||
@@ -285,6 +313,7 @@ mod tests {
|
||||
rerank_pool_size: Some(4),
|
||||
rerank_keep_top: 10,
|
||||
concurrency: 2,
|
||||
retrieval_strategy: "initial".into(),
|
||||
detailed_report: false,
|
||||
chunk_vector_take: 20,
|
||||
chunk_fts_take: 20,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -23,5 +23,6 @@ pub fn slice_config_with_limit<'a>(
|
||||
slice_seed: config.slice_seed,
|
||||
llm_mode: config.llm_mode,
|
||||
negative_multiplier: config.negative_multiplier,
|
||||
require_verified_chunks: config.retrieval.require_verified_chunks,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ pub struct SliceConfig<'a> {
|
||||
pub slice_seed: u64,
|
||||
pub llm_mode: bool,
|
||||
pub negative_multiplier: f32,
|
||||
pub require_verified_chunks: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -36,6 +37,8 @@ pub struct SliceManifest {
|
||||
pub dataset_label: String,
|
||||
pub dataset_source: String,
|
||||
pub includes_unanswerable: bool,
|
||||
#[serde(default = "default_require_verified_chunks")]
|
||||
pub require_verified_chunks: bool,
|
||||
pub seed: u64,
|
||||
pub requested_limit: Option<usize>,
|
||||
pub requested_corpus: usize,
|
||||
@@ -49,6 +52,10 @@ pub struct SliceManifest {
|
||||
pub paragraphs: Vec<SliceParagraphEntry>,
|
||||
}
|
||||
|
||||
fn default_require_verified_chunks() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SliceCaseEntry {
|
||||
pub question_id: String,
|
||||
@@ -184,6 +191,7 @@ impl DatasetIndex {
|
||||
struct SliceKey<'a> {
|
||||
dataset_id: &'a str,
|
||||
includes_unanswerable: bool,
|
||||
require_verified_chunks: bool,
|
||||
requested_corpus: usize,
|
||||
seed: u64,
|
||||
}
|
||||
@@ -222,7 +230,8 @@ pub fn resolve_slice<'a>(
|
||||
.max(1);
|
||||
let key = SliceKey {
|
||||
dataset_id: dataset.metadata.id.as_str(),
|
||||
includes_unanswerable: dataset.metadata.include_unanswerable,
|
||||
includes_unanswerable: config.llm_mode,
|
||||
require_verified_chunks: config.require_verified_chunks,
|
||||
requested_corpus,
|
||||
seed: config.slice_seed,
|
||||
};
|
||||
@@ -248,13 +257,24 @@ pub fn resolve_slice<'a>(
|
||||
let mut manifest = if !config.force_convert && path.exists() {
|
||||
match read_manifest(&path) {
|
||||
Ok(manifest) if manifest.dataset_id == dataset.metadata.id => {
|
||||
if manifest.includes_unanswerable != dataset.metadata.include_unanswerable {
|
||||
if manifest.includes_unanswerable != config.llm_mode {
|
||||
warn!(
|
||||
slice = manifest.slice_id,
|
||||
path = %path.display(),
|
||||
expected = config.llm_mode,
|
||||
found = manifest.includes_unanswerable,
|
||||
"Slice manifest includes_unanswerable mismatch; regenerating"
|
||||
);
|
||||
None
|
||||
} else if manifest.require_verified_chunks != config.require_verified_chunks {
|
||||
warn!(
|
||||
slice = manifest.slice_id,
|
||||
path = %path.display(),
|
||||
expected = config.require_verified_chunks,
|
||||
found = manifest.require_verified_chunks,
|
||||
"Slice manifest verified-chunk requirement mismatch; regenerating"
|
||||
);
|
||||
None
|
||||
} else {
|
||||
Some(manifest)
|
||||
}
|
||||
@@ -312,6 +332,7 @@ pub fn resolve_slice<'a>(
|
||||
¶ms,
|
||||
requested_corpus,
|
||||
config.negative_multiplier,
|
||||
config.require_verified_chunks,
|
||||
config.limit,
|
||||
)
|
||||
});
|
||||
@@ -319,6 +340,8 @@ pub fn resolve_slice<'a>(
|
||||
manifest.requested_limit = config.limit;
|
||||
manifest.requested_corpus = requested_corpus;
|
||||
manifest.negative_multiplier = config.negative_multiplier;
|
||||
manifest.includes_unanswerable = config.llm_mode;
|
||||
manifest.require_verified_chunks = config.require_verified_chunks;
|
||||
|
||||
let mut changed = ensure_shard_paths(&mut manifest);
|
||||
|
||||
@@ -439,6 +462,22 @@ fn load_explicit_slice<'a>(
|
||||
dataset.metadata.id
|
||||
));
|
||||
}
|
||||
if manifest.includes_unanswerable != config.llm_mode {
|
||||
return Err(anyhow!(
|
||||
"slice '{}' includes_unanswerable mismatch (expected {}, found {})",
|
||||
manifest.slice_id,
|
||||
config.llm_mode,
|
||||
manifest.includes_unanswerable
|
||||
));
|
||||
}
|
||||
if manifest.require_verified_chunks != config.require_verified_chunks {
|
||||
return Err(anyhow!(
|
||||
"slice '{}' verified-chunk requirement mismatch (expected {}, found {})",
|
||||
manifest.slice_id,
|
||||
config.require_verified_chunks,
|
||||
manifest.require_verified_chunks
|
||||
));
|
||||
}
|
||||
|
||||
// Validate the manifest before returning.
|
||||
manifest_to_resolved(dataset, index, manifest.clone(), candidate_path.clone())?;
|
||||
@@ -452,6 +491,7 @@ fn empty_manifest(
|
||||
params: &BuildParams,
|
||||
requested_corpus: usize,
|
||||
negative_multiplier: f32,
|
||||
require_verified_chunks: bool,
|
||||
requested_limit: Option<usize>,
|
||||
) -> SliceManifest {
|
||||
SliceManifest {
|
||||
@@ -460,7 +500,8 @@ fn empty_manifest(
|
||||
dataset_id: dataset.metadata.id.clone(),
|
||||
dataset_label: dataset.metadata.label.clone(),
|
||||
dataset_source: dataset.source.clone(),
|
||||
includes_unanswerable: dataset.metadata.include_unanswerable,
|
||||
includes_unanswerable: params.include_impossible,
|
||||
require_verified_chunks,
|
||||
seed: params.base_seed,
|
||||
requested_limit,
|
||||
requested_corpus,
|
||||
@@ -891,6 +932,7 @@ mod tests {
|
||||
slice_seed: 0x5eed_2025,
|
||||
llm_mode: false,
|
||||
negative_multiplier: DEFAULT_NEGATIVE_MULTIPLIER,
|
||||
require_verified_chunks: true,
|
||||
};
|
||||
|
||||
let first = resolve_slice(&dataset, &config)?;
|
||||
@@ -922,6 +964,7 @@ mod tests {
|
||||
slice_seed: 0x5eed_2025,
|
||||
llm_mode: false,
|
||||
negative_multiplier: DEFAULT_NEGATIVE_MULTIPLIER,
|
||||
require_verified_chunks: true,
|
||||
};
|
||||
let resolved = resolve_slice(&dataset, &config)?;
|
||||
let window = select_window(&resolved, 1, Some(1))?;
|
||||
|
||||
@@ -54,9 +54,9 @@ impl Descriptor {
|
||||
embedding_backend: embedding_provider.backend_label().to_string(),
|
||||
embedding_model: embedding_provider.model_code(),
|
||||
embedding_dimension: embedding_provider.dimension(),
|
||||
chunk_min_chars: config.chunk_min_chars,
|
||||
chunk_max_chars: config.chunk_max_chars,
|
||||
rerank_enabled: config.rerank,
|
||||
chunk_min_chars: config.retrieval.chunk_min_chars,
|
||||
chunk_max_chars: config.retrieval.chunk_max_chars,
|
||||
rerank_enabled: config.retrieval.rerank,
|
||||
};
|
||||
|
||||
let dir = config
|
||||
|
||||
Reference in New Issue
Block a user